import torch import open_clip import os from diffusers import DDIMScheduler from core.unet import Unet from core.vae import VAE from core.sample_ddim import ddim_sample from core.config import * from core.seed import seed_everything from huggingface_hub import hf_hub_download from torchvision.utils import save_image class LDMPipeline: def __init__(self, device=None): self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self._load_models() def _load_models(self): print("Loading UNET...") self.unet = Unet().to(self.device) unet_path = hf_hub_download( repo_id="Rohan3/flickr8k-ldm-weights", subfolder="ldm", filename="best_ema.pth", token=os.getenv("HF_TOKEN") ) checkpoint = torch.load(unet_path, map_location=self.device, weights_only=True) self.unet.load_state_dict(checkpoint["ema"], strict=True) self.unet.eval() print("Loading VAE...") self.vae = VAE().to(self.device) vae_path = hf_hub_download( repo_id="Rohan3/flickr8k-ldm-weights", subfolder="vae", filename="vae_best.pth", token=os.getenv("HF_TOKEN") ) vae_ckpt = torch.load(vae_path, map_location=self.device, weights_only=True) self.vae.load_state_dict(vae_ckpt["vae"]) self.vae.eval() print("Loading CLIP...") self.text_model, _, _ = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=self.device) self.text_model.eval() self.text_model.transformer.batch_first = False for p in self.text_model.parameters(): p.requires_grad = False self.tokenizer = open_clip.get_tokenizer(embedding_model) print("Loading NULL EMBEDDING...") null_embedding_path = hf_hub_download( repo_id="Rohan3/flickr8k-ldm-weights", subfolder="null_embedding", filename="null_embedding.pt", token=os.getenv("HF_TOKEN") ) # null_embedding_path = os.path.join(os.path.dirname(__file__), "null_embedding.pt") self.null_embedding = torch.load(null_embedding_path, map_location=self.device, weights_only=True).unsqueeze(0) print("Loading DDIM SCHEDULER...") self.noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_schedule=unet_beta_schedule, prediction_type=unet_pred_type, rescale_betas_zero_snr=True, timestep_spacing="trailing", clip_sample=False, set_alpha_to_one=False ) print("All models loaded...") @torch.no_grad() def get_text_embedding(self, caption: str): tokens = self.tokenizer(caption).to(self.device) x = self.text_model.token_embedding(tokens) x = x + self.text_model.positional_embedding x = x.permute(1, 0, 2) # (L, N, D) for transformer seq_len = x.shape[0] mask = torch.empty(seq_len, seq_len, device=x.device) mask.fill_(float("-inf")) mask.triu_(1) x = self.text_model.transformer(x, attn_mask=mask) # Applying mask x = x.permute(1, 0, 2) # back to (N, L, D) per_token_contextual = self.text_model.ln_final(x) # (B, T, D) = (1, 77, 1024) return per_token_contextual.squeeze(0) # (77, 1024)` def generate(self, caption: str, num_images: int = 4, num_steps: int = 50, guidance_scale: float = 7.5, seed: int = 42, eta: float = 0): seed_everything(seed) caption = caption.strip() if caption.endswith("."): caption = caption.rstrip(".") # caption = caption.lower() embedding = self.get_text_embedding(caption).unsqueeze(0) latents = ddim_sample( unet=self.unet, noise_scheduler=self.noise_scheduler, shape=(num_images, vae_latent_channels, vae_latent_dim, vae_latent_dim), null_embedding=self.null_embedding, embedding=embedding, guidance_scale=guidance_scale, num_steps=num_steps, eta=eta, device=self.device ) latents = latents * latent_std images = self.vae.decode_latent_to_img(latents) # (B, C, H, W) # os.makedirs("./test", exist_ok=True) # save_image(images, f"./test/haha.png",nrow=images.size(0)) return images