Smart-Heal-agent / src /streamlit_app.py
Ani14's picture
Update src/streamlit_app.py
097c7f4 verified
import os
# Force Streamlit (and libraries that depend on ~/.config) to use a writable directory
os.environ["HOME"] = os.getcwd() # or "/tmp"
os.environ["MPLCONFIGDIR"] = os.path.join(os.getcwd(), ".config") # ensure matplotlib writes here
import streamlit as st
import numpy as np
import cv2
import torch
import tempfile
from PIL import Image
from tensorflow.keras.models import load_model
from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration
import io
import os # Import the os module to access environment variables
from ultralytics import YOLO # Import YOLO from ultralytics
# --- Page Configuration (Best practice: call this first) ---
st.set_page_config(
page_title="SmartHeal Wound Care Agent",
page_icon="🩹",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Model Loading (Cached for performance) ---
@st.cache_resource
def load_all_models():
"""Loads all required models and pipelines into memory once."""
try:
# Get Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if not hf_token:
st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).")
st.stop()
# YOLOv8 detection model (using user's specified path)
detection_model = YOLO("/home/ubuntu/upload/best(1).pt") # Load YOLOv8 model
# Segmentation model (using user's specified path)
segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False)
# Classification model (using user's specified model ID)
# Some pipelines might require token for private models or rate limits
classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token)
# Med-Gemma for analysis (using user's specified model ID)
medgemma_model_id = "google/medgemma-4b-it"
medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token)
medgemma_model = LlavaForConditionalGeneration.from_pretrained(
medgemma_model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
token=hf_token # Pass the token here
)
medgemma_model.to("cuda") # Move model to GPU
return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor
except Exception as e:
st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}")
st.stop()
# --- Agent Class ---
class WoundCareAgent:
"""An agentic class to encapsulate the wound care analysis pipeline."""
def __init__(self, models):
self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models
self.px_per_cm = 38 # Example value, should be calibrated for real-world use
def detect_wound(self, image_cv):
"""Detects the wound region using YOLOv8."""
st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."})
results = self.yolo_model(image_cv) # Use YOLOv8 model directly
boxes = results[0].boxes.xyxy.cpu().numpy() # Access boxes from YOLOv8 results
if len(boxes) == 0:
return None, None
# Assuming the largest bounding box is the wound (or the first detected)
box = boxes[0]
x1, y1, x2, y2 = map(int, box[:4])
detected_region = image_cv[y1:y2, x1:x2]
return detected_region, (x1, y1, x2, y2)
def segment_wound(self, detected_region):
"""Segments the wound from the detected region using the provided segmentation model."""
st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."})
# Resize for segmentation model input
resized = cv2.resize(detected_region, (256, 256)) / 255.0
input_tensor = np.expand_dims(resized, axis=0)
pred_mask = self.seg_model.predict(input_tensor)[0]
binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8)
return binary_mask
def estimate_area(self, mask):
"""Estimates the area of the wound from the mask."""
st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."})
pixel_area = np.sum(mask > 0)
area_cm2 = pixel_area / (self.px_per_cm ** 2)
return round(area_cm2, 2)
def classify_wound(self, detected_region):
"""Classifies the type of the wound using the provided classification pipeline."""
st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."})
try:
# Convert numpy array to PIL Image for the pipeline
pil_image = Image.fromarray(detected_region)
# Save to a temporary file for the pipeline, as it expects a file path or PIL Image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
pil_image.save(tmp.name)
tmp_path = tmp.name
result = self.classify_pipe(tmp_path)
import os
os.unlink(tmp_path) # Clean up temporary file
return result[0]["label"]
except Exception as e:
st.warning(f"Could not classify wound type: {e}")
return "Unknown"
def generate_recommendations(self, image, patient_info, analysis_results):
"""Generates a detailed assessment and treatment plan using Med-Gemma."""
st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."})
# Prepare the image for Med-Gemma (ensure it's a PIL Image)
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Construct the prompt for Med-Gemma
prompt_text = f"""Patient Info:
- Age: {patient_info["age"]}
- Diabetic: {patient_info["diabetic"]}
- Wound Type: {analysis_results["wound_type"]}
- Area: {analysis_results["area_cm2"]}
- Signs of infection: {patient_info["infection"]}
Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections:
1. **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data.
2. **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management.
3. **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices.
4. **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse.
5. **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress.
**Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment.
"""
# Med-Gemma expects messages in a specific format for multi-modal input
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a wound care expert."}]
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{"type": "image", "image": image}
]
}
]
# Convert messages to input_ids using the processor's chat template
input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device)
# Generate response
output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7)
response = self.medgemma_processor.decode(output[0], skip_special_tokens=True)
# Extract only the assistant's response part
if "ASSISTANT:" in response:
assistant_response = response.split("ASSISTANT:", 1)[1].strip()
else:
assistant_response = response.strip()
return assistant_response
def run_full_analysis(self, image, patient_info):
"""Executes the entire analysis pipeline."""
st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."})
image_cv = np.array(image.convert("RGB"))
try:
detected_region, box = self.detect_wound(image_cv)
if detected_region is None:
st.error("Agent Error: No wound could be detected in the image. Please try another image.")
st.session_state.clear()
return None
except Exception as e:
st.error(f"Error during wound detection: {e}")
st.session_state.clear()
return None
try:
mask_resized = self.segment_wound(detected_region)
area_cm2 = self.estimate_area(mask_resized)
except Exception as e:
st.error(f"Error during wound segmentation or area estimation: {e}")
st.session_state.clear()
return None
try:
wound_type = self.classify_wound(detected_region)
except Exception as e:
st.error(f"Error during wound classification: {e}")
st.session_state.clear()
return None
analysis_results = {
"box": box,
"detected_region": detected_region,
"mask": mask_resized,
"area_cm2": area_cm2,
"wound_type": wound_type,
}
try:
recommendations = self.generate_recommendations(image, patient_info, analysis_results)
analysis_results["recommendations"] = recommendations
except Exception as e:
st.error(f"Error during recommendation generation: {e}")
st.session_state.clear()
return None
st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."})
return analysis_results
# --- UI Layout ---
st.title("🩹 SmartHeal: The Agentic Wound Care Assistant")
# Initialize session state
if "analysis_results" not in st.session_state:
st.session_state.analysis_results = None
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Sidebar for Inputs ---
with st.sidebar:
st.header("πŸ“‹ Patient & Image Input")
uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"])
with st.form("patient_form"):
st.write("2. Enter Patient Details")
age = st.number_input("Patient Age", min_value=1, max_value=120, value=50)
diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0)
infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0)
col1, col2 = st.columns(2)
with col1:
submitted = st.form_submit_button("πŸš€ Analyze Wound", use_container_width=True)
with col2:
cleared = st.form_submit_button("❌ Clear", use_container_width=True)
if cleared:
st.session_state.analysis_results = None
st.session_state.messages = []
st.rerun()
# --- Main Content Area ---
if submitted and uploaded_file:
# Load models and instantiate agent
models = load_all_models()
agent = WoundCareAgent(models)
# Store patient info
patient_info = {"age": age, "diabetic": diabetic, "infection": infection}
# Open image
image = Image.open(uploaded_file)
# Clear previous results and run new analysis
st.session_state.analysis_results = None
st.session_state.messages = [] # Clear messages for new analysis
st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."})
with st.spinner("The SmartHeal Agent is at work..."):
st.session_state.analysis_results = agent.run_full_analysis(image, patient_info)
# Display chat messages from the agent's process
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Display final results in tabs if analysis is complete
if st.session_state.analysis_results:
results = st.session_state.analysis_results
image = Image.open(uploaded_file) # Re-open image for display
image_cv = np.array(image.convert("RGB"))
st.header("βœ… Analysis Complete")
tab1, tab2, tab3 = st.tabs(["πŸ“ **Expert Recommendations**", "πŸ”¬ **Vision Analysis**", "πŸ“„ **Download Report**"])
with tab1:
st.markdown(results["recommendations"])
with tab2:
st.subheader("Wound Detection & Segmentation")
# Create overlay
x1, y1, x2, y2 = results["box"]
overlay_image = image_cv.copy()
# Ensure mask_resized is the correct size for the detected region
mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST)
# Create a colored mask for blending
colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2])
colored_mask_region[mask_for_overlay > 0] = [255, 0, 0] # Red color for wound
# Blend the original detected region with the colored mask
overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0)
# Draw bounding box
cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box
st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True)
st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cmΒ²")
st.metric(label="Classified Wound Type", value=f"{results['wound_type']}")
with tab3:
st.subheader("Download Full Report")
st.download_button(
label="πŸ“₯ Download as Text File",
data=results["recommendations"],
file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt",
mime="text/plain"
)
else:
# Show information prompting the user to upload an image and fill in patient details.
st.info("Please upload an image and patient details in the sidebar to start.")