flickr8k-backend / core /train_unet.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from diffusers import DDPMScheduler
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, SequentialLR, CosineAnnealingLR, LinearLR, ConstantLR
import torch.nn.functional as F
from dataloader import latent_embedding_dataloader
from config import *
from tqdm import tqdm
# from noise import NoiseScheduler
from seed import seed_everything
from unet import Unet
from ema import EMA
from sample_ddim import gen_n_sampled_img, gen_val_sampled_img
import numpy as np
import random
def train_unet_ddpm_simple(use_checkpoint=False, checkpoint_path=None):
os.makedirs(unet_checkpoint_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE", device)
train_loader = latent_embedding_dataloader()
unet = Unet().to(device)
optimizer = torch.optim.AdamW(unet.parameters(), lr=unet_optim_lr, weight_decay=0)
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_schedule=unet_beta_schedule,
prediction_type=unet_pred_type,
rescale_betas_zero_snr=True
)
epoch_size = 1265
max_steps = epoch_size * 4000
log_every = epoch_size * 2
save_every = epoch_size * 8
val_every = epoch_size * 4
ema_warmup_steps = epoch_size * 20
null_embedding = (torch.load(null_embedding_dir, map_location=device, weights_only=True).unsqueeze(0))
ema = EMA(unet, decay=0.999, warmup_steps=ema_warmup_steps)
scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
global_step = loaded_step = 0
total_mse = total_samples = 0
if use_checkpoint and checkpoint_path and os.path.exists(checkpoint_path):
loaded_step = int(checkpoint_path.split("step_")[1].split("_")[0])
print("Loaded checkpoint successfully")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
if "rng_state" in checkpoint: torch.set_rng_state(checkpoint["rng_state"].cpu())
if torch.cuda.is_available() and "cuda_rng_state" in checkpoint:
torch.cuda.set_rng_state(checkpoint["cuda_rng_state"].cpu())
unet.load_state_dict(checkpoint["unet"], strict=True)
global_step = checkpoint["step"]
optimizer.load_state_dict(checkpoint["optimizer"])
scaler.load_state_dict(checkpoint["scaler"])
noise_scheduler = DDPMScheduler.from_pretrained("./checkpoints/noise_scheduler")
ema.shadow = {k: v.to(device) for k, v in checkpoint["ema"].items()}
ema.num_updates = checkpoint["ema_num_updates"]
# for _ in range(global_step): scheduler.step()
unet.train()
while global_step < max_steps:
for latent, embedding in tqdm(train_loader, desc="Train: ", colour=tqdm_colors[5]):
if global_step >= max_steps: break
latent, embedding = latent.to(device), embedding.to(device)
assert null_embedding.shape[1:] == embedding.shape[1:], f"{null_embedding.shape}"
assert embedding.dim() == 3, f"Expected 3D embedding (B, seq, dim), got {embedding.shape}"
batch_size = latent.size(0)
t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long()
noise = torch.randn_like(latent)
noised_latent = noise_scheduler.add_noise(latent, noise, t)
mask = torch.rand(batch_size, 1, 1, device=device) < 0.1
embedding = torch.where(mask, null_embedding.expand(batch_size, -1, -1), embedding)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
target = noise_scheduler.get_velocity(latent, noise, t) if unet_pred_type == "v_prediction" else noise
pred = unet(noised_latent, t, embedding)
loss = F.mse_loss(pred, target, reduction="mean")
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
old_scale = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
if scaler.get_scale() >= old_scale:
ema.update(unet)
# scheduler.step()
global_step += 1
total_mse += loss.item() * batch_size
total_samples += batch_size
if global_step % log_every == 0 and (global_step != loaded_step):
# print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f} | LR: {scheduler.get_last_lr()[0]:.2e}")
if total_samples > 0: print(f" Step {global_step} | Loss: {total_mse / total_samples:.6f}")
if global_step % save_every == 0 and (global_step != loaded_step):
torch.save({
"step": global_step,
"unet": unet.state_dict(),
"ema": ema.shadow,
"ema_num_updates": ema.num_updates,
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"rng_state": torch.get_rng_state(),
"cuda_rng_state": torch.cuda.get_rng_state()
}, f"{unet_checkpoint_dir}/ema_step_{global_step}_{total_mse / total_samples:.6f}.pth")
noise_scheduler.save_pretrained(f"./checkpoints/noise_scheduler")
total_samples = total_mse = 0
if global_step % val_every == 0 and (global_step != loaded_step):
unet.eval()
applied_ema = False
with torch.no_grad():
try:
if global_step >= ema_warmup_steps:
ema.apply_shadow(unet)
applied_ema = True
gen_n_sampled_img(unet=unet, saved_nth_step=global_step, n=16, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps)
gen_val_sampled_img(unet=unet, saved_nth_step=global_step, n=4, scale=ddim_guidace_scale, num_steps=ddim_num_sampling_steps)
finally:
if applied_ema: ema.restore(unet)
unet.train()
print("✅ Training finished")
if __name__ == "__main__":
seed_everything(42)
train_unet_ddpm_simple(use_checkpoint=True, checkpoint_path="./backend\core\checkpoints\ldm\ema_step_1163800_0.262618.pth")