Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,52 +1,21 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
|
| 5 |
from torchvision import transforms
|
| 6 |
from PIL import Image
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
|
| 9 |
-
# β
|
| 10 |
class_names = [
|
| 11 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
|
| 12 |
-
"Rubber vine", "Siam weed", "Snake weed", #
|
| 13 |
-
"Broadleaf", "Grass", "Soil", "Soybean", #
|
| 14 |
"Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen",
|
| 15 |
"Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse",
|
| 16 |
-
"Small-flowered Cranesbill", "Sugar beet" #
|
| 17 |
]
|
| 18 |
|
| 19 |
-
# β
|
| 20 |
-
class MMIM(nn.Module):
|
| 21 |
-
def __init__(self, num_classes):
|
| 22 |
-
super(MMIM, self).__init__()
|
| 23 |
-
self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
|
| 24 |
-
self.backbone.head = nn.Identity()
|
| 25 |
-
self.classifier = nn.Sequential(
|
| 26 |
-
nn.Linear(768, 512),
|
| 27 |
-
nn.ReLU(),
|
| 28 |
-
nn.Dropout(0.3),
|
| 29 |
-
nn.Linear(512, num_classes)
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
def forward(self, x):
|
| 33 |
-
features = self.backbone(x)
|
| 34 |
-
return self.classifier(features)
|
| 35 |
-
|
| 36 |
-
# β
Load 3 models trained on subsets of classes
|
| 37 |
-
model1 = MMIM(num_classes=9) # class 1β9
|
| 38 |
-
model3 = MMIM(num_classes=4) # class 10β13
|
| 39 |
-
model2 = MMIM(num_classes=12) # class 14β25
|
| 40 |
-
|
| 41 |
-
model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
|
| 42 |
-
model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
|
| 43 |
-
model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
|
| 44 |
-
|
| 45 |
-
model1.eval()
|
| 46 |
-
model2.eval()
|
| 47 |
-
model3.eval()
|
| 48 |
-
|
| 49 |
-
# β
Image preprocessing for Swin
|
| 50 |
transform = transforms.Compose([
|
| 51 |
transforms.Resize((224, 224)),
|
| 52 |
transforms.ToTensor(),
|
|
@@ -54,34 +23,26 @@ transform = transforms.Compose([
|
|
| 54 |
std=[0.5, 0.5, 0.5])
|
| 55 |
])
|
| 56 |
|
| 57 |
-
# β
|
| 58 |
-
|
| 59 |
-
image_tensor = transform(image).unsqueeze(0) # shape: [1, 3, 224, 224]
|
| 60 |
-
|
| 61 |
-
with torch.no_grad():
|
| 62 |
-
out1 = F.softmax(model1(image_tensor), dim=1) # [1, 9]
|
| 63 |
-
out3 = F.softmax(model3(image_tensor), dim=1) # [1, 4]
|
| 64 |
-
out2 = F.softmax(model2(image_tensor), dim=1) # [1, 12]
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
confidence = combined[0, pred_idx].item()
|
| 70 |
|
| 71 |
if confidence < 0.5:
|
| 72 |
-
return "Prediction uncertain β possibly unknown or low confidence"
|
| 73 |
|
| 74 |
-
return f"π§ Predicted: **{
|
| 75 |
|
| 76 |
-
# β
Gradio
|
| 77 |
app = gr.Interface(
|
| 78 |
fn=predict,
|
| 79 |
inputs=gr.Image(type="pil", label="Upload a weed image"),
|
| 80 |
outputs=gr.Markdown(),
|
| 81 |
title="πΏ Weed Classifier (25 Classes - MMIM)",
|
| 82 |
-
description="Upload an image to classify it into
|
| 83 |
)
|
| 84 |
|
| 85 |
-
# β
Launch app
|
| 86 |
if __name__ == "__main__":
|
| 87 |
app.launch()
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
|
|
|
| 3 |
from torchvision import transforms
|
| 4 |
from PIL import Image
|
| 5 |
import gradio as gr
|
| 6 |
+
from utils import MMIM, load_all_models, predict_image
|
| 7 |
|
| 8 |
+
# β
All 25 class labels
|
| 9 |
class_names = [
|
| 10 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia",
|
| 11 |
+
"Rubber vine", "Siam weed", "Snake weed", # 1β9 (model1)
|
| 12 |
+
"Broadleaf", "Grass", "Soil", "Soybean", # 10β13 (model3)
|
| 13 |
"Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen",
|
| 14 |
"Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse",
|
| 15 |
+
"Small-flowered Cranesbill", "Sugar beet" # 14β25 (model2)
|
| 16 |
]
|
| 17 |
|
| 18 |
+
# β
Image transform for Swin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
transform = transforms.Compose([
|
| 20 |
transforms.Resize((224, 224)),
|
| 21 |
transforms.ToTensor(),
|
|
|
|
| 23 |
std=[0.5, 0.5, 0.5])
|
| 24 |
])
|
| 25 |
|
| 26 |
+
# β
Load models
|
| 27 |
+
model1, model2, model3 = load_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# β
Gradio prediction function
|
| 30 |
+
def predict(image):
|
| 31 |
+
label, confidence = predict_image(image, model1, model2, model3, transform, class_names)
|
|
|
|
| 32 |
|
| 33 |
if confidence < 0.5:
|
| 34 |
+
return "β οΈ Prediction uncertain β possibly unknown or low confidence"
|
| 35 |
|
| 36 |
+
return f"π§ Predicted: **{label}** (Confidence: {confidence:.2f})"
|
| 37 |
|
| 38 |
+
# β
Gradio interface
|
| 39 |
app = gr.Interface(
|
| 40 |
fn=predict,
|
| 41 |
inputs=gr.Image(type="pil", label="Upload a weed image"),
|
| 42 |
outputs=gr.Markdown(),
|
| 43 |
title="πΏ Weed Classifier (25 Classes - MMIM)",
|
| 44 |
+
description="Upload an image to classify it into one of 25 weed categories using 3 Swin-MMIM models."
|
| 45 |
)
|
| 46 |
|
|
|
|
| 47 |
if __name__ == "__main__":
|
| 48 |
app.launch()
|