| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from typing import Tuple, Callable, Optional |
| |
|
| |
|
| | def normalize_data(data: torch.Tensor, mean: Optional[torch.Tensor] = None, |
| | std: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Normalize data to zero mean and unit variance. |
| | |
| | Args: |
| | data: Input tensor to normalize |
| | mean: Optional precomputed mean (if None, computed from data) |
| | std: Optional precomputed std (if None, computed from data) |
| | |
| | Returns: |
| | Tuple of (normalized_data, mean, std) |
| | """ |
| | if mean is None: |
| | mean = data.mean() |
| | if std is None: |
| | std = data.std() |
| | |
| | |
| | std = torch.clamp(std, min=1e-8) |
| | |
| | normalized = (data - mean) / std |
| | return normalized, mean, std |
| |
|
| |
|
| | def denormalize_data(normalized_data: torch.Tensor, mean: torch.Tensor, |
| | std: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Denormalize data using provided mean and std. |
| | |
| | Args: |
| | normalized_data: Normalized tensor |
| | mean: Mean used for normalization |
| | std: Standard deviation used for normalization |
| | |
| | Returns: |
| | Denormalized tensor |
| | """ |
| | return normalized_data * std + mean |
| |
|
| |
|
| | def mean_pooling(x: torch.Tensor, dim: int = 1) -> torch.Tensor: |
| | """ |
| | Apply mean pooling along specified dimension. |
| | |
| | Args: |
| | x: Input tensor |
| | dim: Dimension to pool over |
| | |
| | Returns: |
| | Mean-pooled tensor |
| | """ |
| | return x.mean(dim=dim) |
| |
|
| |
|
| | def masked_mean_pooling(x: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: |
| | """ |
| | Apply mean pooling along specified dimension, excluding masked (padded) positions. |
| | |
| | Args: |
| | x: Input tensor (B, seq_len, dim) |
| | mask: Boolean mask tensor (B, seq_len) where True indicates real data |
| | dim: Dimension to pool over (default: 1, sequence dimension) |
| | |
| | Returns: |
| | Mean-pooled tensor excluding masked positions |
| | """ |
| | if mask.dim() == 2 and x.dim() == 3: |
| | |
| | mask = mask.unsqueeze(-1) |
| | |
| | |
| | masked_x = x * mask.float() |
| | |
| | |
| | sum_x = masked_x.sum(dim=dim) |
| | |
| | |
| | count = mask.float().sum(dim=dim) |
| | |
| | |
| | count = torch.clamp(count, min=1e-8) |
| | |
| | |
| | return sum_x / count |
| |
|
| |
|
| |
|
| |
|
| | def pad_sequences(sequences: list, max_length: Optional[int] = None, |
| | padding_value: float = -1e9) -> torch.Tensor: |
| | """ |
| | Pad sequences to the same length with a configurable padding value. |
| | |
| | Args: |
| | sequences: List of tensors with different lengths |
| | max_length: Maximum length to pad to (if None, use longest sequence) |
| | padding_value: Value to use for padding (default: -1e9, avoids conflict with meaningful zeros) |
| | |
| | Returns: |
| | Padded tensor of shape (batch_size, max_length, dim) |
| | """ |
| | if max_length is None: |
| | max_length = max(seq.size(0) for seq in sequences) |
| | |
| | batch_size = len(sequences) |
| | dim = sequences[0].size(-1) |
| | |
| | padded = torch.full((batch_size, max_length, dim), padding_value, |
| | dtype=sequences[0].dtype, device=sequences[0].device) |
| | |
| | for i, seq in enumerate(sequences): |
| | length = min(seq.size(0), max_length) |
| | padded[i, :length] = seq[:length] |
| | |
| | return padded |
| |
|
| |
|
| | def create_padding_mask(sequences: list, max_length: Optional[int] = None) -> torch.Tensor: |
| | """ |
| | Create padding mask for sequences. |
| | |
| | Args: |
| | sequences: List of tensors with different lengths |
| | max_length: Maximum length (if None, use longest sequence) |
| | |
| | Returns: |
| | Boolean mask tensor where True indicates real data, False indicates padding |
| | """ |
| | if max_length is None: |
| | max_length = max(seq.size(0) for seq in sequences) |
| | |
| | batch_size = len(sequences) |
| | mask = torch.zeros(batch_size, max_length, dtype=torch.bool, device=sequences[0].device) |
| | |
| | for i, seq in enumerate(sequences): |
| | length = min(seq.size(0), max_length) |
| | mask[i, :length] = True |
| | |
| | return mask |
| |
|
| |
|
| |
|
| |
|
| | def compute_rmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
| | """ |
| | Compute Root Mean Square Error. |
| | |
| | Args: |
| | predictions: Predicted values |
| | targets: True target values |
| | |
| | Returns: |
| | RMSE value |
| | """ |
| | mse = torch.mean((predictions - targets) ** 2) |
| | return torch.sqrt(mse).item() |
| |
|
| |
|
| | def compute_mae(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
| | """ |
| | Compute Mean Absolute Error. |
| | |
| | Args: |
| | predictions: Predicted values |
| | targets: True target values |
| | |
| | Returns: |
| | MAE value |
| | """ |
| | mae = torch.mean(torch.abs(predictions - targets)) |
| | return mae.item() |
| |
|
| |
|
| | class EarlyStopping: |
| | """ |
| | Early stopping utility to stop training when validation loss stops improving. |
| | """ |
| | |
| | def __init__(self, patience: int = 5, min_delta: float = 0.0, |
| | restore_best_weights: bool = True): |
| | """ |
| | Args: |
| | patience: Number of epochs with no improvement after which training will be stopped |
| | min_delta: Minimum change in monitored quantity to qualify as improvement |
| | restore_best_weights: Whether to restore model weights from the best epoch |
| | """ |
| | self.patience = patience |
| | self.min_delta = min_delta |
| | self.restore_best_weights = restore_best_weights |
| | |
| | self.best_loss = float('inf') |
| | self.counter = 0 |
| | self.best_weights = None |
| | |
| | def __call__(self, val_loss: float, model: nn.Module) -> bool: |
| | """ |
| | Check if training should be stopped. |
| | |
| | Args: |
| | val_loss: Current validation loss |
| | model: Model to potentially save weights for |
| | |
| | Returns: |
| | True if training should be stopped, False otherwise |
| | """ |
| | if val_loss < self.best_loss - self.min_delta: |
| | self.best_loss = val_loss |
| | self.counter = 0 |
| | if self.restore_best_weights: |
| | self.best_weights = model.state_dict().copy() |
| | else: |
| | self.counter += 1 |
| | |
| | if self.counter >= self.patience: |
| | if self.restore_best_weights and self.best_weights is not None: |
| | model.load_state_dict(self.best_weights) |
| | return True |
| | |
| | return False |