flickr8k-backend / core /sample_ddim.py
Rohan3's picture
Updated: removed sampling using std_uncond
cf8a031
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from diffusers import DDIMScheduler, DDPMPipeline, DDPMScheduler
import torch
from torchvision.utils import save_image
import os
from config import *
from tqdm import tqdm
from unet import Unet
from text_embeddings import get_embedding_model, get_text_embedding
from vae import *
from seed import seed_everything
@torch.no_grad()
def ddim_sample(unet, noise_scheduler, shape, null_embedding=None, x_start=None, embedding=None, guidance_scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps, eta=0.0, device="cuda"): # eta=0 -> deterministic DDIM
x = x_start.clone() if x_start is not None else torch.randn(shape, device=device)
# x = torch.randn(shape, device=device) # start from pure noise
noise_scheduler.set_timesteps(num_steps) # set DDIM timesteps
embedding = embedding.expand(x.shape[0], -1, -1) if embedding is not None else None
if null_embedding is None: null_embedding = torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0) # [1, 77, 1024]
null_embedding = null_embedding.expand(x.shape[0], -1, -1)
for t in tqdm(noise_scheduler.timesteps, desc=f"Sampling timesteps: ", colour=tqdm_colors[-1]):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
if guidance_scale > 1 and embedding is not None:
out_uncond, out_cond = unet(torch.cat([x, x]), torch.cat([t_batch, t_batch]), torch.cat([null_embedding, embedding])).chunk(2)
# out_uncond = unet(x, t_batch, null_embedding); out_cond = unet(x, t_batch, embedding)
model_out = out_uncond + guidance_scale * (out_cond - out_uncond)
# Guidance Rescale (using both stds)
# std_uncond = out_uncond.std()
std_cond = out_cond.std()
std_cfg = model_out.std()
# target_std = max(std_uncond, std_cond) # Target std is usually the max of the two individual predictions or just the conditional std.
# factor = target_std / std_cfg # Factor to bring the CFG output back into range
factor = std_cond / std_cfg.clamp(min=1e-8)
model_out_rescaled = model_out * factor
phi = 0.7 # Final blend
model_out = phi * model_out_rescaled + (1 - phi) * model_out
elif guidance_scale <= 1 and embedding is not None: model_out = unet(x, t_batch, embedding)
else: model_out = unet(x, t_batch, null_embedding)
x = noise_scheduler.step(model_output=model_out, timestep=t, sample=x, eta=eta).prev_sample # DDIM step
return x
def gen_n_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
seed_everything(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if unet == None:
unet = Unet().to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
unet.load_state_dict(checkpoint["ema"], strict=True)
unet.eval()
text_emb = torch.load(f"{embedding_dir}/101654506_8eb26cfb60_3.pt", map_location=device, weights_only=True).unsqueeze(0)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type, # epsilon, v_prediction or sample
rescale_betas_zero_snr=True, # Match training
timestep_spacing="trailing", # Start from pure noise
clip_sample=False,
set_alpha_to_one=False
)
latent_samples = ddim_sample(
unet=unet,
noise_scheduler=noise_scheduler,
shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
embedding=text_emb,
guidance_scale=scale,
num_steps=num_steps,
eta=0.0,
device=device
)
# latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
latent_samples = (latent_samples * latent_std) # Reverse scale the latents
vae = VAE().to(device)
path = f"{vae_checkpoint_dir}/{vae_weight}"
checkpoint = torch.load(path, map_location=device, weights_only=True)
vae.load_state_dict(checkpoint["vae"])
vae.eval()
recon_images = vae.decode_latent_to_img(latent_samples)
os.makedirs("./backend/core/ddim_sampled_img", exist_ok=True)
save_image(recon_images, f"./backend/core/ddim_sampled_img/101654506_8eb26cfb60_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=recon_images.size(0))
print("💾 Successfully saved sampled images")
def gen_val_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
seed_everything(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if unet == None:
unet = Unet().to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
unet.load_state_dict(checkpoint["ema"], strict=True)
unet.eval()
val_embeddings = torch.cat([torch.load(f"{unet_val_embeddings_dir}/{path}", map_location=device, weights_only=True).unsqueeze(0) for path in os.listdir(unet_val_embeddings_dir)]).to(device)
# val_noises = torch.randn((n, vae_latent_channels, vae_latent_dim, vae_latent_dim), device=device)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type, # epsilon, v_prediction or sample
rescale_betas_zero_snr=True, # Match training
timestep_spacing="trailing", # Start from pure noise
clip_sample=False,
set_alpha_to_one=False
)
vae = VAE().to(device)
path = f"{vae_checkpoint_dir}/{vae_weight}"
checkpoint = torch.load(path, map_location=device, weights_only=True)
vae.load_state_dict(checkpoint["vae"])
vae.eval()
img_rows = []
for i in range(4):
latent_samples = ddim_sample(
unet=unet,
noise_scheduler=noise_scheduler,
shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
# x_start=val_noises,
x_start=None,
embedding=val_embeddings[i:i+1],
guidance_scale=scale,
num_steps=num_steps,
eta=0.0,
device=device
)
# latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
latent_samples = (latent_samples * latent_std) # Reverse scale the latents
recon_images = vae.decode_latent_to_img(latent_samples)
img_rows.append(recon_images)
os.makedirs(ddim_img_dir, exist_ok=True)
save_image(torch.cat(img_rows, dim=0), f"{ddim_img_dir}/val_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=n)
print("💾 Successfully saved validation grid images")
if __name__=="__main__":
checkpoint_path = "./backend/core/checkpoints/ldm/ema_loss_0.220238.pth"
gen_n_sampled_img(checkpoint_path=checkpoint_path, n=16, scale=8, num_steps=100)
gen_val_sampled_img(checkpoint_path=checkpoint_path, n=4, scale=8, num_steps=100)