edcelbogs commited on
Commit
cc23aa1
·
verified ·
1 Parent(s): 186fa1b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +183 -184
src/streamlit_app.py CHANGED
@@ -1,185 +1,184 @@
1
- import streamlit as st
2
- from inference_sdk import InferenceHTTPClient
3
- import requests
4
- from PIL import Image
5
- import tempfile
6
- import os
7
- import cv2
8
- import numpy as np
9
-
10
- # -----------------------------
11
- # PAGE CONFIG
12
- # -----------------------------
13
- st.set_page_config(page_title="Cassava Disease Detection", layout="centered")
14
-
15
- # -----------------------------
16
- # CONFIGURATION
17
- # -----------------------------
18
- ROBOFLOW_API_KEY = st.secrets["ROBOFLOW_API_KEY"]
19
- OPENROUTER_API_KEY = st.secrets["OPENROUTER_API_KEY"]
20
-
21
- MODEL_ID = "cassavadisease/1"
22
- ROBOFLOW_API_URL = "https://serverless.roboflow.com"
23
-
24
- # -----------------------------
25
- # INITIALIZE ROBOFLOW CLIENT
26
- # -----------------------------
27
- CLIENT = InferenceHTTPClient(
28
- api_url=ROBOFLOW_API_URL,
29
- api_key=ROBOFLOW_API_KEY
30
- )
31
-
32
- # -----------------------------
33
- # FUNCTION: AI EXPLANATION
34
- # -----------------------------
35
- def get_ai_explanation(disease_name):
36
-
37
- prompt = f"""
38
- Explain briefly the cassava disease: {disease_name}.
39
- Include:
40
- - Cause
41
- - Main Symptoms
42
- - Prevention
43
- - Treatment
44
- Keep answer short.
45
- """
46
-
47
- response = requests.post(
48
- "https://openrouter.ai/api/v1/chat/completions",
49
- headers={
50
- "Authorization": f"Bearer {OPENROUTER_API_KEY}",
51
- "Content-Type": "application/json",
52
- "HTTP-Referer": "http://localhost:8501",
53
- "X-Title": "Cassava Disease Detection App"
54
- },
55
- json={
56
- "model": "minimax/minimax-m2.5",
57
- "messages": [
58
- {"role": "user", "content": prompt}
59
- ],
60
- "max_tokens": 800,
61
- "temperature": 0.3
62
- }
63
- )
64
-
65
- if response.status_code != 200:
66
- return f"OpenRouter API Error:\n{response.text}"
67
-
68
- result = response.json()
69
-
70
- if "choices" not in result:
71
- return f"Unexpected API Response:\n{result}"
72
-
73
- return result["choices"][0]["message"]["content"]
74
-
75
-
76
- # -----------------------------
77
- # UI
78
- # -----------------------------
79
- st.title("Cassava Disease Detection Web App")
80
- st.write("Upload or capture a cassava leaf image for disease detection.")
81
-
82
- source = st.radio("Select Image Source:", ["Upload Image", "Use Camera"])
83
-
84
- image = None
85
-
86
- if source == "Upload Image":
87
- uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
88
- if uploaded_file:
89
- image = Image.open(uploaded_file)
90
-
91
- elif source == "Use Camera":
92
- camera_photo = st.camera_input("Take a picture of the cassava leaf")
93
- if camera_photo:
94
- image = Image.open(camera_photo)
95
-
96
- # -----------------------------
97
- # MAIN PROCESS
98
- # -----------------------------
99
- if image is not None:
100
-
101
- st.image(image, caption="Captured Image", use_container_width=True)
102
-
103
- # Save temp image
104
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
105
- image.save(tmp.name)
106
- temp_path = tmp.name
107
-
108
- # Roboflow inference
109
- with st.spinner("Analyzing image..."):
110
- result = CLIENT.infer(temp_path, model_id=MODEL_ID)
111
-
112
- os.remove(temp_path)
113
-
114
- predictions = result.get("predictions", [])
115
-
116
- if predictions:
117
-
118
- img_cv = np.array(image)
119
- img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2BGR)
120
-
121
- for pred in predictions:
122
-
123
- x = pred["x"]
124
- y = pred["y"]
125
- w = pred["width"]
126
- h = pred["height"]
127
- label = pred["class"]
128
- confidence = round(pred["confidence"] * 100, 2)
129
-
130
- # Convert center to corner format
131
- x1 = int(x - w / 2)
132
- y1 = int(y - h / 2)
133
- x2 = int(x + w / 2)
134
- y2 = int(y + h / 2)
135
-
136
- # Draw bounding box
137
- cv2.rectangle(img_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
138
-
139
- # Label background
140
- cv2.rectangle(img_cv, (x1, y1 - 30), (x1 + 250, y1), (0, 255, 0), -1)
141
-
142
- # Label text
143
- cv2.putText(
144
- img_cv,
145
- f"{label} ({confidence}%)",
146
- (x1 + 5, y1 - 10),
147
- cv2.FONT_HERSHEY_SIMPLEX,
148
- 0.6,
149
- (0, 0, 0),
150
- 2
151
- )
152
-
153
- # Convert back to RGB
154
- img_display = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
155
- st.image(img_display, caption="Detected & Labeled Image", use_container_width=True)
156
-
157
- # Get highest confidence prediction
158
- top_prediction = max(predictions, key=lambda x: x["confidence"])
159
- disease_name = top_prediction["class"]
160
- confidence = round(top_prediction["confidence"] * 100, 2)
161
-
162
- st.success(f"Detected: **{disease_name}**")
163
- st.info(f"Confidence: {confidence}%")
164
-
165
- # AI Explanation
166
- with st.spinner("Generating disease explanation..."):
167
- explanation = get_ai_explanation(disease_name)
168
-
169
- st.markdown("## 📘 Disease Information")
170
- st.write(explanation)
171
-
172
- else:
173
- st.warning("No cassava leaf detected.")
174
-
175
-
176
- # -----------------------------
177
- # FOOTER
178
- # -----------------------------
179
- st.markdown("---")
180
- st.markdown(
181
- "<div style='text-align: center; font-size: 14px;'>"
182
- "Developed by <b>Edcel Bogay</b>"
183
- "</div>",
184
- unsafe_allow_html=True
185
  )
 
1
+ import streamlit as st
2
+ import requests
3
+ from PIL import Image
4
+ import tempfile
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+
9
+ # -----------------------------
10
+ # PAGE CONFIG
11
+ # -----------------------------
12
+ st.set_page_config(page_title="Cassava Disease Detection", layout="centered")
13
+
14
+ # -----------------------------
15
+ # CONFIGURATION
16
+ # -----------------------------
17
+ ROBOFLOW_API_KEY = st.secrets["ROBOFLOW_API_KEY"]
18
+ OPENROUTER_API_KEY = st.secrets["OPENROUTER_API_KEY"]
19
+
20
+ MODEL_ID = "cassavadisease/1"
21
+ ROBOFLOW_API_URL = "https://serverless.roboflow.com"
22
+
23
+ # -----------------------------
24
+ # INITIALIZE ROBOFLOW CLIENT
25
+ # -----------------------------
26
+ CLIENT = InferenceHTTPClient(
27
+ api_url=ROBOFLOW_API_URL,
28
+ api_key=ROBOFLOW_API_KEY
29
+ )
30
+
31
+ # -----------------------------
32
+ # FUNCTION: AI EXPLANATION
33
+ # -----------------------------
34
+ def get_ai_explanation(disease_name):
35
+
36
+ prompt = f"""
37
+ Explain briefly the cassava disease: {disease_name}.
38
+ Include:
39
+ - Cause
40
+ - Main Symptoms
41
+ - Prevention
42
+ - Treatment
43
+ Keep answer short.
44
+ """
45
+
46
+ response = requests.post(
47
+ "https://openrouter.ai/api/v1/chat/completions",
48
+ headers={
49
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
50
+ "Content-Type": "application/json",
51
+ "HTTP-Referer": "http://localhost:8501",
52
+ "X-Title": "Cassava Disease Detection App"
53
+ },
54
+ json={
55
+ "model": "minimax/minimax-m2.5",
56
+ "messages": [
57
+ {"role": "user", "content": prompt}
58
+ ],
59
+ "max_tokens": 800,
60
+ "temperature": 0.3
61
+ }
62
+ )
63
+
64
+ if response.status_code != 200:
65
+ return f"OpenRouter API Error:\n{response.text}"
66
+
67
+ result = response.json()
68
+
69
+ if "choices" not in result:
70
+ return f"Unexpected API Response:\n{result}"
71
+
72
+ return result["choices"][0]["message"]["content"]
73
+
74
+
75
+ # -----------------------------
76
+ # UI
77
+ # -----------------------------
78
+ st.title("Cassava Disease Detection Web App")
79
+ st.write("Upload or capture a cassava leaf image for disease detection.")
80
+
81
+ source = st.radio("Select Image Source:", ["Upload Image", "Use Camera"])
82
+
83
+ image = None
84
+
85
+ if source == "Upload Image":
86
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
87
+ if uploaded_file:
88
+ image = Image.open(uploaded_file)
89
+
90
+ elif source == "Use Camera":
91
+ camera_photo = st.camera_input("Take a picture of the cassava leaf")
92
+ if camera_photo:
93
+ image = Image.open(camera_photo)
94
+
95
+ # -----------------------------
96
+ # MAIN PROCESS
97
+ # -----------------------------
98
+ if image is not None:
99
+
100
+ st.image(image, caption="Captured Image", use_container_width=True)
101
+
102
+ # Save temp image
103
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
104
+ image.save(tmp.name)
105
+ temp_path = tmp.name
106
+
107
+ # Roboflow inference
108
+ with st.spinner("Analyzing image..."):
109
+ result = CLIENT.infer(temp_path, model_id=MODEL_ID)
110
+
111
+ os.remove(temp_path)
112
+
113
+ predictions = result.get("predictions", [])
114
+
115
+ if predictions:
116
+
117
+ img_cv = np.array(image)
118
+ img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2BGR)
119
+
120
+ for pred in predictions:
121
+
122
+ x = pred["x"]
123
+ y = pred["y"]
124
+ w = pred["width"]
125
+ h = pred["height"]
126
+ label = pred["class"]
127
+ confidence = round(pred["confidence"] * 100, 2)
128
+
129
+ # Convert center to corner format
130
+ x1 = int(x - w / 2)
131
+ y1 = int(y - h / 2)
132
+ x2 = int(x + w / 2)
133
+ y2 = int(y + h / 2)
134
+
135
+ # Draw bounding box
136
+ cv2.rectangle(img_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
137
+
138
+ # Label background
139
+ cv2.rectangle(img_cv, (x1, y1 - 30), (x1 + 250, y1), (0, 255, 0), -1)
140
+
141
+ # Label text
142
+ cv2.putText(
143
+ img_cv,
144
+ f"{label} ({confidence}%)",
145
+ (x1 + 5, y1 - 10),
146
+ cv2.FONT_HERSHEY_SIMPLEX,
147
+ 0.6,
148
+ (0, 0, 0),
149
+ 2
150
+ )
151
+
152
+ # Convert back to RGB
153
+ img_display = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
154
+ st.image(img_display, caption="Detected & Labeled Image", use_container_width=True)
155
+
156
+ # Get highest confidence prediction
157
+ top_prediction = max(predictions, key=lambda x: x["confidence"])
158
+ disease_name = top_prediction["class"]
159
+ confidence = round(top_prediction["confidence"] * 100, 2)
160
+
161
+ st.success(f"Detected: **{disease_name}**")
162
+ st.info(f"Confidence: {confidence}%")
163
+
164
+ # AI Explanation
165
+ with st.spinner("Generating disease explanation..."):
166
+ explanation = get_ai_explanation(disease_name)
167
+
168
+ st.markdown("## 📘 Disease Information")
169
+ st.write(explanation)
170
+
171
+ else:
172
+ st.warning("No cassava leaf detected.")
173
+
174
+
175
+ # -----------------------------
176
+ # FOOTER
177
+ # -----------------------------
178
+ st.markdown("---")
179
+ st.markdown(
180
+ "<div style='text-align: center; font-size: 14px;'>"
181
+ "Developed by <b>Edcel Bogay</b>"
182
+ "</div>",
183
+ unsafe_allow_html=True
 
184
  )