wjnwjn59's picture
update model
2af49ec
raw
history blame contribute delete
894 Bytes
import torch
from diffusers import AutoencoderKL
from ddpm import DDPMScheduler
from clip import CLIPTextEncoder
from diffusion import Diffusion
device = "cpu"
def load_model():
vae_weight_path = "stabilityai/sd-vae-ft-mse"
vae = AutoencoderKL.from_pretrained(vae_weight_path)
vae.requires_grad_(False)
vae.eval()
vae = vae.to(device)
clip_weight_path = "openai/clip-vit-base-patch32"
clip = CLIPTextEncoder(clip_weight_path=clip_weight_path).to(device)
clip.eval()
h_dim = 320
n_head = 8
diffusion = Diffusion(h_dim, n_head).to(device)
diffusion.load_state_dict(
torch.load("./weights/emoji_diffusion_model.pth",
map_location=device))
diffusion.eval()
random_generator = torch.Generator(device=device)
noise_scheduler = DDPMScheduler(random_generator)
return vae, clip, diffusion, noise_scheduler