jayn95's picture
Update app.py
a43bc51 verified
# app.py
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ExifTags
import os
from periodontitis_detection import SimpleDentalSegmentationNoEnhance
# ==========================
# Load model
# ==========================
model = SimpleDentalSegmentationNoEnhance(
unet_model_path="models/unet/best.keras",
yolo_model_path="best2.pt"
)
# ====================================================
# 1. Read DPI from metadata (EXIF / PNG)
# ====================================================
def read_dpi(path):
try:
img = Image.open(path)
info = img.info
# PIL standard DPI field
if "dpi" in info:
d = info["dpi"]
if isinstance(d, (tuple, list)):
return float(d[0])
return float(d)
# EXIF resolution (rare on xrays)
exif = img._getexif()
if exif:
for tag_id, value in exif.items():
tag = ExifTags.TAGS.get(tag_id, tag_id)
if tag == "XResolution":
if isinstance(value, tuple) and value[1] != 0:
return float(value[0]) / float(value[1])
return float(value)
except:
pass
return None
# ====================================================
# 2. Detect 1 mm tick spacing
# ====================================================
def detect_tick_mm(path):
try:
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if img is None:
return None
h, w = img.shape
# Right-side crop (where ruler usually is)
crop = img[:, int(w * 0.80):]
# Threshold for tick marks
blur = cv2.GaussianBlur(crop, (5, 5), 0)
_, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
edges = cv2.Canny(thr, 50, 150)
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30,
minLineLength=10, maxLineGap=5)
if lines is None:
return None
ys = []
for l in lines:
x1, y1, x2, y2 = l[0]
if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5:
ys.append(y1)
if len(ys) < 3:
return None
ys = sorted(ys)
diffs = np.diff(ys)
diffs = diffs[(diffs > 2) & (diffs < h // 2)]
if len(diffs) == 0:
return None
px_per_mm = float(np.mean(diffs))
return px_per_mm
except:
return None
# ====================================================
# 3. Compute mm per pixel
# ====================================================
def compute_mm_per_pixel(path):
# A) Metadata DPI
dpi = read_dpi(path)
if dpi and dpi > 1:
return (25.4 / dpi), "metadata"
# B) Tick marks (1 mm)
tick = detect_tick_mm(path)
if tick and tick > 0:
return (1.0 / tick), "tickmarks"
# C) Fallback 300 DPI
return (25.4 / 453.5714), "fallback"
# ==========================
# Wrapped function
# ==========================
def detect_periodontitis(image_np):
# Save temporary image for model + mm scaling
temp_path = "temp_input.jpg"
cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
# Run periodontitis detection
results = model.analyze_image(temp_path)
# Convert combined BGR → RGB for display
combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
# Compute mm scaling
mm_per_px, method = compute_mm_per_pixel(temp_path)
# Summaries — CLEAN (no method labels)
summaries = []
has_periodontitis = False # flag
for tooth in results["distance_analyses"]:
tid = tooth["tooth_id"]
analysis = tooth["analysis"]
if analysis:
px = analysis["mean_distance"]
mm = px * mm_per_px
summaries.append(f"Tooth {tid}: {mm:.2f} mm")
if mm > 2.0:
has_periodontitis = True
else:
summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement")
summary_text = "\n".join(summaries)
# Add interpretation
if has_periodontitis:
summary_text += "\n\n⚠️ You have periodontitis."
else:
summary_text += "\n\n✅ You don't have periodontitis."
# Remove temp
try:
os.remove(temp_path)
except:
pass
return combined_rgb, summary_text
# ==========================
# Gradio Interface
# ==========================
demo = gr.Interface(
fn=detect_periodontitis,
inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
outputs=[
gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
gr.Textbox(label="Analysis Summary (mm)"),
],
title="🦷 Periodontitis Detection & Analysis (mm accurate)",
description="Outputs CEJ–ABC distances in millimeters."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)