Drazcat-AI commited on
Commit
05bbddd
·
verified ·
1 Parent(s): bd025b7

Update train_categories.py

Browse files
Files changed (1) hide show
  1. train_categories.py +613 -613
train_categories.py CHANGED
@@ -1,614 +1,614 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import Dataset, DataLoader
5
- import pandas as pd
6
- import numpy as np
7
- from PIL import Image, ImageOps
8
- import torchvision.transforms as transforms
9
- import os
10
- from transformers import ViTForImageClassification, ViTConfig
11
- from sklearn.metrics import accuracy_score, classification_report
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- from tqdm import tqdm
15
- from typing import List, Tuple, Dict, Optional
16
- import json
17
- import warnings
18
- warnings.filterwarnings('ignore')
19
-
20
- # ============================================================================
21
- # CONFIGURACIÓN PARA JUPYTER NOTEBOOK
22
- # ============================================================================
23
-
24
- # CONFIGURAR ESTOS PATHS SEGÚN TU ESTRUCTURA DE DATOS
25
- DATA_PATH = "datasets/peru_cencosud_categories-2" # Cambiar por tu path de datos
26
- SAVE_PATH = "vit_multiclass_model" # Donde guardar el modelo entrenado
27
- MODEL_NAME = "google/vit-base-patch16-224" # Modelo ViT preentrenado
28
-
29
- # CONFIGURACIÓN DE IMAGEN
30
- IMAGE_SIZE = 800 # Resolución objetivo
31
- PADDING_COLOR = (128, 128, 128) # Color de padding (gris medio)
32
-
33
- # HIPERPARÁMETROS OPTIMIZADOS PARA 26K IMÁGENES / 90 CLASES
34
- EPOCHS = 30 # Más épocas por la cantidad de datos y clases
35
- BATCH_SIZE = 8 # Aumentado para mejor estabilidad
36
- LEARNING_RATE = 1e-4 # Reducido para mejor convergencia
37
- WEIGHT_DECAY = 1e-4 # Regularización
38
- WARMUP_EPOCHS = 3 # Warmup para estabilidad inicial
39
-
40
- # ============================================================================
41
- # PROCESADOR DE IMÁGENES PERSONALIZADO
42
- # ============================================================================
43
-
44
- class PaddingImageProcessor:
45
- """Procesador de imágenes personalizado que mantiene aspect ratio con padding"""
46
-
47
- def __init__(self, target_size: int = 1280, padding_color: tuple = (128, 128, 128)):
48
- """
49
- Args:
50
- target_size: Tamaño objetivo (cuadrado)
51
- padding_color: Color del padding en RGB
52
- """
53
- self.target_size = target_size
54
- self.padding_color = padding_color
55
-
56
- # Transforms para normalización (valores estándar de ImageNet)
57
- self.normalize = transforms.Normalize(
58
- mean=[0.485, 0.456, 0.406],
59
- std=[0.229, 0.224, 0.225]
60
- )
61
-
62
- def pad_to_square(self, image: Image.Image) -> Image.Image:
63
- """Aplica padding para hacer la imagen cuadrada manteniendo aspect ratio"""
64
- width, height = image.size
65
-
66
- # Determinar el tamaño del cuadrado (el lado más largo)
67
- max_size = max(width, height)
68
-
69
- # Crear imagen cuadrada con color de padding
70
- padded_image = Image.new('RGB', (max_size, max_size), self.padding_color)
71
-
72
- # Calcular posición para centrar la imagen original
73
- left = (max_size - width) // 2
74
- top = (max_size - height) // 2
75
-
76
- # Pegar la imagen original en el centro
77
- padded_image.paste(image, (left, top))
78
-
79
- return padded_image
80
-
81
- def __call__(self, image: Image.Image) -> torch.Tensor:
82
- """
83
- Procesa una imagen aplicando padding + resize
84
-
85
- Args:
86
- image: Imagen PIL en formato RGB
87
-
88
- Returns:
89
- Tensor procesado listo para el modelo
90
- """
91
- # 1. Aplicar padding para hacer cuadrada
92
- padded_image = self.pad_to_square(image)
93
-
94
- # 2. Resize a la resolución objetivo manteniendo aspect ratio (ya es cuadrada)
95
- resized_image = padded_image.resize((self.target_size, self.target_size), Image.Resampling.LANCZOS)
96
-
97
- # 3. Convertir a tensor y normalizar
98
- # Convertir PIL a tensor [0, 1]
99
- transform_to_tensor = transforms.ToTensor()
100
- tensor_image = transform_to_tensor(resized_image)
101
-
102
- # 4. Normalizar con valores de ImageNet
103
- normalized_image = self.normalize(tensor_image)
104
-
105
- return normalized_image
106
-
107
- # ============================================================================
108
- # DATASET PERSONALIZADO
109
- # ============================================================================
110
-
111
- class MultiClassImageDataset(Dataset):
112
- """Dataset personalizado para clasificación multi-clase de imágenes"""
113
-
114
- def __init__(self, csv_path: str, images_dir: str, image_processor: PaddingImageProcessor,
115
- class_columns: List[str], filename_column: str):
116
- """
117
- Args:
118
- csv_path: Ruta al archivo CSV con las anotaciones
119
- images_dir: Directorio que contiene las imágenes
120
- image_processor: Procesador personalizado de imágenes
121
- class_columns: Lista de nombres de columnas que representan las clases
122
- filename_column: Nombre de la columna que contiene los nombres de archivos
123
- """
124
- self.df = pd.read_csv(csv_path)
125
- self.images_dir = images_dir
126
- self.image_processor = image_processor
127
- self.class_columns = class_columns
128
- self.filename_column = filename_column
129
-
130
- print(f"Dataset cargado desde {csv_path}: {len(self.df)} imágenes")
131
- print(f"Columnas de clases: {class_columns}")
132
-
133
- def __len__(self):
134
- return len(self.df)
135
-
136
- def __getitem__(self, idx):
137
- row = self.df.iloc[idx]
138
-
139
- # Cargar imagen usando la columna de filename detectada
140
- img_path = os.path.join(self.images_dir, row[self.filename_column])
141
- try:
142
- image = Image.open(img_path).convert('RGB')
143
- except Exception as e:
144
- print(f"Error cargando imagen {img_path}: {e}")
145
- # Crear imagen dummy si hay error
146
- image = Image.new('RGB', (224, 224), color='black')
147
-
148
- # Procesar imagen con padding + resize personalizado
149
- processed_image = self.image_processor(image)
150
-
151
- # Crear tensor de etiquetas multi-clase
152
- labels = torch.tensor([row[col] for col in self.class_columns], dtype=torch.float32)
153
-
154
- return processed_image, labels
155
-
156
- # ============================================================================
157
- # ENTRENADOR ViT
158
- # ============================================================================
159
-
160
- class ViTMultiClassTrainer:
161
- """Entrenador para ViT con clasificación multi-clase"""
162
-
163
- def __init__(self, data_path: str, model_name: str = "google/vit-base-patch16-224"):
164
- """
165
- Args:
166
- data_path: Ruta base donde están los directorios train/valid/test
167
- model_name: Nombre del modelo ViT preentrenado
168
- """
169
- self.data_path = data_path
170
- self.model_name = model_name
171
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
172
- print(f"Usando dispositivo: {self.device}")
173
-
174
- # Inicializar procesador personalizado
175
- self.image_processor = PaddingImageProcessor(
176
- target_size=IMAGE_SIZE,
177
- padding_color=PADDING_COLOR
178
- )
179
- print(f"Procesador de imágenes configurado: {IMAGE_SIZE}px con padding {PADDING_COLOR}")
180
-
181
- # Detectar estructura de datos automáticamente
182
- self._detect_data_structure()
183
-
184
- def _find_csv_in_folder(self, folder_path: str) -> Optional[str]:
185
- """Busca el archivo CSV en una carpeta específica"""
186
- if not os.path.exists(folder_path):
187
- return None
188
-
189
- csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]
190
-
191
- if len(csv_files) == 0:
192
- print(f"No se encontró CSV en {folder_path}")
193
- return None
194
- elif len(csv_files) == 1:
195
- csv_path = os.path.join(folder_path, csv_files[0])
196
- print(f"CSV encontrado: {csv_path}")
197
- return csv_path
198
- else:
199
- # Si hay múltiples CSVs, tomar el primero
200
- csv_path = os.path.join(folder_path, csv_files[0])
201
- print(f"Múltiples CSVs en {folder_path}, usando: {csv_files[0]}")
202
- return csv_path
203
-
204
- def _detect_filename_column(self, df: pd.DataFrame) -> str:
205
- """Detecta la columna que contiene los nombres de archivos"""
206
- possible_names = ['filename', 'image', 'image_name', 'file', 'name', 'img']
207
-
208
- for col in possible_names:
209
- if col in df.columns:
210
- return col
211
-
212
- # Si no encuentra ninguna, usar la primera columna
213
- print(f"No se encontró columna de filename conocida. Usando: {df.columns[0]}")
214
- return df.columns[0]
215
-
216
- def _detect_data_structure(self):
217
- """Detecta automáticamente la estructura de datos y clases"""
218
- print("Detectando estructura de datos...")
219
-
220
- # Buscar CSV en carpeta de entrenamiento
221
- train_folder = os.path.join(self.data_path, 'train')
222
- train_csv = self._find_csv_in_folder(train_folder)
223
-
224
- if train_csv is None:
225
- raise FileNotFoundError(f"No se encontró CSV en {train_folder}")
226
-
227
- # Cargar CSV para detectar columnas
228
- df = pd.read_csv(train_csv)
229
- print(f"Columnas encontradas: {list(df.columns)}")
230
-
231
- # Detectar columna de filename
232
- self.filename_column = self._detect_filename_column(df)
233
- print(f"Columna de archivos detectada: {self.filename_column}")
234
-
235
- # Las demás columnas son las clases
236
- self.class_columns = [col for col in df.columns if col != self.filename_column]
237
- self.num_classes = len(self.class_columns)
238
-
239
- if self.num_classes == 0:
240
- raise ValueError("No se encontraron columnas de clases")
241
-
242
- print(f"Clases detectadas ({self.num_classes}): {self.class_columns}")
243
-
244
- # Verificar otras carpetas
245
- for split in ['valid', 'test']:
246
- split_folder = os.path.join(self.data_path, split)
247
- if os.path.exists(split_folder):
248
- csv_path = self._find_csv_in_folder(split_folder)
249
- if csv_path:
250
- print(f"Carpeta {split}: CSV encontrado")
251
- else:
252
- print(f"Carpeta {split}: Sin CSV")
253
- else:
254
- print(f"Carpeta {split}: No existe")
255
-
256
- def _create_datasets(self) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]:
257
- """Crea los datasets de entrenamiento, validación y prueba"""
258
- datasets = {}
259
-
260
- for split in ['train', 'valid', 'test']:
261
- split_folder = os.path.join(self.data_path, split)
262
- csv_path = self._find_csv_in_folder(split_folder)
263
-
264
- if csv_path is not None:
265
- datasets[split] = MultiClassImageDataset(
266
- csv_path=csv_path,
267
- images_dir=split_folder,
268
- image_processor=self.image_processor,
269
- class_columns=self.class_columns,
270
- filename_column=self.filename_column
271
- )
272
- else:
273
- datasets[split] = None
274
-
275
- return datasets.get('train'), datasets.get('valid'), datasets.get('test')
276
-
277
- def _create_model(self):
278
- """Crea el modelo ViT para clasificación multi-clase con resolución personalizada"""
279
- # Configurar el modelo para la nueva resolución
280
- config = ViTConfig.from_pretrained(self.model_name)
281
-
282
- # Calcular el número de patches para la nueva resolución
283
- patch_size = config.patch_size
284
- num_patches = (IMAGE_SIZE // patch_size) ** 2
285
-
286
- # Actualizar configuración
287
- config.image_size = IMAGE_SIZE
288
- config.num_labels = self.num_classes
289
-
290
- print(f"Configuración del modelo:")
291
- print(f" - Resolución de imagen: {IMAGE_SIZE}x{IMAGE_SIZE}")
292
- print(f" - Tamaño de patch: {patch_size}x{patch_size}")
293
- print(f" - Número de patches: {num_patches}")
294
- print(f" - Número de clases: {self.num_classes}")
295
-
296
- # Cargar modelo preentrenado con nueva configuración
297
- model = ViTForImageClassification.from_pretrained(
298
- self.model_name,
299
- config=config,
300
- ignore_mismatched_sizes=True
301
- )
302
-
303
- # Modificar la cabeza de clasificación para multi-clase
304
- model.classifier = nn.Linear(model.config.hidden_size, self.num_classes)
305
-
306
- return model.to(self.device)
307
-
308
- def _calculate_multilabel_accuracy(self, labels, preds):
309
- """Calcula la precisión para clasificación multi-etiqueta"""
310
- labels = np.array(labels)
311
- preds = np.array(preds)
312
-
313
- # Precisión exacta (todas las etiquetas deben coincidir)
314
- exact_match = np.all(labels == preds, axis=1).mean()
315
- return exact_match
316
-
317
- def _save_model(self, model, save_path):
318
- """Guarda el modelo entrenado"""
319
- os.makedirs(save_path, exist_ok=True)
320
-
321
- # Guardar modelo
322
- model.save_pretrained(save_path)
323
-
324
- # Guardar configuración del procesador personalizado
325
- processor_config = {
326
- 'target_size': IMAGE_SIZE,
327
- 'padding_color': PADDING_COLOR,
328
- 'mean': [0.485, 0.456, 0.406],
329
- 'std': [0.229, 0.224, 0.225]
330
- }
331
-
332
- with open(f'{save_path}/processor_config.json', 'w') as f:
333
- json.dump(processor_config, f, indent=2)
334
-
335
- # Guardar información de las clases
336
- class_info = {
337
- 'class_columns': self.class_columns,
338
- 'filename_column': self.filename_column,
339
- 'num_classes': self.num_classes,
340
- 'image_size': IMAGE_SIZE
341
- }
342
-
343
- with open(f'{save_path}/class_info.json', 'w') as f:
344
- json.dump(class_info, f, indent=2)
345
-
346
- print(f"Modelo guardado en: {save_path}")
347
-
348
- def _plot_training_metrics(self, train_losses, valid_losses, train_accs, valid_accs, save_path):
349
- """Plotea las métricas de entrenamiento"""
350
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
351
-
352
- # Pérdidas
353
- epochs = range(1, len(train_losses) + 1)
354
- ax1.plot(epochs, train_losses, 'b-', label='Train Loss')
355
- if valid_losses:
356
- ax1.plot(epochs, valid_losses, 'r-', label='Valid Loss')
357
- ax1.set_title('Pérdida durante el entrenamiento')
358
- ax1.set_xlabel('Época')
359
- ax1.set_ylabel('Pérdida')
360
- ax1.legend()
361
- ax1.grid(True)
362
-
363
- # Precisión
364
- ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy')
365
- if valid_accs:
366
- ax2.plot(epochs, valid_accs, 'r-', label='Valid Accuracy')
367
- ax2.set_title('Precisión durante el entrenamiento')
368
- ax2.set_xlabel('Época')
369
- ax2.set_ylabel('Precisión')
370
- ax2.legend()
371
- ax2.grid(True)
372
-
373
- plt.tight_layout()
374
- plt.savefig(f'{save_path}/training_metrics.png', dpi=300, bbox_inches='tight')
375
- plt.show()
376
-
377
- print(f"Gráficas guardadas en: {save_path}/training_metrics.png")
378
-
379
- def train(self,
380
- epochs: int = 30,
381
- batch_size: int = 16,
382
- learning_rate: float = 1e-4,
383
- save_path: str = 'vit_multiclass_model'):
384
- """
385
- Entrena el modelo ViT
386
-
387
- Args:
388
- epochs: Número de épocas
389
- batch_size: Tamaño del lote
390
- learning_rate: Tasa de aprendizaje
391
- save_path: Ruta donde guardar el modelo entrenado
392
- """
393
-
394
- # Crear datasets
395
- train_dataset, valid_dataset, test_dataset = self._create_datasets()
396
-
397
- if train_dataset is None:
398
- raise ValueError("No se pudo cargar el dataset de entrenamiento")
399
-
400
- # Crear data loaders
401
- train_loader = DataLoader(
402
- train_dataset,
403
- batch_size=batch_size,
404
- shuffle=True,
405
- num_workers=2
406
- )
407
-
408
- valid_loader = None
409
- if valid_dataset is not None:
410
- valid_loader = DataLoader(
411
- valid_dataset,
412
- batch_size=batch_size,
413
- shuffle=False,
414
- num_workers=2
415
- )
416
-
417
- # Crear modelo
418
- model = self._create_model()
419
-
420
- # Optimizador y función de pérdida
421
- optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY)
422
- criterion = nn.BCEWithLogitsLoss() # Para clasificación multi-clase
423
-
424
- # Scheduler mejorado para datasets grandes
425
- total_steps = len(train_loader) * epochs
426
- warmup_steps = len(train_loader) * WARMUP_EPOCHS
427
-
428
- scheduler = optim.lr_scheduler.OneCycleLR(
429
- optimizer,
430
- max_lr=learning_rate,
431
- total_steps=total_steps,
432
- pct_start=warmup_steps/total_steps,
433
- anneal_strategy='cos'
434
- )
435
-
436
- # Métricas de entrenamiento
437
- train_losses = []
438
- valid_losses = []
439
- train_accuracies = []
440
- valid_accuracies = []
441
-
442
- # Variables para guardar el mejor modelo
443
- best_valid_acc = 0.0
444
- best_epoch = 0
445
- patience_counter = 0
446
- patience = 5 # Épocas sin mejora antes de early stopping
447
-
448
- print(f"\nIniciando entrenamiento por {epochs} épocas...")
449
- print(f"Clases: {self.class_columns}")
450
- print(f"🎯 Guardado automático del mejor modelo activado")
451
- print("=" * 60)
452
-
453
- for epoch in range(epochs):
454
- # Entrenamiento
455
- model.train()
456
- train_loss = 0.0
457
- train_preds = []
458
- train_labels = []
459
-
460
- train_pbar = tqdm(train_loader, desc=f'Época {epoch+1}/{epochs} - Entrenamiento')
461
- for batch_idx, (images, labels) in enumerate(train_pbar):
462
- images, labels = images.to(self.device), labels.to(self.device)
463
-
464
- optimizer.zero_grad()
465
- outputs = model(pixel_values=images).logits
466
- loss = criterion(outputs, labels)
467
- loss.backward()
468
- optimizer.step()
469
- scheduler.step() # Actualizar cada batch para OneCycleLR
470
-
471
- train_loss += loss.item()
472
-
473
- # Calcular predicciones (umbral 0.5 para multi-clase)
474
- preds = torch.sigmoid(outputs) > 0.5
475
- train_preds.extend(preds.cpu().numpy())
476
- train_labels.extend(labels.cpu().numpy())
477
-
478
- train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
479
-
480
- # Calcular métricas de entrenamiento
481
- avg_train_loss = train_loss / len(train_loader)
482
- train_acc = self._calculate_multilabel_accuracy(train_labels, train_preds)
483
-
484
- train_losses.append(avg_train_loss)
485
- train_accuracies.append(train_acc)
486
-
487
- # Validación
488
- if valid_loader is not None:
489
- model.eval()
490
- valid_loss = 0.0
491
- valid_preds = []
492
- valid_labels = []
493
-
494
- with torch.no_grad():
495
- valid_pbar = tqdm(valid_loader, desc=f'Época {epoch+1}/{epochs} - Validación')
496
- for images, labels in valid_pbar:
497
- images, labels = images.to(self.device), labels.to(self.device)
498
-
499
- outputs = model(pixel_values=images).logits
500
- loss = criterion(outputs, labels)
501
-
502
- valid_loss += loss.item()
503
-
504
- preds = torch.sigmoid(outputs) > 0.5
505
- valid_preds.extend(preds.cpu().numpy())
506
- valid_labels.extend(labels.cpu().numpy())
507
-
508
- valid_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
509
-
510
- avg_valid_loss = valid_loss / len(valid_loader)
511
- valid_acc = self._calculate_multilabel_accuracy(valid_labels, valid_preds)
512
-
513
- valid_losses.append(avg_valid_loss)
514
- valid_accuracies.append(valid_acc)
515
-
516
- print(f'Época {epoch+1}/{epochs}:')
517
- print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}')
518
- print(f' Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_acc:.4f}')
519
-
520
- # Guardar mejor modelo automáticamente
521
- if valid_acc > best_valid_acc:
522
- best_valid_acc = valid_acc
523
- best_epoch = epoch + 1
524
- patience_counter = 0
525
-
526
- # Guardar mejor modelo
527
- best_model_path = f"{save_path}_best"
528
- self._save_model(model, best_model_path)
529
- print(f' 🎯 ¡Nuevo mejor modelo guardado! Accuracy: {valid_acc:.4f}')
530
- else:
531
- patience_counter += 1
532
- print(f' 📊 Mejor accuracy sigue siendo: {best_valid_acc:.4f} (época {best_epoch})')
533
- if patience_counter >= patience:
534
- print(f' ⏹️ Early stopping: {patience} épocas sin mejora')
535
- break
536
- else:
537
- print(f'Época {epoch+1}/{epochs}:')
538
- print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}')
539
-
540
- current_lr = scheduler.get_last_lr()[0]
541
- print(f' Learning Rate: {current_lr:.2e}')
542
- print('-' * 60)
543
-
544
- # Guardar modelo final
545
- final_model_path = f"{save_path}_final"
546
- self._save_model(model, final_model_path)
547
-
548
- # Resumen de guardado
549
- print(f"\n📁 Modelos guardados:")
550
- if valid_loader is not None:
551
- print(f" 🎯 Mejor modelo: {save_path}_best (época {best_epoch}, acc: {best_valid_acc:.4f})")
552
- print(f" 📋 Modelo final: {final_model_path} (última época)")
553
-
554
- # Guardar métricas
555
- metrics = {
556
- 'train_losses': train_losses,
557
- 'valid_losses': valid_losses,
558
- 'train_accuracies': train_accuracies,
559
- 'valid_accuracies': valid_accuracies,
560
- 'class_columns': self.class_columns,
561
- 'filename_column': self.filename_column,
562
- 'best_valid_acc': best_valid_acc,
563
- 'best_epoch': best_epoch
564
- }
565
-
566
- with open(f'{final_model_path}/training_metrics.json', 'w') as f:
567
- json.dump(metrics, f, indent=2)
568
-
569
- # Plotear métricas
570
- self._plot_training_metrics(train_losses, valid_losses, train_accuracies, valid_accuracies, final_model_path)
571
-
572
- print("\n¡Entrenamiento completado!")
573
- print(f"Modelo guardado con resolución {IMAGE_SIZE}x{IMAGE_SIZE}")
574
- print(f"Uso de memoria optimizado con batch size {batch_size}")
575
- return model
576
-
577
- # ============================================================================
578
- # FUNCIÓN PRINCIPAL PARA JUPYTER
579
- # ============================================================================
580
-
581
- def train_model():
582
- """Función principal para entrenar el modelo en Jupyter"""
583
-
584
- print("=== Entrenamiento de ViT Multi-Clasificación ===")
585
- print(f"Ruta de datos: {DATA_PATH}")
586
- print(f"Épocas: {EPOCHS}")
587
- print(f"Batch size: {BATCH_SIZE}")
588
- print(f"Learning rate: {LEARNING_RATE}")
589
- print(f"Modelo: {MODEL_NAME}")
590
- print("=" * 50)
591
-
592
- # Crear entrenador
593
- trainer = ViTMultiClassTrainer(
594
- data_path=DATA_PATH,
595
- model_name=MODEL_NAME
596
- )
597
-
598
- # Entrenar modelo
599
- model = trainer.train(
600
- epochs=EPOCHS,
601
- batch_size=BATCH_SIZE,
602
- learning_rate=LEARNING_RATE,
603
- save_path=SAVE_PATH
604
- )
605
-
606
- return model
607
-
608
- # ============================================================================
609
- # EJECUCIÓN DIRECTA PARA JUPYTER
610
- # ============================================================================
611
-
612
- # Descomenta la siguiente línea para ejecutar directamente
613
- if __name__ == "__main__":
614
  model = train_model()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import pandas as pd
6
+ import numpy as np
7
+ from PIL import Image, ImageOps
8
+ import torchvision.transforms as transforms
9
+ import os
10
+ from transformers import ViTForImageClassification, ViTConfig
11
+ from sklearn.metrics import accuracy_score, classification_report
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from tqdm import tqdm
15
+ from typing import List, Tuple, Dict, Optional
16
+ import json
17
+ import warnings
18
+ warnings.filterwarnings('ignore')
19
+
20
+ # ============================================================================
21
+ # CONFIGURACIÓN PARA JUPYTER NOTEBOOK
22
+ # ============================================================================
23
+
24
+ # CONFIGURAR ESTOS PATHS SEGÚN TU ESTRUCTURA DE DATOS
25
+ DATA_PATH = "datasets/peru_cencosud_categories-2" # Cambiar por tu path de datos
26
+ SAVE_PATH = "vit_multiclass_model" # Donde guardar el modelo entrenado
27
+ MODEL_NAME = "google/vit-base-patch16-224" # Modelo ViT preentrenado
28
+
29
+ # CONFIGURACIÓN DE IMAGEN
30
+ IMAGE_SIZE = 800 # Resolución objetivo
31
+ PADDING_COLOR = (128, 128, 128) # Color de padding (gris medio)
32
+
33
+ # HIPERPARÁMETROS OPTIMIZADOS PARA 26K IMÁGENES / 90 CLASES
34
+ EPOCHS = 30 # Más épocas por la cantidad de datos y clases
35
+ BATCH_SIZE = 8 # Aumentado para mejor estabilidad
36
+ LEARNING_RATE = 1e-4 # Reducido para mejor convergencia
37
+ WEIGHT_DECAY = 1e-4 # Regularización
38
+ WARMUP_EPOCHS = 3 # Warmup para estabilidad inicial
39
+
40
+ # ============================================================================
41
+ # PROCESADOR DE IMÁGENES PERSONALIZADO
42
+ # ============================================================================
43
+
44
+ class PaddingImageProcessor:
45
+ """Procesador de imágenes personalizado que mantiene aspect ratio con padding"""
46
+
47
+ def __init__(self, target_size: int = 1280, padding_color: tuple = (128, 128, 128)):
48
+ """
49
+ Args:
50
+ target_size: Tamaño objetivo (cuadrado)
51
+ padding_color: Color del padding en RGB
52
+ """
53
+ self.target_size = target_size
54
+ self.padding_color = padding_color
55
+
56
+ # Transforms para normalización (valores estándar de ImageNet)
57
+ self.normalize = transforms.Normalize(
58
+ mean=[0.485, 0.456, 0.406],
59
+ std=[0.229, 0.224, 0.225]
60
+ )
61
+
62
+ def pad_to_square(self, image: Image.Image) -> Image.Image:
63
+ """Aplica padding para hacer la imagen cuadrada manteniendo aspect ratio"""
64
+ width, height = image.size
65
+
66
+ # Determinar el tamaño del cuadrado (el lado más largo)
67
+ max_size = max(width, height)
68
+
69
+ # Crear imagen cuadrada con color de padding
70
+ padded_image = Image.new('RGB', (max_size, max_size), self.padding_color)
71
+
72
+ # Calcular posición para centrar la imagen original
73
+ left = (max_size - width) // 2
74
+ top = (max_size - height) // 2
75
+
76
+ # Pegar la imagen original en el centro
77
+ padded_image.paste(image, (left, top))
78
+
79
+ return padded_image
80
+
81
+ def __call__(self, image: Image.Image) -> torch.Tensor:
82
+ """
83
+ Procesa una imagen aplicando padding + resize
84
+
85
+ Args:
86
+ image: Imagen PIL en formato RGB
87
+
88
+ Returns:
89
+ Tensor procesado listo para el modelo
90
+ """
91
+ # 1. Aplicar padding para hacer cuadrada
92
+ padded_image = self.pad_to_square(image)
93
+
94
+ # 2. Resize a la resolución objetivo manteniendo aspect ratio (ya es cuadrada)
95
+ resized_image = padded_image.resize((self.target_size, self.target_size), Image.Resampling.LANCZOS)
96
+
97
+ # 3. Convertir a tensor y normalizar
98
+ # Convertir PIL a tensor [0, 1]
99
+ transform_to_tensor = transforms.ToTensor()
100
+ tensor_image = transform_to_tensor(resized_image)
101
+
102
+ # 4. Normalizar con valores de ImageNet
103
+ normalized_image = self.normalize(tensor_image)
104
+
105
+ return normalized_image
106
+
107
+ # ============================================================================
108
+ # DATASET PERSONALIZADO
109
+ # ============================================================================
110
+
111
+ class MultiClassImageDataset(Dataset):
112
+ """Dataset personalizado para clasificación multi-clase de imágenes"""
113
+
114
+ def __init__(self, csv_path: str, images_dir: str, image_processor: PaddingImageProcessor,
115
+ class_columns: List[str], filename_column: str):
116
+ """
117
+ Args:
118
+ csv_path: Ruta al archivo CSV con las anotaciones
119
+ images_dir: Directorio que contiene las imágenes
120
+ image_processor: Procesador personalizado de imágenes
121
+ class_columns: Lista de nombres de columnas que representan las clases
122
+ filename_column: Nombre de la columna que contiene los nombres de archivos
123
+ """
124
+ self.df = pd.read_csv(csv_path)
125
+ self.images_dir = images_dir
126
+ self.image_processor = image_processor
127
+ self.class_columns = class_columns
128
+ self.filename_column = filename_column
129
+
130
+ print(f"Dataset cargado desde {csv_path}: {len(self.df)} imágenes")
131
+ print(f"Columnas de clases: {class_columns}")
132
+
133
+ def __len__(self):
134
+ return len(self.df)
135
+
136
+ def __getitem__(self, idx):
137
+ row = self.df.iloc[idx]
138
+
139
+ # Cargar imagen usando la columna de filename detectada
140
+ img_path = os.path.join(self.images_dir, row[self.filename_column])
141
+ try:
142
+ image = Image.open(img_path).convert('RGB')
143
+ except Exception as e:
144
+ print(f"Error cargando imagen {img_path}: {e}")
145
+ # Crear imagen dummy si hay error
146
+ image = Image.new('RGB', (224, 224), color='black')
147
+
148
+ # Procesar imagen con padding + resize personalizado
149
+ processed_image = self.image_processor(image)
150
+
151
+ # Crear tensor de etiquetas multi-clase
152
+ labels = torch.tensor([row[col] for col in self.class_columns], dtype=torch.float32)
153
+
154
+ return processed_image, labels
155
+
156
+ # ============================================================================
157
+ # ENTRENADOR ViT
158
+ # ============================================================================
159
+
160
+ class ViTMultiClassTrainer:
161
+ """Entrenador para ViT con clasificación multi-clase"""
162
+
163
+ def __init__(self, data_path: str, model_name: str = "google/vit-base-patch16-224"):
164
+ """
165
+ Args:
166
+ data_path: Ruta base donde están los directorios train/valid/test
167
+ model_name: Nombre del modelo ViT preentrenado
168
+ """
169
+ self.data_path = data_path
170
+ self.model_name = model_name
171
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
172
+ print(f"Usando dispositivo: {self.device}")
173
+
174
+ # Inicializar procesador personalizado
175
+ self.image_processor = PaddingImageProcessor(
176
+ target_size=IMAGE_SIZE,
177
+ padding_color=PADDING_COLOR
178
+ )
179
+ print(f"Procesador de imágenes configurado: {IMAGE_SIZE}px con padding {PADDING_COLOR}")
180
+
181
+ # Detectar estructura de datos automáticamente
182
+ self._detect_data_structure()
183
+
184
+ def _find_csv_in_folder(self, folder_path: str) -> Optional[str]:
185
+ """Busca el archivo CSV en una carpeta específica"""
186
+ if not os.path.exists(folder_path):
187
+ return None
188
+
189
+ csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]
190
+
191
+ if len(csv_files) == 0:
192
+ print(f"No se encontró CSV en {folder_path}")
193
+ return None
194
+ elif len(csv_files) == 1:
195
+ csv_path = os.path.join(folder_path, csv_files[0])
196
+ print(f"CSV encontrado: {csv_path}")
197
+ return csv_path
198
+ else:
199
+ # Si hay múltiples CSVs, tomar el primero
200
+ csv_path = os.path.join(folder_path, csv_files[0])
201
+ print(f"Múltiples CSVs en {folder_path}, usando: {csv_files[0]}")
202
+ return csv_path
203
+
204
+ def _detect_filename_column(self, df: pd.DataFrame) -> str:
205
+ """Detecta la columna que contiene los nombres de archivos"""
206
+ possible_names = ['filename', 'image', 'image_name', 'file', 'name', 'img']
207
+
208
+ for col in possible_names:
209
+ if col in df.columns:
210
+ return col
211
+
212
+ # Si no encuentra ninguna, usar la primera columna
213
+ print(f"No se encontró columna de filename conocida. Usando: {df.columns[0]}")
214
+ return df.columns[0]
215
+
216
+ def _detect_data_structure(self):
217
+ """Detecta automáticamente la estructura de datos y clases"""
218
+ print("Detectando estructura de datos...")
219
+
220
+ # Buscar CSV en carpeta de entrenamiento
221
+ train_folder = os.path.join(self.data_path, 'train')
222
+ train_csv = self._find_csv_in_folder(train_folder)
223
+
224
+ if train_csv is None:
225
+ raise FileNotFoundError(f"No se encontró CSV en {train_folder}")
226
+
227
+ # Cargar CSV para detectar columnas
228
+ df = pd.read_csv(train_csv)
229
+ print(f"Columnas encontradas: {list(df.columns)}")
230
+
231
+ # Detectar columna de filename
232
+ self.filename_column = self._detect_filename_column(df)
233
+ print(f"Columna de archivos detectada: {self.filename_column}")
234
+
235
+ # Las demás columnas son las clases
236
+ self.class_columns = [col for col in df.columns if col != self.filename_column]
237
+ self.num_classes = len(self.class_columns)
238
+
239
+ if self.num_classes == 0:
240
+ raise ValueError("No se encontraron columnas de clases")
241
+
242
+ print(f"Clases detectadas ({self.num_classes}): {self.class_columns}")
243
+
244
+ # Verificar otras carpetas
245
+ for split in ['valid', 'test']:
246
+ split_folder = os.path.join(self.data_path, split)
247
+ if os.path.exists(split_folder):
248
+ csv_path = self._find_csv_in_folder(split_folder)
249
+ if csv_path:
250
+ print(f"Carpeta {split}: CSV encontrado")
251
+ else:
252
+ print(f"Carpeta {split}: Sin CSV")
253
+ else:
254
+ print(f"Carpeta {split}: No existe")
255
+
256
+ def _create_datasets(self) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]:
257
+ """Crea los datasets de entrenamiento, validación y prueba"""
258
+ datasets = {}
259
+
260
+ for split in ['train', 'valid', 'test']:
261
+ split_folder = os.path.join(self.data_path, split)
262
+ csv_path = self._find_csv_in_folder(split_folder)
263
+
264
+ if csv_path is not None:
265
+ datasets[split] = MultiClassImageDataset(
266
+ csv_path=csv_path,
267
+ images_dir=split_folder,
268
+ image_processor=self.image_processor,
269
+ class_columns=self.class_columns,
270
+ filename_column=self.filename_column
271
+ )
272
+ else:
273
+ datasets[split] = None
274
+
275
+ return datasets.get('train'), datasets.get('valid'), datasets.get('test')
276
+
277
+ def _create_model(self):
278
+ """Crea el modelo ViT para clasificación multi-clase con resolución personalizada"""
279
+ # Configurar el modelo para la nueva resolución
280
+ config = ViTConfig.from_pretrained(self.model_name)
281
+
282
+ # Calcular el número de patches para la nueva resolución
283
+ patch_size = config.patch_size
284
+ num_patches = (IMAGE_SIZE // patch_size) ** 2
285
+
286
+ # Actualizar configuración
287
+ config.image_size = IMAGE_SIZE
288
+ config.num_labels = self.num_classes
289
+
290
+ print(f"Configuración del modelo:")
291
+ print(f" - Resolución de imagen: {IMAGE_SIZE}x{IMAGE_SIZE}")
292
+ print(f" - Tamaño de patch: {patch_size}x{patch_size}")
293
+ print(f" - Número de patches: {num_patches}")
294
+ print(f" - Número de clases: {self.num_classes}")
295
+
296
+ # Cargar modelo preentrenado con nueva configuración
297
+ model = ViTForImageClassification.from_pretrained(
298
+ self.model_name,
299
+ config=config,
300
+ ignore_mismatched_sizes=True
301
+ )
302
+
303
+ # Modificar la cabeza de clasificación para multi-clase
304
+ model.classifier = nn.Linear(model.config.hidden_size, self.num_classes)
305
+
306
+ return model.to(self.device)
307
+
308
+ def _calculate_multilabel_accuracy(self, labels, preds):
309
+ """Calcula la precisión para clasificación multi-etiqueta"""
310
+ labels = np.array(labels)
311
+ preds = np.array(preds)
312
+
313
+ # Precisión exacta (todas las etiquetas deben coincidir)
314
+ exact_match = np.all(labels == preds, axis=1).mean()
315
+ return exact_match
316
+
317
+ def _save_model(self, model, save_path):
318
+ """Guarda el modelo entrenado"""
319
+ os.makedirs(save_path, exist_ok=True)
320
+
321
+ # Guardar modelo
322
+ model.save_pretrained(save_path)
323
+
324
+ # Guardar configuración del procesador personalizado
325
+ processor_config = {
326
+ 'target_size': IMAGE_SIZE,
327
+ 'padding_color': PADDING_COLOR,
328
+ 'mean': [0.485, 0.456, 0.406],
329
+ 'std': [0.229, 0.224, 0.225]
330
+ }
331
+
332
+ with open(f'{save_path}/processor_config.json', 'w') as f:
333
+ json.dump(processor_config, f, indent=2)
334
+
335
+ # Guardar información de las clases
336
+ class_info = {
337
+ 'class_columns': self.class_columns,
338
+ 'filename_column': self.filename_column,
339
+ 'num_classes': self.num_classes,
340
+ 'image_size': IMAGE_SIZE
341
+ }
342
+
343
+ with open(f'{save_path}/class_info.json', 'w') as f:
344
+ json.dump(class_info, f, indent=2)
345
+
346
+ print(f"Modelo guardado en: {save_path}")
347
+
348
+ def _plot_training_metrics(self, train_losses, valid_losses, train_accs, valid_accs, save_path):
349
+ """Plotea las métricas de entrenamiento"""
350
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
351
+
352
+ # Pérdidas
353
+ epochs = range(1, len(train_losses) + 1)
354
+ ax1.plot(epochs, train_losses, 'b-', label='Train Loss')
355
+ if valid_losses:
356
+ ax1.plot(epochs, valid_losses, 'r-', label='Valid Loss')
357
+ ax1.set_title('Pérdida durante el entrenamiento')
358
+ ax1.set_xlabel('Época')
359
+ ax1.set_ylabel('Pérdida')
360
+ ax1.legend()
361
+ ax1.grid(True)
362
+
363
+ # Precisión
364
+ ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy')
365
+ if valid_accs:
366
+ ax2.plot(epochs, valid_accs, 'r-', label='Valid Accuracy')
367
+ ax2.set_title('Precisión durante el entrenamiento')
368
+ ax2.set_xlabel('Época')
369
+ ax2.set_ylabel('Precisión')
370
+ ax2.legend()
371
+ ax2.grid(True)
372
+
373
+ plt.tight_layout()
374
+ plt.savefig(f'{save_path}/training_metrics.png', dpi=300, bbox_inches='tight')
375
+ plt.show()
376
+
377
+ print(f"Gráficas guardadas en: {save_path}/training_metrics.png")
378
+
379
+ def train(self,
380
+ epochs: int = 30,
381
+ batch_size: int = 16,
382
+ learning_rate: float = 1e-4,
383
+ save_path: str = 'vit_multiclass_model'):
384
+ """
385
+ Entrena el modelo ViT
386
+
387
+ Args:
388
+ epochs: Número de épocas
389
+ batch_size: Tamaño del lote
390
+ learning_rate: Tasa de aprendizaje
391
+ save_path: Ruta donde guardar el modelo entrenado
392
+ """
393
+
394
+ # Crear datasets
395
+ train_dataset, valid_dataset, test_dataset = self._create_datasets()
396
+
397
+ if train_dataset is None:
398
+ raise ValueError("No se pudo cargar el dataset de entrenamiento")
399
+
400
+ # Crear data loaders
401
+ train_loader = DataLoader(
402
+ train_dataset,
403
+ batch_size=batch_size,
404
+ shuffle=True,
405
+ num_workers=2
406
+ )
407
+
408
+ valid_loader = None
409
+ if valid_dataset is not None:
410
+ valid_loader = DataLoader(
411
+ valid_dataset,
412
+ batch_size=batch_size,
413
+ shuffle=False,
414
+ num_workers=2
415
+ )
416
+
417
+ # Crear modelo
418
+ model = self._create_model()
419
+
420
+ # Optimizador y función de pérdida
421
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY)
422
+ criterion = nn.BCEWithLogitsLoss() # Para clasificación multi-clase
423
+
424
+ # Scheduler mejorado para datasets grandes
425
+ total_steps = len(train_loader) * epochs
426
+ warmup_steps = len(train_loader) * WARMUP_EPOCHS
427
+
428
+ scheduler = optim.lr_scheduler.OneCycleLR(
429
+ optimizer,
430
+ max_lr=learning_rate,
431
+ total_steps=total_steps,
432
+ pct_start=warmup_steps/total_steps,
433
+ anneal_strategy='cos'
434
+ )
435
+
436
+ # Métricas de entrenamiento
437
+ train_losses = []
438
+ valid_losses = []
439
+ train_accuracies = []
440
+ valid_accuracies = []
441
+
442
+ # Variables para guardar el mejor modelo
443
+ best_valid_acc = 0.0
444
+ best_epoch = 0
445
+ patience_counter = 0
446
+ patience = 5 # Épocas sin mejora antes de early stopping
447
+
448
+ print(f"\nIniciando entrenamiento por {epochs} épocas...")
449
+ print(f"Clases: {self.class_columns}")
450
+ print(f"🎯 Guardado automático del mejor modelo activado")
451
+ print("=" * 60)
452
+
453
+ for epoch in range(epochs):
454
+ # Entrenamiento
455
+ model.train()
456
+ train_loss = 0.0
457
+ train_preds = []
458
+ train_labels = []
459
+
460
+ train_pbar = tqdm(train_loader, desc=f'Época {epoch+1}/{epochs} - Entrenamiento')
461
+ for batch_idx, (images, labels) in enumerate(train_pbar):
462
+ images, labels = images.to(self.device), labels.to(self.device)
463
+
464
+ optimizer.zero_grad()
465
+ outputs = model(pixel_values=images).logits
466
+ loss = criterion(outputs, labels)
467
+ loss.backward()
468
+ optimizer.step()
469
+ scheduler.step() # Actualizar cada batch para OneCycleLR
470
+
471
+ train_loss += loss.item()
472
+
473
+ # Calcular predicciones (umbral 0.5 para multi-clase)
474
+ preds = torch.sigmoid(outputs) > 0.5
475
+ train_preds.extend(preds.cpu().numpy())
476
+ train_labels.extend(labels.cpu().numpy())
477
+
478
+ train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
479
+
480
+ # Calcular métricas de entrenamiento
481
+ avg_train_loss = train_loss / len(train_loader)
482
+ train_acc = self._calculate_multilabel_accuracy(train_labels, train_preds)
483
+
484
+ train_losses.append(avg_train_loss)
485
+ train_accuracies.append(train_acc)
486
+
487
+ # Validación
488
+ if valid_loader is not None:
489
+ model.eval()
490
+ valid_loss = 0.0
491
+ valid_preds = []
492
+ valid_labels = []
493
+
494
+ with torch.no_grad():
495
+ valid_pbar = tqdm(valid_loader, desc=f'Época {epoch+1}/{epochs} - Validación')
496
+ for images, labels in valid_pbar:
497
+ images, labels = images.to(self.device), labels.to(self.device)
498
+
499
+ outputs = model(pixel_values=images).logits
500
+ loss = criterion(outputs, labels)
501
+
502
+ valid_loss += loss.item()
503
+
504
+ preds = torch.sigmoid(outputs) > 0.5
505
+ valid_preds.extend(preds.cpu().numpy())
506
+ valid_labels.extend(labels.cpu().numpy())
507
+
508
+ valid_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
509
+
510
+ avg_valid_loss = valid_loss / len(valid_loader)
511
+ valid_acc = self._calculate_multilabel_accuracy(valid_labels, valid_preds)
512
+
513
+ valid_losses.append(avg_valid_loss)
514
+ valid_accuracies.append(valid_acc)
515
+
516
+ print(f'Época {epoch+1}/{epochs}:')
517
+ print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}')
518
+ print(f' Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_acc:.4f}')
519
+
520
+ # Guardar mejor modelo automáticamente
521
+ if valid_acc > best_valid_acc:
522
+ best_valid_acc = valid_acc
523
+ best_epoch = epoch + 1
524
+ patience_counter = 0
525
+
526
+ # Guardar mejor modelo
527
+ best_model_path = f"{save_path}_best"
528
+ self._save_model(model, best_model_path)
529
+ print(f' 🎯 ¡Nuevo mejor modelo guardado! Accuracy: {valid_acc:.4f}')
530
+ else:
531
+ patience_counter += 1
532
+ print(f' 📊 Mejor accuracy sigue siendo: {best_valid_acc:.4f} (época {best_epoch})')
533
+ if patience_counter >= patience:
534
+ print(f' ⏹️ Early stopping: {patience} épocas sin mejora')
535
+ break
536
+ else:
537
+ print(f'Época {epoch+1}/{epochs}:')
538
+ print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}')
539
+
540
+ current_lr = scheduler.get_last_lr()[0]
541
+ print(f' Learning Rate: {current_lr:.2e}')
542
+ print('-' * 60)
543
+
544
+ # Guardar modelo final
545
+ final_model_path = f"{save_path}_final"
546
+ self._save_model(model, final_model_path)
547
+
548
+ # Resumen de guardado
549
+ print(f"\n📁 Modelos guardados:")
550
+ if valid_loader is not None:
551
+ print(f" 🎯 Mejor modelo: {save_path}_best (época {best_epoch}, acc: {best_valid_acc:.4f})")
552
+ print(f" 📋 Modelo final: {final_model_path} (última época)")
553
+
554
+ # Guardar métricas
555
+ metrics = {
556
+ 'train_losses': train_losses,
557
+ 'valid_losses': valid_losses,
558
+ 'train_accuracies': train_accuracies,
559
+ 'valid_accuracies': valid_accuracies,
560
+ 'class_columns': self.class_columns,
561
+ 'filename_column': self.filename_column,
562
+ 'best_valid_acc': best_valid_acc,
563
+ 'best_epoch': best_epoch
564
+ }
565
+
566
+ with open(f'{final_model_path}/training_metrics.json', 'w') as f:
567
+ json.dump(metrics, f, indent=2)
568
+
569
+ # Plotear métricas
570
+ self._plot_training_metrics(train_losses, valid_losses, train_accuracies, valid_accuracies, final_model_path)
571
+
572
+ print("\n¡Entrenamiento completado!")
573
+ print(f"Modelo guardado con resolución {IMAGE_SIZE}x{IMAGE_SIZE}")
574
+ print(f"Uso de memoria optimizado con batch size {batch_size}")
575
+ return model
576
+
577
+ # ============================================================================
578
+ # FUNCIÓN PRINCIPAL PARA JUPYTER
579
+ # ============================================================================
580
+
581
+ def train_model():
582
+ """Función principal para entrenar el modelo en Jupyter"""
583
+
584
+ print("=== Entrenamiento de ViT Multi-Clasificación ===")
585
+ print(f"Ruta de datos: {DATA_PATH}")
586
+ print(f"Épocas: {EPOCHS}")
587
+ print(f"Batch size: {BATCH_SIZE}")
588
+ print(f"Learning rate: {LEARNING_RATE}")
589
+ print(f"Modelo: {MODEL_NAME}")
590
+ print("=" * 50)
591
+
592
+ # Crear entrenador
593
+ trainer = ViTMultiClassTrainer(
594
+ data_path=DATA_PATH,
595
+ model_name=MODEL_NAME
596
+ )
597
+
598
+ # Entrenar modelo
599
+ model = trainer.train(
600
+ epochs=EPOCHS,
601
+ batch_size=BATCH_SIZE,
602
+ learning_rate=LEARNING_RATE,
603
+ save_path=SAVE_PATH
604
+ )
605
+
606
+ return model
607
+
608
+ # ============================================================================
609
+ # EJECUCIÓN DIRECTA PARA JUPYTER
610
+ # ============================================================================
611
+
612
+ # Descomenta la siguiente línea para ejecutar directamente
613
+ if __name__ == "__main__":
614
  model = train_model()