Spaces:
Sleeping
Sleeping
File size: 14,686 Bytes
ce45ba9 293bc49 d451e55 0ab2b4c b450aae 0ab2b4c b450aae d451e55 0ab2b4c b450aae d451e55 b450aae d451e55 b450aae d451e55 b450aae 0ab2b4c b450aae 0ab2b4c b450aae 0ab2b4c b450aae d451e55 b450aae 293bc49 b450aae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import os
# Force Streamlit (and libraries that depend on ~/.config) to use a writable directory
os.environ["HOME"] = os.getcwd() # or "/tmp"
os.environ["MPLCONFIGDIR"] = os.path.join(os.getcwd(), ".config") # ensure matplotlib writes here
import streamlit as st
import numpy as np
import cv2
import torch
import tempfile
from PIL import Image
from tensorflow.keras.models import load_model
from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration
import io
import os # Import the os module to access environment variables
from ultralytics import YOLO # Import YOLO from ultralytics
# --- Page Configuration (Best practice: call this first) ---
st.set_page_config(
page_title="SmartHeal Wound Care Agent",
page_icon="π©Ή",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Model Loading (Cached for performance) ---
@st.cache_resource
def load_all_models():
"""Loads all required models and pipelines into memory once."""
try:
# Get Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if not hf_token:
st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).")
st.stop()
# YOLOv8 detection model (using user's specified path)
detection_model = YOLO("/home/ubuntu/upload/best(1).pt") # Load YOLOv8 model
# Segmentation model (using user's specified path)
segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False)
# Classification model (using user's specified model ID)
# Some pipelines might require token for private models or rate limits
classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token)
# Med-Gemma for analysis (using user's specified model ID)
medgemma_model_id = "google/medgemma-4b-it"
medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token)
medgemma_model = LlavaForConditionalGeneration.from_pretrained(
medgemma_model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
token=hf_token # Pass the token here
)
medgemma_model.to("cuda") # Move model to GPU
return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor
except Exception as e:
st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}")
st.stop()
# --- Agent Class ---
class WoundCareAgent:
"""An agentic class to encapsulate the wound care analysis pipeline."""
def __init__(self, models):
self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models
self.px_per_cm = 38 # Example value, should be calibrated for real-world use
def detect_wound(self, image_cv):
"""Detects the wound region using YOLOv8."""
st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."})
results = self.yolo_model(image_cv) # Use YOLOv8 model directly
boxes = results[0].boxes.xyxy.cpu().numpy() # Access boxes from YOLOv8 results
if len(boxes) == 0:
return None, None
# Assuming the largest bounding box is the wound (or the first detected)
box = boxes[0]
x1, y1, x2, y2 = map(int, box[:4])
detected_region = image_cv[y1:y2, x1:x2]
return detected_region, (x1, y1, x2, y2)
def segment_wound(self, detected_region):
"""Segments the wound from the detected region using the provided segmentation model."""
st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."})
# Resize for segmentation model input
resized = cv2.resize(detected_region, (256, 256)) / 255.0
input_tensor = np.expand_dims(resized, axis=0)
pred_mask = self.seg_model.predict(input_tensor)[0]
binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8)
return binary_mask
def estimate_area(self, mask):
"""Estimates the area of the wound from the mask."""
st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."})
pixel_area = np.sum(mask > 0)
area_cm2 = pixel_area / (self.px_per_cm ** 2)
return round(area_cm2, 2)
def classify_wound(self, detected_region):
"""Classifies the type of the wound using the provided classification pipeline."""
st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."})
try:
# Convert numpy array to PIL Image for the pipeline
pil_image = Image.fromarray(detected_region)
# Save to a temporary file for the pipeline, as it expects a file path or PIL Image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
pil_image.save(tmp.name)
tmp_path = tmp.name
result = self.classify_pipe(tmp_path)
import os
os.unlink(tmp_path) # Clean up temporary file
return result[0]["label"]
except Exception as e:
st.warning(f"Could not classify wound type: {e}")
return "Unknown"
def generate_recommendations(self, image, patient_info, analysis_results):
"""Generates a detailed assessment and treatment plan using Med-Gemma."""
st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."})
# Prepare the image for Med-Gemma (ensure it's a PIL Image)
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Construct the prompt for Med-Gemma
prompt_text = f"""Patient Info:
- Age: {patient_info["age"]}
- Diabetic: {patient_info["diabetic"]}
- Wound Type: {analysis_results["wound_type"]}
- Area: {analysis_results["area_cm2"]}
- Signs of infection: {patient_info["infection"]}
Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections:
1. **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data.
2. **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management.
3. **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices.
4. **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse.
5. **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress.
**Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment.
"""
# Med-Gemma expects messages in a specific format for multi-modal input
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a wound care expert."}]
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{"type": "image", "image": image}
]
}
]
# Convert messages to input_ids using the processor's chat template
input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device)
# Generate response
output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7)
response = self.medgemma_processor.decode(output[0], skip_special_tokens=True)
# Extract only the assistant's response part
if "ASSISTANT:" in response:
assistant_response = response.split("ASSISTANT:", 1)[1].strip()
else:
assistant_response = response.strip()
return assistant_response
def run_full_analysis(self, image, patient_info):
"""Executes the entire analysis pipeline."""
st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."})
image_cv = np.array(image.convert("RGB"))
try:
detected_region, box = self.detect_wound(image_cv)
if detected_region is None:
st.error("Agent Error: No wound could be detected in the image. Please try another image.")
st.session_state.clear()
return None
except Exception as e:
st.error(f"Error during wound detection: {e}")
st.session_state.clear()
return None
try:
mask_resized = self.segment_wound(detected_region)
area_cm2 = self.estimate_area(mask_resized)
except Exception as e:
st.error(f"Error during wound segmentation or area estimation: {e}")
st.session_state.clear()
return None
try:
wound_type = self.classify_wound(detected_region)
except Exception as e:
st.error(f"Error during wound classification: {e}")
st.session_state.clear()
return None
analysis_results = {
"box": box,
"detected_region": detected_region,
"mask": mask_resized,
"area_cm2": area_cm2,
"wound_type": wound_type,
}
try:
recommendations = self.generate_recommendations(image, patient_info, analysis_results)
analysis_results["recommendations"] = recommendations
except Exception as e:
st.error(f"Error during recommendation generation: {e}")
st.session_state.clear()
return None
st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."})
return analysis_results
# --- UI Layout ---
st.title("π©Ή SmartHeal: The Agentic Wound Care Assistant")
# Initialize session state
if "analysis_results" not in st.session_state:
st.session_state.analysis_results = None
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Sidebar for Inputs ---
with st.sidebar:
st.header("π Patient & Image Input")
uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"])
with st.form("patient_form"):
st.write("2. Enter Patient Details")
age = st.number_input("Patient Age", min_value=1, max_value=120, value=50)
diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0)
infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0)
col1, col2 = st.columns(2)
with col1:
submitted = st.form_submit_button("π Analyze Wound", use_container_width=True)
with col2:
cleared = st.form_submit_button("β Clear", use_container_width=True)
if cleared:
st.session_state.analysis_results = None
st.session_state.messages = []
st.rerun()
# --- Main Content Area ---
if submitted and uploaded_file:
# Load models and instantiate agent
models = load_all_models()
agent = WoundCareAgent(models)
# Store patient info
patient_info = {"age": age, "diabetic": diabetic, "infection": infection}
# Open image
image = Image.open(uploaded_file)
# Clear previous results and run new analysis
st.session_state.analysis_results = None
st.session_state.messages = [] # Clear messages for new analysis
st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."})
with st.spinner("The SmartHeal Agent is at work..."):
st.session_state.analysis_results = agent.run_full_analysis(image, patient_info)
# Display chat messages from the agent's process
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Display final results in tabs if analysis is complete
if st.session_state.analysis_results:
results = st.session_state.analysis_results
image = Image.open(uploaded_file) # Re-open image for display
image_cv = np.array(image.convert("RGB"))
st.header("β
Analysis Complete")
tab1, tab2, tab3 = st.tabs(["π **Expert Recommendations**", "π¬ **Vision Analysis**", "π **Download Report**"])
with tab1:
st.markdown(results["recommendations"])
with tab2:
st.subheader("Wound Detection & Segmentation")
# Create overlay
x1, y1, x2, y2 = results["box"]
overlay_image = image_cv.copy()
# Ensure mask_resized is the correct size for the detected region
mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST)
# Create a colored mask for blending
colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2])
colored_mask_region[mask_for_overlay > 0] = [255, 0, 0] # Red color for wound
# Blend the original detected region with the colored mask
overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0)
# Draw bounding box
cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box
st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True)
st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cmΒ²")
st.metric(label="Classified Wound Type", value=f"{results['wound_type']}")
with tab3:
st.subheader("Download Full Report")
st.download_button(
label="π₯ Download as Text File",
data=results["recommendations"],
file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt",
mime="text/plain"
)
else:
# Show information prompting the user to upload an image and fill in patient details.
st.info("Please upload an image and patient details in the sidebar to start.")
|