Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import os | |
| import json | |
| from disease_cure_dict import disease_cure_dict | |
| st.set_page_config(page_title="๐ฑ Plant Disease Detection", layout="wide") | |
| # ------------------------- | |
| # Load Model | |
| # ------------------------- | |
| DEVICE = torch.device("cpu") | |
| def load_model(): | |
| checkpoint = torch.load("model/plant_disease_model.pth", map_location=DEVICE) | |
| class_names = checkpoint["class_names"] | |
| num_classes = len(class_names) | |
| model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
| model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, class_names | |
| model, class_names = load_model() | |
| # ------------------------- | |
| # Transform | |
| # ------------------------- | |
| predict_tf = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # ------------------------- | |
| # Predict Function | |
| # ------------------------- | |
| def predict(image): | |
| img_t = predict_tf(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(img_t) | |
| prob = F.softmax(output, dim=1) | |
| conf, pred = torch.max(prob, 1) | |
| return class_names[pred.item()], float(conf.item()) | |
| # ------------------------- | |
| # Streamlit UI | |
| # ------------------------- | |
| st.title("๐ฟ Plant Disease Detection & Cure Suggestion") | |
| st.write("Upload a plant leaf image to detect the disease and find treatment suggestions.") | |
| uploaded_file = st.file_uploader("Upload Leaf Image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", width=300) | |
| with st.spinner("๐ Analyzing imageโฆ please wait"): | |
| # Prediction | |
| label, confidence = predict(image) | |
| st.subheader("๐ Prediction Result") | |
| st.write(f"**Disease:** {label}") | |
| st.write(f"**Confidence:** {confidence*100:.2f}%") | |
| # Cure | |
| cure = disease_cure_dict.get(label, "No cure information available.") | |
| st.subheader("๐ Suggested Cure") | |
| st.write(cure) | |