NagashreePai commited on
Commit
4b4e2d5
·
verified ·
1 Parent(s): 0464d2a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +170 -79
utils.py CHANGED
@@ -1,83 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
- import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torchvision.models import swin_t
5
  from torchvision import transforms
6
  from PIL import Image
7
- import os
8
-
9
- # --- MMIM model class ---
10
- class MMIM(nn.Module):
11
- def __init__(self, num_classes):
12
- super(MMIM, self).__init__()
13
- self.backbone = swin_t(weights='IMAGENET1K_V1')
14
- self.backbone.head = nn.Identity()
15
- self.classifier = nn.Sequential(
16
- nn.Linear(768, 512),
17
- nn.ReLU(),
18
- nn.Dropout(0.3),
19
- nn.Linear(512, num_classes)
20
- )
21
-
22
- def forward(self, x):
23
- x = self.backbone(x)
24
- return self.classifier(x)
25
-
26
- # --- Load models with offsets ---
27
- def load_all_models():
28
- model_defs = [
29
- ("MMIM_best1.pth", 9),
30
- ("MMIM_best3.pth", 4),
31
- ("MMIM_best2.pth", 12)
32
- ]
33
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
-
35
- models = []
36
- offsets = []
37
- total_classes = 0
38
- for path, num_classes in model_defs:
39
- model = MMIM(num_classes)
40
- state_dict = torch.load(path, map_location=device)
41
- model.load_state_dict(state_dict)
42
- model.to(device)
43
- model.eval()
44
- models.append(model)
45
- offsets.append(total_classes)
46
- total_classes += num_classes
47
-
48
- # Generate dummy class names like class0, class1, ...
49
- idx_to_class = {i: f"class{i}" for i in range(total_classes)}
50
- return models, offsets, idx_to_class
51
-
52
- # --- Inference on one image ---
53
- def predict_image(image, models, offsets, idx_to_class):
54
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
55
- transform = transforms.Compose([
56
- transforms.Resize((224, 224)),
57
- transforms.ToTensor(),
58
- transforms.Normalize([0.5]*3, [0.5]*3)
59
- ])
60
- image_tensor = transform(image).unsqueeze(0).to(device)
61
-
62
- temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence
63
-
64
- max_score = float('-inf')
65
- final_pred = -1
66
- probs_combined = {}
67
-
68
- for model, offset, temp in zip(models, offsets, temperatures):
69
- with torch.no_grad():
70
- logits = model(image_tensor) / temp
71
- probs = F.softmax(logits, dim=1).squeeze(0)
72
- top_score, top_class = torch.max(probs, dim=0)
73
- if top_score.item() > max_score:
74
- max_score = top_score.item()
75
- final_pred = top_class.item() + offset
76
-
77
- # Also collect probabilities for all classes
78
- for i, p in enumerate(probs):
79
- probs_combined[offset + i] = p.item()
80
-
81
- # Sort top 3
82
- top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
83
- return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torch.nn as nn
3
+ # import torch.nn.functional as F
4
+ # from torchvision.models import swin_t
5
+ # from torchvision import transforms
6
+ # from PIL import Image
7
+ # import os
8
+
9
+ # # --- MMIM model class ---
10
+ # class MMIM(nn.Module):
11
+ # def __init__(self, num_classes):
12
+ # super(MMIM, self).__init__()
13
+ # self.backbone = swin_t(weights='IMAGENET1K_V1')
14
+ # self.backbone.head = nn.Identity()
15
+ # self.classifier = nn.Sequential(
16
+ # nn.Linear(768, 512),
17
+ # nn.ReLU(),
18
+ # nn.Dropout(0.3),
19
+ # nn.Linear(512, num_classes)
20
+ # )
21
+
22
+ # def forward(self, x):
23
+ # x = self.backbone(x)
24
+ # return self.classifier(x)
25
+
26
+ # # --- Load models with offsets ---
27
+ # def load_all_models():
28
+ # model_defs = [
29
+ # ("MMIM_best1.pth", 9),
30
+ # ("MMIM_best3.pth", 4),
31
+ # ("MMIM_best2.pth", 12)
32
+ # ]
33
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+
35
+ # models = []
36
+ # offsets = []
37
+ # total_classes = 0
38
+ # for path, num_classes in model_defs:
39
+ # model = MMIM(num_classes)
40
+ # state_dict = torch.load(path, map_location=device)
41
+ # model.load_state_dict(state_dict)
42
+ # model.to(device)
43
+ # model.eval()
44
+ # models.append(model)
45
+ # offsets.append(total_classes)
46
+ # total_classes += num_classes
47
+
48
+ # # Generate dummy class names like class0, class1, ...
49
+ # idx_to_class = {i: f"class{i}" for i in range(total_classes)}
50
+ # return models, offsets, idx_to_class
51
+
52
+ # # --- Inference on one image ---
53
+ # def predict_image(image, models, offsets, idx_to_class):
54
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
55
+ # transform = transforms.Compose([
56
+ # transforms.Resize((224, 224)),
57
+ # transforms.ToTensor(),
58
+ # transforms.Normalize([0.5]*3, [0.5]*3)
59
+ # ])
60
+ # image_tensor = transform(image).unsqueeze(0).to(device)
61
+
62
+ # temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence
63
+
64
+ # max_score = float('-inf')
65
+ # final_pred = -1
66
+ # probs_combined = {}
67
+
68
+ # for model, offset, temp in zip(models, offsets, temperatures):
69
+ # with torch.no_grad():
70
+ # logits = model(image_tensor) / temp
71
+ # probs = F.softmax(logits, dim=1).squeeze(0)
72
+ # top_score, top_class = torch.max(probs, dim=0)
73
+ # if top_score.item() > max_score:
74
+ # max_score = top_score.item()
75
+ # final_pred = top_class.item() + offset
76
+
77
+ # # Also collect probabilities for all classes
78
+ # for i, p in enumerate(probs):
79
+ # probs_combined[offset + i] = p.item()
80
+
81
+ # # Sort top 3
82
+ # top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
83
+ # return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
84
+
85
+
86
+
87
  import torch
 
88
  import torch.nn.functional as F
 
89
  from torchvision import transforms
90
  from PIL import Image
91
+ import gradio as gr
92
+ import torch.nn as nn
93
+ from torchvision.models import resnet18 # Example; change to your actual architecture
94
+
95
+ # ✅ Define your 25 class names (index 0 → class 1)
96
+ class_names = [
97
+ "Capplehinee ", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
98
+ "Rubber vine", "Siam weed", "Snake weed", # 1-9 (Model 1)
99
+ "Broadleaf", # class10 (Model 3)
100
+ "Grass", # class11
101
+ "Soil", # class12
102
+ "Soybean", # class13
103
+ "Black grass", # class14 (Model 2)
104
+ "Charlock", # class15
105
+ "Cleavers", # class16
106
+ "Common Chickweed", # class17
107
+ "Common Wheat", # class18
108
+ "Fat Hen", # class19
109
+ "Loose Silky-bent", # class20
110
+ "Maize", # class21
111
+ "Scentless Mayweed", # class22
112
+ "Shepherds purse", # class23
113
+ "Small-flowered Cranesbill", # class24
114
+ "Sugar beet" # 14-25 (Model 2)
115
+ ]
116
+
117
+ # Define transforms (adjust to match your model training)
118
+ transform = transforms.Compose([
119
+ transforms.Resize((224, 224)),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],
122
+ std=[0.5, 0.5, 0.5])
123
+ ])
124
+
125
+ # ✅ Load your models (use correct architecture and weights)
126
+ def get_model(num_classes):
127
+ model = resnet18(pretrained=False)
128
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
129
+ return model
130
+
131
+ model1 = get_model(9) # for class 1-9
132
+ model3 = get_model(4) # for class 10-13
133
+ model2 = get_model(12) # for class 14-25
134
+
135
+ model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
136
+ model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
137
+ model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
138
+
139
+ model1.eval()
140
+ model2.eval()
141
+ model3.eval()
142
+
143
+ # ✅ Inference function
144
+ def predict(image):
145
+ image_tensor = transform(image).unsqueeze(0)
146
+
147
+ with torch.no_grad():
148
+ out1 = F.softmax(model1(image_tensor), dim=1) # [1, 9]
149
+ out3 = F.softmax(model3(image_tensor), dim=1) # [1, 4]
150
+ out2 = F.softmax(model2(image_tensor), dim=1) # [1, 12]
151
+
152
+ # Combine into a 25-class vector
153
+ combined = torch.cat([out1, out3, out2], dim=1) # shape: [1, 25]
154
+ pred_idx = combined.argmax(dim=1).item()
155
+ confidence = combined.max().item()
156
+
157
+ # Optional rejection
158
+ if confidence < 0.5:
159
+ return "Prediction uncertain or unknown class"
160
+
161
+ return f"Predicted: {class_names[pred_idx]} (Confidence: {confidence:.2f})"
162
+
163
+ # Gradio app
164
+ app = gr.Interface(
165
+ fn=predict,
166
+ inputs=gr.Image(type="pil"),
167
+ outputs="text",
168
+ title="Weed Classifier - 25 Class Combined (3 Models)",
169
+ description="Upload an image to classify weeds across 25 species using 3 separate models."
170
+ )
171
+
172
+ # ✅ Launch
173
+ if __name__ == "__main__":
174
+ app.launch()