| import torch | |
| import pickle | |
| from noise_schedulers import NoiseScheduleVE | |
| from torch.utils.checkpoint import checkpoint | |
| def model_wrapper(model, noise_schedule, class_labels=None, use_checkpoint=False): | |
| ''' | |
| always return a model that predicting noise! | |
| ''' | |
| def noise_pred_fn(x, t_continuous, cond=None): | |
| t_input = t_continuous | |
| if use_checkpoint: | |
| output = checkpoint(model, x, t_input, cond) | |
| else: | |
| output = model(x, t_input, cond) | |
| alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) | |
| return (x - alpha_t[:, None, None, None] * output) / sigma_t[:, None, None, None] | |
| def model_fn(x, t_continuous, *args, **kwargs): | |
| return noise_pred_fn(x, t_continuous, class_labels).to(torch.float64) | |
| return model_fn | |
| def get_pretrained_sde_model(args, requires_grad=False): | |
| ''' | |
| checked! | |
| ''' | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| with open(args.ckp_path, "rb") as f: | |
| net = pickle.load(f)["ema"].to(device) | |
| if not requires_grad: | |
| for param in net.parameters(): | |
| param.requires_grad = False | |
| noise_schedule = NoiseScheduleVE(schedule='edm') | |
| return model_wrapper(net, noise_schedule), net, lambda x: x, noise_schedule, net.img_resolution, net.img_channels, net.img_resolution, net.img_channels |