NagashreePai commited on
Commit
84b609f
Β·
verified Β·
1 Parent(s): 3881eaa

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +54 -86
utils.py CHANGED
@@ -82,93 +82,61 @@
82
  # top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
83
  # return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
84
 
85
-
86
-
87
  import torch
88
- import torch.nn.functional as F
89
- from torchvision import transforms
90
- from PIL import Image
91
- import gradio as gr
92
  import torch.nn as nn
93
- from torchvision.models import resnet18 # Example; change to your actual architecture
94
-
95
- # βœ… Define your 25 class names (index 0 β†’ class 1)
96
- class_names = [
97
- "Capplehinee ", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
98
- "Rubber vine", "Siam weed", "Snake weed", # 1-9 (Model 1)
99
- "Broadleaf", # class10 (Model 3)
100
- "Grass", # class11
101
- "Soil", # class12
102
- "Soybean", # class13
103
- "Black grass", # class14 (Model 2)
104
- "Charlock", # class15
105
- "Cleavers", # class16
106
- "Common Chickweed", # class17
107
- "Common Wheat", # class18
108
- "Fat Hen", # class19
109
- "Loose Silky-bent", # class20
110
- "Maize", # class21
111
- "Scentless Mayweed", # class22
112
- "Shepherds purse", # class23
113
- "Small-flowered Cranesbill", # class24
114
- "Sugar beet" # 14-25 (Model 2)
115
- ]
116
-
117
- # βœ… Define transforms (adjust to match your model training)
118
- transform = transforms.Compose([
119
- transforms.Resize((224, 224)),
120
- transforms.ToTensor(),
121
- transforms.Normalize(mean=[0.5, 0.5, 0.5],
122
- std=[0.5, 0.5, 0.5])
123
- ])
124
-
125
- # βœ… Load your models (use correct architecture and weights)
126
- def get_model(num_classes):
127
- model = resnet18(pretrained=False)
128
- model.fc = nn.Linear(model.fc.in_features, num_classes)
129
- return model
130
-
131
- model1 = get_model(9) # for class 1-9
132
- model3 = get_model(4) # for class 10-13
133
- model2 = get_model(12) # for class 14-25
134
-
135
- model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
136
- model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
137
- model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
138
-
139
- model1.eval()
140
- model2.eval()
141
- model3.eval()
142
-
143
- # βœ… Inference function
144
- def predict(image):
145
- image_tensor = transform(image).unsqueeze(0)
146
 
147
  with torch.no_grad():
148
- out1 = F.softmax(model1(image_tensor), dim=1) # [1, 9]
149
- out3 = F.softmax(model3(image_tensor), dim=1) # [1, 4]
150
- out2 = F.softmax(model2(image_tensor), dim=1) # [1, 12]
151
-
152
- # Combine into a 25-class vector
153
- combined = torch.cat([out1, out3, out2], dim=1) # shape: [1, 25]
154
- pred_idx = combined.argmax(dim=1).item()
155
- confidence = combined.max().item()
156
-
157
- # Optional rejection
158
- if confidence < 0.5:
159
- return "Prediction uncertain or unknown class"
160
-
161
- return f"Predicted: {class_names[pred_idx]} (Confidence: {confidence:.2f})"
162
-
163
- # βœ… Gradio app
164
- app = gr.Interface(
165
- fn=predict,
166
- inputs=gr.Image(type="pil"),
167
- outputs="text",
168
- title="Weed Classifier - 25 Class Combined (3 Models)",
169
- description="Upload an image to classify weeds across 25 species using 3 separate models."
170
- )
171
-
172
- # βœ… Launch
173
- if __name__ == "__main__":
174
- app.launch()
 
82
  # top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
83
  # return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
84
 
 
 
85
  import torch
 
 
 
 
86
  import torch.nn as nn
87
+ from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
88
+ import torch.nn.functional as F
89
+
90
+ # βœ… Define MMIM architecture (same as used during training)
91
+ class MMIM(nn.Module):
92
+ def __init__(self, num_classes):
93
+ super(MMIM, self).__init__()
94
+ self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
95
+ self.backbone.head = nn.Identity()
96
+ self.classifier = nn.Sequential(
97
+ nn.Linear(768, 512),
98
+ nn.ReLU(),
99
+ nn.Dropout(0.3),
100
+ nn.Linear(512, num_classes)
101
+ )
102
+
103
+ def forward(self, x):
104
+ features = self.backbone(x)
105
+ return self.classifier(features)
106
+
107
+ # βœ… Load all 3 models
108
+ def load_all_models():
109
+ model1 = MMIM(num_classes=9) # class1–9
110
+ model2 = MMIM(num_classes=12) # class14–25
111
+ model3 = MMIM(num_classes=4) # class10–13
112
+
113
+ model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
114
+ model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
115
+ model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
116
+
117
+ model1.eval()
118
+ model2.eval()
119
+ model3.eval()
120
+
121
+ return model1, model2, model3
122
+
123
+ # βœ… Inference combining raw logits before softmax
124
+ def predict_image(image, model1, model2, model3, transform, class_names):
125
+ image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  with torch.no_grad():
128
+ logit1 = model1(image_tensor) # [1, 9]
129
+ logit3 = model3(image_tensor) # [1, 4]
130
+ logit2 = model2(image_tensor) # [1, 12]
131
+
132
+ # βœ… Combine logits (not softmax) β†’ then apply softmax
133
+ combined_logits = torch.cat([logit1, logit3, logit2], dim=1) # [1, 25]
134
+ combined_probs = F.softmax(combined_logits, dim=1) # unified softmax
135
+
136
+ pred_idx = combined_probs.argmax(dim=1).item()
137
+ confidence = combined_probs[0, pred_idx].item()
138
+
139
+ return class_names[pred_idx], confidence
140
+
141
+
142
+