Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: skip-file
import importlib
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from einops import rearrange
from megatron.core import parallel_state
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import nn
from typing_extensions import override
from nemo.collections.diffusion.models.dit_llama.dit_llama_model import DiTLlamaModel
from nemo.collections.diffusion.sampler.edm.edm_pipeline import EDMPipeline
from nemo.collections.llm.gpt.model.base import GPTModel
from nemo.lightning import io
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction, MegatronLossReduction
from nemo.lightning.pytorch.optim import OptimizerModule
from .dit.dit_model import DiTCrossAttentionModel
def dit_forward_step(model, batch) -> torch.Tensor:
"""Forward pass of DiT."""
return model(**batch)
def dit_data_step(module, dataloader_iter):
"""DiT data batch preparation."""
batch = next(dataloader_iter)[0]
batch = get_batch_on_this_cp_rank(batch)
batch = {k: v.to(device='cuda', non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
cu_seqlens = batch['seq_len_q'].cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
cu_seqlens_kv = batch['seq_len_kv'].cumsum(dim=0).to(torch.int32)
cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv))
batch['packed_seq_params'] = {
'self_attention': PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
qkv_format=module.qkv_format,
),
'cross_attention': PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens_kv,
qkv_format=module.qkv_format,
),
}
return batch
def get_batch_on_this_cp_rank(data: Dict):
"""Split the data for context parallelism."""
from megatron.core import mpu
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
if cp_size > 1:
num_valid_tokens_in_ub = None
if 'loss_mask' in data and data['loss_mask'] is not None:
num_valid_tokens_in_ub = data['loss_mask'].sum()
for key, value in data.items():
if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']):
if len(value.shape) > 5:
value = value.squeeze(0)
if len(value.shape) == 5:
B, C, T, H, W = value.shape
data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous()
else:
B, S, D = value.shape
data[key] = value.view(B, cp_size, S // cp_size, D)[:, cp_rank, ...].contiguous()
# TODO: sequence packing
loss_mask = data["loss_mask"]
data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[
:, cp_rank, ...
].contiguous()
data['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub
return data
@dataclass
class DiTConfig(TransformerConfig, io.IOMixin):
"""
Config for DiT-S model
"""
crossattn_emb_size: int = 1024
add_bias_linear: bool = False
gated_linear_unit: bool = False
num_layers: int = 12
hidden_size: int = 384
max_img_h: int = 80
max_img_w: int = 80
max_frames: int = 34
patch_spatial: int = 2
num_attention_heads: int = 6
layernorm_epsilon = 1e-6
normalization = "RMSNorm"
add_bias_linear = False
qk_layernorm_per_head = True
layernorm_zero_centered_gamma = False
fp16_lm_cross_entropy: bool = False
parallel_output: bool = True
share_embeddings_and_output_weights: bool = True
# max_position_embeddings: int = 5400
hidden_dropout: float = 0
attention_dropout: float = 0
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16
vae_module: str = 'nemo.collections.diffusion.vae.diffusers_vae.AutoencoderKLVAE'
vae_path: str = None
sigma_data: float = 0.5
in_channels: int = 16
data_step_fn = dit_data_step
forward_step_fn = dit_forward_step
replicated_t_embedder = True
seq_length: int = 2048
qkv_format: str = 'sbhd'
attn_mask_type: AttnMaskType = AttnMaskType.no_mask
@override
def configure_model(self, tokenizer=None, vp_stage: Optional[int] = None) -> DiTCrossAttentionModel:
"""Configure DiT Model from MCore."""
vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
assert (
self.num_layers // p_size
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."
if isinstance(self, DiTLlama30BConfig):
model = DiTLlamaModel
else:
model = DiTCrossAttentionModel
# During fake lightning initialization, pass 0 to bypass the assertion that vp_stage must be
# non-None when using virtual pipeline model parallelism
vp_stage = vp_stage or 0
return model(
self,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
pre_process=parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage),
post_process=parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage),
max_img_h=self.max_img_h,
max_img_w=self.max_img_w,
max_frames=self.max_frames,
patch_spatial=self.patch_spatial,
vp_stage=vp_stage,
)
def configure_vae(self):
"""Dynamically import video tokenizer."""
return dynamic_import(self.vae_module)(self.vae_path)
@dataclass
class DiTBConfig(DiTConfig):
"""DiT-B"""
num_layers: int = 12
hidden_size: int = 768
num_attention_heads: int = 12
@dataclass
class DiTLConfig(DiTConfig):
"""DiT-L"""
num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16
@dataclass
class DiTXLConfig(DiTConfig):
"""DiT-XL"""
num_layers: int = 28
hidden_size: int = 1152
num_attention_heads: int = 16
@dataclass
class DiT7BConfig(DiTConfig):
"""DiT-7B"""
num_layers: int = 32
hidden_size: int = 3072
num_attention_heads: int = 24
@dataclass
class DiTLlama30BConfig(DiTConfig):
"""MovieGen 30B"""
num_layers: int = 48
hidden_size: int = 6144
ffn_hidden_size: int = 16384
num_attention_heads: int = 48
num_query_groups: int = 8
gated_linear_unit: int = True
bias_activation_fusion: int = True
activation_func: Callable = F.silu
normalization: str = "RMSNorm"
layernorm_epsilon: float = 1e-5
max_frames: int = 128
max_img_h: int = 240
max_img_w: int = 240
patch_spatial: int = 2
init_method_std: float = 0.01
add_bias_linear: bool = False
seq_length: int = 256
bias_activation_fusion: bool = True
masked_softmax_fusion: bool = True
persist_layer_norm: bool = True
bias_dropout_fusion: bool = True
@dataclass
class DiTLlama5BConfig(DiTLlama30BConfig):
"""MovieGen 5B"""
num_layers: int = 32
hidden_size: int = 3072
ffn_hidden_size: int = 8192
num_attention_heads: int = 24
@dataclass
class DiTLlama1BConfig(DiTLlama30BConfig):
"""MovieGen 1B"""
num_layers: int = 16
hidden_size: int = 2048
ffn_hidden_size: int = 8192
num_attention_heads: int = 32
@dataclass
class ECDiTLlama1BConfig(DiTLlama1BConfig):
"EC-DiT 1B"
moe_router_load_balancing_type: str = 'expert_choice'
moe_token_dispatcher_type: str = 'alltoall'
moe_grouped_gemm: bool = True
moe_expert_capacity_factor: float = 8
moe_pad_expert_input_to_capacity: bool = True
moe_router_topk: int = 1
num_moe_experts: int = 64
ffn_hidden_size: int = 1024
class DiTModel(GPTModel):
"""
Diffusion Transformer Model
"""
def __init__(
self,
config: Optional[DiTConfig] = None,
optim: Optional[OptimizerModule] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
tokenizer: Optional[Any] = None,
):
super().__init__(config or DiTConfig(), optim=optim, model_transform=model_transform)
self.vae = None
self._training_loss_reduction = None
self._validation_loss_reduction = None
self.diffusion_pipeline = EDMPipeline(net=self, sigma_data=self.config.sigma_data)
self._noise_generator = None
self.seed = 42
self.vae = None
def load_state_dict(self, state_dict, strict=False):
self.module.load_state_dict(state_dict, strict=False)
def data_step(self, dataloader_iter) -> Dict[str, Any]:
return self.config.data_step_fn(dataloader_iter)
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def forward_step(self, batch) -> torch.Tensor:
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=self.vp_stage):
output_batch, loss = self.diffusion_pipeline.training_step(batch, 0)
loss = torch.mean(loss, dim=-1)
return loss
else:
output_tensor = self.diffusion_pipeline.training_step(batch, 0)
return output_tensor
def training_step(self, batch, batch_idx=None) -> torch.Tensor:
# In mcore the loss-function is part of the forward-pass (when labels are provided)
return self.forward_step(batch)
def on_validation_start(self):
if self.vae is None:
if self.config.vae_path is None:
warnings.warn('vae_path not specified skipping validation')
return None
self.vae = self.config.configure_vae()
self.vae.to('cuda')
def on_validation_end(self):
"""Move video tokenizer to CPU after validation."""
if self.vae is not None:
self.vae.to('cpu')
def validation_step(self, batch, batch_idx=None) -> torch.Tensor:
"""Generated validation video sample and logs to wandb."""
# In mcore the loss-function is part of the forward-pass (when labels are provided)
state_shape = batch['video'].shape
sample = self.diffusion_pipeline.generate_samples_from_batch(
batch,
guidance=7,
state_shape=state_shape,
num_steps=35,
is_negative_prompt=True if 'neg_t5_text_embeddings' in batch else False,
)
# TODO visualize more than 1 sample
sample = sample[0, None]
C, T, H, W = batch['latent_shape'][0]
seq_len_q = batch['seq_len_q'][0]
sample = rearrange(
sample[0, None, :seq_len_q],
'B (T H W) (ph pw pt C) -> B C (T pt) (H ph) (W pw)',
ph=self.config.patch_spatial,
pw=self.config.patch_spatial,
C=C,
T=T,
H=H // self.config.patch_spatial,
W=W // self.config.patch_spatial,
)
video = (1.0 + self.vae.decode(sample / self.config.sigma_data)).clamp(0, 2) / 2 # [B, 3, T, H, W]
video = (video * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
result = rearrange(video, 'b c t h w -> (b t) c h w')
# wandb is on the last rank for megatron, first rank for nemo
wandb_rank = 0
if parallel_state.get_data_parallel_src_rank() == wandb_rank:
if torch.distributed.get_rank() == wandb_rank:
gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())]
else:
gather_list = None
torch.distributed.gather_object(
result, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group()
)
if gather_list is not None:
videos = []
for video in gather_list:
try:
videos.append(wandb.Video(video, fps=24, format='mp4'))
except Exception as e:
warnings.warn(f'Error saving video as mp4: {e}')
videos.append(wandb.Video(video, fps=24))
wandb.log({'prediction': videos})
return None
@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()
return self._training_loss_reduction
@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._validation_loss_reduction:
self._validation_loss_reduction = DummyLossReduction()
return self._validation_loss_reduction
def on_validation_model_zero_grad(self) -> None:
'''
Small hack to avoid first validation on resume.
This will NOT work if the gradient accumulation step should be performed at this point.
https://github.com/Lightning-AI/pytorch-lightning/discussions/18110
'''
super().on_validation_model_zero_grad()
if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True):
self.trainer.sanity_checking = True
self._restarting_skip_val_flag = False
class DummyLossReduction(MegatronLossReduction):
"""
Diffusion Loss Reduction
"""
def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> None:
super().__init__()
self.validation_step = validation_step
self.val_drop_last = val_drop_last
def forward(
self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
return torch.tensor(0.0, device=torch.cuda.current_device()), {
"avg": torch.tensor(0.0, device=torch.cuda.current_device())
}
def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
return torch.tensor(0.0, device=torch.cuda.current_device())
def dynamic_import(full_path):
"""
Dynamically import a class or function from a given full path.
:param full_path: The full path to the class or function (e.g., "package.module.ClassName")
:return: The imported class or function
:raises ImportError: If the module or attribute cannot be imported
:raises AttributeError: If the attribute does not exist in the module
"""
try:
# Split the full path into module path and attribute name
module_path, attribute_name = full_path.rsplit('.', 1)
except ValueError as e:
raise ImportError(
f"Invalid full path '{full_path}'. It should contain both module and attribute names."
) from e
# Import the module
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(f"Cannot import module '{module_path}'.") from e
# Retrieve the attribute from the module
try:
attribute = getattr(module, attribute_name)
except AttributeError as e:
raise AttributeError(f"Module '{module_path}' does not have an attribute '{attribute_name}'.") from e
return attribute