NeMo_Canary / nemo /collections /diffusion /data /diffusion_taskencoder.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from dataclasses import dataclass
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.energon import DefaultTaskEncoder, Sample, SkipSample
from megatron.energon.task_encoder.base import stateless
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
from nemo.lightning.io.mixin import IOMixin
from nemo.utils.sequence_packing_utils import first_fit_decreasing
@dataclass
class DiffusionSample(Sample):
"""
Data class representing a sample for diffusion tasks.
Attributes:
video (torch.Tensor): Video latents (C T H W).
t5_text_embeddings (torch.Tensor): Text embeddings (S D).
t5_text_mask (torch.Tensor): Mask for text embeddings.
loss_mask (torch.Tensor): Mask indicating valid positions for loss computation.
image_size (Optional[torch.Tensor]): Tensor containing image dimensions.
fps (Optional[torch.Tensor]): Frame rate of the video.
num_frames (Optional[torch.Tensor]): Number of frames in the video.
padding_mask (Optional[torch.Tensor]): Mask indicating padding positions.
seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings.
seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings.
pos_ids (Optional[torch.Tensor]): Positional IDs.
latent_shape (Optional[torch.Tensor]): Shape of the latent tensor.
"""
video: torch.Tensor # video latents (C T H W)
t5_text_embeddings: torch.Tensor # (S D)
t5_text_mask: torch.Tensor # 1
loss_mask: torch.Tensor
image_size: Optional[torch.Tensor] = None
fps: Optional[torch.Tensor] = None
num_frames: Optional[torch.Tensor] = None
padding_mask: Optional[torch.Tensor] = None
seq_len_q: Optional[torch.Tensor] = None
seq_len_kv: Optional[torch.Tensor] = None
pos_ids: Optional[torch.Tensor] = None
latent_shape: Optional[torch.Tensor] = None
def to_dict(self) -> dict:
"""Converts the sample to a dictionary."""
return dict(
video=self.video,
t5_text_embeddings=self.t5_text_embeddings,
t5_text_mask=self.t5_text_mask,
loss_mask=self.loss_mask,
image_size=self.image_size,
fps=self.fps,
num_frames=self.num_frames,
padding_mask=self.padding_mask,
seq_len_q=self.seq_len_q,
seq_len_kv=self.seq_len_kv,
pos_ids=self.pos_ids,
latent_shape=self.latent_shape,
)
def __add__(self, other: Any) -> int:
"""Adds the sequence length of this sample with another sample or integer."""
if isinstance(other, DiffusionSample):
# Combine the values of the two instances
return self.seq_len_q.item() + other.seq_len_q.item()
elif isinstance(other, int):
# Add an integer to the value
return self.seq_len_q.item() + other
raise NotImplementedError
def __radd__(self, other: Any) -> int:
"""Handles reverse addition for summing with integers."""
# This is called if sum or other operations start with a non-DiffusionSample object.
# e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
if isinstance(other, int):
return self.seq_len_q.item() + other
raise NotImplementedError
def __lt__(self, other: Any) -> bool:
"""Compares this sample's sequence length with another sample or integer."""
if isinstance(other, DiffusionSample):
return self.seq_len_q.item() < other.seq_len_q.item()
elif isinstance(other, int):
return self.seq_len_q.item() < other
raise NotImplementedError
def cook(sample: dict) -> dict:
"""
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.
Args:
sample (dict): The input dictionary containing the raw sample data.
Returns:
dict: A new dictionary containing the processed sample data with the following keys:
- All keys from the result of `basic_sample_keys(sample)`
- 'json': The contains meta data like resolution, aspect ratio, fps, etc.
- 'pth': contains video latent tensor
- 'pickle': contains text embeddings
"""
return dict(
**basic_sample_keys(sample),
json=sample['.json'],
pth=sample['.pth'],
pickle=sample['.pickle'],
)
class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
"""
BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks.
Attributes:
cookers (list): A list of Cooker objects used for processing.
max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
Methods:
__init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
encode_sample(sample: dict) -> dict:
Encodes a given sample dictionary containing video and text data.
Args:
sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info.
Returns:
dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask.
Raises:
SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size.
"""
cookers = [
Cooker(cook),
]
def __init__(
self,
*args,
max_frames: int = None,
text_embedding_padding_size: int = 512,
seq_length: int = None,
max_seq_length: int = None,
patch_spatial: int = 2,
patch_temporal: int = 1,
aesthetic_score: float = 0.0,
**kwargs,
):
super().__init__(*args, **kwargs)
self.max_frames = max_frames
self.text_embedding_padding_size = text_embedding_padding_size
self.seq_length = seq_length
self.max_seq_length = max_seq_length
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.aesthetic_score = aesthetic_score
@stateless(restore_seeds=True)
def encode_sample(self, sample: dict) -> dict:
"""
Encodes video / text sample.
"""
video_latent = sample['pth']
if torch.isnan(video_latent).any() or torch.isinf(video_latent).any():
raise SkipSample()
if torch.max(torch.abs(video_latent)) > 1e3:
raise SkipSample()
info = sample['json']
if info['aesthetic_score'] < self.aesthetic_score:
raise SkipSample()
C, T, H, W = video_latent.shape
seq_len = (
video_latent.shape[-1]
* video_latent.shape[-2]
* video_latent.shape[-3]
// self.patch_spatial**2
// self.patch_temporal
)
is_image = T == 1
if self.seq_length is not None and seq_len > self.seq_length:
raise SkipSample()
if self.max_seq_length is not None and seq_len > self.max_seq_length:
raise SkipSample()
if self.max_frames is not None:
video_latent = video_latent[:, : self.max_frames, :, :]
video_latent = rearrange(
video_latent,
'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
ph=self.patch_spatial,
pw=self.patch_spatial,
pt=self.patch_temporal,
)
if is_image:
t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16)
else:
t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16)
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]
if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
else:
t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length,
),
)
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)
if is_image:
h, w = info['image_height'], info['image_width']
fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
else:
h, w = info['height'], info['width']
fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16)
image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)
pos_ids = rearrange(
pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
'T H W d -> (T H W) d',
)
if self.seq_length is not None and self.max_seq_length is None:
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
loss_mask[:seq_len] = 1
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
else:
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
return DiffusionSample(
__key__=sample['__key__'],
__restore_key__=sample['__restore_key__'],
__subflavor__=None,
__subflavors__=sample['__subflavors__'],
video=video_latent,
t5_text_embeddings=t5_text_embeddings,
t5_text_mask=t5_text_mask,
image_size=image_size,
fps=fps,
num_frames=num_frames,
loss_mask=loss_mask,
seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32),
pos_ids=pos_ids,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
)
def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]:
"""
Selects sequences to pack for mixed image-video training.
"""
results = first_fit_decreasing(samples, self.max_seq_length)
random.shuffle(results)
return results
@stateless
def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample:
"""Construct a new Diffusion sample by concatenating the sequences."""
def stack(attr):
return torch.stack([getattr(sample, attr) for sample in samples], dim=0)
def cat(attr):
return torch.cat([getattr(sample, attr) for sample in samples], dim=0)
video = concat_pad([i.video for i in samples], self.max_seq_length)
loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length)
pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length)
return DiffusionSample(
__key__=",".join([s.__key__ for s in samples]),
__restore_key__=(), # Will be set by energon based on `samples`
__subflavor__=None,
__subflavors__=samples[0].__subflavors__,
video=video,
t5_text_embeddings=cat('t5_text_embeddings'),
t5_text_mask=cat('t5_text_mask'),
# image_size=stack('image_size'),
# fps=stack('fps'),
# num_frames=stack('num_frames'),
loss_mask=loss_mask,
seq_len_q=stack('seq_len_q'),
seq_len_kv=stack('seq_len_kv'),
pos_ids=pos_ids,
latent_shape=stack('latent_shape'),
)
@stateless
def batch(self, samples: List[DiffusionSample]) -> dict:
"""Return dictionary with data for batch."""
if self.max_seq_length is None:
# no packing
return super().batch(samples).to_dict()
# packing
sample = samples[0]
return dict(
video=sample.video.unsqueeze_(0),
t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0),
t5_text_mask=sample.t5_text_mask.unsqueeze_(0),
loss_mask=sample.loss_mask.unsqueeze_(0),
# image_size=sample.image_size,
# fps=sample.fps,
# num_frames=sample.num_frames,
# padding_mask=sample.padding_mask.unsqueeze_(0),
seq_len_q=sample.seq_len_q,
seq_len_kv=sample.seq_len_kv,
pos_ids=sample.pos_ids.unsqueeze_(0),
latent_shape=sample.latent_shape,
)
class PosID3D:
"""
Generates 3D positional IDs for video data.
Attributes:
max_t (int): Maximum number of time frames.
max_h (int): Maximum height dimension.
max_w (int): Maximum width dimension.
"""
def __init__(self, *, max_t=32, max_h=128, max_w=128):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()
def generate_pos_id(self):
"""Generates a grid of positional IDs based on max_t, max_h, and max_w."""
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)
def get_pos_id_3d(self, *, t, h, w):
"""Retrieves positional IDs for specified dimensions."""
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]
def pad_divisible(x, padding_value=0):
"""
Pads the input tensor to make its size divisible by a specified value.
Args:
x (torch.Tensor): Input tensor.
padding_value (int): The value to make the tensor size divisible by.
Returns:
torch.Tensor: Padded tensor.
"""
if padding_value == 0:
return x
# Get the size of the first dimension
n = x.size(0)
# Compute the padding needed to make the first dimension divisible by 16
padding_needed = (padding_value - n % padding_value) % padding_value
if padding_needed <= 0:
return x
# Create a new shape with the padded first dimension
new_shape = list(x.shape)
new_shape[0] += padding_needed
# Create a new tensor filled with zeros
x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device)
# Assign the original tensor to the beginning of the new tensor
x_padded[:n] = x
return x_padded
def concat_pad(tensor_list, max_seq_length):
"""
Efficiently concatenates a list of tensors along the first dimension and pads with zeros
to reach max_seq_length.
Args:
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
max_seq_length (int): The desired size of the first dimension of the output tensor.
Returns:
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
"""
import torch
# Get common properties from the first tensor
other_shape = tensor_list[0].shape[1:]
dtype = tensor_list[0].dtype
device = tensor_list[0].device
# Initialize the result tensor with zeros
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)
current_index = 0
for tensor in tensor_list:
length = tensor.shape[0]
# Directly assign the tensor to the result tensor without checks
result[current_index : current_index + length] = tensor
current_index += length
return result
pos_id_3d = PosID3D()
def cook_raw_images(sample: dict) -> dict:
"""
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.
Args:
sample (dict): The input dictionary containing the raw sample data.
Returns:
dict: A new dictionary containing the processed sample data with the following keys:
- All keys from the result of `basic_sample_keys(sample)`
- 'jpg': original images
- 'png': contains control images
- 'txt': contains raw text
"""
return dict(
**basic_sample_keys(sample),
images=sample['jpg'],
hint=sample['png'],
txt=sample['txt'],
)
class RawImageDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
'''
Dummy task encoder takes raw image input on CrudeDataset.
'''
cookers = [
# Cooker(cook),
Cooker(cook_raw_images),
]