mujtaba1212 commited on
Commit
fd19e5e
Β·
verified Β·
1 Parent(s): 7eebeae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -306
app.py CHANGED
@@ -1,45 +1,30 @@
1
- """
2
- OrthoTimes QuickCephTool β€” HRNet Landmark Detection API
3
- Hugging Face Space backend.
4
-
5
- Loads cwlachap/hrnet-cephalometric-landmark-detection, runs inference,
6
- and returns 19 landmark coordinates as normalised (0-1) JSON.
7
-
8
- The /detect endpoint is called by the HTML tool with a base64-encoded image.
9
- """
10
-
11
  import io
12
  import json
13
  import base64
14
  import numpy as np
15
- from PIL import Image
16
  import torch
17
  import torch.nn as nn
18
- import torch.nn.functional as F
19
  from huggingface_hub import hf_hub_download
20
  import gradio as gr
21
 
22
- # ─── Model architecture (HRNet-W32 heatmap output) ───────────────────────────
23
- # Minimal re-implementation matching the checkpoint structure.
24
- # Adapted from the official HRNet codebase.
25
 
26
  class BasicBlock(nn.Module):
27
  expansion = 1
28
  def __init__(self, inplanes, planes, stride=1, downsample=None):
29
  super().__init__()
30
  self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
31
- self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
32
- self.relu = nn.ReLU(inplace=True)
33
  self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
34
- self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
35
  self.downsample = downsample
36
 
37
  def forward(self, x):
38
- residual = x
39
  out = self.relu(self.bn1(self.conv1(x)))
40
  out = self.bn2(self.conv2(out))
41
- if self.downsample: residual = self.downsample(x)
42
- return self.relu(out + residual)
43
 
44
 
45
  class Bottleneck(nn.Module):
@@ -47,346 +32,227 @@ class Bottleneck(nn.Module):
47
  def __init__(self, inplanes, planes, stride=1, downsample=None):
48
  super().__init__()
49
  self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
50
- self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
51
  self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False)
52
- self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
53
  self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
54
- self.bn3 = nn.BatchNorm2d(planes * 4, momentum=0.1)
55
- self.relu = nn.ReLU(inplace=True)
56
  self.downsample = downsample
57
 
58
  def forward(self, x):
59
- residual = x
60
  out = self.relu(self.bn1(self.conv1(x)))
61
  out = self.relu(self.bn2(self.conv2(out)))
62
  out = self.bn3(self.conv3(out))
63
- if self.downsample: residual = self.downsample(x)
64
- return self.relu(out + residual)
65
 
66
 
67
- class HRModule(nn.Module):
68
- def __init__(self, num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method='SUM'):
69
- super().__init__()
70
- self.num_branches = num_branches
71
- self.fuse_method = fuse_method
72
- self.branches = self._make_branches(num_branches, block, num_blocks, num_channels)
73
- self.fuse_layers = self._make_fuse_layers(num_inchannels, num_channels)
74
- self.relu = nn.ReLU(True)
75
-
76
- def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
77
- layers = []
78
- for i in range(num_blocks[branch_index]):
79
- layers.append(block(num_channels[branch_index], num_channels[branch_index]))
80
- return nn.Sequential(*layers)
81
-
82
- def _make_branches(self, num_branches, block, num_blocks, num_channels):
83
- return nn.ModuleList([
84
- self._make_one_branch(i, block, num_blocks, num_channels)
85
- for i in range(num_branches)
86
- ])
87
 
88
- def _make_fuse_layers(self, num_inchannels, num_channels):
89
- fuse_layers = []
90
- for i in range(self.num_branches):
91
- fuse_layer = []
92
- for j in range(self.num_branches):
 
 
 
 
93
  if j > i:
94
- fuse_layer.append(nn.Sequential(
95
- nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, bias=False),
96
- nn.BatchNorm2d(num_inchannels[i], momentum=0.1),
97
- nn.Upsample(scale_factor=2**(j-i), mode='nearest')
98
  ))
99
  elif j == i:
100
- fuse_layer.append(None)
101
  else:
102
- conv3x3s = []
103
  for k in range(i - j):
104
- if k == i - j - 1:
105
- conv3x3s.append(nn.Sequential(
106
- nn.Conv2d(num_inchannels[j], num_inchannels[i], 3, stride=2, padding=1, bias=False),
107
- nn.BatchNorm2d(num_inchannels[i], momentum=0.1)
108
- ))
109
- else:
110
- conv3x3s.append(nn.Sequential(
111
- nn.Conv2d(num_inchannels[j], num_inchannels[j], 3, stride=2, padding=1, bias=False),
112
- nn.BatchNorm2d(num_inchannels[j], momentum=0.1),
113
- nn.ReLU(True)
114
- ))
115
- fuse_layer.append(nn.Sequential(*conv3x3s))
116
- fuse_layers.append(nn.ModuleList(fuse_layer))
117
- return nn.ModuleList(fuse_layers)
118
 
119
  def forward(self, x):
120
- for i, branch in enumerate(self.branches):
121
- x[i] = branch(x[i])
122
- x_fuse = []
123
- for i in range(len(self.fuse_layers)):
124
- y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
125
  for j in range(1, self.num_branches):
126
- if i == j:
127
- y = y + x[j]
128
- elif j > i:
129
- y = y + self.fuse_layers[i][j](x[j])
130
- else:
131
- y = y + self.fuse_layers[i][j](x[j])
132
- x_fuse.append(self.relu(y))
133
- return x_fuse
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  class HRNet(nn.Module):
137
- """HRNet-W32 for heatmap-based landmark detection."""
138
  def __init__(self, num_joints=19):
139
  super().__init__()
140
- self.num_joints = num_joints
141
- # Stem
142
- self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)
143
- self.bn1 = nn.BatchNorm2d(64, momentum=0.1)
144
- self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False)
145
- self.bn2 = nn.BatchNorm2d(64, momentum=0.1)
146
- self.relu = nn.ReLU(inplace=True)
147
- # Layer1
148
- self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
149
- # Transition1 (256 β†’ [32, 64])
150
- self.transition1 = nn.ModuleList([
151
- nn.Sequential(nn.Conv2d(256,32,3,padding=1,bias=False), nn.BatchNorm2d(32,momentum=0.1), nn.ReLU(True)),
152
- nn.Sequential(nn.Sequential(nn.Conv2d(256,64,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(64,momentum=0.1), nn.ReLU(True)))
153
- ])
154
- # Stage2
155
- self.stage2 = nn.Sequential(HRModule(2, BasicBlock, [4,4], [32,64], [32,64]))
156
- # Transition2
157
- self.transition2 = nn.ModuleList([
158
- None, None,
159
- nn.Sequential(nn.Conv2d(64,128,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(128,momentum=0.1), nn.ReLU(True))
160
- ])
161
- # Stage3
162
- self.stage3 = nn.Sequential(*[HRModule(3, BasicBlock, [4,4,4], [32,64,128], [32,64,128]) for _ in range(4)])
163
- # Transition3
164
- self.transition3 = nn.ModuleList([
165
- None, None, None,
166
- nn.Sequential(nn.Conv2d(128,256,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(256,momentum=0.1), nn.ReLU(True))
167
  ])
168
- # Stage4
169
- self.stage4 = nn.Sequential(*[HRModule(4, BasicBlock, [4,4,4,4], [32,64,128,256], [32,64,128,256]) for _ in range(3)])
170
- # Head
171
- self.final_layer = nn.Conv2d(32, num_joints, 1)
172
-
173
- def _make_layer(self, block, inplanes, planes, blocks, stride=1):
174
- downsample = None
175
- if stride != 1 or inplanes != planes * block.expansion:
176
- downsample = nn.Sequential(
177
- nn.Conv2d(inplanes, planes*block.expansion, 1, stride=stride, bias=False),
178
- nn.BatchNorm2d(planes*block.expansion, momentum=0.1)
179
- )
180
- layers = [block(inplanes, planes, stride, downsample)]
181
- for _ in range(1, blocks):
182
- layers.append(block(planes*block.expansion, planes))
183
- return nn.Sequential(*layers)
184
 
185
  def forward(self, x):
186
- x = self.relu(self.bn1(self.conv1(x)))
187
- x = self.relu(self.bn2(self.conv2(x)))
188
  x = self.layer1(x)
189
- xl = [t(x) if t else x for t in self.transition1]
190
- xl = list(self.stage2[0](xl))
191
- xl2 = []
192
- for i, t in enumerate(self.transition2):
193
- if t is None:
194
- xl2.append(xl[i] if i < len(xl) else xl[-1])
195
- else:
196
- xl2.append(t(xl[-1]))
197
- xl = xl2
198
  for m in self.stage3:
199
- xl = m(xl)
200
- xl3 = []
201
- for i, t in enumerate(self.transition3):
202
- if t is None:
203
- xl3.append(xl[i] if i < len(xl) else xl[-1])
204
- else:
205
- xl3.append(t(xl[-1]))
206
- xl = xl3
207
  for m in self.stage4:
208
- xl = m(xl)
209
- return self.final_layer(xl[0])
210
 
211
 
212
- # ─── Load model ──────────────────────────────────────────────────────────────
213
  print("Downloading model weights...")
214
  model_path = hf_hub_download(
215
  repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
216
  filename="best_model.pth"
217
  )
 
 
218
  model = HRNet(num_joints=19)
219
- checkpoint = torch.load(model_path, map_location='cpu')
220
- state = checkpoint.get('model_state_dict', checkpoint)
221
- model.load_state_dict(state, strict=False)
222
  model.eval()
223
- print("Model loaded.")
224
-
225
- # ─── Landmark mapping ────────────────────────────────────────────────────────
226
- # HRNet model outputs 19 landmarks in this order (ISBI 2015 dataset standard)
227
- HRNET_LANDMARKS = [
228
- 'S', # 0 Sella
229
- 'N', # 1 Nasion
230
- 'Or', # 2 Orbitale
231
- 'Po', # 3 Porion
232
- 'ANS', # 4 ANS
233
- 'PNS', # 5 PNS
234
- 'A', # 6 Point A
235
- 'U1tip', # 7 Upper incisor tip
236
- 'L1tip', # 8 Lower incisor tip
237
- 'B', # 9 Point B
238
- 'Pog', # 10 Pogonion
239
- 'Me', # 11 Menton
240
- 'Gn', # 12 Gnathion
241
- 'Go', # 13 Gonion
242
- 'Co', # 14 Condylion (some datasets use Ar here)
243
- 'L1ap', # 15 Lower incisor apex
244
- 'U1ap', # 16 Upper incisor apex
245
- 'U6', # 17 Upper molar
246
- 'L6', # 18 Lower molar
247
- ]
248
-
249
- # ─── Preprocessing ───────────────────────────────────────────────────────────
250
- INPUT_W, INPUT_H = 256, 320 # model input size
251
-
252
- def preprocess(img_pil):
253
- """Convert PIL image β†’ normalised tensor [1,3,H,W]."""
254
- img = img_pil.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
255
- arr = np.array(img, dtype=np.float32) / 255.0
256
- mean = np.array([0.485, 0.456, 0.406])
257
- std = np.array([0.229, 0.224, 0.225])
258
- arr = (arr - mean) / std
259
- tensor = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).float()
260
- return tensor
261
-
262
- def heatmap_to_coords(heatmaps, orig_w, orig_h):
263
- """
264
- heatmaps: [1, num_joints, H, W]
265
- Returns list of (x_norm, y_norm) tuples in original image space.
266
- """
267
- hm = heatmaps[0] # [num_joints, H, W]
268
- num_joints, hm_h, hm_w = hm.shape
269
- coords = []
270
- for j in range(num_joints):
271
- flat = hm[j].reshape(-1)
272
- idx = int(flat.argmax())
273
- py = idx // hm_w
274
- px = idx % hm_w
275
- # Sub-pixel refinement: nudge toward neighbouring maxima
276
- if 1 <= px < hm_w-1 and 1 <= py < hm_h-1:
277
- dx = float(hm[j, py, px+1] - hm[j, py, px-1])
278
- dy = float(hm[j, py+1, px] - hm[j, py-1, px])
279
- px += 0.25 * np.sign(dx)
280
- py += 0.25 * np.sign(dy)
281
- # Normalise to 0-1 in original image space
282
- x_norm = (px / hm_w) * (INPUT_W / orig_w)
283
- y_norm = (py / hm_h) * (INPUT_H / orig_h)
284
- # Clamp
285
- x_norm = float(np.clip(x_norm, 0.0, 1.0))
286
- y_norm = float(np.clip(y_norm, 0.0, 1.0))
287
- coords.append((x_norm, y_norm))
288
- return coords
289
 
290
- # ─── Inference function ───────────────────────────────────────────────────────
291
- def detect_landmarks(image_b64: str, mime_type: str = "image/jpeg") -> str:
292
- """
293
- Accepts base64-encoded image string.
294
- Returns JSON: {"landmarks": {"S": {"x":..,"y":..}, ...}}
295
- """
296
- try:
297
- img_bytes = base64.b64decode(image_b64)
298
- img_pil = Image.open(io.BytesIO(img_bytes)).convert('RGB')
299
- orig_w, orig_h = img_pil.size
300
 
301
- tensor = preprocess(img_pil)
302
- with torch.no_grad():
303
- heatmaps = model(tensor)
 
 
 
 
 
304
 
305
- coords = heatmap_to_coords(heatmaps.numpy(), orig_w, orig_h)
306
 
307
- result = {}
308
- for i, lm_id in enumerate(HRNET_LANDMARKS):
309
- if i < len(coords):
310
- result[lm_id] = {"x": round(coords[i][0], 4),
311
- "y": round(coords[i][1], 4),
312
- "confidence": 0.85}
313
 
314
- return json.dumps({"landmarks": result, "notes": "HRNet-W32 detection"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- except Exception as e:
317
- return json.dumps({"error": str(e)})
318
 
 
 
 
 
 
 
319
 
320
- # ─── Gradio interface ─────────────────────────────────────────────────────────
321
- # We expose both a visual demo AND a pure API endpoint
322
 
323
- def demo_fn(image):
324
- """Visual demo: accepts PIL image, returns annotated image + JSON."""
325
- if image is None:
326
  return None, "{}"
327
- buf = io.BytesIO()
328
- image.save(buf, format='JPEG')
329
- b64 = base64.b64encode(buf.getvalue()).decode()
330
- result_json = detect_landmarks(b64)
331
- result = json.loads(result_json)
332
-
333
- # Draw landmarks on image for visual output
334
- from PIL import ImageDraw, ImageFont
335
- draw = ImageDraw.Draw(image)
336
- w, h = image.size
337
- colors = {'S':'#58a6ff','N':'#58a6ff','Or':'#58a6ff','Po':'#58a6ff',
338
- 'ANS':'#3fb950','PNS':'#3fb950','A':'#3fb950',
339
- 'U1tip':'#3fb950','U1ap':'#3fb950','U6':'#3fb950',
340
- 'B':'#f0883e','L1tip':'#f0883e','L1ap':'#f0883e',
341
- 'Pog':'#f0883e','Me':'#f0883e','Gn':'#f0883e',
342
- 'Go':'#f0883e','Co':'#f0883e','L6':'#f0883e'}
343
-
344
- for lm_id, pt in result.get('landmarks', {}).items():
345
- cx = int(pt['x'] * w)
346
- cy = int(pt['y'] * h)
347
- col = colors.get(lm_id, '#ffffff')
348
- r = 5
349
- draw.ellipse([cx-r, cy-r, cx+r, cy+r], fill=col, outline='black')
350
- draw.text((cx+7, cy-8), lm_id, fill=col)
351
-
352
- return image, json.dumps(result, indent=2)
353
-
354
- def api_fn(image_b64, mime_type="image/jpeg"):
355
- """Pure JSON API β€” called by the HTML tool."""
356
- return detect_landmarks(image_b64, mime_type)
357
-
358
- with gr.Blocks(title="OrthoTimes β€” Ceph Landmark Detection") as demo:
359
- gr.Markdown("""
360
- # OrthoTimes QuickCephTool β€” Landmark Detection API
361
- **HRNet-W32** pretrained on cephalometric radiographs.
362
- Detects 19 landmarks with MRE ~1.2–1.6 mm.
363
-
364
- ### API Usage (from JavaScript):
365
- ```js
366
- const result = await fetch(
367
- "https://YOUR-SPACE.hf.space/run/api",
368
- { method:"POST", headers:{"Content-Type":"application/json"},
369
- body: JSON.stringify({ data: [base64String, "image/jpeg"] }) }
370
- );
371
- const json = await result.json();
372
- const landmarks = JSON.parse(json.data[0]).landmarks;
373
- ```
374
- """)
375
 
376
  with gr.Row():
377
- inp = gr.Image(type='pil', label='Upload lateral cephalogram')
378
  with gr.Column():
379
- out_img = gr.Image(type='pil', label='Detected landmarks')
380
- out_json = gr.Textbox(label='JSON output', lines=20)
381
 
382
- gr.Button("Detect Landmarks").click(demo_fn, inputs=inp, outputs=[out_img, out_json])
383
 
384
- # Headless API endpoint
385
  gr.Interface(
386
- fn=api_fn,
387
- inputs=[gr.Textbox(label="base64 image"), gr.Textbox(label="mime type", value="image/jpeg")],
388
  outputs=gr.Textbox(label="JSON result"),
389
- api_name="api"
390
  )
391
 
392
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
  import json
3
  import base64
4
  import numpy as np
5
+ from PIL import Image, ImageDraw
6
  import torch
7
  import torch.nn as nn
 
8
  from huggingface_hub import hf_hub_download
9
  import gradio as gr
10
 
11
+ # ── HRNet-W32 ─────────────────────────────────────────────────────────────────
 
 
12
 
13
  class BasicBlock(nn.Module):
14
  expansion = 1
15
  def __init__(self, inplanes, planes, stride=1, downsample=None):
16
  super().__init__()
17
  self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu = nn.ReLU(inplace=True)
20
  self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
  self.downsample = downsample
23
 
24
  def forward(self, x):
 
25
  out = self.relu(self.bn1(self.conv1(x)))
26
  out = self.bn2(self.conv2(out))
27
+ return self.relu(out + (self.downsample(x) if self.downsample else x))
 
28
 
29
 
30
  class Bottleneck(nn.Module):
 
32
  def __init__(self, inplanes, planes, stride=1, downsample=None):
33
  super().__init__()
34
  self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
35
+ self.bn1 = nn.BatchNorm2d(planes)
36
  self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes)
38
  self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
39
+ self.bn3 = nn.BatchNorm2d(planes * 4)
40
+ self.relu = nn.ReLU(inplace=True)
41
  self.downsample = downsample
42
 
43
  def forward(self, x):
 
44
  out = self.relu(self.bn1(self.conv1(x)))
45
  out = self.relu(self.bn2(self.conv2(out)))
46
  out = self.bn3(self.conv3(out))
47
+ return self.relu(out + (self.downsample(x) if self.downsample else x))
 
48
 
49
 
50
+ def make_layer(block, inplanes, planes, blocks, stride=1):
51
+ downsample = None
52
+ if stride != 1 or inplanes != planes * block.expansion:
53
+ downsample = nn.Sequential(
54
+ nn.Conv2d(inplanes, planes * block.expansion, 1, stride=stride, bias=False),
55
+ nn.BatchNorm2d(planes * block.expansion)
56
+ )
57
+ layers = [block(inplanes, planes, stride, downsample)]
58
+ for _ in range(1, blocks):
59
+ layers.append(block(planes * block.expansion, planes))
60
+ return nn.Sequential(*layers)
 
 
 
 
 
 
 
 
 
61
 
62
+
63
+ class FuseLayer(nn.Module):
64
+ def __init__(self, num_branches, num_channels):
65
+ super().__init__()
66
+ self.num_branches = num_branches
67
+ fuse = []
68
+ for i in range(num_branches):
69
+ row = []
70
+ for j in range(num_branches):
71
  if j > i:
72
+ row.append(nn.Sequential(
73
+ nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False),
74
+ nn.BatchNorm2d(num_channels[i]),
75
+ nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')
76
  ))
77
  elif j == i:
78
+ row.append(nn.Identity())
79
  else:
80
+ convs = []
81
  for k in range(i - j):
82
+ inc = num_channels[j] if k == 0 else num_channels[i]
83
+ convs += [nn.Conv2d(inc, num_channels[i], 3, stride=2, padding=1, bias=False),
84
+ nn.BatchNorm2d(num_channels[i])]
85
+ if k < i - j - 1:
86
+ convs.append(nn.ReLU(True))
87
+ row.append(nn.Sequential(*convs))
88
+ fuse.append(nn.ModuleList(row))
89
+ self.fuse = nn.ModuleList(fuse)
90
+ self.relu = nn.ReLU(True)
 
 
 
 
 
91
 
92
  def forward(self, x):
93
+ out = []
94
+ for i in range(self.num_branches):
95
+ y = x[0] if i == 0 else self.fuse[i][0](x[0])
 
 
96
  for j in range(1, self.num_branches):
97
+ y = y + (x[j] if i == j else self.fuse[i][j](x[j]))
98
+ out.append(self.relu(y))
99
+ return out
100
+
101
+
102
+ class HRStage(nn.Module):
103
+ def __init__(self, num_branches, block, num_blocks, num_channels):
104
+ super().__init__()
105
+ self.branches = nn.ModuleList([
106
+ nn.Sequential(*[block(num_channels[i], num_channels[i]) for _ in range(num_blocks)])
107
+ for i in range(num_branches)
108
+ ])
109
+ self.fuse = FuseLayer(num_branches, num_channels)
110
+
111
+ def forward(self, x):
112
+ x = [b(xi) for b, xi in zip(self.branches, x)]
113
+ return self.fuse(x)
114
 
115
 
116
  class HRNet(nn.Module):
 
117
  def __init__(self, num_joints=19):
118
  super().__init__()
119
+ self.stem = nn.Sequential(
120
+ nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
121
+ nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
122
+ )
123
+ self.layer1 = make_layer(Bottleneck, 64, 64, 4)
124
+ self.trans1 = nn.ModuleList([
125
+ nn.Sequential(nn.Conv2d(256, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True)),
126
+ nn.Sequential(nn.Conv2d(256, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ])
128
+ self.stage2 = HRStage(2, BasicBlock, 4, [32, 64])
129
+ self.trans2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True))
130
+ self.stage3 = nn.Sequential(*[HRStage(3, BasicBlock, 4, [32, 64, 128]) for _ in range(4)])
131
+ self.trans3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True))
132
+ self.stage4 = nn.Sequential(*[HRStage(4, BasicBlock, 4, [32, 64, 128, 256]) for _ in range(3)])
133
+ self.head = nn.Conv2d(32, num_joints, 1)
 
 
 
 
 
 
 
 
 
 
134
 
135
  def forward(self, x):
136
+ x = self.stem(x)
 
137
  x = self.layer1(x)
138
+ x = [t(x) for t in self.trans1]
139
+ x = self.stage2(x)
140
+ x = [x[0], x[1], self.trans2(x[1])]
 
 
 
 
 
 
141
  for m in self.stage3:
142
+ x = m(x)
143
+ x = [x[0], x[1], x[2], self.trans3(x[2])]
 
 
 
 
 
 
144
  for m in self.stage4:
145
+ x = m(x)
146
+ return self.head(x[0])
147
 
148
 
149
+ # ── Load weights ──────────────────────────────────────────────────────────────
150
  print("Downloading model weights...")
151
  model_path = hf_hub_download(
152
  repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
153
  filename="best_model.pth"
154
  )
155
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
156
+ state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
157
  model = HRNet(num_joints=19)
158
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
159
+ print(f"Loaded. Missing: {len(missing)} | Unexpected: {len(unexpected)}")
 
160
  model.eval()
161
+ print("Model ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # ── Constants ─────────────────────────────────────────────────────────────────
164
+ LM_IDS = ['S', 'N', 'Or', 'Po', 'ANS', 'PNS', 'A', 'U1tip', 'L1tip', 'B',
165
+ 'Pog', 'Me', 'Gn', 'Go', 'Co', 'L1ap', 'U1ap', 'U6', 'L6']
 
 
 
 
 
 
 
166
 
167
+ LM_COLORS = {
168
+ 'S': '#58a6ff', 'N': '#58a6ff', 'Or': '#58a6ff', 'Po': '#58a6ff',
169
+ 'ANS': '#3fb950', 'PNS': '#3fb950', 'A': '#3fb950',
170
+ 'U1tip': '#3fb950', 'U1ap': '#3fb950', 'U6': '#3fb950',
171
+ 'B': '#f0883e', 'L1tip': '#f0883e', 'L1ap': '#f0883e',
172
+ 'Pog': '#f0883e', 'Me': '#f0883e', 'Gn': '#f0883e',
173
+ 'Go': '#f0883e', 'Co': '#f0883e', 'L6': '#f0883e'
174
+ }
175
 
176
+ INPUT_W, INPUT_H = 256, 320
177
 
 
 
 
 
 
 
178
 
179
+ # ── Helpers ───────────────────────────────────────────────────────────────────
180
+ def preprocess(pil_img):
181
+ img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
182
+ arr = np.array(img, dtype=np.float32) / 255.0
183
+ arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
184
+ return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()
185
+
186
+
187
+ def heatmap_to_coords(hm_np, orig_w, orig_h):
188
+ coords = {}
189
+ nj, hh, hw = hm_np.shape
190
+ for j in range(min(nj, len(LM_IDS))):
191
+ hm = hm_np[j]
192
+ idx = int(hm.argmax())
193
+ py, px = divmod(idx, hw)
194
+ if 1 <= px < hw - 1 and 1 <= py < hh - 1:
195
+ px += 0.25 * np.sign(float(hm[py, px + 1] - hm[py, px - 1]))
196
+ py += 0.25 * np.sign(float(hm[py + 1, px] - hm[py - 1, px]))
197
+ x_norm = float(np.clip((px / hw) * (INPUT_W / orig_w), 0, 1))
198
+ y_norm = float(np.clip((py / hh) * (INPUT_H / orig_h), 0, 1))
199
+ coords[LM_IDS[j]] = {"x": round(x_norm, 4), "y": round(y_norm, 4), "confidence": 0.85}
200
+ return coords
201
 
 
 
202
 
203
+ def run_detection(pil_img):
204
+ orig_w, orig_h = pil_img.size
205
+ tensor = preprocess(pil_img)
206
+ with torch.no_grad():
207
+ hm = model(tensor)[0].numpy()
208
+ return heatmap_to_coords(hm, orig_w, orig_h)
209
 
 
 
210
 
211
+ # ── Gradio functions ──────────────────────────────────────────────────────────
212
+ def detect_visual(pil_img):
213
+ if pil_img is None:
214
  return None, "{}"
215
+ coords = run_detection(pil_img)
216
+ out = pil_img.copy().convert("RGB")
217
+ draw = ImageDraw.Draw(out)
218
+ w, h = out.size
219
+ r = max(4, w // 120)
220
+ for lm_id, pt in coords.items():
221
+ cx, cy = int(pt['x'] * w), int(pt['y'] * h)
222
+ col = LM_COLORS.get(lm_id, '#ffffff')
223
+ draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, outline='black', width=1)
224
+ draw.text((cx + r + 2, cy - r), lm_id, fill=col)
225
+ return out, json.dumps({"landmarks": coords}, indent=2)
226
+
227
+
228
+ def detect_api(image_b64: str) -> str:
229
+ try:
230
+ img_bytes = base64.b64decode(image_b64)
231
+ pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
232
+ coords = run_detection(pil_img)
233
+ return json.dumps({"landmarks": coords})
234
+ except Exception as e:
235
+ return json.dumps({"error": str(e)})
236
+
237
+
238
+ # ── UI ────────────────────────────────────────────────────────────────────────
239
+ with gr.Blocks(title="OrthoTimes Landmark Detection") as demo:
240
+ gr.Markdown("## OrthoTimes QuickCephTool β€” HRNet Landmark Detection\nDetects 19 cephalometric landmarks automatically.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  with gr.Row():
243
+ img_in = gr.Image(type="pil", label="Upload lateral cephalogram")
244
  with gr.Column():
245
+ img_out = gr.Image(type="pil", label="Detected landmarks")
246
+ json_out = gr.Textbox(label="JSON output", lines=12)
247
 
248
+ gr.Button("β–Ά Detect Landmarks").click(fn=detect_visual, inputs=img_in, outputs=[img_out, json_out])
249
 
250
+ # Headless API endpoint β€” called by the HTML tool
251
  gr.Interface(
252
+ fn=detect_api,
253
+ inputs=gr.Textbox(label="Base64 image"),
254
  outputs=gr.Textbox(label="JSON result"),
255
+ api_name="detect"
256
  )
257
 
258
+ demo.launch()