from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict import os.path import copy import torch import torch.nn as nn import lightning.pytorch as pl from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT from torch.optim.lr_scheduler import LRScheduler from torch.optim import Optimizer from lightning.pytorch.callbacks import Callback from src.models.autoencoder.base import BaseAE, fp2uint8 from src.models.conditioner.base import BaseConditioner from src.utils.model_loader import ModelLoader from src.callbacks.simple_ema import SimpleEMA from src.diffusion.base.sampling import BaseSampler from src.diffusion.base.training import BaseTrainer from src.utils.no_grad import no_grad, filter_nograd_tensors from src.utils.copy import copy_params torch._functorch.config.donated_buffer = False EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] OptimizerCallable = Callable[[Iterable], Optimizer] LRSchedulerCallable = Callable[[Optimizer], LRScheduler] def set_requires_grad(module: nn.Module, requires_grad: bool): for param in module.parameters(): param.requires_grad_(requires_grad) def set_discriminator_trainable(module: nn.Module, requires_grad: bool): if hasattr(module, "set_trainable"): module.set_trainable(requires_grad) else: set_requires_grad(module, requires_grad) def set_optimizer_initial_lrs(optimizer: Optimizer): for group in optimizer.param_groups: if "lr_scale" in group and not group.get("_lr_scale_applied", False): group["lr"] *= group["lr_scale"] group["_lr_scale_applied"] = True group.setdefault("initial_lr", group["lr"]) class LightningModel(pl.LightningModule): def __init__(self, vae: BaseAE, conditioner: BaseConditioner, denoiser: nn.Module, diffusion_trainer: BaseTrainer, diffusion_sampler: BaseSampler, ema_tracker: SimpleEMA=None, optimizer: OptimizerCallable = None, lr_scheduler: LRSchedulerCallable = None, eval_original_model: bool = False, # ---- optional adversarial fine-tuning ---- discriminator: nn.Module = None, d_optimizer: OptimizerCallable = None, d_steps_per_g: int = 1, g_grad_clip: float = 1.0, d_grad_clip: float = 1.0, ): super().__init__() self.vae = vae self.conditioner = conditioner self.denoiser = denoiser self.ema_denoiser = copy.deepcopy(self.denoiser) self.diffusion_sampler = diffusion_sampler self.diffusion_trainer = diffusion_trainer self.ema_tracker = ema_tracker self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.eval_original_model = eval_original_model # discriminator / GAN fine-tuning self.discriminator = discriminator self.d_optimizer = d_optimizer self.d_steps_per_g = d_steps_per_g self.g_grad_clip = g_grad_clip self.d_grad_clip = d_grad_clip self._d_step_counter = 0 if self.discriminator is not None: # manual optimization is required for two-optimizer GAN training self.automatic_optimization = False self._strict_loading = False def configure_model(self) -> None: self.trainer.strategy.barrier() copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) # disable grad for conditioner and vae no_grad(self.conditioner) no_grad(self.vae) # no_grad(self.diffusion_sampler) no_grad(self.ema_denoiser) # GAN: keep DINOv2 backbone of the discriminator frozen, only the # trainable heads + text projection get updated. if self.discriminator is not None: self.discriminator.train() set_discriminator_trainable(self.discriminator, True) # torch.compile self.denoiser.compile() self.ema_denoiser.compile() def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: return [self.ema_tracker] def configure_optimizers(self) -> OptimizerLRScheduler: params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) params_sampler = filter_nograd_tensors(self.diffusion_sampler.parameters()) param_groups = [ {"params": params_denoiser, }, {"params": params_trainer,}, {"params": params_sampler, "lr": 1e-3}, ] # optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) optimizer: torch.optim.Optimizer = self.optimizer(param_groups) set_optimizer_initial_lrs(optimizer) # ---- GAN: also build a discriminator optimizer ---- d_optimizer = None if self.discriminator is not None: if hasattr(self.discriminator, "optimizer_param_groups"): d_params = self.discriminator.optimizer_param_groups() else: d_params = filter_nograd_tensors(self.discriminator.parameters()) if self.d_optimizer is None: d_optimizer = torch.optim.AdamW(d_params, lr=2e-4, betas=(0.0, 0.99)) else: d_optimizer = self.d_optimizer(d_params) set_optimizer_initial_lrs(d_optimizer) if self.lr_scheduler is None: if d_optimizer is None: return dict(optimizer=optimizer) return [optimizer, d_optimizer] else: lr_scheduler = self.lr_scheduler(optimizer) g_cfg = dict( optimizer=optimizer, lr_scheduler={ "scheduler": lr_scheduler, "interval": "step", "frequency": 1, "name": "learning_rate" } ) if d_optimizer is None: return g_cfg return [g_cfg, dict(optimizer=d_optimizer)] def on_validation_start(self) -> None: self.ema_denoiser.to(torch.float32) def on_predict_start(self) -> None: self.ema_denoiser.to(torch.float32) # sanity check before training start def on_train_start(self) -> None: self.ema_denoiser.to(torch.float32) self.ema_tracker.setup_models(net=self.denoiser, ema_net=self.ema_denoiser) if (self.discriminator is not None and hasattr(self.discriminator, "initialize_from_denoiser")): self.discriminator.initialize_from_denoiser(self.denoiser) set_discriminator_trainable(self.discriminator, True) def _optimizer_param_groups(self, optimizer): if isinstance(optimizer, LightningOptimizer): return optimizer.optimizer.param_groups return optimizer.param_groups def _apply_dynamic_lr_schedule(self, *optimizers): if not hasattr(self.diffusion_trainer, "get_lr_multiplier"): return lr_multiplier = self.diffusion_trainer.get_lr_multiplier(self.global_step) for optimizer in optimizers: for group in self._optimizer_param_groups(optimizer): group["lr"] = group["initial_lr"] * lr_multiplier self.log("lr_multiplier", lr_multiplier, prog_bar=True, on_step=True, sync_dist=False) def on_load_checkpoint(self, checkpoint): keys_to_check = [ "denoiser.pos_embed", "ema_denoiser.pos_embed" ] ckpt_state_dict = checkpoint["state_dict"] current_state_dict = self.state_dict() for key in keys_to_check: if key in ckpt_state_dict and key in current_state_dict: ckpt_shape = ckpt_state_dict[key].shape curr_shape = current_state_dict[key].shape if ckpt_shape != curr_shape: print(f"[Warning] Shape mismatch for '{key}': " f"Checkpoint {ckpt_shape} vs Current {curr_shape}. " f"Dropping from checkpoint to avoid RuntimeError.") del ckpt_state_dict[key] else: pass def training_step(self, batch, batch_idx): x, y, metadata = batch if metadata is None: metadata = {} metadata['global_step'] = self.global_step with torch.no_grad(): x = self.vae.encode(x) condition, uncondition = self.conditioner(y, metadata) # --------- non-GAN path: identical to the original implementation ---- if self.discriminator is None: loss = self.diffusion_trainer( self.denoiser, self.ema_denoiser, self.diffusion_sampler, x, condition, uncondition, metadata, ) self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) return loss["loss"] # ----------------- GAN path: manual two-optimizer step --------------- opt_g, opt_d = self.optimizers() self._apply_dynamic_lr_schedule(opt_g, opt_d) # Sample / drop conditioning the same way BaseTrainer does. # NOTE: preproprocess returns (x, condition, metadata) -- ordering # matters; do NOT swap to (condition, _, metadata). x, condition_used, metadata = self.diffusion_trainer.preproprocess( x, condition, uncondition, metadata, ) # ===== Generator step ===== set_discriminator_trainable(self.discriminator, False) g_losses, cache = self.diffusion_trainer.generator_step( self.denoiser, self.ema_denoiser, self.diffusion_sampler, x, condition_used, metadata, discriminator=self.discriminator, ) opt_g.zero_grad(set_to_none=True) self.manual_backward(g_losses["loss"]) if self.g_grad_clip is not None and self.g_grad_clip > 0: self.clip_gradients(opt_g, gradient_clip_val=self.g_grad_clip, gradient_clip_algorithm="norm") opt_g.step() # ===== Discriminator step ===== set_discriminator_trainable(self.discriminator, True) d_losses = self.diffusion_trainer.discriminator_step( self.discriminator, cache["pred_img"].detach(), cache["real_img"], cache["cond"], valid_length_y=cache.get("valid_length_y"), gan_mask=cache.get("gan_mask"), gan_active=cache.get("gan_active", True), ) opt_d.zero_grad(set_to_none=True) self.manual_backward(d_losses["d_loss"]) if self.d_grad_clip is not None and self.d_grad_clip > 0: self.clip_gradients(opt_d, gradient_clip_val=self.d_grad_clip, gradient_clip_algorithm="norm") opt_d.step() log_dict = dict(g_losses) log_dict.update(d_losses) self.log_dict(log_dict, prog_bar=True, on_step=True, sync_dist=False) return g_losses["loss"] def predict_step(self, batch, batch_idx): xT, y, metadata = batch with torch.no_grad(): condition, uncondition = self.conditioner(y) # sample images if self.eval_original_model: samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) else: samples = self.diffusion_sampler(self.ema_denoiser, xT, condition, uncondition) samples = self.vae.decode(samples) # fp32 -1,1 -> uint8 0,255 samples = fp2uint8(samples) return samples def validation_step(self, batch, batch_idx): samples = self.predict_step(batch, batch_idx) return samples def state_dict(self, *args, destination=None, prefix="", keep_vars=False): if destination is None: destination = {} self._save_to_state_dict(destination, prefix, keep_vars) self.denoiser.state_dict( destination=destination, prefix=prefix+"denoiser.", keep_vars=keep_vars) self.ema_denoiser.state_dict( destination=destination, prefix=prefix+"ema_denoiser.", keep_vars=keep_vars) self.diffusion_trainer.state_dict( destination=destination, prefix=prefix+"diffusion_trainer.", keep_vars=keep_vars) if self.discriminator is not None: # only checkpoint the trainable heads + text projection (the DINO # backbone is frozen and easily reconstructible from torch.hub). d_full = self.discriminator.state_dict( destination=None, prefix="", keep_vars=keep_vars, ) for k, v in d_full.items(): if k.startswith("dino."): continue destination[prefix + "discriminator." + k] = v return destination