Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import logging | |
| from collections import OrderedDict | |
| from PIL import Image | |
| def requires_grad(model, flag=True): | |
| """ | |
| Set requires_grad flag for all parameters in a model. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| def create_logger(logging_dir): | |
| """ | |
| Create a logger that writes to a log file and stdout. | |
| """ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| return logger | |
| def update_ema(ema_model, model, decay=0.9999): | |
| """ | |
| Step the EMA model towards the current model. | |
| """ | |
| ema_params = OrderedDict(ema_model.named_parameters()) | |
| model_params = OrderedDict(model.named_parameters()) | |
| for name, param in model_params.items(): | |
| name = name.replace("module.", "") | |
| # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |
| def center_crop_arr(pil_image, image_size): | |
| """ | |
| Center cropping implementation from ADM. | |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
| """ | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) | |
| def load_model(ckpt_name): | |
| """ | |
| Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. | |
| """ | |
| # Load a custom DiT checkpoint: | |
| assert os.path.isfile(ckpt_name), f'Could not find DiT checkpoint at {ckpt_name}' | |
| checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage) | |
| if "ema" in checkpoint: # supports checkpoints from train.py | |
| checkpoint = checkpoint["ema"] | |
| return checkpoint |