NagashreePai commited on
Commit
ecee537
·
verified ·
1 Parent(s): 1a86e5d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.models import swin_t
6
+ from PIL import Image
7
+
8
+ # 🔧 Model definition
9
+ class MMIM(nn.Module):
10
+ def __init__(self, num_classes=9):
11
+ super(MMIM, self).__init__()
12
+ self.backbone = swin_t(weights='IMAGENET1K_V1')
13
+ self.backbone.head = nn.Identity()
14
+ self.classifier = nn.Sequential(
15
+ nn.Linear(768, 512),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.3),
18
+ nn.Linear(512, num_classes)
19
+ )
20
+
21
+ def forward(self, x):
22
+ features = self.backbone(x)
23
+ return self.classifier(features)
24
+
25
+ # ✅ Load model
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MMIM(num_classes=4)
28
+ model.load_state_dict(torch.load("MMIM_best3.pth", map_location=device))
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ # ✅ Updated class names (match folder structure)
33
+ class_names = [
34
+ "Broadleaf", "Grass", "Soil", "Soybean",
35
+ ]
36
+
37
+ # 🔁 Image transform
38
+ transform = transforms.Compose([
39
+ transforms.Resize((224, 224)),
40
+ transforms.ToTensor()
41
+ ])
42
+
43
+ # 🔍 Prediction function with negative detection
44
+ def predict(img):
45
+ img = img.convert('RGB')
46
+ img_tensor = transform(img).unsqueeze(0).to(device)
47
+
48
+ with torch.no_grad():
49
+ outputs = model(img_tensor)
50
+ probs = torch.softmax(outputs, dim=1)
51
+ conf, pred = torch.max(probs, 1)
52
+
53
+ predicted_class = class_names[pred.item()]
54
+ confidence = conf.item() * 100
55
+
56
+ if predicted_class.lower() == "negative":
57
+ return f"⚠️ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
58
+
59
+ return f"✅ Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
60
+
61
+ # 🎨 Gradio Interface
62
+ interface = gr.Interface(
63
+ fn=predict,
64
+ inputs=gr.Image(type="pil"),
65
+ outputs="text",
66
+ title="Weed Image Classifier",
67
+ description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
68
+ )
69
+
70
+ interface.launch()
71
+