# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT import torch import torch.distributed as dist import imageio import os import random import math import numpy as np from einops import rearrange import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import sys import pdb as pdb_original from contextlib import contextmanager COLOR_BLUE = "\033[94m" COLOR_RESET = "\033[0m" ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} def rank_zero_only(fn): def wrapped_fn(*args, **kwargs): if not dist.is_initialized() or dist.get_rank() == 0: return fn(*args, **kwargs) return wrapped_fn @rank_zero_only def print_gpu_usage(model_name) -> None: allocated_memory = torch.cuda.memory_allocated() reserved_memory = torch.cuda.memory_reserved() print(f"after {model_name} backward Allocated Memory: {allocated_memory}, Reserved Memory: {reserved_memory}") torch.cuda.empty_cache() def seed_everything(seed=0, allow_tf32=True, benchmark=True, deterministic=False): random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = benchmark # default False in torch 2.3.1 # See https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # See https://pytorch.org/docs/stable/notes/randomness.html torch.use_deterministic_algorithms(deterministic) torch.backends.cudnn.allow_tf32 = allow_tf32 # default True in torch 2.3.1 torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # default True in torch 2.3.1 # Function to print model summary in table format @rank_zero_only def print_model_summary(models): # Table headers print(f"{'Layer Name':<20} {'Param #':<20}") print("="*40) total_params = 0 for model in models: for name, module in model.named_children(): params = sum(p.numel() for p in module.parameters()) total_params += params params_str = f"{params/1e6:.2f}M" print(f"{name:<20} {params_str:<20}") print("="*40) print(f"Total number of parameters: {total_params/1e6:.2f}M") def version_checker(base_version, high_version): try: from bytedance.ndtimeline import __version__ from packaging.version import Version if Version(__version__) < Version(base_version) or Version(__version__) >= Version(high_version): raise RuntimeError(f"bytedance.ndtimeline's version should be >={base_version} <{high_version}, but {__version__} found") except ImportError: raise RuntimeError(f"bytedance.ndtimeline's version should be >={base_version} <{high_version}") def is_torch_optim_sch(obj): return isinstance(obj, (optim.Optimizer, optim.lr_scheduler.LambdaLR)) def rearranged_forward(x, func): if x.ndim == 4: x = rearrange(x, "B C H W -> B H W C") elif x.ndim == 5: x = rearrange(x, "B C T H W -> B T H W C") x = func(x) if x.ndim == 4: x = rearrange(x, "B H W C -> B C H W") elif x.ndim == 5: x = rearrange(x, "B T H W C -> B C T H W") return x def is_dtype_16(data): return data.dtype == torch.float16 or data.dtype == torch.bfloat16 @contextmanager def set_tf32_flags(flag): old_matmul_flag = torch.backends.cuda.matmul.allow_tf32 old_cudnn_flag = torch.backends.cudnn.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = flag torch.backends.cudnn.allow_tf32 = flag try: yield finally: # Restore the original flags torch.backends.cuda.matmul.allow_tf32 = old_matmul_flag torch.backends.cudnn.allow_tf32 = old_cudnn_flag class ByteNASManager: bytenas_dir = { } _current_bytenas = None _username = None @classmethod def set_bytenas(cls, bytenas, username="zhufengda"): cls._current_bytenas = bytenas cls._username = username @classmethod def get_work_dir(cls, use_username=True): if use_username: username = cls._username else: username = "" base_dir = cls.bytenas_dir[cls._current_bytenas] return os.path.join(base_dir, username) @classmethod def __call__(cls, rel_path, use_username=True, prefix=""): return os.path.join(cls.get_work_dir(use_username=use_username), prefix, rel_path) bytenas_manager = ByteNASManager() def get_last_ckpt(root_dir): if not os.path.exists(root_dir): return None ckpt_files = {} for dirpath, dirnames, filenames in os.walk(root_dir): for filename in filenames: if filename.endswith('.ckpt'): num_iter = int(filename.split('.ckpt')[0].split('_')[-1]) ckpt_files[num_iter]=os.path.join(dirpath, filename) iter_list = list(ckpt_files.keys()) if len(iter_list) == 0: return None max_iter = max(iter_list) return ckpt_files[max_iter] # Shifts src_tf dim to dest dim # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x # reshapes tensor start from dim i (inclusive) # to dim j (exclusive) to the desired shape # e.g. if x.shape = (b, thw, c) then # view_range(x, 1, 2, (t, h, w)) returns # x of shape (b, t, h, w, c) def view_range(x, i, j, shape): shape = tuple(shape) n_dims = len(x.shape) if i < 0: i = n_dims + i if j is None: j = n_dims elif j < 0: j = n_dims + j assert 0 <= i < j <= n_dims x_shape = x.shape target_shape = x_shape[:i] + shape + x_shape[j:] return x.view(target_shape) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def tensor_slice(x, begin, size): assert all([b >= 0 for b in begin]) size = [l - b if s == -1 else s for s, b, l in zip(size, begin, x.shape)] assert all([s >= 0 for s in size]) slices = [slice(b, b + s) for b, s in zip(begin, size)] return x[slices] def save_video_grid(video, fname, nrow=None, fps=16): b, c, t, h, w = video.shape video = video.permute(0, 2, 3, 4, 1).contiguous() video = (video.detach().cpu().numpy() * 255).astype('uint8') if nrow is None: nrow = math.ceil(math.sqrt(b)) ncol = math.ceil(b / nrow) padding = 1 video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype='uint8') # print(video_grid.shape) for i in range(b): r = i // ncol c = i % ncol start_r = (padding + h) * r start_c = (padding + w) * c video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] video = [] for i in range(t): video.append(video_grid[i]) imageio.mimsave(fname, video, fps=fps) # skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) # print('saved videos to', fname) def comp_getattr(args, attr_name, default=None): if hasattr(args, attr_name): return getattr(args, attr_name) else: return default def visualize_tensors(t, name=None, nest=0): if name is not None: print(name, "current nest: ", nest) print("type: ", type(t)) if 'dict' in str(type(t)): print(t.keys()) for k in t.keys(): if t[k] is None: print(k, "None") else: if 'Tensor' in str(type(t[k])): print(k, t[k].shape) elif 'dict' in str(type(t[k])): print(k, 'dict') visualize_tensors(t[k], name, nest + 1) elif 'list' in str(type(t[k])): print(k, len(t[k])) visualize_tensors(t[k], name, nest + 1) elif 'list' in str(type(t)): print("list length: ", len(t)) for t2 in t: visualize_tensors(t2, name, nest + 1) elif 'Tensor' in str(type(t)): print(t.shape) else: print(t) return ""