Klinapps commited on
Commit
11e7d47
·
verified ·
1 Parent(s): 17ade45

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +295 -320
  2. requirements.txt +0 -8
app.py CHANGED
@@ -1,105 +1,272 @@
1
  """
2
  Cephalometric Landmark Detection API
3
- HRNet-W32 based automatic landmark detection for lateral cephalometric radiographs
4
-
5
- Space para integración con Klinafy
6
  """
7
 
8
  import os
9
  import json
10
  import numpy as np
11
- from PIL import Image
12
  import torch
13
- import torch.nn.functional as F
14
  from huggingface_hub import hf_hub_download
15
  import gradio as gr
16
 
17
- from hrnet import get_hrnet_w32
18
-
19
-
20
  # ============================================================================
21
  # CONFIGURACIÓN
22
  # ============================================================================
23
 
24
  MODEL_REPO = "cwlachap/hrnet-cephalometric-landmark-detection"
25
  MODEL_FILE = "best_model.pth"
26
- INPUT_SIZE = 768
27
- HEATMAP_SIZE = 192
28
  NUM_LANDMARKS = 19
 
29
 
30
- # Nombres de los 19 landmarks en orden del modelo
31
  LANDMARK_NAMES = [
32
- "S", # 0 - Sella turcica
33
- "N", # 1 - Nasion
34
- "Or", # 2 - Orbitale
35
- "Po", # 3 - Porion
36
- "Ba", # 4 - Basion
37
- "Pt", # 5 - Pterygoid point
38
- "ANS", # 6 - Anterior Nasal Spine
39
- "PNS", # 7 - Posterior Nasal Spine
40
- "A", # 8 - Point A (Subspinale)
41
- "U1T", # 9 - Upper Incisor Tip
42
- "U1R", # 10 - Upper Incisor Root
43
- "L1T", # 11 - Lower Incisor Tip
44
- "L1R", # 12 - Lower Incisor Root
45
- "B", # 13 - Point B (Supramentale)
46
- "Pog", # 14 - Pogonion
47
- "Gn", # 15 - Gnathion
48
- "Me", # 16 - Menton
49
- "Go", # 17 - Gonion
50
- "Ar" # 18 - Articulare
51
  ]
52
 
53
- # Colores para visualización (RGB)
54
  LANDMARK_COLORS = {
55
- "cranial": (255, 0, 0), # Rojo - S, N, Ba, Ar
56
- "orbital": (0, 255, 0), # Verde - Or, Po
57
- "maxilar": (0, 0, 255), # Azul - ANS, PNS, A, Pt
58
- "dental": (255, 255, 0), # Amarillo - U1T, U1R, L1T, L1R
59
- "mandibular": (255, 0, 255) # Magenta - B, Pog, Gn, Me, Go
 
60
  }
61
 
62
- LANDMARK_GROUPS = {
63
- "S": "cranial", "N": "cranial", "Ba": "cranial", "Ar": "cranial",
64
- "Or": "orbital", "Po": "orbital",
65
- "ANS": "maxilar", "PNS": "maxilar", "A": "maxilar", "Pt": "maxilar",
66
- "U1T": "dental", "U1R": "dental", "L1T": "dental", "L1R": "dental",
67
- "B": "mandibular", "Pog": "mandibular", "Gn": "mandibular",
68
- "Me": "mandibular", "Go": "mandibular"
69
- }
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # ============================================================================
73
- # MODELO
74
  # ============================================================================
75
 
76
- # Variable global para el modelo
77
  model = None
78
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
 
80
-
81
  def load_model():
82
- """Carga el modelo HRNet desde Hugging Face Hub"""
83
  global model
84
-
85
  if model is not None:
86
  return model
87
 
88
  print(f"Cargando modelo en {device}...")
89
-
90
- # Descargar pesos
91
- model_path = hf_hub_download(
92
- repo_id=MODEL_REPO,
93
- filename=MODEL_FILE
94
- )
95
-
96
- # Crear modelo
97
- model = get_hrnet_w32(num_landmarks=NUM_LANDMARKS)
98
-
99
- # Cargar pesos
100
  checkpoint = torch.load(model_path, map_location=device, weights_only=False)
101
 
102
- # Manejar diferentes formatos de checkpoint
103
  if 'model_state_dict' in checkpoint:
104
  state_dict = checkpoint['model_state_dict']
105
  elif 'state_dict' in checkpoint:
@@ -107,339 +274,147 @@ def load_model():
107
  else:
108
  state_dict = checkpoint
109
 
110
- # Limpiar prefijos si existen
 
 
 
 
 
111
  new_state_dict = {}
112
  for k, v in state_dict.items():
113
- name = k.replace('module.', '') # Remover prefijo de DataParallel
 
 
 
114
  new_state_dict[name] = v
115
 
116
- model.load_state_dict(new_state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
117
  model.to(device)
118
  model.eval()
119
-
120
- print("Modelo cargado exitosamente!")
121
  return model
122
 
 
 
 
 
 
 
 
 
 
123
 
124
- # ============================================================================
125
- # PREPROCESAMIENTO
126
- # ============================================================================
127
-
128
- def preprocess_image(image):
129
- """
130
- Preprocesa la imagen para el modelo
131
-
132
- Args:
133
- image: PIL Image o numpy array
134
 
135
- Returns:
136
- tensor: Tensor normalizado [1, 3, 768, 768]
137
- original_size: (width, height) original
138
- """
139
- # Convertir a PIL si es necesario
140
  if isinstance(image, np.ndarray):
141
  image = Image.fromarray(image)
142
 
143
- # Guardar tamaño original
144
- original_size = image.size # (width, height)
145
-
146
- # Convertir a RGB si es necesario
147
  if image.mode != 'RGB':
148
  image = image.convert('RGB')
149
 
150
- # Redimensionar a 768x768
151
- image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.BILINEAR)
152
-
153
- # Convertir a tensor
154
- img_array = np.array(image).astype(np.float32) / 255.0
155
 
156
- # Normalizar con ImageNet stats
157
  mean = np.array([0.485, 0.456, 0.406])
158
  std = np.array([0.229, 0.224, 0.225])
159
  img_array = (img_array - mean) / std
160
 
161
- # Cambiar a formato CHW y agregar batch dimension
162
- img_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float()
163
- img_tensor = img_tensor.unsqueeze(0)
164
-
165
- return img_tensor, original_size
166
-
167
-
168
- # ============================================================================
169
- # POSTPROCESAMIENTO
170
- # ============================================================================
171
-
172
- def get_max_preds(heatmaps):
173
- """
174
- Obtiene las coordenadas del máximo de cada heatmap
175
-
176
- Args:
177
- heatmaps: tensor [batch, num_landmarks, H, W]
178
 
179
- Returns:
180
- preds: coordenadas [batch, num_landmarks, 2]
181
- maxvals: valores de confianza [batch, num_landmarks, 1]
182
- """
183
- batch_size = heatmaps.shape[0]
184
- num_joints = heatmaps.shape[1]
185
- width = heatmaps.shape[3]
186
-
187
- heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
188
- idx = np.argmax(heatmaps_reshaped, axis=2)
189
- maxvals = np.amax(heatmaps_reshaped, axis=2)
190
-
191
- maxvals = maxvals.reshape((batch_size, num_joints, 1))
192
- idx = idx.reshape((batch_size, num_joints, 1))
193
-
194
- preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
195
- preds[:, :, 0] = (preds[:, :, 0]) % width
196
- preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
197
-
198
- return preds, maxvals
199
-
200
-
201
- def heatmaps_to_landmarks(heatmaps, original_size):
202
- """
203
- Convierte heatmaps a coordenadas de landmarks
204
-
205
- Args:
206
- heatmaps: tensor [1, 19, H, W]
207
- original_size: (width, height) de la imagen original
208
-
209
- Returns:
210
- landmarks: lista de dicts con name, x, y, confidence
211
- """
212
- heatmaps_np = heatmaps.cpu().numpy()
213
 
214
- # Obtener coordenadas del máximo
215
- preds, maxvals = get_max_preds(heatmaps_np)
216
 
217
- # Escalar a tamaño original
218
  orig_w, orig_h = original_size
219
- heatmap_h, heatmap_w = heatmaps_np.shape[2], heatmaps_np.shape[3]
220
 
221
  scale_x = orig_w / heatmap_w
222
  scale_y = orig_h / heatmap_h
223
 
224
  landmarks = []
225
  for i in range(NUM_LANDMARKS):
226
- x = float(preds[0, i, 0] * scale_x)
227
- y = float(preds[0, i, 1] * scale_y)
228
- conf = float(maxvals[0, i, 0])
229
-
230
  landmarks.append({
231
  "name": LANDMARK_NAMES[i],
232
- "x": round(x, 2),
233
- "y": round(y, 2),
234
- "confidence": round(conf, 4),
235
- "group": LANDMARK_GROUPS[LANDMARK_NAMES[i]]
236
  })
237
 
238
- return landmarks
239
-
240
-
241
- # ============================================================================
242
- # INFERENCIA
243
- # ============================================================================
244
-
245
- def detect_landmarks(image):
246
- """
247
- Detecta landmarks cefalométricos en una imagen
248
-
249
- Args:
250
- image: PIL Image o numpy array
251
-
252
- Returns:
253
- landmarks: lista de dicts con name, x, y, confidence
254
- annotated_image: imagen con landmarks dibujados
255
- """
256
- # Cargar modelo si no está cargado
257
- model = load_model()
258
-
259
- # Preprocesar
260
- img_tensor, original_size = preprocess_image(image)
261
- img_tensor = img_tensor.to(device)
262
-
263
- # Inferencia
264
- with torch.no_grad():
265
- heatmaps = model(img_tensor)
266
-
267
- # Postprocesar
268
- landmarks = heatmaps_to_landmarks(heatmaps, original_size)
269
-
270
- # Crear imagen anotada
271
  annotated = draw_landmarks(image, landmarks)
272
-
273
  return landmarks, annotated
274
 
275
-
276
- def draw_landmarks(image, landmarks, radius=5):
277
- """
278
- Dibuja los landmarks en la imagen
279
-
280
- Args:
281
- image: PIL Image o numpy array
282
- landmarks: lista de dicts con coordenadas
283
- radius: radio del círculo
284
-
285
- Returns:
286
- PIL Image con landmarks dibujados
287
- """
288
- from PIL import ImageDraw, ImageFont
289
-
290
  if isinstance(image, np.ndarray):
291
  image = Image.fromarray(image)
292
-
293
- # Crear copia para dibujar
294
- img_draw = image.copy()
295
- if img_draw.mode != 'RGB':
296
- img_draw = img_draw.convert('RGB')
297
-
298
  draw = ImageDraw.Draw(img_draw)
299
 
300
- # Intentar cargar fuente, usar default si falla
301
  try:
302
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
303
  except:
304
  font = ImageFont.load_default()
305
 
306
  for lm in landmarks:
307
  x, y = lm['x'], lm['y']
308
- color = LANDMARK_COLORS[lm['group']]
309
-
310
- # Dibujar círculo
311
- draw.ellipse([x-radius, y-radius, x+radius, y+radius],
312
- fill=color, outline=(255, 255, 255))
313
-
314
- # Dibujar nombre
315
- draw.text((x+radius+2, y-radius), lm['name'],
316
- fill=(255, 255, 255), font=font,
317
- stroke_width=1, stroke_fill=(0, 0, 0))
318
 
319
  return img_draw
320
 
321
-
322
  # ============================================================================
323
  # INTERFAZ GRADIO
324
  # ============================================================================
325
 
326
  def process_image(image):
327
- """Función principal para Gradio"""
328
  if image is None:
329
- return None, "Por favor sube una imagen cefalométrica"
330
 
331
  try:
332
  landmarks, annotated = detect_landmarks(image)
333
-
334
- # Formatear JSON para mostrar
335
- json_output = json.dumps({
336
- "success": True,
337
- "num_landmarks": len(landmarks),
338
- "landmarks": landmarks
339
- }, indent=2)
340
-
341
- return annotated, json_output
342
-
343
- except Exception as e:
344
- return None, json.dumps({
345
- "success": False,
346
- "error": str(e)
347
- }, indent=2)
348
-
349
-
350
- def api_predict(image):
351
- """Endpoint API para integración con Klinafy"""
352
- if image is None:
353
- return {"success": False, "error": "No image provided"}
354
-
355
- try:
356
- landmarks, _ = detect_landmarks(image)
357
-
358
- return {
359
  "success": True,
360
- "model": "HRNet-W32",
361
  "num_landmarks": len(landmarks),
362
  "landmarks": landmarks
363
  }
364
-
365
  except Exception as e:
366
- return {
367
- "success": False,
368
- "error": str(e)
369
- }
370
-
371
 
372
- # ============================================================================
373
- # CREAR APP
374
- # ============================================================================
375
-
376
- # Cargar modelo al inicio
377
- print("Inicializando modelo...")
378
  load_model()
379
 
380
- # Crear interfaz Gradio
381
- with gr.Blocks(title="Cephalometric Landmark Detection") as demo:
382
- gr.Markdown("""
383
- # 🦷 Detección de Landmarks Cefalométricos
384
-
385
- Detección automática de **19 puntos cefalométricos** usando HRNet-W32.
386
-
387
- ### Landmarks detectados:
388
- - **Craneales** (rojo): S, N, Ba, Ar
389
- - **Orbitales** (verde): Or, Po
390
- - **Maxilares** (azul): ANS, PNS, A, Pt
391
- - **Dentales** (amarillo): U1T, U1R, L1T, L1R
392
- - **Mandibulares** (magenta): B, Pog, Gn, Me, Go
393
-
394
- ---
395
- """)
396
 
397
  with gr.Row():
398
  with gr.Column():
399
- input_image = gr.Image(
400
- label="Radiografía Cefalométrica Lateral",
401
- type="pil"
402
- )
403
- detect_btn = gr.Button("🔍 Detectar Landmarks", variant="primary")
404
-
405
  with gr.Column():
406
- output_image = gr.Image(
407
- label="Imagen con Landmarks"
408
- )
409
- output_json = gr.Code(
410
- label="Coordenadas (JSON)",
411
- language="json"
412
- )
413
-
414
- detect_btn.click(
415
- fn=process_image,
416
- inputs=[input_image],
417
- outputs=[output_image, output_json]
418
- )
419
-
420
- gr.Markdown("""
421
- ---
422
- ### 📡 API Endpoint
423
 
424
- Para integración programática (ej. Klinafy):
425
-
426
- ```javascript
427
- const response = await fetch('https://YOUR-SPACE.hf.space/api/predict', {
428
- method: 'POST',
429
- headers: { 'Content-Type': 'application/json' },
430
- body: JSON.stringify({ data: [base64Image] })
431
- });
432
- const result = await response.json();
433
- // result.data[0] = { success: true, landmarks: [...] }
434
- ```
435
-
436
- ---
437
- **Modelo**: HRNet-W32 | **Precisión**: MRE ~1.5mm | **Licencia**: MIT
438
- """)
439
-
440
 
441
- # Habilitar API
442
  demo.queue()
443
 
444
- iif __name__ == "__main__":
445
  demo.launch(ssr_mode=False)
 
1
  """
2
  Cephalometric Landmark Detection API
3
+ HRNet-W32 para 19 landmarks cefalométricos
 
 
4
  """
5
 
6
  import os
7
  import json
8
  import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
  import torch
11
+ import torch.nn as nn
12
  from huggingface_hub import hf_hub_download
13
  import gradio as gr
14
 
 
 
 
15
  # ============================================================================
16
  # CONFIGURACIÓN
17
  # ============================================================================
18
 
19
  MODEL_REPO = "cwlachap/hrnet-cephalometric-landmark-detection"
20
  MODEL_FILE = "best_model.pth"
 
 
21
  NUM_LANDMARKS = 19
22
+ INPUT_SIZE = 768
23
 
 
24
  LANDMARK_NAMES = [
25
+ "S", "N", "Or", "Po", "Ba", "Pt", "ANS", "PNS", "A",
26
+ "U1T", "U1R", "L1T", "L1R", "B", "Pog", "Gn", "Me", "Go", "Ar"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ]
28
 
 
29
  LANDMARK_COLORS = {
30
+ 'S': (255, 0, 0), 'N': (255, 0, 0), 'Ba': (255, 0, 0), 'Ar': (255, 0, 0),
31
+ 'Or': (0, 255, 0), 'Po': (0, 255, 0),
32
+ 'ANS': (0, 100, 255), 'PNS': (0, 100, 255), 'A': (0, 100, 255), 'Pt': (0, 100, 255),
33
+ 'U1T': (255, 255, 0), 'U1R': (255, 255, 0), 'L1T': (255, 255, 0), 'L1R': (255, 255, 0),
34
+ 'B': (255, 0, 255), 'Pog': (255, 0, 255), 'Gn': (255, 0, 255),
35
+ 'Me': (255, 0, 255), 'Go': (255, 0, 255)
36
  }
37
 
38
+ # ============================================================================
39
+ # ARQUITECTURA HRNET-W32
40
+ # ============================================================================
 
 
 
 
 
41
 
42
+ BN_MOMENTUM = 0.1
43
+
44
+ def conv3x3(in_planes, out_planes, stride=1):
45
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
46
+
47
+ class BasicBlock(nn.Module):
48
+ expansion = 1
49
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
50
+ super(BasicBlock, self).__init__()
51
+ self.conv1 = conv3x3(inplanes, planes, stride)
52
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
53
+ self.relu = nn.ReLU(inplace=True)
54
+ self.conv2 = conv3x3(planes, planes)
55
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
56
+ self.downsample = downsample
57
+
58
+ def forward(self, x):
59
+ residual = x
60
+ out = self.relu(self.bn1(self.conv1(x)))
61
+ out = self.bn2(self.conv2(out))
62
+ if self.downsample is not None:
63
+ residual = self.downsample(x)
64
+ return self.relu(out + residual)
65
+
66
+ class Bottleneck(nn.Module):
67
+ expansion = 4
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
69
+ super(Bottleneck, self).__init__()
70
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
71
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
72
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
73
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
74
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
75
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.downsample = downsample
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+ out = self.relu(self.bn1(self.conv1(x)))
82
+ out = self.relu(self.bn2(self.conv2(out)))
83
+ out = self.bn3(self.conv3(out))
84
+ if self.downsample is not None:
85
+ residual = self.downsample(x)
86
+ return self.relu(out + residual)
87
+
88
+ class HighResolutionModule(nn.Module):
89
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True):
90
+ super(HighResolutionModule, self).__init__()
91
+ self.num_inchannels = num_inchannels
92
+ self.num_branches = num_branches
93
+ self.multi_scale_output = multi_scale_output
94
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
95
+ self.fuse_layers = self._make_fuse_layers()
96
+ self.relu = nn.ReLU(True)
97
+
98
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, 1, stride, bias=False),
103
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM))
104
+ layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)]
105
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
106
+ for _ in range(1, num_blocks[branch_index]):
107
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
108
+ return nn.Sequential(*layers)
109
+
110
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
111
+ return nn.ModuleList([self._make_one_branch(i, block, num_blocks, num_channels) for i in range(num_branches)])
112
+
113
+ def _make_fuse_layers(self):
114
+ if self.num_branches == 1:
115
+ return None
116
+ fuse_layers = []
117
+ for i in range(self.num_branches if self.multi_scale_output else 1):
118
+ fuse_layer = []
119
+ for j in range(self.num_branches):
120
+ if j > i:
121
+ fuse_layer.append(nn.Sequential(
122
+ nn.Conv2d(self.num_inchannels[j], self.num_inchannels[i], 1, bias=False),
123
+ nn.BatchNorm2d(self.num_inchannels[i]),
124
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
125
+ elif j == i:
126
+ fuse_layer.append(None)
127
+ else:
128
+ conv3x3s = []
129
+ for k in range(i-j):
130
+ out_ch = self.num_inchannels[i] if k == i - j - 1 else self.num_inchannels[j]
131
+ conv3x3s.append(nn.Sequential(
132
+ nn.Conv2d(self.num_inchannels[j], out_ch, 3, 2, 1, bias=False),
133
+ nn.BatchNorm2d(out_ch),
134
+ nn.ReLU(True) if k < i - j - 1 else nn.Identity()))
135
+ fuse_layer.append(nn.Sequential(*conv3x3s))
136
+ fuse_layers.append(nn.ModuleList(fuse_layer))
137
+ return nn.ModuleList(fuse_layers)
138
+
139
+ def get_num_inchannels(self):
140
+ return self.num_inchannels
141
+
142
+ def forward(self, x):
143
+ if self.num_branches == 1:
144
+ return [self.branches[0](x[0])]
145
+ for i in range(self.num_branches):
146
+ x[i] = self.branches[i](x[i])
147
+ x_fuse = []
148
+ for i in range(len(self.fuse_layers)):
149
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
150
+ for j in range(1, self.num_branches):
151
+ if i == j:
152
+ y = y + x[j]
153
+ else:
154
+ y = y + self.fuse_layers[i][j](x[j])
155
+ x_fuse.append(self.relu(y))
156
+ return x_fuse
157
+
158
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
159
+
160
+ class HRNet(nn.Module):
161
+ def __init__(self, num_joints=19):
162
+ super(HRNet, self).__init__()
163
+ self.inplanes = 64
164
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
165
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
166
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
167
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
168
+ self.relu = nn.ReLU(inplace=True)
169
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
170
+
171
+ self.stage2_cfg = {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4], 'NUM_CHANNELS': [32, 64]}
172
+ num_channels = [ch * BasicBlock.expansion for ch in self.stage2_cfg['NUM_CHANNELS']]
173
+ self.transition1 = self._make_transition_layer([256], num_channels)
174
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
175
+
176
+ self.stage3_cfg = {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4, 4], 'NUM_CHANNELS': [32, 64, 128]}
177
+ num_channels = [ch * BasicBlock.expansion for ch in self.stage3_cfg['NUM_CHANNELS']]
178
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
179
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
180
+
181
+ self.stage4_cfg = {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': [4, 4, 4, 4], 'NUM_CHANNELS': [32, 64, 128, 256]}
182
+ num_channels = [ch * BasicBlock.expansion for ch in self.stage4_cfg['NUM_CHANNELS']]
183
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
184
+ self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=False)
185
+
186
+ self.final_layer = nn.Conv2d(pre_stage_channels[0], num_joints, kernel_size=1, stride=1, padding=0)
187
+
188
+ def _make_transition_layer(self, num_channels_pre, num_channels_cur):
189
+ num_branches_cur = len(num_channels_cur)
190
+ num_branches_pre = len(num_channels_pre)
191
+ transition_layers = []
192
+ for i in range(num_branches_cur):
193
+ if i < num_branches_pre:
194
+ if num_channels_cur[i] != num_channels_pre[i]:
195
+ transition_layers.append(nn.Sequential(
196
+ nn.Conv2d(num_channels_pre[i], num_channels_cur[i], 3, 1, 1, bias=False),
197
+ nn.BatchNorm2d(num_channels_cur[i]), nn.ReLU(inplace=True)))
198
+ else:
199
+ transition_layers.append(None)
200
+ else:
201
+ conv3x3s = []
202
+ for j in range(i + 1 - num_branches_pre):
203
+ inchannels = num_channels_pre[-1]
204
+ outchannels = num_channels_cur[i] if j == i - num_branches_pre else inchannels
205
+ conv3x3s.append(nn.Sequential(
206
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
207
+ nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True)))
208
+ transition_layers.append(nn.Sequential(*conv3x3s))
209
+ return nn.ModuleList(transition_layers)
210
+
211
+ def _make_layer(self, block, planes, blocks, stride=1):
212
+ downsample = None
213
+ if stride != 1 or self.inplanes != planes * block.expansion:
214
+ downsample = nn.Sequential(
215
+ nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
216
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM))
217
+ layers = [block(self.inplanes, planes, stride, downsample)]
218
+ self.inplanes = planes * block.expansion
219
+ for _ in range(1, blocks):
220
+ layers.append(block(self.inplanes, planes))
221
+ return nn.Sequential(*layers)
222
+
223
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
224
+ num_modules = layer_config['NUM_MODULES']
225
+ num_branches = layer_config['NUM_BRANCHES']
226
+ num_blocks = layer_config['NUM_BLOCKS']
227
+ num_channels = layer_config['NUM_CHANNELS']
228
+ block = blocks_dict[layer_config['BLOCK']]
229
+ modules = []
230
+ for i in range(num_modules):
231
+ reset_multi_scale = multi_scale_output or i < num_modules - 1
232
+ modules.append(HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, 'SUM', reset_multi_scale))
233
+ num_inchannels = modules[-1].get_num_inchannels()
234
+ return nn.Sequential(*modules), num_inchannels
235
+
236
+ def forward(self, x):
237
+ x = self.relu(self.bn1(self.conv1(x)))
238
+ x = self.relu(self.bn2(self.conv2(x)))
239
+ x = self.layer1(x)
240
+ x_list = [self.transition1[i](x) if self.transition1[i] else x for i in range(self.stage2_cfg['NUM_BRANCHES'])]
241
+ y_list = self.stage2(x_list)
242
+ x_list = []
243
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
244
+ idx = min(i, len(y_list)-1)
245
+ x_list.append(self.transition2[i](y_list[idx]) if self.transition2[i] else y_list[i])
246
+ y_list = self.stage3(x_list)
247
+ x_list = []
248
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
249
+ idx = min(i, len(y_list)-1)
250
+ x_list.append(self.transition3[i](y_list[idx]) if self.transition3[i] else y_list[i])
251
+ y_list = self.stage4(x_list)
252
+ return self.final_layer(y_list[0])
253
 
254
  # ============================================================================
255
+ # MODELO GLOBAL
256
  # ============================================================================
257
 
 
258
  model = None
259
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
260
 
 
261
  def load_model():
 
262
  global model
 
263
  if model is not None:
264
  return model
265
 
266
  print(f"Cargando modelo en {device}...")
267
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
 
 
 
 
 
 
 
 
 
 
268
  checkpoint = torch.load(model_path, map_location=device, weights_only=False)
269
 
 
270
  if 'model_state_dict' in checkpoint:
271
  state_dict = checkpoint['model_state_dict']
272
  elif 'state_dict' in checkpoint:
 
274
  else:
275
  state_dict = checkpoint
276
 
277
+ # Analizar estructura del checkpoint
278
+ print(f"Keys en checkpoint: {len(state_dict)}")
279
+ sample_keys = list(state_dict.keys())[:5]
280
+ print(f"Ejemplo de keys: {sample_keys}")
281
+
282
+ # Limpiar prefijos comunes
283
  new_state_dict = {}
284
  for k, v in state_dict.items():
285
+ name = k
286
+ for prefix in ['module.', 'backbone.', 'model.']:
287
+ if name.startswith(prefix):
288
+ name = name[len(prefix):]
289
  new_state_dict[name] = v
290
 
291
+ model = HRNet(num_joints=NUM_LANDMARKS)
292
+
293
+ try:
294
+ model.load_state_dict(new_state_dict, strict=True)
295
+ print("✓ Pesos cargados correctamente (strict=True)")
296
+ except Exception as e:
297
+ print(f"⚠ Carga estricta falló: {e}")
298
+ missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
299
+ print(f" - Keys faltantes: {len(missing)}")
300
+ print(f" - Keys inesperadas: {len(unexpected)}")
301
+
302
  model.to(device)
303
  model.eval()
304
+ print("✓ Modelo listo!")
 
305
  return model
306
 
307
+ def get_max_preds(batch_heatmaps):
308
+ batch_size, num_joints, h, w = batch_heatmaps.shape
309
+ heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
310
+ idx = np.argmax(heatmaps_reshaped, 2)
311
+ maxvals = np.amax(heatmaps_reshaped, 2)
312
+ preds = np.zeros((batch_size, num_joints, 2), dtype=np.float32)
313
+ preds[:, :, 0] = idx % w
314
+ preds[:, :, 1] = idx // w
315
+ return preds, maxvals.reshape((batch_size, num_joints, 1))
316
 
317
+ def detect_landmarks(image):
318
+ model = load_model()
 
 
 
 
 
 
 
 
319
 
 
 
 
 
 
320
  if isinstance(image, np.ndarray):
321
  image = Image.fromarray(image)
322
 
323
+ original_size = image.size
 
 
 
324
  if image.mode != 'RGB':
325
  image = image.convert('RGB')
326
 
327
+ image_resized = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.BILINEAR)
 
 
 
 
328
 
329
+ img_array = np.array(image_resized).astype(np.float32) / 255.0
330
  mean = np.array([0.485, 0.456, 0.406])
331
  std = np.array([0.229, 0.224, 0.225])
332
  img_array = (img_array - mean) / std
333
 
334
+ img_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ with torch.no_grad():
337
+ output = model(img_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ heatmaps = output.cpu().numpy()
340
+ preds, maxvals = get_max_preds(heatmaps)
341
 
342
+ heatmap_h, heatmap_w = heatmaps.shape[2], heatmaps.shape[3]
343
  orig_w, orig_h = original_size
 
344
 
345
  scale_x = orig_w / heatmap_w
346
  scale_y = orig_h / heatmap_h
347
 
348
  landmarks = []
349
  for i in range(NUM_LANDMARKS):
 
 
 
 
350
  landmarks.append({
351
  "name": LANDMARK_NAMES[i],
352
+ "x": round(float(preds[0, i, 0] * scale_x), 2),
353
+ "y": round(float(preds[0, i, 1] * scale_y), 2),
354
+ "confidence": round(float(maxvals[0, i, 0]), 4)
 
355
  })
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  annotated = draw_landmarks(image, landmarks)
 
358
  return landmarks, annotated
359
 
360
+ def draw_landmarks(image, landmarks, radius=6):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  if isinstance(image, np.ndarray):
362
  image = Image.fromarray(image)
363
+ img_draw = image.copy().convert('RGB')
 
 
 
 
 
364
  draw = ImageDraw.Draw(img_draw)
365
 
 
366
  try:
367
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
368
  except:
369
  font = ImageFont.load_default()
370
 
371
  for lm in landmarks:
372
  x, y = lm['x'], lm['y']
373
+ color = LANDMARK_COLORS.get(lm['name'], (255, 255, 255))
374
+ draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill=color, outline=(255, 255, 255), width=2)
375
+ draw.text((x+radius+3, y-7), lm['name'], fill=(255, 255, 255), font=font, stroke_width=2, stroke_fill=(0, 0, 0))
 
 
 
 
 
 
 
376
 
377
  return img_draw
378
 
 
379
  # ============================================================================
380
  # INTERFAZ GRADIO
381
  # ============================================================================
382
 
383
  def process_image(image):
 
384
  if image is None:
385
+ return None, json.dumps({"error": "Por favor sube una imagen"}, indent=2)
386
 
387
  try:
388
  landmarks, annotated = detect_landmarks(image)
389
+ result = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  "success": True,
 
391
  "num_landmarks": len(landmarks),
392
  "landmarks": landmarks
393
  }
394
+ return annotated, json.dumps(result, indent=2)
395
  except Exception as e:
396
+ import traceback
397
+ return None, json.dumps({"success": False, "error": str(e), "traceback": traceback.format_exc()}, indent=2)
 
 
 
398
 
399
+ print("=" * 50)
400
+ print("Inicializando Cephalometric Landmark Detection...")
401
+ print("=" * 50)
 
 
 
402
  load_model()
403
 
404
+ with gr.Blocks(title="Cephalometric Landmark Detection", theme=gr.themes.Soft()) as demo:
405
+ gr.Markdown("# 🦷 Detección de Landmarks Cefalométricos\n\nDetección automática de **19 puntos anatómicos** usando HRNet-W32.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  with gr.Row():
408
  with gr.Column():
409
+ input_image = gr.Image(label="📤 Radiografía", type="pil", height=400)
410
+ detect_btn = gr.Button("🔍 Detectar", variant="primary", size="lg")
 
 
 
 
411
  with gr.Column():
412
+ output_image = gr.Image(label="📍 Resultado", height=400)
413
+ output_json = gr.Code(label="📋 JSON", language="json", lines=12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ detect_btn.click(fn=process_image, inputs=[input_image], outputs=[output_image, output_json])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
 
417
  demo.queue()
418
 
419
+ if __name__ == "__main__":
420
  demo.launch(ssr_mode=False)
requirements.txt CHANGED
@@ -1,14 +1,6 @@
1
- # Core ML
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
-
5
- # Hugging Face
6
  huggingface_hub>=0.19.0
7
  gradio>=4.0.0
8
-
9
- # Image processing
10
  Pillow>=10.0.0
11
  numpy>=1.24.0
12
-
13
- # Optional but useful
14
- scipy>=1.10.0
 
 
1
  torch>=2.0.0
2
  torchvision>=0.15.0
 
 
3
  huggingface_hub>=0.19.0
4
  gradio>=4.0.0
 
 
5
  Pillow>=10.0.0
6
  numpy>=1.24.0