Deepdent-Periodontitistry / periodontitis_detection.py
jayn95's picture
Update periodontitis_detection.py
f6e3545 verified
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from ultralytics import YOLO
class SimpleDentalSegmentationNoEnhance:
def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224, 224, 3)):
"""
Initialize the dental segmentation and analysis pipeline.
"""
# Load Keras U-Net model
self.unet = tf.keras.models.load_model(unet_model_path)
self.in_h, self.in_w, self.in_c = unet_input_size
# Load YOLOv8 (PyTorch) model
self.yolo = YOLO(yolo_model_path)
print(f"✅ YOLO model loaded from: {yolo_model_path}")
print("Models loaded successfully.")
print(f"Keras U-Net input shape: {self.unet.input_shape}")
print(f"Keras U-Net output shape: {self.unet.output_shape}")
print(f"YOLO model loaded: {yolo_model_path}")
def preprocess_for_unet(self, image_bgr):
"""
Prepare a BGR image for U-Net prediction.
Converts to RGB, resizes, and normalizes.
"""
img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR)
img_norm = img_resized.astype(np.float32) / 255.0
input_tensor = np.expand_dims(img_norm, axis=0)
return input_tensor, img_resized
def run_unet(self, image_bgr):
"""
Run the Keras U-Net model on the given image.
Returns CEJ and ABC masks.
"""
input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr)
preds = self.unet.predict(input_tensor, verbose=0)
out = preds[0]
if out.ndim == 3 and out.shape[2] >= 2:
class_map = np.argmax(out, axis=2).astype(np.uint8)
abc = (class_map == 1).astype(np.uint8)
cej = (class_map == 2).astype(np.uint8)
elif out.ndim == 3 and out.shape[2] == 1:
binary = out[:, :, 0]
abc = (binary > 0.5).astype(np.uint8)
cej = np.zeros_like(abc)
else:
h, w = out.shape[:2]
abc = np.zeros((h, w), dtype=np.uint8)
cej = np.zeros((h, w), dtype=np.uint8)
return cej, abc, model_resized_image
def detect_teeth(self, image_bgr):
"""
Detect teeth using YOLOv8 PyTorch model.
Returns bounding boxes and confidence scores.
"""
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
results = self.yolo(image_rgb)
print(results)
detections = []
for r in results:
boxes = getattr(r, "boxes", None)
if boxes is None:
continue
for box in boxes:
xyxy = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0].cpu().numpy())
detections.append({
"bbox": xyxy.astype(np.float32),
"confidence": conf,
"tooth_id": len(detections) + 1
})
return detections
def resize_mask_to_original(self, mask, original_shape):
"""Resize a predicted mask back to original image size."""
h_orig, w_orig = original_shape
mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
return (mask_resized > 127).astype(np.uint8)
def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox):
"""Extract the uppermost ABC line within a detected tooth bounding box."""
x1, y1, x2, y2 = map(int, bbox)
height, width = abc_mask.shape
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(width - 1, x2), min(height - 1, y2)
abc_points = []
for x in range(x1, x2 + 1):
column = np.where(abc_mask[y1:y2 + 1, x] == 1)[0]
if len(column) > 0:
y_absolute = y1 + np.min(column)
abc_points.append([x, y_absolute])
if len(abc_points) < 2:
return None
return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2)
def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox):
"""Extract the lowermost CEJ line within a detected tooth bounding box."""
x1, y1, x2, y2 = map(int, bbox)
height, width = cej_mask.shape
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(width - 1, x2), min(height - 1, y2)
cej_points = []
for x in range(x1, x2 + 1):
column = np.where(cej_mask[y1:y2 + 1, x] == 1)[0]
if len(column) > 0:
y_absolute = y1 + np.max(column)
cej_points.append([x, y_absolute])
if len(cej_points) < 2:
return None
return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2)
def smooth_landmarks(self, points, window_size=5):
"""Smooth a polyline using a simple moving average."""
if points is None or len(points) < window_size:
return points
pts = points.reshape(-1, 2)
smoothed = []
for i in range(len(pts)):
start, end = max(0, i - window_size // 2), min(len(pts), i + window_size // 2 + 1)
smoothed_y = np.mean(pts[start:end, 1])
smoothed.append([pts[i, 0], smoothed_y])
return np.array(smoothed, dtype=np.int32).reshape(-1, 1, 2)
def compute_cej_abc_distances(self, cej_points, abc_points):
"""Compute vertical distances between CEJ and ABC points."""
if cej_points is None or abc_points is None:
return None
cej_2d, abc_2d = cej_points.reshape(-1, 2), abc_points.reshape(-1, 2)
cej_dict = {x: y for x, y in cej_2d}
abc_dict = {x: y for x, y in abc_2d}
common_x = set(cej_dict.keys()) & set(abc_dict.keys())
if not common_x:
return self.compute_distances_with_interpolation(cej_2d, abc_2d)
distances, connections = [], []
for x in sorted(common_x):
cej_y, abc_y = cej_dict[x], abc_dict[x]
dist = abs(abc_y - cej_y)
distances.append({'x': x, 'cej_y': cej_y, 'abc_y': abc_y, 'distance': dist})
connections.append([(x, cej_y), (x, abc_y)])
return {
'distances': distances,
'connection_points': connections,
'mean_distance': np.mean([d['distance'] for d in distances]),
'max_distance': np.max([d['distance'] for d in distances]),
'min_distance': np.min([d['distance'] for d in distances]),
}
def compute_distances_with_interpolation(self, cej_points, abc_points):
"""Interpolate CEJ and ABC lines when x-coordinates don’t match exactly."""
cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0])
abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0])
x_min, x_max = max(cej_x_min, abc_x_min), min(cej_x_max, abc_x_max)
if x_min >= x_max:
return None
x_samples = np.linspace(x_min, x_max, min(50, int(x_max - x_min) + 1), dtype=int)
cej_y = np.interp(x_samples, cej_points[:, 0], cej_points[:, 1])
abc_y = np.interp(x_samples, abc_points[:, 0], abc_points[:, 1])
distances, connections = [], []
for x, cy, ay in zip(x_samples, cej_y, abc_y):
dist = abs(ay - cy)
distances.append({'x': int(x), 'cej_y': int(cy), 'abc_y': int(ay), 'distance': dist})
connections.append([(int(x), int(cy)), (int(x), int(ay))])
return {
'distances': distances,
'connection_points': connections,
'mean_distance': np.mean([d['distance'] for d in distances]),
'max_distance': np.max([d['distance'] for d in distances]),
'min_distance': np.min([d['distance'] for d in distances]),
}
def draw_distance_measurements(self, image, distance_analysis, tooth_id):
"""Draw color-coded CEJ-ABC measurement lines."""
if distance_analysis is None:
return image
img = image.copy()
connections = distance_analysis['connection_points']
distances = [d['distance'] for d in distance_analysis['distances']]
if not distances:
return img
min_d, max_d = min(distances), max(distances)
dist_range = max_d - min_d if max_d != min_d else 1
for i in range(0, len(connections), 3): # draw every 3rd to reduce clutter
(x1, y1), (x2, y2) = connections[i]
dist = distances[i]
norm = (dist - min_d) / dist_range
color = (0, int(255 * (1 - norm)), int(255 * norm)) # Green→Red
cv2.line(img, (x1, y1), (x2, y2), color, max(1, int(2 + 2 * norm)))
return img
def analyze_image(self, image_path):
"""
Perform full analysis on a dental image:
segmentation, detection, distance measurement, and visualization.
"""
img_bgr = cv2.imread(image_path)
if img_bgr is None:
raise FileNotFoundError(f"Could not read image: {image_path}")
h_orig, w_orig = img_bgr.shape[:2]
cej_unet, abc_unet, _ = self.run_unet(img_bgr)
cej_mask = self.resize_mask_to_original(cej_unet, (h_orig, w_orig))
abc_mask = self.resize_mask_to_original(abc_unet, (h_orig, w_orig))
detections = self.detect_teeth(img_bgr)
print(f"Detected {len(detections)} teeth")
# Create a color overlay for CEJ (red) and ABC (blue)
overlay = img_bgr.copy()
overlay[cej_mask == 1] = (0, 0, 255) # Red for CEJ
overlay[abc_mask == 1] = (255, 0, 0) # Blue for ABC
# Blend overlay with original
alpha = 0.4
combined = cv2.addWeighted(overlay, alpha, img_bgr, 1 - alpha, 0)
all_results = []
for det in detections:
x1, y1, x2, y2 = det["bbox"].astype(int)
cv2.rectangle(combined, (x1, y1), (x2, y2), (0, 255, 0), 3)
cv2.putText(
combined,
f"Tooth {det['tooth_id']}",
(x1, max(y1 - 5, 15)), # avoid going above top edge
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
1,
cv2.LINE_AA
)
# 🔴🔵 Extract and draw CEJ/ABC lines
abc_line = self.extract_abc_uppermost_line_within_bbox(abc_mask, (x1, y1, x2, y2))
cej_line = self.extract_cej_lowermost_line_within_bbox(cej_mask, (x1, y1, x2, y2))
if abc_line is not None:
abc_line = self.smooth_landmarks(abc_line)
cv2.polylines(combined, [abc_line], False, (255, 0, 0), 2)
if cej_line is not None:
cej_line = self.smooth_landmarks(cej_line)
cv2.polylines(combined, [cej_line], False, (0, 0, 255), 2)
# 📏 Compute CEJ–ABC distances
distance_analysis = None
if cej_line is not None and abc_line is not None:
distance_analysis = self.compute_cej_abc_distances(cej_line, abc_line)
if distance_analysis:
combined = self.draw_distance_measurements(combined, distance_analysis, det["tooth_id"])
# 🧮 NEW: Add text showing the mean CEJ–ABC distance
mean_d = distance_analysis["mean_distance"]
cv2.putText(
combined,
f"{mean_d:.1f}px",
(x1 + 5, y2 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0), # black outline
3, # thicker for contrast
cv2.LINE_AA
)
# Draw main text on top (yellow or white)
cv2.putText(
combined,
f"{mean_d:.1f}px",
(x1 + 5, y2 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 255), # yellow; try (255,255,255) for white
1,
cv2.LINE_AA
)
all_results.append({
"tooth_id": det["tooth_id"],
"analysis": distance_analysis
})
return {
"original": img_bgr,
"cej_mask": cej_mask,
"abc_mask": abc_mask,
"detections": detections,
"combined": combined,
"distance_analyses": all_results
}
if __name__ == "__main__":
unet_model = "unet.keras" # Keras model
yolo_model = "best2.pt" # YOLOv8 PyTorch model
image_path = "trial.jpg"
seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model)
res = seg.analyze_image(image_path)
plt.figure(figsize=(12, 8))
plt.imshow(cv2.cvtColor(res["combined"], cv2.COLOR_BGR2RGB))
plt.title("Dental CEJ–ABC Analysis Result")
plt.axis("off")
plt.show()