Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,45 +1,30 @@
|
|
| 1 |
-
"""
|
| 2 |
-
OrthoTimes QuickCephTool β HRNet Landmark Detection API
|
| 3 |
-
Hugging Face Space backend.
|
| 4 |
-
|
| 5 |
-
Loads cwlachap/hrnet-cephalometric-landmark-detection, runs inference,
|
| 6 |
-
and returns 19 landmark coordinates as normalised (0-1) JSON.
|
| 7 |
-
|
| 8 |
-
The /detect endpoint is called by the HTML tool with a base64-encoded image.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
import io
|
| 12 |
import json
|
| 13 |
import base64
|
| 14 |
import numpy as np
|
| 15 |
-
from PIL import Image
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
| 18 |
-
import torch.nn.functional as F
|
| 19 |
from huggingface_hub import hf_hub_download
|
| 20 |
import gradio as gr
|
| 21 |
|
| 22 |
-
# ββ
|
| 23 |
-
# Minimal re-implementation matching the checkpoint structure.
|
| 24 |
-
# Adapted from the official HRNet codebase.
|
| 25 |
|
| 26 |
class BasicBlock(nn.Module):
|
| 27 |
expansion = 1
|
| 28 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 29 |
super().__init__()
|
| 30 |
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
|
| 31 |
-
self.bn1
|
| 32 |
-
self.relu
|
| 33 |
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 34 |
-
self.bn2
|
| 35 |
self.downsample = downsample
|
| 36 |
|
| 37 |
def forward(self, x):
|
| 38 |
-
residual = x
|
| 39 |
out = self.relu(self.bn1(self.conv1(x)))
|
| 40 |
out = self.bn2(self.conv2(out))
|
| 41 |
-
|
| 42 |
-
return self.relu(out + residual)
|
| 43 |
|
| 44 |
|
| 45 |
class Bottleneck(nn.Module):
|
|
@@ -47,346 +32,227 @@ class Bottleneck(nn.Module):
|
|
| 47 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 48 |
super().__init__()
|
| 49 |
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 50 |
-
self.bn1
|
| 51 |
self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False)
|
| 52 |
-
self.bn2
|
| 53 |
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
|
| 54 |
-
self.bn3
|
| 55 |
-
self.relu
|
| 56 |
self.downsample = downsample
|
| 57 |
|
| 58 |
def forward(self, x):
|
| 59 |
-
residual = x
|
| 60 |
out = self.relu(self.bn1(self.conv1(x)))
|
| 61 |
out = self.relu(self.bn2(self.conv2(out)))
|
| 62 |
out = self.bn3(self.conv3(out))
|
| 63 |
-
|
| 64 |
-
return self.relu(out + residual)
|
| 65 |
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
for i in range(num_blocks[branch_index]):
|
| 79 |
-
layers.append(block(num_channels[branch_index], num_channels[branch_index]))
|
| 80 |
-
return nn.Sequential(*layers)
|
| 81 |
-
|
| 82 |
-
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 83 |
-
return nn.ModuleList([
|
| 84 |
-
self._make_one_branch(i, block, num_blocks, num_channels)
|
| 85 |
-
for i in range(num_branches)
|
| 86 |
-
])
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if j > i:
|
| 94 |
-
|
| 95 |
-
nn.Conv2d(
|
| 96 |
-
nn.BatchNorm2d(
|
| 97 |
-
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
|
| 98 |
))
|
| 99 |
elif j == i:
|
| 100 |
-
|
| 101 |
else:
|
| 102 |
-
|
| 103 |
for k in range(i - j):
|
| 104 |
-
if k ==
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
))
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
nn.ReLU(True)
|
| 114 |
-
))
|
| 115 |
-
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 116 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 117 |
-
return nn.ModuleList(fuse_layers)
|
| 118 |
|
| 119 |
def forward(self, x):
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
for i in range(len(self.fuse_layers)):
|
| 124 |
-
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 125 |
for j in range(1, self.num_branches):
|
| 126 |
-
if i == j
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
class HRNet(nn.Module):
|
| 137 |
-
"""HRNet-W32 for heatmap-based landmark detection."""
|
| 138 |
def __init__(self, num_joints=19):
|
| 139 |
super().__init__()
|
| 140 |
-
self.
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
self.
|
| 145 |
-
self.
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
|
| 149 |
-
# Transition1 (256 β [32, 64])
|
| 150 |
-
self.transition1 = nn.ModuleList([
|
| 151 |
-
nn.Sequential(nn.Conv2d(256,32,3,padding=1,bias=False), nn.BatchNorm2d(32,momentum=0.1), nn.ReLU(True)),
|
| 152 |
-
nn.Sequential(nn.Sequential(nn.Conv2d(256,64,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(64,momentum=0.1), nn.ReLU(True)))
|
| 153 |
-
])
|
| 154 |
-
# Stage2
|
| 155 |
-
self.stage2 = nn.Sequential(HRModule(2, BasicBlock, [4,4], [32,64], [32,64]))
|
| 156 |
-
# Transition2
|
| 157 |
-
self.transition2 = nn.ModuleList([
|
| 158 |
-
None, None,
|
| 159 |
-
nn.Sequential(nn.Conv2d(64,128,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(128,momentum=0.1), nn.ReLU(True))
|
| 160 |
-
])
|
| 161 |
-
# Stage3
|
| 162 |
-
self.stage3 = nn.Sequential(*[HRModule(3, BasicBlock, [4,4,4], [32,64,128], [32,64,128]) for _ in range(4)])
|
| 163 |
-
# Transition3
|
| 164 |
-
self.transition3 = nn.ModuleList([
|
| 165 |
-
None, None, None,
|
| 166 |
-
nn.Sequential(nn.Conv2d(128,256,3,stride=2,padding=1,bias=False), nn.BatchNorm2d(256,momentum=0.1), nn.ReLU(True))
|
| 167 |
])
|
| 168 |
-
|
| 169 |
-
self.
|
| 170 |
-
|
| 171 |
-
self.
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
downsample = None
|
| 175 |
-
if stride != 1 or inplanes != planes * block.expansion:
|
| 176 |
-
downsample = nn.Sequential(
|
| 177 |
-
nn.Conv2d(inplanes, planes*block.expansion, 1, stride=stride, bias=False),
|
| 178 |
-
nn.BatchNorm2d(planes*block.expansion, momentum=0.1)
|
| 179 |
-
)
|
| 180 |
-
layers = [block(inplanes, planes, stride, downsample)]
|
| 181 |
-
for _ in range(1, blocks):
|
| 182 |
-
layers.append(block(planes*block.expansion, planes))
|
| 183 |
-
return nn.Sequential(*layers)
|
| 184 |
|
| 185 |
def forward(self, x):
|
| 186 |
-
x = self.
|
| 187 |
-
x = self.relu(self.bn2(self.conv2(x)))
|
| 188 |
x = self.layer1(x)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
for i, t in enumerate(self.transition2):
|
| 193 |
-
if t is None:
|
| 194 |
-
xl2.append(xl[i] if i < len(xl) else xl[-1])
|
| 195 |
-
else:
|
| 196 |
-
xl2.append(t(xl[-1]))
|
| 197 |
-
xl = xl2
|
| 198 |
for m in self.stage3:
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
for i, t in enumerate(self.transition3):
|
| 202 |
-
if t is None:
|
| 203 |
-
xl3.append(xl[i] if i < len(xl) else xl[-1])
|
| 204 |
-
else:
|
| 205 |
-
xl3.append(t(xl[-1]))
|
| 206 |
-
xl = xl3
|
| 207 |
for m in self.stage4:
|
| 208 |
-
|
| 209 |
-
return self.
|
| 210 |
|
| 211 |
|
| 212 |
-
# ββ
|
| 213 |
print("Downloading model weights...")
|
| 214 |
model_path = hf_hub_download(
|
| 215 |
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
|
| 216 |
filename="best_model.pth"
|
| 217 |
)
|
|
|
|
|
|
|
| 218 |
model = HRNet(num_joints=19)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
model.load_state_dict(state, strict=False)
|
| 222 |
model.eval()
|
| 223 |
-
print("Model
|
| 224 |
-
|
| 225 |
-
# βββ Landmark mapping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 226 |
-
# HRNet model outputs 19 landmarks in this order (ISBI 2015 dataset standard)
|
| 227 |
-
HRNET_LANDMARKS = [
|
| 228 |
-
'S', # 0 Sella
|
| 229 |
-
'N', # 1 Nasion
|
| 230 |
-
'Or', # 2 Orbitale
|
| 231 |
-
'Po', # 3 Porion
|
| 232 |
-
'ANS', # 4 ANS
|
| 233 |
-
'PNS', # 5 PNS
|
| 234 |
-
'A', # 6 Point A
|
| 235 |
-
'U1tip', # 7 Upper incisor tip
|
| 236 |
-
'L1tip', # 8 Lower incisor tip
|
| 237 |
-
'B', # 9 Point B
|
| 238 |
-
'Pog', # 10 Pogonion
|
| 239 |
-
'Me', # 11 Menton
|
| 240 |
-
'Gn', # 12 Gnathion
|
| 241 |
-
'Go', # 13 Gonion
|
| 242 |
-
'Co', # 14 Condylion (some datasets use Ar here)
|
| 243 |
-
'L1ap', # 15 Lower incisor apex
|
| 244 |
-
'U1ap', # 16 Upper incisor apex
|
| 245 |
-
'U6', # 17 Upper molar
|
| 246 |
-
'L6', # 18 Lower molar
|
| 247 |
-
]
|
| 248 |
-
|
| 249 |
-
# βββ Preprocessing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 250 |
-
INPUT_W, INPUT_H = 256, 320 # model input size
|
| 251 |
-
|
| 252 |
-
def preprocess(img_pil):
|
| 253 |
-
"""Convert PIL image β normalised tensor [1,3,H,W]."""
|
| 254 |
-
img = img_pil.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
|
| 255 |
-
arr = np.array(img, dtype=np.float32) / 255.0
|
| 256 |
-
mean = np.array([0.485, 0.456, 0.406])
|
| 257 |
-
std = np.array([0.229, 0.224, 0.225])
|
| 258 |
-
arr = (arr - mean) / std
|
| 259 |
-
tensor = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).float()
|
| 260 |
-
return tensor
|
| 261 |
-
|
| 262 |
-
def heatmap_to_coords(heatmaps, orig_w, orig_h):
|
| 263 |
-
"""
|
| 264 |
-
heatmaps: [1, num_joints, H, W]
|
| 265 |
-
Returns list of (x_norm, y_norm) tuples in original image space.
|
| 266 |
-
"""
|
| 267 |
-
hm = heatmaps[0] # [num_joints, H, W]
|
| 268 |
-
num_joints, hm_h, hm_w = hm.shape
|
| 269 |
-
coords = []
|
| 270 |
-
for j in range(num_joints):
|
| 271 |
-
flat = hm[j].reshape(-1)
|
| 272 |
-
idx = int(flat.argmax())
|
| 273 |
-
py = idx // hm_w
|
| 274 |
-
px = idx % hm_w
|
| 275 |
-
# Sub-pixel refinement: nudge toward neighbouring maxima
|
| 276 |
-
if 1 <= px < hm_w-1 and 1 <= py < hm_h-1:
|
| 277 |
-
dx = float(hm[j, py, px+1] - hm[j, py, px-1])
|
| 278 |
-
dy = float(hm[j, py+1, px] - hm[j, py-1, px])
|
| 279 |
-
px += 0.25 * np.sign(dx)
|
| 280 |
-
py += 0.25 * np.sign(dy)
|
| 281 |
-
# Normalise to 0-1 in original image space
|
| 282 |
-
x_norm = (px / hm_w) * (INPUT_W / orig_w)
|
| 283 |
-
y_norm = (py / hm_h) * (INPUT_H / orig_h)
|
| 284 |
-
# Clamp
|
| 285 |
-
x_norm = float(np.clip(x_norm, 0.0, 1.0))
|
| 286 |
-
y_norm = float(np.clip(y_norm, 0.0, 1.0))
|
| 287 |
-
coords.append((x_norm, y_norm))
|
| 288 |
-
return coords
|
| 289 |
|
| 290 |
-
# ββ
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
Accepts base64-encoded image string.
|
| 294 |
-
Returns JSON: {"landmarks": {"S": {"x":..,"y":..}, ...}}
|
| 295 |
-
"""
|
| 296 |
-
try:
|
| 297 |
-
img_bytes = base64.b64decode(image_b64)
|
| 298 |
-
img_pil = Image.open(io.BytesIO(img_bytes)).convert('RGB')
|
| 299 |
-
orig_w, orig_h = img_pil.size
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
|
| 306 |
|
| 307 |
-
result = {}
|
| 308 |
-
for i, lm_id in enumerate(HRNET_LANDMARKS):
|
| 309 |
-
if i < len(coords):
|
| 310 |
-
result[lm_id] = {"x": round(coords[i][0], 4),
|
| 311 |
-
"y": round(coords[i][1], 4),
|
| 312 |
-
"confidence": 0.85}
|
| 313 |
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
-
except Exception as e:
|
| 317 |
-
return json.dumps({"error": str(e)})
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
# βββ Gradio interface βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 321 |
-
# We expose both a visual demo AND a pure API endpoint
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
if
|
| 326 |
return None, "{}"
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def api_fn(image_b64, mime_type="image/jpeg"):
|
| 355 |
-
"""Pure JSON API β called by the HTML tool."""
|
| 356 |
-
return detect_landmarks(image_b64, mime_type)
|
| 357 |
-
|
| 358 |
-
with gr.Blocks(title="OrthoTimes β Ceph Landmark Detection") as demo:
|
| 359 |
-
gr.Markdown("""
|
| 360 |
-
# OrthoTimes QuickCephTool β Landmark Detection API
|
| 361 |
-
**HRNet-W32** pretrained on cephalometric radiographs.
|
| 362 |
-
Detects 19 landmarks with MRE ~1.2β1.6 mm.
|
| 363 |
-
|
| 364 |
-
### API Usage (from JavaScript):
|
| 365 |
-
```js
|
| 366 |
-
const result = await fetch(
|
| 367 |
-
"https://YOUR-SPACE.hf.space/run/api",
|
| 368 |
-
{ method:"POST", headers:{"Content-Type":"application/json"},
|
| 369 |
-
body: JSON.stringify({ data: [base64String, "image/jpeg"] }) }
|
| 370 |
-
);
|
| 371 |
-
const json = await result.json();
|
| 372 |
-
const landmarks = JSON.parse(json.data[0]).landmarks;
|
| 373 |
-
```
|
| 374 |
-
""")
|
| 375 |
|
| 376 |
with gr.Row():
|
| 377 |
-
|
| 378 |
with gr.Column():
|
| 379 |
-
|
| 380 |
-
|
| 381 |
|
| 382 |
-
gr.Button("Detect Landmarks").click(
|
| 383 |
|
| 384 |
-
# Headless API endpoint
|
| 385 |
gr.Interface(
|
| 386 |
-
fn=
|
| 387 |
-
inputs=
|
| 388 |
outputs=gr.Textbox(label="JSON result"),
|
| 389 |
-
api_name="
|
| 390 |
)
|
| 391 |
|
| 392 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
| 2 |
import json
|
| 3 |
import base64
|
| 4 |
import numpy as np
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
+
# ββ HRNet-W32 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class BasicBlock(nn.Module):
|
| 14 |
expansion = 1
|
| 15 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 16 |
super().__init__()
|
| 17 |
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu = nn.ReLU(inplace=True)
|
| 20 |
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 22 |
self.downsample = downsample
|
| 23 |
|
| 24 |
def forward(self, x):
|
|
|
|
| 25 |
out = self.relu(self.bn1(self.conv1(x)))
|
| 26 |
out = self.bn2(self.conv2(out))
|
| 27 |
+
return self.relu(out + (self.downsample(x) if self.downsample else x))
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class Bottleneck(nn.Module):
|
|
|
|
| 32 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 33 |
super().__init__()
|
| 34 |
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 35 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 36 |
self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False)
|
| 37 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 38 |
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
|
| 39 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 40 |
+
self.relu = nn.ReLU(inplace=True)
|
| 41 |
self.downsample = downsample
|
| 42 |
|
| 43 |
def forward(self, x):
|
|
|
|
| 44 |
out = self.relu(self.bn1(self.conv1(x)))
|
| 45 |
out = self.relu(self.bn2(self.conv2(out)))
|
| 46 |
out = self.bn3(self.conv3(out))
|
| 47 |
+
return self.relu(out + (self.downsample(x) if self.downsample else x))
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
+
def make_layer(block, inplanes, planes, blocks, stride=1):
|
| 51 |
+
downsample = None
|
| 52 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
| 53 |
+
downsample = nn.Sequential(
|
| 54 |
+
nn.Conv2d(inplanes, planes * block.expansion, 1, stride=stride, bias=False),
|
| 55 |
+
nn.BatchNorm2d(planes * block.expansion)
|
| 56 |
+
)
|
| 57 |
+
layers = [block(inplanes, planes, stride, downsample)]
|
| 58 |
+
for _ in range(1, blocks):
|
| 59 |
+
layers.append(block(planes * block.expansion, planes))
|
| 60 |
+
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
|
| 63 |
+
class FuseLayer(nn.Module):
|
| 64 |
+
def __init__(self, num_branches, num_channels):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.num_branches = num_branches
|
| 67 |
+
fuse = []
|
| 68 |
+
for i in range(num_branches):
|
| 69 |
+
row = []
|
| 70 |
+
for j in range(num_branches):
|
| 71 |
if j > i:
|
| 72 |
+
row.append(nn.Sequential(
|
| 73 |
+
nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False),
|
| 74 |
+
nn.BatchNorm2d(num_channels[i]),
|
| 75 |
+
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')
|
| 76 |
))
|
| 77 |
elif j == i:
|
| 78 |
+
row.append(nn.Identity())
|
| 79 |
else:
|
| 80 |
+
convs = []
|
| 81 |
for k in range(i - j):
|
| 82 |
+
inc = num_channels[j] if k == 0 else num_channels[i]
|
| 83 |
+
convs += [nn.Conv2d(inc, num_channels[i], 3, stride=2, padding=1, bias=False),
|
| 84 |
+
nn.BatchNorm2d(num_channels[i])]
|
| 85 |
+
if k < i - j - 1:
|
| 86 |
+
convs.append(nn.ReLU(True))
|
| 87 |
+
row.append(nn.Sequential(*convs))
|
| 88 |
+
fuse.append(nn.ModuleList(row))
|
| 89 |
+
self.fuse = nn.ModuleList(fuse)
|
| 90 |
+
self.relu = nn.ReLU(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def forward(self, x):
|
| 93 |
+
out = []
|
| 94 |
+
for i in range(self.num_branches):
|
| 95 |
+
y = x[0] if i == 0 else self.fuse[i][0](x[0])
|
|
|
|
|
|
|
| 96 |
for j in range(1, self.num_branches):
|
| 97 |
+
y = y + (x[j] if i == j else self.fuse[i][j](x[j]))
|
| 98 |
+
out.append(self.relu(y))
|
| 99 |
+
return out
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class HRStage(nn.Module):
|
| 103 |
+
def __init__(self, num_branches, block, num_blocks, num_channels):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.branches = nn.ModuleList([
|
| 106 |
+
nn.Sequential(*[block(num_channels[i], num_channels[i]) for _ in range(num_blocks)])
|
| 107 |
+
for i in range(num_branches)
|
| 108 |
+
])
|
| 109 |
+
self.fuse = FuseLayer(num_branches, num_channels)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
x = [b(xi) for b, xi in zip(self.branches, x)]
|
| 113 |
+
return self.fuse(x)
|
| 114 |
|
| 115 |
|
| 116 |
class HRNet(nn.Module):
|
|
|
|
| 117 |
def __init__(self, num_joints=19):
|
| 118 |
super().__init__()
|
| 119 |
+
self.stem = nn.Sequential(
|
| 120 |
+
nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
|
| 121 |
+
nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
|
| 122 |
+
)
|
| 123 |
+
self.layer1 = make_layer(Bottleneck, 64, 64, 4)
|
| 124 |
+
self.trans1 = nn.ModuleList([
|
| 125 |
+
nn.Sequential(nn.Conv2d(256, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True)),
|
| 126 |
+
nn.Sequential(nn.Conv2d(256, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
])
|
| 128 |
+
self.stage2 = HRStage(2, BasicBlock, 4, [32, 64])
|
| 129 |
+
self.trans2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True))
|
| 130 |
+
self.stage3 = nn.Sequential(*[HRStage(3, BasicBlock, 4, [32, 64, 128]) for _ in range(4)])
|
| 131 |
+
self.trans3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True))
|
| 132 |
+
self.stage4 = nn.Sequential(*[HRStage(4, BasicBlock, 4, [32, 64, 128, 256]) for _ in range(3)])
|
| 133 |
+
self.head = nn.Conv2d(32, num_joints, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def forward(self, x):
|
| 136 |
+
x = self.stem(x)
|
|
|
|
| 137 |
x = self.layer1(x)
|
| 138 |
+
x = [t(x) for t in self.trans1]
|
| 139 |
+
x = self.stage2(x)
|
| 140 |
+
x = [x[0], x[1], self.trans2(x[1])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
for m in self.stage3:
|
| 142 |
+
x = m(x)
|
| 143 |
+
x = [x[0], x[1], x[2], self.trans3(x[2])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
for m in self.stage4:
|
| 145 |
+
x = m(x)
|
| 146 |
+
return self.head(x[0])
|
| 147 |
|
| 148 |
|
| 149 |
+
# ββ Load weights ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
print("Downloading model weights...")
|
| 151 |
model_path = hf_hub_download(
|
| 152 |
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
|
| 153 |
filename="best_model.pth"
|
| 154 |
)
|
| 155 |
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 156 |
+
state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
|
| 157 |
model = HRNet(num_joints=19)
|
| 158 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 159 |
+
print(f"Loaded. Missing: {len(missing)} | Unexpected: {len(unexpected)}")
|
|
|
|
| 160 |
model.eval()
|
| 161 |
+
print("Model ready.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
# ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 164 |
+
LM_IDS = ['S', 'N', 'Or', 'Po', 'ANS', 'PNS', 'A', 'U1tip', 'L1tip', 'B',
|
| 165 |
+
'Pog', 'Me', 'Gn', 'Go', 'Co', 'L1ap', 'U1ap', 'U6', 'L6']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
LM_COLORS = {
|
| 168 |
+
'S': '#58a6ff', 'N': '#58a6ff', 'Or': '#58a6ff', 'Po': '#58a6ff',
|
| 169 |
+
'ANS': '#3fb950', 'PNS': '#3fb950', 'A': '#3fb950',
|
| 170 |
+
'U1tip': '#3fb950', 'U1ap': '#3fb950', 'U6': '#3fb950',
|
| 171 |
+
'B': '#f0883e', 'L1tip': '#f0883e', 'L1ap': '#f0883e',
|
| 172 |
+
'Pog': '#f0883e', 'Me': '#f0883e', 'Gn': '#f0883e',
|
| 173 |
+
'Go': '#f0883e', 'Co': '#f0883e', 'L6': '#f0883e'
|
| 174 |
+
}
|
| 175 |
|
| 176 |
+
INPUT_W, INPUT_H = 256, 320
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
def preprocess(pil_img):
|
| 181 |
+
img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
|
| 182 |
+
arr = np.array(img, dtype=np.float32) / 255.0
|
| 183 |
+
arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
| 184 |
+
return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def heatmap_to_coords(hm_np, orig_w, orig_h):
|
| 188 |
+
coords = {}
|
| 189 |
+
nj, hh, hw = hm_np.shape
|
| 190 |
+
for j in range(min(nj, len(LM_IDS))):
|
| 191 |
+
hm = hm_np[j]
|
| 192 |
+
idx = int(hm.argmax())
|
| 193 |
+
py, px = divmod(idx, hw)
|
| 194 |
+
if 1 <= px < hw - 1 and 1 <= py < hh - 1:
|
| 195 |
+
px += 0.25 * np.sign(float(hm[py, px + 1] - hm[py, px - 1]))
|
| 196 |
+
py += 0.25 * np.sign(float(hm[py + 1, px] - hm[py - 1, px]))
|
| 197 |
+
x_norm = float(np.clip((px / hw) * (INPUT_W / orig_w), 0, 1))
|
| 198 |
+
y_norm = float(np.clip((py / hh) * (INPUT_H / orig_h), 0, 1))
|
| 199 |
+
coords[LM_IDS[j]] = {"x": round(x_norm, 4), "y": round(y_norm, 4), "confidence": 0.85}
|
| 200 |
+
return coords
|
| 201 |
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
def run_detection(pil_img):
|
| 204 |
+
orig_w, orig_h = pil_img.size
|
| 205 |
+
tensor = preprocess(pil_img)
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
hm = model(tensor)[0].numpy()
|
| 208 |
+
return heatmap_to_coords(hm, orig_w, orig_h)
|
| 209 |
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
# ββ Gradio functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
def detect_visual(pil_img):
|
| 213 |
+
if pil_img is None:
|
| 214 |
return None, "{}"
|
| 215 |
+
coords = run_detection(pil_img)
|
| 216 |
+
out = pil_img.copy().convert("RGB")
|
| 217 |
+
draw = ImageDraw.Draw(out)
|
| 218 |
+
w, h = out.size
|
| 219 |
+
r = max(4, w // 120)
|
| 220 |
+
for lm_id, pt in coords.items():
|
| 221 |
+
cx, cy = int(pt['x'] * w), int(pt['y'] * h)
|
| 222 |
+
col = LM_COLORS.get(lm_id, '#ffffff')
|
| 223 |
+
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, outline='black', width=1)
|
| 224 |
+
draw.text((cx + r + 2, cy - r), lm_id, fill=col)
|
| 225 |
+
return out, json.dumps({"landmarks": coords}, indent=2)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def detect_api(image_b64: str) -> str:
|
| 229 |
+
try:
|
| 230 |
+
img_bytes = base64.b64decode(image_b64)
|
| 231 |
+
pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
|
| 232 |
+
coords = run_detection(pil_img)
|
| 233 |
+
return json.dumps({"landmarks": coords})
|
| 234 |
+
except Exception as e:
|
| 235 |
+
return json.dumps({"error": str(e)})
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
+
with gr.Blocks(title="OrthoTimes Landmark Detection") as demo:
|
| 240 |
+
gr.Markdown("## OrthoTimes QuickCephTool β HRNet Landmark Detection\nDetects 19 cephalometric landmarks automatically.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
with gr.Row():
|
| 243 |
+
img_in = gr.Image(type="pil", label="Upload lateral cephalogram")
|
| 244 |
with gr.Column():
|
| 245 |
+
img_out = gr.Image(type="pil", label="Detected landmarks")
|
| 246 |
+
json_out = gr.Textbox(label="JSON output", lines=12)
|
| 247 |
|
| 248 |
+
gr.Button("βΆ Detect Landmarks").click(fn=detect_visual, inputs=img_in, outputs=[img_out, json_out])
|
| 249 |
|
| 250 |
+
# Headless API endpoint β called by the HTML tool
|
| 251 |
gr.Interface(
|
| 252 |
+
fn=detect_api,
|
| 253 |
+
inputs=gr.Textbox(label="Base64 image"),
|
| 254 |
outputs=gr.Textbox(label="JSON result"),
|
| 255 |
+
api_name="detect"
|
| 256 |
)
|
| 257 |
|
| 258 |
+
demo.launch()
|