Munzali commited on
Commit
4f0b287
·
verified ·
1 Parent(s): ed3483a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -46
app.py CHANGED
@@ -1,4 +1,5 @@
1
  ### 1. Imports and class names setup ###
 
2
  import gradio as gr
3
  import os
4
  import torch
@@ -12,89 +13,82 @@ class_names = ['bacterial', 'blast', 'brownspot', 'tungro']
12
 
13
  ### 2. Model and transforms preparation ###
14
 
15
- # Create EffNetB2 model
16
  mobilenet, manual_transforms = create_mobilenet_model(
17
- num_classes=4, # len(class_names) would also work
18
  )
19
 
20
- # Load saved weights
21
  mobilenet.load_state_dict(
22
  torch.load(
23
  f="mobilenet_5_epochs.pth",
24
- map_location=torch.device("cpu"), # load to CPU
25
  )
26
  )
27
 
28
  ### 3. Predict function ###
29
-
30
- # Create predict function
31
  def predict(img) -> Tuple[Dict, float]:
32
- """Transforms and performs a prediction on img and returns prediction and time taken.
33
- """
34
- # Start the timer
35
  start_time = timer()
36
 
37
- # Transform the target image and add a batch dimension
38
  img = manual_transforms(img).unsqueeze(0)
39
 
40
- # Put model into evaluation mode and turn on inference mode
41
  mobilenet.eval()
42
  with torch.inference_mode():
43
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
44
  pred_probs = torch.softmax(mobilenet(img), dim=1)
45
 
46
- # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
47
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
48
 
49
- # Calculate the prediction time
50
  pred_time = round(timer() - start_time, 5)
51
 
52
- # Return the prediction dictionary and prediction time
53
  return pred_labels_and_probs, pred_time
54
 
55
  ### 4. Gradio app ###
56
- # Gradio interface
57
 
58
- def app():
59
- with gr.Blocks():
60
- with gr.Row():
61
- with gr.Column():
62
- image = gr.Image(type="pil", label="Image")
63
- infer = gr.Button(value="Submit")
64
 
65
- with gr.Column():
66
- output_image = [gr.Label(num_top_classes=4, label="Predictions"), # what are the outputs?
67
- gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
68
- with gr.Column():
69
- example_list = [["examples/" + example] for example in os.listdir("examples")]
70
- examples = gr.Examples(examples=example_list,
71
- inputs=gr.Image(type='pil') )
72
- infer.click(
73
- fn=predict,
74
- inputs=image,
75
- outputs=output_image,
76
-
77
- )
78
- #[gr.Textbox(label="greeting", lines=1)], gr.Image(image_mode="RGB")
79
- gradio_app = gr.Blocks()
80
-
81
- with gradio_app:
82
  gr.HTML(
83
  """
84
- <h1 style='text-align: center'>
85
- SIAMESE: Real-Time End-to-End Face Security System
86
- </h1>
87
- """)
 
 
88
  gr.HTML(
89
  """
90
  <h3 style='text-align: center'>
91
  Follow me for more!
92
- <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> | <a href='https://www.huggingface.co/kadirnar/' target='_blank'>HuggingFace</a>
 
 
 
93
  </h3>
94
- """)
 
 
95
  with gr.Row():
96
  with gr.Column():
97
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # gradio_app.launch(debug=True, share=True)
99
  # # Create title, description and article strings
100
  # title = "RICE DISEASES CLASSIFICATION"
 
1
  ### 1. Imports and class names setup ###
2
+ ### 1. Imports and class names setup ###
3
  import gradio as gr
4
  import os
5
  import torch
 
13
 
14
  ### 2. Model and transforms preparation ###
15
 
 
16
  mobilenet, manual_transforms = create_mobilenet_model(
17
+ num_classes=4
18
  )
19
 
 
20
  mobilenet.load_state_dict(
21
  torch.load(
22
  f="mobilenet_5_epochs.pth",
23
+ map_location=torch.device("cpu"),
24
  )
25
  )
26
 
27
  ### 3. Predict function ###
 
 
28
  def predict(img) -> Tuple[Dict, float]:
 
 
 
29
  start_time = timer()
30
 
 
31
  img = manual_transforms(img).unsqueeze(0)
32
 
 
33
  mobilenet.eval()
34
  with torch.inference_mode():
 
35
  pred_probs = torch.softmax(mobilenet(img), dim=1)
36
 
 
37
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
38
 
 
39
  pred_time = round(timer() - start_time, 5)
40
 
 
41
  return pred_labels_and_probs, pred_time
42
 
43
  ### 4. Gradio app ###
 
44
 
45
+ # Create a Blocks app (only one!)
46
+ with gr.Blocks() as gradio_app:
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  gr.HTML(
49
  """
50
+ <h1 style='text-align: center'>
51
+ Rice Disease Detection - MobileNet Model
52
+ </h1>
53
+ """
54
+ )
55
+
56
  gr.HTML(
57
  """
58
  <h3 style='text-align: center'>
59
  Follow me for more!
60
+ <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> |
61
+ <a href='https://github.com/kadirnar' target='_blank'>Github</a> |
62
+ <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> |
63
+ <a href='https://www.huggingface.co/kadirnar/' target='_blank'>HuggingFace</a>
64
  </h3>
65
+ """
66
+ )
67
+
68
  with gr.Row():
69
  with gr.Column():
70
+ image = gr.Image(type="pil", label="Upload Image")
71
+ infer = gr.Button(value="Predict")
72
+
73
+ # Examples linked to the input component 'image'
74
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
75
+ gr.Examples(
76
+ examples=example_list,
77
+ inputs=[image] # Pass the actual input component
78
+ )
79
+
80
+ with gr.Column():
81
+ label = gr.Label(num_top_classes=4, label="Predictions")
82
+ pred_time = gr.Number(label="Prediction Time (s)")
83
+
84
+ infer.click(
85
+ fn=predict,
86
+ inputs=[image],
87
+ outputs=[label, pred_time]
88
+ )
89
+
90
+ # Launch the app
91
+ gradio_app.launch()
92
  # gradio_app.launch(debug=True, share=True)
93
  # # Create title, description and article strings
94
  # title = "RICE DISEASES CLASSIFICATION"