Spaces:
Sleeping
Sleeping
| import os | |
| # Force Streamlit (and libraries that depend on ~/.config) to use a writable directory | |
| os.environ["HOME"] = os.getcwd() # or "/tmp" | |
| os.environ["MPLCONFIGDIR"] = os.path.join(os.getcwd(), ".config") # ensure matplotlib writes here | |
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration | |
| import io | |
| import os # Import the os module to access environment variables | |
| from ultralytics import YOLO # Import YOLO from ultralytics | |
| # --- Page Configuration (Best practice: call this first) --- | |
| st.set_page_config( | |
| page_title="SmartHeal Wound Care Agent", | |
| page_icon="π©Ή", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- Model Loading (Cached for performance) --- | |
| def load_all_models(): | |
| """Loads all required models and pipelines into memory once.""" | |
| try: | |
| # Get Hugging Face token from environment variable | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| if not hf_token: | |
| st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).") | |
| st.stop() | |
| # YOLOv8 detection model (using user's specified path) | |
| detection_model = YOLO("/home/ubuntu/upload/best(1).pt") # Load YOLOv8 model | |
| # Segmentation model (using user's specified path) | |
| segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False) | |
| # Classification model (using user's specified model ID) | |
| # Some pipelines might require token for private models or rate limits | |
| classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token) | |
| # Med-Gemma for analysis (using user's specified model ID) | |
| medgemma_model_id = "google/medgemma-4b-it" | |
| medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token) | |
| medgemma_model = LlavaForConditionalGeneration.from_pretrained( | |
| medgemma_model_id, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| token=hf_token # Pass the token here | |
| ) | |
| medgemma_model.to("cuda") # Move model to GPU | |
| return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor | |
| except Exception as e: | |
| st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}") | |
| st.stop() | |
| # --- Agent Class --- | |
| class WoundCareAgent: | |
| """An agentic class to encapsulate the wound care analysis pipeline.""" | |
| def __init__(self, models): | |
| self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models | |
| self.px_per_cm = 38 # Example value, should be calibrated for real-world use | |
| def detect_wound(self, image_cv): | |
| """Detects the wound region using YOLOv8.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."}) | |
| results = self.yolo_model(image_cv) # Use YOLOv8 model directly | |
| boxes = results[0].boxes.xyxy.cpu().numpy() # Access boxes from YOLOv8 results | |
| if len(boxes) == 0: | |
| return None, None | |
| # Assuming the largest bounding box is the wound (or the first detected) | |
| box = boxes[0] | |
| x1, y1, x2, y2 = map(int, box[:4]) | |
| detected_region = image_cv[y1:y2, x1:x2] | |
| return detected_region, (x1, y1, x2, y2) | |
| def segment_wound(self, detected_region): | |
| """Segments the wound from the detected region using the provided segmentation model.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."}) | |
| # Resize for segmentation model input | |
| resized = cv2.resize(detected_region, (256, 256)) / 255.0 | |
| input_tensor = np.expand_dims(resized, axis=0) | |
| pred_mask = self.seg_model.predict(input_tensor)[0] | |
| binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8) | |
| return binary_mask | |
| def estimate_area(self, mask): | |
| """Estimates the area of the wound from the mask.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."}) | |
| pixel_area = np.sum(mask > 0) | |
| area_cm2 = pixel_area / (self.px_per_cm ** 2) | |
| return round(area_cm2, 2) | |
| def classify_wound(self, detected_region): | |
| """Classifies the type of the wound using the provided classification pipeline.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."}) | |
| try: | |
| # Convert numpy array to PIL Image for the pipeline | |
| pil_image = Image.fromarray(detected_region) | |
| # Save to a temporary file for the pipeline, as it expects a file path or PIL Image | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| pil_image.save(tmp.name) | |
| tmp_path = tmp.name | |
| result = self.classify_pipe(tmp_path) | |
| import os | |
| os.unlink(tmp_path) # Clean up temporary file | |
| return result[0]["label"] | |
| except Exception as e: | |
| st.warning(f"Could not classify wound type: {e}") | |
| return "Unknown" | |
| def generate_recommendations(self, image, patient_info, analysis_results): | |
| """Generates a detailed assessment and treatment plan using Med-Gemma.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."}) | |
| # Prepare the image for Med-Gemma (ensure it's a PIL Image) | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Construct the prompt for Med-Gemma | |
| prompt_text = f"""Patient Info: | |
| - Age: {patient_info["age"]} | |
| - Diabetic: {patient_info["diabetic"]} | |
| - Wound Type: {analysis_results["wound_type"]} | |
| - Area: {analysis_results["area_cm2"]} | |
| - Signs of infection: {patient_info["infection"]} | |
| Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections: | |
| 1. **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data. | |
| 2. **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management. | |
| 3. **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices. | |
| 4. **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse. | |
| 5. **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress. | |
| **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment. | |
| """ | |
| # Med-Gemma expects messages in a specific format for multi-modal input | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": "You are a wound care expert."}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt_text}, | |
| {"type": "image", "image": image} | |
| ] | |
| } | |
| ] | |
| # Convert messages to input_ids using the processor's chat template | |
| input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device) | |
| # Generate response | |
| output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7) | |
| response = self.medgemma_processor.decode(output[0], skip_special_tokens=True) | |
| # Extract only the assistant's response part | |
| if "ASSISTANT:" in response: | |
| assistant_response = response.split("ASSISTANT:", 1)[1].strip() | |
| else: | |
| assistant_response = response.strip() | |
| return assistant_response | |
| def run_full_analysis(self, image, patient_info): | |
| """Executes the entire analysis pipeline.""" | |
| st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."}) | |
| image_cv = np.array(image.convert("RGB")) | |
| try: | |
| detected_region, box = self.detect_wound(image_cv) | |
| if detected_region is None: | |
| st.error("Agent Error: No wound could be detected in the image. Please try another image.") | |
| st.session_state.clear() | |
| return None | |
| except Exception as e: | |
| st.error(f"Error during wound detection: {e}") | |
| st.session_state.clear() | |
| return None | |
| try: | |
| mask_resized = self.segment_wound(detected_region) | |
| area_cm2 = self.estimate_area(mask_resized) | |
| except Exception as e: | |
| st.error(f"Error during wound segmentation or area estimation: {e}") | |
| st.session_state.clear() | |
| return None | |
| try: | |
| wound_type = self.classify_wound(detected_region) | |
| except Exception as e: | |
| st.error(f"Error during wound classification: {e}") | |
| st.session_state.clear() | |
| return None | |
| analysis_results = { | |
| "box": box, | |
| "detected_region": detected_region, | |
| "mask": mask_resized, | |
| "area_cm2": area_cm2, | |
| "wound_type": wound_type, | |
| } | |
| try: | |
| recommendations = self.generate_recommendations(image, patient_info, analysis_results) | |
| analysis_results["recommendations"] = recommendations | |
| except Exception as e: | |
| st.error(f"Error during recommendation generation: {e}") | |
| st.session_state.clear() | |
| return None | |
| st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."}) | |
| return analysis_results | |
| # --- UI Layout --- | |
| st.title("π©Ή SmartHeal: The Agentic Wound Care Assistant") | |
| # Initialize session state | |
| if "analysis_results" not in st.session_state: | |
| st.session_state.analysis_results = None | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # --- Sidebar for Inputs --- | |
| with st.sidebar: | |
| st.header("π Patient & Image Input") | |
| uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"]) | |
| with st.form("patient_form"): | |
| st.write("2. Enter Patient Details") | |
| age = st.number_input("Patient Age", min_value=1, max_value=120, value=50) | |
| diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0) | |
| infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| submitted = st.form_submit_button("π Analyze Wound", use_container_width=True) | |
| with col2: | |
| cleared = st.form_submit_button("β Clear", use_container_width=True) | |
| if cleared: | |
| st.session_state.analysis_results = None | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # --- Main Content Area --- | |
| if submitted and uploaded_file: | |
| # Load models and instantiate agent | |
| models = load_all_models() | |
| agent = WoundCareAgent(models) | |
| # Store patient info | |
| patient_info = {"age": age, "diabetic": diabetic, "infection": infection} | |
| # Open image | |
| image = Image.open(uploaded_file) | |
| # Clear previous results and run new analysis | |
| st.session_state.analysis_results = None | |
| st.session_state.messages = [] # Clear messages for new analysis | |
| st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."}) | |
| with st.spinner("The SmartHeal Agent is at work..."): | |
| st.session_state.analysis_results = agent.run_full_analysis(image, patient_info) | |
| # Display chat messages from the agent's process | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Display final results in tabs if analysis is complete | |
| if st.session_state.analysis_results: | |
| results = st.session_state.analysis_results | |
| image = Image.open(uploaded_file) # Re-open image for display | |
| image_cv = np.array(image.convert("RGB")) | |
| st.header("β Analysis Complete") | |
| tab1, tab2, tab3 = st.tabs(["π **Expert Recommendations**", "π¬ **Vision Analysis**", "π **Download Report**"]) | |
| with tab1: | |
| st.markdown(results["recommendations"]) | |
| with tab2: | |
| st.subheader("Wound Detection & Segmentation") | |
| # Create overlay | |
| x1, y1, x2, y2 = results["box"] | |
| overlay_image = image_cv.copy() | |
| # Ensure mask_resized is the correct size for the detected region | |
| mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST) | |
| # Create a colored mask for blending | |
| colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2]) | |
| colored_mask_region[mask_for_overlay > 0] = [255, 0, 0] # Red color for wound | |
| # Blend the original detected region with the colored mask | |
| overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0) | |
| # Draw bounding box | |
| cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box | |
| st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True) | |
| st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cmΒ²") | |
| st.metric(label="Classified Wound Type", value=f"{results['wound_type']}") | |
| with tab3: | |
| st.subheader("Download Full Report") | |
| st.download_button( | |
| label="π₯ Download as Text File", | |
| data=results["recommendations"], | |
| file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt", | |
| mime="text/plain" | |
| ) | |
| else: | |
| # Show information prompting the user to upload an image and fill in patient details. | |
| st.info("Please upload an image and patient details in the sidebar to start.") | |