Jasur05 commited on
Commit
4757677
·
verified ·
1 Parent(s): 034050c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -53,7 +53,6 @@ class_classes = [
53
  "End of no passing by vehicles over 3.5 metric tons"
54
  ]
55
 
56
- # 2. model and transfomrs prep
57
 
58
  effnetb2, effnetb2_transforms = create_effnetb2_model(43)
59
 
@@ -65,7 +64,7 @@ effnetb2_transforms_new = torchvision.transforms.Compose([
65
 
66
  effnetb2.load_state_dict(torch.load(f="effnetb2_traffic_sign_recognition.pth", map_location=torch.device("cpu")))
67
 
68
- # predict function
69
 
70
  def predict(
71
  img,
@@ -81,17 +80,17 @@ def predict(
81
  """
82
  start = timer()
83
 
84
- # 1. Pre-process
85
  img_t = transform(img).unsqueeze(0)
86
 
87
- # 2. Forward pass
88
  model.eval()
89
  with torch.inference_mode():
90
  logits = model(img_t)
91
- probs = torch.softmax(logits, dim=1).squeeze(0) # shape [43]
92
 
93
  # 3. Top-k
94
- top_probs, top_idxs = probs.topk(k) # tensors of length k
95
  pred_topk = {
96
  class_classes[int(idx)]: float(prob)
97
  for idx, prob in zip(top_idxs, top_probs)
@@ -101,7 +100,7 @@ def predict(
101
  return pred_topk, pred_time
102
 
103
 
104
- # 4. gradio app
105
 
106
  import gradio as gr
107
 
 
53
  "End of no passing by vehicles over 3.5 metric tons"
54
  ]
55
 
 
56
 
57
  effnetb2, effnetb2_transforms = create_effnetb2_model(43)
58
 
 
64
 
65
  effnetb2.load_state_dict(torch.load(f="effnetb2_traffic_sign_recognition.pth", map_location=torch.device("cpu")))
66
 
67
+
68
 
69
  def predict(
70
  img,
 
80
  """
81
  start = timer()
82
 
83
+
84
  img_t = transform(img).unsqueeze(0)
85
 
86
+
87
  model.eval()
88
  with torch.inference_mode():
89
  logits = model(img_t)
90
+ probs = torch.softmax(logits, dim=1).squeeze(0)
91
 
92
  # 3. Top-k
93
+ top_probs, top_idxs = probs.topk(k)
94
  pred_topk = {
95
  class_classes[int(idx)]: float(prob)
96
  for idx, prob in zip(top_idxs, top_probs)
 
100
  return pred_topk, pred_time
101
 
102
 
103
+
104
 
105
  import gradio as gr
106