Spaces:
Sleeping
Sleeping
File size: 7,226 Bytes
4aabce3 cf8a031 4aabce3 cf8a031 4aabce3 cf8a031 4aabce3 a625e96 4aabce3 cf8a031 4aabce3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from diffusers import DDIMScheduler, DDPMPipeline, DDPMScheduler
import torch
from torchvision.utils import save_image
import os
from config import *
from tqdm import tqdm
from unet import Unet
from text_embeddings import get_embedding_model, get_text_embedding
from vae import *
from seed import seed_everything
@torch.no_grad()
def ddim_sample(unet, noise_scheduler, shape, null_embedding=None, x_start=None, embedding=None, guidance_scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps, eta=0.0, device="cuda"): # eta=0 -> deterministic DDIM
x = x_start.clone() if x_start is not None else torch.randn(shape, device=device)
# x = torch.randn(shape, device=device) # start from pure noise
noise_scheduler.set_timesteps(num_steps) # set DDIM timesteps
embedding = embedding.expand(x.shape[0], -1, -1) if embedding is not None else None
if null_embedding is None: null_embedding = torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0) # [1, 77, 1024]
null_embedding = null_embedding.expand(x.shape[0], -1, -1)
for t in tqdm(noise_scheduler.timesteps, desc=f"Sampling timesteps: ", colour=tqdm_colors[-1]):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
if guidance_scale > 1 and embedding is not None:
out_uncond, out_cond = unet(torch.cat([x, x]), torch.cat([t_batch, t_batch]), torch.cat([null_embedding, embedding])).chunk(2)
# out_uncond = unet(x, t_batch, null_embedding); out_cond = unet(x, t_batch, embedding)
model_out = out_uncond + guidance_scale * (out_cond - out_uncond)
# Guidance Rescale (using both stds)
# std_uncond = out_uncond.std()
std_cond = out_cond.std()
std_cfg = model_out.std()
# target_std = max(std_uncond, std_cond) # Target std is usually the max of the two individual predictions or just the conditional std.
# factor = target_std / std_cfg # Factor to bring the CFG output back into range
factor = std_cond / std_cfg.clamp(min=1e-8)
model_out_rescaled = model_out * factor
phi = 0.7 # Final blend
model_out = phi * model_out_rescaled + (1 - phi) * model_out
elif guidance_scale <= 1 and embedding is not None: model_out = unet(x, t_batch, embedding)
else: model_out = unet(x, t_batch, null_embedding)
x = noise_scheduler.step(model_output=model_out, timestep=t, sample=x, eta=eta).prev_sample # DDIM step
return x
def gen_n_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
seed_everything(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if unet == None:
unet = Unet().to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
unet.load_state_dict(checkpoint["ema"], strict=True)
unet.eval()
text_emb = torch.load(f"{embedding_dir}/101654506_8eb26cfb60_3.pt", map_location=device, weights_only=True).unsqueeze(0)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type, # epsilon, v_prediction or sample
rescale_betas_zero_snr=True, # Match training
timestep_spacing="trailing", # Start from pure noise
clip_sample=False,
set_alpha_to_one=False
)
latent_samples = ddim_sample(
unet=unet,
noise_scheduler=noise_scheduler,
shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
embedding=text_emb,
guidance_scale=scale,
num_steps=num_steps,
eta=0.0,
device=device
)
# latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
latent_samples = (latent_samples * latent_std) # Reverse scale the latents
vae = VAE().to(device)
path = f"{vae_checkpoint_dir}/{vae_weight}"
checkpoint = torch.load(path, map_location=device, weights_only=True)
vae.load_state_dict(checkpoint["vae"])
vae.eval()
recon_images = vae.decode_latent_to_img(latent_samples)
os.makedirs("./backend/core/ddim_sampled_img", exist_ok=True)
save_image(recon_images, f"./backend/core/ddim_sampled_img/101654506_8eb26cfb60_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=recon_images.size(0))
print("💾 Successfully saved sampled images")
def gen_val_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
seed_everything(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if unet == None:
unet = Unet().to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
unet.load_state_dict(checkpoint["ema"], strict=True)
unet.eval()
val_embeddings = torch.cat([torch.load(f"{unet_val_embeddings_dir}/{path}", map_location=device, weights_only=True).unsqueeze(0) for path in os.listdir(unet_val_embeddings_dir)]).to(device)
# val_noises = torch.randn((n, vae_latent_channels, vae_latent_dim, vae_latent_dim), device=device)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type, # epsilon, v_prediction or sample
rescale_betas_zero_snr=True, # Match training
timestep_spacing="trailing", # Start from pure noise
clip_sample=False,
set_alpha_to_one=False
)
vae = VAE().to(device)
path = f"{vae_checkpoint_dir}/{vae_weight}"
checkpoint = torch.load(path, map_location=device, weights_only=True)
vae.load_state_dict(checkpoint["vae"])
vae.eval()
img_rows = []
for i in range(4):
latent_samples = ddim_sample(
unet=unet,
noise_scheduler=noise_scheduler,
shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
# x_start=val_noises,
x_start=None,
embedding=val_embeddings[i:i+1],
guidance_scale=scale,
num_steps=num_steps,
eta=0.0,
device=device
)
# latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
latent_samples = (latent_samples * latent_std) # Reverse scale the latents
recon_images = vae.decode_latent_to_img(latent_samples)
img_rows.append(recon_images)
os.makedirs(ddim_img_dir, exist_ok=True)
save_image(torch.cat(img_rows, dim=0), f"{ddim_img_dir}/val_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=n)
print("💾 Successfully saved validation grid images")
if __name__=="__main__":
checkpoint_path = "./backend/core/checkpoints/ldm/ema_loss_0.220238.pth"
gen_n_sampled_img(checkpoint_path=checkpoint_path, n=16, scale=8, num_steps=100)
gen_val_sampled_img(checkpoint_path=checkpoint_path, n=4, scale=8, num_steps=100) |