File size: 6,713 Bytes
e2fc7b7
 
 
 
 
 
 
8a3a43a
e2fc7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3a43a
e2fc7b7
 
 
 
 
 
 
 
 
 
8a3a43a
e2fc7b7
7a0972c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2fc7b7
 
8a3a43a
e2fc7b7
8a3a43a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2fc7b7
 
8a3a43a
e2fc7b7
 
 
 
 
8a3a43a
e2fc7b7
 
 
 
 
 
 
 
 
 
 
 
 
8a3a43a
e2fc7b7
8a3a43a
e2fc7b7
8a3a43a
e2fc7b7
8a3a43a
 
e2fc7b7
8a3a43a
e2fc7b7
8a3a43a
e2fc7b7
 
8a3a43a
 
 
 
e2fc7b7
894f899
e2fc7b7
 
8a3a43a
e2fc7b7
 
098d2a3
e2fc7b7
8a3a43a
 
e2fc7b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import swin_t
from PIL import Image

#  Model definition
class MMIM(nn.Module):
    def __init__(self, num_classes=36):
        super(MMIM, self).__init__()
        self.backbone = swin_t(weights='IMAGENET1K_V1')
        self.backbone.head = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

#  Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MMIM(num_classes=36)
checkpoint = torch.load("MMIM_best.pth", map_location=device)
filtered_checkpoint = {
    k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
}
model.load_state_dict(filtered_checkpoint, strict=False)
model.to(device)
model.eval()

#  Class names
class_names = [
   "Chinee apple",              # class1
    "Black grass",                   # class14
    "Charlock",                  # class15
     "Cleavers",               # class16
     "Common Chickweed",          # class17
     "Common Wheat",               # class18
     "Fat Hen",                  # class19
      "Lanthana",                  # class2
      "Loose Silky bent",           # class20
      "Maize",              # class21
    "Scentless Mayweed",                   # class22
     "Shepherds Purse",          # class23
    "Small-Flowered Cranesbill",                     # class24
    "Sugar beet",         # class25
    "Carpetweeds",           # class26
    "Crabgrass",# class27
    "Eclipta",                # class28
    "Goosegrass",               # class29
    "Negative",                # class3
    "Morningglory",                 # class30
    "Nutsedge",                   # class31
    "Palmer Amarnath",              # class32
    "Prickly Sida",                  # class33
    "Purslane",            # class34
    "Ragweed",               # class35
    "Sicklepod",                  # class36
    "SpottedSpurge",                   # class37
    "SpurredAnoda",                 # class38
    "Swinecress",             # class39
     "Parkinsonia",                # class4
    "Waterhemp",              # class40
   
    "Parthenium",                 # class5
    "Prickly acacia",               # class6
    "Rubber vine",                # class7
    "Siam weed",            # class8
   "Snake weed",  
]

#  Weed info dictionary
weed_info = {
    "Chinee apple": " Invasive shrub. Control by uprooting or herbicide treatment.",
    "Black grass": " Infests cereal crops. Remove before seed shedding.",
    "Charlock": " Common weed in oilseed crops. Responds to early herbicide.",
    "Cleavers": " Sticky climbing weed. Control before flowering.",
    "Common Chickweed": " Fast-spreading groundcover weed. Avoid soil disturbance.",
    "Common Wheat": " May appear as weed in rotation crops.",
    "Fat Hen": " Broadleaf weed. Competes heavily with crops.",
    "Lanthana": " Invasive ornamental plant, toxic to livestock.",
    "Loose Silky bent": " Grass weed affecting wheat fields.",
    "Maize": " Sometimes emerges as volunteer weed post-harvest.",
    "Scentless Mayweed": " Strong competitor in cereals. Shallow-rooted.",
    "Shepherds Purse": " Common weed in cool seasons. Heart-shaped pods.",
    "Small-Flowered Cranesbill": " Low-growing, thrives in dry areas.",
    "Sugar beet": " Appears as volunteer in crop fields.",
    "Carpetweeds": " Low mat-forming weed. Easy to remove manually.",
    "Crabgrass": " Summer annual grass. Thrives in disturbed soil.",
    "Eclipta": " Moisture-loving herbaceous weed.",
    "Goosegrass": " Mat-forming weed, tough to hand-pull.",
    "Negative": " No weed confidently detected. Please recheck input.",
    "Morningglory": " Climbing vine, chokes crops quickly.",
    "Nutsedge": " Grass-like weed with tubers. Hard to control.",
    "Palmer Amarnath": " Highly aggressive and herbicide-resistant.",
    "Prickly Sida": " Hairy, thorny stems. Requires early control.",
    "Purslane": " Succulent weed, common in warm climates.",
    "Ragweed": " Allergen-producing weed. Kill before flowering.",
    "Sicklepod": " Toxic to livestock. Control before pod set.",
    "SpottedSpurge": " Low-growing. Releases milky sap.",
    "SpurredAnoda": " Fast-growing summer annual. Common in cotton.",
    "Swinecress": " Strong odor. Grows in compacted soils.",
    "Parkinsonia": " Woody shrub. Mechanical removal advised.",
    "Waterhemp": " Fast-growing amaranth. Glyphosate-resistant strains exist.",
    "Parthenium": " Toxic and invasive. Avoid contact.",
    "Prickly acacia": " Thorny shrub. Displaces native plants.",
    "Rubber vine": " Woody climber. Toxic and invasive.",
    "Siam weed": " Highly invasive in tropical zones.",
    "Snake weed": " Woody perennial, toxic to livestock."
}

#  Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

#  Prediction function
def predict(img):
    img = img.convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)

    conf, pred = torch.max(probs, 1)
    predicted_class = class_names[pred.item()]
    confidence = conf.item() * 100

    if predicted_class.lower() == "negative":
        label = f" Predicted as: Negative\nConfidence: {confidence:.2f}%"
    elif confidence < 60:
        label = f" Low confidence. Possibly Not a Weed\nConfidence: {confidence:.2f}%"
    else:
        label = f" Predicted class: {predicted_class}\nConfidence: {confidence:.2f}%"

    info = weed_info.get(predicted_class, " No additional info available.")
    return f"{label}\n\n Info: {info}"

#  App description
about_markdown = """
###  Weed Classifier — Swin Transformer + MMIM  
This tool predicts weed species from images using a Vision Transformer backbone trained with multi-masked image modeling.

-  Shows confidence scores  
-  Flags uncertain or non-weed predictions  
-  Displays weed info after prediction  
-  Upload an image

> Tip: Use clear, focused weed images for better results.
"""

#  Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Textbox(label="Prediction"),
    title=" Weed Image Classifier",
    description="A Self- Spervised Learning model for weed image classification.",
    article=about_markdown
)

interface.launch()