|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple |
|
|
|
|
|
import nemo_run as run |
|
|
import torch |
|
|
import torch.distributed |
|
|
import torch.utils.checkpoint |
|
|
import torchvision |
|
|
import wandb |
|
|
from autovae import VAEGenerator |
|
|
from contperceptual_loss import LPIPSWithDiscriminator |
|
|
from diffusers import AutoencoderKL |
|
|
from megatron.core import parallel_state |
|
|
from megatron.core.transformer.enums import ModelType |
|
|
from megatron.core.transformer.module import MegatronModule |
|
|
from megatron.core.transformer.transformer_config import TransformerConfig |
|
|
from megatron.energon import DefaultTaskEncoder, ImageSample |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from nemo import lightning as nl |
|
|
from nemo.collections import llm |
|
|
from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule |
|
|
from nemo.collections.diffusion.train import pretrain |
|
|
from nemo.collections.llm.gpt.model.base import GPTModel |
|
|
from nemo.lightning.io.mixin import IOMixin |
|
|
from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT |
|
|
from nemo.lightning.pytorch.optim import OptimizerModule |
|
|
|
|
|
|
|
|
class AvgLossReduction(MegatronLossReduction): |
|
|
"""Performs average loss reduction across micro-batches.""" |
|
|
|
|
|
def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]: |
|
|
""" |
|
|
Forward pass for loss reduction. |
|
|
|
|
|
Args: |
|
|
batch: The batch of data. |
|
|
forward_out: The output tensor from forward computation. |
|
|
|
|
|
Returns: |
|
|
A tuple of (loss, reduction dictionary). |
|
|
""" |
|
|
loss = forward_out.mean() |
|
|
return loss, {"avg": loss} |
|
|
|
|
|
def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor: |
|
|
""" |
|
|
Reduce losses across multiple micro-batches by averaging them. |
|
|
|
|
|
Args: |
|
|
losses_reduced_per_micro_batch: A sequence of loss dictionaries. |
|
|
|
|
|
Returns: |
|
|
The averaged loss tensor. |
|
|
""" |
|
|
losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch]) |
|
|
return losses.mean() |
|
|
|
|
|
|
|
|
class VAE(MegatronModule): |
|
|
"""Variational Autoencoder (VAE) module.""" |
|
|
|
|
|
def __init__(self, config, pretrained_model_name_or_path, search_vae=False): |
|
|
""" |
|
|
Initialize the VAE model. |
|
|
|
|
|
Args: |
|
|
config: Transformer configuration. |
|
|
pretrained_model_name_or_path: Path or name of the pretrained model. |
|
|
search_vae: Flag to indicate whether to search for a target VAE using AutoVAE. |
|
|
""" |
|
|
super().__init__(config) |
|
|
if search_vae: |
|
|
|
|
|
self.vae = VAEGenerator(input_resolution=1024, compression_ratio=16) |
|
|
|
|
|
|
|
|
else: |
|
|
self.vae = AutoencoderKL.from_config(pretrained_model_name_or_path, weight_dtype=torch.bfloat16) |
|
|
|
|
|
sdxl_vae = AutoencoderKL.from_pretrained( |
|
|
'stabilityai/stable-diffusion-xl-base-1.0', subfolder="vae", weight_dtype=torch.bfloat16 |
|
|
) |
|
|
sd_dict = sdxl_vae.state_dict() |
|
|
vae_dict = self.vae.state_dict() |
|
|
pre_dict = {k: v for k, v in sd_dict.items() if (k in vae_dict) and (vae_dict[k].numel() == v.numel())} |
|
|
self.vae.load_state_dict(pre_dict, strict=False) |
|
|
del sdxl_vae |
|
|
|
|
|
self.vae_loss = LPIPSWithDiscriminator( |
|
|
disc_start=50001, |
|
|
logvar_init=0.0, |
|
|
kl_weight=0.000001, |
|
|
pixelloss_weight=1.0, |
|
|
disc_num_layers=3, |
|
|
disc_in_channels=3, |
|
|
disc_factor=1.0, |
|
|
disc_weight=0.5, |
|
|
perceptual_weight=1.0, |
|
|
use_actnorm=False, |
|
|
disc_conditional=False, |
|
|
disc_loss="hinge", |
|
|
) |
|
|
|
|
|
def forward(self, target, global_step): |
|
|
""" |
|
|
Forward pass through the VAE. |
|
|
|
|
|
Args: |
|
|
target: Target images. |
|
|
global_step: Current global step. |
|
|
|
|
|
Returns: |
|
|
A tuple (aeloss, log_dict_ae, pred) containing the loss, log dictionary, and predictions. |
|
|
""" |
|
|
posterior = self.vae.encode(target).latent_dist |
|
|
z = posterior.sample() |
|
|
pred = self.vae.decode(z).sample |
|
|
aeloss, log_dict_ae = self.vae_loss( |
|
|
inputs=target, |
|
|
reconstructions=pred, |
|
|
posteriors=posterior, |
|
|
optimizer_idx=0, |
|
|
global_step=global_step, |
|
|
last_layer=self.vae.decoder.conv_out.weight, |
|
|
) |
|
|
return aeloss, log_dict_ae, pred |
|
|
|
|
|
def set_input_tensor(self, input_tensor: Tensor) -> None: |
|
|
""" |
|
|
Set input tensor. |
|
|
|
|
|
Args: |
|
|
input_tensor: The input tensor to the model. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class VAEModel(GPTModel): |
|
|
"""A GPTModel wrapper for the VAE.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pretrained_model_name_or_path: str, |
|
|
optim: Optional[OptimizerModule] = None, |
|
|
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the VAEModel. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path: Path or name of the pretrained model. |
|
|
optim: Optional optimizer module. |
|
|
model_transform: Optional function to transform the model. |
|
|
""" |
|
|
self.pretrained_model_name_or_path = pretrained_model_name_or_path |
|
|
config = TransformerConfig(num_layers=1, hidden_size=1, num_attention_heads=1) |
|
|
self.model_type = ModelType.encoder_or_decoder |
|
|
super().__init__(config, optim=optim, model_transform=model_transform) |
|
|
|
|
|
def configure_model(self) -> None: |
|
|
"""Configure the model by initializing the module.""" |
|
|
if not hasattr(self, "module"): |
|
|
self.module = VAE(self.config, self.pretrained_model_name_or_path) |
|
|
|
|
|
def data_step(self, dataloader_iter) -> Dict[str, Any]: |
|
|
""" |
|
|
Perform a single data step to fetch a batch from the iterator. |
|
|
|
|
|
Args: |
|
|
dataloader_iter: The dataloader iterator. |
|
|
|
|
|
Returns: |
|
|
A dictionary with 'pixel_values' ready for the model. |
|
|
""" |
|
|
batch = next(dataloader_iter)[0] |
|
|
return {'pixel_values': batch.image.to(device='cuda', dtype=torch.bfloat16, non_blocking=True)} |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
""" |
|
|
Forward pass through the underlying module. |
|
|
|
|
|
Args: |
|
|
*args: Variable length argument list. |
|
|
**kwargs: Arbitrary keyword arguments. |
|
|
|
|
|
Returns: |
|
|
The result of forward pass of self.module. |
|
|
""" |
|
|
return self.module(*args, **kwargs) |
|
|
|
|
|
def training_step(self, batch, batch_idx=None) -> torch.Tensor: |
|
|
""" |
|
|
Perform a single training step. |
|
|
|
|
|
Args: |
|
|
batch: The input batch. |
|
|
batch_idx: Batch index. |
|
|
|
|
|
Returns: |
|
|
The loss tensor. |
|
|
""" |
|
|
loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step) |
|
|
|
|
|
if torch.distributed.get_rank() == 0: |
|
|
self.log_dict(log_dict_ae) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx=None) -> torch.Tensor: |
|
|
""" |
|
|
Perform a single validation step. |
|
|
|
|
|
Args: |
|
|
batch: The input batch. |
|
|
batch_idx: Batch index. |
|
|
|
|
|
Returns: |
|
|
The loss tensor. |
|
|
""" |
|
|
loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step) |
|
|
|
|
|
image = torch.cat([batch["pixel_values"].cpu(), pred.cpu()], axis=0) |
|
|
image = (image + 0.5).clamp(0, 1) |
|
|
|
|
|
|
|
|
wandb_rank = 0 |
|
|
|
|
|
if parallel_state.get_data_parallel_src_rank() == wandb_rank: |
|
|
if torch.distributed.get_rank() == wandb_rank: |
|
|
gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] |
|
|
else: |
|
|
gather_list = None |
|
|
torch.distributed.gather_object( |
|
|
image, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group() |
|
|
) |
|
|
if gather_list is not None: |
|
|
self.log_dict(log_dict_ae) |
|
|
wandb.log( |
|
|
{ |
|
|
"Original (left), Reconstruction (right)": [ |
|
|
wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(gather_list) |
|
|
] |
|
|
}, |
|
|
) |
|
|
|
|
|
return loss |
|
|
|
|
|
@property |
|
|
def training_loss_reduction(self) -> AvgLossReduction: |
|
|
"""Returns the loss reduction method for training.""" |
|
|
if not self._training_loss_reduction: |
|
|
self._training_loss_reduction = AvgLossReduction() |
|
|
return self._training_loss_reduction |
|
|
|
|
|
@property |
|
|
def validation_loss_reduction(self) -> AvgLossReduction: |
|
|
"""Returns the loss reduction method for validation.""" |
|
|
if not self._validation_loss_reduction: |
|
|
self._validation_loss_reduction = AvgLossReduction() |
|
|
return self._validation_loss_reduction |
|
|
|
|
|
def on_validation_model_zero_grad(self) -> None: |
|
|
""" |
|
|
Hook to handle zero grad on validation model step. |
|
|
Used here to skip first validation on resume. |
|
|
""" |
|
|
super().on_validation_model_zero_grad() |
|
|
if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True): |
|
|
self.trainer.sanity_checking = True |
|
|
self._restarting_skip_val_flag = False |
|
|
|
|
|
|
|
|
def crop_image(img, divisor=16): |
|
|
""" |
|
|
Crop the image so that both dimensions are divisible by the given divisor. |
|
|
|
|
|
Args: |
|
|
img: Image tensor. |
|
|
divisor: The divisor to use for cropping. |
|
|
|
|
|
Returns: |
|
|
The cropped image tensor. |
|
|
""" |
|
|
h, w = img.shape[-2], img.shape[-1] |
|
|
|
|
|
delta_h = h % divisor |
|
|
delta_w = w % divisor |
|
|
|
|
|
delta_h_top = delta_h // 2 |
|
|
delta_h_bottom = delta_h - delta_h_top |
|
|
|
|
|
delta_w_left = delta_w // 2 |
|
|
delta_w_right = delta_w - delta_w_left |
|
|
|
|
|
img_cropped = img[..., delta_h_top : h - delta_h_bottom, delta_w_left : w - delta_w_right] |
|
|
|
|
|
return img_cropped |
|
|
|
|
|
|
|
|
class ImageTaskEncoder(DefaultTaskEncoder, IOMixin): |
|
|
"""Image task encoder that crops and normalizes the image.""" |
|
|
|
|
|
def encode_sample(self, sample: ImageSample) -> ImageSample: |
|
|
""" |
|
|
Encode a single image sample by cropping and shifting its values. |
|
|
|
|
|
Args: |
|
|
sample: An image sample. |
|
|
|
|
|
Returns: |
|
|
The transformed image sample. |
|
|
""" |
|
|
sample = super().encode_sample(sample) |
|
|
sample.image = crop_image(sample.image, 16) |
|
|
sample.image -= 0.5 |
|
|
return sample |
|
|
|
|
|
|
|
|
@run.cli.factory(target=llm.train) |
|
|
def train_vae() -> run.Partial: |
|
|
""" |
|
|
Training factory function for VAE. |
|
|
|
|
|
Returns: |
|
|
A run.Partial recipe for training. |
|
|
""" |
|
|
recipe = pretrain() |
|
|
recipe.model = run.Config( |
|
|
VAEModel, |
|
|
pretrained_model_name_or_path='nemo/collections/diffusion/vae/vae16x/config.json', |
|
|
) |
|
|
recipe.data = run.Config( |
|
|
DiffusionDataModule, |
|
|
task_encoder=run.Config(ImageTaskEncoder), |
|
|
global_batch_size=24, |
|
|
num_workers=10, |
|
|
) |
|
|
recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9) |
|
|
recipe.optim.config.lr = 5e-6 |
|
|
recipe.optim.config.weight_decay = 1e-2 |
|
|
recipe.log.log_dir = 'nemo_experiments/train_vae' |
|
|
recipe.trainer.val_check_interval = 1000 |
|
|
recipe.trainer.callbacks[0].every_n_train_steps = 1000 |
|
|
|
|
|
return recipe |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
run.cli.main(llm.train, default_factory=train_vae) |
|
|
|