Spaces:
Build error
Build error
| import glob | |
| import streamlit as st | |
| import wget | |
| from PIL import Image | |
| import torch | |
| import cv2 | |
| import os | |
| import time | |
| import textwrap | |
| class_info = { | |
| 'Brownspot': { | |
| 'desc': """ | |
| Brown spot of rice is caused by the fungus Bipolaris oryzae. The disease is characterized by small, | |
| dark brown spots that form on the leaves and sheaths of the plant. Brown spot is often associated | |
| with conditions of poor fertility, particularly low levels of nitrogen and silica. | |
| """, | |
| 'treatment': """ | |
| Improve nutrient management, particularly nitrogen and silica. Fungicides can also be used for treatment. | |
| """ | |
| }, | |
| 'Blast': { | |
| 'desc': """ | |
| Rice blast is one of the most destructive diseases of rice worldwide, caused by the fungus Pyricularia grisea. | |
| The disease can affect all parts of the plant during both vegetative and reproductive growth stages. | |
| Symptoms of rice blast include leaf lesions, rotting of nodes, neck blast, and rotting of the panicle base. | |
| """, | |
| 'treatment': """ | |
| Crop rotation, use of resistant varieties, and proper water management can help control rice blast. | |
| Fungicides may also be used. | |
| """ | |
| }, | |
| 'BacterialBlight': { | |
| 'desc': """ | |
| Bacterial blight of rice is caused by the bacterium Xanthomonas oryzae pv. oryzae. | |
| The disease is characterized by wilting of seedlings, leaf streaking, and leaf blight. | |
| Bacterial blight is often more severe on plants growing in fields with high nitrogen levels and in areas | |
| where rice is grown under flooded conditions. | |
| """, | |
| 'treatment': """ | |
| Use of resistant varieties, proper water management, and avoiding excessive nitrogen fertilization can help | |
| control bacterial blight. Copper-based sprays may also be used as a preventive measure. | |
| """ | |
| } | |
| } | |
| st.set_page_config(layout="wide") | |
| cfg_model_path = 'models/YOLOv5m.pt' | |
| model = None | |
| confidence = .25 | |
| from collections import Counter | |
| def count_classes(results): | |
| class_ids = results.xyxy[0][:, 5].int().tolist() | |
| class_names = [model.names[i] for i in class_ids] | |
| class_counts = dict(Counter(class_names)) | |
| return class_counts | |
| def infer_image(img, size=None): | |
| model.conf = confidence | |
| result = model(img, size=size) if size else model(img) | |
| class_counts = count_classes(result) | |
| result.render() | |
| image = Image.fromarray(result.ims[0]) | |
| return image, class_counts | |
| from PIL import ImageOps | |
| def image_input(data_src): | |
| img_file = None | |
| if data_src == 'Sample data': | |
| # get all sample images | |
| img_path = glob.glob('data/sample_images/*') | |
| img_slider = st.slider("Select a test image.", min_value=1, max_value=len(img_path), step=1) | |
| img_file = img_path[img_slider - 1] | |
| else: | |
| img_bytes = st.sidebar.file_uploader("Upload an image", type=['png', 'jpeg', 'jpg']) | |
| if img_bytes: | |
| img_file = "data/uploaded_data/upload." + img_bytes.name.split('.')[-1] | |
| img = Image.open(img_bytes) | |
| img = ImageOps.fit(img, (640, 640), Image.ANTIALIAS) | |
| img.save(img_file) | |
| if img_file: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(img_file, caption="Selected Image") | |
| with col2: | |
| img, class_counts = infer_image(img_file) | |
| st.image(img, caption="Model prediction") | |
| st.markdown("#### Counts of detected classes:") | |
| for class_name, count in class_counts.items(): | |
| st.markdown(f"**{class_name}:** *{count}*") | |
| st.markdown(f"**Information about {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('desc', ''), width=60)}") | |
| st.markdown(f"**Treatment for {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('treatment', ''), width=60)}") | |
| def load_model(path, device): | |
| model_ = torch.hub.load('ultralytics/yolov5', 'custom', path=path, force_reload=True) | |
| model_.to(device) | |
| print("model to ", device) | |
| return model_ | |
| def download_model(url): | |
| model_file = wget.download(url, out="models") | |
| return model_file | |
| def get_user_model(): | |
| model_src = st.sidebar.radio("Model source", ["file upload", "url"]) | |
| model_file = None | |
| if model_src == "file upload": | |
| model_bytes = st.sidebar.file_uploader("Upload a model file", type=['pt']) | |
| if model_bytes: | |
| model_file = "models/uploaded_" + model_bytes.name | |
| with open(model_file, 'wb') as out: | |
| out.write(model_bytes.read()) | |
| else: | |
| url = st.sidebar.text_input("model url") | |
| if url: | |
| model_file_ = download_model(url) | |
| if model_file_.split(".")[-1] == "pt": | |
| model_file = model_file_ | |
| return model_file | |
| def main(): | |
| # custom CSS | |
| st.markdown(""" | |
| <style> | |
| .reportview-container { | |
| background: #262626; | |
| } | |
| .sidebar .sidebar-content { | |
| background: #262626; | |
| } | |
| h1 { | |
| color: #59d455; | |
| } | |
| .block-container { | |
| color: #f0f0f0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # global variables | |
| global model, confidence, cfg_model_path | |
| st.title("🌾 Rice Leaf Disease Detection Dashboard") | |
| st.sidebar.title("Settings") | |
| # upload model | |
| model_src = st.sidebar.radio("Select yolov5 weight file", ["YOLOv5n", "YOLOv5s", "YOLOv5m", "YOLOv5l", "YOLOv5x"]) | |
| # URL, upload file (max 200 mb) | |
| if model_src == "YOLOv5n": | |
| cfg_model_path = 'models/YOLOv5n.pt' | |
| elif model_src == "YOLOv5s": | |
| cfg_model_path = 'models/YOLOv5s.pt' | |
| elif model_src == "YOLOv5m": | |
| cfg_model_path = 'models/YOLOv5m.pt' | |
| elif model_src == "YOLOv5l": | |
| cfg_model_path = 'models/YOLOv5l.pt' | |
| elif model_src == "YOLOv5x": | |
| cfg_model_path = 'models/YOLOv5x.pt' | |
| # check if model file is available | |
| if not os.path.isfile(cfg_model_path): | |
| st.warning("Model file not available!!!, please added to the model folder.", icon="⚠️") | |
| else: | |
| # device options | |
| if torch.cuda.is_available(): | |
| device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=False, index=0) | |
| else: | |
| device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=True, index=0) | |
| # load model | |
| model = load_model(cfg_model_path, device_option) | |
| # confidence slider | |
| confidence = st.sidebar.slider('Confidence', min_value=0.1, max_value=1.0, value=.45) | |
| # custom classes | |
| if st.sidebar.checkbox("Custom Classes"): | |
| model_names = list(model.names.values()) | |
| assigned_class = st.sidebar.multiselect("Select Classes", model_names, default=[model_names[0]]) | |
| classes = [model_names.index(name) for name in assigned_class] | |
| model.classes = classes | |
| else: | |
| model.classes = list(model.names.keys()) | |
| st.sidebar.markdown("---") | |
| # input src option | |
| data_src = st.sidebar.radio("Select input source: ", ['Sample data', 'Upload your own data']) | |
| image_input(data_src) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except SystemExit: | |
| pass |