import os # Force Streamlit (and libraries that depend on ~/.config) to use a writable directory os.environ["HOME"] = os.getcwd() # or "/tmp" os.environ["MPLCONFIGDIR"] = os.path.join(os.getcwd(), ".config") # ensure matplotlib writes here import streamlit as st import numpy as np import cv2 import torch import tempfile from PIL import Image from tensorflow.keras.models import load_model from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration import io import os # Import the os module to access environment variables from ultralytics import YOLO # Import YOLO from ultralytics # --- Page Configuration (Best practice: call this first) --- st.set_page_config( page_title="SmartHeal Wound Care Agent", page_icon="🩹", layout="wide", initial_sidebar_state="expanded" ) # --- Model Loading (Cached for performance) --- @st.cache_resource def load_all_models(): """Loads all required models and pipelines into memory once.""" try: # Get Hugging Face token from environment variable hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") if not hf_token: st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).") st.stop() # YOLOv8 detection model (using user's specified path) detection_model = YOLO("/home/ubuntu/upload/best(1).pt") # Load YOLOv8 model # Segmentation model (using user's specified path) segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False) # Classification model (using user's specified model ID) # Some pipelines might require token for private models or rate limits classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token) # Med-Gemma for analysis (using user's specified model ID) medgemma_model_id = "google/medgemma-4b-it" medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token) medgemma_model = LlavaForConditionalGeneration.from_pretrained( medgemma_model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, token=hf_token # Pass the token here ) medgemma_model.to("cuda") # Move model to GPU return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor except Exception as e: st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}") st.stop() # --- Agent Class --- class WoundCareAgent: """An agentic class to encapsulate the wound care analysis pipeline.""" def __init__(self, models): self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models self.px_per_cm = 38 # Example value, should be calibrated for real-world use def detect_wound(self, image_cv): """Detects the wound region using YOLOv8.""" st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."}) results = self.yolo_model(image_cv) # Use YOLOv8 model directly boxes = results[0].boxes.xyxy.cpu().numpy() # Access boxes from YOLOv8 results if len(boxes) == 0: return None, None # Assuming the largest bounding box is the wound (or the first detected) box = boxes[0] x1, y1, x2, y2 = map(int, box[:4]) detected_region = image_cv[y1:y2, x1:x2] return detected_region, (x1, y1, x2, y2) def segment_wound(self, detected_region): """Segments the wound from the detected region using the provided segmentation model.""" st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."}) # Resize for segmentation model input resized = cv2.resize(detected_region, (256, 256)) / 255.0 input_tensor = np.expand_dims(resized, axis=0) pred_mask = self.seg_model.predict(input_tensor)[0] binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8) return binary_mask def estimate_area(self, mask): """Estimates the area of the wound from the mask.""" st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."}) pixel_area = np.sum(mask > 0) area_cm2 = pixel_area / (self.px_per_cm ** 2) return round(area_cm2, 2) def classify_wound(self, detected_region): """Classifies the type of the wound using the provided classification pipeline.""" st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."}) try: # Convert numpy array to PIL Image for the pipeline pil_image = Image.fromarray(detected_region) # Save to a temporary file for the pipeline, as it expects a file path or PIL Image with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: pil_image.save(tmp.name) tmp_path = tmp.name result = self.classify_pipe(tmp_path) import os os.unlink(tmp_path) # Clean up temporary file return result[0]["label"] except Exception as e: st.warning(f"Could not classify wound type: {e}") return "Unknown" def generate_recommendations(self, image, patient_info, analysis_results): """Generates a detailed assessment and treatment plan using Med-Gemma.""" st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."}) # Prepare the image for Med-Gemma (ensure it's a PIL Image) if not isinstance(image, Image.Image): image = Image.fromarray(image) # Construct the prompt for Med-Gemma prompt_text = f"""Patient Info: - Age: {patient_info["age"]} - Diabetic: {patient_info["diabetic"]} - Wound Type: {analysis_results["wound_type"]} - Area: {analysis_results["area_cm2"]} - Signs of infection: {patient_info["infection"]} Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections: 1. **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data. 2. **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management. 3. **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices. 4. **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse. 5. **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress. **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment. """ # Med-Gemma expects messages in a specific format for multi-modal input messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a wound care expert."}] }, { "role": "user", "content": [ {"type": "text", "text": prompt_text}, {"type": "image", "image": image} ] } ] # Convert messages to input_ids using the processor's chat template input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device) # Generate response output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7) response = self.medgemma_processor.decode(output[0], skip_special_tokens=True) # Extract only the assistant's response part if "ASSISTANT:" in response: assistant_response = response.split("ASSISTANT:", 1)[1].strip() else: assistant_response = response.strip() return assistant_response def run_full_analysis(self, image, patient_info): """Executes the entire analysis pipeline.""" st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."}) image_cv = np.array(image.convert("RGB")) try: detected_region, box = self.detect_wound(image_cv) if detected_region is None: st.error("Agent Error: No wound could be detected in the image. Please try another image.") st.session_state.clear() return None except Exception as e: st.error(f"Error during wound detection: {e}") st.session_state.clear() return None try: mask_resized = self.segment_wound(detected_region) area_cm2 = self.estimate_area(mask_resized) except Exception as e: st.error(f"Error during wound segmentation or area estimation: {e}") st.session_state.clear() return None try: wound_type = self.classify_wound(detected_region) except Exception as e: st.error(f"Error during wound classification: {e}") st.session_state.clear() return None analysis_results = { "box": box, "detected_region": detected_region, "mask": mask_resized, "area_cm2": area_cm2, "wound_type": wound_type, } try: recommendations = self.generate_recommendations(image, patient_info, analysis_results) analysis_results["recommendations"] = recommendations except Exception as e: st.error(f"Error during recommendation generation: {e}") st.session_state.clear() return None st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."}) return analysis_results # --- UI Layout --- st.title("🩹 SmartHeal: The Agentic Wound Care Assistant") # Initialize session state if "analysis_results" not in st.session_state: st.session_state.analysis_results = None if "messages" not in st.session_state: st.session_state.messages = [] # --- Sidebar for Inputs --- with st.sidebar: st.header("📋 Patient & Image Input") uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"]) with st.form("patient_form"): st.write("2. Enter Patient Details") age = st.number_input("Patient Age", min_value=1, max_value=120, value=50) diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0) infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0) col1, col2 = st.columns(2) with col1: submitted = st.form_submit_button("🚀 Analyze Wound", use_container_width=True) with col2: cleared = st.form_submit_button("❌ Clear", use_container_width=True) if cleared: st.session_state.analysis_results = None st.session_state.messages = [] st.rerun() # --- Main Content Area --- if submitted and uploaded_file: # Load models and instantiate agent models = load_all_models() agent = WoundCareAgent(models) # Store patient info patient_info = {"age": age, "diabetic": diabetic, "infection": infection} # Open image image = Image.open(uploaded_file) # Clear previous results and run new analysis st.session_state.analysis_results = None st.session_state.messages = [] # Clear messages for new analysis st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."}) with st.spinner("The SmartHeal Agent is at work..."): st.session_state.analysis_results = agent.run_full_analysis(image, patient_info) # Display chat messages from the agent's process for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Display final results in tabs if analysis is complete if st.session_state.analysis_results: results = st.session_state.analysis_results image = Image.open(uploaded_file) # Re-open image for display image_cv = np.array(image.convert("RGB")) st.header("✅ Analysis Complete") tab1, tab2, tab3 = st.tabs(["📝 **Expert Recommendations**", "🔬 **Vision Analysis**", "📄 **Download Report**"]) with tab1: st.markdown(results["recommendations"]) with tab2: st.subheader("Wound Detection & Segmentation") # Create overlay x1, y1, x2, y2 = results["box"] overlay_image = image_cv.copy() # Ensure mask_resized is the correct size for the detected region mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST) # Create a colored mask for blending colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2]) colored_mask_region[mask_for_overlay > 0] = [255, 0, 0] # Red color for wound # Blend the original detected region with the colored mask overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0) # Draw bounding box cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True) st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cm²") st.metric(label="Classified Wound Type", value=f"{results['wound_type']}") with tab3: st.subheader("Download Full Report") st.download_button( label="📥 Download as Text File", data=results["recommendations"], file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt", mime="text/plain" ) else: # Show information prompting the user to upload an image and fill in patient details. st.info("Please upload an image and patient details in the sidebar to start.")