Spaces:
Sleeping
Sleeping
Update predict.py
Browse files- predict.py +23 -41
predict.py
CHANGED
|
@@ -4,25 +4,24 @@ import numpy as np
|
|
| 4 |
from ultralytics import YOLO
|
| 5 |
import tensorflow as tf
|
| 6 |
import io
|
| 7 |
-
from typing import Union
|
| 8 |
|
| 9 |
# --- Configuration ---
|
| 10 |
-
PIXELS_PER_CM = 50.0
|
| 11 |
|
| 12 |
# --- App Initialization ---
|
| 13 |
app = FastAPI(
|
| 14 |
title="Wound Analysis API",
|
| 15 |
description="An API to analyze wound images and return an annotated image with data in headers.",
|
| 16 |
-
version="3.
|
| 17 |
)
|
| 18 |
|
| 19 |
# --- Model Loading ---
|
| 20 |
def load_models():
|
| 21 |
-
"""Loads ML models safely, allowing the app to run even if a model is missing."""
|
| 22 |
segmentation_model, yolo_model = None, None
|
| 23 |
try:
|
| 24 |
-
segmentation_model = tf.keras.models.load_model("
|
| 25 |
-
print("Segmentation model '
|
| 26 |
except Exception as e:
|
| 27 |
print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
|
| 28 |
|
|
@@ -39,26 +38,20 @@ segmentation_model, yolo_model = load_models()
|
|
| 39 |
# --- Helper Functions ---
|
| 40 |
|
| 41 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
| 42 |
-
"""Applies denoising, CLAHE, and gamma correction to the image."""
|
| 43 |
img_denoised = cv2.medianBlur(image, 3)
|
| 44 |
-
|
| 45 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
| 46 |
l_channel, a_channel, b_channel = cv2.split(lab)
|
| 47 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 48 |
l_clahe = clahe.apply(l_channel)
|
| 49 |
lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
|
| 50 |
img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
| 51 |
-
|
| 52 |
gamma = 1.2
|
| 53 |
img_float = img_clahe.astype(np.float32) / 255.0
|
| 54 |
img_gamma = np.power(img_float, gamma)
|
| 55 |
return (img_gamma * 255).astype(np.uint8)
|
| 56 |
|
| 57 |
-
# --- FIX APPLIED HERE ---
|
| 58 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
| 59 |
-
|
| 60 |
-
if not yolo_model:
|
| 61 |
-
return None
|
| 62 |
try:
|
| 63 |
results = yolo_model.predict(image, verbose=False)
|
| 64 |
if results and results[0].boxes:
|
|
@@ -69,7 +62,6 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
|
| 69 |
print(f"YOLO prediction failed: {e}")
|
| 70 |
return None
|
| 71 |
|
| 72 |
-
# --- FIX APPLIED HERE ---
|
| 73 |
def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
|
| 74 |
"""Segments the wound using the U-Net model."""
|
| 75 |
if not segmentation_model:
|
|
@@ -79,7 +71,19 @@ def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
|
|
| 79 |
img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
|
| 80 |
img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
|
| 84 |
|
| 85 |
return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
|
@@ -88,16 +92,12 @@ def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
|
|
| 88 |
return None
|
| 89 |
|
| 90 |
def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
|
| 91 |
-
"""Segments the wound using K-means color clustering as a fallback."""
|
| 92 |
pixels = image.reshape((-1, 3)).astype(np.float32)
|
| 93 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
| 94 |
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
| 95 |
-
|
| 96 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
| 97 |
wound_cluster_idx = np.argmax(centers_lab[:, 1])
|
| 98 |
-
|
| 99 |
mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
|
| 100 |
-
|
| 101 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 102 |
if contours:
|
| 103 |
largest_contour = max(contours, key=cv2.contourArea)
|
|
@@ -107,19 +107,13 @@ def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
|
|
| 107 |
return mask
|
| 108 |
|
| 109 |
def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
|
| 110 |
-
"""Calculates area, dimensions, depth, and moisture from the final mask and image."""
|
| 111 |
wound_pixels = cv2.countNonZero(mask)
|
| 112 |
if wound_pixels == 0:
|
| 113 |
-
return {
|
| 114 |
-
"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0,
|
| 115 |
-
"depth_score": 0.0, "moisture_score": 0.0
|
| 116 |
-
}
|
| 117 |
|
| 118 |
area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
|
| 119 |
-
|
| 120 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 121 |
largest_contour = max(contours, key=cv2.contourArea)
|
| 122 |
-
|
| 123 |
(_, (width, height), _) = cv2.minAreaRect(largest_contour)
|
| 124 |
length_cm = max(width, height) / PIXELS_PER_CM
|
| 125 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
|
@@ -130,28 +124,19 @@ def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
|
|
| 130 |
|
| 131 |
mean_a = np.mean(lab_img[:, :, 1][mask_bool])
|
| 132 |
depth_score = mean_a - 128.0
|
| 133 |
-
|
| 134 |
texture_std = np.std(gray_img[mask_bool])
|
| 135 |
moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
|
| 136 |
|
| 137 |
-
return {
|
| 138 |
-
"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm,
|
| 139 |
-
"depth_score": depth_score, "moisture_score": moisture_score
|
| 140 |
-
}
|
| 141 |
|
| 142 |
def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 143 |
-
|
| 144 |
-
if cv2.countNonZero(mask) == 0:
|
| 145 |
-
return image
|
| 146 |
-
|
| 147 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
| 148 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
| 149 |
-
|
| 150 |
overlay = np.zeros_like(image)
|
| 151 |
overlay[dist >= 0.66] = (0, 0, 255)
|
| 152 |
overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
|
| 153 |
overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
|
| 154 |
-
|
| 155 |
blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
|
| 156 |
annotated_img = image.copy()
|
| 157 |
annotated_img[mask.astype(bool)] = blended[mask.astype(bool)]
|
|
@@ -167,7 +152,6 @@ async def analyze_wound(file: UploadFile = File(...)):
|
|
| 167 |
raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
|
| 168 |
|
| 169 |
processed_image = preprocess_image(original_image)
|
| 170 |
-
|
| 171 |
bbox = detect_wound_region_yolo(processed_image)
|
| 172 |
if bbox:
|
| 173 |
xmin, ymin, xmax, ymax = bbox
|
|
@@ -180,7 +164,6 @@ async def analyze_wound(file: UploadFile = File(...)):
|
|
| 180 |
mask = segment_wound_with_fallback(cropped_image)
|
| 181 |
|
| 182 |
metrics = calculate_metrics(mask, cropped_image)
|
| 183 |
-
|
| 184 |
full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
|
| 185 |
if bbox:
|
| 186 |
full_mask[ymin:ymax, xmin:xmax] = mask
|
|
@@ -188,7 +171,6 @@ async def analyze_wound(file: UploadFile = File(...)):
|
|
| 188 |
full_mask = mask
|
| 189 |
|
| 190 |
annotated_image = create_visual_overlay(original_image, full_mask)
|
| 191 |
-
|
| 192 |
success, png_data = cv2.imencode(".png", annotated_image)
|
| 193 |
if not success:
|
| 194 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
|
@@ -200,4 +182,4 @@ async def analyze_wound(file: UploadFile = File(...)):
|
|
| 200 |
"Wound-Depth-Score": f"{metrics['depth_score']:.1f}",
|
| 201 |
"Wound-Moisture-Score": f"{metrics['moisture_score']:.0f}"
|
| 202 |
}
|
| 203 |
-
return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
|
|
|
|
| 4 |
from ultralytics import YOLO
|
| 5 |
import tensorflow as tf
|
| 6 |
import io
|
| 7 |
+
from typing import Union
|
| 8 |
|
| 9 |
# --- Configuration ---
|
| 10 |
+
PIXELS_PER_CM = 50.0
|
| 11 |
|
| 12 |
# --- App Initialization ---
|
| 13 |
app = FastAPI(
|
| 14 |
title="Wound Analysis API",
|
| 15 |
description="An API to analyze wound images and return an annotated image with data in headers.",
|
| 16 |
+
version="3.4.0" # Version updated for prediction output fix
|
| 17 |
)
|
| 18 |
|
| 19 |
# --- Model Loading ---
|
| 20 |
def load_models():
|
|
|
|
| 21 |
segmentation_model, yolo_model = None, None
|
| 22 |
try:
|
| 23 |
+
segmentation_model = tf.keras.models.load_model("segmentation model.h5")
|
| 24 |
+
print("Segmentation model 'segmentation model.h5' loaded successfully.")
|
| 25 |
except Exception as e:
|
| 26 |
print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
|
| 27 |
|
|
|
|
| 38 |
# --- Helper Functions ---
|
| 39 |
|
| 40 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
|
|
|
| 41 |
img_denoised = cv2.medianBlur(image, 3)
|
|
|
|
| 42 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
| 43 |
l_channel, a_channel, b_channel = cv2.split(lab)
|
| 44 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 45 |
l_clahe = clahe.apply(l_channel)
|
| 46 |
lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
|
| 47 |
img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
|
|
|
| 48 |
gamma = 1.2
|
| 49 |
img_float = img_clahe.astype(np.float32) / 255.0
|
| 50 |
img_gamma = np.power(img_float, gamma)
|
| 51 |
return (img_gamma * 255).astype(np.uint8)
|
| 52 |
|
|
|
|
| 53 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
| 54 |
+
if not yolo_model: return None
|
|
|
|
|
|
|
| 55 |
try:
|
| 56 |
results = yolo_model.predict(image, verbose=False)
|
| 57 |
if results and results[0].boxes:
|
|
|
|
| 62 |
print(f"YOLO prediction failed: {e}")
|
| 63 |
return None
|
| 64 |
|
|
|
|
| 65 |
def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
|
| 66 |
"""Segments the wound using the U-Net model."""
|
| 67 |
if not segmentation_model:
|
|
|
|
| 71 |
img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
|
| 72 |
img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
|
| 73 |
|
| 74 |
+
prediction = segmentation_model.predict(img_norm, verbose=0)
|
| 75 |
+
|
| 76 |
+
# --- FIX APPLIED HERE ---
|
| 77 |
+
# Handle cases where the model returns a list of outputs
|
| 78 |
+
if isinstance(prediction, list):
|
| 79 |
+
pred_mask = prediction[0]
|
| 80 |
+
else:
|
| 81 |
+
pred_mask = prediction
|
| 82 |
+
|
| 83 |
+
# The output of predict() is batched, so get the first item.
|
| 84 |
+
pred_mask = pred_mask[0]
|
| 85 |
+
# --- END OF FIX ---
|
| 86 |
+
|
| 87 |
pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
|
| 88 |
|
| 89 |
return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
|
|
|
| 92 |
return None
|
| 93 |
|
| 94 |
def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
|
|
|
|
| 95 |
pixels = image.reshape((-1, 3)).astype(np.float32)
|
| 96 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
| 97 |
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
|
|
|
| 98 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
| 99 |
wound_cluster_idx = np.argmax(centers_lab[:, 1])
|
|
|
|
| 100 |
mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
|
|
|
|
| 101 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 102 |
if contours:
|
| 103 |
largest_contour = max(contours, key=cv2.contourArea)
|
|
|
|
| 107 |
return mask
|
| 108 |
|
| 109 |
def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
|
|
|
|
| 110 |
wound_pixels = cv2.countNonZero(mask)
|
| 111 |
if wound_pixels == 0:
|
| 112 |
+
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_score": 0.0, "moisture_score": 0.0}
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
|
|
|
|
| 115 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 116 |
largest_contour = max(contours, key=cv2.contourArea)
|
|
|
|
| 117 |
(_, (width, height), _) = cv2.minAreaRect(largest_contour)
|
| 118 |
length_cm = max(width, height) / PIXELS_PER_CM
|
| 119 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
|
|
|
| 124 |
|
| 125 |
mean_a = np.mean(lab_img[:, :, 1][mask_bool])
|
| 126 |
depth_score = mean_a - 128.0
|
|
|
|
| 127 |
texture_std = np.std(gray_img[mask_bool])
|
| 128 |
moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
|
| 129 |
|
| 130 |
+
return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_score": depth_score, "moisture_score": moisture_score}
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 133 |
+
if cv2.countNonZero(mask) == 0: return image
|
|
|
|
|
|
|
|
|
|
| 134 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
| 135 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
|
|
|
| 136 |
overlay = np.zeros_like(image)
|
| 137 |
overlay[dist >= 0.66] = (0, 0, 255)
|
| 138 |
overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
|
| 139 |
overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
|
|
|
|
| 140 |
blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
|
| 141 |
annotated_img = image.copy()
|
| 142 |
annotated_img[mask.astype(bool)] = blended[mask.astype(bool)]
|
|
|
|
| 152 |
raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
|
| 153 |
|
| 154 |
processed_image = preprocess_image(original_image)
|
|
|
|
| 155 |
bbox = detect_wound_region_yolo(processed_image)
|
| 156 |
if bbox:
|
| 157 |
xmin, ymin, xmax, ymax = bbox
|
|
|
|
| 164 |
mask = segment_wound_with_fallback(cropped_image)
|
| 165 |
|
| 166 |
metrics = calculate_metrics(mask, cropped_image)
|
|
|
|
| 167 |
full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
|
| 168 |
if bbox:
|
| 169 |
full_mask[ymin:ymax, xmin:xmax] = mask
|
|
|
|
| 171 |
full_mask = mask
|
| 172 |
|
| 173 |
annotated_image = create_visual_overlay(original_image, full_mask)
|
|
|
|
| 174 |
success, png_data = cv2.imencode(".png", annotated_image)
|
| 175 |
if not success:
|
| 176 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
|
|
|
| 182 |
"Wound-Depth-Score": f"{metrics['depth_score']:.1f}",
|
| 183 |
"Wound-Moisture-Score": f"{metrics['moisture_score']:.0f}"
|
| 184 |
}
|
| 185 |
+
return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
|