File size: 3,712 Bytes
7689141
 
 
 
 
 
 
 
 
169d309
7689141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169d309
05f67c3
f23bb1e
05f67c3
 
f23bb1e
05f67c3
 
f23bb1e
7689141
 
 
f23bb1e
169d309
f23bb1e
ce22489
 
 
 
 
 
 
82b7788
ce22489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169d309
7689141
 
 
 
 
 
 
a174a2c
7689141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c39f03
7689141
 
 
 
 
 
 
8200b06
7689141
 
f23bb1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import swin_t
from PIL import Image

# ๐Ÿ”ง Model definition
class MMIM(nn.Module):
    def __init__(self, num_classes=36):
        super(MMIM, self).__init__()
        self.backbone = swin_t(weights='IMAGENET1K_V1')
        self.backbone.head = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# โœ… Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MMIM(num_classes=36)

# ๐Ÿง  Load only matching weights from checkpoint (skip classifier mismatch)
checkpoint = torch.load("MMIM_best.pth", map_location=device)
filtered_checkpoint = {
    k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
}
model.load_state_dict(filtered_checkpoint, strict=False)

model.to(device)
model.eval()

# โœ… class_names mapped according to confusion matrix order
class_names = [
    "Chinee apple",              # class1
    "Black grass",                   # class14
    "Charlock",                  # class15
     "Cleavers",               # class16
     "Common Chickweed",          # class17
     "Common Wheat",               # class18
     "Fat Hen",                  # class19
      "Lanthana",                  # class2
      "Loose Silky bent",           # class20
      "Maize",              # class21
    "Scentless Mayweed",                   # class22
     "Shepherds Purse",          # class23
    "Small-Flowered Cranesbill",                     # class24
    "Sugar beet",         # class25
    "Carpetweeds",           # class26
    "Crabgrass",# class27
    "Eclipta",                # class28
    "Goosegrass",               # class29
    "Negative",                # class3
    "Morningglory",                 # class30
    "Nutsedge",                   # class31
    "Palmer Amarnath",              # class32
    "Prickly Sida",                  # class33
    "Purslane",            # class34
    "Ragweed",               # class35
    "Sicklepod",                  # class36
    "SpottedSpurge",                   # class37
    "SpurredAnoda",                 # class38
    "Swinecress",             # class39
     "Parkinsonia",                # class4
    "Waterhemp",              # class40
   
    "Parthenium",                 # class5
    "Prickly acacia",               # class6
    "Rubber vine",                # class7
    "Siam weed",            # class8
   "Snake weed",                 # class9
]

# ๐Ÿ” Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# ๐Ÿ” Prediction function
def predict(img):
    img = img.convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, pred = torch.max(probs, 1)

    predicted_class = class_names[pred.item()]
    confidence = conf.item() * 100

    if predicted_class.lower() == "negative":
        return f"โš ๏ธ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"

    return f"โœ… Predicted class: {predicted_class}\nConfidence: {confidence:.2f}%"

# ๐ŸŽจ Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Weed Image Classifier",
    description="Upload a weed image to predict its class."
)

interface.launch()