asad9641 commited on
Commit
3898444
·
verified ·
1 Parent(s): a389ede

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as T
5
+ import torchvision.models as models
6
+ from PIL import Image
7
+ import json
8
+ import os
9
+
10
+ # -----------------------------
11
+ # Load Class Names
12
+ # -----------------------------
13
+ model_path = "model/model.pth"
14
+
15
+ checkpoint = torch.load(model_path, map_location="cpu")
16
+ class_names = checkpoint["class_names"]
17
+
18
+ # -----------------------------
19
+ # Load Model
20
+ # -----------------------------
21
+ model = models.resnet50(pretrained=False)
22
+ model.fc = nn.Linear(model.fc.in_features, len(class_names))
23
+ model.load_state_dict(checkpoint["model_state_dict"], strict=True)
24
+ model.eval()
25
+
26
+ # -----------------------------
27
+ # Image Preprocessing
28
+ # -----------------------------
29
+ transform = T.Compose([
30
+ T.Resize((224,224)),
31
+ T.ToTensor(),
32
+ T.Normalize([0.485, 0.456, 0.406],
33
+ [0.229, 0.224, 0.225])
34
+ ])
35
+
36
+ # -----------------------------
37
+ # Prediction Function
38
+ # -----------------------------
39
+ def predict(img):
40
+ img = transform(img).unsqueeze(0)
41
+ with torch.no_grad():
42
+ outputs = model(img)
43
+ probs = torch.softmax(outputs[0], dim=0)
44
+
45
+ # Return top 3 predictions
46
+ top3_probs, top3_idxs = torch.topk(probs, 3)
47
+ result = {class_names[i]: float(top3_probs[idx])
48
+ for idx, i in enumerate(top3_idxs)}
49
+
50
+ return result
51
+
52
+ # -----------------------------
53
+ # Gradio Interface
54
+ # -----------------------------
55
+ title = "🐾 Animal Classifier — ResNet50 Fine-Tuned"
56
+ description = """
57
+ Upload an image of an animal and the model will predict what species it is.
58
+ """
59
+
60
+ iface = gr.Interface(
61
+ fn=predict,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs=gr.Label(num_top_classes=3),
64
+ title=title,
65
+ description=description,
66
+ )
67
+
68
+ iface.launch()