| from pathlib import Path |
| from typing import Annotated, Literal |
|
|
| from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, ValidationInfo, field_validator, model_validator |
|
|
| from ltx_trainer.quantization import QuantizationOptions |
| from ltx_trainer.training_strategies.base_strategy import TrainingStrategyConfigBase |
| from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig |
| from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig |
|
|
|
|
| class ConfigBaseModel(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
|
|
| class ModelConfig(ConfigBaseModel): |
| """Configuration for the base model and training mode""" |
|
|
| model_path: str | Path = Field( |
| ..., |
| description="Model path - local path to safetensors checkpoint file", |
| ) |
|
|
| text_encoder_path: str | Path | None = Field( |
| default=None, |
| description="Path to text encoder (required for LTX-2/Gemma models, optional for LTXV/T5 models)", |
| ) |
|
|
| training_mode: Literal["lora", "full"] = Field( |
| default="lora", |
| description="Training mode - either LoRA fine-tuning or full model fine-tuning", |
| ) |
|
|
| load_checkpoint: str | Path | None = Field( |
| default=None, |
| description="Path to a checkpoint file or directory to load from. " |
| "If a directory is provided, the latest checkpoint will be used.", |
| ) |
|
|
| @field_validator("model_path") |
| @classmethod |
| def validate_model_path(cls, v: str | Path) -> str | Path: |
| """Validate that model_path is either a valid URL or an existing local path.""" |
| is_url = str(v).startswith(("http://", "https://")) |
|
|
| if is_url: |
| raise ValueError(f"Model path cannot be a URL: {v}") |
|
|
| if not Path(v).exists(): |
| raise ValueError(f"Model path does not exist: {v}") |
|
|
| return v |
|
|
|
|
| class LoraConfig(ConfigBaseModel): |
| """Configuration for LoRA fine-tuning""" |
|
|
| rank: int = Field( |
| default=64, |
| description="Rank of LoRA adaptation", |
| ge=2, |
| ) |
|
|
| alpha: int = Field( |
| default=64, |
| description="Alpha scaling factor for LoRA", |
| ge=1, |
| ) |
|
|
| dropout: float = Field( |
| default=0.0, |
| description="Dropout probability for LoRA layers", |
| ge=0.0, |
| le=1.0, |
| ) |
|
|
| target_modules: list[str] = Field( |
| default=["to_k", "to_q", "to_v", "to_out.0"], |
| description="List of modules to target with LoRA", |
| ) |
|
|
|
|
| def _get_strategy_discriminator(v: dict | TrainingStrategyConfigBase) -> str: |
| """Discriminator function for strategy config union.""" |
| if isinstance(v, dict): |
| return v.get("name", "text_to_video") |
| return v.name |
|
|
|
|
| |
| TrainingStrategyConfig = Annotated[ |
| Annotated[TextToVideoConfig, Tag("text_to_video")] | Annotated[VideoToVideoConfig, Tag("video_to_video")], |
| Discriminator(_get_strategy_discriminator), |
| ] |
|
|
|
|
| class OptimizationConfig(ConfigBaseModel): |
| """Configuration for optimization parameters""" |
|
|
| learning_rate: float = Field( |
| default=5e-4, |
| description="Learning rate for optimization", |
| ) |
|
|
| steps: int = Field( |
| default=3000, |
| description="Number of training steps", |
| ) |
|
|
| batch_size: int = Field( |
| default=2, |
| description="Batch size for training", |
| ) |
|
|
| gradient_accumulation_steps: int = Field( |
| default=1, |
| description="Number of steps to accumulate gradients", |
| ) |
|
|
| max_grad_norm: float = Field( |
| default=1.0, |
| description="Maximum gradient norm for clipping", |
| ) |
|
|
| optimizer_type: Literal["adamw", "adamw8bit"] = Field( |
| default="adamw", |
| description="Type of optimizer to use for training", |
| ) |
|
|
| scheduler_type: Literal[ |
| "constant", |
| "linear", |
| "cosine", |
| "cosine_with_restarts", |
| "polynomial", |
| ] = Field( |
| default="linear", |
| description="Type of scheduler to use for training", |
| ) |
|
|
| scheduler_params: dict = Field( |
| default_factory=dict, |
| description="Parameters for the scheduler", |
| ) |
|
|
| enable_gradient_checkpointing: bool = Field( |
| default=False, |
| description="Enable gradient checkpointing to save memory at the cost of slower training", |
| ) |
|
|
|
|
| class AccelerationConfig(ConfigBaseModel): |
| """Configuration for hardware acceleration and compute optimization""" |
|
|
| mixed_precision_mode: Literal["no", "fp16", "bf16"] | None = Field( |
| default="bf16", |
| description="Mixed precision training mode", |
| ) |
|
|
| quantization: QuantizationOptions | None = Field( |
| default=None, |
| description="Quantization precision to use", |
| ) |
|
|
| load_text_encoder_in_8bit: bool = Field( |
| default=False, |
| description="Whether to load the text encoder in 8-bit precision to save memory", |
| ) |
|
|
|
|
| class DataConfig(ConfigBaseModel): |
| """Configuration for data loading and processing""" |
|
|
| preprocessed_data_root: str = Field( |
| description="Path to folder containing preprocessed training data", |
| ) |
|
|
| num_dataloader_workers: int = Field( |
| default=2, |
| description="Number of background processes for data loading (0 means synchronous loading)", |
| ge=0, |
| ) |
|
|
|
|
| class ValidationConfig(ConfigBaseModel): |
| """Configuration for validation during training""" |
|
|
| prompts: list[str] = Field( |
| default_factory=list, |
| description="List of prompts to use for validation", |
| ) |
|
|
| negative_prompt: str = Field( |
| default="worst quality, inconsistent motion, blurry, jittery, distorted", |
| description="Negative prompt to use for validation examples", |
| ) |
|
|
| images: list[str] | None = Field( |
| default=None, |
| description="List of image paths to use for validation. " |
| "One image path must be provided for each validation prompt", |
| ) |
|
|
| reference_videos: list[str] | None = Field( |
| default=None, |
| description="List of reference video paths to use for validation. " |
| "One video path must be provided for each validation prompt", |
| ) |
|
|
| video_dims: tuple[int, int, int] = Field( |
| default=(960, 544, 97), |
| description="Dimensions of validation videos (width, height, frames). " |
| "Width and height must be divisible by 32. Frames must satisfy frames % 8 == 1 for LTX-2.", |
| ) |
|
|
| @field_validator("video_dims") |
| @classmethod |
| def validate_video_dims(cls, v: tuple[int, int, int]) -> tuple[int, int, int]: |
| """Validate video dimensions for LTX-2 compatibility.""" |
| width, height, frames = v |
|
|
| if width % 32 != 0: |
| raise ValueError(f"Width ({width}) must be divisible by 32") |
| if height % 32 != 0: |
| raise ValueError(f"Height ({height}) must be divisible by 32") |
| if frames % 8 != 1: |
| raise ValueError(f"Frames ({frames}) must satisfy frames % 8 == 1 for LTX-2 (e.g., 1, 9, 17, 25, ...)") |
|
|
| return v |
|
|
| frame_rate: float = Field( |
| default=25.0, |
| description="Frame rate for validation videos", |
| gt=0, |
| ) |
|
|
| seed: int = Field( |
| default=42, |
| description="Random seed used when sampling validation videos", |
| ) |
|
|
| inference_steps: int = Field( |
| default=50, |
| description="Number of inference steps for validation", |
| gt=0, |
| ) |
|
|
| interval: int | None = Field( |
| default=100, |
| description="Number of steps between validation runs. If None, validation is disabled.", |
| gt=0, |
| ) |
|
|
| videos_per_prompt: int = Field( |
| default=1, |
| description="Number of videos to generate per validation prompt", |
| gt=0, |
| ) |
|
|
| guidance_scale: float = Field( |
| default=3.0, |
| description="CFG guidance scale to use during validation", |
| ge=1.0, |
| ) |
|
|
| stg_scale: float = Field( |
| default=1.0, |
| description="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. " |
| "Recommended value is 1.0. STG is combined with CFG for improved video quality.", |
| ge=0.0, |
| ) |
|
|
| stg_blocks: list[int] | None = Field( |
| default=[29], |
| description="Which transformer blocks to perturb for STG. " |
| "None means all blocks are perturbed. Recommended for LTX-2: [29].", |
| ) |
|
|
| stg_mode: Literal["stg_av", "stg_v"] = Field( |
| default="stg_av", |
| description="STG mode: 'stg_av' skips both audio and video self-attention, " |
| "'stg_v' skips only video self-attention.", |
| ) |
|
|
| generate_audio: bool = Field( |
| default=True, |
| description="Whether to generate audio in validation samples. " |
| "Independent of training strategy setting - you can generate audio " |
| "in validation even when not training the audio branch.", |
| ) |
|
|
| skip_initial_validation: bool = Field( |
| default=False, |
| description="Skip validation video sampling at step 0 (beginning of training)", |
| ) |
|
|
| include_reference_in_output: bool = Field( |
| default=False, |
| description="For video-to-video training: concatenate the original reference video side-by-side " |
| "with the generated output. The reference comes from the input video, not from the model's output.", |
| ) |
|
|
| @field_validator("images") |
| @classmethod |
| def validate_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None: |
| """Validate that number of images (if provided) matches number of prompts.""" |
| if v is None: |
| return None |
|
|
| num_prompts = len(info.data.get("prompts", [])) |
| if v is not None and len(v) != num_prompts: |
| raise ValueError(f"Number of images ({len(v)}) must match number of prompts ({num_prompts})") |
|
|
| for image_path in v: |
| if not Path(image_path).exists(): |
| raise ValueError(f"Image path '{image_path}' does not exist") |
|
|
| return v |
|
|
| @field_validator("reference_videos") |
| @classmethod |
| def validate_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None: |
| """Validate that number of reference videos (if provided) matches number of prompts.""" |
| if v is None: |
| return None |
|
|
| num_prompts = len(info.data.get("prompts", [])) |
| if v is not None and len(v) != num_prompts: |
| raise ValueError(f"Number of reference videos ({len(v)}) must match number of prompts ({num_prompts})") |
|
|
| for video_path in v: |
| if not Path(video_path).exists(): |
| raise ValueError(f"Reference video path '{video_path}' does not exist") |
|
|
| return v |
|
|
|
|
| class CheckpointsConfig(ConfigBaseModel): |
| """Configuration for model checkpointing during training""" |
|
|
| interval: int | None = Field( |
| default=None, |
| description="Number of steps between checkpoint saves. If None, intermediate checkpoints are disabled.", |
| gt=0, |
| ) |
|
|
| keep_last_n: int = Field( |
| default=1, |
| description="Number of most recent checkpoints to keep. Set to -1 to keep all checkpoints.", |
| ge=-1, |
| ) |
|
|
|
|
| class HubConfig(ConfigBaseModel): |
| """Configuration for Hugging Face Hub integration""" |
|
|
| push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub") |
| hub_model_id: str | None = Field( |
| default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')" |
| ) |
|
|
| @model_validator(mode="after") |
| def validate_hub_config(self) -> "HubConfig": |
| """Validate that hub_model_id is not None when push_to_hub is True.""" |
| if self.push_to_hub and not self.hub_model_id: |
| raise ValueError("hub_model_id must be specified when push_to_hub is True") |
| return self |
|
|
|
|
| class WandbConfig(ConfigBaseModel): |
| """Configuration for Weights & Biases logging""" |
|
|
| enabled: bool = Field( |
| default=False, |
| description="Whether to enable W&B logging", |
| ) |
|
|
| project: str = Field( |
| default="ltxv-trainer", |
| description="W&B project name", |
| ) |
|
|
| entity: str | None = Field( |
| default=None, |
| description="W&B username or team", |
| ) |
|
|
| tags: list[str] = Field( |
| default_factory=list, |
| description="Tags to add to the W&B run", |
| ) |
|
|
| log_validation_videos: bool = Field( |
| default=True, |
| description="Whether to log validation videos to W&B", |
| ) |
|
|
|
|
| class FlowMatchingConfig(ConfigBaseModel): |
| """Configuration for flow matching training""" |
|
|
| timestep_sampling_mode: Literal["uniform", "shifted_logit_normal"] = Field( |
| default="shifted_logit_normal", |
| description="Mode to use for timestep sampling", |
| ) |
|
|
| timestep_sampling_params: dict = Field( |
| default_factory=dict, |
| description="Parameters for timestep sampling", |
| ) |
|
|
|
|
| class LtxTrainerConfig(ConfigBaseModel): |
| """Unified configuration for LTXV training""" |
|
|
| |
| model: ModelConfig = Field(default_factory=ModelConfig) |
| lora: LoraConfig | None = Field(default=None) |
| training_strategy: TrainingStrategyConfig = Field( |
| default_factory=TextToVideoConfig, |
| description="Training strategy configuration. Determines the training mode and its parameters.", |
| ) |
| optimization: OptimizationConfig = Field(default_factory=OptimizationConfig) |
| acceleration: AccelerationConfig = Field(default_factory=AccelerationConfig) |
| data: DataConfig |
| validation: ValidationConfig = Field(default_factory=ValidationConfig) |
| checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig) |
| hub: HubConfig = Field(default_factory=HubConfig) |
| flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig) |
| wandb: WandbConfig = Field(default_factory=WandbConfig) |
|
|
| |
| seed: int = Field( |
| default=42, |
| description="Random seed for reproducibility", |
| ) |
|
|
| output_dir: str = Field( |
| default="outputs", |
| description="Directory to save model outputs", |
| ) |
|
|
| |
| @field_validator("output_dir") |
| @classmethod |
| def expand_output_path(cls, v: str) -> str: |
| """Expand user home directory in output path.""" |
| return str(Path(v).expanduser().resolve()) |
|
|
| @model_validator(mode="after") |
| def validate_strategy_compatibility(self) -> "LtxTrainerConfig": |
| """Validate that training strategy and other configurations are compatible.""" |
|
|
| |
| if ( |
| self.training_strategy.name == "video_to_video" |
| and self.validation.interval |
| and not self.validation.reference_videos |
| ): |
| raise ValueError( |
| "reference_videos must be provided in validation config when using video_to_video strategy" |
| ) |
|
|
| |
| if self.model.training_mode == "lora" and self.lora is None: |
| raise ValueError("LoRA configuration must be provided when training_mode is 'lora'") |
|
|
| |
| if self.training_strategy.name == "video_to_video" and self.model.training_mode != "lora": |
| raise ValueError("Training mode must be 'lora' when using video_to_video strategy") |
|
|
| return self |
|
|