#!/usr/bin/env python3 """ ============================================================= Sinhala TTS - Complete Data Extraction Pipeline ============================================================= Processes YouTube audio into TTS-ready training data. Pipeline steps: 1. Download YouTube videos as audio 2. Source separation (HTDemucs - remove background music) 3. Audio enhancement (VoiceFixer + DeepFilterNet3) 4. Speaker diarization (pyannote 3.1) 5. VAD segmentation (Silero-VAD, 3-30s chunks) 6. ASR transcription (Whisper large-v3) 7. Quality filtering (DNSMOS, SNR, pitch, speaking rate) 8. Export as LJSpeech-format dataset Based on: - Emilia-Pipe (arxiv:2407.05361) - pipeline design - IndicVoices-R (arxiv:2409.05356) - audio enhancement - IndicTTS (arxiv:2211.09536) - target training format Requirements: pip install -U yt-dlp torch torchaudio demucs voicefixer deepfilternet \ pyannote.audio whisper librosa soundfile numpy scipy \ tqdm pandas certifi GPU recommended. CPU works but is 10-50x slower. Usage: # Process from a video list JSON python scripts/data_pipeline.py --video-list tts_channel_eval/unlimited_history_videos.json # Process a single video python scripts/data_pipeline.py --video-id AJ0Ul2Wl4Pk # Process a folder of already-downloaded audio files python scripts/data_pipeline.py --audio-dir /path/to/raw_audio/ # Resume from a checkpoint (skips completed steps) python scripts/data_pipeline.py --video-list videos.json --resume # Skip steps (e.g. if source audio is already clean) python scripts/data_pipeline.py --audio-dir audio/ --skip-separation --skip-enhancement ============================================================= """ import os import sys import ssl import json import argparse import hashlib import logging import warnings from pathlib import Path from typing import Optional, Dict, List, Tuple import numpy as np import torch import torchaudio import soundfile as sf from tqdm import tqdm warnings.filterwarnings("ignore") # macOS SSL fix try: import certifi os.environ['SSL_CERT_FILE'] = certifi.where() os.environ['REQUESTS_CA_BUNDLE'] = certifi.where() except ImportError: pass try: ssl._create_default_https_context = ssl._create_unverified_context except AttributeError: pass # ============================================================ # CONFIG # ============================================================ SAMPLE_RATE = 22050 # FastPitch target sample rate DIARIZE_SR = 16000 # pyannote expects 16kHz MIN_SEGMENT_SEC = 3.0 # minimum utterance length MAX_SEGMENT_SEC = 20.0 # maximum utterance length (IndicTTS filters >20s) TARGET_SPEAKER = None # set after diarization analysis; None = use dominant speaker # Quality thresholds (IndicVoices-R + Emilia-Pipe) SNR_THRESHOLD = 25.0 # dB PITCH_MEAN_MAX = 350.0 # Hz PITCH_STD_MAX = 150.0 # Hz SPEAKING_RATE_MAX = 30.0 # chars/second MIN_SPEECH_RATIO = 0.5 # at least 50% speech in segment logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%H:%M:%S' ) log = logging.getLogger("sinhala-tts") # ============================================================ # STEP 0: State management (resume support) # ============================================================ class PipelineState: """Track processing state for resume support.""" def __init__(self, state_dir: str): self.state_dir = Path(state_dir) self.state_dir.mkdir(parents=True, exist_ok=True) self.state_file = self.state_dir / "pipeline_state.json" self.state = self._load() def _load(self): if self.state_file.exists(): with open(self.state_file) as f: return json.load(f) return {"completed_videos": {}, "completed_steps": {}} def save(self): with open(self.state_file, "w") as f: json.dump(self.state, f, indent=2) def is_done(self, video_id: str, step: str) -> bool: return self.state.get("completed_videos", {}).get(video_id, {}).get(step, False) def mark_done(self, video_id: str, step: str): if video_id not in self.state["completed_videos"]: self.state["completed_videos"][video_id] = {} self.state["completed_videos"][video_id][step] = True self.save() # ============================================================ # STEP 1: Download # ============================================================ def download_video(video_id: str, output_dir: Path) -> Optional[Path]: """Download a YouTube video as mono WAV at target sample rate.""" import yt_dlp wav_path = output_dir / f"{video_id}.wav" if wav_path.exists(): log.info(f" [download] {video_id} already exists, skipping") return wav_path url = f"https://www.youtube.com/watch?v={video_id}" dl_opts = { 'format': 'bestaudio/best', 'outtmpl': str(output_dir / f"{video_id}.%(ext)s"), 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', }], 'postprocessor_args': { 'ffmpeg': ['-ac', '1', '-ar', str(SAMPLE_RATE)], }, 'quiet': True, 'no_warnings': True, 'nocheckcertificate': True, } try: with yt_dlp.YoutubeDL(dl_opts) as ydl: ydl.download([url]) if wav_path.exists(): return wav_path # yt-dlp sometimes adds double extension for f in output_dir.glob(f"{video_id}*.wav"): f.rename(wav_path) return wav_path except Exception as e: log.error(f" [download] Failed {video_id}: {e}") return None def download_all_videos(video_list: List[Dict], output_dir: Path) -> List[Tuple[str, Path]]: """Download all videos from list. Returns [(video_id, wav_path), ...]""" output_dir.mkdir(parents=True, exist_ok=True) results = [] for i, v in enumerate(tqdm(video_list, desc="Downloading")): vid_id = v["id"] log.info(f"[{i+1}/{len(video_list)}] Downloading: {v.get('title', vid_id)[:60]}") wav_path = download_video(vid_id, output_dir) if wav_path and wav_path.exists(): results.append((vid_id, wav_path)) log.info(f"Downloaded {len(results)}/{len(video_list)} videos") return results # ============================================================ # STEP 2: Source Separation (HTDemucs) # ============================================================ def separate_vocals(wav_path: Path, output_dir: Path) -> Optional[Path]: """ Extract vocals using HTDemucs (Meta's hybrid transformer Demucs). Removes background music, ambient noise, and effects. Based on IndicVoices-R pipeline (arxiv:2409.05356). """ output_path = output_dir / f"{wav_path.stem}_vocals.wav" if output_path.exists(): log.info(f" [separation] {wav_path.stem} already done, skipping") return output_path try: from demucs.pretrained import get_model from demucs.apply import apply_model import demucs.audio # Load model model = get_model("htdemucs") model.eval() if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") model.to(device) # Load audio (demucs expects stereo at its native sr) waveform, sr = torchaudio.load(str(wav_path)) # Convert mono to stereo if needed (demucs expects stereo) if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) # Resample to model's sample rate if needed if sr != model.samplerate: resampler = torchaudio.transforms.Resample(sr, model.samplerate) waveform = resampler(waveform) # Add batch dimension: (channels, samples) -> (1, channels, samples) waveform = waveform.unsqueeze(0).to(device) # Separate with torch.no_grad(): sources = apply_model(model, waveform, device=device) # sources shape: (batch, n_sources, channels, samples) # htdemucs sources: drums, bass, other, vocals # We want vocals (index 3) vocals = sources[0, 3] # (channels, samples) # Convert back to mono vocals_mono = vocals.mean(dim=0, keepdim=True) # Resample back to target SR if model.samplerate != SAMPLE_RATE: resampler = torchaudio.transforms.Resample(model.samplerate, SAMPLE_RATE) vocals_mono = resampler(vocals_mono) # Save output_dir.mkdir(parents=True, exist_ok=True) torchaudio.save(str(output_path), vocals_mono.cpu(), SAMPLE_RATE) log.info(f" [separation] Vocals extracted: {output_path.name}") return output_path except Exception as e: log.error(f" [separation] Failed: {e}") # Fall back to original audio return wav_path # ============================================================ # STEP 3: Audio Enhancement (VoiceFixer + DeepFilterNet3) # ============================================================ def enhance_audio(wav_path: Path, output_dir: Path) -> Optional[Path]: """ Two-stage enhancement from IndicVoices-R (arxiv:2409.05356): 1. VoiceFixer: dereverberation + bandwidth extension + denoising 2. DeepFilterNet3: remove remaining artifacts + noise This sequential approach was shown to produce cleaner speech than either method alone. """ output_path = output_dir / f"{wav_path.stem}_enhanced.wav" if output_path.exists(): log.info(f" [enhance] {wav_path.stem} already done, skipping") return output_path output_dir.mkdir(parents=True, exist_ok=True) current_path = wav_path # Stage 1: VoiceFixer (dereverberation + restoration) try: from voicefixer import VoiceFixer vf = VoiceFixer() vf_output = output_dir / f"{wav_path.stem}_vf.wav" vf.restore( input=str(current_path), output=str(vf_output), cuda=torch.cuda.is_available(), mode=0 # mode 0 = speech restoration (denoise + dereverb + upsample) ) if vf_output.exists(): current_path = vf_output log.info(f" [enhance] VoiceFixer done") except Exception as e: log.warning(f" [enhance] VoiceFixer failed (continuing): {e}") # Stage 2: DeepFilterNet3 (fine noise/artifact removal) try: from df.enhance import enhance, init_df, load_audio, save_audio df_model, df_state, _ = init_df() audio, _ = load_audio(str(current_path), sr=df_state.sr()) enhanced = enhance(df_model, df_state, audio) save_audio(str(output_path), enhanced, df_state.sr()) if output_path.exists(): # Resample to target SR if DeepFilterNet outputs different SR waveform, sr = torchaudio.load(str(output_path)) if sr != SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) waveform = resampler(waveform) torchaudio.save(str(output_path), waveform, SAMPLE_RATE) log.info(f" [enhance] DeepFilterNet3 done") else: # If DeepFilterNet failed silently, use VoiceFixer output if current_path != wav_path: import shutil shutil.copy2(str(current_path), str(output_path)) except Exception as e: log.warning(f" [enhance] DeepFilterNet3 failed (continuing): {e}") # Use whatever we have so far if current_path != wav_path: import shutil shutil.copy2(str(current_path), str(output_path)) else: return wav_path # Clean up intermediate VoiceFixer file vf_temp = output_dir / f"{wav_path.stem}_vf.wav" if vf_temp.exists() and output_path.exists(): vf_temp.unlink() return output_path if output_path.exists() else wav_path # ============================================================ # STEP 4: Speaker Diarization (pyannote 3.1) # ============================================================ def diarize_audio(wav_path: Path, num_speakers: int = 2, hf_token: Optional[str] = None) -> Dict[str, List[Dict]]: """ Run speaker diarization using pyannote/speaker-diarization-3.1. Returns dict: {speaker_label: [{start, end, duration}, ...], ...} NOTE: Requires accepting model licenses on HuggingFace: - https://huggingface.co/pyannote/segmentation-3.0 - https://huggingface.co/pyannote/speaker-diarization-3.1 """ from pyannote.audio import Pipeline token = hf_token or os.environ.get("HF_TOKEN") if not token: log.warning(" [diarize] No HF_TOKEN found. pyannote requires auth.") log.warning(" [diarize] Set HF_TOKEN env var or pass --hf-token") # Fall back to simple-diarizer (no auth needed) return _diarize_simple(wav_path, num_speakers) try: pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=token ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device) # Load at 16kHz for diarization waveform, sr = torchaudio.load(str(wav_path)) if sr != DIARIZE_SR: resampler = torchaudio.transforms.Resample(sr, DIARIZE_SR) waveform = resampler(waveform) diarization = pipeline( {"waveform": waveform, "sample_rate": DIARIZE_SR}, num_speakers=num_speakers ) speakers = {} for turn, _, speaker in diarization.itertracks(yield_label=True): if speaker not in speakers: speakers[speaker] = [] speakers[speaker].append({ "start": round(turn.start, 3), "end": round(turn.end, 3), "duration": round(turn.end - turn.start, 3), }) log.info(f" [diarize] Found {len(speakers)} speakers (pyannote)") return speakers except Exception as e: log.warning(f" [diarize] pyannote failed: {e}") log.info(" [diarize] Falling back to simple-diarizer (no auth)") return _diarize_simple(wav_path, num_speakers) def _diarize_simple(wav_path: Path, num_speakers: int = 2) -> Dict[str, List[Dict]]: """Fallback: simple-diarizer using SpeechBrain ECAPA (no auth needed).""" try: from simple_diarizer.diarizer import Diarizer diar = Diarizer(embed_model='ecapa', cluster_method='sc') segments = diar.diarize(str(wav_path), num_speakers=num_speakers) speakers = {} for seg in segments: label = str(seg['label']) if label not in speakers: speakers[label] = [] speakers[label].append({ "start": round(seg['start'], 3), "end": round(seg['end'], 3), "duration": round(seg['end'] - seg['start'], 3), }) log.info(f" [diarize] Found {len(speakers)} speakers (simple-diarizer)") return speakers except Exception as e: log.error(f" [diarize] simple-diarizer also failed: {e}") # Last resort: treat entire audio as one speaker import librosa dur = librosa.get_duration(path=str(wav_path)) return {"SPEAKER_0": [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}]} def select_target_speaker(speakers: Dict[str, List[Dict]], target_speaker: Optional[str] = None) -> str: """Select which speaker to extract. Default: the one with most speaking time.""" if target_speaker and target_speaker in speakers: return target_speaker # Pick speaker with most total duration durations = {} for spk, segs in speakers.items(): durations[spk] = sum(s["duration"] for s in segs) best = max(durations, key=durations.get) log.info(f" [diarize] Selected speaker: {best} " f"({durations[best]/60:.1f} min / " f"{sum(durations.values())/60:.1f} min total)") return best # ============================================================ # STEP 5: VAD Segmentation (Silero-VAD) # ============================================================ def segment_with_vad(wav_path: Path, speaker_segments: List[Dict], output_dir: Path) -> List[Dict]: """ Fine-grained VAD segmentation within speaker turns. Takes diarization segments for one speaker and: 1. Extracts audio for that speaker 2. Runs Silero-VAD to find speech boundaries 3. Splits long segments, merges short ones 4. Exports individual utterance WAV files (3-20s each) Returns list of {path, start, end, duration} for each utterance. """ output_dir.mkdir(parents=True, exist_ok=True) # Load full audio waveform, sr = torchaudio.load(str(wav_path)) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample to 16kHz for Silero-VAD if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) waveform_16k = resampler(waveform) else: waveform_16k = waveform # Load Silero-VAD vad_model, vad_utils = torch.hub.load( repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, trust_repo=True, ) get_speech_timestamps = vad_utils[0] utterances = [] utt_idx = 0 for seg in speaker_segments: # Extract segment audio (at 16kHz for VAD) start_sample_16k = int(seg["start"] * 16000) end_sample_16k = int(seg["end"] * 16000) seg_audio = waveform_16k[0, start_sample_16k:end_sample_16k] if len(seg_audio) < int(MIN_SEGMENT_SEC * 16000): continue # Run VAD on this segment try: speech_ts = get_speech_timestamps( seg_audio, vad_model, sampling_rate=16000, min_speech_duration_ms=500, min_silence_duration_ms=300, speech_pad_ms=100, return_seconds=False, ) except Exception: speech_ts = [{"start": 0, "end": len(seg_audio)}] if not speech_ts: continue # Merge close VAD segments and enforce duration limits merged = _merge_vad_segments(speech_ts, sr=16000) for vad_seg in merged: # Convert back to original audio timestamps vad_start_sec = seg["start"] + vad_seg["start"] / 16000 vad_end_sec = seg["start"] + vad_seg["end"] / 16000 duration = vad_end_sec - vad_start_sec if duration < MIN_SEGMENT_SEC or duration > MAX_SEGMENT_SEC: continue # Extract from original audio at target sample rate start_sample = int(vad_start_sec * sr) end_sample = int(vad_end_sec * sr) utt_audio = waveform[:, start_sample:end_sample] # Resample to target if needed if sr != SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) utt_audio = resampler(utt_audio) # Normalize volume (peak normalize to -3 dBFS) peak = utt_audio.abs().max() if peak > 0: target_peak = 10 ** (-3 / 20) # -3 dBFS utt_audio = utt_audio * (target_peak / peak) # Save utt_name = f"{wav_path.stem}_utt{utt_idx:05d}.wav" utt_path = output_dir / utt_name torchaudio.save(str(utt_path), utt_audio, SAMPLE_RATE) utterances.append({ "path": str(utt_path), "filename": utt_name, "start": round(vad_start_sec, 3), "end": round(vad_end_sec, 3), "duration": round(duration, 3), }) utt_idx += 1 log.info(f" [vad] Extracted {len(utterances)} utterances " f"({sum(u['duration'] for u in utterances)/60:.1f} min)") return utterances def _merge_vad_segments(segments: List[Dict], sr: int = 16000, gap_ms: int = 500) -> List[Dict]: """Merge VAD segments that are close together.""" if not segments: return [] gap_samples = int(gap_ms * sr / 1000) merged = [{"start": segments[0]["start"], "end": segments[0]["end"]}] for seg in segments[1:]: if seg["start"] - merged[-1]["end"] < gap_samples: merged[-1]["end"] = seg["end"] else: merged.append({"start": seg["start"], "end": seg["end"]}) # Split segments that are too long final = [] for seg in merged: duration_sec = (seg["end"] - seg["start"]) / sr if duration_sec > MAX_SEGMENT_SEC: # Split into chunks at MAX_SEGMENT_SEC boundaries chunk_samples = int(MAX_SEGMENT_SEC * sr) pos = seg["start"] while pos < seg["end"]: end = min(pos + chunk_samples, seg["end"]) if (end - pos) / sr >= MIN_SEGMENT_SEC: final.append({"start": pos, "end": end}) pos = end else: final.append(seg) return final # ============================================================ # STEP 6: ASR Transcription (Whisper large-v3) # ============================================================ def transcribe_utterances(utterances: List[Dict], model_size: str = "large-v3") -> List[Dict]: """ Transcribe utterances using Whisper. Uses faster-whisper (CTranslate2 backend) if available, falls back to openai-whisper. """ # Try faster-whisper first (2-4x faster) try: from faster_whisper import WhisperModel log.info(f" [asr] Loading faster-whisper {model_size}...") device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" model = WhisperModel(model_size, device=device, compute_type=compute_type) for utt in tqdm(utterances, desc="Transcribing"): try: segments, info = model.transcribe( utt["path"], language="si", beam_size=5, best_of=5, temperature=0.0, condition_on_previous_text=False, vad_filter=False, # we already did VAD ) text = " ".join(seg.text.strip() for seg in segments) utt["text"] = text.strip() utt["language_prob"] = info.language_probability except Exception as e: utt["text"] = "" utt["language_prob"] = 0.0 log.warning(f" [asr] Failed on {utt['filename']}: {e}") return utterances except ImportError: pass # Fallback: openai-whisper try: import whisper log.info(f" [asr] Loading whisper {model_size}...") device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper.load_model(model_size, device=device) for utt in tqdm(utterances, desc="Transcribing"): try: result = model.transcribe( utt["path"], language="si", beam_size=5, best_of=5, temperature=0.0, condition_on_previous_text=False, no_speech_threshold=0.6, ) utt["text"] = result["text"].strip() utt["language_prob"] = result.get("language", {}).get("si", 0.0) except Exception as e: utt["text"] = "" utt["language_prob"] = 0.0 log.warning(f" [asr] Failed on {utt['filename']}: {e}") return utterances except ImportError: log.error(" [asr] Neither faster-whisper nor openai-whisper installed!") log.error(" Install: pip install faster-whisper (recommended)") log.error(" or: pip install openai-whisper") return utterances # ============================================================ # STEP 7: Quality Filtering # ============================================================ def compute_snr(wav_path: str) -> float: """Compute approximate SNR using RMS energy thresholding.""" import librosa y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True) rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0] threshold = np.percentile(rms, 20) noise = rms[rms <= threshold] speech = rms[rms > threshold] if len(noise) > 0 and np.mean(noise) > 1e-10: return float(20 * np.log10(np.mean(speech) / np.mean(noise))) return 40.0 def compute_pitch_stats(wav_path: str) -> Tuple[float, float]: """Compute pitch mean and std for a single utterance.""" import librosa y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True) f0, _, _ = librosa.pyin(y, fmin=50, fmax=500, sr=sr) f0v = f0[~np.isnan(f0)] if len(f0v) > 0: return float(np.mean(f0v)), float(np.std(f0v)) return 0.0, 0.0 def compute_speaking_rate(text: str, duration: float) -> float: """Characters per second (for Sinhala, grapheme clusters approximate syllables).""" # Remove spaces and punctuation for char count chars = len([c for c in text if c.strip() and c not in "!?.,;:\"'()-"]) if duration > 0: return chars / duration return 0.0 def filter_utterances(utterances: List[Dict]) -> Tuple[List[Dict], List[Dict]]: """ Apply quality filters based on IndicVoices-R and Emilia-Pipe thresholds. Returns (kept, rejected) lists. """ import librosa kept = [] rejected = [] for utt in tqdm(utterances, desc="Quality filtering"): reasons = [] # Skip empty transcriptions if not utt.get("text", "").strip(): reasons.append("empty_text") # Skip very low language probability if utt.get("language_prob", 0) < 0.5: reasons.append(f"low_lang_prob={utt.get('language_prob', 0):.2f}") # Duration check (should already be filtered, but double-check) if utt["duration"] < MIN_SEGMENT_SEC or utt["duration"] > MAX_SEGMENT_SEC: reasons.append(f"duration={utt['duration']:.1f}s") # SNR check try: snr = compute_snr(utt["path"]) utt["snr_db"] = round(snr, 1) if snr < SNR_THRESHOLD: reasons.append(f"low_snr={snr:.1f}dB") except Exception: utt["snr_db"] = 0.0 reasons.append("snr_failed") # Pitch check (detect multi-speaker leakage or non-speech) try: pitch_mean, pitch_std = compute_pitch_stats(utt["path"]) utt["pitch_mean_hz"] = round(pitch_mean, 1) utt["pitch_std_hz"] = round(pitch_std, 1) if pitch_mean > PITCH_MEAN_MAX: reasons.append(f"high_pitch={pitch_mean:.0f}Hz") if pitch_std > PITCH_STD_MAX: reasons.append(f"high_pitch_var={pitch_std:.0f}Hz") except Exception: utt["pitch_mean_hz"] = 0.0 utt["pitch_std_hz"] = 0.0 # Speaking rate check if utt.get("text"): rate = compute_speaking_rate(utt["text"], utt["duration"]) utt["speaking_rate"] = round(rate, 1) if rate > SPEAKING_RATE_MAX: reasons.append(f"fast_speech={rate:.1f}c/s") if rate < 1.0 and utt["duration"] > 3.0: reasons.append(f"slow_speech={rate:.1f}c/s") # Speech ratio (check for silence-heavy segments) try: y, sr_loaded = librosa.load(utt["path"], sr=SAMPLE_RATE, mono=True) rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0] threshold = np.percentile(rms, 20) speech_ratio = float(np.sum(rms > threshold) / len(rms)) utt["speech_ratio"] = round(speech_ratio, 3) if speech_ratio < MIN_SPEECH_RATIO: reasons.append(f"low_speech_ratio={speech_ratio:.2f}") except Exception: utt["speech_ratio"] = 0.0 if reasons: utt["reject_reasons"] = reasons rejected.append(utt) else: kept.append(utt) log.info(f" [filter] Kept {len(kept)}/{len(utterances)} " f"({len(kept)/len(utterances)*100:.1f}%)") # Log rejection stats if rejected: all_reasons = [r for u in rejected for r in u.get("reject_reasons", [])] reason_counts = {} for r in all_reasons: key = r.split("=")[0] reason_counts[key] = reason_counts.get(key, 0) + 1 log.info(f" [filter] Rejection reasons: {reason_counts}") return kept, rejected # ============================================================ # STEP 8: Export as LJSpeech Format # ============================================================ def export_dataset(utterances: List[Dict], output_dir: Path, val_split: float = 0.05): """ Export as LJSpeech-format dataset for Coqui-TTS FastPitch training. Creates: output_dir/ wavs/ - all WAV files (22050 Hz, mono) metadata.csv - full dataset: filename|text|normalized_text metadata_train.csv metadata_val.csv dataset_stats.json - corpus statistics """ output_dir.mkdir(parents=True, exist_ok=True) wavs_dir = output_dir / "wavs" wavs_dir.mkdir(exist_ok=True) # Copy/rename WAV files to sequential names import shutil metadata = [] for i, utt in enumerate(tqdm(utterances, desc="Exporting")): new_name = f"si_{i:06d}" new_path = wavs_dir / f"{new_name}.wav" # Copy WAV if not new_path.exists(): src = Path(utt["path"]) if src.exists(): shutil.copy2(str(src), str(new_path)) else: continue text = utt.get("text", "").strip() if not text: continue # Normalize text: basic Sinhala text cleaning normalized = _normalize_sinhala_text(text) metadata.append(f"{new_name}|{text}|{normalized}") # Shuffle and split import random random.seed(42) random.shuffle(metadata) n_val = max(1, int(len(metadata) * val_split)) val_lines = metadata[:n_val] train_lines = metadata[n_val:] # Write metadata files with open(output_dir / "metadata.csv", "w", encoding="utf-8") as f: f.write("\n".join(metadata) + "\n") with open(output_dir / "metadata_train.csv", "w", encoding="utf-8") as f: f.write("\n".join(train_lines) + "\n") with open(output_dir / "metadata_val.csv", "w", encoding="utf-8") as f: f.write("\n".join(val_lines) + "\n") # Compute corpus statistics durations = [u["duration"] for u in utterances] stats = { "total_utterances": len(metadata), "train_utterances": len(train_lines), "val_utterances": len(val_lines), "total_hours": round(sum(durations) / 3600, 2), "mean_duration_sec": round(np.mean(durations), 2), "median_duration_sec": round(np.median(durations), 2), "min_duration_sec": round(min(durations), 2), "max_duration_sec": round(max(durations), 2), "sample_rate": SAMPLE_RATE, } # Pitch stats across corpus pitches = [u.get("pitch_mean_hz", 0) for u in utterances if u.get("pitch_mean_hz", 0) > 0] if pitches: stats["corpus_pitch_mean_hz"] = round(float(np.mean(pitches)), 1) stats["corpus_pitch_std_hz"] = round(float(np.std(pitches)), 1) with open(output_dir / "dataset_stats.json", "w") as f: json.dump(stats, f, indent=2) log.info(f"\n{'='*60}") log.info(f"DATASET EXPORTED") log.info(f"{'='*60}") log.info(f" Location: {output_dir}") log.info(f" Total: {stats['total_utterances']} utterances") log.info(f" Train: {stats['train_utterances']}") log.info(f" Val: {stats['val_utterances']}") log.info(f" Duration: {stats['total_hours']} hours") log.info(f" Mean length: {stats['mean_duration_sec']}s") if 'corpus_pitch_mean_hz' in stats: log.info(f" Pitch mean: {stats['corpus_pitch_mean_hz']} Hz") log.info(f" Pitch std: {stats['corpus_pitch_std_hz']} Hz") log.info(f"{'='*60}") return stats def _normalize_sinhala_text(text: str) -> str: """ Basic Sinhala text normalization for TTS. - Unicode NFC normalization (canonical decomposition → composition) - Remove zero-width characters (except ZWJ which forms conjuncts) - Normalize punctuation - Collapse whitespace """ import unicodedata # NFC normalization (critical for Brahmic scripts) text = unicodedata.normalize('NFC', text) # Remove zero-width non-joiner (ZWNJ) but keep ZWJ (conjunct former) text = text.replace('\u200C', '') # ZWNJ # Keep \u200D (ZWJ) — it's part of Sinhala conjunct consonants like ක්‍ෂ # Normalize quotation marks text = text.replace('"', '"').replace('"', '"') text = text.replace(''', "'").replace(''', "'") # Replace semicolons and colons with commas (IndicTTS convention) text = text.replace(';', ',').replace(':', ',') # Remove parentheses (but keep content) text = text.replace('(', '').replace(')', '') text = text.replace('[', '').replace(']', '') # Collapse whitespace text = ' '.join(text.split()) return text.strip() # ============================================================ # MAIN # ============================================================ def parse_args(): parser = argparse.ArgumentParser( description="Sinhala TTS Data Pipeline", formatter_class=argparse.RawDescriptionHelpFormatter, ) # Input sources (pick one) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument("--video-list", type=str, help="JSON file with video list (from list_unlimited_history.py)") input_group.add_argument("--video-id", type=str, help="Single YouTube video ID") input_group.add_argument("--audio-dir", type=str, help="Directory of pre-downloaded audio files") # Output parser.add_argument("--output-dir", type=str, default="sinhala_tts_dataset", help="Output directory (default: sinhala_tts_dataset)") # Pipeline options parser.add_argument("--num-speakers", type=int, default=2, help="Expected number of speakers per video (default: 2)") parser.add_argument("--target-speaker", type=str, default=None, help="Speaker label to extract (default: auto-select dominant)") parser.add_argument("--whisper-model", type=str, default="large-v3", choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], help="Whisper model size (default: large-v3)") parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace token for pyannote (optional)") # Skip options parser.add_argument("--skip-separation", action="store_true", help="Skip source separation step") parser.add_argument("--skip-enhancement", action="store_true", help="Skip audio enhancement step") parser.add_argument("--skip-diarization", action="store_true", help="Skip diarization (treat all audio as one speaker)") parser.add_argument("--skip-transcription", action="store_true", help="Skip Whisper transcription (need pre-existing transcripts)") # Control parser.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip completed steps)") parser.add_argument("--max-videos", type=int, default=None, help="Process only first N videos (for testing)") parser.add_argument("--batch-size", type=int, default=10, help="Process videos in batches of N (default: 10)") parser.add_argument("--use-unlimited-history-only", action="store_true", help="Only use Unlimited History videos from the list") return parser.parse_args() def main(): args = parse_args() output_dir = Path(args.output_dir) raw_dir = output_dir / "raw_audio" separated_dir = output_dir / "separated" enhanced_dir = output_dir / "enhanced" segments_dir = output_dir / "segments" dataset_dir = output_dir / "dataset" state = PipelineState(str(output_dir / ".state")) log.info("=" * 60) log.info("Sinhala TTS Data Pipeline") log.info("=" * 60) log.info(f"Output: {output_dir}") log.info(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") if torch.cuda.is_available(): log.info(f"GPU: {torch.cuda.get_device_name(0)}") log.info("") # ---- Resolve input ---- if args.video_list: with open(args.video_list) as f: data = json.load(f) if args.use_unlimited_history_only: videos = data.get("unlimited_history", []) log.info(f"Using {len(videos)} Unlimited History videos") else: videos = data.get("unlimited_history", []) + data.get("other", []) log.info(f"Using all {len(videos)} videos") if args.max_videos: videos = videos[:args.max_videos] log.info(f"Limited to first {args.max_videos} videos") elif args.video_id: videos = [{"id": args.video_id, "title": args.video_id}] elif args.audio_dir: audio_dir = Path(args.audio_dir) videos = [] for f in sorted(audio_dir.glob("*.wav")): videos.append({"id": f.stem, "title": f.stem, "_local_path": str(f)}) log.info(f"Found {len(videos)} audio files in {audio_dir}") if not videos: log.error("No videos to process!") sys.exit(1) # ---- Process in batches ---- all_utterances = [] batch_size = args.batch_size for batch_start in range(0, len(videos), batch_size): batch = videos[batch_start:batch_start + batch_size] batch_num = batch_start // batch_size + 1 total_batches = (len(videos) + batch_size - 1) // batch_size log.info(f"\n{'='*60}") log.info(f"BATCH {batch_num}/{total_batches} ({len(batch)} videos)") log.info(f"{'='*60}") for v in batch: vid_id = v["id"] title = v.get("title", vid_id) log.info(f"\n--- Processing: {title[:60]} ({vid_id}) ---") # Step 1: Download (or use local file) if "_local_path" in v: wav_path = Path(v["_local_path"]) elif args.resume and state.is_done(vid_id, "download"): wav_path = raw_dir / f"{vid_id}.wav" else: wav_path = download_video(vid_id, raw_dir) if wav_path: state.mark_done(vid_id, "download") if not wav_path or not wav_path.exists(): log.warning(f" Skipping {vid_id}: no audio") continue current_audio = wav_path # Step 2: Source separation if not args.skip_separation: if args.resume and state.is_done(vid_id, "separation"): current_audio = separated_dir / f"{vid_id}_vocals.wav" if not current_audio.exists(): current_audio = wav_path else: result = separate_vocals(wav_path, separated_dir) if result: current_audio = result state.mark_done(vid_id, "separation") # Step 3: Audio enhancement if not args.skip_enhancement: if args.resume and state.is_done(vid_id, "enhancement"): enh_path = enhanced_dir / f"{current_audio.stem}_enhanced.wav" if enh_path.exists(): current_audio = enh_path else: result = enhance_audio(current_audio, enhanced_dir) if result: current_audio = result state.mark_done(vid_id, "enhancement") # Step 4: Speaker diarization if not args.skip_diarization: if args.resume and state.is_done(vid_id, "diarization"): diar_file = output_dir / ".state" / f"{vid_id}_diarization.json" if diar_file.exists(): with open(diar_file) as f: speakers = json.load(f) else: speakers = diarize_audio(current_audio, args.num_speakers, args.hf_token) else: speakers = diarize_audio(current_audio, args.num_speakers, args.hf_token) # Save diarization results diar_file = output_dir / ".state" / f"{vid_id}_diarization.json" with open(diar_file, "w") as f: json.dump(speakers, f, indent=2) state.mark_done(vid_id, "diarization") target = select_target_speaker(speakers, args.target_speaker) speaker_segments = speakers[target] else: # No diarization: use full audio import librosa dur = librosa.get_duration(path=str(current_audio)) speaker_segments = [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}] # Step 5: VAD segmentation vid_segments_dir = segments_dir / vid_id utterances = segment_with_vad(current_audio, speaker_segments, vid_segments_dir) # Step 6: ASR transcription if not args.skip_transcription and utterances: utterances = transcribe_utterances(utterances, args.whisper_model) all_utterances.extend(utterances) state.mark_done(vid_id, "complete") log.info(f" Total utterances so far: {len(all_utterances)} " f"({sum(u['duration'] for u in all_utterances)/3600:.1f}h)") # ---- Step 7: Quality filtering ---- log.info(f"\n{'='*60}") log.info(f"QUALITY FILTERING ({len(all_utterances)} utterances)") log.info(f"{'='*60}") if all_utterances: kept, rejected = filter_utterances(all_utterances) # Save rejected for inspection with open(output_dir / "rejected_utterances.json", "w", encoding="utf-8") as f: json.dump(rejected, f, indent=2, ensure_ascii=False) # ---- Step 8: Export dataset ---- stats = export_dataset(kept, dataset_dir) # Save full manifest with open(output_dir / "full_manifest.json", "w", encoding="utf-8") as f: json.dump(kept, f, indent=2, ensure_ascii=False) log.info(f"\n{'='*60}") log.info(f"PIPELINE COMPLETE") log.info(f"{'='*60}") log.info(f" Raw utterances: {len(all_utterances)}") log.info(f" After filtering: {len(kept)}") log.info(f" Rejected: {len(rejected)}") log.info(f" Dataset: {dataset_dir}") log.info(f" Next step: python scripts/train_fastpitch.py --dataset {dataset_dir}") else: log.error("No utterances extracted! Check logs above for errors.") if __name__ == "__main__": main()