Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """ | |
| Preprocess TTS datasets for LTX-2.3 audio-only LoRA fine-tuning. | |
| Takes paired (audio, transcript) data and produces the format expected by | |
| the LTX trainer: | |
| .precomputed/ | |
| βββ latents/sample_N.pt # Dummy video latents (minimal) | |
| βββ conditions/sample_N.pt # Text embeddings from Gemma | |
| βββ audio_latents/sample_N.pt # Audio VAE-encoded latents | |
| Supports multiple dataset formats: | |
| - gemini_synthetic: index.txt with ~-separated fields (id~speaker~lang~sr~samples~dur~phonemes~text) | |
| - libriheavy: index_ft.txt with ~-separated fields (id~speaker~lang~samples~dur~phonemes~text) | |
| - manifest: JSON/JSONL with {"audio_filepath": ..., "text": ...} | |
| - tsv: TSV file with audio_path<TAB>text columns | |
| Usage: | |
| python preprocess_tts_data.py \ | |
| --dataset-type gemini_synthetic \ | |
| --index /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/index.txt \ | |
| --audio-dir /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/wavs \ | |
| --output-dir /mnt/persistent0/manmay/tts_training_data \ | |
| --max-samples 10000 \ | |
| --max-duration 20.0 \ | |
| --min-duration 3.0 | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2")) | |
| # ltx-pipelines on path via ltx2/ | |
| MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized") | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Preprocess TTS data for LTX-2.3 fine-tuning") | |
| p.add_argument("--dataset-type", required=True, | |
| choices=["gemini_synthetic", "libriheavy", "manifest", "tsv"], | |
| help="Dataset format type") | |
| p.add_argument("--index", required=True, help="Path to index/manifest file") | |
| p.add_argument("--audio-dir", default=None, | |
| help="Base directory for audio files (if paths in index are relative)") | |
| p.add_argument("--output-dir", required=True, help="Output directory for preprocessed data") | |
| p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors")) | |
| p.add_argument("--gemma-root", default=GEMMA_DIR) | |
| p.add_argument("--max-samples", type=int, default=0, help="Max samples to process (0=all)") | |
| p.add_argument("--max-duration", type=float, default=20.0, help="Max audio duration in seconds") | |
| p.add_argument("--min-duration", type=float, default=2.0, help="Min audio duration in seconds") | |
| p.add_argument("--batch-size", type=int, default=8, help="Batch size for text encoding") | |
| p.add_argument("--skip-existing", action="store_true", help="Skip already processed samples") | |
| p.add_argument("--audio-only-ckpt", default=None, | |
| help="Audio-only checkpoint for VAE encoding (optional, uses full ckpt if not set)") | |
| p.add_argument("--shard", type=int, default=0, help="Shard index (for parallel processing)") | |
| p.add_argument("--num-shards", type=int, default=1, help="Total number of shards") | |
| p.add_argument("--gpu", type=int, default=None, help="GPU device index to use") | |
| return p.parse_args() | |
| def parse_gemini_synthetic(index_path: str, audio_dir: str | None) -> list[dict]: | |
| """Parse gemini_synthetic format: id~speaker~lang~sr~samples~dur~phonemes~text""" | |
| samples = [] | |
| with open(index_path) as f: | |
| for line in f: | |
| parts = line.strip().split("~") | |
| if len(parts) < 7: | |
| continue | |
| file_id = parts[0] | |
| text = parts[-1] # Last field is always the text | |
| sr = int(parts[3]) | |
| n_samples = int(parts[4]) | |
| duration = n_samples / sr | |
| # Find audio file | |
| if audio_dir: | |
| # Try common extensions | |
| for ext in [".flac", ".wav", ".mp3"]: | |
| audio_path = os.path.join(audio_dir, file_id + ext) | |
| if os.path.exists(audio_path): | |
| break | |
| else: | |
| continue | |
| else: | |
| audio_path = file_id | |
| samples.append({ | |
| "id": file_id, | |
| "audio_path": audio_path, | |
| "text": text, | |
| "duration": duration, | |
| }) | |
| return samples | |
| def parse_libriheavy(index_path: str, audio_dir: str | None) -> list[dict]: | |
| """Parse libriheavy format: id~speaker~lang~samples~dur~phonemes~text""" | |
| samples = [] | |
| with open(index_path) as f: | |
| for line in f: | |
| parts = line.strip().split("~") | |
| if len(parts) < 7: | |
| continue | |
| file_id = parts[0] | |
| text = parts[-1] | |
| n_samples = int(parts[3]) | |
| duration = int(parts[4]) / 1000.0 # milliseconds to seconds | |
| if audio_dir: | |
| for ext in [".flac", ".wav", ".mp3"]: | |
| audio_path = os.path.join(audio_dir, file_id + ext) | |
| if os.path.exists(audio_path): | |
| break | |
| else: | |
| continue | |
| else: | |
| audio_path = file_id | |
| samples.append({ | |
| "id": file_id, | |
| "audio_path": audio_path, | |
| "text": text, | |
| "duration": duration, | |
| }) | |
| return samples | |
| def parse_manifest(index_path: str, audio_dir: str | None) -> list[dict]: | |
| """Parse JSON/JSONL manifest with audio_filepath and text fields.""" | |
| samples = [] | |
| with open(index_path) as f: | |
| for line in f: | |
| entry = json.loads(line.strip()) | |
| audio_path = entry.get("audio_filepath", entry.get("audio_path", "")) | |
| text = entry.get("text", entry.get("transcript", "")) | |
| duration = entry.get("duration", 0.0) | |
| if audio_dir and not os.path.isabs(audio_path): | |
| audio_path = os.path.join(audio_dir, audio_path) | |
| if os.path.exists(audio_path) and text: | |
| samples.append({ | |
| "id": Path(audio_path).stem, | |
| "audio_path": audio_path, | |
| "text": text, | |
| "duration": duration, | |
| }) | |
| return samples | |
| def parse_tsv(index_path: str, audio_dir: str | None) -> list[dict]: | |
| """Parse TSV file with audio_path<TAB>text.""" | |
| samples = [] | |
| with open(index_path) as f: | |
| for line in f: | |
| parts = line.strip().split("\t") | |
| if len(parts) < 2: | |
| continue | |
| audio_path, text = parts[0], parts[1] | |
| if audio_dir and not os.path.isabs(audio_path): | |
| audio_path = os.path.join(audio_dir, audio_path) | |
| if os.path.exists(audio_path): | |
| samples.append({ | |
| "id": Path(audio_path).stem, | |
| "audio_path": audio_path, | |
| "text": text, | |
| "duration": 0.0, | |
| }) | |
| return samples | |
| PARSERS = { | |
| "gemini_synthetic": parse_gemini_synthetic, | |
| "libriheavy": parse_libriheavy, | |
| "manifest": parse_manifest, | |
| "tsv": parse_tsv, | |
| } | |
| def main(): | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| args = parse_args() | |
| from ltx_core.model.audio_vae import encode_audio as vae_encode_audio | |
| from ltx_core.types import Audio | |
| from ltx_pipelines.utils.blocks import AudioConditioner | |
| from ltx_pipelines.utils.media_io import decode_audio_from_file | |
| from ltx_trainer.model_loader import load_text_encoder, load_embeddings_processor | |
| if args.gpu is not None: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.bfloat16 | |
| # Create output directories | |
| out = Path(args.output_dir) | |
| (out / "latents").mkdir(parents=True, exist_ok=True) | |
| (out / "conditions").mkdir(parents=True, exist_ok=True) | |
| (out / "audio_latents").mkdir(parents=True, exist_ok=True) | |
| # Parse dataset | |
| logging.info(f"Parsing {args.dataset_type} dataset from {args.index}...") | |
| samples = PARSERS[args.dataset_type](args.index, args.audio_dir) | |
| logging.info(f"Found {len(samples)} samples") | |
| # Filter by duration | |
| before = len(samples) | |
| samples = [s for s in samples if args.min_duration <= s["duration"] <= args.max_duration] | |
| logging.info(f"After duration filter [{args.min_duration}s, {args.max_duration}s]: {len(samples)} (dropped {before - len(samples)})") | |
| if args.max_samples > 0: | |
| samples = samples[:args.max_samples] | |
| logging.info(f"Limiting to {len(samples)} samples") | |
| # Assign global indices before sharding | |
| for i, s in enumerate(samples): | |
| s["global_idx"] = i | |
| # Shard the data for parallel processing | |
| if args.num_shards > 1: | |
| total = len(samples) | |
| samples = samples[args.shard::args.num_shards] | |
| logging.info(f"Shard {args.shard}/{args.num_shards}: {len(samples)} samples (of {total} total)") | |
| # ββ Step 1: Encode text with Gemma (Blocks 1+2 only) ββ | |
| # The trainer runs Block 3 (embeddings processor/connectors) during training, | |
| # so we only precompute Blocks 1+2 here (Gemma LLM + feature extractor). | |
| logging.info("Loading text encoder (Gemma + feature extractor)...") | |
| text_encoder = load_text_encoder(args.gemma_root, device=device, dtype=dtype) | |
| # Load feature extractor on CPU first to save GPU memory, then move to device | |
| logging.info("Loading feature extractor (on CPU first to save GPU memory)...") | |
| emb_proc = load_embeddings_processor(args.checkpoint, device="cpu", dtype=dtype) | |
| text_encoder.feature_extractor = emb_proc.feature_extractor.to(device) | |
| del emb_proc | |
| torch.cuda.empty_cache() | |
| logging.info("Encoding text prompts (Blocks 1+2: Gemma + feature extractor)...") | |
| for i, sample in enumerate(samples): | |
| gidx = sample["global_idx"] | |
| cond_path = out / "conditions" / f"sample_{gidx:06d}.pt" | |
| if args.skip_existing and cond_path.exists(): | |
| continue | |
| text = sample["text"] | |
| # Run Blocks 1+2: Gemma LLM β feature extractor | |
| hidden_states, attention_mask = text_encoder.encode(text) | |
| video_feats, audio_feats = text_encoder.feature_extractor( | |
| hidden_states, attention_mask, "left" | |
| ) | |
| torch.save({ | |
| "video_prompt_embeds": video_feats.squeeze(0).cpu(), | |
| "audio_prompt_embeds": audio_feats.squeeze(0).cpu() if audio_feats is not None else video_feats.squeeze(0).cpu(), | |
| "prompt_attention_mask": attention_mask.squeeze(0).bool().cpu(), | |
| }, cond_path) | |
| if i % 100 == 0: | |
| logging.info(f" Text encoding: {i}/{len(samples)}") | |
| del text_encoder | |
| torch.cuda.empty_cache() | |
| # ββ Step 2: Encode audio with Audio VAE ββ | |
| ckpt_for_vae = args.audio_only_ckpt or args.checkpoint | |
| logging.info(f"Loading audio VAE from {ckpt_for_vae}...") | |
| ac = AudioConditioner(checkpoint_path=ckpt_for_vae, dtype=dtype, device=device) | |
| logging.info("Encoding audio samples...") | |
| for idx, sample in enumerate(samples): | |
| gidx = sample["global_idx"] | |
| audio_path = out / "audio_latents" / f"sample_{gidx:06d}.pt" | |
| if args.skip_existing and audio_path.exists(): | |
| continue | |
| try: | |
| # Load audio | |
| voice = decode_audio_from_file(sample["audio_path"], device, 0.0, args.max_duration) | |
| if voice is None: | |
| logging.warning(f" Skipping {sample['id']}: no audio") | |
| continue | |
| w = voice.waveform | |
| if w.dim() == 2: | |
| if w.shape[0] == 1: | |
| w = w.repeat(2, 1) | |
| w = w.unsqueeze(0) | |
| elif w.dim() == 3 and w.shape[1] == 1: | |
| w = w.repeat(1, 2, 1) | |
| voice = Audio(waveform=w, sampling_rate=voice.sampling_rate) | |
| # Encode through Audio VAE | |
| audio_latent = ac(lambda enc: vae_encode_audio(voice, enc, None)) | |
| # Save audio latent | |
| torch.save({ | |
| "latents": audio_latent.squeeze(0).cpu(), # [C=8, T, F=16] | |
| "sample_rate": 16000, | |
| }, audio_path) | |
| except Exception as e: | |
| logging.warning(f" Skipping {sample['id']}: {e}") | |
| continue | |
| if idx % 100 == 0: | |
| logging.info(f" Audio encoding: {idx}/{len(samples)}") | |
| del ac | |
| torch.cuda.empty_cache() | |
| # ββ Step 3: Create dummy video latents ββ | |
| logging.info("Creating dummy video latents...") | |
| # Minimal video: 1 frame, 64x64 = 2x2 in latent space | |
| dummy_video = { | |
| "latents": torch.zeros(128, 1, 2, 2), | |
| "num_frames": 1, | |
| "height": 2, | |
| "width": 2, | |
| "fps": 24.0, | |
| } | |
| for idx, sample in enumerate(samples): | |
| gidx = sample["global_idx"] | |
| latent_path = out / "latents" / f"sample_{gidx:06d}.pt" | |
| if args.skip_existing and latent_path.exists(): | |
| continue | |
| torch.save(dummy_video, latent_path) | |
| # ββ Summary ββ | |
| n_audio = len(list((out / "audio_latents").glob("*.pt"))) | |
| n_cond = len(list((out / "conditions").glob("*.pt"))) | |
| n_lat = len(list((out / "latents").glob("*.pt"))) | |
| logging.info(f"\nDone! Output: {args.output_dir}") | |
| logging.info(f" audio_latents: {n_audio} files") | |
| logging.info(f" conditions: {n_cond} files") | |
| logging.info(f" latents: {n_lat} files") | |
| if __name__ == "__main__": | |
| main() | |