Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from utils import MMIM, load_all_models, predict_image | |
| # β All 25 class labels | |
| class_names = [ | |
| "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium", "Prickly acacia", | |
| "Rubber vine", "Siam weed", "Snake weed", # 1β9 (model1) | |
| "Broadleaf", "Grass", "Soil", "Soybean", # 10β13 (model3) | |
| "Black grass", "Charlock", "Cleavers", "Common Chickweed", "Common Wheat", "Fat Hen", | |
| "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds purse", | |
| "Small-flowered Cranesbill", "Sugar beet" # 14β25 (model2) | |
| ] | |
| # β Image transform for Swin | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # β Load models | |
| model1, model2, model3 = load_all_models() | |
| # β Gradio prediction function | |
| def predict(image): | |
| label, confidence = predict_image(image, model1, model2, model3, transform, class_names) | |
| if confidence < 0.5: | |
| return "β οΈ Prediction uncertain β possibly unknown or low confidence" | |
| return f"π§ Predicted: **{label}** (Confidence: {confidence:.2f})" | |
| # β Gradio interface | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload a weed image"), | |
| outputs=gr.Markdown(), | |
| title="πΏ Weed Classifier (25 Classes - MMIM)", | |
| description="Upload an image to classify it into one of 25 weed categories using 3 Swin-MMIM models." | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |