Ani14 commited on
Commit
2619f31
·
verified ·
1 Parent(s): 6c507c1

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +23 -41
predict.py CHANGED
@@ -4,25 +4,24 @@ import numpy as np
4
  from ultralytics import YOLO
5
  import tensorflow as tf
6
  import io
7
- from typing import Union # <--- IMPORT THIS
8
 
9
  # --- Configuration ---
10
- PIXELS_PER_CM = 50.0 # Calibrate this value: pixels per centimeter
11
 
12
  # --- App Initialization ---
13
  app = FastAPI(
14
  title="Wound Analysis API",
15
  description="An API to analyze wound images and return an annotated image with data in headers.",
16
- version="3.3.0" # Version updated for Python 3.9 compatibility
17
  )
18
 
19
  # --- Model Loading ---
20
  def load_models():
21
- """Loads ML models safely, allowing the app to run even if a model is missing."""
22
  segmentation_model, yolo_model = None, None
23
  try:
24
- segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
25
- print("Segmentation model 'segmentation_model.h5' loaded successfully.")
26
  except Exception as e:
27
  print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
28
 
@@ -39,26 +38,20 @@ segmentation_model, yolo_model = load_models()
39
  # --- Helper Functions ---
40
 
41
  def preprocess_image(image: np.ndarray) -> np.ndarray:
42
- """Applies denoising, CLAHE, and gamma correction to the image."""
43
  img_denoised = cv2.medianBlur(image, 3)
44
-
45
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
46
  l_channel, a_channel, b_channel = cv2.split(lab)
47
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
48
  l_clahe = clahe.apply(l_channel)
49
  lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
50
  img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
51
-
52
  gamma = 1.2
53
  img_float = img_clahe.astype(np.float32) / 255.0
54
  img_gamma = np.power(img_float, gamma)
55
  return (img_gamma * 255).astype(np.uint8)
56
 
57
- # --- FIX APPLIED HERE ---
58
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
59
- """Detects the primary wound region using YOLO and returns its bounding box."""
60
- if not yolo_model:
61
- return None
62
  try:
63
  results = yolo_model.predict(image, verbose=False)
64
  if results and results[0].boxes:
@@ -69,7 +62,6 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
69
  print(f"YOLO prediction failed: {e}")
70
  return None
71
 
72
- # --- FIX APPLIED HERE ---
73
  def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
74
  """Segments the wound using the U-Net model."""
75
  if not segmentation_model:
@@ -79,7 +71,19 @@ def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
79
  img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
80
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
81
 
82
- pred_mask = segmentation_model.predict(img_norm, verbose=0)[0]
 
 
 
 
 
 
 
 
 
 
 
 
83
  pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
84
 
85
  return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
@@ -88,16 +92,12 @@ def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
88
  return None
89
 
90
  def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
91
- """Segments the wound using K-means color clustering as a fallback."""
92
  pixels = image.reshape((-1, 3)).astype(np.float32)
93
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
94
  _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
95
-
96
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
97
  wound_cluster_idx = np.argmax(centers_lab[:, 1])
98
-
99
  mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
100
-
101
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
102
  if contours:
103
  largest_contour = max(contours, key=cv2.contourArea)
@@ -107,19 +107,13 @@ def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
107
  return mask
108
 
109
  def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
110
- """Calculates area, dimensions, depth, and moisture from the final mask and image."""
111
  wound_pixels = cv2.countNonZero(mask)
112
  if wound_pixels == 0:
113
- return {
114
- "area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0,
115
- "depth_score": 0.0, "moisture_score": 0.0
116
- }
117
 
118
  area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
119
-
120
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
121
  largest_contour = max(contours, key=cv2.contourArea)
122
-
123
  (_, (width, height), _) = cv2.minAreaRect(largest_contour)
124
  length_cm = max(width, height) / PIXELS_PER_CM
125
  breadth_cm = min(width, height) / PIXELS_PER_CM
@@ -130,28 +124,19 @@ def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
130
 
131
  mean_a = np.mean(lab_img[:, :, 1][mask_bool])
132
  depth_score = mean_a - 128.0
133
-
134
  texture_std = np.std(gray_img[mask_bool])
135
  moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
136
 
137
- return {
138
- "area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm,
139
- "depth_score": depth_score, "moisture_score": moisture_score
140
- }
141
 
142
  def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
143
- """Generates a colored overlay on the image based on the wound mask."""
144
- if cv2.countNonZero(mask) == 0:
145
- return image
146
-
147
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
148
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
149
-
150
  overlay = np.zeros_like(image)
151
  overlay[dist >= 0.66] = (0, 0, 255)
152
  overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
153
  overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
154
-
155
  blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
156
  annotated_img = image.copy()
157
  annotated_img[mask.astype(bool)] = blended[mask.astype(bool)]
@@ -167,7 +152,6 @@ async def analyze_wound(file: UploadFile = File(...)):
167
  raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
168
 
169
  processed_image = preprocess_image(original_image)
170
-
171
  bbox = detect_wound_region_yolo(processed_image)
172
  if bbox:
173
  xmin, ymin, xmax, ymax = bbox
@@ -180,7 +164,6 @@ async def analyze_wound(file: UploadFile = File(...)):
180
  mask = segment_wound_with_fallback(cropped_image)
181
 
182
  metrics = calculate_metrics(mask, cropped_image)
183
-
184
  full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
185
  if bbox:
186
  full_mask[ymin:ymax, xmin:xmax] = mask
@@ -188,7 +171,6 @@ async def analyze_wound(file: UploadFile = File(...)):
188
  full_mask = mask
189
 
190
  annotated_image = create_visual_overlay(original_image, full_mask)
191
-
192
  success, png_data = cv2.imencode(".png", annotated_image)
193
  if not success:
194
  raise HTTPException(status_code=500, detail="Failed to encode output image")
@@ -200,4 +182,4 @@ async def analyze_wound(file: UploadFile = File(...)):
200
  "Wound-Depth-Score": f"{metrics['depth_score']:.1f}",
201
  "Wound-Moisture-Score": f"{metrics['moisture_score']:.0f}"
202
  }
203
- return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
 
4
  from ultralytics import YOLO
5
  import tensorflow as tf
6
  import io
7
+ from typing import Union
8
 
9
  # --- Configuration ---
10
+ PIXELS_PER_CM = 50.0
11
 
12
  # --- App Initialization ---
13
  app = FastAPI(
14
  title="Wound Analysis API",
15
  description="An API to analyze wound images and return an annotated image with data in headers.",
16
+ version="3.4.0" # Version updated for prediction output fix
17
  )
18
 
19
  # --- Model Loading ---
20
  def load_models():
 
21
  segmentation_model, yolo_model = None, None
22
  try:
23
+ segmentation_model = tf.keras.models.load_model("segmentation model.h5")
24
+ print("Segmentation model 'segmentation model.h5' loaded successfully.")
25
  except Exception as e:
26
  print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
27
 
 
38
  # --- Helper Functions ---
39
 
40
  def preprocess_image(image: np.ndarray) -> np.ndarray:
 
41
  img_denoised = cv2.medianBlur(image, 3)
 
42
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
43
  l_channel, a_channel, b_channel = cv2.split(lab)
44
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
45
  l_clahe = clahe.apply(l_channel)
46
  lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
47
  img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
 
48
  gamma = 1.2
49
  img_float = img_clahe.astype(np.float32) / 255.0
50
  img_gamma = np.power(img_float, gamma)
51
  return (img_gamma * 255).astype(np.uint8)
52
 
 
53
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
54
+ if not yolo_model: return None
 
 
55
  try:
56
  results = yolo_model.predict(image, verbose=False)
57
  if results and results[0].boxes:
 
62
  print(f"YOLO prediction failed: {e}")
63
  return None
64
 
 
65
  def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
66
  """Segments the wound using the U-Net model."""
67
  if not segmentation_model:
 
71
  img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
72
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
73
 
74
+ prediction = segmentation_model.predict(img_norm, verbose=0)
75
+
76
+ # --- FIX APPLIED HERE ---
77
+ # Handle cases where the model returns a list of outputs
78
+ if isinstance(prediction, list):
79
+ pred_mask = prediction[0]
80
+ else:
81
+ pred_mask = prediction
82
+
83
+ # The output of predict() is batched, so get the first item.
84
+ pred_mask = pred_mask[0]
85
+ # --- END OF FIX ---
86
+
87
  pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
88
 
89
  return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
 
92
  return None
93
 
94
  def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
 
95
  pixels = image.reshape((-1, 3)).astype(np.float32)
96
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
97
  _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
 
98
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
99
  wound_cluster_idx = np.argmax(centers_lab[:, 1])
 
100
  mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
 
101
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
102
  if contours:
103
  largest_contour = max(contours, key=cv2.contourArea)
 
107
  return mask
108
 
109
  def calculate_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
 
110
  wound_pixels = cv2.countNonZero(mask)
111
  if wound_pixels == 0:
112
+ return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_score": 0.0, "moisture_score": 0.0}
 
 
 
113
 
114
  area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
 
115
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116
  largest_contour = max(contours, key=cv2.contourArea)
 
117
  (_, (width, height), _) = cv2.minAreaRect(largest_contour)
118
  length_cm = max(width, height) / PIXELS_PER_CM
119
  breadth_cm = min(width, height) / PIXELS_PER_CM
 
124
 
125
  mean_a = np.mean(lab_img[:, :, 1][mask_bool])
126
  depth_score = mean_a - 128.0
 
127
  texture_std = np.std(gray_img[mask_bool])
128
  moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
129
 
130
+ return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_score": depth_score, "moisture_score": moisture_score}
 
 
 
131
 
132
  def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
133
+ if cv2.countNonZero(mask) == 0: return image
 
 
 
134
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
135
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
 
136
  overlay = np.zeros_like(image)
137
  overlay[dist >= 0.66] = (0, 0, 255)
138
  overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
139
  overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
 
140
  blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
141
  annotated_img = image.copy()
142
  annotated_img[mask.astype(bool)] = blended[mask.astype(bool)]
 
152
  raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
153
 
154
  processed_image = preprocess_image(original_image)
 
155
  bbox = detect_wound_region_yolo(processed_image)
156
  if bbox:
157
  xmin, ymin, xmax, ymax = bbox
 
164
  mask = segment_wound_with_fallback(cropped_image)
165
 
166
  metrics = calculate_metrics(mask, cropped_image)
 
167
  full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
168
  if bbox:
169
  full_mask[ymin:ymax, xmin:xmax] = mask
 
171
  full_mask = mask
172
 
173
  annotated_image = create_visual_overlay(original_image, full_mask)
 
174
  success, png_data = cv2.imencode(".png", annotated_image)
175
  if not success:
176
  raise HTTPException(status_code=500, detail="Failed to encode output image")
 
182
  "Wound-Depth-Score": f"{metrics['depth_score']:.1f}",
183
  "Wound-Moisture-Score": f"{metrics['moisture_score']:.0f}"
184
  }
185
+ return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)