import torch import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr from utils import MMIM, load_all_models, predict_image # ✅ All 25 class labels class_names = [ "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed", # 1–9 (model1) "Broadleaf", "Grass", "Soil", "Soybean", # 10–13 (model3) "Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen", "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse", "Small-flowered Cranesbill", "Sugar beet" # 14–25 (model2) ] # ✅ Image transform for Swin transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # ✅ Load models model1, model2, model3 = load_all_models() # ✅ Gradio prediction function def predict(image): label, confidence = predict_image(image, model1, model2, model3, transform, class_names) if confidence < 0.5: return "⚠️ Prediction uncertain – possibly unknown or low confidence" return f"🧠 Predicted: **{label}** (Confidence: {confidence:.2f})" # ✅ Gradio interface app = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload a weed image"), outputs=gr.Markdown(), title="🌿 Weed Classifier (25 Classes - MMIM)", description="Upload an image to classify it into one of 25 weed categories using 3 Swin-MMIM models." ) if __name__ == "__main__": app.launch()