|
|
"""Validation sampling for LTX-2 training using ltx-core components. |
|
|
|
|
|
This module provides a simplified validation pipeline for generating samples during training, |
|
|
using the new ltx-core components (VideoLatentTools, AudioLatentTools, LatentState, etc.). |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, replace |
|
|
from typing import TYPE_CHECKING, Literal |
|
|
|
|
|
import torch |
|
|
from einops import rearrange |
|
|
from torch import Tensor |
|
|
|
|
|
from ltx_core.guidance.perturbations import ( |
|
|
BatchedPerturbationConfig, |
|
|
Perturbation, |
|
|
PerturbationConfig, |
|
|
PerturbationType, |
|
|
) |
|
|
from ltx_core.model.transformer.modality import Modality |
|
|
from ltx_core.model.transformer.model import X0Model |
|
|
from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep |
|
|
from ltx_core.pipeline.components.guiders import CFGGuider, STGGuider |
|
|
from ltx_core.pipeline.components.noisers import GaussianNoiser |
|
|
from ltx_core.pipeline.components.patchifiers import ( |
|
|
AudioLatentShape, |
|
|
AudioPatchifier, |
|
|
VideoLatentPatchifier, |
|
|
VideoLatentShape, |
|
|
get_pixel_coords, |
|
|
) |
|
|
from ltx_core.pipeline.components.protocols import VideoPixelShape |
|
|
from ltx_core.pipeline.components.schedulers import LTX2Scheduler |
|
|
from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, VideoLatentTools |
|
|
from ltx_core.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig |
|
|
from ltx_trainer.progress import SamplingContext |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder |
|
|
from ltx_core.model.audio_vae.vocoder import Vocoder |
|
|
from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel |
|
|
from ltx_core.model.transformer.model import LTXModel |
|
|
from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder |
|
|
from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder |
|
|
|
|
|
|
|
|
VIDEO_SCALE_FACTORS = (8, 32, 32) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CachedPromptEmbeddings: |
|
|
"""Pre-computed text embeddings for a validation prompt. |
|
|
|
|
|
These embeddings are computed once at training start and reused for all validation runs, |
|
|
avoiding the need to load the full Gemma text encoder during validation. |
|
|
""" |
|
|
|
|
|
video_context_positive: Tensor |
|
|
audio_context_positive: Tensor |
|
|
video_context_negative: Tensor | None = None |
|
|
audio_context_negative: Tensor | None = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TiledDecodingConfig: |
|
|
"""Configuration for tiled video decoding to reduce VRAM usage. |
|
|
|
|
|
Tiled decoding splits the latent tensor into overlapping tiles, decodes each |
|
|
tile individually, and blends them together. This significantly reduces peak |
|
|
VRAM usage at the cost of slightly slower decoding. |
|
|
|
|
|
Defaults match the recommended values from ltx-core tests. |
|
|
""" |
|
|
|
|
|
enabled: bool = True |
|
|
tile_size_pixels: int = 192 |
|
|
tile_overlap_pixels: int = 64 |
|
|
tile_size_frames: int = 48 |
|
|
tile_overlap_frames: int = 24 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationConfig: |
|
|
"""Configuration for video/audio generation.""" |
|
|
|
|
|
prompt: str |
|
|
negative_prompt: str = "" |
|
|
height: int = 544 |
|
|
width: int = 960 |
|
|
num_frames: int = 97 |
|
|
frame_rate: float = 25.0 |
|
|
num_inference_steps: int = 30 |
|
|
guidance_scale: float = 3.0 |
|
|
seed: int = 42 |
|
|
condition_image: Tensor | None = None |
|
|
reference_video: Tensor | None = None |
|
|
generate_audio: bool = True |
|
|
include_reference_in_output: bool = False |
|
|
cached_embeddings: CachedPromptEmbeddings | None = None |
|
|
stg_scale: float = 0.0 |
|
|
stg_blocks: list[int] | None = None |
|
|
stg_mode: Literal["stg_av", "stg_v"] = "stg_av" |
|
|
|
|
|
tiled_decoding: TiledDecodingConfig | Literal[False] | None = None |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
"""Apply default tiled decoding config if not provided.""" |
|
|
if self.tiled_decoding is None: |
|
|
|
|
|
object.__setattr__(self, "tiled_decoding", TiledDecodingConfig()) |
|
|
elif self.tiled_decoding is False: |
|
|
|
|
|
object.__setattr__(self, "tiled_decoding", TiledDecodingConfig(enabled=False)) |
|
|
|
|
|
|
|
|
class ValidationSampler: |
|
|
"""Generates validation samples during training using ltx-core components. |
|
|
|
|
|
This class provides a simplified interface for generating video (and optionally audio) |
|
|
samples during training validation. It supports: |
|
|
- Text-to-video generation |
|
|
- Image-to-video generation (first frame conditioning) |
|
|
- Video-to-video generation (IC-LoRA reference video conditioning) |
|
|
- Optional audio generation |
|
|
|
|
|
The implementation follows the patterns from ltx_pipelines.single_stage. |
|
|
|
|
|
Text embeddings can be provided either via: |
|
|
- A full text_encoder (encodes prompts on-the-fly) |
|
|
- Pre-computed cached_embeddings (avoids loading Gemma during validation) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
transformer: "LTXModel", |
|
|
vae_decoder: "VideoDecoder", |
|
|
vae_encoder: "VideoEncoder | None", |
|
|
text_encoder: "AVGemmaTextEncoderModel | None" = None, |
|
|
audio_decoder: "AudioDecoder | None" = None, |
|
|
vocoder: "Vocoder | None" = None, |
|
|
sampling_context: SamplingContext | None = None, |
|
|
): |
|
|
"""Initialize the validation sampler. |
|
|
|
|
|
Args: |
|
|
transformer: LTX-2 transformer model |
|
|
vae_decoder: Video VAE decoder |
|
|
vae_encoder: Video VAE encoder (for image/video conditioning), can be None if not needed |
|
|
text_encoder: Gemma text encoder with embeddings connector (optional if cached_embeddings in config) |
|
|
audio_decoder: Optional audio VAE decoder (for audio generation) |
|
|
vocoder: Optional vocoder (for audio generation) |
|
|
sampling_context: Optional SamplingContext for progress display during denoising |
|
|
""" |
|
|
self._transformer = transformer |
|
|
self._vae_decoder = vae_decoder |
|
|
self._vae_encoder = vae_encoder |
|
|
self._text_encoder = text_encoder |
|
|
self._audio_decoder = audio_decoder |
|
|
self._vocoder = vocoder |
|
|
self._sampling_context = sampling_context |
|
|
|
|
|
|
|
|
self._video_patchifier = VideoLatentPatchifier(patch_size=1) |
|
|
self._audio_patchifier = AudioPatchifier(patch_size=1) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
config: GenerationConfig, |
|
|
device: torch.device | str = "cuda", |
|
|
) -> tuple[Tensor, Tensor | None]: |
|
|
"""Generate a video (and optionally audio) sample. |
|
|
|
|
|
Args: |
|
|
config: Generation configuration |
|
|
device: Device to run generation on |
|
|
|
|
|
Returns: |
|
|
Tuple of: |
|
|
- video: Video tensor [C, F, H, W] in [0, 1] (float32) |
|
|
- audio: Audio waveform tensor [C, samples] or None |
|
|
""" |
|
|
device = torch.device(device) if isinstance(device, str) else device |
|
|
self._validate_config(config) |
|
|
|
|
|
|
|
|
if config.reference_video is not None: |
|
|
return self._generate_with_reference(config, device) |
|
|
return self._generate_standard(config, device) |
|
|
|
|
|
def _generate_standard(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: |
|
|
"""Standard generation (text-to-video or image-to-video).""" |
|
|
|
|
|
v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) |
|
|
|
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(config.seed) |
|
|
|
|
|
|
|
|
video_tools = self._create_video_latent_tools(config) |
|
|
audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None |
|
|
|
|
|
|
|
|
video_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) |
|
|
audio_clean_state = ( |
|
|
audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None |
|
|
) |
|
|
|
|
|
|
|
|
if config.condition_image is not None: |
|
|
video_clean_state = self._apply_image_conditioning( |
|
|
video_clean_state, config.condition_image, config, device |
|
|
) |
|
|
|
|
|
|
|
|
noiser = GaussianNoiser(generator=generator) |
|
|
video_state = noiser(latent_state=video_clean_state, noise_scale=1.0) |
|
|
audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None |
|
|
|
|
|
|
|
|
video_state, audio_state = self._run_denoising( |
|
|
config=config, |
|
|
video_state=video_state, |
|
|
audio_state=audio_state, |
|
|
video_clean_state=video_clean_state, |
|
|
audio_clean_state=audio_clean_state, |
|
|
v_ctx_pos=v_ctx_pos, |
|
|
a_ctx_pos=a_ctx_pos, |
|
|
v_ctx_neg=v_ctx_neg, |
|
|
a_ctx_neg=a_ctx_neg, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
video_state = video_tools.clear_conditioning(video_state) |
|
|
video_state = video_tools.unpatchify(video_state) |
|
|
video_output = self._decode_video(video_state, device, config.tiled_decoding) |
|
|
|
|
|
audio_output = None |
|
|
if audio_state is not None and audio_tools is not None: |
|
|
audio_state = audio_tools.clear_conditioning(audio_state) |
|
|
audio_state = audio_tools.unpatchify(audio_state) |
|
|
audio_output = self._decode_audio(audio_state, device) |
|
|
|
|
|
return video_output, audio_output |
|
|
|
|
|
def _generate_with_reference(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: |
|
|
"""Generate with reference video conditioning (IC-LoRA style). |
|
|
|
|
|
For IC-LoRA: |
|
|
- Reference video latents are concatenated with target latents |
|
|
- Reference latents have timestep=0 (clean, not denoised) |
|
|
- Target latents are denoised normally |
|
|
- If condition_image is also provided, the first frame of the target is conditioned |
|
|
- If include_reference_in_output is True, the preprocessed reference video |
|
|
is concatenated side-by-side with the generated video |
|
|
""" |
|
|
|
|
|
v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) |
|
|
|
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(config.seed) |
|
|
|
|
|
|
|
|
ref_video_preprocessed = self._preprocess_reference_video(config) |
|
|
ref_latent, ref_positions = self._encode_video(ref_video_preprocessed, config.frame_rate, device) |
|
|
ref_seq_len = ref_latent.shape[1] |
|
|
|
|
|
|
|
|
video_tools = self._create_video_latent_tools(config) |
|
|
target_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
if config.condition_image is not None: |
|
|
target_clean_state = self._apply_image_conditioning( |
|
|
target_clean_state, config.condition_image, config, device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ref_denoise_mask = torch.zeros(1, ref_seq_len, 1, device=device, dtype=torch.float32) |
|
|
combined_clean_state = LatentState( |
|
|
latent=torch.cat([ref_latent, target_clean_state.latent], dim=1), |
|
|
denoise_mask=torch.cat([ref_denoise_mask, target_clean_state.denoise_mask], dim=1), |
|
|
positions=torch.cat([ref_positions, target_clean_state.positions], dim=2), |
|
|
clean_latent=torch.cat([ref_latent, target_clean_state.clean_latent], dim=1), |
|
|
) |
|
|
|
|
|
|
|
|
noiser = GaussianNoiser(generator=generator) |
|
|
combined_state = noiser(latent_state=combined_clean_state, noise_scale=1.0) |
|
|
|
|
|
|
|
|
audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None |
|
|
audio_clean_state = ( |
|
|
audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None |
|
|
) |
|
|
audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None |
|
|
|
|
|
|
|
|
combined_state, audio_state = self._run_denoising( |
|
|
config=config, |
|
|
video_state=combined_state, |
|
|
audio_state=audio_state, |
|
|
video_clean_state=combined_clean_state, |
|
|
audio_clean_state=audio_clean_state, |
|
|
v_ctx_pos=v_ctx_pos, |
|
|
a_ctx_pos=a_ctx_pos, |
|
|
v_ctx_neg=v_ctx_neg, |
|
|
a_ctx_neg=a_ctx_neg, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
target_latent = combined_state.latent[:, ref_seq_len:] |
|
|
video_output = self._decode_video_latent(target_latent, config, device) |
|
|
|
|
|
|
|
|
if config.include_reference_in_output: |
|
|
|
|
|
|
|
|
ref_video_pixels = ref_video_preprocessed[0].cpu() |
|
|
|
|
|
ref_video_pixels = ((ref_video_pixels + 1.0) / 2.0).clamp(0.0, 1.0) |
|
|
video_output = self._concatenate_videos_side_by_side(ref_video_pixels, video_output) |
|
|
|
|
|
|
|
|
audio_output = None |
|
|
if audio_state is not None and audio_tools is not None: |
|
|
audio_state = audio_tools.clear_conditioning(audio_state) |
|
|
audio_state = audio_tools.unpatchify(audio_state) |
|
|
audio_output = self._decode_audio(audio_state, device) |
|
|
|
|
|
return video_output, audio_output |
|
|
|
|
|
def _create_video_latent_tools(self, config: GenerationConfig) -> VideoLatentTools: |
|
|
"""Create video latent tools for the given configuration.""" |
|
|
pixel_shape = VideoPixelShape( |
|
|
batch=1, |
|
|
frames=config.num_frames, |
|
|
height=config.height, |
|
|
width=config.width, |
|
|
fps=config.frame_rate, |
|
|
) |
|
|
return VideoLatentTools( |
|
|
patchifier=self._video_patchifier, |
|
|
target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), |
|
|
fps=config.frame_rate, |
|
|
scale_factors=VIDEO_SCALE_FACTORS, |
|
|
causal_fix=True, |
|
|
) |
|
|
|
|
|
def _create_audio_latent_tools(self, config: GenerationConfig) -> AudioLatentTools: |
|
|
"""Create audio latent tools for the given configuration.""" |
|
|
return AudioLatentTools( |
|
|
patchifier=self._audio_patchifier, |
|
|
target_shape=AudioLatentShape.from_duration(batch=1, duration=config.num_frames / config.frame_rate), |
|
|
) |
|
|
|
|
|
def _apply_image_conditioning( |
|
|
self, video_state: LatentState, image: Tensor, config: GenerationConfig, device: torch.device |
|
|
) -> LatentState: |
|
|
"""Apply first-frame image conditioning to the video state.""" |
|
|
|
|
|
encoded_image = self._encode_conditioning_image(image, config.height, config.width, device) |
|
|
|
|
|
|
|
|
patchified_image = self._video_patchifier.patchify(encoded_image) |
|
|
num_image_tokens = patchified_image.shape[1] |
|
|
|
|
|
|
|
|
new_latent = video_state.latent.clone() |
|
|
new_latent[:, :num_image_tokens] = patchified_image.to(new_latent.dtype) |
|
|
|
|
|
|
|
|
new_clean_latent = video_state.clean_latent.clone() |
|
|
new_clean_latent[:, :num_image_tokens] = patchified_image.to(new_clean_latent.dtype) |
|
|
|
|
|
|
|
|
new_denoise_mask = video_state.denoise_mask.clone() |
|
|
new_denoise_mask[:, :num_image_tokens] = 0.0 |
|
|
|
|
|
return LatentState( |
|
|
latent=new_latent, |
|
|
denoise_mask=new_denoise_mask, |
|
|
positions=video_state.positions, |
|
|
clean_latent=new_clean_latent, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _preprocess_reference_video(config: GenerationConfig) -> Tensor: |
|
|
"""Preprocess reference video: resize, crop, and convert to model input format. |
|
|
|
|
|
Args: |
|
|
config: Generation configuration with reference_video |
|
|
|
|
|
Returns: |
|
|
Preprocessed video tensor [B, C, F, H, W] in [-1, 1] range |
|
|
""" |
|
|
ref_video = config.reference_video |
|
|
target_height, target_width = config.height, config.width |
|
|
current_height, current_width = ref_video.shape[2:] |
|
|
|
|
|
|
|
|
if current_height != target_height or current_width != target_width: |
|
|
aspect_ratio = current_width / current_height |
|
|
target_aspect_ratio = target_width / target_height |
|
|
|
|
|
if aspect_ratio > target_aspect_ratio: |
|
|
resize_height, resize_width = target_height, int(target_height * aspect_ratio) |
|
|
else: |
|
|
resize_height, resize_width = int(target_width / aspect_ratio), target_width |
|
|
|
|
|
ref_video = torch.nn.functional.interpolate( |
|
|
ref_video, size=(resize_height, resize_width), mode="bilinear", align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
h_start = (resize_height - target_height) // 2 |
|
|
w_start = (resize_width - target_width) // 2 |
|
|
ref_video = ref_video[:, :, h_start : h_start + target_height, w_start : w_start + target_width] |
|
|
|
|
|
|
|
|
ref_video = rearrange(ref_video, "f c h w -> 1 c f h w") |
|
|
valid_frames = (ref_video.shape[2] - 1) // 8 * 8 + 1 |
|
|
ref_video = ref_video[:, :, :valid_frames] |
|
|
|
|
|
|
|
|
return ref_video * 2.0 - 1.0 |
|
|
|
|
|
def _encode_video(self, video: Tensor, fps: float, device: torch.device) -> tuple[Tensor, Tensor]: |
|
|
"""Encode video to patchified latents and compute positions. |
|
|
|
|
|
Args: |
|
|
video: Video tensor [B, C, F, H, W] in [-1, 1] range |
|
|
fps: Frame rate for temporal position scaling |
|
|
device: Device to run encoding on |
|
|
|
|
|
Returns: |
|
|
Tuple of (patchified_latents, positions) |
|
|
""" |
|
|
video = video.to(device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
self._vae_encoder.to(device) |
|
|
with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): |
|
|
latents = self._vae_encoder(video) |
|
|
self._vae_encoder.to("cpu") |
|
|
|
|
|
latents = latents.to(torch.bfloat16) |
|
|
patchified = self._video_patchifier.patchify(latents) |
|
|
|
|
|
|
|
|
latent_shape = VideoLatentShape( |
|
|
batch=1, |
|
|
channels=latents.shape[1], |
|
|
frames=latents.shape[2], |
|
|
height=latents.shape[3], |
|
|
width=latents.shape[4], |
|
|
) |
|
|
latent_coords = self._video_patchifier.get_patch_grid_bounds(output_shape=latent_shape, device=device) |
|
|
positions = get_pixel_coords(latent_coords, scale_factors=VIDEO_SCALE_FACTORS, causal_fix=True) |
|
|
positions = positions.to(torch.bfloat16) |
|
|
positions[:, 0, ...] = positions[:, 0, ...] / fps |
|
|
|
|
|
return patchified, positions |
|
|
|
|
|
def _run_denoising( |
|
|
self, |
|
|
config: GenerationConfig, |
|
|
video_state: LatentState, |
|
|
audio_state: LatentState | None, |
|
|
video_clean_state: LatentState, |
|
|
audio_clean_state: LatentState | None, |
|
|
v_ctx_pos: Tensor, |
|
|
a_ctx_pos: Tensor, |
|
|
v_ctx_neg: Tensor | None, |
|
|
a_ctx_neg: Tensor | None, |
|
|
device: torch.device, |
|
|
) -> tuple[LatentState, LatentState | None]: |
|
|
"""Run the denoising loop using X0 prediction with CFG and optional STG.""" |
|
|
scheduler = LTX2Scheduler() |
|
|
sigmas = scheduler.execute(steps=config.num_inference_steps).to(device).float() |
|
|
stepper = EulerDiffusionStep() |
|
|
cfg_guider = CFGGuider(config.guidance_scale) |
|
|
stg_guider = STGGuider(config.stg_scale) |
|
|
|
|
|
|
|
|
stg_perturbation_config = self._build_stg_perturbation_config(config) if stg_guider.enabled() else None |
|
|
|
|
|
|
|
|
video = Modality( |
|
|
enabled=True, |
|
|
latent=video_state.latent, |
|
|
timesteps=video_state.denoise_mask, |
|
|
positions=video_state.positions, |
|
|
context=v_ctx_pos, |
|
|
context_mask=None, |
|
|
) |
|
|
|
|
|
|
|
|
audio: Modality | None = None |
|
|
if audio_state is not None: |
|
|
audio = Modality( |
|
|
enabled=True, |
|
|
latent=audio_state.latent, |
|
|
timesteps=audio_state.denoise_mask, |
|
|
positions=audio_state.positions, |
|
|
context=a_ctx_pos, |
|
|
context_mask=None, |
|
|
) |
|
|
|
|
|
|
|
|
self._transformer.to(device) |
|
|
x0_model = X0Model(self._transformer) |
|
|
|
|
|
with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): |
|
|
for step_idx, sigma in enumerate(sigmas[:-1]): |
|
|
|
|
|
video = replace( |
|
|
video, |
|
|
latent=video_state.latent, |
|
|
timesteps=sigma * video_state.denoise_mask, |
|
|
positions=video_state.positions, |
|
|
) |
|
|
|
|
|
if audio is not None and audio_state is not None: |
|
|
audio = replace( |
|
|
audio, |
|
|
latent=audio_state.latent, |
|
|
timesteps=sigma * audio_state.denoise_mask, |
|
|
positions=audio_state.positions, |
|
|
) |
|
|
|
|
|
|
|
|
pos_video, pos_audio = x0_model(video=video, audio=audio, perturbations=None) |
|
|
denoised_video, denoised_audio = pos_video, pos_audio |
|
|
|
|
|
|
|
|
if cfg_guider.enabled() and v_ctx_neg is not None: |
|
|
video_neg = replace(video, context=v_ctx_neg) |
|
|
audio_neg = replace(audio, context=a_ctx_neg) if audio is not None else None |
|
|
neg_video, neg_audio = x0_model(video=video_neg, audio=audio_neg, perturbations=None) |
|
|
|
|
|
denoised_video = denoised_video + cfg_guider.delta(pos_video, neg_video) |
|
|
if audio is not None and denoised_audio is not None: |
|
|
denoised_audio = denoised_audio + cfg_guider.delta(pos_audio, neg_audio) |
|
|
|
|
|
|
|
|
if stg_guider.enabled() and stg_perturbation_config is not None: |
|
|
perturbed_video, perturbed_audio = x0_model( |
|
|
video=video, audio=audio, perturbations=stg_perturbation_config |
|
|
) |
|
|
denoised_video = denoised_video + stg_guider.delta(pos_video, perturbed_video) |
|
|
if audio is not None and denoised_audio is not None and perturbed_audio is not None: |
|
|
denoised_audio = denoised_audio + stg_guider.delta(pos_audio, perturbed_audio) |
|
|
|
|
|
|
|
|
denoised_video = denoised_video * video_state.denoise_mask + video_clean_state.latent.float() * ( |
|
|
1 - video_state.denoise_mask |
|
|
) |
|
|
if audio is not None and audio_state is not None and audio_clean_state is not None: |
|
|
denoised_audio = denoised_audio * audio_state.denoise_mask + audio_clean_state.latent.float() * ( |
|
|
1 - audio_state.denoise_mask |
|
|
) |
|
|
|
|
|
|
|
|
video_state = replace( |
|
|
video_state, |
|
|
latent=stepper.step( |
|
|
sample=video.latent, denoised_sample=denoised_video, sigmas=sigmas, step_index=step_idx |
|
|
), |
|
|
) |
|
|
if audio is not None and audio_state is not None: |
|
|
audio_state = replace( |
|
|
audio_state, |
|
|
latent=stepper.step( |
|
|
sample=audio.latent, denoised_sample=denoised_audio, sigmas=sigmas, step_index=step_idx |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if self._sampling_context is not None: |
|
|
self._sampling_context.advance_step() |
|
|
|
|
|
return video_state, audio_state |
|
|
|
|
|
@staticmethod |
|
|
def _build_stg_perturbation_config(config: GenerationConfig) -> BatchedPerturbationConfig: |
|
|
"""Build the perturbation config for STG based on the stg_mode.""" |
|
|
|
|
|
perturbations: list[Perturbation] = [ |
|
|
Perturbation(type=PerturbationType.SKIP_VIDEO_SELF_ATTN, blocks=config.stg_blocks) |
|
|
] |
|
|
|
|
|
|
|
|
if config.stg_mode == "stg_av": |
|
|
perturbations.append(Perturbation(type=PerturbationType.SKIP_AUDIO_SELF_ATTN, blocks=config.stg_blocks)) |
|
|
|
|
|
perturbation_config = PerturbationConfig(perturbations=perturbations) |
|
|
|
|
|
return BatchedPerturbationConfig(perturbations=[perturbation_config]) |
|
|
|
|
|
def _decode_video_latent(self, latent: Tensor, config: GenerationConfig, device: torch.device) -> Tensor: |
|
|
"""Decode patchified video latent to pixel space.""" |
|
|
|
|
|
latent_frames = config.num_frames // VIDEO_SCALE_FACTORS[0] + 1 |
|
|
latent_height = config.height // VIDEO_SCALE_FACTORS[1] |
|
|
latent_width = config.width // VIDEO_SCALE_FACTORS[2] |
|
|
|
|
|
unpatchified = self._video_patchifier.unpatchify( |
|
|
latent, |
|
|
output_shape=VideoLatentShape( |
|
|
height=latent_height, |
|
|
width=latent_width, |
|
|
frames=latent_frames, |
|
|
batch=1, |
|
|
channels=128, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self._vae_decoder.to(device) |
|
|
unpatchified = unpatchified.to(dtype=torch.bfloat16) |
|
|
tiled_config = config.tiled_decoding |
|
|
|
|
|
if tiled_config is not None and tiled_config.enabled: |
|
|
|
|
|
tiling_config = TilingConfig( |
|
|
spatial_config=SpatialTilingConfig( |
|
|
tile_size_in_pixels=tiled_config.tile_size_pixels, |
|
|
tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, |
|
|
), |
|
|
temporal_config=TemporalTilingConfig( |
|
|
tile_size_in_frames=tiled_config.tile_size_frames, |
|
|
tile_overlap_in_frames=tiled_config.tile_overlap_frames, |
|
|
), |
|
|
) |
|
|
chunks = [] |
|
|
for video_chunk, _ in self._vae_decoder.tiled_decode( |
|
|
unpatchified, |
|
|
tiling_config=tiling_config, |
|
|
): |
|
|
chunks.append(video_chunk) |
|
|
decoded_video = torch.cat(chunks, dim=2) |
|
|
else: |
|
|
|
|
|
decoded_video = self._vae_decoder(unpatchified) |
|
|
|
|
|
decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) |
|
|
self._vae_decoder.to("cpu") |
|
|
|
|
|
return decoded_video[0].float().cpu() |
|
|
|
|
|
def _validate_config(self, config: GenerationConfig) -> None: |
|
|
"""Validate generation configuration.""" |
|
|
if config.height % 32 != 0 or config.width % 32 != 0: |
|
|
raise ValueError(f"height and width must be divisible by 32, got {config.height}x{config.width}") |
|
|
if config.num_frames % 8 != 1: |
|
|
raise ValueError(f"num_frames must satisfy num_frames % 8 == 1, got {config.num_frames}") |
|
|
if config.generate_audio and (self._audio_decoder is None or self._vocoder is None): |
|
|
raise ValueError("Audio generation requires audio_decoder and vocoder") |
|
|
if config.condition_image is not None and self._vae_encoder is None: |
|
|
raise ValueError("Image conditioning requires vae_encoder") |
|
|
if config.reference_video is not None and self._vae_encoder is None: |
|
|
raise ValueError("Reference video conditioning requires vae_encoder") |
|
|
|
|
|
|
|
|
if config.cached_embeddings is None and self._text_encoder is None: |
|
|
raise ValueError("Either text_encoder or config.cached_embeddings must be provided") |
|
|
|
|
|
def _get_prompt_embeddings( |
|
|
self, config: GenerationConfig, device: torch.device |
|
|
) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: |
|
|
"""Get prompt embeddings from config cache or encode on-the-fly.""" |
|
|
if config.cached_embeddings is not None: |
|
|
|
|
|
cached = config.cached_embeddings |
|
|
v_ctx_pos = cached.video_context_positive.to(device) |
|
|
a_ctx_pos = cached.audio_context_positive.to(device) |
|
|
v_ctx_neg = cached.video_context_negative.to(device) if cached.video_context_negative is not None else None |
|
|
a_ctx_neg = cached.audio_context_negative.to(device) if cached.audio_context_negative is not None else None |
|
|
return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg |
|
|
|
|
|
|
|
|
return self._encode_prompts(config, device) |
|
|
|
|
|
def _encode_prompts( |
|
|
self, config: GenerationConfig, device: torch.device |
|
|
) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: |
|
|
"""Encode positive and negative prompts using the text encoder.""" |
|
|
self._text_encoder.to(device) |
|
|
v_ctx_pos, a_ctx_pos, _ = self._text_encoder(config.prompt) |
|
|
v_ctx_neg, a_ctx_neg = None, None |
|
|
if config.guidance_scale != 1.0: |
|
|
v_ctx_neg, a_ctx_neg, _ = self._text_encoder(config.negative_prompt) |
|
|
|
|
|
|
|
|
|
|
|
self._text_encoder.model.to("cpu") |
|
|
self._text_encoder.feature_extractor_linear.to("cpu") |
|
|
|
|
|
return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg |
|
|
|
|
|
def _decode_video( |
|
|
self, video_state: LatentState, device: torch.device, tiled_config: TiledDecodingConfig | None = None |
|
|
) -> Tensor: |
|
|
"""Decode video latents to pixel space. |
|
|
|
|
|
Args: |
|
|
video_state: Video latent state to decode |
|
|
device: Device to run decoding on |
|
|
tiled_config: Optional tiled decoding configuration for reduced VRAM usage |
|
|
|
|
|
Returns: |
|
|
Decoded video tensor [C, F, H, W] in [0, 1] range |
|
|
""" |
|
|
self._vae_decoder.to(device) |
|
|
|
|
|
latent = video_state.latent.to(dtype=torch.bfloat16) |
|
|
|
|
|
if tiled_config is not None and tiled_config.enabled: |
|
|
|
|
|
tiling_config = TilingConfig( |
|
|
spatial_config=SpatialTilingConfig( |
|
|
tile_size_in_pixels=tiled_config.tile_size_pixels, |
|
|
tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, |
|
|
), |
|
|
temporal_config=TemporalTilingConfig( |
|
|
tile_size_in_frames=tiled_config.tile_size_frames, |
|
|
tile_overlap_in_frames=tiled_config.tile_overlap_frames, |
|
|
), |
|
|
) |
|
|
chunks = [] |
|
|
for video_chunk, _ in self._vae_decoder.tiled_decode( |
|
|
latent, |
|
|
tiling_config=tiling_config, |
|
|
): |
|
|
chunks.append(video_chunk) |
|
|
decoded_video = torch.cat(chunks, dim=2) |
|
|
else: |
|
|
|
|
|
decoded_video = self._vae_decoder(latent) |
|
|
|
|
|
decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) |
|
|
self._vae_decoder.to("cpu") |
|
|
return decoded_video[0].float().cpu() |
|
|
|
|
|
def _decode_audio(self, audio_state: LatentState, device: torch.device) -> Tensor: |
|
|
"""Decode audio latents to waveform.""" |
|
|
self._audio_decoder.to(device) |
|
|
|
|
|
latent = audio_state.latent.to(dtype=torch.bfloat16) |
|
|
decoded_audio = self._audio_decoder(latent) |
|
|
self._audio_decoder.to("cpu") |
|
|
|
|
|
self._vocoder.to(device) |
|
|
audio_waveform = self._vocoder(decoded_audio) |
|
|
self._vocoder.to("cpu") |
|
|
|
|
|
return audio_waveform.squeeze(0).float().cpu() |
|
|
|
|
|
@staticmethod |
|
|
def _concatenate_videos_side_by_side(left_video: Tensor, right_video: Tensor) -> Tensor: |
|
|
"""Concatenate two videos side-by-side (horizontally). |
|
|
|
|
|
If the videos have different frame counts, the shorter one is padded with |
|
|
its last frame repeated. |
|
|
|
|
|
Args: |
|
|
left_video: Left video tensor [C, F1, H, W] in [0, 1] |
|
|
right_video: Right video tensor [C, F2, H, W] in [0, 1] |
|
|
|
|
|
Returns: |
|
|
Concatenated video tensor [C, max(F1,F2), H, W*2] in [0, 1] |
|
|
""" |
|
|
left_frames = left_video.shape[1] |
|
|
right_frames = right_video.shape[1] |
|
|
|
|
|
|
|
|
if left_frames < right_frames: |
|
|
padding = left_video[:, -1:, :, :].expand(-1, right_frames - left_frames, -1, -1) |
|
|
left_video = torch.cat([left_video, padding], dim=1) |
|
|
elif right_frames < left_frames: |
|
|
padding = right_video[:, -1:, :, :].expand(-1, left_frames - right_frames, -1, -1) |
|
|
right_video = torch.cat([right_video, padding], dim=1) |
|
|
|
|
|
|
|
|
return torch.cat([left_video, right_video], dim=3) |
|
|
|
|
|
def _encode_conditioning_image( |
|
|
self, |
|
|
image: Tensor, |
|
|
target_height: int, |
|
|
target_width: int, |
|
|
device: torch.device, |
|
|
) -> Tensor: |
|
|
"""Encode a conditioning image to latent space. |
|
|
|
|
|
The image is resized to cover the target dimensions while preserving aspect ratio, |
|
|
then center-cropped to exactly match the target size. |
|
|
""" |
|
|
|
|
|
current_height, current_width = image.shape[1:] |
|
|
|
|
|
|
|
|
if current_height != target_height or current_width != target_width: |
|
|
aspect_ratio = current_width / current_height |
|
|
target_aspect_ratio = target_width / target_height |
|
|
|
|
|
if aspect_ratio > target_aspect_ratio: |
|
|
|
|
|
resize_height = target_height |
|
|
resize_width = int(target_height * aspect_ratio) |
|
|
else: |
|
|
|
|
|
resize_height = int(target_width / aspect_ratio) |
|
|
resize_width = target_width |
|
|
|
|
|
image = rearrange(image, "c h w -> 1 c h w") |
|
|
image = torch.nn.functional.interpolate( |
|
|
image, size=(resize_height, resize_width), mode="bilinear", align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
h_start = (resize_height - target_height) // 2 |
|
|
w_start = (resize_width - target_width) // 2 |
|
|
image = image[:, :, h_start : h_start + target_height, w_start : w_start + target_width] |
|
|
else: |
|
|
image = rearrange(image, "c h w -> 1 c h w") |
|
|
|
|
|
|
|
|
image = rearrange(image, "b c h w -> b c 1 h w") |
|
|
image = (image * 2.0 - 1.0).to(device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
self._vae_encoder.to(device) |
|
|
with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): |
|
|
encoded = self._vae_encoder(image) |
|
|
self._vae_encoder.to("cpu") |
|
|
|
|
|
return encoded |
|
|
|