File size: 8,607 Bytes
0e2bcdc
ab3691c
 
0e2bcdc
2619f31
ab3691c
7dda7cf
831ed15
17d50ed
7dda7cf
831ed15
7ff1d87
17d50ed
831ed15
7dda7cf
 
 
 
 
17d50ed
7dda7cf
 
 
 
 
7ff1d87
 
 
17d50ed
 
 
7ff1d87
 
17d50ed
 
 
 
 
 
 
7ff1d87
17d50ed
7ff1d87
 
7dda7cf
17d50ed
bfe1160
 
 
7ff1d87
7dda7cf
7ff1d87
 
17d50ed
 
 
7dda7cf
 
 
 
17d50ed
bfe1160
7dda7cf
bfe1160
7dda7cf
 
17d50ed
 
7dda7cf
17d50ed
7dda7cf
17d50ed
 
7ff1d87
17d50ed
7dda7cf
7ff1d87
 
bfe1160
17d50ed
 
 
7dda7cf
 
 
17d50ed
 
7dda7cf
 
17d50ed
 
 
7dda7cf
 
 
af8bbd0
17d50ed
7dda7cf
17d50ed
bfe1160
17d50ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dda7cf
 
17d50ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dda7cf
17d50ed
 
37ac18e
7dda7cf
dd6e4e2
7dda7cf
17d50ed
 
 
 
bfe1160
 
 
17d50ed
7dda7cf
17d50ed
bfe1160
 
17d50ed
 
7ff1d87
 
bfe1160
7dda7cf
 
17d50ed
7ff1d87
17d50ed
 
 
 
 
 
 
 
7ff1d87
bfe1160
 
 
17d50ed
bfe1160
 
 
 
7ff1d87
17d50ed
bfe1160
0e2bcdc
17d50ed
97562fa
0e2bcdc
17d50ed
 
 
 
7dda7cf
0e2bcdc
bfe1160
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from fastapi import FastAPI, File, UploadFile, HTTPException, Response
import cv2
import numpy as np
import io
from typing import Union

# --- Load Models ---
def load_models():
    """Loads machine learning models safely."""
    segmentation_model, yolo_model = None, None
    try:
        import tensorflow as tf
        # Ensure you have a valid model file at this path
        segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
        print("✅ Segmentation model loaded.")
    except Exception as e:
        print(f"⚠️ Failed to load segmentation model: {e}")
    try:
        from ultralytics import YOLO
        # Ensure you have a valid model file at this path
        yolo_model = YOLO("best.pt")
        print("✅ YOLO model loaded.")
    except Exception as e:
        print(f"⚠️ Failed to load YOLO model: {e}")
    return segmentation_model, yolo_model

segmentation_model, yolo_model = load_models()

# --- Configuration ---
# This value should be calibrated by taking a picture of a ruler.
# Measure a known length (e.g., 5cm) in pixels, then divide pixels by cm.
PIXELS_PER_CM = 50.0

app = FastAPI(
    title="Wound Analyzer",
    version="10.2", # Version with improved depth logic
    description="Analyzes wound images, keeping the original API contract."
)

# --- Image Processing ---
def preprocess_image(image: np.ndarray) -> np.ndarray:
    """Enhances image for better segmentation by improving contrast and reducing noise."""
    img_denoised = cv2.medianBlur(image, 3)
    lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    l_clahe = clahe.apply(l)
    lab_clahe = cv2.merge((l_clahe, a, b))
    result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
    gamma = 1.2
    return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)

def segment_wound(image: np.ndarray) -> np.ndarray:
    """Segments wound from a preprocessed image, with a fallback to KMeans if the model fails."""
    if segmentation_model:
        try:
            input_size = segmentation_model.input_shape[1:3]
            resized = cv2.resize(image, (input_size[1], input_size[0]))
            norm = np.expand_dims(resized / 255.0, axis=0)
            prediction = segmentation_model.predict(norm, verbose=0)
            # Handle models with multiple outputs
            if isinstance(prediction, list):
                prediction = prediction[0]
            prediction = prediction[0]
            mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
            return (mask >= 0.5).astype(np.uint8) * 255
        except Exception as e:
            print(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")

    # Fallback method using color clustering if the primary model fails
    Z = image.reshape((-1, 3)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
    centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
    wound_idx = np.argmax(centers_lab[:, 1]) # Assume wound is the reddest cluster
    mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
    return mask

def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
    """
    Calculates all metrics. Depth is now calculated based on shadow/intensity analysis.
    """
    area_px = cv2.countNonZero(mask)
    if area_px == 0:
        return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)

    # --- Area and Dimensions ---
    area_cm2 = area_px / (PIXELS_PER_CM ** 2)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours: # Handle case where mask is present but no contours are found
        return dict(area_cm2=area_cm2, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)

    rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
    (w, h) = rect[1]
    length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM

    # --- Depth and Moisture Calculation ---
    mask_bool = mask.astype(bool)
    gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
    
    # **ENHANCED DEPTH CALCULATION**: Use standard deviation of pixel intensity.
    # A higher standard deviation implies more shadows and highlights, suggesting greater depth.
    # The result is scaled to a 0-10 range for a more intuitive, relative depth score.
    intensity_std_dev = np.std(gray_image[mask_bool])
    depth_score = (intensity_std_dev / 127.0) * 10.0 # Scale to a 0-10 range

    # Moisture calculation remains the same, based on intensity variance
    moisture = max(0.0, 100.0 * (1 - np.std(gray_image[mask_bool]) / 127.0))

    return dict(
        area_cm2=round(area_cm2, 2),
        length_cm=round(length_cm, 2),
        breadth_cm=round(breadth_cm, 2),
        depth_cm=round(depth_score, 1), # This is now the new depth score
        moisture=round(moisture, 1)
    )

def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Draws a heatmap and boundary on the image for visualization."""
    dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
    cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
    
    heatmap = np.zeros_like(image, dtype=np.uint8)
    heatmap[dist_transform >= 0.66] = (0, 255, 255)      # Yellow - Core
    heatmap[(dist_transform >= 0.33) & (dist_transform < 0.66)] = (255, 0, 0)  # Blue - Moderate
    heatmap[(dist_transform > 0) & (dist_transform < 0.33)] = (0, 255, 0)      # Green - Periphery

    # Create a blended image only where the mask is active
    blended = image.copy()
    alpha = 0.4
    masked_pixels = mask.astype(bool)
    blended[masked_pixels] = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)[masked_pixels]

    # Draw a clean, white contour
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(blended, contours, -1, (255, 255, 255), 2)
    return blended

# --- API Endpoint ---
@app.post("/analyze_wound")
async def analyze(file: UploadFile = File(...)):
    """
    Accepts an image, analyzes the wound, and returns an annotated image
    with metrics in the response headers, maintaining the original API contract.
    """
    contents = await file.read()
    original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
    if original_image is None:
        raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")

    # 1. Preprocess a copy of the image for object detection and segmentation
    preprocessed_image = preprocess_image(original_image)
    
    # 2. Detect ROI using YOLO if available, otherwise use the whole image
    roi_coords = (0, 0, original_image.shape[1], original_image.shape[0])
    if yolo_model:
        try:
            results = yolo_model.predict(preprocessed_image, verbose=False)
            if results and results[0].boxes:
                coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
                roi_coords = (coords[0], coords[1], coords[2], coords[3])
        except Exception as e:
            print(f"⚠️ YOLO detection failed: {e}. Using full image as ROI.")

    x1, y1, x2, y2 = roi_coords
    original_roi = original_image[y1:y2, x1:x2]
    preprocessed_roi = preprocessed_image[y1:y2, x1:x2]

    if original_roi.size == 0:
        raise HTTPException(status_code=404, detail="Wound region of interest could not be determined.")

    # 3. Get mask from the preprocessed ROI
    mask = segment_wound(preprocessed_roi)
    
    # 4. Calculate metrics using the ORIGINAL ROI for color/intensity accuracy
    metrics = calculate_metrics(mask, original_roi)
    
    # 5. Draw overlay on the ORIGINAL ROI for correct visualization
    annotated_image = draw_overlay(original_roi, mask)

    # 6. Encode image and prepare response
    success, out_bytes = cv2.imencode(".png", annotated_image)
    if not success:
        raise HTTPException(status_code=500, detail="Failed to encode output image.")

    headers = {
        "X-Length-Cm": str(metrics["length_cm"]),
        "X-Breadth-Cm": str(metrics["breadth_cm"]),
        "X-Depth-Cm": str(metrics["depth_cm"]),
        "X-Area-Cm2": str(metrics["area_cm2"]),
        "X-Moisture": str(metrics["moisture"]),
    }
    return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)