| import os |
|
|
| import torch |
|
|
| import os |
| import math |
| import torch |
| import logging |
| import random |
| import subprocess |
| import numpy as np |
| import torch.distributed as dist |
|
|
| |
| from torch import inf |
| from PIL import Image |
| from typing import Union, Iterable |
| from collections import OrderedDict |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from diffusers.utils import is_bs4_available, is_ftfy_available |
|
|
| import html |
| import re |
| import urllib.parse as ul |
|
|
| if is_bs4_available(): |
| from bs4 import BeautifulSoup |
|
|
| if is_ftfy_available(): |
| import ftfy |
|
|
| _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] |
|
|
| def find_model(model_name): |
| """ |
| Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. |
| """ |
| assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' |
| checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) |
|
|
| |
| |
| |
| |
| print('Using model!') |
| checkpoint = checkpoint['model'] |
| return checkpoint |
|
|
| |
| |
| |
|
|
| def get_grad_norm( |
| parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: |
| r""" |
| Copy from torch.nn.utils.clip_grad_norm_ |
| |
| Clips gradient norm of an iterable of parameters. |
| |
| The norm is computed over all gradients together, as if they were |
| concatenated into a single vector. Gradients are modified in-place. |
| |
| Args: |
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a |
| single Tensor that will have gradients normalized |
| max_norm (float or int): max norm of the gradients |
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for |
| infinity norm. |
| error_if_nonfinite (bool): if True, an error is thrown if the total |
| norm of the gradients from :attr:`parameters` is ``nan``, |
| ``inf``, or ``-inf``. Default: False (will switch to True in the future) |
| |
| Returns: |
| Total norm of the parameter gradients (viewed as a single vector). |
| """ |
| if isinstance(parameters, torch.Tensor): |
| parameters = [parameters] |
| grads = [p.grad for p in parameters if p.grad is not None] |
| norm_type = float(norm_type) |
| if len(grads) == 0: |
| return torch.tensor(0.) |
| device = grads[0].device |
| if norm_type == inf: |
| norms = [g.detach().abs().max().to(device) for g in grads] |
| total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) |
| else: |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) |
| return total_norm |
|
|
|
|
| def clip_grad_norm_( |
| parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, |
| error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor: |
| r""" |
| Copy from torch.nn.utils.clip_grad_norm_ |
| |
| Clips gradient norm of an iterable of parameters. |
| |
| The norm is computed over all gradients together, as if they were |
| concatenated into a single vector. Gradients are modified in-place. |
| |
| Args: |
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a |
| single Tensor that will have gradients normalized |
| max_norm (float or int): max norm of the gradients |
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for |
| infinity norm. |
| error_if_nonfinite (bool): if True, an error is thrown if the total |
| norm of the gradients from :attr:`parameters` is ``nan``, |
| ``inf``, or ``-inf``. Default: False (will switch to True in the future) |
| |
| Returns: |
| Total norm of the parameter gradients (viewed as a single vector). |
| """ |
| if isinstance(parameters, torch.Tensor): |
| parameters = [parameters] |
| grads = [p.grad for p in parameters if p.grad is not None] |
| max_norm = float(max_norm) |
| norm_type = float(norm_type) |
| if len(grads) == 0: |
| return torch.tensor(0.) |
| device = grads[0].device |
| if norm_type == inf: |
| norms = [g.detach().abs().max().to(device) for g in grads] |
| total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) |
| else: |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) |
|
|
| if clip_grad: |
| if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): |
| raise RuntimeError( |
| f'The total norm of order {norm_type} for gradients from ' |
| '`parameters` is non-finite, so it cannot be clipped. To disable ' |
| 'this error and scale the gradients by the non-finite norm anyway, ' |
| 'set `error_if_nonfinite=False`') |
| clip_coef = max_norm / (total_norm + 1e-6) |
| |
| |
| |
| clip_coef_clamped = torch.clamp(clip_coef, max=1.0) |
| for g in grads: |
| g.detach().mul_(clip_coef_clamped.to(g.device)) |
| |
| |
| return total_norm |
|
|
|
|
| def get_experiment_dir(root_dir, args): |
| |
| |
| if args.use_compile: |
| root_dir += '-Compile' |
| if args.attention_mode: |
| root_dir += f'-{args.attention_mode.upper()}' |
| |
| |
| if args.gradient_checkpointing: |
| root_dir += '-Gc' |
| if args.mixed_precision: |
| root_dir += f'-{args.mixed_precision.upper()}' |
| root_dir += f'-{args.max_image_size}' |
| return root_dir |
|
|
| def get_precision(args): |
| if args.mixed_precision == "bf16": |
| dtype = torch.bfloat16 |
| elif args.mixed_precision == "fp16": |
| dtype = torch.float16 |
| else: |
| dtype = torch.float32 |
| return dtype |
|
|
| |
| |
| |
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| if dist.get_rank() == 0: |
| logging.basicConfig( |
| level=logging.INFO, |
| |
| format='[%(asctime)s] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| else: |
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
| return logger |
|
|
|
|
| def create_tensorboard(tensorboard_dir): |
| """ |
| Create a tensorboard that saves losses. |
| """ |
| if dist.get_rank() == 0: |
| |
| writer = SummaryWriter(tensorboard_dir) |
|
|
| return writer |
|
|
|
|
| def write_tensorboard(writer, *args): |
| ''' |
| write the loss information to a tensorboard file. |
| Only for pytorch DDP mode. |
| ''' |
| if dist.get_rank() == 0: |
| writer.add_scalar(args[0], args[1], args[2]) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def cleanup(): |
| """ |
| End DDP training. |
| """ |
| dist.destroy_process_group() |
|
|
|
|
| def setup_distributed(backend="nccl", port=None): |
| """Initialize distributed training environment. |
| support both slurm and torch.distributed.launch |
| see torch.distributed.init_process_group() for more details |
| """ |
| num_gpus = torch.cuda.device_count() |
|
|
| if "SLURM_JOB_ID" in os.environ: |
| rank = int(os.environ["SLURM_PROCID"]) |
| world_size = int(os.environ["SLURM_NTASKS"]) |
| node_list = os.environ["SLURM_NODELIST"] |
| addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") |
| |
| if port is not None: |
| os.environ["MASTER_PORT"] = str(port) |
| elif "MASTER_PORT" not in os.environ: |
| |
| os.environ["MASTER_PORT"] = str(29567 + num_gpus) |
| if "MASTER_ADDR" not in os.environ: |
| os.environ["MASTER_ADDR"] = addr |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["LOCAL_RANK"] = str(rank % num_gpus) |
| os.environ["RANK"] = str(rank) |
| else: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
|
|
| |
|
|
| dist.init_process_group( |
| backend=backend, |
| world_size=world_size, |
| rank=rank, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def save_video_grid(video, nrow=None): |
| b, t, h, w, c = video.shape |
|
|
| if nrow is None: |
| nrow = math.ceil(math.sqrt(b)) |
| ncol = math.ceil(b / nrow) |
| padding = 1 |
| video_grid = torch.zeros((t, (padding + h) * nrow + padding, |
| (padding + w) * ncol + padding, c), dtype=torch.uint8) |
|
|
| print(video_grid.shape) |
| for i in range(b): |
| r = i // ncol |
| c = i % ncol |
| start_r = (padding + h) * r |
| start_c = (padding + w) * c |
| video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] |
|
|
| return video_grid |
|
|
|
|
| |
| |
| |
|
|
|
|
| def collect_env(): |
| |
| from mmcv.utils import collect_env as collect_base_env |
| from mmcv.utils import get_git_hash |
| """Collect the information of the running environments.""" |
|
|
| env_info = collect_base_env() |
| env_info['MMClassification'] = get_git_hash()[:7] |
|
|
| for name, val in env_info.items(): |
| print(f'{name}: {val}') |
|
|
| print(torch.cuda.get_arch_list()) |
| print(torch.version.cuda) |
|
|
|
|
| |
| |
| |
|
|
|
|
| bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') |
|
|
| def text_preprocessing(text): |
| |
| text = clean_caption(text) |
| text = clean_caption(text) |
| return text |
|
|
| def basic_clean(text): |
| text = ftfy.fix_text(text) |
| text = html.unescape(html.unescape(text)) |
| return text.strip() |
|
|
| def clean_caption(caption): |
| caption = str(caption) |
| caption = ul.unquote_plus(caption) |
| caption = caption.strip().lower() |
| caption = re.sub('<person>', 'person', caption) |
| |
| caption = re.sub( |
| r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', |
| '', caption) |
| caption = re.sub( |
| r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', |
| '', caption) |
| |
| caption = BeautifulSoup(caption, features='html.parser').text |
|
|
| |
| caption = re.sub(r'@[\w\d]+\b', '', caption) |
|
|
| |
| |
| |
| |
| |
| |
| |
| caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) |
| caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) |
| caption = re.sub(r'[\u3200-\u32ff]+', '', caption) |
| caption = re.sub(r'[\u3300-\u33ff]+', '', caption) |
| caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) |
| caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) |
| caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) |
| |
|
|
| |
| caption = re.sub( |
| r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', |
| '-', caption) |
|
|
| |
| caption = re.sub(r'[`´«»“”¨]', '"', caption) |
| caption = re.sub(r'[‘’]', "'", caption) |
|
|
| |
| caption = re.sub(r'"?', '', caption) |
| |
| caption = re.sub(r'&', '', caption) |
|
|
| |
| caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) |
|
|
| |
| caption = re.sub(r'\d:\d\d\s+$', '', caption) |
|
|
| |
| caption = re.sub(r'\\n', ' ', caption) |
|
|
| |
| caption = re.sub(r'#\d{1,3}\b', '', caption) |
| |
| caption = re.sub(r'#\d{5,}\b', '', caption) |
| |
| caption = re.sub(r'\b\d{6,}\b', '', caption) |
| |
| caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) |
|
|
| |
| caption = re.sub(r'[\"\']{2,}', r'"', caption) |
| caption = re.sub(r'[\.]{2,}', r' ', caption) |
|
|
| caption = re.sub(bad_punct_regex, r' ', caption) |
| caption = re.sub(r'\s+\.\s+', r' ', caption) |
|
|
| |
| regex2 = re.compile(r'(?:\-|\_)') |
| if len(re.findall(regex2, caption)) > 3: |
| caption = re.sub(regex2, ' ', caption) |
|
|
| caption = basic_clean(caption) |
|
|
| caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) |
| caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) |
| caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) |
|
|
| caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) |
| caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) |
| caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) |
| caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) |
| caption = re.sub(r'\bpage\s+\d+\b', '', caption) |
|
|
| caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) |
|
|
| caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) |
|
|
| caption = re.sub(r'\b\s+\:\s+', r': ', caption) |
| caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) |
| caption = re.sub(r'\s+', ' ', caption) |
|
|
| caption.strip() |
|
|
| caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) |
| caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) |
| caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) |
| caption = re.sub(r'^\.\S+$', '', caption) |
|
|
| return caption.strip() |
|
|
|
|
|
|
|
|
|
|