Spaces:
Running
on
Zero
Running
on
Zero
| """Validation sampling for LTX-2 training using ltx-core components. | |
| This module provides a simplified validation pipeline for generating samples during training, | |
| using the new ltx-core components (VideoLatentTools, AudioLatentTools, LatentState, etc.). | |
| """ | |
| from dataclasses import dataclass, replace | |
| from typing import TYPE_CHECKING, Literal | |
| import torch | |
| from einops import rearrange | |
| from torch import Tensor | |
| from ltx_core.guidance.perturbations import ( | |
| BatchedPerturbationConfig, | |
| Perturbation, | |
| PerturbationConfig, | |
| PerturbationType, | |
| ) | |
| from ltx_core.model.transformer.modality import Modality | |
| from ltx_core.model.transformer.model import X0Model | |
| from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep | |
| from ltx_core.pipeline.components.guiders import CFGGuider, STGGuider | |
| from ltx_core.pipeline.components.noisers import GaussianNoiser | |
| from ltx_core.pipeline.components.patchifiers import ( | |
| AudioLatentShape, | |
| AudioPatchifier, | |
| VideoLatentPatchifier, | |
| VideoLatentShape, | |
| get_pixel_coords, | |
| ) | |
| from ltx_core.pipeline.components.protocols import VideoPixelShape | |
| from ltx_core.pipeline.components.schedulers import LTX2Scheduler | |
| from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, VideoLatentTools | |
| from ltx_core.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig | |
| from ltx_trainer.progress import SamplingContext | |
| if TYPE_CHECKING: | |
| from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder | |
| from ltx_core.model.audio_vae.vocoder import Vocoder | |
| from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel | |
| from ltx_core.model.transformer.model import LTXModel | |
| from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder | |
| from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder | |
| # Video VAE scale factors (temporal, height, width) | |
| VIDEO_SCALE_FACTORS = (8, 32, 32) | |
| class CachedPromptEmbeddings: | |
| """Pre-computed text embeddings for a validation prompt. | |
| These embeddings are computed once at training start and reused for all validation runs, | |
| avoiding the need to load the full Gemma text encoder during validation. | |
| """ | |
| video_context_positive: Tensor # [1, seq_len, hidden_dim] | |
| audio_context_positive: Tensor # [1, seq_len, hidden_dim] | |
| video_context_negative: Tensor | None = None | |
| audio_context_negative: Tensor | None = None | |
| class TiledDecodingConfig: | |
| """Configuration for tiled video decoding to reduce VRAM usage. | |
| Tiled decoding splits the latent tensor into overlapping tiles, decodes each | |
| tile individually, and blends them together. This significantly reduces peak | |
| VRAM usage at the cost of slightly slower decoding. | |
| Defaults match the recommended values from ltx-core tests. | |
| """ | |
| enabled: bool = True # Whether to use tiled decoding (enabled by default) | |
| tile_size_pixels: int = 192 # Spatial tile size in pixels (must be ≥64 and divisible by 32) | |
| tile_overlap_pixels: int = 64 # Spatial tile overlap in pixels (must be divisible by 32) | |
| tile_size_frames: int = 48 # Temporal tile size in frames (must be ≥16 and divisible by 8) | |
| tile_overlap_frames: int = 24 # Temporal tile overlap in frames (must be divisible by 8) | |
| class GenerationConfig: | |
| """Configuration for video/audio generation.""" | |
| prompt: str # Text prompt for generation | |
| negative_prompt: str = "" # Negative prompt to avoid unwanted artifacts | |
| height: int = 544 # Output video height in pixels | |
| width: int = 960 # Output video width in pixels | |
| num_frames: int = 97 # Number of frames to generate | |
| frame_rate: float = 25.0 # Frame rate for temporal position scaling | |
| num_inference_steps: int = 30 # Number of denoising steps | |
| guidance_scale: float = 3.0 # CFG guidance scale | |
| seed: int = 42 # Random seed for reproducibility | |
| condition_image: Tensor | None = None # Optional first frame image for image-to-video | |
| reference_video: Tensor | None = None # For IC-LoRA: [F, C, H, W] in [0, 1] | |
| generate_audio: bool = True # Whether to generate audio alongside video | |
| include_reference_in_output: bool = False # For IC-LoRA: concatenate original reference with generated output | |
| cached_embeddings: CachedPromptEmbeddings | None = None # Pre-computed text embeddings (avoids loading Gemma) | |
| stg_scale: float = 0.0 # STG strength (0.0 = disabled, recommended: 1.0) | |
| stg_blocks: list[int] | None = None # Transformer blocks to perturb (None = all, recommended: [29]) | |
| stg_mode: Literal["stg_av", "stg_v"] = "stg_av" # STG mode: "stg_av" (audio+video) or "stg_v" (video only) | |
| # Tiled decoding config: None = use defaults (enabled), False = disable, or TiledDecodingConfig for custom settings | |
| tiled_decoding: TiledDecodingConfig | Literal[False] | None = None | |
| def __post_init__(self) -> None: | |
| """Apply default tiled decoding config if not provided.""" | |
| if self.tiled_decoding is None: | |
| # Use default config with tiling enabled | |
| object.__setattr__(self, "tiled_decoding", TiledDecodingConfig()) | |
| elif self.tiled_decoding is False: | |
| # Explicitly disabled - use config with enabled=False | |
| object.__setattr__(self, "tiled_decoding", TiledDecodingConfig(enabled=False)) | |
| class ValidationSampler: | |
| """Generates validation samples during training using ltx-core components. | |
| This class provides a simplified interface for generating video (and optionally audio) | |
| samples during training validation. It supports: | |
| - Text-to-video generation | |
| - Image-to-video generation (first frame conditioning) | |
| - Video-to-video generation (IC-LoRA reference video conditioning) | |
| - Optional audio generation | |
| The implementation follows the patterns from ltx_pipelines.single_stage. | |
| Text embeddings can be provided either via: | |
| - A full text_encoder (encodes prompts on-the-fly) | |
| - Pre-computed cached_embeddings (avoids loading Gemma during validation) | |
| """ | |
| def __init__( | |
| self, | |
| transformer: "LTXModel", | |
| vae_decoder: "VideoDecoder", | |
| vae_encoder: "VideoEncoder | None", | |
| text_encoder: "AVGemmaTextEncoderModel | None" = None, | |
| audio_decoder: "AudioDecoder | None" = None, | |
| vocoder: "Vocoder | None" = None, | |
| sampling_context: SamplingContext | None = None, | |
| ): | |
| """Initialize the validation sampler. | |
| Args: | |
| transformer: LTX-2 transformer model | |
| vae_decoder: Video VAE decoder | |
| vae_encoder: Video VAE encoder (for image/video conditioning), can be None if not needed | |
| text_encoder: Gemma text encoder with embeddings connector (optional if cached_embeddings in config) | |
| audio_decoder: Optional audio VAE decoder (for audio generation) | |
| vocoder: Optional vocoder (for audio generation) | |
| sampling_context: Optional SamplingContext for progress display during denoising | |
| """ | |
| self._transformer = transformer | |
| self._vae_decoder = vae_decoder | |
| self._vae_encoder = vae_encoder | |
| self._text_encoder = text_encoder | |
| self._audio_decoder = audio_decoder | |
| self._vocoder = vocoder | |
| self._sampling_context = sampling_context | |
| # Patchifiers | |
| self._video_patchifier = VideoLatentPatchifier(patch_size=1) | |
| self._audio_patchifier = AudioPatchifier(patch_size=1) | |
| # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation | |
| def generate( | |
| self, | |
| config: GenerationConfig, | |
| device: torch.device | str = "cuda", | |
| ) -> tuple[Tensor, Tensor | None]: | |
| """Generate a video (and optionally audio) sample. | |
| Args: | |
| config: Generation configuration | |
| device: Device to run generation on | |
| Returns: | |
| Tuple of: | |
| - video: Video tensor [C, F, H, W] in [0, 1] (float32) | |
| - audio: Audio waveform tensor [C, samples] or None | |
| """ | |
| device = torch.device(device) if isinstance(device, str) else device | |
| self._validate_config(config) | |
| # Route to appropriate generation method | |
| if config.reference_video is not None: | |
| return self._generate_with_reference(config, device) | |
| return self._generate_standard(config, device) | |
| def _generate_standard(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: | |
| """Standard generation (text-to-video or image-to-video).""" | |
| # Get prompt embeddings (from cache or encode on-the-fly) | |
| v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) | |
| # Setup generator | |
| generator = torch.Generator(device=device).manual_seed(config.seed) | |
| # Create latent tools | |
| video_tools = self._create_video_latent_tools(config) | |
| audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None | |
| # Create initial states | |
| video_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) | |
| audio_clean_state = ( | |
| audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None | |
| ) | |
| # Apply image conditioning if provided | |
| if config.condition_image is not None: | |
| video_clean_state = self._apply_image_conditioning( | |
| video_clean_state, config.condition_image, config, device | |
| ) | |
| # Add noise | |
| noiser = GaussianNoiser(generator=generator) | |
| video_state = noiser(latent_state=video_clean_state, noise_scale=1.0) | |
| audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None | |
| # Run denoising loop | |
| video_state, audio_state = self._run_denoising( | |
| config=config, | |
| video_state=video_state, | |
| audio_state=audio_state, | |
| video_clean_state=video_clean_state, | |
| audio_clean_state=audio_clean_state, | |
| v_ctx_pos=v_ctx_pos, | |
| a_ctx_pos=a_ctx_pos, | |
| v_ctx_neg=v_ctx_neg, | |
| a_ctx_neg=a_ctx_neg, | |
| device=device, | |
| ) | |
| # Decode outputs | |
| video_state = video_tools.clear_conditioning(video_state) | |
| video_state = video_tools.unpatchify(video_state) | |
| video_output = self._decode_video(video_state, device, config.tiled_decoding) | |
| audio_output = None | |
| if audio_state is not None and audio_tools is not None: | |
| audio_state = audio_tools.clear_conditioning(audio_state) | |
| audio_state = audio_tools.unpatchify(audio_state) | |
| audio_output = self._decode_audio(audio_state, device) | |
| return video_output, audio_output | |
| def _generate_with_reference(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: | |
| """Generate with reference video conditioning (IC-LoRA style). | |
| For IC-LoRA: | |
| - Reference video latents are concatenated with target latents | |
| - Reference latents have timestep=0 (clean, not denoised) | |
| - Target latents are denoised normally | |
| - If condition_image is also provided, the first frame of the target is conditioned | |
| - If include_reference_in_output is True, the preprocessed reference video | |
| is concatenated side-by-side with the generated video | |
| """ | |
| # Get prompt embeddings (from cache or encode on-the-fly) | |
| v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) | |
| # Setup generator | |
| generator = torch.Generator(device=device).manual_seed(config.seed) | |
| # Preprocess and encode reference video | |
| ref_video_preprocessed = self._preprocess_reference_video(config) | |
| ref_latent, ref_positions = self._encode_video(ref_video_preprocessed, config.frame_rate, device) | |
| ref_seq_len = ref_latent.shape[1] | |
| # Create target video state | |
| video_tools = self._create_video_latent_tools(config) | |
| target_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) | |
| # Apply first-frame image conditioning to target if provided | |
| if config.condition_image is not None: | |
| target_clean_state = self._apply_image_conditioning( | |
| target_clean_state, config.condition_image, config, device | |
| ) | |
| # Create combined state (reference + target) | |
| # denoise_mask shape is [B, seq_len, 1] after patchification | |
| ref_denoise_mask = torch.zeros(1, ref_seq_len, 1, device=device, dtype=torch.float32) | |
| combined_clean_state = LatentState( | |
| latent=torch.cat([ref_latent, target_clean_state.latent], dim=1), | |
| denoise_mask=torch.cat([ref_denoise_mask, target_clean_state.denoise_mask], dim=1), | |
| positions=torch.cat([ref_positions, target_clean_state.positions], dim=2), | |
| clean_latent=torch.cat([ref_latent, target_clean_state.clean_latent], dim=1), | |
| ) | |
| # Add noise (only to the target portion via denoise_mask) | |
| noiser = GaussianNoiser(generator=generator) | |
| combined_state = noiser(latent_state=combined_clean_state, noise_scale=1.0) | |
| # Create audio state if needed | |
| audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None | |
| audio_clean_state = ( | |
| audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None | |
| ) | |
| audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None | |
| # Run denoising loop | |
| combined_state, audio_state = self._run_denoising( | |
| config=config, | |
| video_state=combined_state, | |
| audio_state=audio_state, | |
| video_clean_state=combined_clean_state, | |
| audio_clean_state=audio_clean_state, | |
| v_ctx_pos=v_ctx_pos, | |
| a_ctx_pos=a_ctx_pos, | |
| v_ctx_neg=v_ctx_neg, | |
| a_ctx_neg=a_ctx_neg, | |
| device=device, | |
| ) | |
| # Extract target portion and decode | |
| target_latent = combined_state.latent[:, ref_seq_len:] | |
| video_output = self._decode_video_latent(target_latent, config, device) | |
| # Optionally concatenate original reference video side-by-side | |
| if config.include_reference_in_output: | |
| # Use preprocessed reference (already resized/cropped, in pixel space) | |
| # Convert from [B, C, F, H, W] to [C, F, H, W] | |
| ref_video_pixels = ref_video_preprocessed[0].cpu() | |
| # Normalize from [-1, 1] to [0, 1] | |
| ref_video_pixels = ((ref_video_pixels + 1.0) / 2.0).clamp(0.0, 1.0) | |
| video_output = self._concatenate_videos_side_by_side(ref_video_pixels, video_output) | |
| # Decode audio | |
| audio_output = None | |
| if audio_state is not None and audio_tools is not None: | |
| audio_state = audio_tools.clear_conditioning(audio_state) | |
| audio_state = audio_tools.unpatchify(audio_state) | |
| audio_output = self._decode_audio(audio_state, device) | |
| return video_output, audio_output | |
| def _create_video_latent_tools(self, config: GenerationConfig) -> VideoLatentTools: | |
| """Create video latent tools for the given configuration.""" | |
| pixel_shape = VideoPixelShape( | |
| batch=1, | |
| frames=config.num_frames, | |
| height=config.height, | |
| width=config.width, | |
| fps=config.frame_rate, | |
| ) | |
| return VideoLatentTools( | |
| patchifier=self._video_patchifier, | |
| target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), | |
| fps=config.frame_rate, | |
| scale_factors=VIDEO_SCALE_FACTORS, | |
| causal_fix=True, | |
| ) | |
| def _create_audio_latent_tools(self, config: GenerationConfig) -> AudioLatentTools: | |
| """Create audio latent tools for the given configuration.""" | |
| return AudioLatentTools( | |
| patchifier=self._audio_patchifier, | |
| target_shape=AudioLatentShape.from_duration(batch=1, duration=config.num_frames / config.frame_rate), | |
| ) | |
| def _apply_image_conditioning( | |
| self, video_state: LatentState, image: Tensor, config: GenerationConfig, device: torch.device | |
| ) -> LatentState: | |
| """Apply first-frame image conditioning to the video state.""" | |
| # Encode the image | |
| encoded_image = self._encode_conditioning_image(image, config.height, config.width, device) | |
| # Patchify the encoded image (single frame) | |
| patchified_image = self._video_patchifier.patchify(encoded_image) # [1, 1, C] -> [1, num_patches, C] | |
| num_image_tokens = patchified_image.shape[1] | |
| # Update the first frame tokens in the latent | |
| new_latent = video_state.latent.clone() | |
| new_latent[:, :num_image_tokens] = patchified_image.to(new_latent.dtype) | |
| # Update clean_latent as well (conditioning image is clean) | |
| new_clean_latent = video_state.clean_latent.clone() | |
| new_clean_latent[:, :num_image_tokens] = patchified_image.to(new_clean_latent.dtype) | |
| # Set denoise_mask to 0 for conditioned tokens (don't denoise them) | |
| new_denoise_mask = video_state.denoise_mask.clone() | |
| new_denoise_mask[:, :num_image_tokens] = 0.0 | |
| return LatentState( | |
| latent=new_latent, | |
| denoise_mask=new_denoise_mask, | |
| positions=video_state.positions, | |
| clean_latent=new_clean_latent, | |
| ) | |
| def _preprocess_reference_video(config: GenerationConfig) -> Tensor: | |
| """Preprocess reference video: resize, crop, and convert to model input format. | |
| Args: | |
| config: Generation configuration with reference_video | |
| Returns: | |
| Preprocessed video tensor [B, C, F, H, W] in [-1, 1] range | |
| """ | |
| ref_video = config.reference_video # [F, C, H, W] in [0, 1] | |
| target_height, target_width = config.height, config.width | |
| current_height, current_width = ref_video.shape[2:] | |
| # Resize maintaining aspect ratio and center crop if needed | |
| if current_height != target_height or current_width != target_width: | |
| aspect_ratio = current_width / current_height | |
| target_aspect_ratio = target_width / target_height | |
| if aspect_ratio > target_aspect_ratio: | |
| resize_height, resize_width = target_height, int(target_height * aspect_ratio) | |
| else: | |
| resize_height, resize_width = int(target_width / aspect_ratio), target_width | |
| ref_video = torch.nn.functional.interpolate( | |
| ref_video, size=(resize_height, resize_width), mode="bilinear", align_corners=False | |
| ) | |
| # Center crop | |
| h_start = (resize_height - target_height) // 2 | |
| w_start = (resize_width - target_width) // 2 | |
| ref_video = ref_video[:, :, h_start : h_start + target_height, w_start : w_start + target_width] | |
| # Convert to [B, C, F, H, W] and trim to valid frame count (k*8 + 1) | |
| ref_video = rearrange(ref_video, "f c h w -> 1 c f h w") | |
| valid_frames = (ref_video.shape[2] - 1) // 8 * 8 + 1 | |
| ref_video = ref_video[:, :, :valid_frames] | |
| # Convert to [-1, 1] range | |
| return ref_video * 2.0 - 1.0 | |
| def _encode_video(self, video: Tensor, fps: float, device: torch.device) -> tuple[Tensor, Tensor]: | |
| """Encode video to patchified latents and compute positions. | |
| Args: | |
| video: Video tensor [B, C, F, H, W] in [-1, 1] range | |
| fps: Frame rate for temporal position scaling | |
| device: Device to run encoding on | |
| Returns: | |
| Tuple of (patchified_latents, positions) | |
| """ | |
| video = video.to(device=device, dtype=torch.float32) | |
| # Encode with VAE | |
| self._vae_encoder.to(device) | |
| with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): | |
| latents = self._vae_encoder(video) | |
| self._vae_encoder.to("cpu") | |
| latents = latents.to(torch.bfloat16) | |
| patchified = self._video_patchifier.patchify(latents) | |
| # Compute positions | |
| latent_shape = VideoLatentShape( | |
| batch=1, | |
| channels=latents.shape[1], | |
| frames=latents.shape[2], | |
| height=latents.shape[3], | |
| width=latents.shape[4], | |
| ) | |
| latent_coords = self._video_patchifier.get_patch_grid_bounds(output_shape=latent_shape, device=device) | |
| positions = get_pixel_coords(latent_coords, scale_factors=VIDEO_SCALE_FACTORS, causal_fix=True) | |
| positions = positions.to(torch.bfloat16) | |
| positions[:, 0, ...] = positions[:, 0, ...] / fps | |
| return patchified, positions | |
| def _run_denoising( | |
| self, | |
| config: GenerationConfig, | |
| video_state: LatentState, | |
| audio_state: LatentState | None, | |
| video_clean_state: LatentState, | |
| audio_clean_state: LatentState | None, | |
| v_ctx_pos: Tensor, | |
| a_ctx_pos: Tensor, | |
| v_ctx_neg: Tensor | None, | |
| a_ctx_neg: Tensor | None, | |
| device: torch.device, | |
| ) -> tuple[LatentState, LatentState | None]: | |
| """Run the denoising loop using X0 prediction with CFG and optional STG.""" | |
| scheduler = LTX2Scheduler() | |
| sigmas = scheduler.execute(steps=config.num_inference_steps).to(device).float() | |
| stepper = EulerDiffusionStep() | |
| cfg_guider = CFGGuider(config.guidance_scale) | |
| stg_guider = STGGuider(config.stg_scale) | |
| # Build STG perturbation config if STG is enabled | |
| stg_perturbation_config = self._build_stg_perturbation_config(config) if stg_guider.enabled() else None | |
| # Create initial modalities (will be updated each step via replace()) | |
| video = Modality( | |
| enabled=True, | |
| latent=video_state.latent, | |
| timesteps=video_state.denoise_mask, | |
| positions=video_state.positions, | |
| context=v_ctx_pos, | |
| context_mask=None, | |
| ) | |
| # Audio modality is None when not generating audio | |
| audio: Modality | None = None | |
| if audio_state is not None: | |
| audio = Modality( | |
| enabled=True, | |
| latent=audio_state.latent, | |
| timesteps=audio_state.denoise_mask, | |
| positions=audio_state.positions, | |
| context=a_ctx_pos, | |
| context_mask=None, | |
| ) | |
| # Wrap transformer with X0Model to convert velocity predictions to denoised outputs | |
| self._transformer.to(device) | |
| x0_model = X0Model(self._transformer) | |
| with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): | |
| for step_idx, sigma in enumerate(sigmas[:-1]): | |
| # Update modalities with current state and timesteps | |
| video = replace( | |
| video, | |
| latent=video_state.latent, | |
| timesteps=sigma * video_state.denoise_mask, | |
| positions=video_state.positions, | |
| ) | |
| if audio is not None and audio_state is not None: | |
| audio = replace( | |
| audio, | |
| latent=audio_state.latent, | |
| timesteps=sigma * audio_state.denoise_mask, | |
| positions=audio_state.positions, | |
| ) | |
| # Run model (positive pass) - X0Model returns denoised outputs | |
| pos_video, pos_audio = x0_model(video=video, audio=audio, perturbations=None) | |
| denoised_video, denoised_audio = pos_video, pos_audio | |
| # Apply CFG if guidance_scale != 1.0 | |
| if cfg_guider.enabled() and v_ctx_neg is not None: | |
| video_neg = replace(video, context=v_ctx_neg) | |
| audio_neg = replace(audio, context=a_ctx_neg) if audio is not None else None | |
| neg_video, neg_audio = x0_model(video=video_neg, audio=audio_neg, perturbations=None) | |
| denoised_video = denoised_video + cfg_guider.delta(pos_video, neg_video) | |
| if audio is not None and denoised_audio is not None: | |
| denoised_audio = denoised_audio + cfg_guider.delta(pos_audio, neg_audio) | |
| # Apply STG if stg_scale != 0.0 | |
| if stg_guider.enabled() and stg_perturbation_config is not None: | |
| perturbed_video, perturbed_audio = x0_model( | |
| video=video, audio=audio, perturbations=stg_perturbation_config | |
| ) | |
| denoised_video = denoised_video + stg_guider.delta(pos_video, perturbed_video) | |
| if audio is not None and denoised_audio is not None and perturbed_audio is not None: | |
| denoised_audio = denoised_audio + stg_guider.delta(pos_audio, perturbed_audio) | |
| # Apply conditioning mask (keep conditioned tokens clean) | |
| denoised_video = denoised_video * video_state.denoise_mask + video_clean_state.latent.float() * ( | |
| 1 - video_state.denoise_mask | |
| ) | |
| if audio is not None and audio_state is not None and audio_clean_state is not None: | |
| denoised_audio = denoised_audio * audio_state.denoise_mask + audio_clean_state.latent.float() * ( | |
| 1 - audio_state.denoise_mask | |
| ) | |
| # Euler step | |
| video_state = replace( | |
| video_state, | |
| latent=stepper.step( | |
| sample=video.latent, denoised_sample=denoised_video, sigmas=sigmas, step_index=step_idx | |
| ), | |
| ) | |
| if audio is not None and audio_state is not None: | |
| audio_state = replace( | |
| audio_state, | |
| latent=stepper.step( | |
| sample=audio.latent, denoised_sample=denoised_audio, sigmas=sigmas, step_index=step_idx | |
| ), | |
| ) | |
| # Update progress | |
| if self._sampling_context is not None: | |
| self._sampling_context.advance_step() | |
| return video_state, audio_state | |
| def _build_stg_perturbation_config(config: GenerationConfig) -> BatchedPerturbationConfig: | |
| """Build the perturbation config for STG based on the stg_mode.""" | |
| # Always skip video self-attention for STG | |
| perturbations: list[Perturbation] = [ | |
| Perturbation(type=PerturbationType.SKIP_VIDEO_SELF_ATTN, blocks=config.stg_blocks) | |
| ] | |
| # Optionally also skip audio self-attention (stg_av mode) | |
| if config.stg_mode == "stg_av": | |
| perturbations.append(Perturbation(type=PerturbationType.SKIP_AUDIO_SELF_ATTN, blocks=config.stg_blocks)) | |
| perturbation_config = PerturbationConfig(perturbations=perturbations) | |
| # Batch size is 1 for validation | |
| return BatchedPerturbationConfig(perturbations=[perturbation_config]) | |
| def _decode_video_latent(self, latent: Tensor, config: GenerationConfig, device: torch.device) -> Tensor: | |
| """Decode patchified video latent to pixel space.""" | |
| # Unpatchify | |
| latent_frames = config.num_frames // VIDEO_SCALE_FACTORS[0] + 1 | |
| latent_height = config.height // VIDEO_SCALE_FACTORS[1] | |
| latent_width = config.width // VIDEO_SCALE_FACTORS[2] | |
| unpatchified = self._video_patchifier.unpatchify( | |
| latent, | |
| output_shape=VideoLatentShape( | |
| height=latent_height, | |
| width=latent_width, | |
| frames=latent_frames, | |
| batch=1, | |
| channels=128, | |
| ), | |
| ) | |
| # Decode - ensure bfloat16 to match decoder weights | |
| self._vae_decoder.to(device) | |
| unpatchified = unpatchified.to(dtype=torch.bfloat16) | |
| tiled_config = config.tiled_decoding | |
| if tiled_config is not None and tiled_config.enabled: | |
| # Use tiled decoding for reduced VRAM | |
| tiling_config = TilingConfig( | |
| spatial_config=SpatialTilingConfig( | |
| tile_size_in_pixels=tiled_config.tile_size_pixels, | |
| tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, | |
| ), | |
| temporal_config=TemporalTilingConfig( | |
| tile_size_in_frames=tiled_config.tile_size_frames, | |
| tile_overlap_in_frames=tiled_config.tile_overlap_frames, | |
| ), | |
| ) | |
| chunks = [] | |
| for video_chunk, _ in self._vae_decoder.tiled_decode( | |
| unpatchified, | |
| tiling_config=tiling_config, | |
| ): | |
| chunks.append(video_chunk) | |
| decoded_video = torch.cat(chunks, dim=2) | |
| else: | |
| # Standard full decoding | |
| decoded_video = self._vae_decoder(unpatchified) | |
| decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) | |
| self._vae_decoder.to("cpu") | |
| return decoded_video[0].float().cpu() | |
| def _validate_config(self, config: GenerationConfig) -> None: | |
| """Validate generation configuration.""" | |
| if config.height % 32 != 0 or config.width % 32 != 0: | |
| raise ValueError(f"height and width must be divisible by 32, got {config.height}x{config.width}") | |
| if config.num_frames % 8 != 1: | |
| raise ValueError(f"num_frames must satisfy num_frames % 8 == 1, got {config.num_frames}") | |
| if config.generate_audio and (self._audio_decoder is None or self._vocoder is None): | |
| raise ValueError("Audio generation requires audio_decoder and vocoder") | |
| if config.condition_image is not None and self._vae_encoder is None: | |
| raise ValueError("Image conditioning requires vae_encoder") | |
| if config.reference_video is not None and self._vae_encoder is None: | |
| raise ValueError("Reference video conditioning requires vae_encoder") | |
| # Validate prompt embedding source | |
| if config.cached_embeddings is None and self._text_encoder is None: | |
| raise ValueError("Either text_encoder or config.cached_embeddings must be provided") | |
| def _get_prompt_embeddings( | |
| self, config: GenerationConfig, device: torch.device | |
| ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: | |
| """Get prompt embeddings from config cache or encode on-the-fly.""" | |
| if config.cached_embeddings is not None: | |
| # Use pre-computed embeddings from config | |
| cached = config.cached_embeddings | |
| v_ctx_pos = cached.video_context_positive.to(device) | |
| a_ctx_pos = cached.audio_context_positive.to(device) | |
| v_ctx_neg = cached.video_context_negative.to(device) if cached.video_context_negative is not None else None | |
| a_ctx_neg = cached.audio_context_negative.to(device) if cached.audio_context_negative is not None else None | |
| return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg | |
| # Fall back to encoding on-the-fly | |
| return self._encode_prompts(config, device) | |
| def _encode_prompts( | |
| self, config: GenerationConfig, device: torch.device | |
| ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: | |
| """Encode positive and negative prompts using the text encoder.""" | |
| self._text_encoder.to(device) | |
| v_ctx_pos, a_ctx_pos, _ = self._text_encoder(config.prompt) | |
| v_ctx_neg, a_ctx_neg = None, None | |
| if config.guidance_scale != 1.0: | |
| v_ctx_neg, a_ctx_neg, _ = self._text_encoder(config.negative_prompt) | |
| # Move the base Gemma model to CPU but keep embeddings connectors on GPU | |
| # as this module is also used during training | |
| self._text_encoder.model.to("cpu") | |
| self._text_encoder.feature_extractor_linear.to("cpu") | |
| return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg | |
| def _decode_video( | |
| self, video_state: LatentState, device: torch.device, tiled_config: TiledDecodingConfig | None = None | |
| ) -> Tensor: | |
| """Decode video latents to pixel space. | |
| Args: | |
| video_state: Video latent state to decode | |
| device: Device to run decoding on | |
| tiled_config: Optional tiled decoding configuration for reduced VRAM usage | |
| Returns: | |
| Decoded video tensor [C, F, H, W] in [0, 1] range | |
| """ | |
| self._vae_decoder.to(device) | |
| # Ensure latent is bfloat16 to match decoder weights | |
| latent = video_state.latent.to(dtype=torch.bfloat16) | |
| if tiled_config is not None and tiled_config.enabled: | |
| # Use tiled decoding for reduced VRAM | |
| tiling_config = TilingConfig( | |
| spatial_config=SpatialTilingConfig( | |
| tile_size_in_pixels=tiled_config.tile_size_pixels, | |
| tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, | |
| ), | |
| temporal_config=TemporalTilingConfig( | |
| tile_size_in_frames=tiled_config.tile_size_frames, | |
| tile_overlap_in_frames=tiled_config.tile_overlap_frames, | |
| ), | |
| ) | |
| chunks = [] | |
| for video_chunk, _ in self._vae_decoder.tiled_decode( | |
| latent, | |
| tiling_config=tiling_config, | |
| ): | |
| chunks.append(video_chunk) | |
| decoded_video = torch.cat(chunks, dim=2) | |
| else: | |
| # Standard full decoding | |
| decoded_video = self._vae_decoder(latent) | |
| decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) | |
| self._vae_decoder.to("cpu") | |
| return decoded_video[0].float().cpu() | |
| def _decode_audio(self, audio_state: LatentState, device: torch.device) -> Tensor: | |
| """Decode audio latents to waveform.""" | |
| self._audio_decoder.to(device) | |
| # Ensure latent is bfloat16 to match decoder weights | |
| latent = audio_state.latent.to(dtype=torch.bfloat16) | |
| decoded_audio = self._audio_decoder(latent) | |
| self._audio_decoder.to("cpu") | |
| self._vocoder.to(device) | |
| audio_waveform = self._vocoder(decoded_audio) | |
| self._vocoder.to("cpu") | |
| return audio_waveform.squeeze(0).float().cpu() | |
| def _concatenate_videos_side_by_side(left_video: Tensor, right_video: Tensor) -> Tensor: | |
| """Concatenate two videos side-by-side (horizontally). | |
| If the videos have different frame counts, the shorter one is padded with | |
| its last frame repeated. | |
| Args: | |
| left_video: Left video tensor [C, F1, H, W] in [0, 1] | |
| right_video: Right video tensor [C, F2, H, W] in [0, 1] | |
| Returns: | |
| Concatenated video tensor [C, max(F1,F2), H, W*2] in [0, 1] | |
| """ | |
| left_frames = left_video.shape[1] | |
| right_frames = right_video.shape[1] | |
| # Pad shorter video by repeating last frame | |
| if left_frames < right_frames: | |
| padding = left_video[:, -1:, :, :].expand(-1, right_frames - left_frames, -1, -1) | |
| left_video = torch.cat([left_video, padding], dim=1) | |
| elif right_frames < left_frames: | |
| padding = right_video[:, -1:, :, :].expand(-1, left_frames - right_frames, -1, -1) | |
| right_video = torch.cat([right_video, padding], dim=1) | |
| # Concatenate along width dimension | |
| return torch.cat([left_video, right_video], dim=3) | |
| def _encode_conditioning_image( | |
| self, | |
| image: Tensor, | |
| target_height: int, | |
| target_width: int, | |
| device: torch.device, | |
| ) -> Tensor: | |
| """Encode a conditioning image to latent space. | |
| The image is resized to cover the target dimensions while preserving aspect ratio, | |
| then center-cropped to exactly match the target size. | |
| """ | |
| # image is [C, H, W] in [0, 1] # noqa: ERA001 | |
| current_height, current_width = image.shape[1:] | |
| # Resize maintaining aspect ratio (cover target, then center crop) | |
| if current_height != target_height or current_width != target_width: | |
| aspect_ratio = current_width / current_height | |
| target_aspect_ratio = target_width / target_height | |
| if aspect_ratio > target_aspect_ratio: | |
| # Image is wider than target - resize to match height, crop width | |
| resize_height = target_height | |
| resize_width = int(target_height * aspect_ratio) | |
| else: | |
| # Image is taller than target - resize to match width, crop height | |
| resize_height = int(target_width / aspect_ratio) | |
| resize_width = target_width | |
| image = rearrange(image, "c h w -> 1 c h w") | |
| image = torch.nn.functional.interpolate( | |
| image, size=(resize_height, resize_width), mode="bilinear", align_corners=False | |
| ) | |
| # Center crop to target dimensions | |
| h_start = (resize_height - target_height) // 2 | |
| w_start = (resize_width - target_width) // 2 | |
| image = image[:, :, h_start : h_start + target_height, w_start : w_start + target_width] | |
| else: | |
| image = rearrange(image, "c h w -> 1 c h w") | |
| # Add frame dimension and convert to [-1, 1] | |
| image = rearrange(image, "b c h w -> b c 1 h w") | |
| image = (image * 2.0 - 1.0).to(device=device, dtype=torch.float32) | |
| # Encode | |
| self._vae_encoder.to(device) | |
| with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): | |
| encoded = self._vae_encoder(image) | |
| self._vae_encoder.to("cpu") | |
| return encoded | |