Rasleen's picture
Update app.py
1b4ad33 verified
import streamlit as st
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import os
import json
from disease_cure_dict import disease_cure_dict
st.set_page_config(page_title="๐ŸŒฑ Plant Disease Detection", layout="wide")
# -------------------------
# Load Model
# -------------------------
DEVICE = torch.device("cpu")
@st.cache_resource
def load_model():
checkpoint = torch.load("model/plant_disease_model.pth", map_location=DEVICE)
class_names = checkpoint["class_names"]
num_classes = len(class_names)
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(checkpoint["model_state"])
model.to(DEVICE)
model.eval()
return model, class_names
model, class_names = load_model()
# -------------------------
# Transform
# -------------------------
predict_tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
# -------------------------
# Predict Function
# -------------------------
def predict(image):
img_t = predict_tf(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_t)
prob = F.softmax(output, dim=1)
conf, pred = torch.max(prob, 1)
return class_names[pred.item()], float(conf.item())
# -------------------------
# Streamlit UI
# -------------------------
st.title("๐ŸŒฟ Plant Disease Detection & Cure Suggestion")
st.write("Upload a plant leaf image to detect the disease and find treatment suggestions.")
uploaded_file = st.file_uploader("Upload Leaf Image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", width=300)
with st.spinner("๐Ÿ” Analyzing imageโ€ฆ please wait"):
# Prediction
label, confidence = predict(image)
st.subheader("๐Ÿ” Prediction Result")
st.write(f"**Disease:** {label}")
st.write(f"**Confidence:** {confidence*100:.2f}%")
# Cure
cure = disease_cure_dict.get(label, "No cure information available.")
st.subheader("๐Ÿ’Š Suggested Cure")
st.write(cure)