import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from PIL import Image, ImageOps import torchvision.transforms as transforms import os from transformers import ViTForImageClassification, ViTConfig from sklearn.metrics import accuracy_score, classification_report import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm from typing import List, Tuple, Dict, Optional import json import warnings warnings.filterwarnings('ignore') # ============================================================================ # CONFIGURACIÓN PARA JUPYTER NOTEBOOK # ============================================================================ # CONFIGURAR ESTOS PATHS SEGÚN TU ESTRUCTURA DE DATOS DATA_PATH = "datasets/peru_cencosud_categories-2" # Cambiar por tu path de datos SAVE_PATH = "vit_multiclass_model" # Donde guardar el modelo entrenado MODEL_NAME = "google/vit-base-patch16-224" # Modelo ViT preentrenado # CONFIGURACIÓN DE IMAGEN IMAGE_SIZE = 800 # Resolución objetivo PADDING_COLOR = (128, 128, 128) # Color de padding (gris medio) # HIPERPARÁMETROS OPTIMIZADOS PARA 26K IMÁGENES / 90 CLASES EPOCHS = 30 # Más épocas por la cantidad de datos y clases BATCH_SIZE = 8 # Aumentado para mejor estabilidad LEARNING_RATE = 1e-4 # Reducido para mejor convergencia WEIGHT_DECAY = 1e-4 # Regularización WARMUP_EPOCHS = 3 # Warmup para estabilidad inicial # ============================================================================ # PROCESADOR DE IMÁGENES PERSONALIZADO # ============================================================================ class PaddingImageProcessor: """Procesador de imágenes personalizado que mantiene aspect ratio con padding""" def __init__(self, target_size: int = 1280, padding_color: tuple = (128, 128, 128)): """ Args: target_size: Tamaño objetivo (cuadrado) padding_color: Color del padding en RGB """ self.target_size = target_size self.padding_color = padding_color # Transforms para normalización (valores estándar de ImageNet) self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) def pad_to_square(self, image: Image.Image) -> Image.Image: """Aplica padding para hacer la imagen cuadrada manteniendo aspect ratio""" width, height = image.size # Determinar el tamaño del cuadrado (el lado más largo) max_size = max(width, height) # Crear imagen cuadrada con color de padding padded_image = Image.new('RGB', (max_size, max_size), self.padding_color) # Calcular posición para centrar la imagen original left = (max_size - width) // 2 top = (max_size - height) // 2 # Pegar la imagen original en el centro padded_image.paste(image, (left, top)) return padded_image def __call__(self, image: Image.Image) -> torch.Tensor: """ Procesa una imagen aplicando padding + resize Args: image: Imagen PIL en formato RGB Returns: Tensor procesado listo para el modelo """ # 1. Aplicar padding para hacer cuadrada padded_image = self.pad_to_square(image) # 2. Resize a la resolución objetivo manteniendo aspect ratio (ya es cuadrada) resized_image = padded_image.resize((self.target_size, self.target_size), Image.Resampling.LANCZOS) # 3. Convertir a tensor y normalizar # Convertir PIL a tensor [0, 1] transform_to_tensor = transforms.ToTensor() tensor_image = transform_to_tensor(resized_image) # 4. Normalizar con valores de ImageNet normalized_image = self.normalize(tensor_image) return normalized_image # ============================================================================ # DATASET PERSONALIZADO # ============================================================================ class MultiClassImageDataset(Dataset): """Dataset personalizado para clasificación multi-clase de imágenes""" def __init__(self, csv_path: str, images_dir: str, image_processor: PaddingImageProcessor, class_columns: List[str], filename_column: str): """ Args: csv_path: Ruta al archivo CSV con las anotaciones images_dir: Directorio que contiene las imágenes image_processor: Procesador personalizado de imágenes class_columns: Lista de nombres de columnas que representan las clases filename_column: Nombre de la columna que contiene los nombres de archivos """ self.df = pd.read_csv(csv_path) self.images_dir = images_dir self.image_processor = image_processor self.class_columns = class_columns self.filename_column = filename_column print(f"Dataset cargado desde {csv_path}: {len(self.df)} imágenes") print(f"Columnas de clases: {class_columns}") def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] # Cargar imagen usando la columna de filename detectada img_path = os.path.join(self.images_dir, row[self.filename_column]) try: image = Image.open(img_path).convert('RGB') except Exception as e: print(f"Error cargando imagen {img_path}: {e}") # Crear imagen dummy si hay error image = Image.new('RGB', (224, 224), color='black') # Procesar imagen con padding + resize personalizado processed_image = self.image_processor(image) # Crear tensor de etiquetas multi-clase labels = torch.tensor([row[col] for col in self.class_columns], dtype=torch.float32) return processed_image, labels # ============================================================================ # ENTRENADOR ViT # ============================================================================ class ViTMultiClassTrainer: """Entrenador para ViT con clasificación multi-clase""" def __init__(self, data_path: str, model_name: str = "google/vit-base-patch16-224"): """ Args: data_path: Ruta base donde están los directorios train/valid/test model_name: Nombre del modelo ViT preentrenado """ self.data_path = data_path self.model_name = model_name self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Usando dispositivo: {self.device}") # Inicializar procesador personalizado self.image_processor = PaddingImageProcessor( target_size=IMAGE_SIZE, padding_color=PADDING_COLOR ) print(f"Procesador de imágenes configurado: {IMAGE_SIZE}px con padding {PADDING_COLOR}") # Detectar estructura de datos automáticamente self._detect_data_structure() def _find_csv_in_folder(self, folder_path: str) -> Optional[str]: """Busca el archivo CSV en una carpeta específica""" if not os.path.exists(folder_path): return None csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')] if len(csv_files) == 0: print(f"No se encontró CSV en {folder_path}") return None elif len(csv_files) == 1: csv_path = os.path.join(folder_path, csv_files[0]) print(f"CSV encontrado: {csv_path}") return csv_path else: # Si hay múltiples CSVs, tomar el primero csv_path = os.path.join(folder_path, csv_files[0]) print(f"Múltiples CSVs en {folder_path}, usando: {csv_files[0]}") return csv_path def _detect_filename_column(self, df: pd.DataFrame) -> str: """Detecta la columna que contiene los nombres de archivos""" possible_names = ['filename', 'image', 'image_name', 'file', 'name', 'img'] for col in possible_names: if col in df.columns: return col # Si no encuentra ninguna, usar la primera columna print(f"No se encontró columna de filename conocida. Usando: {df.columns[0]}") return df.columns[0] def _detect_data_structure(self): """Detecta automáticamente la estructura de datos y clases""" print("Detectando estructura de datos...") # Buscar CSV en carpeta de entrenamiento train_folder = os.path.join(self.data_path, 'train') train_csv = self._find_csv_in_folder(train_folder) if train_csv is None: raise FileNotFoundError(f"No se encontró CSV en {train_folder}") # Cargar CSV para detectar columnas df = pd.read_csv(train_csv) print(f"Columnas encontradas: {list(df.columns)}") # Detectar columna de filename self.filename_column = self._detect_filename_column(df) print(f"Columna de archivos detectada: {self.filename_column}") # Las demás columnas son las clases self.class_columns = [col for col in df.columns if col != self.filename_column] self.num_classes = len(self.class_columns) if self.num_classes == 0: raise ValueError("No se encontraron columnas de clases") print(f"Clases detectadas ({self.num_classes}): {self.class_columns}") # Verificar otras carpetas for split in ['valid', 'test']: split_folder = os.path.join(self.data_path, split) if os.path.exists(split_folder): csv_path = self._find_csv_in_folder(split_folder) if csv_path: print(f"Carpeta {split}: CSV encontrado") else: print(f"Carpeta {split}: Sin CSV") else: print(f"Carpeta {split}: No existe") def _create_datasets(self) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]: """Crea los datasets de entrenamiento, validación y prueba""" datasets = {} for split in ['train', 'valid', 'test']: split_folder = os.path.join(self.data_path, split) csv_path = self._find_csv_in_folder(split_folder) if csv_path is not None: datasets[split] = MultiClassImageDataset( csv_path=csv_path, images_dir=split_folder, image_processor=self.image_processor, class_columns=self.class_columns, filename_column=self.filename_column ) else: datasets[split] = None return datasets.get('train'), datasets.get('valid'), datasets.get('test') def _create_model(self): """Crea el modelo ViT para clasificación multi-clase con resolución personalizada""" # Configurar el modelo para la nueva resolución config = ViTConfig.from_pretrained(self.model_name) # Calcular el número de patches para la nueva resolución patch_size = config.patch_size num_patches = (IMAGE_SIZE // patch_size) ** 2 # Actualizar configuración config.image_size = IMAGE_SIZE config.num_labels = self.num_classes print(f"Configuración del modelo:") print(f" - Resolución de imagen: {IMAGE_SIZE}x{IMAGE_SIZE}") print(f" - Tamaño de patch: {patch_size}x{patch_size}") print(f" - Número de patches: {num_patches}") print(f" - Número de clases: {self.num_classes}") # Cargar modelo preentrenado con nueva configuración model = ViTForImageClassification.from_pretrained( self.model_name, config=config, ignore_mismatched_sizes=True ) # Modificar la cabeza de clasificación para multi-clase model.classifier = nn.Linear(model.config.hidden_size, self.num_classes) return model.to(self.device) def _calculate_multilabel_accuracy(self, labels, preds): """Calcula la precisión para clasificación multi-etiqueta""" labels = np.array(labels) preds = np.array(preds) # Precisión exacta (todas las etiquetas deben coincidir) exact_match = np.all(labels == preds, axis=1).mean() return exact_match def _save_model(self, model, save_path): """Guarda el modelo entrenado""" os.makedirs(save_path, exist_ok=True) # Guardar modelo model.save_pretrained(save_path) # Guardar configuración del procesador personalizado processor_config = { 'target_size': IMAGE_SIZE, 'padding_color': PADDING_COLOR, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225] } with open(f'{save_path}/processor_config.json', 'w') as f: json.dump(processor_config, f, indent=2) # Guardar información de las clases class_info = { 'class_columns': self.class_columns, 'filename_column': self.filename_column, 'num_classes': self.num_classes, 'image_size': IMAGE_SIZE } with open(f'{save_path}/class_info.json', 'w') as f: json.dump(class_info, f, indent=2) print(f"Modelo guardado en: {save_path}") def _plot_training_metrics(self, train_losses, valid_losses, train_accs, valid_accs, save_path): """Plotea las métricas de entrenamiento""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) # Pérdidas epochs = range(1, len(train_losses) + 1) ax1.plot(epochs, train_losses, 'b-', label='Train Loss') if valid_losses: ax1.plot(epochs, valid_losses, 'r-', label='Valid Loss') ax1.set_title('Pérdida durante el entrenamiento') ax1.set_xlabel('Época') ax1.set_ylabel('Pérdida') ax1.legend() ax1.grid(True) # Precisión ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy') if valid_accs: ax2.plot(epochs, valid_accs, 'r-', label='Valid Accuracy') ax2.set_title('Precisión durante el entrenamiento') ax2.set_xlabel('Época') ax2.set_ylabel('Precisión') ax2.legend() ax2.grid(True) plt.tight_layout() plt.savefig(f'{save_path}/training_metrics.png', dpi=300, bbox_inches='tight') plt.show() print(f"Gráficas guardadas en: {save_path}/training_metrics.png") def train(self, epochs: int = 30, batch_size: int = 16, learning_rate: float = 1e-4, save_path: str = 'vit_multiclass_model'): """ Entrena el modelo ViT Args: epochs: Número de épocas batch_size: Tamaño del lote learning_rate: Tasa de aprendizaje save_path: Ruta donde guardar el modelo entrenado """ # Crear datasets train_dataset, valid_dataset, test_dataset = self._create_datasets() if train_dataset is None: raise ValueError("No se pudo cargar el dataset de entrenamiento") # Crear data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2 ) valid_loader = None if valid_dataset is not None: valid_loader = DataLoader( valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2 ) # Crear modelo model = self._create_model() # Optimizador y función de pérdida optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY) criterion = nn.BCEWithLogitsLoss() # Para clasificación multi-clase # Scheduler mejorado para datasets grandes total_steps = len(train_loader) * epochs warmup_steps = len(train_loader) * WARMUP_EPOCHS scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, total_steps=total_steps, pct_start=warmup_steps/total_steps, anneal_strategy='cos' ) # Métricas de entrenamiento train_losses = [] valid_losses = [] train_accuracies = [] valid_accuracies = [] # Variables para guardar el mejor modelo best_valid_acc = 0.0 best_epoch = 0 patience_counter = 0 patience = 5 # Épocas sin mejora antes de early stopping print(f"\nIniciando entrenamiento por {epochs} épocas...") print(f"Clases: {self.class_columns}") print(f"🎯 Guardado automático del mejor modelo activado") print("=" * 60) for epoch in range(epochs): # Entrenamiento model.train() train_loss = 0.0 train_preds = [] train_labels = [] train_pbar = tqdm(train_loader, desc=f'Época {epoch+1}/{epochs} - Entrenamiento') for batch_idx, (images, labels) in enumerate(train_pbar): images, labels = images.to(self.device), labels.to(self.device) optimizer.zero_grad() outputs = model(pixel_values=images).logits loss = criterion(outputs, labels) loss.backward() optimizer.step() scheduler.step() # Actualizar cada batch para OneCycleLR train_loss += loss.item() # Calcular predicciones (umbral 0.5 para multi-clase) preds = torch.sigmoid(outputs) > 0.5 train_preds.extend(preds.cpu().numpy()) train_labels.extend(labels.cpu().numpy()) train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) # Calcular métricas de entrenamiento avg_train_loss = train_loss / len(train_loader) train_acc = self._calculate_multilabel_accuracy(train_labels, train_preds) train_losses.append(avg_train_loss) train_accuracies.append(train_acc) # Validación if valid_loader is not None: model.eval() valid_loss = 0.0 valid_preds = [] valid_labels = [] with torch.no_grad(): valid_pbar = tqdm(valid_loader, desc=f'Época {epoch+1}/{epochs} - Validación') for images, labels in valid_pbar: images, labels = images.to(self.device), labels.to(self.device) outputs = model(pixel_values=images).logits loss = criterion(outputs, labels) valid_loss += loss.item() preds = torch.sigmoid(outputs) > 0.5 valid_preds.extend(preds.cpu().numpy()) valid_labels.extend(labels.cpu().numpy()) valid_pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) avg_valid_loss = valid_loss / len(valid_loader) valid_acc = self._calculate_multilabel_accuracy(valid_labels, valid_preds) valid_losses.append(avg_valid_loss) valid_accuracies.append(valid_acc) print(f'Época {epoch+1}/{epochs}:') print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') print(f' Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_acc:.4f}') # Guardar mejor modelo automáticamente if valid_acc > best_valid_acc: best_valid_acc = valid_acc best_epoch = epoch + 1 patience_counter = 0 # Guardar mejor modelo best_model_path = f"{save_path}_best" self._save_model(model, best_model_path) print(f' 🎯 ¡Nuevo mejor modelo guardado! Accuracy: {valid_acc:.4f}') else: patience_counter += 1 print(f' 📊 Mejor accuracy sigue siendo: {best_valid_acc:.4f} (época {best_epoch})') if patience_counter >= patience: print(f' ⏹️ Early stopping: {patience} épocas sin mejora') break else: print(f'Época {epoch+1}/{epochs}:') print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') current_lr = scheduler.get_last_lr()[0] print(f' Learning Rate: {current_lr:.2e}') print('-' * 60) # Guardar modelo final final_model_path = f"{save_path}_final" self._save_model(model, final_model_path) # Resumen de guardado print(f"\n📁 Modelos guardados:") if valid_loader is not None: print(f" 🎯 Mejor modelo: {save_path}_best (época {best_epoch}, acc: {best_valid_acc:.4f})") print(f" 📋 Modelo final: {final_model_path} (última época)") # Guardar métricas metrics = { 'train_losses': train_losses, 'valid_losses': valid_losses, 'train_accuracies': train_accuracies, 'valid_accuracies': valid_accuracies, 'class_columns': self.class_columns, 'filename_column': self.filename_column, 'best_valid_acc': best_valid_acc, 'best_epoch': best_epoch } with open(f'{final_model_path}/training_metrics.json', 'w') as f: json.dump(metrics, f, indent=2) # Plotear métricas self._plot_training_metrics(train_losses, valid_losses, train_accuracies, valid_accuracies, final_model_path) print("\n¡Entrenamiento completado!") print(f"Modelo guardado con resolución {IMAGE_SIZE}x{IMAGE_SIZE}") print(f"Uso de memoria optimizado con batch size {batch_size}") return model # ============================================================================ # FUNCIÓN PRINCIPAL PARA JUPYTER # ============================================================================ def train_model(): """Función principal para entrenar el modelo en Jupyter""" print("=== Entrenamiento de ViT Multi-Clasificación ===") print(f"Ruta de datos: {DATA_PATH}") print(f"Épocas: {EPOCHS}") print(f"Batch size: {BATCH_SIZE}") print(f"Learning rate: {LEARNING_RATE}") print(f"Modelo: {MODEL_NAME}") print("=" * 50) # Crear entrenador trainer = ViTMultiClassTrainer( data_path=DATA_PATH, model_name=MODEL_NAME ) # Entrenar modelo model = trainer.train( epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, save_path=SAVE_PATH ) return model # ============================================================================ # EJECUCIÓN DIRECTA PARA JUPYTER # ============================================================================ # Descomenta la siguiente línea para ejecutar directamente if __name__ == "__main__": model = train_model()