Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,54 +26,54 @@ class MMIM(nn.Module):
|
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
model = MMIM(num_classes=36)
|
| 28 |
|
|
|
|
| 29 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 30 |
filtered_checkpoint = {
|
| 31 |
-
k: v for k, v in checkpoint.items()
|
| 32 |
-
if k in model.state_dict() and model.state_dict()[k].shape == v.shape
|
| 33 |
}
|
| 34 |
model.load_state_dict(filtered_checkpoint, strict=False)
|
|
|
|
| 35 |
model.to(device)
|
| 36 |
model.eval()
|
| 37 |
|
| 38 |
-
# ✅
|
| 39 |
class_names = [
|
| 40 |
-
"Chinee apple", #
|
| 41 |
-
"
|
| 42 |
-
"Negative", #
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
"
|
| 51 |
-
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
-
"
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
-
|
| 76 |
-
"Extra4" # 35 = class40
|
| 77 |
]
|
| 78 |
|
| 79 |
# 🔁 Image transform
|
|
@@ -109,4 +109,4 @@ interface = gr.Interface(
|
|
| 109 |
description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
|
| 110 |
)
|
| 111 |
|
| 112 |
-
interface.launch()
|
|
|
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
model = MMIM(num_classes=36)
|
| 28 |
|
| 29 |
+
# 🧠 Load only matching weights from checkpoint (skip classifier mismatch)
|
| 30 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 31 |
filtered_checkpoint = {
|
| 32 |
+
k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
|
|
|
|
| 33 |
}
|
| 34 |
model.load_state_dict(filtered_checkpoint, strict=False)
|
| 35 |
+
|
| 36 |
model.to(device)
|
| 37 |
model.eval()
|
| 38 |
|
| 39 |
+
# ✅ class_names mapped according to confusion matrix order
|
| 40 |
class_names = [
|
| 41 |
+
"Chinee apple", # class1
|
| 42 |
+
"Black grass", # class2
|
| 43 |
+
"Negative", # class3
|
| 44 |
+
"Rubber vine", # class7
|
| 45 |
+
"Snake weed", # class9
|
| 46 |
+
"Black grass", # class10
|
| 47 |
+
"Charlock", # class11
|
| 48 |
+
"Cleavers", # class12
|
| 49 |
+
"Common Chickweed" # class13
|
| 50 |
+
"Common Wheat", # class14
|
| 51 |
+
"Fat Hen", # class15
|
| 52 |
+
"Loose Silky-bent", # class16
|
| 53 |
+
"Maize", # class17
|
| 54 |
+
"Scentless Mayweed", # class18
|
| 55 |
+
"Shepherds purse", # class19
|
| 56 |
+
"Small-flowered Cranesbill",# class20
|
| 57 |
+
"Sugar beet", # class21
|
| 58 |
+
"Carpetweeds", # class22
|
| 59 |
+
"Crabgrass", # class23
|
| 60 |
+
"Eclipta", # class24
|
| 61 |
+
"Goosegrass", # class25
|
| 62 |
+
"Morningglory", # class26
|
| 63 |
+
"Nutsedge", # class27
|
| 64 |
+
"PalmerAmaranth", # class28
|
| 65 |
+
"Pricky Sida", # class29
|
| 66 |
+
"Purslane", # class30
|
| 67 |
+
"Ragweed", # class31
|
| 68 |
+
"Sicklepod", # class32
|
| 69 |
+
"SpottedSpurge", # class33
|
| 70 |
+
"SpurredAnoda", # class34
|
| 71 |
+
"Swinecress", # class35
|
| 72 |
+
"Waterhemp", # class36
|
| 73 |
+
"Parkinsonia", # class4
|
| 74 |
+
"Parthenium", # class5
|
| 75 |
+
"Prickly acacia", # class6
|
| 76 |
+
"Siam weed", # class8
|
|
|
|
| 77 |
]
|
| 78 |
|
| 79 |
# 🔁 Image transform
|
|
|
|
| 109 |
description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
|
| 110 |
)
|
| 111 |
|
| 112 |
+
interface.launch()
|