import glob import streamlit as st import wget from PIL import Image import torch import cv2 import os import time import textwrap class_info = { 'Brownspot': { 'desc': """ Brown spot of rice is caused by the fungus Bipolaris oryzae. The disease is characterized by small, dark brown spots that form on the leaves and sheaths of the plant. Brown spot is often associated with conditions of poor fertility, particularly low levels of nitrogen and silica. """, 'treatment': """ Improve nutrient management, particularly nitrogen and silica. Fungicides can also be used for treatment. """ }, 'Blast': { 'desc': """ Rice blast is one of the most destructive diseases of rice worldwide, caused by the fungus Pyricularia grisea. The disease can affect all parts of the plant during both vegetative and reproductive growth stages. Symptoms of rice blast include leaf lesions, rotting of nodes, neck blast, and rotting of the panicle base. """, 'treatment': """ Crop rotation, use of resistant varieties, and proper water management can help control rice blast. Fungicides may also be used. """ }, 'BacterialBlight': { 'desc': """ Bacterial blight of rice is caused by the bacterium Xanthomonas oryzae pv. oryzae. The disease is characterized by wilting of seedlings, leaf streaking, and leaf blight. Bacterial blight is often more severe on plants growing in fields with high nitrogen levels and in areas where rice is grown under flooded conditions. """, 'treatment': """ Use of resistant varieties, proper water management, and avoiding excessive nitrogen fertilization can help control bacterial blight. Copper-based sprays may also be used as a preventive measure. """ } } st.set_page_config(layout="wide") cfg_model_path = 'models/YOLOv5m.pt' model = None confidence = .25 from collections import Counter def count_classes(results): class_ids = results.xyxy[0][:, 5].int().tolist() class_names = [model.names[i] for i in class_ids] class_counts = dict(Counter(class_names)) return class_counts def infer_image(img, size=None): model.conf = confidence result = model(img, size=size) if size else model(img) class_counts = count_classes(result) result.render() image = Image.fromarray(result.ims[0]) return image, class_counts from PIL import ImageOps def image_input(data_src): img_file = None if data_src == 'Sample data': # get all sample images img_path = glob.glob('data/sample_images/*') img_slider = st.slider("Select a test image.", min_value=1, max_value=len(img_path), step=1) img_file = img_path[img_slider - 1] else: img_bytes = st.sidebar.file_uploader("Upload an image", type=['png', 'jpeg', 'jpg']) if img_bytes: img_file = "data/uploaded_data/upload." + img_bytes.name.split('.')[-1] img = Image.open(img_bytes) img = ImageOps.fit(img, (640, 640), Image.ANTIALIAS) img.save(img_file) if img_file: col1, col2 = st.columns(2) with col1: st.image(img_file, caption="Selected Image") with col2: img, class_counts = infer_image(img_file) st.image(img, caption="Model prediction") st.markdown("#### Counts of detected classes:") for class_name, count in class_counts.items(): st.markdown(f"**{class_name}:** *{count}*") st.markdown(f"**Information about {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('desc', ''), width=60)}") st.markdown(f"**Treatment for {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('treatment', ''), width=60)}") @st.cache_resource def load_model(path, device): model_ = torch.hub.load('ultralytics/yolov5', 'custom', path=path, force_reload=True) model_.to(device) print("model to ", device) return model_ @st.cache_resource def download_model(url): model_file = wget.download(url, out="models") return model_file def get_user_model(): model_src = st.sidebar.radio("Model source", ["file upload", "url"]) model_file = None if model_src == "file upload": model_bytes = st.sidebar.file_uploader("Upload a model file", type=['pt']) if model_bytes: model_file = "models/uploaded_" + model_bytes.name with open(model_file, 'wb') as out: out.write(model_bytes.read()) else: url = st.sidebar.text_input("model url") if url: model_file_ = download_model(url) if model_file_.split(".")[-1] == "pt": model_file = model_file_ return model_file def main(): # custom CSS st.markdown(""" """, unsafe_allow_html=True) # global variables global model, confidence, cfg_model_path st.title("🌾 Rice Leaf Disease Detection Dashboard") st.sidebar.title("Settings") # upload model model_src = st.sidebar.radio("Select yolov5 weight file", ["YOLOv5n", "YOLOv5s", "YOLOv5m", "YOLOv5l", "YOLOv5x"]) # URL, upload file (max 200 mb) if model_src == "YOLOv5n": cfg_model_path = 'models/YOLOv5n.pt' elif model_src == "YOLOv5s": cfg_model_path = 'models/YOLOv5s.pt' elif model_src == "YOLOv5m": cfg_model_path = 'models/YOLOv5m.pt' elif model_src == "YOLOv5l": cfg_model_path = 'models/YOLOv5l.pt' elif model_src == "YOLOv5x": cfg_model_path = 'models/YOLOv5x.pt' # check if model file is available if not os.path.isfile(cfg_model_path): st.warning("Model file not available!!!, please added to the model folder.", icon="⚠️") else: # device options if torch.cuda.is_available(): device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=False, index=0) else: device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=True, index=0) # load model model = load_model(cfg_model_path, device_option) # confidence slider confidence = st.sidebar.slider('Confidence', min_value=0.1, max_value=1.0, value=.45) # custom classes if st.sidebar.checkbox("Custom Classes"): model_names = list(model.names.values()) assigned_class = st.sidebar.multiselect("Select Classes", model_names, default=[model_names[0]]) classes = [model_names.index(name) for name in assigned_class] model.classes = classes else: model.classes = list(model.names.keys()) st.sidebar.markdown("---") # input src option data_src = st.sidebar.radio("Select input source: ", ['Sample data', 'Upload your own data']) image_input(data_src) if __name__ == "__main__": try: main() except SystemExit: pass