HilmiZr's picture
attempt 00
2d39fd4
import glob
import streamlit as st
import wget
from PIL import Image
import torch
import cv2
import os
import time
import textwrap
class_info = {
'Brownspot': {
'desc': """
Brown spot of rice is caused by the fungus Bipolaris oryzae. The disease is characterized by small,
dark brown spots that form on the leaves and sheaths of the plant. Brown spot is often associated
with conditions of poor fertility, particularly low levels of nitrogen and silica.
""",
'treatment': """
Improve nutrient management, particularly nitrogen and silica. Fungicides can also be used for treatment.
"""
},
'Blast': {
'desc': """
Rice blast is one of the most destructive diseases of rice worldwide, caused by the fungus Pyricularia grisea.
The disease can affect all parts of the plant during both vegetative and reproductive growth stages.
Symptoms of rice blast include leaf lesions, rotting of nodes, neck blast, and rotting of the panicle base.
""",
'treatment': """
Crop rotation, use of resistant varieties, and proper water management can help control rice blast.
Fungicides may also be used.
"""
},
'BacterialBlight': {
'desc': """
Bacterial blight of rice is caused by the bacterium Xanthomonas oryzae pv. oryzae.
The disease is characterized by wilting of seedlings, leaf streaking, and leaf blight.
Bacterial blight is often more severe on plants growing in fields with high nitrogen levels and in areas
where rice is grown under flooded conditions.
""",
'treatment': """
Use of resistant varieties, proper water management, and avoiding excessive nitrogen fertilization can help
control bacterial blight. Copper-based sprays may also be used as a preventive measure.
"""
}
}
st.set_page_config(layout="wide")
cfg_model_path = 'models/YOLOv5m.pt'
model = None
confidence = .25
from collections import Counter
def count_classes(results):
class_ids = results.xyxy[0][:, 5].int().tolist()
class_names = [model.names[i] for i in class_ids]
class_counts = dict(Counter(class_names))
return class_counts
def infer_image(img, size=None):
model.conf = confidence
result = model(img, size=size) if size else model(img)
class_counts = count_classes(result)
result.render()
image = Image.fromarray(result.ims[0])
return image, class_counts
from PIL import ImageOps
def image_input(data_src):
img_file = None
if data_src == 'Sample data':
# get all sample images
img_path = glob.glob('data/sample_images/*')
img_slider = st.slider("Select a test image.", min_value=1, max_value=len(img_path), step=1)
img_file = img_path[img_slider - 1]
else:
img_bytes = st.sidebar.file_uploader("Upload an image", type=['png', 'jpeg', 'jpg'])
if img_bytes:
img_file = "data/uploaded_data/upload." + img_bytes.name.split('.')[-1]
img = Image.open(img_bytes)
img = ImageOps.fit(img, (640, 640), Image.ANTIALIAS)
img.save(img_file)
if img_file:
col1, col2 = st.columns(2)
with col1:
st.image(img_file, caption="Selected Image")
with col2:
img, class_counts = infer_image(img_file)
st.image(img, caption="Model prediction")
st.markdown("#### Counts of detected classes:")
for class_name, count in class_counts.items():
st.markdown(f"**{class_name}:** *{count}*")
st.markdown(f"**Information about {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('desc', ''), width=60)}")
st.markdown(f"**Treatment for {class_name}:**\n{textwrap.fill(class_info.get(class_name, {}).get('treatment', ''), width=60)}")
@st.cache_resource
def load_model(path, device):
model_ = torch.hub.load('ultralytics/yolov5', 'custom', path=path, force_reload=True)
model_.to(device)
print("model to ", device)
return model_
@st.cache_resource
def download_model(url):
model_file = wget.download(url, out="models")
return model_file
def get_user_model():
model_src = st.sidebar.radio("Model source", ["file upload", "url"])
model_file = None
if model_src == "file upload":
model_bytes = st.sidebar.file_uploader("Upload a model file", type=['pt'])
if model_bytes:
model_file = "models/uploaded_" + model_bytes.name
with open(model_file, 'wb') as out:
out.write(model_bytes.read())
else:
url = st.sidebar.text_input("model url")
if url:
model_file_ = download_model(url)
if model_file_.split(".")[-1] == "pt":
model_file = model_file_
return model_file
def main():
# custom CSS
st.markdown("""
<style>
.reportview-container {
background: #262626;
}
.sidebar .sidebar-content {
background: #262626;
}
h1 {
color: #59d455;
}
.block-container {
color: #f0f0f0;
}
</style>
""", unsafe_allow_html=True)
# global variables
global model, confidence, cfg_model_path
st.title("🌾 Rice Leaf Disease Detection Dashboard")
st.sidebar.title("Settings")
# upload model
model_src = st.sidebar.radio("Select yolov5 weight file", ["YOLOv5n", "YOLOv5s", "YOLOv5m", "YOLOv5l", "YOLOv5x"])
# URL, upload file (max 200 mb)
if model_src == "YOLOv5n":
cfg_model_path = 'models/YOLOv5n.pt'
elif model_src == "YOLOv5s":
cfg_model_path = 'models/YOLOv5s.pt'
elif model_src == "YOLOv5m":
cfg_model_path = 'models/YOLOv5m.pt'
elif model_src == "YOLOv5l":
cfg_model_path = 'models/YOLOv5l.pt'
elif model_src == "YOLOv5x":
cfg_model_path = 'models/YOLOv5x.pt'
# check if model file is available
if not os.path.isfile(cfg_model_path):
st.warning("Model file not available!!!, please added to the model folder.", icon="⚠️")
else:
# device options
if torch.cuda.is_available():
device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=False, index=0)
else:
device_option = st.sidebar.radio("Select Device", ['cpu', 'cuda'], disabled=True, index=0)
# load model
model = load_model(cfg_model_path, device_option)
# confidence slider
confidence = st.sidebar.slider('Confidence', min_value=0.1, max_value=1.0, value=.45)
# custom classes
if st.sidebar.checkbox("Custom Classes"):
model_names = list(model.names.values())
assigned_class = st.sidebar.multiselect("Select Classes", model_names, default=[model_names[0]])
classes = [model_names.index(name) for name in assigned_class]
model.classes = classes
else:
model.classes = list(model.names.keys())
st.sidebar.markdown("---")
# input src option
data_src = st.sidebar.radio("Select input source: ", ['Sample data', 'Upload your own data'])
image_input(data_src)
if __name__ == "__main__":
try:
main()
except SystemExit:
pass