Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, ExifTags | |
| import os | |
| from periodontitis_detection import SimpleDentalSegmentationNoEnhance | |
| # ========================== | |
| # Load model | |
| # ========================== | |
| model = SimpleDentalSegmentationNoEnhance( | |
| unet_model_path="models/unet/best.keras", | |
| yolo_model_path="best2.pt" | |
| ) | |
| # ==================================================== | |
| # 1. Read DPI from metadata (EXIF / PNG) | |
| # ==================================================== | |
| def read_dpi(path): | |
| try: | |
| img = Image.open(path) | |
| info = img.info | |
| # PIL standard DPI field | |
| if "dpi" in info: | |
| d = info["dpi"] | |
| if isinstance(d, (tuple, list)): | |
| return float(d[0]) | |
| return float(d) | |
| # EXIF resolution (rare on xrays) | |
| exif = img._getexif() | |
| if exif: | |
| for tag_id, value in exif.items(): | |
| tag = ExifTags.TAGS.get(tag_id, tag_id) | |
| if tag == "XResolution": | |
| if isinstance(value, tuple) and value[1] != 0: | |
| return float(value[0]) / float(value[1]) | |
| return float(value) | |
| except: | |
| pass | |
| return None | |
| # ==================================================== | |
| # 2. Detect 1 mm tick spacing | |
| # ==================================================== | |
| def detect_tick_mm(path): | |
| try: | |
| img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| if img is None: | |
| return None | |
| h, w = img.shape | |
| # Right-side crop (where ruler usually is) | |
| crop = img[:, int(w * 0.80):] | |
| # Threshold for tick marks | |
| blur = cv2.GaussianBlur(crop, (5, 5), 0) | |
| _, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| edges = cv2.Canny(thr, 50, 150) | |
| lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30, | |
| minLineLength=10, maxLineGap=5) | |
| if lines is None: | |
| return None | |
| ys = [] | |
| for l in lines: | |
| x1, y1, x2, y2 = l[0] | |
| if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5: | |
| ys.append(y1) | |
| if len(ys) < 3: | |
| return None | |
| ys = sorted(ys) | |
| diffs = np.diff(ys) | |
| diffs = diffs[(diffs > 2) & (diffs < h // 2)] | |
| if len(diffs) == 0: | |
| return None | |
| px_per_mm = float(np.mean(diffs)) | |
| return px_per_mm | |
| except: | |
| return None | |
| # ==================================================== | |
| # 3. Compute mm per pixel | |
| # ==================================================== | |
| def compute_mm_per_pixel(path): | |
| # A) Metadata DPI | |
| dpi = read_dpi(path) | |
| if dpi and dpi > 1: | |
| return (25.4 / dpi), "metadata" | |
| # B) Tick marks (1 mm) | |
| tick = detect_tick_mm(path) | |
| if tick and tick > 0: | |
| return (1.0 / tick), "tickmarks" | |
| # C) Fallback 300 DPI | |
| return (25.4 / 453.5714), "fallback" | |
| # ========================== | |
| # Wrapped function | |
| # ========================== | |
| def detect_periodontitis(image_np): | |
| # Save temporary image for model + mm scaling | |
| temp_path = "temp_input.jpg" | |
| cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) | |
| # Run periodontitis detection | |
| results = model.analyze_image(temp_path) | |
| # Convert combined BGR → RGB for display | |
| combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB) | |
| # Compute mm scaling | |
| mm_per_px, method = compute_mm_per_pixel(temp_path) | |
| # Summaries — CLEAN (no method labels) | |
| summaries = [] | |
| has_periodontitis = False # flag | |
| for tooth in results["distance_analyses"]: | |
| tid = tooth["tooth_id"] | |
| analysis = tooth["analysis"] | |
| if analysis: | |
| px = analysis["mean_distance"] | |
| mm = px * mm_per_px | |
| summaries.append(f"Tooth {tid}: {mm:.2f} mm") | |
| if mm > 2.0: | |
| has_periodontitis = True | |
| else: | |
| summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement") | |
| summary_text = "\n".join(summaries) | |
| # Add interpretation | |
| if has_periodontitis: | |
| summary_text += "\n\n⚠️ You have periodontitis." | |
| else: | |
| summary_text += "\n\n✅ You don't have periodontitis." | |
| # Remove temp | |
| try: | |
| os.remove(temp_path) | |
| except: | |
| pass | |
| return combined_rgb, summary_text | |
| # ========================== | |
| # Gradio Interface | |
| # ========================== | |
| demo = gr.Interface( | |
| fn=detect_periodontitis, | |
| inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"), | |
| outputs=[ | |
| gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"), | |
| gr.Textbox(label="Analysis Summary (mm)"), | |
| ], | |
| title="🦷 Periodontitis Detection & Analysis (mm accurate)", | |
| description="Outputs CEJ–ABC distances in millimeters." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |