NagashreePai commited on
Commit
f23bb1e
·
verified ·
1 Parent(s): 169d309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -26,54 +26,54 @@ class MMIM(nn.Module):
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model = MMIM(num_classes=36)
28
 
 
29
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
30
  filtered_checkpoint = {
31
- k: v for k, v in checkpoint.items()
32
- if k in model.state_dict() and model.state_dict()[k].shape == v.shape
33
  }
34
  model.load_state_dict(filtered_checkpoint, strict=False)
 
35
  model.to(device)
36
  model.eval()
37
 
38
- # ✅ Final class_names (corrected index order, without class10–class13)
39
  class_names = [
40
- "Chinee apple", # 0 = class1
41
- "Lantana", # 1 = class2
42
- "Negative", # 2 = class3
43
- "Parkinsonia", # 3 = class4
44
- "Parthenium", # 4 = class5
45
- "Prickly acacia", # 5 = class6
46
- "Rubber vine", # 6 = class7
47
- "Siam weed", # 7 = class8
48
- "Snake weed", # 8 = class9
49
- # Skipping class10–class13
50
- "Common Wheat", # 9 = class14
51
- "Fat Hen", # 10 = class15
52
- "Loose Silky-bent", # 11 = class16
53
- "Maize", # 12 = class17
54
- "Scentless Mayweed", # 13 = class18
55
- "Shepherds purse", # 14 = class19
56
- "Small-flowered Cranesbill",# 15 = class20
57
- "Sugar beet", # 16 = class21
58
- "Carpetweeds", # 17 = class22
59
- "Crabgrass", # 18 = class23
60
- "Eclipta", # 19 = class24
61
- "Goosegrass", # 20 = class25
62
- "Morningglory", # 21 = class26
63
- "Nutsedge", # 22 = class27
64
- "PalmerAmaranth", # 23 = class28
65
- "Pricky Sida", # 24 = class29
66
- "Purslane", # 25 = class30
67
- "Ragweed", # 26 = class31
68
- "Sicklepod", # 27 = class32
69
- "SpottedSpurge", # 28 = class33
70
- "Spurred Anoda", # 29 = class34
71
- "Swinecress", # 30 = class35
72
- "Waterhemp", # 31 = class36
73
- "Extra1", # 32 = class37 (if any)
74
- "Extra2", # 33 = class38
75
- "Extra3", # 34 = class39
76
- "Extra4" # 35 = class40
77
  ]
78
 
79
  # 🔁 Image transform
@@ -109,4 +109,4 @@ interface = gr.Interface(
109
  description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
110
  )
111
 
112
- interface.launch()
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model = MMIM(num_classes=36)
28
 
29
+ # 🧠 Load only matching weights from checkpoint (skip classifier mismatch)
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
+ k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
 
33
  }
34
  model.load_state_dict(filtered_checkpoint, strict=False)
35
+
36
  model.to(device)
37
  model.eval()
38
 
39
+ # ✅ class_names mapped according to confusion matrix order
40
  class_names = [
41
+ "Chinee apple", # class1
42
+ "Black grass", # class2
43
+ "Negative", # class3
44
+ "Rubber vine", # class7
45
+ "Snake weed", # class9
46
+ "Black grass", # class10
47
+ "Charlock", # class11
48
+ "Cleavers", # class12
49
+ "Common Chickweed" # class13
50
+ "Common Wheat", # class14
51
+ "Fat Hen", # class15
52
+ "Loose Silky-bent", # class16
53
+ "Maize", # class17
54
+ "Scentless Mayweed", # class18
55
+ "Shepherds purse", # class19
56
+ "Small-flowered Cranesbill",# class20
57
+ "Sugar beet", # class21
58
+ "Carpetweeds", # class22
59
+ "Crabgrass", # class23
60
+ "Eclipta", # class24
61
+ "Goosegrass", # class25
62
+ "Morningglory", # class26
63
+ "Nutsedge", # class27
64
+ "PalmerAmaranth", # class28
65
+ "Pricky Sida", # class29
66
+ "Purslane", # class30
67
+ "Ragweed", # class31
68
+ "Sicklepod", # class32
69
+ "SpottedSpurge", # class33
70
+ "SpurredAnoda", # class34
71
+ "Swinecress", # class35
72
+ "Waterhemp", # class36
73
+ "Parkinsonia", # class4
74
+ "Parthenium", # class5
75
+ "Prickly acacia", # class6
76
+ "Siam weed", # class8
 
77
  ]
78
 
79
  # 🔁 Image transform
 
109
  description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
110
  )
111
 
112
+ interface.launch()