Klinapps's picture
Upload app.py
4ea5bee verified
"""
Cephalometric Landmark Detection API
HRNet-W32 - 19 Landmarks Cefalométricos
Versión Final con Visualización Mejorada
"""
import os
import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import gradio as gr
# ============================================================================
# CONFIGURACIÓN
# ============================================================================
MODEL_REPO = "cwlachap/hrnet-cephalometric-landmark-detection"
MODEL_FILE = "best_model.pth"
NUM_LANDMARKS = 19
INPUT_SIZE = 768
# Orden según ISBI/cwlachap
LANDMARK_NAMES = [
"Sella (S)",
"Nasion (N)",
"Orbitale (Or)",
"Porion (Po)",
"Point A",
"Point B",
"Pogonion (Pog)",
"Menton (Me)",
"Gnathion (Gn)",
"Gonion (Go)",
"Lower Incisor (L1)",
"Upper Incisor (U1)",
"Upper Lip (UL)",
"Lower Lip (LL)",
"Subnasale (Sn)",
"Soft Tissue Pog (Pog')",
"PNS",
"ANS",
"Articulare (Ar)"
]
# Abreviaturas para mostrar en imagen
LANDMARK_ABBREV = [
"S", "N", "Or", "Po", "A", "B", "Pog", "Me", "Gn", "Go",
"L1", "U1", "UL", "LL", "Sn", "Pog'", "PNS", "ANS", "Ar"
]
# Colores por grupo anatómico
LANDMARK_GROUPS = {
"Craneal Base": {"indices": [0, 1, 18], "color": (255, 80, 80)}, # Rojo
"Orbital": {"indices": [2, 3], "color": (80, 255, 80)}, # Verde
"Maxilar": {"indices": [4, 16, 17], "color": (80, 150, 255)}, # Azul
"Mandibular": {"indices": [5, 6, 7, 8, 9], "color": (255, 80, 255)}, # Magenta
"Dental": {"indices": [10, 11], "color": (255, 255, 80)}, # Amarillo
"Tejido Blando": {"indices": [12, 13, 14, 15], "color": (80, 255, 255)} # Cyan
}
def get_color_for_landmark(idx):
for group_data in LANDMARK_GROUPS.values():
if idx in group_data["indices"]:
return group_data["color"]
return (255, 255, 255)
# ============================================================================
# ARQUITECTURA HRNET-W32
# ============================================================================
BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.downsample is not None:
residual = self.downsample(x)
return self.relu(out + residual)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.downsample is not None:
residual = self.downsample(x)
return self.relu(out + residual)
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self.num_inchannels = num_inchannels
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
downsample = None
if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, 1, stride, bias=False),
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM))
layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)]
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
for _ in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
return nn.ModuleList([self._make_one_branch(i, block, num_blocks, num_channels) for i in range(num_branches)])
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
fuse_layers = []
for i in range(self.num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(self.num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(self.num_inchannels[j], self.num_inchannels[i], 1, bias=False),
nn.BatchNorm2d(self.num_inchannels[i]),
nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
out_ch = self.num_inchannels[i] if k == i - j - 1 else self.num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(self.num_inchannels[j], out_ch, 3, 2, 1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(True) if k < i - j - 1 else nn.Identity()))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
class HRNet(nn.Module):
def __init__(self, num_joints=19):
super(HRNet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, 64, 4)
self.stage2_cfg = {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4], 'NUM_CHANNELS': [32, 64]}
num_channels = [ch * BasicBlock.expansion for ch in self.stage2_cfg['NUM_CHANNELS']]
self.transition1 = self._make_transition_layer([256], num_channels)
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
self.stage3_cfg = {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4, 4], 'NUM_CHANNELS': [32, 64, 128]}
num_channels = [ch * BasicBlock.expansion for ch in self.stage3_cfg['NUM_CHANNELS']]
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
self.stage4_cfg = {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4, 4, 4], 'NUM_CHANNELS': [32, 64, 128, 256]}
num_channels = [ch * BasicBlock.expansion for ch in self.stage4_cfg['NUM_CHANNELS']]
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=False)
self.final_layer = nn.Conv2d(pre_stage_channels[0], num_joints, kernel_size=1, stride=1, padding=0)
def _make_transition_layer(self, num_channels_pre, num_channels_cur):
num_branches_cur = len(num_channels_cur)
num_branches_pre = len(num_channels_pre)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur[i] != num_channels_pre[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre[i], num_channels_cur[i], 3, 1, 1, bias=False),
nn.BatchNorm2d(num_channels_cur[i]), nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre[-1]
outchannels = num_channels_cur[i] if j == i - num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM))
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
modules = []
for i in range(num_modules):
reset_multi_scale = multi_scale_output or i < num_modules - 1
modules.append(HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, 'SUM', reset_multi_scale))
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.layer1(x)
x_list = [self.transition1[i](x) if self.transition1[i] else x for i in range(self.stage2_cfg['NUM_BRANCHES'])]
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
idx = min(i, len(y_list)-1)
x_list.append(self.transition2[i](y_list[idx]) if self.transition2[i] else y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
idx = min(i, len(y_list)-1)
x_list.append(self.transition3[i](y_list[idx]) if self.transition3[i] else y_list[i])
y_list = self.stage4(x_list)
return self.final_layer(y_list[0])
# ============================================================================
# MODELO GLOBAL
# ============================================================================
model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model():
global model
if model is not None:
return model
print(f"Cargando modelo en {device}...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
new_state_dict = {}
for k, v in state_dict.items():
name = k
for prefix in ['module.', 'backbone.', 'model.']:
if name.startswith(prefix):
name = name[len(prefix):]
new_state_dict[name] = v
model = HRNet(num_joints=NUM_LANDMARKS)
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
print("✓ Modelo cargado!")
return model
def detect_landmarks(image):
model = load_model()
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
original_size = image.size
if image.mode != 'RGB':
image = image.convert('RGB')
image_resized = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.BILINEAR)
img_array = np.array(image_resized).astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_array = (img_array - mean) / std
img_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
heatmaps = output.cpu().numpy()
batch_size, num_joints, h, w = heatmaps.shape
heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
preds = np.zeros((batch_size, num_joints, 2), dtype=np.float32)
preds[:, :, 0] = idx % w
preds[:, :, 1] = idx // w
orig_w, orig_h = original_size
scale_x = orig_w / w
scale_y = orig_h / h
landmarks = []
for i in range(NUM_LANDMARKS):
landmarks.append({
"id": i,
"name": LANDMARK_NAMES[i],
"abbrev": LANDMARK_ABBREV[i],
"x": round(float(preds[0, i, 0] * scale_x), 1),
"y": round(float(preds[0, i, 1] * scale_y), 1),
"confidence": round(float(maxvals[0, i]), 2)
})
return landmarks
def draw_landmarks_clean(image, landmarks, show_labels=True):
"""Dibuja landmarks con visualización limpia"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_draw = image.copy().convert('RGB')
draw = ImageDraw.Draw(img_draw)
try:
font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 10)
font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11)
except:
font_small = ImageFont.load_default()
font_label = font_small
radius = 6
# Dibujar puntos
for lm in landmarks:
x, y = lm['x'], lm['y']
color = get_color_for_landmark(lm['id'])
# Círculo con borde negro
draw.ellipse([x-radius-1, y-radius-1, x+radius+1, y+radius+1], fill=(0, 0, 0))
draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill=color)
if show_labels:
# Calcular posiciones de etiquetas para evitar superposición
label_positions = []
for lm in landmarks:
x, y = lm['x'], lm['y']
# Buscar posición libre para la etiqueta
best_pos = find_label_position(x, y, lm['abbrev'], label_positions, font_label, draw, img_draw.size)
label_positions.append(best_pos)
lx, ly = best_pos
text = lm['abbrev']
# Línea conectora si la etiqueta está lejos
dist = ((lx - x)**2 + (ly - y)**2)**0.5
if dist > 15:
draw.line([(x, y), (lx, ly)], fill=(100, 100, 100), width=1)
# Fondo de etiqueta
bbox = draw.textbbox((lx, ly), text, font=font_label)
padding = 2
draw.rectangle([bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding],
fill=(0, 0, 0, 200))
# Texto
color = get_color_for_landmark(lm['id'])
draw.text((lx, ly), text, fill=color, font=font_label)
return img_draw
def find_label_position(x, y, text, existing_positions, font, draw, img_size):
"""Encuentra una posición para la etiqueta que no se superponga"""
bbox = draw.textbbox((0, 0), text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
# Posiciones candidatas (en orden de preferencia)
offsets = [
(12, -5), # Derecha
(12, 5), # Derecha abajo
(-text_w-12, -5), # Izquierda
(-text_w-12, 5), # Izquierda abajo
(5, -text_h-8), # Arriba
(5, 12), # Abajo
(15, -10), # Diagonal derecha arriba
(15, 10), # Diagonal derecha abajo
(-text_w-15, -10), # Diagonal izquierda arriba
(-text_w-15, 10), # Diagonal izquierda abajo
]
img_w, img_h = img_size
for ox, oy in offsets:
lx, ly = x + ox, y + oy
# Verificar límites de imagen
if lx < 5 or ly < 5 or lx + text_w > img_w - 5 or ly + text_h > img_h - 5:
continue
# Verificar superposición con etiquetas existentes
overlap = False
for ex, ey in existing_positions:
if abs(lx - ex) < text_w + 5 and abs(ly - ey) < text_h + 3:
overlap = True
break
if not overlap:
return (lx, ly)
# Si todas fallan, usar posición por defecto
return (x + 12, y - 5)
def create_legend_image(landmarks):
"""Crea imagen con leyenda de landmarks"""
width = 280
height = 520
legend = Image.new('RGB', (width, height), (30, 30, 30))
draw = ImageDraw.Draw(legend)
try:
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
font_text = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11)
except:
font_title = ImageFont.load_default()
font_text = font_title
y_pos = 10
draw.text((10, y_pos), "LANDMARKS DETECTADOS", fill=(255, 255, 255), font=font_title)
y_pos += 25
draw.line([(10, y_pos), (width-10, y_pos)], fill=(100, 100, 100), width=1)
y_pos += 10
for group_name, group_data in LANDMARK_GROUPS.items():
# Título del grupo
draw.text((10, y_pos), group_name, fill=group_data["color"], font=font_title)
y_pos += 18
for idx in group_data["indices"]:
lm = landmarks[idx]
# Círculo de color
draw.ellipse([15, y_pos+2, 23, y_pos+10], fill=group_data["color"])
# Texto
text = f"{lm['abbrev']}: ({lm['x']:.0f}, {lm['y']:.0f})"
draw.text((30, y_pos), text, fill=(220, 220, 220), font=font_text)
y_pos += 16
y_pos += 8
return legend
def combine_images(main_image, legend):
"""Combina imagen principal con leyenda"""
main_w, main_h = main_image.size
legend_w, legend_h = legend.size
# Redimensionar leyenda si es necesario
if legend_h > main_h:
ratio = main_h / legend_h
legend = legend.resize((int(legend_w * ratio), main_h), Image.Resampling.LANCZOS)
legend_w, legend_h = legend.size
combined = Image.new('RGB', (main_w + legend_w + 10, max(main_h, legend_h)), (30, 30, 30))
combined.paste(main_image, (0, 0))
combined.paste(legend, (main_w + 10, 0))
return combined
# ============================================================================
# INTERFAZ GRADIO
# ============================================================================
def process_image(image, show_labels):
if image is None:
return None, "Por favor sube una imagen"
try:
landmarks = detect_landmarks(image)
annotated = draw_landmarks_clean(image, landmarks, show_labels)
legend = create_legend_image(landmarks)
combined = combine_images(annotated, legend)
# JSON output
result = {
"num_landmarks": len(landmarks),
"landmarks": [{
"id": lm["id"],
"name": lm["name"],
"x": lm["x"],
"y": lm["y"]
} for lm in landmarks]
}
return combined, json.dumps(result, indent=2)
except Exception as e:
import traceback
return None, f"Error: {e}\n{traceback.format_exc()}"
print("=" * 50)
print("Cephalometric Landmark Detection v1.0")
print("=" * 50)
load_model()
with gr.Blocks(title="Cephalometric Landmark Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🦷 Detección de Landmarks Cefalométricos
Detección automática de **19 puntos anatómicos** en radiografías cefalométricas laterales usando HRNet-W32.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="📤 Subir Radiografía", type="pil", height=400)
show_labels = gr.Checkbox(label="Mostrar etiquetas en imagen", value=True)
detect_btn = gr.Button("🔍 Detectar Landmarks", variant="primary", size="lg")
with gr.Column(scale=2):
output_image = gr.Image(label="📍 Resultado con Leyenda", height=500)
with gr.Accordion("📋 Datos JSON", open=False):
output_json = gr.Code(label="Coordenadas", language="json", lines=10)
gr.Markdown("""
---
### 📍 Grupos de Landmarks
| Color | Grupo | Landmarks |
|-------|-------|-----------|
| 🔴 | Craneal | Sella (S), Nasion (N), Articulare (Ar) |
| 🟢 | Orbital | Orbitale (Or), Porion (Po) |
| 🔵 | Maxilar | Point A, ANS, PNS |
| 🟣 | Mandibular | Point B, Pogonion, Menton, Gnathion, Gonion |
| 🟡 | Dental | Upper Incisor (U1), Lower Incisor (L1) |
| 🩵 | Tejido Blando | Upper Lip, Lower Lip, Subnasale, Soft Tissue Pog |
---
> ⚠️ **Nota:** Los puntos detectados son una aproximación inicial. En pacientes en crecimiento y desarrollo o con condiciones anatómicas atípicas, se recomienda ajuste manual.
""")
detect_btn.click(fn=process_image, inputs=[input_image, show_labels], outputs=[output_image, output_json])
demo.queue()
if __name__ == "__main__":
demo.launch(ssr_mode=False)