NagashreePai commited on
Commit
7689141
·
verified ·
1 Parent(s): 653b330

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.models import swin_t
6
+ from PIL import Image
7
+
8
+ # 🔧 Model definition
9
+ class MMIM(nn.Module):
10
+ def __init__(self, num_classes=9):
11
+ super(MMIM, self).__init__()
12
+ self.backbone = swin_t(weights='IMAGENET1K_V1')
13
+ self.backbone.head = nn.Identity()
14
+ self.classifier = nn.Sequential(
15
+ nn.Linear(768, 512),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.3),
18
+ nn.Linear(512, num_classes)
19
+ )
20
+
21
+ def forward(self, x):
22
+ features = self.backbone(x)
23
+ return self.classifier(features)
24
+
25
+ # ✅ Load model
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MMIM(num_classes=9)
28
+ model.load_state_dict(torch.load("MMIM_best.pth", map_location=device))
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ # ✅ Updated class names (match folder structure)
33
+ class_names = [
34
+ "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
35
+ "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed"
36
+ ]
37
+
38
+ # 🔁 Image transform
39
+ transform = transforms.Compose([
40
+ transforms.Resize((224, 224)),
41
+ transforms.ToTensor()
42
+ ])
43
+
44
+ # 🔍 Prediction function with negative detection
45
+ def predict(img):
46
+ img = img.convert('RGB')
47
+ img_tensor = transform(img).unsqueeze(0).to(device)
48
+
49
+ with torch.no_grad():
50
+ outputs = model(img_tensor)
51
+ probs = torch.softmax(outputs, dim=1)
52
+ conf, pred = torch.max(probs, 1)
53
+
54
+ predicted_class = class_names[pred.item()]
55
+ confidence = conf.item() * 100
56
+
57
+ if predicted_class.lower() == "negative":
58
+ return f"⚠️ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
59
+
60
+ return f"✅ Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
61
+
62
+ # 🎨 Gradio Interface
63
+ interface = gr.Interface(
64
+ fn=predict,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs="text",
67
+ title="Weed Image Classifier",
68
+ description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
69
+ )
70
+
71
+ interface.launch()