Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,9 @@ import torch
|
|
| 3 |
import torch.nn as nn
|
| 4 |
from torchvision import transforms
|
| 5 |
from torchvision.models import swin_t
|
|
|
|
| 6 |
from PIL import Image
|
|
|
|
| 7 |
|
| 8 |
# 🔧 Model definition
|
| 9 |
class MMIM(nn.Module):
|
|
@@ -26,54 +28,66 @@ class MMIM(nn.Module):
|
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
model = MMIM(num_classes=40)
|
| 28 |
|
| 29 |
-
# 🧠 Load checkpoint (ignore mismatched keys like classifier weights)
|
| 30 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 31 |
filtered_checkpoint = {
|
| 32 |
-
k: v for k, v in checkpoint.items()
|
|
|
|
| 33 |
}
|
| 34 |
model.load_state_dict(filtered_checkpoint, strict=False)
|
| 35 |
model.to(device)
|
| 36 |
model.eval()
|
| 37 |
|
| 38 |
-
# ✅
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"
|
| 55 |
-
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
-
"
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
-
"
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# 🔁 Image transform
|
| 79 |
transform = transforms.Compose([
|
|
|
|
| 3 |
import torch.nn as nn
|
| 4 |
from torchvision import transforms
|
| 5 |
from torchvision.models import swin_t
|
| 6 |
+
from torchvision.datasets import ImageFolder
|
| 7 |
from PIL import Image
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
# 🔧 Model definition
|
| 11 |
class MMIM(nn.Module):
|
|
|
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
model = MMIM(num_classes=40)
|
| 30 |
|
|
|
|
| 31 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 32 |
filtered_checkpoint = {
|
| 33 |
+
k: v for k, v in checkpoint.items()
|
| 34 |
+
if k in model.state_dict() and model.state_dict()[k].shape == v.shape
|
| 35 |
}
|
| 36 |
model.load_state_dict(filtered_checkpoint, strict=False)
|
| 37 |
model.to(device)
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
+
# ✅ Load actual class order from test folder
|
| 41 |
+
test_dir = "test" # path to the test folder you used with ImageFolder
|
| 42 |
+
folder_class_map = ImageFolder(test_dir).class_to_idx
|
| 43 |
+
idx_to_folder = {v: k for k, v in folder_class_map.items()}
|
| 44 |
+
|
| 45 |
+
# ✅ Map folder name like 'class15' → weed name
|
| 46 |
+
label_translation = {
|
| 47 |
+
'class1': "Chinee apple",
|
| 48 |
+
'class2': "Lantana",
|
| 49 |
+
'class3': "Negative",
|
| 50 |
+
'class4': "Parkinsonia",
|
| 51 |
+
'class5': "Parthenium",
|
| 52 |
+
'class6': "Prickly acacia",
|
| 53 |
+
'class7': "Rubber vine",
|
| 54 |
+
'class8': "Siam weed",
|
| 55 |
+
'class9': "Snake weed",
|
| 56 |
+
'class10': "Black grass",
|
| 57 |
+
'class11': "Charlock",
|
| 58 |
+
'class12': "Cleavers",
|
| 59 |
+
'class13': "Common Chickweed",
|
| 60 |
+
'class14': "Common Wheat",
|
| 61 |
+
'class15': "Fat Hen",
|
| 62 |
+
'class16': "Loose Silky-bent",
|
| 63 |
+
'class17': "Maize",
|
| 64 |
+
'class18': "Scentless Mayweed",
|
| 65 |
+
'class19': "Shepherds purse",
|
| 66 |
+
'class20': "Small-flowered Cranesbill",
|
| 67 |
+
'class21': "Sugar beet",
|
| 68 |
+
'class22': "Carpetweeds",
|
| 69 |
+
'class23': "Crabgrass",
|
| 70 |
+
'class24': "Eclipta",
|
| 71 |
+
'class25': "Goosegrass",
|
| 72 |
+
'class26': "Morningglory",
|
| 73 |
+
'class27': "Nutsedge",
|
| 74 |
+
'class28': "PalmerAmaranth",
|
| 75 |
+
'class29': "Pricky Sida",
|
| 76 |
+
'class30': "Purslane",
|
| 77 |
+
'class31': "Ragweed",
|
| 78 |
+
'class32': "Sicklepod",
|
| 79 |
+
'class33': "SpottedSpurge",
|
| 80 |
+
'class34': "Spurred Anoda",
|
| 81 |
+
'class35': "Swinecress",
|
| 82 |
+
'class36': "Waterhemp",
|
| 83 |
+
'class37': "Extra1",
|
| 84 |
+
'class38': "Extra2",
|
| 85 |
+
'class39': "Extra3",
|
| 86 |
+
'class40': "Extra4"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# ✅ Final class_names list (aligned to model output indices)
|
| 90 |
+
class_names = [label_translation[idx_to_folder[i]] for i in range(len(idx_to_folder))]
|
| 91 |
|
| 92 |
# 🔁 Image transform
|
| 93 |
transform = transforms.Compose([
|