Drazcat-AI commited on
Commit
8fb36c3
·
verified ·
1 Parent(s): 5589ea9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +435 -434
handler.py CHANGED
@@ -1,435 +1,436 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
- import numpy as np
6
- import pandas as pd
7
- from pathlib import Path
8
- from collections import defaultdict
9
- import requests
10
- import json
11
- from io import BytesIO
12
- import os
13
- from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork
14
- import boto3
15
- from botocore.exceptions import ClientError
16
-
17
- def load_image_from_s3_direct(bucket_name, s3_key, s3_client):
18
-
19
- """Cargar imagen directamente desde S3 usando boto3 (RECOMENDADO)"""
20
- try:
21
- print(f"🔄 Cargando imagen desde S3...")
22
- print(f"📦 Bucket: {bucket_name}")
23
- print(f"🗝️ Key: {s3_key}")
24
-
25
- # Descargar objeto desde S3
26
- response = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
27
-
28
- # Leer contenido y convertir a imagen
29
- image_data = response['Body'].read()
30
- bbox_image = Image.open(BytesIO(image_data)).convert('RGB')
31
-
32
- print("✅ Imagen cargada exitosamente")
33
- return bbox_image
34
-
35
- except ClientError as e:
36
- error_code = e.response['Error']['Code']
37
- if error_code == 'NoSuchKey':
38
- print(f"❌ La imagen no existe en S3: {s3_key}")
39
- elif error_code == 'NoSuchBucket':
40
- print(f"❌ El bucket no existe: {bucket_name}")
41
- elif error_code == 'AccessDenied':
42
- print(f"❌ Sin permisos para acceder a: {s3_key}")
43
- else:
44
- print(f"❌ Error de S3: {e}")
45
- return None
46
-
47
- except Exception as e:
48
- print(f"❌ Error cargando imagen: {e}")
49
- return None
50
-
51
- def load_image_from_s3_url(s3_url, s3_client):
52
- """Cargar imagen desde S3 extrayendo bucket y key de la URL"""
53
- try:
54
- # Extraer bucket y key de la URL
55
- # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.jpg
56
- url_parts = s3_url.replace('https://', '').split('/')
57
- bucket = url_parts[0].split('.s3.amazonaws.com')[0]
58
- key = '/'.join(url_parts[1:])
59
-
60
- return load_image_from_s3_direct(bucket, key, s3_client)
61
-
62
- except Exception as e:
63
- print(f"❌ Error procesando URL: {e}")
64
- return None
65
-
66
- def model_selector(self, model_category):
67
- if model_category == "bebidas_gas":
68
- encoder, class_names, prototypes, eval_transform = self.encoder_bebidas_gas, self.class_names_bebidas_gas, self.prototypes_bebidas_gas, self.eval_transform_bebidas_gas
69
- elif model_category == "detergentes":
70
- encoder, class_names, prototypes, eval_transform = self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes
71
-
72
- return encoder, class_names, prototypes, eval_transform
73
-
74
- def load_classification_model(model_path, train_path, device):
75
-
76
- if Path(model_path).exists():
77
- actual_model_path = model_path
78
- model_name = "MODELO_ESPECIFICADO"
79
- print(f"✅ Usando modelo especificado: {model_path}")
80
- else:
81
- raise FileNotFoundError(f"❌ No se encontró ningún modelo en las rutas esperadas")
82
-
83
- # Cargar modelo con la arquitectura correcta (256 dims)
84
- encoder = ConvEncoder(hidden_dim=64, output_dim=256).to(device)
85
- model = PrototypicalNetwork(encoder).to(device)
86
-
87
- # Cargar pesos con weights_only=False para compatibilidad
88
- checkpoint = torch.load(actual_model_path, map_location=device, weights_only=False)
89
- encoder.load_state_dict(checkpoint['encoder_state_dict'])
90
- model.load_state_dict(checkpoint['model_state_dict'])
91
- model.eval()
92
-
93
- print(f"✅ Modelo de clasificación cargado correctamente ({model_name})")
94
-
95
- # Transformaciones para evaluación
96
- eval_transform = transforms.Compose([
97
- transforms.Resize((224, 224)),
98
- transforms.ToTensor(),
99
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
100
- std=[0.229, 0.224, 0.225])
101
- ])
102
-
103
- # Crear prototipos robustos usando múltiples shots del dataset de entrenamiento
104
- print("🔄 Creando prototipos de clases...")
105
- class_images = defaultdict(list)
106
-
107
- # Cargar imágenes del train para crear prototipos
108
- for img_path in Path(train_path).glob('*'):
109
- if img_path.suffix.lower() in {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}:
110
- parts = img_path.stem.split('_')[:-1]
111
- class_name = '_'.join(parts) if parts else img_path.stem
112
-
113
- # Usar hasta 5 imágenes por clase para prototipos robustos
114
- if len(class_images[class_name]) < 5:
115
- try:
116
- image = Image.open(img_path).convert('RGB')
117
- image_tensor = eval_transform(image).unsqueeze(0).to(device)
118
- class_images[class_name].append(image_tensor)
119
- except Exception as e:
120
- pass
121
-
122
- # Crear prototipos con normalización
123
- class_names = sorted(class_images.keys())
124
- prototypes = []
125
-
126
- with torch.no_grad():
127
- for class_name in class_names:
128
- if class_images[class_name]:
129
- # Concatenar imágenes de la clase
130
- class_tensors = torch.cat(class_images[class_name], dim=0)
131
- # Extraer características
132
- class_features = encoder(class_tensors)
133
- # Normalizar (como hace el modelo)
134
- class_features = F.normalize(class_features, p=2, dim=1)
135
- # Promediar para obtener prototipo
136
- prototype = class_features.mean(dim=0, keepdim=True)
137
- # Normalizar el prototipo también
138
- prototype = F.normalize(prototype, p=2, dim=1)
139
- prototypes.append(prototype)
140
-
141
- prototypes = torch.cat(prototypes, dim=0)
142
- print(f"✅ Prototipos creados para {len(class_names)} clases")
143
-
144
- return encoder, class_names, prototypes, eval_transform
145
-
146
- def load_json_from_s3(json_s3_url):
147
-
148
- # Configuración S3
149
- aws_access_key = 'AKIA6BH4GPXQCUZ3PAX5' # Cambiar por tu access key
150
- aws_secret_key = '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' # Cambiar por tu secret key
151
- region_name = 'us-east-1' # Cambiar por tu región
152
- S3_BUCKET_NAME = 'rocketpin-ml-data' # Cambiar por tu bucket
153
-
154
- # Crear sesión y cliente S3
155
- session = boto3.Session(
156
- aws_access_key_id=aws_access_key,
157
- aws_secret_access_key=aws_secret_key,
158
- region_name=region_name
159
- )
160
- s3_client = session.client('s3')
161
-
162
- """Cargar JSON desde S3 usando la URL completa"""
163
- try:
164
- # Extraer bucket y key de la URL
165
- # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.json
166
- url_parts = json_s3_url.replace('https://', '').split('/')
167
- bucket = url_parts[0].split('.s3.amazonaws.com')[0]
168
- key = '/'.join(url_parts[1:])
169
-
170
- #print(f"🔄 Cargando JSON desde S3...")
171
- #print(f"📦 Bucket: {bucket}")
172
- #print(f"🗝️ Key: {key}")
173
-
174
- # Descargar objeto desde S3
175
- response = s3_client.get_object(Bucket=bucket, Key=key)
176
-
177
- # Leer contenido y convertir a JSON
178
- json_content = response['Body'].read().decode('utf-8')
179
- json_data = json.loads(json_content)
180
-
181
- print("✅ JSON cargado exitosamente")
182
- return json_data, s3_client
183
-
184
- except ClientError as e:
185
- error_code = e.response['Error']['Code']
186
- if error_code == 'NoSuchKey':
187
- print(f"❌ El archivo no existe en S3: {key}")
188
- elif error_code == 'NoSuchBucket':
189
- print(f"❌ El bucket no existe: {bucket}")
190
- else:
191
- print(f"❌ Error de S3: {e}")
192
- return None
193
-
194
- except Exception as e:
195
- print(f"❌ Error cargando JSON: {e}")
196
- return None
197
-
198
- def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client):
199
- """Clasificar las imágenes de bounding boxes guardadas"""
200
-
201
- if not saved_images:
202
- print("❌ No hay imágenes guardadas para clasificar")
203
- return pd.DataFrame()
204
-
205
- print(f"🔄 Clasificando {len(saved_images)} imágenes guardadas...")
206
- print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
207
-
208
- results = []
209
- filtered_count = 0
210
- with torch.no_grad():
211
- for img_info in saved_images:
212
- try:
213
- #if True:
214
- # Cargar imagen guardada
215
- #response = requests.get(img_info['bbox_path'])
216
- #bbox_image = Image.open(img_info['bbox_path']).convert('RGB')
217
- bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client)
218
-
219
- # Transformar para el modelo
220
- query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device)
221
-
222
- # Extraer características
223
- query_features = encoder(query_tensor)
224
- # Normalizar
225
- query_features = F.normalize(query_features, p=2, dim=1)
226
-
227
- # Calcular similitud coseno
228
- similarities = torch.mm(query_features, prototypes.t())
229
- similarities_numpy = similarities.cpu().numpy()[0]
230
-
231
- # Obtener top 3 predicciones
232
- #top3_indices = np.argsort(similarities_numpy)[::-1][:3]
233
- top3_indices = np.argsort(similarities_numpy)[::-1]
234
- top3_predictions = []
235
- top3_similarities = []
236
-
237
- for idx_pred in top3_indices:
238
- prediction = class_names[idx_pred]
239
- similarity = similarities_numpy[idx_pred]
240
- # Solo agregar si cumple con minimal_accuracy
241
- if similarity >= minimal_accuracy:
242
- top3_predictions.append(prediction)
243
- top3_similarities.append(round(similarity, 4))
244
-
245
- # Si no hay predicciones que cumplan con minimal_accuracy, saltar
246
- if len(top3_predictions) == 0:
247
- filtered_count += 1
248
- print(f"🔽 Bbox {img_info['bbox_id']} filtrado: ninguna predicción cumple minimal_accuracy {minimal_accuracy}")
249
- continue
250
-
251
- # Guardar predictions y accuracy como listas (solo las que cumplen el filtro)
252
- predictions_list = top3_predictions
253
- similarities_list = top3_similarities
254
-
255
- # La predicción principal es la primera de la lista filtrada
256
- predicted_class = predictions_list[0]
257
-
258
- # Formatear similarity_principal como porcentaje
259
- similarity_principal_formatted = f"{similarities_list[0]*100:.2f}%"
260
-
261
- # Formatear bbox_confidence con 4 decimales
262
- bbox_confidence_formatted = round(float(img_info['confidence']), 4)
263
-
264
- # Agregar resultado
265
- result = {
266
- 'sku_bb_id': str(img_info['bbox_id']),
267
- 'predictions': predictions_list,
268
- 'accuracy': similarities_list,
269
- 'prediccion_principal': predicted_class,
270
- 'similarity_principal': similarity_principal_formatted,
271
- 'bbox_confidence': bbox_confidence_formatted,
272
- 'xmin': img_info['x_min'],
273
- 'ymin': img_info['y_min'],
274
- 'xmax': img_info['x_max'],
275
- 'ymax': img_info['y_max']
276
- }
277
-
278
- results.append(result)
279
- #"""
280
- except Exception as e:
281
- print(f"❌ Error clasificando bbox {str(img_info['bbox_id'])}: {e}")
282
- # Agregar entrada de error
283
- results.append({
284
- 'sku_bb_id': str(img_info['bbox_id']),
285
- 'predictions': ['ERROR'],
286
- 'accuracy': [0.0000],
287
- 'prediccion_principal': 'ERROR',
288
- 'similarity_principal': 'ERROR',
289
- 'bbox_confidence': round(float(img_info['confidence']), 4),
290
- 'xmin': img_info['x_min'],
291
- 'ymin': img_info['y_min'],
292
- 'xmax': img_info['x_max'],
293
- 'ymax': img_info['y_max']
294
- })
295
- #"""
296
-
297
- if filtered_count > 0:
298
- print(f"📊 Resumen de filtrado:")
299
- print(f" - Detecciones procesadas: {len(results)}")
300
- print(f" - Detecciones filtradas: {filtered_count}")
301
- print(f" - Total original: {len(saved_images)}")
302
-
303
- return pd.DataFrame(results)
304
-
305
- def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url):
306
- """Función principal para procesar una imagen: detectar BB, guardar recortes y clasificar"""
307
-
308
- print("="*80)
309
- print("PROCESAMIENTO DE IMAGEN CON BOUNDING BOXES - MODELO OPTIMIZADO V4")
310
- print("="*80)
311
- print(f"📸 Imagen: {image_url}")
312
- print(f"🆔 Picture ID: {picture_id}")
313
-
314
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
315
- print(f"💻 Dispositivo: {device}")
316
- print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
317
-
318
- #logica de carga de bbs con s3
319
- saved_images, s3_client = load_json_from_s3(json_s3_url)
320
- saved_images = saved_images['bounding_boxes']
321
-
322
- if not saved_images:
323
- print("❌ No se pudieron guardar las imágenes recortadas")
324
- return pd.DataFrame()
325
-
326
- # 3. Cargar modelo de clasificación
327
- print("\n🤖 PASO 3: Cargando modelo de clasificación...")
328
- try:
329
- encoder, class_names, prototypes, eval_transform = model_selector(self, model_category)
330
- except Exception as e:
331
- print(f"❌ Error cargando modelo: {e}")
332
- return pd.DataFrame()
333
-
334
- # 4. Clasificar imágenes guardadas
335
- print("\n🔬 PASO 4: Clasificando imágenes guardadas...")
336
- results_df = classify_saved_bboxes(
337
- saved_images, encoder, class_names, prototypes, eval_transform, device,
338
- minimal_accuracy, s3_client
339
- )
340
-
341
- # 5. Mostrar resumen
342
- if not results_df.empty:
343
- print(f"\n✅ Procesamiento completado:")
344
- print(f" - Total de detecciones procesadas: {len(results_df)}")
345
- print(f" - Clases detectadas: {results_df['prediccion_principal'].nunique()}")
346
- print(f" - Clases únicas encontradas: {', '.join(results_df['prediccion_principal'].unique())}")
347
-
348
- # Top predicciones
349
- print(f"\n📊 Top predicciones:")
350
- top_predictions = results_df['prediccion_principal'].value_counts().head(5)
351
- for clase, count in top_predictions.items():
352
- print(f" - {clase}: {count} detecciones")
353
-
354
- # Estadísticas de accuracy
355
- if len(results_df) > 0:
356
- avg_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).mean()
357
- min_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).min()
358
- max_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).max()
359
- print(f"\n📈 Estadísticas de accuracy:")
360
- print(f" - Promedio: {avg_accuracy:.4f}")
361
- print(f" - Mínimo: {min_accuracy:.4f}")
362
- print(f" - Máximo: {max_accuracy:.4f}")
363
- else:
364
- print("❌ No hay detecciones que cumplan con el filtro de accuracy")
365
-
366
- return results_df
367
-
368
- class EndpointHandler():
369
- def __init__(self):
370
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
371
- model_path_detergentes = "model_curriculum4/prototypical_model_best_detergentes.pth"
372
- train_path_detergentes = "datasets/detergentes/train"
373
- #hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_path_detergentes)
374
- self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes = load_classification_model(model_path_detergentes, train_path_detergentes, device)
375
- #model_path_bebidas_gas = "model_curriculum4/prototypical_model_best_bebidas_gas.pth"
376
- #train_path_bebidas_gas = "datasets/bebidas_gas/train"
377
- #hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename="model_curriculum4/prototypical_model_best_bebidas_gas.pth")
378
- #self.encoder_bebidas_gas, self.class_names_bebidas_gas, self.prototypes_bebidas_gas, self.eval_transform_bebidas_gas = load_classification_model(model_path_bebidas_gas, train_path_bebidas_gas, device)
379
-
380
- def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
381
-
382
- model_path="model_curriculum4/prototypical_model_best_" + model_category + ".pth"
383
- train_path ="datasets/" + model_category + "/train"
384
-
385
- print("Ejecutando test con una imagen...")
386
- result_df = process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url)
387
- return result_df
388
-
389
- def __call__(self, event):
390
-
391
- image_url = event["image_url"]
392
- picture_id = event["picture_id"]
393
- visit_id = event["visit_id"]
394
- minimal_accuracy = event["minimal_accuracy"]
395
- model_category = event["model_category"]
396
- client = event["client"]
397
- json_s3_url = event["json_s3_url"]
398
-
399
- try:
400
- #if True:
401
-
402
- predictions = self.predict_objects(image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url)
403
- predictions_json = predictions.to_json(orient='records')
404
- #print(predictions)
405
- return {
406
- "statusCode": 200,
407
- "body": json.dumps(predictions_json),
408
- }
409
- except Exception as e:
410
- return {
411
- "statusCode": 500,
412
- "body": json.dumps(f"Error: {str(e)}"),
413
- }
414
-
415
- """
416
- # Instanciar la clase
417
- handler = EndpointHandler()
418
-
419
- # Preparar el evento con los datos necesarios
420
- event_data = {
421
- "image_url": "https://dmnoqeddtk0uw.cloudfront.net/peru_cencosud/visits/34/pi/upload_image385772781681046090.jpg",
422
- "picture_id": 11025,
423
- "visit_id": 34,
424
- "minimal_accuracy": 0.0,
425
- "model_category": "detergentes",
426
- "json_s3_url": "https://rocketpin-ml-data.s3.amazonaws.com/redes_prototipicas/bounding_boxes_images/results/visit_34/bboxes_results_picture_11025_visit_34.json"
427
- }
428
-
429
- # Ejecutar la predicción
430
- response = handler(event_data)
431
-
432
- # Verificar el resultado
433
- print(f"Status Code: {response['statusCode']}")
434
- print(f"Body: {response['body']}")
 
435
  """
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pathlib import Path
8
+ from collections import defaultdict
9
+ import requests
10
+ import json
11
+ from io import BytesIO
12
+ import os
13
+ from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork
14
+ import boto3
15
+ from botocore.exceptions import ClientError
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ def load_image_from_s3_direct(bucket_name, s3_key, s3_client):
19
+
20
+ """Cargar imagen directamente desde S3 usando boto3 (RECOMENDADO)"""
21
+ try:
22
+ print(f"🔄 Cargando imagen desde S3...")
23
+ print(f"📦 Bucket: {bucket_name}")
24
+ print(f"🗝️ Key: {s3_key}")
25
+
26
+ # Descargar objeto desde S3
27
+ response = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
28
+
29
+ # Leer contenido y convertir a imagen
30
+ image_data = response['Body'].read()
31
+ bbox_image = Image.open(BytesIO(image_data)).convert('RGB')
32
+
33
+ print("✅ Imagen cargada exitosamente")
34
+ return bbox_image
35
+
36
+ except ClientError as e:
37
+ error_code = e.response['Error']['Code']
38
+ if error_code == 'NoSuchKey':
39
+ print(f"❌ La imagen no existe en S3: {s3_key}")
40
+ elif error_code == 'NoSuchBucket':
41
+ print(f"❌ El bucket no existe: {bucket_name}")
42
+ elif error_code == 'AccessDenied':
43
+ print(f"❌ Sin permisos para acceder a: {s3_key}")
44
+ else:
45
+ print(f"❌ Error de S3: {e}")
46
+ return None
47
+
48
+ except Exception as e:
49
+ print(f"❌ Error cargando imagen: {e}")
50
+ return None
51
+
52
+ def load_image_from_s3_url(s3_url, s3_client):
53
+ """Cargar imagen desde S3 extrayendo bucket y key de la URL"""
54
+ try:
55
+ # Extraer bucket y key de la URL
56
+ # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.jpg
57
+ url_parts = s3_url.replace('https://', '').split('/')
58
+ bucket = url_parts[0].split('.s3.amazonaws.com')[0]
59
+ key = '/'.join(url_parts[1:])
60
+
61
+ return load_image_from_s3_direct(bucket, key, s3_client)
62
+
63
+ except Exception as e:
64
+ print(f"❌ Error procesando URL: {e}")
65
+ return None
66
+
67
+ def model_selector(self, model_category):
68
+ if model_category == "bebidas_gas":
69
+ encoder, class_names, prototypes, eval_transform = self.encoder_bebidas_gas, self.class_names_bebidas_gas, self.prototypes_bebidas_gas, self.eval_transform_bebidas_gas
70
+ elif model_category == "detergentes":
71
+ encoder, class_names, prototypes, eval_transform = self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes
72
+
73
+ return encoder, class_names, prototypes, eval_transform
74
+
75
+ def load_classification_model(model_path, train_path, device):
76
+
77
+ if Path(model_path).exists():
78
+ actual_model_path = model_path
79
+ model_name = "MODELO_ESPECIFICADO"
80
+ print(f"✅ Usando modelo especificado: {model_path}")
81
+ else:
82
+ raise FileNotFoundError(f"❌ No se encontró ningún modelo en las rutas esperadas")
83
+
84
+ # Cargar modelo con la arquitectura correcta (256 dims)
85
+ encoder = ConvEncoder(hidden_dim=64, output_dim=256).to(device)
86
+ model = PrototypicalNetwork(encoder).to(device)
87
+
88
+ # Cargar pesos con weights_only=False para compatibilidad
89
+ checkpoint = torch.load(actual_model_path, map_location=device, weights_only=False)
90
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
91
+ model.load_state_dict(checkpoint['model_state_dict'])
92
+ model.eval()
93
+
94
+ print(f"✅ Modelo de clasificación cargado correctamente ({model_name})")
95
+
96
+ # Transformaciones para evaluación
97
+ eval_transform = transforms.Compose([
98
+ transforms.Resize((224, 224)),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
101
+ std=[0.229, 0.224, 0.225])
102
+ ])
103
+
104
+ # Crear prototipos robustos usando múltiples shots del dataset de entrenamiento
105
+ print("🔄 Creando prototipos de clases...")
106
+ class_images = defaultdict(list)
107
+
108
+ # Cargar imágenes del train para crear prototipos
109
+ for img_path in Path(train_path).glob('*'):
110
+ if img_path.suffix.lower() in {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}:
111
+ parts = img_path.stem.split('_')[:-1]
112
+ class_name = '_'.join(parts) if parts else img_path.stem
113
+
114
+ # Usar hasta 5 imágenes por clase para prototipos robustos
115
+ if len(class_images[class_name]) < 5:
116
+ try:
117
+ image = Image.open(img_path).convert('RGB')
118
+ image_tensor = eval_transform(image).unsqueeze(0).to(device)
119
+ class_images[class_name].append(image_tensor)
120
+ except Exception as e:
121
+ pass
122
+
123
+ # Crear prototipos con normalización
124
+ class_names = sorted(class_images.keys())
125
+ prototypes = []
126
+
127
+ with torch.no_grad():
128
+ for class_name in class_names:
129
+ if class_images[class_name]:
130
+ # Concatenar imágenes de la clase
131
+ class_tensors = torch.cat(class_images[class_name], dim=0)
132
+ # Extraer características
133
+ class_features = encoder(class_tensors)
134
+ # Normalizar (como hace el modelo)
135
+ class_features = F.normalize(class_features, p=2, dim=1)
136
+ # Promediar para obtener prototipo
137
+ prototype = class_features.mean(dim=0, keepdim=True)
138
+ # Normalizar el prototipo también
139
+ prototype = F.normalize(prototype, p=2, dim=1)
140
+ prototypes.append(prototype)
141
+
142
+ prototypes = torch.cat(prototypes, dim=0)
143
+ print(f"✅ Prototipos creados para {len(class_names)} clases")
144
+
145
+ return encoder, class_names, prototypes, eval_transform
146
+
147
+ def load_json_from_s3(json_s3_url):
148
+
149
+ # Configuración S3
150
+ aws_access_key = 'AKIA6BH4GPXQCUZ3PAX5' # Cambiar por tu access key
151
+ aws_secret_key = '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' # Cambiar por tu secret key
152
+ region_name = 'us-east-1' # Cambiar por tu región
153
+ S3_BUCKET_NAME = 'rocketpin-ml-data' # Cambiar por tu bucket
154
+
155
+ # Crear sesión y cliente S3
156
+ session = boto3.Session(
157
+ aws_access_key_id=aws_access_key,
158
+ aws_secret_access_key=aws_secret_key,
159
+ region_name=region_name
160
+ )
161
+ s3_client = session.client('s3')
162
+
163
+ """Cargar JSON desde S3 usando la URL completa"""
164
+ try:
165
+ # Extraer bucket y key de la URL
166
+ # URL formato: https://bucket-name.s3.amazonaws.com/path/to/file.json
167
+ url_parts = json_s3_url.replace('https://', '').split('/')
168
+ bucket = url_parts[0].split('.s3.amazonaws.com')[0]
169
+ key = '/'.join(url_parts[1:])
170
+
171
+ #print(f"🔄 Cargando JSON desde S3...")
172
+ #print(f"📦 Bucket: {bucket}")
173
+ #print(f"🗝️ Key: {key}")
174
+
175
+ # Descargar objeto desde S3
176
+ response = s3_client.get_object(Bucket=bucket, Key=key)
177
+
178
+ # Leer contenido y convertir a JSON
179
+ json_content = response['Body'].read().decode('utf-8')
180
+ json_data = json.loads(json_content)
181
+
182
+ print("✅ JSON cargado exitosamente")
183
+ return json_data, s3_client
184
+
185
+ except ClientError as e:
186
+ error_code = e.response['Error']['Code']
187
+ if error_code == 'NoSuchKey':
188
+ print(f"❌ El archivo no existe en S3: {key}")
189
+ elif error_code == 'NoSuchBucket':
190
+ print(f"❌ El bucket no existe: {bucket}")
191
+ else:
192
+ print(f"❌ Error de S3: {e}")
193
+ return None
194
+
195
+ except Exception as e:
196
+ print(f"❌ Error cargando JSON: {e}")
197
+ return None
198
+
199
+ def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client):
200
+ """Clasificar las imágenes de bounding boxes guardadas"""
201
+
202
+ if not saved_images:
203
+ print("❌ No hay imágenes guardadas para clasificar")
204
+ return pd.DataFrame()
205
+
206
+ print(f"🔄 Clasificando {len(saved_images)} imágenes guardadas...")
207
+ print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
208
+
209
+ results = []
210
+ filtered_count = 0
211
+ with torch.no_grad():
212
+ for img_info in saved_images:
213
+ try:
214
+ #if True:
215
+ # Cargar imagen guardada
216
+ #response = requests.get(img_info['bbox_path'])
217
+ #bbox_image = Image.open(img_info['bbox_path']).convert('RGB')
218
+ bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client)
219
+
220
+ # Transformar para el modelo
221
+ query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device)
222
+
223
+ # Extraer características
224
+ query_features = encoder(query_tensor)
225
+ # Normalizar
226
+ query_features = F.normalize(query_features, p=2, dim=1)
227
+
228
+ # Calcular similitud coseno
229
+ similarities = torch.mm(query_features, prototypes.t())
230
+ similarities_numpy = similarities.cpu().numpy()[0]
231
+
232
+ # Obtener top 3 predicciones
233
+ #top3_indices = np.argsort(similarities_numpy)[::-1][:3]
234
+ top3_indices = np.argsort(similarities_numpy)[::-1]
235
+ top3_predictions = []
236
+ top3_similarities = []
237
+
238
+ for idx_pred in top3_indices:
239
+ prediction = class_names[idx_pred]
240
+ similarity = similarities_numpy[idx_pred]
241
+ # Solo agregar si cumple con minimal_accuracy
242
+ if similarity >= minimal_accuracy:
243
+ top3_predictions.append(prediction)
244
+ top3_similarities.append(round(similarity, 4))
245
+
246
+ # Si no hay predicciones que cumplan con minimal_accuracy, saltar
247
+ if len(top3_predictions) == 0:
248
+ filtered_count += 1
249
+ print(f"🔽 Bbox {img_info['bbox_id']} filtrado: ninguna predicción cumple minimal_accuracy {minimal_accuracy}")
250
+ continue
251
+
252
+ # Guardar predictions y accuracy como listas (solo las que cumplen el filtro)
253
+ predictions_list = top3_predictions
254
+ similarities_list = top3_similarities
255
+
256
+ # La predicción principal es la primera de la lista filtrada
257
+ predicted_class = predictions_list[0]
258
+
259
+ # Formatear similarity_principal como porcentaje
260
+ similarity_principal_formatted = f"{similarities_list[0]*100:.2f}%"
261
+
262
+ # Formatear bbox_confidence con 4 decimales
263
+ bbox_confidence_formatted = round(float(img_info['confidence']), 4)
264
+
265
+ # Agregar resultado
266
+ result = {
267
+ 'sku_bb_id': str(img_info['bbox_id']),
268
+ 'predictions': predictions_list,
269
+ 'accuracy': similarities_list,
270
+ 'prediccion_principal': predicted_class,
271
+ 'similarity_principal': similarity_principal_formatted,
272
+ 'bbox_confidence': bbox_confidence_formatted,
273
+ 'xmin': img_info['x_min'],
274
+ 'ymin': img_info['y_min'],
275
+ 'xmax': img_info['x_max'],
276
+ 'ymax': img_info['y_max']
277
+ }
278
+
279
+ results.append(result)
280
+ #"""
281
+ except Exception as e:
282
+ print(f"❌ Error clasificando bbox {str(img_info['bbox_id'])}: {e}")
283
+ # Agregar entrada de error
284
+ results.append({
285
+ 'sku_bb_id': str(img_info['bbox_id']),
286
+ 'predictions': ['ERROR'],
287
+ 'accuracy': [0.0000],
288
+ 'prediccion_principal': 'ERROR',
289
+ 'similarity_principal': 'ERROR',
290
+ 'bbox_confidence': round(float(img_info['confidence']), 4),
291
+ 'xmin': img_info['x_min'],
292
+ 'ymin': img_info['y_min'],
293
+ 'xmax': img_info['x_max'],
294
+ 'ymax': img_info['y_max']
295
+ })
296
+ #"""
297
+
298
+ if filtered_count > 0:
299
+ print(f"📊 Resumen de filtrado:")
300
+ print(f" - Detecciones procesadas: {len(results)}")
301
+ print(f" - Detecciones filtradas: {filtered_count}")
302
+ print(f" - Total original: {len(saved_images)}")
303
+
304
+ return pd.DataFrame(results)
305
+
306
+ def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url):
307
+ """Función principal para procesar una imagen: detectar BB, guardar recortes y clasificar"""
308
+
309
+ print("="*80)
310
+ print("PROCESAMIENTO DE IMAGEN CON BOUNDING BOXES - MODELO OPTIMIZADO V4")
311
+ print("="*80)
312
+ print(f"📸 Imagen: {image_url}")
313
+ print(f"🆔 Picture ID: {picture_id}")
314
+
315
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
316
+ print(f"💻 Dispositivo: {device}")
317
+ print(f"🎯 Filtro minimal_accuracy: {minimal_accuracy}")
318
+
319
+ #logica de carga de bbs con s3
320
+ saved_images, s3_client = load_json_from_s3(json_s3_url)
321
+ saved_images = saved_images['bounding_boxes']
322
+
323
+ if not saved_images:
324
+ print("❌ No se pudieron guardar las imágenes recortadas")
325
+ return pd.DataFrame()
326
+
327
+ # 3. Cargar modelo de clasificación
328
+ print("\n🤖 PASO 3: Cargando modelo de clasificación...")
329
+ try:
330
+ encoder, class_names, prototypes, eval_transform = model_selector(self, model_category)
331
+ except Exception as e:
332
+ print(f"❌ Error cargando modelo: {e}")
333
+ return pd.DataFrame()
334
+
335
+ # 4. Clasificar imágenes guardadas
336
+ print("\n🔬 PASO 4: Clasificando imágenes guardadas...")
337
+ results_df = classify_saved_bboxes(
338
+ saved_images, encoder, class_names, prototypes, eval_transform, device,
339
+ minimal_accuracy, s3_client
340
+ )
341
+
342
+ # 5. Mostrar resumen
343
+ if not results_df.empty:
344
+ print(f"\n✅ Procesamiento completado:")
345
+ print(f" - Total de detecciones procesadas: {len(results_df)}")
346
+ print(f" - Clases detectadas: {results_df['prediccion_principal'].nunique()}")
347
+ print(f" - Clases únicas encontradas: {', '.join(results_df['prediccion_principal'].unique())}")
348
+
349
+ # Top predicciones
350
+ print(f"\n📊 Top predicciones:")
351
+ top_predictions = results_df['prediccion_principal'].value_counts().head(5)
352
+ for clase, count in top_predictions.items():
353
+ print(f" - {clase}: {count} detecciones")
354
+
355
+ # Estadísticas de accuracy
356
+ if len(results_df) > 0:
357
+ avg_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).mean()
358
+ min_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).min()
359
+ max_accuracy = results_df['accuracy'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else 0).max()
360
+ print(f"\n📈 Estadísticas de accuracy:")
361
+ print(f" - Promedio: {avg_accuracy:.4f}")
362
+ print(f" - Mínimo: {min_accuracy:.4f}")
363
+ print(f" - Máximo: {max_accuracy:.4f}")
364
+ else:
365
+ print("❌ No hay detecciones que cumplan con el filtro de accuracy")
366
+
367
+ return results_df
368
+
369
+ class EndpointHandler():
370
+ def __init__(self):
371
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
372
+ model_path_detergentes = "model_curriculum4/prototypical_model_best_detergentes.pth"
373
+ train_path_detergentes = "datasets/detergentes/train"
374
+ hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_path_detergentes)
375
+ self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes = load_classification_model(model_path_detergentes, train_path_detergentes, device)
376
+ #model_path_bebidas_gas = "model_curriculum4/prototypical_model_best_bebidas_gas.pth"
377
+ #train_path_bebidas_gas = "datasets/bebidas_gas/train"
378
+ #hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename="model_curriculum4/prototypical_model_best_bebidas_gas.pth")
379
+ #self.encoder_bebidas_gas, self.class_names_bebidas_gas, self.prototypes_bebidas_gas, self.eval_transform_bebidas_gas = load_classification_model(model_path_bebidas_gas, train_path_bebidas_gas, device)
380
+
381
+ def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
382
+
383
+ model_path="model_curriculum4/prototypical_model_best_" + model_category + ".pth"
384
+ train_path ="datasets/" + model_category + "/train"
385
+
386
+ print("Ejecutando test con una imagen...")
387
+ result_df = process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url)
388
+ return result_df
389
+
390
+ def __call__(self, event):
391
+
392
+ image_url = event["image_url"]
393
+ picture_id = event["picture_id"]
394
+ visit_id = event["visit_id"]
395
+ minimal_accuracy = event["minimal_accuracy"]
396
+ model_category = event["model_category"]
397
+ client = event["client"]
398
+ json_s3_url = event["json_s3_url"]
399
+
400
+ try:
401
+ #if True:
402
+
403
+ predictions = self.predict_objects(image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url)
404
+ predictions_json = predictions.to_json(orient='records')
405
+ #print(predictions)
406
+ return {
407
+ "statusCode": 200,
408
+ "body": json.dumps(predictions_json),
409
+ }
410
+ except Exception as e:
411
+ return {
412
+ "statusCode": 500,
413
+ "body": json.dumps(f"Error: {str(e)}"),
414
+ }
415
+
416
+ """
417
+ # Instanciar la clase
418
+ handler = EndpointHandler()
419
+
420
+ # Preparar el evento con los datos necesarios
421
+ event_data = {
422
+ "image_url": "https://dmnoqeddtk0uw.cloudfront.net/peru_cencosud/visits/34/pi/upload_image385772781681046090.jpg",
423
+ "picture_id": 11025,
424
+ "visit_id": 34,
425
+ "minimal_accuracy": 0.0,
426
+ "model_category": "detergentes",
427
+ "json_s3_url": "https://rocketpin-ml-data.s3.amazonaws.com/redes_prototipicas/bounding_boxes_images/results/visit_34/bboxes_results_picture_11025_visit_34.json"
428
+ }
429
+
430
+ # Ejecutar la predicción
431
+ response = handler(event_data)
432
+
433
+ # Verificar el resultado
434
+ print(f"Status Code: {response['statusCode']}")
435
+ print(f"Body: {response['body']}")
436
  """