Drazcat-AI commited on
Commit
37cbec3
·
verified ·
1 Parent(s): b9498f6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +244 -534
handler.py CHANGED
@@ -5,563 +5,273 @@ from PIL import Image
5
  import numpy as np
6
  import pandas as pd
7
  from pathlib import Path
8
- from collections import defaultdict
9
- import requests
10
  import json
11
  from io import BytesIO
12
- import os
13
- from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork
14
  import boto3
15
  from botocore.exceptions import ClientError
16
  from huggingface_hub import hf_hub_download
17
 
18
- def load_image_from_s3_direct(bucket_name, s3_key, s3_client):
19
- """Cargar imagen directamente desde S3 usando boto3 (RECOMENDADO)"""
20
- try:
21
- print(f"🔄 Cargando imagen desde S3...")
22
- print(f"📦 Bucket: {bucket_name}")
23
- print(f"🗝️ Key: {s3_key}")
24
-
25
- # Descargar objeto desde S3
26
- response = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
27
-
28
- # Leer contenido y convertir a imagen
29
- image_data = response['Body'].read()
30
- bbox_image = Image.open(BytesIO(image_data)).convert('RGB')
31
-
32
- print("✅ Imagen cargada exitosamente")
33
- return bbox_image
34
-
35
- except ClientError as e:
36
- error_code = e.response['Error']['Code']
37
- if error_code == 'NoSuchKey':
38
- print(f"❌ La imagen no existe en S3: {s3_key}")
39
- elif error_code == 'NoSuchBucket':
40
- print(f"❌ El bucket no existe: {bucket_name}")
41
- elif error_code == 'AccessDenied':
42
- print(f"❌ Sin permisos para acceder a: {s3_key}")
43
- else:
44
- print(f"❌ Error de S3: {e}")
45
- return None
46
-
47
- except Exception as e:
48
- print(f"❌ Error cargando imagen: {e}")
49
- return None
50
 
51
  def load_image_from_s3_url(s3_url, s3_client):
52
- """Cargar imagen desde S3 extrayendo bucket y key de la URL"""
53
- try:
54
- # Extraer bucket y key de la URL
55
- # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.jpg
56
- url_parts = s3_url.replace('https://', '').split('/')
57
- bucket = url_parts[0].split('.s3.amazonaws.com')[0]
58
- key = '/'.join(url_parts[1:])
59
-
60
- return load_image_from_s3_direct(bucket, key, s3_client)
61
-
62
- except Exception as e:
63
- print(f"❌ Error procesando URL: {e}")
64
- return None
65
 
66
  def model_selector(self, model_category):
67
- if model_category == 182:
68
- encoder, class_names, prototypes, eval_transform = self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes
69
- elif model_category == 175:
70
- encoder, class_names, prototypes, eval_transform = self.encoder_mascotas, self.class_names_mascotas, self.prototypes_mascotas, self.eval_transform_mascotas
71
- elif model_category == 202:
72
- encoder, class_names, prototypes, eval_transform = self.encoder_vinos, self.class_names_vinos, self.prototypes_vinos, self.eval_transform_vinos
73
- elif model_category == 161:
74
- encoder, class_names, prototypes, eval_transform = self.encoder_cecinas, self.class_names_cecinas, self.prototypes_cecinas, self.eval_transform_cecinas
75
- elif model_category == 198:
76
- encoder, class_names, prototypes, eval_transform = self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores
77
 
78
- return encoder, class_names, prototypes, eval_transform
79
-
80
- # 🆕 NUEVA FUNCIÓN: Configuración de umbrales OOD por categoría
81
  def get_ood_thresholds(model_category):
82
- """
83
- Configuración de umbrales OOD específicos por categoría de modelo
84
- Estos valores pueden ajustarse según la performance de cada modelo
85
- """
86
- ood_config = {
87
- 182: { # detergentes
88
- 'similarity_threshold': 0.65,
89
- 'distance_threshold': 0.85,
90
- 'confidence_penalty': 0.1
91
- },
92
- 175: { # mascotas
93
- 'similarity_threshold': 0.62,
94
- 'distance_threshold': 0.90,
95
- 'confidence_penalty': 0.1
96
- },
97
- 202: { # vinos
98
- 'similarity_threshold': 0.68,
99
- 'distance_threshold': 0.80,
100
- 'confidence_penalty': 0.1
101
- },
102
- 161: { # cecinas
103
- 'similarity_threshold': 0.64,
104
- 'distance_threshold': 0.88,
105
- 'confidence_penalty': 0.1
106
- },
107
- 198: { # licores
108
- 'similarity_threshold': 0.66,
109
- 'distance_threshold': 0.85,
110
- 'confidence_penalty': 0.1
111
- }
112
- }
113
-
114
- # Configuración por defecto si no se encuentra la categoría
115
- default_config = {
116
- 'similarity_threshold': 0.65,
117
- 'distance_threshold': 0.85,
118
- 'confidence_penalty': 0.1
119
- }
120
-
121
- return ood_config.get(model_category, default_config)
122
 
123
- # 🆕 NUEVA FUNCIÓN: Detección OOD
124
- def detect_out_of_distribution(query_features, prototypes, ood_config, class_names):
125
- """
126
- Detecta si una muestra está fuera de distribución usando múltiples métricas
127
-
128
- Args:
129
- query_features: Features de la imagen query (tensor)
130
- prototypes: Prototipos del modelo (tensor)
131
- ood_config: Configuración de umbrales
132
- class_names: Nombres de las clases
133
-
134
- Returns:
135
- is_ood: bool - True si es OOD
136
- ood_score: float - Puntuación de confianza (0=muy OOD, 1=muy in-distribution)
137
- ood_reason: str - Razón de la decisión
138
- """
139
-
140
- # 1. Calcular similitud coseno con todos los prototipos
141
- similarities = torch.mm(query_features, prototypes.t()).squeeze(0)
142
- max_similarity = similarities.max().item()
143
-
144
- # 2. Calcular distancia euclidiana al prototipo más cercano
145
- distances = torch.cdist(query_features, prototypes).squeeze(0)
146
- min_distance = distances.min().item()
147
-
148
- # 3. Aplicar umbrales
149
- similarity_threshold = ood_config['similarity_threshold']
150
- distance_threshold = ood_config['distance_threshold']
151
- confidence_penalty = ood_config['confidence_penalty']
152
-
153
- # 4. Decisión OOD basada en múltiples criterios
154
- is_ood = False
155
- ood_reasons = []
156
-
157
- # Criterio 1: Similitud muy baja
158
- if max_similarity < similarity_threshold:
159
- is_ood = True
160
- ood_reasons.append(f"similitud_baja({max_similarity:.3f}<{similarity_threshold})")
161
-
162
- # Criterio 2: Distancia muy alta
163
- if min_distance > distance_threshold:
164
- is_ood = True
165
- ood_reasons.append(f"distancia_alta({min_distance:.3f}>{distance_threshold})")
166
-
167
- # 5. Calcular puntuación de confianza OOD
168
- # Combinamos similitud y distancia en una métrica unificada
169
- similarity_score = max_similarity # 0-1, más alto = mejor
170
- distance_score = max(0, (distance_threshold - min_distance) / distance_threshold) # 0-1, más alto = mejor
171
-
172
- # Promedio ponderado (puedes ajustar los pesos)
173
- ood_score = (0.7 * similarity_score + 0.3 * distance_score)
174
-
175
- # Aplicar penalización si es OOD
176
- if is_ood:
177
- ood_score = max(0, ood_score - confidence_penalty)
178
-
179
- # 6. Crear razón legible
180
- if is_ood:
181
- ood_reason = f"OOD_DETECTED: {', '.join(ood_reasons)}"
182
- else:
183
- ood_reason = f"IN_DISTRIBUTION: sim={max_similarity:.3f}, dist={min_distance:.3f}"
184
-
185
- return is_ood, ood_score, ood_reason
186
 
187
- # ✅ NUEVA FUNCIÓN OPTIMIZADA: Cargar modelo sin necesidad de dataset
188
  def load_classification_model_optimized(model_path, device):
189
- """Versión optimizada que carga prototipos directamente del modelo guardado"""
190
-
191
- if not Path(model_path).exists():
192
- raise FileNotFoundError(f"❌ No se encontró el modelo: {model_path}")
193
-
194
- print(f"✅ Cargando modelo optimizado: {model_path}")
195
-
196
- # Cargar checkpoint
197
- checkpoint = torch.load(model_path, map_location=device, weights_only=False)
198
-
199
- # Verificar que el modelo tiene prototipos guardados
200
- if 'prototypes' not in checkpoint or 'class_names' not in checkpoint:
201
- raise ValueError(" El modelo no contiene prototipos guardados. Necesitas re-entrenar con la versión actualizada del código.")
202
-
203
- # Cargar configuración del modelo
204
- model_config = checkpoint.get('model_config', {})
205
- hidden_dim = model_config.get('hidden_dim', 64)
206
- output_dim = model_config.get('output_dim', 256)
207
-
208
- # Cargar arquitectura del modelo
209
- encoder = ConvEncoder(hidden_dim=hidden_dim, output_dim=output_dim).to(device)
210
- model = PrototypicalNetwork(encoder).to(device)
211
-
212
- # Cargar pesos
213
- encoder.load_state_dict(checkpoint['encoder_state_dict'])
214
- model.load_state_dict(checkpoint['model_state_dict'])
215
- encoder.eval()
216
- model.eval()
217
-
218
- # ✅ Cargar prototipos y clases guardados
219
- prototypes = checkpoint['prototypes'].to(device)
220
- class_names = checkpoint['class_names']
221
-
222
- print(f"✅ Modelo cargado correctamente:")
223
- print(f" - Prototipos: {len(class_names)} clases")
224
- print(f" - Dimensión: {prototypes.shape}")
225
- print(f" - Clases: {', '.join(class_names[:5])}{'...' if len(class_names) > 5 else ''}")
226
-
227
- # Transformaciones para evaluación (mismas que en entrenamiento)
228
- eval_transform = transforms.Compose([
229
- transforms.Resize((224, 224)),
230
- transforms.ToTensor(),
231
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
232
- std=[0.229, 0.224, 0.225])
233
- ])
234
-
235
- return encoder, class_names, prototypes, eval_transform
236
 
237
  def load_json_from_s3(json_s3_url):
238
- # Configuración S3
239
- aws_access_key = 'AKIA6BH4GPXQCUZ3PAX5' # Cambiar por tu access key
240
- aws_secret_key = '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' # Cambiar por tu secret key
241
- region_name = 'us-east-1' # Cambiar por tu región
242
- S3_BUCKET_NAME = 'rocketpin-ml-data' # Cambiar por tu bucket
243
-
244
- # Crear sesión y cliente S3
245
- session = boto3.Session(
246
- aws_access_key_id=aws_access_key,
247
- aws_secret_access_key=aws_secret_key,
248
- region_name=region_name
249
- )
250
- s3_client = session.client('s3')
251
 
252
- """Cargar JSON desde S3 usando la URL completa"""
253
- try:
254
- # Extraer bucket y key de la URL
255
- # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.json
256
- url_parts = json_s3_url.replace('https://', '').split('/')
257
- bucket = url_parts[0].split('.s3.amazonaws.com')[0]
258
- key = '/'.join(url_parts[1:])
259
-
260
- # Descargar objeto desde S3
261
- response = s3_client.get_object(Bucket=bucket, Key=key)
262
-
263
- # Leer contenido y convertir a JSON
264
- json_content = response['Body'].read().decode('utf-8')
265
- json_data = json.loads(json_content)
266
-
267
- print("✅ JSON cargado exitosamente")
268
- return json_data, s3_client
269
-
270
- except ClientError as e:
271
- error_code = e.response['Error']['Code']
272
- if error_code == 'NoSuchKey':
273
- print(f"❌ El archivo no existe en S3: {key}")
274
- elif error_code == 'NoSuchBucket':
275
- print(f"❌ El bucket no existe: {bucket}")
276
- else:
277
- print(f"❌ Error de S3: {e}")
278
- return None
279
-
280
- except Exception as e:
281
- print(f"❌ Error cargando JSON: {e}")
282
- return None
283
 
284
  def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category):
285
- """
286
- 🆕 MEJORADO: Clasificar las imágenes de bounding boxes guardadas CON DETECCIÓN OOD
287
- """
288
-
289
- if not saved_images:
290
- print("❌ No hay imágenes guardadas para clasificar")
291
- return pd.DataFrame()
292
-
293
- print(f"🔄 Clasificando {len(saved_images)} imágenes guardadas...")
294
- print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
295
-
296
- # 🆕 Obtener configuración OOD para esta categoría
297
- ood_config = get_ood_thresholds(model_category)
298
- print(f"🛡️ Detección OOD activada:")
299
- print(f" - Umbral similitud: {ood_config['similarity_threshold']}")
300
- print(f" - Umbral distancia: {ood_config['distance_threshold']}")
301
-
302
- results = []
303
- filtered_count = 0
304
- ood_detected_count = 0
305
-
306
- with torch.no_grad():
307
- for img_info in saved_images:
308
- try:
309
- # Cargar imagen guardada
310
- bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client)
311
-
312
- if bbox_image is None:
313
- print(f"❌ No se pudo cargar imagen: {img_info['bbox_path']}")
314
- continue
315
-
316
- # Transformar para el modelo
317
- query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device)
318
-
319
- # Extraer características
320
- query_features = encoder(query_tensor)
321
- # Normalizar
322
- query_features = F.normalize(query_features, p=2, dim=1)
323
-
324
- # 🆕 DETECCIÓN OOD
325
- is_ood, ood_score, ood_reason = detect_out_of_distribution(
326
- query_features, prototypes, ood_config, class_names
327
- )
328
-
329
- # 🆕 Si es OOD, manejar de forma especial
330
- if is_ood:
331
- ood_detected_count += 1
332
- print(f"🚨 OOD detectado en bbox {img_info['bbox_id']}: {ood_reason}")
333
-
334
- # Opción 1: Filtrar completamente (recomendado)
335
- filtered_count += 1
336
- continue
337
-
338
- # Opción 2: Marcar como "PRODUCTO_DESCONOCIDO" (opcional - descomenta si prefieres esto)
339
- # result = {
340
- # 'sku_bb_id': str(img_info['bbox_id']),
341
- # 'predictions': ['PRODUCTO_DESCONOCIDO'],
342
- # 'accuracy': [round(ood_score, 4)],
343
- # 'prediccion_principal': 'PRODUCTO_DESCONOCIDO',
344
- # 'similarity_principal': f"{ood_score*100:.2f}%",
345
- # 'bbox_confidence': round(float(img_info['confidence']), 4),
346
- # 'ood_detected': True,
347
- # 'ood_reason': ood_reason,
348
- # 'xmin': img_info['x_min'],
349
- # 'ymin': img_info['y_min'],
350
- # 'xmax': img_info['x_max'],
351
- # 'ymax': img_info['y_max']
352
- # }
353
- # results.append(result)
354
- # continue
355
-
356
- # Calcular similitud coseno con prototipos guardados (solo si no es OOD)
357
- similarities = torch.mm(query_features, prototypes.t())
358
- similarities_numpy = similarities.cpu().numpy()[0]
359
-
360
- # Obtener top 3 predicciones
361
- top3_indices = np.argsort(similarities_numpy)[::-1]
362
- top3_predictions = []
363
- top3_similarities = []
364
-
365
- for idx_pred in top3_indices:
366
- prediction = class_names[idx_pred]
367
- similarity = similarities_numpy[idx_pred]
368
- # Solo agregar si cumple con minimal_accuracy
369
- if similarity >= minimal_accuracy:
370
- top3_predictions.append(prediction)
371
- top3_similarities.append(round(similarity, 4))
372
-
373
- # Si no hay predicciones que cumplan con minimal_accuracy, saltar
374
- if len(top3_predictions) == 0:
375
- filtered_count += 1
376
- print(f"🔽 Bbox {img_info['bbox_id']} filtrado: ninguna predicción cumple minimal_accuracy {minimal_accuracy}")
377
- continue
378
-
379
- # 🆕 Aplicar ajuste de confianza basado en OOD score
380
- adjusted_similarities = []
381
- for sim in top3_similarities:
382
- # Combinar similarity original con OOD confidence
383
- adjusted_sim = (sim * 0.8) + (ood_score * 0.2) # Peso 80-20
384
- adjusted_similarities.append(round(adjusted_sim, 4))
385
-
386
- # Guardar predictions y accuracy como listas (solo las que cumplen el filtro)
387
- predictions_list = top3_predictions
388
- similarities_list = adjusted_similarities # 🆕 Usar similarities ajustadas
389
-
390
- # La predicción principal es la primera de la lista filtrada
391
- predicted_class = predictions_list[0]
392
-
393
- # Formatear similarity_principal como porcentaje
394
- similarity_principal_formatted = f"{similarities_list[0]*100:.2f}%"
395
-
396
- # Formatear bbox_confidence con 4 decimales
397
- bbox_confidence_formatted = round(float(img_info['confidence']), 4)
398
-
399
- # 🆕 Agregar resultado con información OOD
400
- result = {
401
- 'sku_bb_id': str(img_info['bbox_id']),
402
- 'predictions': predictions_list,
403
- 'accuracy': similarities_list,
404
- 'prediccion_principal': predicted_class,
405
- 'similarity_principal': similarity_principal_formatted,
406
- 'bbox_confidence': bbox_confidence_formatted,
407
- 'ood_detected': False, # 🆕 No es OOD
408
- 'ood_score': round(ood_score, 4), # 🆕 Puntuación OOD
409
- 'xmin': img_info['x_min'],
410
- 'ymin': img_info['y_min'],
411
- 'xmax': img_info['x_max'],
412
- 'ymax': img_info['y_max']
413
- }
414
-
415
- results.append(result)
416
-
417
- except Exception as e:
418
- print(f"❌ Error clasificando bbox {str(img_info['bbox_id'])}: {e}")
419
- # Agregar entrada de error
420
- results.append({
421
- 'sku_bb_id': str(img_info['bbox_id']),
422
- 'predictions': ['ERROR'],
423
- 'accuracy': [0.0000],
424
- 'prediccion_principal': 'ERROR',
425
- 'similarity_principal': 'ERROR',
426
- 'bbox_confidence': round(float(img_info['confidence']), 4),
427
- 'ood_detected': False,
428
- 'ood_score': 0.0000,
429
- 'xmin': img_info['x_min'],
430
- 'ymin': img_info['y_min'],
431
- 'xmax': img_info['x_max'],
432
- 'ymax': img_info['y_max']
433
- })
434
-
435
- # 🆕 Resumen mejorado con estadísticas OOD
436
- if filtered_count > 0 or ood_detected_count > 0:
437
- print(f"📊 Resumen de filtrado:")
438
- print(f" - Detecciones procesadas: {len(results)}")
439
- print(f" - Detecciones filtradas por accuracy: {filtered_count - ood_detected_count}")
440
- print(f" - 🆕 Detecciones OOD filtradas: {ood_detected_count}")
441
- print(f" - Total filtrado: {filtered_count}")
442
- print(f" - Total original: {len(saved_images)}")
443
- if ood_detected_count > 0:
444
- ood_percentage = (ood_detected_count / len(saved_images)) * 100
445
- print(f" - 🆕 Porcentaje OOD: {ood_percentage:.1f}%")
446
-
447
- return pd.DataFrame(results)
448
 
449
  def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url):
450
- """Función principal para procesar una imagen: detectar BB, guardar recortes y clasificar"""
451
-
452
- print("="*80)
453
- print("PROCESAMIENTO DE IMAGEN CON BOUNDING BOXES - MODELO OPTIMIZADO V5 + OOD")
454
- print("="*80)
455
- print(f"📸 Imagen: {image_url}")
456
- print(f"🆔 Picture ID: {picture_id}")
457
-
458
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
459
- print(f"💻 Dispositivo: {device}")
460
- print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
461
- print(f"🛡️ Detección OOD activada para categoría: {model_category}") # 🆕
462
-
463
- # Cargar bounding boxes desde S3
464
- saved_images, s3_client = load_json_from_s3(json_s3_url)
465
- saved_images = saved_images['bounding_boxes']
466
-
467
- if not saved_images:
468
- print("❌ No se pudieron cargar las imágenes desde S3")
469
- return pd.DataFrame()
470
-
471
- # 3. Cargar modelo de clasificación (OPTIMIZADO)
472
- print("\n🤖 PASO 3: Cargando modelo de clasificación optimizado...")
473
- try:
474
- encoder, class_names, prototypes, eval_transform = model_selector(self, model_category)
475
- except Exception as e:
476
- print(f" Error cargando modelo: {e}")
477
- return pd.DataFrame()
478
-
479
- # 4. Clasificar imágenes guardadas CON DETECCIÓN OOD
480
- print("\n🔬 PASO 4: Clasificando imágenes guardadas con detección OOD...")
481
- results_df = classify_saved_bboxes(
482
- saved_images, encoder, class_names, prototypes, eval_transform, device,
483
- minimal_accuracy, s3_client, model_category # 🆕 Pasar model_category
484
- )
485
-
486
- # 5. Mostrar resumen
487
- if not results_df.empty:
488
- print(f"\n✅ Procesamiento completado:")
489
- print(f" - Total de detecciones procesadas: {len(results_df)}")
490
- print(f" - Clases detectadas: {results_df['prediccion_principal'].nunique()}")
491
- print(f" - Clases únicas encontradas: {', '.join(results_df['prediccion_principal'].unique())}")
492
-
493
- # 🆕 Estadísticas OOD
494
- if 'ood_score' in results_df.columns:
495
- avg_ood_score = results_df['ood_score'].mean()
496
- min_ood_score = results_df['ood_score'].min()
497
- max_ood_score = results_df['ood_score'].max()
498
- print(f"\n🛡️ Estadísticas OOD:")
499
- print(f" - OOD Score promedio: {avg_ood_score:.4f}")
500
- print(f" - OOD Score mínimo: {min_ood_score:.4f}")
501
- print(f" - OOD Score máximo: {max_ood_score:.4f}")
502
-
503
- # Top predicciones
504
- print(f"\n📊 Top predicciones:")
505
- top_predictions = results_df['prediccion_principal'].value_counts().head(5)
506
- for clase, count in top_predictions.items():
507
- print(f" - {clase}: {count} detecciones")
508
-
509
- # Estadísticas de accuracy
510
- if len(results_df) > 0:
511
- avg_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).mean()
512
- min_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).min()
513
- max_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).max()
514
- print(f"\n📈 Estadísticas de accuracy:")
515
- print(f" - Promedio: {avg_accuracy:.4f}")
516
- print(f" - Mínimo: {min_accuracy:.4f}")
517
- print(f" - Máximo: {max_accuracy:.4f}")
518
- else:
519
- print("❌ No hay detecciones que cumplan con el filtro de accuracy o todas fueron detectadas como OOD")
520
-
521
- return results_df
522
 
523
  class EndpointHandler():
524
- def __init__(self, path=""):
525
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
526
- model_filename_licores = "model_curriculum4/prototypical_model_best_licores.pth"
527
- local_model_path_licores = hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_filename_licores)
528
- self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores = load_classification_model_optimized(
529
- local_model_path_licores, device)
 
 
 
 
 
 
530
 
531
- def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
532
-
533
- print("Ejecutando clasificación optimizada con prototipos pre-cargados y detección OOD...")
534
- result_df = process_image_with_bboxes(
535
- self, image_url, picture_id, visit_id, minimal_accuracy,
536
- None, None, # model_path y train_path ya no son necesarios
537
- model_category, json_s3_url
538
- )
539
- return result_df
540
 
541
- def __call__(self, event):
542
- if "inputs" not in event:
543
- return {
544
- "statusCode": 400,
545
- "body": json.dumps("Error: Please provide an 'inputs' parameter."),
546
- }
547
- event = event["inputs"]
548
- image_url = event["image_url"]
549
- picture_id = event["picture_id"]
550
- visit_id = event["visit_id"]
551
- minimal_accuracy = event["minimal_accuracy"]
552
- model_category = event["model_category"]
553
- json_s3_url = event["json_s3_url"]
554
-
555
- try:
556
- predictions = self.predict_objects(image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url)
557
- predictions_json = predictions.to_json(orient='records')
558
-
559
- return {
560
- "statusCode": 200,
561
- "body": json.dumps(predictions_json),
562
- }
563
- except Exception as e:
564
- return {
565
- "statusCode": 500,
566
- "body": json.dumps(f"Error: {str(e)}"),
567
- }
 
5
  import numpy as np
6
  import pandas as pd
7
  from pathlib import Path
 
 
8
  import json
9
  from io import BytesIO
 
 
10
  import boto3
11
  from botocore.exceptions import ClientError
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Imports desde el código de entrenamiento actualizado
15
+ from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork, SmartPadResize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def load_image_from_s3_url(s3_url, s3_client):
18
+ """Cargar imagen desde S3 extrayendo bucket y key de la URL"""
19
+ try:
20
+ url_parts = s3_url.replace('https://', '').split('/')
21
+ bucket = url_parts[0].split('.s3.amazonaws.com')[0]
22
+ key = '/'.join(url_parts[1:])
23
+
24
+ response = s3_client.get_object(Bucket=bucket, Key=key)
25
+ image_data = response['Body'].read()
26
+ return Image.open(BytesIO(image_data)).convert('RGB')
27
+ except Exception as e:
28
+ print(f"❌ Error cargando imagen: {e}")
29
+ return None
 
30
 
31
  def model_selector(self, model_category):
32
+ """Seleccionar modelo según categoría"""
33
+ models = {
34
+ 182: (self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes),
35
+ 175: (self.encoder_mascotas, self.class_names_mascotas, self.prototypes_mascotas, self.eval_transform_mascotas),
36
+ 202: (self.encoder_vinos, self.class_names_vinos, self.prototypes_vinos, self.eval_transform_vinos),
37
+ 161: (self.encoder_cecinas, self.class_names_cecinas, self.prototypes_cecinas, self.eval_transform_cecinas),
38
+ 198: (self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores)
39
+ }
40
+ return models.get(model_category)
 
41
 
 
 
 
42
  def get_ood_thresholds(model_category):
43
+ """Umbrales OOD para modelos 512px"""
44
+ config = {
45
+ 182: {'similarity_threshold': 0.70, 'distance_threshold': 0.80}, # detergentes
46
+ 175: {'similarity_threshold': 0.68, 'distance_threshold': 0.85}, # mascotas
47
+ 202: {'similarity_threshold': 0.72, 'distance_threshold': 0.75}, # vinos
48
+ 161: {'similarity_threshold': 0.69, 'distance_threshold': 0.82}, # cecinas
49
+ 198: {'similarity_threshold': 0.71, 'distance_threshold': 0.78} # licores
50
+ }
51
+ return config.get(model_category, {'similarity_threshold': 0.70, 'distance_threshold': 0.80})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def detect_out_of_distribution(query_features, prototypes, ood_config):
54
+ """Detección OOD simplificada"""
55
+ similarities = torch.mm(query_features, prototypes.t()).squeeze(0)
56
+ max_similarity = similarities.max().item()
57
+
58
+ distances = torch.cdist(query_features, prototypes).squeeze(0)
59
+ min_distance = distances.min().item()
60
+
61
+ # Criterios OOD
62
+ is_ood = (max_similarity < ood_config['similarity_threshold'] or
63
+ min_distance > ood_config['distance_threshold'])
64
+
65
+ # Score combinado
66
+ similarity_score = max_similarity
67
+ distance_score = max(0, (ood_config['distance_threshold'] - min_distance) / ood_config['distance_threshold'])
68
+ ood_score = (0.7 * similarity_score + 0.3 * distance_score)
69
+
70
+ if is_ood:
71
+ ood_score = max(0, ood_score - 0.05)
72
+
73
+ return is_ood, ood_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
75
  def load_classification_model_optimized(model_path, device):
76
+ """Cargar modelo 512px únicamente"""
77
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
78
+
79
+ if 'prototypes' not in checkpoint or 'class_names' not in checkpoint:
80
+ raise ValueError("❌ Modelo sin prototipos. Re-entrena con código actualizado.")
81
+
82
+ # Configuración del modelo
83
+ model_config = checkpoint.get('model_config', {})
84
+ hidden_dim = model_config.get('hidden_dim', 64)
85
+ output_dim = model_config.get('output_dim', 256)
86
+ image_size = model_config.get('image_size', 512)
87
+
88
+ print(f"📊 Cargando modelo {image_size}px: {len(checkpoint['class_names'])} clases")
89
+
90
+ # Cargar arquitectura y pesos
91
+ encoder = ConvEncoder(hidden_dim=hidden_dim, output_dim=output_dim).to(device)
92
+ model = PrototypicalNetwork(encoder).to(device)
93
+
94
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
95
+ model.load_state_dict(checkpoint['model_state_dict'])
96
+ encoder.eval()
97
+ model.eval()
98
+
99
+ # Prototipos y clases
100
+ prototypes = checkpoint['prototypes'].to(device)
101
+ class_names = checkpoint['class_names']
102
+
103
+ # Transformaciones 512px con SmartPadResize
104
+ eval_transform = transforms.Compose([
105
+ SmartPadResize(target_size=image_size, fill_value=128),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
108
+ ])
109
+
110
+ return encoder, class_names, prototypes, eval_transform
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  def load_json_from_s3(json_s3_url):
113
+ """Cargar JSON desde S3"""
114
+ session = boto3.Session(
115
+ aws_access_key_id='AKIA6BH4GPXQCUZ3PAX5',
116
+ aws_secret_access_key='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',
117
+ region_name='us-east-1'
118
+ )
119
+ s3_client = session.client('s3')
 
 
 
 
 
 
120
 
121
+ try:
122
+ url_parts = json_s3_url.replace('https://', '').split('/')
123
+ bucket = url_parts[0].split('.s3.amazonaws.com')[0]
124
+ key = '/'.join(url_parts[1:])
125
+
126
+ response = s3_client.get_object(Bucket=bucket, Key=key)
127
+ json_content = response['Body'].read().decode('utf-8')
128
+ return json.loads(json_content), s3_client
129
+ except Exception as e:
130
+ print(f"❌ Error cargando JSON: {e}")
131
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category):
134
+ """Clasificación con detección OOD"""
135
+ if not saved_images:
136
+ return pd.DataFrame()
137
+
138
+ print(f"🔄 Clasificando {len(saved_images)} imágenes...")
139
+
140
+ ood_config = get_ood_thresholds(model_category)
141
+ results = []
142
+ filtered_count = 0
143
+ ood_count = 0
144
+
145
+ with torch.no_grad():
146
+ for img_info in saved_images:
147
+ try:
148
+ # Cargar y transformar imagen
149
+ bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client)
150
+ if bbox_image is None:
151
+ continue
152
+
153
+ query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device)
154
+ query_features = F.normalize(encoder(query_tensor), p=2, dim=1)
155
+
156
+ # Detección OOD
157
+ is_ood, ood_score = detect_out_of_distribution(query_features, prototypes, ood_config)
158
+
159
+ if is_ood:
160
+ ood_count += 1
161
+ filtered_count += 1
162
+ continue
163
+
164
+ # Calcular similitudes
165
+ similarities = torch.mm(query_features, prototypes.t()).cpu().numpy()[0]
166
+ top3_indices = np.argsort(similarities)[::-1]
167
+
168
+ # Filtrar por minimal_accuracy
169
+ predictions = []
170
+ accuracies = []
171
+
172
+ for idx in top3_indices:
173
+ if similarities[idx] >= minimal_accuracy:
174
+ predictions.append(class_names[idx])
175
+ accuracies.append(round(similarities[idx], 4))
176
+
177
+ if not predictions:
178
+ filtered_count += 1
179
+ continue
180
+
181
+ # Ajustar con OOD score
182
+ adjusted_accuracies = [round((acc * 0.9) + (ood_score * 0.1), 4) for acc in accuracies]
183
+
184
+ result = {
185
+ 'sku_bb_id': str(img_info['bbox_id']),
186
+ 'predictions': predictions,
187
+ 'accuracy': adjusted_accuracies,
188
+ 'prediccion_principal': predictions[0],
189
+ 'similarity_principal': f"{adjusted_accuracies[0]*100:.2f}%",
190
+ 'bbox_confidence': round(float(img_info['confidence']), 4),
191
+ 'ood_score': round(ood_score, 4),
192
+ 'xmin': img_info['x_min'],
193
+ 'ymin': img_info['y_min'],
194
+ 'xmax': img_info['x_max'],
195
+ 'ymax': img_info['y_max']
196
+ }
197
+ results.append(result)
198
+
199
+ except Exception as e:
200
+ print(f"❌ Error en bbox {img_info['bbox_id']}: {e}")
201
+ continue
202
+
203
+ print(f"📊 Procesadas: {len(results)}, Filtradas: {filtered_count}, OOD: {ood_count}")
204
+ return pd.DataFrame(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url):
207
+ """Función principal de procesamiento"""
208
+ print(f"🚀 Procesando imagen con modelo 512px - Categoría: {model_category}")
209
+
210
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
211
+
212
+ # Cargar bounding boxes
213
+ saved_images, s3_client = load_json_from_s3(json_s3_url)
214
+ if not saved_images:
215
+ return pd.DataFrame()
216
+
217
+ saved_images = saved_images['bounding_boxes']
218
+
219
+ # Seleccionar modelo
220
+ try:
221
+ encoder, class_names, prototypes, eval_transform = model_selector(self, model_category)
222
+ except Exception as e:
223
+ print(f"❌ Error cargando modelo: {e}")
224
+ return pd.DataFrame()
225
+
226
+ # Clasificar
227
+ results_df = classify_saved_bboxes(
228
+ saved_images, encoder, class_names, prototypes, eval_transform,
229
+ device, minimal_accuracy, s3_client, model_category
230
+ )
231
+
232
+ if not results_df.empty:
233
+ print(f" {len(results_df)} detecciones procesadas")
234
+ print(f"📊 Clases detectadas: {', '.join(results_df['prediccion_principal'].unique())}")
235
+
236
+ return results_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  class EndpointHandler():
239
+ def __init__(self, path=""):
240
+ """Inicialización con modelos 512px únicamente"""
241
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242
+ print(f"🚀 Inicializando handler con device: {device}")
243
+
244
+ # Cargar modelo de licores
245
+ model_filename = "model_curriculum4/prototypical_model_best_licores.pth"
246
+ local_model_path = hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_filename)
247
+
248
+ self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores = load_classification_model_optimized(local_model_path, device)
249
+
250
+ print("✅ Handler inicializado")
251
 
252
+ def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
253
+ """Predicción con modelos 512px"""
254
+ return process_image_with_bboxes(
255
+ self, image_url, picture_id, visit_id, minimal_accuracy,
256
+ None, None, model_category, json_s3_url
257
+ )
 
 
 
258
 
259
+ def __call__(self, event):
260
+ """Método de llamada principal"""
261
+ if "inputs" not in event:
262
+ return {"statusCode": 400, "body": json.dumps("Error: No 'inputs' parameter.")}
263
+
264
+ event = event["inputs"]
265
+
266
+ try:
267
+ predictions = self.predict_objects(
268
+ event["image_url"], event["picture_id"], event["visit_id"],
269
+ event["minimal_accuracy"], event["model_category"], event["json_s3_url"]
270
+ )
271
+
272
+ return {
273
+ "statusCode": 200,
274
+ "body": json.dumps(predictions.to_json(orient='records'))
275
+ }
276
+ except Exception as e:
277
+ return {"statusCode": 500, "body": json.dumps(f"Error: {str(e)}")}