| import torch | |
| from diffusers import AutoencoderKL | |
| from ddpm import DDPMScheduler | |
| from clip import CLIPTextEncoder | |
| from diffusion import Diffusion | |
| device = "cpu" | |
| def load_model(): | |
| vae_weight_path = "stabilityai/sd-vae-ft-mse" | |
| vae = AutoencoderKL.from_pretrained(vae_weight_path) | |
| vae.requires_grad_(False) | |
| vae.eval() | |
| vae = vae.to(device) | |
| clip_weight_path = "openai/clip-vit-base-patch32" | |
| clip = CLIPTextEncoder(clip_weight_path=clip_weight_path).to(device) | |
| clip.eval() | |
| h_dim = 320 | |
| n_head = 8 | |
| diffusion = Diffusion(h_dim, n_head).to(device) | |
| diffusion.load_state_dict( | |
| torch.load("./weights/emoji_diffusion_model.pth", | |
| map_location=device)) | |
| diffusion.eval() | |
| random_generator = torch.Generator(device=device) | |
| noise_scheduler = DDPMScheduler(random_generator) | |
| return vae, clip, diffusion, noise_scheduler |