# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import torch import numpy as np import tqdm from .setup import Setup, HookMonitor # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # def train_step( # Always granted: model: torch.nn.Module, data: torch.utils.data.DataLoader, loss: torch.nn.Module, optimizer: torch.optim.Optimizer, controller: Setup, # Not always granted: scheduler: torch.optim.lr_scheduler.LRScheduler = None, ) -> float: """ Performs a single training step including forward pass, loss calculation, backward pass, and optimization step. Parameters: model (torch.nn.Module): The model to be trained. data (torch.utils.data.DataLoader): DataLoader providing the training data. loss (torch.nn.Module): Loss function to be used. optimizer (torch.optim.Optimizer): Optimizer used for gradient updates. controller (Setup): The setup object containing configuration and state. scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler to adjust the learning rate. Returns: float: The mean loss value for this training step. """ # Train mode: model.to(controller.device) model.train() # Train the model for dataloaders or iterators: losses = list() with HookMonitor(model, controller.watcher['activations'], controller.logger) as hooks: with tqdm.tqdm(data, desc=f'\rTraining epoch {controller.epoch}', leave=True) as pbar: pbar: torch.DataLoader hooks: HookMonitor for i, element in enumerate(pbar): # 1. Gather elements: args = tuple() if len(element) == 2: # Prediction: x, y = element x_m, y_m = None, None elif len(element) == 3: # Prediction with x_mask: x, y, x_m = element y_m = None elif len(element) == 4: # Prediction with x_mask and y_mask: x, y, x_m, y_m = element elif len(element) > 4: # More input arguments: x, y = element[0], element[1] x_m, y_m = element[2], element[3] args = element[4:] else: raise ValueError("DataLoader elements must have at least two elements.") # 2. Load data to device: x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True) optimizer.zero_grad() if x_m is not None: x_m = x_m.to(controller.device, non_blocking=True) if y_m is not None: y_m = y_m.to(controller.device, non_blocking=True) # 3. TRAIN - Control autocast (mem-speed): if controller.autoscaler is not None: with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type): # Forward: y_hat = model(x, x_m, *args) if x_m is not None else model(x) loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) # Backward: controller.autoscaler.scale(loss_metric).backward() controller.autoscaler.step(optimizer) controller.autoscaler.update() else: # Forward: y_hat = model(x, x_m, *args) if x_m is not None else model(x) loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) # Backward: loss_metric.backward() optimizer.step() # 4. Append to metrics: losses.append(loss_metric.item()) # 5. Monitor hooks: if controller.replay_id[0] == i: controller.register_replay(predicted=y_hat, target=y, mask=y_m) # Write in summary writer (per epoch): losses = np.array(losses) mean_loss = float(np.mean(losses)) # ================ WATCH ================ # Register parameters: for name, parameter in model.named_parameters(): controller.register(name, parameter) # Register train: controller.register('loss', mean_loss) # Register hooks: for layer_name, layer_stats in hooks.get_stats().items(): for func_name, item in layer_stats.items(): controller.register(f'{func_name}/{layer_name}', torch.Tensor([item])[0]) # ================ CONTROL ================ # Scheduler step: if scheduler is not None: controller.register('lr', scheduler.get_last_lr()[0]) scheduler.step() # Write for logger: controller.logger.info(f"Epoch [{controller.epoch}]: loss = {mean_loss:.8f}") # Checkpointing: controller.check(model, optimizer, scheduler) return mean_loss # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # def validation_step( # Always granted: model: torch.nn.Module, data: torch.utils.data.DataLoader, loss: torch.nn.Module, controller: Setup, additional_metrics: dict = (), ) -> dict: """ Performs a single validation step including forward pass and loss calculation. Parameters: model (torch.nn.Module): The model to be validated. data (torch.utils.data.DataLoader): DataLoader providing the validation data. loss (torch.nn.Module): Loss function to be used. controller (Setup): The setup object containing configuration and state. additional_metrics (dict): Additional metrics to calculate for each epoch. Returns: float: The mean loss value for this validation step. """ # Validation mode: model.to(controller.device) model.eval() # Validation the model for dataloaders or iterators: losses = list() metrics: dict[str, list | float] = {name: list() for name in additional_metrics} with torch.no_grad(): with tqdm.tqdm(data, desc=f'\rValidation epoch {controller.epoch}', leave=True) as pbar: pbar: torch.DataLoader for element in pbar: # Gather elements: if len(element) == 2: # Prediction: x, y = element x_m, y_m = None, None args = tuple() elif len(element) == 3: # Prediction with x_mask: x, y, x_m = element y_m = None args = tuple() elif len(element) == 4: # Prediction with x_mask and y_mask: x, y, x_m, y_m = element elif len(element) > 4: # More input arguments: x, y = element[0], element[1] x_m, y_m = element[2], element[3] args = element[4:] else: raise ValueError("DataLoader elements must have at least two elements.") # Load data to device: x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True) if x_m is not None: x_m = x_m.to(controller.device, non_blocking=True) if y_m is not None: y_m = y_m.to(controller.device, non_blocking=True) # Control autocast (mem-speed): if controller.autoscaler is not None: with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type): # Forward: y_hat = model(x, x_m, *args) if x_m is not None else model(x) loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) # Compute additional metrics: if additional_metrics: for name, additional_metric in additional_metrics.items(): metrics[name].append(additional_metric(y_hat, y, y_m).item()) else: # Forward: y_hat = model(x, x_m, *args) if x_m is not None else model(x) loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) # Compute additional metrics: if additional_metrics: for name, additional_metric in additional_metrics.items(): metrics[name].append(additional_metric(y_hat, y, y_m).item()) # Append to metrics: losses.append(loss_metric.item()) # Convert: losses = np.array(losses) mean_loss = float(np.mean(losses)) # Additional metrics: for name, variable in metrics.items(): metrics[name] = float(np.mean(variable)) metrics['loss'] = mean_loss # Write to register: controller.register("val_loss", mean_loss) # Write for logger: controller.logger.info(f"Epoch [{controller.epoch}]: val_loss = {mean_loss:.8f}") return metrics # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #