Jamshid15 commited on
Commit
b912248
·
1 Parent(s): 83940c0

Add changes

Browse files
Files changed (2) hide show
  1. app.py +16 -7
  2. model.py +1 -4
app.py CHANGED
@@ -20,12 +20,21 @@ class_names = ['apple_pie','baby_back_ribs','baklava','beef_carpaccio','beef_tar
20
  effnetb4, effnetb4_transforms = create_effnetb4_model(num_classes=len(class_names))
21
 
22
  # Load the saved weights
23
- effnetb4.load_state_dict(
24
- torch.load(
25
- f="pretrained_efficientnet_b4_model.pth",
26
- map_location=torch.device("cpu") # Load the model to the CPU
27
- )
28
- )
 
 
 
 
 
 
 
 
 
29
 
30
  # Prediciton function
31
  def predict(img) -> Tuple[Dict, float]:
@@ -58,7 +67,7 @@ example_list = [["examples/" + example] for example in os.listdir("examples")]
58
  demo = gr.Interface(fn=predict, # maps inputs to outputs
59
  inputs=gr.Image(type="pil"),
60
  outputs=[gr.Label(num_top_classes=5, label='Predictions'),
61
- gr.Number(label="Prediction time (s}")],
62
  examples=example_list,
63
  title=title,
64
  description=description,
 
20
  effnetb4, effnetb4_transforms = create_effnetb4_model(num_classes=len(class_names))
21
 
22
  # Load the saved weights
23
+ checkpoint = torch.load("pretrained_efficientnet_b4_model.pth", map_location=torch.device("cpu"))
24
+ state_dict = checkpoint
25
+
26
+
27
+ # If the state_dict contains "module." due to DataParallel, we need to remove it
28
+ state_dict = checkpoint
29
+ if "module." in list(state_dict.keys())[0]:
30
+ from collections import OrderedDict
31
+ new_state_dict = OrderedDict()
32
+ for k, v in state_dict.items():
33
+ new_state_dict[k.replace("module.", "")] = v
34
+ state_dict = new_state_dict
35
+
36
+ # Load the state dict into the model
37
+ effnetb4.load_state_dict(state_dict)
38
 
39
  # Prediciton function
40
  def predict(img) -> Tuple[Dict, float]:
 
67
  demo = gr.Interface(fn=predict, # maps inputs to outputs
68
  inputs=gr.Image(type="pil"),
69
  outputs=[gr.Label(num_top_classes=5, label='Predictions'),
70
+ gr.Number(label="Prediction time (s)")],
71
  examples=example_list,
72
  title=title,
73
  description=description,
model.py CHANGED
@@ -8,11 +8,8 @@ def create_effnetb4_model(num_classes:int = 101):
8
  weights = torchvision.models.EfficientNet_B4_Weights.DEFAULT
9
  transforms = weights.transforms()
10
 
11
- model = torchvision.models.efficientnet_b4(pretrained=False)
12
 
13
  model.classifier[1] = nn.Linear(in_features=1792, out_features=num_classes)
14
 
15
- # Use DataParallel to match the saved state dict
16
- model = torch.nn.DataParallel(model)
17
-
18
  return model, transforms
 
8
  weights = torchvision.models.EfficientNet_B4_Weights.DEFAULT
9
  transforms = weights.transforms()
10
 
11
+ model = torchvision.models.efficientnet_b4(weights=weights)
12
 
13
  model.classifier[1] = nn.Linear(in_features=1792, out_features=num_classes)
14
 
 
 
 
15
  return model, transforms