jayn95 commited on
Commit
8f7e84e
·
verified ·
1 Parent(s): cda72d3

Update periodontitis_detection.py

Browse files
Files changed (1) hide show
  1. periodontitis_detection.py +163 -530
periodontitis_detection.py CHANGED
@@ -5,85 +5,75 @@ import matplotlib.pyplot as plt
5
  import tensorflow as tf
6
  from ultralytics import YOLO
7
 
 
8
  class SimpleDentalSegmentationNoEnhance:
9
- def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224,224,3)):
10
- # Load TFLite U-Net
11
- self.interpreter = tf.lite.Interpreter(model_path=unet_model_path)
12
- self.interpreter.allocate_tensors()
13
- self.input_details = self.interpreter.get_input_details()
14
- self.output_details = self.interpreter.get_output_details()
15
-
16
- # Force/prefer the desired U-Net input size
17
  self.in_h, self.in_w, self.in_c = unet_input_size
18
 
19
- # Load YOLOv8
20
  self.yolo = YOLO(yolo_model_path)
21
 
22
- print("Models loaded successfully!")
23
- print(f"Using forced U-Net input shape: (1, {self.in_h}, {self.in_w}, {self.in_c})")
24
- print(f"U-Net output shape (raw): {self.output_details[0]['shape']}")
 
25
 
26
  def preprocess_for_unet(self, image_bgr):
27
- img = image_bgr.copy()
28
- proc_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29
- proc_resized = cv2.resize(proc_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR)
30
- normalized = proc_resized.astype(np.float32) / 255.0
31
- input_tensor = np.expand_dims(normalized, axis=0).astype(np.float32)
32
- return input_tensor, proc_resized
 
 
 
33
 
34
  def run_unet(self, image_bgr):
 
 
 
 
35
  input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr)
36
-
37
- try:
38
- self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
39
- self.interpreter.invoke()
40
- output = self.interpreter.get_tensor(self.output_details[0]['index'])
41
- except Exception as e:
42
- print("Interpreter set_tensor failed, attempting to resize input to forced shape:", e)
43
- try:
44
- self.interpreter.resize_tensor_input(self.input_details[0]['index'], [1, self.in_h, self.in_w, self.in_c])
45
- self.interpreter.allocate_tensors()
46
- self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
47
- self.interpreter.invoke()
48
- output = self.interpreter.get_tensor(self.output_details[0]['index'])
49
- except Exception as e2:
50
- raise RuntimeError("Failed to run TFLite interpreter") from e2
51
-
52
- out = output[0]
53
 
54
  if out.ndim == 3 and out.shape[2] >= 2:
55
  class_map = np.argmax(out, axis=2).astype(np.uint8)
56
  abc = (class_map == 1).astype(np.uint8)
57
  cej = (class_map == 2).astype(np.uint8)
58
- elif out.ndim == 2:
59
- combined = out
60
- abc = (combined > 0.5).astype(np.uint8)
61
- cej = (combined > 0.8).astype(np.uint8)
62
  else:
63
- h_unet = out.shape[0]
64
- w_unet = out.shape[1] if out.ndim >= 2 else (self.in_w)
65
- abc = np.zeros((h_unet, w_unet), dtype=np.uint8)
66
- cej = np.zeros((h_unet, w_unet), dtype=np.uint8)
67
 
68
  return cej, abc, model_resized_image
69
 
70
  def detect_teeth(self, image_bgr):
 
 
 
 
71
  image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
72
  results = self.yolo(image_rgb)
73
  detections = []
 
74
  for r in results:
75
  boxes = getattr(r, "boxes", None)
76
  if boxes is None:
77
  continue
78
- for i, box in enumerate(boxes):
79
- try:
80
- xyxy = box.xyxy[0].cpu().numpy()
81
- except Exception:
82
- xyxy = np.array(box.xyxy).astype(np.float32).reshape(-1)[:4]
83
- try:
84
- conf = float(box.conf[0].cpu().numpy())
85
- except Exception:
86
- conf = float(box.conf if hasattr(box, "conf") else 0.0)
87
  detections.append({
88
  "bbox": xyxy.astype(np.float32),
89
  "confidence": conf,
@@ -92,555 +82,198 @@ class SimpleDentalSegmentationNoEnhance:
92
  return detections
93
 
94
  def resize_mask_to_original(self, mask, original_shape):
 
95
  h_orig, w_orig = original_shape
96
- mask_resized = cv2.resize((mask * 255).astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
97
- mask_resized = (mask_resized.astype(np.float32) / 255.0)
98
- return mask_resized
99
 
100
  def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox):
101
- x1, y1, x2, y2 = bbox
102
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
103
  height, width = abc_mask.shape
104
- x1 = max(0, x1)
105
- y1 = max(0, y1)
106
- x2 = min(width-1, x2)
107
- y2 = min(height-1, y2)
108
-
109
  abc_points = []
110
- for x in range(x1, x2+1):
111
- column_abc = np.where(abc_mask[y1:y2+1, x] == 255)[0]
112
- if len(column_abc) > 0:
113
- y_min_relative = np.min(column_abc)
114
- y_absolute = y1 + y_min_relative
115
  abc_points.append([x, y_absolute])
116
-
117
  if len(abc_points) < 2:
118
  return None
119
  return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2)
120
 
121
  def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox):
122
- x1, y1, x2, y2 = bbox
123
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
124
  height, width = cej_mask.shape
125
- x1 = max(0, x1)
126
- y1 = max(0, y1)
127
- x2 = min(width-1, x2)
128
- y2 = min(height-1, y2)
129
-
130
  cej_points = []
131
- for x in range(x1, x2+1):
132
- column_cej = np.where(cej_mask[y1:y2+1, x] == 255)[0]
133
- if len(column_cej) > 0:
134
- y_max_relative = np.max(column_cej)
135
- y_absolute = y1 + y_max_relative
136
  cej_points.append([x, y_absolute])
137
-
138
  if len(cej_points) < 2:
139
  return None
140
  return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2)
141
 
142
  def smooth_landmarks(self, points, window_size=5):
 
143
  if points is None or len(points) < window_size:
144
  return points
145
- points_2d = points.reshape(-1, 2)
146
- smoothed_points = []
147
- for i in range(len(points_2d)):
148
- start_idx = max(0, i - window_size // 2)
149
- end_idx = min(len(points_2d), i + window_size // 2 + 1)
150
- window_points = points_2d[start_idx:end_idx]
151
- smoothed_y = np.mean(window_points[:, 1])
152
- smoothed_points.append([points_2d[i][0], smoothed_y])
153
- return np.array(smoothed_points, dtype=np.int32).reshape(-1, 1, 2)
154
 
155
  def compute_cej_abc_distances(self, cej_points, abc_points):
156
- """
157
- Compute distances between CEJ and ABC points.
158
- For each x-coordinate, find the vertical distance between CEJ and ABC.
159
- """
160
  if cej_points is None or abc_points is None:
161
  return None
162
-
163
- cej_2d = cej_points.reshape(-1, 2)
164
- abc_2d = abc_points.reshape(-1, 2)
165
-
166
- # Create dictionaries for quick lookup
167
- cej_dict = {point[0]: point[1] for point in cej_2d}
168
- abc_dict = {point[0]: point[1] for point in abc_2d}
169
-
170
- # Find common x-coordinates
171
  common_x = set(cej_dict.keys()) & set(abc_dict.keys())
172
-
173
  if not common_x:
174
- # If no exact matches, use interpolation
175
  return self.compute_distances_with_interpolation(cej_2d, abc_2d)
176
-
177
- distances = []
178
- connection_points = []
179
-
180
  for x in sorted(common_x):
181
- cej_y = cej_dict[x]
182
- abc_y = abc_dict[x]
183
- distance = abs(abc_y - cej_y) # Vertical distance
184
-
185
- distances.append({
186
- 'x': x,
187
- 'cej_y': cej_y,
188
- 'abc_y': abc_y,
189
- 'distance': distance
190
- })
191
-
192
- connection_points.append([(x, cej_y), (x, abc_y)])
193
-
194
  return {
195
  'distances': distances,
196
- 'connection_points': connection_points,
197
  'mean_distance': np.mean([d['distance'] for d in distances]),
198
- 'max_distance': max([d['distance'] for d in distances]),
199
- 'min_distance': min([d['distance'] for d in distances])
200
  }
201
 
202
  def compute_distances_with_interpolation(self, cej_points, abc_points):
203
- """
204
- Compute distances using interpolation when points don't have exact x-matches.
205
- """
206
- # Get x-range that's common to both curves
207
  cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0])
208
  abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0])
209
-
210
- x_min = max(cej_x_min, abc_x_min)
211
- x_max = min(cej_x_max, abc_x_max)
212
-
213
  if x_min >= x_max:
214
  return None
215
-
216
- # Sample points at regular intervals
217
- num_samples = min(50, x_max - x_min + 1)
218
- x_sample = np.linspace(x_min, x_max, num_samples, dtype=int)
219
-
220
- # Interpolate y-values for both curves
221
- cej_y_interp = np.interp(x_sample, cej_points[:, 0], cej_points[:, 1])
222
- abc_y_interp = np.interp(x_sample, abc_points[:, 0], abc_points[:, 1])
223
-
224
- distances = []
225
- connection_points = []
226
-
227
- for i, x in enumerate(x_sample):
228
- cej_y = cej_y_interp[i]
229
- abc_y = abc_y_interp[i]
230
- distance = abs(abc_y - cej_y)
231
-
232
- distances.append({
233
- 'x': int(x),
234
- 'cej_y': int(cej_y),
235
- 'abc_y': int(abc_y),
236
- 'distance': distance
237
- })
238
-
239
- connection_points.append([(int(x), int(cej_y)), (int(x), int(abc_y))])
240
-
241
  return {
242
  'distances': distances,
243
- 'connection_points': connection_points,
244
  'mean_distance': np.mean([d['distance'] for d in distances]),
245
- 'max_distance': max([d['distance'] for d in distances]),
246
- 'min_distance': min([d['distance'] for d in distances])
247
  }
248
 
249
  def draw_distance_measurements(self, image, distance_analysis, tooth_id):
250
- """
251
- Draw clean distance measurements on the image without text overlays.
252
- """
253
  if distance_analysis is None:
254
  return image
255
-
256
- img_with_distances = image.copy()
257
-
258
- # Draw connection lines with gradient effect
259
- connection_points = distance_analysis['connection_points']
260
  distances = [d['distance'] for d in distance_analysis['distances']]
261
-
262
  if not distances:
263
- return img_with_distances
264
-
265
- # Normalize distances for color mapping
266
- min_dist = min(distances)
267
- max_dist = max(distances)
268
- dist_range = max_dist - min_dist if max_dist != min_dist else 1
269
-
270
- # Draw every 3rd line to reduce clutter, with color coding
271
- for i in range(0, len(connection_points), 3):
272
- start_point, end_point = connection_points[i]
273
- distance = distances[i]
274
-
275
- # Color based on distance (green = small, red = large)
276
- normalized_dist = (distance - min_dist) / dist_range
277
- color_intensity = int(255 * normalized_dist)
278
- color = (0, 255 - color_intensity, color_intensity) # Green to Red
279
-
280
- # Draw thicker line for longer distances
281
- thickness = max(1, int(3 * normalized_dist) + 1)
282
- cv2.line(img_with_distances, start_point, end_point, color, thickness)
283
-
284
- # Add small circles at measurement points
285
- cv2.circle(img_with_distances, start_point, 2, (255, 255, 255), -1)
286
- cv2.circle(img_with_distances, end_point, 2, (255, 255, 255), -1)
287
-
288
- return img_with_distances
289
 
290
  def analyze_image(self, image_path):
 
 
 
 
291
  img_bgr = cv2.imread(image_path)
292
  if img_bgr is None:
293
  raise FileNotFoundError(f"Could not read image: {image_path}")
294
- h_orig, w_orig = img_bgr.shape[:2]
295
 
 
296
  cej_unet, abc_unet, _ = self.run_unet(img_bgr)
 
 
 
297
 
298
- cej_orig = self.resize_mask_to_original(cej_unet, (h_orig, w_orig))
299
- abc_orig = self.resize_mask_to_original(abc_unet, (h_orig, w_orig))
300
 
301
- cej_bin = (cej_orig > 0.5).astype(np.uint8) * 255
302
- abc_bin = (abc_orig > 0.5).astype(np.uint8) * 255
 
303
 
304
- detections = self.detect_teeth(img_bgr)
 
305
 
306
- combined = img_bgr.copy()
307
- all_abc_segments = []
308
- all_cej_segments = []
309
- all_distance_analyses = []
310
-
311
- for i, det in enumerate(detections):
312
- x1, y1, x2, y2 = det["bbox"]
313
- x1i = max(0, int(np.floor(x1)))
314
- y1i = max(0, int(np.floor(y1)))
315
- x2i = min(w_orig - 1, int(np.ceil(x2)))
316
- y2i = min(h_orig - 1, int(np.ceil(y2)))
317
- if x2i <= x1i or y2i <= y1i:
318
- continue
319
 
320
- cv2.rectangle(combined, (x1i, y1i), (x2i, y2i), (0,255,0), 2)
321
-
322
- # ABC
323
- abc_line_segment = self.extract_abc_uppermost_line_within_bbox(abc_bin, (x1i, y1i, x2i, y2i))
324
- abc_data = None
325
- if abc_line_segment is not None and len(abc_line_segment) > 1:
326
- abc_line_segment = self.smooth_landmarks(abc_line_segment, window_size=3)
327
- cv2.polylines(combined, [abc_line_segment], False, (255,0,0), 3)
328
- abc_start = tuple(abc_line_segment[0][0])
329
- abc_end = tuple(abc_line_segment[-1][0])
330
- cv2.circle(combined, abc_start, 4, (255,165,0), -1)
331
- cv2.circle(combined, abc_end, 4, (255,165,0), -1)
332
- abc_data = {
333
- "points": abc_line_segment,
334
- "start": abc_start,
335
- "end": abc_end
336
- }
337
- all_abc_segments.append(abc_data)
338
-
339
- # CEJ
340
- cej_line_segment = self.extract_cej_lowermost_line_within_bbox(cej_bin, (x1i, y1i, x2i, y2i))
341
- cej_data = None
342
- if cej_line_segment is not None and len(cej_line_segment) > 1:
343
- cej_line_segment = self.smooth_landmarks(cej_line_segment, window_size=3)
344
- cv2.polylines(combined, [cej_line_segment], False, (0,0,255), 3)
345
- cej_start = tuple(cej_line_segment[0][0])
346
- cej_end = tuple(cej_line_segment[-1][0])
347
- cv2.circle(combined, cej_start, 4, (0,255,255), -1)
348
- cv2.circle(combined, cej_end, 4, (0,255,255), -1)
349
- cej_data = {
350
- "points": cej_line_segment,
351
- "start": cej_start,
352
- "end": cej_end
353
- }
354
- all_cej_segments.append(cej_data)
355
-
356
- # Compute distances between CEJ and ABC
357
  distance_analysis = None
358
- if abc_data is not None and cej_data is not None:
359
- distance_analysis = self.compute_cej_abc_distances(
360
- cej_data["points"], abc_data["points"]
361
- )
362
- if distance_analysis is not None:
363
- combined = self.draw_distance_measurements(combined, distance_analysis, i+1)
364
-
365
- all_distance_analyses.append({
366
- 'tooth_id': i + 1,
367
- 'analysis': distance_analysis
368
  })
369
 
370
- cv2.putText(combined, f"T{i+1}", (x1i+5, y1i+20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
371
-
372
- results = {
373
  "original": img_bgr,
374
- "cej_mask": cej_bin,
375
- "abc_mask": abc_bin,
376
  "detections": detections,
377
  "combined": combined,
378
- "abc_segments": all_abc_segments,
379
- "cej_segments": all_cej_segments,
380
- "distance_analyses": all_distance_analyses
381
  }
382
- return results
383
-
384
- def print_distance_summary(self, results):
385
- """
386
- Print a summary of distance measurements for all teeth.
387
- """
388
- print("\n" + "="*50)
389
- print("CEJ-ABC DISTANCE ANALYSIS SUMMARY")
390
- print("="*50)
391
-
392
- for tooth_data in results["distance_analyses"]:
393
- tooth_id = tooth_data['tooth_id']
394
- analysis = tooth_data['analysis']
395
-
396
- if analysis is None:
397
- print(f"Tooth {tooth_id}: No distance measurements available")
398
- continue
399
-
400
- print(f"\nTooth {tooth_id}:")
401
- print(f" Mean distance: {analysis['mean_distance']:.2f} pixels")
402
- print(f" Maximum distance: {analysis['max_distance']:.2f} pixels")
403
- print(f" Minimum distance: {analysis['min_distance']:.2f} pixels")
404
- print(f" Number of measurement points: {len(analysis['distances'])}")
405
-
406
- # Show some sample measurements
407
- if len(analysis['distances']) > 0:
408
- sample_size = min(3, len(analysis['distances']))
409
- print(f" Sample measurements:")
410
- for i in range(0, len(analysis['distances']), len(analysis['distances'])//sample_size):
411
- d = analysis['distances'][i]
412
- print(f" X={d['x']}: CEJ_Y={d['cej_y']}, ABC_Y={d['abc_y']}, Distance={d['distance']:.1f}px")
413
-
414
- def create_distance_heatmap(self, results):
415
- """Create a heatmap visualization of distances across all teeth."""
416
- all_distances = []
417
- tooth_labels = []
418
-
419
- for tooth_data in results["distance_analyses"]:
420
- if tooth_data['analysis'] is not None:
421
- distances = [d['distance'] for d in tooth_data['analysis']['distances']]
422
- all_distances.extend(distances)
423
- tooth_labels.extend([f"T{tooth_data['tooth_id']}"] * len(distances))
424
-
425
- if not all_distances:
426
- return None
427
-
428
- # Create histogram data
429
- unique_teeth = list(set(tooth_labels))
430
- tooth_distances = {tooth: [] for tooth in unique_teeth}
431
-
432
- for i, tooth in enumerate(tooth_labels):
433
- tooth_distances[tooth].append(all_distances[i])
434
-
435
- return tooth_distances
436
-
437
- def create_overlay_image(self, results):
438
- """Create an enhanced overlay image with better visualization."""
439
- img = results["original"].copy()
440
- cej_mask = results["cej_mask"]
441
- abc_mask = results["abc_mask"]
442
-
443
- # Create colored overlays
444
- overlay = img.copy()
445
-
446
- # CEJ in red with transparency
447
- cej_colored = np.zeros_like(img)
448
- cej_colored[:, :, 2] = cej_mask # Red channel
449
-
450
- # ABC in blue with transparency
451
- abc_colored = np.zeros_like(img)
452
- abc_colored[:, :, 0] = abc_mask # Blue channel
453
-
454
- # Blend overlays
455
- alpha = 0.4
456
- overlay = cv2.addWeighted(overlay, 1-alpha, cej_colored, alpha, 0)
457
- overlay = cv2.addWeighted(overlay, 1-alpha, abc_colored, alpha, 0)
458
-
459
- # Add tooth detection boxes and labels
460
- for i, det in enumerate(results["detections"]):
461
- x1, y1, x2, y2 = det["bbox"].astype(int)
462
- cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
463
-
464
- # Add tooth label with background
465
- label = f"Tooth {i+1}"
466
- font = cv2.FONT_HERSHEY_SIMPLEX
467
- font_scale = 0.7
468
- thickness = 2
469
- (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
470
-
471
- cv2.rectangle(overlay, (x1, y1-text_height-10), (x1+text_width+10, y1), (0, 255, 0), -1)
472
- cv2.putText(overlay, label, (x1+5, y1-5), font, font_scale, (0, 0, 0), thickness)
473
-
474
- return overlay
475
-
476
- def visualize_results(self, results, save_path=None):
477
- # Prepare images
478
- orig_rgb = cv2.cvtColor(results["original"], cv2.COLOR_BGR2RGB)
479
- combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
480
- overlay_bgr = self.create_overlay_image(results)
481
- overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
482
-
483
- # Create figure with custom layout
484
- fig = plt.figure(figsize=(20, 15))
485
-
486
- # Create a grid layout
487
- gs = fig.add_gridspec(3, 4, height_ratios=[1, 1, 0.8], hspace=0.3, wspace=0.2)
488
-
489
- # Top row - Original images
490
- ax1 = fig.add_subplot(gs[0, 0])
491
- ax1.imshow(orig_rgb)
492
- ax1.set_title("Original Image", fontsize=14, fontweight='bold')
493
- ax1.axis("off")
494
-
495
- ax2 = fig.add_subplot(gs[0, 1])
496
- ax2.imshow(overlay_rgb)
497
- ax2.set_title("Segmentation Overlay\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
498
- ax2.axis("off")
499
-
500
- ax3 = fig.add_subplot(gs[0, 2])
501
- ax3.imshow(combined_rgb)
502
- ax3.set_title("Distance Analysis\n(Yellow: Measurements)", fontsize=14, fontweight='bold')
503
- ax3.axis("off")
504
-
505
- # Individual mask visualization
506
- ax4 = fig.add_subplot(gs[0, 3])
507
- # Create combined mask visualization
508
- combined_mask = np.zeros((*results["cej_mask"].shape, 3), dtype=np.uint8)
509
- combined_mask[:, :, 2] = results["cej_mask"] # CEJ in red
510
- combined_mask[:, :, 0] = results["abc_mask"] # ABC in blue
511
- ax4.imshow(combined_mask)
512
- ax4.set_title("Combined Masks\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
513
- ax4.axis("off")
514
-
515
- # Middle row - Distance analysis charts
516
- ax5 = fig.add_subplot(gs[1, :2]) # Span 2 columns
517
-
518
- # Create bar chart of average distances per tooth
519
- tooth_means = []
520
- tooth_labels = []
521
- tooth_maxs = []
522
- tooth_mins = []
523
-
524
- for tooth_data in results["distance_analyses"]:
525
- if tooth_data['analysis'] is not None:
526
- tooth_labels.append(f"T{tooth_data['tooth_id']}")
527
- tooth_means.append(tooth_data['analysis']['mean_distance'])
528
- tooth_maxs.append(tooth_data['analysis']['max_distance'])
529
- tooth_mins.append(tooth_data['analysis']['min_distance'])
530
-
531
- if tooth_means:
532
- x_pos = np.arange(len(tooth_labels))
533
- bars = ax5.bar(x_pos, tooth_means, alpha=0.7, color='skyblue', edgecolor='navy')
534
-
535
- # Add error bars showing min/max range
536
- yerr_lower = np.array(tooth_means) - np.array(tooth_mins)
537
- yerr_upper = np.array(tooth_maxs) - np.array(tooth_means)
538
- ax5.errorbar(x_pos, tooth_means, yerr=[yerr_lower, yerr_upper],
539
- fmt='none', ecolor='red', capsize=5, alpha=0.7)
540
-
541
- ax5.set_xlabel("Tooth Number", fontsize=12, fontweight='bold')
542
- ax5.set_ylabel("Distance (pixels)", fontsize=12, fontweight='bold')
543
- ax5.set_title("CEJ-ABC Distance Analysis by Tooth", fontsize=14, fontweight='bold')
544
- ax5.set_xticks(x_pos)
545
- ax5.set_xticklabels(tooth_labels)
546
- ax5.grid(True, alpha=0.3)
547
-
548
- # Add value labels on bars
549
- for i, (bar, mean_val, max_val, min_val) in enumerate(zip(bars, tooth_means, tooth_maxs, tooth_mins)):
550
- height = bar.get_height()
551
- ax5.text(bar.get_x() + bar.get_width()/2., height + 1,
552
- f'{mean_val:.1f}', ha='center', va='bottom', fontweight='bold')
553
- else:
554
- ax5.text(0.5, 0.5, "No distance measurements available",
555
- ha='center', va='center', transform=ax5.transAxes, fontsize=14)
556
- ax5.set_title("CEJ-ABC Distance Analysis", fontsize=14, fontweight='bold')
557
-
558
- # Distance distribution histogram
559
- ax6 = fig.add_subplot(gs[1, 2:]) # Span 2 columns
560
-
561
- tooth_distances = self.create_distance_heatmap(results)
562
- if tooth_distances:
563
- all_vals = []
564
- labels = []
565
- colors = plt.cm.Set3(np.linspace(0, 1, len(tooth_distances)))
566
-
567
- for i, (tooth, distances) in enumerate(tooth_distances.items()):
568
- ax6.hist(distances, bins=20, alpha=0.6, label=tooth, color=colors[i])
569
- all_vals.extend(distances)
570
-
571
- ax6.set_xlabel("Distance (pixels)", fontsize=12, fontweight='bold')
572
- ax6.set_ylabel("Frequency", fontsize=12, fontweight='bold')
573
- ax6.set_title("Distance Distribution Across All Measurements", fontsize=14, fontweight='bold')
574
- ax6.legend()
575
- ax6.grid(True, alpha=0.3)
576
- else:
577
- ax6.text(0.5, 0.5, "No distance data available for histogram",
578
- ha='center', va='center', transform=ax6.transAxes, fontsize=14)
579
- ax6.set_title("Distance Distribution", fontsize=14, fontweight='bold')
580
-
581
- # Bottom row - Summary statistics table
582
- ax7 = fig.add_subplot(gs[2, :])
583
- ax7.axis('tight')
584
- ax7.axis('off')
585
-
586
- # Create summary table
587
- if tooth_means:
588
- table_data = []
589
- headers = ['Tooth', 'Mean Distance (px)', 'Max Distance (px)', 'Min Distance (px)', 'Range (px)', 'Measurements']
590
-
591
- for tooth_data in results["distance_analyses"]:
592
- if tooth_data['analysis'] is not None:
593
- analysis = tooth_data['analysis']
594
- range_val = analysis['max_distance'] - analysis['min_distance']
595
- num_measurements = len(analysis['distances'])
596
-
597
- table_data.append([
598
- f"T{tooth_data['tooth_id']}",
599
- f"{analysis['mean_distance']:.2f}",
600
- f"{analysis['max_distance']:.2f}",
601
- f"{analysis['min_distance']:.2f}",
602
- f"{range_val:.2f}",
603
- str(num_measurements)
604
- ])
605
-
606
- table = ax7.table(cellText=table_data, colLabels=headers,
607
- cellLoc='center', loc='center')
608
- table.auto_set_font_size(False)
609
- table.set_fontsize(10)
610
- table.scale(1, 2)
611
-
612
- # Style the table
613
- for (i, j), cell in table.get_celld().items():
614
- if i == 0: # Header row
615
- cell.set_text_props(weight='bold', color='white')
616
- cell.set_facecolor('#4472C4')
617
- else:
618
- cell.set_facecolor('#F2F2F2' if i % 2 == 0 else 'white')
619
-
620
- ax7.set_title("Detailed Distance Analysis Summary", fontsize=14, fontweight='bold', pad=20)
621
- else:
622
- ax7.text(0.5, 0.5, "No measurements available for summary table",
623
- ha='center', va='center', transform=ax7.transAxes, fontsize=14)
624
-
625
- plt.suptitle("Comprehensive Dental CEJ-ABC Distance Analysis", fontsize=16, fontweight='bold', y=0.98)
626
-
627
- if save_path:
628
- plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white')
629
- print(f"Saved enhanced visualization to {save_path}")
630
-
631
- return fig
632
 
633
 
634
  if __name__ == "__main__":
635
- unet_model = "unet.keras"
636
- yolo_model = "yolov8n-seg.pt"
637
- image_path = "trial.jpg"
638
 
639
  seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model)
640
  res = seg.analyze_image(image_path)
641
-
642
- # Print distance analysis summary
643
- seg.print_distance_summary(res)
644
-
645
- fig = seg.visualize_results(res, save_path="segmentation_with_distances.png")
646
- plt.show()
 
5
  import tensorflow as tf
6
  from ultralytics import YOLO
7
 
8
+
9
  class SimpleDentalSegmentationNoEnhance:
10
+ def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224, 224, 3)):
11
+ """
12
+ Initialize the dental segmentation and analysis pipeline.
13
+ """
14
+ # Load Keras U-Net model
15
+ self.unet = tf.keras.models.load_model(unet_model_path)
 
 
16
  self.in_h, self.in_w, self.in_c = unet_input_size
17
 
18
+ # Load YOLOv8 (PyTorch) model
19
  self.yolo = YOLO(yolo_model_path)
20
 
21
+ print("Models loaded successfully.")
22
+ print(f"Keras U-Net input shape: {self.unet.input_shape}")
23
+ print(f"Keras U-Net output shape: {self.unet.output_shape}")
24
+ print(f"YOLO model loaded: {yolo_model_path}")
25
 
26
  def preprocess_for_unet(self, image_bgr):
27
+ """
28
+ Prepare a BGR image for U-Net prediction.
29
+ Converts to RGB, resizes, and normalizes.
30
+ """
31
+ img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
32
+ img_resized = cv2.resize(img_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR)
33
+ img_norm = img_resized.astype(np.float32) / 255.0
34
+ input_tensor = np.expand_dims(img_norm, axis=0)
35
+ return input_tensor, img_resized
36
 
37
  def run_unet(self, image_bgr):
38
+ """
39
+ Run the Keras U-Net model on the given image.
40
+ Returns CEJ and ABC masks.
41
+ """
42
  input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr)
43
+ preds = self.unet.predict(input_tensor, verbose=0)
44
+ out = preds[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if out.ndim == 3 and out.shape[2] >= 2:
47
  class_map = np.argmax(out, axis=2).astype(np.uint8)
48
  abc = (class_map == 1).astype(np.uint8)
49
  cej = (class_map == 2).astype(np.uint8)
50
+ elif out.ndim == 3 and out.shape[2] == 1:
51
+ binary = out[:, :, 0]
52
+ abc = (binary > 0.5).astype(np.uint8)
53
+ cej = np.zeros_like(abc)
54
  else:
55
+ h, w = out.shape[:2]
56
+ abc = np.zeros((h, w), dtype=np.uint8)
57
+ cej = np.zeros((h, w), dtype=np.uint8)
 
58
 
59
  return cej, abc, model_resized_image
60
 
61
  def detect_teeth(self, image_bgr):
62
+ """
63
+ Detect teeth using YOLOv8 PyTorch model.
64
+ Returns bounding boxes and confidence scores.
65
+ """
66
  image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
67
  results = self.yolo(image_rgb)
68
  detections = []
69
+
70
  for r in results:
71
  boxes = getattr(r, "boxes", None)
72
  if boxes is None:
73
  continue
74
+ for box in boxes:
75
+ xyxy = box.xyxy[0].cpu().numpy()
76
+ conf = float(box.conf[0].cpu().numpy())
 
 
 
 
 
 
77
  detections.append({
78
  "bbox": xyxy.astype(np.float32),
79
  "confidence": conf,
 
82
  return detections
83
 
84
  def resize_mask_to_original(self, mask, original_shape):
85
+ """Resize a predicted mask back to original image size."""
86
  h_orig, w_orig = original_shape
87
+ mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
88
+ return (mask_resized > 127).astype(np.uint8)
 
89
 
90
  def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox):
91
+ """Extract the uppermost ABC line within a detected tooth bounding box."""
92
+ x1, y1, x2, y2 = map(int, bbox)
93
  height, width = abc_mask.shape
94
+ x1, y1 = max(0, x1), max(0, y1)
95
+ x2, y2 = min(width - 1, x2), min(height - 1, y2)
96
+
 
 
97
  abc_points = []
98
+ for x in range(x1, x2 + 1):
99
+ column = np.where(abc_mask[y1:y2 + 1, x] == 1)[0]
100
+ if len(column) > 0:
101
+ y_absolute = y1 + np.min(column)
 
102
  abc_points.append([x, y_absolute])
103
+
104
  if len(abc_points) < 2:
105
  return None
106
  return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2)
107
 
108
  def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox):
109
+ """Extract the lowermost CEJ line within a detected tooth bounding box."""
110
+ x1, y1, x2, y2 = map(int, bbox)
111
  height, width = cej_mask.shape
112
+ x1, y1 = max(0, x1), max(0, y1)
113
+ x2, y2 = min(width - 1, x2), min(height - 1, y2)
114
+
 
 
115
  cej_points = []
116
+ for x in range(x1, x2 + 1):
117
+ column = np.where(cej_mask[y1:y2 + 1, x] == 1)[0]
118
+ if len(column) > 0:
119
+ y_absolute = y1 + np.max(column)
 
120
  cej_points.append([x, y_absolute])
121
+
122
  if len(cej_points) < 2:
123
  return None
124
  return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2)
125
 
126
  def smooth_landmarks(self, points, window_size=5):
127
+ """Smooth a polyline using a simple moving average."""
128
  if points is None or len(points) < window_size:
129
  return points
130
+ pts = points.reshape(-1, 2)
131
+ smoothed = []
132
+ for i in range(len(pts)):
133
+ start, end = max(0, i - window_size // 2), min(len(pts), i + window_size // 2 + 1)
134
+ smoothed_y = np.mean(pts[start:end, 1])
135
+ smoothed.append([pts[i, 0], smoothed_y])
136
+ return np.array(smoothed, dtype=np.int32).reshape(-1, 1, 2)
 
 
137
 
138
  def compute_cej_abc_distances(self, cej_points, abc_points):
139
+ """Compute vertical distances between CEJ and ABC points."""
 
 
 
140
  if cej_points is None or abc_points is None:
141
  return None
142
+
143
+ cej_2d, abc_2d = cej_points.reshape(-1, 2), abc_points.reshape(-1, 2)
144
+ cej_dict = {x: y for x, y in cej_2d}
145
+ abc_dict = {x: y for x, y in abc_2d}
146
+
 
 
 
 
147
  common_x = set(cej_dict.keys()) & set(abc_dict.keys())
 
148
  if not common_x:
 
149
  return self.compute_distances_with_interpolation(cej_2d, abc_2d)
150
+
151
+ distances, connections = [], []
 
 
152
  for x in sorted(common_x):
153
+ cej_y, abc_y = cej_dict[x], abc_dict[x]
154
+ dist = abs(abc_y - cej_y)
155
+ distances.append({'x': x, 'cej_y': cej_y, 'abc_y': abc_y, 'distance': dist})
156
+ connections.append([(x, cej_y), (x, abc_y)])
157
+
 
 
 
 
 
 
 
 
158
  return {
159
  'distances': distances,
160
+ 'connection_points': connections,
161
  'mean_distance': np.mean([d['distance'] for d in distances]),
162
+ 'max_distance': np.max([d['distance'] for d in distances]),
163
+ 'min_distance': np.min([d['distance'] for d in distances]),
164
  }
165
 
166
  def compute_distances_with_interpolation(self, cej_points, abc_points):
167
+ """Interpolate CEJ and ABC lines when x-coordinates don’t match exactly."""
 
 
 
168
  cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0])
169
  abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0])
170
+ x_min, x_max = max(cej_x_min, abc_x_min), min(cej_x_max, abc_x_max)
 
 
 
171
  if x_min >= x_max:
172
  return None
173
+
174
+ x_samples = np.linspace(x_min, x_max, min(50, int(x_max - x_min) + 1), dtype=int)
175
+ cej_y = np.interp(x_samples, cej_points[:, 0], cej_points[:, 1])
176
+ abc_y = np.interp(x_samples, abc_points[:, 0], abc_points[:, 1])
177
+
178
+ distances, connections = [], []
179
+ for x, cy, ay in zip(x_samples, cej_y, abc_y):
180
+ dist = abs(ay - cy)
181
+ distances.append({'x': int(x), 'cej_y': int(cy), 'abc_y': int(ay), 'distance': dist})
182
+ connections.append([(int(x), int(cy)), (int(x), int(ay))])
183
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  return {
185
  'distances': distances,
186
+ 'connection_points': connections,
187
  'mean_distance': np.mean([d['distance'] for d in distances]),
188
+ 'max_distance': np.max([d['distance'] for d in distances]),
189
+ 'min_distance': np.min([d['distance'] for d in distances]),
190
  }
191
 
192
  def draw_distance_measurements(self, image, distance_analysis, tooth_id):
193
+ """Draw color-coded CEJ-ABC measurement lines."""
 
 
194
  if distance_analysis is None:
195
  return image
196
+ img = image.copy()
197
+ connections = distance_analysis['connection_points']
 
 
 
198
  distances = [d['distance'] for d in distance_analysis['distances']]
199
+
200
  if not distances:
201
+ return img
202
+
203
+ min_d, max_d = min(distances), max(distances)
204
+ dist_range = max_d - min_d if max_d != min_d else 1
205
+
206
+ for i in range(0, len(connections), 3): # draw every 3rd to reduce clutter
207
+ (x1, y1), (x2, y2) = connections[i]
208
+ dist = distances[i]
209
+ norm = (dist - min_d) / dist_range
210
+ color = (0, int(255 * (1 - norm)), int(255 * norm)) # Green→Red
211
+ cv2.line(img, (x1, y1), (x2, y2), color, max(1, int(2 + 2 * norm)))
212
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  def analyze_image(self, image_path):
215
+ """
216
+ Perform full analysis on a dental image:
217
+ segmentation, detection, distance measurement, and visualization.
218
+ """
219
  img_bgr = cv2.imread(image_path)
220
  if img_bgr is None:
221
  raise FileNotFoundError(f"Could not read image: {image_path}")
 
222
 
223
+ h_orig, w_orig = img_bgr.shape[:2]
224
  cej_unet, abc_unet, _ = self.run_unet(img_bgr)
225
+ cej_mask = self.resize_mask_to_original(cej_unet, (h_orig, w_orig))
226
+ abc_mask = self.resize_mask_to_original(abc_unet, (h_orig, w_orig))
227
+ detections = self.detect_teeth(img_bgr)
228
 
229
+ combined = img_bgr.copy()
230
+ all_results = []
231
 
232
+ for det in detections:
233
+ x1, y1, x2, y2 = det["bbox"].astype(int)
234
+ cv2.rectangle(combined, (x1, y1), (x2, y2), (0, 255, 0), 2)
235
 
236
+ abc_line = self.extract_abc_uppermost_line_within_bbox(abc_mask, (x1, y1, x2, y2))
237
+ cej_line = self.extract_cej_lowermost_line_within_bbox(cej_mask, (x1, y1, x2, y2))
238
 
239
+ if abc_line is not None:
240
+ abc_line = self.smooth_landmarks(abc_line)
241
+ cv2.polylines(combined, [abc_line], False, (255, 0, 0), 2)
242
+ if cej_line is not None:
243
+ cej_line = self.smooth_landmarks(cej_line)
244
+ cv2.polylines(combined, [cej_line], False, (0, 0, 255), 2)
 
 
 
 
 
 
 
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  distance_analysis = None
247
+ if cej_line is not None and abc_line is not None:
248
+ distance_analysis = self.compute_cej_abc_distances(cej_line, abc_line)
249
+ if distance_analysis:
250
+ combined = self.draw_distance_measurements(combined, distance_analysis, det["tooth_id"])
251
+
252
+ all_results.append({
253
+ "tooth_id": det["tooth_id"],
254
+ "analysis": distance_analysis
 
 
255
  })
256
 
257
+ return {
 
 
258
  "original": img_bgr,
259
+ "cej_mask": cej_mask,
260
+ "abc_mask": abc_mask,
261
  "detections": detections,
262
  "combined": combined,
263
+ "distance_analyses": all_results
 
 
264
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
 
267
  if __name__ == "__main__":
268
+ unet_model = "unet.keras" # Keras model
269
+ yolo_model = "yolov8n-seg.pt" # YOLOv8 PyTorch model
270
+ image_path = "trial2.jpg"
271
 
272
  seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model)
273
  res = seg.analyze_image(image_path)
274
+
275
+ plt.figure(figsize=(12, 8))
276
+ plt.imshow(cv2.cvtColor(res["combined"], cv2.COLOR_BGR2RGB))
277
+ plt.title("Dental CEJ–ABC Analysis Result")
278
+ plt.axis("off")
279
+ plt.show()