Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import nemo_run as run
import torch
import torch.distributed
import torch.utils.checkpoint
import torchvision
import wandb
from autovae import VAEGenerator
from contperceptual_loss import LPIPSWithDiscriminator
from diffusers import AutoencoderKL
from megatron.core import parallel_state
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.energon import DefaultTaskEncoder, ImageSample
from torch import Tensor, nn
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule
from nemo.collections.diffusion.train import pretrain
from nemo.collections.llm.gpt.model.base import GPTModel
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT
from nemo.lightning.pytorch.optim import OptimizerModule
class AvgLossReduction(MegatronLossReduction):
"""Performs average loss reduction across micro-batches."""
def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]:
"""
Forward pass for loss reduction.
Args:
batch: The batch of data.
forward_out: The output tensor from forward computation.
Returns:
A tuple of (loss, reduction dictionary).
"""
loss = forward_out.mean()
return loss, {"avg": loss}
def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor:
"""
Reduce losses across multiple micro-batches by averaging them.
Args:
losses_reduced_per_micro_batch: A sequence of loss dictionaries.
Returns:
The averaged loss tensor.
"""
losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
return losses.mean()
class VAE(MegatronModule):
"""Variational Autoencoder (VAE) module."""
def __init__(self, config, pretrained_model_name_or_path, search_vae=False):
"""
Initialize the VAE model.
Args:
config: Transformer configuration.
pretrained_model_name_or_path: Path or name of the pretrained model.
search_vae: Flag to indicate whether to search for a target VAE using AutoVAE.
"""
super().__init__(config)
if search_vae:
# Get VAE automatically from AutoVAE
self.vae = VAEGenerator(input_resolution=1024, compression_ratio=16)
# Below line is commented out due to an undefined 'generator' variable in original code snippet.
# self.vae = generator.search_for_target_vae(parameters_budget=895.178707, cuda_max_mem=0)
else:
self.vae = AutoencoderKL.from_config(pretrained_model_name_or_path, weight_dtype=torch.bfloat16)
sdxl_vae = AutoencoderKL.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', subfolder="vae", weight_dtype=torch.bfloat16
)
sd_dict = sdxl_vae.state_dict()
vae_dict = self.vae.state_dict()
pre_dict = {k: v for k, v in sd_dict.items() if (k in vae_dict) and (vae_dict[k].numel() == v.numel())}
self.vae.load_state_dict(pre_dict, strict=False)
del sdxl_vae
self.vae_loss = LPIPSWithDiscriminator(
disc_start=50001,
logvar_init=0.0,
kl_weight=0.000001,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=0.5,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
)
def forward(self, target, global_step):
"""
Forward pass through the VAE.
Args:
target: Target images.
global_step: Current global step.
Returns:
A tuple (aeloss, log_dict_ae, pred) containing the loss, log dictionary, and predictions.
"""
posterior = self.vae.encode(target).latent_dist
z = posterior.sample()
pred = self.vae.decode(z).sample
aeloss, log_dict_ae = self.vae_loss(
inputs=target,
reconstructions=pred,
posteriors=posterior,
optimizer_idx=0,
global_step=global_step,
last_layer=self.vae.decoder.conv_out.weight,
)
return aeloss, log_dict_ae, pred
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""
Set input tensor.
Args:
input_tensor: The input tensor to the model.
"""
pass
class VAEModel(GPTModel):
"""A GPTModel wrapper for the VAE."""
def __init__(
self,
pretrained_model_name_or_path: str,
optim: Optional[OptimizerModule] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
):
"""
Initialize the VAEModel.
Args:
pretrained_model_name_or_path: Path or name of the pretrained model.
optim: Optional optimizer module.
model_transform: Optional function to transform the model.
"""
self.pretrained_model_name_or_path = pretrained_model_name_or_path
config = TransformerConfig(num_layers=1, hidden_size=1, num_attention_heads=1)
self.model_type = ModelType.encoder_or_decoder
super().__init__(config, optim=optim, model_transform=model_transform)
def configure_model(self) -> None:
"""Configure the model by initializing the module."""
if not hasattr(self, "module"):
self.module = VAE(self.config, self.pretrained_model_name_or_path)
def data_step(self, dataloader_iter) -> Dict[str, Any]:
"""
Perform a single data step to fetch a batch from the iterator.
Args:
dataloader_iter: The dataloader iterator.
Returns:
A dictionary with 'pixel_values' ready for the model.
"""
batch = next(dataloader_iter)[0]
return {'pixel_values': batch.image.to(device='cuda', dtype=torch.bfloat16, non_blocking=True)}
def forward(self, *args, **kwargs):
"""
Forward pass through the underlying module.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
The result of forward pass of self.module.
"""
return self.module(*args, **kwargs)
def training_step(self, batch, batch_idx=None) -> torch.Tensor:
"""
Perform a single training step.
Args:
batch: The input batch.
batch_idx: Batch index.
Returns:
The loss tensor.
"""
loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step)
if torch.distributed.get_rank() == 0:
self.log_dict(log_dict_ae)
return loss
def validation_step(self, batch, batch_idx=None) -> torch.Tensor:
"""
Perform a single validation step.
Args:
batch: The input batch.
batch_idx: Batch index.
Returns:
The loss tensor.
"""
loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step)
image = torch.cat([batch["pixel_values"].cpu(), pred.cpu()], axis=0)
image = (image + 0.5).clamp(0, 1)
# wandb is on the last rank for megatron, first rank for nemo
wandb_rank = 0
if parallel_state.get_data_parallel_src_rank() == wandb_rank:
if torch.distributed.get_rank() == wandb_rank:
gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())]
else:
gather_list = None
torch.distributed.gather_object(
image, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group()
)
if gather_list is not None:
self.log_dict(log_dict_ae)
wandb.log(
{
"Original (left), Reconstruction (right)": [
wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(gather_list)
]
},
)
return loss
@property
def training_loss_reduction(self) -> AvgLossReduction:
"""Returns the loss reduction method for training."""
if not self._training_loss_reduction:
self._training_loss_reduction = AvgLossReduction()
return self._training_loss_reduction
@property
def validation_loss_reduction(self) -> AvgLossReduction:
"""Returns the loss reduction method for validation."""
if not self._validation_loss_reduction:
self._validation_loss_reduction = AvgLossReduction()
return self._validation_loss_reduction
def on_validation_model_zero_grad(self) -> None:
"""
Hook to handle zero grad on validation model step.
Used here to skip first validation on resume.
"""
super().on_validation_model_zero_grad()
if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True):
self.trainer.sanity_checking = True
self._restarting_skip_val_flag = False
def crop_image(img, divisor=16):
"""
Crop the image so that both dimensions are divisible by the given divisor.
Args:
img: Image tensor.
divisor: The divisor to use for cropping.
Returns:
The cropped image tensor.
"""
h, w = img.shape[-2], img.shape[-1]
delta_h = h % divisor
delta_w = w % divisor
delta_h_top = delta_h // 2
delta_h_bottom = delta_h - delta_h_top
delta_w_left = delta_w // 2
delta_w_right = delta_w - delta_w_left
img_cropped = img[..., delta_h_top : h - delta_h_bottom, delta_w_left : w - delta_w_right]
return img_cropped
class ImageTaskEncoder(DefaultTaskEncoder, IOMixin):
"""Image task encoder that crops and normalizes the image."""
def encode_sample(self, sample: ImageSample) -> ImageSample:
"""
Encode a single image sample by cropping and shifting its values.
Args:
sample: An image sample.
Returns:
The transformed image sample.
"""
sample = super().encode_sample(sample)
sample.image = crop_image(sample.image, 16)
sample.image -= 0.5
return sample
@run.cli.factory(target=llm.train)
def train_vae() -> run.Partial:
"""
Training factory function for VAE.
Returns:
A run.Partial recipe for training.
"""
recipe = pretrain()
recipe.model = run.Config(
VAEModel,
pretrained_model_name_or_path='nemo/collections/diffusion/vae/vae16x/config.json',
)
recipe.data = run.Config(
DiffusionDataModule,
task_encoder=run.Config(ImageTaskEncoder),
global_batch_size=24,
num_workers=10,
)
recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9)
recipe.optim.config.lr = 5e-6
recipe.optim.config.weight_decay = 1e-2
recipe.log.log_dir = 'nemo_experiments/train_vae'
recipe.trainer.val_check_interval = 1000
recipe.trainer.callbacks[0].every_n_train_steps = 1000
return recipe
if __name__ == "__main__":
run.cli.main(llm.train, default_factory=train_vae)