Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from ultralytics import YOLO | |
| class SimpleDentalSegmentationNoEnhance: | |
| def __init__(self, unet_model_path, yolo_model_path, unet_input_size=(224, 224, 3)): | |
| """ | |
| Initialize the dental segmentation and analysis pipeline. | |
| """ | |
| # Load Keras U-Net model | |
| self.unet = tf.keras.models.load_model(unet_model_path) | |
| self.in_h, self.in_w, self.in_c = unet_input_size | |
| # Load YOLOv8 (PyTorch) model | |
| self.yolo = YOLO(yolo_model_path) | |
| print(f"✅ YOLO model loaded from: {yolo_model_path}") | |
| print("Models loaded successfully.") | |
| print(f"Keras U-Net input shape: {self.unet.input_shape}") | |
| print(f"Keras U-Net output shape: {self.unet.output_shape}") | |
| print(f"YOLO model loaded: {yolo_model_path}") | |
| def preprocess_for_unet(self, image_bgr): | |
| """ | |
| Prepare a BGR image for U-Net prediction. | |
| Converts to RGB, resizes, and normalizes. | |
| """ | |
| img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| img_resized = cv2.resize(img_rgb, (self.in_w, self.in_h), interpolation=cv2.INTER_LINEAR) | |
| img_norm = img_resized.astype(np.float32) / 255.0 | |
| input_tensor = np.expand_dims(img_norm, axis=0) | |
| return input_tensor, img_resized | |
| def run_unet(self, image_bgr): | |
| """ | |
| Run the Keras U-Net model on the given image. | |
| Returns CEJ and ABC masks. | |
| """ | |
| input_tensor, model_resized_image = self.preprocess_for_unet(image_bgr) | |
| preds = self.unet.predict(input_tensor, verbose=0) | |
| out = preds[0] | |
| if out.ndim == 3 and out.shape[2] >= 2: | |
| class_map = np.argmax(out, axis=2).astype(np.uint8) | |
| abc = (class_map == 1).astype(np.uint8) | |
| cej = (class_map == 2).astype(np.uint8) | |
| elif out.ndim == 3 and out.shape[2] == 1: | |
| binary = out[:, :, 0] | |
| abc = (binary > 0.5).astype(np.uint8) | |
| cej = np.zeros_like(abc) | |
| else: | |
| h, w = out.shape[:2] | |
| abc = np.zeros((h, w), dtype=np.uint8) | |
| cej = np.zeros((h, w), dtype=np.uint8) | |
| return cej, abc, model_resized_image | |
| def detect_teeth(self, image_bgr): | |
| """ | |
| Detect teeth using YOLOv8 PyTorch model. | |
| Returns bounding boxes and confidence scores. | |
| """ | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| results = self.yolo(image_rgb) | |
| print(results) | |
| detections = [] | |
| for r in results: | |
| boxes = getattr(r, "boxes", None) | |
| if boxes is None: | |
| continue | |
| for box in boxes: | |
| xyxy = box.xyxy[0].cpu().numpy() | |
| conf = float(box.conf[0].cpu().numpy()) | |
| detections.append({ | |
| "bbox": xyxy.astype(np.float32), | |
| "confidence": conf, | |
| "tooth_id": len(detections) + 1 | |
| }) | |
| return detections | |
| def resize_mask_to_original(self, mask, original_shape): | |
| """Resize a predicted mask back to original image size.""" | |
| h_orig, w_orig = original_shape | |
| mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) | |
| return (mask_resized > 127).astype(np.uint8) | |
| def extract_abc_uppermost_line_within_bbox(self, abc_mask, bbox): | |
| """Extract the uppermost ABC line within a detected tooth bounding box.""" | |
| x1, y1, x2, y2 = map(int, bbox) | |
| height, width = abc_mask.shape | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(width - 1, x2), min(height - 1, y2) | |
| abc_points = [] | |
| for x in range(x1, x2 + 1): | |
| column = np.where(abc_mask[y1:y2 + 1, x] == 1)[0] | |
| if len(column) > 0: | |
| y_absolute = y1 + np.min(column) | |
| abc_points.append([x, y_absolute]) | |
| if len(abc_points) < 2: | |
| return None | |
| return np.array(abc_points, dtype=np.int32).reshape(-1, 1, 2) | |
| def extract_cej_lowermost_line_within_bbox(self, cej_mask, bbox): | |
| """Extract the lowermost CEJ line within a detected tooth bounding box.""" | |
| x1, y1, x2, y2 = map(int, bbox) | |
| height, width = cej_mask.shape | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(width - 1, x2), min(height - 1, y2) | |
| cej_points = [] | |
| for x in range(x1, x2 + 1): | |
| column = np.where(cej_mask[y1:y2 + 1, x] == 1)[0] | |
| if len(column) > 0: | |
| y_absolute = y1 + np.max(column) | |
| cej_points.append([x, y_absolute]) | |
| if len(cej_points) < 2: | |
| return None | |
| return np.array(cej_points, dtype=np.int32).reshape(-1, 1, 2) | |
| def smooth_landmarks(self, points, window_size=5): | |
| """Smooth a polyline using a simple moving average.""" | |
| if points is None or len(points) < window_size: | |
| return points | |
| pts = points.reshape(-1, 2) | |
| smoothed = [] | |
| for i in range(len(pts)): | |
| start, end = max(0, i - window_size // 2), min(len(pts), i + window_size // 2 + 1) | |
| smoothed_y = np.mean(pts[start:end, 1]) | |
| smoothed.append([pts[i, 0], smoothed_y]) | |
| return np.array(smoothed, dtype=np.int32).reshape(-1, 1, 2) | |
| def compute_cej_abc_distances(self, cej_points, abc_points): | |
| """Compute vertical distances between CEJ and ABC points.""" | |
| if cej_points is None or abc_points is None: | |
| return None | |
| cej_2d, abc_2d = cej_points.reshape(-1, 2), abc_points.reshape(-1, 2) | |
| cej_dict = {x: y for x, y in cej_2d} | |
| abc_dict = {x: y for x, y in abc_2d} | |
| common_x = set(cej_dict.keys()) & set(abc_dict.keys()) | |
| if not common_x: | |
| return self.compute_distances_with_interpolation(cej_2d, abc_2d) | |
| distances, connections = [], [] | |
| for x in sorted(common_x): | |
| cej_y, abc_y = cej_dict[x], abc_dict[x] | |
| dist = abs(abc_y - cej_y) | |
| distances.append({'x': x, 'cej_y': cej_y, 'abc_y': abc_y, 'distance': dist}) | |
| connections.append([(x, cej_y), (x, abc_y)]) | |
| return { | |
| 'distances': distances, | |
| 'connection_points': connections, | |
| 'mean_distance': np.mean([d['distance'] for d in distances]), | |
| 'max_distance': np.max([d['distance'] for d in distances]), | |
| 'min_distance': np.min([d['distance'] for d in distances]), | |
| } | |
| def compute_distances_with_interpolation(self, cej_points, abc_points): | |
| """Interpolate CEJ and ABC lines when x-coordinates don’t match exactly.""" | |
| cej_x_min, cej_x_max = np.min(cej_points[:, 0]), np.max(cej_points[:, 0]) | |
| abc_x_min, abc_x_max = np.min(abc_points[:, 0]), np.max(abc_points[:, 0]) | |
| x_min, x_max = max(cej_x_min, abc_x_min), min(cej_x_max, abc_x_max) | |
| if x_min >= x_max: | |
| return None | |
| x_samples = np.linspace(x_min, x_max, min(50, int(x_max - x_min) + 1), dtype=int) | |
| cej_y = np.interp(x_samples, cej_points[:, 0], cej_points[:, 1]) | |
| abc_y = np.interp(x_samples, abc_points[:, 0], abc_points[:, 1]) | |
| distances, connections = [], [] | |
| for x, cy, ay in zip(x_samples, cej_y, abc_y): | |
| dist = abs(ay - cy) | |
| distances.append({'x': int(x), 'cej_y': int(cy), 'abc_y': int(ay), 'distance': dist}) | |
| connections.append([(int(x), int(cy)), (int(x), int(ay))]) | |
| return { | |
| 'distances': distances, | |
| 'connection_points': connections, | |
| 'mean_distance': np.mean([d['distance'] for d in distances]), | |
| 'max_distance': np.max([d['distance'] for d in distances]), | |
| 'min_distance': np.min([d['distance'] for d in distances]), | |
| } | |
| def draw_distance_measurements(self, image, distance_analysis, tooth_id): | |
| """Draw color-coded CEJ-ABC measurement lines.""" | |
| if distance_analysis is None: | |
| return image | |
| img = image.copy() | |
| connections = distance_analysis['connection_points'] | |
| distances = [d['distance'] for d in distance_analysis['distances']] | |
| if not distances: | |
| return img | |
| min_d, max_d = min(distances), max(distances) | |
| dist_range = max_d - min_d if max_d != min_d else 1 | |
| for i in range(0, len(connections), 3): # draw every 3rd to reduce clutter | |
| (x1, y1), (x2, y2) = connections[i] | |
| dist = distances[i] | |
| norm = (dist - min_d) / dist_range | |
| color = (0, int(255 * (1 - norm)), int(255 * norm)) # Green→Red | |
| cv2.line(img, (x1, y1), (x2, y2), color, max(1, int(2 + 2 * norm))) | |
| return img | |
| def analyze_image(self, image_path): | |
| """ | |
| Perform full analysis on a dental image: | |
| segmentation, detection, distance measurement, and visualization. | |
| """ | |
| img_bgr = cv2.imread(image_path) | |
| if img_bgr is None: | |
| raise FileNotFoundError(f"Could not read image: {image_path}") | |
| h_orig, w_orig = img_bgr.shape[:2] | |
| cej_unet, abc_unet, _ = self.run_unet(img_bgr) | |
| cej_mask = self.resize_mask_to_original(cej_unet, (h_orig, w_orig)) | |
| abc_mask = self.resize_mask_to_original(abc_unet, (h_orig, w_orig)) | |
| detections = self.detect_teeth(img_bgr) | |
| print(f"Detected {len(detections)} teeth") | |
| # Create a color overlay for CEJ (red) and ABC (blue) | |
| overlay = img_bgr.copy() | |
| overlay[cej_mask == 1] = (0, 0, 255) # Red for CEJ | |
| overlay[abc_mask == 1] = (255, 0, 0) # Blue for ABC | |
| # Blend overlay with original | |
| alpha = 0.4 | |
| combined = cv2.addWeighted(overlay, alpha, img_bgr, 1 - alpha, 0) | |
| all_results = [] | |
| for det in detections: | |
| x1, y1, x2, y2 = det["bbox"].astype(int) | |
| cv2.rectangle(combined, (x1, y1), (x2, y2), (0, 255, 0), 3) | |
| cv2.putText( | |
| combined, | |
| f"Tooth {det['tooth_id']}", | |
| (x1, max(y1 - 5, 15)), # avoid going above top edge | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 255, 0), | |
| 1, | |
| cv2.LINE_AA | |
| ) | |
| # 🔴🔵 Extract and draw CEJ/ABC lines | |
| abc_line = self.extract_abc_uppermost_line_within_bbox(abc_mask, (x1, y1, x2, y2)) | |
| cej_line = self.extract_cej_lowermost_line_within_bbox(cej_mask, (x1, y1, x2, y2)) | |
| if abc_line is not None: | |
| abc_line = self.smooth_landmarks(abc_line) | |
| cv2.polylines(combined, [abc_line], False, (255, 0, 0), 2) | |
| if cej_line is not None: | |
| cej_line = self.smooth_landmarks(cej_line) | |
| cv2.polylines(combined, [cej_line], False, (0, 0, 255), 2) | |
| # 📏 Compute CEJ–ABC distances | |
| distance_analysis = None | |
| if cej_line is not None and abc_line is not None: | |
| distance_analysis = self.compute_cej_abc_distances(cej_line, abc_line) | |
| if distance_analysis: | |
| combined = self.draw_distance_measurements(combined, distance_analysis, det["tooth_id"]) | |
| # 🧮 NEW: Add text showing the mean CEJ–ABC distance | |
| mean_d = distance_analysis["mean_distance"] | |
| cv2.putText( | |
| combined, | |
| f"{mean_d:.1f}px", | |
| (x1 + 5, y2 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 0), # black outline | |
| 3, # thicker for contrast | |
| cv2.LINE_AA | |
| ) | |
| # Draw main text on top (yellow or white) | |
| cv2.putText( | |
| combined, | |
| f"{mean_d:.1f}px", | |
| (x1 + 5, y2 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 255, 255), # yellow; try (255,255,255) for white | |
| 1, | |
| cv2.LINE_AA | |
| ) | |
| all_results.append({ | |
| "tooth_id": det["tooth_id"], | |
| "analysis": distance_analysis | |
| }) | |
| return { | |
| "original": img_bgr, | |
| "cej_mask": cej_mask, | |
| "abc_mask": abc_mask, | |
| "detections": detections, | |
| "combined": combined, | |
| "distance_analyses": all_results | |
| } | |
| if __name__ == "__main__": | |
| unet_model = "unet.keras" # Keras model | |
| yolo_model = "best2.pt" # YOLOv8 PyTorch model | |
| image_path = "trial.jpg" | |
| seg = SimpleDentalSegmentationNoEnhance(unet_model, yolo_model) | |
| res = seg.analyze_image(image_path) | |
| plt.figure(figsize=(12, 8)) | |
| plt.imshow(cv2.cvtColor(res["combined"], cv2.COLOR_BGR2RGB)) | |
| plt.title("Dental CEJ–ABC Analysis Result") | |
| plt.axis("off") | |
| plt.show() | |