Landmark / app.py
mujtaba1212's picture
Update app.py
c70a6c2 verified
raw
history blame
11.3 kB
# Patch missing audioop module (removed in Python 3.13) before importing gradio
import sys
import types
if 'audioop' not in sys.modules:
_audioop = types.ModuleType('audioop')
sys.modules['audioop'] = _audioop
import io
import json
import base64
import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import gradio as gr
# ── HRNet-W32 ─────────────────────────────────────────────────────────────────
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.relu(out + (self.downsample(x) if self.downsample else x))
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
return self.relu(out + (self.downsample(x) if self.downsample else x))
def make_layer(block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion, 1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = [block(inplanes, planes, stride, downsample)]
for _ in range(1, blocks):
layers.append(block(planes * block.expansion, planes))
return nn.Sequential(*layers)
class FuseLayer(nn.Module):
def __init__(self, num_branches, num_channels):
super().__init__()
self.num_branches = num_branches
fuse = []
for i in range(num_branches):
row = []
for j in range(num_branches):
if j > i:
row.append(nn.Sequential(
nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False),
nn.BatchNorm2d(num_channels[i]),
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')
))
elif j == i:
row.append(nn.Identity())
else:
convs = []
for k in range(i - j):
inc = num_channels[j] if k == 0 else num_channels[i]
convs += [nn.Conv2d(inc, num_channels[i], 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_channels[i])]
if k < i - j - 1:
convs.append(nn.ReLU(True))
row.append(nn.Sequential(*convs))
fuse.append(nn.ModuleList(row))
self.fuse = nn.ModuleList(fuse)
self.relu = nn.ReLU(True)
def forward(self, x):
out = []
for i in range(self.num_branches):
y = x[0] if i == 0 else self.fuse[i][0](x[0])
for j in range(1, self.num_branches):
y = y + (x[j] if i == j else self.fuse[i][j](x[j]))
out.append(self.relu(y))
return out
class HRStage(nn.Module):
def __init__(self, num_branches, block, num_blocks, num_channels):
super().__init__()
self.branches = nn.ModuleList([
nn.Sequential(*[block(num_channels[i], num_channels[i]) for _ in range(num_blocks)])
for i in range(num_branches)
])
self.fuse = FuseLayer(num_branches, num_channels)
def forward(self, x):
x = [b(xi) for b, xi in zip(self.branches, x)]
return self.fuse(x)
class HRNet(nn.Module):
def __init__(self, num_joints=19):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
)
self.layer1 = make_layer(Bottleneck, 64, 64, 4)
self.trans1 = nn.ModuleList([
nn.Sequential(nn.Conv2d(256, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True)),
nn.Sequential(nn.Conv2d(256, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)),
])
self.stage2 = HRStage(2, BasicBlock, 4, [32, 64])
self.trans2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True))
self.stage3 = nn.Sequential(*[HRStage(3, BasicBlock, 4, [32, 64, 128]) for _ in range(4)])
self.trans3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True))
self.stage4 = nn.Sequential(*[HRStage(4, BasicBlock, 4, [32, 64, 128, 256]) for _ in range(3)])
self.head = nn.Conv2d(32, num_joints, 1)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = [t(x) for t in self.trans1]
x = self.stage2(x)
x = [x[0], x[1], self.trans2(x[1])]
for m in self.stage3:
x = m(x)
x = [x[0], x[1], x[2], self.trans3(x[2])]
for m in self.stage4:
x = m(x)
return self.head(x[0])
# ── Load weights ──────────────────────────────────────────────────────────────
print("Downloading model weights...")
model_path = hf_hub_download(
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
filename="best_model.pth"
)
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
model = HRNet(num_joints=19)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Loaded. Missing: {len(missing)} | Unexpected: {len(unexpected)}")
model.eval()
print("Model ready.")
# ── Constants ─────────────────────────────────────────────────────────────────
LM_IDS = ['S', 'N', 'Or', 'Po', 'ANS', 'PNS', 'A', 'U1tip', 'L1tip', 'B',
'Pog', 'Me', 'Gn', 'Go', 'Co', 'L1ap', 'U1ap', 'U6', 'L6']
LM_COLORS = {
'S': '#58a6ff', 'N': '#58a6ff', 'Or': '#58a6ff', 'Po': '#58a6ff',
'ANS': '#3fb950', 'PNS': '#3fb950', 'A': '#3fb950',
'U1tip': '#3fb950', 'U1ap': '#3fb950', 'U6': '#3fb950',
'B': '#f0883e', 'L1tip': '#f0883e', 'L1ap': '#f0883e',
'Pog': '#f0883e', 'Me': '#f0883e', 'Gn': '#f0883e',
'Go': '#f0883e', 'Co': '#f0883e', 'L6': '#f0883e'
}
INPUT_W, INPUT_H = 256, 320
# ── Helpers ───────────────────────────────────────────────────────────────────
def preprocess(pil_img):
img = pil_img.convert('RGB').resize((INPUT_W, INPUT_H), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()
def heatmap_to_coords(hm_np, orig_w, orig_h):
coords = {}
nj, hh, hw = hm_np.shape
for j in range(min(nj, len(LM_IDS))):
hm = hm_np[j]
idx = int(hm.argmax())
py, px = divmod(idx, hw)
if 1 <= px < hw - 1 and 1 <= py < hh - 1:
px += 0.25 * np.sign(float(hm[py, px + 1] - hm[py, px - 1]))
py += 0.25 * np.sign(float(hm[py + 1, px] - hm[py - 1, px]))
x_norm = float(np.clip((px / hw) * (INPUT_W / orig_w), 0, 1))
y_norm = float(np.clip((py / hh) * (INPUT_H / orig_h), 0, 1))
coords[LM_IDS[j]] = {"x": round(x_norm, 4), "y": round(y_norm, 4), "confidence": 0.85}
return coords
def run_detection(pil_img):
orig_w, orig_h = pil_img.size
tensor = preprocess(pil_img)
with torch.no_grad():
hm = model(tensor)[0].numpy()
return heatmap_to_coords(hm, orig_w, orig_h)
# ── Gradio functions ──────────────────────────────────────────────────────────
def detect_visual(pil_img):
if pil_img is None:
return None, "{}"
coords = run_detection(pil_img)
out = pil_img.copy().convert("RGB")
draw = ImageDraw.Draw(out)
w, h = out.size
r = max(4, w // 120)
for lm_id, pt in coords.items():
cx, cy = int(pt['x'] * w), int(pt['y'] * h)
col = LM_COLORS.get(lm_id, '#ffffff')
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, outline='black', width=1)
draw.text((cx + r + 2, cy - r), lm_id, fill=col)
return out, json.dumps({"landmarks": coords}, indent=2)
def detect_api(image_b64: str) -> str:
try:
img_bytes = base64.b64decode(image_b64)
pil_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
coords = run_detection(pil_img)
return json.dumps({"landmarks": coords})
except Exception as e:
return json.dumps({"error": str(e)})
# ── UI β€” single Blocks interface only, no nested Interface ────────────────────
with gr.Blocks(title="OrthoTimes Landmark Detection") as demo:
gr.Markdown("## OrthoTimes QuickCephTool β€” HRNet Landmark Detection\nDetects 19 cephalometric landmarks automatically.")
with gr.Row():
img_in = gr.Image(type="pil", label="Upload lateral cephalogram")
with gr.Column():
img_out = gr.Image(type="pil", label="Detected landmarks")
json_out = gr.Textbox(label="JSON output", lines=12)
gr.Button("β–Ά Detect Landmarks").click(fn=detect_visual, inputs=img_in, outputs=[img_out, json_out])
# API tab for programmatic access
with gr.Tab("API"):
b64_in = gr.Textbox(label="Base64 encoded image", lines=3)
b64_out = gr.Textbox(label="JSON result", lines=10)
gr.Button("Detect (API)").click(fn=detect_api, inputs=b64_in, outputs=b64_out)
demo.launch(server_name="0.0.0.0", server_port=7860)