import logging import torch import importlib from noise_schedulers import NoiseScheduleVP from omegaconf import OmegaConf from torch.utils.checkpoint import checkpoint def expand_dims(v, dims): """ Expand the tensor `v` to the dim `dims`. Args: `v`: a PyTorch tensor with shape [N]. `dim`: a `int`. Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ return v[(...,) + (None,) * (dims - 1)] def model_wrapper( model, noise_schedule, model_type="noise", model_kwargs={}, guidance_type="uncond", guidance_scale=1.0, classifier_fn=None, classifier_kwargs={}, use_checkpoint=False, ): def get_model_input_time(t_continuous): """ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. For continuous-time DPMs, we just use `t_continuous`. """ if noise_schedule.schedule == "discrete": return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 else: return t_continuous def noise_pred_fn(x, t_continuous, cond=None): if t_continuous.reshape((-1,)).shape[0] == 1: t_continuous = t_continuous.expand((x.shape[0])) t_input = get_model_input_time(t_continuous) if use_checkpoint: output = checkpoint(model, x, t_input, cond, **model_kwargs) else: output = model(x, t_input, cond, **model_kwargs) if model_type == "noise": return output elif model_type == "x_start": alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) dims = x.dim() return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) elif model_type == "v": alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) dims = x.dim() return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x elif model_type == "score": sigma_t = noise_schedule.marginal_std(t_continuous) dims = x.dim() return -expand_dims(sigma_t, dims) * output def cond_grad_fn(x, t_input, condition): """ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). """ with torch.enable_grad(): x_in = x.detach().requires_grad_(True) log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) return torch.autograd.grad(log_prob.sum(), x_in)[0] def model_fn(x, t_continuous, condition=None, unconditional_condition=None, *args, **kwargs): """ The noise predicition model function that is used for DPM-Solver. """ if t_continuous.reshape((-1,)).shape[0] == 1: t_continuous = t_continuous.expand((x.shape[0])) if guidance_type == "uncond": return noise_pred_fn(x, t_continuous) elif guidance_type == "classifier": assert classifier_fn is not None assert condition is not None t_input = get_model_input_time(t_continuous) cond_grad = cond_grad_fn(x, t_input, condition) sigma_t = noise_schedule.marginal_std(t_continuous) noise = noise_pred_fn(x, t_continuous) return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad elif guidance_type == "classifier-free": if guidance_scale == 1 or unconditional_condition is None: assert condition is not None return noise_pred_fn(x, t_continuous, cond=condition) else: assert condition is not None and unconditional_condition is not None x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) assert model_type in ["noise", "x_start", "v"] assert guidance_type in ["uncond", "classifier", "classifier-free"] return model_fn def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if "target" not in config: if config == '__is_first_stage__' or config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def load_model_from_config(config, ckp_path, verbose=False): # DONE! ''' checking this! Done! ''' logging.info(f"Loading model from {ckp_path}") pl_sd = torch.load(ckp_path, map_location="cpu") if "global_step" in pl_sd: logging.info(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: logging.info("missing keys:") logging.info(m) if len(u) > 0 and verbose: logging.info("unexpected keys:") logging.info(u) model.cuda() model.eval() # for param in model.parameters(): # param.requires_grad = False logging.info("Model loaded from {}".format(ckp_path)) return model def load_ema_weights(model): model.model_ema.store(model.model.parameters()) model.model_ema.copy_to(model.model) def get_pretrained_ldm_model(args): config = OmegaConf.load(args.config) model = load_model_from_config(config, args.ckp_path) if args.use_ema: load_ema_weights(model) for param in model.parameters(): param.requires_grad = False noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=model.alphas_cumprod) noise_schedule.lambda_min = noise_schedule.marginal_lambda(noise_schedule.T).item() noise_schedule.lambda_max = noise_schedule.marginal_lambda(1.0 / noise_schedule.total_N).item() model_fn = model_wrapper( lambda x, t, c: model.apply_model(x, t, c), noise_schedule, model_type="noise", guidance_type="uncond", use_checkpoint=args.low_gpu, ) return model_fn, model, model.decode_first_stage, noise_schedule, args.H // args.f, args.C, args.H, 3 def get_pretrained_conditioned_ldm_model(args): config = OmegaConf.load(args.config) model = load_model_from_config(config, args.ckp_path) noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=model.alphas_cumprod) noise_schedule.lambda_min = noise_schedule.marginal_lambda(noise_schedule.T).item() noise_schedule.lambda_max = noise_schedule.marginal_lambda(1.0 / noise_schedule.total_N).item() model_fn = model_wrapper( lambda x, t, c: model.apply_model(x, t, c), noise_schedule, model_type="noise", guidance_type="classifier-free", guidance_scale=args.scale, use_checkpoint=args.low_gpu, ) return model_fn, model, model.decode_first_stage, noise_schedule, args.H // args.f, args.C, args.H, 3