linoy
inital commit
ebfc6b3
import os
import time
import warnings
from pathlib import Path
from typing import Callable
import torch
import wandb
import yaml
from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper
from pydantic import BaseModel
from safetensors.torch import load_file, save_file
from torch import Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
LinearLR,
LRScheduler,
PolynomialLR,
StepLR,
)
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F # noqa: N812
from ltx_trainer import logger
from ltx_trainer.config import LtxTrainerConfig
from ltx_trainer.config_display import print_config
from ltx_trainer.datasets import PrecomputedDataset
from ltx_trainer.hf_hub_utils import push_to_hub
from ltx_trainer.model_loader import load_model as load_ltx_model
from ltx_trainer.model_loader import load_text_encoder
from ltx_trainer.progress import TrainingProgress
from ltx_trainer.quantization import quantize_model
from ltx_trainer.timestep_samplers import SAMPLERS
from ltx_trainer.training_strategies import get_training_strategy
from ltx_trainer.utils import get_gpu_memory_gb, open_image_as_srgb
from ltx_trainer.validation_sampler import CachedPromptEmbeddings, GenerationConfig, ValidationSampler
from ltx_trainer.video_utils import read_video, save_video
# Disable irrelevant warnings from transformers
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Silence bitsandbytes warnings about casting
warnings.filterwarnings(
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
)
# Disable progress bars if not main process
IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0"
if not IS_MAIN_PROCESS:
from transformers.utils.logging import disable_progress_bar
disable_progress_bar()
StepCallback = Callable[[int, int, list[Path]], None] # (step, total, list[sampled_video_path]) -> None
MEMORY_CHECK_INTERVAL = 200
class TrainingStats(BaseModel):
"""Statistics collected during training"""
total_time_seconds: float
steps_per_second: float
samples_per_second: float
peak_gpu_memory_gb: float
global_batch_size: int
num_processes: int
class LtxvTrainer:
def __init__(self, trainer_config: LtxTrainerConfig) -> None:
self._config = trainer_config
if IS_MAIN_PROCESS:
print_config(trainer_config)
self._training_strategy = get_training_strategy(self._config.training_strategy)
self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings()
self._load_models()
self._setup_accelerator()
self._collect_trainable_params()
self._load_checkpoint()
self._prepare_models_for_training()
self._dataset = None
self._global_step = -1
self._checkpoint_paths = []
self._init_wandb()
def train( # noqa: PLR0912, PLR0915
self,
disable_progress_bars: bool = False,
step_callback: StepCallback | None = None,
) -> tuple[Path, TrainingStats]:
"""
Start the training process.
Returns:
Tuple of (saved_model_path, training_stats)
"""
device = self._accelerator.device
cfg = self._config
start_mem = get_gpu_memory_gb(device)
train_start_time = time.time()
# Use the same seed for all processes and ensure deterministic operations
set_seed(cfg.seed)
logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}")
self._init_optimizer()
self._init_dataloader()
data_iter = iter(self._dataloader)
self._init_timestep_sampler()
# Synchronize all processes after initialization
self._accelerator.wait_for_everyone()
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
# Save the training configuration as YAML
self._save_config()
logger.info("🚀 Starting training...")
# Create progress tracking (disabled for non-main processes or when explicitly disabled)
progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars
progress = TrainingProgress(
enabled=progress_enabled,
total_steps=cfg.optimization.steps,
)
if IS_MAIN_PROCESS and disable_progress_bars:
logger.warning("Progress bars disabled. Intermediate status messages will be logged instead.")
self._transformer.train()
self._global_step = 0
peak_mem_during_training = start_mem
sampled_videos_paths = None
with progress:
# Initial validation before training starts
if cfg.validation.interval and not cfg.validation.skip_initial_validation:
sampled_videos_paths = self._sample_videos(progress)
if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts)
self._accelerator.wait_for_everyone()
for step in range(cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps):
# Get next batch, reset the dataloader if needed
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(self._dataloader)
batch = next(data_iter)
step_start_time = time.time()
with self._accelerator.accumulate(self._transformer):
is_optimization_step = (step + 1) % cfg.optimization.gradient_accumulation_steps == 0
if is_optimization_step:
self._global_step += 1
loss = self._training_step(batch)
self._accelerator.backward(loss)
if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0:
self._accelerator.clip_grad_norm_(
self._trainable_params,
cfg.optimization.max_grad_norm,
)
self._optimizer.step()
self._optimizer.zero_grad()
if self._lr_scheduler is not None:
self._lr_scheduler.step()
# Run validation if needed
if (
cfg.validation.interval
and self._global_step > 0
and self._global_step % cfg.validation.interval == 0
and is_optimization_step
):
if self._accelerator.distributed_type == DistributedType.FSDP:
# FSDP: All processes must participate in validation
sampled_videos_paths = self._sample_videos(progress)
if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts)
# DDP: Only main process runs validation
elif IS_MAIN_PROCESS:
sampled_videos_paths = self._sample_videos(progress)
if sampled_videos_paths and self._config.wandb.log_validation_videos:
self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts)
# Save checkpoint if needed
if (
cfg.checkpoints.interval
and self._global_step > 0
and self._global_step % cfg.checkpoints.interval == 0
and is_optimization_step
):
self._save_checkpoint()
self._accelerator.wait_for_everyone()
# Call step callback if provided
if step_callback and is_optimization_step:
step_callback(self._global_step, cfg.optimization.steps, sampled_videos_paths)
self._accelerator.wait_for_everyone()
# Update progress and log metrics
current_lr = self._optimizer.param_groups[0]["lr"]
step_time = (time.time() - step_start_time) * cfg.optimization.gradient_accumulation_steps
progress.update_training(
loss=loss.item(),
lr=current_lr,
step_time=step_time,
advance=is_optimization_step,
)
# Log metrics to W&B (only on main process and optimization steps)
if IS_MAIN_PROCESS and is_optimization_step:
self._log_metrics(
{
"train/loss": loss.item(),
"train/learning_rate": current_lr,
"train/step_time": step_time,
"train/global_step": self._global_step,
}
)
# Fallback logging when progress bars are disabled
if disable_progress_bars and IS_MAIN_PROCESS and self._global_step % 20 == 0:
elapsed = time.time() - train_start_time
progress_percentage = self._global_step / cfg.optimization.steps
if progress_percentage > 0:
total_estimated = elapsed / progress_percentage
total_time = f"{total_estimated // 3600:.0f}h {(total_estimated % 3600) // 60:.0f}m"
else:
total_time = "calculating..."
logger.info(
f"Step {self._global_step}/{cfg.optimization.steps} - "
f"Loss: {loss.item():.4f}, LR: {current_lr:.2e}, "
f"Time/Step: {step_time:.2f}s, Total Time: {total_time}",
)
# Sample GPU memory periodically
if step % MEMORY_CHECK_INTERVAL == 0:
current_mem = get_gpu_memory_gb(device)
peak_mem_during_training = max(peak_mem_during_training, current_mem)
# Collect final stats
train_end_time = time.time()
end_mem = get_gpu_memory_gb(device)
peak_mem = max(start_mem, end_mem, peak_mem_during_training)
# Calculate steps/second over entire training
total_time_seconds = train_end_time - train_start_time
steps_per_second = cfg.optimization.steps / total_time_seconds
samples_per_second = steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size
stats = TrainingStats(
total_time_seconds=total_time_seconds,
steps_per_second=steps_per_second,
samples_per_second=samples_per_second,
peak_gpu_memory_gb=peak_mem,
num_processes=self._accelerator.num_processes,
global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes,
)
saved_path = self._save_checkpoint()
if IS_MAIN_PROCESS:
# Log the training statistics
self._log_training_stats(stats)
# Upload artifacts to hub if enabled
if cfg.hub.push_to_hub:
push_to_hub(saved_path, sampled_videos_paths, self._config)
# Log final stats to W&B
if self._wandb_run is not None:
self._log_metrics(
{
"stats/total_time_minutes": stats.total_time_seconds / 60,
"stats/steps_per_second": stats.steps_per_second,
"stats/samples_per_second": stats.samples_per_second,
"stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb,
}
)
self._wandb_run.finish()
self._accelerator.wait_for_everyone()
self._accelerator.end_training()
return saved_path, stats
def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor:
"""Perform a single training step using the configured strategy."""
# Apply embedding connectors to transform pre-computed text embeddings
conditions = batch["conditions"]
video_embeds, audio_embeds, attention_mask = self._text_encoder._run_connectors(
conditions["prompt_embeds"], conditions["prompt_attention_mask"]
)
conditions["video_prompt_embeds"] = video_embeds
conditions["audio_prompt_embeds"] = audio_embeds
conditions["prompt_attention_mask"] = attention_mask
# Use strategy to prepare training inputs (returns ModelInputs with Modality objects)
model_inputs = self._training_strategy.prepare_training_inputs(batch, self._timestep_sampler)
# Run transformer forward pass with Modality-based interface
video_pred, audio_pred = self._transformer(
video=model_inputs.video,
audio=model_inputs.audio,
perturbations=None,
)
# Use strategy to compute loss
loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs)
return loss
def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings] | None:
"""Load text encoder, computes and returns validation embeddings."""
# This method:
# 1. Loads the text encoder on GPU
# 2. If validation prompts are configured, computes and caches their embeddings
# 3. Unloads the heavy Gemma model while keeping the lightweight embedding connectors
# The text encoder is kept (as self._text_encoder) but with model/tokenizer/feature_extractor
# set to None. Only the embedding connectors remain for use during training.
# Load text encoder on GPU
logger.debug("Loading text encoder...")
if self._config.acceleration.load_text_encoder_in_8bit:
logger.warning(
"⚠️ load_text_encoder_in_8bit is set to True but 8-bit text encoder loading "
"is not currently implemented. The text encoder will be loaded in bfloat16 precision."
)
self._text_encoder = load_text_encoder(
checkpoint_path=self._config.model.model_path,
gemma_model_path=self._config.model.text_encoder_path,
device="cuda",
dtype=torch.bfloat16,
)
# Cache validation embeddings if prompts are configured
cached_embeddings = None
if self._config.validation.prompts:
logger.info(f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts...")
cached_embeddings = []
with torch.inference_mode():
for prompt in self._config.validation.prompts:
v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt)
v_ctx_neg, a_ctx_neg, _ = self._text_encoder(self._config.validation.negative_prompt)
cached_embeddings.append(
CachedPromptEmbeddings(
video_context_positive=v_ctx_pos.cpu(),
audio_context_positive=a_ctx_pos.cpu(),
video_context_negative=v_ctx_neg.cpu() if v_ctx_neg is not None else None,
audio_context_negative=a_ctx_neg.cpu() if a_ctx_neg is not None else None,
)
)
# Unload heavy components to free VRAM, keeping only the embedding connectors
self._text_encoder.model = None
self._text_encoder.tokenizer = None
self._text_encoder.feature_extractor_linear = None
torch.cuda.empty_cache()
logger.debug("Validation prompt embeddings cached. Gemma model unloaded")
return cached_embeddings
def _load_models(self) -> None:
"""Load the LTX-2 model components."""
# Load audio components if:
# 1. Training strategy requires audio (training the audio branch), OR
# 2. Validation is configured to generate audio (even if not training audio)
load_audio = self._training_strategy.requires_audio or self._config.validation.generate_audio
# Check if we need VAE encoder (for image or reference video conditioning)
need_vae_encoder = (
self._config.validation.images is not None or self._config.validation.reference_videos is not None
)
# Load all model components (except text encoder - already handled)
components = load_ltx_model(
checkpoint_path=self._config.model.model_path,
device="cpu",
dtype=torch.bfloat16,
with_video_vae_encoder=need_vae_encoder, # Needed for image conditioning
with_video_vae_decoder=True, # Needed for validation sampling
with_audio_vae_decoder=load_audio,
with_vocoder=load_audio,
with_text_encoder=False, # Text encoder handled separately
)
# Extract components
self._transformer = components.transformer
self._vae_decoder = components.video_vae_decoder.to(dtype=torch.bfloat16)
self._vae_encoder = components.video_vae_encoder
if self._vae_encoder is not None:
self._vae_encoder = self._vae_encoder.to(dtype=torch.bfloat16)
self._scheduler = components.scheduler
self._audio_vae = components.audio_vae_decoder
self._vocoder = components.vocoder
# Note: self._text_encoder was set in _load_text_encoder_and_cache_embeddings
# Determine initial dtype based on training mode.
# Note: For FSDP + LoRA, we'll cast to FP32 later in _prepare_models_for_training()
# after the accelerator is set up, and we can detect FSDP.
transformer_dtype = torch.bfloat16 if self._config.model.training_mode == "lora" else torch.float32
self._transformer = self._transformer.to(dtype=transformer_dtype)
if self._config.acceleration.quantization is not None:
if self._config.model.training_mode == "full":
raise ValueError("Quantization is not supported in full training mode.")
logger.warning(f"Quantizing model with precision: {self._config.acceleration.quantization}")
self._transformer = quantize_model(
self._transformer,
precision=self._config.acceleration.quantization,
)
# Freeze all models. We later unfreeze the transformer based on training mode.
# Note: embedding_connectors are already frozen (they come from the frozen text encoder)
self._vae_decoder.requires_grad_(False)
if self._vae_encoder is not None:
self._vae_encoder.requires_grad_(False)
self._transformer.requires_grad_(False)
if self._audio_vae is not None:
self._audio_vae.requires_grad_(False)
if self._vocoder is not None:
self._vocoder.requires_grad_(False)
def _collect_trainable_params(self) -> None:
"""Collect trainable parameters based on training mode."""
if self._config.model.training_mode == "lora":
# For LoRA training, first set up LoRA layers
self._setup_lora()
elif self._config.model.training_mode == "full":
# For full training, unfreeze all transformer parameters
self._transformer.requires_grad_(True)
else:
raise ValueError(f"Unknown training mode: {self._config.model.training_mode}")
self._trainable_params = [p for p in self._transformer.parameters() if p.requires_grad]
logger.debug(f"Trainable params count: {sum(p.numel() for p in self._trainable_params):,}")
def _init_timestep_sampler(self) -> None:
"""Initialize the timestep sampler based on the config."""
sampler_cls = SAMPLERS[self._config.flow_matching.timestep_sampling_mode]
self._timestep_sampler = sampler_cls(**self._config.flow_matching.timestep_sampling_params)
def _setup_lora(self) -> None:
"""Configure LoRA adapters for the transformer. Only called in LoRA training mode."""
logger.debug(f"Adding LoRA adapter with rank {self._config.lora.rank}")
lora_config = LoraConfig(
r=self._config.lora.rank,
lora_alpha=self._config.lora.alpha,
target_modules=self._config.lora.target_modules,
lora_dropout=self._config.lora.dropout,
init_lora_weights=True,
)
# Wrap the transformer with PEFT to add LoRA layers
# noinspection PyTypeChecker
self._transformer = get_peft_model(self._transformer, lora_config)
def _load_checkpoint(self) -> None:
"""Load checkpoint if specified in config."""
if not self._config.model.load_checkpoint:
return
checkpoint_path = self._find_checkpoint(self._config.model.load_checkpoint)
if not checkpoint_path:
logger.warning(f"⚠️ Could not find checkpoint at {self._config.model.load_checkpoint}")
return
logger.info(f"📥 Loading checkpoint from {checkpoint_path}")
if self._config.model.training_mode == "full":
self._load_full_checkpoint(checkpoint_path)
else: # LoRA mode
self._load_lora_checkpoint(checkpoint_path)
def _load_full_checkpoint(self, checkpoint_path: Path) -> None:
"""Load full model checkpoint."""
state_dict = load_file(checkpoint_path)
self._transformer.load_state_dict(state_dict, strict=True)
logger.info("✅ Full model checkpoint loaded successfully")
def _load_lora_checkpoint(self, checkpoint_path: Path) -> None:
"""Load LoRA checkpoint with DDP/FSDP compatibility."""
state_dict = load_file(checkpoint_path)
# Adjust layer names to match internal format.
# (Weights are saved in ComfyUI-compatible format, with "diffusion_model." prefix)
state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()}
# Load LoRA weights and verify all weights were loaded
base_model = self._transformer.get_base_model()
set_peft_model_state_dict(base_model, state_dict)
logger.info("✅ LoRA checkpoint loaded successfully")
def _prepare_models_for_training(self) -> None:
"""Prepare models for training with Accelerate."""
# For FSDP + LoRA: Cast entire model to FP32.
# FSDP requires uniform dtype across all parameters in wrapped modules.
# In LoRA mode, PEFT creates LoRA params in FP32 while base model is BF16.
# We cast the base model to FP32 to match the LoRA params.
if self._accelerator.distributed_type == DistributedType.FSDP and self._config.model.training_mode == "lora":
logger.debug("FSDP: casting transformer to FP32 for uniform dtype")
self._transformer = self._transformer.to(dtype=torch.float32)
# Enable gradient checkpointing if requested
# For PeftModel, we need to access the underlying base model
transformer = (
self._transformer.get_base_model() if hasattr(self._transformer, "get_base_model") else self._transformer
)
transformer.set_gradient_checkpointing(self._config.optimization.enable_gradient_checkpointing)
# Keep frozen models on CPU for memory efficiency
self._vae_decoder = self._vae_decoder.to("cpu")
if self._vae_encoder is not None:
self._vae_encoder = self._vae_encoder.to("cpu")
# Embedding connectors are already on GPU from _load_text_encoder_and_cache_embeddings
# noinspection PyTypeChecker
self._transformer = self._accelerator.prepare(self._transformer)
# Log GPU memory usage after model preparation
vram_usage_gb = torch.cuda.memory_allocated() / 1024**3
logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB")
@staticmethod
def _find_checkpoint(checkpoint_path: str | Path) -> Path | None:
"""Find the checkpoint file to load, handling both file and directory paths."""
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
if not checkpoint_path.suffix == ".safetensors":
raise ValueError(f"Checkpoint file must have a .safetensors extension: {checkpoint_path}")
return checkpoint_path
if checkpoint_path.is_dir():
# Look for checkpoint files in the directory
checkpoints = list(checkpoint_path.rglob("*step_*.safetensors"))
if not checkpoints:
return None
# Sort by step number and return the latest
def _get_step_num(p: Path) -> int:
try:
return int(p.stem.split("step_")[1])
except (IndexError, ValueError):
return -1
latest = max(checkpoints, key=_get_step_num)
return latest
else:
raise ValueError(f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory.")
def _init_dataloader(self) -> None:
"""Initialize the training data loader using the strategy's data sources."""
if self._dataset is None:
# Get data sources from the training strategy
data_sources = self._training_strategy.get_data_sources()
self._dataset = PrecomputedDataset(self._config.data.preprocessed_data_root, data_sources=data_sources)
logger.debug(f"Loaded dataset with {len(self._dataset):,} samples from sources: {list(data_sources)}")
num_workers = self._config.data.num_dataloader_workers
dataloader = DataLoader(
self._dataset,
batch_size=self._config.optimization.batch_size,
shuffle=True,
drop_last=True,
num_workers=num_workers,
pin_memory=num_workers > 0,
persistent_workers=num_workers > 0,
)
self._dataloader = self._accelerator.prepare(dataloader)
def _init_lora_weights(self) -> None:
"""Initialize LoRA weights for the transformer."""
logger.debug("Initializing LoRA weights...")
for _, module in self._transformer.named_modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.reset_lora_parameters(adapter_name="default", init_lora_weights=True)
def _init_optimizer(self) -> None:
"""Initialize the optimizer and learning rate scheduler."""
opt_cfg = self._config.optimization
lr = opt_cfg.learning_rate
if opt_cfg.optimizer_type == "adamw":
optimizer = AdamW(self._trainable_params, lr=lr)
elif opt_cfg.optimizer_type == "adamw8bit":
# noinspection PyUnresolvedReferences
from bitsandbytes.optim import AdamW8bit # noqa: PLC0415
optimizer = AdamW8bit(self._trainable_params, lr=lr)
else:
raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}")
# Add scheduler initialization
lr_scheduler = self._create_scheduler(optimizer)
# noinspection PyTypeChecker
self._optimizer, self._lr_scheduler = self._accelerator.prepare(optimizer, lr_scheduler)
def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None:
"""Create learning rate scheduler based on config."""
scheduler_type = self._config.optimization.scheduler_type
steps = self._config.optimization.steps
params = self._config.optimization.scheduler_params or {}
if scheduler_type is None:
return None
if scheduler_type == "linear":
scheduler = LinearLR(
optimizer,
start_factor=params.pop("start_factor", 1.0),
end_factor=params.pop("end_factor", 0.1),
total_iters=steps,
**params,
)
elif scheduler_type == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=steps,
eta_min=params.pop("eta_min", 0),
**params,
)
elif scheduler_type == "cosine_with_restarts":
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=params.pop("T_0", steps // 4), # First restart cycle length
T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths
eta_min=params.pop("eta_min", 5e-5),
**params,
)
elif scheduler_type == "polynomial":
scheduler = PolynomialLR(
optimizer,
total_iters=steps,
power=params.pop("power", 1.0),
**params,
)
elif scheduler_type == "step":
scheduler = StepLR(
optimizer,
step_size=params.pop("step_size", steps // 2),
gamma=params.pop("gamma", 0.1),
**params,
)
elif scheduler_type == "constant":
scheduler = None
else:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
return scheduler
def _setup_accelerator(self) -> None:
"""Initialize the Accelerator with the appropriate settings."""
# All distributed setup (DDP/FSDP, number of processes, etc.) is controlled by
# the user's Accelerate configuration (accelerate config / accelerate launch).
self._accelerator = Accelerator(
mixed_precision=self._config.acceleration.mixed_precision_mode,
gradient_accumulation_steps=self._config.optimization.gradient_accumulation_steps,
)
if self._accelerator.num_processes > 1:
logger.info(
f"{self._accelerator.distributed_type.value} distributed training enabled "
f"with {self._accelerator.num_processes} processes"
)
local_batch = self._config.optimization.batch_size
global_batch = self._config.optimization.batch_size * self._accelerator.num_processes
logger.info(f"Local batch size: {local_batch}, global batch size: {global_batch}")
# Log torch.compile status from Accelerate's dynamo plugin
is_compile_enabled = (
hasattr(self._accelerator.state, "dynamo_plugin") and self._accelerator.state.dynamo_plugin.backend != "NO"
)
if is_compile_enabled:
plugin = self._accelerator.state.dynamo_plugin
logger.info(f"🔥 torch.compile enabled via Accelerate: backend={plugin.backend}, mode={plugin.mode}")
if self._accelerator.distributed_type == DistributedType.FSDP:
logger.warning(
"⚠️ FSDP + torch.compile is experimental and may hang on the first training iteration. "
"If this occurs, disable torch.compile by removing dynamo_config from your Accelerate config."
)
if self._accelerator.distributed_type == DistributedType.FSDP and self._config.acceleration.quantization:
logger.warning(
f"FSDP with quantization ({self._config.acceleration.quantization}) may have compatibility issues."
"Monitor training stability and consider disabling quantization if issues arise."
)
# Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation
@torch.no_grad()
def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
"""Run validation by generating videos from validation prompts."""
use_images = self._config.validation.images is not None
use_reference_videos = self._config.validation.reference_videos is not None
generate_audio = self._config.validation.generate_audio
inference_steps = self._config.validation.inference_steps
# Free up GPU memory before validation sampling.
# Zero gradients and empty the cache to reclaim memory.
self._optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
# Start sampling progress tracking
sampling_ctx = progress.start_sampling(
num_prompts=len(self._config.validation.prompts),
num_steps=inference_steps,
)
# Create validation sampler with loaded models and progress tracking
sampler = ValidationSampler(
transformer=self._transformer,
vae_decoder=self._vae_decoder,
vae_encoder=self._vae_encoder,
text_encoder=None,
audio_decoder=self._audio_vae if generate_audio else None,
vocoder=self._vocoder if generate_audio else None,
sampling_context=sampling_ctx,
)
output_dir = Path(self._config.output_dir) / "samples"
output_dir.mkdir(exist_ok=True, parents=True)
video_paths = []
width, height, num_frames = self._config.validation.video_dims
for prompt_idx, prompt in enumerate(self._config.validation.prompts):
# Update progress to show current video
sampling_ctx.start_video(prompt_idx)
# Load conditioning image if provided
condition_image = None
if use_images:
image_path = self._config.validation.images[prompt_idx]
image = open_image_as_srgb(image_path)
# Convert PIL image to tensor [C, H, W] in [0, 1]
condition_image = F.to_tensor(image)
# Load reference video if provided (for IC-LoRA)
reference_video = None
if use_reference_videos:
ref_video_path = self._config.validation.reference_videos[prompt_idx]
# read_video returns [F, C, H, W] in [0, 1]
reference_video, _ = read_video(ref_video_path, max_frames=num_frames)
# Get cached embeddings for this prompt if available
cached_embeddings = (
self._cached_validation_embeddings[prompt_idx]
if self._cached_validation_embeddings is not None
else None
)
# Create generation config
gen_config = GenerationConfig(
prompt=prompt,
negative_prompt=self._config.validation.negative_prompt,
height=height,
width=width,
num_frames=num_frames,
frame_rate=self._config.validation.frame_rate,
num_inference_steps=inference_steps,
guidance_scale=self._config.validation.guidance_scale,
seed=self._config.validation.seed,
condition_image=condition_image,
reference_video=reference_video,
generate_audio=generate_audio,
include_reference_in_output=self._config.validation.include_reference_in_output,
cached_embeddings=cached_embeddings,
stg_scale=self._config.validation.stg_scale,
stg_blocks=self._config.validation.stg_blocks,
stg_mode=self._config.validation.stg_mode,
)
# Generate sample
video, audio = sampler.generate(
config=gen_config,
device=self._accelerator.device,
)
# Save video
if IS_MAIN_PROCESS:
video_path = output_dir / f"step_{self._global_step:06d}_{prompt_idx + 1}.mp4"
save_video(
video_tensor=video,
output_path=video_path,
fps=self._config.validation.frame_rate,
audio=audio,
audio_sample_rate=self._vocoder.output_sample_rate if audio is not None else None,
)
video_paths.append(video_path)
# Clean up progress tasks
sampling_ctx.cleanup()
# Clear GPU cache after validation
torch.cuda.empty_cache()
rel_outputs_path = output_dir.relative_to(self._config.output_dir)
logger.info(f"🎥 Validation samples for step {self._global_step} saved in {rel_outputs_path}")
return video_paths
@staticmethod
def _log_training_stats(stats: TrainingStats) -> None:
"""Log training statistics."""
stats_str = (
"📊 Training Statistics:\n"
f" - Total time: {stats.total_time_seconds / 60:.1f} minutes\n"
f" - Training speed: {stats.steps_per_second:.2f} steps/second\n"
f" - Samples/second: {stats.samples_per_second:.2f}\n"
f" - Peak GPU memory: {stats.peak_gpu_memory_gb:.2f} GB"
)
if stats.num_processes > 1:
stats_str += f"\n - Number of processes: {stats.num_processes}\n"
stats_str += f" - Global batch size: {stats.global_batch_size}"
logger.info(stats_str)
def _save_checkpoint(self) -> Path | None:
"""Save the model weights."""
is_lora = self._config.model.training_mode == "lora"
is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP
# Prepare paths
save_dir = Path(self._config.output_dir) / "checkpoints"
prefix = "lora" if is_lora else "model"
filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors"
saved_weights_path = save_dir / filename
# Get state dict (collective operation - all processes must participate)
self._accelerator.wait_for_everyone()
full_state_dict = self._accelerator.get_state_dict(self._transformer)
if not IS_MAIN_PROCESS:
return None
save_dir.mkdir(exist_ok=True, parents=True)
# For LoRA: extract only adapter weights; for full: use as-is
if is_lora:
unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False)
# For FSDP, pass full_state_dict since model params aren't directly accessible
state_dict = get_peft_model_state_dict(unwrapped, state_dict=full_state_dict if is_fsdp else None)
# Remove "base_model.model." prefix added by PEFT
state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()}
# Convert to ComfyUI-compatible format (add "diffusion_model." prefix)
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
# Save to disk
save_file(state_dict, saved_weights_path)
else:
# Save to disk
self._accelerator.save(full_state_dict, saved_weights_path)
rel_path = saved_weights_path.relative_to(self._config.output_dir)
logger.info(f"💾 {prefix.capitalize()} weights for step {self._global_step} saved in {rel_path}")
# Keep track of checkpoint paths, and cleanup old checkpoints if needed
self._checkpoint_paths.append(saved_weights_path)
self._cleanup_checkpoints()
return saved_weights_path
def _cleanup_checkpoints(self) -> None:
"""Clean up old checkpoints."""
if 0 < self._config.checkpoints.keep_last_n < len(self._checkpoint_paths):
checkpoints_to_remove = self._checkpoint_paths[: -self._config.checkpoints.keep_last_n]
for old_checkpoint in checkpoints_to_remove:
if old_checkpoint.exists():
old_checkpoint.unlink()
logger.info(f"Removed old checkpoints: {old_checkpoint}")
# Update the list to only contain kept checkpoints
self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :]
def _save_config(self) -> None:
"""Save the training configuration as a YAML file in the output directory."""
if not IS_MAIN_PROCESS:
return
config_path = Path(self._config.output_dir) / "training_config.yaml"
with open(config_path, "w") as f:
yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2)
logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}")
def _init_wandb(self) -> None:
"""Initialize Weights & Biases run."""
if not self._config.wandb.enabled or not IS_MAIN_PROCESS:
self._wandb_run = None
return
wandb_config = self._config.wandb
run = wandb.init(
project=wandb_config.project,
entity=wandb_config.entity,
name=Path(self._config.output_dir).name,
tags=wandb_config.tags,
config=self._config.model_dump(),
)
self._wandb_run = run
def _log_metrics(self, metrics: dict[str, float]) -> None:
"""Log metrics to Weights & Biases."""
if self._wandb_run is not None:
self._wandb_run.log(metrics)
def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None:
"""Log validation videos to Weights & Biases."""
if not self._config.wandb.log_validation_videos or self._wandb_run is None:
return
# Create lists of videos with their captions
validation_videos = [
wandb.Video(str(video_path), caption=prompt, format="mp4")
for video_path, prompt in zip(video_paths, prompts, strict=False)
]
# Log all videos at once
self._wandb_run.log(
{
"validation_videos": validation_videos,
},
step=self._global_step,
)