NagashreePai commited on
Commit
ae95bc5
·
verified ·
1 Parent(s): 8d367db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -41
app.py CHANGED
@@ -3,7 +3,9 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
 
6
  from PIL import Image
 
7
 
8
  # 🔧 Model definition
9
  class MMIM(nn.Module):
@@ -26,54 +28,66 @@ class MMIM(nn.Module):
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model = MMIM(num_classes=40)
28
 
29
- # 🧠 Load checkpoint (ignore mismatched keys like classifier weights)
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
- k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
 
33
  }
34
  model.load_state_dict(filtered_checkpoint, strict=False)
35
  model.to(device)
36
  model.eval()
37
 
38
- # ✅ Corrected class_names mapping based on your confusion matrix and folder order
39
- class_names = [
40
- "Chinee apple", # class1
41
- "Lantana", # class2
42
- "Negative", # class3
43
- "Parkinsonia", # class4
44
- "Parthenium", # class5
45
- "Prickly acacia", # class6
46
- "Rubber vine", # class7
47
- "Siam weed", # class8
48
- "Snake weed", # class9
49
- "Black grass", # class14
50
- "Charlock", # class15
51
- "Cleavers", # class16
52
- "Common Chickweed", # class17
53
- "Common Wheat", # class18
54
- "Fat Hen", # class19
55
- "Loose Silky-bent", # class20
56
- "Maize", # class21
57
- "Scentless Mayweed", # class22
58
- "Shepherds purse", # class23
59
- "Small-flowered Cranesbill",# class24
60
- "Sugar beet", # class25
61
- "Carpetweeds", # class26
62
- "Crabgrass", # class27
63
- "Eclipta", # class28
64
- "Goosegrass", # class29
65
- "Morningglory", # class30
66
- "Nutsedge", # class31
67
- "PalmerAmaranth", # class32
68
- "Pricky Sida", # class33
69
- "Purslane", # class34
70
- "Ragweed", # class35
71
- "Sicklepod", # class36
72
- "SpottedSpurge", # class37
73
- "Spurred Anoda", # class38
74
- "Swinecress", # class39
75
- "Waterhemp" # class40
76
- ]
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # 🔁 Image transform
79
  transform = transforms.Compose([
 
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
6
+ from torchvision.datasets import ImageFolder
7
  from PIL import Image
8
+ import os
9
 
10
  # 🔧 Model definition
11
  class MMIM(nn.Module):
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model = MMIM(num_classes=40)
30
 
 
31
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
32
  filtered_checkpoint = {
33
+ k: v for k, v in checkpoint.items()
34
+ if k in model.state_dict() and model.state_dict()[k].shape == v.shape
35
  }
36
  model.load_state_dict(filtered_checkpoint, strict=False)
37
  model.to(device)
38
  model.eval()
39
 
40
+ # ✅ Load actual class order from test folder
41
+ test_dir = "test" # path to the test folder you used with ImageFolder
42
+ folder_class_map = ImageFolder(test_dir).class_to_idx
43
+ idx_to_folder = {v: k for k, v in folder_class_map.items()}
44
+
45
+ # ✅ Map folder name like 'class15' → weed name
46
+ label_translation = {
47
+ 'class1': "Chinee apple",
48
+ 'class2': "Lantana",
49
+ 'class3': "Negative",
50
+ 'class4': "Parkinsonia",
51
+ 'class5': "Parthenium",
52
+ 'class6': "Prickly acacia",
53
+ 'class7': "Rubber vine",
54
+ 'class8': "Siam weed",
55
+ 'class9': "Snake weed",
56
+ 'class10': "Black grass",
57
+ 'class11': "Charlock",
58
+ 'class12': "Cleavers",
59
+ 'class13': "Common Chickweed",
60
+ 'class14': "Common Wheat",
61
+ 'class15': "Fat Hen",
62
+ 'class16': "Loose Silky-bent",
63
+ 'class17': "Maize",
64
+ 'class18': "Scentless Mayweed",
65
+ 'class19': "Shepherds purse",
66
+ 'class20': "Small-flowered Cranesbill",
67
+ 'class21': "Sugar beet",
68
+ 'class22': "Carpetweeds",
69
+ 'class23': "Crabgrass",
70
+ 'class24': "Eclipta",
71
+ 'class25': "Goosegrass",
72
+ 'class26': "Morningglory",
73
+ 'class27': "Nutsedge",
74
+ 'class28': "PalmerAmaranth",
75
+ 'class29': "Pricky Sida",
76
+ 'class30': "Purslane",
77
+ 'class31': "Ragweed",
78
+ 'class32': "Sicklepod",
79
+ 'class33': "SpottedSpurge",
80
+ 'class34': "Spurred Anoda",
81
+ 'class35': "Swinecress",
82
+ 'class36': "Waterhemp",
83
+ 'class37': "Extra1",
84
+ 'class38': "Extra2",
85
+ 'class39': "Extra3",
86
+ 'class40': "Extra4"
87
+ }
88
+
89
+ # ✅ Final class_names list (aligned to model output indices)
90
+ class_names = [label_translation[idx_to_folder[i]] for i in range(len(idx_to_folder))]
91
 
92
  # 🔁 Image transform
93
  transform = transforms.Compose([