ma4389 commited on
Commit
0e9fcd9
·
verified ·
1 Parent(s): d8ac0a0

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +55 -0
  2. best_model.pth +3 -0
  3. reqruiements.txt +4 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # 🖥️ Device (CPU for Gradio unless you have GPU setup)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # 🔨 Rebuild your model
11
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
12
+ in_features = resnet.fc.in_features
13
+ resnet.fc = nn.Sequential(
14
+ nn.Linear(in_features, 1024),
15
+ nn.ReLU(),
16
+ nn.Dropout(0.5),
17
+ nn.Linear(1024, 3) # 3 classes: dog, wild, cat
18
+ )
19
+ resnet = resnet.to(device)
20
+
21
+ # 📥 Load saved weights
22
+ resnet.load_state_dict(torch.load("model_model.pth", map_location=device))
23
+ resnet.eval()
24
+
25
+ # 🖼️ Validation transforms
26
+ val_transforms = transforms.Compose([
27
+ transforms.Lambda(lambda img: img.convert("RGB")), # 🧠 Force 3-channel
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
31
+ ])
32
+
33
+ # 🏷️ Class names
34
+ class_names = ["dog", "wild", "cat"]
35
+
36
+ # 🔮 Prediction function
37
+ def classify_image(img):
38
+ img = val_transforms(img).unsqueeze(0).to(device) # Add batch dim & send to device
39
+ with torch.no_grad():
40
+ outputs = resnet(img)
41
+ probs = torch.softmax(outputs, dim=1)
42
+ confidences = probs.squeeze().cpu().tolist()
43
+ predicted_class = class_names[torch.argmax(probs).item()]
44
+ return {class_names[i]: confidences[i] for i in range(len(class_names))}
45
+
46
+ # 🎨 Gradio Interface
47
+ iface = gr.Interface(
48
+ fn=classify_image,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=gr.Label(num_top_classes=3),
51
+ title="Dog/Wild/Cat Classifier 🐶🐯🐱",
52
+ description="Upload an image to classify it as Dog, Wild Animal, or Cat."
53
+ )
54
+
55
+ iface.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:303c454f2b4f2b116458dda641005045de2774df886bdafabd50bd22d6d5b3ec
3
+ size 102756416
reqruiements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ Pillow>=9.0.0