Spaces:
Paused
Paused
| import os | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| from typing import Callable | |
| import torch | |
| import wandb | |
| import yaml | |
| from accelerate import Accelerator, DistributedType | |
| from accelerate.utils import set_seed | |
| from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| from peft.utils import ModulesToSaveWrapper | |
| from pydantic import BaseModel | |
| from safetensors.torch import load_file, save_file | |
| from torch import Tensor | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ( | |
| CosineAnnealingLR, | |
| CosineAnnealingWarmRestarts, | |
| LinearLR, | |
| LRScheduler, | |
| PolynomialLR, | |
| StepLR, | |
| ) | |
| from torch.utils.data import DataLoader | |
| from torchvision.transforms import functional as F # noqa: N812 | |
| from ltx_trainer import logger | |
| from ltx_trainer.config import LtxTrainerConfig | |
| from ltx_trainer.config_display import print_config | |
| from ltx_trainer.datasets import PrecomputedDataset | |
| from ltx_trainer.hf_hub_utils import push_to_hub | |
| from ltx_trainer.model_loader import load_model as load_ltx_model | |
| from ltx_trainer.model_loader import load_text_encoder | |
| from ltx_trainer.progress import TrainingProgress | |
| from ltx_trainer.quantization import quantize_model | |
| from ltx_trainer.timestep_samplers import SAMPLERS | |
| from ltx_trainer.training_strategies import get_training_strategy | |
| from ltx_trainer.utils import get_gpu_memory_gb, open_image_as_srgb | |
| from ltx_trainer.validation_sampler import CachedPromptEmbeddings, GenerationConfig, ValidationSampler | |
| from ltx_trainer.video_utils import read_video, save_video | |
| # Disable irrelevant warnings from transformers | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| # Silence bitsandbytes warnings about casting | |
| warnings.filterwarnings( | |
| "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization" | |
| ) | |
| # Disable progress bars if not main process | |
| IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0" | |
| if not IS_MAIN_PROCESS: | |
| from transformers.utils.logging import disable_progress_bar | |
| disable_progress_bar() | |
| StepCallback = Callable[[int, int, list[Path]], None] # (step, total, list[sampled_video_path]) -> None | |
| MEMORY_CHECK_INTERVAL = 200 | |
| class TrainingStats(BaseModel): | |
| """Statistics collected during training""" | |
| total_time_seconds: float | |
| steps_per_second: float | |
| samples_per_second: float | |
| peak_gpu_memory_gb: float | |
| global_batch_size: int | |
| num_processes: int | |
| class LtxvTrainer: | |
| def __init__(self, trainer_config: LtxTrainerConfig) -> None: | |
| self._config = trainer_config | |
| if IS_MAIN_PROCESS: | |
| print_config(trainer_config) | |
| self._training_strategy = get_training_strategy(self._config.training_strategy) | |
| self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings() | |
| self._load_models() | |
| self._setup_accelerator() | |
| self._collect_trainable_params() | |
| self._load_checkpoint() | |
| self._prepare_models_for_training() | |
| self._dataset = None | |
| self._global_step = -1 | |
| self._checkpoint_paths = [] | |
| self._init_wandb() | |
| def train( # noqa: PLR0912, PLR0915 | |
| self, | |
| disable_progress_bars: bool = False, | |
| step_callback: StepCallback | None = None, | |
| ) -> tuple[Path, TrainingStats]: | |
| """ | |
| Start the training process. | |
| Returns: | |
| Tuple of (saved_model_path, training_stats) | |
| """ | |
| device = self._accelerator.device | |
| cfg = self._config | |
| start_mem = get_gpu_memory_gb(device) | |
| train_start_time = time.time() | |
| # Use the same seed for all processes and ensure deterministic operations | |
| set_seed(cfg.seed) | |
| logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}") | |
| self._init_optimizer() | |
| self._init_dataloader() | |
| data_iter = iter(self._dataloader) | |
| self._init_timestep_sampler() | |
| # Synchronize all processes after initialization | |
| self._accelerator.wait_for_everyone() | |
| Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) | |
| # Save the training configuration as YAML | |
| self._save_config() | |
| logger.info("🚀 Starting training...") | |
| # Create progress tracking (disabled for non-main processes or when explicitly disabled) | |
| progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars | |
| progress = TrainingProgress( | |
| enabled=progress_enabled, | |
| total_steps=cfg.optimization.steps, | |
| ) | |
| if IS_MAIN_PROCESS and disable_progress_bars: | |
| logger.warning("Progress bars disabled. Intermediate status messages will be logged instead.") | |
| self._transformer.train() | |
| self._global_step = 0 | |
| peak_mem_during_training = start_mem | |
| sampled_videos_paths = None | |
| with progress: | |
| # Initial validation before training starts | |
| if cfg.validation.interval and not cfg.validation.skip_initial_validation: | |
| sampled_videos_paths = self._sample_videos(progress) | |
| if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos: | |
| self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) | |
| self._accelerator.wait_for_everyone() | |
| for step in range(cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps): | |
| # Get next batch, reset the dataloader if needed | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(self._dataloader) | |
| batch = next(data_iter) | |
| step_start_time = time.time() | |
| with self._accelerator.accumulate(self._transformer): | |
| is_optimization_step = (step + 1) % cfg.optimization.gradient_accumulation_steps == 0 | |
| if is_optimization_step: | |
| self._global_step += 1 | |
| loss = self._training_step(batch) | |
| self._accelerator.backward(loss) | |
| if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0: | |
| self._accelerator.clip_grad_norm_( | |
| self._trainable_params, | |
| cfg.optimization.max_grad_norm, | |
| ) | |
| self._optimizer.step() | |
| self._optimizer.zero_grad() | |
| if self._lr_scheduler is not None: | |
| self._lr_scheduler.step() | |
| # Run validation if needed | |
| if ( | |
| cfg.validation.interval | |
| and self._global_step > 0 | |
| and self._global_step % cfg.validation.interval == 0 | |
| and is_optimization_step | |
| ): | |
| if self._accelerator.distributed_type == DistributedType.FSDP: | |
| # FSDP: All processes must participate in validation | |
| sampled_videos_paths = self._sample_videos(progress) | |
| if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos: | |
| self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) | |
| # DDP: Only main process runs validation | |
| elif IS_MAIN_PROCESS: | |
| sampled_videos_paths = self._sample_videos(progress) | |
| if sampled_videos_paths and self._config.wandb.log_validation_videos: | |
| self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) | |
| # Save checkpoint if needed | |
| if ( | |
| cfg.checkpoints.interval | |
| and self._global_step > 0 | |
| and self._global_step % cfg.checkpoints.interval == 0 | |
| and is_optimization_step | |
| ): | |
| self._save_checkpoint() | |
| self._accelerator.wait_for_everyone() | |
| # Call step callback if provided | |
| if step_callback and is_optimization_step: | |
| step_callback(self._global_step, cfg.optimization.steps, sampled_videos_paths) | |
| self._accelerator.wait_for_everyone() | |
| # Update progress and log metrics | |
| current_lr = self._optimizer.param_groups[0]["lr"] | |
| step_time = (time.time() - step_start_time) * cfg.optimization.gradient_accumulation_steps | |
| progress.update_training( | |
| loss=loss.item(), | |
| lr=current_lr, | |
| step_time=step_time, | |
| advance=is_optimization_step, | |
| ) | |
| # Log metrics to W&B (only on main process and optimization steps) | |
| if IS_MAIN_PROCESS and is_optimization_step: | |
| self._log_metrics( | |
| { | |
| "train/loss": loss.item(), | |
| "train/learning_rate": current_lr, | |
| "train/step_time": step_time, | |
| "train/global_step": self._global_step, | |
| } | |
| ) | |
| # Fallback logging when progress bars are disabled | |
| if disable_progress_bars and IS_MAIN_PROCESS and self._global_step % 20 == 0: | |
| elapsed = time.time() - train_start_time | |
| progress_percentage = self._global_step / cfg.optimization.steps | |
| if progress_percentage > 0: | |
| total_estimated = elapsed / progress_percentage | |
| total_time = f"{total_estimated // 3600:.0f}h {(total_estimated % 3600) // 60:.0f}m" | |
| else: | |
| total_time = "calculating..." | |
| logger.info( | |
| f"Step {self._global_step}/{cfg.optimization.steps} - " | |
| f"Loss: {loss.item():.4f}, LR: {current_lr:.2e}, " | |
| f"Time/Step: {step_time:.2f}s, Total Time: {total_time}", | |
| ) | |
| # Sample GPU memory periodically | |
| if step % MEMORY_CHECK_INTERVAL == 0: | |
| current_mem = get_gpu_memory_gb(device) | |
| peak_mem_during_training = max(peak_mem_during_training, current_mem) | |
| # Collect final stats | |
| train_end_time = time.time() | |
| end_mem = get_gpu_memory_gb(device) | |
| peak_mem = max(start_mem, end_mem, peak_mem_during_training) | |
| # Calculate steps/second over entire training | |
| total_time_seconds = train_end_time - train_start_time | |
| steps_per_second = cfg.optimization.steps / total_time_seconds | |
| samples_per_second = steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size | |
| stats = TrainingStats( | |
| total_time_seconds=total_time_seconds, | |
| steps_per_second=steps_per_second, | |
| samples_per_second=samples_per_second, | |
| peak_gpu_memory_gb=peak_mem, | |
| num_processes=self._accelerator.num_processes, | |
| global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes, | |
| ) | |
| saved_path = self._save_checkpoint() | |
| if IS_MAIN_PROCESS: | |
| # Log the training statistics | |
| self._log_training_stats(stats) | |
| # Upload artifacts to hub if enabled | |
| if cfg.hub.push_to_hub: | |
| push_to_hub(saved_path, sampled_videos_paths, self._config) | |
| # Log final stats to W&B | |
| if self._wandb_run is not None: | |
| self._log_metrics( | |
| { | |
| "stats/total_time_minutes": stats.total_time_seconds / 60, | |
| "stats/steps_per_second": stats.steps_per_second, | |
| "stats/samples_per_second": stats.samples_per_second, | |
| "stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb, | |
| } | |
| ) | |
| self._wandb_run.finish() | |
| self._accelerator.wait_for_everyone() | |
| self._accelerator.end_training() | |
| return saved_path, stats | |
| def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor: | |
| """Perform a single training step using the configured strategy.""" | |
| # Apply embedding connectors to transform pre-computed text embeddings | |
| conditions = batch["conditions"] | |
| video_embeds, audio_embeds, attention_mask = self._text_encoder._run_connectors( | |
| conditions["prompt_embeds"], conditions["prompt_attention_mask"] | |
| ) | |
| conditions["video_prompt_embeds"] = video_embeds | |
| conditions["audio_prompt_embeds"] = audio_embeds | |
| conditions["prompt_attention_mask"] = attention_mask | |
| # Use strategy to prepare training inputs (returns ModelInputs with Modality objects) | |
| model_inputs = self._training_strategy.prepare_training_inputs(batch, self._timestep_sampler) | |
| # Run transformer forward pass with Modality-based interface | |
| video_pred, audio_pred = self._transformer( | |
| video=model_inputs.video, | |
| audio=model_inputs.audio, | |
| perturbations=None, | |
| ) | |
| # Use strategy to compute loss | |
| loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs) | |
| return loss | |
| def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings] | None: | |
| """Load text encoder, computes and returns validation embeddings.""" | |
| # This method: | |
| # 1. Loads the text encoder on GPU | |
| # 2. If validation prompts are configured, computes and caches their embeddings | |
| # 3. Unloads the heavy Gemma model while keeping the lightweight embedding connectors | |
| # The text encoder is kept (as self._text_encoder) but with model/tokenizer/feature_extractor | |
| # set to None. Only the embedding connectors remain for use during training. | |
| # Load text encoder on GPU | |
| logger.debug("Loading text encoder...") | |
| if self._config.acceleration.load_text_encoder_in_8bit: | |
| logger.warning( | |
| "⚠️ load_text_encoder_in_8bit is set to True but 8-bit text encoder loading " | |
| "is not currently implemented. The text encoder will be loaded in bfloat16 precision." | |
| ) | |
| self._text_encoder = load_text_encoder( | |
| checkpoint_path=self._config.model.model_path, | |
| gemma_model_path=self._config.model.text_encoder_path, | |
| device="cuda", | |
| dtype=torch.bfloat16, | |
| ) | |
| # Cache validation embeddings if prompts are configured | |
| cached_embeddings = None | |
| if self._config.validation.prompts: | |
| logger.info(f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts...") | |
| cached_embeddings = [] | |
| with torch.inference_mode(): | |
| for prompt in self._config.validation.prompts: | |
| v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) | |
| v_ctx_neg, a_ctx_neg, _ = self._text_encoder(self._config.validation.negative_prompt) | |
| cached_embeddings.append( | |
| CachedPromptEmbeddings( | |
| video_context_positive=v_ctx_pos.cpu(), | |
| audio_context_positive=a_ctx_pos.cpu(), | |
| video_context_negative=v_ctx_neg.cpu() if v_ctx_neg is not None else None, | |
| audio_context_negative=a_ctx_neg.cpu() if a_ctx_neg is not None else None, | |
| ) | |
| ) | |
| # Unload heavy components to free VRAM, keeping only the embedding connectors | |
| self._text_encoder.model = None | |
| self._text_encoder.tokenizer = None | |
| self._text_encoder.feature_extractor_linear = None | |
| torch.cuda.empty_cache() | |
| logger.debug("Validation prompt embeddings cached. Gemma model unloaded") | |
| return cached_embeddings | |
| def _load_models(self) -> None: | |
| """Load the LTX-2 model components.""" | |
| # Load audio components if: | |
| # 1. Training strategy requires audio (training the audio branch), OR | |
| # 2. Validation is configured to generate audio (even if not training audio) | |
| load_audio = self._training_strategy.requires_audio or self._config.validation.generate_audio | |
| # Check if we need VAE encoder (for image or reference video conditioning) | |
| need_vae_encoder = ( | |
| self._config.validation.images is not None or self._config.validation.reference_videos is not None | |
| ) | |
| # Load all model components (except text encoder - already handled) | |
| components = load_ltx_model( | |
| checkpoint_path=self._config.model.model_path, | |
| device="cpu", | |
| dtype=torch.bfloat16, | |
| with_video_vae_encoder=need_vae_encoder, # Needed for image conditioning | |
| with_video_vae_decoder=True, # Needed for validation sampling | |
| with_audio_vae_decoder=load_audio, | |
| with_vocoder=load_audio, | |
| with_text_encoder=False, # Text encoder handled separately | |
| ) | |
| # Extract components | |
| self._transformer = components.transformer | |
| self._vae_decoder = components.video_vae_decoder.to(dtype=torch.bfloat16) | |
| self._vae_encoder = components.video_vae_encoder | |
| if self._vae_encoder is not None: | |
| self._vae_encoder = self._vae_encoder.to(dtype=torch.bfloat16) | |
| self._scheduler = components.scheduler | |
| self._audio_vae = components.audio_vae_decoder | |
| self._vocoder = components.vocoder | |
| # Note: self._text_encoder was set in _load_text_encoder_and_cache_embeddings | |
| # Determine initial dtype based on training mode. | |
| # Note: For FSDP + LoRA, we'll cast to FP32 later in _prepare_models_for_training() | |
| # after the accelerator is set up, and we can detect FSDP. | |
| transformer_dtype = torch.bfloat16 if self._config.model.training_mode == "lora" else torch.float32 | |
| self._transformer = self._transformer.to(dtype=transformer_dtype) | |
| if self._config.acceleration.quantization is not None: | |
| if self._config.model.training_mode == "full": | |
| raise ValueError("Quantization is not supported in full training mode.") | |
| logger.warning(f"Quantizing model with precision: {self._config.acceleration.quantization}") | |
| self._transformer = quantize_model( | |
| self._transformer, | |
| precision=self._config.acceleration.quantization, | |
| ) | |
| # Freeze all models. We later unfreeze the transformer based on training mode. | |
| # Note: embedding_connectors are already frozen (they come from the frozen text encoder) | |
| self._vae_decoder.requires_grad_(False) | |
| if self._vae_encoder is not None: | |
| self._vae_encoder.requires_grad_(False) | |
| self._transformer.requires_grad_(False) | |
| if self._audio_vae is not None: | |
| self._audio_vae.requires_grad_(False) | |
| if self._vocoder is not None: | |
| self._vocoder.requires_grad_(False) | |
| def _collect_trainable_params(self) -> None: | |
| """Collect trainable parameters based on training mode.""" | |
| if self._config.model.training_mode == "lora": | |
| # For LoRA training, first set up LoRA layers | |
| self._setup_lora() | |
| elif self._config.model.training_mode == "full": | |
| # For full training, unfreeze all transformer parameters | |
| self._transformer.requires_grad_(True) | |
| else: | |
| raise ValueError(f"Unknown training mode: {self._config.model.training_mode}") | |
| self._trainable_params = [p for p in self._transformer.parameters() if p.requires_grad] | |
| logger.debug(f"Trainable params count: {sum(p.numel() for p in self._trainable_params):,}") | |
| def _init_timestep_sampler(self) -> None: | |
| """Initialize the timestep sampler based on the config.""" | |
| sampler_cls = SAMPLERS[self._config.flow_matching.timestep_sampling_mode] | |
| self._timestep_sampler = sampler_cls(**self._config.flow_matching.timestep_sampling_params) | |
| def _setup_lora(self) -> None: | |
| """Configure LoRA adapters for the transformer. Only called in LoRA training mode.""" | |
| logger.debug(f"Adding LoRA adapter with rank {self._config.lora.rank}") | |
| lora_config = LoraConfig( | |
| r=self._config.lora.rank, | |
| lora_alpha=self._config.lora.alpha, | |
| target_modules=self._config.lora.target_modules, | |
| lora_dropout=self._config.lora.dropout, | |
| init_lora_weights=True, | |
| ) | |
| # Wrap the transformer with PEFT to add LoRA layers | |
| # noinspection PyTypeChecker | |
| self._transformer = get_peft_model(self._transformer, lora_config) | |
| def _load_checkpoint(self) -> None: | |
| """Load checkpoint if specified in config.""" | |
| if not self._config.model.load_checkpoint: | |
| return | |
| checkpoint_path = self._find_checkpoint(self._config.model.load_checkpoint) | |
| if not checkpoint_path: | |
| logger.warning(f"⚠️ Could not find checkpoint at {self._config.model.load_checkpoint}") | |
| return | |
| logger.info(f"📥 Loading checkpoint from {checkpoint_path}") | |
| if self._config.model.training_mode == "full": | |
| self._load_full_checkpoint(checkpoint_path) | |
| else: # LoRA mode | |
| self._load_lora_checkpoint(checkpoint_path) | |
| def _load_full_checkpoint(self, checkpoint_path: Path) -> None: | |
| """Load full model checkpoint.""" | |
| state_dict = load_file(checkpoint_path) | |
| self._transformer.load_state_dict(state_dict, strict=True) | |
| logger.info("✅ Full model checkpoint loaded successfully") | |
| def _load_lora_checkpoint(self, checkpoint_path: Path) -> None: | |
| """Load LoRA checkpoint with DDP/FSDP compatibility.""" | |
| state_dict = load_file(checkpoint_path) | |
| # Adjust layer names to match internal format. | |
| # (Weights are saved in ComfyUI-compatible format, with "diffusion_model." prefix) | |
| state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()} | |
| # Load LoRA weights and verify all weights were loaded | |
| base_model = self._transformer.get_base_model() | |
| set_peft_model_state_dict(base_model, state_dict) | |
| logger.info("✅ LoRA checkpoint loaded successfully") | |
| def _prepare_models_for_training(self) -> None: | |
| """Prepare models for training with Accelerate.""" | |
| # For FSDP + LoRA: Cast entire model to FP32. | |
| # FSDP requires uniform dtype across all parameters in wrapped modules. | |
| # In LoRA mode, PEFT creates LoRA params in FP32 while base model is BF16. | |
| # We cast the base model to FP32 to match the LoRA params. | |
| if self._accelerator.distributed_type == DistributedType.FSDP and self._config.model.training_mode == "lora": | |
| logger.debug("FSDP: casting transformer to FP32 for uniform dtype") | |
| self._transformer = self._transformer.to(dtype=torch.float32) | |
| # Enable gradient checkpointing if requested | |
| # For PeftModel, we need to access the underlying base model | |
| transformer = ( | |
| self._transformer.get_base_model() if hasattr(self._transformer, "get_base_model") else self._transformer | |
| ) | |
| transformer.set_gradient_checkpointing(self._config.optimization.enable_gradient_checkpointing) | |
| # Keep frozen models on CPU for memory efficiency | |
| self._vae_decoder = self._vae_decoder.to("cpu") | |
| if self._vae_encoder is not None: | |
| self._vae_encoder = self._vae_encoder.to("cpu") | |
| # Embedding connectors are already on GPU from _load_text_encoder_and_cache_embeddings | |
| # noinspection PyTypeChecker | |
| self._transformer = self._accelerator.prepare(self._transformer) | |
| # Log GPU memory usage after model preparation | |
| vram_usage_gb = torch.cuda.memory_allocated() / 1024**3 | |
| logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB") | |
| def _find_checkpoint(checkpoint_path: str | Path) -> Path | None: | |
| """Find the checkpoint file to load, handling both file and directory paths.""" | |
| checkpoint_path = Path(checkpoint_path) | |
| if checkpoint_path.is_file(): | |
| if not checkpoint_path.suffix == ".safetensors": | |
| raise ValueError(f"Checkpoint file must have a .safetensors extension: {checkpoint_path}") | |
| return checkpoint_path | |
| if checkpoint_path.is_dir(): | |
| # Look for checkpoint files in the directory | |
| checkpoints = list(checkpoint_path.rglob("*step_*.safetensors")) | |
| if not checkpoints: | |
| return None | |
| # Sort by step number and return the latest | |
| def _get_step_num(p: Path) -> int: | |
| try: | |
| return int(p.stem.split("step_")[1]) | |
| except (IndexError, ValueError): | |
| return -1 | |
| latest = max(checkpoints, key=_get_step_num) | |
| return latest | |
| else: | |
| raise ValueError(f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory.") | |
| def _init_dataloader(self) -> None: | |
| """Initialize the training data loader using the strategy's data sources.""" | |
| if self._dataset is None: | |
| # Get data sources from the training strategy | |
| data_sources = self._training_strategy.get_data_sources() | |
| self._dataset = PrecomputedDataset(self._config.data.preprocessed_data_root, data_sources=data_sources) | |
| logger.debug(f"Loaded dataset with {len(self._dataset):,} samples from sources: {list(data_sources)}") | |
| num_workers = self._config.data.num_dataloader_workers | |
| dataloader = DataLoader( | |
| self._dataset, | |
| batch_size=self._config.optimization.batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| num_workers=num_workers, | |
| pin_memory=num_workers > 0, | |
| persistent_workers=num_workers > 0, | |
| ) | |
| self._dataloader = self._accelerator.prepare(dataloader) | |
| def _init_lora_weights(self) -> None: | |
| """Initialize LoRA weights for the transformer.""" | |
| logger.debug("Initializing LoRA weights...") | |
| for _, module in self._transformer.named_modules(): | |
| if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): | |
| module.reset_lora_parameters(adapter_name="default", init_lora_weights=True) | |
| def _init_optimizer(self) -> None: | |
| """Initialize the optimizer and learning rate scheduler.""" | |
| opt_cfg = self._config.optimization | |
| lr = opt_cfg.learning_rate | |
| if opt_cfg.optimizer_type == "adamw": | |
| optimizer = AdamW(self._trainable_params, lr=lr) | |
| elif opt_cfg.optimizer_type == "adamw8bit": | |
| # noinspection PyUnresolvedReferences | |
| from bitsandbytes.optim import AdamW8bit # noqa: PLC0415 | |
| optimizer = AdamW8bit(self._trainable_params, lr=lr) | |
| else: | |
| raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}") | |
| # Add scheduler initialization | |
| lr_scheduler = self._create_scheduler(optimizer) | |
| # noinspection PyTypeChecker | |
| self._optimizer, self._lr_scheduler = self._accelerator.prepare(optimizer, lr_scheduler) | |
| def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None: | |
| """Create learning rate scheduler based on config.""" | |
| scheduler_type = self._config.optimization.scheduler_type | |
| steps = self._config.optimization.steps | |
| params = self._config.optimization.scheduler_params or {} | |
| if scheduler_type is None: | |
| return None | |
| if scheduler_type == "linear": | |
| scheduler = LinearLR( | |
| optimizer, | |
| start_factor=params.pop("start_factor", 1.0), | |
| end_factor=params.pop("end_factor", 0.1), | |
| total_iters=steps, | |
| **params, | |
| ) | |
| elif scheduler_type == "cosine": | |
| scheduler = CosineAnnealingLR( | |
| optimizer, | |
| T_max=steps, | |
| eta_min=params.pop("eta_min", 0), | |
| **params, | |
| ) | |
| elif scheduler_type == "cosine_with_restarts": | |
| scheduler = CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=params.pop("T_0", steps // 4), # First restart cycle length | |
| T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths | |
| eta_min=params.pop("eta_min", 5e-5), | |
| **params, | |
| ) | |
| elif scheduler_type == "polynomial": | |
| scheduler = PolynomialLR( | |
| optimizer, | |
| total_iters=steps, | |
| power=params.pop("power", 1.0), | |
| **params, | |
| ) | |
| elif scheduler_type == "step": | |
| scheduler = StepLR( | |
| optimizer, | |
| step_size=params.pop("step_size", steps // 2), | |
| gamma=params.pop("gamma", 0.1), | |
| **params, | |
| ) | |
| elif scheduler_type == "constant": | |
| scheduler = None | |
| else: | |
| raise ValueError(f"Unknown scheduler type: {scheduler_type}") | |
| return scheduler | |
| def _setup_accelerator(self) -> None: | |
| """Initialize the Accelerator with the appropriate settings.""" | |
| # All distributed setup (DDP/FSDP, number of processes, etc.) is controlled by | |
| # the user's Accelerate configuration (accelerate config / accelerate launch). | |
| self._accelerator = Accelerator( | |
| mixed_precision=self._config.acceleration.mixed_precision_mode, | |
| gradient_accumulation_steps=self._config.optimization.gradient_accumulation_steps, | |
| ) | |
| if self._accelerator.num_processes > 1: | |
| logger.info( | |
| f"{self._accelerator.distributed_type.value} distributed training enabled " | |
| f"with {self._accelerator.num_processes} processes" | |
| ) | |
| local_batch = self._config.optimization.batch_size | |
| global_batch = self._config.optimization.batch_size * self._accelerator.num_processes | |
| logger.info(f"Local batch size: {local_batch}, global batch size: {global_batch}") | |
| # Log torch.compile status from Accelerate's dynamo plugin | |
| is_compile_enabled = ( | |
| hasattr(self._accelerator.state, "dynamo_plugin") and self._accelerator.state.dynamo_plugin.backend != "NO" | |
| ) | |
| if is_compile_enabled: | |
| plugin = self._accelerator.state.dynamo_plugin | |
| logger.info(f"🔥 torch.compile enabled via Accelerate: backend={plugin.backend}, mode={plugin.mode}") | |
| if self._accelerator.distributed_type == DistributedType.FSDP: | |
| logger.warning( | |
| "⚠️ FSDP + torch.compile is experimental and may hang on the first training iteration. " | |
| "If this occurs, disable torch.compile by removing dynamo_config from your Accelerate config." | |
| ) | |
| if self._accelerator.distributed_type == DistributedType.FSDP and self._config.acceleration.quantization: | |
| logger.warning( | |
| f"FSDP with quantization ({self._config.acceleration.quantization}) may have compatibility issues." | |
| "Monitor training stability and consider disabling quantization if issues arise." | |
| ) | |
| # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation | |
| def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None: | |
| """Run validation by generating videos from validation prompts.""" | |
| use_images = self._config.validation.images is not None | |
| use_reference_videos = self._config.validation.reference_videos is not None | |
| generate_audio = self._config.validation.generate_audio | |
| inference_steps = self._config.validation.inference_steps | |
| # Free up GPU memory before validation sampling. | |
| # Zero gradients and empty the cache to reclaim memory. | |
| self._optimizer.zero_grad(set_to_none=True) | |
| torch.cuda.empty_cache() | |
| # Start sampling progress tracking | |
| sampling_ctx = progress.start_sampling( | |
| num_prompts=len(self._config.validation.prompts), | |
| num_steps=inference_steps, | |
| ) | |
| # Create validation sampler with loaded models and progress tracking | |
| sampler = ValidationSampler( | |
| transformer=self._transformer, | |
| vae_decoder=self._vae_decoder, | |
| vae_encoder=self._vae_encoder, | |
| text_encoder=None, | |
| audio_decoder=self._audio_vae if generate_audio else None, | |
| vocoder=self._vocoder if generate_audio else None, | |
| sampling_context=sampling_ctx, | |
| ) | |
| output_dir = Path(self._config.output_dir) / "samples" | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| video_paths = [] | |
| width, height, num_frames = self._config.validation.video_dims | |
| for prompt_idx, prompt in enumerate(self._config.validation.prompts): | |
| # Update progress to show current video | |
| sampling_ctx.start_video(prompt_idx) | |
| # Load conditioning image if provided | |
| condition_image = None | |
| if use_images: | |
| image_path = self._config.validation.images[prompt_idx] | |
| image = open_image_as_srgb(image_path) | |
| # Convert PIL image to tensor [C, H, W] in [0, 1] | |
| condition_image = F.to_tensor(image) | |
| # Load reference video if provided (for IC-LoRA) | |
| reference_video = None | |
| if use_reference_videos: | |
| ref_video_path = self._config.validation.reference_videos[prompt_idx] | |
| # read_video returns [F, C, H, W] in [0, 1] | |
| reference_video, _ = read_video(ref_video_path, max_frames=num_frames) | |
| # Get cached embeddings for this prompt if available | |
| cached_embeddings = ( | |
| self._cached_validation_embeddings[prompt_idx] | |
| if self._cached_validation_embeddings is not None | |
| else None | |
| ) | |
| # Create generation config | |
| gen_config = GenerationConfig( | |
| prompt=prompt, | |
| negative_prompt=self._config.validation.negative_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| frame_rate=self._config.validation.frame_rate, | |
| num_inference_steps=inference_steps, | |
| guidance_scale=self._config.validation.guidance_scale, | |
| seed=self._config.validation.seed, | |
| condition_image=condition_image, | |
| reference_video=reference_video, | |
| generate_audio=generate_audio, | |
| include_reference_in_output=self._config.validation.include_reference_in_output, | |
| cached_embeddings=cached_embeddings, | |
| stg_scale=self._config.validation.stg_scale, | |
| stg_blocks=self._config.validation.stg_blocks, | |
| stg_mode=self._config.validation.stg_mode, | |
| ) | |
| # Generate sample | |
| video, audio = sampler.generate( | |
| config=gen_config, | |
| device=self._accelerator.device, | |
| ) | |
| # Save video | |
| if IS_MAIN_PROCESS: | |
| video_path = output_dir / f"step_{self._global_step:06d}_{prompt_idx + 1}.mp4" | |
| save_video( | |
| video_tensor=video, | |
| output_path=video_path, | |
| fps=self._config.validation.frame_rate, | |
| audio=audio, | |
| audio_sample_rate=self._vocoder.output_sample_rate if audio is not None else None, | |
| ) | |
| video_paths.append(video_path) | |
| # Clean up progress tasks | |
| sampling_ctx.cleanup() | |
| # Clear GPU cache after validation | |
| torch.cuda.empty_cache() | |
| rel_outputs_path = output_dir.relative_to(self._config.output_dir) | |
| logger.info(f"🎥 Validation samples for step {self._global_step} saved in {rel_outputs_path}") | |
| return video_paths | |
| def _log_training_stats(stats: TrainingStats) -> None: | |
| """Log training statistics.""" | |
| stats_str = ( | |
| "📊 Training Statistics:\n" | |
| f" - Total time: {stats.total_time_seconds / 60:.1f} minutes\n" | |
| f" - Training speed: {stats.steps_per_second:.2f} steps/second\n" | |
| f" - Samples/second: {stats.samples_per_second:.2f}\n" | |
| f" - Peak GPU memory: {stats.peak_gpu_memory_gb:.2f} GB" | |
| ) | |
| if stats.num_processes > 1: | |
| stats_str += f"\n - Number of processes: {stats.num_processes}\n" | |
| stats_str += f" - Global batch size: {stats.global_batch_size}" | |
| logger.info(stats_str) | |
| def _save_checkpoint(self) -> Path | None: | |
| """Save the model weights.""" | |
| is_lora = self._config.model.training_mode == "lora" | |
| is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP | |
| # Prepare paths | |
| save_dir = Path(self._config.output_dir) / "checkpoints" | |
| prefix = "lora" if is_lora else "model" | |
| filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors" | |
| saved_weights_path = save_dir / filename | |
| # Get state dict (collective operation - all processes must participate) | |
| self._accelerator.wait_for_everyone() | |
| full_state_dict = self._accelerator.get_state_dict(self._transformer) | |
| if not IS_MAIN_PROCESS: | |
| return None | |
| save_dir.mkdir(exist_ok=True, parents=True) | |
| # For LoRA: extract only adapter weights; for full: use as-is | |
| if is_lora: | |
| unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False) | |
| # For FSDP, pass full_state_dict since model params aren't directly accessible | |
| state_dict = get_peft_model_state_dict(unwrapped, state_dict=full_state_dict if is_fsdp else None) | |
| # Remove "base_model.model." prefix added by PEFT | |
| state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()} | |
| # Convert to ComfyUI-compatible format (add "diffusion_model." prefix) | |
| state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()} | |
| # Save to disk | |
| save_file(state_dict, saved_weights_path) | |
| else: | |
| # Save to disk | |
| self._accelerator.save(full_state_dict, saved_weights_path) | |
| rel_path = saved_weights_path.relative_to(self._config.output_dir) | |
| logger.info(f"💾 {prefix.capitalize()} weights for step {self._global_step} saved in {rel_path}") | |
| # Keep track of checkpoint paths, and cleanup old checkpoints if needed | |
| self._checkpoint_paths.append(saved_weights_path) | |
| self._cleanup_checkpoints() | |
| return saved_weights_path | |
| def _cleanup_checkpoints(self) -> None: | |
| """Clean up old checkpoints.""" | |
| if 0 < self._config.checkpoints.keep_last_n < len(self._checkpoint_paths): | |
| checkpoints_to_remove = self._checkpoint_paths[: -self._config.checkpoints.keep_last_n] | |
| for old_checkpoint in checkpoints_to_remove: | |
| if old_checkpoint.exists(): | |
| old_checkpoint.unlink() | |
| logger.info(f"Removed old checkpoints: {old_checkpoint}") | |
| # Update the list to only contain kept checkpoints | |
| self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :] | |
| def _save_config(self) -> None: | |
| """Save the training configuration as a YAML file in the output directory.""" | |
| if not IS_MAIN_PROCESS: | |
| return | |
| config_path = Path(self._config.output_dir) / "training_config.yaml" | |
| with open(config_path, "w") as f: | |
| yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2) | |
| logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}") | |
| def _init_wandb(self) -> None: | |
| """Initialize Weights & Biases run.""" | |
| if not self._config.wandb.enabled or not IS_MAIN_PROCESS: | |
| self._wandb_run = None | |
| return | |
| wandb_config = self._config.wandb | |
| run = wandb.init( | |
| project=wandb_config.project, | |
| entity=wandb_config.entity, | |
| name=Path(self._config.output_dir).name, | |
| tags=wandb_config.tags, | |
| config=self._config.model_dump(), | |
| ) | |
| self._wandb_run = run | |
| def _log_metrics(self, metrics: dict[str, float]) -> None: | |
| """Log metrics to Weights & Biases.""" | |
| if self._wandb_run is not None: | |
| self._wandb_run.log(metrics) | |
| def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None: | |
| """Log validation videos to Weights & Biases.""" | |
| if not self._config.wandb.log_validation_videos or self._wandb_run is None: | |
| return | |
| # Create lists of videos with their captions | |
| validation_videos = [ | |
| wandb.Video(str(video_path), caption=prompt, format="mp4") | |
| for video_path, prompt in zip(video_paths, prompts, strict=False) | |
| ] | |
| # Log all videos at once | |
| self._wandb_run.log( | |
| { | |
| "validation_videos": validation_videos, | |
| }, | |
| step=self._global_step, | |
| ) | |