File size: 14,686 Bytes
ce45ba9
 
 
 
 
293bc49
d451e55
0ab2b4c
b450aae
0ab2b4c
b450aae
d451e55
0ab2b4c
b450aae
 
 
d451e55
b450aae
 
 
 
 
 
 
d451e55
b450aae
d451e55
 
b450aae
 
 
 
 
 
0ab2b4c
 
b450aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab2b4c
b450aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab2b4c
b450aae
 
 
 
 
 
 
d451e55
b450aae
 
 
 
 
 
293bc49
b450aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import os

# Force Streamlit (and libraries that depend on ~/.config) to use a writable directory
os.environ["HOME"] = os.getcwd()          # or "/tmp"
os.environ["MPLCONFIGDIR"] = os.path.join(os.getcwd(), ".config")  # ensure matplotlib writes here
import streamlit as st
import numpy as np
import cv2
import torch
import tempfile
from PIL import Image
from tensorflow.keras.models import load_model
from transformers import pipeline, AutoProcessor, LlavaForConditionalGeneration
import io
import os  # Import the os module to access environment variables
from ultralytics import YOLO  # Import YOLO from ultralytics

# --- Page Configuration (Best practice: call this first) ---
st.set_page_config(
    page_title="SmartHeal Wound Care Agent",
    page_icon="🩹",
    layout="wide",
    initial_sidebar_state="expanded"
)

# --- Model Loading (Cached for performance) ---
@st.cache_resource
def load_all_models():
    """Loads all required models and pipelines into memory once."""
    try:
        # Get Hugging Face token from environment variable
        hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
        if not hf_token:
            st.error("Fatal Error: Hugging Face token not found in environment variables (HF_TOKEN or HUGGING_FACE_HUB_TOKEN).")
            st.stop()

        # YOLOv8 detection model (using user's specified path)
        detection_model = YOLO("/home/ubuntu/upload/best(1).pt")  # Load YOLOv8 model

        # Segmentation model (using user's specified path)
        segmentation_model = load_model("/home/ubuntu/upload/segmentation_model.h5", compile=False)

        # Classification model (using user's specified model ID)
        # Some pipelines might require token for private models or rate limits
        classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification", token=hf_token)

        # Med-Gemma for analysis (using user's specified model ID)
        medgemma_model_id = "google/medgemma-4b-it"
        medgemma_processor = AutoProcessor.from_pretrained(medgemma_model_id, token=hf_token)
        medgemma_model = LlavaForConditionalGeneration.from_pretrained(
            medgemma_model_id,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            token=hf_token  # Pass the token here
        )
        medgemma_model.to("cuda")  # Move model to GPU

        return detection_model, segmentation_model, classification_pipe, medgemma_model, medgemma_processor
    except Exception as e:
        st.error(f"Fatal Error: Could not load models. Please check model paths, dependencies, and Hugging Face token. Details: {e}")
        st.stop()


# --- Agent Class ---
class WoundCareAgent:
    """An agentic class to encapsulate the wound care analysis pipeline."""
    def __init__(self, models):
        self.yolo_model, self.seg_model, self.classify_pipe, self.medgemma_model, self.medgemma_processor = models
        self.px_per_cm = 38  # Example value, should be calibrated for real-world use

    def detect_wound(self, image_cv):
        """Detects the wound region using YOLOv8."""
        st.session_state.messages.append({"role": "assistant", "content": "Detecting wound..."})
        results = self.yolo_model(image_cv)  # Use YOLOv8 model directly
        boxes = results[0].boxes.xyxy.cpu().numpy()  # Access boxes from YOLOv8 results
        if len(boxes) == 0:
            return None, None
        # Assuming the largest bounding box is the wound (or the first detected)
        box = boxes[0]
        x1, y1, x2, y2 = map(int, box[:4])
        detected_region = image_cv[y1:y2, x1:x2]
        return detected_region, (x1, y1, x2, y2)

    def segment_wound(self, detected_region):
        """Segments the wound from the detected region using the provided segmentation model."""
        st.session_state.messages.append({"role": "assistant", "content": "Segmenting wound area..."})
        # Resize for segmentation model input
        resized = cv2.resize(detected_region, (256, 256)) / 255.0
        input_tensor = np.expand_dims(resized, axis=0)
        pred_mask = self.seg_model.predict(input_tensor)[0]
        binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8)
        return binary_mask

    def estimate_area(self, mask):
        """Estimates the area of the wound from the mask."""
        st.session_state.messages.append({"role": "assistant", "content": "Estimating wound area..."})
        pixel_area = np.sum(mask > 0)
        area_cm2 = pixel_area / (self.px_per_cm ** 2)
        return round(area_cm2, 2)

    def classify_wound(self, detected_region):
        """Classifies the type of the wound using the provided classification pipeline."""
        st.session_state.messages.append({"role": "assistant", "content": "Classifying wound type..."})
        try:
            # Convert numpy array to PIL Image for the pipeline
            pil_image = Image.fromarray(detected_region)
            # Save to a temporary file for the pipeline, as it expects a file path or PIL Image
            with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
                pil_image.save(tmp.name)
                tmp_path = tmp.name
            result = self.classify_pipe(tmp_path)
            import os
            os.unlink(tmp_path)  # Clean up temporary file
            return result[0]["label"]
        except Exception as e:
            st.warning(f"Could not classify wound type: {e}")
            return "Unknown"

    def generate_recommendations(self, image, patient_info, analysis_results):
        """Generates a detailed assessment and treatment plan using Med-Gemma."""
        st.session_state.messages.append({"role": "assistant", "content": "Generating expert recommendations with Med-Gemma..."})

        # Prepare the image for Med-Gemma (ensure it's a PIL Image)
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        # Construct the prompt for Med-Gemma
        prompt_text = f"""Patient Info:
- Age: {patient_info["age"]}
- Diabetic: {patient_info["diabetic"]}
- Wound Type: {analysis_results["wound_type"]}
- Area: {analysis_results["area_cm2"]}
- Signs of infection: {patient_info["infection"]}

Please act as a highly experienced wound care specialist. Provide a comprehensive wound assessment and a detailed treatment plan. Structure your response clearly with the following sections:

1.  **Wound Assessment:** Describe the wound's characteristics, potential causes, and current state based on the image and provided data.
2.  **Recommended Treatment Plan:** Outline a primary course of action, including general principles of wound management.
3.  **Cleaning and Dressing Protocol:** Provide specific, step-by-step instructions for wound cleaning and appropriate dressing choices.
4.  **Red Flags & When to Seek Professional Medical Attention:** List critical signs and symptoms that indicate complications or require immediate consultation with a doctor or wound care nurse.
5.  **Follow-up Schedule:** Suggest a realistic timeline for wound reassessment and monitoring progress.

**Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice. Always consult a qualified healthcare provider for diagnosis and treatment.
"""

        # Med-Gemma expects messages in a specific format for multi-modal input
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a wound care expert."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_text},
                    {"type": "image", "image": image}
                ]
            }
        ]

        # Convert messages to input_ids using the processor's chat template
        input_ids = self.medgemma_processor.apply_chat_template(messages, return_tensors="pt").to(self.medgemma_model.device)

        # Generate response
        output = self.medgemma_model.generate(input_ids, max_new_tokens=1000, do_sample=True, temperature=0.7)
        response = self.medgemma_processor.decode(output[0], skip_special_tokens=True)

        # Extract only the assistant's response part
        if "ASSISTANT:" in response:
            assistant_response = response.split("ASSISTANT:", 1)[1].strip()
        else:
            assistant_response = response.strip()

        return assistant_response

    def run_full_analysis(self, image, patient_info):
        """Executes the entire analysis pipeline."""
        st.session_state.messages.append({"role": "assistant", "content": "Starting analysis..."})
        image_cv = np.array(image.convert("RGB"))

        try:
            detected_region, box = self.detect_wound(image_cv)
            if detected_region is None:
                st.error("Agent Error: No wound could be detected in the image. Please try another image.")
                st.session_state.clear()
                return None
        except Exception as e:
            st.error(f"Error during wound detection: {e}")
            st.session_state.clear()
            return None

        try:
            mask_resized = self.segment_wound(detected_region)
            area_cm2 = self.estimate_area(mask_resized)
        except Exception as e:
            st.error(f"Error during wound segmentation or area estimation: {e}")
            st.session_state.clear()
            return None

        try:
            wound_type = self.classify_wound(detected_region)
        except Exception as e:
            st.error(f"Error during wound classification: {e}")
            st.session_state.clear()
            return None

        analysis_results = {
            "box": box,
            "detected_region": detected_region,
            "mask": mask_resized,
            "area_cm2": area_cm2,
            "wound_type": wound_type,
        }

        try:
            recommendations = self.generate_recommendations(image, patient_info, analysis_results)
            analysis_results["recommendations"] = recommendations
        except Exception as e:
            st.error(f"Error during recommendation generation: {e}")
            st.session_state.clear()
            return None

        st.session_state.messages.append({"role": "assistant", "content": "Analysis complete. See results below."})
        return analysis_results


# --- UI Layout ---
st.title("🩹 SmartHeal: The Agentic Wound Care Assistant")

# Initialize session state
if "analysis_results" not in st.session_state:
    st.session_state.analysis_results = None
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- Sidebar for Inputs ---
with st.sidebar:
    st.header("πŸ“‹ Patient & Image Input")
    uploaded_file = st.file_uploader("1. Upload a clear wound image", type=["jpg", "jpeg", "png"])

    with st.form("patient_form"):
        st.write("2. Enter Patient Details")
        age = st.number_input("Patient Age", min_value=1, max_value=120, value=50)
        diabetic = st.radio("Is the patient diabetic?", ["No", "Yes"], index=0)
        infection = st.radio("Are there visible signs of infection (e.g., pus, redness, swelling)?", ["No", "Yes"], index=0)

        col1, col2 = st.columns(2)
        with col1:
            submitted = st.form_submit_button("πŸš€ Analyze Wound", use_container_width=True)
        with col2:
            cleared = st.form_submit_button("❌ Clear", use_container_width=True)

if cleared:
    st.session_state.analysis_results = None
    st.session_state.messages = []
    st.rerun()

# --- Main Content Area ---
if submitted and uploaded_file:
    # Load models and instantiate agent
    models = load_all_models()
    agent = WoundCareAgent(models)

    # Store patient info
    patient_info = {"age": age, "diabetic": diabetic, "infection": infection}

    # Open image
    image = Image.open(uploaded_file)

    # Clear previous results and run new analysis
    st.session_state.analysis_results = None
    st.session_state.messages = []  # Clear messages for new analysis
    st.session_state.messages.append({"role": "user", "content": "Analyzing the uploaded wound image..."})

    with st.spinner("The SmartHeal Agent is at work..."):
        st.session_state.analysis_results = agent.run_full_analysis(image, patient_info)

# Display chat messages from the agent's process
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Display final results in tabs if analysis is complete
if st.session_state.analysis_results:
    results = st.session_state.analysis_results
    image = Image.open(uploaded_file)  # Re-open image for display
    image_cv = np.array(image.convert("RGB"))

    st.header("βœ… Analysis Complete")
    tab1, tab2, tab3 = st.tabs(["πŸ“ **Expert Recommendations**", "πŸ”¬ **Vision Analysis**", "πŸ“„ **Download Report**"])

    with tab1:
        st.markdown(results["recommendations"])

    with tab2:
        st.subheader("Wound Detection & Segmentation")

        # Create overlay
        x1, y1, x2, y2 = results["box"]
        overlay_image = image_cv.copy()

        # Ensure mask_resized is the correct size for the detected region
        mask_for_overlay = cv2.resize(results["mask"], (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST)

        # Create a colored mask for blending
        colored_mask_region = np.zeros_like(overlay_image[y1:y2, x1:x2])
        colored_mask_region[mask_for_overlay > 0] = [255, 0, 0]  # Red color for wound

        # Blend the original detected region with the colored mask
        overlay_image[y1:y2, x1:x2] = cv2.addWeighted(overlay_image[y1:y2, x1:x2], 0.7, colored_mask_region, 0.3, 0)

        # Draw bounding box
        cv2.rectangle(overlay_image, (x1, y1), (x2, y2), (0, 255, 0), 2)  # Green box

        st.image(overlay_image, caption="Detected Wound with Segmentation Overlay", use_column_width=True)
        st.metric(label="Estimated Wound Area", value=f"{results['area_cm2']} cmΒ²")
        st.metric(label="Classified Wound Type", value=f"{results['wound_type']}")

    with tab3:
        st.subheader("Download Full Report")
        st.download_button(
            label="πŸ“₯ Download as Text File",
            data=results["recommendations"],
            file_name=f"wound_report_{uploaded_file.name.split('.')[0]}.txt",
            mime="text/plain"
        )
else:
    # Show information prompting the user to upload an image and fill in patient details.
    st.info("Please upload an image and patient details in the sidebar to start.")