Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from diffusers import DDIMScheduler, DDPMPipeline, DDPMScheduler | |
| import torch | |
| from torchvision.utils import save_image | |
| import os | |
| from config import * | |
| from tqdm import tqdm | |
| from unet import Unet | |
| from text_embeddings import get_embedding_model, get_text_embedding | |
| from vae import * | |
| from seed import seed_everything | |
| def ddim_sample(unet, noise_scheduler, shape, null_embedding=None, x_start=None, embedding=None, guidance_scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps, eta=0.0, device="cuda"): # eta=0 -> deterministic DDIM | |
| x = x_start.clone() if x_start is not None else torch.randn(shape, device=device) | |
| # x = torch.randn(shape, device=device) # start from pure noise | |
| noise_scheduler.set_timesteps(num_steps) # set DDIM timesteps | |
| embedding = embedding.expand(x.shape[0], -1, -1) if embedding is not None else None | |
| if null_embedding is None: null_embedding = torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0) # [1, 77, 1024] | |
| null_embedding = null_embedding.expand(x.shape[0], -1, -1) | |
| for t in tqdm(noise_scheduler.timesteps, desc=f"Sampling timesteps: ", colour=tqdm_colors[-1]): | |
| t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long) | |
| if guidance_scale > 1 and embedding is not None: | |
| out_uncond, out_cond = unet(torch.cat([x, x]), torch.cat([t_batch, t_batch]), torch.cat([null_embedding, embedding])).chunk(2) | |
| # out_uncond = unet(x, t_batch, null_embedding); out_cond = unet(x, t_batch, embedding) | |
| model_out = out_uncond + guidance_scale * (out_cond - out_uncond) | |
| # Guidance Rescale (using both stds) | |
| # std_uncond = out_uncond.std() | |
| std_cond = out_cond.std() | |
| std_cfg = model_out.std() | |
| # target_std = max(std_uncond, std_cond) # Target std is usually the max of the two individual predictions or just the conditional std. | |
| # factor = target_std / std_cfg # Factor to bring the CFG output back into range | |
| factor = std_cond / std_cfg.clamp(min=1e-8) | |
| model_out_rescaled = model_out * factor | |
| phi = 0.7 # Final blend | |
| model_out = phi * model_out_rescaled + (1 - phi) * model_out | |
| elif guidance_scale <= 1 and embedding is not None: model_out = unet(x, t_batch, embedding) | |
| else: model_out = unet(x, t_batch, null_embedding) | |
| x = noise_scheduler.step(model_output=model_out, timestep=t, sample=x, eta=eta).prev_sample # DDIM step | |
| return x | |
| def gen_n_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps): | |
| seed_everything(seed=42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if unet == None: | |
| unet = Unet().to(device) | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) | |
| unet.load_state_dict(checkpoint["ema"], strict=True) | |
| unet.eval() | |
| text_emb = torch.load(f"{embedding_dir}/101654506_8eb26cfb60_3.pt", map_location=device, weights_only=True).unsqueeze(0) | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_schedule=unet_beta_schedule, | |
| prediction_type=unet_pred_type, # epsilon, v_prediction or sample | |
| rescale_betas_zero_snr=True, # Match training | |
| timestep_spacing="trailing", # Start from pure noise | |
| clip_sample=False, | |
| set_alpha_to_one=False | |
| ) | |
| latent_samples = ddim_sample( | |
| unet=unet, | |
| noise_scheduler=noise_scheduler, | |
| shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim), | |
| embedding=text_emb, | |
| guidance_scale=scale, | |
| num_steps=num_steps, | |
| eta=0.0, | |
| device=device | |
| ) | |
| # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents | |
| latent_samples = (latent_samples * latent_std) # Reverse scale the latents | |
| vae = VAE().to(device) | |
| path = f"{vae_checkpoint_dir}/{vae_weight}" | |
| checkpoint = torch.load(path, map_location=device, weights_only=True) | |
| vae.load_state_dict(checkpoint["vae"]) | |
| vae.eval() | |
| recon_images = vae.decode_latent_to_img(latent_samples) | |
| os.makedirs("./backend/core/ddim_sampled_img", exist_ok=True) | |
| save_image(recon_images, f"./backend/core/ddim_sampled_img/101654506_8eb26cfb60_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=recon_images.size(0)) | |
| print("💾 Successfully saved sampled images") | |
| def gen_val_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps): | |
| seed_everything(seed=42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if unet == None: | |
| unet = Unet().to(device) | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) | |
| unet.load_state_dict(checkpoint["ema"], strict=True) | |
| unet.eval() | |
| val_embeddings = torch.cat([torch.load(f"{unet_val_embeddings_dir}/{path}", map_location=device, weights_only=True).unsqueeze(0) for path in os.listdir(unet_val_embeddings_dir)]).to(device) | |
| # val_noises = torch.randn((n, vae_latent_channels, vae_latent_dim, vae_latent_dim), device=device) | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_schedule=unet_beta_schedule, | |
| prediction_type=unet_pred_type, # epsilon, v_prediction or sample | |
| rescale_betas_zero_snr=True, # Match training | |
| timestep_spacing="trailing", # Start from pure noise | |
| clip_sample=False, | |
| set_alpha_to_one=False | |
| ) | |
| vae = VAE().to(device) | |
| path = f"{vae_checkpoint_dir}/{vae_weight}" | |
| checkpoint = torch.load(path, map_location=device, weights_only=True) | |
| vae.load_state_dict(checkpoint["vae"]) | |
| vae.eval() | |
| img_rows = [] | |
| for i in range(4): | |
| latent_samples = ddim_sample( | |
| unet=unet, | |
| noise_scheduler=noise_scheduler, | |
| shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim), | |
| # x_start=val_noises, | |
| x_start=None, | |
| embedding=val_embeddings[i:i+1], | |
| guidance_scale=scale, | |
| num_steps=num_steps, | |
| eta=0.0, | |
| device=device | |
| ) | |
| # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents | |
| latent_samples = (latent_samples * latent_std) # Reverse scale the latents | |
| recon_images = vae.decode_latent_to_img(latent_samples) | |
| img_rows.append(recon_images) | |
| os.makedirs(ddim_img_dir, exist_ok=True) | |
| save_image(torch.cat(img_rows, dim=0), f"{ddim_img_dir}/val_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=n) | |
| print("💾 Successfully saved validation grid images") | |
| if __name__=="__main__": | |
| checkpoint_path = "./backend/core/checkpoints/ldm/ema_loss_0.220238.pth" | |
| gen_n_sampled_img(checkpoint_path=checkpoint_path, n=16, scale=8, num_steps=100) | |
| gen_val_sampled_img(checkpoint_path=checkpoint_path, n=4, scale=8, num_steps=100) |