File size: 3,766 Bytes
7689141
 
 
 
 
 
 
 
 
a174a2c
7689141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a174a2c
05f67c3
a174a2c
05f67c3
 
 
 
 
 
7689141
 
 
3c39f03
7689141
3c39f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7689141
 
 
 
 
 
 
 
a174a2c
7689141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c39f03
7689141
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import swin_t
from PIL import Image

# ๐Ÿ”ง Model definition
class MMIM(nn.Module):
    def __init__(self, num_classes=36):
        super(MMIM, self).__init__()
        self.backbone = swin_t(weights='IMAGENET1K_V1')
        self.backbone.head = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# โœ… Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MMIM(num_classes=36)

# ๐Ÿง  Load only matching weights from checkpoint (skip classifier mismatch)
checkpoint = torch.load("MMIM_best.pth", map_location=device)
filtered_checkpoint = {
    k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
}
model.load_state_dict(filtered_checkpoint, strict=False)

model.to(device)
model.eval()

# โœ… class_names mapped according to confusion matrix order
class_names = [
    "Chinee apple",              # class1
    "Common Wheat",              # class14
    "Fat Hen",                   # class15
    "Loose Silky-bent",          # class16
    "Maize",                     # class17
    "Scentless Mayweed",         # class18
    "Shepherds purse",           # class19
    "Lantana",                   # class2
    "Small-flowered Cranesbill",# class20
    "Sugar beet",                # class21
    "Carpetweeds",               # class22
    "Crabgrass",                 # class23
    "Eclipta",                   # class24
    "Goosegrass",                # class25
    "Morningglory",              # class26
    "Nutsedge",                  # class27
    "PalmerAmaranth",            # class28
    "Pricky Sida",               # class29
    "Negative",                  # class3
    "Purslane",                  # class30
    "Ragweed",                   # class31
    "Sicklepod",                 # class32
    "SpottedSpurge",             # class33
    "SpurredAnoda",              # class34
    "Swinecress",                # class35
    "Waterhemp",                 # class36
    "Parkinsonia",               # class4
    "Parthenium",                # class5
    "Prickly acacia",            # class6
    "Rubber vine",               # class7
    "Siam weed",                 # class8
    "Snake weed",                # class9
    "Black grass",               # class10
    "Charlock",                  # class11
    "Cleavers",                  # class12
    "Common Chickweed"           # class13
]

# ๐Ÿ” Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# ๐Ÿ” Prediction function
def predict(img):
    img = img.convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, pred = torch.max(probs, 1)

    predicted_class = class_names[pred.item()]
    confidence = conf.item() * 100

    if predicted_class.lower() == "negative":
        return f"โš ๏ธ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"

    return f"โœ… Predicted class: {predicted_class}\nConfidence: {confidence:.2f}%"

# ๐ŸŽจ Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Weed Image Classifier",
    description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
)

interface.launch()