jayn95's picture
Update app.py
095f5cb verified
# app.py
import gradio as gr
import cv2
from periodontitis_detection import SimpleDentalSegmentationNoEnhance
# ==========================
# 1️⃣ Load models once
# ==========================
model = SimpleDentalSegmentationNoEnhance(
unet_model_path="unet.keras", # same filenames as your repo
yolo_model_path="best2.pt"
)
# ==========================
# 2️⃣ Define wrapper for Gradio
# ==========================
def detect_periodontitis(image_np):
"""
Gradio sends image as a NumPy RGB array.
We temporarily save it to a file path since analyze_image() needs a path.
"""
temp_path = "temp_input.jpg"
cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
# Run full pipeline
results = model.analyze_image(temp_path)
# Convert OpenCV BGR → RGB for Gradio display
combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
# Optional: summarize measurements for text output
summaries = []
for tooth in results["distance_analyses"]:
tooth_id = tooth["tooth_id"]
analysis = tooth["analysis"]
if analysis:
mean_d = analysis["mean_distance"]
summaries.append(f"Tooth {tooth_id}: mean={mean_d:.2f}px")
else:
summaries.append(f"Tooth {tooth_id}: no valid CEJ–ABC measurement")
summary_text = "\n".join(summaries) if summaries else "No detections found."
return combined_rgb, summary_text
# ==========================
# 3️⃣ Build Gradio Interface
# ==========================
demo = gr.Interface(
fn=detect_periodontitis,
inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
outputs=[
gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
gr.Textbox(label="Analysis Summary"),
],
title="🦷 Periodontitis Detection & Analysis",
description=(
"Automatically detects teeth (YOLOv8), segments CEJ/ABC (U-Net), "
"and measures CEJ–ABC distances per tooth to assess bone loss."
),
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)