Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from PIL import Image
|
|
| 7 |
|
| 8 |
# π§ Model definition
|
| 9 |
class MMIM(nn.Module):
|
| 10 |
-
def __init__(self, num_classes=
|
| 11 |
super(MMIM, self).__init__()
|
| 12 |
self.backbone = swin_t(weights='IMAGENET1K_V1')
|
| 13 |
self.backbone.head = nn.Identity()
|
|
@@ -24,9 +24,9 @@ class MMIM(nn.Module):
|
|
| 24 |
|
| 25 |
# β
Load model
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
-
model = MMIM(num_classes=
|
| 28 |
|
| 29 |
-
# π§ Load
|
| 30 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 31 |
filtered_checkpoint = {
|
| 32 |
k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
|
|
@@ -36,12 +36,16 @@ model.load_state_dict(filtered_checkpoint, strict=False)
|
|
| 36 |
model.to(device)
|
| 37 |
model.eval()
|
| 38 |
|
| 39 |
-
# β
|
| 40 |
class_names = [
|
| 41 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
|
| 42 |
-
"Prickly acacia", "Rubber vine", "Siam weed", "Snake weed",
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
]
|
| 46 |
|
| 47 |
# π Image transform
|
|
@@ -50,7 +54,7 @@ transform = transforms.Compose([
|
|
| 50 |
transforms.ToTensor()
|
| 51 |
])
|
| 52 |
|
| 53 |
-
# π Prediction function
|
| 54 |
def predict(img):
|
| 55 |
img = img.convert('RGB')
|
| 56 |
img_tensor = transform(img).unsqueeze(0).to(device)
|
|
|
|
| 7 |
|
| 8 |
# π§ Model definition
|
| 9 |
class MMIM(nn.Module):
|
| 10 |
+
def __init__(self, num_classes=36):
|
| 11 |
super(MMIM, self).__init__()
|
| 12 |
self.backbone = swin_t(weights='IMAGENET1K_V1')
|
| 13 |
self.backbone.head = nn.Identity()
|
|
|
|
| 24 |
|
| 25 |
# β
Load model
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
model = MMIM(num_classes=36)
|
| 28 |
|
| 29 |
+
# π§ Load only matching weights from checkpoint (skip classifier mismatch)
|
| 30 |
checkpoint = torch.load("MMIM_best.pth", map_location=device)
|
| 31 |
filtered_checkpoint = {
|
| 32 |
k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
|
|
|
|
| 36 |
model.to(device)
|
| 37 |
model.eval()
|
| 38 |
|
| 39 |
+
# β
Correct class names list (ordered by folder names)
|
| 40 |
class_names = [
|
| 41 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
|
| 42 |
+
"Prickly acacia", "Rubber vine", "Siam weed", "Snake weed",
|
| 43 |
+
"Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat",
|
| 44 |
+
"Fat Hen", "Loose Silky-bent", "Maize", "Scentless Mayweed",
|
| 45 |
+
"Shepherds purse", "Small-flowered Cranesbill", "Sugar beet",
|
| 46 |
+
"Carpetweeds", "Crabgrass", "Eclipta", "Goosegrass", "Morningglory",
|
| 47 |
+
"Nutsedge", "PalmerAmaranth", "Pricky Sida", "Purslane", "Ragweed",
|
| 48 |
+
"Sicklepod", "SpottedSpurge", "SpurredAnoda", "Swinecress", "Waterhemp"
|
| 49 |
]
|
| 50 |
|
| 51 |
# π Image transform
|
|
|
|
| 54 |
transforms.ToTensor()
|
| 55 |
])
|
| 56 |
|
| 57 |
+
# π Prediction function
|
| 58 |
def predict(img):
|
| 59 |
img = img.convert('RGB')
|
| 60 |
img_tensor = transform(img).unsqueeze(0).to(device)
|