Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from diffusers import DDPMScheduler | |
| import torch | |
| from torch import nn, optim | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, SequentialLR, CosineAnnealingLR, LinearLR, ConstantLR | |
| import torch.nn.functional as F | |
| from dataloader import latent_embedding_dataloader | |
| from config import * | |
| from tqdm import tqdm | |
| # from noise import NoiseScheduler | |
| from seed import seed_everything | |
| from unet import Unet | |
| from ema import EMA | |
| from sample_ddim import gen_n_sampled_img, gen_val_sampled_img | |
| import numpy as np | |
| import random | |
| def train_unet_ddpm_simple(use_checkpoint=False, checkpoint_path=None): | |
| os.makedirs(unet_checkpoint_dir, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("DEVICE", device) | |
| train_loader = latent_embedding_dataloader() | |
| unet = Unet().to(device) | |
| optimizer = torch.optim.AdamW(unet.parameters(), lr=unet_optim_lr, weight_decay=0) | |
| noise_scheduler = DDPMScheduler( | |
| num_train_timesteps=1000, | |
| beta_schedule=unet_beta_schedule, | |
| prediction_type=unet_pred_type, | |
| rescale_betas_zero_snr=True | |
| ) | |
| epoch_size = 1265 | |
| max_steps = epoch_size * 4000 | |
| log_every = epoch_size * 2 | |
| save_every = epoch_size * 8 | |
| val_every = epoch_size * 4 | |
| ema_warmup_steps = epoch_size * 20 | |
| null_embedding = (torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0)) | |
| ema = EMA(unet, decay=0.999, warmup_steps=ema_warmup_steps) | |
| scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available()) | |
| # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6) | |
| global_step = loaded_step = 0 | |
| total_mse = total_samples = 0 | |
| if use_checkpoint and checkpoint_path and os.path.exists(checkpoint_path): | |
| loaded_step = int(checkpoint_path.split("step_")[1].split("_")[0]) | |
| print("Loaded checkpoint successfully") | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) | |
| if "rng_state" in checkpoint: torch.set_rng_state(checkpoint["rng_state"].cpu()) | |
| if torch.cuda.is_available() and "cuda_rng_state" in checkpoint: | |
| torch.cuda.set_rng_state(checkpoint["cuda_rng_state"].cpu()) | |
| unet.load_state_dict(checkpoint["unet"], strict=True) | |
| global_step = checkpoint["step"] | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| scaler.load_state_dict(checkpoint["scaler"]) | |
| noise_scheduler = DDPMScheduler.from_pretrained("./checkpoints/noise_scheduler") | |
| ema.shadow = {k: v.to(device) for k, v in checkpoint["ema"].items()} | |
| ema.num_updates = checkpoint["ema_num_updates"] | |
| # for _ in range(global_step): scheduler.step() | |
| unet.train() | |
| while global_step < max_steps: | |
| for latent, embedding in tqdm(train_loader, desc="Train: ", colour=tqdm_colors[5]): | |
| if global_step >= max_steps: break | |
| latent, embedding = latent.to(device), embedding.to(device) | |
| assert null_embedding.shape[1:] == embedding.shape[1:], f"{null_embedding.shape}" | |
| assert embedding.dim() == 3, f"Expected 3D embedding (B, seq, dim), got {embedding.shape}" | |
| batch_size = latent.size(0) | |
| t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long() | |
| noise = torch.randn_like(latent) | |
| noised_latent = noise_scheduler.add_noise(latent, noise, t) | |
| mask = torch.rand(batch_size, 1, 1, device=device) < 0.1 | |
| embedding = torch.where(mask, null_embedding.expand(batch_size, -1, -1), embedding) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()): | |
| target = noise_scheduler.get_velocity(latent, noise, t) if unet_pred_type == "v_prediction" else noise | |
| pred = unet(noised_latent, t, embedding) | |
| loss = F.mse_loss(pred, target, reduction="mean") | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) | |
| old_scale = scaler.get_scale() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| if scaler.get_scale() >= old_scale: | |
| ema.update(unet) | |
| # scheduler.step() | |
| global_step += 1 | |
| total_mse += loss.item() * batch_size | |
| total_samples += batch_size | |
| if global_step % log_every == 0 and (global_step != loaded_step): | |
| # print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f} | LR: {scheduler.get_last_lr()[0]:.2e}") | |
| if total_samples > 0: print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f}") | |
| if global_step % save_every == 0 and (global_step != loaded_step): | |
| torch.save({ | |
| "step": global_step, | |
| "unet": unet.state_dict(), | |
| "ema": ema.shadow, | |
| "ema_num_updates": ema.num_updates, | |
| "optimizer": optimizer.state_dict(), | |
| "scaler": scaler.state_dict(), | |
| "rng_state": torch.get_rng_state(), | |
| "cuda_rng_state": torch.cuda.get_rng_state() | |
| }, f"{unet_checkpoint_dir}/ema_step_{global_step}_{total_mse / total_samples:.6f}.pth") | |
| noise_scheduler.save_pretrained(f"./checkpoints/noise_scheduler") | |
| total_samples = total_mse = 0 | |
| if global_step % val_every == 0 and (global_step != loaded_step): | |
| unet.eval() | |
| applied_ema = False | |
| with torch.no_grad(): | |
| try: | |
| if global_step >= ema_warmup_steps: | |
| ema.apply_shadow(unet) | |
| applied_ema = True | |
| gen_n_sampled_img(unet=unet, saved_nth_step=global_step, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps) | |
| gen_val_sampled_img(unet=unet, saved_nth_step=global_step, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps) | |
| finally: | |
| if applied_ema: ema.restore(unet) | |
| unet.train() | |
| print("✅ Training finished") | |
| if __name__ == "__main__": | |
| seed_everything(42) | |
| train_unet_ddpm_simple(use_checkpoint=True, checkpoint_path="./backend\core\checkpoints\ldm\ema_step_1163800_0.262618.pth") |