|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from megatron.energon import DefaultTaskEncoder, Sample, SkipSample |
|
|
from megatron.energon.task_encoder.base import stateless |
|
|
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys |
|
|
|
|
|
from nemo.lightning.io.mixin import IOMixin |
|
|
from nemo.utils.sequence_packing_utils import first_fit_decreasing |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiffusionSample(Sample): |
|
|
""" |
|
|
Data class representing a sample for diffusion tasks. |
|
|
|
|
|
Attributes: |
|
|
video (torch.Tensor): Video latents (C T H W). |
|
|
t5_text_embeddings (torch.Tensor): Text embeddings (S D). |
|
|
t5_text_mask (torch.Tensor): Mask for text embeddings. |
|
|
loss_mask (torch.Tensor): Mask indicating valid positions for loss computation. |
|
|
image_size (Optional[torch.Tensor]): Tensor containing image dimensions. |
|
|
fps (Optional[torch.Tensor]): Frame rate of the video. |
|
|
num_frames (Optional[torch.Tensor]): Number of frames in the video. |
|
|
padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. |
|
|
seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. |
|
|
seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. |
|
|
pos_ids (Optional[torch.Tensor]): Positional IDs. |
|
|
latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. |
|
|
""" |
|
|
|
|
|
video: torch.Tensor |
|
|
t5_text_embeddings: torch.Tensor |
|
|
t5_text_mask: torch.Tensor |
|
|
loss_mask: torch.Tensor |
|
|
image_size: Optional[torch.Tensor] = None |
|
|
fps: Optional[torch.Tensor] = None |
|
|
num_frames: Optional[torch.Tensor] = None |
|
|
padding_mask: Optional[torch.Tensor] = None |
|
|
seq_len_q: Optional[torch.Tensor] = None |
|
|
seq_len_kv: Optional[torch.Tensor] = None |
|
|
pos_ids: Optional[torch.Tensor] = None |
|
|
latent_shape: Optional[torch.Tensor] = None |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Converts the sample to a dictionary.""" |
|
|
return dict( |
|
|
video=self.video, |
|
|
t5_text_embeddings=self.t5_text_embeddings, |
|
|
t5_text_mask=self.t5_text_mask, |
|
|
loss_mask=self.loss_mask, |
|
|
image_size=self.image_size, |
|
|
fps=self.fps, |
|
|
num_frames=self.num_frames, |
|
|
padding_mask=self.padding_mask, |
|
|
seq_len_q=self.seq_len_q, |
|
|
seq_len_kv=self.seq_len_kv, |
|
|
pos_ids=self.pos_ids, |
|
|
latent_shape=self.latent_shape, |
|
|
) |
|
|
|
|
|
def __add__(self, other: Any) -> int: |
|
|
"""Adds the sequence length of this sample with another sample or integer.""" |
|
|
if isinstance(other, DiffusionSample): |
|
|
|
|
|
return self.seq_len_q.item() + other.seq_len_q.item() |
|
|
elif isinstance(other, int): |
|
|
|
|
|
return self.seq_len_q.item() + other |
|
|
raise NotImplementedError |
|
|
|
|
|
def __radd__(self, other: Any) -> int: |
|
|
"""Handles reverse addition for summing with integers.""" |
|
|
|
|
|
|
|
|
if isinstance(other, int): |
|
|
return self.seq_len_q.item() + other |
|
|
raise NotImplementedError |
|
|
|
|
|
def __lt__(self, other: Any) -> bool: |
|
|
"""Compares this sample's sequence length with another sample or integer.""" |
|
|
if isinstance(other, DiffusionSample): |
|
|
return self.seq_len_q.item() < other.seq_len_q.item() |
|
|
elif isinstance(other, int): |
|
|
return self.seq_len_q.item() < other |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def cook(sample: dict) -> dict: |
|
|
""" |
|
|
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. |
|
|
|
|
|
Args: |
|
|
sample (dict): The input dictionary containing the raw sample data. |
|
|
|
|
|
Returns: |
|
|
dict: A new dictionary containing the processed sample data with the following keys: |
|
|
- All keys from the result of `basic_sample_keys(sample)` |
|
|
- 'json': The contains meta data like resolution, aspect ratio, fps, etc. |
|
|
- 'pth': contains video latent tensor |
|
|
- 'pickle': contains text embeddings |
|
|
""" |
|
|
return dict( |
|
|
**basic_sample_keys(sample), |
|
|
json=sample['.json'], |
|
|
pth=sample['.pth'], |
|
|
pickle=sample['.pickle'], |
|
|
) |
|
|
|
|
|
|
|
|
class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin): |
|
|
""" |
|
|
BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. |
|
|
Attributes: |
|
|
cookers (list): A list of Cooker objects used for processing. |
|
|
max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. |
|
|
text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. |
|
|
Methods: |
|
|
__init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): |
|
|
Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. |
|
|
encode_sample(sample: dict) -> dict: |
|
|
Encodes a given sample dictionary containing video and text data. |
|
|
Args: |
|
|
sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. |
|
|
Returns: |
|
|
dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. |
|
|
Raises: |
|
|
SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. |
|
|
""" |
|
|
|
|
|
cookers = [ |
|
|
Cooker(cook), |
|
|
] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
max_frames: int = None, |
|
|
text_embedding_padding_size: int = 512, |
|
|
seq_length: int = None, |
|
|
max_seq_length: int = None, |
|
|
patch_spatial: int = 2, |
|
|
patch_temporal: int = 1, |
|
|
aesthetic_score: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.max_frames = max_frames |
|
|
self.text_embedding_padding_size = text_embedding_padding_size |
|
|
self.seq_length = seq_length |
|
|
self.max_seq_length = max_seq_length |
|
|
self.patch_spatial = patch_spatial |
|
|
self.patch_temporal = patch_temporal |
|
|
self.aesthetic_score = aesthetic_score |
|
|
|
|
|
@stateless(restore_seeds=True) |
|
|
def encode_sample(self, sample: dict) -> dict: |
|
|
""" |
|
|
Encodes video / text sample. |
|
|
""" |
|
|
video_latent = sample['pth'] |
|
|
|
|
|
if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): |
|
|
raise SkipSample() |
|
|
if torch.max(torch.abs(video_latent)) > 1e3: |
|
|
raise SkipSample() |
|
|
|
|
|
info = sample['json'] |
|
|
if info['aesthetic_score'] < self.aesthetic_score: |
|
|
raise SkipSample() |
|
|
|
|
|
C, T, H, W = video_latent.shape |
|
|
seq_len = ( |
|
|
video_latent.shape[-1] |
|
|
* video_latent.shape[-2] |
|
|
* video_latent.shape[-3] |
|
|
// self.patch_spatial**2 |
|
|
// self.patch_temporal |
|
|
) |
|
|
is_image = T == 1 |
|
|
|
|
|
if self.seq_length is not None and seq_len > self.seq_length: |
|
|
raise SkipSample() |
|
|
if self.max_seq_length is not None and seq_len > self.max_seq_length: |
|
|
raise SkipSample() |
|
|
|
|
|
if self.max_frames is not None: |
|
|
video_latent = video_latent[:, : self.max_frames, :, :] |
|
|
|
|
|
video_latent = rearrange( |
|
|
video_latent, |
|
|
'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)', |
|
|
ph=self.patch_spatial, |
|
|
pw=self.patch_spatial, |
|
|
pt=self.patch_temporal, |
|
|
) |
|
|
|
|
|
if is_image: |
|
|
t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16) |
|
|
else: |
|
|
t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16) |
|
|
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] |
|
|
|
|
|
if t5_text_embeddings_seq_length > self.text_embedding_padding_size: |
|
|
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] |
|
|
else: |
|
|
t5_text_embeddings = F.pad( |
|
|
t5_text_embeddings, |
|
|
( |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
self.text_embedding_padding_size - t5_text_embeddings_seq_length, |
|
|
), |
|
|
) |
|
|
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) |
|
|
|
|
|
if is_image: |
|
|
h, w = info['image_height'], info['image_width'] |
|
|
fps = torch.tensor([30] * 1, dtype=torch.bfloat16) |
|
|
num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) |
|
|
else: |
|
|
h, w = info['height'], info['width'] |
|
|
fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16) |
|
|
num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16) |
|
|
image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) |
|
|
|
|
|
pos_ids = rearrange( |
|
|
pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), |
|
|
'T H W d -> (T H W) d', |
|
|
) |
|
|
|
|
|
if self.seq_length is not None and self.max_seq_length is None: |
|
|
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) |
|
|
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) |
|
|
loss_mask[:seq_len] = 1 |
|
|
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) |
|
|
else: |
|
|
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) |
|
|
|
|
|
return DiffusionSample( |
|
|
__key__=sample['__key__'], |
|
|
__restore_key__=sample['__restore_key__'], |
|
|
__subflavor__=None, |
|
|
__subflavors__=sample['__subflavors__'], |
|
|
video=video_latent, |
|
|
t5_text_embeddings=t5_text_embeddings, |
|
|
t5_text_mask=t5_text_mask, |
|
|
image_size=image_size, |
|
|
fps=fps, |
|
|
num_frames=num_frames, |
|
|
loss_mask=loss_mask, |
|
|
seq_len_q=torch.tensor(seq_len, dtype=torch.int32), |
|
|
seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), |
|
|
pos_ids=pos_ids, |
|
|
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), |
|
|
) |
|
|
|
|
|
def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: |
|
|
""" |
|
|
Selects sequences to pack for mixed image-video training. |
|
|
""" |
|
|
results = first_fit_decreasing(samples, self.max_seq_length) |
|
|
random.shuffle(results) |
|
|
return results |
|
|
|
|
|
@stateless |
|
|
def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: |
|
|
"""Construct a new Diffusion sample by concatenating the sequences.""" |
|
|
|
|
|
def stack(attr): |
|
|
return torch.stack([getattr(sample, attr) for sample in samples], dim=0) |
|
|
|
|
|
def cat(attr): |
|
|
return torch.cat([getattr(sample, attr) for sample in samples], dim=0) |
|
|
|
|
|
video = concat_pad([i.video for i in samples], self.max_seq_length) |
|
|
loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length) |
|
|
pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length) |
|
|
|
|
|
return DiffusionSample( |
|
|
__key__=",".join([s.__key__ for s in samples]), |
|
|
__restore_key__=(), |
|
|
__subflavor__=None, |
|
|
__subflavors__=samples[0].__subflavors__, |
|
|
video=video, |
|
|
t5_text_embeddings=cat('t5_text_embeddings'), |
|
|
t5_text_mask=cat('t5_text_mask'), |
|
|
|
|
|
|
|
|
|
|
|
loss_mask=loss_mask, |
|
|
seq_len_q=stack('seq_len_q'), |
|
|
seq_len_kv=stack('seq_len_kv'), |
|
|
pos_ids=pos_ids, |
|
|
latent_shape=stack('latent_shape'), |
|
|
) |
|
|
|
|
|
@stateless |
|
|
def batch(self, samples: List[DiffusionSample]) -> dict: |
|
|
"""Return dictionary with data for batch.""" |
|
|
if self.max_seq_length is None: |
|
|
|
|
|
return super().batch(samples).to_dict() |
|
|
|
|
|
|
|
|
sample = samples[0] |
|
|
return dict( |
|
|
video=sample.video.unsqueeze_(0), |
|
|
t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0), |
|
|
t5_text_mask=sample.t5_text_mask.unsqueeze_(0), |
|
|
loss_mask=sample.loss_mask.unsqueeze_(0), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len_q=sample.seq_len_q, |
|
|
seq_len_kv=sample.seq_len_kv, |
|
|
pos_ids=sample.pos_ids.unsqueeze_(0), |
|
|
latent_shape=sample.latent_shape, |
|
|
) |
|
|
|
|
|
|
|
|
class PosID3D: |
|
|
""" |
|
|
Generates 3D positional IDs for video data. |
|
|
|
|
|
Attributes: |
|
|
max_t (int): Maximum number of time frames. |
|
|
max_h (int): Maximum height dimension. |
|
|
max_w (int): Maximum width dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, *, max_t=32, max_h=128, max_w=128): |
|
|
self.max_t = max_t |
|
|
self.max_h = max_h |
|
|
self.max_w = max_w |
|
|
self.generate_pos_id() |
|
|
|
|
|
def generate_pos_id(self): |
|
|
"""Generates a grid of positional IDs based on max_t, max_h, and max_w.""" |
|
|
self.grid = torch.stack( |
|
|
torch.meshgrid( |
|
|
torch.arange(self.max_t, device='cpu'), |
|
|
torch.arange(self.max_h, device='cpu'), |
|
|
torch.arange(self.max_w, device='cpu'), |
|
|
), |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
def get_pos_id_3d(self, *, t, h, w): |
|
|
"""Retrieves positional IDs for specified dimensions.""" |
|
|
if t > self.max_t or h > self.max_h or w > self.max_w: |
|
|
self.max_t = max(self.max_t, t) |
|
|
self.max_h = max(self.max_h, h) |
|
|
self.max_w = max(self.max_w, w) |
|
|
self.generate_pos_id() |
|
|
return self.grid[:t, :h, :w] |
|
|
|
|
|
|
|
|
def pad_divisible(x, padding_value=0): |
|
|
""" |
|
|
Pads the input tensor to make its size divisible by a specified value. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor. |
|
|
padding_value (int): The value to make the tensor size divisible by. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Padded tensor. |
|
|
""" |
|
|
if padding_value == 0: |
|
|
return x |
|
|
|
|
|
n = x.size(0) |
|
|
|
|
|
|
|
|
padding_needed = (padding_value - n % padding_value) % padding_value |
|
|
|
|
|
if padding_needed <= 0: |
|
|
return x |
|
|
|
|
|
|
|
|
new_shape = list(x.shape) |
|
|
new_shape[0] += padding_needed |
|
|
|
|
|
|
|
|
x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device) |
|
|
|
|
|
|
|
|
x_padded[:n] = x |
|
|
return x_padded |
|
|
|
|
|
|
|
|
def concat_pad(tensor_list, max_seq_length): |
|
|
""" |
|
|
Efficiently concatenates a list of tensors along the first dimension and pads with zeros |
|
|
to reach max_seq_length. |
|
|
|
|
|
Args: |
|
|
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad. |
|
|
max_seq_length (int): The desired size of the first dimension of the output tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions. |
|
|
""" |
|
|
import torch |
|
|
|
|
|
|
|
|
other_shape = tensor_list[0].shape[1:] |
|
|
dtype = tensor_list[0].dtype |
|
|
device = tensor_list[0].device |
|
|
|
|
|
|
|
|
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device) |
|
|
|
|
|
current_index = 0 |
|
|
for tensor in tensor_list: |
|
|
length = tensor.shape[0] |
|
|
|
|
|
result[current_index : current_index + length] = tensor |
|
|
current_index += length |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
pos_id_3d = PosID3D() |
|
|
|
|
|
|
|
|
def cook_raw_images(sample: dict) -> dict: |
|
|
""" |
|
|
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. |
|
|
|
|
|
Args: |
|
|
sample (dict): The input dictionary containing the raw sample data. |
|
|
|
|
|
Returns: |
|
|
dict: A new dictionary containing the processed sample data with the following keys: |
|
|
- All keys from the result of `basic_sample_keys(sample)` |
|
|
- 'jpg': original images |
|
|
- 'png': contains control images |
|
|
- 'txt': contains raw text |
|
|
""" |
|
|
return dict( |
|
|
**basic_sample_keys(sample), |
|
|
images=sample['jpg'], |
|
|
hint=sample['png'], |
|
|
txt=sample['txt'], |
|
|
) |
|
|
|
|
|
|
|
|
class RawImageDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin): |
|
|
''' |
|
|
Dummy task encoder takes raw image input on CrudeDataset. |
|
|
''' |
|
|
|
|
|
cookers = [ |
|
|
|
|
|
Cooker(cook_raw_images), |
|
|
] |
|
|
|