|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from PIL import Image, ImageOps |
|
|
import torchvision.transforms as transforms |
|
|
import os |
|
|
from transformers import ViTForImageClassification, ViTConfig |
|
|
from sklearn.metrics import accuracy_score, classification_report |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from tqdm import tqdm |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
import json |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_PATH = "datasets/peru_cencosud_categories-2" |
|
|
SAVE_PATH = "vit_multiclass_model" |
|
|
MODEL_NAME = "google/vit-base-patch16-224" |
|
|
|
|
|
|
|
|
IMAGE_SIZE = 800 |
|
|
PADDING_COLOR = (128, 128, 128) |
|
|
|
|
|
|
|
|
EPOCHS = 30 |
|
|
BATCH_SIZE = 8 |
|
|
LEARNING_RATE = 1e-4 |
|
|
WEIGHT_DECAY = 1e-4 |
|
|
WARMUP_EPOCHS = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaddingImageProcessor: |
|
|
"""Procesador de imágenes personalizado que mantiene aspect ratio con padding""" |
|
|
|
|
|
def __init__(self, target_size: int = 1280, padding_color: tuple = (128, 128, 128)): |
|
|
""" |
|
|
Args: |
|
|
target_size: Tamaño objetivo (cuadrado) |
|
|
padding_color: Color del padding en RGB |
|
|
""" |
|
|
self.target_size = target_size |
|
|
self.padding_color = padding_color |
|
|
|
|
|
|
|
|
self.normalize = transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
|
|
|
def pad_to_square(self, image: Image.Image) -> Image.Image: |
|
|
"""Aplica padding para hacer la imagen cuadrada manteniendo aspect ratio""" |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
max_size = max(width, height) |
|
|
|
|
|
|
|
|
padded_image = Image.new('RGB', (max_size, max_size), self.padding_color) |
|
|
|
|
|
|
|
|
left = (max_size - width) // 2 |
|
|
top = (max_size - height) // 2 |
|
|
|
|
|
|
|
|
padded_image.paste(image, (left, top)) |
|
|
|
|
|
return padded_image |
|
|
|
|
|
def __call__(self, image: Image.Image) -> torch.Tensor: |
|
|
""" |
|
|
Procesa una imagen aplicando padding + resize |
|
|
|
|
|
Args: |
|
|
image: Imagen PIL en formato RGB |
|
|
|
|
|
Returns: |
|
|
Tensor procesado listo para el modelo |
|
|
""" |
|
|
|
|
|
padded_image = self.pad_to_square(image) |
|
|
|
|
|
|
|
|
resized_image = padded_image.resize((self.target_size, self.target_size), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
transform_to_tensor = transforms.ToTensor() |
|
|
tensor_image = transform_to_tensor(resized_image) |
|
|
|
|
|
|
|
|
normalized_image = self.normalize(tensor_image) |
|
|
|
|
|
return normalized_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiClassImageDataset(Dataset): |
|
|
"""Dataset personalizado para clasificación multi-clase de imágenes""" |
|
|
|
|
|
def __init__(self, csv_path: str, images_dir: str, image_processor: PaddingImageProcessor, |
|
|
class_columns: List[str], filename_column: str): |
|
|
""" |
|
|
Args: |
|
|
csv_path: Ruta al archivo CSV con las anotaciones |
|
|
images_dir: Directorio que contiene las imágenes |
|
|
image_processor: Procesador personalizado de imágenes |
|
|
class_columns: Lista de nombres de columnas que representan las clases |
|
|
filename_column: Nombre de la columna que contiene los nombres de archivos |
|
|
""" |
|
|
self.df = pd.read_csv(csv_path) |
|
|
self.images_dir = images_dir |
|
|
self.image_processor = image_processor |
|
|
self.class_columns = class_columns |
|
|
self.filename_column = filename_column |
|
|
|
|
|
print(f"Dataset cargado desde {csv_path}: {len(self.df)} imágenes") |
|
|
print(f"Columnas de clases: {class_columns}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.df) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.df.iloc[idx] |
|
|
|
|
|
|
|
|
img_path = os.path.join(self.images_dir, row[self.filename_column]) |
|
|
try: |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"Error cargando imagen {img_path}: {e}") |
|
|
|
|
|
image = Image.new('RGB', (224, 224), color='black') |
|
|
|
|
|
|
|
|
processed_image = self.image_processor(image) |
|
|
|
|
|
|
|
|
labels = torch.tensor([row[col] for col in self.class_columns], dtype=torch.float32) |
|
|
|
|
|
return processed_image, labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ViTMultiClassTrainer: |
|
|
"""Entrenador para ViT con clasificación multi-clase""" |
|
|
|
|
|
def __init__(self, data_path: str, model_name: str = "google/vit-base-patch16-224"): |
|
|
""" |
|
|
Args: |
|
|
data_path: Ruta base donde están los directorios train/valid/test |
|
|
model_name: Nombre del modelo ViT preentrenado |
|
|
""" |
|
|
self.data_path = data_path |
|
|
self.model_name = model_name |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Usando dispositivo: {self.device}") |
|
|
|
|
|
|
|
|
self.image_processor = PaddingImageProcessor( |
|
|
target_size=IMAGE_SIZE, |
|
|
padding_color=PADDING_COLOR |
|
|
) |
|
|
print(f"Procesador de imágenes configurado: {IMAGE_SIZE}px con padding {PADDING_COLOR}") |
|
|
|
|
|
|
|
|
self._detect_data_structure() |
|
|
|
|
|
def _find_csv_in_folder(self, folder_path: str) -> Optional[str]: |
|
|
"""Busca el archivo CSV en una carpeta específica""" |
|
|
if not os.path.exists(folder_path): |
|
|
return None |
|
|
|
|
|
csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')] |
|
|
|
|
|
if len(csv_files) == 0: |
|
|
print(f"No se encontró CSV en {folder_path}") |
|
|
return None |
|
|
elif len(csv_files) == 1: |
|
|
csv_path = os.path.join(folder_path, csv_files[0]) |
|
|
print(f"CSV encontrado: {csv_path}") |
|
|
return csv_path |
|
|
else: |
|
|
|
|
|
csv_path = os.path.join(folder_path, csv_files[0]) |
|
|
print(f"Múltiples CSVs en {folder_path}, usando: {csv_files[0]}") |
|
|
return csv_path |
|
|
|
|
|
def _detect_filename_column(self, df: pd.DataFrame) -> str: |
|
|
"""Detecta la columna que contiene los nombres de archivos""" |
|
|
possible_names = ['filename', 'image', 'image_name', 'file', 'name', 'img'] |
|
|
|
|
|
for col in possible_names: |
|
|
if col in df.columns: |
|
|
return col |
|
|
|
|
|
|
|
|
print(f"No se encontró columna de filename conocida. Usando: {df.columns[0]}") |
|
|
return df.columns[0] |
|
|
|
|
|
def _detect_data_structure(self): |
|
|
"""Detecta automáticamente la estructura de datos y clases""" |
|
|
print("Detectando estructura de datos...") |
|
|
|
|
|
|
|
|
train_folder = os.path.join(self.data_path, 'train') |
|
|
train_csv = self._find_csv_in_folder(train_folder) |
|
|
|
|
|
if train_csv is None: |
|
|
raise FileNotFoundError(f"No se encontró CSV en {train_folder}") |
|
|
|
|
|
|
|
|
df = pd.read_csv(train_csv) |
|
|
print(f"Columnas encontradas: {list(df.columns)}") |
|
|
|
|
|
|
|
|
self.filename_column = self._detect_filename_column(df) |
|
|
print(f"Columna de archivos detectada: {self.filename_column}") |
|
|
|
|
|
|
|
|
self.class_columns = [col for col in df.columns if col != self.filename_column] |
|
|
self.num_classes = len(self.class_columns) |
|
|
|
|
|
if self.num_classes == 0: |
|
|
raise ValueError("No se encontraron columnas de clases") |
|
|
|
|
|
print(f"Clases detectadas ({self.num_classes}): {self.class_columns}") |
|
|
|
|
|
|
|
|
for split in ['valid', 'test']: |
|
|
split_folder = os.path.join(self.data_path, split) |
|
|
if os.path.exists(split_folder): |
|
|
csv_path = self._find_csv_in_folder(split_folder) |
|
|
if csv_path: |
|
|
print(f"Carpeta {split}: CSV encontrado") |
|
|
else: |
|
|
print(f"Carpeta {split}: Sin CSV") |
|
|
else: |
|
|
print(f"Carpeta {split}: No existe") |
|
|
|
|
|
def _create_datasets(self) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]: |
|
|
"""Crea los datasets de entrenamiento, validación y prueba""" |
|
|
datasets = {} |
|
|
|
|
|
for split in ['train', 'valid', 'test']: |
|
|
split_folder = os.path.join(self.data_path, split) |
|
|
csv_path = self._find_csv_in_folder(split_folder) |
|
|
|
|
|
if csv_path is not None: |
|
|
datasets[split] = MultiClassImageDataset( |
|
|
csv_path=csv_path, |
|
|
images_dir=split_folder, |
|
|
image_processor=self.image_processor, |
|
|
class_columns=self.class_columns, |
|
|
filename_column=self.filename_column |
|
|
) |
|
|
else: |
|
|
datasets[split] = None |
|
|
|
|
|
return datasets.get('train'), datasets.get('valid'), datasets.get('test') |
|
|
|
|
|
def _create_model(self): |
|
|
"""Crea el modelo ViT para clasificación multi-clase con resolución personalizada""" |
|
|
|
|
|
config = ViTConfig.from_pretrained(self.model_name) |
|
|
|
|
|
|
|
|
patch_size = config.patch_size |
|
|
num_patches = (IMAGE_SIZE // patch_size) ** 2 |
|
|
|
|
|
|
|
|
config.image_size = IMAGE_SIZE |
|
|
config.num_labels = self.num_classes |
|
|
|
|
|
print(f"Configuración del modelo:") |
|
|
print(f" - Resolución de imagen: {IMAGE_SIZE}x{IMAGE_SIZE}") |
|
|
print(f" - Tamaño de patch: {patch_size}x{patch_size}") |
|
|
print(f" - Número de patches: {num_patches}") |
|
|
print(f" - Número de clases: {self.num_classes}") |
|
|
|
|
|
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
|
self.model_name, |
|
|
config=config, |
|
|
ignore_mismatched_sizes=True |
|
|
) |
|
|
|
|
|
|
|
|
model.classifier = nn.Linear(model.config.hidden_size, self.num_classes) |
|
|
|
|
|
return model.to(self.device) |
|
|
|
|
|
def _calculate_multilabel_accuracy(self, labels, preds): |
|
|
"""Calcula la precisión para clasificación multi-etiqueta""" |
|
|
labels = np.array(labels) |
|
|
preds = np.array(preds) |
|
|
|
|
|
|
|
|
exact_match = np.all(labels == preds, axis=1).mean() |
|
|
return exact_match |
|
|
|
|
|
def _save_model(self, model, save_path): |
|
|
"""Guarda el modelo entrenado""" |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
|
|
|
model.save_pretrained(save_path) |
|
|
|
|
|
|
|
|
processor_config = { |
|
|
'target_size': IMAGE_SIZE, |
|
|
'padding_color': PADDING_COLOR, |
|
|
'mean': [0.485, 0.456, 0.406], |
|
|
'std': [0.229, 0.224, 0.225] |
|
|
} |
|
|
|
|
|
with open(f'{save_path}/processor_config.json', 'w') as f: |
|
|
json.dump(processor_config, f, indent=2) |
|
|
|
|
|
|
|
|
class_info = { |
|
|
'class_columns': self.class_columns, |
|
|
'filename_column': self.filename_column, |
|
|
'num_classes': self.num_classes, |
|
|
'image_size': IMAGE_SIZE |
|
|
} |
|
|
|
|
|
with open(f'{save_path}/class_info.json', 'w') as f: |
|
|
json.dump(class_info, f, indent=2) |
|
|
|
|
|
print(f"Modelo guardado en: {save_path}") |
|
|
|
|
|
def _plot_training_metrics(self, train_losses, valid_losses, train_accs, valid_accs, save_path): |
|
|
"""Plotea las métricas de entrenamiento""" |
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
epochs = range(1, len(train_losses) + 1) |
|
|
ax1.plot(epochs, train_losses, 'b-', label='Train Loss') |
|
|
if valid_losses: |
|
|
ax1.plot(epochs, valid_losses, 'r-', label='Valid Loss') |
|
|
ax1.set_title('Pérdida durante el entrenamiento') |
|
|
ax1.set_xlabel('Época') |
|
|
ax1.set_ylabel('Pérdida') |
|
|
ax1.legend() |
|
|
ax1.grid(True) |
|
|
|
|
|
|
|
|
ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy') |
|
|
if valid_accs: |
|
|
ax2.plot(epochs, valid_accs, 'r-', label='Valid Accuracy') |
|
|
ax2.set_title('Precisión durante el entrenamiento') |
|
|
ax2.set_xlabel('Época') |
|
|
ax2.set_ylabel('Precisión') |
|
|
ax2.legend() |
|
|
ax2.grid(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(f'{save_path}/training_metrics.png', dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
print(f"Gráficas guardadas en: {save_path}/training_metrics.png") |
|
|
|
|
|
def train(self, |
|
|
epochs: int = 30, |
|
|
batch_size: int = 16, |
|
|
learning_rate: float = 1e-4, |
|
|
save_path: str = 'vit_multiclass_model'): |
|
|
""" |
|
|
Entrena el modelo ViT |
|
|
|
|
|
Args: |
|
|
epochs: Número de épocas |
|
|
batch_size: Tamaño del lote |
|
|
learning_rate: Tasa de aprendizaje |
|
|
save_path: Ruta donde guardar el modelo entrenado |
|
|
""" |
|
|
|
|
|
|
|
|
train_dataset, valid_dataset, test_dataset = self._create_datasets() |
|
|
|
|
|
if train_dataset is None: |
|
|
raise ValueError("No se pudo cargar el dataset de entrenamiento") |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2 |
|
|
) |
|
|
|
|
|
valid_loader = None |
|
|
if valid_dataset is not None: |
|
|
valid_loader = DataLoader( |
|
|
valid_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2 |
|
|
) |
|
|
|
|
|
|
|
|
model = self._create_model() |
|
|
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY) |
|
|
criterion = nn.BCEWithLogitsLoss() |
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * epochs |
|
|
warmup_steps = len(train_loader) * WARMUP_EPOCHS |
|
|
|
|
|
scheduler = optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=learning_rate, |
|
|
total_steps=total_steps, |
|
|
pct_start=warmup_steps/total_steps, |
|
|
anneal_strategy='cos' |
|
|
) |
|
|
|
|
|
|
|
|
train_losses = [] |
|
|
valid_losses = [] |
|
|
train_accuracies = [] |
|
|
valid_accuracies = [] |
|
|
|
|
|
|
|
|
best_valid_acc = 0.0 |
|
|
best_epoch = 0 |
|
|
patience_counter = 0 |
|
|
patience = 5 |
|
|
|
|
|
print(f"\nIniciando entrenamiento por {epochs} épocas...") |
|
|
print(f"Clases: {self.class_columns}") |
|
|
print(f"🎯 Guardado automático del mejor modelo activado") |
|
|
print("=" * 60) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
|
|
|
model.train() |
|
|
train_loss = 0.0 |
|
|
train_preds = [] |
|
|
train_labels = [] |
|
|
|
|
|
train_pbar = tqdm(train_loader, desc=f'Época {epoch+1}/{epochs} - Entrenamiento') |
|
|
for batch_idx, (images, labels) in enumerate(train_pbar): |
|
|
images, labels = images.to(self.device), labels.to(self.device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(pixel_values=images).logits |
|
|
loss = criterion(outputs, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
train_loss += loss.item() |
|
|
|
|
|
|
|
|
preds = torch.sigmoid(outputs) > 0.5 |
|
|
train_preds.extend(preds.cpu().numpy()) |
|
|
train_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
|
|
|
avg_train_loss = train_loss / len(train_loader) |
|
|
train_acc = self._calculate_multilabel_accuracy(train_labels, train_preds) |
|
|
|
|
|
train_losses.append(avg_train_loss) |
|
|
train_accuracies.append(train_acc) |
|
|
|
|
|
|
|
|
if valid_loader is not None: |
|
|
model.eval() |
|
|
valid_loss = 0.0 |
|
|
valid_preds = [] |
|
|
valid_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
valid_pbar = tqdm(valid_loader, desc=f'Época {epoch+1}/{epochs} - Validación') |
|
|
for images, labels in valid_pbar: |
|
|
images, labels = images.to(self.device), labels.to(self.device) |
|
|
|
|
|
outputs = model(pixel_values=images).logits |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
valid_loss += loss.item() |
|
|
|
|
|
preds = torch.sigmoid(outputs) > 0.5 |
|
|
valid_preds.extend(preds.cpu().numpy()) |
|
|
valid_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
valid_pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
avg_valid_loss = valid_loss / len(valid_loader) |
|
|
valid_acc = self._calculate_multilabel_accuracy(valid_labels, valid_preds) |
|
|
|
|
|
valid_losses.append(avg_valid_loss) |
|
|
valid_accuracies.append(valid_acc) |
|
|
|
|
|
print(f'Época {epoch+1}/{epochs}:') |
|
|
print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') |
|
|
print(f' Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_acc:.4f}') |
|
|
|
|
|
|
|
|
if valid_acc > best_valid_acc: |
|
|
best_valid_acc = valid_acc |
|
|
best_epoch = epoch + 1 |
|
|
patience_counter = 0 |
|
|
|
|
|
|
|
|
best_model_path = f"{save_path}_best" |
|
|
self._save_model(model, best_model_path) |
|
|
print(f' 🎯 ¡Nuevo mejor modelo guardado! Accuracy: {valid_acc:.4f}') |
|
|
else: |
|
|
patience_counter += 1 |
|
|
print(f' 📊 Mejor accuracy sigue siendo: {best_valid_acc:.4f} (época {best_epoch})') |
|
|
if patience_counter >= patience: |
|
|
print(f' ⏹️ Early stopping: {patience} épocas sin mejora') |
|
|
break |
|
|
else: |
|
|
print(f'Época {epoch+1}/{epochs}:') |
|
|
print(f' Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') |
|
|
|
|
|
current_lr = scheduler.get_last_lr()[0] |
|
|
print(f' Learning Rate: {current_lr:.2e}') |
|
|
print('-' * 60) |
|
|
|
|
|
|
|
|
final_model_path = f"{save_path}_final" |
|
|
self._save_model(model, final_model_path) |
|
|
|
|
|
|
|
|
print(f"\n📁 Modelos guardados:") |
|
|
if valid_loader is not None: |
|
|
print(f" 🎯 Mejor modelo: {save_path}_best (época {best_epoch}, acc: {best_valid_acc:.4f})") |
|
|
print(f" 📋 Modelo final: {final_model_path} (última época)") |
|
|
|
|
|
|
|
|
metrics = { |
|
|
'train_losses': train_losses, |
|
|
'valid_losses': valid_losses, |
|
|
'train_accuracies': train_accuracies, |
|
|
'valid_accuracies': valid_accuracies, |
|
|
'class_columns': self.class_columns, |
|
|
'filename_column': self.filename_column, |
|
|
'best_valid_acc': best_valid_acc, |
|
|
'best_epoch': best_epoch |
|
|
} |
|
|
|
|
|
with open(f'{final_model_path}/training_metrics.json', 'w') as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
|
|
|
|
|
|
self._plot_training_metrics(train_losses, valid_losses, train_accuracies, valid_accuracies, final_model_path) |
|
|
|
|
|
print("\n¡Entrenamiento completado!") |
|
|
print(f"Modelo guardado con resolución {IMAGE_SIZE}x{IMAGE_SIZE}") |
|
|
print(f"Uso de memoria optimizado con batch size {batch_size}") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(): |
|
|
"""Función principal para entrenar el modelo en Jupyter""" |
|
|
|
|
|
print("=== Entrenamiento de ViT Multi-Clasificación ===") |
|
|
print(f"Ruta de datos: {DATA_PATH}") |
|
|
print(f"Épocas: {EPOCHS}") |
|
|
print(f"Batch size: {BATCH_SIZE}") |
|
|
print(f"Learning rate: {LEARNING_RATE}") |
|
|
print(f"Modelo: {MODEL_NAME}") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
trainer = ViTMultiClassTrainer( |
|
|
data_path=DATA_PATH, |
|
|
model_name=MODEL_NAME |
|
|
) |
|
|
|
|
|
|
|
|
model = trainer.train( |
|
|
epochs=EPOCHS, |
|
|
batch_size=BATCH_SIZE, |
|
|
learning_rate=LEARNING_RATE, |
|
|
save_path=SAVE_PATH |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = train_model() |