alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
import numpy as np
import tqdm
from .setup import Setup, HookMonitor
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def train_step(
# Always granted:
model: torch.nn.Module,
data: torch.utils.data.DataLoader,
loss: torch.nn.Module,
optimizer: torch.optim.Optimizer,
controller: Setup,
# Not always granted:
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
) -> float:
"""
Performs a single training step including forward pass, loss calculation, backward pass,
and optimization step.
Parameters:
model (torch.nn.Module): The model to be trained.
data (torch.utils.data.DataLoader): DataLoader providing the training data.
loss (torch.nn.Module): Loss function to be used.
optimizer (torch.optim.Optimizer): Optimizer used for gradient updates.
controller (Setup): The setup object containing configuration and state.
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler to adjust the learning rate.
Returns:
float: The mean loss value for this training step.
"""
# Train mode:
model.to(controller.device)
model.train()
# Train the model for dataloaders or iterators:
losses = list()
with HookMonitor(model, controller.watcher['activations'], controller.logger) as hooks:
with tqdm.tqdm(data, desc=f'\rTraining epoch {controller.epoch}', leave=True) as pbar:
pbar: torch.DataLoader
hooks: HookMonitor
for i, element in enumerate(pbar):
# 1. Gather elements:
args = tuple()
if len(element) == 2:
# Prediction:
x, y = element
x_m, y_m = None, None
elif len(element) == 3:
# Prediction with x_mask:
x, y, x_m = element
y_m = None
elif len(element) == 4:
# Prediction with x_mask and y_mask:
x, y, x_m, y_m = element
elif len(element) > 4:
# More input arguments:
x, y = element[0], element[1]
x_m, y_m = element[2], element[3]
args = element[4:]
else:
raise ValueError("DataLoader elements must have at least two elements.")
# 2. Load data to device:
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
optimizer.zero_grad()
if x_m is not None:
x_m = x_m.to(controller.device, non_blocking=True)
if y_m is not None:
y_m = y_m.to(controller.device, non_blocking=True)
# 3. TRAIN - Control autocast (mem-speed):
if controller.autoscaler is not None:
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type):
# Forward:
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
# Backward:
controller.autoscaler.scale(loss_metric).backward()
controller.autoscaler.step(optimizer)
controller.autoscaler.update()
else:
# Forward:
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
# Backward:
loss_metric.backward()
optimizer.step()
# 4. Append to metrics:
losses.append(loss_metric.item())
# 5. Monitor hooks:
if controller.replay_id[0] == i:
controller.register_replay(predicted=y_hat, target=y, mask=y_m)
# Write in summary writer (per epoch):
losses = np.array(losses)
mean_loss = float(np.mean(losses))
# ================ WATCH ================
# Register parameters:
for name, parameter in model.named_parameters():
controller.register(name, parameter)
# Register train:
controller.register('loss', mean_loss)
# Register hooks:
for layer_name, layer_stats in hooks.get_stats().items():
for func_name, item in layer_stats.items():
controller.register(f'{func_name}/{layer_name}', torch.Tensor([item])[0])
# ================ CONTROL ================
# Scheduler step:
if scheduler is not None:
controller.register('lr', scheduler.get_last_lr()[0])
scheduler.step()
# Write for logger:
controller.logger.info(f"Epoch [{controller.epoch}]: loss = {mean_loss:.8f}")
# Checkpointing:
controller.check(model, optimizer, scheduler)
return mean_loss
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def validation_step(
# Always granted:
model: torch.nn.Module,
data: torch.utils.data.DataLoader,
loss: torch.nn.Module,
controller: Setup,
additional_metrics: dict = (),
) -> dict:
"""
Performs a single validation step including forward pass and loss calculation.
Parameters:
model (torch.nn.Module): The model to be validated.
data (torch.utils.data.DataLoader): DataLoader providing the validation data.
loss (torch.nn.Module): Loss function to be used.
controller (Setup): The setup object containing configuration and state.
additional_metrics (dict): Additional metrics to calculate for each epoch.
Returns:
float: The mean loss value for this validation step.
"""
# Validation mode:
model.to(controller.device)
model.eval()
# Validation the model for dataloaders or iterators:
losses = list()
metrics: dict[str, list | float] = {name: list() for name in additional_metrics}
with torch.no_grad():
with tqdm.tqdm(data, desc=f'\rValidation epoch {controller.epoch}', leave=True) as pbar:
pbar: torch.DataLoader
for element in pbar:
# Gather elements:
if len(element) == 2:
# Prediction:
x, y = element
x_m, y_m = None, None
args = tuple()
elif len(element) == 3:
# Prediction with x_mask:
x, y, x_m = element
y_m = None
args = tuple()
elif len(element) == 4:
# Prediction with x_mask and y_mask:
x, y, x_m, y_m = element
elif len(element) > 4:
# More input arguments:
x, y = element[0], element[1]
x_m, y_m = element[2], element[3]
args = element[4:]
else:
raise ValueError("DataLoader elements must have at least two elements.")
# Load data to device:
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
if x_m is not None:
x_m = x_m.to(controller.device, non_blocking=True)
if y_m is not None:
y_m = y_m.to(controller.device, non_blocking=True)
# Control autocast (mem-speed):
if controller.autoscaler is not None:
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'),
device_type=controller.device.type):
# Forward:
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
# Compute additional metrics:
if additional_metrics:
for name, additional_metric in additional_metrics.items():
metrics[name].append(additional_metric(y_hat, y, y_m).item())
else:
# Forward:
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
# Compute additional metrics:
if additional_metrics:
for name, additional_metric in additional_metrics.items():
metrics[name].append(additional_metric(y_hat, y, y_m).item())
# Append to metrics:
losses.append(loss_metric.item())
# Convert:
losses = np.array(losses)
mean_loss = float(np.mean(losses))
# Additional metrics:
for name, variable in metrics.items():
metrics[name] = float(np.mean(variable))
metrics['loss'] = mean_loss
# Write to register:
controller.register("val_loss", mean_loss)
# Write for logger:
controller.logger.info(f"Epoch [{controller.epoch}]: val_loss = {mean_loss:.8f}")
return metrics
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #