sosohrabian commited on
Commit
00550ea
·
verified ·
1 Parent(s): b873d60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -7,42 +7,44 @@ import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load class names (must match training order)
11
  with open("classes.json", "r", encoding="utf-8") as f:
12
  class_names = json.load(f)
13
 
14
- # Build EfficientNet-B0 and replace classifier
15
  model = models.efficientnet_b0(weights=None)
16
  num_ftrs = model.classifier[1].in_features
17
  model.classifier[1] = nn.Linear(num_ftrs, len(class_names))
18
 
19
- # Load trained weights
20
  state_dict = torch.load("best_efficientnetb0.pt", map_location=DEVICE)
21
  model.load_state_dict(state_dict)
22
  model.to(DEVICE)
23
  model.eval()
24
 
25
- # Same preprocessing as your val_transform
26
  val_transform = Compose([
27
  Resize((224, 224)),
28
  ToTensor(),
29
- Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
30
  ])
31
 
32
  @torch.no_grad()
33
  def predict(image):
34
- # image comes as PIL
35
  x = val_transform(image).unsqueeze(0).to(DEVICE)
36
  logits = model(x)
37
  probs = torch.softmax(logits, dim=1)[0].cpu().tolist()
38
- return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
 
 
39
 
40
  demo = gr.Interface(
41
  fn=predict,
42
  inputs=gr.Image(type="pil", label="Upload a dog image"),
43
  outputs=gr.Label(num_top_classes=5, label="Prediction"),
44
  title="Dog Breed Classifier (EfficientNet-B0)",
45
- description="Upload a dog photo to get breed prediction + confidence.",
46
  )
47
 
48
  if __name__ == "__main__":
 
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # 1) Load class names from your saved file
11
  with open("classes.json", "r", encoding="utf-8") as f:
12
  class_names = json.load(f)
13
 
14
+ # 2) Build the model architecture (no downloading on the server)
15
  model = models.efficientnet_b0(weights=None)
16
  num_ftrs = model.classifier[1].in_features
17
  model.classifier[1] = nn.Linear(num_ftrs, len(class_names))
18
 
19
+ # 3) Load your trained weights
20
  state_dict = torch.load("best_efficientnetb0.pt", map_location=DEVICE)
21
  model.load_state_dict(state_dict)
22
  model.to(DEVICE)
23
  model.eval()
24
 
25
+ # 4) Same preprocessing as validation/testing
26
  val_transform = Compose([
27
  Resize((224, 224)),
28
  ToTensor(),
29
+ Normalize(mean=[0.485, 0.456, 0.406],
30
+ std=[0.229, 0.224, 0.225]),
31
  ])
32
 
33
  @torch.no_grad()
34
  def predict(image):
 
35
  x = val_transform(image).unsqueeze(0).to(DEVICE)
36
  logits = model(x)
37
  probs = torch.softmax(logits, dim=1)[0].cpu().tolist()
38
+ # show top-5 nicely
39
+ conf = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
40
+ return conf
41
 
42
  demo = gr.Interface(
43
  fn=predict,
44
  inputs=gr.Image(type="pil", label="Upload a dog image"),
45
  outputs=gr.Label(num_top_classes=5, label="Prediction"),
46
  title="Dog Breed Classifier (EfficientNet-B0)",
47
+ description="Upload a dog photo and the model predicts the breed with confidence."
48
  )
49
 
50
  if __name__ == "__main__":