File size: 13,868 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#!/usr/bin/env python3
"""
Preprocess TTS datasets for LTX-2.3 audio-only LoRA fine-tuning.

Takes paired (audio, transcript) data and produces the format expected by
the LTX trainer:
    .precomputed/
    β”œβ”€β”€ latents/sample_N.pt         # Dummy video latents (minimal)
    β”œβ”€β”€ conditions/sample_N.pt      # Text embeddings from Gemma
    └── audio_latents/sample_N.pt   # Audio VAE-encoded latents

Supports multiple dataset formats:
  - gemini_synthetic: index.txt with ~-separated fields (id~speaker~lang~sr~samples~dur~phonemes~text)
  - libriheavy: index_ft.txt with ~-separated fields (id~speaker~lang~samples~dur~phonemes~text)
  - manifest: JSON/JSONL with {"audio_filepath": ..., "text": ...}
  - tsv: TSV file with audio_path<TAB>text columns

Usage:
    python preprocess_tts_data.py \
        --dataset-type gemini_synthetic \
        --index /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/index.txt \
        --audio-dir /mnt/large-datasets/gemini_synthetic_dataset/conversational_dataset_pp/wavs \
        --output-dir /mnt/persistent0/manmay/tts_training_data \
        --max-samples 10000 \
        --max-duration 20.0 \
        --min-duration 3.0
"""

import argparse
import json
import logging
import os
import sys
from pathlib import Path

import torch
import torchaudio

REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ltx2"))
# ltx-pipelines on path via ltx2/

MODEL_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
GEMMA_DIR = os.environ.get("GEMMA_DIR", "gemma-3-12b-it-qat-q4_0-unquantized")


def parse_args():
    p = argparse.ArgumentParser(description="Preprocess TTS data for LTX-2.3 fine-tuning")
    p.add_argument("--dataset-type", required=True,
                   choices=["gemini_synthetic", "libriheavy", "manifest", "tsv"],
                   help="Dataset format type")
    p.add_argument("--index", required=True, help="Path to index/manifest file")
    p.add_argument("--audio-dir", default=None,
                   help="Base directory for audio files (if paths in index are relative)")
    p.add_argument("--output-dir", required=True, help="Output directory for preprocessed data")
    p.add_argument("--checkpoint", default=os.path.join(MODEL_DIR, "ltx-2.3-22b-distilled.safetensors"))
    p.add_argument("--gemma-root", default=GEMMA_DIR)
    p.add_argument("--max-samples", type=int, default=0, help="Max samples to process (0=all)")
    p.add_argument("--max-duration", type=float, default=20.0, help="Max audio duration in seconds")
    p.add_argument("--min-duration", type=float, default=2.0, help="Min audio duration in seconds")
    p.add_argument("--batch-size", type=int, default=8, help="Batch size for text encoding")
    p.add_argument("--skip-existing", action="store_true", help="Skip already processed samples")
    p.add_argument("--audio-only-ckpt", default=None,
                   help="Audio-only checkpoint for VAE encoding (optional, uses full ckpt if not set)")
    p.add_argument("--shard", type=int, default=0, help="Shard index (for parallel processing)")
    p.add_argument("--num-shards", type=int, default=1, help="Total number of shards")
    p.add_argument("--gpu", type=int, default=None, help="GPU device index to use")
    return p.parse_args()


def parse_gemini_synthetic(index_path: str, audio_dir: str | None) -> list[dict]:
    """Parse gemini_synthetic format: id~speaker~lang~sr~samples~dur~phonemes~text"""
    samples = []
    with open(index_path) as f:
        for line in f:
            parts = line.strip().split("~")
            if len(parts) < 7:
                continue
            file_id = parts[0]
            text = parts[-1]  # Last field is always the text
            sr = int(parts[3])
            n_samples = int(parts[4])
            duration = n_samples / sr

            # Find audio file
            if audio_dir:
                # Try common extensions
                for ext in [".flac", ".wav", ".mp3"]:
                    audio_path = os.path.join(audio_dir, file_id + ext)
                    if os.path.exists(audio_path):
                        break
                else:
                    continue
            else:
                audio_path = file_id

            samples.append({
                "id": file_id,
                "audio_path": audio_path,
                "text": text,
                "duration": duration,
            })
    return samples


def parse_libriheavy(index_path: str, audio_dir: str | None) -> list[dict]:
    """Parse libriheavy format: id~speaker~lang~samples~dur~phonemes~text"""
    samples = []
    with open(index_path) as f:
        for line in f:
            parts = line.strip().split("~")
            if len(parts) < 7:
                continue
            file_id = parts[0]
            text = parts[-1]
            n_samples = int(parts[3])
            duration = int(parts[4]) / 1000.0  # milliseconds to seconds

            if audio_dir:
                for ext in [".flac", ".wav", ".mp3"]:
                    audio_path = os.path.join(audio_dir, file_id + ext)
                    if os.path.exists(audio_path):
                        break
                else:
                    continue
            else:
                audio_path = file_id

            samples.append({
                "id": file_id,
                "audio_path": audio_path,
                "text": text,
                "duration": duration,
            })
    return samples


def parse_manifest(index_path: str, audio_dir: str | None) -> list[dict]:
    """Parse JSON/JSONL manifest with audio_filepath and text fields."""
    samples = []
    with open(index_path) as f:
        for line in f:
            entry = json.loads(line.strip())
            audio_path = entry.get("audio_filepath", entry.get("audio_path", ""))
            text = entry.get("text", entry.get("transcript", ""))
            duration = entry.get("duration", 0.0)

            if audio_dir and not os.path.isabs(audio_path):
                audio_path = os.path.join(audio_dir, audio_path)

            if os.path.exists(audio_path) and text:
                samples.append({
                    "id": Path(audio_path).stem,
                    "audio_path": audio_path,
                    "text": text,
                    "duration": duration,
                })
    return samples


def parse_tsv(index_path: str, audio_dir: str | None) -> list[dict]:
    """Parse TSV file with audio_path<TAB>text."""
    samples = []
    with open(index_path) as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 2:
                continue
            audio_path, text = parts[0], parts[1]
            if audio_dir and not os.path.isabs(audio_path):
                audio_path = os.path.join(audio_dir, audio_path)
            if os.path.exists(audio_path):
                samples.append({
                    "id": Path(audio_path).stem,
                    "audio_path": audio_path,
                    "text": text,
                    "duration": 0.0,
                })
    return samples


PARSERS = {
    "gemini_synthetic": parse_gemini_synthetic,
    "libriheavy": parse_libriheavy,
    "manifest": parse_manifest,
    "tsv": parse_tsv,
}


@torch.inference_mode()
def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
    args = parse_args()

    from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
    from ltx_core.types import Audio
    from ltx_pipelines.utils.blocks import AudioConditioner
    from ltx_pipelines.utils.media_io import decode_audio_from_file
    from ltx_trainer.model_loader import load_text_encoder, load_embeddings_processor

    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16

    # Create output directories
    out = Path(args.output_dir)
    (out / "latents").mkdir(parents=True, exist_ok=True)
    (out / "conditions").mkdir(parents=True, exist_ok=True)
    (out / "audio_latents").mkdir(parents=True, exist_ok=True)

    # Parse dataset
    logging.info(f"Parsing {args.dataset_type} dataset from {args.index}...")
    samples = PARSERS[args.dataset_type](args.index, args.audio_dir)
    logging.info(f"Found {len(samples)} samples")

    # Filter by duration
    before = len(samples)
    samples = [s for s in samples if args.min_duration <= s["duration"] <= args.max_duration]
    logging.info(f"After duration filter [{args.min_duration}s, {args.max_duration}s]: {len(samples)} (dropped {before - len(samples)})")

    if args.max_samples > 0:
        samples = samples[:args.max_samples]
        logging.info(f"Limiting to {len(samples)} samples")

    # Assign global indices before sharding
    for i, s in enumerate(samples):
        s["global_idx"] = i

    # Shard the data for parallel processing
    if args.num_shards > 1:
        total = len(samples)
        samples = samples[args.shard::args.num_shards]
        logging.info(f"Shard {args.shard}/{args.num_shards}: {len(samples)} samples (of {total} total)")

    # ── Step 1: Encode text with Gemma (Blocks 1+2 only) ──
    # The trainer runs Block 3 (embeddings processor/connectors) during training,
    # so we only precompute Blocks 1+2 here (Gemma LLM + feature extractor).
    logging.info("Loading text encoder (Gemma + feature extractor)...")
    text_encoder = load_text_encoder(args.gemma_root, device=device, dtype=dtype)

    # Load feature extractor on CPU first to save GPU memory, then move to device
    logging.info("Loading feature extractor (on CPU first to save GPU memory)...")
    emb_proc = load_embeddings_processor(args.checkpoint, device="cpu", dtype=dtype)
    text_encoder.feature_extractor = emb_proc.feature_extractor.to(device)
    del emb_proc
    torch.cuda.empty_cache()

    logging.info("Encoding text prompts (Blocks 1+2: Gemma + feature extractor)...")
    for i, sample in enumerate(samples):
        gidx = sample["global_idx"]
        cond_path = out / "conditions" / f"sample_{gidx:06d}.pt"
        if args.skip_existing and cond_path.exists():
            continue

        text = sample["text"]
        # Run Blocks 1+2: Gemma LLM β†’ feature extractor
        hidden_states, attention_mask = text_encoder.encode(text)
        video_feats, audio_feats = text_encoder.feature_extractor(
            hidden_states, attention_mask, "left"
        )

        torch.save({
            "video_prompt_embeds": video_feats.squeeze(0).cpu(),
            "audio_prompt_embeds": audio_feats.squeeze(0).cpu() if audio_feats is not None else video_feats.squeeze(0).cpu(),
            "prompt_attention_mask": attention_mask.squeeze(0).bool().cpu(),
        }, cond_path)

        if i % 100 == 0:
            logging.info(f"  Text encoding: {i}/{len(samples)}")

    del text_encoder
    torch.cuda.empty_cache()

    # ── Step 2: Encode audio with Audio VAE ──
    ckpt_for_vae = args.audio_only_ckpt or args.checkpoint
    logging.info(f"Loading audio VAE from {ckpt_for_vae}...")

    ac = AudioConditioner(checkpoint_path=ckpt_for_vae, dtype=dtype, device=device)

    logging.info("Encoding audio samples...")
    for idx, sample in enumerate(samples):
        gidx = sample["global_idx"]
        audio_path = out / "audio_latents" / f"sample_{gidx:06d}.pt"
        if args.skip_existing and audio_path.exists():
            continue

        try:
            # Load audio
            voice = decode_audio_from_file(sample["audio_path"], device, 0.0, args.max_duration)
            if voice is None:
                logging.warning(f"  Skipping {sample['id']}: no audio")
                continue

            w = voice.waveform
            if w.dim() == 2:
                if w.shape[0] == 1:
                    w = w.repeat(2, 1)
                w = w.unsqueeze(0)
            elif w.dim() == 3 and w.shape[1] == 1:
                w = w.repeat(1, 2, 1)
            voice = Audio(waveform=w, sampling_rate=voice.sampling_rate)

            # Encode through Audio VAE
            audio_latent = ac(lambda enc: vae_encode_audio(voice, enc, None))

            # Save audio latent
            torch.save({
                "latents": audio_latent.squeeze(0).cpu(),  # [C=8, T, F=16]
                "sample_rate": 16000,
            }, audio_path)

        except Exception as e:
            logging.warning(f"  Skipping {sample['id']}: {e}")
            continue

        if idx % 100 == 0:
            logging.info(f"  Audio encoding: {idx}/{len(samples)}")

    del ac
    torch.cuda.empty_cache()

    # ── Step 3: Create dummy video latents ──
    logging.info("Creating dummy video latents...")
    # Minimal video: 1 frame, 64x64 = 2x2 in latent space
    dummy_video = {
        "latents": torch.zeros(128, 1, 2, 2),
        "num_frames": 1,
        "height": 2,
        "width": 2,
        "fps": 24.0,
    }
    for idx, sample in enumerate(samples):
        gidx = sample["global_idx"]
        latent_path = out / "latents" / f"sample_{gidx:06d}.pt"
        if args.skip_existing and latent_path.exists():
            continue
        torch.save(dummy_video, latent_path)

    # ── Summary ──
    n_audio = len(list((out / "audio_latents").glob("*.pt")))
    n_cond = len(list((out / "conditions").glob("*.pt")))
    n_lat = len(list((out / "latents").glob("*.pt")))
    logging.info(f"\nDone! Output: {args.output_dir}")
    logging.info(f"  audio_latents: {n_audio} files")
    logging.info(f"  conditions:    {n_cond} files")
    logging.info(f"  latents:       {n_lat} files")


if __name__ == "__main__":
    main()