import torch import torch.nn as nn import numpy as np from typing import Tuple, Callable, Optional def normalize_data(data: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Normalize data to zero mean and unit variance. Args: data: Input tensor to normalize mean: Optional precomputed mean (if None, computed from data) std: Optional precomputed std (if None, computed from data) Returns: Tuple of (normalized_data, mean, std) """ if mean is None: mean = data.mean() if std is None: std = data.std() # Avoid division by zero std = torch.clamp(std, min=1e-8) normalized = (data - mean) / std return normalized, mean, std def denormalize_data(normalized_data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: """ Denormalize data using provided mean and std. Args: normalized_data: Normalized tensor mean: Mean used for normalization std: Standard deviation used for normalization Returns: Denormalized tensor """ return normalized_data * std + mean def mean_pooling(x: torch.Tensor, dim: int = 1) -> torch.Tensor: """ Apply mean pooling along specified dimension. Args: x: Input tensor dim: Dimension to pool over Returns: Mean-pooled tensor """ return x.mean(dim=dim) def masked_mean_pooling(x: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: """ Apply mean pooling along specified dimension, excluding masked (padded) positions. Args: x: Input tensor (B, seq_len, dim) mask: Boolean mask tensor (B, seq_len) where True indicates real data dim: Dimension to pool over (default: 1, sequence dimension) Returns: Mean-pooled tensor excluding masked positions """ if mask.dim() == 2 and x.dim() == 3: # Expand mask to match x dimensions: (B, seq_len) -> (B, seq_len, 1) mask = mask.unsqueeze(-1) # Set masked positions to 0 for summation masked_x = x * mask.float() # Sum over the specified dimension sum_x = masked_x.sum(dim=dim) # Count non-masked positions count = mask.float().sum(dim=dim) # Avoid division by zero count = torch.clamp(count, min=1e-8) # Compute mean return sum_x / count def pad_sequences(sequences: list, max_length: Optional[int] = None, padding_value: float = -1e9) -> torch.Tensor: """ Pad sequences to the same length with a configurable padding value. Args: sequences: List of tensors with different lengths max_length: Maximum length to pad to (if None, use longest sequence) padding_value: Value to use for padding (default: -1e9, avoids conflict with meaningful zeros) Returns: Padded tensor of shape (batch_size, max_length, dim) """ if max_length is None: max_length = max(seq.size(0) for seq in sequences) batch_size = len(sequences) dim = sequences[0].size(-1) padded = torch.full((batch_size, max_length, dim), padding_value, dtype=sequences[0].dtype, device=sequences[0].device) for i, seq in enumerate(sequences): length = min(seq.size(0), max_length) padded[i, :length] = seq[:length] return padded def create_padding_mask(sequences: list, max_length: Optional[int] = None) -> torch.Tensor: """ Create padding mask for sequences. Args: sequences: List of tensors with different lengths max_length: Maximum length (if None, use longest sequence) Returns: Boolean mask tensor where True indicates real data, False indicates padding """ if max_length is None: max_length = max(seq.size(0) for seq in sequences) batch_size = len(sequences) mask = torch.zeros(batch_size, max_length, dtype=torch.bool, device=sequences[0].device) for i, seq in enumerate(sequences): length = min(seq.size(0), max_length) mask[i, :length] = True return mask def compute_rmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: """ Compute Root Mean Square Error. Args: predictions: Predicted values targets: True target values Returns: RMSE value """ mse = torch.mean((predictions - targets) ** 2) return torch.sqrt(mse).item() def compute_mae(predictions: torch.Tensor, targets: torch.Tensor) -> float: """ Compute Mean Absolute Error. Args: predictions: Predicted values targets: True target values Returns: MAE value """ mae = torch.mean(torch.abs(predictions - targets)) return mae.item() class EarlyStopping: """ Early stopping utility to stop training when validation loss stops improving. """ def __init__(self, patience: int = 5, min_delta: float = 0.0, restore_best_weights: bool = True): """ Args: patience: Number of epochs with no improvement after which training will be stopped min_delta: Minimum change in monitored quantity to qualify as improvement restore_best_weights: Whether to restore model weights from the best epoch """ self.patience = patience self.min_delta = min_delta self.restore_best_weights = restore_best_weights self.best_loss = float('inf') self.counter = 0 self.best_weights = None def __call__(self, val_loss: float, model: nn.Module) -> bool: """ Check if training should be stopped. Args: val_loss: Current validation loss model: Model to potentially save weights for Returns: True if training should be stopped, False otherwise """ if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 if self.restore_best_weights: self.best_weights = model.state_dict().copy() else: self.counter += 1 if self.counter >= self.patience: if self.restore_best_weights and self.best_weights is not None: model.load_state_dict(self.best_weights) return True return False