itserphan commited on
Commit
5541f6d
·
verified ·
1 Parent(s): bafbf08

updated the code to run the app.

Browse files
Files changed (1) hide show
  1. app.py +39 -44
app.py CHANGED
@@ -1,60 +1,55 @@
1
- ### 1. Imports and class names setup ###
2
  import gradio as gr
3
  import os
4
  import torch
5
-
6
  from model import create_effnetb2_model
7
  from timeit import default_timer as timer
8
  from typing import Tuple, Dict
9
 
10
- # Setting up the class names
11
  with open("class_names.txt", "r") as f:
12
- class_names = [food.strip() for food in f.readlines()]
 
13
 
14
- ### 2. Model and transforms preparation ###
15
- # Create model and transforms
16
- effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
17
 
18
- # Load the saved Weights
19
  effnetb2.load_state_dict(
20
- torch.load(f="09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
21
- map_location = torch.device("cpu"))
 
 
22
  )
23
 
24
- ### 3. Predict Function ###
25
  def predict(img) -> Tuple[Dict, float]:
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- start_time = timer()
28
-
29
- img = effnetb2_transforms(img).unsqueeze(0) # Unsqueeze == Add batch dimension on 0th index
30
-
31
- effnetb2.eval()
32
- with torch.inference_mode():
33
- pred_probs = torch.softmax(effnetb2(img), dim = 1)
34
-
35
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
36
-
37
- end_time = timer()
38
- pred_time = round(end_time - start_time, 4)
39
-
40
- return pred_labels_and_probs, pred_time
41
-
42
- ### 4. Gradio app ###
43
- # Create title,, description and articcle
44
  title = "FoodVision Big 🍔👁"
45
- description = "An [EfficientNetB2 Feature Extractor](https://pytorch.org/vision/stable/models/efficientnet.html) computer vision model to classify [101 classes of food from the Food101 Dataset](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/food101_class_names.txt)"
46
- article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
47
-
48
- # Create example list
49
- example_list = [["examples/" + example] for example in os.listdir("examples")]
50
- # Create the Gradio demo
51
- demo =gr.Interface(fn=predict,
52
- inputs=gr.Image(type="pil"),
53
- outputs=[gr.Label(num_top_classes=5, label=("Predicitions")),
54
- gr.Number(label="Prediction Time (s)")],
55
- examples=example_list,
56
- title=title,
57
- description=description,
58
- article=article)
59
- # Launch the demo!
60
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import os
3
  import torch
 
4
  from model import create_effnetb2_model
5
  from timeit import default_timer as timer
6
  from typing import Tuple, Dict
7
 
 
8
  with open("class_names.txt", "r") as f:
9
+ class_names = [food.strip() for food in f.readlines()]
10
+
11
 
12
+ effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
 
 
13
 
 
14
  effnetb2.load_state_dict(
15
+ torch.load(
16
+ f="09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
17
+ map_location=torch.device("cpu")
18
+ )
19
  )
20
 
 
21
  def predict(img) -> Tuple[Dict, float]:
22
+ start_time = timer()
23
+
24
+ img = effnetb2_transforms(img).unsqueeze(0)
25
+
26
+ effnetb2.eval()
27
+ with torch.inference_mode():
28
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
29
+
30
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
31
+
32
+ pred_time = round(timer() - start_time, 4)
33
+ return pred_labels_and_probs, pred_time
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  title = "FoodVision Big 🍔👁"
36
+ description = "An EfficientNetB2 feature extractor to classify 101 classes of food from the Food101 dataset."
37
+ article = "Created during PyTorch Model Deployment study."
38
+
39
+ example_list = [["examples/" + example] for example in os.listdir("examples") if example.endswith(('.png', '.jpg', '.jpeg'))]
40
+
41
+ demo = gr.Interface(
42
+ fn=predict,
43
+ inputs=gr.Image(type="pil"),
44
+ outputs=[
45
+ gr.Label(num_top_classes=5, label="Predictions"),
46
+ gr.Number(label="Prediction Time (s)")
47
+ ],
48
+ examples=example_list,
49
+ title=title,
50
+ description=description,
51
+ article=article
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch()