Spaces:
Running
Running
| # Patch missing audioop module (removed in Python 3.13) before importing gradio | |
| import sys | |
| import types | |
| if 'audioop' not in sys.modules: | |
| _audioop = types.ModuleType('audioop') | |
| sys.modules['audioop'] = _audioop | |
| import io | |
| import json | |
| import base64 | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # ββ HRNet-W32 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| out = self.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| return self.relu(out + (self.downsample(x) if self.downsample else x)) | |
| class Bottleneck(nn.Module): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(planes * 4) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| out = self.relu(self.bn1(self.conv1(x))) | |
| out = self.relu(self.bn2(self.conv2(out))) | |
| out = self.bn3(self.conv3(out)) | |
| return self.relu(out + (self.downsample(x) if self.downsample else x)) | |
| def make_layer(block, inplanes, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv2d(inplanes, planes * block.expansion, 1, stride=stride, bias=False), | |
| nn.BatchNorm2d(planes * block.expansion) | |
| ) | |
| layers = [block(inplanes, planes, stride, downsample)] | |
| for _ in range(1, blocks): | |
| layers.append(block(planes * block.expansion, planes)) | |
| return nn.Sequential(*layers) | |
| class FuseLayer(nn.Module): | |
| def __init__(self, num_branches, num_channels): | |
| super().__init__() | |
| self.num_branches = num_branches | |
| fuse = [] | |
| for i in range(num_branches): | |
| row = [] | |
| for j in range(num_branches): | |
| if j > i: | |
| row.append(nn.Sequential( | |
| nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False), | |
| nn.BatchNorm2d(num_channels[i]), | |
| nn.Upsample(scale_factor=2 ** (j - i), mode='nearest') | |
| )) | |
| elif j == i: | |
| row.append(nn.Identity()) | |
| else: | |
| convs = [] | |
| for k in range(i - j): | |
| inc = num_channels[j] if k == 0 else num_channels[i] | |
| convs += [nn.Conv2d(inc, num_channels[i], 3, stride=2, padding=1, bias=False), | |
| nn.BatchNorm2d(num_channels[i])] | |
| if k < i - j - 1: | |
| convs.append(nn.ReLU(True)) | |
| row.append(nn.Sequential(*convs)) | |
| fuse.append(nn.ModuleList(row)) | |
| self.fuse = nn.ModuleList(fuse) | |
| self.relu = nn.ReLU(True) | |
| def forward(self, x): | |
| out = [] | |
| for i in range(self.num_branches): | |
| y = x[0] if i == 0 else self.fuse[i][0](x[0]) | |
| for j in range(1, self.num_branches): | |
| y = y + (x[j] if i == j else self.fuse[i][j](x[j])) | |
| out.append(self.relu(y)) | |
| return out | |
| class HRStage(nn.Module): | |
| def __init__(self, num_branches, block, num_blocks, num_channels): | |
| super().__init__() | |
| self.branches = nn.ModuleList([ | |
| nn.Sequential(*[block(num_channels[i], num_channels[i]) for _ in range(num_blocks)]) | |
| for i in range(num_branches) | |
| ]) | |
| self.fuse = FuseLayer(num_branches, num_channels) | |
| def forward(self, x): | |
| x = [b(xi) for b, xi in zip(self.branches, x)] | |
| return self.fuse(x) | |
| class HRNet(nn.Module): | |
| def __init__(self, num_joints=19): | |
| super().__init__() | |
| self.stem = nn.Sequential( | |
| nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), | |
| nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), | |
| ) | |
| self.layer1 = make_layer(Bottleneck, 64, 64, 4) | |
| self.trans1 = nn.ModuleList([ | |
| nn.Sequential(nn.Conv2d(256, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True)), | |
| nn.Sequential(nn.Conv2d(256, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)), | |
| ]) | |
| self.stage2 = HRStage(2, BasicBlock, 4, [32, 64]) | |
| self.trans2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True)) | |
| self.stage3 = nn.Sequential(*[HRStage(3, BasicBlock, 4, [32, 64, 128]) for _ in range(4)]) | |
| self.trans3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True)) | |
| self.stage4 = nn.Sequential(*[HRStage(4, BasicBlock, 4, [32, 64, 128, 256]) for _ in range(3)]) | |
| self.head = nn.Conv2d(32, num_joints, 1) | |
| def forward(self, x): | |
| x = self.stem(x) | |
| x = self.layer1(x) | |
| x = [t(x) for t in self.trans1] | |
| x = self.stage2(x) | |
| x = [x[0], x[1], self.trans2(x[1])] | |
| for m in self.stage3: | |
| x = m(x) | |
| x = [x[0], x[1], x[2], self.trans3(x[2])] | |
| for m in self.stage4: | |
| x = m(x) | |
| return self.head(x[0]) | |
| # ββ Load weights ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Downloading model weights...") | |
| model_path = hf_hub_download( | |
| repo_id="cwlachap/hrnet-cephalometric-landmark-detection", | |
| filename="best_model.pth" | |
| ) | |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint)) | |
| model = HRNet(num_joints=19) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded. Missing: {len(missing)} | Unexpected: {len(unexpected)}") | |
| model.eval() | |
| print("Model ready.") | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LM_IDS = ['S', 'N', 'Or', 'Po', 'ANS', 'PNS', 'A', 'U1tip', 'L1tip', 'B', | |
| 'Pog', 'Me', 'Gn', 'Go', 'Co', 'L1ap', 'U1ap', 'U6', 'L6'] | |
| LM_COLORS = { | |
| 'S': '#58a6ff', 'N': '#58a6ff', 'Or': '#58a6ff', 'Po': '#58a6ff', | |
| 'ANS': '#3fb950', 'PNS': '#3fb950', 'A': '#3fb950', | |
| 'U1tip': '#3fb950', 'U1ap': '#3fb950', 'U6': '#3fb950', | |
| 'B': '#f0883e', 'L1tip': '#f0883e', 'L1ap': '#f0883e', | |
| 'Pog': '#f0883e', 'Me': '#f0883e', 'Gn': '#f0883e', | |
| 'Go': '#f0883e', 'Co': '#f0883e', 'L6': '#f0883e' | |
| } | |
| INPUT_W, INPUT_H = 256, 320 | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(pil_img): | |
| img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] | |
| return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float() | |
| def heatmap_to_coords(hm_np, orig_w, orig_h): | |
| coords = {} | |
| nj, hh, hw = hm_np.shape | |
| for j in range(min(nj, len(LM_IDS))): | |
| hm = hm_np[j] | |
| idx = int(hm.argmax()) | |
| py, px = divmod(idx, hw) | |
| if 1 <= px < hw - 1 and 1 <= py < hh - 1: | |
| px += 0.25 * np.sign(float(hm[py, px + 1] - hm[py, px - 1])) | |
| py += 0.25 * np.sign(float(hm[py + 1, px] - hm[py - 1, px])) | |
| x_norm = float(np.clip((px / hw) * (INPUT_W / orig_w), 0, 1)) | |
| y_norm = float(np.clip((py / hh) * (INPUT_H / orig_h), 0, 1)) | |
| coords[LM_IDS[j]] = {"x": round(x_norm, 4), "y": round(y_norm, 4), "confidence": 0.85} | |
| return coords | |
| def run_detection(pil_img): | |
| orig_w, orig_h = pil_img.size | |
| tensor = preprocess(pil_img) | |
| with torch.no_grad(): | |
| hm = model(tensor)[0].numpy() | |
| return heatmap_to_coords(hm, orig_w, orig_h) | |
| # ββ Gradio functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_visual(pil_img): | |
| if pil_img is None: | |
| return None, "{}" | |
| coords = run_detection(pil_img) | |
| out = pil_img.copy().convert("RGB") | |
| draw = ImageDraw.Draw(out) | |
| w, h = out.size | |
| r = max(4, w // 120) | |
| for lm_id, pt in coords.items(): | |
| cx, cy = int(pt['x'] * w), int(pt['y'] * h) | |
| col = LM_COLORS.get(lm_id, '#ffffff') | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, outline='black', width=1) | |
| draw.text((cx + r + 2, cy - r), lm_id, fill=col) | |
| return out, json.dumps({"landmarks": coords}, indent=2) | |
| def detect_api(image_b64: str) -> str: | |
| try: | |
| img_bytes = base64.b64decode(image_b64) | |
| pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB') | |
| coords = run_detection(pil_img) | |
| return json.dumps({"landmarks": coords}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| # ββ UI β single Blocks interface only, no nested Interface ββββββββββββββββββββ | |
| with gr.Blocks(title="OrthoTimes Landmark Detection") as demo: | |
| gr.Markdown("## OrthoTimes QuickCephTool β HRNet Landmark Detection\nDetects 19 cephalometric landmarks automatically.") | |
| with gr.Row(): | |
| img_in = gr.Image(type="pil", label="Upload lateral cephalogram") | |
| with gr.Column(): | |
| img_out = gr.Image(type="pil", label="Detected landmarks") | |
| json_out = gr.Textbox(label="JSON output", lines=12) | |
| gr.Button("βΆ Detect Landmarks").click(fn=detect_visual, inputs=img_in, outputs=[img_out, json_out]) | |
| # API tab for programmatic access | |
| with gr.Tab("API"): | |
| b64_in = gr.Textbox(label="Base64 encoded image", lines=3) | |
| b64_out = gr.Textbox(label="JSON result", lines=10) | |
| gr.Button("Detect (API)").click(fn=detect_api, inputs=b64_in, outputs=b64_out) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |