NagashreePai commited on
Commit
8d367db
·
verified ·
1 Parent(s): 3cc807c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -38
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
 
8
  # 🔧 Model definition
9
  class MMIM(nn.Module):
10
- def __init__(self, num_classes=36):
11
  super(MMIM, self).__init__()
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
@@ -24,56 +24,55 @@ class MMIM(nn.Module):
24
 
25
  # ✅ Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- model = MMIM(num_classes=36)
28
 
29
- # 🧠 Load only matching weights from checkpoint (skip classifier mismatch)
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
  k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
33
  }
34
  model.load_state_dict(filtered_checkpoint, strict=False)
35
-
36
  model.to(device)
37
  model.eval()
38
 
39
- # ✅ class_names mapped according to confusion matrix order
40
  class_names = [
41
  "Chinee apple", # class1
42
  "Lantana", # class2
43
  "Negative", # class3
44
- "Rubber vine", # class7
45
- "Snake weed", # class9
46
- "Black grass", # class10
47
- "Charlock", # class11
48
- "Cleavers", # class12
49
- "Common Chickweed" # class13
50
- "Common Wheat", # class14
51
- "Fat Hen", # class15
52
- "Loose Silky-bent", # class16
53
- "Maize", # class17
54
- "Scentless Mayweed", # class18
55
- "Shepherds purse", # class19
56
- "Small-flowered Cranesbill",# class20
57
- "Sugar beet", # class21
58
- "Carpetweeds", # class22
59
- "Crabgrass", # class23
60
- "Eclipta", # class24
61
- "Goosegrass", # class25
62
- "Morningglory", # class26
63
- "Nutsedge", # class27
64
- "PalmerAmaranth", # class28
65
- "Pricky Sida", # class29
66
- "Purslane", # class30
67
- "Ragweed", # class31
68
- "Sicklepod", # class32
69
- "SpottedSpurge", # class33
70
- "SpurredAnoda", # class34
71
- "Swinecress", # class35
72
- "Waterhemp", # class36
73
- "Parkinsonia", # class4
74
- "Parthenium", # class5
75
- "Prickly acacia", # class6
76
- "Siam weed", # class8
77
  ]
78
 
79
  # 🔁 Image transform
 
7
 
8
  # 🔧 Model definition
9
  class MMIM(nn.Module):
10
+ def __init__(self, num_classes=40):
11
  super(MMIM, self).__init__()
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
 
24
 
25
  # ✅ Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MMIM(num_classes=40)
28
 
29
+ # 🧠 Load checkpoint (ignore mismatched keys like classifier weights)
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
  k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
33
  }
34
  model.load_state_dict(filtered_checkpoint, strict=False)
 
35
  model.to(device)
36
  model.eval()
37
 
38
+ # ✅ Corrected class_names mapping based on your confusion matrix and folder order
39
  class_names = [
40
  "Chinee apple", # class1
41
  "Lantana", # class2
42
  "Negative", # class3
43
+ "Parkinsonia", # class4
44
+ "Parthenium", # class5
45
+ "Prickly acacia", # class6
46
+ "Rubber vine", # class7
47
+ "Siam weed", # class8
48
+ "Snake weed", # class9
49
+ "Black grass", # class14
50
+ "Charlock", # class15
51
+ "Cleavers", # class16
52
+ "Common Chickweed", # class17
53
+ "Common Wheat", # class18
54
+ "Fat Hen", # class19
55
+ "Loose Silky-bent", # class20
56
+ "Maize", # class21
57
+ "Scentless Mayweed", # class22
58
+ "Shepherds purse", # class23
59
+ "Small-flowered Cranesbill",# class24
60
+ "Sugar beet", # class25
61
+ "Carpetweeds", # class26
62
+ "Crabgrass", # class27
63
+ "Eclipta", # class28
64
+ "Goosegrass", # class29
65
+ "Morningglory", # class30
66
+ "Nutsedge", # class31
67
+ "PalmerAmaranth", # class32
68
+ "Pricky Sida", # class33
69
+ "Purslane", # class34
70
+ "Ragweed", # class35
71
+ "Sicklepod", # class36
72
+ "SpottedSpurge", # class37
73
+ "Spurred Anoda", # class38
74
+ "Swinecress", # class39
75
+ "Waterhemp" # class40
76
  ]
77
 
78
  # 🔁 Image transform