Drazcat-AI commited on
Commit
860195d
·
verified ·
1 Parent(s): 0bd7e9b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +190 -13
handler.py CHANGED
@@ -77,6 +77,113 @@ def model_selector(self, model_category):
77
 
78
  return encoder, class_names, prototypes, eval_transform
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # ✅ NUEVA FUNCIÓN OPTIMIZADA: Cargar modelo sin necesidad de dataset
81
  def load_classification_model_optimized(model_path, device):
82
  """Versión optimizada que carga prototipos directamente del modelo guardado"""
@@ -174,8 +281,10 @@ def load_json_from_s3(json_s3_url):
174
  print(f"❌ Error cargando JSON: {e}")
175
  return None
176
 
177
- def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client):
178
- """Clasificar las imágenes de bounding boxes guardadas"""
 
 
179
 
180
  if not saved_images:
181
  print("❌ No hay imágenes guardadas para clasificar")
@@ -184,8 +293,16 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
184
  print(f"🔄 Clasificando {len(saved_images)} imágenes guardadas...")
185
  print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
186
 
 
 
 
 
 
 
187
  results = []
188
  filtered_count = 0
 
 
189
  with torch.no_grad():
190
  for img_info in saved_images:
191
  try:
@@ -204,7 +321,39 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
204
  # Normalizar
205
  query_features = F.normalize(query_features, p=2, dim=1)
206
 
207
- # Calcular similitud coseno con prototipos guardados
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  similarities = torch.mm(query_features, prototypes.t())
209
  similarities_numpy = similarities.cpu().numpy()[0]
210
 
@@ -227,9 +376,16 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
227
  print(f"🔽 Bbox {img_info['bbox_id']} filtrado: ninguna predicción cumple minimal_accuracy {minimal_accuracy}")
228
  continue
229
 
 
 
 
 
 
 
 
230
  # Guardar predictions y accuracy como listas (solo las que cumplen el filtro)
231
  predictions_list = top3_predictions
232
- similarities_list = top3_similarities
233
 
234
  # La predicción principal es la primera de la lista filtrada
235
  predicted_class = predictions_list[0]
@@ -240,7 +396,7 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
240
  # Formatear bbox_confidence con 4 decimales
241
  bbox_confidence_formatted = round(float(img_info['confidence']), 4)
242
 
243
- # Agregar resultado
244
  result = {
245
  'sku_bb_id': str(img_info['bbox_id']),
246
  'predictions': predictions_list,
@@ -248,6 +404,8 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
248
  'prediccion_principal': predicted_class,
249
  'similarity_principal': similarity_principal_formatted,
250
  'bbox_confidence': bbox_confidence_formatted,
 
 
251
  'xmin': img_info['x_min'],
252
  'ymin': img_info['y_min'],
253
  'xmax': img_info['x_max'],
@@ -266,17 +424,25 @@ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_t
266
  'prediccion_principal': 'ERROR',
267
  'similarity_principal': 'ERROR',
268
  'bbox_confidence': round(float(img_info['confidence']), 4),
 
 
269
  'xmin': img_info['x_min'],
270
  'ymin': img_info['y_min'],
271
  'xmax': img_info['x_max'],
272
  'ymax': img_info['y_max']
273
  })
274
 
275
- if filtered_count > 0:
 
276
  print(f"📊 Resumen de filtrado:")
277
  print(f" - Detecciones procesadas: {len(results)}")
278
- print(f" - Detecciones filtradas: {filtered_count}")
 
 
279
  print(f" - Total original: {len(saved_images)}")
 
 
 
280
 
281
  return pd.DataFrame(results)
282
 
@@ -284,7 +450,7 @@ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_acc
284
  """Función principal para procesar una imagen: detectar BB, guardar recortes y clasificar"""
285
 
286
  print("="*80)
287
- print("PROCESAMIENTO DE IMAGEN CON BOUNDING BOXES - MODELO OPTIMIZADO V5")
288
  print("="*80)
289
  print(f"📸 Imagen: {image_url}")
290
  print(f"🆔 Picture ID: {picture_id}")
@@ -292,6 +458,7 @@ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_acc
292
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
293
  print(f"💻 Dispositivo: {device}")
294
  print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
 
295
 
296
  # Cargar bounding boxes desde S3
297
  saved_images, s3_client = load_json_from_s3(json_s3_url)
@@ -309,11 +476,11 @@ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_acc
309
  print(f"❌ Error cargando modelo: {e}")
310
  return pd.DataFrame()
311
 
312
- # 4. Clasificar imágenes guardadas
313
- print("\n🔬 PASO 4: Clasificando imágenes guardadas...")
314
  results_df = classify_saved_bboxes(
315
  saved_images, encoder, class_names, prototypes, eval_transform, device,
316
- minimal_accuracy, s3_client
317
  )
318
 
319
  # 5. Mostrar resumen
@@ -323,6 +490,16 @@ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_acc
323
  print(f" - Clases detectadas: {results_df['prediccion_principal'].nunique()}")
324
  print(f" - Clases únicas encontradas: {', '.join(results_df['prediccion_principal'].unique())}")
325
 
 
 
 
 
 
 
 
 
 
 
326
  # Top predicciones
327
  print(f"\n📊 Top predicciones:")
328
  top_predictions = results_df['prediccion_principal'].value_counts().head(5)
@@ -339,7 +516,7 @@ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_acc
339
  print(f" - Mínimo: {min_accuracy:.4f}")
340
  print(f" - Máximo: {max_accuracy:.4f}")
341
  else:
342
- print("❌ No hay detecciones que cumplan con el filtro de accuracy")
343
 
344
  return results_df
345
 
@@ -353,7 +530,7 @@ class EndpointHandler():
353
 
354
  def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
355
 
356
- print("Ejecutando clasificación optimizada con prototipos pre-cargados...")
357
  result_df = process_image_with_bboxes(
358
  self, image_url, picture_id, visit_id, minimal_accuracy,
359
  None, None, # model_path y train_path ya no son necesarios
 
77
 
78
  return encoder, class_names, prototypes, eval_transform
79
 
80
+ # 🆕 NUEVA FUNCIÓN: Configuración de umbrales OOD por categoría
81
+ def get_ood_thresholds(model_category):
82
+ """
83
+ Configuración de umbrales OOD específicos por categoría de modelo
84
+ Estos valores pueden ajustarse según la performance de cada modelo
85
+ """
86
+ ood_config = {
87
+ 182: { # detergentes
88
+ 'similarity_threshold': 0.65,
89
+ 'distance_threshold': 0.85,
90
+ 'confidence_penalty': 0.1
91
+ },
92
+ 175: { # mascotas
93
+ 'similarity_threshold': 0.62,
94
+ 'distance_threshold': 0.90,
95
+ 'confidence_penalty': 0.1
96
+ },
97
+ 202: { # vinos
98
+ 'similarity_threshold': 0.68,
99
+ 'distance_threshold': 0.80,
100
+ 'confidence_penalty': 0.1
101
+ },
102
+ 161: { # cecinas
103
+ 'similarity_threshold': 0.64,
104
+ 'distance_threshold': 0.88,
105
+ 'confidence_penalty': 0.1
106
+ },
107
+ 198: { # licores
108
+ 'similarity_threshold': 0.66,
109
+ 'distance_threshold': 0.85,
110
+ 'confidence_penalty': 0.1
111
+ }
112
+ }
113
+
114
+ # Configuración por defecto si no se encuentra la categoría
115
+ default_config = {
116
+ 'similarity_threshold': 0.65,
117
+ 'distance_threshold': 0.85,
118
+ 'confidence_penalty': 0.1
119
+ }
120
+
121
+ return ood_config.get(model_category, default_config)
122
+
123
+ # 🆕 NUEVA FUNCIÓN: Detección OOD
124
+ def detect_out_of_distribution(query_features, prototypes, ood_config, class_names):
125
+ """
126
+ Detecta si una muestra está fuera de distribución usando múltiples métricas
127
+
128
+ Args:
129
+ query_features: Features de la imagen query (tensor)
130
+ prototypes: Prototipos del modelo (tensor)
131
+ ood_config: Configuración de umbrales
132
+ class_names: Nombres de las clases
133
+
134
+ Returns:
135
+ is_ood: bool - True si es OOD
136
+ ood_score: float - Puntuación de confianza (0=muy OOD, 1=muy in-distribution)
137
+ ood_reason: str - Razón de la decisión
138
+ """
139
+
140
+ # 1. Calcular similitud coseno con todos los prototipos
141
+ similarities = torch.mm(query_features, prototypes.t()).squeeze(0)
142
+ max_similarity = similarities.max().item()
143
+
144
+ # 2. Calcular distancia euclidiana al prototipo más cercano
145
+ distances = torch.cdist(query_features, prototypes).squeeze(0)
146
+ min_distance = distances.min().item()
147
+
148
+ # 3. Aplicar umbrales
149
+ similarity_threshold = ood_config['similarity_threshold']
150
+ distance_threshold = ood_config['distance_threshold']
151
+ confidence_penalty = ood_config['confidence_penalty']
152
+
153
+ # 4. Decisión OOD basada en múltiples criterios
154
+ is_ood = False
155
+ ood_reasons = []
156
+
157
+ # Criterio 1: Similitud muy baja
158
+ if max_similarity < similarity_threshold:
159
+ is_ood = True
160
+ ood_reasons.append(f"similitud_baja({max_similarity:.3f}<{similarity_threshold})")
161
+
162
+ # Criterio 2: Distancia muy alta
163
+ if min_distance > distance_threshold:
164
+ is_ood = True
165
+ ood_reasons.append(f"distancia_alta({min_distance:.3f}>{distance_threshold})")
166
+
167
+ # 5. Calcular puntuación de confianza OOD
168
+ # Combinamos similitud y distancia en una métrica unificada
169
+ similarity_score = max_similarity # 0-1, más alto = mejor
170
+ distance_score = max(0, (distance_threshold - min_distance) / distance_threshold) # 0-1, más alto = mejor
171
+
172
+ # Promedio ponderado (puedes ajustar los pesos)
173
+ ood_score = (0.7 * similarity_score + 0.3 * distance_score)
174
+
175
+ # Aplicar penalización si es OOD
176
+ if is_ood:
177
+ ood_score = max(0, ood_score - confidence_penalty)
178
+
179
+ # 6. Crear razón legible
180
+ if is_ood:
181
+ ood_reason = f"OOD_DETECTED: {', '.join(ood_reasons)}"
182
+ else:
183
+ ood_reason = f"IN_DISTRIBUTION: sim={max_similarity:.3f}, dist={min_distance:.3f}"
184
+
185
+ return is_ood, ood_score, ood_reason
186
+
187
  # ✅ NUEVA FUNCIÓN OPTIMIZADA: Cargar modelo sin necesidad de dataset
188
  def load_classification_model_optimized(model_path, device):
189
  """Versión optimizada que carga prototipos directamente del modelo guardado"""
 
281
  print(f"❌ Error cargando JSON: {e}")
282
  return None
283
 
284
+ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category):
285
+ """
286
+ 🆕 MEJORADO: Clasificar las imágenes de bounding boxes guardadas CON DETECCIÓN OOD
287
+ """
288
 
289
  if not saved_images:
290
  print("❌ No hay imágenes guardadas para clasificar")
 
293
  print(f"🔄 Clasificando {len(saved_images)} imágenes guardadas...")
294
  print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
295
 
296
+ # 🆕 Obtener configuración OOD para esta categoría
297
+ ood_config = get_ood_thresholds(model_category)
298
+ print(f"🛡️ Detección OOD activada:")
299
+ print(f" - Umbral similitud: {ood_config['similarity_threshold']}")
300
+ print(f" - Umbral distancia: {ood_config['distance_threshold']}")
301
+
302
  results = []
303
  filtered_count = 0
304
+ ood_detected_count = 0
305
+
306
  with torch.no_grad():
307
  for img_info in saved_images:
308
  try:
 
321
  # Normalizar
322
  query_features = F.normalize(query_features, p=2, dim=1)
323
 
324
+ # 🆕 DETECCIÓN OOD
325
+ is_ood, ood_score, ood_reason = detect_out_of_distribution(
326
+ query_features, prototypes, ood_config, class_names
327
+ )
328
+
329
+ # 🆕 Si es OOD, manejar de forma especial
330
+ if is_ood:
331
+ ood_detected_count += 1
332
+ print(f"🚨 OOD detectado en bbox {img_info['bbox_id']}: {ood_reason}")
333
+
334
+ # Opción 1: Filtrar completamente (recomendado)
335
+ filtered_count += 1
336
+ continue
337
+
338
+ # Opción 2: Marcar como "PRODUCTO_DESCONOCIDO" (opcional - descomenta si prefieres esto)
339
+ # result = {
340
+ # 'sku_bb_id': str(img_info['bbox_id']),
341
+ # 'predictions': ['PRODUCTO_DESCONOCIDO'],
342
+ # 'accuracy': [round(ood_score, 4)],
343
+ # 'prediccion_principal': 'PRODUCTO_DESCONOCIDO',
344
+ # 'similarity_principal': f"{ood_score*100:.2f}%",
345
+ # 'bbox_confidence': round(float(img_info['confidence']), 4),
346
+ # 'ood_detected': True,
347
+ # 'ood_reason': ood_reason,
348
+ # 'xmin': img_info['x_min'],
349
+ # 'ymin': img_info['y_min'],
350
+ # 'xmax': img_info['x_max'],
351
+ # 'ymax': img_info['y_max']
352
+ # }
353
+ # results.append(result)
354
+ # continue
355
+
356
+ # Calcular similitud coseno con prototipos guardados (solo si no es OOD)
357
  similarities = torch.mm(query_features, prototypes.t())
358
  similarities_numpy = similarities.cpu().numpy()[0]
359
 
 
376
  print(f"🔽 Bbox {img_info['bbox_id']} filtrado: ninguna predicción cumple minimal_accuracy {minimal_accuracy}")
377
  continue
378
 
379
+ # 🆕 Aplicar ajuste de confianza basado en OOD score
380
+ adjusted_similarities = []
381
+ for sim in top3_similarities:
382
+ # Combinar similarity original con OOD confidence
383
+ adjusted_sim = (sim * 0.8) + (ood_score * 0.2) # Peso 80-20
384
+ adjusted_similarities.append(round(adjusted_sim, 4))
385
+
386
  # Guardar predictions y accuracy como listas (solo las que cumplen el filtro)
387
  predictions_list = top3_predictions
388
+ similarities_list = adjusted_similarities # 🆕 Usar similarities ajustadas
389
 
390
  # La predicción principal es la primera de la lista filtrada
391
  predicted_class = predictions_list[0]
 
396
  # Formatear bbox_confidence con 4 decimales
397
  bbox_confidence_formatted = round(float(img_info['confidence']), 4)
398
 
399
+ # 🆕 Agregar resultado con información OOD
400
  result = {
401
  'sku_bb_id': str(img_info['bbox_id']),
402
  'predictions': predictions_list,
 
404
  'prediccion_principal': predicted_class,
405
  'similarity_principal': similarity_principal_formatted,
406
  'bbox_confidence': bbox_confidence_formatted,
407
+ 'ood_detected': False, # 🆕 No es OOD
408
+ 'ood_score': round(ood_score, 4), # 🆕 Puntuación OOD
409
  'xmin': img_info['x_min'],
410
  'ymin': img_info['y_min'],
411
  'xmax': img_info['x_max'],
 
424
  'prediccion_principal': 'ERROR',
425
  'similarity_principal': 'ERROR',
426
  'bbox_confidence': round(float(img_info['confidence']), 4),
427
+ 'ood_detected': False,
428
+ 'ood_score': 0.0000,
429
  'xmin': img_info['x_min'],
430
  'ymin': img_info['y_min'],
431
  'xmax': img_info['x_max'],
432
  'ymax': img_info['y_max']
433
  })
434
 
435
+ # 🆕 Resumen mejorado con estadísticas OOD
436
+ if filtered_count > 0 or ood_detected_count > 0:
437
  print(f"📊 Resumen de filtrado:")
438
  print(f" - Detecciones procesadas: {len(results)}")
439
+ print(f" - Detecciones filtradas por accuracy: {filtered_count - ood_detected_count}")
440
+ print(f" - 🆕 Detecciones OOD filtradas: {ood_detected_count}")
441
+ print(f" - Total filtrado: {filtered_count}")
442
  print(f" - Total original: {len(saved_images)}")
443
+ if ood_detected_count > 0:
444
+ ood_percentage = (ood_detected_count / len(saved_images)) * 100
445
+ print(f" - 🆕 Porcentaje OOD: {ood_percentage:.1f}%")
446
 
447
  return pd.DataFrame(results)
448
 
 
450
  """Función principal para procesar una imagen: detectar BB, guardar recortes y clasificar"""
451
 
452
  print("="*80)
453
+ print("PROCESAMIENTO DE IMAGEN CON BOUNDING BOXES - MODELO OPTIMIZADO V5 + OOD")
454
  print("="*80)
455
  print(f"📸 Imagen: {image_url}")
456
  print(f"🆔 Picture ID: {picture_id}")
 
458
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
459
  print(f"💻 Dispositivo: {device}")
460
  print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
461
+ print(f"🛡️ Detección OOD activada para categoría: {model_category}") # 🆕
462
 
463
  # Cargar bounding boxes desde S3
464
  saved_images, s3_client = load_json_from_s3(json_s3_url)
 
476
  print(f"❌ Error cargando modelo: {e}")
477
  return pd.DataFrame()
478
 
479
+ # 4. Clasificar imágenes guardadas CON DETECCIÓN OOD
480
+ print("\n🔬 PASO 4: Clasificando imágenes guardadas con detección OOD...")
481
  results_df = classify_saved_bboxes(
482
  saved_images, encoder, class_names, prototypes, eval_transform, device,
483
+ minimal_accuracy, s3_client, model_category # 🆕 Pasar model_category
484
  )
485
 
486
  # 5. Mostrar resumen
 
490
  print(f" - Clases detectadas: {results_df['prediccion_principal'].nunique()}")
491
  print(f" - Clases únicas encontradas: {', '.join(results_df['prediccion_principal'].unique())}")
492
 
493
+ # 🆕 Estadísticas OOD
494
+ if 'ood_score' in results_df.columns:
495
+ avg_ood_score = results_df['ood_score'].mean()
496
+ min_ood_score = results_df['ood_score'].min()
497
+ max_ood_score = results_df['ood_score'].max()
498
+ print(f"\n🛡️ Estadísticas OOD:")
499
+ print(f" - OOD Score promedio: {avg_ood_score:.4f}")
500
+ print(f" - OOD Score mínimo: {min_ood_score:.4f}")
501
+ print(f" - OOD Score máximo: {max_ood_score:.4f}")
502
+
503
  # Top predicciones
504
  print(f"\n📊 Top predicciones:")
505
  top_predictions = results_df['prediccion_principal'].value_counts().head(5)
 
516
  print(f" - Mínimo: {min_accuracy:.4f}")
517
  print(f" - Máximo: {max_accuracy:.4f}")
518
  else:
519
+ print("❌ No hay detecciones que cumplan con el filtro de accuracy o todas fueron detectadas como OOD")
520
 
521
  return results_df
522
 
 
530
 
531
  def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
532
 
533
+ print("Ejecutando clasificación optimizada con prototipos pre-cargados y detección OOD...")
534
  result_df = process_image_with_bboxes(
535
  self, image_url, picture_id, visit_id, minimal_accuracy,
536
  None, None, # model_path y train_path ya no son necesarios