#!/usr/bin/env python3 """Warm validation runner — loads base dev + LoRA + all aux models ONCE, then iterates every speaker in val_config generating each output. Matches the same generation path as inference.py but keeps Gemma / audio VAE / velocity model / audio decoder resident across entries. Inference settings default to the Gradio warm-server values (cfg=2.5, stg=1.5, modality=1.0, rescale=0, 30 steps, fps=25) — use --inference-params to override. """ import argparse import logging import os import sys import time import traceback import torch import torchaudio REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) MODEL_DIR = REPO_DIR sys.path.insert(0, os.path.join(REPO_DIR, "ltx2")) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) DEV_FULL_CKPT = os.environ.get( "LTX_FULL_CHECKPOINT", os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx-2.3-22b-dev.safetensors"), ) GEMMA_ROOT = os.environ.get( "GEMMA_ROOT", os.path.expanduser("~/.cache/dramabox/gemma-3-12b-it-bnb-4bit"), ) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--val-config", required=True) p.add_argument("--output-dir", required=True) p.add_argument("--lora", default=None) p.add_argument("--lora-rank", type=int, default=128) p.add_argument("--full-checkpoint", default=DEV_FULL_CKPT) p.add_argument("--gemma-root", default=GEMMA_ROOT) p.add_argument("--cfg-scale", type=float, default=2.5) p.add_argument("--stg-scale", type=float, default=1.5) p.add_argument("--rescale-scale", type=float, default=0.0) p.add_argument("--modality-scale", type=float, default=1.0) p.add_argument("--steps", type=int, default=30) p.add_argument("--fps", type=float, default=25.0) p.add_argument("--stg-block", type=int, default=29) p.add_argument("--cfg-clamp", type=float, default=0.0) p.add_argument("--seed", type=int, default=42) p.add_argument("--duration-multiplier", type=float, default=1.1) # Match Gradio / inference_server.py DEFAULT_NEG exactly p.add_argument("--negative-prompt", default=( "worst quality, inconsistent, robotic, distorted, noise, static, " "muffled, unclear, unnatural, monotone" )) return p.parse_args() def estimate_speech_duration(prompt: str, speed: float = 1.0) -> float: import re quoted = re.findall(r'"([^"]*)"', prompt) or re.findall(r"'([^']*)'", prompt) text = " ".join(quoted) if quoted else prompt duration = len(text) * 0.065 / max(speed, 0.1) + 1.5 return max(3.0, round(duration, 1)) class WarmValidator: def __init__(self, full_checkpoint, gemma_root, lora_path=None, lora_rank=128, device="cuda", dtype=torch.bfloat16): from audio_conditioning import AudioConditionByReferenceLatent # noqa: F401 (imported by inference.py) from ltx_core.components.patchifiers import AudioPatchifier from ltx_pipelines.utils.blocks import PromptEncoder, AudioConditioner, AudioDecoder self.device = torch.device(device) self.dtype = dtype self.full_checkpoint = full_checkpoint self.gemma_root = gemma_root self.patchifier = AudioPatchifier(patch_size=1) logging.info("Loading PromptEncoder (Gemma + embeddings_processor)...") t0 = time.time() self.prompt_encoder = PromptEncoder( checkpoint_path=full_checkpoint, gemma_root=gemma_root, dtype=dtype, device=self.device, warm=True, audio_only=True, ) logging.info(f" PromptEncoder ready in {time.time()-t0:.1f}s") logging.info("Loading AudioConditioner (audio VAE encoder)...") t0 = time.time() self.audio_conditioner = AudioConditioner( checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True, ) logging.info(f" AudioConditioner ready in {time.time()-t0:.1f}s") logging.info("Loading AudioDecoder...") t0 = time.time() self.audio_decoder = AudioDecoder( checkpoint_path=full_checkpoint, dtype=dtype, device=self.device, warm=True, ) logging.info(f" AudioDecoder ready in {time.time()-t0:.1f}s") logging.info("Building velocity model (audio-only from base dev)...") t0 = time.time() self.velocity_model = self._build_velocity_model(full_checkpoint, lora_path, lora_rank) logging.info(f" Velocity model ready in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in self.velocity_model.parameters()) / 1e9:.1f}B params)") def _build_velocity_model(self, checkpoint_path, lora_path, lora_rank): from ltx_core.loader.registry import DummyRegistry from ltx_core.loader.sd_ops import SDOps from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder from ltx_core.model.model_protocol import ModelConfigurator from ltx_core.model.transformer.attention import AttentionFunction from ltx_core.model.transformer.model import LTXModel, LTXModelType from ltx_core.model.transformer.rope import LTXRopeType sd_ops = ( SDOps("AO") .with_matching(prefix="model.diffusion_model.") .with_replacement("model.diffusion_model.", "") ) class Cfg(ModelConfigurator[LTXModel]): @classmethod def from_config(cls, config): t = config.get("transformer", {}) cp = None if not t.get("caption_proj_before_connector", False): from ltx_core.model.transformer.text_projection import create_caption_projection with torch.device("meta"): cp = create_caption_projection(t, audio=True) return LTXModel( model_type=LTXModelType.AudioOnly, audio_num_attention_heads=t.get("audio_num_attention_heads", 32), audio_attention_head_dim=t.get("audio_attention_head_dim", 64), audio_in_channels=t.get("audio_in_channels", 128), audio_out_channels=t.get("audio_out_channels", 128), num_layers=t.get("num_layers", 48), audio_cross_attention_dim=t.get("audio_cross_attention_dim", 2048), norm_eps=t.get("norm_eps", 1e-6), attention_type=AttentionFunction(t.get("attention_type", "default")), positional_embedding_theta=10000.0, audio_positional_embedding_max_pos=[20.0], timestep_scale_multiplier=t.get("timestep_scale_multiplier", 1000), use_middle_indices_grid=t.get("use_middle_indices_grid", True), rope_type=LTXRopeType(t.get("rope_type", "interleaved")), double_precision_rope=t.get("frequencies_precision", False) == "float64", apply_gated_attention=t.get("apply_gated_attention", False), audio_caption_projection=cp, cross_attention_adaln=t.get("cross_attention_adaln", False), ) builder = Builder( model_path=checkpoint_path, model_class_configurator=Cfg, model_sd_ops=sd_ops, registry=DummyRegistry(), ) velocity = builder.build(device=self.device, dtype=self.dtype).to(self.device).eval() if lora_path and os.path.exists(lora_path): from peft import LoraConfig, get_peft_model from safetensors.torch import load_file as st_load logging.info(f"Attaching LoRA: {lora_path}") lora_sd = st_load(lora_path) is_peft = any("base_model.model." in k for k in lora_sd.keys()) is_iclora = any("diffusion_model." in k for k in lora_sd.keys()) cfg = LoraConfig( r=lora_rank, lora_alpha=lora_rank, lora_dropout=0.0, bias="none", target_modules=[ "audio_attn1.to_k", "audio_attn1.to_q", "audio_attn1.to_v", "audio_attn1.to_out.0", "audio_attn2.to_k", "audio_attn2.to_q", "audio_attn2.to_v", "audio_attn2.to_out.0", "audio_ff.net.0.proj", "audio_ff.net.2", ], ) velocity = get_peft_model(velocity, cfg) if is_peft: mapped = {} for k, v in lora_sd.items(): nk = k if ".lora_A.weight" in k and ".lora_A.default.weight" not in k: nk = k.replace(".lora_A.weight", ".lora_A.default.weight") if ".lora_B.weight" in k and ".lora_B.default.weight" not in k: nk = k.replace(".lora_B.weight", ".lora_B.default.weight") mapped[nk] = v _, unexpected = velocity.load_state_dict(mapped, strict=False) logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (peft)") elif is_iclora: audio_keys = {k: v for k, v in lora_sd.items() if "audio_attn1" in k or "audio_attn2" in k or "audio_ff" in k} mapped = {} for k, v in audio_keys.items(): nk = k.replace("diffusion_model.", "base_model.model.") nk = nk.replace(".lora_A.weight", ".lora_A.default.weight") nk = nk.replace(".lora_B.weight", ".lora_B.default.weight") mapped[nk] = v _, unexpected = velocity.load_state_dict(mapped, strict=False) logging.info(f" Loaded {len(mapped) - len(unexpected)} LoRA weights (iclora)") velocity = velocity.merge_and_unload() logging.info(" Merged LoRA into base weights") return velocity @torch.inference_mode() def generate(self, prompt, output_path, voice_ref=None, args=None): from audio_conditioning import AudioConditionByReferenceLatent from ltx_core.batch_split import BatchSplitAdapter from ltx_core.components.diffusion_steps import EulerDiffusionStep from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams from ltx_core.components.noisers import GaussianNoiser from ltx_core.components.schedulers import LTX2Scheduler from ltx_core.model.audio_vae import encode_audio as vae_encode_audio from ltx_core.model.transformer.model import X0Model from ltx_core.tools import AudioLatentTools from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser from ltx_pipelines.utils.gpu_model import gpu_model from ltx_pipelines.utils.media_io import decode_audio_from_file from ltx_pipelines.utils.samplers import euler_denoising_loop t_total = time.time() # ---- Duration + shape ---- gen_dur = estimate_speech_duration(prompt) * args.duration_multiplier raw_frames = int(round(gen_dur * args.fps)) + 1 num_frames = ((raw_frames - 1 + 4) // 8) * 8 + 1 pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=64, width=64, fps=args.fps) tgt_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape) audio_tools = AudioLatentTools(patchifier=self.patchifier, target_shape=tgt_shape) state = audio_tools.create_initial_state(self.device, self.dtype) # ---- Voice reference ---- if voice_ref and os.path.exists(voice_ref): voice = decode_audio_from_file(voice_ref, self.device, 0.0, 10.0) if voice is not None: w = voice.waveform if w.dim() == 2: if w.shape[0] == 1: w = w.repeat(2, 1) w = w.unsqueeze(0) elif w.dim() == 3 and w.shape[1] == 1: w = w.repeat(1, 2, 1) target_samples = int(10.0 * voice.sampling_rate) if w.shape[-1] < target_samples: w = w.repeat(1, 1, (target_samples // w.shape[-1]) + 1) w = w[..., :target_samples] peak = w.abs().max() if peak > 0: w = w * (10 ** (-4.0 / 20) / peak) voice = Audio(waveform=w, sampling_rate=voice.sampling_rate) ref_latent = self.audio_conditioner(lambda enc: vae_encode_audio(voice, enc, None)) cond = AudioConditionByReferenceLatent( latent=ref_latent.to(self.device, self.dtype), strength=1.0, ) state = cond.apply_to(latent_state=state, latent_tools=audio_tools) # ---- Noise ---- gen = torch.Generator(device=self.device).manual_seed(args.seed) noiser = GaussianNoiser(generator=gen) state = noiser(state, noise_scale=1.0) # ---- Prompt encode ---- use_cfg = args.cfg_scale > 1.0 prompts = [prompt, args.negative_prompt] if use_cfg else [prompt] ctx = self.prompt_encoder(prompts, streaming_prefetch_count=None) a_ctx = ctx[0].audio_encoding a_ctx_neg = ctx[1].audio_encoding if use_cfg else None # ---- Denoiser ---- needs_guidance = args.cfg_scale > 1.0 or args.stg_scale > 0.0 or args.modality_scale > 1.0 if needs_guidance: guider = MultiModalGuider( params=MultiModalGuiderParams( cfg_scale=args.cfg_scale, stg_scale=args.stg_scale, stg_blocks=[args.stg_block] if args.stg_scale > 0 else [], rescale_scale=args.rescale_scale, modality_scale=args.modality_scale, cfg_clamp_scale=args.cfg_clamp, ), negative_context=a_ctx_neg, ) denoiser = GuidedDenoiser( v_context=None, a_context=a_ctx, video_guider=None, audio_guider=guider, ) else: denoiser = SimpleDenoiser(v_context=None, a_context=a_ctx) sigmas = LTX2Scheduler().execute(steps=args.steps, latent=state.latent).to(self.device) # ---- Denoise ---- # NOTE: don't wrap in gpu_model() — that context manager moves the # model back off GPU on exit, which breaks subsequent iterations of # our warm validator. We keep the velocity model resident. x0 = X0Model(self.velocity_model) batched = BatchSplitAdapter(x0, max_batch_size=1) _, audio_state = euler_denoising_loop( sigmas=sigmas, video_state=None, audio_state=state, stepper=EulerDiffusionStep(), transformer=batched, denoiser=denoiser, ) audio_state = audio_tools.clear_conditioning(audio_state) audio_state = audio_tools.unpatchify(audio_state) decoded = self.audio_decoder(audio_state.latent) wav = decoded.waveform if wav.dim() == 1: wav = wav.unsqueeze(0) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) torchaudio.save(output_path, wav.float().cpu(), decoded.sampling_rate) logging.info(f" -> {output_path} ({wav.shape[-1]/decoded.sampling_rate:.1f}s, " f"{time.time()-t_total:.1f}s)") def main(): logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") args = parse_args() import yaml with open(args.val_config) as f: val_cfg = yaml.safe_load(f) os.makedirs(args.output_dir, exist_ok=True) # Build validator once (models warm for all entries). validator = WarmValidator( full_checkpoint=args.full_checkpoint, gemma_root=args.gemma_root, lora_path=args.lora, lora_rank=args.lora_rank, device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16, ) n_ok = n_fail = 0 t0 = time.time() for entry in val_cfg.get("speakers", []): name = entry["name"] out_path = os.path.join(args.output_dir, f"{name}.wav") try: validator.generate( prompt=entry["prompt"], output_path=out_path, voice_ref=entry.get("reference"), args=args, ) n_ok += 1 logging.info(f" [{name}] OK") except Exception as e: n_fail += 1 logging.warning(f" [{name}] FAILED: {e}") traceback.print_exc() logging.info(f"Validation done: ok={n_ok} fail={n_fail} in {(time.time()-t0)/60:.1f}min " f"at {args.output_dir}") if __name__ == "__main__": main()