LD3 / models /edm_uncond.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
import pickle
from noise_schedulers import NoiseScheduleVE
from torch.utils.checkpoint import checkpoint
def model_wrapper(model, noise_schedule, class_labels=None, use_checkpoint=False):
'''
always return a model that predicting noise!
'''
def noise_pred_fn(x, t_continuous, cond=None):
t_input = t_continuous
if use_checkpoint:
output = checkpoint(model, x, t_input, cond)
else:
output = model(x, t_input, cond)
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return (x - alpha_t[:, None, None, None] * output) / sigma_t[:, None, None, None]
def model_fn(x, t_continuous, *args, **kwargs):
return noise_pred_fn(x, t_continuous, class_labels).to(torch.float64)
return model_fn
def get_pretrained_sde_model(args, requires_grad=False):
'''
checked!
'''
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
with open(args.ckp_path, "rb") as f:
net = pickle.load(f)["ema"].to(device)
if not requires_grad:
for param in net.parameters():
param.requires_grad = False
noise_schedule = NoiseScheduleVE(schedule='edm')
return model_wrapper(net, noise_schedule), net, lambda x: x, noise_schedule, net.img_resolution, net.img_channels, net.img_resolution, net.img_channels