StarMist0012's picture
Add files using upload-large-folder tool
05b535a verified
"""Utility helper functions."""
import random
import torch
import numpy as np
def set_seed(seed: int) -> None:
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_device(device_str: str = "cuda") -> torch.device:
"""
Get torch device.
Args:
device_str: Device string ("cuda" or "cpu")
Returns:
torch.device instance
"""
if device_str == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def format_size(num_bytes: int) -> str:
"""Format bytes to human-readable size."""
for unit in ["B", "KB", "MB", "GB"]:
if num_bytes < 1024:
return f"{num_bytes:.1f}{unit}"
num_bytes /= 1024
return f"{num_bytes:.1f}TB"
def count_parameters(model: torch.nn.Module) -> tuple[int, int]:
"""
Count parameters in model.
Args:
model: PyTorch model
Returns:
Tuple of (total_params, trainable_params)
"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def get_dtype(dtype_str: str) -> torch.dtype:
"""
Get torch dtype from string.
Args:
dtype_str: Dtype string (float32, float16, bfloat16)
Returns:
torch.dtype
"""
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
return dtype_map.get(dtype_str, torch.float32)