NeMo_Canary / nemo /collections /diffusion /data /diffusion_fake_datamodule.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader
from nemo.collections.diffusion.models.model import DiTConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from .diffusion_taskencoder import pos_id_3d
class PosEmb3D:
"""Generates and provides 3D positional embeddings for video data."""
def __init__(self, *, max_t=96, max_h=960, max_w=960):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()
def generate_pos_id(self):
"""Generates the positional ID grid based on max_t, max_h, and max_w."""
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)
def get_pos_id_3d(self, *, t, h, w):
"""Retrieves a subset of the positional IDs for the specified dimensions.
Parameters:
t (int): Number of time frames.
h (int): Height dimension.
w (int): Width dimension.
Returns:
torch.Tensor: The positional IDs tensor with shape (t, h, w, 3).
"""
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]
class DiTVideoLatentFakeDataset(torch.utils.data.Dataset):
"""A fake dataset for generating synthetic video latent data."""
def __init__(
self,
n_frames,
max_h,
max_w,
patch_size,
in_channels,
crossattn_emb_size,
max_text_seqlen=512,
seq_length=8192,
):
self.max_t = n_frames
self.max_height = max_h
self.max_width = max_w
self.patch_size = patch_size
self.in_channels = in_channels
self.text_dim = crossattn_emb_size
self.text_seqlen = max_text_seqlen
self.seq_length = seq_length
def __len__(self):
"""Returns the total number of samples."""
return 100000000
def __getitem__(self, idx):
"""Generates a single sample of data.
Parameters:
idx (int): Index of the data sample.
Returns:
dict: A dictionary containing video latent data and related information.
"""
t = self.max_t
h = self.max_height
w = self.max_width
p = self.patch_size
c = self.in_channels
video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5
text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16)
pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3)
return {
'video': video_latent,
't5_text_embeddings': text_embedding,
'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(),
'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(),
'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32),
'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16),
}
def _collate_fn(self, batch):
"""A default implementation of a collation function.
Users should override this method to define custom data loaders.
"""
return torch.utils.data.dataloader.default_collate(batch)
def collate_fn(self, batch):
"""Method that user passes as a functor to DataLoader.
The method optionally performs neural type checking and adds types to the outputs.
Please note, subclasses of Dataset should not implement `input_types`.
Usage:
dataloader = torch.utils.data.DataLoader(
....,
collate_fn=dataset.collate_fn,
....
)
Returns:
Collated batch, with or without types.
"""
return self._collate_fn(batch)
class VideoLatentFakeDataModule(pl.LightningDataModule):
"""A LightningDataModule for generating fake video latent data for training."""
def __init__(
self,
model_config: DiTConfig,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
task_encoder=None,
use_train_split_for_val: bool = False,
) -> None:
super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers
self.model_config = model_config
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
)
def setup(self, stage: str = "") -> None:
"""Sets up the dataset for training and validation.
Parameters:
stage (str): Optional stage argument (unused).
"""
self._train_ds = DiTVideoLatentFakeDataset(
n_frames=self.model_config.max_frames,
max_h=self.model_config.max_img_h,
max_w=self.model_config.max_img_w,
patch_size=self.model_config.patch_spatial,
in_channels=self.model_config.in_channels,
crossattn_emb_size=self.model_config.crossattn_emb_size,
)
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Returns the training DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)
def val_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the validation DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)
def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
"""Creates a DataLoader for the given dataset.
Parameters:
dataset (Dataset): The dataset to load.
**kwargs: Additional arguments for DataLoader.
Returns:
DataLoader: The DataLoader instance.
"""
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
collate_fn=dataset.collate_fn,
**kwargs,
)