import sys, os sys.path.insert(0, os.path.dirname(__file__)) from diffusers import DDIMScheduler, DDPMPipeline, DDPMScheduler import torch from torchvision.utils import save_image import os from config import * from tqdm import tqdm from unet import Unet from text_embeddings import get_embedding_model, get_text_embedding from vae import * from seed import seed_everything @torch.no_grad() def ddim_sample(unet, noise_scheduler, shape, null_embedding=None, x_start=None, embedding=None, guidance_scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps, eta=0.0, device="cuda"): # eta=0 -> deterministic DDIM x = x_start.clone() if x_start is not None else torch.randn(shape, device=device) # x = torch.randn(shape, device=device) # start from pure noise noise_scheduler.set_timesteps(num_steps) # set DDIM timesteps embedding = embedding.expand(x.shape[0], -1, -1) if embedding is not None else None if null_embedding is None: null_embedding = torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0) # [1, 77, 1024] null_embedding = null_embedding.expand(x.shape[0], -1, -1) for t in tqdm(noise_scheduler.timesteps, desc=f"Sampling timesteps: ", colour=tqdm_colors[-1]): t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long) if guidance_scale > 1 and embedding is not None: out_uncond, out_cond = unet(torch.cat([x, x]), torch.cat([t_batch, t_batch]), torch.cat([null_embedding, embedding])).chunk(2) # out_uncond = unet(x, t_batch, null_embedding); out_cond = unet(x, t_batch, embedding) model_out = out_uncond + guidance_scale * (out_cond - out_uncond) # Guidance Rescale (using both stds) # std_uncond = out_uncond.std() std_cond = out_cond.std() std_cfg = model_out.std() # target_std = max(std_uncond, std_cond) # Target std is usually the max of the two individual predictions or just the conditional std. # factor = target_std / std_cfg # Factor to bring the CFG output back into range factor = std_cond / std_cfg.clamp(min=1e-8) model_out_rescaled = model_out * factor phi = 0.7 # Final blend model_out = phi * model_out_rescaled + (1 - phi) * model_out elif guidance_scale <= 1 and embedding is not None: model_out = unet(x, t_batch, embedding) else: model_out = unet(x, t_batch, null_embedding) x = noise_scheduler.step(model_output=model_out, timestep=t, sample=x, eta=eta).prev_sample # DDIM step return x def gen_n_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps): seed_everything(seed=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if unet == None: unet = Unet().to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) unet.load_state_dict(checkpoint["ema"], strict=True) unet.eval() text_emb = torch.load(f"{embedding_dir}/101654506_8eb26cfb60_3.pt", map_location=device, weights_only=True).unsqueeze(0) noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_schedule=unet_beta_schedule, prediction_type=unet_pred_type, # epsilon, v_prediction or sample rescale_betas_zero_snr=True, # Match training timestep_spacing="trailing", # Start from pure noise clip_sample=False, set_alpha_to_one=False ) latent_samples = ddim_sample( unet=unet, noise_scheduler=noise_scheduler, shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim), embedding=text_emb, guidance_scale=scale, num_steps=num_steps, eta=0.0, device=device ) # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents latent_samples = (latent_samples * latent_std) # Reverse scale the latents vae = VAE().to(device) path = f"{vae_checkpoint_dir}/{vae_weight}" checkpoint = torch.load(path, map_location=device, weights_only=True) vae.load_state_dict(checkpoint["vae"]) vae.eval() recon_images = vae.decode_latent_to_img(latent_samples) os.makedirs("./backend/core/ddim_sampled_img", exist_ok=True) save_image(recon_images, f"./backend/core/ddim_sampled_img/101654506_8eb26cfb60_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=recon_images.size(0)) print("💾 Successfully saved sampled images") def gen_val_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps): seed_everything(seed=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if unet == None: unet = Unet().to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) unet.load_state_dict(checkpoint["ema"], strict=True) unet.eval() val_embeddings = torch.cat([torch.load(f"{unet_val_embeddings_dir}/{path}", map_location=device, weights_only=True).unsqueeze(0) for path in os.listdir(unet_val_embeddings_dir)]).to(device) # val_noises = torch.randn((n, vae_latent_channels, vae_latent_dim, vae_latent_dim), device=device) noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_schedule=unet_beta_schedule, prediction_type=unet_pred_type, # epsilon, v_prediction or sample rescale_betas_zero_snr=True, # Match training timestep_spacing="trailing", # Start from pure noise clip_sample=False, set_alpha_to_one=False ) vae = VAE().to(device) path = f"{vae_checkpoint_dir}/{vae_weight}" checkpoint = torch.load(path, map_location=device, weights_only=True) vae.load_state_dict(checkpoint["vae"]) vae.eval() img_rows = [] for i in range(4): latent_samples = ddim_sample( unet=unet, noise_scheduler=noise_scheduler, shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim), # x_start=val_noises, x_start=None, embedding=val_embeddings[i:i+1], guidance_scale=scale, num_steps=num_steps, eta=0.0, device=device ) # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents latent_samples = (latent_samples * latent_std) # Reverse scale the latents recon_images = vae.decode_latent_to_img(latent_samples) img_rows.append(recon_images) os.makedirs(ddim_img_dir, exist_ok=True) save_image(torch.cat(img_rows, dim=0), f"{ddim_img_dir}/val_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=n) print("💾 Successfully saved validation grid images") if __name__=="__main__": checkpoint_path = "./backend/core/checkpoints/ldm/ema_loss_0.220238.pth" gen_n_sampled_img(checkpoint_path=checkpoint_path, n=16, scale=8, num_steps=100) gen_val_sampled_img(checkpoint_path=checkpoint_path, n=4, scale=8, num_steps=100)