| | import logging |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def create_logger(logging_dir): |
| | """ |
| | Create a logger that writes to a log file and stdout. |
| | """ |
| | if dist.get_rank() == 0: |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="[\033[34m%(asctime)s\033[0m] %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | handlers=[ |
| | logging.StreamHandler(), |
| | logging.FileHandler(f"{logging_dir}/log.txt"), |
| | ], |
| | ) |
| | logger = logging.getLogger(__name__) |
| | else: |
| | logger = logging.getLogger(__name__) |
| | logger.addHandler(logging.NullHandler()) |
| | return logger |
| |
|
| |
|
| | @torch.no_grad() |
| | def update_ema(ema_model, model, decay=0.9999): |
| | """ |
| | Step the EMA model towards the current model. |
| | """ |
| | ema_ps = [] |
| | ps = [] |
| |
|
| | for e, m in zip(ema_model.parameters(), model.parameters()): |
| | if m.requires_grad: |
| | ema_ps.append(e) |
| | ps.append(m) |
| | torch._foreach_lerp_(ema_ps, ps, 1.0 - decay) |
| |
|
| |
|
| | @torch.no_grad() |
| | def sync_frozen_params_once(ema_model, model): |
| | for e, m in zip(ema_model.parameters(), model.parameters()): |
| | if not m.requires_grad: |
| | e.copy_(m) |
| |
|
| |
|
| | def requires_grad(model, flag=True): |
| | """ |
| | Set requires_grad flag for all parameters in a model. |
| | """ |
| | for p in model.parameters(): |
| | p.requires_grad = flag |
| |
|
| | def patchify_raster(x, p): |
| | B, C, H, W = x.shape |
| | |
| | assert H % p == 0 and W % p == 0, f"Image dimensions ({H},{W}) must be divisible by patch size {p}" |
| | |
| | h_patches = H // p |
| | w_patches = W // p |
| | |
| |
|
| | x = x.view(B, C, h_patches, p, w_patches, p) |
| | |
| |
|
| | x = x.permute(0, 2, 4, 3, 5, 1).contiguous() |
| |
|
| | x = x.reshape(B, -1, C) |
| | |
| | return x |
| |
|
| | def unpatchify_raster(x, p, target_shape): |
| | B, N, C = x.shape |
| | H, W = target_shape |
| | |
| | h_patches = H // p |
| | w_patches = W // p |
| | |
| | x = x.view(B, h_patches, w_patches, p, p, C) |
| | |
| | x = x.permute(0, 5, 1, 3, 2, 4).contiguous() |
| | |
| | x = x.reshape(B, C, H, W) |
| | |
| | return x |
| |
|
| | def patchify_raster_2d(x: torch.Tensor, p: int, H: int, W: int) -> torch.Tensor: |
| | N, C1, C2 = x.shape |
| | |
| | assert N == H * W, f"N ({N}) must equal H*W ({H*W})" |
| | assert H % p == 0 and W % p == 0, f"H/W ({H}/{W}) must be divisible by patch size {p}" |
| | |
| | C_prime = C1 * C2 |
| | x_flat = x.view(N, C_prime) |
| | |
| | x_2d = x_flat.view(H, W, C_prime) |
| | |
| | h_patches = H // p |
| | w_patches = W // p |
| | |
| | x_split = x_2d.view(h_patches, p, w_patches, p, C_prime) |
| | |
| | x_permuted = x_split.permute(0, 2, 1, 3, 4) |
| | |
| | x_reordered = x_permuted.reshape(N, C_prime) |
| | |
| | out = x_reordered.view(N, C1, C2) |
| | |
| | return out |