Spaces:
Sleeping
Sleeping
File size: 6,713 Bytes
e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 7a0972c e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 8a3a43a e2fc7b7 894f899 e2fc7b7 8a3a43a e2fc7b7 098d2a3 e2fc7b7 8a3a43a e2fc7b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import swin_t
from PIL import Image
# Model definition
class MMIM(nn.Module):
def __init__(self, num_classes=36):
super(MMIM, self).__init__()
self.backbone = swin_t(weights='IMAGENET1K_V1')
self.backbone.head = nn.Identity()
self.classifier = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MMIM(num_classes=36)
checkpoint = torch.load("MMIM_best.pth", map_location=device)
filtered_checkpoint = {
k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape
}
model.load_state_dict(filtered_checkpoint, strict=False)
model.to(device)
model.eval()
# Class names
class_names = [
"Chinee apple", # class1
"Black grass", # class14
"Charlock", # class15
"Cleavers", # class16
"Common Chickweed", # class17
"Common Wheat", # class18
"Fat Hen", # class19
"Lanthana", # class2
"Loose Silky bent", # class20
"Maize", # class21
"Scentless Mayweed", # class22
"Shepherds Purse", # class23
"Small-Flowered Cranesbill", # class24
"Sugar beet", # class25
"Carpetweeds", # class26
"Crabgrass",# class27
"Eclipta", # class28
"Goosegrass", # class29
"Negative", # class3
"Morningglory", # class30
"Nutsedge", # class31
"Palmer Amarnath", # class32
"Prickly Sida", # class33
"Purslane", # class34
"Ragweed", # class35
"Sicklepod", # class36
"SpottedSpurge", # class37
"SpurredAnoda", # class38
"Swinecress", # class39
"Parkinsonia", # class4
"Waterhemp", # class40
"Parthenium", # class5
"Prickly acacia", # class6
"Rubber vine", # class7
"Siam weed", # class8
"Snake weed",
]
# Weed info dictionary
weed_info = {
"Chinee apple": " Invasive shrub. Control by uprooting or herbicide treatment.",
"Black grass": " Infests cereal crops. Remove before seed shedding.",
"Charlock": " Common weed in oilseed crops. Responds to early herbicide.",
"Cleavers": " Sticky climbing weed. Control before flowering.",
"Common Chickweed": " Fast-spreading groundcover weed. Avoid soil disturbance.",
"Common Wheat": " May appear as weed in rotation crops.",
"Fat Hen": " Broadleaf weed. Competes heavily with crops.",
"Lanthana": " Invasive ornamental plant, toxic to livestock.",
"Loose Silky bent": " Grass weed affecting wheat fields.",
"Maize": " Sometimes emerges as volunteer weed post-harvest.",
"Scentless Mayweed": " Strong competitor in cereals. Shallow-rooted.",
"Shepherds Purse": " Common weed in cool seasons. Heart-shaped pods.",
"Small-Flowered Cranesbill": " Low-growing, thrives in dry areas.",
"Sugar beet": " Appears as volunteer in crop fields.",
"Carpetweeds": " Low mat-forming weed. Easy to remove manually.",
"Crabgrass": " Summer annual grass. Thrives in disturbed soil.",
"Eclipta": " Moisture-loving herbaceous weed.",
"Goosegrass": " Mat-forming weed, tough to hand-pull.",
"Negative": " No weed confidently detected. Please recheck input.",
"Morningglory": " Climbing vine, chokes crops quickly.",
"Nutsedge": " Grass-like weed with tubers. Hard to control.",
"Palmer Amarnath": " Highly aggressive and herbicide-resistant.",
"Prickly Sida": " Hairy, thorny stems. Requires early control.",
"Purslane": " Succulent weed, common in warm climates.",
"Ragweed": " Allergen-producing weed. Kill before flowering.",
"Sicklepod": " Toxic to livestock. Control before pod set.",
"SpottedSpurge": " Low-growing. Releases milky sap.",
"SpurredAnoda": " Fast-growing summer annual. Common in cotton.",
"Swinecress": " Strong odor. Grows in compacted soils.",
"Parkinsonia": " Woody shrub. Mechanical removal advised.",
"Waterhemp": " Fast-growing amaranth. Glyphosate-resistant strains exist.",
"Parthenium": " Toxic and invasive. Avoid contact.",
"Prickly acacia": " Thorny shrub. Displaces native plants.",
"Rubber vine": " Woody climber. Toxic and invasive.",
"Siam weed": " Highly invasive in tropical zones.",
"Snake weed": " Woody perennial, toxic to livestock."
}
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Prediction function
def predict(img):
img = img.convert('RGB')
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
conf, pred = torch.max(probs, 1)
predicted_class = class_names[pred.item()]
confidence = conf.item() * 100
if predicted_class.lower() == "negative":
label = f" Predicted as: Negative\nConfidence: {confidence:.2f}%"
elif confidence < 60:
label = f" Low confidence. Possibly Not a Weed\nConfidence: {confidence:.2f}%"
else:
label = f" Predicted class: {predicted_class}\nConfidence: {confidence:.2f}%"
info = weed_info.get(predicted_class, " No additional info available.")
return f"{label}\n\n Info: {info}"
# App description
about_markdown = """
### Weed Classifier — Swin Transformer + MMIM
This tool predicts weed species from images using a Vision Transformer backbone trained with multi-masked image modeling.
- Shows confidence scores
- Flags uncertain or non-weed predictions
- Displays weed info after prediction
- Upload an image
> Tip: Use clear, focused weed images for better results.
"""
# Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Textbox(label="Prediction"),
title=" Weed Image Classifier",
description="A Self- Spervised Learning model for weed image classification.",
article=about_markdown
)
interface.launch()
|