File size: 7,226 Bytes
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8a031
 
4aabce3
 
 
 
 
 
 
 
cf8a031
4aabce3
 
cf8a031
 
 
4aabce3
a625e96
4aabce3
cf8a031
 
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from diffusers import DDIMScheduler, DDPMPipeline, DDPMScheduler
import torch
from torchvision.utils import save_image
import os
from config import *
from tqdm import tqdm
from unet import Unet
from text_embeddings import get_embedding_model, get_text_embedding
from vae import *
from seed import seed_everything

@torch.no_grad()
def ddim_sample(unet, noise_scheduler, shape, null_embedding=None, x_start=None, embedding=None, guidance_scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps, eta=0.0, device="cuda"): # eta=0 -> deterministic DDIM
    x = x_start.clone() if x_start is not None else torch.randn(shape, device=device)
    # x = torch.randn(shape, device=device) # start from pure noise
    noise_scheduler.set_timesteps(num_steps) # set DDIM timesteps
    embedding = embedding.expand(x.shape[0], -1, -1) if embedding is not None else None
    if null_embedding is None: null_embedding = torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0) # [1, 77, 1024]
    null_embedding = null_embedding.expand(x.shape[0], -1, -1)
    for t in tqdm(noise_scheduler.timesteps, desc=f"Sampling timesteps: ", colour=tqdm_colors[-1]):
        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
        if guidance_scale > 1 and embedding is not None:
            out_uncond, out_cond = unet(torch.cat([x, x]), torch.cat([t_batch, t_batch]), torch.cat([null_embedding, embedding])).chunk(2)
            # out_uncond = unet(x, t_batch, null_embedding); out_cond   = unet(x, t_batch, embedding)
            model_out  = out_uncond + guidance_scale * (out_cond - out_uncond)
            # Guidance Rescale (using both stds)
            # std_uncond = out_uncond.std()
            std_cond = out_cond.std()
            std_cfg = model_out.std()
            # target_std = max(std_uncond, std_cond) # Target std is usually the max of the two individual predictions or just the conditional std.
            # factor = target_std / std_cfg # Factor to bring the CFG output back into range
            factor = std_cond / std_cfg.clamp(min=1e-8)
            model_out_rescaled = model_out * factor
            phi = 0.7 # Final blend
            model_out = phi * model_out_rescaled + (1 - phi) * model_out
        elif guidance_scale <= 1 and embedding is not None: model_out = unet(x, t_batch, embedding)
        else: model_out = unet(x, t_batch, null_embedding)
        x = noise_scheduler.step(model_output=model_out, timestep=t, sample=x, eta=eta).prev_sample # DDIM step
    return x

def gen_n_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
    seed_everything(seed=42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if unet == None:
        unet = Unet().to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        unet.load_state_dict(checkpoint["ema"], strict=True)
    unet.eval()
    text_emb = torch.load(f"{embedding_dir}/101654506_8eb26cfb60_3.pt", map_location=device, weights_only=True).unsqueeze(0)
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_schedule=unet_beta_schedule,
        prediction_type=unet_pred_type, # epsilon, v_prediction or sample
        rescale_betas_zero_snr=True,      # Match training
        timestep_spacing="trailing",       # Start from pure noise
        clip_sample=False,
        set_alpha_to_one=False
    )
    latent_samples = ddim_sample(
        unet=unet,
        noise_scheduler=noise_scheduler,
        shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
        embedding=text_emb,
        guidance_scale=scale,
        num_steps=num_steps, 
        eta=0.0,
        device=device
    )
    # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
    latent_samples = (latent_samples * latent_std) # Reverse scale the latents
    vae = VAE().to(device)
    path = f"{vae_checkpoint_dir}/{vae_weight}"
    checkpoint = torch.load(path, map_location=device, weights_only=True)
    vae.load_state_dict(checkpoint["vae"])
    vae.eval()
    recon_images = vae.decode_latent_to_img(latent_samples)
    os.makedirs("./backend/core/ddim_sampled_img", exist_ok=True)
    save_image(recon_images, f"./backend/core/ddim_sampled_img/101654506_8eb26cfb60_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=recon_images.size(0))
    print("💾 Successfully saved sampled images")

def gen_val_sampled_img(checkpoint_path=None, unet=None, saved_nth_step=unet_max_steps, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps):
    seed_everything(seed=42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if unet == None:
        unet = Unet().to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        unet.load_state_dict(checkpoint["ema"], strict=True)
    unet.eval()
    val_embeddings = torch.cat([torch.load(f"{unet_val_embeddings_dir}/{path}", map_location=device, weights_only=True).unsqueeze(0) for path in os.listdir(unet_val_embeddings_dir)]).to(device)
    # val_noises = torch.randn((n, vae_latent_channels, vae_latent_dim, vae_latent_dim), device=device)
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_schedule=unet_beta_schedule,
        prediction_type=unet_pred_type, # epsilon, v_prediction or sample
        rescale_betas_zero_snr=True,      # Match training
        timestep_spacing="trailing",       # Start from pure noise
        clip_sample=False,
        set_alpha_to_one=False
    )
    vae = VAE().to(device)
    path = f"{vae_checkpoint_dir}/{vae_weight}"
    checkpoint = torch.load(path, map_location=device, weights_only=True)
    vae.load_state_dict(checkpoint["vae"])
    vae.eval()
    img_rows = []
    for i in range(4):
        latent_samples = ddim_sample(
            unet=unet,
            noise_scheduler=noise_scheduler,
            shape=(n, vae_latent_channels, vae_latent_dim, vae_latent_dim),
            # x_start=val_noises,
            x_start=None,
            embedding=val_embeddings[i:i+1],
            guidance_scale=scale,
            num_steps=num_steps, 
            eta=0.0,
            device=device
        )
        # latent_samples = (latent_samples * latent_std) + latent_mu # Reverse normalize the latents
        latent_samples = (latent_samples * latent_std) # Reverse scale the latents
        recon_images = vae.decode_latent_to_img(latent_samples)
        img_rows.append(recon_images)
    os.makedirs(ddim_img_dir, exist_ok=True)
    save_image(torch.cat(img_rows, dim=0), f"{ddim_img_dir}/val_step_{saved_nth_step}_num_steps_{num_steps}_scale_{scale}.png",nrow=n)
    print("💾 Successfully saved validation grid images")

if __name__=="__main__":
    checkpoint_path = "./backend/core/checkpoints/ldm/ema_loss_0.220238.pth"
    gen_n_sampled_img(checkpoint_path=checkpoint_path, n=16, scale=8, num_steps=100)
    gen_val_sampled_img(checkpoint_path=checkpoint_path, n=4, scale=8, num_steps=100)