# Patch missing audioop module (removed in Python 3.13) before importing gradio import sys import types if 'audioop' not in sys.modules: _audioop = types.ModuleType('audioop') sys.modules['audioop'] = _audioop import io import json import base64 import numpy as np from PIL import Image, ImageDraw import torch import torch.nn as nn from huggingface_hub import hf_hub_download import gradio as gr # ── HRNet-W32 ───────────────────────────────────────────────────────────────── class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample def forward(self, x): out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return self.relu(out + (self.downsample(x) if self.downsample else x)) class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) return self.relu(out + (self.downsample(x) if self.downsample else x)) def make_layer(block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, 1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion) ) layers = [block(inplanes, planes, stride, downsample)] for _ in range(1, blocks): layers.append(block(planes * block.expansion, planes)) return nn.Sequential(*layers) class FuseLayer(nn.Module): def __init__(self, num_branches, num_channels): super().__init__() self.num_branches = num_branches fuse = [] for i in range(num_branches): row = [] for j in range(num_branches): if j > i: row.append(nn.Sequential( nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False), nn.BatchNorm2d(num_channels[i]), nn.Upsample(scale_factor=2 ** (j - i), mode='nearest') )) elif j == i: row.append(nn.Identity()) else: convs = [] for k in range(i - j): inc = num_channels[j] if k == 0 else num_channels[i] convs += [nn.Conv2d(inc, num_channels[i], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(num_channels[i])] if k < i - j - 1: convs.append(nn.ReLU(True)) row.append(nn.Sequential(*convs)) fuse.append(nn.ModuleList(row)) self.fuse = nn.ModuleList(fuse) self.relu = nn.ReLU(True) def forward(self, x): out = [] for i in range(self.num_branches): y = x[0] if i == 0 else self.fuse[i][0](x[0]) for j in range(1, self.num_branches): y = y + (x[j] if i == j else self.fuse[i][j](x[j])) out.append(self.relu(y)) return out class HRStage(nn.Module): def __init__(self, num_branches, block, num_blocks, num_channels): super().__init__() self.branches = nn.ModuleList([ nn.Sequential(*[block(num_channels[i], num_channels[i]) for _ in range(num_blocks)]) for i in range(num_branches) ]) self.fuse = FuseLayer(num_branches, num_channels) def forward(self, x): x = [b(xi) for b, xi in zip(self.branches, x)] return self.fuse(x) class HRNet(nn.Module): def __init__(self, num_joints=19): super().__init__() self.stem = nn.Sequential( nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), ) self.layer1 = make_layer(Bottleneck, 64, 64, 4) self.trans1 = nn.ModuleList([ nn.Sequential(nn.Conv2d(256, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True)), nn.Sequential(nn.Conv2d(256, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)), ]) self.stage2 = HRStage(2, BasicBlock, 4, [32, 64]) self.trans2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True)) self.stage3 = nn.Sequential(*[HRStage(3, BasicBlock, 4, [32, 64, 128]) for _ in range(4)]) self.trans3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True)) self.stage4 = nn.Sequential(*[HRStage(4, BasicBlock, 4, [32, 64, 128, 256]) for _ in range(3)]) self.head = nn.Conv2d(32, num_joints, 1) def forward(self, x): x = self.stem(x) x = self.layer1(x) x = [t(x) for t in self.trans1] x = self.stage2(x) x = [x[0], x[1], self.trans2(x[1])] for m in self.stage3: x = m(x) x = [x[0], x[1], x[2], self.trans3(x[2])] for m in self.stage4: x = m(x) return self.head(x[0]) # ── Load weights ────────────────────────────────────────────────────────────── print("Downloading model weights...") model_path = hf_hub_download( repo_id="cwlachap/hrnet-cephalometric-landmark-detection", filename="best_model.pth" ) checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint)) model = HRNet(num_joints=19) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"Loaded. Missing: {len(missing)} | Unexpected: {len(unexpected)}") model.eval() print("Model ready.") # ── Constants ───────────────────────────────────────────────────────────────── LM_IDS = ['S', 'N', 'Or', 'Po', 'ANS', 'PNS', 'A', 'U1tip', 'L1tip', 'B', 'Pog', 'Me', 'Gn', 'Go', 'Co', 'L1ap', 'U1ap', 'U6', 'L6'] LM_COLORS = { 'S': '#58a6ff', 'N': '#58a6ff', 'Or': '#58a6ff', 'Po': '#58a6ff', 'ANS': '#3fb950', 'PNS': '#3fb950', 'A': '#3fb950', 'U1tip': '#3fb950', 'U1ap': '#3fb950', 'U6': '#3fb950', 'B': '#f0883e', 'L1tip': '#f0883e', 'L1ap': '#f0883e', 'Pog': '#f0883e', 'Me': '#f0883e', 'Gn': '#f0883e', 'Go': '#f0883e', 'Co': '#f0883e', 'L6': '#f0883e' } INPUT_W, INPUT_H = 256, 320 # ── Helpers ─────────────────────────────────────────────────────────────────── def preprocess(pil_img): img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float() def heatmap_to_coords(hm_np, orig_w, orig_h): coords = {} nj, hh, hw = hm_np.shape for j in range(min(nj, len(LM_IDS))): hm = hm_np[j] idx = int(hm.argmax()) py, px = divmod(idx, hw) if 1 <= px < hw - 1 and 1 <= py < hh - 1: px += 0.25 * np.sign(float(hm[py, px + 1] - hm[py, px - 1])) py += 0.25 * np.sign(float(hm[py + 1, px] - hm[py - 1, px])) x_norm = float(np.clip((px / hw) * (INPUT_W / orig_w), 0, 1)) y_norm = float(np.clip((py / hh) * (INPUT_H / orig_h), 0, 1)) coords[LM_IDS[j]] = {"x": round(x_norm, 4), "y": round(y_norm, 4), "confidence": 0.85} return coords def run_detection(pil_img): orig_w, orig_h = pil_img.size tensor = preprocess(pil_img) with torch.no_grad(): hm = model(tensor)[0].numpy() return heatmap_to_coords(hm, orig_w, orig_h) # ── Gradio functions ────────────────────────────────────────────────────────── def detect_visual(pil_img): if pil_img is None: return None, "{}" coords = run_detection(pil_img) out = pil_img.copy().convert("RGB") draw = ImageDraw.Draw(out) w, h = out.size r = max(4, w // 120) for lm_id, pt in coords.items(): cx, cy = int(pt['x'] * w), int(pt['y'] * h) col = LM_COLORS.get(lm_id, '#ffffff') draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, outline='black', width=1) draw.text((cx + r + 2, cy - r), lm_id, fill=col) return out, json.dumps({"landmarks": coords}, indent=2) def detect_api(image_b64: str) -> str: try: img_bytes = base64.b64decode(image_b64) pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB') coords = run_detection(pil_img) return json.dumps({"landmarks": coords}) except Exception as e: return json.dumps({"error": str(e)}) # ── UI — single Blocks interface only, no nested Interface ──────────────────── with gr.Blocks(title="OrthoTimes Landmark Detection") as demo: gr.Markdown("## OrthoTimes QuickCephTool — HRNet Landmark Detection\nDetects 19 cephalometric landmarks automatically.") with gr.Row(): img_in = gr.Image(type="pil", label="Upload lateral cephalogram") with gr.Column(): img_out = gr.Image(type="pil", label="Detected landmarks") json_out = gr.Textbox(label="JSON output", lines=12) gr.Button("▶ Detect Landmarks").click(fn=detect_visual, inputs=img_in, outputs=[img_out, json_out]) # API tab for programmatic access with gr.Tab("API"): b64_in = gr.Textbox(label="Base64 encoded image", lines=3) b64_out = gr.Textbox(label="JSON result", lines=10) gr.Button("Detect (API)").click(fn=detect_api, inputs=b64_in, outputs=b64_out) demo.launch(server_name="0.0.0.0", server_port=7860)