File size: 20,510 Bytes
bc8c4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
import sys
import torch, types
from PIL import Image
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit

from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
from ..models.wan_video_vae import WanVideoVAE
from ..models.mova_audio_dit import MovaAudioDit
from ..models.mova_audio_vae import DacVAE
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
from ..utils.data.audio import convert_to_mono, resample_waveform


class MovaAudioVideoPipeline(BasePipeline):

    def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
        super().__init__(
            device=device, torch_dtype=torch_dtype,
            height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
        )
        self.scheduler = FlowMatchScheduler("Wan")
        self.tokenizer: HuggingfaceTokenizer = None
        self.text_encoder: WanTextEncoder = None
        self.video_dit: WanModel = None # high noise model
        self.video_dit2: WanModel = None # low noise model
        self.audio_dit: MovaAudioDit = None
        self.dual_tower_bridge: DualTowerConditionalBridge = None
        self.video_vae: WanVideoVAE = None
        self.audio_vae: DacVAE = None

        self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
        self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")

        self.units = [
            MovaAudioVideoUnit_ShapeChecker(),
            MovaAudioVideoUnit_NoiseInitializer(),
            MovaAudioVideoUnit_InputVideoEmbedder(),
            MovaAudioVideoUnit_InputAudioEmbedder(),
            MovaAudioVideoUnit_PromptEmbedder(),
            MovaAudioVideoUnit_ImageEmbedderVAE(),
            MovaAudioVideoUnit_UnifiedSequenceParallel(),
        ]
        self.model_fn = model_fn_mova_audio_video
        self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]

    def enable_usp(self):
        from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
        for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
            block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
        self.sp_size = get_sequence_parallel_world_size()
        self.use_unified_sequence_parallel = True

    @staticmethod
    def from_pretrained(
        torch_dtype: torch.dtype = torch.bfloat16,
        device: Union[str, torch.device] = get_device_type(),
        model_configs: list[ModelConfig] = [],
        tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
        use_usp: bool = False,
        vram_limit: float = None,
    ):
        if use_usp:
            from ..utils.xfuser import initialize_usp
            initialize_usp(device)
            import torch.distributed as dist
            from ..core.device.npu_compatible_device import get_device_name
            if dist.is_available() and dist.is_initialized():
                device = get_device_name()
        # Initialize pipeline
        pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
        model_pool = pipe.download_and_load_models(model_configs, vram_limit)

        # Fetch models
        pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
        dit = model_pool.fetch_model("wan_video_dit", index=2)
        if isinstance(dit, list):
            pipe.video_dit, pipe.video_dit2 = dit
        else:
            pipe.video_dit = dit
        pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
        pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
        pipe.video_vae = model_pool.fetch_model("wan_video_vae")
        pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
        set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))

        # Size division factor
        if pipe.video_vae is not None:
            pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
            pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2

        # Initialize tokenizer and processor
        if tokenizer_config is not None:
            tokenizer_config.download_if_necessary()
            pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')

        # Unified Sequence Parallel
        if use_usp: pipe.enable_usp()

        # VRAM Management
        pipe.vram_management_enabled = pipe.check_vram_management_state()
        return pipe

    @torch.no_grad()
    def __call__(
        self,
        # Prompt
        prompt: str,
        negative_prompt: Optional[str] = "",
        # Image-to-video
        input_image: Optional[Image.Image] = None,
        # First-last-frame-to-video
        end_image: Optional[Image.Image] = None,
        # Video-to-video
        denoising_strength: Optional[float] = 1.0,
        # Randomness
        seed: Optional[int] = None,
        rand_device: Optional[str] = "cpu",
        # Shape
        height: Optional[int] = 352,
        width: Optional[int] = 640,
        num_frames: Optional[int] = 81,
        frame_rate: Optional[int] = 24,
        # Classifier-free guidance
        cfg_scale: Optional[float] = 5.0,
        # Boundary
        switch_DiT_boundary: Optional[float] = 0.9,
        # Scheduler
        num_inference_steps: Optional[int] = 50,
        sigma_shift: Optional[float] = 5.0,
        # VAE tiling
        tiled: Optional[bool] = True,
        tile_size: Optional[tuple[int, int]] = (30, 52),
        tile_stride: Optional[tuple[int, int]] = (15, 26),
        # progress_bar
        progress_bar_cmd=tqdm,
    ):
        # Scheduler
        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

        # Inputs
        inputs_posi = {
            "prompt": prompt,
        }
        inputs_nega = {
            "negative_prompt": negative_prompt,
        }
        inputs_shared = {
            "input_image": input_image,
            "end_image": end_image,
            "denoising_strength": denoising_strength,
            "seed": seed, "rand_device": rand_device,
            "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
            "cfg_scale": cfg_scale,
            "sigma_shift": sigma_shift,
            "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
        }
        for unit in self.units:
            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

        # Denoise
        self.load_models_to_device(self.in_iteration_models)
        models = {name: getattr(self, name) for name in self.in_iteration_models}
        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
            # Switch DiT if necessary
            if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
                self.load_models_to_device(self.in_iteration_models_2)
                models["video_dit"] = self.video_dit2
            # Timestep
            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
            noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
                self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
                **models, timestep=timestep, progress_id=progress_id
            )
            # Scheduler
            inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
            inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)

        # Decode
        self.load_models_to_device(['video_vae'])
        video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
        video = self.vae_output_to_video(video)
        self.load_models_to_device(["audio_vae"])
        audio = self.audio_vae.decode(inputs_shared["audio_latents"])
        audio = self.output_audio_format_check(audio)
        self.load_models_to_device([])
        return video, audio


class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("height", "width", "num_frames"),
            output_params=("height", "width", "num_frames"),
        )

    def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
        height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
        return {"height": height, "width": width, "num_frames": num_frames}


class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
            output_params=("video_noise", "audio_noise")
        )

    def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
        length = (num_frames - 1) // 4 + 1
        video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
        video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)

        audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
        audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
        audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
        return {"video_noise": video_noise, "audio_noise": audio_noise}


class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
            output_params=("video_latents", "input_latents"),
            onload_model_names=("video_vae",)
        )

    def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
        if input_video is None or not pipe.scheduler.training:
            return {"video_latents": video_noise}
        else:
            pipe.load_models_to_device(self.onload_model_names)
            input_video = pipe.preprocess_video(input_video)
            input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
            return {"input_latents": input_latents}


class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_audio", "audio_noise"),
            output_params=("audio_latents", "audio_input_latents"),
            onload_model_names=("audio_vae",)
        )

    def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
        if input_audio is None or not pipe.scheduler.training:
            return {"audio_latents": audio_noise}
        else:
            pipe.load_models_to_device(self.onload_model_names)
            input_audio, sample_rate = input_audio
            input_audio = convert_to_mono(input_audio)
            input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
            input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
            z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
            return {"audio_input_latents": z.mode()}


class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            seperate_cfg=True,
            input_params_posi={"prompt": "prompt"},
            input_params_nega={"prompt": "negative_prompt"},
            output_params=("context",),
            onload_model_names=("text_encoder",)
        )

    def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
        ids, mask = pipe.tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            add_special_tokens=True,
            return_mask=True,
            return_tensors="pt",
        )
        ids = ids.to(pipe.device)
        mask = mask.to(pipe.device)
        seq_lens = mask.gt(0).sum(dim=1).long()
        prompt_emb = pipe.text_encoder(ids, mask)
        for i, v in enumerate(seq_lens):
            prompt_emb[:, v:] = 0
        return prompt_emb

    def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
        pipe.load_models_to_device(self.onload_model_names)
        prompt_emb = self.encode_prompt(pipe, prompt)
        return {"context": prompt_emb}


class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
            output_params=("y",),
            onload_model_names=("video_vae",)
        )

    def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
        if input_image is None or not pipe.video_dit.require_vae_embedding:
            return {}
        pipe.load_models_to_device(self.onload_model_names)

        image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
        msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
        msk[:, 1:] = 0
        if end_image is not None:
            end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
            vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
            msk[:, -1:] = 1
        else:
            vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)

        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
        msk = msk.transpose(1, 2)[0]

        y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
        y = torch.concat([msk, y])
        y = y.unsqueeze(0)
        y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
        return {"y": y}


class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
    def __init__(self):
        super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))

    def process(self, pipe: MovaAudioVideoPipeline):
        if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
            return {"use_unified_sequence_parallel": True}
        return {"use_unified_sequence_parallel": False}


def model_fn_mova_audio_video(
    video_dit: WanModel,
    audio_dit: MovaAudioDit,
    dual_tower_bridge: DualTowerConditionalBridge,
    video_latents: torch.Tensor = None,
    audio_latents: torch.Tensor = None,
    timestep: torch.Tensor = None,
    context: torch.Tensor = None,
    y: Optional[torch.Tensor] = None,
    frame_rate: Optional[int] = 24,
    use_unified_sequence_parallel: bool = False,
    use_gradient_checkpointing: bool = False,
    use_gradient_checkpointing_offload: bool = False,
    **kwargs,
):
    video_x, audio_x = video_latents, audio_latents
    # First-Last Frame
    if y is not None:
        video_x = torch.cat([video_x, y], dim=1)

    # Timestep
    video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
    video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
    audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
    audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))

    # Context
    video_context = video_dit.text_embedding(context)
    audio_context = audio_dit.text_embedding(context)

    # Patchify
    video_x = video_dit.patch_embedding(video_x)
    f_v, h, w = video_x.shape[2:]
    video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
    seq_len_video = video_x.shape[1]

    audio_x = audio_dit.patch_embedding(audio_x)
    f_a = audio_x.shape[2]
    audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
    seq_len_audio = audio_x.shape[1]

    # Freqs
    video_freqs = torch.cat([
        video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
        video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
        video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
    ], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
    audio_freqs = torch.cat([
        audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
        audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
        audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
    ], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)

    video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
        video_fps=frame_rate,
        grid_size=(f_v, h, w),
        audio_steps=audio_x.shape[1],
        device=video_x.device,
        dtype=video_x.dtype,
    )
    # usp func
    if use_unified_sequence_parallel:
        from ..utils.xfuser import get_current_chunk, gather_all_chunks
    else:
        get_current_chunk = lambda x, dim=1: x
        gather_all_chunks = lambda x, seq_len, dim=1: x
    # Forward blocks
    for block_id in range(len(audio_dit.blocks)):
        if dual_tower_bridge.should_interact(block_id, "a2v"):
            video_x, audio_x = dual_tower_bridge(
                block_id,
                video_x,
                audio_x,
                x_freqs=video_rope,
                y_freqs=audio_rope,
                condition_scale=1.0,
                video_grid_size=(f_v, h, w),
                use_gradient_checkpointing=use_gradient_checkpointing,
                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
            )
        video_x = get_current_chunk(video_x, dim=1)
        video_x = gradient_checkpoint_forward(
            video_dit.blocks[block_id],
            use_gradient_checkpointing,
            use_gradient_checkpointing_offload,
            video_x, video_context, video_t_mod, video_freqs
        )
        video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
        audio_x = get_current_chunk(audio_x, dim=1)
        audio_x = gradient_checkpoint_forward(
            audio_dit.blocks[block_id],
            use_gradient_checkpointing,
            use_gradient_checkpointing_offload,
            audio_x, audio_context, audio_t_mod, audio_freqs
        )
        audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)

    video_x = get_current_chunk(video_x, dim=1)
    for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
        video_x = gradient_checkpoint_forward(
            video_dit.blocks[block_id],
            use_gradient_checkpointing,
            use_gradient_checkpointing_offload,
            video_x, video_context, video_t_mod, video_freqs
        )
    video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)

    # Head
    video_x = video_dit.head(video_x, video_t)
    video_x = video_dit.unpatchify(video_x, (f_v, h, w))

    audio_x = audio_dit.head(audio_x, audio_t)
    audio_x = audio_dit.unpatchify(audio_x, (f_a,))
    return video_x, audio_x