NagashreePai commited on
Commit
169d309
·
verified ·
1 Parent(s): ae95bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -55
app.py CHANGED
@@ -3,13 +3,11 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
6
- from torchvision.datasets import ImageFolder
7
  from PIL import Image
8
- import os
9
 
10
  # 🔧 Model definition
11
  class MMIM(nn.Module):
12
- def __init__(self, num_classes=40):
13
  super(MMIM, self).__init__()
14
  self.backbone = swin_t(weights='IMAGENET1K_V1')
15
  self.backbone.head = nn.Identity()
@@ -26,7 +24,7 @@ class MMIM(nn.Module):
26
 
27
  # ✅ Load model
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- model = MMIM(num_classes=40)
30
 
31
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
32
  filtered_checkpoint = {
@@ -37,57 +35,46 @@ model.load_state_dict(filtered_checkpoint, strict=False)
37
  model.to(device)
38
  model.eval()
39
 
40
- # ✅ Load actual class order from test folder
41
- test_dir = "test" # path to the test folder you used with ImageFolder
42
- folder_class_map = ImageFolder(test_dir).class_to_idx
43
- idx_to_folder = {v: k for k, v in folder_class_map.items()}
44
-
45
- # Map folder name like 'class15' → weed name
46
- label_translation = {
47
- 'class1': "Chinee apple",
48
- 'class2': "Lantana",
49
- 'class3': "Negative",
50
- 'class4': "Parkinsonia",
51
- 'class5': "Parthenium",
52
- 'class6': "Prickly acacia",
53
- 'class7': "Rubber vine",
54
- 'class8': "Siam weed",
55
- 'class9': "Snake weed",
56
- 'class10': "Black grass",
57
- 'class11': "Charlock",
58
- 'class12': "Cleavers",
59
- 'class13': "Common Chickweed",
60
- 'class14': "Common Wheat",
61
- 'class15': "Fat Hen",
62
- 'class16': "Loose Silky-bent",
63
- 'class17': "Maize",
64
- 'class18': "Scentless Mayweed",
65
- 'class19': "Shepherds purse",
66
- 'class20': "Small-flowered Cranesbill",
67
- 'class21': "Sugar beet",
68
- 'class22': "Carpetweeds",
69
- 'class23': "Crabgrass",
70
- 'class24': "Eclipta",
71
- 'class25': "Goosegrass",
72
- 'class26': "Morningglory",
73
- 'class27': "Nutsedge",
74
- 'class28': "PalmerAmaranth",
75
- 'class29': "Pricky Sida",
76
- 'class30': "Purslane",
77
- 'class31': "Ragweed",
78
- 'class32': "Sicklepod",
79
- 'class33': "SpottedSpurge",
80
- 'class34': "Spurred Anoda",
81
- 'class35': "Swinecress",
82
- 'class36': "Waterhemp",
83
- 'class37': "Extra1",
84
- 'class38': "Extra2",
85
- 'class39': "Extra3",
86
- 'class40': "Extra4"
87
- }
88
-
89
- # ✅ Final class_names list (aligned to model output indices)
90
- class_names = [label_translation[idx_to_folder[i]] for i in range(len(idx_to_folder))]
91
 
92
  # 🔁 Image transform
93
  transform = transforms.Compose([
 
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
 
6
  from PIL import Image
 
7
 
8
  # 🔧 Model definition
9
  class MMIM(nn.Module):
10
+ def __init__(self, num_classes=36):
11
  super(MMIM, self).__init__()
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
 
24
 
25
  # ✅ Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MMIM(num_classes=36)
28
 
29
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
30
  filtered_checkpoint = {
 
35
  model.to(device)
36
  model.eval()
37
 
38
+ # ✅ Final class_names (corrected index order, without class10–class13)
39
+ class_names = [
40
+ "Chinee apple", # 0 = class1
41
+ "Lantana", # 1 = class2
42
+ "Negative", # 2 = class3
43
+ "Parkinsonia", # 3 = class4
44
+ "Parthenium", # 4 = class5
45
+ "Prickly acacia", # 5 = class6
46
+ "Rubber vine", # 6 = class7
47
+ "Siam weed", # 7 = class8
48
+ "Snake weed", # 8 = class9
49
+ # Skipping class10–class13
50
+ "Common Wheat", # 9 = class14
51
+ "Fat Hen", # 10 = class15
52
+ "Loose Silky-bent", # 11 = class16
53
+ "Maize", # 12 = class17
54
+ "Scentless Mayweed", # 13 = class18
55
+ "Shepherds purse", # 14 = class19
56
+ "Small-flowered Cranesbill",# 15 = class20
57
+ "Sugar beet", # 16 = class21
58
+ "Carpetweeds", # 17 = class22
59
+ "Crabgrass", # 18 = class23
60
+ "Eclipta", # 19 = class24
61
+ "Goosegrass", # 20 = class25
62
+ "Morningglory", # 21 = class26
63
+ "Nutsedge", # 22 = class27
64
+ "PalmerAmaranth", # 23 = class28
65
+ "Pricky Sida", # 24 = class29
66
+ "Purslane", # 25 = class30
67
+ "Ragweed", # 26 = class31
68
+ "Sicklepod", # 27 = class32
69
+ "SpottedSpurge", # 28 = class33
70
+ "Spurred Anoda", # 29 = class34
71
+ "Swinecress", # 30 = class35
72
+ "Waterhemp", # 31 = class36
73
+ "Extra1", # 32 = class37 (if any)
74
+ "Extra2", # 33 = class38
75
+ "Extra3", # 34 = class39
76
+ "Extra4" # 35 = class40
77
+ ]
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # 🔁 Image transform
80
  transform = transforms.Compose([