Klinapps commited on
Commit
03d222a
·
verified ·
1 Parent(s): 809fb82

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +449 -0
  2. hrnet.py +395 -0
  3. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cephalometric Landmark Detection API
3
+ HRNet-W32 based automatic landmark detection for lateral cephalometric radiographs
4
+
5
+ Space para integración con Klinafy
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import numpy as np
11
+ from PIL import Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from huggingface_hub import hf_hub_download
15
+ import gradio as gr
16
+
17
+ from hrnet import get_hrnet_w32
18
+
19
+
20
+ # ============================================================================
21
+ # CONFIGURACIÓN
22
+ # ============================================================================
23
+
24
+ MODEL_REPO = "cwlachap/hrnet-cephalometric-landmark-detection"
25
+ MODEL_FILE = "best_model.pth"
26
+ INPUT_SIZE = 768
27
+ HEATMAP_SIZE = 192
28
+ NUM_LANDMARKS = 19
29
+
30
+ # Nombres de los 19 landmarks en orden del modelo
31
+ LANDMARK_NAMES = [
32
+ "S", # 0 - Sella turcica
33
+ "N", # 1 - Nasion
34
+ "Or", # 2 - Orbitale
35
+ "Po", # 3 - Porion
36
+ "Ba", # 4 - Basion
37
+ "Pt", # 5 - Pterygoid point
38
+ "ANS", # 6 - Anterior Nasal Spine
39
+ "PNS", # 7 - Posterior Nasal Spine
40
+ "A", # 8 - Point A (Subspinale)
41
+ "U1T", # 9 - Upper Incisor Tip
42
+ "U1R", # 10 - Upper Incisor Root
43
+ "L1T", # 11 - Lower Incisor Tip
44
+ "L1R", # 12 - Lower Incisor Root
45
+ "B", # 13 - Point B (Supramentale)
46
+ "Pog", # 14 - Pogonion
47
+ "Gn", # 15 - Gnathion
48
+ "Me", # 16 - Menton
49
+ "Go", # 17 - Gonion
50
+ "Ar" # 18 - Articulare
51
+ ]
52
+
53
+ # Colores para visualización (RGB)
54
+ LANDMARK_COLORS = {
55
+ "cranial": (255, 0, 0), # Rojo - S, N, Ba, Ar
56
+ "orbital": (0, 255, 0), # Verde - Or, Po
57
+ "maxilar": (0, 0, 255), # Azul - ANS, PNS, A, Pt
58
+ "dental": (255, 255, 0), # Amarillo - U1T, U1R, L1T, L1R
59
+ "mandibular": (255, 0, 255) # Magenta - B, Pog, Gn, Me, Go
60
+ }
61
+
62
+ LANDMARK_GROUPS = {
63
+ "S": "cranial", "N": "cranial", "Ba": "cranial", "Ar": "cranial",
64
+ "Or": "orbital", "Po": "orbital",
65
+ "ANS": "maxilar", "PNS": "maxilar", "A": "maxilar", "Pt": "maxilar",
66
+ "U1T": "dental", "U1R": "dental", "L1T": "dental", "L1R": "dental",
67
+ "B": "mandibular", "Pog": "mandibular", "Gn": "mandibular",
68
+ "Me": "mandibular", "Go": "mandibular"
69
+ }
70
+
71
+
72
+ # ============================================================================
73
+ # MODELO
74
+ # ============================================================================
75
+
76
+ # Variable global para el modelo
77
+ model = None
78
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+
80
+
81
+ def load_model():
82
+ """Carga el modelo HRNet desde Hugging Face Hub"""
83
+ global model
84
+
85
+ if model is not None:
86
+ return model
87
+
88
+ print(f"Cargando modelo en {device}...")
89
+
90
+ # Descargar pesos
91
+ model_path = hf_hub_download(
92
+ repo_id=MODEL_REPO,
93
+ filename=MODEL_FILE
94
+ )
95
+
96
+ # Crear modelo
97
+ model = get_hrnet_w32(num_landmarks=NUM_LANDMARKS)
98
+
99
+ # Cargar pesos
100
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
101
+
102
+ # Manejar diferentes formatos de checkpoint
103
+ if 'model_state_dict' in checkpoint:
104
+ state_dict = checkpoint['model_state_dict']
105
+ elif 'state_dict' in checkpoint:
106
+ state_dict = checkpoint['state_dict']
107
+ else:
108
+ state_dict = checkpoint
109
+
110
+ # Limpiar prefijos si existen
111
+ new_state_dict = {}
112
+ for k, v in state_dict.items():
113
+ name = k.replace('module.', '') # Remover prefijo de DataParallel
114
+ new_state_dict[name] = v
115
+
116
+ model.load_state_dict(new_state_dict, strict=False)
117
+ model.to(device)
118
+ model.eval()
119
+
120
+ print("Modelo cargado exitosamente!")
121
+ return model
122
+
123
+
124
+ # ============================================================================
125
+ # PREPROCESAMIENTO
126
+ # ============================================================================
127
+
128
+ def preprocess_image(image):
129
+ """
130
+ Preprocesa la imagen para el modelo
131
+
132
+ Args:
133
+ image: PIL Image o numpy array
134
+
135
+ Returns:
136
+ tensor: Tensor normalizado [1, 3, 768, 768]
137
+ original_size: (width, height) original
138
+ """
139
+ # Convertir a PIL si es necesario
140
+ if isinstance(image, np.ndarray):
141
+ image = Image.fromarray(image)
142
+
143
+ # Guardar tamaño original
144
+ original_size = image.size # (width, height)
145
+
146
+ # Convertir a RGB si es necesario
147
+ if image.mode != 'RGB':
148
+ image = image.convert('RGB')
149
+
150
+ # Redimensionar a 768x768
151
+ image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.BILINEAR)
152
+
153
+ # Convertir a tensor
154
+ img_array = np.array(image).astype(np.float32) / 255.0
155
+
156
+ # Normalizar con ImageNet stats
157
+ mean = np.array([0.485, 0.456, 0.406])
158
+ std = np.array([0.229, 0.224, 0.225])
159
+ img_array = (img_array - mean) / std
160
+
161
+ # Cambiar a formato CHW y agregar batch dimension
162
+ img_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float()
163
+ img_tensor = img_tensor.unsqueeze(0)
164
+
165
+ return img_tensor, original_size
166
+
167
+
168
+ # ============================================================================
169
+ # POSTPROCESAMIENTO
170
+ # ============================================================================
171
+
172
+ def get_max_preds(heatmaps):
173
+ """
174
+ Obtiene las coordenadas del máximo de cada heatmap
175
+
176
+ Args:
177
+ heatmaps: tensor [batch, num_landmarks, H, W]
178
+
179
+ Returns:
180
+ preds: coordenadas [batch, num_landmarks, 2]
181
+ maxvals: valores de confianza [batch, num_landmarks, 1]
182
+ """
183
+ batch_size = heatmaps.shape[0]
184
+ num_joints = heatmaps.shape[1]
185
+ width = heatmaps.shape[3]
186
+
187
+ heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
188
+ idx = np.argmax(heatmaps_reshaped, axis=2)
189
+ maxvals = np.amax(heatmaps_reshaped, axis=2)
190
+
191
+ maxvals = maxvals.reshape((batch_size, num_joints, 1))
192
+ idx = idx.reshape((batch_size, num_joints, 1))
193
+
194
+ preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
195
+ preds[:, :, 0] = (preds[:, :, 0]) % width
196
+ preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
197
+
198
+ return preds, maxvals
199
+
200
+
201
+ def heatmaps_to_landmarks(heatmaps, original_size):
202
+ """
203
+ Convierte heatmaps a coordenadas de landmarks
204
+
205
+ Args:
206
+ heatmaps: tensor [1, 19, H, W]
207
+ original_size: (width, height) de la imagen original
208
+
209
+ Returns:
210
+ landmarks: lista de dicts con name, x, y, confidence
211
+ """
212
+ heatmaps_np = heatmaps.cpu().numpy()
213
+
214
+ # Obtener coordenadas del máximo
215
+ preds, maxvals = get_max_preds(heatmaps_np)
216
+
217
+ # Escalar a tamaño original
218
+ orig_w, orig_h = original_size
219
+ heatmap_h, heatmap_w = heatmaps_np.shape[2], heatmaps_np.shape[3]
220
+
221
+ scale_x = orig_w / heatmap_w
222
+ scale_y = orig_h / heatmap_h
223
+
224
+ landmarks = []
225
+ for i in range(NUM_LANDMARKS):
226
+ x = float(preds[0, i, 0] * scale_x)
227
+ y = float(preds[0, i, 1] * scale_y)
228
+ conf = float(maxvals[0, i, 0])
229
+
230
+ landmarks.append({
231
+ "name": LANDMARK_NAMES[i],
232
+ "x": round(x, 2),
233
+ "y": round(y, 2),
234
+ "confidence": round(conf, 4),
235
+ "group": LANDMARK_GROUPS[LANDMARK_NAMES[i]]
236
+ })
237
+
238
+ return landmarks
239
+
240
+
241
+ # ============================================================================
242
+ # INFERENCIA
243
+ # ============================================================================
244
+
245
+ def detect_landmarks(image):
246
+ """
247
+ Detecta landmarks cefalométricos en una imagen
248
+
249
+ Args:
250
+ image: PIL Image o numpy array
251
+
252
+ Returns:
253
+ landmarks: lista de dicts con name, x, y, confidence
254
+ annotated_image: imagen con landmarks dibujados
255
+ """
256
+ # Cargar modelo si no está cargado
257
+ model = load_model()
258
+
259
+ # Preprocesar
260
+ img_tensor, original_size = preprocess_image(image)
261
+ img_tensor = img_tensor.to(device)
262
+
263
+ # Inferencia
264
+ with torch.no_grad():
265
+ heatmaps = model(img_tensor)
266
+
267
+ # Postprocesar
268
+ landmarks = heatmaps_to_landmarks(heatmaps, original_size)
269
+
270
+ # Crear imagen anotada
271
+ annotated = draw_landmarks(image, landmarks)
272
+
273
+ return landmarks, annotated
274
+
275
+
276
+ def draw_landmarks(image, landmarks, radius=5):
277
+ """
278
+ Dibuja los landmarks en la imagen
279
+
280
+ Args:
281
+ image: PIL Image o numpy array
282
+ landmarks: lista de dicts con coordenadas
283
+ radius: radio del círculo
284
+
285
+ Returns:
286
+ PIL Image con landmarks dibujados
287
+ """
288
+ from PIL import ImageDraw, ImageFont
289
+
290
+ if isinstance(image, np.ndarray):
291
+ image = Image.fromarray(image)
292
+
293
+ # Crear copia para dibujar
294
+ img_draw = image.copy()
295
+ if img_draw.mode != 'RGB':
296
+ img_draw = img_draw.convert('RGB')
297
+
298
+ draw = ImageDraw.Draw(img_draw)
299
+
300
+ # Intentar cargar fuente, usar default si falla
301
+ try:
302
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
303
+ except:
304
+ font = ImageFont.load_default()
305
+
306
+ for lm in landmarks:
307
+ x, y = lm['x'], lm['y']
308
+ color = LANDMARK_COLORS[lm['group']]
309
+
310
+ # Dibujar círculo
311
+ draw.ellipse([x-radius, y-radius, x+radius, y+radius],
312
+ fill=color, outline=(255, 255, 255))
313
+
314
+ # Dibujar nombre
315
+ draw.text((x+radius+2, y-radius), lm['name'],
316
+ fill=(255, 255, 255), font=font,
317
+ stroke_width=1, stroke_fill=(0, 0, 0))
318
+
319
+ return img_draw
320
+
321
+
322
+ # ============================================================================
323
+ # INTERFAZ GRADIO
324
+ # ============================================================================
325
+
326
+ def process_image(image):
327
+ """Función principal para Gradio"""
328
+ if image is None:
329
+ return None, "Por favor sube una imagen cefalométrica"
330
+
331
+ try:
332
+ landmarks, annotated = detect_landmarks(image)
333
+
334
+ # Formatear JSON para mostrar
335
+ json_output = json.dumps({
336
+ "success": True,
337
+ "num_landmarks": len(landmarks),
338
+ "landmarks": landmarks
339
+ }, indent=2)
340
+
341
+ return annotated, json_output
342
+
343
+ except Exception as e:
344
+ return None, json.dumps({
345
+ "success": False,
346
+ "error": str(e)
347
+ }, indent=2)
348
+
349
+
350
+ def api_predict(image):
351
+ """Endpoint API para integración con Klinafy"""
352
+ if image is None:
353
+ return {"success": False, "error": "No image provided"}
354
+
355
+ try:
356
+ landmarks, _ = detect_landmarks(image)
357
+
358
+ return {
359
+ "success": True,
360
+ "model": "HRNet-W32",
361
+ "num_landmarks": len(landmarks),
362
+ "landmarks": landmarks
363
+ }
364
+
365
+ except Exception as e:
366
+ return {
367
+ "success": False,
368
+ "error": str(e)
369
+ }
370
+
371
+
372
+ # ============================================================================
373
+ # CREAR APP
374
+ # ============================================================================
375
+
376
+ # Cargar modelo al inicio
377
+ print("Inicializando modelo...")
378
+ load_model()
379
+
380
+ # Crear interfaz Gradio
381
+ with gr.Blocks(title="Cephalometric Landmark Detection") as demo:
382
+ gr.Markdown("""
383
+ # 🦷 Detección de Landmarks Cefalométricos
384
+
385
+ Detección automática de **19 puntos cefalométricos** usando HRNet-W32.
386
+
387
+ ### Landmarks detectados:
388
+ - **Craneales** (rojo): S, N, Ba, Ar
389
+ - **Orbitales** (verde): Or, Po
390
+ - **Maxilares** (azul): ANS, PNS, A, Pt
391
+ - **Dentales** (amarillo): U1T, U1R, L1T, L1R
392
+ - **Mandibulares** (magenta): B, Pog, Gn, Me, Go
393
+
394
+ ---
395
+ """)
396
+
397
+ with gr.Row():
398
+ with gr.Column():
399
+ input_image = gr.Image(
400
+ label="Radiografía Cefalométrica Lateral",
401
+ type="pil"
402
+ )
403
+ detect_btn = gr.Button("🔍 Detectar Landmarks", variant="primary")
404
+
405
+ with gr.Column():
406
+ output_image = gr.Image(
407
+ label="Imagen con Landmarks"
408
+ )
409
+ output_json = gr.Code(
410
+ label="Coordenadas (JSON)",
411
+ language="json"
412
+ )
413
+
414
+ detect_btn.click(
415
+ fn=process_image,
416
+ inputs=[input_image],
417
+ outputs=[output_image, output_json]
418
+ )
419
+
420
+ gr.Markdown("""
421
+ ---
422
+ ### 📡 API Endpoint
423
+
424
+ Para integración programática (ej. Klinafy):
425
+
426
+ ```javascript
427
+ const response = await fetch('https://YOUR-SPACE.hf.space/api/predict', {
428
+ method: 'POST',
429
+ headers: { 'Content-Type': 'application/json' },
430
+ body: JSON.stringify({ data: [base64Image] })
431
+ });
432
+ const result = await response.json();
433
+ // result.data[0] = { success: true, landmarks: [...] }
434
+ ```
435
+
436
+ ---
437
+ **Modelo**: HRNet-W32 | **Precisión**: MRE ~1.5mm | **Licencia**: MIT
438
+ """)
439
+
440
+
441
+ # Habilitar API
442
+ demo.queue()
443
+
444
+ if __name__ == "__main__":
445
+ demo.launch(
446
+ server_name="0.0.0.0",
447
+ server_port=7860,
448
+ share=False
449
+ )
hrnet.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HRNet-W32 Architecture for Cephalometric Landmark Detection
3
+ Based on: Deep High-Resolution Representation Learning for Visual Recognition
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14
+ padding=1, bias=False)
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ expansion = 1
19
+
20
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
21
+ super(BasicBlock, self).__init__()
22
+ self.conv1 = conv3x3(inplanes, planes, stride)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.relu = nn.ReLU(inplace=True)
25
+ self.conv2 = conv3x3(planes, planes)
26
+ self.bn2 = nn.BatchNorm2d(planes)
27
+ self.downsample = downsample
28
+ self.stride = stride
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+ out = self.conv1(x)
33
+ out = self.bn1(out)
34
+ out = self.relu(out)
35
+ out = self.conv2(out)
36
+ out = self.bn2(out)
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+ out += residual
40
+ out = self.relu(out)
41
+ return out
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
48
+ super(Bottleneck, self).__init__()
49
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(planes)
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52
+ padding=1, bias=False)
53
+ self.bn2 = nn.BatchNorm2d(planes)
54
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
55
+ bias=False)
56
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
57
+ self.relu = nn.ReLU(inplace=True)
58
+ self.downsample = downsample
59
+ self.stride = stride
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+ out = self.conv1(x)
64
+ out = self.bn1(out)
65
+ out = self.relu(out)
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+ out = self.relu(out)
69
+ out = self.conv3(out)
70
+ out = self.bn3(out)
71
+ if self.downsample is not None:
72
+ residual = self.downsample(x)
73
+ out += residual
74
+ out = self.relu(out)
75
+ return out
76
+
77
+
78
+ class HighResolutionModule(nn.Module):
79
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
80
+ num_channels, fuse_method, multi_scale_output=True):
81
+ super(HighResolutionModule, self).__init__()
82
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
83
+
84
+ self.num_inchannels = num_inchannels
85
+ self.fuse_method = fuse_method
86
+ self.num_branches = num_branches
87
+ self.multi_scale_output = multi_scale_output
88
+
89
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
90
+ self.fuse_layers = self._make_fuse_layers()
91
+ self.relu = nn.ReLU(inplace=True)
92
+
93
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
94
+ if num_branches != len(num_blocks):
95
+ raise ValueError('NUM_BRANCHES != len(NUM_BLOCKS)')
96
+ if num_branches != len(num_channels):
97
+ raise ValueError('NUM_BRANCHES != len(NUM_CHANNELS)')
98
+ if num_branches != len(num_inchannels):
99
+ raise ValueError('NUM_BRANCHES != len(NUM_INCHANNELS)')
100
+
101
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
102
+ downsample = None
103
+ if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
104
+ downsample = nn.Sequential(
105
+ nn.Conv2d(self.num_inchannels[branch_index],
106
+ num_channels[branch_index] * block.expansion,
107
+ kernel_size=1, stride=stride, bias=False),
108
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion),
109
+ )
110
+
111
+ layers = []
112
+ layers.append(block(self.num_inchannels[branch_index],
113
+ num_channels[branch_index], stride, downsample))
114
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
115
+ for i in range(1, num_blocks[branch_index]):
116
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
117
+
118
+ return nn.Sequential(*layers)
119
+
120
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
121
+ branches = []
122
+ for i in range(num_branches):
123
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
124
+ return nn.ModuleList(branches)
125
+
126
+ def _make_fuse_layers(self):
127
+ if self.num_branches == 1:
128
+ return None
129
+
130
+ num_branches = self.num_branches
131
+ num_inchannels = self.num_inchannels
132
+ fuse_layers = []
133
+ for i in range(num_branches if self.multi_scale_output else 1):
134
+ fuse_layer = []
135
+ for j in range(num_branches):
136
+ if j > i:
137
+ fuse_layer.append(nn.Sequential(
138
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
139
+ nn.BatchNorm2d(num_inchannels[i])))
140
+ elif j == i:
141
+ fuse_layer.append(None)
142
+ else:
143
+ conv3x3s = []
144
+ for k in range(i - j):
145
+ if k == i - j - 1:
146
+ num_outchannels_conv3x3 = num_inchannels[i]
147
+ conv3x3s.append(nn.Sequential(
148
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
149
+ nn.BatchNorm2d(num_outchannels_conv3x3)))
150
+ else:
151
+ num_outchannels_conv3x3 = num_inchannels[j]
152
+ conv3x3s.append(nn.Sequential(
153
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
154
+ nn.BatchNorm2d(num_outchannels_conv3x3),
155
+ nn.ReLU(inplace=True)))
156
+ fuse_layer.append(nn.Sequential(*conv3x3s))
157
+ fuse_layers.append(nn.ModuleList(fuse_layer))
158
+
159
+ return nn.ModuleList(fuse_layers)
160
+
161
+ def get_num_inchannels(self):
162
+ return self.num_inchannels
163
+
164
+ def forward(self, x):
165
+ if self.num_branches == 1:
166
+ return [self.branches[0](x[0])]
167
+
168
+ for i in range(self.num_branches):
169
+ x[i] = self.branches[i](x[i])
170
+
171
+ x_fuse = []
172
+ for i in range(len(self.fuse_layers)):
173
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
174
+ for j in range(1, self.num_branches):
175
+ if i == j:
176
+ y = y + x[j]
177
+ elif j > i:
178
+ width_output = x[i].shape[-1]
179
+ height_output = x[i].shape[-2]
180
+ y = y + F.interpolate(
181
+ self.fuse_layers[i][j](x[j]),
182
+ size=[height_output, width_output],
183
+ mode='bilinear', align_corners=True)
184
+ else:
185
+ y = y + self.fuse_layers[i][j](x[j])
186
+ x_fuse.append(self.relu(y))
187
+
188
+ return x_fuse
189
+
190
+
191
+ class HRNetW32(nn.Module):
192
+ """
193
+ HRNet-W32 for Cephalometric Landmark Detection
194
+ Input: 768x768 grayscale/RGB image
195
+ Output: 19 landmark heatmaps (192x192)
196
+ """
197
+
198
+ def __init__(self, num_landmarks=19):
199
+ super(HRNetW32, self).__init__()
200
+
201
+ self.num_landmarks = num_landmarks
202
+
203
+ # Stem
204
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
205
+ self.bn1 = nn.BatchNorm2d(64)
206
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
207
+ self.bn2 = nn.BatchNorm2d(64)
208
+ self.relu = nn.ReLU(inplace=True)
209
+
210
+ # Stage 1
211
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
212
+
213
+ # Stage 2
214
+ self.stage2_cfg = {
215
+ 'NUM_MODULES': 1,
216
+ 'NUM_BRANCHES': 2,
217
+ 'NUM_BLOCKS': [4, 4],
218
+ 'NUM_CHANNELS': [32, 64],
219
+ 'BLOCK': BasicBlock,
220
+ 'FUSE_METHOD': 'SUM'
221
+ }
222
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
223
+ block = self.stage2_cfg['BLOCK']
224
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
225
+ self.transition1 = self._make_transition_layer([256], num_channels)
226
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
227
+
228
+ # Stage 3
229
+ self.stage3_cfg = {
230
+ 'NUM_MODULES': 4,
231
+ 'NUM_BRANCHES': 3,
232
+ 'NUM_BLOCKS': [4, 4, 4],
233
+ 'NUM_CHANNELS': [32, 64, 128],
234
+ 'BLOCK': BasicBlock,
235
+ 'FUSE_METHOD': 'SUM'
236
+ }
237
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
238
+ block = self.stage3_cfg['BLOCK']
239
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
240
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
241
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
242
+
243
+ # Stage 4
244
+ self.stage4_cfg = {
245
+ 'NUM_MODULES': 3,
246
+ 'NUM_BRANCHES': 4,
247
+ 'NUM_BLOCKS': [4, 4, 4, 4],
248
+ 'NUM_CHANNELS': [32, 64, 128, 256],
249
+ 'BLOCK': BasicBlock,
250
+ 'FUSE_METHOD': 'SUM'
251
+ }
252
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
253
+ block = self.stage4_cfg['BLOCK']
254
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
255
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
256
+ self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
257
+
258
+ # Head
259
+ last_inp_channels = sum(pre_stage_channels)
260
+ self.head = nn.Sequential(
261
+ nn.Conv2d(last_inp_channels, last_inp_channels, kernel_size=1, stride=1, padding=0),
262
+ nn.BatchNorm2d(last_inp_channels),
263
+ nn.ReLU(inplace=True),
264
+ nn.Conv2d(last_inp_channels, num_landmarks, kernel_size=1, stride=1, padding=0)
265
+ )
266
+
267
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
268
+ downsample = None
269
+ if stride != 1 or inplanes != planes * block.expansion:
270
+ downsample = nn.Sequential(
271
+ nn.Conv2d(inplanes, planes * block.expansion,
272
+ kernel_size=1, stride=stride, bias=False),
273
+ nn.BatchNorm2d(planes * block.expansion),
274
+ )
275
+
276
+ layers = []
277
+ layers.append(block(inplanes, planes, stride, downsample))
278
+ inplanes = planes * block.expansion
279
+ for i in range(1, blocks):
280
+ layers.append(block(inplanes, planes))
281
+
282
+ return nn.Sequential(*layers)
283
+
284
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
285
+ num_branches_cur = len(num_channels_cur_layer)
286
+ num_branches_pre = len(num_channels_pre_layer)
287
+
288
+ transition_layers = []
289
+ for i in range(num_branches_cur):
290
+ if i < num_branches_pre:
291
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
292
+ transition_layers.append(nn.Sequential(
293
+ nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
294
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
295
+ nn.ReLU(inplace=True)))
296
+ else:
297
+ transition_layers.append(None)
298
+ else:
299
+ conv3x3s = []
300
+ for j in range(i + 1 - num_branches_pre):
301
+ inchannels = num_channels_pre_layer[-1]
302
+ outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
303
+ conv3x3s.append(nn.Sequential(
304
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
305
+ nn.BatchNorm2d(outchannels),
306
+ nn.ReLU(inplace=True)))
307
+ transition_layers.append(nn.Sequential(*conv3x3s))
308
+
309
+ return nn.ModuleList(transition_layers)
310
+
311
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
312
+ num_modules = layer_config['NUM_MODULES']
313
+ num_branches = layer_config['NUM_BRANCHES']
314
+ num_blocks = layer_config['NUM_BLOCKS']
315
+ num_channels = layer_config['NUM_CHANNELS']
316
+ block = layer_config['BLOCK']
317
+ fuse_method = layer_config['FUSE_METHOD']
318
+
319
+ modules = []
320
+ for i in range(num_modules):
321
+ if not multi_scale_output and i == num_modules - 1:
322
+ reset_multi_scale_output = False
323
+ else:
324
+ reset_multi_scale_output = True
325
+ modules.append(
326
+ HighResolutionModule(num_branches, block, num_blocks, num_inchannels,
327
+ num_channels, fuse_method, reset_multi_scale_output)
328
+ )
329
+ num_inchannels = modules[-1].get_num_inchannels()
330
+
331
+ return nn.Sequential(*modules), num_inchannels
332
+
333
+ def forward(self, x):
334
+ # Stem
335
+ x = self.conv1(x)
336
+ x = self.bn1(x)
337
+ x = self.relu(x)
338
+ x = self.conv2(x)
339
+ x = self.bn2(x)
340
+ x = self.relu(x)
341
+
342
+ # Stage 1
343
+ x = self.layer1(x)
344
+
345
+ # Stage 2
346
+ x_list = []
347
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
348
+ if self.transition1[i] is not None:
349
+ x_list.append(self.transition1[i](x))
350
+ else:
351
+ x_list.append(x)
352
+ y_list = self.stage2(x_list)
353
+
354
+ # Stage 3
355
+ x_list = []
356
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
357
+ if self.transition2[i] is not None:
358
+ if i < self.stage2_cfg['NUM_BRANCHES']:
359
+ x_list.append(self.transition2[i](y_list[i]))
360
+ else:
361
+ x_list.append(self.transition2[i](y_list[-1]))
362
+ else:
363
+ x_list.append(y_list[i])
364
+ y_list = self.stage3(x_list)
365
+
366
+ # Stage 4
367
+ x_list = []
368
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
369
+ if self.transition3[i] is not None:
370
+ if i < self.stage3_cfg['NUM_BRANCHES']:
371
+ x_list.append(self.transition3[i](y_list[i]))
372
+ else:
373
+ x_list.append(self.transition3[i](y_list[-1]))
374
+ else:
375
+ x_list.append(y_list[i])
376
+ x = self.stage4(x_list)
377
+
378
+ # Upscale to highest resolution
379
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
380
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
381
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
382
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
383
+
384
+ x = torch.cat([x[0], x1, x2, x3], 1)
385
+
386
+ # Head
387
+ x = self.head(x)
388
+
389
+ return x
390
+
391
+
392
+ def get_hrnet_w32(num_landmarks=19):
393
+ """Create HRNet-W32 model for cephalometric landmark detection"""
394
+ model = HRNetW32(num_landmarks=num_landmarks)
395
+ return model
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+
5
+ # Hugging Face
6
+ huggingface_hub>=0.19.0
7
+ gradio>=4.0.0
8
+
9
+ # Image processing
10
+ Pillow>=10.0.0
11
+ numpy>=1.24.0
12
+
13
+ # Optional but useful
14
+ scipy>=1.10.0