import torch from diffusers import AutoencoderKL from ddpm import DDPMScheduler from clip import CLIPTextEncoder from diffusion import Diffusion device = "cpu" def load_model(): vae_weight_path = "stabilityai/sd-vae-ft-mse" vae = AutoencoderKL.from_pretrained(vae_weight_path) vae.requires_grad_(False) vae.eval() vae = vae.to(device) clip_weight_path = "openai/clip-vit-base-patch32" clip = CLIPTextEncoder(clip_weight_path=clip_weight_path).to(device) clip.eval() h_dim = 320 n_head = 8 diffusion = Diffusion(h_dim, n_head).to(device) diffusion.load_state_dict( torch.load("./weights/emoji_diffusion_model.pth", map_location=device)) diffusion.eval() random_generator = torch.Generator(device=device) noise_scheduler = DDPMScheduler(random_generator) return vae, clip, diffusion, noise_scheduler