Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """ | |
| Audio-Only IC-LoRA Training for Voice Cloning on LTX-2.3. | |
| Uses the IC-LoRA pattern: reference audio tokens are APPENDED to the end of | |
| the target sequence using AudioConditionByReferenceLatent. Loss is computed | |
| only on target tokens; reference tokens remain clean (denoise_mask=0). | |
| This follows the official video-to-video IC-LoRA strategy closely, but adapted | |
| for the audio-only modality path. | |
| Usage (single GPU): | |
| CUDA_VISIBLE_DEVICES=0 python train_audio_iclora.py --data-dir ... --speaker-index ... | |
| Usage (multi-GPU with accelerate): | |
| CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --num_processes=4 train_audio_iclora.py ... | |
| """ | |
| import argparse | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import shutil | |
| import sys | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2")) | |
| # ltx-pipelines already on path via ltx2/ | |
| MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| # Import audio conditioning item from our module | |
| sys.path.insert(0, MODEL_DIR) | |
| from audio_conditioning import AudioConditionByReferenceLatent | |
| # βββ Timestep Sampling βββ | |
| class DistilledTimestepSampler: | |
| """Sample timesteps from the distilled sigma schedule. | |
| The distilled model was trained to denoise at these specific sigma values. | |
| We sample uniformly from the intervals between consecutive sigmas, | |
| matching the distribution the model actually operates on. | |
| """ | |
| # Distilled 8-step sigma values (boundaries of denoising intervals) | |
| SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] | |
| def __init__(self, jitter: float = 0.02): | |
| self.jitter = jitter | |
| def sample(self, batch_size: int, seq_length: int = None, device: torch.device = None) -> torch.Tensor: | |
| n_intervals = len(self.SIGMAS) - 1 | |
| interval_idx = torch.randint(0, n_intervals, (batch_size,), device=device) | |
| t = torch.rand(batch_size, device=device) | |
| sigma_high = torch.tensor([self.SIGMAS[i] for i in interval_idx], device=device) | |
| sigma_low = torch.tensor([self.SIGMAS[i + 1] for i in interval_idx], device=device) | |
| sigma = sigma_low + t * (sigma_high - sigma_low) | |
| return sigma.clamp(0.01, 0.99) | |
| class ShiftedLogitNormalTimestepSampler: | |
| """Shifted logit-normal distribution, shift depends on sequence length.""" | |
| def __init__(self, std: float = 1.0, eps: float = 1e-3, uniform_prob: float = 0.1): | |
| self.std = std | |
| self.eps = eps | |
| self.uniform_prob = uniform_prob | |
| self.normal_999_percentile = 3.0902 * std | |
| self.normal_005_percentile = -2.5758 * std | |
| def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor: | |
| mu = self._get_shift(seq_length) | |
| normal = torch.randn(batch_size, device=device) * self.std + mu | |
| logitnormal = torch.sigmoid(normal) | |
| p999 = torch.sigmoid(torch.tensor(mu + self.normal_999_percentile, device=device)) | |
| p005 = torch.sigmoid(torch.tensor(mu + self.normal_005_percentile, device=device)) | |
| stretched = (logitnormal - p005) / (p999 - p005) | |
| stretched = torch.where(stretched >= self.eps, stretched, 2 * self.eps - stretched) | |
| stretched = stretched.clamp(0, 1) | |
| uniform = (1 - self.eps) * torch.rand(batch_size, device=device) + self.eps | |
| prob = torch.rand(batch_size, device=device) | |
| return torch.where(prob > self.uniform_prob, stretched, uniform) | |
| def _get_shift(seq_length, min_tok=1024, max_tok=4096, min_s=0.95, max_s=2.05): | |
| m = (max_s - min_s) / (max_tok - min_tok) | |
| return m * seq_length + (min_s - m * min_tok) | |
| # βββ Dataset βββ | |
| def build_speaker_map(index_paths, data_dirs): | |
| """Map speaker β [(data_dir, sample_idx)] from index file(s). | |
| The sample index comes from field 0 of the `~`-delimited row when it | |
| parses as int (allows subset indexes that keep original sample numbers), | |
| otherwise we fall back to the row's line number (legacy behaviour for | |
| string-keyed indexes like tts_training_data_podcast). | |
| """ | |
| speaker_to_samples = defaultdict(list) | |
| for index_path, data_dir in zip(index_paths, data_dirs): | |
| with open(index_path) as f: | |
| for line_num, line in enumerate(f): | |
| parts = line.strip().split("~") | |
| if len(parts) < 7: | |
| continue | |
| try: | |
| idx = int(parts[0]) | |
| except ValueError: | |
| idx = line_num | |
| speaker_id = parts[1] | |
| speaker_to_samples[speaker_id].append((data_dir, idx)) | |
| return {k: v for k, v in speaker_to_samples.items() if len(v) >= 2} | |
| class IDLoRADataset(Dataset): | |
| # Silence-latent reference loaded once, used to detect and strip any | |
| # leading silence frames baked into the preprocessed audio_latents. The | |
| # training loop ALREADY prepends 0-25 random silence frames, so we don't | |
| # want accidental silence in the source data compounding on top. | |
| _silence_ref = None | |
| def _load_silence_ref(cls): | |
| if cls._silence_ref is None: | |
| p = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "assets", "silence_latent_frame.pt") | |
| if os.path.exists(p): | |
| cls._silence_ref = torch.load(p, weights_only=True).float().squeeze() # [C, F] | |
| return cls._silence_ref | |
| def __init__(self, speaker_map): | |
| self.samples = [] | |
| self.speaker_map = {} | |
| for speaker, entries in speaker_map.items(): | |
| valid = [] | |
| for data_dir, idx in entries: | |
| audio_path = Path(data_dir) / "audio_latents" / f"sample_{idx:06d}.pt" | |
| cond_path = Path(data_dir) / "conditions" / f"sample_{idx:06d}.pt" | |
| if audio_path.exists() and cond_path.exists(): | |
| valid.append((data_dir, idx)) | |
| if len(valid) >= 2: | |
| self.speaker_map[speaker] = valid | |
| for speaker, entries in self.speaker_map.items(): | |
| for entry in entries: | |
| self.samples.append((entry, speaker)) | |
| IDLoRADataset._load_silence_ref() | |
| def __len__(self): | |
| return len(self.samples) | |
| def _load_sample(self, data_dir, idx): | |
| base = Path(data_dir) | |
| audio = torch.load(base / "audio_latents" / f"sample_{idx:06d}.pt", weights_only=False) | |
| # Prefer prefix-stripped text embeddings if they exist (re-encoded with | |
| # just the quoted dialogue, dropping the "A woman says, " / "A man | |
| # speaks with X accent, " scene-description prefix). | |
| stripped = base / "conditions_stripped" / f"sample_{idx:06d}.pt" | |
| cond_path = stripped if stripped.exists() else base / "conditions" / f"sample_{idx:06d}.pt" | |
| cond = torch.load(cond_path, weights_only=False) | |
| if isinstance(audio, dict): | |
| audio = audio.get("audio_latent", audio.get("latent", list(audio.values())[0])) | |
| if audio.dim() == 2: | |
| audio = audio.unsqueeze(0) | |
| audio_feats = cond.get("audio_prompt_embeds", cond.get("prompt_embeds")) | |
| attn_mask = cond.get("prompt_attention_mask") | |
| # The audio_connector has num_learnable_registers=128 and asserts the | |
| # input sequence length is divisible by 128. Our new preprocessing | |
| # saved trimmed conditions (dropping left-padding to save disk), which | |
| # produces short/irregular sequence lengths. Left-pad back to the next | |
| # multiple of 128 with zeros (matching the tokenizer's left-padding | |
| # convention) so this assertion holds. | |
| REG = 128 | |
| L = audio_feats.shape[0] | |
| target_L = ((L + REG - 1) // REG) * REG | |
| if target_L != L: | |
| pad_len = target_L - L | |
| pad_emb = torch.zeros(pad_len, audio_feats.shape[1], | |
| dtype=audio_feats.dtype) | |
| pad_mask = torch.zeros(pad_len, dtype=attn_mask.dtype) | |
| audio_feats = torch.cat([pad_emb, audio_feats], dim=0) | |
| attn_mask = torch.cat([pad_mask, attn_mask], dim=0) | |
| return audio, audio_feats, attn_mask | |
| def __getitem__(self, idx): | |
| (data_dir, tgt_idx), speaker = self.samples[idx] | |
| tgt_latent, audio_feats, attn_mask = self._load_sample(data_dir, tgt_idx) | |
| # Drop the reference entirely for non-voice-cloning categories: | |
| # - SFX samples (speaker starts with "sfx_"): descriptive sound events, | |
| # no speaker identity to clone. | |
| # - Song/music samples (suno dataset): prompts describe the music style, | |
| # reference audio doesn't transfer anything useful. | |
| # Return a zero-length ref so the model trains target-only for these. | |
| drop_ref = speaker.startswith("sfx_") or "preprocessed_ltx_suno" in str(data_dir) | |
| if drop_ref: | |
| C, F_dim = tgt_latent.shape[0], tgt_latent.shape[2] | |
| ref_latent = torch.zeros(C, 0, F_dim, dtype=tgt_latent.dtype) | |
| else: | |
| entries = self.speaker_map[speaker] | |
| ref_entry = random.choice([e for e in entries if e[1] != tgt_idx]) | |
| ref_latent, _, _ = self._load_sample(*ref_entry) | |
| return { | |
| "tgt_latent": tgt_latent, | |
| "ref_latent": ref_latent, | |
| "audio_features": audio_feats, | |
| "attention_mask": attn_mask, | |
| } | |
| # βββ Model building βββ | |
| def build_audio_only_model(checkpoint_path, device, dtype): | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder | |
| from ltx_core.loader.registry import DummyRegistry | |
| from ltx_core.loader.sd_ops import SDOps | |
| from ltx_core.model.transformer.model import LTXModel, LTXModelType | |
| from ltx_core.model.model_protocol import ModelConfigurator | |
| from ltx_core.model.transformer.attention import AttentionFunction | |
| from ltx_core.model.transformer.rope import LTXRopeType | |
| sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement("model.diffusion_model.", "") | |
| class Cfg(ModelConfigurator[LTXModel]): | |
| def from_config(cls, config): | |
| t = config.get("transformer", {}) | |
| cp = None | |
| if not t.get("caption_proj_before_connector", False): | |
| from ltx_core.model.transformer.text_projection import create_caption_projection | |
| with torch.device("meta"): | |
| cp = create_caption_projection(t, audio=True) | |
| return LTXModel( | |
| model_type=LTXModelType.AudioOnly, | |
| audio_num_attention_heads=t.get("audio_num_attention_heads", 32), | |
| audio_attention_head_dim=t.get("audio_attention_head_dim", 64), | |
| audio_in_channels=t.get("audio_in_channels", 128), | |
| audio_out_channels=t.get("audio_out_channels", 128), | |
| num_layers=t.get("num_layers", 48), | |
| audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048), | |
| norm_eps=t.get("norm_eps", 1e-6), | |
| attention_type=AttentionFunction(t.get("attention_type", "default")), | |
| positional_embedding_theta=t.get("positional_embedding_theta", 10000.0), | |
| audio_positional_embedding_max_pos=t.get("audio_positional_embedding_max_pos", [20]), | |
| timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000), | |
| use_middle_indices_grid=t.get("use_middle_indices_grid", True), | |
| rope_type=LTXRopeType(t.get("rope_type", "interleaved")), | |
| double_precision_rope=t.get("frequencies_precision", False) == "float64", | |
| apply_gated_attention=t.get("apply_gated_attention", False), | |
| audio_caption_projection=cp, | |
| cross_attention_adaln=t.get("cross_attention_adaln", False), | |
| ) | |
| builder = Builder(model_path=checkpoint_path, model_class_configurator=Cfg, | |
| model_sd_ops=sd_ops, registry=DummyRegistry()) | |
| return builder.build(device=device, dtype=dtype) | |
| def load_audio_connector(checkpoint_path, device, dtype): | |
| # ltx-trainer already on path via ltx2/ | |
| from ltx_trainer.model_loader import load_embeddings_processor | |
| emb_proc = load_embeddings_processor(checkpoint_path, device=device, dtype=dtype) | |
| connector = emb_proc.audio_connector | |
| del emb_proc | |
| return connector | |
| def apply_lora(model, rank, alpha, dropout=0.0): | |
| from peft import LoraConfig, get_peft_model | |
| config = LoraConfig( | |
| r=rank, lora_alpha=alpha, lora_dropout=dropout, bias="none", | |
| target_modules=[ | |
| # Self-attention over audio tokens (voice-transfer pathway via ref). | |
| "audio_attn1.to_k", "audio_attn1.to_q", "audio_attn1.to_v", "audio_attn1.to_out.0", | |
| # Cross-attention (audio β text context) NOT adapted β keep base | |
| # model's promptβaudio behaviour intact and rely on dataset balance | |
| # to drive expressiveness. (v15c tried this with adaLN unfreeze, | |
| # that proved too destructive; v16 tries it adaLN-frozen.) | |
| # FFN β non-linear capacity for style/phonetic adaptation. | |
| "audio_ff.net.0.proj", "audio_ff.net.2", | |
| ], | |
| ) | |
| model = get_peft_model(model, config) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| logging.info(f"LoRA: {trainable:,} trainable / {total:,} total ({100*trainable/total:.1f}%)") | |
| return model | |
| def prepare_audio_context(audio_connector, audio_features, attention_mask, device, dtype): | |
| from ltx_core.text_encoders.gemma.embeddings_processor import convert_to_additive_mask | |
| audio_features = audio_features.to(device=device, dtype=dtype) | |
| attention_mask = attention_mask.to(device=device) | |
| if audio_features.shape[0] > 1: | |
| results = [] | |
| for i in range(audio_features.shape[0]): | |
| feat_i = audio_features[i:i+1] | |
| mask_i = attention_mask[i:i+1] | |
| additive = convert_to_additive_mask(mask_i, feat_i.dtype) | |
| enc_i, _ = audio_connector(feat_i, additive) | |
| results.append(enc_i) | |
| return torch.cat(results, dim=0) | |
| additive_mask = convert_to_additive_mask(attention_mask, audio_features.dtype) | |
| audio_encoded, _ = audio_connector(audio_features, additive_mask) | |
| return audio_encoded | |
| # βββ Validation βββ | |
| def _unwrap_model_safe(model): | |
| """Strip DDP / peft wrappers without going through accelerate.unwrap_model, | |
| which imports deepspeed β broken in our env (torch API drift).""" | |
| while hasattr(model, "module"): | |
| model = model.module | |
| return model | |
| def run_validation(lora_path, val_config_path, output_dir, step, lora_rank=128): | |
| """Call validate.py in a subprocess. It loads TTSServer (the same stack | |
| the warm server / Gradio app uses), attaches our LoRA, then iterates every | |
| entry in val_config with the same inference settings the user tests with. | |
| Single subprocess amortises the model-load cost across all val entries. | |
| Forces validation onto VAL_GPU (default "0") because training already | |
| occupies the rest. Override via TRAIN_VAL_GPU env var. | |
| """ | |
| import subprocess | |
| val_dir = os.path.join(output_dir, "validation", f"step_{step:05d}") | |
| os.makedirs(val_dir, exist_ok=True) | |
| script = os.path.join(os.path.dirname(__file__), "validate.py") | |
| cmd = [ | |
| sys.executable, script, | |
| "--val-config", val_config_path, | |
| "--output-dir", val_dir, | |
| "--lora", lora_path, | |
| "--lora-rank", str(lora_rank), | |
| # Use raw estimator output (no +10% buffer) so we can hear | |
| # whether the model needs more/less duration at current quality. | |
| "--duration-multiplier", "1.0", | |
| ] | |
| log_path = os.path.join(val_dir, "validate.log") | |
| env = os.environ.copy() | |
| # Validation needs its OWN GPU (training fills the others). | |
| env["CUDA_VISIBLE_DEVICES"] = os.environ.get("TRAIN_VAL_GPU", "0") | |
| try: | |
| with open(log_path, "w") as logf: | |
| result = subprocess.run( | |
| cmd, stdout=logf, stderr=subprocess.STDOUT, timeout=1800, env=env, | |
| ) | |
| if result.returncode == 0: | |
| logging.info(f" Validation step {step}: OK β {val_dir}") | |
| else: | |
| logging.warning(f" Validation step {step} FAILED (see {log_path})") | |
| except subprocess.TimeoutExpired: | |
| logging.warning(f" Validation step {step} TIMEOUT (>30min)") | |
| # βββ Args βββ | |
| def parse_args(): | |
| # First pass: pull out --config so its values can become argparse defaults. | |
| cfg_parser = argparse.ArgumentParser(add_help=False) | |
| cfg_parser.add_argument("--config", default=None, | |
| help="YAML file with default values for any of the flags below. " | |
| "Explicit CLI flags still override the YAML.") | |
| cfg_args, remaining = cfg_parser.parse_known_args() | |
| yaml_defaults: dict = {} | |
| if cfg_args.config: | |
| import yaml as _yaml | |
| with open(cfg_args.config) as f: | |
| yaml_defaults = _yaml.safe_load(f) or {} | |
| # YAML keys are dashes-or-underscores β normalize to argparse dest (underscore). | |
| yaml_defaults = {k.replace("-", "_"): v for k, v in yaml_defaults.items()} | |
| def _yaml(name, fallback): | |
| return yaml_defaults.get(name, fallback) | |
| p = argparse.ArgumentParser( | |
| parents=[cfg_parser], | |
| description="Audio-Only IC-LoRA Training for Voice Cloning", | |
| ) | |
| p.add_argument("--data-dir", required="data_dir" not in yaml_defaults, | |
| nargs="+", default=_yaml("data_dir", None)) | |
| p.add_argument("--speaker-index", required="speaker_index" not in yaml_defaults, | |
| nargs="+", default=_yaml("speaker_index", None)) | |
| p.add_argument("--output-dir", default=_yaml("output_dir", os.path.join(MODEL_DIR, "tts_iclora_v1"))) | |
| p.add_argument("--checkpoint", default=_yaml("checkpoint", os.path.join(MODEL_DIR, "dramabox-dit-v1.safetensors"))) | |
| p.add_argument("--full-checkpoint", default=_yaml("full_checkpoint", os.path.join(MODEL_DIR, "dramabox-audio-components.safetensors"))) | |
| p.add_argument("--base-model", choices=["distilled", "dev"], default=_yaml("base_model", "dev"), | |
| help="Base model type: distilled uses DistilledTimestepSampler, dev uses ShiftedLogitNormal") | |
| p.add_argument("--lora-rank", type=int, default=_yaml("lora_rank", 128)) | |
| p.add_argument("--lora-alpha", type=int, default=_yaml("lora_alpha", 128)) | |
| p.add_argument("--lora-dropout", type=float, default=_yaml("lora_dropout", 0.0), | |
| help="Dropout applied to LoRA A/B matrices during training. " | |
| "Recommended ~0.1 for small datasets to regularize.") | |
| p.add_argument("--resume-lora", default=_yaml("resume_lora", None)) | |
| p.add_argument("--resume-step-offset", type=int, default=_yaml("resume_step_offset", None), | |
| help="Step to add when naming saved checkpoints. If None, inferred " | |
| "from --resume-lora filename (e.g. lora_step_10000.safetensors β 10000). " | |
| "Set to 0 to start numbering at 0 regardless.") | |
| p.add_argument("--ref-ratio", type=float, default=_yaml("ref_ratio", 0.3), | |
| help="Fraction of target length to use as reference (default 0.3)") | |
| p.add_argument("--max-ref-tokens", type=int, default=_yaml("max_ref_tokens", 200), | |
| help="Maximum reference tokens after patchification (default 200)") | |
| p.add_argument("--text-dropout", type=float, default=_yaml("text_dropout", 0.0), | |
| help="Probability of dropping text conditioning (forces reliance on voice ref)") | |
| p.add_argument("--steps", type=int, default=_yaml("steps", 30000)) | |
| p.add_argument("--lr", type=float, default=_yaml("lr", 3e-5)) | |
| p.add_argument("--lr-scheduler", choices=["cosine", "linear", "constant"], default=_yaml("lr_scheduler", "cosine")) | |
| p.add_argument("--batch-size", type=int, default=_yaml("batch_size", 1)) | |
| p.add_argument("--grad-accum", type=int, default=_yaml("grad_accum", 4)) | |
| p.add_argument("--max-grad-norm", type=float, default=_yaml("max_grad_norm", 1.0)) | |
| p.add_argument("--save-every", type=int, default=_yaml("save_every", 1000)) | |
| p.add_argument("--log-every", type=int, default=_yaml("log_every", 50)) | |
| p.add_argument("--seed", type=int, default=_yaml("seed", 42)) | |
| p.add_argument("--warmup-steps", type=int, default=_yaml("warmup_steps", 100)) | |
| p.add_argument("--val-config", default=_yaml("val_config", None)) | |
| return p.parse_args(remaining) | |
| # βββ Main βββ | |
| def main(): | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| args = parse_args() | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.grad_accum, | |
| mixed_precision="bf16", | |
| ) | |
| is_main = accelerator.is_main_process | |
| if is_main: | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| else: | |
| logging.basicConfig(level=logging.WARNING) | |
| set_seed(args.seed) | |
| device = accelerator.device | |
| dtype = torch.bfloat16 | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Save training args | |
| if is_main: | |
| import yaml | |
| args_dict = vars(args).copy() | |
| args_dict["_meta"] = { | |
| "world_size": accelerator.num_processes, | |
| "dtype": str(dtype), | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| "script": "train_audio_iclora.py", | |
| "pattern": "IC-LoRA (ref appended to end)", | |
| } | |
| with open(os.path.join(args.output_dir, "training_args.yaml"), "w") as f: | |
| yaml.dump(args_dict, f, default_flow_style=False, sort_keys=False) | |
| from ltx_core.components.patchifiers import AudioPatchifier | |
| from ltx_core.model.transformer.modality import Modality | |
| from ltx_core.guidance.perturbations import BatchedPerturbationConfig | |
| from ltx_core.tools import AudioLatentTools | |
| from ltx_core.types import AudioLatentShape, LatentState | |
| from ltx_pipelines.utils.helpers import modality_from_latent_state, timesteps_from_mask | |
| # Build speaker map | |
| if is_main: | |
| logging.info("Building speaker map...") | |
| speaker_map = build_speaker_map(args.speaker_index, args.data_dir) | |
| if is_main: | |
| logging.info(f"Speaker map: {len(speaker_map)} speakers, " | |
| f"{sum(len(v) for v in speaker_map.values())} samples") | |
| # Load model | |
| if is_main: | |
| logging.info("Loading audio-only model...") | |
| model = build_audio_only_model(args.checkpoint, device, dtype) | |
| if is_main: | |
| logging.info("Loading audio connector...") | |
| audio_connector = load_audio_connector(args.full_checkpoint, device, dtype) | |
| audio_connector.eval() | |
| for p in audio_connector.parameters(): | |
| p.requires_grad = False | |
| if is_main: | |
| logging.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...") | |
| model = apply_lora(model, args.lora_rank, args.lora_alpha, args.lora_dropout) | |
| # Resume from checkpoint | |
| if args.resume_lora: | |
| from safetensors.torch import load_file as st_load | |
| if is_main: | |
| logging.info(f"Resuming from: {args.resume_lora}") | |
| lora_sd = st_load(args.resume_lora) | |
| mapped = {} | |
| for k, v in lora_sd.items(): | |
| nk = k.replace(".lora_A.weight", ".lora_A.default.weight").replace( | |
| ".lora_B.weight", ".lora_B.default.weight") | |
| mapped[nk] = v | |
| model.load_state_dict(mapped, strict=False) | |
| # Determine step offset for save filenames. Without this, resuming a run | |
| # restarts step numbering at 0 and would overwrite earlier phase-1 | |
| # checkpoints with the same save_every cadence. | |
| if args.resume_step_offset is None: | |
| resume_offset = 0 | |
| if args.resume_lora: | |
| import re as _re | |
| m = _re.search(r"lora_step_(\d+)", os.path.basename(args.resume_lora)) | |
| if m: | |
| resume_offset = int(m.group(1)) | |
| args.resume_step_offset = resume_offset | |
| if is_main and args.resume_step_offset: | |
| logging.info(f"Save-step offset: +{args.resume_step_offset}") | |
| model.train() | |
| model.base_model.model.set_gradient_checkpointing(True) | |
| # Dataset & DataLoader | |
| dataset = IDLoRADataset(speaker_map) | |
| if is_main: | |
| logging.info(f"Dataset: {len(dataset)} samples, {len(dataset.speaker_map)} speakers") | |
| def collate_fn(batch): | |
| """Pad variable-length audio to max in batch, track real lengths for loss masking.""" | |
| max_tgt_T = max(b["tgt_latent"].shape[1] for b in batch) # [C, T, F] | |
| max_ref_T = max(b["ref_latent"].shape[1] for b in batch) | |
| C = batch[0]["tgt_latent"].shape[0] | |
| F_dim = batch[0]["tgt_latent"].shape[2] | |
| tgt_list, ref_list, feat_list, mask_list = [], [], [], [] | |
| tgt_lengths, ref_lengths = [], [] | |
| for b in batch: | |
| tgt = b["tgt_latent"] | |
| ref = b["ref_latent"] | |
| tgt_lengths.append(tgt.shape[1]) | |
| ref_lengths.append(ref.shape[1]) | |
| if tgt.shape[1] < max_tgt_T: | |
| pad = torch.zeros(C, max_tgt_T - tgt.shape[1], F_dim, dtype=tgt.dtype) | |
| tgt = torch.cat([tgt, pad], dim=1) | |
| tgt_list.append(tgt) | |
| if ref.shape[1] < max_ref_T: | |
| pad = torch.zeros(C, max_ref_T - ref.shape[1], F_dim, dtype=ref.dtype) | |
| ref = torch.cat([ref, pad], dim=1) | |
| ref_list.append(ref) | |
| feat_list.append(b["audio_features"]) | |
| mask_list.append(b["attention_mask"]) | |
| return { | |
| "tgt_latent": torch.stack(tgt_list), | |
| "ref_latent": torch.stack(ref_list), | |
| "audio_features": torch.stack(feat_list), | |
| "attention_mask": torch.stack(mask_list), | |
| "tgt_lengths": torch.tensor(tgt_lengths), | |
| "ref_lengths": torch.tensor(ref_lengths), | |
| } | |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, | |
| pin_memory=True, drop_last=True, collate_fn=collate_fn) | |
| # Optimizer & Scheduler | |
| optimizer = torch.optim.AdamW( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01, | |
| ) | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ConstantLR | |
| warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_steps) | |
| remaining = args.steps - args.warmup_steps | |
| if args.lr_scheduler == "cosine": | |
| # Warmup -> constant hold (20% of remaining) -> cosine decay | |
| hold_steps = max(remaining // 5, 0) | |
| decay_steps = max(remaining - hold_steps, 1) | |
| hold_sched = ConstantLR(optimizer, factor=1.0, total_iters=hold_steps) | |
| decay_sched = CosineAnnealingLR(optimizer, T_max=decay_steps, eta_min=1e-6) | |
| scheduler = SequentialLR( | |
| optimizer, | |
| [warmup, hold_sched, decay_sched], | |
| milestones=[args.warmup_steps, args.warmup_steps + hold_steps], | |
| ) | |
| elif args.lr_scheduler == "linear": | |
| main_sched = LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=max(remaining, 1)) | |
| scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps]) | |
| else: | |
| main_sched = ConstantLR(optimizer, factor=1.0, total_iters=max(remaining, 1)) | |
| scheduler = SequentialLR(optimizer, [warmup, main_sched], milestones=[args.warmup_steps]) | |
| # Prepare with Accelerate β but NOT the scheduler. AcceleratedScheduler | |
| # calls the underlying scheduler.step() `num_processes` times per sync, | |
| # which silently scales down our warmup/cosine spans by that factor. | |
| # We call scheduler.step() ourselves, gated on sync_gradients β exactly | |
| # one advance per optimizer step, as the yaml spec intends. | |
| model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) | |
| patchifier = AudioPatchifier(patch_size=1) | |
| # Select timestep sampler based on base model type | |
| if args.base_model == "distilled": | |
| timestep_sampler = DistilledTimestepSampler() | |
| if is_main: | |
| logging.info("Using DistilledTimestepSampler (matching distilled model sigmas)") | |
| else: | |
| timestep_sampler = ShiftedLogitNormalTimestepSampler() | |
| if is_main: | |
| logging.info("Using ShiftedLogitNormalTimestepSampler (dev model)") | |
| # Training loop | |
| if is_main: | |
| logging.info(f"Training: {args.steps} steps, lr={args.lr}, scheduler={args.lr_scheduler}, " | |
| f"batch={args.batch_size}, grad_accum={args.grad_accum}, " | |
| f"world_size={accelerator.num_processes}, " | |
| f"ref_ratio={args.ref_ratio}, max_ref_tokens={args.max_ref_tokens}") | |
| logging.info("IC-LoRA pattern: ref tokens APPENDED to target, loss on target only") | |
| data_iter = iter(dataloader) | |
| step = 0 | |
| accum_loss = 0.0 | |
| best_loss = float("inf") | |
| best_step = 0 | |
| t0 = time.time() | |
| total_micro_steps = args.steps * args.grad_accum | |
| for micro_step in range(total_micro_steps): | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(dataloader) | |
| batch = next(data_iter) | |
| is_opt_step = (micro_step + 1) % args.grad_accum == 0 | |
| if is_opt_step: | |
| step += 1 | |
| with accelerator.accumulate(model): | |
| tgt_latent = batch["tgt_latent"].to(dtype=dtype) # [B, C, max_tgt_T, F] | |
| ref_latent = batch["ref_latent"].to(dtype=dtype) # [B, C, max_ref_T, F] | |
| tgt_lengths = batch["tgt_lengths"].to(device=device) # [B] | |
| B = tgt_latent.shape[0] | |
| # ββ Random silence padding (0-1s) ββ ltx_audio_tts baseline. | |
| # User observed reference-audio leak at end of generations when this | |
| # was reduced to 5 (v14) or 10 frames (v16/v17) β the model seemed | |
| # to use the extra target budget to regurgitate ref content. Full | |
| # 25 frames (0-1s avg 500ms) was apparently load-bearing for | |
| # regularising the boundary and reducing hallucinations. | |
| # Uses the real silence latent (not zeros) so the VAE decodes it as | |
| # true silence instead of static noise. | |
| max_pad_frames = 25 # ~1s at 25 latent frames/sec | |
| pad_frames = random.randint(0, max_pad_frames) | |
| if pad_frames > 0: | |
| C, F_dim = tgt_latent.shape[1], tgt_latent.shape[3] | |
| if not hasattr(args, '_silence_frame') or args._silence_frame is None: | |
| _sf_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "silence_latent_frame.pt") | |
| if os.path.exists(_sf_path): | |
| args._silence_frame = torch.load(_sf_path, weights_only=True) # [C, 1, F] | |
| if is_main: | |
| logging.info(f"Loaded silence latent from {_sf_path}") | |
| else: | |
| args._silence_frame = False # fallback to zeros | |
| if is_main: | |
| logging.warning(f"silence_latent_frame.pt not found, using zeros") | |
| if args._silence_frame is not False: | |
| sf = args._silence_frame.to(dtype=dtype, device=device) # [C, 1, F] | |
| silence_pad = sf.unsqueeze(0).expand(B, -1, pad_frames, -1) # [B, C, pad, F] | |
| else: | |
| silence_pad = torch.zeros(B, C, pad_frames, F_dim, dtype=dtype, device=device) | |
| tgt_latent = torch.cat([silence_pad, tgt_latent], dim=2) | |
| # Cap reference to max_ref_tokens (in latent frames, before patchification) | |
| # After patchification, ref_T tokens = ref frames (patch_size=1) | |
| ref_T_frames = min(ref_latent.shape[2], args.max_ref_tokens) | |
| ref_latent = ref_latent[:, :, :ref_T_frames, :] | |
| tgt_T_frames = tgt_latent.shape[2] # max (padded) target frames | |
| # ββ Step 1: Create target AudioLatentShape and AudioLatentTools ββ | |
| tgt_shape = AudioLatentShape( | |
| batch=B, | |
| channels=tgt_latent.shape[1], # 8 | |
| frames=tgt_T_frames, | |
| mel_bins=tgt_latent.shape[3], # 16 | |
| ) | |
| audio_tools = AudioLatentTools( | |
| patchifier=patchifier, | |
| target_shape=tgt_shape, | |
| ) | |
| # ββ Step 2: Create initial state from target latent ββ | |
| # create_initial_state patchifies: [B, C, T, F] -> [B, T, C*F] | |
| # Also creates denoise_mask=1 (all target tokens will be denoised) | |
| # and computes temporal positions | |
| state = audio_tools.create_initial_state( | |
| device=device, | |
| dtype=dtype, | |
| initial_latent=tgt_latent, | |
| ) | |
| # state.latent: [B, tgt_T, 128], state.denoise_mask: [B, tgt_T, 1] | |
| # state.positions: [B, 1, tgt_T, 2] | |
| tgt_T = audio_tools.target_shape.token_count() # = tgt_T_frames | |
| # ββ Step 3: Apply flow-matching noise to target BEFORE appending ref ββ | |
| # Sample sigma | |
| total_tokens = tgt_T + ref_T_frames | |
| sigma = timestep_sampler.sample(B, total_tokens, device=device) | |
| sigma_exp = sigma.view(-1, 1, 1) # [B, 1, 1] | |
| noise = torch.randn_like(state.latent) # [B, tgt_T, 128] | |
| noisy_tgt = (1 - sigma_exp) * state.latent + sigma_exp * noise | |
| # Replace the latent in state with the noisy version | |
| # (clean_latent stays clean for post_process_latent pattern) | |
| state = LatentState( | |
| latent=noisy_tgt, | |
| denoise_mask=state.denoise_mask, | |
| positions=state.positions, | |
| clean_latent=state.clean_latent, | |
| attention_mask=state.attention_mask, | |
| ) | |
| # ββ Step 4: Append reference tokens using AudioConditionByReferenceLatent ββ | |
| # This appends ref tokens to the END with denoise_mask=0 (frozen/clean) | |
| # Skip entirely when ref_T=0 (SFX / song samples): the model trains | |
| # target-only for those categories since there's no voice to clone. | |
| if ref_T_frames > 0: | |
| ref_conditioning = AudioConditionByReferenceLatent( | |
| latent=ref_latent, | |
| strength=1.0, # 1.0 = ref fully clean (denoise_mask=0) | |
| ) | |
| state = ref_conditioning.apply_to( | |
| latent_state=state, | |
| latent_tools=audio_tools, | |
| ) | |
| # state.latent: [B, tgt_T + ref_T, 128] | |
| # state.denoise_mask: [B, tgt_T + ref_T, 1] | |
| # target tokens: 1.0 (denoise), ref tokens: 0.0 (frozen) | |
| # state.positions: [B, 1, tgt_T + ref_T, 2] | |
| # ββ Step 5: Build loss mask for target tokens (excluding padding) ββ | |
| # loss_mask: 1 for real target tokens, 0 for padding and ref tokens | |
| loss_mask = torch.zeros(B, tgt_T, device=device) | |
| for b_idx in range(B): | |
| real_len = min(tgt_lengths[b_idx].item(), tgt_T) | |
| loss_mask[b_idx, :real_len] = 1.0 | |
| # ββ Step 6: Prepare text context ββ | |
| # Text conditioning dropout: randomly zero out text context to force | |
| # the model to rely on the voice reference for identity/style. | |
| with torch.no_grad(): | |
| audio_context = prepare_audio_context( | |
| audio_connector, batch["audio_features"], | |
| batch["attention_mask"], device, dtype) | |
| if args.text_dropout > 0 and random.random() < args.text_dropout: | |
| audio_context = torch.zeros_like(audio_context) | |
| # ββ Step 7: Build Modality using modality_from_latent_state ββ | |
| # timesteps = sigma * denoise_mask (ref gets 0, target gets sigma) | |
| audio_mod = modality_from_latent_state( | |
| state=state, | |
| context=audio_context, | |
| sigma=sigma, | |
| enabled=True, | |
| ) | |
| # ββ Step 8: Forward pass ββ | |
| perturbations = BatchedPerturbationConfig.empty(B) | |
| with torch.autocast(device_type="cuda", dtype=dtype): | |
| _, velocity_pred = model(video=None, audio=audio_mod, perturbations=perturbations) | |
| # ββ Step 9: Compute loss (IC-LoRA pattern) ββ | |
| # Target is at the FRONT (indices 0..tgt_T), ref at the END | |
| # velocity target = noise - clean | |
| tgt_patchified = audio_tools.patchifier.patchify(tgt_latent) # [B, tgt_T, 128] | |
| target_velocity = noise - tgt_patchified | |
| # Extract target portion of prediction | |
| pred_tgt = velocity_pred[:, :tgt_T] # [B, tgt_T, 128] | |
| # MSE loss with mask: only on real target tokens (not padding or ref) | |
| per_token_mse = (pred_tgt - target_velocity).pow(2).mean(dim=-1) # [B, tgt_T] | |
| loss = per_token_mse.mul(loss_mask).div(loss_mask.mean().clamp(min=1e-6)).mean() | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients and args.max_grad_norm > 0: | |
| accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Only advance the LR scheduler once per OPTIMIZER step (not per | |
| # micro-step). Mirrors AcceleratedOptimizer.step() which is | |
| # internally gated on sync_gradients. | |
| if accelerator.sync_gradients: | |
| scheduler.step() | |
| accum_loss += loss.item() | |
| # Logging & saving on optimization steps only | |
| if is_opt_step and step % args.log_every == 0 and is_main: | |
| avg_loss = accum_loss / (args.log_every * args.grad_accum) | |
| lr = optimizer.param_groups[0]["lr"] | |
| elapsed = time.time() - t0 | |
| sps = step / elapsed if elapsed > 0 else 0 | |
| eta = (args.steps - step) / sps if sps > 0 else 0 | |
| logging.info( | |
| f"Step {step}/{args.steps} | loss={avg_loss:.4f} | lr={lr:.2e} | " | |
| f"tgt_T={tgt_T} ref_T={ref_T_frames} total={tgt_T + ref_T_frames} | " | |
| f"{sps:.1f} steps/s | ETA {eta/60:.0f}min" | |
| ) | |
| # Save best whenever loss improves β no warmup gate, so we can | |
| # observe best checkpoints during warmup too. | |
| if avg_loss < best_loss: | |
| best_loss = avg_loss | |
| old_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors") | |
| best_step = step + args.resume_step_offset | |
| new_best = os.path.join(args.output_dir, f"best_step_{best_step:05d}.safetensors") | |
| unwrapped = _unwrap_model_safe(model) | |
| unwrapped.save_pretrained(args.output_dir) | |
| adapter = os.path.join(args.output_dir, "adapter_model.safetensors") | |
| if os.path.exists(adapter): | |
| shutil.copy(adapter, new_best) | |
| if old_best != new_best and os.path.exists(old_best): | |
| os.remove(old_best) | |
| logging.info(f"New best: loss={best_loss:.4f} at step {best_step}") | |
| accum_loss = 0.0 | |
| if is_opt_step and step % args.save_every == 0 and is_main: | |
| global_step = step + args.resume_step_offset | |
| save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors") | |
| logging.info(f"Saving: {save_path}") | |
| unwrapped = _unwrap_model_safe(model) | |
| unwrapped.save_pretrained(args.output_dir) | |
| adapter = os.path.join(args.output_dir, "adapter_model.safetensors") | |
| if os.path.exists(adapter): | |
| shutil.copy(adapter, save_path) | |
| if args.val_config: | |
| logging.info(f"Running validation at step {global_step}...") | |
| model.eval() | |
| run_validation(save_path, args.val_config, args.output_dir, global_step, | |
| lora_rank=args.lora_rank) | |
| model.train() | |
| # Final save | |
| if is_main: | |
| unwrapped = _unwrap_model_safe(model) | |
| unwrapped.save_pretrained(args.output_dir) | |
| adapter = os.path.join(args.output_dir, "adapter_model.safetensors") | |
| global_step = step + args.resume_step_offset | |
| save_path = os.path.join(args.output_dir, f"lora_step_{global_step:05d}.safetensors") | |
| if os.path.exists(adapter): | |
| shutil.copy(adapter, save_path) | |
| logging.info(f"Training complete! {step} steps in {time.time()-t0:.0f}s") | |
| logging.info(f"Best loss: {best_loss:.4f} at step {best_step}") | |
| if __name__ == "__main__": | |
| main() | |