NagashreePai commited on
Commit
a174a2c
Β·
verified Β·
1 Parent(s): 0e7c6a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
10
- def __init__(self, num_classes=9):
11
  super(MMIM, self).__init__()
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
@@ -24,9 +24,9 @@ class MMIM(nn.Module):
24
 
25
  # βœ… Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- model = MMIM(num_classes=9)
28
 
29
- # 🧠 Load pretrained weights except mismatched classifier
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
  k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
@@ -36,12 +36,16 @@ model.load_state_dict(filtered_checkpoint, strict=False)
36
  model.to(device)
37
  model.eval()
38
 
39
- # βœ… Updated class names (match folder structure)
40
  class_names = [
41
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
42
- "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed", 'Black grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common Wheat',
43
- 'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed',
44
- 'Shepherds purse', 'Small-flowered Cranesbill', 'Sugar beet','Carpetweeds','Crabgrass','Eclipta','Goosegrass','Morningglory','Nutsedge','PalmerAmaranth','Pricky Sida','Purslane','Ragweed','Sicklepod','SpottedSpurge','SpurredAnoda','Swinecress','Waterhemp'
 
 
 
 
45
  ]
46
 
47
  # πŸ” Image transform
@@ -50,7 +54,7 @@ transform = transforms.Compose([
50
  transforms.ToTensor()
51
  ])
52
 
53
- # πŸ” Prediction function with negative detection
54
  def predict(img):
55
  img = img.convert('RGB')
56
  img_tensor = transform(img).unsqueeze(0).to(device)
 
7
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
10
+ def __init__(self, num_classes=36):
11
  super(MMIM, self).__init__()
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
 
24
 
25
  # βœ… Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MMIM(num_classes=36)
28
 
29
+ # 🧠 Load only matching weights from checkpoint (skip classifier mismatch)
30
  checkpoint = torch.load("MMIM_best.pth", map_location=device)
31
  filtered_checkpoint = {
32
  k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
 
36
  model.to(device)
37
  model.eval()
38
 
39
+ # βœ… Correct class names list (ordered by folder names)
40
  class_names = [
41
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
42
+ "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed",
43
+ "Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat",
44
+ "Fat Hen", "Loose Silky-bent", "Maize", "Scentless Mayweed",
45
+ "Shepherds purse", "Small-flowered Cranesbill", "Sugar beet",
46
+ "Carpetweeds", "Crabgrass", "Eclipta", "Goosegrass", "Morningglory",
47
+ "Nutsedge", "PalmerAmaranth", "Pricky Sida", "Purslane", "Ragweed",
48
+ "Sicklepod", "SpottedSpurge", "SpurredAnoda", "Swinecress", "Waterhemp"
49
  ]
50
 
51
  # πŸ” Image transform
 
54
  transforms.ToTensor()
55
  ])
56
 
57
+ # πŸ” Prediction function
58
  def predict(img):
59
  img = img.convert('RGB')
60
  img_tensor = transform(img).unsqueeze(0).to(device)