File size: 4,476 Bytes
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a625e96
4aabce3
87b5061
 
 
a625e96
4aabce3
 
 
 
 
 
 
 
 
a625e96
4aabce3
 
 
 
87b5061
 
4aabce3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import open_clip
import os
from diffusers import DDIMScheduler
from core.unet import Unet
from core.vae import VAE
from core.sample_ddim import ddim_sample
from core.config import *
from core.seed import seed_everything
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image

class LDMPipeline:
    def __init__(self, device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_models()

    def _load_models(self):
        print("Loading UNET...")
        self.unet = Unet().to(self.device)
        unet_path = hf_hub_download(
            repo_id="Rohan3/flickr8k-ldm-weights",
            subfolder="ldm",
            filename="best_ema.pth",
            token=os.getenv("HF_TOKEN")
        )
        checkpoint = torch.load(unet_path, map_location=self.device, weights_only=True)
        self.unet.load_state_dict(checkpoint["ema"], strict=True)
        self.unet.eval()

        print("Loading VAE...")
        self.vae = VAE().to(self.device)
        vae_path = hf_hub_download(
            repo_id="Rohan3/flickr8k-ldm-weights",
            subfolder="vae",
            filename="vae_best.pth",
            token=os.getenv("HF_TOKEN")
        )
        vae_ckpt = torch.load(vae_path, map_location=self.device, weights_only=True)
        self.vae.load_state_dict(vae_ckpt["vae"])
        self.vae.eval()

        print("Loading CLIP...")
        self.text_model, _, _ = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=self.device)
        self.text_model.eval()
        self.text_model.transformer.batch_first = False
        for p in self.text_model.parameters(): p.requires_grad = False
        self.tokenizer = open_clip.get_tokenizer(embedding_model)

        print("Loading NULL EMBEDDING...")
        null_embedding_path = hf_hub_download(
            repo_id="Rohan3/flickr8k-ldm-weights",
            subfolder="null_embedding",
            filename="null_embedding.pt",
            token=os.getenv("HF_TOKEN")
        )
        # null_embedding_path = os.path.join(os.path.dirname(__file__), "null_embedding.pt")
        self.null_embedding = torch.load(null_embedding_path, map_location=self.device, weights_only=True).unsqueeze(0)

        print("Loading DDIM SCHEDULER...")
        self.noise_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_schedule=unet_beta_schedule,
            prediction_type=unet_pred_type,
            rescale_betas_zero_snr=True,
            timestep_spacing="trailing",
            clip_sample=False,
            set_alpha_to_one=False
        )
        print("All models loaded...")

    @torch.no_grad()
    def get_text_embedding(self, caption: str):
        tokens = self.tokenizer(caption).to(self.device)
        x = self.text_model.token_embedding(tokens)
        x = x + self.text_model.positional_embedding
        x = x.permute(1, 0, 2) # (L, N, D) for transformer
        seq_len = x.shape[0]
        mask = torch.empty(seq_len, seq_len, device=x.device)
        mask.fill_(float("-inf"))
        mask.triu_(1)
        x = self.text_model.transformer(x, attn_mask=mask) # Applying mask
        x = x.permute(1, 0, 2) # back to (N, L, D)
        per_token_contextual = self.text_model.ln_final(x) # (B, T, D) = (1, 77, 1024)
        return per_token_contextual.squeeze(0) # (77, 1024)`

    def generate(self, caption: str, num_images: int = 4, num_steps: int = 50, guidance_scale: float = 7.5, seed: int = 42, eta: float = 0):
        seed_everything(seed)
        caption = caption.strip()
        if caption.endswith("."):
            caption = caption.rstrip(".")
        # caption = caption.lower()
        embedding = self.get_text_embedding(caption).unsqueeze(0)
        latents = ddim_sample(
            unet=self.unet,
            noise_scheduler=self.noise_scheduler,
            shape=(num_images, vae_latent_channels, vae_latent_dim, vae_latent_dim),
            null_embedding=self.null_embedding,
            embedding=embedding,
            guidance_scale=guidance_scale,
            num_steps=num_steps,
            eta=eta,
            device=self.device
        )
        latents = latents * latent_std
        images = self.vae.decode_latent_to_img(latents)  # (B, C, H, W)
        # os.makedirs("./test", exist_ok=True)
        # save_image(images, f"./test/haha.png",nrow=images.size(0))
        return images