File size: 4,975 Bytes
25f3eef
 
 
c640a6f
 
 
 
25f3eef
 
 
c640a6f
25f3eef
 
2da1714
0c26922
25f3eef
 
c640a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8e03fd
c640a6f
 
25f3eef
c640a6f
25f3eef
 
c640a6f
25f3eef
 
 
c640a6f
25f3eef
 
c640a6f
25f3eef
 
c640a6f
 
 
0f88e64
 
de688e6
 
25f3eef
c640a6f
25f3eef
c640a6f
25f3eef
c640a6f
 
a43bc51
de688e6
 
071f625
a43bc51
c640a6f
0f88e64
25f3eef
de688e6
 
0aa756a
de688e6
0aa756a
de688e6
0f88e64
c640a6f
 
 
 
25f3eef
0f88e64
de688e6
25f3eef
c640a6f
25f3eef
 
 
 
 
 
c640a6f
25f3eef
c640a6f
 
25f3eef
 
 
c640a6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# app.py
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ExifTags
import os

from periodontitis_detection import SimpleDentalSegmentationNoEnhance

# ==========================
# Load model
# ==========================
model = SimpleDentalSegmentationNoEnhance(
    unet_model_path="models/unet/best.keras",
    yolo_model_path="best2.pt"
)

# ====================================================
# 1. Read DPI from metadata (EXIF / PNG)
# ====================================================
def read_dpi(path):
    try:
        img = Image.open(path)
        info = img.info

        # PIL standard DPI field
        if "dpi" in info:
            d = info["dpi"]
            if isinstance(d, (tuple, list)):
                return float(d[0])
            return float(d)

        # EXIF resolution (rare on xrays)
        exif = img._getexif()
        if exif:
            for tag_id, value in exif.items():
                tag = ExifTags.TAGS.get(tag_id, tag_id)
                if tag == "XResolution":
                    if isinstance(value, tuple) and value[1] != 0:
                        return float(value[0]) / float(value[1])
                    return float(value)
    except:
        pass

    return None


# ====================================================
# 2. Detect 1 mm tick spacing
# ====================================================
def detect_tick_mm(path):
    try:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            return None

        h, w = img.shape

        # Right-side crop (where ruler usually is)
        crop = img[:, int(w * 0.80):]

        # Threshold for tick marks
        blur = cv2.GaussianBlur(crop, (5, 5), 0)
        _, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        edges = cv2.Canny(thr, 50, 150)

        lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30,
                                minLineLength=10, maxLineGap=5)
        if lines is None:
            return None

        ys = []
        for l in lines:
            x1, y1, x2, y2 = l[0]
            if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5:
                ys.append(y1)

        if len(ys) < 3:
            return None

        ys = sorted(ys)
        diffs = np.diff(ys)

        diffs = diffs[(diffs > 2) & (diffs < h // 2)]
        if len(diffs) == 0:
            return None

        px_per_mm = float(np.mean(diffs))
        return px_per_mm
    except:
        return None


# ====================================================
# 3. Compute mm per pixel
# ====================================================
def compute_mm_per_pixel(path):
    # A) Metadata DPI
    dpi = read_dpi(path)
    if dpi and dpi > 1:
        return (25.4 / dpi), "metadata"

    # B) Tick marks (1 mm)
    tick = detect_tick_mm(path)
    if tick and tick > 0:
        return (1.0 / tick), "tickmarks"

    # C) Fallback 300 DPI
    return (25.4 / 453.5714), "fallback"


# ==========================
# Wrapped function
# ==========================
def detect_periodontitis(image_np):
    # Save temporary image for model + mm scaling
    temp_path = "temp_input.jpg"
    cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))

    # Run periodontitis detection
    results = model.analyze_image(temp_path)

    # Convert combined BGR → RGB for display
    combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)

    # Compute mm scaling
    mm_per_px, method = compute_mm_per_pixel(temp_path)

    # Summaries — CLEAN (no method labels)
    summaries = []
    has_periodontitis = False  # flag

    for tooth in results["distance_analyses"]:
        tid = tooth["tooth_id"]
        analysis = tooth["analysis"]

        if analysis:
            px = analysis["mean_distance"]
            mm = px * mm_per_px
            summaries.append(f"Tooth {tid}: {mm:.2f} mm")
            if mm > 2.0:
                has_periodontitis = True
        else:
            summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement")

    summary_text = "\n".join(summaries)

    # Add interpretation
    if has_periodontitis:
        summary_text += "\n\n⚠️ You have periodontitis."
    else:
        summary_text += "\n\n✅ You don't have periodontitis."

    # Remove temp
    try:
        os.remove(temp_path)
    except:
        pass

    return combined_rgb, summary_text
    
# ==========================
# Gradio Interface
# ==========================
demo = gr.Interface(
    fn=detect_periodontitis,
    inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
    outputs=[
        gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
        gr.Textbox(label="Analysis Summary (mm)"),
    ],
    title="🦷 Periodontitis Detection & Analysis (mm accurate)",
    description="Outputs CEJ–ABC distances in millimeters."
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)