BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
# from https://github.com/FoundationVision/LlamaGen/blob/main/utils/distributed.py
import os
import sys
import glob
import torch
import subprocess
import torch.distributed as dist
import datetime
import logging
from infinity.models.videovae.utils.misc import rank_zero_only, COLOR_BLUE, COLOR_RESET
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
)
from infinity.models.videovae.models.cvivit_vqgan import CViViT_Decoder, CViViT_Encoder
def setup_for_distributed(is_master, logging_dir=""):
"""
This function disables printing when not in master process and
redirects stdout to log_out.txt and stderr to log_err.txt.
"""
import builtins as __builtin__
class Logger(logging.StreamHandler):
def __init__(self, stream, file):
super().__init__(stream)
self.file = file
def emit(self, record):
try:
msg = self.format(record)
stream = self.stream
fs = "%s\n"
# Stream to the original stream and then flush
stream.write(fs % msg)
stream.flush()
# Stream to the file and then flush
self.file.write(fs % msg)
self.file.flush()
except Exception as e:
self.handleError(record)
def isatty(self):
# Mimic the isatty method usually found in file-like objects
return self.stream.isatty()
# print rank 0 only
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
if is_master:
os.makedirs(logging_dir, exist_ok=True)
existing_logs = glob.glob(os.path.join(logging_dir, 'log_out_*.txt'))
log_numbers = [int(log.split('.txt')[0].split('_')[-1]) for log in existing_logs]
next_log_number = max(log_numbers) + 1 if log_numbers else 1
log_out_path = os.path.join(logging_dir, f'log_out_{next_log_number}.txt')
log_err_path = os.path.join(logging_dir, f'log_err_{next_log_number}.txt')
logger_stdout = Logger(sys.stdout, open(log_out_path, 'w'))
logger_stderr = Logger(sys.stderr, open(log_err_path, 'w'))
logging.basicConfig(level=logging.DEBUG, handlers=[logger_stdout, logger_stderr])
print(f"{COLOR_BLUE}stdout will be written to {log_out_path}{COLOR_RESET}")
print(f"{COLOR_BLUE}stderr will be written to {log_err_path}{COLOR_RESET}")
def init_distributed_mode(args, timeout_minutes=15):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=datetime.timedelta(seconds=timeout_minutes * 60)
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0, args.default_root_dir)
def _FSDP(model: torch.nn.Module, device, zero) -> FSDP:
model = FSDP(
model,
auto_wrap_policy=ModuleWrapPolicy([CViViT_Encoder, CViViT_Decoder]),
device_id=device,
sharding_strategy={1:ShardingStrategy.HYBRID_SHARD, 2:ShardingStrategy.SHARD_GRAD_OP, 3:ShardingStrategy.FULL_SHARD}.get(zero),
mixed_precision=MixedPrecision(
param_dtype=torch.float,
reduce_dtype=torch.float,
buffer_dtype=torch.float,
),
sync_module_states=True,
limit_all_gathers=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def reduce_losses(loss_dict, dst=0):
loss_names = list(loss_dict.keys())
loss_tensor = torch.stack([loss_dict[name] for name in loss_names])
dist.reduce(loss_tensor, dst=dst, op=dist.ReduceOp.SUM)
# Only average the loss values on the destination rank
if dist.get_rank() == dst:
loss_tensor /= dist.get_world_size()
averaged_losses = {name: loss_tensor[i].item() for i, name in enumerate(loss_names)}
else:
averaged_losses = {name: None for name in loss_names}
return averaged_losses
@rank_zero_only
def average_losses(loss_dict_list):
sum_dict = {}
count_dict = {}
for loss_dict in loss_dict_list:
for key, value in loss_dict.items():
if key in sum_dict:
sum_dict[key] += value
count_dict[key] += 1
else:
sum_dict[key] = value
count_dict[key] = 1
avg_dict = {key: sum_dict[key] / count_dict[key] for key in sum_dict}
return avg_dict