| | from contextlib import contextmanager |
| | import hashlib |
| | import math |
| | from pathlib import Path |
| | import shutil |
| | import urllib |
| | import warnings |
| |
|
| | from PIL import Image |
| | import torch |
| | from torch import nn, optim |
| | from torch.utils import data |
| |
|
| |
|
| | def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): |
| | """Apply passed in transforms for HuggingFace Datasets.""" |
| | images = [transform(image.convert(mode)) for image in examples[image_key]] |
| | return {image_key: images} |
| |
|
| |
|
| | def append_dims(x, target_dims): |
| | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| | dims_to_append = target_dims - x.ndim |
| | if dims_to_append < 0: |
| | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') |
| | expanded = x[(...,) + (None,) * dims_to_append] |
| | |
| | |
| | return expanded.detach().clone() if expanded.device.type == 'mps' else expanded |
| |
|
| |
|
| | def n_params(module): |
| | """Returns the number of trainable parameters in a module.""" |
| | return sum(p.numel() for p in module.parameters()) |
| |
|
| |
|
| | def download_file(path, url, digest=None): |
| | """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" |
| | path = Path(path) |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| | if not path.exists(): |
| | with urllib.request.urlopen(url) as response, open(path, 'wb') as f: |
| | shutil.copyfileobj(response, f) |
| | if digest is not None: |
| | file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() |
| | if digest != file_digest: |
| | raise OSError(f'hash of {path} (url: {url}) failed to validate') |
| | return path |
| |
|
| |
|
| | @contextmanager |
| | def train_mode(model, mode=True): |
| | """A context manager that places a model into training mode and restores |
| | the previous mode on exit.""" |
| | modes = [module.training for module in model.modules()] |
| | try: |
| | yield model.train(mode) |
| | finally: |
| | for i, module in enumerate(model.modules()): |
| | module.training = modes[i] |
| |
|
| |
|
| | def eval_mode(model): |
| | """A context manager that places a model into evaluation mode and restores |
| | the previous mode on exit.""" |
| | return train_mode(model, False) |
| |
|
| |
|
| | @torch.no_grad() |
| | def ema_update(model, averaged_model, decay): |
| | """Incorporates updated model parameters into an exponential moving averaged |
| | version of a model. It should be called after each optimizer step.""" |
| | model_params = dict(model.named_parameters()) |
| | averaged_params = dict(averaged_model.named_parameters()) |
| | assert model_params.keys() == averaged_params.keys() |
| |
|
| | for name, param in model_params.items(): |
| | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) |
| |
|
| | model_buffers = dict(model.named_buffers()) |
| | averaged_buffers = dict(averaged_model.named_buffers()) |
| | assert model_buffers.keys() == averaged_buffers.keys() |
| |
|
| | for name, buf in model_buffers.items(): |
| | averaged_buffers[name].copy_(buf) |
| |
|
| |
|
| | class EMAWarmup: |
| | """Implements an EMA warmup using an inverse decay schedule. |
| | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are |
| | good values for models you plan to train for a million or more steps (reaches decay |
| | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models |
| | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at |
| | 215.4k steps). |
| | Args: |
| | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. |
| | power (float): Exponential factor of EMA warmup. Default: 1. |
| | min_value (float): The minimum EMA decay rate. Default: 0. |
| | max_value (float): The maximum EMA decay rate. Default: 1. |
| | start_at (int): The epoch to start averaging at. Default: 0. |
| | last_epoch (int): The index of last epoch. Default: 0. |
| | """ |
| |
|
| | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, |
| | last_epoch=0): |
| | self.inv_gamma = inv_gamma |
| | self.power = power |
| | self.min_value = min_value |
| | self.max_value = max_value |
| | self.start_at = start_at |
| | self.last_epoch = last_epoch |
| |
|
| | def state_dict(self): |
| | """Returns the state of the class as a :class:`dict`.""" |
| | return dict(self.__dict__.items()) |
| |
|
| | def load_state_dict(self, state_dict): |
| | """Loads the class's state. |
| | Args: |
| | state_dict (dict): scaler state. Should be an object returned |
| | from a call to :meth:`state_dict`. |
| | """ |
| | self.__dict__.update(state_dict) |
| |
|
| | def get_value(self): |
| | """Gets the current EMA decay rate.""" |
| | epoch = max(0, self.last_epoch - self.start_at) |
| | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power |
| | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) |
| |
|
| | def step(self): |
| | """Updates the step count.""" |
| | self.last_epoch += 1 |
| |
|
| |
|
| | class InverseLR(optim.lr_scheduler._LRScheduler): |
| | """Implements an inverse decay learning rate schedule with an optional exponential |
| | warmup. When last_epoch=-1, sets initial lr as lr. |
| | inv_gamma is the number of steps/epochs required for the learning rate to decay to |
| | (1 / 2)**power of its original value. |
| | Args: |
| | optimizer (Optimizer): Wrapped optimizer. |
| | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. |
| | power (float): Exponential factor of learning rate decay. Default: 1. |
| | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) |
| | Default: 0. |
| | min_lr (float): The minimum learning rate. Default: 0. |
| | last_epoch (int): The index of last epoch. Default: -1. |
| | verbose (bool): If ``True``, prints a message to stdout for |
| | each update. Default: ``False``. |
| | """ |
| |
|
| | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., |
| | last_epoch=-1, verbose=False): |
| | self.inv_gamma = inv_gamma |
| | self.power = power |
| | if not 0. <= warmup < 1: |
| | raise ValueError('Invalid value for warmup') |
| | self.warmup = warmup |
| | self.min_lr = min_lr |
| | super().__init__(optimizer, last_epoch, verbose) |
| |
|
| | def get_lr(self): |
| | if not self._get_lr_called_within_step: |
| | warnings.warn("To get the last learning rate computed by the scheduler, " |
| | "please use `get_last_lr()`.") |
| |
|
| | return self._get_closed_form_lr() |
| |
|
| | def _get_closed_form_lr(self): |
| | warmup = 1 - self.warmup ** (self.last_epoch + 1) |
| | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power |
| | return [warmup * max(self.min_lr, base_lr * lr_mult) |
| | for base_lr in self.base_lrs] |
| |
|
| |
|
| | class ExponentialLR(optim.lr_scheduler._LRScheduler): |
| | """Implements an exponential learning rate schedule with an optional exponential |
| | warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate |
| | continuously by decay (default 0.5) every num_steps steps. |
| | Args: |
| | optimizer (Optimizer): Wrapped optimizer. |
| | num_steps (float): The number of steps to decay the learning rate by decay in. |
| | decay (float): The factor by which to decay the learning rate every num_steps |
| | steps. Default: 0.5. |
| | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) |
| | Default: 0. |
| | min_lr (float): The minimum learning rate. Default: 0. |
| | last_epoch (int): The index of last epoch. Default: -1. |
| | verbose (bool): If ``True``, prints a message to stdout for |
| | each update. Default: ``False``. |
| | """ |
| |
|
| | def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., |
| | last_epoch=-1, verbose=False): |
| | self.num_steps = num_steps |
| | self.decay = decay |
| | if not 0. <= warmup < 1: |
| | raise ValueError('Invalid value for warmup') |
| | self.warmup = warmup |
| | self.min_lr = min_lr |
| | super().__init__(optimizer, last_epoch, verbose) |
| |
|
| | def get_lr(self): |
| | if not self._get_lr_called_within_step: |
| | warnings.warn("To get the last learning rate computed by the scheduler, " |
| | "please use `get_last_lr()`.") |
| |
|
| | return self._get_closed_form_lr() |
| |
|
| | def _get_closed_form_lr(self): |
| | warmup = 1 - self.warmup ** (self.last_epoch + 1) |
| | lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch |
| | return [warmup * max(self.min_lr, base_lr * lr_mult) |
| | for base_lr in self.base_lrs] |
| |
|
| |
|
| | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): |
| | """Draws samples from an lognormal distribution.""" |
| | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() |
| |
|
| |
|
| | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): |
| | """Draws samples from an optionally truncated log-logistic distribution.""" |
| | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) |
| | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) |
| | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() |
| | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() |
| | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf |
| | return u.logit().mul(scale).add(loc).exp().to(dtype) |
| |
|
| |
|
| | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): |
| | """Draws samples from an log-uniform distribution.""" |
| | min_value = math.log(min_value) |
| | max_value = math.log(max_value) |
| | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() |
| |
|
| |
|
| | def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): |
| | """Draws samples from a truncated v-diffusion training timestep distribution.""" |
| | min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi |
| | max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi |
| | u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf |
| | return torch.tan(u * math.pi / 2) * sigma_data |
| |
|
| |
|
| | def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): |
| | """Draws samples from a split lognormal distribution.""" |
| | n = torch.randn(shape, device=device, dtype=dtype).abs() |
| | u = torch.rand(shape, device=device, dtype=dtype) |
| | n_left = n * -scale_1 + loc |
| | n_right = n * scale_2 + loc |
| | ratio = scale_1 / (scale_1 + scale_2) |
| | return torch.where(u < ratio, n_left, n_right).exp() |
| |
|
| |
|
| | class FolderOfImages(data.Dataset): |
| | """Recursively finds all images in a directory. It does not support |
| | classes/targets.""" |
| |
|
| | IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} |
| |
|
| | def __init__(self, root, transform=None): |
| | super().__init__() |
| | self.root = Path(root) |
| | self.transform = nn.Identity() if transform is None else transform |
| | self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) |
| |
|
| | def __repr__(self): |
| | return f'FolderOfImages(root="{self.root}", len: {len(self)})' |
| |
|
| | def __len__(self): |
| | return len(self.paths) |
| |
|
| | def __getitem__(self, key): |
| | path = self.paths[key] |
| | with open(path, 'rb') as f: |
| | image = Image.open(f).convert('RGB') |
| | image = self.transform(image) |
| | return image, |
| |
|
| |
|
| | class CSVLogger: |
| | def __init__(self, filename, columns): |
| | self.filename = Path(filename) |
| | self.columns = columns |
| | if self.filename.exists(): |
| | self.file = open(self.filename, 'a') |
| | else: |
| | self.file = open(self.filename, 'w') |
| | self.write(*self.columns) |
| |
|
| | def write(self, *args): |
| | print(*args, sep=',', file=self.file, flush=True) |
| |
|
| |
|
| | @contextmanager |
| | def tf32_mode(cudnn=None, matmul=None): |
| | """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" |
| | cudnn_old = torch.backends.cudnn.allow_tf32 |
| | matmul_old = torch.backends.cuda.matmul.allow_tf32 |
| | try: |
| | if cudnn is not None: |
| | torch.backends.cudnn.allow_tf32 = cudnn |
| | if matmul is not None: |
| | torch.backends.cuda.matmul.allow_tf32 = matmul |
| | yield |
| | finally: |
| | if cudnn is not None: |
| | torch.backends.cudnn.allow_tf32 = cudnn_old |
| | if matmul is not None: |
| | torch.backends.cuda.matmul.allow_tf32 = matmul_old |
| |
|