|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import imageio |
|
|
import os |
|
|
import random |
|
|
|
|
|
import math |
|
|
import numpy as np |
|
|
from einops import rearrange |
|
|
import torch.optim as optim |
|
|
import torch.optim.lr_scheduler as lr_scheduler |
|
|
|
|
|
import sys |
|
|
import pdb as pdb_original |
|
|
from contextlib import contextmanager |
|
|
|
|
|
COLOR_BLUE = "\033[94m" |
|
|
COLOR_RESET = "\033[0m" |
|
|
ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} |
|
|
|
|
|
def rank_zero_only(fn): |
|
|
def wrapped_fn(*args, **kwargs): |
|
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
|
return fn(*args, **kwargs) |
|
|
return wrapped_fn |
|
|
|
|
|
@rank_zero_only |
|
|
def print_gpu_usage(model_name) -> None: |
|
|
allocated_memory = torch.cuda.memory_allocated() |
|
|
reserved_memory = torch.cuda.memory_reserved() |
|
|
print(f"after {model_name} backward Allocated Memory: {allocated_memory}, Reserved Memory: {reserved_memory}") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def seed_everything(seed=0, allow_tf32=True, benchmark=True, deterministic=False): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = deterministic |
|
|
torch.backends.cudnn.benchmark = benchmark |
|
|
|
|
|
|
|
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
|
|
|
|
|
torch.use_deterministic_algorithms(deterministic) |
|
|
|
|
|
torch.backends.cudnn.allow_tf32 = allow_tf32 |
|
|
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
def print_model_summary(models): |
|
|
|
|
|
print(f"{'Layer Name':<20} {'Param #':<20}") |
|
|
print("="*40) |
|
|
total_params = 0 |
|
|
for model in models: |
|
|
for name, module in model.named_children(): |
|
|
params = sum(p.numel() for p in module.parameters()) |
|
|
total_params += params |
|
|
params_str = f"{params/1e6:.2f}M" |
|
|
print(f"{name:<20} {params_str:<20}") |
|
|
print("="*40) |
|
|
print(f"Total number of parameters: {total_params/1e6:.2f}M") |
|
|
|
|
|
def version_checker(base_version, high_version): |
|
|
try: |
|
|
from bytedance.ndtimeline import __version__ |
|
|
from packaging.version import Version |
|
|
if Version(__version__) < Version(base_version) or Version(__version__) >= Version(high_version): |
|
|
raise RuntimeError(f"bytedance.ndtimeline's version should be >={base_version} <{high_version}, but {__version__} found") |
|
|
except ImportError: |
|
|
raise RuntimeError(f"bytedance.ndtimeline's version should be >={base_version} <{high_version}") |
|
|
|
|
|
def is_torch_optim_sch(obj): |
|
|
return isinstance(obj, (optim.Optimizer, optim.lr_scheduler.LambdaLR)) |
|
|
|
|
|
def rearranged_forward(x, func): |
|
|
if x.ndim == 4: |
|
|
x = rearrange(x, "B C H W -> B H W C") |
|
|
elif x.ndim == 5: |
|
|
x = rearrange(x, "B C T H W -> B T H W C") |
|
|
x = func(x) |
|
|
if x.ndim == 4: |
|
|
x = rearrange(x, "B H W C -> B C H W") |
|
|
elif x.ndim == 5: |
|
|
x = rearrange(x, "B T H W C -> B C T H W") |
|
|
return x |
|
|
|
|
|
def is_dtype_16(data): |
|
|
return data.dtype == torch.float16 or data.dtype == torch.bfloat16 |
|
|
|
|
|
@contextmanager |
|
|
def set_tf32_flags(flag): |
|
|
old_matmul_flag = torch.backends.cuda.matmul.allow_tf32 |
|
|
old_cudnn_flag = torch.backends.cudnn.allow_tf32 |
|
|
torch.backends.cuda.matmul.allow_tf32 = flag |
|
|
torch.backends.cudnn.allow_tf32 = flag |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = old_matmul_flag |
|
|
torch.backends.cudnn.allow_tf32 = old_cudnn_flag |
|
|
|
|
|
class ByteNASManager: |
|
|
bytenas_dir = { |
|
|
|
|
|
} |
|
|
_current_bytenas = None |
|
|
_username = None |
|
|
|
|
|
@classmethod |
|
|
def set_bytenas(cls, bytenas, username="zhufengda"): |
|
|
cls._current_bytenas = bytenas |
|
|
cls._username = username |
|
|
|
|
|
@classmethod |
|
|
def get_work_dir(cls, use_username=True): |
|
|
if use_username: |
|
|
username = cls._username |
|
|
else: |
|
|
username = "" |
|
|
base_dir = cls.bytenas_dir[cls._current_bytenas] |
|
|
return os.path.join(base_dir, username) |
|
|
|
|
|
@classmethod |
|
|
def __call__(cls, rel_path, use_username=True, prefix=""): |
|
|
return os.path.join(cls.get_work_dir(use_username=use_username), prefix, rel_path) |
|
|
|
|
|
bytenas_manager = ByteNASManager() |
|
|
|
|
|
def get_last_ckpt(root_dir): |
|
|
if not os.path.exists(root_dir): return None |
|
|
ckpt_files = {} |
|
|
for dirpath, dirnames, filenames in os.walk(root_dir): |
|
|
for filename in filenames: |
|
|
if filename.endswith('.ckpt'): |
|
|
num_iter = int(filename.split('.ckpt')[0].split('_')[-1]) |
|
|
ckpt_files[num_iter]=os.path.join(dirpath, filename) |
|
|
iter_list = list(ckpt_files.keys()) |
|
|
if len(iter_list) == 0: return None |
|
|
max_iter = max(iter_list) |
|
|
return ckpt_files[max_iter] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): |
|
|
n_dims = len(x.shape) |
|
|
if src_dim < 0: |
|
|
src_dim = n_dims + src_dim |
|
|
if dest_dim < 0: |
|
|
dest_dim = n_dims + dest_dim |
|
|
|
|
|
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims |
|
|
|
|
|
dims = list(range(n_dims)) |
|
|
del dims[src_dim] |
|
|
|
|
|
permutation = [] |
|
|
ctr = 0 |
|
|
for i in range(n_dims): |
|
|
if i == dest_dim: |
|
|
permutation.append(src_dim) |
|
|
else: |
|
|
permutation.append(dims[ctr]) |
|
|
ctr += 1 |
|
|
x = x.permute(permutation) |
|
|
if make_contiguous: |
|
|
x = x.contiguous() |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def view_range(x, i, j, shape): |
|
|
shape = tuple(shape) |
|
|
|
|
|
n_dims = len(x.shape) |
|
|
if i < 0: |
|
|
i = n_dims + i |
|
|
|
|
|
if j is None: |
|
|
j = n_dims |
|
|
elif j < 0: |
|
|
j = n_dims + j |
|
|
|
|
|
assert 0 <= i < j <= n_dims |
|
|
|
|
|
x_shape = x.shape |
|
|
target_shape = x_shape[:i] + shape + x_shape[j:] |
|
|
return x.view(target_shape) |
|
|
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
|
with torch.no_grad(): |
|
|
maxk = max(topk) |
|
|
batch_size = target.size(0) |
|
|
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
|
pred = pred.t() |
|
|
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) |
|
|
|
|
|
res = [] |
|
|
for k in topk: |
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
|
return res |
|
|
|
|
|
|
|
|
def tensor_slice(x, begin, size): |
|
|
assert all([b >= 0 for b in begin]) |
|
|
size = [l - b if s == -1 else s |
|
|
for s, b, l in zip(size, begin, x.shape)] |
|
|
assert all([s >= 0 for s in size]) |
|
|
|
|
|
slices = [slice(b, b + s) for b, s in zip(begin, size)] |
|
|
return x[slices] |
|
|
|
|
|
|
|
|
def save_video_grid(video, fname, nrow=None, fps=16): |
|
|
b, c, t, h, w = video.shape |
|
|
video = video.permute(0, 2, 3, 4, 1).contiguous() |
|
|
|
|
|
video = (video.detach().cpu().numpy() * 255).astype('uint8') |
|
|
if nrow is None: |
|
|
nrow = math.ceil(math.sqrt(b)) |
|
|
ncol = math.ceil(b / nrow) |
|
|
padding = 1 |
|
|
video_grid = np.zeros((t, (padding + h) * nrow + padding, |
|
|
(padding + w) * ncol + padding, c), dtype='uint8') |
|
|
|
|
|
for i in range(b): |
|
|
r = i // ncol |
|
|
c = i % ncol |
|
|
start_r = (padding + h) * r |
|
|
start_c = (padding + w) * c |
|
|
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] |
|
|
video = [] |
|
|
for i in range(t): |
|
|
video.append(video_grid[i]) |
|
|
imageio.mimsave(fname, video, fps=fps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def comp_getattr(args, attr_name, default=None): |
|
|
if hasattr(args, attr_name): |
|
|
return getattr(args, attr_name) |
|
|
else: |
|
|
return default |
|
|
|
|
|
|
|
|
def visualize_tensors(t, name=None, nest=0): |
|
|
if name is not None: |
|
|
print(name, "current nest: ", nest) |
|
|
print("type: ", type(t)) |
|
|
if 'dict' in str(type(t)): |
|
|
print(t.keys()) |
|
|
for k in t.keys(): |
|
|
if t[k] is None: |
|
|
print(k, "None") |
|
|
else: |
|
|
if 'Tensor' in str(type(t[k])): |
|
|
print(k, t[k].shape) |
|
|
elif 'dict' in str(type(t[k])): |
|
|
print(k, 'dict') |
|
|
visualize_tensors(t[k], name, nest + 1) |
|
|
elif 'list' in str(type(t[k])): |
|
|
print(k, len(t[k])) |
|
|
visualize_tensors(t[k], name, nest + 1) |
|
|
elif 'list' in str(type(t)): |
|
|
print("list length: ", len(t)) |
|
|
for t2 in t: |
|
|
visualize_tensors(t2, name, nest + 1) |
|
|
elif 'Tensor' in str(type(t)): |
|
|
print(t.shape) |
|
|
else: |
|
|
print(t) |
|
|
return "" |
|
|
|