NagashreePai commited on
Commit
5957fe9
Β·
verified Β·
1 Parent(s): b0ba491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -1
app.py CHANGED
@@ -35,7 +35,7 @@ model.eval()
35
  print("[INFO] Model loaded successfully.")
36
 
37
  # βœ… Class names
38
- class_names = ["Broadleaf", "Grass", "Soil", "Soybean"]
39
 
40
  # πŸ” Transform
41
  transform = transforms.Compose([
@@ -73,3 +73,75 @@ interface = gr.Interface(
73
  )
74
 
75
  interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  print("[INFO] Model loaded successfully.")
36
 
37
  # βœ… Class names
38
+ class_names = []
39
 
40
  # πŸ” Transform
41
  transform = transforms.Compose([
 
73
  )
74
 
75
  interface.launch()
76
+
77
+
78
+ import gradio as gr
79
+ import torch
80
+ import torch.nn as nn
81
+ from torchvision import transforms
82
+ from torchvision.models import swin_t
83
+ from PIL import Image
84
+
85
+ # πŸ”§ Model definition
86
+ class MMIM(nn.Module):
87
+ def __init__(self, num_classes=9):
88
+ super(MMIM, self).__init__()
89
+ self.backbone = swin_t(weights='IMAGENET1K_V1')
90
+ self.backbone.head = nn.Identity()
91
+ self.classifier = nn.Sequential(
92
+ nn.Linear(768, 512),
93
+ nn.ReLU(),
94
+ nn.Dropout(0.3),
95
+ nn.Linear(512, num_classes)
96
+ )
97
+
98
+ def forward(self, x):
99
+ features = self.backbone(x)
100
+ return self.classifier(features)
101
+
102
+ # βœ… Load model
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ model = MMIM(num_classes=4)
105
+ model.load_state_dict(torch.load("MMIM_best3.pth", map_location=device))
106
+ model.to(device)
107
+ model.eval()
108
+
109
+ # βœ… Updated class names (match folder structure)
110
+ class_names = [
111
+ "Broadleaf", "Grass", "Soil", "Soybean"
112
+ ]
113
+
114
+ # πŸ” Image transform
115
+ transform = transforms.Compose([
116
+ transforms.Resize((224, 224)),
117
+ transforms.ToTensor()
118
+ ])
119
+
120
+ # πŸ” Prediction function with negative detection
121
+ def predict(img):
122
+ img = img.convert('RGB')
123
+ img_tensor = transform(img).unsqueeze(0).to(device)
124
+
125
+ with torch.no_grad():
126
+ outputs = model(img_tensor)
127
+ probs = torch.softmax(outputs, dim=1)
128
+ conf, pred = torch.max(probs, 1)
129
+
130
+ predicted_class = class_names[pred.item()]
131
+ confidence = conf.item() * 100
132
+
133
+ if predicted_class.lower() == "negative":
134
+ return f"⚠️ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
135
+
136
+ return f"βœ… Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
137
+
138
+ # 🎨 Gradio Interface
139
+ interface = gr.Interface(
140
+ fn=predict,
141
+ inputs=gr.Image(type="pil"),
142
+ outputs="text",
143
+ title="Weed Image Classifier",
144
+ description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
145
+ )
146
+
147
+ interface.launch()