File size: 17,933 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from dataclasses import dataclass
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.energon import DefaultTaskEncoder, Sample, SkipSample
from megatron.energon.task_encoder.base import stateless
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys

from nemo.lightning.io.mixin import IOMixin
from nemo.utils.sequence_packing_utils import first_fit_decreasing


@dataclass
class DiffusionSample(Sample):
    """
    Data class representing a sample for diffusion tasks.

    Attributes:
        video (torch.Tensor): Video latents (C T H W).
        t5_text_embeddings (torch.Tensor): Text embeddings (S D).
        t5_text_mask (torch.Tensor): Mask for text embeddings.
        loss_mask (torch.Tensor): Mask indicating valid positions for loss computation.
        image_size (Optional[torch.Tensor]): Tensor containing image dimensions.
        fps (Optional[torch.Tensor]): Frame rate of the video.
        num_frames (Optional[torch.Tensor]): Number of frames in the video.
        padding_mask (Optional[torch.Tensor]): Mask indicating padding positions.
        seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings.
        seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings.
        pos_ids (Optional[torch.Tensor]): Positional IDs.
        latent_shape (Optional[torch.Tensor]): Shape of the latent tensor.
    """

    video: torch.Tensor  # video latents (C T H W)
    t5_text_embeddings: torch.Tensor  # (S D)
    t5_text_mask: torch.Tensor  # 1
    loss_mask: torch.Tensor
    image_size: Optional[torch.Tensor] = None
    fps: Optional[torch.Tensor] = None
    num_frames: Optional[torch.Tensor] = None
    padding_mask: Optional[torch.Tensor] = None
    seq_len_q: Optional[torch.Tensor] = None
    seq_len_kv: Optional[torch.Tensor] = None
    pos_ids: Optional[torch.Tensor] = None
    latent_shape: Optional[torch.Tensor] = None

    def to_dict(self) -> dict:
        """Converts the sample to a dictionary."""
        return dict(
            video=self.video,
            t5_text_embeddings=self.t5_text_embeddings,
            t5_text_mask=self.t5_text_mask,
            loss_mask=self.loss_mask,
            image_size=self.image_size,
            fps=self.fps,
            num_frames=self.num_frames,
            padding_mask=self.padding_mask,
            seq_len_q=self.seq_len_q,
            seq_len_kv=self.seq_len_kv,
            pos_ids=self.pos_ids,
            latent_shape=self.latent_shape,
        )

    def __add__(self, other: Any) -> int:
        """Adds the sequence length of this sample with another sample or integer."""
        if isinstance(other, DiffusionSample):
            # Combine the values of the two instances
            return self.seq_len_q.item() + other.seq_len_q.item()
        elif isinstance(other, int):
            # Add an integer to the value
            return self.seq_len_q.item() + other
        raise NotImplementedError

    def __radd__(self, other: Any) -> int:
        """Handles reverse addition for summing with integers."""
        # This is called if sum or other operations start with a non-DiffusionSample object.
        # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
        if isinstance(other, int):
            return self.seq_len_q.item() + other
        raise NotImplementedError

    def __lt__(self, other: Any) -> bool:
        """Compares this sample's sequence length with another sample or integer."""
        if isinstance(other, DiffusionSample):
            return self.seq_len_q.item() < other.seq_len_q.item()
        elif isinstance(other, int):
            return self.seq_len_q.item() < other
        raise NotImplementedError


def cook(sample: dict) -> dict:
    """
    Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

    Args:
        sample (dict): The input dictionary containing the raw sample data.

    Returns:
        dict: A new dictionary containing the processed sample data with the following keys:
            - All keys from the result of `basic_sample_keys(sample)`
            - 'json': The contains meta data like resolution, aspect ratio, fps, etc.
            - 'pth': contains video latent tensor
            - 'pickle': contains text embeddings
    """
    return dict(
        **basic_sample_keys(sample),
        json=sample['.json'],
        pth=sample['.pth'],
        pickle=sample['.pickle'],
    )


class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
    """
    BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks.
    Attributes:
        cookers (list): A list of Cooker objects used for processing.
        max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
        text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
    Methods:
        __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
            Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
        encode_sample(sample: dict) -> dict:
            Encodes a given sample dictionary containing video and text data.
            Args:
                sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info.
            Returns:
                dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask.
            Raises:
                SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size.
    """

    cookers = [
        Cooker(cook),
    ]

    def __init__(
        self,
        *args,
        max_frames: int = None,
        text_embedding_padding_size: int = 512,
        seq_length: int = None,
        max_seq_length: int = None,
        patch_spatial: int = 2,
        patch_temporal: int = 1,
        aesthetic_score: float = 0.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.max_frames = max_frames
        self.text_embedding_padding_size = text_embedding_padding_size
        self.seq_length = seq_length
        self.max_seq_length = max_seq_length
        self.patch_spatial = patch_spatial
        self.patch_temporal = patch_temporal
        self.aesthetic_score = aesthetic_score

    @stateless(restore_seeds=True)
    def encode_sample(self, sample: dict) -> dict:
        """
        Encodes video / text sample.
        """
        video_latent = sample['pth']

        if torch.isnan(video_latent).any() or torch.isinf(video_latent).any():
            raise SkipSample()
        if torch.max(torch.abs(video_latent)) > 1e3:
            raise SkipSample()

        info = sample['json']
        if info['aesthetic_score'] < self.aesthetic_score:
            raise SkipSample()

        C, T, H, W = video_latent.shape
        seq_len = (
            video_latent.shape[-1]
            * video_latent.shape[-2]
            * video_latent.shape[-3]
            // self.patch_spatial**2
            // self.patch_temporal
        )
        is_image = T == 1

        if self.seq_length is not None and seq_len > self.seq_length:
            raise SkipSample()
        if self.max_seq_length is not None and seq_len > self.max_seq_length:
            raise SkipSample()

        if self.max_frames is not None:
            video_latent = video_latent[:, : self.max_frames, :, :]

        video_latent = rearrange(
            video_latent,
            'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
            ph=self.patch_spatial,
            pw=self.patch_spatial,
            pt=self.patch_temporal,
        )

        if is_image:
            t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16)
        else:
            t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16)
        t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

        if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
            t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
        else:
            t5_text_embeddings = F.pad(
                t5_text_embeddings,
                (
                    0,
                    0,
                    0,
                    self.text_embedding_padding_size - t5_text_embeddings_seq_length,
                ),
            )
        t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

        if is_image:
            h, w = info['image_height'], info['image_width']
            fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
            num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
        else:
            h, w = info['height'], info['width']
            fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16)
            num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16)
        image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)

        pos_ids = rearrange(
            pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
            'T H W d -> (T H W) d',
        )

        if self.seq_length is not None and self.max_seq_length is None:
            pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
            loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
            loss_mask[:seq_len] = 1
            video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
        else:
            loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)

        return DiffusionSample(
            __key__=sample['__key__'],
            __restore_key__=sample['__restore_key__'],
            __subflavor__=None,
            __subflavors__=sample['__subflavors__'],
            video=video_latent,
            t5_text_embeddings=t5_text_embeddings,
            t5_text_mask=t5_text_mask,
            image_size=image_size,
            fps=fps,
            num_frames=num_frames,
            loss_mask=loss_mask,
            seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
            seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32),
            pos_ids=pos_ids,
            latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
        )

    def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]:
        """
        Selects sequences to pack for mixed image-video training.
        """
        results = first_fit_decreasing(samples, self.max_seq_length)
        random.shuffle(results)
        return results

    @stateless
    def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample:
        """Construct a new Diffusion sample by concatenating the sequences."""

        def stack(attr):
            return torch.stack([getattr(sample, attr) for sample in samples], dim=0)

        def cat(attr):
            return torch.cat([getattr(sample, attr) for sample in samples], dim=0)

        video = concat_pad([i.video for i in samples], self.max_seq_length)
        loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length)
        pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length)

        return DiffusionSample(
            __key__=",".join([s.__key__ for s in samples]),
            __restore_key__=(),  # Will be set by energon based on `samples`
            __subflavor__=None,
            __subflavors__=samples[0].__subflavors__,
            video=video,
            t5_text_embeddings=cat('t5_text_embeddings'),
            t5_text_mask=cat('t5_text_mask'),
            # image_size=stack('image_size'),
            # fps=stack('fps'),
            # num_frames=stack('num_frames'),
            loss_mask=loss_mask,
            seq_len_q=stack('seq_len_q'),
            seq_len_kv=stack('seq_len_kv'),
            pos_ids=pos_ids,
            latent_shape=stack('latent_shape'),
        )

    @stateless
    def batch(self, samples: List[DiffusionSample]) -> dict:
        """Return dictionary with data for batch."""
        if self.max_seq_length is None:
            # no packing
            return super().batch(samples).to_dict()

        # packing
        sample = samples[0]
        return dict(
            video=sample.video.unsqueeze_(0),
            t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0),
            t5_text_mask=sample.t5_text_mask.unsqueeze_(0),
            loss_mask=sample.loss_mask.unsqueeze_(0),
            # image_size=sample.image_size,
            # fps=sample.fps,
            # num_frames=sample.num_frames,
            # padding_mask=sample.padding_mask.unsqueeze_(0),
            seq_len_q=sample.seq_len_q,
            seq_len_kv=sample.seq_len_kv,
            pos_ids=sample.pos_ids.unsqueeze_(0),
            latent_shape=sample.latent_shape,
        )


class PosID3D:
    """
    Generates 3D positional IDs for video data.

    Attributes:
        max_t (int): Maximum number of time frames.
        max_h (int): Maximum height dimension.
        max_w (int): Maximum width dimension.
    """

    def __init__(self, *, max_t=32, max_h=128, max_w=128):
        self.max_t = max_t
        self.max_h = max_h
        self.max_w = max_w
        self.generate_pos_id()

    def generate_pos_id(self):
        """Generates a grid of positional IDs based on max_t, max_h, and max_w."""
        self.grid = torch.stack(
            torch.meshgrid(
                torch.arange(self.max_t, device='cpu'),
                torch.arange(self.max_h, device='cpu'),
                torch.arange(self.max_w, device='cpu'),
            ),
            dim=-1,
        )

    def get_pos_id_3d(self, *, t, h, w):
        """Retrieves positional IDs for specified dimensions."""
        if t > self.max_t or h > self.max_h or w > self.max_w:
            self.max_t = max(self.max_t, t)
            self.max_h = max(self.max_h, h)
            self.max_w = max(self.max_w, w)
            self.generate_pos_id()
        return self.grid[:t, :h, :w]


def pad_divisible(x, padding_value=0):
    """
    Pads the input tensor to make its size divisible by a specified value.

    Args:
        x (torch.Tensor): Input tensor.
        padding_value (int): The value to make the tensor size divisible by.

    Returns:
        torch.Tensor: Padded tensor.
    """
    if padding_value == 0:
        return x
    # Get the size of the first dimension
    n = x.size(0)

    # Compute the padding needed to make the first dimension divisible by 16
    padding_needed = (padding_value - n % padding_value) % padding_value

    if padding_needed <= 0:
        return x

    # Create a new shape with the padded first dimension
    new_shape = list(x.shape)
    new_shape[0] += padding_needed

    # Create a new tensor filled with zeros
    x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device)

    # Assign the original tensor to the beginning of the new tensor
    x_padded[:n] = x
    return x_padded


def concat_pad(tensor_list, max_seq_length):
    """
    Efficiently concatenates a list of tensors along the first dimension and pads with zeros
    to reach max_seq_length.

    Args:
        tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
        max_seq_length (int): The desired size of the first dimension of the output tensor.

    Returns:
        torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
    """
    import torch

    # Get common properties from the first tensor
    other_shape = tensor_list[0].shape[1:]
    dtype = tensor_list[0].dtype
    device = tensor_list[0].device

    # Initialize the result tensor with zeros
    result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)

    current_index = 0
    for tensor in tensor_list:
        length = tensor.shape[0]
        # Directly assign the tensor to the result tensor without checks
        result[current_index : current_index + length] = tensor
        current_index += length

    return result


pos_id_3d = PosID3D()


def cook_raw_images(sample: dict) -> dict:
    """
    Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

    Args:
        sample (dict): The input dictionary containing the raw sample data.

    Returns:
        dict: A new dictionary containing the processed sample data with the following keys:
            - All keys from the result of `basic_sample_keys(sample)`
            - 'jpg': original images
            - 'png': contains control images
            - 'txt': contains raw text
    """
    return dict(
        **basic_sample_keys(sample),
        images=sample['jpg'],
        hint=sample['png'],
        txt=sample['txt'],
    )


class RawImageDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
    '''
    Dummy task encoder takes raw image input on CrudeDataset.
    '''

    cookers = [
        # Cooker(cook),
        Cooker(cook_raw_images),
    ]