File size: 3,830 Bytes
85653bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import random
import datetime
import numpy as np
import torch
import torch.distributed as dist
def setup_ddp():
"""
Initializes the distributed data parallel environment.
This function relies on environment variables set by `torchrun` or a similar
launcher. It initializes the process group and sets the CUDA device for the
current process.
Returns:
tuple: A tuple containing (rank, world_size, local_rank).
"""
if not dist.is_available():
raise RuntimeError("torch.distributed is not available.")
dist.init_process_group(backend="nccl")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
print(
f"[DDP Setup] Global Rank: {rank}/{world_size}, "
f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}"
)
return rank, world_size, local_rank
def cleanup_ddp():
"""Cleans up the distributed process group."""
if dist.is_initialized():
dist.destroy_process_group()
def set_seed(seed: int, rank: int = 0):
"""
Sets the random seed for reproducibility across all relevant libraries.
Args:
seed (int): The base seed value.
rank (int): The process rank, used to ensure different processes have
different seeds, which can be important for data loading.
"""
actual_seed = seed + rank
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
# The two lines below can impact performance, so they are often
# reserved for final experiments where reproducibility is critical.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str:
"""
Calculates the number of trainable parameters in a PyTorch model and returns
it as a human-readable string.
Args:
model (torch.nn.Module): The PyTorch model.
Returns:
str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K").
"""
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9:
return f"{total_params / 1e9:.1f}B" # Billions
elif total_params >= 1e6:
return f"{total_params / 1e6:.1f}M" # Millions
else:
return f"{total_params / 1e3:.1f}K" # Thousands
def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor:
"""
Reduces a tensor's value across all processes in a distributed setup.
Args:
tensor (torch.Tensor): The tensor to be reduced.
world_size (int): The total number of processes.
op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.).
Defaults to dist.ReduceOp.SUM.
Returns:
torch.Tensor: The reduced tensor, which will be identical on all processes.
"""
rt = tensor.clone()
dist.all_reduce(rt, op=op)
# Note: `dist.ReduceOp.AVG` is available in newer torch versions.
# For compatibility, manual division is sometimes used after a SUM.
if op == dist.ReduceOp.AVG:
rt /= world_size
return rt
def format_time(seconds: float) -> str:
"""
Formats a duration in seconds into a human-readable H:M:S string.
Args:
seconds (float): The total seconds.
Returns:
str: The formatted time string (e.g., "0:15:32").
"""
return str(datetime.timedelta(seconds=int(seconds)))
|