wound_detect / predict.py
Ani14's picture
Update predict.py
17d50ed verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Response
import cv2
import numpy as np
import io
from typing import Union
# --- Load Models ---
def load_models():
"""Loads machine learning models safely."""
segmentation_model, yolo_model = None, None
try:
import tensorflow as tf
# Ensure you have a valid model file at this path
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
print("✅ Segmentation model loaded.")
except Exception as e:
print(f"⚠️ Failed to load segmentation model: {e}")
try:
from ultralytics import YOLO
# Ensure you have a valid model file at this path
yolo_model = YOLO("best.pt")
print("✅ YOLO model loaded.")
except Exception as e:
print(f"⚠️ Failed to load YOLO model: {e}")
return segmentation_model, yolo_model
segmentation_model, yolo_model = load_models()
# --- Configuration ---
# This value should be calibrated by taking a picture of a ruler.
# Measure a known length (e.g., 5cm) in pixels, then divide pixels by cm.
PIXELS_PER_CM = 50.0
app = FastAPI(
title="Wound Analyzer",
version="10.2", # Version with improved depth logic
description="Analyzes wound images, keeping the original API contract."
)
# --- Image Processing ---
def preprocess_image(image: np.ndarray) -> np.ndarray:
"""Enhances image for better segmentation by improving contrast and reducing noise."""
img_denoised = cv2.medianBlur(image, 3)
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
l_clahe = clahe.apply(l)
lab_clahe = cv2.merge((l_clahe, a, b))
result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
gamma = 1.2
return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
def segment_wound(image: np.ndarray) -> np.ndarray:
"""Segments wound from a preprocessed image, with a fallback to KMeans if the model fails."""
if segmentation_model:
try:
input_size = segmentation_model.input_shape[1:3]
resized = cv2.resize(image, (input_size[1], input_size[0]))
norm = np.expand_dims(resized / 255.0, axis=0)
prediction = segmentation_model.predict(norm, verbose=0)
# Handle models with multiple outputs
if isinstance(prediction, list):
prediction = prediction[0]
prediction = prediction[0]
mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
return (mask >= 0.5).astype(np.uint8) * 255
except Exception as e:
print(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
# Fallback method using color clustering if the primary model fails
Z = image.reshape((-1, 3)).astype(np.float32)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
_, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
wound_idx = np.argmax(centers_lab[:, 1]) # Assume wound is the reddest cluster
mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
return mask
def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
"""
Calculates all metrics. Depth is now calculated based on shadow/intensity analysis.
"""
area_px = cv2.countNonZero(mask)
if area_px == 0:
return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
# --- Area and Dimensions ---
area_cm2 = area_px / (PIXELS_PER_CM ** 2)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours: # Handle case where mask is present but no contours are found
return dict(area_cm2=area_cm2, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
(w, h) = rect[1]
length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
# --- Depth and Moisture Calculation ---
mask_bool = mask.astype(bool)
gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
# **ENHANCED DEPTH CALCULATION**: Use standard deviation of pixel intensity.
# A higher standard deviation implies more shadows and highlights, suggesting greater depth.
# The result is scaled to a 0-10 range for a more intuitive, relative depth score.
intensity_std_dev = np.std(gray_image[mask_bool])
depth_score = (intensity_std_dev / 127.0) * 10.0 # Scale to a 0-10 range
# Moisture calculation remains the same, based on intensity variance
moisture = max(0.0, 100.0 * (1 - np.std(gray_image[mask_bool]) / 127.0))
return dict(
area_cm2=round(area_cm2, 2),
length_cm=round(length_cm, 2),
breadth_cm=round(breadth_cm, 2),
depth_cm=round(depth_score, 1), # This is now the new depth score
moisture=round(moisture, 1)
)
def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Draws a heatmap and boundary on the image for visualization."""
dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
heatmap = np.zeros_like(image, dtype=np.uint8)
heatmap[dist_transform >= 0.66] = (0, 255, 255) # Yellow - Core
heatmap[(dist_transform >= 0.33) & (dist_transform < 0.66)] = (255, 0, 0) # Blue - Moderate
heatmap[(dist_transform > 0) & (dist_transform < 0.33)] = (0, 255, 0) # Green - Periphery
# Create a blended image only where the mask is active
blended = image.copy()
alpha = 0.4
masked_pixels = mask.astype(bool)
blended[masked_pixels] = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)[masked_pixels]
# Draw a clean, white contour
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(blended, contours, -1, (255, 255, 255), 2)
return blended
# --- API Endpoint ---
@app.post("/analyze_wound")
async def analyze(file: UploadFile = File(...)):
"""
Accepts an image, analyzes the wound, and returns an annotated image
with metrics in the response headers, maintaining the original API contract.
"""
contents = await file.read()
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
if original_image is None:
raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
# 1. Preprocess a copy of the image for object detection and segmentation
preprocessed_image = preprocess_image(original_image)
# 2. Detect ROI using YOLO if available, otherwise use the whole image
roi_coords = (0, 0, original_image.shape[1], original_image.shape[0])
if yolo_model:
try:
results = yolo_model.predict(preprocessed_image, verbose=False)
if results and results[0].boxes:
coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
roi_coords = (coords[0], coords[1], coords[2], coords[3])
except Exception as e:
print(f"⚠️ YOLO detection failed: {e}. Using full image as ROI.")
x1, y1, x2, y2 = roi_coords
original_roi = original_image[y1:y2, x1:x2]
preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
if original_roi.size == 0:
raise HTTPException(status_code=404, detail="Wound region of interest could not be determined.")
# 3. Get mask from the preprocessed ROI
mask = segment_wound(preprocessed_roi)
# 4. Calculate metrics using the ORIGINAL ROI for color/intensity accuracy
metrics = calculate_metrics(mask, original_roi)
# 5. Draw overlay on the ORIGINAL ROI for correct visualization
annotated_image = draw_overlay(original_roi, mask)
# 6. Encode image and prepare response
success, out_bytes = cv2.imencode(".png", annotated_image)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode output image.")
headers = {
"X-Length-Cm": str(metrics["length_cm"]),
"X-Breadth-Cm": str(metrics["breadth_cm"]),
"X-Depth-Cm": str(metrics["depth_cm"]),
"X-Area-Cm2": str(metrics["area_cm2"]),
"X-Moisture": str(metrics["moisture"]),
}
return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)