Dramabox / src /inference.py
Manmay Nakhashi
Add explicit Target Duration slider (0–60s) + gen_duration kwarg
5cc51a5
#!/usr/bin/env python3
"""
LTX-2.3 TTS with IC-LoRA voice cloning.
Uses AudioConditionByReferenceLatent to append reference audio tokens to the
end of the target sequence. Auto-detects distilled vs dev checkpoint and
selects the appropriate denoiser (SimpleDenoiser / GuidedDenoiser) and sigma
schedule. Leverages the official euler_denoising_loop, AudioLatentTools,
GaussianNoiser, and X0Model wrapper throughout.
Usage (distilled):
python tts_iclora.py \
--voice-sample reference.wav \
--prompt "A woman speaks clearly: The weather today will be sunny." \
--output tts_output.wav
Usage (dev):
python tts_iclora.py \
--voice-sample reference.wav \
--prompt "A woman speaks clearly: The weather today will be sunny." \
--checkpoint ltx-2.3-22b-dev-audio-only.safetensors \
--full-checkpoint ltx-2.3-22b-dev.safetensors \
--output tts_output.wav
"""
import argparse
import json
import logging
import os
import re
import struct
import sys
import time
from pathlib import Path
import torch
import torchaudio
REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
# ltx-pipelines already on path via ltx2/
# Also add the local directory so audio_conditioning.py is importable
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models")
GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def detect_model_type(checkpoint_path: str) -> str:
"""Detect if checkpoint is distilled or dev by checking filename and metadata."""
path_lower = checkpoint_path.lower()
if "distilled" in path_lower:
return "distilled"
if "dev" in path_lower:
return "dev"
# Fallback: try to read safetensors metadata
try:
with open(checkpoint_path, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header = json.loads(f.read(header_size).decode())
metadata = header.get("__metadata__", {})
version = metadata.get("model_version", "")
if "distilled" in version.lower():
return "distilled"
except Exception:
pass
# Default to distilled (most common for audio-only)
return "distilled"
_LAUGH_VERBS = {
# base seconds per occurrence; gets scaled by the modifier found nearby.
# Verb regex covers inflections: laugh/laughs/laughed/laughing.
r"\blaugh(?:s|ed|ing)?\b": 1.5,
r"\bcackl(?:e|es|ed|ing)\b": 1.5,
r"\bchuckl(?:e|es|ed|ing)\b": 1.0,
r"\bgiggl(?:e|es|ed|ing)\b": 1.0,
r"\bsnicker(?:s|ed|ing)?\b": 0.8,
r"\bcru?el laugh\b": 1.5,
}
def _contextual_laugh_duration(text: str) -> float:
"""Context-aware laugh budget.
For each laugh verb in the prompt, look at the adjective/adverb that
modifies it and scale the base duration:
- short modifiers (briefly, softly, once) -> 0.4x base
- long modifiers (maniacally, heartily, ...) -> 1.2x base
- default (no mod / neutral) -> 1.0x base
Also reward phonetic repetition inside quotes -- 'Hahahahahaha' buys more
time than 'Haha' -- at ~0.2s per extra repeated syllable.
"""
# "softly" / "quietly" describe volume not length, so keep at default 1.0x.
short_mod = re.compile(
r"^\s*(?:[a-z]+ly )?(?:briefly|shortly|once|quickly)",
re.IGNORECASE)
long_mod = re.compile(
r"^\s*(?:[a-z]+ly )?(?:maniacally|heartily|uproariously|uncontrollably|"
r"hysterically|darkly|wickedly|evilly|loudly|long)"
r"|^\s*between phrases", re.IGNORECASE)
total = 0.0
for pat, base_dur in _LAUGH_VERBS.items():
for m in re.finditer(pat, text, re.IGNORECASE):
ctx = text[m.end(): m.end() + 40]
if short_mod.match(ctx):
total += base_dur * 0.4
elif long_mod.match(ctx):
total += base_dur * 1.2
else:
total += base_dur
# Phonetic laugh repetition inside quotes:
# 'Haha' = 2 syllables (base, no bonus)
# 'Hahahaha' = 4 syllables (+0.4s)
# 'Hehehehahahahahahahaha' ~ 10 syllables (+1.6s)
for q in re.findall(r'"([^"]+)"', text) + re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text):
for run in re.findall(r"(?:h[ae]){3,}|(?:h[ae][ \-]?){3,}", q, re.IGNORECASE):
syls = len(re.findall(r"h[ae]", run, re.IGNORECASE))
total += 0.2 * max(syls - 2, 0)
return total
def _estimate_nonverbal_duration(text: str) -> float:
"""Estimate extra duration for non-verbal sounds and actions in the prompt.
Laugh-verb handling lives in ``_contextual_laugh_duration`` so cackle /
chuckle / laugh budgets scale with the adjective ("maniacally" vs
"briefly") and with the repetition length of 'Ha'/'He' tokens inside
quotes.
"""
PATTERNS = {
# Breathing / sighs
r'\bsighs?\b': 0.8, r'\bshaky breath\b': 1.0, r'\bbreathing deeply\b': 1.0,
r'\bgasps?\b': 0.5, r'\bburps?\b': 0.5, r'\byawns?\b': 1.0,
r'\bpants?\b': 0.8, r'\bwheezes?\b': 0.8, r'\bcoughs?\b': 0.8,
r'\bsniffles?\b': 0.5, r'\bsnorts?\b': 0.3, r'\bgroans?\b': 0.8,
# Pauses (trimmed; earlier values over-budgeted silence)
r'\blong pause\b': 1.0, r'\bpauses? briefly\b': 0.3,
r'\bpauses?\b': 0.5, r'\bsilence\b': 1.0,
r'\blets? the .{1,20} hang\b': 1.0, r'\blets? .{1,20} sink in\b': 1.0,
# Physical actions that produce sound
r'\bslams?\b': 0.5, r'\bclaps?\b': 0.3,
r'\bdraws? (?:his|her|a) sword\b': 0.5,
r'\btakes? a (?:drag|swig|sip|drink)\b': 0.5,
r'\bwhistles?\b': 1.0, r'\bhums?\b': 0.8,
# Vocal actions (not in quotes but take time)
r'\bmutters?\b': 1.5, r'\bmumbles?\b': 1.0, r'\bwhispers?\b': 0.0,
r'\bclears? (?:his|her) throat\b': 0.5, r'\bgulps?\b': 0.5,
r'\bswallows?\b': 0.5,
# (laugh / chuckle / cackle / giggle / snicker handled by
# _contextual_laugh_duration below -- modifier-aware, not flat.)
# Emotional transitions
r'\bvoice (?:breaks?|cracks?|trembles?|drops?|rises?)\b': 0.5,
r'\bsteadies? (?:him|her)self\b': 1.0,
r'\bcatches? (?:his|her) breath\b': 1.0,
r'\bcomposes? (?:him|her)self\b': 0.8,
# Scene transitions that imply time
r'\bdemeanor shifts?\b': 0.5, r'\bsettles? in\b': 0.5,
r'\bleans? in\b': 0.3, r'\bwipes? (?:his|her) eyes\b': 0.5,
}
extra = 0.0
for pattern, dur in PATTERNS.items():
extra += dur * len(re.findall(pattern, text, re.IGNORECASE))
extra += _contextual_laugh_duration(text)
return extra
def estimate_speech_duration(text: str, speed: float = 1.0) -> float:
"""Estimate speech duration from spoken content + non-verbal actions.
Extracts spoken text by priority:
1. Quoted text ('...' or "...") -- official prompt guide format
2. Text after colon -- simple "Speaker: dialogue" format
3. Full text -- fallback
Also scans the full prompt for non-verbal cues (laughs, pauses, sighs,
gasps, etc.) and adds estimated duration for each.
"""
# Try double quotes first (clean, no contraction issues)
quotes = re.findall(r'"([^"]+)"', text)
if not quotes:
# Single quotes: allow apostrophes in contractions (don't, can't, it's)
# Match ' to ' but apostrophes NOT followed by space/punctuation are kept inside
quotes = re.findall(r"'((?:[^']|'(?![\s.,!?)\]]))+)'", text)
# Filter out short fragments (scene directions like "He pauses")
quotes = [q for q in quotes if len(q.split()) > 3]
if quotes:
spoken = " ".join(quotes)
elif ":" in text:
spoken = text.split(":", 1)[1].strip()
else:
spoken = text
CHARS_PER_SEC = 14.0
text_len = len(spoken)
if text_len < 40:
chars_per_sec = CHARS_PER_SEC * 0.6
elif text_len < 80:
chars_per_sec = CHARS_PER_SEC * 0.8
else:
chars_per_sec = CHARS_PER_SEC
chars_per_sec *= speed
duration = text_len / chars_per_sec
sentence_count = spoken.count(".") + spoken.count("!") + spoken.count("?")
duration += sentence_count * 0.3
# Add time for non-verbal sounds/actions in the full prompt
duration += _estimate_nonverbal_duration(text)
return max(3.0, round(duration + 2.0, 1))
def parse_args():
p = argparse.ArgumentParser(description="LTX-2.3 TTS with IC-LoRA voice cloning")
p.add_argument("--voice-sample", default=None, help="Voice reference WAV")
p.add_argument("--no-ref", action="store_true", help="Skip voice reference conditioning (raw base model)")
p.add_argument("--prompt", required=True, help="Text/scene description to synthesize")
p.add_argument("--output", default="tts_output.wav")
p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
p.add_argument("--gen-duration", type=float, default=0.0,
help="Target output duration in seconds (0 = auto from prompt + multiplier). "
"Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). "
"Outputs >20.5s automatically engage the end-of-clip silence-prior patch.")
p.add_argument("--pad-start", type=float, default=0.0,
help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
p.add_argument("--speed", type=float, default=1.0)
p.add_argument("--duration-multiplier", type=float, default=1.0,
help="Multiply auto-estimated duration by this factor (e.g. 1.1 for 10%% more breathing room)")
p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-audio-only.safetensors"))
p.add_argument("--full-checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
p.add_argument("--gemma-root", default=GEMMA_DIR)
p.add_argument("--bnb-4bit", dest="bnb_4bit", action="store_true", default=True,
help="Load Gemma text encoder via the bitsandbytes 4-bit path "
"(required for the default unsloth/gemma-3-12b-it-bnb-4bit "
"pre-quantized weights). Default: on.")
p.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false",
help="Disable the bitsandbytes path (use only if --gemma-root "
"points at an unquantized Gemma checkpoint).")
p.add_argument("--lora", default=None, help="Path to trained IC-LoRA .safetensors (audio-only)")
p.add_argument("--lora-rank", type=int, default=128, help="LoRA rank (must match training)")
p.add_argument("--id-guidance-scale", type=float, default=3.0, help="Identity guidance scale (0=disabled)")
p.add_argument("--seed", type=int, default=42)
# Auto-set based on model type but overridable
p.add_argument("--no-watermark", action="store_true",
help="Skip Perth audio watermarking on the output (default: watermark on).")
p.add_argument("--sampler", choices=["euler", "heun"], default="euler",
help="Denoising loop. 'heun' = jkass_quality 2nd-order predictor-corrector (~2x model calls, cleaner audio).")
p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
p.add_argument("--rescale-scale", type=float, default=None,
help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents "
"output clipping at high cfg; pass any float in [0,1] to override).")
p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
p.add_argument("--fps", type=float, default=None, help="FPS (auto: 24.0 distilled, 25.0 dev)")
p.add_argument(
"--negative-prompt",
default=(
"worst quality, inconsistent motion, blurry, jittery, distorted, "
"robotic voice, echo, background noise, off-sync audio, repetitive speech"
),
help="Negative prompt for CFG (dev model)",
)
return p.parse_args()
@torch.inference_mode()
def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
args = parse_args()
t0 = time.time()
# ---- Imports (deferred to avoid startup cost when checking --help) ----
from audio_conditioning import AudioConditionByReferenceLatent
from ltx_core.batch_split import BatchSplitAdapter
from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.components.patchifiers import AudioPatchifier
from ltx_core.components.schedulers import LTX2Scheduler
from ltx_core.loader.registry import DummyRegistry
from ltx_core.loader.sd_ops import SDOps
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
from ltx_core.model.model_protocol import ModelConfigurator
from ltx_core.model.transformer.attention import AttentionFunction
from ltx_core.model.transformer.model import LTXModel, LTXModelType, X0Model
from ltx_core.model.transformer.rope import LTXRopeType
from ltx_core.tools import AudioLatentTools
from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoPixelShape
from ltx_pipelines.utils.blocks import AudioConditioner, AudioDecoder, PromptEncoder
from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES
from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
from ltx_pipelines.utils.gpu_model import gpu_model
from ltx_pipelines.utils.media_io import decode_audio_from_file
from ltx_pipelines.utils.samplers import euler_denoising_loop, heun_denoising_loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
patchifier = AudioPatchifier(patch_size=1)
# ---- Detect model type and set defaults ----
model_type = detect_model_type(args.full_checkpoint)
logging.info(f"Detected model type: {model_type}")
is_distilled = model_type == "distilled"
if args.cfg_scale is None:
args.cfg_scale = 1.0 if is_distilled else 7.0
if args.stg_scale is None:
args.stg_scale = 0.0 if is_distilled else 1.0
if args.rescale_scale is None:
# Auto cfg-aware rescale: imported from inference_server to keep one source of truth.
from inference_server import auto_rescale_for_cfg
args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale)
if args.modality_scale is None:
args.modality_scale = 1.0 if is_distilled else 3.0
if args.fps is None:
args.fps = 24.0 if is_distilled else 25.0
logging.info(
f"Params: cfg={args.cfg_scale}, stg={args.stg_scale}, rescale={args.rescale_scale}, "
f"modality={args.modality_scale}, fps={args.fps}"
)
# ---- Auto duration ----
if args.gen_duration <= 0:
args.gen_duration = estimate_speech_duration(args.prompt, args.speed)
if args.duration_multiplier != 1.0:
args.gen_duration = round(args.gen_duration * args.duration_multiplier, 1)
logging.info(f"Auto duration: {args.gen_duration}s for {len(args.prompt)} chars"
f"{f' (x{args.duration_multiplier})' if args.duration_multiplier != 1.0 else ''}")
# ---- Compute target shape (include pad_start in duration) ----
padded_duration = args.gen_duration + args.pad_start
raw_frames = int(round(padded_duration * args.fps)) + 1
num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1
pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps)
tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
logging.info(f"Target shape: {tgt_shape} ({args.gen_duration}s, {num_frames} frames)")
# ---- AudioLatentTools for target ----
audio_tools = AudioLatentTools(patchifier=patchifier, target_shape=tgt_shape)
# ---- Create initial state ----
state = audio_tools.create_initial_state(device, dtype)
logging.info(
f"Initial state: latent={state.latent.shape}, positions={state.positions.shape}, "
f"denoise_mask={state.denoise_mask.shape}"
)
if not args.no_ref and args.voice_sample:
# ---- Encode voice reference ----
logging.info(f"Loading voice reference: {args.voice_sample}")
voice = decode_audio_from_file(args.voice_sample, device, 0.0, args.ref_duration)
if voice is None:
raise ValueError(f"Could not load audio from {args.voice_sample}")
w = voice.waveform
if w.dim() == 2:
if w.shape[0] == 1:
w = w.repeat(2, 1)
w = w.unsqueeze(0)
elif w.dim() == 3 and w.shape[1] == 1:
w = w.repeat(1, 2, 1)
target_samples = int(args.ref_duration * voice.sampling_rate)
if w.shape[-1] < target_samples:
w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1)
w = w[..., :target_samples]
# Peak normalize reference
peak = w.abs().max()
if peak > 0:
target_peak = 10 ** (-4.0 / 20) # -4dB
w = w * (target_peak / peak)
logging.info(f"Normalized reference: peak {peak:.4f} -> {target_peak:.4f}")
voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)
logging.info("Encoding voice through Audio VAE...")
ac = AudioConditioner(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
ref_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))
del ac
torch.cuda.empty_cache()
logging.info(f"Reference latent: {ref_latent.shape}")
# ---- Apply conditioning: append ref tokens to END ----
conditioning = AudioConditionByReferenceLatent(latent=ref_latent.to(device, dtype), strength=1.0)
state = conditioning.apply_to(latent_state=state, latent_tools=audio_tools)
logging.info(
f"After conditioning: latent={state.latent.shape}, positions={state.positions.shape}, "
f"attention_mask={'None' if state.attention_mask is None else state.attention_mask.shape}"
)
else:
logging.info("No voice reference — running raw base model")
# ---- Apply noise ----
generator = torch.Generator(device=device).manual_seed(args.seed)
noiser = GaussianNoiser(generator=generator)
noised_state = noiser(state, noise_scale=1.0)
logging.info("Applied Gaussian noise to state")
# ---- Encode prompt ----
use_cfg = args.cfg_scale > 1.0
logging.info("Encoding prompt...")
pe = PromptEncoder(checkpoint_path=args.full_checkpoint, gemma_root=args.gemma_root, dtype=dtype, device=device,
use_bnb_4bit=args.bnb_4bit, warm=True)
prompts_to_encode = [args.prompt]
if use_cfg:
prompts_to_encode.append(args.negative_prompt)
ctx = pe(prompts_to_encode, streaming_prefetch_count=None)
a_ctx = ctx[0].audio_encoding
a_ctx_neg = ctx[1].audio_encoding if use_cfg else None
del pe
torch.cuda.empty_cache()
logging.info(f"Prompt encoded: a_ctx={a_ctx.shape}" + (f", a_ctx_neg={a_ctx_neg.shape}" if a_ctx_neg is not None else ""))
# ---- Build audio-only model ----
logging.info("Building audio-only model...")
audio_only_sd_ops = SDOps("AO").with_matching(prefix="model.diffusion_model.").with_replacement(
"model.diffusion_model.", ""
)
class AudioOnlyConfigurator(ModelConfigurator[LTXModel]):
@classmethod
def from_config(cls, config):
t = config.get("transformer", {})
cp = None
if not t.get("caption_proj_before_connector", False):
from ltx_core.model.transformer.text_projection import create_caption_projection
with torch.device("meta"):
cp = create_caption_projection(t, audio=True)
return LTXModel(
model_type=LTXModelType.AudioOnly,
audio_num_attention_heads=t.get("audio_num_attention_heads", 32),
audio_attention_head_dim=t.get("audio_attention_head_dim", 64),
audio_in_channels=t.get("audio_in_channels", 128),
audio_out_channels=t.get("audio_out_channels", 128),
num_layers=t.get("num_layers", 48),
audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048),
norm_eps=t.get("norm_eps", 1e-6),
attention_type=AttentionFunction(t.get("attention_type", "default")),
positional_embedding_theta=10000.0,
audio_positional_embedding_max_pos=[20.0],
timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000),
use_middle_indices_grid=t.get("use_middle_indices_grid", True),
rope_type=LTXRopeType(t.get("rope_type", "interleaved")),
double_precision_rope=t.get("frequencies_precision", False) == "float64",
apply_gated_attention=t.get("apply_gated_attention", False),
audio_caption_projection=cp,
cross_attention_adaln=t.get("cross_attention_adaln", False),
)
builder = Builder(
model_path=args.checkpoint,
model_class_configurator=AudioOnlyConfigurator,
model_sd_ops=audio_only_sd_ops,
registry=DummyRegistry(),
)
velocity_model = builder.build(device=device, dtype=dtype).to(device).eval()
# ---- Load LoRA weights (if provided) ----
if args.lora and os.path.exists(args.lora):
from peft import LoraConfig, get_peft_model
from safetensors.torch import load_file as st_load
logging.info(f"Loading LoRA: {args.lora}")
lora_sd = st_load(args.lora)
is_peft_format = any("base_model.model." in k for k in lora_sd.keys())
is_original_idlora = any("diffusion_model." in k for k in lora_sd.keys())
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_rank,
lora_dropout=0.0,
bias="none",
target_modules=[
"audio_attn1.to_k",
"audio_attn1.to_q",
"audio_attn1.to_v",
"audio_attn1.to_out.0",
"audio_attn2.to_k",
"audio_attn2.to_q",
"audio_attn2.to_v",
"audio_attn2.to_out.0",
"audio_ff.net.0.proj",
"audio_ff.net.2",
],
)
velocity_model = get_peft_model(velocity_model, lora_config)
if is_peft_format:
mapped_sd = {}
for k, v in lora_sd.items():
new_key = k
if ".lora_A.weight" in k and ".lora_A.default.weight" not in k:
new_key = k.replace(".lora_A.weight", ".lora_A.default.weight")
if ".lora_B.weight" in k and ".lora_B.default.weight" not in k:
new_key = k.replace(".lora_B.weight", ".lora_B.default.weight")
mapped_sd[new_key] = v
missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
loaded = len(mapped_sd) - len(unexpected)
logging.info(f"Loaded {loaded} LoRA weights (peft format)")
elif is_original_idlora:
audio_keys = {
k: v
for k, v in lora_sd.items()
if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k
}
mapped_sd = {}
for k, v in audio_keys.items():
new_key = k.replace("diffusion_model.", "base_model.model.")
new_key = new_key.replace(".lora_A.weight", ".lora_A.default.weight")
new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight")
mapped_sd[new_key] = v
missing, unexpected = velocity_model.load_state_dict(mapped_sd, strict=False)
loaded = len(mapped_sd) - len(unexpected)
logging.info(f"Loaded {loaded} LoRA weights (original ID-LoRA)")
velocity_model = velocity_model.merge_and_unload()
logging.info("Merged LoRA into model")
logging.info(f"Model: {sum(p.numel() for p in velocity_model.parameters()) / 1e9:.1f}B params")
# ---- Wrap velocity model in X0Model ----
x0_model = X0Model(velocity_model)
# ---- Build denoiser and sigmas ----
stepper = EulerDiffusionStep()
# ---- Sigma schedule ----
if is_distilled:
if args.steps is not None and args.steps > 0:
sigmas = LTX2Scheduler().execute(steps=args.steps, latent=noised_state.latent).to(device)
logging.info(f"Distilled with custom {args.steps}-step schedule")
else:
sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, dtype=torch.float32, device=device)
logging.info(f"Distilled {len(DISTILLED_SIGMA_VALUES) - 1}-step schedule")
else:
steps = args.steps if args.steps is not None and args.steps > 0 else 30
sigmas = LTX2Scheduler().execute(steps=steps, latent=noised_state.latent).to(device)
logging.info(f"Dev {steps}-step schedule")
# ---- Denoiser: use GuidedDenoiser if any guidance is active, SimpleDenoiser otherwise ----
needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0
if needs_guidance:
audio_guider = MultiModalGuider(
params=MultiModalGuiderParams(
cfg_scale=args.cfg_scale,
stg_scale=args.stg_scale,
stg_blocks=[args.stg_block] if args.stg_scale > 0 else [],
rescale_scale=args.rescale_scale,
modality_scale=args.modality_scale,
cfg_clamp_scale=args.cfg_clamp,
),
negative_context=a_ctx_neg,
)
denoiser = GuidedDenoiser(
v_context=None,
a_context=a_ctx,
video_guider=None,
audio_guider=audio_guider,
)
logging.info(f"GuidedDenoiser: cfg={args.cfg_scale}, stg={args.stg_scale}, "
f"rescale={args.rescale_scale}, modality={args.modality_scale}")
else:
denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx)
logging.info("SimpleDenoiser (no guidance)")
logging.info(f"Sigmas: {sigmas.tolist()}")
# ---- Denoising loop ----
logging.info(f"Running denoising loop ({len(sigmas) - 1} steps)...")
with gpu_model(x0_model) as model:
batched_model = BatchSplitAdapter(model, max_batch_size=1)
denoise_fn = heun_denoising_loop if args.sampler == "heun" else euler_denoising_loop
_, audio_state = denoise_fn(
sigmas=sigmas,
video_state=None,
audio_state=noised_state,
stepper=stepper,
transformer=batched_model,
denoiser=denoiser,
)
del velocity_model, x0_model
torch.cuda.empty_cache()
# ---- Strip ref tokens and unpatchify ----
logging.info("Stripping conditioning and unpatchifying...")
audio_state = audio_tools.clear_conditioning(audio_state)
audio_state = audio_tools.unpatchify(audio_state)
logging.info(f"Final latent shape: {audio_state.latent.shape}")
# ---- End-of-clip silence-prior fix ----
# Base LTX-2.3 22B was trained on audio clips ≤ ~20 s and learned a strong
# "clip-end silence" prior at the next patchifier-aligned latent boundary
# (frame 513 = 8 × 64 + 1). For longer outputs that prior leaks through as
# a ~30 ms hard silence dip near 20.4 s. Linearly interpolating frames
# 512–513 between their neighbours (511 and 514) removes the dip cleanly.
latent_in = audio_state.latent
if latent_in.shape[2] > 513:
f0, f1 = 511, 514
n = f1 - f0
patched = latent_in.clone()
for f in (512, 513):
t = (f - f0) / n
patched[:, :, f, :] = (1.0 - t) * latent_in[:, :, f0, :] + t * latent_in[:, :, f1, :]
latent_in = patched
# ---- Decode audio ----
logging.info("Decoding audio...")
ad = AudioDecoder(checkpoint_path=args.full_checkpoint, dtype=dtype, device=device)
decoded = ad(latent_in)
del ad
torch.cuda.empty_cache()
wav = decoded.waveform
if wav.dim() == 1:
wav = wav.unsqueeze(0)
sr = decoded.sampling_rate
# Trim leading pad if --pad-start was used
if args.pad_start > 0:
trim_samples = int(args.pad_start * sr)
wav = wav[..., trim_samples:]
logging.info(f"Trimmed {args.pad_start}s ({trim_samples} samples) of start padding")
# Apply Perth (Perceptual Threshold) imperceptible neural watermark — see
# https://github.com/resemble-ai/perth. Mono waveform required; if stereo,
# we average to mono for the watermark and broadcast back. Skip on
# --no-watermark for debugging.
wav_cpu = wav.float().cpu()
if not getattr(args, "no_watermark", False):
try:
import perth
import numpy as np
wm = perth.PerthImplicitWatermarker()
mono = wav_cpu.mean(dim=0).numpy() if wav_cpu.shape[0] > 1 else wav_cpu[0].numpy()
mono_wm = wm.apply_watermark(mono, sample_rate=sr)
mono_wm_t = torch.from_numpy(np.asarray(mono_wm, dtype=np.float32)).unsqueeze(0)
wav_cpu = mono_wm_t if wav_cpu.shape[0] == 1 else mono_wm_t.repeat(wav_cpu.shape[0], 1)
except Exception as e:
logging.warning(f"Perth watermark skipped ({e})")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
torchaudio.save(args.output, wav_cpu, sr)
elapsed = time.time() - t0
logging.info(f"Output: {args.output} ({wav.shape[-1] / sr:.1f}s)")
logging.info(f"Total time: {elapsed:.1f}s")
if __name__ == "__main__":
main()