| from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict |
| import os.path |
| import copy |
| import torch |
| import torch.nn as nn |
| import lightning.pytorch as pl |
| from lightning.pytorch.core.optimizer import LightningOptimizer |
| from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT |
| from torch.optim.lr_scheduler import LRScheduler |
| from torch.optim import Optimizer |
| from lightning.pytorch.callbacks import Callback |
|
|
|
|
| from src.models.autoencoder.base import BaseAE, fp2uint8 |
| from src.models.conditioner.base import BaseConditioner |
| from src.utils.model_loader import ModelLoader |
| from src.callbacks.simple_ema import SimpleEMA |
| from src.diffusion.base.sampling import BaseSampler |
| from src.diffusion.base.training import BaseTrainer |
| from src.utils.no_grad import no_grad, filter_nograd_tensors |
| from src.utils.copy import copy_params |
|
|
| torch._functorch.config.donated_buffer = False |
|
|
| EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] |
| OptimizerCallable = Callable[[Iterable], Optimizer] |
| LRSchedulerCallable = Callable[[Optimizer], LRScheduler] |
|
|
|
|
| def set_requires_grad(module: nn.Module, requires_grad: bool): |
| for param in module.parameters(): |
| param.requires_grad_(requires_grad) |
|
|
|
|
| def set_discriminator_trainable(module: nn.Module, requires_grad: bool): |
| if hasattr(module, "set_trainable"): |
| module.set_trainable(requires_grad) |
| else: |
| set_requires_grad(module, requires_grad) |
|
|
|
|
| def set_optimizer_initial_lrs(optimizer: Optimizer): |
| for group in optimizer.param_groups: |
| if "lr_scale" in group and not group.get("_lr_scale_applied", False): |
| group["lr"] *= group["lr_scale"] |
| group["_lr_scale_applied"] = True |
| group.setdefault("initial_lr", group["lr"]) |
|
|
| class LightningModel(pl.LightningModule): |
| def __init__(self, |
| vae: BaseAE, |
| conditioner: BaseConditioner, |
| denoiser: nn.Module, |
| diffusion_trainer: BaseTrainer, |
| diffusion_sampler: BaseSampler, |
| ema_tracker: SimpleEMA=None, |
| optimizer: OptimizerCallable = None, |
| lr_scheduler: LRSchedulerCallable = None, |
| eval_original_model: bool = False, |
| |
| discriminator: nn.Module = None, |
| d_optimizer: OptimizerCallable = None, |
| d_steps_per_g: int = 1, |
| g_grad_clip: float = 1.0, |
| d_grad_clip: float = 1.0, |
| ): |
| super().__init__() |
| self.vae = vae |
| self.conditioner = conditioner |
| self.denoiser = denoiser |
| self.ema_denoiser = copy.deepcopy(self.denoiser) |
| self.diffusion_sampler = diffusion_sampler |
| self.diffusion_trainer = diffusion_trainer |
| self.ema_tracker = ema_tracker |
| self.optimizer = optimizer |
| self.lr_scheduler = lr_scheduler |
|
|
| self.eval_original_model = eval_original_model |
|
|
| |
| self.discriminator = discriminator |
| self.d_optimizer = d_optimizer |
| self.d_steps_per_g = d_steps_per_g |
| self.g_grad_clip = g_grad_clip |
| self.d_grad_clip = d_grad_clip |
| self._d_step_counter = 0 |
|
|
| if self.discriminator is not None: |
| |
| self.automatic_optimization = False |
|
|
| self._strict_loading = False |
|
|
| def configure_model(self) -> None: |
| self.trainer.strategy.barrier() |
| copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) |
|
|
| |
| no_grad(self.conditioner) |
| no_grad(self.vae) |
| |
| no_grad(self.ema_denoiser) |
|
|
| |
| |
| if self.discriminator is not None: |
| self.discriminator.train() |
| set_discriminator_trainable(self.discriminator, True) |
|
|
| |
| self.denoiser.compile() |
| self.ema_denoiser.compile() |
|
|
| def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: |
| return [self.ema_tracker] |
|
|
| def configure_optimizers(self) -> OptimizerLRScheduler: |
| params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) |
| params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) |
| params_sampler = filter_nograd_tensors(self.diffusion_sampler.parameters()) |
| param_groups = [ |
| {"params": params_denoiser, }, |
| {"params": params_trainer,}, |
| {"params": params_sampler, "lr": 1e-3}, |
| ] |
| |
| optimizer: torch.optim.Optimizer = self.optimizer(param_groups) |
| set_optimizer_initial_lrs(optimizer) |
|
|
| |
| d_optimizer = None |
| if self.discriminator is not None: |
| if hasattr(self.discriminator, "optimizer_param_groups"): |
| d_params = self.discriminator.optimizer_param_groups() |
| else: |
| d_params = filter_nograd_tensors(self.discriminator.parameters()) |
| if self.d_optimizer is None: |
| d_optimizer = torch.optim.AdamW(d_params, lr=2e-4, betas=(0.0, 0.99)) |
| else: |
| d_optimizer = self.d_optimizer(d_params) |
| set_optimizer_initial_lrs(d_optimizer) |
|
|
| if self.lr_scheduler is None: |
| if d_optimizer is None: |
| return dict(optimizer=optimizer) |
| return [optimizer, d_optimizer] |
| else: |
| lr_scheduler = self.lr_scheduler(optimizer) |
| g_cfg = dict( |
| optimizer=optimizer, |
| lr_scheduler={ |
| "scheduler": lr_scheduler, |
| "interval": "step", |
| "frequency": 1, |
| "name": "learning_rate" |
| } |
| ) |
| if d_optimizer is None: |
| return g_cfg |
| return [g_cfg, dict(optimizer=d_optimizer)] |
|
|
| def on_validation_start(self) -> None: |
| self.ema_denoiser.to(torch.float32) |
|
|
| def on_predict_start(self) -> None: |
| self.ema_denoiser.to(torch.float32) |
|
|
| |
| def on_train_start(self) -> None: |
| self.ema_denoiser.to(torch.float32) |
| self.ema_tracker.setup_models(net=self.denoiser, ema_net=self.ema_denoiser) |
| if (self.discriminator is not None |
| and hasattr(self.discriminator, "initialize_from_denoiser")): |
| self.discriminator.initialize_from_denoiser(self.denoiser) |
| set_discriminator_trainable(self.discriminator, True) |
|
|
| def _optimizer_param_groups(self, optimizer): |
| if isinstance(optimizer, LightningOptimizer): |
| return optimizer.optimizer.param_groups |
| return optimizer.param_groups |
|
|
| def _apply_dynamic_lr_schedule(self, *optimizers): |
| if not hasattr(self.diffusion_trainer, "get_lr_multiplier"): |
| return |
| lr_multiplier = self.diffusion_trainer.get_lr_multiplier(self.global_step) |
| for optimizer in optimizers: |
| for group in self._optimizer_param_groups(optimizer): |
| group["lr"] = group["initial_lr"] * lr_multiplier |
| self.log("lr_multiplier", lr_multiplier, prog_bar=True, on_step=True, sync_dist=False) |
|
|
| def on_load_checkpoint(self, checkpoint): |
| keys_to_check = [ |
| "denoiser.pos_embed", |
| "ema_denoiser.pos_embed" |
| ] |
| ckpt_state_dict = checkpoint["state_dict"] |
| |
| current_state_dict = self.state_dict() |
|
|
| for key in keys_to_check: |
| if key in ckpt_state_dict and key in current_state_dict: |
| ckpt_shape = ckpt_state_dict[key].shape |
| curr_shape = current_state_dict[key].shape |
| if ckpt_shape != curr_shape: |
| print(f"[Warning] Shape mismatch for '{key}': " |
| f"Checkpoint {ckpt_shape} vs Current {curr_shape}. " |
| f"Dropping from checkpoint to avoid RuntimeError.") |
| del ckpt_state_dict[key] |
| else: |
| pass |
|
|
| def training_step(self, batch, batch_idx): |
| x, y, metadata = batch |
| if metadata is None: |
| metadata = {} |
| metadata['global_step'] = self.global_step |
| with torch.no_grad(): |
| x = self.vae.encode(x) |
| condition, uncondition = self.conditioner(y, metadata) |
|
|
| |
| if self.discriminator is None: |
| loss = self.diffusion_trainer( |
| self.denoiser, self.ema_denoiser, self.diffusion_sampler, |
| x, condition, uncondition, metadata, |
| ) |
| self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) |
| return loss["loss"] |
|
|
| |
| opt_g, opt_d = self.optimizers() |
| self._apply_dynamic_lr_schedule(opt_g, opt_d) |
|
|
| |
| |
| |
| x, condition_used, metadata = self.diffusion_trainer.preproprocess( |
| x, condition, uncondition, metadata, |
| ) |
|
|
| |
| set_discriminator_trainable(self.discriminator, False) |
| g_losses, cache = self.diffusion_trainer.generator_step( |
| self.denoiser, self.ema_denoiser, self.diffusion_sampler, |
| x, condition_used, metadata, |
| discriminator=self.discriminator, |
| ) |
| opt_g.zero_grad(set_to_none=True) |
| self.manual_backward(g_losses["loss"]) |
| if self.g_grad_clip is not None and self.g_grad_clip > 0: |
| self.clip_gradients(opt_g, gradient_clip_val=self.g_grad_clip, |
| gradient_clip_algorithm="norm") |
| opt_g.step() |
|
|
| |
| set_discriminator_trainable(self.discriminator, True) |
| d_losses = self.diffusion_trainer.discriminator_step( |
| self.discriminator, |
| cache["pred_img"].detach(), |
| cache["real_img"], |
| cache["cond"], |
| valid_length_y=cache.get("valid_length_y"), |
| gan_mask=cache.get("gan_mask"), |
| gan_active=cache.get("gan_active", True), |
| ) |
| opt_d.zero_grad(set_to_none=True) |
| self.manual_backward(d_losses["d_loss"]) |
| if self.d_grad_clip is not None and self.d_grad_clip > 0: |
| self.clip_gradients(opt_d, gradient_clip_val=self.d_grad_clip, |
| gradient_clip_algorithm="norm") |
| opt_d.step() |
|
|
| log_dict = dict(g_losses) |
| log_dict.update(d_losses) |
| self.log_dict(log_dict, prog_bar=True, on_step=True, sync_dist=False) |
| return g_losses["loss"] |
|
|
| def predict_step(self, batch, batch_idx): |
| xT, y, metadata = batch |
| with torch.no_grad(): |
| condition, uncondition = self.conditioner(y) |
|
|
| |
| if self.eval_original_model: |
| samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) |
| else: |
| samples = self.diffusion_sampler(self.ema_denoiser, xT, condition, uncondition) |
|
|
| samples = self.vae.decode(samples) |
| |
| samples = fp2uint8(samples) |
| return samples |
|
|
| def validation_step(self, batch, batch_idx): |
| samples = self.predict_step(batch, batch_idx) |
| return samples |
|
|
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): |
| if destination is None: |
| destination = {} |
| self._save_to_state_dict(destination, prefix, keep_vars) |
| self.denoiser.state_dict( |
| destination=destination, |
| prefix=prefix+"denoiser.", |
| keep_vars=keep_vars) |
| self.ema_denoiser.state_dict( |
| destination=destination, |
| prefix=prefix+"ema_denoiser.", |
| keep_vars=keep_vars) |
| self.diffusion_trainer.state_dict( |
| destination=destination, |
| prefix=prefix+"diffusion_trainer.", |
| keep_vars=keep_vars) |
| if self.discriminator is not None: |
| |
| |
| d_full = self.discriminator.state_dict( |
| destination=None, prefix="", keep_vars=keep_vars, |
| ) |
| for k, v in d_full.items(): |
| if k.startswith("dino."): |
| continue |
| destination[prefix + "discriminator." + k] = v |
| return destination |