NagashreePai commited on
Commit
3881eaa
Β·
verified Β·
1 Parent(s): 05f27b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -54
app.py CHANGED
@@ -1,52 +1,21 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
5
  from torchvision import transforms
6
  from PIL import Image
7
  import gradio as gr
 
8
 
9
- # βœ… Class labels (indexed 0 to 24 β†’ class1 to class25)
10
  class_names = [
11
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
12
- "Rubber vine", "Siam weed", "Snake weed", # Model 1 (1–9)
13
- "Broadleaf", "Grass", "Soil", "Soybean", # Model 3 (10–13)
14
  "Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen",
15
  "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse",
16
- "Small-flowered Cranesbill", "Sugar beet" # Model 2 (14–25)
17
  ]
18
 
19
- # βœ… MMIM Model definition with Swin-T backbone
20
- class MMIM(nn.Module):
21
- def __init__(self, num_classes):
22
- super(MMIM, self).__init__()
23
- self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
24
- self.backbone.head = nn.Identity()
25
- self.classifier = nn.Sequential(
26
- nn.Linear(768, 512),
27
- nn.ReLU(),
28
- nn.Dropout(0.3),
29
- nn.Linear(512, num_classes)
30
- )
31
-
32
- def forward(self, x):
33
- features = self.backbone(x)
34
- return self.classifier(features)
35
-
36
- # βœ… Load 3 models trained on subsets of classes
37
- model1 = MMIM(num_classes=9) # class 1–9
38
- model3 = MMIM(num_classes=4) # class 10–13
39
- model2 = MMIM(num_classes=12) # class 14–25
40
-
41
- model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
42
- model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
43
- model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
44
-
45
- model1.eval()
46
- model2.eval()
47
- model3.eval()
48
-
49
- # βœ… Image preprocessing for Swin
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
52
  transforms.ToTensor(),
@@ -54,34 +23,26 @@ transform = transforms.Compose([
54
  std=[0.5, 0.5, 0.5])
55
  ])
56
 
57
- # βœ… Prediction function
58
- def predict(image):
59
- image_tensor = transform(image).unsqueeze(0) # shape: [1, 3, 224, 224]
60
-
61
- with torch.no_grad():
62
- out1 = F.softmax(model1(image_tensor), dim=1) # [1, 9]
63
- out3 = F.softmax(model3(image_tensor), dim=1) # [1, 4]
64
- out2 = F.softmax(model2(image_tensor), dim=1) # [1, 12]
65
 
66
- # Combine predictions into one 25-class vector
67
- combined = torch.cat([out1, out3, out2], dim=1) # [1, 25]
68
- pred_idx = combined.argmax(dim=1).item()
69
- confidence = combined[0, pred_idx].item()
70
 
71
  if confidence < 0.5:
72
- return "Prediction uncertain – possibly unknown or low confidence"
73
 
74
- return f"🧠 Predicted: **{class_names[pred_idx]}** (Confidence: {confidence:.2f})"
75
 
76
- # βœ… Gradio UI
77
  app = gr.Interface(
78
  fn=predict,
79
  inputs=gr.Image(type="pil", label="Upload a weed image"),
80
  outputs=gr.Markdown(),
81
  title="🌿 Weed Classifier (25 Classes - MMIM)",
82
- description="Upload an image to classify it into 1 of 25 weed classes using 3 separate MMIM models."
83
  )
84
 
85
- # βœ… Launch app
86
  if __name__ == "__main__":
87
  app.launch()
 
1
  import torch
2
  import torch.nn as nn
 
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  import gradio as gr
6
+ from utils import MMIM, load_all_models, predict_image
7
 
8
+ # βœ… All 25 class labels
9
  class_names = [
10
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
11
+ "Rubber vine", "Siam weed", "Snake weed", # 1–9 (model1)
12
+ "Broadleaf", "Grass", "Soil", "Soybean", # 10–13 (model3)
13
  "Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen",
14
  "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse",
15
+ "Small-flowered Cranesbill", "Sugar beet" # 14–25 (model2)
16
  ]
17
 
18
+ # βœ… Image transform for Swin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
 
23
  std=[0.5, 0.5, 0.5])
24
  ])
25
 
26
+ # βœ… Load models
27
+ model1, model2, model3 = load_all_models()
 
 
 
 
 
 
28
 
29
+ # βœ… Gradio prediction function
30
+ def predict(image):
31
+ label, confidence = predict_image(image, model1, model2, model3, transform, class_names)
 
32
 
33
  if confidence < 0.5:
34
+ return "⚠️ Prediction uncertain – possibly unknown or low confidence"
35
 
36
+ return f"🧠 Predicted: **{label}** (Confidence: {confidence:.2f})"
37
 
38
+ # βœ… Gradio interface
39
  app = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil", label="Upload a weed image"),
42
  outputs=gr.Markdown(),
43
  title="🌿 Weed Classifier (25 Classes - MMIM)",
44
+ description="Upload an image to classify it into one of 25 weed categories using 3 Swin-MMIM models."
45
  )
46
 
 
47
  if __name__ == "__main__":
48
  app.launch()