jacobbarkow commited on
Commit
015558b
·
1 Parent(s): 55286c8

adding full MVP features for diagnosis and treatment

Browse files
Files changed (3) hide show
  1. app.py +100 -8
  2. best_model.pth → model.pth +2 -2
  3. predictor.py +3 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import base64
 
3
  from predictor import predict
4
  import torch
5
  import numpy as np
@@ -27,20 +28,93 @@ def add_bg_from_local(image_file):
27
  def header_white_bg(text, fontsize = 40, bold = True):
28
  st.markdown(
29
  f"""
30
- <span style="background:rgba(255, 255, 255, 0.8); font-size:{fontsize}px; font-weight:{"bold" if bold else "normal"}">{text}</span>
31
  """,
32
  unsafe_allow_html=True
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def diagnose_health(file):
36
  prediction = predict(file)
37
- predicted_strings = []
38
- for p in prediction:
39
- predicted_string = f"{p['predicted']}, Probability: {float(p['probability']):.2f}"
40
- predicted_strings.append(predicted_string)
41
- return predicted_strings
42
 
43
  def app():
 
44
  add_bg_from_local('assets/background.png')
45
  header_white_bg(f'<span style="color:green">Plant</span><span style="color:orange">Dx</span><span style="color:green">: Diagnosis in a Snap!</span> ')
46
 
@@ -56,8 +130,26 @@ def app():
56
  if st.button("Get Diagnosis"):
57
  if uploaded_file is not None:
58
  # Diagnose plant health and display results
59
- result = diagnose_health(uploaded_file)
60
- st.success(f"Your plant is {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
  st.warning("Please upload an image of your plant first")
63
 
 
1
  import streamlit as st
2
  import base64
3
+ import regex as re
4
  from predictor import predict
5
  import torch
6
  import numpy as np
 
28
  def header_white_bg(text, fontsize = 40, bold = True):
29
  st.markdown(
30
  f"""
31
+ <span style="background:rgba(255, 255, 255, 0.8); font-size:{fontsize}px; font-weight:{"bold" if bold else "normal"}; line-height: 2em">{text}</span>
32
  """,
33
  unsafe_allow_html=True
34
  )
35
 
36
+ def header_red_bg(text, fontsize = 40, bold = True):
37
+ st.markdown(
38
+ f"""
39
+ <span style="background:rgba(184, 35, 35, 0.5); font-size:{fontsize}px; font-weight:{"bold" if bold else "normal"}; line-height: 1.25">{text}</span>
40
+ """,
41
+ unsafe_allow_html=True
42
+ )
43
+
44
+ def header_green_bg(text, fontsize = 40, bold = True):
45
+ st.markdown(
46
+ f"""
47
+ <span style="background:rgba(26, 153, 58, 0.5); font-size:{fontsize}px; font-weight:{"bold" if bold else "normal"}; line-height: 1.25">{text}</span>
48
+ """,
49
+ unsafe_allow_html=True
50
+ )
51
+
52
+ def plant_treatment_message(predicted_string):
53
+ if predicted_string == "Apple___Apple_scab":
54
+ return "Remove the infected leaves and fruit and apply a fungicide to prevent it from spreading."
55
+ elif predicted_string == "Apple___Black_rot":
56
+ return "Remove the infected branches and fruit and apply a fungicide to prevent it from spreading."
57
+ elif predicted_string == "Apple___Cedar_apple_rust":
58
+ return "Remove the infected branches and apply a fungicide to prevent it from spreading."
59
+ elif predicted_string == "Cherry_(including_sour)___Powdery_mildew":
60
+ return "Remove the infected leaves and apply a fungicide to prevent it from spreading."
61
+ elif predicted_string == "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot":
62
+ return "Remove the infected leaves and apply a fungicide to prevent it from spreading."
63
+ elif predicted_string == "Corn_(maize)___Common_rust_":
64
+ return "Remove the infected leaves and apply a fungicide to prevent it from spreading."
65
+ elif predicted_string == "Corn_(maize)___Northern_Leaf_Blight":
66
+ return "Remove the infected leaves and apply a fungicide to prevent it from spreading."
67
+ elif predicted_string == "Grape___Black_rot":
68
+ return "Remove the infected branches and fruit and apply a fungicide to prevent it from spreading."
69
+ elif predicted_string == "Grape___Esca_(Black_Measles)":
70
+ return "Remove the infected branches and apply a fungicide to prevent it from spreading."
71
+ elif predicted_string == "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)":
72
+ return "Remove the infected leaves and apply a fungicide to prevent it from spreading."
73
+ elif predicted_string == "Orange___Haunglongbing_(Citrus_greening)":
74
+ return "Remove the infected branches and apply a pesticide to prevent it from spreading."
75
+ elif predicted_string == "Peach___Bacterial_spot":
76
+ return "Remove the infected leaves and apply a copper fungicide to prevent it from spreading."
77
+ elif predicted_string == "Squash___Powdery_mildew":
78
+ return "This is a fungal disease that can cause white powdery spots on leaves and fruit. Consider removing infected plant parts and treating with a fungicide."
79
+ elif predicted_string == "Strawberry___Leaf_scorch":
80
+ return "This can be caused by drought, sunburn, or fungal diseases. Make sure Consider removing infected plant parts and treating with a fungicide."
81
+ elif predicted_string == "Tomato___Bacterial_spot":
82
+ return "This is a bacterial disease that can cause spots on leaves and fruit. Consider removing infected plant parts and treating with a copper-based fungicide."
83
+ elif predicted_string == "Tomato___Early_blight":
84
+ return "This is a fungal disease that can cause dark spots on leaves and stems. Consider removing infected plant parts and treating with a fungicide."
85
+ elif predicted_string == "Tomato___Late_blight":
86
+ return "This is a fungal disease that can cause rapid decay of foliage and fruit. Consider removing infected plant parts and treating with a fungicide."
87
+ elif predicted_string == "Tomato___Leaf_Mold":
88
+ return "This is a fungal disease that can cause brown spots on leaves. Consider removing infected plant parts and treating with a fungicide."
89
+ elif predicted_string == "Tomato___Septoria_leaf_spot":
90
+ return "This is a fungal disease that can cause brown spots with a yellow halo on leaves. Consider removing infected plant parts and treating with a fungicide."
91
+ elif predicted_string == "Tomato___Spider_mites Two-spotted_spider_mite":
92
+ return "These are tiny pests that can cause yellow spots on leaves and webbing. Consider removing infected plant parts and treating with an insecticide."
93
+ elif predicted_string == "Tomato___Target_Spot":
94
+ return "This is a fungal disease that can cause circular spots with a bullseye pattern on leaves. Consider removing infected plant parts and treating with a fungicide."
95
+ elif predicted_string == "Tomato___Tomato_Yellow_Leaf_Curl_Virus":
96
+ return "This is a viral disease that can cause yellowing and curling of leaves. Consider treating with a fungicide."
97
+
98
+ def clean_prediction(prediction):
99
+ pattern = re.compile('(.*)___(.*)')
100
+ clean_predictions = []
101
+ for p in prediction:
102
+ r = pattern.search(p['predicted'])
103
+ plant = r.groups()[0].replace('_', ' ').lower()
104
+ diagnosis = r.groups()[1].replace('_', ' ').lower()
105
+ treatment = plant_treatment_message(p['predicted']) if diagnosis is not 'healthy' else None
106
+ clean_predictions.append([plant, diagnosis, "{0:.1f}%".format(float(p['probability']) * 100), treatment])
107
+
108
+ clean_predictions.sort(key=lambda x: x[2], reverse=True)
109
+ return clean_predictions
110
+
111
  def diagnose_health(file):
112
  prediction = predict(file)
113
+ clean_predictions = clean_prediction(prediction)
114
+ return clean_predictions
 
 
 
115
 
116
  def app():
117
+
118
  add_bg_from_local('assets/background.png')
119
  header_white_bg(f'<span style="color:green">Plant</span><span style="color:orange">Dx</span><span style="color:green">: Diagnosis in a Snap!</span> ')
120
 
 
130
  if st.button("Get Diagnosis"):
131
  if uploaded_file is not None:
132
  # Diagnose plant health and display results
133
+ results = diagnose_health(uploaded_file)
134
+
135
+ if results[0][1] == 'healthy':
136
+ header_green_bg(f"We believe this is a healthy {results[0][0]} plant with {results[0][2]} confidence. Keep up the good work with proper watering, sunlight, and nutrients.", fontsize=32, bold=False)
137
+ else:
138
+ header_red_bg(f"We believe this is an unhealthy {results[0][0]} plant with {results[0][1]}, with {results[0][2]} confidence. {results[0][3]}", fontsize=32, bold=False)
139
+
140
+ if len(results) > 1:
141
+ header_white_bg("Other potential diagnoses: ", fontsize=24)
142
+
143
+ for p in range(1, len(results)):
144
+ if results[p][1] == 'healthy':
145
+ header_white_bg(
146
+ f"A healthy {results[p][0]} plant, {results[p][2]} confidence.",
147
+ fontsize=20, bold=False)
148
+ else:
149
+ header_white_bg(
150
+ f"An unhealthy {results[p][0]} plant with {results[p][1]}, {results[p][2]} confidence. {results[p][3] if results [p][3] else ''}",
151
+ fontsize=20, bold=False)
152
+
153
  else:
154
  st.warning("Please upload an image of your plant first")
155
 
best_model.pth → model.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ad9f39b9178b952df9bf40b334b72cee838fd7b0821d6da40f00cd1e3fe637c3
3
- size 21452911
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b141df83b46c70fcf350d4ee34e9d9943c6347b11d9cf9df1619eea5d5d97470
3
+ size 21454657
predictor.py CHANGED
@@ -7,7 +7,7 @@ def predict(image_file):
7
 
8
  #load model with params
9
  model = models.efficientnet_b0(weights=None)
10
- model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')), strict=False)
11
  device = torch.device('cpu')
12
 
13
  class_names = [
@@ -79,8 +79,8 @@ def predict(image_file):
79
  label_indx = indices[0][j]
80
  # print(label_indx)
81
  class_name = class_names[label_indx]
82
- tmp_dct["predicted"] = class_name
83
- tmp_dct["probability"] = probs[0][j]
84
  tmp_lst.append(tmp_dct)
85
 
86
  # print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")
 
7
 
8
  #load model with params
9
  model = models.efficientnet_b0(weights=None)
10
+ model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')), strict=False)
11
  device = torch.device('cpu')
12
 
13
  class_names = [
 
79
  label_indx = indices[0][j]
80
  # print(label_indx)
81
  class_name = class_names[label_indx]
82
+ tmp_dct["predicted"] = class_name
83
+ tmp_dct["probability"] = probs[0][j]
84
  tmp_lst.append(tmp_dct)
85
 
86
  # print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")