#!/usr/bin/env python3 """ LTX-2.3 TTS with IC-LoRA voice cloning. Uses AudioConditionByReferenceLatent to append reference audio tokens to the end of the target sequence. Auto-detects distilled vs dev checkpoint and selects the appropriate denoiser (SimpleDenoiser / GuidedDenoiser) and sigma schedule. Leverages the official euler_denoising_loop, AudioLatentTools, GaussianNoiser, and X0Model wrapper throughout. Usage (distilled): python tts_iclora.py \ --voice-sample reference.wav \ --prompt "A woman speaks clearly: The weather today will be sunny." \ --output tts_output.wav Usage (dev): python tts_iclora.py \ --voice-sample reference.wav \ --prompt "A woman speaks clearly: The weather today will be sunny." \ --checkpoint ltx-2.3-22b-dev-audio-only.safetensors \ --full-checkpoint ltx-2.3-22b-dev.safetensors \ --output tts_output.wav """ import argparse import json import logging import os import re import struct import sys import time from pathlib import Path import torch import torchaudio REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2")) # ltx-pipelines already on path via ltx2/ # Also add the local directory so audio_conditioning.py is importable sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models") GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def detect_model_type(checkpoint_path: str) -> str: """Detect if checkpoint is distilled or dev by checking filename and metadata.""" path_lower = checkpoint_path.lower() if "distilled" in path_lower: return "distilled" if "dev" in path_lower: return "dev" # Fallback: try to read safetensors metadata try: with open(checkpoint_path, "rb") as f: header_size = struct.unpack(" float: """Context-aware laugh budget. For each laugh verb in the prompt, look at the adjective/adverb that modifies it and scale the base duration: - short modifiers (briefly, softly, once) -> 0.4x base - long modifiers (maniacally, heartily, ...) -> 1.2x base - default (no mod / neutral) -> 1.0x base Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more time than 'Haha' -- at ~0.2s per extra repeated syllable. """ # "softly" / "quietly" describe volume not length, so keep at default 1.0x. short_mod = re.compile( r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)", re.IGNORECASE) long_mod = re.compile( r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|" r"hysterically|darkly|wickedly|evilly|loudly|long)" r"|^\s*between phrases", re.IGNORECASE) total = 0.0 for pat, base_dur in _LAUGH_VERBS.items(): for m in re.finditer(pat, text, re.IGNORECASE): ctx = text[m.end(): m.end() + 40] if short_mod.match(ctx): total += base_dur * 0.4 elif long_mod.match(ctx): total += base_dur * 1.2 else: total += base_dur # Phonetic laugh repetition inside quotes: # 'Haha' = 2 syllables (base, no bonus) # 'Hahahaha' = 4 syllables (+0.4s) # 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s) for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text): for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE): syls = len(re.findall(r"h[ae]", run, re.IGNORECASE)) total += 0.2 * max(syls - 2, 0) return total def _estimate_nonverbal_duration(text: str) -> float: """Estimate extra duration for non-verbal sounds and actions in the prompt. Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle / chuckle / laugh budgets scale with the adjective ("maniacally" vs "briefly") and with the repetition length of 'Ha'/'He' tokens inside quotes. """ PATTERNS = { # Breathing / sighs r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0, r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0, r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8, r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8, # Pauses (trimmed; earlier values over-budgeted silence) r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3, r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0, r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0, # Physical actions that produce sound r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3, r'\bdraws? (?:his|her|a) sword\b': 0.5, r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5, r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8, # Vocal actions (not in quotes but take time) r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0, r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5, r'\bswallows?\b': 0.5, # (laugh / chuckle / cackle / giggle / snicker handled by # _contextual_laugh_duration below -- modifier-aware, not flat.) # Emotional transitions r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5, r'\bsteadies? (?:him|her)self\b': 1.0, r'\bcatches? (?:his|her) breath\b': 1.0, r'\bcomposes? (?:him|her)self\b': 0.8, # Scene transitions that imply time r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5, r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5, } extra = 0.0 for pattern, dur in PATTERNS.items(): extra += dur * len(re.findall(pattern, text, re.IGNORECASE)) extra += _contextual_laugh_duration(text) return extra def estimate_speech_duration(text: str, speed: float = 1.0) -> float: """Estimate speech duration from spoken content + non-verbal actions. Extracts spoken text by priority: 1. Quoted text ('...' or "...") -- official prompt guide format 2. Text after colon -- simple "Speaker: dialogue" format 3. Full text -- fallback Also scans the full prompt for non-verbal cues (laughs, pauses, sighs, gasps, etc.) and adds estimated duration for each. """ # Try double quotes first (clean, no contraction issues) quotes = re.findall(r'"([^"]+)"', text) if not quotes: # Single quotes: allow apostrophes in contractions (don't, can't, it's) # Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text) # Filter out short fragments (scene directions like "He pauses") quotes = [q for q in quotes if len(q.split()) > 3] if quotes: spoken = " ".join(quotes) elif ":" in text: spoken = text.split(":", 1)[1].strip() else: spoken = text CHARS_PER_SEC = 14.0 text_len = len(spoken) if text_len < 40: chars_per_sec = CHARS_PER_SEC * 0.6 elif text_len < 80: chars_per_sec = CHARS_PER_SEC * 0.8 else: chars_per_sec = CHARS_PER_SEC chars_per_sec *= speed duration = text_len / chars_per_sec sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?") duration += sentence_count * 0.3 # Add time for non-verbal sounds/actions in the full prompt duration += _estimate_nonverbal_duration(text) return max(3.0, round(duration + 2.0, 1)) def parse_args(): p = argparse.ArgumentParser(description="LTX-2.3 TTS with IC-LoRA voice cloning") p.add_argument("--voice-sample", default=None, help="Voice reference WAV") p.add_argument("--no-ref", action="store_true", help="Skip voice reference conditioning (raw base model)") p.add_argument("--prompt", required=True, help="Text/scene description to synthesize") p.add_argument("--output", default="tts_output.wav") p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use") p.add_argument("--gen-duration", type=float, default=0.0, help="Target output duration in seconds (0 = auto from prompt + multiplier). " "Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). " "Outputs >20.5s automatically engage the end-of-clip silence-prior patch.") p.add_argument("--pad-start", type=float, default=0.0, help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)") p.add_argument("--speed", type=float, default=1.0) p.add_argument("--duration-multiplier", type=float, default=1.0, help="Multiply auto-estimated duration by this factor (e.g. 1.1 for 10%% more breathing room)") p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-audio-only.safetensors")) p.add_argument("--full-checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors")) p.add_argument("--gemma-root", default=GEMMA_DIR) p.add_argument("--bnb-4bit", dest="bnb_4bit", action="store_true", default=True, help="Load Gemma text encoder via the bitsandbytes 4-bit path " "(required for the default unsloth/gemma-3-12b-it-bnb-4bit " "pre-quantized weights). Default: on.") p.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false", help="Disable the bitsandbytes path (use only if --gemma-root " "points at an unquantized Gemma checkpoint).") p.add_argument("--lora", default=None, help="Path to trained IC-LoRA .safetensors (audio-only)") p.add_argument("--lora-rank", type=int, default=128, help="LoRA rank (must match training)") p.add_argument("--id-guidance-scale", type=float, default=3.0, help="Identity guidance scale (0=disabled)") p.add_argument("--seed", type=int, default=42) # Auto-set based on model type but overridable p.add_argument("--no-watermark", action="store_true", help="Skip Perth audio watermarking on the output (default: watermark on).") p.add_argument("--sampler", choices=["euler", "heun"], default="euler", help="Denoising loop. 'heun' = jkass_quality 2nd-order predictor-corrector (~2x model calls, cleaner audio).") p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)") p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)") p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation") p.add_argument("--rescale-scale", type=float, default=None, help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents " "output clipping at high cfg; pass any float in [0,1] to override).") p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)") p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)") p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)") p.add_argument("--fps", type=float, default=None, help="FPS (auto: 24.0 distilled, 25.0 dev)") p.add_argument( "--negative-prompt", default=( "worst quality, inconsistent motion, blurry, jittery, distorted, " "robotic voice, echo, background noise, off-sync audio, repetitive speech" ), help="Negative prompt for CFG (dev model)", ) return p.parse_args() @torch.inference_mode() def main(): logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") args = parse_args() t0 = time.time() # ---- Imports (deferred to avoid startup cost when checking --help) ---- from audio_conditioning import AudioConditionByReferenceLatent from ltx_core.batch_split import BatchSplitAdapter from ltx_core.components.diffusion_steps import EulerDiffusionStep from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams from ltx_core.components.noisers import GaussianNoiser from ltx_core.components.patchifiers import AudioPatchifier from ltx_core.components.schedulers import LTX2Scheduler from ltx_core.loader.registry import DummyRegistry from ltx_core.loader.sd_ops import SDOps from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder from ltx_core.model.audio_vae import encode_audio as vae_encode_audio from ltx_core.model.model_protocol import ModelConfigurator from ltx_core.model.transformer.attention import AttentionFunction from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model from ltx_core.model.transformer.rope import LTXRopeType from ltx_core.tools import AudioLatentTools from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoPixelShape from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser from ltx_pipelines.utils.gpu_model import gpu_model from ltx_pipelines.utils.media_io import decode_audio_from_file from ltx_pipelines.utils.samplers import euler_denoising_loop, heun_denoising_loop device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 patchifier = AudioPatchifier(patch_size=1) # ---- Detect model type and set defaults ---- model_type = detect_model_type(args.full_checkpoint) logging.info(f"Detected model type: {model_type}") is_distilled = model_type == "distilled" if args.cfg_scale is None: args.cfg_scale = 1.0 if is_distilled else 7.0 if args.stg_scale is None: args.stg_scale = 0.0 if is_distilled else 1.0 if args.rescale_scale is None: # Auto cfg-aware rescale: imported from inference_server to keep one source of truth. from inference_server import auto_rescale_for_cfg args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale) if args.modality_scale is None: args.modality_scale = 1.0 if is_distilled else 3.0 if args.fps is None: args.fps = 24.0 if is_distilled else 25.0 logging.info( f"Params: cfg={args.cfg_scale}, stg={args.stg_scale}, rescale={args.rescale_scale}, " f"modality={args.modality_scale}, fps={args.fps}" ) # ---- Auto duration ---- if args.gen_duration <= 0: args.gen_duration = estimate_speech_duration(args.prompt, args.speed) if args.duration_multiplier != 1.0: args.gen_duration = round(args.gen_duration * args.duration_multiplier, 1) logging.info(f"Auto duration: {args.gen_duration}s for {len(args.prompt)} chars" f"{f' (x{args.duration_multiplier})' if args.duration_multiplier != 1.0 else ''}") # ---- Compute target shape (include pad_start in duration) ---- padded_duration = args.gen_duration + args.pad_start raw_frames = int(round(padded_duration * args.fps)) + 1 num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1 pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps) tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape) logging.info(f"Target shape: {tgt_shape} ({args.gen_duration}s, {num_frames} frames)") # ---- AudioLatentTools for target ---- audio_tools = AudioLatentTools(patchifier=patchifier, target_shape=tgt_shape) # ---- Create initial state ---- state = audio_tools.create_initial_state(device, dtype) logging.info( f"Initial state: latent={state.latent.shape}, positions={state.positions.shape}, " f"denoise_mask={state.denoise_mask.shape}" ) if not args.no_ref and args.voice_sample: # ---- Encode voice reference ---- logging.info(f"Loading voice reference: {args.voice_sample}") voice = decode_audio_from_file(args.voice_sample, device, 0.0, args.ref_duration) if voice is None: raise ValueError(f"Could not load audio from {args.voice_sample}") w = voice.waveform if w.dim() == 2: if w.shape[0] == 1: w = w.repeat(2, 1) w = w.unsqueeze(0) elif w.dim() == 3 and w.shape[1] == 1: w = w.repeat(1, 2, 1) target_samples = int(args.ref_duration * voice.sampling_rate) if w.shape[-1] < target_samples: w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1) w = w[..., :target_samples] # Peak normalize reference peak = w.abs().max() if peak > 0: target_peak = 10 ** (-4.0 / 20) # -4dB w = w * (target_peak / peak) logging.info(f"Normalized reference: peak {peak:.4f} -> {target_peak:.4f}") voice = Audio(waveform=w, sampling_rate=voice.sampling_rate) logging.info("Encoding voice through Audio VAE...") ac = AudioConditioner(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device) ref_latent = ac(lambda enc: vae_encode_audio(voice, enc, None)) del ac torch.cuda.empty_cache() logging.info(f"Reference latent: {ref_latent.shape}") # ---- Apply conditioning: append ref tokens to END ---- conditioning = AudioConditionByReferenceLatent(latent=ref_latent.to(device, dtype), strength=1.0) state = conditioning.apply_to(latent_state=state, latent_tools=audio_tools) logging.info( f"After conditioning: latent={state.latent.shape}, positions={state.positions.shape}, " f"attention_mask={'None' if state.attention_mask is None else state.attention_mask.shape}" ) else: logging.info("No voice reference — running raw base model") # ---- Apply noise ---- generator = torch.Generator(device=device).manual_seed(args.seed) noiser = GaussianNoiser(generator=generator) noised_state = noiser(state, noise_scale=1.0) logging.info("Applied Gaussian noise to state") # ---- Encode prompt ---- use_cfg = args.cfg_scale > 1.0 logging.info("Encoding prompt...") pe = PromptEncoder(checkpoint_path=args.full_checkpoint, gemma_root=args.gemma_root, dtype=dtype, device=device, use_bnb_4bit=args.bnb_4bit, warm=True) prompts_to_encode = [args.prompt] if use_cfg: prompts_to_encode.append(args.negative_prompt) ctx = pe(prompts_to_encode, streaming_prefetch_count=None) a_ctx = ctx[0].audio_encoding a_ctx_neg = ctx[1].audio_encoding if use_cfg else None del pe torch.cuda.empty_cache() logging.info(f"Prompt encoded: a_ctx={a_ctx.shape}" + (f", a_ctx_neg={a_ctx_neg.shape}" if a_ctx_neg is not None else "")) # ---- Build audio-only model ---- logging.info("Building audio-only model...") audio_only_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement( "model.diffusion_model.", "" ) class AudioOnlyConfigurator(ModelConfigurator[LTXModel]): @classmethod def from_config(cls, config): t = config.get("transformer", {}) cp = None if not t.get("caption_proj_before_connector", False): from ltx_core.model.transformer.text_projection import create_caption_projection with torch.device("meta"): cp = create_caption_projection(t, audio=True) return LTXModel( model_type=LTXModelType.AudioOnly, audio_num_attention_heads=t.get("audio_num_attention_heads", 32), audio_attention_head_dim=t.get("audio_attention_head_dim", 64), audio_in_channels=t.get("audio_in_channels", 128), audio_out_channels=t.get("audio_out_channels", 128), num_layers=t.get("num_layers", 48), audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048), norm_eps=t.get("norm_eps", 1e-6), attention_type=AttentionFunction(t.get("attention_type", "default")), positional_embedding_theta=10000.0, audio_positional_embedding_max_pos=[20.0], timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000), use_middle_indices_grid=t.get("use_middle_indices_grid", True), rope_type=LTXRopeType(t.get("rope_type", "interleaved")), double_precision_rope=t.get("frequencies_precision", False) == "float64", apply_gated_attention=t.get("apply_gated_attention", False), audio_caption_projection=cp, cross_attention_adaln=t.get("cross_attention_adaln", False), ) builder = Builder( model_path=args.checkpoint, model_class_configurator=AudioOnlyConfigurator, model_sd_ops=audio_only_sd_ops, registry=DummyRegistry(), ) velocity_model = builder.build(device=device, dtype=dtype).to(device).eval() # ---- Load LoRA weights (if provided) ---- if args.lora and os.path.exists(args.lora): from peft import LoraConfig, get_peft_model from safetensors.torch import load_file as st_load logging.info(f"Loading LoRA: {args.lora}") lora_sd = st_load(args.lora) is_peft_format = any("base_model.model." in k for k in lora_sd.keys()) is_original_idlora = any("diffusion_model." in k for k in lora_sd.keys()) lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_rank, lora_dropout=0.0, bias="none", target_modules=[ "audio_attn1.to_k", "audio_attn1.to_q", "audio_attn1.to_v", "audio_attn1.to_out.0", "audio_attn2.to_k", "audio_attn2.to_q", "audio_attn2.to_v", "audio_attn2.to_out.0", "audio_ff.net.0.proj", "audio_ff.net.2", ], ) velocity_model = get_peft_model(velocity_model, lora_config) if is_peft_format: mapped_sd = {} for k, v in lora_sd.items(): new_key = k if ".lora_A.weight" in k and ".lora_A.default.weight" not in k: new_key = k.replace(".lora_A.weight", ".lora_A.default.weight") if ".lora_B.weight" in k and ".lora_B.default.weight" not in k: new_key = k.replace(".lora_B.weight", ".lora_B.default.weight") mapped_sd[new_key] = v missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False) loaded = len(mapped_sd) - len(unexpected) logging.info(f"Loaded {loaded} LoRA weights (peft format)") elif is_original_idlora: audio_keys = { k: v for k, v in lora_sd.items() if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k } mapped_sd = {} for k, v in audio_keys.items(): new_key = k.replace("diffusion_model.", "base_model.model.") new_key = new_key.replace(".lora_A.weight", ".lora_A.default.weight") new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") mapped_sd[new_key] = v missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False) loaded = len(mapped_sd) - len(unexpected) logging.info(f"Loaded {loaded} LoRA weights (original ID-LoRA)") velocity_model = velocity_model.merge_and_unload() logging.info("Merged LoRA into model") logging.info(f"Model: {sum(p.numel() for p in velocity_model.parameters()) / 1e9:.1f}B params") # ---- Wrap velocity model in X0Model ---- x0_model = X0Model(velocity_model) # ---- Build denoiser and sigmas ---- stepper = EulerDiffusionStep() # ---- Sigma schedule ---- if is_distilled: if args.steps is not None and args.steps > 0: sigmas = LTX2Scheduler().execute(steps=args.steps, latent=noised_state.latent).to(device) logging.info(f"Distilled with custom {args.steps}-step schedule") else: sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, dtype=torch.float32, device=device) logging.info(f"Distilled {len(DISTILLED_SIGMA_VALUES) - 1}-step schedule") else: steps = args.steps if args.steps is not None and args.steps > 0 else 30 sigmas = LTX2Scheduler().execute(steps=steps, latent=noised_state.latent).to(device) logging.info(f"Dev {steps}-step schedule") # ---- Denoiser: use GuidedDenoiser if any guidance is active, SimpleDenoiser otherwise ---- needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0 if needs_guidance: audio_guider = MultiModalGuider( params=MultiModalGuiderParams( cfg_scale=args.cfg_scale, stg_scale=args.stg_scale, stg_blocks=[args.stg_block] if args.stg_scale > 0 else [], rescale_scale=args.rescale_scale, modality_scale=args.modality_scale, cfg_clamp_scale=args.cfg_clamp, ), negative_context=a_ctx_neg, ) denoiser = GuidedDenoiser( v_context=None, a_context=a_ctx, video_guider=None, audio_guider=audio_guider, ) logging.info(f"GuidedDenoiser: cfg={args.cfg_scale}, stg={args.stg_scale}, " f"rescale={args.rescale_scale}, modality={args.modality_scale}") else: denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx) logging.info("SimpleDenoiser (no guidance)") logging.info(f"Sigmas: {sigmas.tolist()}") # ---- Denoising loop ---- logging.info(f"Running denoising loop ({len(sigmas) - 1} steps)...") with gpu_model(x0_model) as model: batched_model = BatchSplitAdapter(model, max_batch_size=1) denoise_fn = heun_denoising_loop if args.sampler == "heun" else euler_denoising_loop _, audio_state = denoise_fn( sigmas=sigmas, video_state=None, audio_state=noised_state, stepper=stepper, transformer=batched_model, denoiser=denoiser, ) del velocity_model, x0_model torch.cuda.empty_cache() # ---- Strip ref tokens and unpatchify ---- logging.info("Stripping conditioning and unpatchifying...") audio_state = audio_tools.clear_conditioning(audio_state) audio_state = audio_tools.unpatchify(audio_state) logging.info(f"Final latent shape: {audio_state.latent.shape}") # ---- End-of-clip silence-prior fix ---- # Base LTX-2.3 22B was trained on audio clips ≤ ~20 s and learned a strong # "clip-end silence" prior at the next patchifier-aligned latent boundary # (frame 513 = 8 × 64 + 1). For longer outputs that prior leaks through as # a ~30 ms hard silence dip near 20.4 s. Linearly interpolating frames # 512–513 between their neighbours (511 and 514) removes the dip cleanly. latent_in = audio_state.latent if latent_in.shape[2] > 513: f0, f1 = 511, 514 n = f1 - f0 patched = latent_in.clone() for f in (512, 513): t = (f - f0) / n patched[:, :, f, :] = (1.0 - t) * latent_in[:, :, f0, :] + t * latent_in[:, :, f1, :] latent_in = patched # ---- Decode audio ---- logging.info("Decoding audio...") ad = AudioDecoder(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device) decoded = ad(latent_in) del ad torch.cuda.empty_cache() wav = decoded.waveform if wav.dim() == 1: wav = wav.unsqueeze(0) sr = decoded.sampling_rate # Trim leading pad if --pad-start was used if args.pad_start > 0: trim_samples = int(args.pad_start * sr) wav = wav[..., trim_samples:] logging.info(f"Trimmed {args.pad_start}s ({trim_samples} samples) of start padding") # Apply Perth (Perceptual Threshold) imperceptible neural watermark — see # https://github.com/resemble-ai/perth. Mono waveform required; if stereo, # we average to mono for the watermark and broadcast back. Skip on # --no-watermark for debugging. wav_cpu = wav.float().cpu() if not getattr(args, "no_watermark", False): try: import perth import numpy as np wm = perth.PerthImplicitWatermarker() mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy() mono_wm = wm.apply_watermark(mono, sample_rate=sr) mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0) wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1) except Exception as e: logging.warning(f"Perth watermark skipped ({e})") os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) torchaudio.save(args.output, wav_cpu, sr) elapsed = time.time() - t0 logging.info(f"Output: {args.output} ({wav.shape[-1] / sr:.1f}s)") logging.info(f"Total time: {elapsed:.1f}s") if __name__ == "__main__": main()