|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
import json |
|
|
from io import BytesIO |
|
|
import boto3 |
|
|
from botocore.exceptions import ClientError |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork, SmartPadResize |
|
|
|
|
|
def load_image_from_s3_url(s3_url, s3_client): |
|
|
"""Cargar imagen desde S3 extrayendo bucket y key de la URL""" |
|
|
try: |
|
|
url_parts = s3_url.replace('https://', '').split('/') |
|
|
bucket = url_parts[0].split('.s3.amazonaws.com')[0] |
|
|
key = '/'.join(url_parts[1:]) |
|
|
|
|
|
response = s3_client.get_object(Bucket=bucket, Key=key) |
|
|
image_data = response['Body'].read() |
|
|
return Image.open(BytesIO(image_data)).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"❌ Error cargando imagen: {e}") |
|
|
return None |
|
|
|
|
|
def model_selector(self, model_category): |
|
|
"""Seleccionar modelo según categoría""" |
|
|
models = { |
|
|
182: (self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes), |
|
|
175: (self.encoder_mascotas, self.class_names_mascotas, self.prototypes_mascotas, self.eval_transform_mascotas), |
|
|
202: (self.encoder_vinos, self.class_names_vinos, self.prototypes_vinos, self.eval_transform_vinos), |
|
|
161: (self.encoder_cecinas, self.class_names_cecinas, self.prototypes_cecinas, self.eval_transform_cecinas), |
|
|
198: (self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores) |
|
|
} |
|
|
return models.get(model_category) |
|
|
|
|
|
def get_ood_thresholds(model_category): |
|
|
"""Umbrales OOD para modelos 512px""" |
|
|
config = { |
|
|
182: {'similarity_threshold': 0.70, 'distance_threshold': 0.80}, |
|
|
175: {'similarity_threshold': 0.68, 'distance_threshold': 0.85}, |
|
|
202: {'similarity_threshold': 0.72, 'distance_threshold': 0.75}, |
|
|
161: {'similarity_threshold': 0.69, 'distance_threshold': 0.82}, |
|
|
198: {'similarity_threshold': 0.71, 'distance_threshold': 0.78} |
|
|
} |
|
|
return config.get(model_category, {'similarity_threshold': 0.70, 'distance_threshold': 0.80}) |
|
|
|
|
|
def detect_out_of_distribution(query_features, prototypes, ood_config): |
|
|
"""Detección OOD simplificada""" |
|
|
similarities = torch.mm(query_features, prototypes.t()).squeeze(0) |
|
|
max_similarity = similarities.max().item() |
|
|
|
|
|
distances = torch.cdist(query_features, prototypes).squeeze(0) |
|
|
min_distance = distances.min().item() |
|
|
|
|
|
|
|
|
is_ood = (max_similarity < ood_config['similarity_threshold'] or |
|
|
min_distance > ood_config['distance_threshold']) |
|
|
|
|
|
|
|
|
similarity_score = max_similarity |
|
|
distance_score = max(0, (ood_config['distance_threshold'] - min_distance) / ood_config['distance_threshold']) |
|
|
ood_score = (0.7 * similarity_score + 0.3 * distance_score) |
|
|
|
|
|
if is_ood: |
|
|
ood_score = max(0, ood_score - 0.05) |
|
|
|
|
|
return is_ood, ood_score |
|
|
|
|
|
def load_classification_model_optimized(model_path, device): |
|
|
"""Cargar modelo 512px únicamente""" |
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
|
|
|
|
if 'prototypes' not in checkpoint or 'class_names' not in checkpoint: |
|
|
raise ValueError("❌ Modelo sin prototipos. Re-entrena con código actualizado.") |
|
|
|
|
|
|
|
|
model_config = checkpoint.get('model_config', {}) |
|
|
hidden_dim = model_config.get('hidden_dim', 64) |
|
|
output_dim = model_config.get('output_dim', 256) |
|
|
image_size = model_config.get('image_size', 512) |
|
|
|
|
|
print(f"📊 Cargando modelo {image_size}px: {len(checkpoint['class_names'])} clases") |
|
|
|
|
|
|
|
|
encoder = ConvEncoder(hidden_dim=hidden_dim, output_dim=output_dim).to(device) |
|
|
model = PrototypicalNetwork(encoder).to(device) |
|
|
|
|
|
encoder.load_state_dict(checkpoint['encoder_state_dict']) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
encoder.eval() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
prototypes = checkpoint['prototypes'].to(device) |
|
|
class_names = checkpoint['class_names'] |
|
|
|
|
|
|
|
|
eval_transform = transforms.Compose([ |
|
|
SmartPadResize(target_size=image_size, fill_value=128), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
return encoder, class_names, prototypes, eval_transform |
|
|
|
|
|
def load_json_from_s3(json_s3_url): |
|
|
"""Cargar JSON desde S3""" |
|
|
session = boto3.Session( |
|
|
aws_access_key_id='AKIA6BH4GPXQCUZ3PAX5', |
|
|
aws_secret_access_key='VMcl897FpEeakLb2mzm3Nfi5FJBIDh9on1yhNFGr', |
|
|
region_name='us-east-1' |
|
|
) |
|
|
s3_client = session.client('s3') |
|
|
|
|
|
try: |
|
|
url_parts = json_s3_url.replace('https://', '').split('/') |
|
|
bucket = url_parts[0].split('.s3.amazonaws.com')[0] |
|
|
key = '/'.join(url_parts[1:]) |
|
|
|
|
|
response = s3_client.get_object(Bucket=bucket, Key=key) |
|
|
json_content = response['Body'].read().decode('utf-8') |
|
|
return json.loads(json_content), s3_client |
|
|
except Exception as e: |
|
|
print(f"❌ Error cargando JSON: {e}") |
|
|
return None, None |
|
|
|
|
|
def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category): |
|
|
"""Clasificación con detección OOD""" |
|
|
if not saved_images: |
|
|
return pd.DataFrame() |
|
|
|
|
|
print(f"🔄 Clasificando {len(saved_images)} imágenes...") |
|
|
|
|
|
ood_config = get_ood_thresholds(model_category) |
|
|
results = [] |
|
|
filtered_count = 0 |
|
|
ood_count = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for img_info in saved_images: |
|
|
try: |
|
|
|
|
|
bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client) |
|
|
if bbox_image is None: |
|
|
continue |
|
|
|
|
|
query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device) |
|
|
query_features = F.normalize(encoder(query_tensor), p=2, dim=1) |
|
|
|
|
|
|
|
|
is_ood, ood_score = detect_out_of_distribution(query_features, prototypes, ood_config) |
|
|
|
|
|
if is_ood: |
|
|
ood_count += 1 |
|
|
filtered_count += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
similarities = torch.mm(query_features, prototypes.t()).cpu().numpy()[0] |
|
|
top3_indices = np.argsort(similarities)[::-1] |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
accuracies = [] |
|
|
|
|
|
for idx in top3_indices: |
|
|
if similarities[idx] >= minimal_accuracy: |
|
|
predictions.append(class_names[idx]) |
|
|
accuracies.append(round(similarities[idx], 4)) |
|
|
|
|
|
if not predictions: |
|
|
filtered_count += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
adjusted_accuracies = [round((acc * 0.9) + (ood_score * 0.1), 4) for acc in accuracies] |
|
|
|
|
|
result = { |
|
|
'sku_bb_id': str(img_info['bbox_id']), |
|
|
'predictions': predictions, |
|
|
'accuracy': adjusted_accuracies, |
|
|
'prediccion_principal': predictions[0], |
|
|
'similarity_principal': f"{adjusted_accuracies[0]*100:.2f}%", |
|
|
'bbox_confidence': round(float(img_info['confidence']), 4), |
|
|
'ood_score': round(ood_score, 4), |
|
|
'xmin': img_info['x_min'], |
|
|
'ymin': img_info['y_min'], |
|
|
'xmax': img_info['x_max'], |
|
|
'ymax': img_info['y_max'] |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error en bbox {img_info['bbox_id']}: {e}") |
|
|
continue |
|
|
|
|
|
print(f"📊 Procesadas: {len(results)}, Filtradas: {filtered_count}, OOD: {ood_count}") |
|
|
return pd.DataFrame(results) |
|
|
|
|
|
def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url): |
|
|
"""Función principal de procesamiento""" |
|
|
print(f"🚀 Procesando imagen con modelo 512px - Categoría: {model_category}") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
saved_images, s3_client = load_json_from_s3(json_s3_url) |
|
|
if not saved_images: |
|
|
return pd.DataFrame() |
|
|
|
|
|
saved_images = saved_images['bounding_boxes'] |
|
|
|
|
|
|
|
|
try: |
|
|
encoder, class_names, prototypes, eval_transform = model_selector(self, model_category) |
|
|
except Exception as e: |
|
|
print(f"❌ Error cargando modelo: {e}") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
results_df = classify_saved_bboxes( |
|
|
saved_images, encoder, class_names, prototypes, eval_transform, |
|
|
device, minimal_accuracy, s3_client, model_category |
|
|
) |
|
|
|
|
|
if not results_df.empty: |
|
|
print(f"✅ {len(results_df)} detecciones procesadas") |
|
|
print(f"📊 Clases detectadas: {', '.join(results_df['prediccion_principal'].unique())}") |
|
|
|
|
|
return results_df |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
"""Inicialización con modelos 512px únicamente""" |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"🚀 Inicializando handler con device: {device}") |
|
|
|
|
|
|
|
|
model_filename = "model_curriculum4/prototypical_model_best_licores.pth" |
|
|
local_model_path = hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_filename) |
|
|
|
|
|
self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores = load_classification_model_optimized(local_model_path, device) |
|
|
|
|
|
print("✅ Handler inicializado") |
|
|
|
|
|
def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url): |
|
|
"""Predicción con modelos 512px""" |
|
|
return process_image_with_bboxes( |
|
|
self, image_url, picture_id, visit_id, minimal_accuracy, |
|
|
None, None, model_category, json_s3_url |
|
|
) |
|
|
|
|
|
def __call__(self, event): |
|
|
"""Método de llamada principal""" |
|
|
if "inputs" not in event: |
|
|
return {"statusCode": 400, "body": json.dumps("Error: No 'inputs' parameter.")} |
|
|
|
|
|
event = event["inputs"] |
|
|
|
|
|
try: |
|
|
predictions = self.predict_objects( |
|
|
event["image_url"], event["picture_id"], event["visit_id"], |
|
|
event["minimal_accuracy"], event["model_category"], event["json_s3_url"] |
|
|
) |
|
|
|
|
|
return { |
|
|
"statusCode": 200, |
|
|
"body": json.dumps(predictions.to_json(orient='records')) |
|
|
} |
|
|
except Exception as e: |
|
|
return {"statusCode": 500, "body": json.dumps(f"Error: {str(e)}")} |