Ani14 commited on
Commit
b450aae
Β·
verified Β·
1 Parent(s): 9a1a57e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +308 -146
src/streamlit_app.py CHANGED
@@ -1,162 +1,324 @@
1
  import streamlit as st
2
  import numpy as np
3
- import torch
4
  import cv2
5
- import os
6
- from PIL import Image
7
  import tempfile
 
8
  from tensorflow.keras.models import load_model
9
  from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration
10
- from huggingface_hub import login
 
 
11
 
12
- # Streamlit UI setup
13
- st.set_page_config(page_title="SmartHeal Wound Agent", layout="wide")
14
- st.title("🩹 SmartHeal: Agentic Wound Care Assistant")
 
 
 
 
15
 
16
- # ---- Load all models securely with HF token + local weights ----
17
  @st.cache_resource
18
  def load_all_models():
19
- hf_token = os.environ.get("HUGGINGFACE_TOKEN")
20
- if not hf_token:
21
- raise ValueError("❌ HUGGINGFACE_TOKEN not found in environment")
22
-
23
- login(token=hf_token)
24
-
25
- # Detection model from local path
26
- yolo_path = "/app/best.pt"
27
- if not os.path.exists(yolo_path):
28
- raise FileNotFoundError(f"Detection model not found: {yolo_path}")
29
- detection_model = torch.hub.load("ultralytics/yolov5", "custom", path=yolo_path, force_reload=False)
30
-
31
- # Segmentation model from local path
32
- seg_path = "/app/segmentation model.h5"
33
- if not os.path.exists(seg_path):
34
- raise FileNotFoundError(f"Segmentation model not found: {seg_path}")
35
- segmentation_model = load_model(seg_path, compile=False)
36
-
37
- # Classification model from HuggingFace (private model allowed)
38
- classification_pipe = pipeline(
39
- "image-classification",
40
- model="Hemg/Wound-classification",
41
- use_auth_token=hf_token
42
- )
43
-
44
- # Med-Gemma (multimodal LLM)
45
- model_id = "google/medgemma-4b-it"
46
- processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
47
- med_model = LlavaForConditionalGeneration.from_pretrained(
48
- model_id,
49
- torch_dtype=torch.bfloat16,
50
- token=hf_token,
51
- low_cpu_mem_usage=True
52
- ).to("cuda")
53
-
54
- return detection_model, segmentation_model, classification_pipe, med_model, processor
55
-
56
- # ---- Inference Helper Functions ----
57
- def detect_wound(yolo_model, image_cv):
58
- results = yolo_model(image_cv)
59
- boxes = results.xyxy[0].cpu().numpy()
60
- if len(boxes) == 0:
61
- return None, None
62
- x1, y1, x2, y2 = map(int, boxes[0][:4])
63
- return image_cv[y1:y2, x1:x2], (x1, y1, x2, y2)
64
-
65
- def segment_wound(seg_model, region):
66
- resized = cv2.resize(region, (256, 256)) / 255.0
67
- pred = seg_model.predict(np.expand_dims(resized, axis=0))[0]
68
- return (pred[:, :, 0] > 0.5).astype(np.uint8)
69
-
70
- def estimate_area(mask, px_per_cm=38):
71
- return round(np.sum(mask > 0) / (px_per_cm ** 2), 2)
72
-
73
- def classify_wound(pipeline, region):
74
- image = Image.fromarray(region)
75
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
76
- image.save(tmp.name)
77
- label = pipeline(tmp.name)[0]["label"]
78
- os.unlink(tmp.name)
79
- return label
80
-
81
- def generate_medgemma_response(image, processor, model, patient_info, area_cm2, wound_type):
82
- messages = [
83
- {
84
- "role": "system",
85
- "content": [{"type": "text", "text": "You are a wound care expert."}]
86
- },
87
- {
88
- "role": "user",
89
- "content": [
90
- {"type": "text", "text": f"""Patient Info:
91
- - Age: {patient_info['age']}
92
- - Diabetic: {patient_info['diabetic']}
93
- - Wound Type: {wound_type}
94
- - Area: {area_cm2} cmΒ²
95
- - Signs of infection: {patient_info['infection']}
96
-
97
- Please provide:
98
- 1. Wound assessment
99
- 2. Recommended treatment
100
- 3. Cleaning & dressing method
101
- 4. Red flags to monitor
102
- 5. Follow-up schedule"""},
103
- {"type": "image", "image": image}
104
- ]
105
- }
106
- ]
107
-
108
- input_ids = processor.apply_chat_template(messages, return_tensors="pt").to(model.device)
109
- output = model.generate(input_ids, max_new_tokens=1000)
110
- response = processor.decode(output[0], skip_special_tokens=True)
111
-
112
- return response.split("ASSISTANT:")[-1].strip()
113
-
114
- # ---- Load models and run UI ----
115
- with st.spinner("πŸ”„ Loading models..."):
116
- yolo_model, seg_model, classify_pipe, med_model, processor = load_all_models()
117
-
118
- uploaded_file = st.file_uploader("πŸ“€ Upload a clear wound image", type=["jpg", "jpeg", "png"])
119
- with st.form("patient_form"):
120
- age = st.number_input("Patient Age", min_value=1, max_value=120)
121
- diabetic = st.radio("Diabetic?", ["Yes", "No"])
122
- infection = st.radio("Visible infection?", ["Yes", "No"])
123
- submit = st.form_submit_button("πŸš€ Analyze")
124
-
125
- if uploaded_file and submit:
126
- image = Image.open(uploaded_file).convert("RGB")
127
- image_cv = np.array(image)
128
-
129
- st.image(image, caption="Uploaded Image", use_column_width=True)
130
- with st.spinner("🧠 Analyzing image..."):
131
- region, box = detect_wound(yolo_model, image_cv)
132
- if region is None:
133
- st.error("❌ No wound detected.")
134
  st.stop()
135
 
136
- mask = segment_wound(seg_model, region)
137
- area_cm2 = estimate_area(mask)
138
- wound_type = classify_wound(classify_pipe, region)
139
-
140
- response = generate_medgemma_response(
141
- image=image,
142
- processor=processor,
143
- model=med_model,
144
- patient_info={"age": age, "diabetic": diabetic, "infection": infection},
145
- area_cm2=area_cm2,
146
- wound_type=wound_type
 
 
 
 
 
 
 
147
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- st.success("βœ… Analysis Complete")
150
- st.markdown("### πŸ“‹ Med-Gemma Expert Recommendation")
151
- st.info(response)
 
 
 
 
152
 
153
- # Optional metrics & image
154
- x1, y1, x2, y2 = box
155
- overlay = image_cv.copy()
156
- resized_mask = cv2.resize(mask, (x2 - x1, y2 - y1))
157
- overlay[y1:y2, x1:x2][resized_mask > 0] = [255, 0, 0]
158
- st.image(overlay, caption="Detection + Segmentation Overlay", use_column_width=True)
159
- st.metric("Wound Area", f"{area_cm2} cmΒ²")
160
- st.metric("Wound Type", wound_type)
161
 
162
- st.download_button("πŸ“₯ Download Report", response, file_name="wound_report.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
 
3
  import cv2
4
+ import torch
 
5
  import tempfile
6
+ from PIL import Image
7
  from tensorflow.keras.models import load_model
8
  from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration
9
+ import io
10
+ import os # Import the os module to access environment variables
11
+ from ultralytics import YOLO # Import YOLO from ultralytics
12
 
13
+ # --- Page Configuration (Best practice: call this first) ---
14
+ st.set_page_config(
15
+ page_title="SmartHeal Wound Care Agent",
16
+ page_icon="🩹",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded"
19
+ )
20
 
21
+ # --- Model Loading (Cached for performance) ---
22
  @st.cache_resource
23
  def load_all_models():
24
+ """Loads all required models and pipelines into memory once."""
25
+ try:
26
+ # Get Hugging Face token from environment variable
27
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
28
+ if not hf_token:
29
+ st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  st.stop()
31
 
32
+ # YOLOv8 detection model (using user's specified path)
33
+ detection_model = YOLO("/home/ubuntu/upload/best(1).pt") # Load YOLOv8 model
34
+
35
+ # Segmentation model (using user's specified path)
36
+ segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False)
37
+
38
+ # Classification model (using user's specified model ID)
39
+ # Some pipelines might require token for private models or rate limits
40
+ classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token)
41
+
42
+ # Med-Gemma for analysis (using user's specified model ID)
43
+ medgemma_model_id = "google/medgemma-4b-it"
44
+ medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token)
45
+ medgemma_model = LlavaForConditionalGeneration.from_pretrained(
46
+ medgemma_model_id,
47
+ torch_dtype=torch.bfloat16,
48
+ low_cpu_mem_usage=True,
49
+ token=hf_token # Pass the token here
50
  )
51
+ medgemma_model.to("cuda") # Move model to GPU
52
+
53
+ return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor
54
+ except Exception as e:
55
+ st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}")
56
+ st.stop()
57
+
58
+
59
+ # --- Agent Class ---
60
+ class WoundCareAgent:
61
+ """An agentic class to encapsulate the wound care analysis pipeline."""
62
+ def __init__(self, models):
63
+ self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models
64
+ self.px_per_cm = 38 # Example value, should be calibrated for real-world use
65
+
66
+ def detect_wound(self, image_cv):
67
+ """Detects the wound region using YOLOv8."""
68
+ st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."})
69
+ results = self.yolo_model(image_cv) # Use YOLOv8 model directly
70
+ boxes = results[0].boxes.xyxy.cpu().numpy() # Access boxes from YOLOv8 results
71
+ if len(boxes) == 0:
72
+ return None, None
73
+ # Assuming the largest bounding box is the wound (or the first detected)
74
+ box = boxes[0]
75
+ x1, y1, x2, y2 = map(int, box[:4])
76
+ detected_region = image_cv[y1:y2, x1:x2]
77
+ return detected_region, (x1, y1, x2, y2)
78
+
79
+ def segment_wound(self, detected_region):
80
+ """Segments the wound from the detected region using the provided segmentation model."""
81
+ st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."})
82
+ # Resize for segmentation model input
83
+ resized = cv2.resize(detected_region, (256, 256)) / 255.0
84
+ input_tensor = np.expand_dims(resized, axis=0)
85
+ pred_mask = self.seg_model.predict(input_tensor)[0]
86
+ binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8)
87
+ return binary_mask
88
+
89
+ def estimate_area(self, mask):
90
+ """Estimates the area of the wound from the mask."""
91
+ st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."})
92
+ pixel_area = np.sum(mask > 0)
93
+ area_cm2 = pixel_area / (self.px_per_cm ** 2)
94
+ return round(area_cm2, 2)
95
+
96
+ def classify_wound(self, detected_region):
97
+ """Classifies the type of the wound using the provided classification pipeline."""
98
+ st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."})
99
+ try:
100
+ # Convert numpy array to PIL Image for the pipeline
101
+ pil_image = Image.fromarray(detected_region)
102
+ # Save to a temporary file for the pipeline, as it expects a file path or PIL Image
103
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
104
+ pil_image.save(tmp.name)
105
+ tmp_path = tmp.name
106
+ result = self.classify_pipe(tmp_path)
107
+ import os
108
+ os.unlink(tmp_path) # Clean up temporary file
109
+ return result[0]["label"]
110
+ except Exception as e:
111
+ st.warning(f"Could not classify wound type: {e}")
112
+ return "Unknown"
113
+
114
+ def generate_recommendations(self, image, patient_info, analysis_results):
115
+ """Generates a detailed assessment and treatment plan using Med-Gemma."""
116
+ st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."})
117
+
118
+ # Prepare the image for Med-Gemma (ensure it's a PIL Image)
119
+ if not isinstance(image, Image.Image):
120
+ image = Image.fromarray(image)
121
+
122
+ # Construct the prompt for Med-Gemma
123
+ prompt_text = f"""Patient Info:
124
+ - Age: {patient_info["age"]}
125
+ - Diabetic: {patient_info["diabetic"]}
126
+ - Wound Type: {analysis_results["wound_type"]}
127
+ - Area: {analysis_results["area_cm2"]}
128
+ - Signs of infection: {patient_info["infection"]}
129
+
130
+ Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections:
131
+
132
+ 1. **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data.
133
+ 2. **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management.
134
+ 3. **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices.
135
+ 4. **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse.
136
+ 5. **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress.
137
+
138
+ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment.
139
+ """
140
+
141
+ # Med-Gemma expects messages in a specific format for multi-modal input
142
+ messages = [
143
+ {
144
+ "role": "system",
145
+ "content": [{"type": "text", "text": "You are a wound care expert."}]
146
+ },
147
+ {
148
+ "role": "user",
149
+ "content": [
150
+ {"type": "text", "text": prompt_text},
151
+ {"type": "image", "image": image}
152
+ ]
153
+ }
154
+ ]
155
+
156
+ # Convert messages to input_ids using the processor's chat template
157
+ input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device)
158
+
159
+ # Generate response
160
+ output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7)
161
+ response = self.medgemma_processor.decode(output[0], skip_special_tokens=True)
162
+
163
+ # Extract only the assistant's response part
164
+ if "ASSISTANT:" in response:
165
+ assistant_response = response.split("ASSISTANT:", 1)[1].strip()
166
+ else:
167
+ assistant_response = response.strip()
168
+
169
+ return assistant_response
170
+
171
+ def run_full_analysis(self, image, patient_info):
172
+ """Executes the entire analysis pipeline."""
173
+ st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."})
174
+ image_cv = np.array(image.convert("RGB"))
175
+
176
+ try:
177
+ detected_region, box = self.detect_wound(image_cv)
178
+ if detected_region is None:
179
+ st.error("Agent Error: No wound could be detected in the image. Please try another image.")
180
+ st.session_state.clear()
181
+ return None
182
+ except Exception as e:
183
+ st.error(f"Error during wound detection: {e}")
184
+ st.session_state.clear()
185
+ return None
186
 
187
+ try:
188
+ mask_resized = self.segment_wound(detected_region)
189
+ area_cm2 = self.estimate_area(mask_resized)
190
+ except Exception as e:
191
+ st.error(f"Error during wound segmentation or area estimation: {e}")
192
+ st.session_state.clear()
193
+ return None
194
 
195
+ try:
196
+ wound_type = self.classify_wound(detected_region)
197
+ except Exception as e:
198
+ st.error(f"Error during wound classification: {e}")
199
+ st.session_state.clear()
200
+ return None
 
 
201
 
202
+ analysis_results = {
203
+ "box": box,
204
+ "detected_region": detected_region,
205
+ "mask": mask_resized,
206
+ "area_cm2": area_cm2,
207
+ "wound_type": wound_type,
208
+ }
209
+
210
+ try:
211
+ recommendations = self.generate_recommendations(image, patient_info, analysis_results)
212
+ analysis_results["recommendations"] = recommendations
213
+ except Exception as e:
214
+ st.error(f"Error during recommendation generation: {e}")
215
+ st.session_state.clear()
216
+ return None
217
+
218
+ st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."})
219
+ return analysis_results
220
+
221
+
222
+ # --- UI Layout ---
223
+ st.title("🩹 SmartHeal: The Agentic Wound Care Assistant")
224
+
225
+ # Initialize session state
226
+ if "analysis_results" not in st.session_state:
227
+ st.session_state.analysis_results = None
228
+ if "messages" not in st.session_state:
229
+ st.session_state.messages = []
230
+
231
+ # --- Sidebar for Inputs ---
232
+ with st.sidebar:
233
+ st.header("πŸ“‹ Patient & Image Input")
234
+ uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"])
235
+
236
+ with st.form("patient_form"):
237
+ st.write("2. Enter Patient Details")
238
+ age = st.number_input("Patient Age", min_value=1, max_value=120, value=50)
239
+ diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0)
240
+ infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0)
241
+
242
+ col1, col2 = st.columns(2)
243
+ with col1:
244
+ submitted = st.form_submit_button("πŸš€ Analyze Wound", use_container_width=True)
245
+ with col2:
246
+ cleared = st.form_submit_button("❌ Clear", use_container_width=True)
247
+
248
+ if cleared:
249
+ st.session_state.analysis_results = None
250
+ st.session_state.messages = []
251
+ st.rerun()
252
+
253
+ # --- Main Content Area ---
254
+ if submitted and uploaded_file:
255
+ # Load models and instantiate agent
256
+ models = load_all_models()
257
+ agent = WoundCareAgent(models)
258
+
259
+ # Store patient info
260
+ patient_info = {"age": age, "diabetic": diabetic, "infection": infection}
261
+
262
+ # Open image
263
+ image = Image.open(uploaded_file)
264
+
265
+ # Clear previous results and run new analysis
266
+ st.session_state.analysis_results = None
267
+ st.session_state.messages = [] # Clear messages for new analysis
268
+ st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."})
269
+
270
+ with st.spinner("The SmartHeal Agent is at work..."):
271
+ st.session_state.analysis_results = agent.run_full_analysis(image, patient_info)
272
+
273
+ # Display chat messages from the agent's process
274
+ for message in st.session_state.messages:
275
+ with st.chat_message(message["role"]):
276
+ st.markdown(message["content"])
277
+
278
+ # Display final results in tabs if analysis is complete
279
+ if st.session_state.analysis_results:
280
+ results = st.session_state.analysis_results
281
+ image = Image.open(uploaded_file) # Re-open image for display
282
+ image_cv = np.array(image.convert("RGB"))
283
+
284
+ st.header("βœ… Analysis Complete")
285
+ tab1, tab2, tab3 = st.tabs(["πŸ“ **Expert Recommendations**", "πŸ”¬ **Vision Analysis**", "πŸ“„ **Download Report**"])
286
+
287
+ with tab1:
288
+ st.markdown(results["recommendations"])
289
+
290
+ with tab2:
291
+ st.subheader("Wound Detection & Segmentation")
292
+
293
+ # Create overlay
294
+ x1, y1, x2, y2 = results["box"]
295
+ overlay_image = image_cv.copy()
296
+
297
+ # Ensure mask_resized is the correct size for the detected region
298
+ mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST)
299
+
300
+ # Create a colored mask for blending
301
+ colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2])
302
+ colored_mask_region[mask_for_overlay > 0] = [255, 0, 0] # Red color for wound
303
+
304
+ # Blend the original detected region with the colored mask
305
+ overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0)
306
+
307
+ # Draw bounding box
308
+ cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box
309
+
310
+ st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True)
311
+ st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cmΒ²")
312
+ st.metric(label="Classified Wound Type", value=f"{results['wound_type']}")
313
+
314
+ with tab3:
315
+ st.subheader("Download Full Report")
316
+ st.download_button(
317
+ label="πŸ“₯ Download as Text File",
318
+ data=results["recommendations"],
319
+ file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt",
320
+ mime="text/plain"
321
+ )
322
+ else:
323
+ # Show information prompting the user to upload an image and fill in patient details.
324
+ st.info("Please upload an image and patient details in the sidebar to start.")