Spaces:
Sleeping
Sleeping
File size: 4,975 Bytes
25f3eef c640a6f 25f3eef c640a6f 25f3eef 2da1714 0c26922 25f3eef c640a6f a8e03fd c640a6f 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f 0f88e64 de688e6 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f a43bc51 de688e6 071f625 a43bc51 c640a6f 0f88e64 25f3eef de688e6 0aa756a de688e6 0aa756a de688e6 0f88e64 c640a6f 25f3eef 0f88e64 de688e6 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f 25f3eef c640a6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | # app.py
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ExifTags
import os
from periodontitis_detection import SimpleDentalSegmentationNoEnhance
# ==========================
# Load model
# ==========================
model = SimpleDentalSegmentationNoEnhance(
unet_model_path="models/unet/best.keras",
yolo_model_path="best2.pt"
)
# ====================================================
# 1. Read DPI from metadata (EXIF / PNG)
# ====================================================
def read_dpi(path):
try:
img = Image.open(path)
info = img.info
# PIL standard DPI field
if "dpi" in info:
d = info["dpi"]
if isinstance(d, (tuple, list)):
return float(d[0])
return float(d)
# EXIF resolution (rare on xrays)
exif = img._getexif()
if exif:
for tag_id, value in exif.items():
tag = ExifTags.TAGS.get(tag_id, tag_id)
if tag == "XResolution":
if isinstance(value, tuple) and value[1] != 0:
return float(value[0]) / float(value[1])
return float(value)
except:
pass
return None
# ====================================================
# 2. Detect 1 mm tick spacing
# ====================================================
def detect_tick_mm(path):
try:
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if img is None:
return None
h, w = img.shape
# Right-side crop (where ruler usually is)
crop = img[:, int(w * 0.80):]
# Threshold for tick marks
blur = cv2.GaussianBlur(crop, (5, 5), 0)
_, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
edges = cv2.Canny(thr, 50, 150)
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30,
minLineLength=10, maxLineGap=5)
if lines is None:
return None
ys = []
for l in lines:
x1, y1, x2, y2 = l[0]
if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5:
ys.append(y1)
if len(ys) < 3:
return None
ys = sorted(ys)
diffs = np.diff(ys)
diffs = diffs[(diffs > 2) & (diffs < h // 2)]
if len(diffs) == 0:
return None
px_per_mm = float(np.mean(diffs))
return px_per_mm
except:
return None
# ====================================================
# 3. Compute mm per pixel
# ====================================================
def compute_mm_per_pixel(path):
# A) Metadata DPI
dpi = read_dpi(path)
if dpi and dpi > 1:
return (25.4 / dpi), "metadata"
# B) Tick marks (1 mm)
tick = detect_tick_mm(path)
if tick and tick > 0:
return (1.0 / tick), "tickmarks"
# C) Fallback 300 DPI
return (25.4 / 453.5714), "fallback"
# ==========================
# Wrapped function
# ==========================
def detect_periodontitis(image_np):
# Save temporary image for model + mm scaling
temp_path = "temp_input.jpg"
cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
# Run periodontitis detection
results = model.analyze_image(temp_path)
# Convert combined BGR → RGB for display
combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
# Compute mm scaling
mm_per_px, method = compute_mm_per_pixel(temp_path)
# Summaries — CLEAN (no method labels)
summaries = []
has_periodontitis = False # flag
for tooth in results["distance_analyses"]:
tid = tooth["tooth_id"]
analysis = tooth["analysis"]
if analysis:
px = analysis["mean_distance"]
mm = px * mm_per_px
summaries.append(f"Tooth {tid}: {mm:.2f} mm")
if mm > 2.0:
has_periodontitis = True
else:
summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement")
summary_text = "\n".join(summaries)
# Add interpretation
if has_periodontitis:
summary_text += "\n\n⚠️ You have periodontitis."
else:
summary_text += "\n\n✅ You don't have periodontitis."
# Remove temp
try:
os.remove(temp_path)
except:
pass
return combined_rgb, summary_text
# ==========================
# Gradio Interface
# ==========================
demo = gr.Interface(
fn=detect_periodontitis,
inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
outputs=[
gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
gr.Textbox(label="Analysis Summary (mm)"),
],
title="🦷 Periodontitis Detection & Analysis (mm accurate)",
description="Outputs CEJ–ABC distances in millimeters."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|