File size: 38,108 Bytes
ebfc6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
"""Validation sampling for LTX-2 training using ltx-core components.

This module provides a simplified validation pipeline for generating samples during training,
using the new ltx-core components (VideoLatentTools, AudioLatentTools, LatentState, etc.).
"""

from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Literal

import torch
from einops import rearrange
from torch import Tensor

from ltx_core.guidance.perturbations import (
    BatchedPerturbationConfig,
    Perturbation,
    PerturbationConfig,
    PerturbationType,
)
from ltx_core.model.transformer.modality import Modality
from ltx_core.model.transformer.model import X0Model
from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep
from ltx_core.pipeline.components.guiders import CFGGuider, STGGuider
from ltx_core.pipeline.components.noisers import GaussianNoiser
from ltx_core.pipeline.components.patchifiers import (
    AudioLatentShape,
    AudioPatchifier,
    VideoLatentPatchifier,
    VideoLatentShape,
    get_pixel_coords,
)
from ltx_core.pipeline.components.protocols import VideoPixelShape
from ltx_core.pipeline.components.schedulers import LTX2Scheduler
from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, VideoLatentTools
from ltx_core.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig
from ltx_trainer.progress import SamplingContext

if TYPE_CHECKING:
    from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder
    from ltx_core.model.audio_vae.vocoder import Vocoder
    from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel
    from ltx_core.model.transformer.model import LTXModel
    from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder
    from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder

# Video VAE scale factors (temporal, height, width)
VIDEO_SCALE_FACTORS = (8, 32, 32)


@dataclass
class CachedPromptEmbeddings:
    """Pre-computed text embeddings for a validation prompt.

    These embeddings are computed once at training start and reused for all validation runs,
    avoiding the need to load the full Gemma text encoder during validation.
    """

    video_context_positive: Tensor  # [1, seq_len, hidden_dim]
    audio_context_positive: Tensor  # [1, seq_len, hidden_dim]
    video_context_negative: Tensor | None = None
    audio_context_negative: Tensor | None = None


@dataclass
class TiledDecodingConfig:
    """Configuration for tiled video decoding to reduce VRAM usage.

    Tiled decoding splits the latent tensor into overlapping tiles, decodes each
    tile individually, and blends them together. This significantly reduces peak
    VRAM usage at the cost of slightly slower decoding.

    Defaults match the recommended values from ltx-core tests.
    """

    enabled: bool = True  # Whether to use tiled decoding (enabled by default)
    tile_size_pixels: int = 192  # Spatial tile size in pixels (must be ≥64 and divisible by 32)
    tile_overlap_pixels: int = 64  # Spatial tile overlap in pixels (must be divisible by 32)
    tile_size_frames: int = 48  # Temporal tile size in frames (must be ≥16 and divisible by 8)
    tile_overlap_frames: int = 24  # Temporal tile overlap in frames (must be divisible by 8)


@dataclass
class GenerationConfig:
    """Configuration for video/audio generation."""

    prompt: str  # Text prompt for generation
    negative_prompt: str = ""  # Negative prompt to avoid unwanted artifacts
    height: int = 544  # Output video height in pixels
    width: int = 960  # Output video width in pixels
    num_frames: int = 97  # Number of frames to generate
    frame_rate: float = 25.0  # Frame rate for temporal position scaling
    num_inference_steps: int = 30  # Number of denoising steps
    guidance_scale: float = 3.0  # CFG guidance scale
    seed: int = 42  # Random seed for reproducibility
    condition_image: Tensor | None = None  # Optional first frame image for image-to-video
    reference_video: Tensor | None = None  # For IC-LoRA: [F, C, H, W] in [0, 1]
    generate_audio: bool = True  # Whether to generate audio alongside video
    include_reference_in_output: bool = False  # For IC-LoRA: concatenate original reference with generated output
    cached_embeddings: CachedPromptEmbeddings | None = None  # Pre-computed text embeddings (avoids loading Gemma)
    stg_scale: float = 0.0  # STG strength (0.0 = disabled, recommended: 1.0)
    stg_blocks: list[int] | None = None  # Transformer blocks to perturb (None = all, recommended: [29])
    stg_mode: Literal["stg_av", "stg_v"] = "stg_av"  # STG mode: "stg_av" (audio+video) or "stg_v" (video only)
    # Tiled decoding config: None = use defaults (enabled), False = disable, or TiledDecodingConfig for custom settings
    tiled_decoding: TiledDecodingConfig | Literal[False] | None = None

    def __post_init__(self) -> None:
        """Apply default tiled decoding config if not provided."""
        if self.tiled_decoding is None:
            # Use default config with tiling enabled
            object.__setattr__(self, "tiled_decoding", TiledDecodingConfig())
        elif self.tiled_decoding is False:
            # Explicitly disabled - use config with enabled=False
            object.__setattr__(self, "tiled_decoding", TiledDecodingConfig(enabled=False))


class ValidationSampler:
    """Generates validation samples during training using ltx-core components.

    This class provides a simplified interface for generating video (and optionally audio)
    samples during training validation. It supports:
    - Text-to-video generation
    - Image-to-video generation (first frame conditioning)
    - Video-to-video generation (IC-LoRA reference video conditioning)
    - Optional audio generation

    The implementation follows the patterns from ltx_pipelines.single_stage.

    Text embeddings can be provided either via:
    - A full text_encoder (encodes prompts on-the-fly)
    - Pre-computed cached_embeddings (avoids loading Gemma during validation)
    """

    def __init__(
        self,
        transformer: "LTXModel",
        vae_decoder: "VideoDecoder",
        vae_encoder: "VideoEncoder | None",
        text_encoder: "AVGemmaTextEncoderModel | None" = None,
        audio_decoder: "AudioDecoder | None" = None,
        vocoder: "Vocoder | None" = None,
        sampling_context: SamplingContext | None = None,
    ):
        """Initialize the validation sampler.

        Args:
            transformer: LTX-2 transformer model
            vae_decoder: Video VAE decoder
            vae_encoder: Video VAE encoder (for image/video conditioning), can be None if not needed
            text_encoder: Gemma text encoder with embeddings connector (optional if cached_embeddings in config)
            audio_decoder: Optional audio VAE decoder (for audio generation)
            vocoder: Optional vocoder (for audio generation)
            sampling_context: Optional SamplingContext for progress display during denoising
        """
        self._transformer = transformer
        self._vae_decoder = vae_decoder
        self._vae_encoder = vae_encoder
        self._text_encoder = text_encoder
        self._audio_decoder = audio_decoder
        self._vocoder = vocoder
        self._sampling_context = sampling_context

        # Patchifiers
        self._video_patchifier = VideoLatentPatchifier(patch_size=1)
        self._audio_patchifier = AudioPatchifier(patch_size=1)

    # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation
    @torch.no_grad()
    def generate(
        self,
        config: GenerationConfig,
        device: torch.device | str = "cuda",
    ) -> tuple[Tensor, Tensor | None]:
        """Generate a video (and optionally audio) sample.

        Args:
            config: Generation configuration
            device: Device to run generation on

        Returns:
            Tuple of:
                - video: Video tensor [C, F, H, W] in [0, 1] (float32)
                - audio: Audio waveform tensor [C, samples] or None
        """
        device = torch.device(device) if isinstance(device, str) else device
        self._validate_config(config)

        # Route to appropriate generation method
        if config.reference_video is not None:
            return self._generate_with_reference(config, device)
        return self._generate_standard(config, device)

    def _generate_standard(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]:
        """Standard generation (text-to-video or image-to-video)."""
        # Get prompt embeddings (from cache or encode on-the-fly)
        v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device)

        # Setup generator
        generator = torch.Generator(device=device).manual_seed(config.seed)

        # Create latent tools
        video_tools = self._create_video_latent_tools(config)
        audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None

        # Create initial states
        video_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16)
        audio_clean_state = (
            audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None
        )

        # Apply image conditioning if provided
        if config.condition_image is not None:
            video_clean_state = self._apply_image_conditioning(
                video_clean_state, config.condition_image, config, device
            )

        # Add noise
        noiser = GaussianNoiser(generator=generator)
        video_state = noiser(latent_state=video_clean_state, noise_scale=1.0)
        audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None

        # Run denoising loop
        video_state, audio_state = self._run_denoising(
            config=config,
            video_state=video_state,
            audio_state=audio_state,
            video_clean_state=video_clean_state,
            audio_clean_state=audio_clean_state,
            v_ctx_pos=v_ctx_pos,
            a_ctx_pos=a_ctx_pos,
            v_ctx_neg=v_ctx_neg,
            a_ctx_neg=a_ctx_neg,
            device=device,
        )

        # Decode outputs
        video_state = video_tools.clear_conditioning(video_state)
        video_state = video_tools.unpatchify(video_state)
        video_output = self._decode_video(video_state, device, config.tiled_decoding)

        audio_output = None
        if audio_state is not None and audio_tools is not None:
            audio_state = audio_tools.clear_conditioning(audio_state)
            audio_state = audio_tools.unpatchify(audio_state)
            audio_output = self._decode_audio(audio_state, device)

        return video_output, audio_output

    def _generate_with_reference(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]:
        """Generate with reference video conditioning (IC-LoRA style).

        For IC-LoRA:
        - Reference video latents are concatenated with target latents
        - Reference latents have timestep=0 (clean, not denoised)
        - Target latents are denoised normally
        - If condition_image is also provided, the first frame of the target is conditioned
        - If include_reference_in_output is True, the preprocessed reference video
          is concatenated side-by-side with the generated video
        """
        # Get prompt embeddings (from cache or encode on-the-fly)
        v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device)

        # Setup generator
        generator = torch.Generator(device=device).manual_seed(config.seed)

        # Preprocess and encode reference video
        ref_video_preprocessed = self._preprocess_reference_video(config)
        ref_latent, ref_positions = self._encode_video(ref_video_preprocessed, config.frame_rate, device)
        ref_seq_len = ref_latent.shape[1]

        # Create target video state
        video_tools = self._create_video_latent_tools(config)
        target_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16)

        # Apply first-frame image conditioning to target if provided
        if config.condition_image is not None:
            target_clean_state = self._apply_image_conditioning(
                target_clean_state, config.condition_image, config, device
            )

        # Create combined state (reference + target)
        # denoise_mask shape is [B, seq_len, 1] after patchification
        ref_denoise_mask = torch.zeros(1, ref_seq_len, 1, device=device, dtype=torch.float32)
        combined_clean_state = LatentState(
            latent=torch.cat([ref_latent, target_clean_state.latent], dim=1),
            denoise_mask=torch.cat([ref_denoise_mask, target_clean_state.denoise_mask], dim=1),
            positions=torch.cat([ref_positions, target_clean_state.positions], dim=2),
            clean_latent=torch.cat([ref_latent, target_clean_state.clean_latent], dim=1),
        )

        # Add noise (only to the target portion via denoise_mask)
        noiser = GaussianNoiser(generator=generator)
        combined_state = noiser(latent_state=combined_clean_state, noise_scale=1.0)

        # Create audio state if needed
        audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None
        audio_clean_state = (
            audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None
        )
        audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None

        # Run denoising loop
        combined_state, audio_state = self._run_denoising(
            config=config,
            video_state=combined_state,
            audio_state=audio_state,
            video_clean_state=combined_clean_state,
            audio_clean_state=audio_clean_state,
            v_ctx_pos=v_ctx_pos,
            a_ctx_pos=a_ctx_pos,
            v_ctx_neg=v_ctx_neg,
            a_ctx_neg=a_ctx_neg,
            device=device,
        )

        # Extract target portion and decode
        target_latent = combined_state.latent[:, ref_seq_len:]
        video_output = self._decode_video_latent(target_latent, config, device)

        # Optionally concatenate original reference video side-by-side
        if config.include_reference_in_output:
            # Use preprocessed reference (already resized/cropped, in pixel space)
            # Convert from [B, C, F, H, W] to [C, F, H, W]
            ref_video_pixels = ref_video_preprocessed[0].cpu()
            # Normalize from [-1, 1] to [0, 1]
            ref_video_pixels = ((ref_video_pixels + 1.0) / 2.0).clamp(0.0, 1.0)
            video_output = self._concatenate_videos_side_by_side(ref_video_pixels, video_output)

        # Decode audio
        audio_output = None
        if audio_state is not None and audio_tools is not None:
            audio_state = audio_tools.clear_conditioning(audio_state)
            audio_state = audio_tools.unpatchify(audio_state)
            audio_output = self._decode_audio(audio_state, device)

        return video_output, audio_output

    def _create_video_latent_tools(self, config: GenerationConfig) -> VideoLatentTools:
        """Create video latent tools for the given configuration."""
        pixel_shape = VideoPixelShape(
            batch=1,
            frames=config.num_frames,
            height=config.height,
            width=config.width,
            fps=config.frame_rate,
        )
        return VideoLatentTools(
            patchifier=self._video_patchifier,
            target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape),
            fps=config.frame_rate,
            scale_factors=VIDEO_SCALE_FACTORS,
            causal_fix=True,
        )

    def _create_audio_latent_tools(self, config: GenerationConfig) -> AudioLatentTools:
        """Create audio latent tools for the given configuration."""
        return AudioLatentTools(
            patchifier=self._audio_patchifier,
            target_shape=AudioLatentShape.from_duration(batch=1, duration=config.num_frames / config.frame_rate),
        )

    def _apply_image_conditioning(
        self, video_state: LatentState, image: Tensor, config: GenerationConfig, device: torch.device
    ) -> LatentState:
        """Apply first-frame image conditioning to the video state."""
        # Encode the image
        encoded_image = self._encode_conditioning_image(image, config.height, config.width, device)

        # Patchify the encoded image (single frame)
        patchified_image = self._video_patchifier.patchify(encoded_image)  # [1, 1, C] -> [1, num_patches, C]
        num_image_tokens = patchified_image.shape[1]

        # Update the first frame tokens in the latent
        new_latent = video_state.latent.clone()
        new_latent[:, :num_image_tokens] = patchified_image.to(new_latent.dtype)

        # Update clean_latent as well (conditioning image is clean)
        new_clean_latent = video_state.clean_latent.clone()
        new_clean_latent[:, :num_image_tokens] = patchified_image.to(new_clean_latent.dtype)

        # Set denoise_mask to 0 for conditioned tokens (don't denoise them)
        new_denoise_mask = video_state.denoise_mask.clone()
        new_denoise_mask[:, :num_image_tokens] = 0.0

        return LatentState(
            latent=new_latent,
            denoise_mask=new_denoise_mask,
            positions=video_state.positions,
            clean_latent=new_clean_latent,
        )

    @staticmethod
    def _preprocess_reference_video(config: GenerationConfig) -> Tensor:
        """Preprocess reference video: resize, crop, and convert to model input format.

        Args:
            config: Generation configuration with reference_video

        Returns:
            Preprocessed video tensor [B, C, F, H, W] in [-1, 1] range
        """
        ref_video = config.reference_video  # [F, C, H, W] in [0, 1]
        target_height, target_width = config.height, config.width
        current_height, current_width = ref_video.shape[2:]

        # Resize maintaining aspect ratio and center crop if needed
        if current_height != target_height or current_width != target_width:
            aspect_ratio = current_width / current_height
            target_aspect_ratio = target_width / target_height

            if aspect_ratio > target_aspect_ratio:
                resize_height, resize_width = target_height, int(target_height * aspect_ratio)
            else:
                resize_height, resize_width = int(target_width / aspect_ratio), target_width

            ref_video = torch.nn.functional.interpolate(
                ref_video, size=(resize_height, resize_width), mode="bilinear", align_corners=False
            )

            # Center crop
            h_start = (resize_height - target_height) // 2
            w_start = (resize_width - target_width) // 2
            ref_video = ref_video[:, :, h_start : h_start + target_height, w_start : w_start + target_width]

        # Convert to [B, C, F, H, W] and trim to valid frame count (k*8 + 1)
        ref_video = rearrange(ref_video, "f c h w -> 1 c f h w")
        valid_frames = (ref_video.shape[2] - 1) // 8 * 8 + 1
        ref_video = ref_video[:, :, :valid_frames]

        # Convert to [-1, 1] range
        return ref_video * 2.0 - 1.0

    def _encode_video(self, video: Tensor, fps: float, device: torch.device) -> tuple[Tensor, Tensor]:
        """Encode video to patchified latents and compute positions.

        Args:
            video: Video tensor [B, C, F, H, W] in [-1, 1] range
            fps: Frame rate for temporal position scaling
            device: Device to run encoding on

        Returns:
            Tuple of (patchified_latents, positions)
        """
        video = video.to(device=device, dtype=torch.float32)

        # Encode with VAE
        self._vae_encoder.to(device)
        with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16):
            latents = self._vae_encoder(video)
        self._vae_encoder.to("cpu")

        latents = latents.to(torch.bfloat16)
        patchified = self._video_patchifier.patchify(latents)

        # Compute positions
        latent_shape = VideoLatentShape(
            batch=1,
            channels=latents.shape[1],
            frames=latents.shape[2],
            height=latents.shape[3],
            width=latents.shape[4],
        )
        latent_coords = self._video_patchifier.get_patch_grid_bounds(output_shape=latent_shape, device=device)
        positions = get_pixel_coords(latent_coords, scale_factors=VIDEO_SCALE_FACTORS, causal_fix=True)
        positions = positions.to(torch.bfloat16)
        positions[:, 0, ...] = positions[:, 0, ...] / fps

        return patchified, positions

    def _run_denoising(
        self,
        config: GenerationConfig,
        video_state: LatentState,
        audio_state: LatentState | None,
        video_clean_state: LatentState,
        audio_clean_state: LatentState | None,
        v_ctx_pos: Tensor,
        a_ctx_pos: Tensor,
        v_ctx_neg: Tensor | None,
        a_ctx_neg: Tensor | None,
        device: torch.device,
    ) -> tuple[LatentState, LatentState | None]:
        """Run the denoising loop using X0 prediction with CFG and optional STG."""
        scheduler = LTX2Scheduler()
        sigmas = scheduler.execute(steps=config.num_inference_steps).to(device).float()
        stepper = EulerDiffusionStep()
        cfg_guider = CFGGuider(config.guidance_scale)
        stg_guider = STGGuider(config.stg_scale)

        # Build STG perturbation config if STG is enabled
        stg_perturbation_config = self._build_stg_perturbation_config(config) if stg_guider.enabled() else None

        # Create initial modalities (will be updated each step via replace())
        video = Modality(
            enabled=True,
            latent=video_state.latent,
            timesteps=video_state.denoise_mask,
            positions=video_state.positions,
            context=v_ctx_pos,
            context_mask=None,
        )

        # Audio modality is None when not generating audio
        audio: Modality | None = None
        if audio_state is not None:
            audio = Modality(
                enabled=True,
                latent=audio_state.latent,
                timesteps=audio_state.denoise_mask,
                positions=audio_state.positions,
                context=a_ctx_pos,
                context_mask=None,
            )

        # Wrap transformer with X0Model to convert velocity predictions to denoised outputs
        self._transformer.to(device)
        x0_model = X0Model(self._transformer)

        with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16):
            for step_idx, sigma in enumerate(sigmas[:-1]):
                # Update modalities with current state and timesteps
                video = replace(
                    video,
                    latent=video_state.latent,
                    timesteps=sigma * video_state.denoise_mask,
                    positions=video_state.positions,
                )

                if audio is not None and audio_state is not None:
                    audio = replace(
                        audio,
                        latent=audio_state.latent,
                        timesteps=sigma * audio_state.denoise_mask,
                        positions=audio_state.positions,
                    )

                # Run model (positive pass) - X0Model returns denoised outputs
                pos_video, pos_audio = x0_model(video=video, audio=audio, perturbations=None)
                denoised_video, denoised_audio = pos_video, pos_audio

                # Apply CFG if guidance_scale != 1.0
                if cfg_guider.enabled() and v_ctx_neg is not None:
                    video_neg = replace(video, context=v_ctx_neg)
                    audio_neg = replace(audio, context=a_ctx_neg) if audio is not None else None
                    neg_video, neg_audio = x0_model(video=video_neg, audio=audio_neg, perturbations=None)

                    denoised_video = denoised_video + cfg_guider.delta(pos_video, neg_video)
                    if audio is not None and denoised_audio is not None:
                        denoised_audio = denoised_audio + cfg_guider.delta(pos_audio, neg_audio)

                # Apply STG if stg_scale != 0.0
                if stg_guider.enabled() and stg_perturbation_config is not None:
                    perturbed_video, perturbed_audio = x0_model(
                        video=video, audio=audio, perturbations=stg_perturbation_config
                    )
                    denoised_video = denoised_video + stg_guider.delta(pos_video, perturbed_video)
                    if audio is not None and denoised_audio is not None and perturbed_audio is not None:
                        denoised_audio = denoised_audio + stg_guider.delta(pos_audio, perturbed_audio)

                # Apply conditioning mask (keep conditioned tokens clean)
                denoised_video = denoised_video * video_state.denoise_mask + video_clean_state.latent.float() * (
                    1 - video_state.denoise_mask
                )
                if audio is not None and audio_state is not None and audio_clean_state is not None:
                    denoised_audio = denoised_audio * audio_state.denoise_mask + audio_clean_state.latent.float() * (
                        1 - audio_state.denoise_mask
                    )

                # Euler step
                video_state = replace(
                    video_state,
                    latent=stepper.step(
                        sample=video.latent, denoised_sample=denoised_video, sigmas=sigmas, step_index=step_idx
                    ),
                )
                if audio is not None and audio_state is not None:
                    audio_state = replace(
                        audio_state,
                        latent=stepper.step(
                            sample=audio.latent, denoised_sample=denoised_audio, sigmas=sigmas, step_index=step_idx
                        ),
                    )

                # Update progress
                if self._sampling_context is not None:
                    self._sampling_context.advance_step()

        return video_state, audio_state

    @staticmethod
    def _build_stg_perturbation_config(config: GenerationConfig) -> BatchedPerturbationConfig:
        """Build the perturbation config for STG based on the stg_mode."""
        # Always skip video self-attention for STG
        perturbations: list[Perturbation] = [
            Perturbation(type=PerturbationType.SKIP_VIDEO_SELF_ATTN, blocks=config.stg_blocks)
        ]

        # Optionally also skip audio self-attention (stg_av mode)
        if config.stg_mode == "stg_av":
            perturbations.append(Perturbation(type=PerturbationType.SKIP_AUDIO_SELF_ATTN, blocks=config.stg_blocks))

        perturbation_config = PerturbationConfig(perturbations=perturbations)
        # Batch size is 1 for validation
        return BatchedPerturbationConfig(perturbations=[perturbation_config])

    def _decode_video_latent(self, latent: Tensor, config: GenerationConfig, device: torch.device) -> Tensor:
        """Decode patchified video latent to pixel space."""
        # Unpatchify
        latent_frames = config.num_frames // VIDEO_SCALE_FACTORS[0] + 1
        latent_height = config.height // VIDEO_SCALE_FACTORS[1]
        latent_width = config.width // VIDEO_SCALE_FACTORS[2]

        unpatchified = self._video_patchifier.unpatchify(
            latent,
            output_shape=VideoLatentShape(
                height=latent_height,
                width=latent_width,
                frames=latent_frames,
                batch=1,
                channels=128,
            ),
        )

        # Decode - ensure bfloat16 to match decoder weights
        self._vae_decoder.to(device)
        unpatchified = unpatchified.to(dtype=torch.bfloat16)
        tiled_config = config.tiled_decoding

        if tiled_config is not None and tiled_config.enabled:
            # Use tiled decoding for reduced VRAM
            tiling_config = TilingConfig(
                spatial_config=SpatialTilingConfig(
                    tile_size_in_pixels=tiled_config.tile_size_pixels,
                    tile_overlap_in_pixels=tiled_config.tile_overlap_pixels,
                ),
                temporal_config=TemporalTilingConfig(
                    tile_size_in_frames=tiled_config.tile_size_frames,
                    tile_overlap_in_frames=tiled_config.tile_overlap_frames,
                ),
            )
            chunks = []
            for video_chunk, _ in self._vae_decoder.tiled_decode(
                unpatchified,
                tiling_config=tiling_config,
            ):
                chunks.append(video_chunk)
            decoded_video = torch.cat(chunks, dim=2)
        else:
            # Standard full decoding
            decoded_video = self._vae_decoder(unpatchified)

        decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0)
        self._vae_decoder.to("cpu")

        return decoded_video[0].float().cpu()

    def _validate_config(self, config: GenerationConfig) -> None:
        """Validate generation configuration."""
        if config.height % 32 != 0 or config.width % 32 != 0:
            raise ValueError(f"height and width must be divisible by 32, got {config.height}x{config.width}")
        if config.num_frames % 8 != 1:
            raise ValueError(f"num_frames must satisfy num_frames % 8 == 1, got {config.num_frames}")
        if config.generate_audio and (self._audio_decoder is None or self._vocoder is None):
            raise ValueError("Audio generation requires audio_decoder and vocoder")
        if config.condition_image is not None and self._vae_encoder is None:
            raise ValueError("Image conditioning requires vae_encoder")
        if config.reference_video is not None and self._vae_encoder is None:
            raise ValueError("Reference video conditioning requires vae_encoder")

        # Validate prompt embedding source
        if config.cached_embeddings is None and self._text_encoder is None:
            raise ValueError("Either text_encoder or config.cached_embeddings must be provided")

    def _get_prompt_embeddings(
        self, config: GenerationConfig, device: torch.device
    ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]:
        """Get prompt embeddings from config cache or encode on-the-fly."""
        if config.cached_embeddings is not None:
            # Use pre-computed embeddings from config
            cached = config.cached_embeddings
            v_ctx_pos = cached.video_context_positive.to(device)
            a_ctx_pos = cached.audio_context_positive.to(device)
            v_ctx_neg = cached.video_context_negative.to(device) if cached.video_context_negative is not None else None
            a_ctx_neg = cached.audio_context_negative.to(device) if cached.audio_context_negative is not None else None
            return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg

        # Fall back to encoding on-the-fly
        return self._encode_prompts(config, device)

    def _encode_prompts(
        self, config: GenerationConfig, device: torch.device
    ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]:
        """Encode positive and negative prompts using the text encoder."""
        self._text_encoder.to(device)
        v_ctx_pos, a_ctx_pos, _ = self._text_encoder(config.prompt)
        v_ctx_neg, a_ctx_neg = None, None
        if config.guidance_scale != 1.0:
            v_ctx_neg, a_ctx_neg, _ = self._text_encoder(config.negative_prompt)

        # Move the base Gemma model to CPU but keep embeddings connectors on GPU
        # as this module is also used during training
        self._text_encoder.model.to("cpu")
        self._text_encoder.feature_extractor_linear.to("cpu")

        return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg

    def _decode_video(
        self, video_state: LatentState, device: torch.device, tiled_config: TiledDecodingConfig | None = None
    ) -> Tensor:
        """Decode video latents to pixel space.

        Args:
            video_state: Video latent state to decode
            device: Device to run decoding on
            tiled_config: Optional tiled decoding configuration for reduced VRAM usage

        Returns:
            Decoded video tensor [C, F, H, W] in [0, 1] range
        """
        self._vae_decoder.to(device)
        # Ensure latent is bfloat16 to match decoder weights
        latent = video_state.latent.to(dtype=torch.bfloat16)

        if tiled_config is not None and tiled_config.enabled:
            # Use tiled decoding for reduced VRAM
            tiling_config = TilingConfig(
                spatial_config=SpatialTilingConfig(
                    tile_size_in_pixels=tiled_config.tile_size_pixels,
                    tile_overlap_in_pixels=tiled_config.tile_overlap_pixels,
                ),
                temporal_config=TemporalTilingConfig(
                    tile_size_in_frames=tiled_config.tile_size_frames,
                    tile_overlap_in_frames=tiled_config.tile_overlap_frames,
                ),
            )
            chunks = []
            for video_chunk, _ in self._vae_decoder.tiled_decode(
                latent,
                tiling_config=tiling_config,
            ):
                chunks.append(video_chunk)
            decoded_video = torch.cat(chunks, dim=2)
        else:
            # Standard full decoding
            decoded_video = self._vae_decoder(latent)

        decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0)
        self._vae_decoder.to("cpu")
        return decoded_video[0].float().cpu()

    def _decode_audio(self, audio_state: LatentState, device: torch.device) -> Tensor:
        """Decode audio latents to waveform."""
        self._audio_decoder.to(device)
        # Ensure latent is bfloat16 to match decoder weights
        latent = audio_state.latent.to(dtype=torch.bfloat16)
        decoded_audio = self._audio_decoder(latent)
        self._audio_decoder.to("cpu")

        self._vocoder.to(device)
        audio_waveform = self._vocoder(decoded_audio)
        self._vocoder.to("cpu")

        return audio_waveform.squeeze(0).float().cpu()

    @staticmethod
    def _concatenate_videos_side_by_side(left_video: Tensor, right_video: Tensor) -> Tensor:
        """Concatenate two videos side-by-side (horizontally).

        If the videos have different frame counts, the shorter one is padded with
        its last frame repeated.

        Args:
            left_video: Left video tensor [C, F1, H, W] in [0, 1]
            right_video: Right video tensor [C, F2, H, W] in [0, 1]

        Returns:
            Concatenated video tensor [C, max(F1,F2), H, W*2] in [0, 1]
        """
        left_frames = left_video.shape[1]
        right_frames = right_video.shape[1]

        # Pad shorter video by repeating last frame
        if left_frames < right_frames:
            padding = left_video[:, -1:, :, :].expand(-1, right_frames - left_frames, -1, -1)
            left_video = torch.cat([left_video, padding], dim=1)
        elif right_frames < left_frames:
            padding = right_video[:, -1:, :, :].expand(-1, left_frames - right_frames, -1, -1)
            right_video = torch.cat([right_video, padding], dim=1)

        # Concatenate along width dimension
        return torch.cat([left_video, right_video], dim=3)

    def _encode_conditioning_image(
        self,
        image: Tensor,
        target_height: int,
        target_width: int,
        device: torch.device,
    ) -> Tensor:
        """Encode a conditioning image to latent space.

        The image is resized to cover the target dimensions while preserving aspect ratio,
        then center-cropped to exactly match the target size.
        """
        # image is [C, H, W] in [0, 1]  # noqa: ERA001
        current_height, current_width = image.shape[1:]

        # Resize maintaining aspect ratio (cover target, then center crop)
        if current_height != target_height or current_width != target_width:
            aspect_ratio = current_width / current_height
            target_aspect_ratio = target_width / target_height

            if aspect_ratio > target_aspect_ratio:
                # Image is wider than target - resize to match height, crop width
                resize_height = target_height
                resize_width = int(target_height * aspect_ratio)
            else:
                # Image is taller than target - resize to match width, crop height
                resize_height = int(target_width / aspect_ratio)
                resize_width = target_width

            image = rearrange(image, "c h w -> 1 c h w")
            image = torch.nn.functional.interpolate(
                image, size=(resize_height, resize_width), mode="bilinear", align_corners=False
            )

            # Center crop to target dimensions
            h_start = (resize_height - target_height) // 2
            w_start = (resize_width - target_width) // 2
            image = image[:, :, h_start : h_start + target_height, w_start : w_start + target_width]
        else:
            image = rearrange(image, "c h w -> 1 c h w")

        # Add frame dimension and convert to [-1, 1]
        image = rearrange(image, "b c h w -> b c 1 h w")
        image = (image * 2.0 - 1.0).to(device=device, dtype=torch.float32)

        # Encode
        self._vae_encoder.to(device)
        with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16):
            encoded = self._vae_encoder(image)
        self._vae_encoder.to("cpu")

        return encoded