File size: 6,296 Bytes
3d1c0e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
# from https://github.com/FoundationVision/LlamaGen/blob/main/utils/distributed.py
import os
import sys
import glob
import torch
import subprocess
import torch.distributed as dist
import datetime
import logging
from infinity.models.videovae.utils.misc import rank_zero_only, COLOR_BLUE, COLOR_RESET
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
)
from infinity.models.videovae.models.cvivit_vqgan import CViViT_Decoder, CViViT_Encoder
def setup_for_distributed(is_master, logging_dir=""):
"""
This function disables printing when not in master process and
redirects stdout to log_out.txt and stderr to log_err.txt.
"""
import builtins as __builtin__
class Logger(logging.StreamHandler):
def __init__(self, stream, file):
super().__init__(stream)
self.file = file
def emit(self, record):
try:
msg = self.format(record)
stream = self.stream
fs = "%s\n"
# Stream to the original stream and then flush
stream.write(fs % msg)
stream.flush()
# Stream to the file and then flush
self.file.write(fs % msg)
self.file.flush()
except Exception as e:
self.handleError(record)
def isatty(self):
# Mimic the isatty method usually found in file-like objects
return self.stream.isatty()
# print rank 0 only
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
if is_master:
os.makedirs(logging_dir, exist_ok=True)
existing_logs = glob.glob(os.path.join(logging_dir, 'log_out_*.txt'))
log_numbers = [int(log.split('.txt')[0].split('_')[-1]) for log in existing_logs]
next_log_number = max(log_numbers) + 1 if log_numbers else 1
log_out_path = os.path.join(logging_dir, f'log_out_{next_log_number}.txt')
log_err_path = os.path.join(logging_dir, f'log_err_{next_log_number}.txt')
logger_stdout = Logger(sys.stdout, open(log_out_path, 'w'))
logger_stderr = Logger(sys.stderr, open(log_err_path, 'w'))
logging.basicConfig(level=logging.DEBUG, handlers=[logger_stdout, logger_stderr])
print(f"{COLOR_BLUE}stdout will be written to {log_out_path}{COLOR_RESET}")
print(f"{COLOR_BLUE}stderr will be written to {log_err_path}{COLOR_RESET}")
def init_distributed_mode(args, timeout_minutes=15):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=datetime.timedelta(seconds=timeout_minutes * 60)
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0, args.default_root_dir)
def _FSDP(model: torch.nn.Module, device, zero) -> FSDP:
model = FSDP(
model,
auto_wrap_policy=ModuleWrapPolicy([CViViT_Encoder, CViViT_Decoder]),
device_id=device,
sharding_strategy={1:ShardingStrategy.HYBRID_SHARD, 2:ShardingStrategy.SHARD_GRAD_OP, 3:ShardingStrategy.FULL_SHARD}.get(zero),
mixed_precision=MixedPrecision(
param_dtype=torch.float,
reduce_dtype=torch.float,
buffer_dtype=torch.float,
),
sync_module_states=True,
limit_all_gathers=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def reduce_losses(loss_dict, dst=0):
loss_names = list(loss_dict.keys())
loss_tensor = torch.stack([loss_dict[name] for name in loss_names])
dist.reduce(loss_tensor, dst=dst, op=dist.ReduceOp.SUM)
# Only average the loss values on the destination rank
if dist.get_rank() == dst:
loss_tensor /= dist.get_world_size()
averaged_losses = {name: loss_tensor[i].item() for i, name in enumerate(loss_names)}
else:
averaged_losses = {name: None for name in loss_names}
return averaged_losses
@rank_zero_only
def average_losses(loss_dict_list):
sum_dict = {}
count_dict = {}
for loss_dict in loss_dict_list:
for key, value in loss_dict.items():
if key in sum_dict:
sum_dict[key] += value
count_dict[key] += 1
else:
sum_dict[key] = value
count_dict[key] = 1
avg_dict = {key: sum_dict[key] / count_dict[key] for key in sum_dict}
return avg_dict
|