import sys, os sys.path.insert(0, os.path.dirname(__file__)) from diffusers import DDPMScheduler import torch from torch import nn, optim from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, SequentialLR, CosineAnnealingLR, LinearLR, ConstantLR import torch.nn.functional as F from dataloader import latent_embedding_dataloader from config import * from tqdm import tqdm # from noise import NoiseScheduler from seed import seed_everything from unet import Unet from ema import EMA from sample_ddim import gen_n_sampled_img, gen_val_sampled_img import numpy as np import random def train_unet_ddpm_simple(use_checkpoint=False, checkpoint_path=None): os.makedirs(unet_checkpoint_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("DEVICE", device) train_loader = latent_embedding_dataloader() unet = Unet().to(device) optimizer = torch.optim.AdamW(unet.parameters(), lr=unet_optim_lr, weight_decay=0) noise_scheduler = DDPMScheduler( num_train_timesteps=1000, beta_schedule=unet_beta_schedule, prediction_type=unet_pred_type, rescale_betas_zero_snr=True ) epoch_size = 1265 max_steps = epoch_size * 4000 log_every = epoch_size * 2 save_every = epoch_size * 8 val_every = epoch_size * 4 ema_warmup_steps = epoch_size * 20 null_embedding = (torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0)) ema = EMA(unet, decay=0.999, warmup_steps=ema_warmup_steps) scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available()) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6) global_step = loaded_step = 0 total_mse = total_samples = 0 if use_checkpoint and checkpoint_path and os.path.exists(checkpoint_path): loaded_step = int(checkpoint_path.split("step_")[1].split("_")[0]) print("Loaded checkpoint successfully") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) if "rng_state" in checkpoint: torch.set_rng_state(checkpoint["rng_state"].cpu()) if torch.cuda.is_available() and "cuda_rng_state" in checkpoint: torch.cuda.set_rng_state(checkpoint["cuda_rng_state"].cpu()) unet.load_state_dict(checkpoint["unet"], strict=True) global_step = checkpoint["step"] optimizer.load_state_dict(checkpoint["optimizer"]) scaler.load_state_dict(checkpoint["scaler"]) noise_scheduler = DDPMScheduler.from_pretrained("./checkpoints/noise_scheduler") ema.shadow = {k: v.to(device) for k, v in checkpoint["ema"].items()} ema.num_updates = checkpoint["ema_num_updates"] # for _ in range(global_step): scheduler.step() unet.train() while global_step < max_steps: for latent, embedding in tqdm(train_loader, desc="Train: ", colour=tqdm_colors[5]): if global_step >= max_steps: break latent, embedding = latent.to(device), embedding.to(device) assert null_embedding.shape[1:] == embedding.shape[1:], f"{null_embedding.shape}" assert embedding.dim() == 3, f"Expected 3D embedding (B, seq, dim), got {embedding.shape}" batch_size = latent.size(0) t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long() noise = torch.randn_like(latent) noised_latent = noise_scheduler.add_noise(latent, noise, t) mask = torch.rand(batch_size, 1, 1, device=device) < 0.1 embedding = torch.where(mask, null_embedding.expand(batch_size, -1, -1), embedding) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()): target = noise_scheduler.get_velocity(latent, noise, t) if unet_pred_type == "v_prediction" else noise pred = unet(noised_latent, t, embedding) loss = F.mse_loss(pred, target, reduction="mean") scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) old_scale = scaler.get_scale() scaler.step(optimizer) scaler.update() if scaler.get_scale() >= old_scale: ema.update(unet) # scheduler.step() global_step += 1 total_mse += loss.item() * batch_size total_samples += batch_size if global_step % log_every == 0 and (global_step != loaded_step): # print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f} | LR: {scheduler.get_last_lr()[0]:.2e}") if total_samples > 0: print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f}") if global_step % save_every == 0 and (global_step != loaded_step): torch.save({ "step": global_step, "unet": unet.state_dict(), "ema": ema.shadow, "ema_num_updates": ema.num_updates, "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state() }, f"{unet_checkpoint_dir}/ema_step_{global_step}_{total_mse / total_samples:.6f}.pth") noise_scheduler.save_pretrained(f"./checkpoints/noise_scheduler") total_samples = total_mse = 0 if global_step % val_every == 0 and (global_step != loaded_step): unet.eval() applied_ema = False with torch.no_grad(): try: if global_step >= ema_warmup_steps: ema.apply_shadow(unet) applied_ema = True gen_n_sampled_img(unet=unet, saved_nth_step=global_step, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps) gen_val_sampled_img(unet=unet, saved_nth_step=global_step, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps) finally: if applied_ema: ema.restore(unet) unet.train() print("✅ Training finished") if __name__ == "__main__": seed_everything(42) train_unet_ddpm_simple(use_checkpoint=True, checkpoint_path="./backend\core\checkpoints\ldm\ema_step_1163800_0.262618.pth")