Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException, Response | |
| import cv2 | |
| import numpy as np | |
| import io | |
| from typing import Union | |
| # --- Load Models --- | |
| def load_models(): | |
| """Loads machine learning models safely.""" | |
| segmentation_model, yolo_model = None, None | |
| try: | |
| import tensorflow as tf | |
| # Ensure you have a valid model file at this path | |
| segmentation_model = tf.keras.models.load_model("segmentation_model.h5") | |
| print("✅ Segmentation model loaded.") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load segmentation model: {e}") | |
| try: | |
| from ultralytics import YOLO | |
| # Ensure you have a valid model file at this path | |
| yolo_model = YOLO("best.pt") | |
| print("✅ YOLO model loaded.") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load YOLO model: {e}") | |
| return segmentation_model, yolo_model | |
| segmentation_model, yolo_model = load_models() | |
| # --- Configuration --- | |
| # This value should be calibrated by taking a picture of a ruler. | |
| # Measure a known length (e.g., 5cm) in pixels, then divide pixels by cm. | |
| PIXELS_PER_CM = 50.0 | |
| app = FastAPI( | |
| title="Wound Analyzer", | |
| version="10.2", # Version with improved depth logic | |
| description="Analyzes wound images, keeping the original API contract." | |
| ) | |
| # --- Image Processing --- | |
| def preprocess_image(image: np.ndarray) -> np.ndarray: | |
| """Enhances image for better segmentation by improving contrast and reducing noise.""" | |
| img_denoised = cv2.medianBlur(image, 3) | |
| lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB) | |
| l, a, b = cv2.split(lab) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| l_clahe = clahe.apply(l) | |
| lab_clahe = cv2.merge((l_clahe, a, b)) | |
| result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR) | |
| gamma = 1.2 | |
| return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8) | |
| def segment_wound(image: np.ndarray) -> np.ndarray: | |
| """Segments wound from a preprocessed image, with a fallback to KMeans if the model fails.""" | |
| if segmentation_model: | |
| try: | |
| input_size = segmentation_model.input_shape[1:3] | |
| resized = cv2.resize(image, (input_size[1], input_size[0])) | |
| norm = np.expand_dims(resized / 255.0, axis=0) | |
| prediction = segmentation_model.predict(norm, verbose=0) | |
| # Handle models with multiple outputs | |
| if isinstance(prediction, list): | |
| prediction = prediction[0] | |
| prediction = prediction[0] | |
| mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0])) | |
| return (mask >= 0.5).astype(np.uint8) * 255 | |
| except Exception as e: | |
| print(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.") | |
| # Fallback method using color clustering if the primary model fails | |
| Z = image.reshape((-1, 3)).astype(np.float32) | |
| criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) | |
| _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS) | |
| centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0] | |
| wound_idx = np.argmax(centers_lab[:, 1]) # Assume wound is the reddest cluster | |
| mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255 | |
| return mask | |
| def calculate_metrics(mask: np.ndarray, original_image: np.ndarray): | |
| """ | |
| Calculates all metrics. Depth is now calculated based on shadow/intensity analysis. | |
| """ | |
| area_px = cv2.countNonZero(mask) | |
| if area_px == 0: | |
| return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0) | |
| # --- Area and Dimensions --- | |
| area_cm2 = area_px / (PIXELS_PER_CM ** 2) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: # Handle case where mask is present but no contours are found | |
| return dict(area_cm2=area_cm2, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0) | |
| rect = cv2.minAreaRect(max(contours, key=cv2.contourArea)) | |
| (w, h) = rect[1] | |
| length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM | |
| # --- Depth and Moisture Calculation --- | |
| mask_bool = mask.astype(bool) | |
| gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY) | |
| # **ENHANCED DEPTH CALCULATION**: Use standard deviation of pixel intensity. | |
| # A higher standard deviation implies more shadows and highlights, suggesting greater depth. | |
| # The result is scaled to a 0-10 range for a more intuitive, relative depth score. | |
| intensity_std_dev = np.std(gray_image[mask_bool]) | |
| depth_score = (intensity_std_dev / 127.0) * 10.0 # Scale to a 0-10 range | |
| # Moisture calculation remains the same, based on intensity variance | |
| moisture = max(0.0, 100.0 * (1 - np.std(gray_image[mask_bool]) / 127.0)) | |
| return dict( | |
| area_cm2=round(area_cm2, 2), | |
| length_cm=round(length_cm, 2), | |
| breadth_cm=round(breadth_cm, 2), | |
| depth_cm=round(depth_score, 1), # This is now the new depth score | |
| moisture=round(moisture, 1) | |
| ) | |
| def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| """Draws a heatmap and boundary on the image for visualization.""" | |
| dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5) | |
| cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX) | |
| heatmap = np.zeros_like(image, dtype=np.uint8) | |
| heatmap[dist_transform >= 0.66] = (0, 255, 255) # Yellow - Core | |
| heatmap[(dist_transform >= 0.33) & (dist_transform < 0.66)] = (255, 0, 0) # Blue - Moderate | |
| heatmap[(dist_transform > 0) & (dist_transform < 0.33)] = (0, 255, 0) # Green - Periphery | |
| # Create a blended image only where the mask is active | |
| blended = image.copy() | |
| alpha = 0.4 | |
| masked_pixels = mask.astype(bool) | |
| blended[masked_pixels] = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)[masked_pixels] | |
| # Draw a clean, white contour | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(blended, contours, -1, (255, 255, 255), 2) | |
| return blended | |
| # --- API Endpoint --- | |
| async def analyze(file: UploadFile = File(...)): | |
| """ | |
| Accepts an image, analyzes the wound, and returns an annotated image | |
| with metrics in the response headers, maintaining the original API contract. | |
| """ | |
| contents = await file.read() | |
| original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR) | |
| if original_image is None: | |
| raise HTTPException(status_code=400, detail="Invalid or corrupt image file.") | |
| # 1. Preprocess a copy of the image for object detection and segmentation | |
| preprocessed_image = preprocess_image(original_image) | |
| # 2. Detect ROI using YOLO if available, otherwise use the whole image | |
| roi_coords = (0, 0, original_image.shape[1], original_image.shape[0]) | |
| if yolo_model: | |
| try: | |
| results = yolo_model.predict(preprocessed_image, verbose=False) | |
| if results and results[0].boxes: | |
| coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) | |
| roi_coords = (coords[0], coords[1], coords[2], coords[3]) | |
| except Exception as e: | |
| print(f"⚠️ YOLO detection failed: {e}. Using full image as ROI.") | |
| x1, y1, x2, y2 = roi_coords | |
| original_roi = original_image[y1:y2, x1:x2] | |
| preprocessed_roi = preprocessed_image[y1:y2, x1:x2] | |
| if original_roi.size == 0: | |
| raise HTTPException(status_code=404, detail="Wound region of interest could not be determined.") | |
| # 3. Get mask from the preprocessed ROI | |
| mask = segment_wound(preprocessed_roi) | |
| # 4. Calculate metrics using the ORIGINAL ROI for color/intensity accuracy | |
| metrics = calculate_metrics(mask, original_roi) | |
| # 5. Draw overlay on the ORIGINAL ROI for correct visualization | |
| annotated_image = draw_overlay(original_roi, mask) | |
| # 6. Encode image and prepare response | |
| success, out_bytes = cv2.imencode(".png", annotated_image) | |
| if not success: | |
| raise HTTPException(status_code=500, detail="Failed to encode output image.") | |
| headers = { | |
| "X-Length-Cm": str(metrics["length_cm"]), | |
| "X-Breadth-Cm": str(metrics["breadth_cm"]), | |
| "X-Depth-Cm": str(metrics["depth_cm"]), | |
| "X-Area-Cm2": str(metrics["area_cm2"]), | |
| "X-Moisture": str(metrics["moisture"]), | |
| } | |
| return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers) |