"""Utility helper functions.""" import random import torch import numpy as np def set_seed(seed: int) -> None: """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(device_str: str = "cuda") -> torch.device: """ Get torch device. Args: device_str: Device string ("cuda" or "cpu") Returns: torch.device instance """ if device_str == "cuda" and torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def format_size(num_bytes: int) -> str: """Format bytes to human-readable size.""" for unit in ["B", "KB", "MB", "GB"]: if num_bytes < 1024: return f"{num_bytes:.1f}{unit}" num_bytes /= 1024 return f"{num_bytes:.1f}TB" def count_parameters(model: torch.nn.Module) -> tuple[int, int]: """ Count parameters in model. Args: model: PyTorch model Returns: Tuple of (total_params, trainable_params) """ total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return total_params, trainable_params def get_dtype(dtype_str: str) -> torch.dtype: """ Get torch dtype from string. Args: dtype_str: Dtype string (float32, float16, bfloat16) Returns: torch.dtype """ dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } return dtype_map.get(dtype_str, torch.float32)