# app.py import gradio as gr import cv2 import numpy as np from PIL import Image, ExifTags import os from periodontitis_detection import SimpleDentalSegmentationNoEnhance # ========================== # Load model # ========================== model = SimpleDentalSegmentationNoEnhance( unet_model_path="models/unet/best.keras", yolo_model_path="best2.pt" ) # ==================================================== # 1. Read DPI from metadata (EXIF / PNG) # ==================================================== def read_dpi(path): try: img = Image.open(path) info = img.info # PIL standard DPI field if "dpi" in info: d = info["dpi"] if isinstance(d, (tuple, list)): return float(d[0]) return float(d) # EXIF resolution (rare on xrays) exif = img._getexif() if exif: for tag_id, value in exif.items(): tag = ExifTags.TAGS.get(tag_id, tag_id) if tag == "XResolution": if isinstance(value, tuple) and value[1] != 0: return float(value[0]) / float(value[1]) return float(value) except: pass return None # ==================================================== # 2. Detect 1 mm tick spacing # ==================================================== def detect_tick_mm(path): try: img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if img is None: return None h, w = img.shape # Right-side crop (where ruler usually is) crop = img[:, int(w * 0.80):] # Threshold for tick marks blur = cv2.GaussianBlur(crop, (5, 5), 0) _, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) edges = cv2.Canny(thr, 50, 150) lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30, minLineLength=10, maxLineGap=5) if lines is None: return None ys = [] for l in lines: x1, y1, x2, y2 = l[0] if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5: ys.append(y1) if len(ys) < 3: return None ys = sorted(ys) diffs = np.diff(ys) diffs = diffs[(diffs > 2) & (diffs < h // 2)] if len(diffs) == 0: return None px_per_mm = float(np.mean(diffs)) return px_per_mm except: return None # ==================================================== # 3. Compute mm per pixel # ==================================================== def compute_mm_per_pixel(path): # A) Metadata DPI dpi = read_dpi(path) if dpi and dpi > 1: return (25.4 / dpi), "metadata" # B) Tick marks (1 mm) tick = detect_tick_mm(path) if tick and tick > 0: return (1.0 / tick), "tickmarks" # C) Fallback 300 DPI return (25.4 / 453.5714), "fallback" # ========================== # Wrapped function # ========================== def detect_periodontitis(image_np): # Save temporary image for model + mm scaling temp_path = "temp_input.jpg" cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) # Run periodontitis detection results = model.analyze_image(temp_path) # Convert combined BGR → RGB for display combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB) # Compute mm scaling mm_per_px, method = compute_mm_per_pixel(temp_path) # Summaries — CLEAN (no method labels) summaries = [] has_periodontitis = False # flag for tooth in results["distance_analyses"]: tid = tooth["tooth_id"] analysis = tooth["analysis"] if analysis: px = analysis["mean_distance"] mm = px * mm_per_px summaries.append(f"Tooth {tid}: {mm:.2f} mm") if mm > 2.0: has_periodontitis = True else: summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement") summary_text = "\n".join(summaries) # Add interpretation if has_periodontitis: summary_text += "\n\n⚠️ You have periodontitis." else: summary_text += "\n\n✅ You don't have periodontitis." # Remove temp try: os.remove(temp_path) except: pass return combined_rgb, summary_text # ========================== # Gradio Interface # ========================== demo = gr.Interface( fn=detect_periodontitis, inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"), outputs=[ gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"), gr.Textbox(label="Analysis Summary (mm)"), ], title="🦷 Periodontitis Detection & Analysis (mm accurate)", description="Outputs CEJ–ABC distances in millimeters." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)