import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import numpy as np import pandas as pd from pathlib import Path import json from io import BytesIO import boto3 from botocore.exceptions import ClientError from huggingface_hub import hf_hub_download # Imports desde el código de entrenamiento actualizado from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork, SmartPadResize def load_image_from_s3_url(s3_url, s3_client): """Cargar imagen desde S3 extrayendo bucket y key de la URL""" try: url_parts = s3_url.replace('https://', '').split('/') bucket = url_parts[0].split('.s3.amazonaws.com')[0] key = '/'.join(url_parts[1:]) response = s3_client.get_object(Bucket=bucket, Key=key) image_data = response['Body'].read() return Image.open(BytesIO(image_data)).convert('RGB') except Exception as e: print(f"❌ Error cargando imagen: {e}") return None def model_selector(self, model_category): """Seleccionar modelo según categoría""" models = { 182: (self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes), 175: (self.encoder_mascotas, self.class_names_mascotas, self.prototypes_mascotas, self.eval_transform_mascotas), 202: (self.encoder_vinos, self.class_names_vinos, self.prototypes_vinos, self.eval_transform_vinos), 161: (self.encoder_cecinas, self.class_names_cecinas, self.prototypes_cecinas, self.eval_transform_cecinas), 198: (self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores) } return models.get(model_category) def get_ood_thresholds(model_category): """Umbrales OOD para modelos 512px""" config = { 182: {'similarity_threshold': 0.70, 'distance_threshold': 0.80}, # detergentes 175: {'similarity_threshold': 0.68, 'distance_threshold': 0.85}, # mascotas 202: {'similarity_threshold': 0.72, 'distance_threshold': 0.75}, # vinos 161: {'similarity_threshold': 0.69, 'distance_threshold': 0.82}, # cecinas 198: {'similarity_threshold': 0.71, 'distance_threshold': 0.78} # licores } return config.get(model_category, {'similarity_threshold': 0.70, 'distance_threshold': 0.80}) def detect_out_of_distribution(query_features, prototypes, ood_config): """Detección OOD simplificada""" similarities = torch.mm(query_features, prototypes.t()).squeeze(0) max_similarity = similarities.max().item() distances = torch.cdist(query_features, prototypes).squeeze(0) min_distance = distances.min().item() # Criterios OOD is_ood = (max_similarity < ood_config['similarity_threshold'] or min_distance > ood_config['distance_threshold']) # Score combinado similarity_score = max_similarity distance_score = max(0, (ood_config['distance_threshold'] - min_distance) / ood_config['distance_threshold']) ood_score = (0.7 * similarity_score + 0.3 * distance_score) if is_ood: ood_score = max(0, ood_score - 0.05) return is_ood, ood_score def load_classification_model_optimized(model_path, device): """Cargar modelo 512px únicamente""" checkpoint = torch.load(model_path, map_location=device, weights_only=False) if 'prototypes' not in checkpoint or 'class_names' not in checkpoint: raise ValueError("❌ Modelo sin prototipos. Re-entrena con código actualizado.") # Configuración del modelo model_config = checkpoint.get('model_config', {}) hidden_dim = model_config.get('hidden_dim', 64) output_dim = model_config.get('output_dim', 256) image_size = model_config.get('image_size', 512) print(f"📊 Cargando modelo {image_size}px: {len(checkpoint['class_names'])} clases") # Cargar arquitectura y pesos encoder = ConvEncoder(hidden_dim=hidden_dim, output_dim=output_dim).to(device) model = PrototypicalNetwork(encoder).to(device) encoder.load_state_dict(checkpoint['encoder_state_dict']) model.load_state_dict(checkpoint['model_state_dict']) encoder.eval() model.eval() # Prototipos y clases prototypes = checkpoint['prototypes'].to(device) class_names = checkpoint['class_names'] # Transformaciones 512px con SmartPadResize eval_transform = transforms.Compose([ SmartPadResize(target_size=image_size, fill_value=128), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return encoder, class_names, prototypes, eval_transform def load_json_from_s3(json_s3_url): """Cargar JSON desde S3""" session = boto3.Session( aws_access_key_id='AKIA6BH4GPXQCUZ3PAX5', aws_secret_access_key='VMcl897FpEeakLb2mzm3Nfi5FJBIDh9on1yhNFGr', region_name='us-east-1' ) s3_client = session.client('s3') try: url_parts = json_s3_url.replace('https://', '').split('/') bucket = url_parts[0].split('.s3.amazonaws.com')[0] key = '/'.join(url_parts[1:]) response = s3_client.get_object(Bucket=bucket, Key=key) json_content = response['Body'].read().decode('utf-8') return json.loads(json_content), s3_client except Exception as e: print(f"❌ Error cargando JSON: {e}") return None, None def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category): """Clasificación con detección OOD""" if not saved_images: return pd.DataFrame() print(f"🔄 Clasificando {len(saved_images)} imágenes...") ood_config = get_ood_thresholds(model_category) results = [] filtered_count = 0 ood_count = 0 with torch.no_grad(): for img_info in saved_images: try: # Cargar y transformar imagen bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client) if bbox_image is None: continue query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device) query_features = F.normalize(encoder(query_tensor), p=2, dim=1) # Detección OOD is_ood, ood_score = detect_out_of_distribution(query_features, prototypes, ood_config) if is_ood: ood_count += 1 filtered_count += 1 continue # Calcular similitudes similarities = torch.mm(query_features, prototypes.t()).cpu().numpy()[0] top3_indices = np.argsort(similarities)[::-1] # Filtrar por minimal_accuracy predictions = [] accuracies = [] for idx in top3_indices: if similarities[idx] >= minimal_accuracy: predictions.append(class_names[idx]) accuracies.append(round(similarities[idx], 4)) if not predictions: filtered_count += 1 continue # Ajustar con OOD score adjusted_accuracies = [round((acc * 0.9) + (ood_score * 0.1), 4) for acc in accuracies] result = { 'sku_bb_id': str(img_info['bbox_id']), 'predictions': predictions, 'accuracy': adjusted_accuracies, 'prediccion_principal': predictions[0], 'similarity_principal': f"{adjusted_accuracies[0]*100:.2f}%", 'bbox_confidence': round(float(img_info['confidence']), 4), 'ood_score': round(ood_score, 4), 'xmin': img_info['x_min'], 'ymin': img_info['y_min'], 'xmax': img_info['x_max'], 'ymax': img_info['y_max'] } results.append(result) except Exception as e: print(f"❌ Error en bbox {img_info['bbox_id']}: {e}") continue print(f"📊 Procesadas: {len(results)}, Filtradas: {filtered_count}, OOD: {ood_count}") return pd.DataFrame(results) def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url): """Función principal de procesamiento""" print(f"🚀 Procesando imagen con modelo 512px - Categoría: {model_category}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Cargar bounding boxes saved_images, s3_client = load_json_from_s3(json_s3_url) if not saved_images: return pd.DataFrame() saved_images = saved_images['bounding_boxes'] # Seleccionar modelo try: encoder, class_names, prototypes, eval_transform = model_selector(self, model_category) except Exception as e: print(f"❌ Error cargando modelo: {e}") return pd.DataFrame() # Clasificar results_df = classify_saved_bboxes( saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category ) if not results_df.empty: print(f"✅ {len(results_df)} detecciones procesadas") print(f"📊 Clases detectadas: {', '.join(results_df['prediccion_principal'].unique())}") return results_df class EndpointHandler(): def __init__(self, path=""): """Inicialización con modelos 512px únicamente""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"🚀 Inicializando handler con device: {device}") # Cargar modelo de licores model_filename = "model_curriculum4/prototypical_model_best_licores.pth" local_model_path = hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_filename) self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores = load_classification_model_optimized(local_model_path, device) print("✅ Handler inicializado") def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url): """Predicción con modelos 512px""" return process_image_with_bboxes( self, image_url, picture_id, visit_id, minimal_accuracy, None, None, model_category, json_s3_url ) def __call__(self, event): """Método de llamada principal""" if "inputs" not in event: return {"statusCode": 400, "body": json.dumps("Error: No 'inputs' parameter.")} event = event["inputs"] try: predictions = self.predict_objects( event["image_url"], event["picture_id"], event["visit_id"], event["minimal_accuracy"], event["model_category"], event["json_s3_url"] ) return { "statusCode": 200, "body": json.dumps(predictions.to_json(orient='records')) } except Exception as e: return {"statusCode": 500, "body": json.dumps(f"Error: {str(e)}")}