jayn95 commited on
Commit
9473443
·
verified ·
1 Parent(s): 7758f8d

Update periodontitis_detection.py

Browse files
Files changed (1) hide show
  1. periodontitis_detection.py +645 -645
periodontitis_detection.py CHANGED
@@ -1,646 +1,646 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import tensorflow as tf
6
- from ultralytics import YOLO
7
-
8
- class SimpleDentalSegmentationNoEnhance:
9
- def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224,224,3)):
10
- # Load TFLite U-Net
11
- self.interpreter = tf.lite.Interpreter(model_path=unet_model_path)
12
- self.interpreter.allocate_tensors()
13
- self.input_details = self.interpreter.get_input_details()
14
- self.output_details = self.interpreter.get_output_details()
15
-
16
- # Force/prefer the desired U-Net input size
17
- self.in_h, self.in_w, self.in_c = unet_input_size
18
-
19
- # Load YOLOv8
20
- self.yolo = YOLO(yolo_model_path)
21
-
22
- print("Models loaded successfully!")
23
- print(f"Using forced U-Net input shape: (1, {self.in_h}, {self.in_w}, {self.in_c})")
24
- print(f"U-Net output shape (raw): {self.output_details[0]['shape']}")
25
-
26
- def preprocess_for_unet(self, image_bgr):
27
- img = image_bgr.copy()
28
- proc_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29
- proc_resized = cv2.resize(proc_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR)
30
- normalized = proc_resized.astype(np.float32) / 255.0
31
- input_tensor = np.expand_dims(normalized, axis=0).astype(np.float32)
32
- return input_tensor, proc_resized
33
-
34
- def run_unet(self, image_bgr):
35
- input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr)
36
-
37
- try:
38
- self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
39
- self.interpreter.invoke()
40
- output = self.interpreter.get_tensor(self.output_details[0]['index'])
41
- except Exception as e:
42
- print("Interpreter set_tensor failed, attempting to resize input to forced shape:", e)
43
- try:
44
- self.interpreter.resize_tensor_input(self.input_details[0]['index'], [1, self.in_h, self.in_w, self.in_c])
45
- self.interpreter.allocate_tensors()
46
- self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
47
- self.interpreter.invoke()
48
- output = self.interpreter.get_tensor(self.output_details[0]['index'])
49
- except Exception as e2:
50
- raise RuntimeError("Failed to run TFLite interpreter") from e2
51
-
52
- out = output[0]
53
-
54
- if out.ndim == 3 and out.shape[2] >= 2:
55
- class_map = np.argmax(out, axis=2).astype(np.uint8)
56
- abc = (class_map == 1).astype(np.uint8)
57
- cej = (class_map == 2).astype(np.uint8)
58
- elif out.ndim == 2:
59
- combined = out
60
- abc = (combined > 0.5).astype(np.uint8)
61
- cej = (combined > 0.8).astype(np.uint8)
62
- else:
63
- h_unet = out.shape[0]
64
- w_unet = out.shape[1] if out.ndim >= 2 else (self.in_w)
65
- abc = np.zeros((h_unet, w_unet), dtype=np.uint8)
66
- cej = np.zeros((h_unet, w_unet), dtype=np.uint8)
67
-
68
- return cej, abc, model_resized_image
69
-
70
- def detect_teeth(self, image_bgr):
71
- image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
72
- results = self.yolo(image_rgb)
73
- detections = []
74
- for r in results:
75
- boxes = getattr(r, "boxes", None)
76
- if boxes is None:
77
- continue
78
- for i, box in enumerate(boxes):
79
- try:
80
- xyxy = box.xyxy[0].cpu().numpy()
81
- except Exception:
82
- xyxy = np.array(box.xyxy).astype(np.float32).reshape(-1)[:4]
83
- try:
84
- conf = float(box.conf[0].cpu().numpy())
85
- except Exception:
86
- conf = float(box.conf if hasattr(box, "conf") else 0.0)
87
- detections.append({
88
- "bbox": xyxy.astype(np.float32),
89
- "confidence": conf,
90
- "tooth_id": len(detections) + 1
91
- })
92
- return detections
93
-
94
- def resize_mask_to_original(self, mask, original_shape):
95
- h_orig, w_orig = original_shape
96
- mask_resized = cv2.resize((mask * 255).astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
97
- mask_resized = (mask_resized.astype(np.float32) / 255.0)
98
- return mask_resized
99
-
100
- def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox):
101
- x1, y1, x2, y2 = bbox
102
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
103
- height, width = abc_mask.shape
104
- x1 = max(0, x1)
105
- y1 = max(0, y1)
106
- x2 = min(width-1, x2)
107
- y2 = min(height-1, y2)
108
-
109
- abc_points = []
110
- for x in range(x1, x2+1):
111
- column_abc = np.where(abc_mask[y1:y2+1, x] == 255)[0]
112
- if len(column_abc) > 0:
113
- y_min_relative = np.min(column_abc)
114
- y_absolute = y1 + y_min_relative
115
- abc_points.append([x, y_absolute])
116
-
117
- if len(abc_points) < 2:
118
- return None
119
- return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2)
120
-
121
- def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox):
122
- x1, y1, x2, y2 = bbox
123
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
124
- height, width = cej_mask.shape
125
- x1 = max(0, x1)
126
- y1 = max(0, y1)
127
- x2 = min(width-1, x2)
128
- y2 = min(height-1, y2)
129
-
130
- cej_points = []
131
- for x in range(x1, x2+1):
132
- column_cej = np.where(cej_mask[y1:y2+1, x] == 255)[0]
133
- if len(column_cej) > 0:
134
- y_max_relative = np.max(column_cej)
135
- y_absolute = y1 + y_max_relative
136
- cej_points.append([x, y_absolute])
137
-
138
- if len(cej_points) < 2:
139
- return None
140
- return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2)
141
-
142
- def smooth_landmarks(self, points, window_size=5):
143
- if points is None or len(points) < window_size:
144
- return points
145
- points_2d = points.reshape(-1, 2)
146
- smoothed_points = []
147
- for i in range(len(points_2d)):
148
- start_idx = max(0, i - window_size // 2)
149
- end_idx = min(len(points_2d), i + window_size // 2 + 1)
150
- window_points = points_2d[start_idx:end_idx]
151
- smoothed_y = np.mean(window_points[:, 1])
152
- smoothed_points.append([points_2d[i][0], smoothed_y])
153
- return np.array(smoothed_points, dtype=np.int32).reshape(-1, 1, 2)
154
-
155
- def compute_cej_abc_distances(self, cej_points, abc_points):
156
- """
157
- Compute distances between CEJ and ABC points.
158
- For each x-coordinate, find the vertical distance between CEJ and ABC.
159
- """
160
- if cej_points is None or abc_points is None:
161
- return None
162
-
163
- cej_2d = cej_points.reshape(-1, 2)
164
- abc_2d = abc_points.reshape(-1, 2)
165
-
166
- # Create dictionaries for quick lookup
167
- cej_dict = {point[0]: point[1] for point in cej_2d}
168
- abc_dict = {point[0]: point[1] for point in abc_2d}
169
-
170
- # Find common x-coordinates
171
- common_x = set(cej_dict.keys()) & set(abc_dict.keys())
172
-
173
- if not common_x:
174
- # If no exact matches, use interpolation
175
- return self.compute_distances_with_interpolation(cej_2d, abc_2d)
176
-
177
- distances = []
178
- connection_points = []
179
-
180
- for x in sorted(common_x):
181
- cej_y = cej_dict[x]
182
- abc_y = abc_dict[x]
183
- distance = abs(abc_y - cej_y) # Vertical distance
184
-
185
- distances.append({
186
- 'x': x,
187
- 'cej_y': cej_y,
188
- 'abc_y': abc_y,
189
- 'distance': distance
190
- })
191
-
192
- connection_points.append([(x, cej_y), (x, abc_y)])
193
-
194
- return {
195
- 'distances': distances,
196
- 'connection_points': connection_points,
197
- 'mean_distance': np.mean([d['distance'] for d in distances]),
198
- 'max_distance': max([d['distance'] for d in distances]),
199
- 'min_distance': min([d['distance'] for d in distances])
200
- }
201
-
202
- def compute_distances_with_interpolation(self, cej_points, abc_points):
203
- """
204
- Compute distances using interpolation when points don't have exact x-matches.
205
- """
206
- # Get x-range that's common to both curves
207
- cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0])
208
- abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0])
209
-
210
- x_min = max(cej_x_min, abc_x_min)
211
- x_max = min(cej_x_max, abc_x_max)
212
-
213
- if x_min >= x_max:
214
- return None
215
-
216
- # Sample points at regular intervals
217
- num_samples = min(50, x_max - x_min + 1)
218
- x_sample = np.linspace(x_min, x_max, num_samples, dtype=int)
219
-
220
- # Interpolate y-values for both curves
221
- cej_y_interp = np.interp(x_sample, cej_points[:, 0], cej_points[:, 1])
222
- abc_y_interp = np.interp(x_sample, abc_points[:, 0], abc_points[:, 1])
223
-
224
- distances = []
225
- connection_points = []
226
-
227
- for i, x in enumerate(x_sample):
228
- cej_y = cej_y_interp[i]
229
- abc_y = abc_y_interp[i]
230
- distance = abs(abc_y - cej_y)
231
-
232
- distances.append({
233
- 'x': int(x),
234
- 'cej_y': int(cej_y),
235
- 'abc_y': int(abc_y),
236
- 'distance': distance
237
- })
238
-
239
- connection_points.append([(int(x), int(cej_y)), (int(x), int(abc_y))])
240
-
241
- return {
242
- 'distances': distances,
243
- 'connection_points': connection_points,
244
- 'mean_distance': np.mean([d['distance'] for d in distances]),
245
- 'max_distance': max([d['distance'] for d in distances]),
246
- 'min_distance': min([d['distance'] for d in distances])
247
- }
248
-
249
- def draw_distance_measurements(self, image, distance_analysis, tooth_id):
250
- """
251
- Draw clean distance measurements on the image without text overlays.
252
- """
253
- if distance_analysis is None:
254
- return image
255
-
256
- img_with_distances = image.copy()
257
-
258
- # Draw connection lines with gradient effect
259
- connection_points = distance_analysis['connection_points']
260
- distances = [d['distance'] for d in distance_analysis['distances']]
261
-
262
- if not distances:
263
- return img_with_distances
264
-
265
- # Normalize distances for color mapping
266
- min_dist = min(distances)
267
- max_dist = max(distances)
268
- dist_range = max_dist - min_dist if max_dist != min_dist else 1
269
-
270
- # Draw every 3rd line to reduce clutter, with color coding
271
- for i in range(0, len(connection_points), 3):
272
- start_point, end_point = connection_points[i]
273
- distance = distances[i]
274
-
275
- # Color based on distance (green = small, red = large)
276
- normalized_dist = (distance - min_dist) / dist_range
277
- color_intensity = int(255 * normalized_dist)
278
- color = (0, 255 - color_intensity, color_intensity) # Green to Red
279
-
280
- # Draw thicker line for longer distances
281
- thickness = max(1, int(3 * normalized_dist) + 1)
282
- cv2.line(img_with_distances, start_point, end_point, color, thickness)
283
-
284
- # Add small circles at measurement points
285
- cv2.circle(img_with_distances, start_point, 2, (255, 255, 255), -1)
286
- cv2.circle(img_with_distances, end_point, 2, (255, 255, 255), -1)
287
-
288
- return img_with_distances
289
-
290
- def analyze_image(self, image_path):
291
- img_bgr = cv2.imread(image_path)
292
- if img_bgr is None:
293
- raise FileNotFoundError(f"Could not read image: {image_path}")
294
- h_orig, w_orig = img_bgr.shape[:2]
295
-
296
- cej_unet, abc_unet, _ = self.run_unet(img_bgr)
297
-
298
- cej_orig = self.resize_mask_to_original(cej_unet, (h_orig, w_orig))
299
- abc_orig = self.resize_mask_to_original(abc_unet, (h_orig, w_orig))
300
-
301
- cej_bin = (cej_orig > 0.5).astype(np.uint8) * 255
302
- abc_bin = (abc_orig > 0.5).astype(np.uint8) * 255
303
-
304
- detections = self.detect_teeth(img_bgr)
305
-
306
- combined = img_bgr.copy()
307
- all_abc_segments = []
308
- all_cej_segments = []
309
- all_distance_analyses = []
310
-
311
- for i, det in enumerate(detections):
312
- x1, y1, x2, y2 = det["bbox"]
313
- x1i = max(0, int(np.floor(x1)))
314
- y1i = max(0, int(np.floor(y1)))
315
- x2i = min(w_orig - 1, int(np.ceil(x2)))
316
- y2i = min(h_orig - 1, int(np.ceil(y2)))
317
- if x2i <= x1i or y2i <= y1i:
318
- continue
319
-
320
- cv2.rectangle(combined, (x1i, y1i), (x2i, y2i), (0,255,0), 2)
321
-
322
- # ABC
323
- abc_line_segment = self.extract_abc_uppermost_line_within_bbox(abc_bin, (x1i, y1i, x2i, y2i))
324
- abc_data = None
325
- if abc_line_segment is not None and len(abc_line_segment) > 1:
326
- abc_line_segment = self.smooth_landmarks(abc_line_segment, window_size=3)
327
- cv2.polylines(combined, [abc_line_segment], False, (255,0,0), 3)
328
- abc_start = tuple(abc_line_segment[0][0])
329
- abc_end = tuple(abc_line_segment[-1][0])
330
- cv2.circle(combined, abc_start, 4, (255,165,0), -1)
331
- cv2.circle(combined, abc_end, 4, (255,165,0), -1)
332
- abc_data = {
333
- "points": abc_line_segment,
334
- "start": abc_start,
335
- "end": abc_end
336
- }
337
- all_abc_segments.append(abc_data)
338
-
339
- # CEJ
340
- cej_line_segment = self.extract_cej_lowermost_line_within_bbox(cej_bin, (x1i, y1i, x2i, y2i))
341
- cej_data = None
342
- if cej_line_segment is not None and len(cej_line_segment) > 1:
343
- cej_line_segment = self.smooth_landmarks(cej_line_segment, window_size=3)
344
- cv2.polylines(combined, [cej_line_segment], False, (0,0,255), 3)
345
- cej_start = tuple(cej_line_segment[0][0])
346
- cej_end = tuple(cej_line_segment[-1][0])
347
- cv2.circle(combined, cej_start, 4, (0,255,255), -1)
348
- cv2.circle(combined, cej_end, 4, (0,255,255), -1)
349
- cej_data = {
350
- "points": cej_line_segment,
351
- "start": cej_start,
352
- "end": cej_end
353
- }
354
- all_cej_segments.append(cej_data)
355
-
356
- # Compute distances between CEJ and ABC
357
- distance_analysis = None
358
- if abc_data is not None and cej_data is not None:
359
- distance_analysis = self.compute_cej_abc_distances(
360
- cej_data["points"], abc_data["points"]
361
- )
362
- if distance_analysis is not None:
363
- combined = self.draw_distance_measurements(combined, distance_analysis, i+1)
364
-
365
- all_distance_analyses.append({
366
- 'tooth_id': i + 1,
367
- 'analysis': distance_analysis
368
- })
369
-
370
- cv2.putText(combined, f"T{i+1}", (x1i+5, y1i+20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
371
-
372
- results = {
373
- "original": img_bgr,
374
- "cej_mask": cej_bin,
375
- "abc_mask": abc_bin,
376
- "detections": detections,
377
- "combined": combined,
378
- "abc_segments": all_abc_segments,
379
- "cej_segments": all_cej_segments,
380
- "distance_analyses": all_distance_analyses
381
- }
382
- return results
383
-
384
- def print_distance_summary(self, results):
385
- """
386
- Print a summary of distance measurements for all teeth.
387
- """
388
- print("\n" + "="*50)
389
- print("CEJ-ABC DISTANCE ANALYSIS SUMMARY")
390
- print("="*50)
391
-
392
- for tooth_data in results["distance_analyses"]:
393
- tooth_id = tooth_data['tooth_id']
394
- analysis = tooth_data['analysis']
395
-
396
- if analysis is None:
397
- print(f"Tooth {tooth_id}: No distance measurements available")
398
- continue
399
-
400
- print(f"\nTooth {tooth_id}:")
401
- print(f" Mean distance: {analysis['mean_distance']:.2f} pixels")
402
- print(f" Maximum distance: {analysis['max_distance']:.2f} pixels")
403
- print(f" Minimum distance: {analysis['min_distance']:.2f} pixels")
404
- print(f" Number of measurement points: {len(analysis['distances'])}")
405
-
406
- # Show some sample measurements
407
- if len(analysis['distances']) > 0:
408
- sample_size = min(3, len(analysis['distances']))
409
- print(f" Sample measurements:")
410
- for i in range(0, len(analysis['distances']), len(analysis['distances'])//sample_size):
411
- d = analysis['distances'][i]
412
- print(f" X={d['x']}: CEJ_Y={d['cej_y']}, ABC_Y={d['abc_y']}, Distance={d['distance']:.1f}px")
413
-
414
- def create_distance_heatmap(self, results):
415
- """Create a heatmap visualization of distances across all teeth."""
416
- all_distances = []
417
- tooth_labels = []
418
-
419
- for tooth_data in results["distance_analyses"]:
420
- if tooth_data['analysis'] is not None:
421
- distances = [d['distance'] for d in tooth_data['analysis']['distances']]
422
- all_distances.extend(distances)
423
- tooth_labels.extend([f"T{tooth_data['tooth_id']}"] * len(distances))
424
-
425
- if not all_distances:
426
- return None
427
-
428
- # Create histogram data
429
- unique_teeth = list(set(tooth_labels))
430
- tooth_distances = {tooth: [] for tooth in unique_teeth}
431
-
432
- for i, tooth in enumerate(tooth_labels):
433
- tooth_distances[tooth].append(all_distances[i])
434
-
435
- return tooth_distances
436
-
437
- def create_overlay_image(self, results):
438
- """Create an enhanced overlay image with better visualization."""
439
- img = results["original"].copy()
440
- cej_mask = results["cej_mask"]
441
- abc_mask = results["abc_mask"]
442
-
443
- # Create colored overlays
444
- overlay = img.copy()
445
-
446
- # CEJ in red with transparency
447
- cej_colored = np.zeros_like(img)
448
- cej_colored[:, :, 2] = cej_mask # Red channel
449
-
450
- # ABC in blue with transparency
451
- abc_colored = np.zeros_like(img)
452
- abc_colored[:, :, 0] = abc_mask # Blue channel
453
-
454
- # Blend overlays
455
- alpha = 0.4
456
- overlay = cv2.addWeighted(overlay, 1-alpha, cej_colored, alpha, 0)
457
- overlay = cv2.addWeighted(overlay, 1-alpha, abc_colored, alpha, 0)
458
-
459
- # Add tooth detection boxes and labels
460
- for i, det in enumerate(results["detections"]):
461
- x1, y1, x2, y2 = det["bbox"].astype(int)
462
- cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
463
-
464
- # Add tooth label with background
465
- label = f"Tooth {i+1}"
466
- font = cv2.FONT_HERSHEY_SIMPLEX
467
- font_scale = 0.7
468
- thickness = 2
469
- (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
470
-
471
- cv2.rectangle(overlay, (x1, y1-text_height-10), (x1+text_width+10, y1), (0, 255, 0), -1)
472
- cv2.putText(overlay, label, (x1+5, y1-5), font, font_scale, (0, 0, 0), thickness)
473
-
474
- return overlay
475
-
476
- def visualize_results(self, results, save_path=None):
477
- # Prepare images
478
- orig_rgb = cv2.cvtColor(results["original"], cv2.COLOR_BGR2RGB)
479
- combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
480
- overlay_bgr = self.create_overlay_image(results)
481
- overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
482
-
483
- # Create figure with custom layout
484
- fig = plt.figure(figsize=(20, 15))
485
-
486
- # Create a grid layout
487
- gs = fig.add_gridspec(3, 4, height_ratios=[1, 1, 0.8], hspace=0.3, wspace=0.2)
488
-
489
- # Top row - Original images
490
- ax1 = fig.add_subplot(gs[0, 0])
491
- ax1.imshow(orig_rgb)
492
- ax1.set_title("Original Image", fontsize=14, fontweight='bold')
493
- ax1.axis("off")
494
-
495
- ax2 = fig.add_subplot(gs[0, 1])
496
- ax2.imshow(overlay_rgb)
497
- ax2.set_title("Segmentation Overlay\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
498
- ax2.axis("off")
499
-
500
- ax3 = fig.add_subplot(gs[0, 2])
501
- ax3.imshow(combined_rgb)
502
- ax3.set_title("Distance Analysis\n(Yellow: Measurements)", fontsize=14, fontweight='bold')
503
- ax3.axis("off")
504
-
505
- # Individual mask visualization
506
- ax4 = fig.add_subplot(gs[0, 3])
507
- # Create combined mask visualization
508
- combined_mask = np.zeros((*results["cej_mask"].shape, 3), dtype=np.uint8)
509
- combined_mask[:, :, 2] = results["cej_mask"] # CEJ in red
510
- combined_mask[:, :, 0] = results["abc_mask"] # ABC in blue
511
- ax4.imshow(combined_mask)
512
- ax4.set_title("Combined Masks\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
513
- ax4.axis("off")
514
-
515
- # Middle row - Distance analysis charts
516
- ax5 = fig.add_subplot(gs[1, :2]) # Span 2 columns
517
-
518
- # Create bar chart of average distances per tooth
519
- tooth_means = []
520
- tooth_labels = []
521
- tooth_maxs = []
522
- tooth_mins = []
523
-
524
- for tooth_data in results["distance_analyses"]:
525
- if tooth_data['analysis'] is not None:
526
- tooth_labels.append(f"T{tooth_data['tooth_id']}")
527
- tooth_means.append(tooth_data['analysis']['mean_distance'])
528
- tooth_maxs.append(tooth_data['analysis']['max_distance'])
529
- tooth_mins.append(tooth_data['analysis']['min_distance'])
530
-
531
- if tooth_means:
532
- x_pos = np.arange(len(tooth_labels))
533
- bars = ax5.bar(x_pos, tooth_means, alpha=0.7, color='skyblue', edgecolor='navy')
534
-
535
- # Add error bars showing min/max range
536
- yerr_lower = np.array(tooth_means) - np.array(tooth_mins)
537
- yerr_upper = np.array(tooth_maxs) - np.array(tooth_means)
538
- ax5.errorbar(x_pos, tooth_means, yerr=[yerr_lower, yerr_upper],
539
- fmt='none', ecolor='red', capsize=5, alpha=0.7)
540
-
541
- ax5.set_xlabel("Tooth Number", fontsize=12, fontweight='bold')
542
- ax5.set_ylabel("Distance (pixels)", fontsize=12, fontweight='bold')
543
- ax5.set_title("CEJ-ABC Distance Analysis by Tooth", fontsize=14, fontweight='bold')
544
- ax5.set_xticks(x_pos)
545
- ax5.set_xticklabels(tooth_labels)
546
- ax5.grid(True, alpha=0.3)
547
-
548
- # Add value labels on bars
549
- for i, (bar, mean_val, max_val, min_val) in enumerate(zip(bars, tooth_means, tooth_maxs, tooth_mins)):
550
- height = bar.get_height()
551
- ax5.text(bar.get_x() + bar.get_width()/2., height + 1,
552
- f'{mean_val:.1f}', ha='center', va='bottom', fontweight='bold')
553
- else:
554
- ax5.text(0.5, 0.5, "No distance measurements available",
555
- ha='center', va='center', transform=ax5.transAxes, fontsize=14)
556
- ax5.set_title("CEJ-ABC Distance Analysis", fontsize=14, fontweight='bold')
557
-
558
- # Distance distribution histogram
559
- ax6 = fig.add_subplot(gs[1, 2:]) # Span 2 columns
560
-
561
- tooth_distances = self.create_distance_heatmap(results)
562
- if tooth_distances:
563
- all_vals = []
564
- labels = []
565
- colors = plt.cm.Set3(np.linspace(0, 1, len(tooth_distances)))
566
-
567
- for i, (tooth, distances) in enumerate(tooth_distances.items()):
568
- ax6.hist(distances, bins=20, alpha=0.6, label=tooth, color=colors[i])
569
- all_vals.extend(distances)
570
-
571
- ax6.set_xlabel("Distance (pixels)", fontsize=12, fontweight='bold')
572
- ax6.set_ylabel("Frequency", fontsize=12, fontweight='bold')
573
- ax6.set_title("Distance Distribution Across All Measurements", fontsize=14, fontweight='bold')
574
- ax6.legend()
575
- ax6.grid(True, alpha=0.3)
576
- else:
577
- ax6.text(0.5, 0.5, "No distance data available for histogram",
578
- ha='center', va='center', transform=ax6.transAxes, fontsize=14)
579
- ax6.set_title("Distance Distribution", fontsize=14, fontweight='bold')
580
-
581
- # Bottom row - Summary statistics table
582
- ax7 = fig.add_subplot(gs[2, :])
583
- ax7.axis('tight')
584
- ax7.axis('off')
585
-
586
- # Create summary table
587
- if tooth_means:
588
- table_data = []
589
- headers = ['Tooth', 'Mean Distance (px)', 'Max Distance (px)', 'Min Distance (px)', 'Range (px)', 'Measurements']
590
-
591
- for tooth_data in results["distance_analyses"]:
592
- if tooth_data['analysis'] is not None:
593
- analysis = tooth_data['analysis']
594
- range_val = analysis['max_distance'] - analysis['min_distance']
595
- num_measurements = len(analysis['distances'])
596
-
597
- table_data.append([
598
- f"T{tooth_data['tooth_id']}",
599
- f"{analysis['mean_distance']:.2f}",
600
- f"{analysis['max_distance']:.2f}",
601
- f"{analysis['min_distance']:.2f}",
602
- f"{range_val:.2f}",
603
- str(num_measurements)
604
- ])
605
-
606
- table = ax7.table(cellText=table_data, colLabels=headers,
607
- cellLoc='center', loc='center')
608
- table.auto_set_font_size(False)
609
- table.set_fontsize(10)
610
- table.scale(1, 2)
611
-
612
- # Style the table
613
- for (i, j), cell in table.get_celld().items():
614
- if i == 0: # Header row
615
- cell.set_text_props(weight='bold', color='white')
616
- cell.set_facecolor('#4472C4')
617
- else:
618
- cell.set_facecolor('#F2F2F2' if i % 2 == 0 else 'white')
619
-
620
- ax7.set_title("Detailed Distance Analysis Summary", fontsize=14, fontweight='bold', pad=20)
621
- else:
622
- ax7.text(0.5, 0.5, "No measurements available for summary table",
623
- ha='center', va='center', transform=ax7.transAxes, fontsize=14)
624
-
625
- plt.suptitle("Comprehensive Dental CEJ-ABC Distance Analysis", fontsize=16, fontweight='bold', y=0.98)
626
-
627
- if save_path:
628
- plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white')
629
- print(f"Saved enhanced visualization to {save_path}")
630
-
631
- return fig
632
-
633
-
634
- if __name__ == "__main__":
635
- unet_model = "models/dental_segmentation_model_augment_100epochs.tflite"
636
- yolo_model = "models/best v8n-seg_float16.tflite"
637
- image_path = "trial2.jpg"
638
-
639
- seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model)
640
- res = seg.analyze_image(image_path)
641
-
642
- # Print distance analysis summary
643
- seg.print_distance_summary(res)
644
-
645
- fig = seg.visualize_results(res, save_path="segmentation_with_distances.png")
646
  plt.show()
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import tensorflow as tf
6
+ from ultralytics import YOLO
7
+
8
+ class SimpleDentalSegmentationNoEnhance:
9
+ def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224,224,3)):
10
+ # Load TFLite U-Net
11
+ self.interpreter = tf.lite.Interpreter(model_path=unet_model_path)
12
+ self.interpreter.allocate_tensors()
13
+ self.input_details = self.interpreter.get_input_details()
14
+ self.output_details = self.interpreter.get_output_details()
15
+
16
+ # Force/prefer the desired U-Net input size
17
+ self.in_h, self.in_w, self.in_c = unet_input_size
18
+
19
+ # Load YOLOv8
20
+ self.yolo = YOLO(yolo_model_path)
21
+
22
+ print("Models loaded successfully!")
23
+ print(f"Using forced U-Net input shape: (1, {self.in_h}, {self.in_w}, {self.in_c})")
24
+ print(f"U-Net output shape (raw): {self.output_details[0]['shape']}")
25
+
26
+ def preprocess_for_unet(self, image_bgr):
27
+ img = image_bgr.copy()
28
+ proc_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29
+ proc_resized = cv2.resize(proc_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR)
30
+ normalized = proc_resized.astype(np.float32) / 255.0
31
+ input_tensor = np.expand_dims(normalized, axis=0).astype(np.float32)
32
+ return input_tensor, proc_resized
33
+
34
+ def run_unet(self, image_bgr):
35
+ input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr)
36
+
37
+ try:
38
+ self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
39
+ self.interpreter.invoke()
40
+ output = self.interpreter.get_tensor(self.output_details[0]['index'])
41
+ except Exception as e:
42
+ print("Interpreter set_tensor failed, attempting to resize input to forced shape:", e)
43
+ try:
44
+ self.interpreter.resize_tensor_input(self.input_details[0]['index'], [1, self.in_h, self.in_w, self.in_c])
45
+ self.interpreter.allocate_tensors()
46
+ self.interpreter.set_tensor(self.input_details[0]['index'], input_tensor)
47
+ self.interpreter.invoke()
48
+ output = self.interpreter.get_tensor(self.output_details[0]['index'])
49
+ except Exception as e2:
50
+ raise RuntimeError("Failed to run TFLite interpreter") from e2
51
+
52
+ out = output[0]
53
+
54
+ if out.ndim == 3 and out.shape[2] >= 2:
55
+ class_map = np.argmax(out, axis=2).astype(np.uint8)
56
+ abc = (class_map == 1).astype(np.uint8)
57
+ cej = (class_map == 2).astype(np.uint8)
58
+ elif out.ndim == 2:
59
+ combined = out
60
+ abc = (combined > 0.5).astype(np.uint8)
61
+ cej = (combined > 0.8).astype(np.uint8)
62
+ else:
63
+ h_unet = out.shape[0]
64
+ w_unet = out.shape[1] if out.ndim >= 2 else (self.in_w)
65
+ abc = np.zeros((h_unet, w_unet), dtype=np.uint8)
66
+ cej = np.zeros((h_unet, w_unet), dtype=np.uint8)
67
+
68
+ return cej, abc, model_resized_image
69
+
70
+ def detect_teeth(self, image_bgr):
71
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
72
+ results = self.yolo(image_rgb)
73
+ detections = []
74
+ for r in results:
75
+ boxes = getattr(r, "boxes", None)
76
+ if boxes is None:
77
+ continue
78
+ for i, box in enumerate(boxes):
79
+ try:
80
+ xyxy = box.xyxy[0].cpu().numpy()
81
+ except Exception:
82
+ xyxy = np.array(box.xyxy).astype(np.float32).reshape(-1)[:4]
83
+ try:
84
+ conf = float(box.conf[0].cpu().numpy())
85
+ except Exception:
86
+ conf = float(box.conf if hasattr(box, "conf") else 0.0)
87
+ detections.append({
88
+ "bbox": xyxy.astype(np.float32),
89
+ "confidence": conf,
90
+ "tooth_id": len(detections) + 1
91
+ })
92
+ return detections
93
+
94
+ def resize_mask_to_original(self, mask, original_shape):
95
+ h_orig, w_orig = original_shape
96
+ mask_resized = cv2.resize((mask * 255).astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
97
+ mask_resized = (mask_resized.astype(np.float32) / 255.0)
98
+ return mask_resized
99
+
100
+ def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox):
101
+ x1, y1, x2, y2 = bbox
102
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
103
+ height, width = abc_mask.shape
104
+ x1 = max(0, x1)
105
+ y1 = max(0, y1)
106
+ x2 = min(width-1, x2)
107
+ y2 = min(height-1, y2)
108
+
109
+ abc_points = []
110
+ for x in range(x1, x2+1):
111
+ column_abc = np.where(abc_mask[y1:y2+1, x] == 255)[0]
112
+ if len(column_abc) > 0:
113
+ y_min_relative = np.min(column_abc)
114
+ y_absolute = y1 + y_min_relative
115
+ abc_points.append([x, y_absolute])
116
+
117
+ if len(abc_points) < 2:
118
+ return None
119
+ return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2)
120
+
121
+ def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox):
122
+ x1, y1, x2, y2 = bbox
123
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
124
+ height, width = cej_mask.shape
125
+ x1 = max(0, x1)
126
+ y1 = max(0, y1)
127
+ x2 = min(width-1, x2)
128
+ y2 = min(height-1, y2)
129
+
130
+ cej_points = []
131
+ for x in range(x1, x2+1):
132
+ column_cej = np.where(cej_mask[y1:y2+1, x] == 255)[0]
133
+ if len(column_cej) > 0:
134
+ y_max_relative = np.max(column_cej)
135
+ y_absolute = y1 + y_max_relative
136
+ cej_points.append([x, y_absolute])
137
+
138
+ if len(cej_points) < 2:
139
+ return None
140
+ return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2)
141
+
142
+ def smooth_landmarks(self, points, window_size=5):
143
+ if points is None or len(points) < window_size:
144
+ return points
145
+ points_2d = points.reshape(-1, 2)
146
+ smoothed_points = []
147
+ for i in range(len(points_2d)):
148
+ start_idx = max(0, i - window_size // 2)
149
+ end_idx = min(len(points_2d), i + window_size // 2 + 1)
150
+ window_points = points_2d[start_idx:end_idx]
151
+ smoothed_y = np.mean(window_points[:, 1])
152
+ smoothed_points.append([points_2d[i][0], smoothed_y])
153
+ return np.array(smoothed_points, dtype=np.int32).reshape(-1, 1, 2)
154
+
155
+ def compute_cej_abc_distances(self, cej_points, abc_points):
156
+ """
157
+ Compute distances between CEJ and ABC points.
158
+ For each x-coordinate, find the vertical distance between CEJ and ABC.
159
+ """
160
+ if cej_points is None or abc_points is None:
161
+ return None
162
+
163
+ cej_2d = cej_points.reshape(-1, 2)
164
+ abc_2d = abc_points.reshape(-1, 2)
165
+
166
+ # Create dictionaries for quick lookup
167
+ cej_dict = {point[0]: point[1] for point in cej_2d}
168
+ abc_dict = {point[0]: point[1] for point in abc_2d}
169
+
170
+ # Find common x-coordinates
171
+ common_x = set(cej_dict.keys()) & set(abc_dict.keys())
172
+
173
+ if not common_x:
174
+ # If no exact matches, use interpolation
175
+ return self.compute_distances_with_interpolation(cej_2d, abc_2d)
176
+
177
+ distances = []
178
+ connection_points = []
179
+
180
+ for x in sorted(common_x):
181
+ cej_y = cej_dict[x]
182
+ abc_y = abc_dict[x]
183
+ distance = abs(abc_y - cej_y) # Vertical distance
184
+
185
+ distances.append({
186
+ 'x': x,
187
+ 'cej_y': cej_y,
188
+ 'abc_y': abc_y,
189
+ 'distance': distance
190
+ })
191
+
192
+ connection_points.append([(x, cej_y), (x, abc_y)])
193
+
194
+ return {
195
+ 'distances': distances,
196
+ 'connection_points': connection_points,
197
+ 'mean_distance': np.mean([d['distance'] for d in distances]),
198
+ 'max_distance': max([d['distance'] for d in distances]),
199
+ 'min_distance': min([d['distance'] for d in distances])
200
+ }
201
+
202
+ def compute_distances_with_interpolation(self, cej_points, abc_points):
203
+ """
204
+ Compute distances using interpolation when points don't have exact x-matches.
205
+ """
206
+ # Get x-range that's common to both curves
207
+ cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0])
208
+ abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0])
209
+
210
+ x_min = max(cej_x_min, abc_x_min)
211
+ x_max = min(cej_x_max, abc_x_max)
212
+
213
+ if x_min >= x_max:
214
+ return None
215
+
216
+ # Sample points at regular intervals
217
+ num_samples = min(50, x_max - x_min + 1)
218
+ x_sample = np.linspace(x_min, x_max, num_samples, dtype=int)
219
+
220
+ # Interpolate y-values for both curves
221
+ cej_y_interp = np.interp(x_sample, cej_points[:, 0], cej_points[:, 1])
222
+ abc_y_interp = np.interp(x_sample, abc_points[:, 0], abc_points[:, 1])
223
+
224
+ distances = []
225
+ connection_points = []
226
+
227
+ for i, x in enumerate(x_sample):
228
+ cej_y = cej_y_interp[i]
229
+ abc_y = abc_y_interp[i]
230
+ distance = abs(abc_y - cej_y)
231
+
232
+ distances.append({
233
+ 'x': int(x),
234
+ 'cej_y': int(cej_y),
235
+ 'abc_y': int(abc_y),
236
+ 'distance': distance
237
+ })
238
+
239
+ connection_points.append([(int(x), int(cej_y)), (int(x), int(abc_y))])
240
+
241
+ return {
242
+ 'distances': distances,
243
+ 'connection_points': connection_points,
244
+ 'mean_distance': np.mean([d['distance'] for d in distances]),
245
+ 'max_distance': max([d['distance'] for d in distances]),
246
+ 'min_distance': min([d['distance'] for d in distances])
247
+ }
248
+
249
+ def draw_distance_measurements(self, image, distance_analysis, tooth_id):
250
+ """
251
+ Draw clean distance measurements on the image without text overlays.
252
+ """
253
+ if distance_analysis is None:
254
+ return image
255
+
256
+ img_with_distances = image.copy()
257
+
258
+ # Draw connection lines with gradient effect
259
+ connection_points = distance_analysis['connection_points']
260
+ distances = [d['distance'] for d in distance_analysis['distances']]
261
+
262
+ if not distances:
263
+ return img_with_distances
264
+
265
+ # Normalize distances for color mapping
266
+ min_dist = min(distances)
267
+ max_dist = max(distances)
268
+ dist_range = max_dist - min_dist if max_dist != min_dist else 1
269
+
270
+ # Draw every 3rd line to reduce clutter, with color coding
271
+ for i in range(0, len(connection_points), 3):
272
+ start_point, end_point = connection_points[i]
273
+ distance = distances[i]
274
+
275
+ # Color based on distance (green = small, red = large)
276
+ normalized_dist = (distance - min_dist) / dist_range
277
+ color_intensity = int(255 * normalized_dist)
278
+ color = (0, 255 - color_intensity, color_intensity) # Green to Red
279
+
280
+ # Draw thicker line for longer distances
281
+ thickness = max(1, int(3 * normalized_dist) + 1)
282
+ cv2.line(img_with_distances, start_point, end_point, color, thickness)
283
+
284
+ # Add small circles at measurement points
285
+ cv2.circle(img_with_distances, start_point, 2, (255, 255, 255), -1)
286
+ cv2.circle(img_with_distances, end_point, 2, (255, 255, 255), -1)
287
+
288
+ return img_with_distances
289
+
290
+ def analyze_image(self, image_path):
291
+ img_bgr = cv2.imread(image_path)
292
+ if img_bgr is None:
293
+ raise FileNotFoundError(f"Could not read image: {image_path}")
294
+ h_orig, w_orig = img_bgr.shape[:2]
295
+
296
+ cej_unet, abc_unet, _ = self.run_unet(img_bgr)
297
+
298
+ cej_orig = self.resize_mask_to_original(cej_unet, (h_orig, w_orig))
299
+ abc_orig = self.resize_mask_to_original(abc_unet, (h_orig, w_orig))
300
+
301
+ cej_bin = (cej_orig > 0.5).astype(np.uint8) * 255
302
+ abc_bin = (abc_orig > 0.5).astype(np.uint8) * 255
303
+
304
+ detections = self.detect_teeth(img_bgr)
305
+
306
+ combined = img_bgr.copy()
307
+ all_abc_segments = []
308
+ all_cej_segments = []
309
+ all_distance_analyses = []
310
+
311
+ for i, det in enumerate(detections):
312
+ x1, y1, x2, y2 = det["bbox"]
313
+ x1i = max(0, int(np.floor(x1)))
314
+ y1i = max(0, int(np.floor(y1)))
315
+ x2i = min(w_orig - 1, int(np.ceil(x2)))
316
+ y2i = min(h_orig - 1, int(np.ceil(y2)))
317
+ if x2i <= x1i or y2i <= y1i:
318
+ continue
319
+
320
+ cv2.rectangle(combined, (x1i, y1i), (x2i, y2i), (0,255,0), 2)
321
+
322
+ # ABC
323
+ abc_line_segment = self.extract_abc_uppermost_line_within_bbox(abc_bin, (x1i, y1i, x2i, y2i))
324
+ abc_data = None
325
+ if abc_line_segment is not None and len(abc_line_segment) > 1:
326
+ abc_line_segment = self.smooth_landmarks(abc_line_segment, window_size=3)
327
+ cv2.polylines(combined, [abc_line_segment], False, (255,0,0), 3)
328
+ abc_start = tuple(abc_line_segment[0][0])
329
+ abc_end = tuple(abc_line_segment[-1][0])
330
+ cv2.circle(combined, abc_start, 4, (255,165,0), -1)
331
+ cv2.circle(combined, abc_end, 4, (255,165,0), -1)
332
+ abc_data = {
333
+ "points": abc_line_segment,
334
+ "start": abc_start,
335
+ "end": abc_end
336
+ }
337
+ all_abc_segments.append(abc_data)
338
+
339
+ # CEJ
340
+ cej_line_segment = self.extract_cej_lowermost_line_within_bbox(cej_bin, (x1i, y1i, x2i, y2i))
341
+ cej_data = None
342
+ if cej_line_segment is not None and len(cej_line_segment) > 1:
343
+ cej_line_segment = self.smooth_landmarks(cej_line_segment, window_size=3)
344
+ cv2.polylines(combined, [cej_line_segment], False, (0,0,255), 3)
345
+ cej_start = tuple(cej_line_segment[0][0])
346
+ cej_end = tuple(cej_line_segment[-1][0])
347
+ cv2.circle(combined, cej_start, 4, (0,255,255), -1)
348
+ cv2.circle(combined, cej_end, 4, (0,255,255), -1)
349
+ cej_data = {
350
+ "points": cej_line_segment,
351
+ "start": cej_start,
352
+ "end": cej_end
353
+ }
354
+ all_cej_segments.append(cej_data)
355
+
356
+ # Compute distances between CEJ and ABC
357
+ distance_analysis = None
358
+ if abc_data is not None and cej_data is not None:
359
+ distance_analysis = self.compute_cej_abc_distances(
360
+ cej_data["points"], abc_data["points"]
361
+ )
362
+ if distance_analysis is not None:
363
+ combined = self.draw_distance_measurements(combined, distance_analysis, i+1)
364
+
365
+ all_distance_analyses.append({
366
+ 'tooth_id': i + 1,
367
+ 'analysis': distance_analysis
368
+ })
369
+
370
+ cv2.putText(combined, f"T{i+1}", (x1i+5, y1i+20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
371
+
372
+ results = {
373
+ "original": img_bgr,
374
+ "cej_mask": cej_bin,
375
+ "abc_mask": abc_bin,
376
+ "detections": detections,
377
+ "combined": combined,
378
+ "abc_segments": all_abc_segments,
379
+ "cej_segments": all_cej_segments,
380
+ "distance_analyses": all_distance_analyses
381
+ }
382
+ return results
383
+
384
+ def print_distance_summary(self, results):
385
+ """
386
+ Print a summary of distance measurements for all teeth.
387
+ """
388
+ print("\n" + "="*50)
389
+ print("CEJ-ABC DISTANCE ANALYSIS SUMMARY")
390
+ print("="*50)
391
+
392
+ for tooth_data in results["distance_analyses"]:
393
+ tooth_id = tooth_data['tooth_id']
394
+ analysis = tooth_data['analysis']
395
+
396
+ if analysis is None:
397
+ print(f"Tooth {tooth_id}: No distance measurements available")
398
+ continue
399
+
400
+ print(f"\nTooth {tooth_id}:")
401
+ print(f" Mean distance: {analysis['mean_distance']:.2f} pixels")
402
+ print(f" Maximum distance: {analysis['max_distance']:.2f} pixels")
403
+ print(f" Minimum distance: {analysis['min_distance']:.2f} pixels")
404
+ print(f" Number of measurement points: {len(analysis['distances'])}")
405
+
406
+ # Show some sample measurements
407
+ if len(analysis['distances']) > 0:
408
+ sample_size = min(3, len(analysis['distances']))
409
+ print(f" Sample measurements:")
410
+ for i in range(0, len(analysis['distances']), len(analysis['distances'])//sample_size):
411
+ d = analysis['distances'][i]
412
+ print(f" X={d['x']}: CEJ_Y={d['cej_y']}, ABC_Y={d['abc_y']}, Distance={d['distance']:.1f}px")
413
+
414
+ def create_distance_heatmap(self, results):
415
+ """Create a heatmap visualization of distances across all teeth."""
416
+ all_distances = []
417
+ tooth_labels = []
418
+
419
+ for tooth_data in results["distance_analyses"]:
420
+ if tooth_data['analysis'] is not None:
421
+ distances = [d['distance'] for d in tooth_data['analysis']['distances']]
422
+ all_distances.extend(distances)
423
+ tooth_labels.extend([f"T{tooth_data['tooth_id']}"] * len(distances))
424
+
425
+ if not all_distances:
426
+ return None
427
+
428
+ # Create histogram data
429
+ unique_teeth = list(set(tooth_labels))
430
+ tooth_distances = {tooth: [] for tooth in unique_teeth}
431
+
432
+ for i, tooth in enumerate(tooth_labels):
433
+ tooth_distances[tooth].append(all_distances[i])
434
+
435
+ return tooth_distances
436
+
437
+ def create_overlay_image(self, results):
438
+ """Create an enhanced overlay image with better visualization."""
439
+ img = results["original"].copy()
440
+ cej_mask = results["cej_mask"]
441
+ abc_mask = results["abc_mask"]
442
+
443
+ # Create colored overlays
444
+ overlay = img.copy()
445
+
446
+ # CEJ in red with transparency
447
+ cej_colored = np.zeros_like(img)
448
+ cej_colored[:, :, 2] = cej_mask # Red channel
449
+
450
+ # ABC in blue with transparency
451
+ abc_colored = np.zeros_like(img)
452
+ abc_colored[:, :, 0] = abc_mask # Blue channel
453
+
454
+ # Blend overlays
455
+ alpha = 0.4
456
+ overlay = cv2.addWeighted(overlay, 1-alpha, cej_colored, alpha, 0)
457
+ overlay = cv2.addWeighted(overlay, 1-alpha, abc_colored, alpha, 0)
458
+
459
+ # Add tooth detection boxes and labels
460
+ for i, det in enumerate(results["detections"]):
461
+ x1, y1, x2, y2 = det["bbox"].astype(int)
462
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
463
+
464
+ # Add tooth label with background
465
+ label = f"Tooth {i+1}"
466
+ font = cv2.FONT_HERSHEY_SIMPLEX
467
+ font_scale = 0.7
468
+ thickness = 2
469
+ (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
470
+
471
+ cv2.rectangle(overlay, (x1, y1-text_height-10), (x1+text_width+10, y1), (0, 255, 0), -1)
472
+ cv2.putText(overlay, label, (x1+5, y1-5), font, font_scale, (0, 0, 0), thickness)
473
+
474
+ return overlay
475
+
476
+ def visualize_results(self, results, save_path=None):
477
+ # Prepare images
478
+ orig_rgb = cv2.cvtColor(results["original"], cv2.COLOR_BGR2RGB)
479
+ combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
480
+ overlay_bgr = self.create_overlay_image(results)
481
+ overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
482
+
483
+ # Create figure with custom layout
484
+ fig = plt.figure(figsize=(20, 15))
485
+
486
+ # Create a grid layout
487
+ gs = fig.add_gridspec(3, 4, height_ratios=[1, 1, 0.8], hspace=0.3, wspace=0.2)
488
+
489
+ # Top row - Original images
490
+ ax1 = fig.add_subplot(gs[0, 0])
491
+ ax1.imshow(orig_rgb)
492
+ ax1.set_title("Original Image", fontsize=14, fontweight='bold')
493
+ ax1.axis("off")
494
+
495
+ ax2 = fig.add_subplot(gs[0, 1])
496
+ ax2.imshow(overlay_rgb)
497
+ ax2.set_title("Segmentation Overlay\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
498
+ ax2.axis("off")
499
+
500
+ ax3 = fig.add_subplot(gs[0, 2])
501
+ ax3.imshow(combined_rgb)
502
+ ax3.set_title("Distance Analysis\n(Yellow: Measurements)", fontsize=14, fontweight='bold')
503
+ ax3.axis("off")
504
+
505
+ # Individual mask visualization
506
+ ax4 = fig.add_subplot(gs[0, 3])
507
+ # Create combined mask visualization
508
+ combined_mask = np.zeros((*results["cej_mask"].shape, 3), dtype=np.uint8)
509
+ combined_mask[:, :, 2] = results["cej_mask"] # CEJ in red
510
+ combined_mask[:, :, 0] = results["abc_mask"] # ABC in blue
511
+ ax4.imshow(combined_mask)
512
+ ax4.set_title("Combined Masks\n(Red: CEJ, Blue: ABC)", fontsize=14, fontweight='bold')
513
+ ax4.axis("off")
514
+
515
+ # Middle row - Distance analysis charts
516
+ ax5 = fig.add_subplot(gs[1, :2]) # Span 2 columns
517
+
518
+ # Create bar chart of average distances per tooth
519
+ tooth_means = []
520
+ tooth_labels = []
521
+ tooth_maxs = []
522
+ tooth_mins = []
523
+
524
+ for tooth_data in results["distance_analyses"]:
525
+ if tooth_data['analysis'] is not None:
526
+ tooth_labels.append(f"T{tooth_data['tooth_id']}")
527
+ tooth_means.append(tooth_data['analysis']['mean_distance'])
528
+ tooth_maxs.append(tooth_data['analysis']['max_distance'])
529
+ tooth_mins.append(tooth_data['analysis']['min_distance'])
530
+
531
+ if tooth_means:
532
+ x_pos = np.arange(len(tooth_labels))
533
+ bars = ax5.bar(x_pos, tooth_means, alpha=0.7, color='skyblue', edgecolor='navy')
534
+
535
+ # Add error bars showing min/max range
536
+ yerr_lower = np.array(tooth_means) - np.array(tooth_mins)
537
+ yerr_upper = np.array(tooth_maxs) - np.array(tooth_means)
538
+ ax5.errorbar(x_pos, tooth_means, yerr=[yerr_lower, yerr_upper],
539
+ fmt='none', ecolor='red', capsize=5, alpha=0.7)
540
+
541
+ ax5.set_xlabel("Tooth Number", fontsize=12, fontweight='bold')
542
+ ax5.set_ylabel("Distance (pixels)", fontsize=12, fontweight='bold')
543
+ ax5.set_title("CEJ-ABC Distance Analysis by Tooth", fontsize=14, fontweight='bold')
544
+ ax5.set_xticks(x_pos)
545
+ ax5.set_xticklabels(tooth_labels)
546
+ ax5.grid(True, alpha=0.3)
547
+
548
+ # Add value labels on bars
549
+ for i, (bar, mean_val, max_val, min_val) in enumerate(zip(bars, tooth_means, tooth_maxs, tooth_mins)):
550
+ height = bar.get_height()
551
+ ax5.text(bar.get_x() + bar.get_width()/2., height + 1,
552
+ f'{mean_val:.1f}', ha='center', va='bottom', fontweight='bold')
553
+ else:
554
+ ax5.text(0.5, 0.5, "No distance measurements available",
555
+ ha='center', va='center', transform=ax5.transAxes, fontsize=14)
556
+ ax5.set_title("CEJ-ABC Distance Analysis", fontsize=14, fontweight='bold')
557
+
558
+ # Distance distribution histogram
559
+ ax6 = fig.add_subplot(gs[1, 2:]) # Span 2 columns
560
+
561
+ tooth_distances = self.create_distance_heatmap(results)
562
+ if tooth_distances:
563
+ all_vals = []
564
+ labels = []
565
+ colors = plt.cm.Set3(np.linspace(0, 1, len(tooth_distances)))
566
+
567
+ for i, (tooth, distances) in enumerate(tooth_distances.items()):
568
+ ax6.hist(distances, bins=20, alpha=0.6, label=tooth, color=colors[i])
569
+ all_vals.extend(distances)
570
+
571
+ ax6.set_xlabel("Distance (pixels)", fontsize=12, fontweight='bold')
572
+ ax6.set_ylabel("Frequency", fontsize=12, fontweight='bold')
573
+ ax6.set_title("Distance Distribution Across All Measurements", fontsize=14, fontweight='bold')
574
+ ax6.legend()
575
+ ax6.grid(True, alpha=0.3)
576
+ else:
577
+ ax6.text(0.5, 0.5, "No distance data available for histogram",
578
+ ha='center', va='center', transform=ax6.transAxes, fontsize=14)
579
+ ax6.set_title("Distance Distribution", fontsize=14, fontweight='bold')
580
+
581
+ # Bottom row - Summary statistics table
582
+ ax7 = fig.add_subplot(gs[2, :])
583
+ ax7.axis('tight')
584
+ ax7.axis('off')
585
+
586
+ # Create summary table
587
+ if tooth_means:
588
+ table_data = []
589
+ headers = ['Tooth', 'Mean Distance (px)', 'Max Distance (px)', 'Min Distance (px)', 'Range (px)', 'Measurements']
590
+
591
+ for tooth_data in results["distance_analyses"]:
592
+ if tooth_data['analysis'] is not None:
593
+ analysis = tooth_data['analysis']
594
+ range_val = analysis['max_distance'] - analysis['min_distance']
595
+ num_measurements = len(analysis['distances'])
596
+
597
+ table_data.append([
598
+ f"T{tooth_data['tooth_id']}",
599
+ f"{analysis['mean_distance']:.2f}",
600
+ f"{analysis['max_distance']:.2f}",
601
+ f"{analysis['min_distance']:.2f}",
602
+ f"{range_val:.2f}",
603
+ str(num_measurements)
604
+ ])
605
+
606
+ table = ax7.table(cellText=table_data, colLabels=headers,
607
+ cellLoc='center', loc='center')
608
+ table.auto_set_font_size(False)
609
+ table.set_fontsize(10)
610
+ table.scale(1, 2)
611
+
612
+ # Style the table
613
+ for (i, j), cell in table.get_celld().items():
614
+ if i == 0: # Header row
615
+ cell.set_text_props(weight='bold', color='white')
616
+ cell.set_facecolor('#4472C4')
617
+ else:
618
+ cell.set_facecolor('#F2F2F2' if i % 2 == 0 else 'white')
619
+
620
+ ax7.set_title("Detailed Distance Analysis Summary", fontsize=14, fontweight='bold', pad=20)
621
+ else:
622
+ ax7.text(0.5, 0.5, "No measurements available for summary table",
623
+ ha='center', va='center', transform=ax7.transAxes, fontsize=14)
624
+
625
+ plt.suptitle("Comprehensive Dental CEJ-ABC Distance Analysis", fontsize=16, fontweight='bold', y=0.98)
626
+
627
+ if save_path:
628
+ plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white')
629
+ print(f"Saved enhanced visualization to {save_path}")
630
+
631
+ return fig
632
+
633
+
634
+ if __name__ == "__main__":
635
+ unet_model = "unet.keras"
636
+ yolo_model = "yolov8n-seg.pt"
637
+ image_path = "trial.jpg"
638
+
639
+ seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model)
640
+ res = seg.analyze_image(image_path)
641
+
642
+ # Print distance analysis summary
643
+ seg.print_distance_summary(res)
644
+
645
+ fig = seg.visualize_results(res, save_path="segmentation_with_distances.png")
646
  plt.show()