| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | import os |
| | import argparse |
| | from pathlib import Path |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from transformers import CLIPProcessor |
| |
|
| | from model import SiameseCLIPModel |
| | from config import Config |
| |
|
| | class SimilarityInference: |
| | def __init__(self, checkpoint_path, device=None): |
| | """ |
| | Inicializa el modelo para inferencia. |
| | |
| | Args: |
| | checkpoint_path: Ruta al checkpoint del modelo entrenado |
| | device: Dispositivo para inferencia (cuda o cpu) |
| | """ |
| | |
| | if device is None: |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | else: |
| | self.device = torch.device(device) |
| | |
| | print(f"Utilizando dispositivo: {self.device}") |
| | |
| | |
| | self.config = Config |
| | |
| | |
| | self.model = SiameseCLIPModel(self.config) |
| | |
| | |
| | self._load_checkpoint(checkpoint_path) |
| | |
| | |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | self.processor = CLIPProcessor.from_pretrained(self.config.CLIP_MODEL_NAME) |
| | |
| | def _load_checkpoint(self, checkpoint_path): |
| | """ |
| | Carga los pesos del modelo desde un checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Ruta al checkpoint |
| | """ |
| | if not os.path.exists(checkpoint_path): |
| | raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}") |
| | |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| | |
| | |
| | if 'model_state_dict' in checkpoint: |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | print(f"Modelo cargado desde {checkpoint_path}") |
| | print(f"脡poca: {checkpoint.get('epoch', 'N/A')}") |
| | print(f"Mejor p茅rdida de validaci贸n: {checkpoint.get('best_val_loss', 'N/A')}") |
| | else: |
| | |
| | self.model.load_state_dict(checkpoint) |
| | print(f"Modelo cargado desde {checkpoint_path} (formato simple)") |
| | |
| | def preprocess_image(self, image_path): |
| | """ |
| | Preprocesa una imagen para inferencia. |
| | |
| | Args: |
| | image_path: Ruta a la imagen |
| | |
| | Returns: |
| | Tensor de la imagen procesada |
| | """ |
| | if not os.path.exists(image_path): |
| | raise FileNotFoundError(f"No se encontr贸 la imagen en {image_path}") |
| | |
| | |
| | image = Image.open(image_path).convert('RGB') |
| | |
| | |
| | inputs = self.processor(images=image, return_tensors="pt") |
| | |
| | return inputs.pixel_values.to(self.device) |
| | |
| | def preprocess_text(self, text): |
| | """ |
| | Preprocesa un texto para inferencia. |
| | |
| | Args: |
| | text: Texto a procesar |
| | |
| | Returns: |
| | Tensores de input_ids y attention_mask |
| | """ |
| | |
| | inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True) |
| | |
| | return { |
| | 'input_ids': inputs.input_ids.to(self.device), |
| | 'attention_mask': inputs.attention_mask.to(self.device) |
| | } |
| | |
| | def calculate_similarity(self, image1_path, image2_path, text1=None, text2=None): |
| | """ |
| | Calcula la similitud entre dos im谩genes (y opcionalmente sus textos). |
| | |
| | Args: |
| | image1_path: Ruta a la primera imagen |
| | image2_path: Ruta a la segunda imagen |
| | text1: Descripci贸n de la primera imagen (opcional) |
| | text2: Descripci贸n de la segunda imagen (opcional) |
| | |
| | Returns: |
| | Similitud del coseno entre los embeddings [-1, 1] |
| | """ |
| | |
| | image1 = self.preprocess_image(image1_path) |
| | image2 = self.preprocess_image(image2_path) |
| | |
| | |
| | text1_inputs = None |
| | text2_inputs = None |
| | |
| | if text1 is not None and text2 is not None and self.config.USE_TEXT_EMBEDDINGS: |
| | text1_inputs = self.preprocess_text(text1) |
| | text2_inputs = self.preprocess_text(text2) |
| | |
| | |
| | with torch.no_grad(): |
| | if text1_inputs is not None and text2_inputs is not None: |
| | similarity = self.model.calculate_similarity( |
| | image1_pixel_values=image1, |
| | image2_pixel_values=image2, |
| | text1_input_ids=text1_inputs['input_ids'], |
| | text2_input_ids=text2_inputs['input_ids'], |
| | text1_attention_mask=text1_inputs['attention_mask'], |
| | text2_attention_mask=text2_inputs['attention_mask'] |
| | ) |
| | else: |
| | similarity = self.model.calculate_similarity( |
| | image1_pixel_values=image1, |
| | image2_pixel_values=image2 |
| | ) |
| | |
| | return similarity.item() |
| | |
| | def calculate_batch_similarities(self, reference_image, comparison_images, reference_text=None, comparison_texts=None): |
| | """ |
| | Calcula similitudes entre una imagen de referencia y m煤ltiples im谩genes de comparaci贸n. |
| | |
| | Args: |
| | reference_image: Ruta a la imagen de referencia |
| | comparison_images: Lista de rutas a im谩genes para comparar |
| | reference_text: Descripci贸n de la imagen de referencia (opcional) |
| | comparison_texts: Lista de descripciones para las im谩genes de comparaci贸n (opcional) |
| | |
| | Returns: |
| | Lista de similitudes ordenadas de mayor a menor |
| | """ |
| | |
| | ref_image = self.preprocess_image(reference_image) |
| | |
| | |
| | ref_text_inputs = None |
| | if reference_text is not None and self.config.USE_TEXT_EMBEDDINGS: |
| | ref_text_inputs = self.preprocess_text(reference_text) |
| | |
| | results = [] |
| | |
| | |
| | for i, comp_image_path in enumerate(comparison_images): |
| | try: |
| | |
| | comp_image = self.preprocess_image(comp_image_path) |
| | |
| | |
| | comp_text_inputs = None |
| | if comparison_texts is not None and i < len(comparison_texts) and self.config.USE_TEXT_EMBEDDINGS: |
| | comp_text_inputs = self.preprocess_text(comparison_texts[i]) |
| | |
| | |
| | with torch.no_grad(): |
| | if ref_text_inputs is not None and comp_text_inputs is not None: |
| | similarity = self.model.calculate_similarity( |
| | image1_pixel_values=ref_image, |
| | image2_pixel_values=comp_image, |
| | text1_input_ids=ref_text_inputs['input_ids'], |
| | text2_input_ids=comp_text_inputs['input_ids'], |
| | text1_attention_mask=ref_text_inputs['attention_mask'], |
| | text2_attention_mask=comp_text_inputs['attention_mask'] |
| | ) |
| | else: |
| | similarity = self.model.calculate_similarity( |
| | image1_pixel_values=ref_image, |
| | image2_pixel_values=comp_image |
| | ) |
| | |
| | results.append({ |
| | 'image_path': comp_image_path, |
| | 'similarity': similarity.item() |
| | }) |
| | except Exception as e: |
| | print(f"Error al procesar {comp_image_path}: {e}") |
| | |
| | |
| | results.sort(key=lambda x: x['similarity'], reverse=True) |
| | |
| | return results |
| | |
| | def visualize_similarities(self, reference_image, comparison_results, num_images=5, figsize=(15, 10)): |
| | """ |
| | Visualiza las similitudes entre una imagen de referencia y las im谩genes m谩s similares. |
| | |
| | Args: |
| | reference_image: Ruta a la imagen de referencia |
| | comparison_results: Resultados de calculate_batch_similarities |
| | num_images: N煤mero de im谩genes similares a mostrar |
| | figsize: Tama帽o de la figura |
| | """ |
| | |
| | num_images = min(num_images, len(comparison_results)) |
| | |
| | |
| | fig, axes = plt.subplots(1, num_images + 1, figsize=figsize) |
| | |
| | |
| | ref_img = Image.open(reference_image).convert('RGB') |
| | axes[0].imshow(ref_img) |
| | axes[0].set_title("Imagen de referencia") |
| | axes[0].axis('off') |
| | |
| | |
| | for i in range(num_images): |
| | img_path = comparison_results[i]['image_path'] |
| | similarity = comparison_results[i]['similarity'] |
| | |
| | img = Image.open(img_path).convert('RGB') |
| | axes[i+1].imshow(img) |
| | axes[i+1].set_title(f"Sim: {similarity:.4f}") |
| | axes[i+1].axis('off') |
| | |
| | plt.tight_layout() |
| | plt.show() |
| |
|
| |
|