linoy
inital commit
ebfc6b3
from pathlib import Path
from typing import Annotated, Literal
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, ValidationInfo, field_validator, model_validator
from ltx_trainer.quantization import QuantizationOptions
from ltx_trainer.training_strategies.base_strategy import TrainingStrategyConfigBase
from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig
from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig
class ConfigBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")
class ModelConfig(ConfigBaseModel):
"""Configuration for the base model and training mode"""
model_path: str | Path = Field(
...,
description="Model path - local path to safetensors checkpoint file",
)
text_encoder_path: str | Path | None = Field(
default=None,
description="Path to text encoder (required for LTX-2/Gemma models, optional for LTXV/T5 models)",
)
training_mode: Literal["lora", "full"] = Field(
default="lora",
description="Training mode - either LoRA fine-tuning or full model fine-tuning",
)
load_checkpoint: str | Path | None = Field(
default=None,
description="Path to a checkpoint file or directory to load from. "
"If a directory is provided, the latest checkpoint will be used.",
)
@field_validator("model_path")
@classmethod
def validate_model_path(cls, v: str | Path) -> str | Path:
"""Validate that model_path is either a valid URL or an existing local path."""
is_url = str(v).startswith(("http://", "https://"))
if is_url:
raise ValueError(f"Model path cannot be a URL: {v}")
if not Path(v).exists():
raise ValueError(f"Model path does not exist: {v}")
return v
class LoraConfig(ConfigBaseModel):
"""Configuration for LoRA fine-tuning"""
rank: int = Field(
default=64,
description="Rank of LoRA adaptation",
ge=2,
)
alpha: int = Field(
default=64,
description="Alpha scaling factor for LoRA",
ge=1,
)
dropout: float = Field(
default=0.0,
description="Dropout probability for LoRA layers",
ge=0.0,
le=1.0,
)
target_modules: list[str] = Field(
default=["to_k", "to_q", "to_v", "to_out.0"],
description="List of modules to target with LoRA",
)
def _get_strategy_discriminator(v: dict | TrainingStrategyConfigBase) -> str:
"""Discriminator function for strategy config union."""
if isinstance(v, dict):
return v.get("name", "text_to_video")
return v.name
# Union type for all strategy configs with discriminator
TrainingStrategyConfig = Annotated[
Annotated[TextToVideoConfig, Tag("text_to_video")] | Annotated[VideoToVideoConfig, Tag("video_to_video")],
Discriminator(_get_strategy_discriminator),
]
class OptimizationConfig(ConfigBaseModel):
"""Configuration for optimization parameters"""
learning_rate: float = Field(
default=5e-4,
description="Learning rate for optimization",
)
steps: int = Field(
default=3000,
description="Number of training steps",
)
batch_size: int = Field(
default=2,
description="Batch size for training",
)
gradient_accumulation_steps: int = Field(
default=1,
description="Number of steps to accumulate gradients",
)
max_grad_norm: float = Field(
default=1.0,
description="Maximum gradient norm for clipping",
)
optimizer_type: Literal["adamw", "adamw8bit"] = Field(
default="adamw",
description="Type of optimizer to use for training",
)
scheduler_type: Literal[
"constant",
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
] = Field(
default="linear",
description="Type of scheduler to use for training",
)
scheduler_params: dict = Field(
default_factory=dict,
description="Parameters for the scheduler",
)
enable_gradient_checkpointing: bool = Field(
default=False,
description="Enable gradient checkpointing to save memory at the cost of slower training",
)
class AccelerationConfig(ConfigBaseModel):
"""Configuration for hardware acceleration and compute optimization"""
mixed_precision_mode: Literal["no", "fp16", "bf16"] | None = Field(
default="bf16",
description="Mixed precision training mode",
)
quantization: QuantizationOptions | None = Field(
default=None,
description="Quantization precision to use",
)
load_text_encoder_in_8bit: bool = Field(
default=False,
description="Whether to load the text encoder in 8-bit precision to save memory",
)
class DataConfig(ConfigBaseModel):
"""Configuration for data loading and processing"""
preprocessed_data_root: str = Field(
description="Path to folder containing preprocessed training data",
)
num_dataloader_workers: int = Field(
default=2,
description="Number of background processes for data loading (0 means synchronous loading)",
ge=0,
)
class ValidationConfig(ConfigBaseModel):
"""Configuration for validation during training"""
prompts: list[str] = Field(
default_factory=list,
description="List of prompts to use for validation",
)
negative_prompt: str = Field(
default="worst quality, inconsistent motion, blurry, jittery, distorted",
description="Negative prompt to use for validation examples",
)
images: list[str] | None = Field(
default=None,
description="List of image paths to use for validation. "
"One image path must be provided for each validation prompt",
)
reference_videos: list[str] | None = Field(
default=None,
description="List of reference video paths to use for validation. "
"One video path must be provided for each validation prompt",
)
video_dims: tuple[int, int, int] = Field(
default=(960, 544, 97),
description="Dimensions of validation videos (width, height, frames). "
"Width and height must be divisible by 32. Frames must satisfy frames % 8 == 1 for LTX-2.",
)
@field_validator("video_dims")
@classmethod
def validate_video_dims(cls, v: tuple[int, int, int]) -> tuple[int, int, int]:
"""Validate video dimensions for LTX-2 compatibility."""
width, height, frames = v
if width % 32 != 0:
raise ValueError(f"Width ({width}) must be divisible by 32")
if height % 32 != 0:
raise ValueError(f"Height ({height}) must be divisible by 32")
if frames % 8 != 1:
raise ValueError(f"Frames ({frames}) must satisfy frames % 8 == 1 for LTX-2 (e.g., 1, 9, 17, 25, ...)")
return v
frame_rate: float = Field(
default=25.0,
description="Frame rate for validation videos",
gt=0,
)
seed: int = Field(
default=42,
description="Random seed used when sampling validation videos",
)
inference_steps: int = Field(
default=50,
description="Number of inference steps for validation",
gt=0,
)
interval: int | None = Field(
default=100,
description="Number of steps between validation runs. If None, validation is disabled.",
gt=0,
)
videos_per_prompt: int = Field(
default=1,
description="Number of videos to generate per validation prompt",
gt=0,
)
guidance_scale: float = Field(
default=3.0,
description="CFG guidance scale to use during validation",
ge=1.0,
)
stg_scale: float = Field(
default=1.0,
description="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. "
"Recommended value is 1.0. STG is combined with CFG for improved video quality.",
ge=0.0,
)
stg_blocks: list[int] | None = Field(
default=[29],
description="Which transformer blocks to perturb for STG. "
"None means all blocks are perturbed. Recommended for LTX-2: [29].",
)
stg_mode: Literal["stg_av", "stg_v"] = Field(
default="stg_av",
description="STG mode: 'stg_av' skips both audio and video self-attention, "
"'stg_v' skips only video self-attention.",
)
generate_audio: bool = Field(
default=True,
description="Whether to generate audio in validation samples. "
"Independent of training strategy setting - you can generate audio "
"in validation even when not training the audio branch.",
)
skip_initial_validation: bool = Field(
default=False,
description="Skip validation video sampling at step 0 (beginning of training)",
)
include_reference_in_output: bool = Field(
default=False,
description="For video-to-video training: concatenate the original reference video side-by-side "
"with the generated output. The reference comes from the input video, not from the model's output.",
)
@field_validator("images")
@classmethod
def validate_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None:
"""Validate that number of images (if provided) matches number of prompts."""
if v is None:
return None
num_prompts = len(info.data.get("prompts", []))
if v is not None and len(v) != num_prompts:
raise ValueError(f"Number of images ({len(v)}) must match number of prompts ({num_prompts})")
for image_path in v:
if not Path(image_path).exists():
raise ValueError(f"Image path '{image_path}' does not exist")
return v
@field_validator("reference_videos")
@classmethod
def validate_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None:
"""Validate that number of reference videos (if provided) matches number of prompts."""
if v is None:
return None
num_prompts = len(info.data.get("prompts", []))
if v is not None and len(v) != num_prompts:
raise ValueError(f"Number of reference videos ({len(v)}) must match number of prompts ({num_prompts})")
for video_path in v:
if not Path(video_path).exists():
raise ValueError(f"Reference video path '{video_path}' does not exist")
return v
class CheckpointsConfig(ConfigBaseModel):
"""Configuration for model checkpointing during training"""
interval: int | None = Field(
default=None,
description="Number of steps between checkpoint saves. If None, intermediate checkpoints are disabled.",
gt=0,
)
keep_last_n: int = Field(
default=1,
description="Number of most recent checkpoints to keep. Set to -1 to keep all checkpoints.",
ge=-1,
)
class HubConfig(ConfigBaseModel):
"""Configuration for Hugging Face Hub integration"""
push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
hub_model_id: str | None = Field(
default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
)
@model_validator(mode="after")
def validate_hub_config(self) -> "HubConfig":
"""Validate that hub_model_id is not None when push_to_hub is True."""
if self.push_to_hub and not self.hub_model_id:
raise ValueError("hub_model_id must be specified when push_to_hub is True")
return self
class WandbConfig(ConfigBaseModel):
"""Configuration for Weights & Biases logging"""
enabled: bool = Field(
default=False,
description="Whether to enable W&B logging",
)
project: str = Field(
default="ltxv-trainer",
description="W&B project name",
)
entity: str | None = Field(
default=None,
description="W&B username or team",
)
tags: list[str] = Field(
default_factory=list,
description="Tags to add to the W&B run",
)
log_validation_videos: bool = Field(
default=True,
description="Whether to log validation videos to W&B",
)
class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""
timestep_sampling_mode: Literal["uniform", "shifted_logit_normal"] = Field(
default="shifted_logit_normal",
description="Mode to use for timestep sampling",
)
timestep_sampling_params: dict = Field(
default_factory=dict,
description="Parameters for timestep sampling",
)
class LtxTrainerConfig(ConfigBaseModel):
"""Unified configuration for LTXV training"""
# Sub-configurations
model: ModelConfig = Field(default_factory=ModelConfig)
lora: LoraConfig | None = Field(default=None)
training_strategy: TrainingStrategyConfig = Field(
default_factory=TextToVideoConfig,
description="Training strategy configuration. Determines the training mode and its parameters.",
)
optimization: OptimizationConfig = Field(default_factory=OptimizationConfig)
acceleration: AccelerationConfig = Field(default_factory=AccelerationConfig)
data: DataConfig
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)
wandb: WandbConfig = Field(default_factory=WandbConfig)
# General configuration
seed: int = Field(
default=42,
description="Random seed for reproducibility",
)
output_dir: str = Field(
default="outputs",
description="Directory to save model outputs",
)
# noinspection PyNestedDecorators
@field_validator("output_dir")
@classmethod
def expand_output_path(cls, v: str) -> str:
"""Expand user home directory in output path."""
return str(Path(v).expanduser().resolve())
@model_validator(mode="after")
def validate_strategy_compatibility(self) -> "LtxTrainerConfig":
"""Validate that training strategy and other configurations are compatible."""
# Check that reference videos are provided when using video_to_video strategy
if (
self.training_strategy.name == "video_to_video"
and self.validation.interval
and not self.validation.reference_videos
):
raise ValueError(
"reference_videos must be provided in validation config when using video_to_video strategy"
)
# Check that LoRA config is provided when training mode is lora
if self.model.training_mode == "lora" and self.lora is None:
raise ValueError("LoRA configuration must be provided when training_mode is 'lora'")
# Check that LoRA config is provided when using video_to_video strategy
if self.training_strategy.name == "video_to_video" and self.model.training_mode != "lora":
raise ValueError("Training mode must be 'lora' when using video_to_video strategy")
return self