NelAlan commited on
Commit
c640a6f
·
verified ·
1 Parent(s): a81f7b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -23
app.py CHANGED
@@ -1,65 +1,169 @@
1
  # app.py
2
  import gradio as gr
3
  import cv2
 
 
 
 
4
  from periodontitis_detection import SimpleDentalSegmentationNoEnhance
5
 
6
  # ==========================
7
- # 1️⃣ Load models once
8
  # ==========================
9
  model = SimpleDentalSegmentationNoEnhance(
10
- unet_model_path="unet.keras", # same filenames as your repo
11
  yolo_model_path="best2.pt"
12
  )
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # ==========================
15
- # 2️⃣ Define wrapper for Gradio
16
  # ==========================
17
  def detect_periodontitis(image_np):
18
- """
19
- Gradio sends image as a NumPy RGB array.
20
- We temporarily save it to a file path since analyze_image() needs a path.
21
- """
22
  temp_path = "temp_input.jpg"
23
  cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
24
 
25
- # Run full pipeline
26
  results = model.analyze_image(temp_path)
27
 
28
- # Convert OpenCV BGR → RGB for Gradio display
29
  combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
30
 
31
- # Optional: summarize measurements for text output
 
 
 
32
  summaries = []
33
  for tooth in results["distance_analyses"]:
34
- tooth_id = tooth["tooth_id"]
35
  analysis = tooth["analysis"]
 
36
  if analysis:
37
- mean_d = analysis["mean_distance"]
38
- summaries.append(f"Tooth {tooth_id}: mean={mean_d:.2f}px")
 
39
  else:
40
- summaries.append(f"Tooth {tooth_id}: no valid CEJ–ABC measurement")
 
 
41
 
42
- summary_text = "\n".join(summaries) if summaries else "No detections found."
 
 
 
 
43
 
44
  return combined_rgb, summary_text
45
 
46
 
47
  # ==========================
48
- # 3️⃣ Build Gradio Interface
49
  # ==========================
50
  demo = gr.Interface(
51
  fn=detect_periodontitis,
52
  inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
53
  outputs=[
54
  gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
55
- gr.Textbox(label="Analysis Summary"),
56
  ],
57
- title="🦷 Periodontitis Detection & Analysis",
58
- description=(
59
- "Automatically detects teeth (YOLOv8), segments CEJ/ABC (U-Net), "
60
- "and measures CEJ–ABC distances per tooth to assess bone loss."
61
- ),
62
  )
63
 
64
  if __name__ == "__main__":
65
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
  # app.py
2
  import gradio as gr
3
  import cv2
4
+ import numpy as np
5
+ from PIL import Image, ExifTags
6
+ import os
7
+
8
  from periodontitis_detection import SimpleDentalSegmentationNoEnhance
9
 
10
  # ==========================
11
+ # Load model
12
  # ==========================
13
  model = SimpleDentalSegmentationNoEnhance(
14
+ unet_model_path="unet.keras",
15
  yolo_model_path="best2.pt"
16
  )
17
 
18
+ # ====================================================
19
+ # 1. Read DPI from metadata (EXIF / PNG)
20
+ # ====================================================
21
+ def read_dpi(path):
22
+ try:
23
+ img = Image.open(path)
24
+ info = img.info
25
+
26
+ # PIL standard DPI field
27
+ if "dpi" in info:
28
+ d = info["dpi"]
29
+ if isinstance(d, (tuple, list)):
30
+ return float(d[0])
31
+ return float(d)
32
+
33
+ # EXIF resolution (rare on xrays)
34
+ exif = img._getexif()
35
+ if exif:
36
+ for tag_id, value in exif.items():
37
+ tag = ExifTags.TAGS.get(tag_id, tag_id)
38
+ if tag == "XResolution":
39
+ if isinstance(value, tuple) and value[1] != 0:
40
+ return float(value[0]) / float(value[1])
41
+ return float(value)
42
+ except:
43
+ pass
44
+
45
+ return None
46
+
47
+
48
+ # ====================================================
49
+ # 2. Detect 1 mm tick spacing
50
+ # ====================================================
51
+ def detect_tick_mm(path):
52
+ try:
53
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
54
+ if img is None:
55
+ return None
56
+
57
+ h, w = img.shape
58
+
59
+ # Right-side crop (where ruler usually is)
60
+ crop = img[:, int(w * 0.80):]
61
+
62
+ # Threshold for tick marks
63
+ blur = cv2.GaussianBlur(crop, (5, 5), 0)
64
+ _, thr = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
65
+
66
+ edges = cv2.Canny(thr, 50, 150)
67
+
68
+ lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=30,
69
+ minLineLength=10, maxLineGap=5)
70
+ if lines is None:
71
+ return None
72
+
73
+ ys = []
74
+ for l in lines:
75
+ x1, y1, x2, y2 = l[0]
76
+ if abs(y1 - y2) <= 3 and abs(x2 - x1) > 5:
77
+ ys.append(y1)
78
+
79
+ if len(ys) < 3:
80
+ return None
81
+
82
+ ys = sorted(ys)
83
+ diffs = np.diff(ys)
84
+
85
+ diffs = diffs[(diffs > 2) & (diffs < h // 2)]
86
+ if len(diffs) == 0:
87
+ return None
88
+
89
+ px_per_mm = float(np.mean(diffs))
90
+ return px_per_mm
91
+ except:
92
+ return None
93
+
94
+
95
+ # ====================================================
96
+ # 3. Compute mm per pixel
97
+ # ====================================================
98
+ def compute_mm_per_pixel(path):
99
+ # A) Metadata DPI
100
+ dpi = read_dpi(path)
101
+ if dpi and dpi > 1:
102
+ return (25.4 / dpi), "metadata"
103
+
104
+ # B) Tick marks (1 mm)
105
+ tick = detect_tick_mm(path)
106
+ if tick and tick > 0:
107
+ return (1.0 / tick), "tickmarks"
108
+
109
+ # C) Fallback 300 DPI
110
+ return (25.4 / 300.0), "fallback"
111
+
112
+
113
  # ==========================
114
+ # Wrapped function
115
  # ==========================
116
  def detect_periodontitis(image_np):
117
+ # Save temporary image for model + mm scaling
 
 
 
118
  temp_path = "temp_input.jpg"
119
  cv2.imwrite(temp_path, cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
120
 
121
+ # Run periodontitis detection
122
  results = model.analyze_image(temp_path)
123
 
124
+ # Convert combined BGR → RGB for display
125
  combined_rgb = cv2.cvtColor(results["combined"], cv2.COLOR_BGR2RGB)
126
 
127
+ # Compute mm scaling
128
+ mm_per_px, method = compute_mm_per_pixel(temp_path)
129
+
130
+ # Summaries — CLEAN (no method labels)
131
  summaries = []
132
  for tooth in results["distance_analyses"]:
133
+ tid = tooth["tooth_id"]
134
  analysis = tooth["analysis"]
135
+
136
  if analysis:
137
+ px = analysis["mean_distance"]
138
+ mm = px * mm_per_px
139
+ summaries.append(f"Tooth {tid}: {mm:.2f} mm")
140
  else:
141
+ summaries.append(f"Tooth {tid}: no valid CEJ–ABC measurement")
142
+
143
+ summary_text = "\n".join(summaries)
144
 
145
+ # Remove temp
146
+ try:
147
+ os.remove(temp_path)
148
+ except:
149
+ pass
150
 
151
  return combined_rgb, summary_text
152
 
153
 
154
  # ==========================
155
+ # Gradio Interface
156
  # ==========================
157
  demo = gr.Interface(
158
  fn=detect_periodontitis,
159
  inputs=gr.Image(type="numpy", label="Upload Dental X-Ray"),
160
  outputs=[
161
  gr.Image(label="Final Annotated Image (YOLO + CEJ–ABC)"),
162
+ gr.Textbox(label="Analysis Summary (mm)"),
163
  ],
164
+ title="🦷 Periodontitis Detection & Analysis (mm accurate)",
165
+ description="Outputs CEJ–ABC distances in millimeters."
 
 
 
166
  )
167
 
168
  if __name__ == "__main__":
169
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)