File size: 11,600 Bytes
8fb36c3
 
 
 
 
 
 
 
 
 
 
 
 
37cbec3
 
8fb36c3
 
37cbec3
 
 
 
 
 
 
 
 
 
 
 
8fb36c3
 
37cbec3
 
 
 
 
 
 
 
 
8fb36c3
860195d
37cbec3
 
 
 
 
 
 
 
 
860195d
37cbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860195d
3ba5712
37cbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb36c3
 
37cbec3
 
 
 
 
 
 
8fb36c3
37cbec3
 
 
 
 
 
 
 
 
 
 
8fb36c3
860195d
37cbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb36c3
 
37cbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb36c3
 
37cbec3
 
 
 
 
 
 
 
 
 
 
 
8fb36c3
37cbec3
 
 
 
 
 
8fb36c3
37cbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
from pathlib import Path
import json
from io import BytesIO
import boto3
from botocore.exceptions import ClientError
from huggingface_hub import hf_hub_download

# Imports desde el código de entrenamiento actualizado
from redes_prototipicas_tvt5 import ConvEncoder, PrototypicalNetwork, SmartPadResize

def load_image_from_s3_url(s3_url, s3_client):
    """Cargar imagen desde S3 extrayendo bucket y key de la URL"""
    try:
        url_parts = s3_url.replace('https://', '').split('/')
        bucket = url_parts[0].split('.s3.amazonaws.com')[0]
        key = '/'.join(url_parts[1:])
        
        response = s3_client.get_object(Bucket=bucket, Key=key)
        image_data = response['Body'].read()
        return Image.open(BytesIO(image_data)).convert('RGB')
    except Exception as e:
        print(f"❌ Error cargando imagen: {e}")
        return None

def model_selector(self, model_category):
    """Seleccionar modelo según categoría"""
    models = {
        182: (self.encoder_detergentes, self.class_names_detergentes, self.prototypes_detergentes, self.eval_transform_detergentes),
        175: (self.encoder_mascotas, self.class_names_mascotas, self.prototypes_mascotas, self.eval_transform_mascotas),
        202: (self.encoder_vinos, self.class_names_vinos, self.prototypes_vinos, self.eval_transform_vinos),
        161: (self.encoder_cecinas, self.class_names_cecinas, self.prototypes_cecinas, self.eval_transform_cecinas),
        198: (self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores)
    }
    return models.get(model_category)

def get_ood_thresholds(model_category):
    """Umbrales OOD para modelos 512px"""
    config = {
        182: {'similarity_threshold': 0.70, 'distance_threshold': 0.80},  # detergentes
        175: {'similarity_threshold': 0.68, 'distance_threshold': 0.85},  # mascotas
        202: {'similarity_threshold': 0.72, 'distance_threshold': 0.75},  # vinos
        161: {'similarity_threshold': 0.69, 'distance_threshold': 0.82},  # cecinas
        198: {'similarity_threshold': 0.71, 'distance_threshold': 0.78}   # licores
    }
    return config.get(model_category, {'similarity_threshold': 0.70, 'distance_threshold': 0.80})

def detect_out_of_distribution(query_features, prototypes, ood_config):
    """Detección OOD simplificada"""
    similarities = torch.mm(query_features, prototypes.t()).squeeze(0)
    max_similarity = similarities.max().item()
    
    distances = torch.cdist(query_features, prototypes).squeeze(0)
    min_distance = distances.min().item()
    
    # Criterios OOD
    is_ood = (max_similarity < ood_config['similarity_threshold'] or 
              min_distance > ood_config['distance_threshold'])
    
    # Score combinado
    similarity_score = max_similarity
    distance_score = max(0, (ood_config['distance_threshold'] - min_distance) / ood_config['distance_threshold'])
    ood_score = (0.7 * similarity_score + 0.3 * distance_score)
    
    if is_ood:
        ood_score = max(0, ood_score - 0.05)
    
    return is_ood, ood_score

def load_classification_model_optimized(model_path, device):
    """Cargar modelo 512px únicamente"""
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    if 'prototypes' not in checkpoint or 'class_names' not in checkpoint:
        raise ValueError("❌ Modelo sin prototipos. Re-entrena con código actualizado.")
    
    # Configuración del modelo
    model_config = checkpoint.get('model_config', {})
    hidden_dim = model_config.get('hidden_dim', 64)
    output_dim = model_config.get('output_dim', 256)
    image_size = model_config.get('image_size', 512)
    
    print(f"📊 Cargando modelo {image_size}px: {len(checkpoint['class_names'])} clases")
    
    # Cargar arquitectura y pesos
    encoder = ConvEncoder(hidden_dim=hidden_dim, output_dim=output_dim).to(device)
    model = PrototypicalNetwork(encoder).to(device)
    
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    model.load_state_dict(checkpoint['model_state_dict'])
    encoder.eval()
    model.eval()
    
    # Prototipos y clases
    prototypes = checkpoint['prototypes'].to(device)
    class_names = checkpoint['class_names']
    
    # Transformaciones 512px con SmartPadResize
    eval_transform = transforms.Compose([
        SmartPadResize(target_size=image_size, fill_value=128),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return encoder, class_names, prototypes, eval_transform

def load_json_from_s3(json_s3_url):
    """Cargar JSON desde S3"""
    session = boto3.Session(
        aws_access_key_id='AKIA6BH4GPXQCUZ3PAX5',
        aws_secret_access_key='VMcl897FpEeakLb2mzm3Nfi5FJBIDh9on1yhNFGr',
        region_name='us-east-1'
    )
    s3_client = session.client('s3')

    try:
        url_parts = json_s3_url.replace('https://', '').split('/')
        bucket = url_parts[0].split('.s3.amazonaws.com')[0]
        key = '/'.join(url_parts[1:])
        
        response = s3_client.get_object(Bucket=bucket, Key=key)
        json_content = response['Body'].read().decode('utf-8')
        return json.loads(json_content), s3_client
    except Exception as e:
        print(f"❌ Error cargando JSON: {e}")
        return None, None

def classify_saved_bboxes(saved_images, encoder, class_names, prototypes, eval_transform, device, minimal_accuracy, s3_client, model_category):
    """Clasificación con detección OOD"""
    if not saved_images:
        return pd.DataFrame()
    
    print(f"🔄 Clasificando {len(saved_images)} imágenes...")
    
    ood_config = get_ood_thresholds(model_category)
    results = []
    filtered_count = 0
    ood_count = 0
    
    with torch.no_grad():
        for img_info in saved_images:
            try:
                # Cargar y transformar imagen
                bbox_image = load_image_from_s3_url(img_info['bbox_path'], s3_client)
                if bbox_image is None:
                    continue
                
                query_tensor = eval_transform(bbox_image).unsqueeze(0).to(device)
                query_features = F.normalize(encoder(query_tensor), p=2, dim=1)
                
                # Detección OOD
                is_ood, ood_score = detect_out_of_distribution(query_features, prototypes, ood_config)
                
                if is_ood:
                    ood_count += 1
                    filtered_count += 1
                    continue
                
                # Calcular similitudes
                similarities = torch.mm(query_features, prototypes.t()).cpu().numpy()[0]
                top3_indices = np.argsort(similarities)[::-1]
                
                # Filtrar por minimal_accuracy
                predictions = []
                accuracies = []
                
                for idx in top3_indices:
                    if similarities[idx] >= minimal_accuracy:
                        predictions.append(class_names[idx])
                        accuracies.append(round(similarities[idx], 4))
                
                if not predictions:
                    filtered_count += 1
                    continue
                
                # Ajustar con OOD score
                adjusted_accuracies = [round((acc * 0.9) + (ood_score * 0.1), 4) for acc in accuracies]
                
                result = {
                    'sku_bb_id': str(img_info['bbox_id']),
                    'predictions': predictions,
                    'accuracy': adjusted_accuracies,
                    'prediccion_principal': predictions[0],
                    'similarity_principal': f"{adjusted_accuracies[0]*100:.2f}%",
                    'bbox_confidence': round(float(img_info['confidence']), 4),
                    'ood_score': round(ood_score, 4),
                    'xmin': img_info['x_min'],
                    'ymin': img_info['y_min'],
                    'xmax': img_info['x_max'],
                    'ymax': img_info['y_max']
                }
                results.append(result)
                
            except Exception as e:
                print(f"❌ Error en bbox {img_info['bbox_id']}: {e}")
                continue
    
    print(f"📊 Procesadas: {len(results)}, Filtradas: {filtered_count}, OOD: {ood_count}")
    return pd.DataFrame(results)

def process_image_with_bboxes(self, image_url, picture_id, visit_id, minimal_accuracy, model_path, train_path, model_category, json_s3_url):
    """Función principal de procesamiento"""
    print(f"🚀 Procesando imagen con modelo 512px - Categoría: {model_category}")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Cargar bounding boxes
    saved_images, s3_client = load_json_from_s3(json_s3_url)
    if not saved_images:
        return pd.DataFrame()
    
    saved_images = saved_images['bounding_boxes']
    
    # Seleccionar modelo
    try:
        encoder, class_names, prototypes, eval_transform = model_selector(self, model_category)
    except Exception as e:
        print(f"❌ Error cargando modelo: {e}")
        return pd.DataFrame()
    
    # Clasificar
    results_df = classify_saved_bboxes(
        saved_images, encoder, class_names, prototypes, eval_transform, 
        device, minimal_accuracy, s3_client, model_category
    )
    
    if not results_df.empty:
        print(f"✅ {len(results_df)} detecciones procesadas")
        print(f"📊 Clases detectadas: {', '.join(results_df['prediccion_principal'].unique())}")
    
    return results_df

class EndpointHandler():
    def __init__(self, path=""):
        """Inicialización con modelos 512px únicamente"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🚀 Inicializando handler con device: {device}")
        
        # Cargar modelo de licores
        model_filename = "model_curriculum4/prototypical_model_best_licores.pth"
        local_model_path = hf_hub_download(repo_id="Drazcat-AI/redes_prototipicas", filename=model_filename)
        
        self.encoder_licores, self.class_names_licores, self.prototypes_licores, self.eval_transform_licores = load_classification_model_optimized(local_model_path, device)
        
        print("✅ Handler inicializado")

    def predict_objects(self, image_url, picture_id, visit_id, minimal_accuracy, model_category, json_s3_url):
        """Predicción con modelos 512px"""
        return process_image_with_bboxes(
            self, image_url, picture_id, visit_id, minimal_accuracy, 
            None, None, model_category, json_s3_url
        )

    def __call__(self, event):
        """Método de llamada principal"""
        if "inputs" not in event:
            return {"statusCode": 400, "body": json.dumps("Error: No 'inputs' parameter.")}
        
        event = event["inputs"]
        
        try:
            predictions = self.predict_objects(
                event["image_url"], event["picture_id"], event["visit_id"],
                event["minimal_accuracy"], event["model_category"], event["json_s3_url"]
            )
            
            return {
                "statusCode": 200,
                "body": json.dumps(predictions.to_json(orient='records'))
            }
        except Exception as e:
            return {"statusCode": 500, "body": json.dumps(f"Error: {str(e)}")}