flickr8k-backend / app /model.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import torch
import open_clip
import os
from diffusers import DDIMScheduler
from core.unet import Unet
from core.vae import VAE
from core.sample_ddim import ddim_sample
from core.config import *
from core.seed import seed_everything
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
class LDMPipeline:
def __init__(self, device=None):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_models()
def _load_models(self):
print("Loading UNET...")
self.unet = Unet().to(self.device)
unet_path = hf_hub_download(
repo_id="Rohan3/flickr8k-ldm-weights",
subfolder="ldm",
filename="best_ema.pth",
token=os.getenv("HF_TOKEN")
)
checkpoint = torch.load(unet_path, map_location=self.device, weights_only=True)
self.unet.load_state_dict(checkpoint["ema"], strict=True)
self.unet.eval()
print("Loading VAE...")
self.vae = VAE().to(self.device)
vae_path = hf_hub_download(
repo_id="Rohan3/flickr8k-ldm-weights",
subfolder="vae",
filename="vae_best.pth",
token=os.getenv("HF_TOKEN")
)
vae_ckpt = torch.load(vae_path, map_location=self.device, weights_only=True)
self.vae.load_state_dict(vae_ckpt["vae"])
self.vae.eval()
print("Loading CLIP...")
self.text_model, _, _ = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=self.device)
self.text_model.eval()
self.text_model.transformer.batch_first = False
for p in self.text_model.parameters(): p.requires_grad = False
self.tokenizer = open_clip.get_tokenizer(embedding_model)
print("Loading NULL EMBEDDING...")
null_embedding_path = hf_hub_download(
repo_id="Rohan3/flickr8k-ldm-weights",
subfolder="null_embedding",
filename="null_embedding.pt",
token=os.getenv("HF_TOKEN")
)
# null_embedding_path = os.path.join(os.path.dirname(__file__), "null_embedding.pt")
self.null_embedding = torch.load(null_embedding_path, map_location=self.device, weights_only=True).unsqueeze(0)
print("Loading DDIM SCHEDULER...")
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type,
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
clip_sample=False,
set_alpha_to_one=False
)
print("All models loaded...")
@torch.no_grad()
def get_text_embedding(self, caption: str):
tokens = self.tokenizer(caption).to(self.device)
x = self.text_model.token_embedding(tokens)
x = x + self.text_model.positional_embedding
x = x.permute(1, 0, 2) # (L, N, D) for transformer
seq_len = x.shape[0]
mask = torch.empty(seq_len, seq_len, device=x.device)
mask.fill_(float("-inf"))
mask.triu_(1)
x = self.text_model.transformer(x, attn_mask=mask) # Applying mask
x = x.permute(1, 0, 2) # back to (N, L, D)
per_token_contextual = self.text_model.ln_final(x) # (B, T, D) = (1, 77, 1024)
return per_token_contextual.squeeze(0) # (77, 1024)`
def generate(self, caption: str, num_images: int = 4, num_steps: int = 50, guidance_scale: float = 7.5, seed: int = 42, eta: float = 0):
seed_everything(seed)
caption = caption.strip()
if caption.endswith("."):
caption = caption.rstrip(".")
# caption = caption.lower()
embedding = self.get_text_embedding(caption).unsqueeze(0)
latents = ddim_sample(
unet=self.unet,
noise_scheduler=self.noise_scheduler,
shape=(num_images, vae_latent_channels, vae_latent_dim, vae_latent_dim),
null_embedding=self.null_embedding,
embedding=embedding,
guidance_scale=guidance_scale,
num_steps=num_steps,
eta=eta,
device=self.device
)
latents = latents * latent_std
images = self.vae.decode_latent_to_img(latents) # (B, C, H, W)
# os.makedirs("./test", exist_ok=True)
# save_image(images, f"./test/haha.png",nrow=images.size(0))
return images