Spaces:
Sleeping
Sleeping
| import torch | |
| import open_clip | |
| import os | |
| from diffusers import DDIMScheduler | |
| from core.unet import Unet | |
| from core.vae import VAE | |
| from core.sample_ddim import ddim_sample | |
| from core.config import * | |
| from core.seed import seed_everything | |
| from huggingface_hub import hf_hub_download | |
| from torchvision.utils import save_image | |
| class LDMPipeline: | |
| def __init__(self, device=None): | |
| self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._load_models() | |
| def _load_models(self): | |
| print("Loading UNET...") | |
| self.unet = Unet().to(self.device) | |
| unet_path = hf_hub_download( | |
| repo_id="Rohan3/flickr8k-ldm-weights", | |
| subfolder="ldm", | |
| filename="best_ema.pth", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| checkpoint = torch.load(unet_path, map_location=self.device, weights_only=True) | |
| self.unet.load_state_dict(checkpoint["ema"], strict=True) | |
| self.unet.eval() | |
| print("Loading VAE...") | |
| self.vae = VAE().to(self.device) | |
| vae_path = hf_hub_download( | |
| repo_id="Rohan3/flickr8k-ldm-weights", | |
| subfolder="vae", | |
| filename="vae_best.pth", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| vae_ckpt = torch.load(vae_path, map_location=self.device, weights_only=True) | |
| self.vae.load_state_dict(vae_ckpt["vae"]) | |
| self.vae.eval() | |
| print("Loading CLIP...") | |
| self.text_model, _, _ = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=self.device) | |
| self.text_model.eval() | |
| self.text_model.transformer.batch_first = False | |
| for p in self.text_model.parameters(): p.requires_grad = False | |
| self.tokenizer = open_clip.get_tokenizer(embedding_model) | |
| print("Loading NULL EMBEDDING...") | |
| null_embedding_path = hf_hub_download( | |
| repo_id="Rohan3/flickr8k-ldm-weights", | |
| subfolder="null_embedding", | |
| filename="null_embedding.pt", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| # null_embedding_path = os.path.join(os.path.dirname(__file__), "null_embedding.pt") | |
| self.null_embedding = torch.load(null_embedding_path, map_location=self.device, weights_only=True).unsqueeze(0) | |
| print("Loading DDIM SCHEDULER...") | |
| self.noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_schedule=unet_beta_schedule, | |
| prediction_type=unet_pred_type, | |
| rescale_betas_zero_snr=True, | |
| timestep_spacing="trailing", | |
| clip_sample=False, | |
| set_alpha_to_one=False | |
| ) | |
| print("All models loaded...") | |
| def get_text_embedding(self, caption: str): | |
| tokens = self.tokenizer(caption).to(self.device) | |
| x = self.text_model.token_embedding(tokens) | |
| x = x + self.text_model.positional_embedding | |
| x = x.permute(1, 0, 2) # (L, N, D) for transformer | |
| seq_len = x.shape[0] | |
| mask = torch.empty(seq_len, seq_len, device=x.device) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) | |
| x = self.text_model.transformer(x, attn_mask=mask) # Applying mask | |
| x = x.permute(1, 0, 2) # back to (N, L, D) | |
| per_token_contextual = self.text_model.ln_final(x) # (B, T, D) = (1, 77, 1024) | |
| return per_token_contextual.squeeze(0) # (77, 1024)` | |
| def generate(self, caption: str, num_images: int = 4, num_steps: int = 50, guidance_scale: float = 7.5, seed: int = 42, eta: float = 0): | |
| seed_everything(seed) | |
| caption = caption.strip() | |
| if caption.endswith("."): | |
| caption = caption.rstrip(".") | |
| # caption = caption.lower() | |
| embedding = self.get_text_embedding(caption).unsqueeze(0) | |
| latents = ddim_sample( | |
| unet=self.unet, | |
| noise_scheduler=self.noise_scheduler, | |
| shape=(num_images, vae_latent_channels, vae_latent_dim, vae_latent_dim), | |
| null_embedding=self.null_embedding, | |
| embedding=embedding, | |
| guidance_scale=guidance_scale, | |
| num_steps=num_steps, | |
| eta=eta, | |
| device=self.device | |
| ) | |
| latents = latents * latent_std | |
| images = self.vae.decode_latent_to_img(latents) # (B, C, H, W) | |
| # os.makedirs("./test", exist_ok=True) | |
| # save_image(images, f"./test/haha.png",nrow=images.size(0)) | |
| return images |