sinhala-tts / scripts /data_pipeline.py
outlawmold's picture
feat(macos): implement Apple Silicon optimizations and switch to wav2vec2 ASR
1a2a2b3
#!/usr/bin/env python3
"""
=============================================================
Sinhala TTS - Complete Data Extraction Pipeline
=============================================================
Processes YouTube audio into TTS-ready training data.
Pipeline steps:
1. Download YouTube videos as audio
2. Source separation (HTDemucs - remove background music)
3. Audio enhancement (VoiceFixer + DeepFilterNet3)
4. Speaker diarization (pyannote 3.1)
5. VAD segmentation (Silero-VAD, 3-30s chunks)
6. ASR transcription (Whisper large-v3)
7. Quality filtering (DNSMOS, SNR, pitch, speaking rate)
8. Export as LJSpeech-format dataset
Based on:
- Emilia-Pipe (arxiv:2407.05361) - pipeline design
- IndicVoices-R (arxiv:2409.05356) - audio enhancement
- IndicTTS (arxiv:2211.09536) - target training format
Requirements:
pip install -U yt-dlp torch torchaudio demucs voicefixer deepfilternet \
pyannote.audio whisper librosa soundfile numpy scipy \
tqdm pandas certifi
GPU recommended. CPU works but is 10-50x slower.
Usage:
# Process from a video list JSON
python scripts/data_pipeline.py --video-list tts_channel_eval/unlimited_history_videos.json
# Process a single video
python scripts/data_pipeline.py --video-id AJ0Ul2Wl4Pk
# Process a folder of already-downloaded audio files
python scripts/data_pipeline.py --audio-dir /path/to/raw_audio/
# Resume from a checkpoint (skips completed steps)
python scripts/data_pipeline.py --video-list videos.json --resume
# Skip steps (e.g. if source audio is already clean)
python scripts/data_pipeline.py --audio-dir audio/ --skip-separation --skip-enhancement
=============================================================
"""
import os
import sys
import ssl
import json
import argparse
import hashlib
import logging
import warnings
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import numpy as np
import torch
import torchaudio
import soundfile as sf
from tqdm import tqdm
warnings.filterwarnings("ignore")
# macOS SSL fix
try:
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
except ImportError:
pass
try:
ssl._create_default_https_context = ssl._create_unverified_context
except AttributeError:
pass
# ============================================================
# CONFIG
# ============================================================
SAMPLE_RATE = 22050 # FastPitch target sample rate
DIARIZE_SR = 16000 # pyannote expects 16kHz
MIN_SEGMENT_SEC = 3.0 # minimum utterance length
MAX_SEGMENT_SEC = 20.0 # maximum utterance length (IndicTTS filters >20s)
TARGET_SPEAKER = None # set after diarization analysis; None = use dominant speaker
# Quality thresholds (IndicVoices-R + Emilia-Pipe)
SNR_THRESHOLD = 25.0 # dB
PITCH_MEAN_MAX = 350.0 # Hz
PITCH_STD_MAX = 150.0 # Hz
SPEAKING_RATE_MAX = 30.0 # chars/second
MIN_SPEECH_RATIO = 0.5 # at least 50% speech in segment
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%H:%M:%S'
)
log = logging.getLogger("sinhala-tts")
# ============================================================
# STEP 0: State management (resume support)
# ============================================================
class PipelineState:
"""Track processing state for resume support."""
def __init__(self, state_dir: str):
self.state_dir = Path(state_dir)
self.state_dir.mkdir(parents=True, exist_ok=True)
self.state_file = self.state_dir / "pipeline_state.json"
self.state = self._load()
def _load(self):
if self.state_file.exists():
with open(self.state_file) as f:
return json.load(f)
return {"completed_videos": {}, "completed_steps": {}}
def save(self):
with open(self.state_file, "w") as f:
json.dump(self.state, f, indent=2)
def is_done(self, video_id: str, step: str) -> bool:
return self.state.get("completed_videos", {}).get(video_id, {}).get(step, False)
def mark_done(self, video_id: str, step: str):
if video_id not in self.state["completed_videos"]:
self.state["completed_videos"][video_id] = {}
self.state["completed_videos"][video_id][step] = True
self.save()
# ============================================================
# STEP 1: Download
# ============================================================
def download_video(video_id: str, output_dir: Path) -> Optional[Path]:
"""Download a YouTube video as mono WAV at target sample rate."""
import yt_dlp
wav_path = output_dir / f"{video_id}.wav"
if wav_path.exists():
log.info(f" [download] {video_id} already exists, skipping")
return wav_path
url = f"https://www.youtube.com/watch?v={video_id}"
dl_opts = {
'format': 'bestaudio/best',
'outtmpl': str(output_dir / f"{video_id}.%(ext)s"),
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
}],
'postprocessor_args': {
'ffmpeg': ['-ac', '1', '-ar', str(SAMPLE_RATE)],
},
'quiet': True,
'no_warnings': True,
'nocheckcertificate': True,
}
try:
with yt_dlp.YoutubeDL(dl_opts) as ydl:
ydl.download([url])
if wav_path.exists():
return wav_path
# yt-dlp sometimes adds double extension
for f in output_dir.glob(f"{video_id}*.wav"):
f.rename(wav_path)
return wav_path
except Exception as e:
log.error(f" [download] Failed {video_id}: {e}")
return None
def download_all_videos(video_list: List[Dict], output_dir: Path) -> List[Tuple[str, Path]]:
"""Download all videos from list. Returns [(video_id, wav_path), ...]"""
output_dir.mkdir(parents=True, exist_ok=True)
results = []
for i, v in enumerate(tqdm(video_list, desc="Downloading")):
vid_id = v["id"]
log.info(f"[{i+1}/{len(video_list)}] Downloading: {v.get('title', vid_id)[:60]}")
wav_path = download_video(vid_id, output_dir)
if wav_path and wav_path.exists():
results.append((vid_id, wav_path))
log.info(f"Downloaded {len(results)}/{len(video_list)} videos")
return results
# ============================================================
# STEP 2: Source Separation (HTDemucs)
# ============================================================
def separate_vocals(wav_path: Path, output_dir: Path) -> Optional[Path]:
"""
Extract vocals using HTDemucs (Meta's hybrid transformer Demucs).
Removes background music, ambient noise, and effects.
Based on IndicVoices-R pipeline (arxiv:2409.05356).
"""
output_path = output_dir / f"{wav_path.stem}_vocals.wav"
if output_path.exists():
log.info(f" [separation] {wav_path.stem} already done, skipping")
return output_path
try:
from demucs.pretrained import get_model
from demucs.apply import apply_model
import demucs.audio
# Load model
model = get_model("htdemucs")
model.eval()
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
# Load audio (demucs expects stereo at its native sr)
waveform, sr = torchaudio.load(str(wav_path))
# Convert mono to stereo if needed (demucs expects stereo)
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
# Resample to model's sample rate if needed
if sr != model.samplerate:
resampler = torchaudio.transforms.Resample(sr, model.samplerate)
waveform = resampler(waveform)
# Add batch dimension: (channels, samples) -> (1, channels, samples)
waveform = waveform.unsqueeze(0).to(device)
# Separate
with torch.no_grad():
sources = apply_model(model, waveform, device=device)
# sources shape: (batch, n_sources, channels, samples)
# htdemucs sources: drums, bass, other, vocals
# We want vocals (index 3)
vocals = sources[0, 3] # (channels, samples)
# Convert back to mono
vocals_mono = vocals.mean(dim=0, keepdim=True)
# Resample back to target SR
if model.samplerate != SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(model.samplerate, SAMPLE_RATE)
vocals_mono = resampler(vocals_mono)
# Save
output_dir.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(output_path), vocals_mono.cpu(), SAMPLE_RATE)
log.info(f" [separation] Vocals extracted: {output_path.name}")
return output_path
except Exception as e:
log.error(f" [separation] Failed: {e}")
# Fall back to original audio
return wav_path
# ============================================================
# STEP 3: Audio Enhancement (VoiceFixer + DeepFilterNet3)
# ============================================================
def enhance_audio(wav_path: Path, output_dir: Path) -> Optional[Path]:
"""
Two-stage enhancement from IndicVoices-R (arxiv:2409.05356):
1. VoiceFixer: dereverberation + bandwidth extension + denoising
2. DeepFilterNet3: remove remaining artifacts + noise
This sequential approach was shown to produce cleaner speech than
either method alone.
"""
output_path = output_dir / f"{wav_path.stem}_enhanced.wav"
if output_path.exists():
log.info(f" [enhance] {wav_path.stem} already done, skipping")
return output_path
output_dir.mkdir(parents=True, exist_ok=True)
current_path = wav_path
# Stage 1: VoiceFixer (dereverberation + restoration)
try:
from voicefixer import VoiceFixer
vf = VoiceFixer()
vf_output = output_dir / f"{wav_path.stem}_vf.wav"
vf.restore(
input=str(current_path),
output=str(vf_output),
cuda=torch.cuda.is_available(),
mode=0 # mode 0 = speech restoration (denoise + dereverb + upsample)
)
if vf_output.exists():
current_path = vf_output
log.info(f" [enhance] VoiceFixer done")
except Exception as e:
log.warning(f" [enhance] VoiceFixer failed (continuing): {e}")
# Stage 2: DeepFilterNet3 (fine noise/artifact removal)
try:
from df.enhance import enhance, init_df, load_audio, save_audio
df_model, df_state, _ = init_df()
audio, _ = load_audio(str(current_path), sr=df_state.sr())
enhanced = enhance(df_model, df_state, audio)
save_audio(str(output_path), enhanced, df_state.sr())
if output_path.exists():
# Resample to target SR if DeepFilterNet outputs different SR
waveform, sr = torchaudio.load(str(output_path))
if sr != SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
waveform = resampler(waveform)
torchaudio.save(str(output_path), waveform, SAMPLE_RATE)
log.info(f" [enhance] DeepFilterNet3 done")
else:
# If DeepFilterNet failed silently, use VoiceFixer output
if current_path != wav_path:
import shutil
shutil.copy2(str(current_path), str(output_path))
except Exception as e:
log.warning(f" [enhance] DeepFilterNet3 failed (continuing): {e}")
# Use whatever we have so far
if current_path != wav_path:
import shutil
shutil.copy2(str(current_path), str(output_path))
else:
return wav_path
# Clean up intermediate VoiceFixer file
vf_temp = output_dir / f"{wav_path.stem}_vf.wav"
if vf_temp.exists() and output_path.exists():
vf_temp.unlink()
return output_path if output_path.exists() else wav_path
# ============================================================
# STEP 4: Speaker Diarization (pyannote 3.1)
# ============================================================
def diarize_audio(wav_path: Path, num_speakers: int = 2,
hf_token: Optional[str] = None) -> Dict[str, List[Dict]]:
"""
Run speaker diarization using pyannote/speaker-diarization-3.1.
Returns dict: {speaker_label: [{start, end, duration}, ...], ...}
NOTE: Requires accepting model licenses on HuggingFace:
- https://huggingface.co/pyannote/segmentation-3.0
- https://huggingface.co/pyannote/speaker-diarization-3.1
"""
from pyannote.audio import Pipeline
token = hf_token or os.environ.get("HF_TOKEN")
if not token:
log.warning(" [diarize] No HF_TOKEN found. pyannote requires auth.")
log.warning(" [diarize] Set HF_TOKEN env var or pass --hf-token")
# Fall back to simple-diarizer (no auth needed)
return _diarize_simple(wav_path, num_speakers)
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=token
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
# Load at 16kHz for diarization
waveform, sr = torchaudio.load(str(wav_path))
if sr != DIARIZE_SR:
resampler = torchaudio.transforms.Resample(sr, DIARIZE_SR)
waveform = resampler(waveform)
diarization = pipeline(
{"waveform": waveform, "sample_rate": DIARIZE_SR},
num_speakers=num_speakers
)
speakers = {}
for turn, _, speaker in diarization.itertracks(yield_label=True):
if speaker not in speakers:
speakers[speaker] = []
speakers[speaker].append({
"start": round(turn.start, 3),
"end": round(turn.end, 3),
"duration": round(turn.end - turn.start, 3),
})
log.info(f" [diarize] Found {len(speakers)} speakers (pyannote)")
return speakers
except Exception as e:
log.warning(f" [diarize] pyannote failed: {e}")
log.info(" [diarize] Falling back to simple-diarizer (no auth)")
return _diarize_simple(wav_path, num_speakers)
def _diarize_simple(wav_path: Path, num_speakers: int = 2) -> Dict[str, List[Dict]]:
"""Fallback: simple-diarizer using SpeechBrain ECAPA (no auth needed)."""
try:
from simple_diarizer.diarizer import Diarizer
diar = Diarizer(embed_model='ecapa', cluster_method='sc')
segments = diar.diarize(str(wav_path), num_speakers=num_speakers)
speakers = {}
for seg in segments:
label = str(seg['label'])
if label not in speakers:
speakers[label] = []
speakers[label].append({
"start": round(seg['start'], 3),
"end": round(seg['end'], 3),
"duration": round(seg['end'] - seg['start'], 3),
})
log.info(f" [diarize] Found {len(speakers)} speakers (simple-diarizer)")
return speakers
except Exception as e:
log.error(f" [diarize] simple-diarizer also failed: {e}")
# Last resort: treat entire audio as one speaker
import librosa
dur = librosa.get_duration(path=str(wav_path))
return {"SPEAKER_0": [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}]}
def select_target_speaker(speakers: Dict[str, List[Dict]],
target_speaker: Optional[str] = None) -> str:
"""Select which speaker to extract. Default: the one with most speaking time."""
if target_speaker and target_speaker in speakers:
return target_speaker
# Pick speaker with most total duration
durations = {}
for spk, segs in speakers.items():
durations[spk] = sum(s["duration"] for s in segs)
best = max(durations, key=durations.get)
log.info(f" [diarize] Selected speaker: {best} "
f"({durations[best]/60:.1f} min / "
f"{sum(durations.values())/60:.1f} min total)")
return best
# ============================================================
# STEP 5: VAD Segmentation (Silero-VAD)
# ============================================================
def segment_with_vad(wav_path: Path, speaker_segments: List[Dict],
output_dir: Path) -> List[Dict]:
"""
Fine-grained VAD segmentation within speaker turns.
Takes diarization segments for one speaker and:
1. Extracts audio for that speaker
2. Runs Silero-VAD to find speech boundaries
3. Splits long segments, merges short ones
4. Exports individual utterance WAV files (3-20s each)
Returns list of {path, start, end, duration} for each utterance.
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Load full audio
waveform, sr = torchaudio.load(str(wav_path))
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample to 16kHz for Silero-VAD
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform_16k = resampler(waveform)
else:
waveform_16k = waveform
# Load Silero-VAD
vad_model, vad_utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
trust_repo=True,
)
get_speech_timestamps = vad_utils[0]
utterances = []
utt_idx = 0
for seg in speaker_segments:
# Extract segment audio (at 16kHz for VAD)
start_sample_16k = int(seg["start"] * 16000)
end_sample_16k = int(seg["end"] * 16000)
seg_audio = waveform_16k[0, start_sample_16k:end_sample_16k]
if len(seg_audio) < int(MIN_SEGMENT_SEC * 16000):
continue
# Run VAD on this segment
try:
speech_ts = get_speech_timestamps(
seg_audio,
vad_model,
sampling_rate=16000,
min_speech_duration_ms=500,
min_silence_duration_ms=300,
speech_pad_ms=100,
return_seconds=False,
)
except Exception:
speech_ts = [{"start": 0, "end": len(seg_audio)}]
if not speech_ts:
continue
# Merge close VAD segments and enforce duration limits
merged = _merge_vad_segments(speech_ts, sr=16000)
for vad_seg in merged:
# Convert back to original audio timestamps
vad_start_sec = seg["start"] + vad_seg["start"] / 16000
vad_end_sec = seg["start"] + vad_seg["end"] / 16000
duration = vad_end_sec - vad_start_sec
if duration < MIN_SEGMENT_SEC or duration > MAX_SEGMENT_SEC:
continue
# Extract from original audio at target sample rate
start_sample = int(vad_start_sec * sr)
end_sample = int(vad_end_sec * sr)
utt_audio = waveform[:, start_sample:end_sample]
# Resample to target if needed
if sr != SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
utt_audio = resampler(utt_audio)
# Normalize volume (peak normalize to -3 dBFS)
peak = utt_audio.abs().max()
if peak > 0:
target_peak = 10 ** (-3 / 20) # -3 dBFS
utt_audio = utt_audio * (target_peak / peak)
# Save
utt_name = f"{wav_path.stem}_utt{utt_idx:05d}.wav"
utt_path = output_dir / utt_name
torchaudio.save(str(utt_path), utt_audio, SAMPLE_RATE)
utterances.append({
"path": str(utt_path),
"filename": utt_name,
"start": round(vad_start_sec, 3),
"end": round(vad_end_sec, 3),
"duration": round(duration, 3),
})
utt_idx += 1
log.info(f" [vad] Extracted {len(utterances)} utterances "
f"({sum(u['duration'] for u in utterances)/60:.1f} min)")
return utterances
def _merge_vad_segments(segments: List[Dict], sr: int = 16000,
gap_ms: int = 500) -> List[Dict]:
"""Merge VAD segments that are close together."""
if not segments:
return []
gap_samples = int(gap_ms * sr / 1000)
merged = [{"start": segments[0]["start"], "end": segments[0]["end"]}]
for seg in segments[1:]:
if seg["start"] - merged[-1]["end"] < gap_samples:
merged[-1]["end"] = seg["end"]
else:
merged.append({"start": seg["start"], "end": seg["end"]})
# Split segments that are too long
final = []
for seg in merged:
duration_sec = (seg["end"] - seg["start"]) / sr
if duration_sec > MAX_SEGMENT_SEC:
# Split into chunks at MAX_SEGMENT_SEC boundaries
chunk_samples = int(MAX_SEGMENT_SEC * sr)
pos = seg["start"]
while pos < seg["end"]:
end = min(pos + chunk_samples, seg["end"])
if (end - pos) / sr >= MIN_SEGMENT_SEC:
final.append({"start": pos, "end": end})
pos = end
else:
final.append(seg)
return final
# ============================================================
# STEP 6: ASR Transcription (Whisper large-v3)
# ============================================================
def transcribe_utterances(utterances: List[Dict],
model_size: str = "large-v3") -> List[Dict]:
"""
Transcribe utterances using Whisper.
Uses faster-whisper (CTranslate2 backend) if available,
falls back to openai-whisper.
"""
# Try faster-whisper first (2-4x faster)
try:
from faster_whisper import WhisperModel
log.info(f" [asr] Loading faster-whisper {model_size}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
model = WhisperModel(model_size, device=device, compute_type=compute_type)
for utt in tqdm(utterances, desc="Transcribing"):
try:
segments, info = model.transcribe(
utt["path"],
language="si",
beam_size=5,
best_of=5,
temperature=0.0,
condition_on_previous_text=False,
vad_filter=False, # we already did VAD
)
text = " ".join(seg.text.strip() for seg in segments)
utt["text"] = text.strip()
utt["language_prob"] = info.language_probability
except Exception as e:
utt["text"] = ""
utt["language_prob"] = 0.0
log.warning(f" [asr] Failed on {utt['filename']}: {e}")
return utterances
except ImportError:
pass
# Fallback: openai-whisper
try:
import whisper
log.info(f" [asr] Loading whisper {model_size}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_size, device=device)
for utt in tqdm(utterances, desc="Transcribing"):
try:
result = model.transcribe(
utt["path"],
language="si",
beam_size=5,
best_of=5,
temperature=0.0,
condition_on_previous_text=False,
no_speech_threshold=0.6,
)
utt["text"] = result["text"].strip()
utt["language_prob"] = result.get("language", {}).get("si", 0.0)
except Exception as e:
utt["text"] = ""
utt["language_prob"] = 0.0
log.warning(f" [asr] Failed on {utt['filename']}: {e}")
return utterances
except ImportError:
log.error(" [asr] Neither faster-whisper nor openai-whisper installed!")
log.error(" Install: pip install faster-whisper (recommended)")
log.error(" or: pip install openai-whisper")
return utterances
# ============================================================
# STEP 7: Quality Filtering
# ============================================================
def compute_snr(wav_path: str) -> float:
"""Compute approximate SNR using RMS energy thresholding."""
import librosa
y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True)
rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]
threshold = np.percentile(rms, 20)
noise = rms[rms <= threshold]
speech = rms[rms > threshold]
if len(noise) > 0 and np.mean(noise) > 1e-10:
return float(20 * np.log10(np.mean(speech) / np.mean(noise)))
return 40.0
def compute_pitch_stats(wav_path: str) -> Tuple[float, float]:
"""Compute pitch mean and std for a single utterance."""
import librosa
y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True)
f0, _, _ = librosa.pyin(y, fmin=50, fmax=500, sr=sr)
f0v = f0[~np.isnan(f0)]
if len(f0v) > 0:
return float(np.mean(f0v)), float(np.std(f0v))
return 0.0, 0.0
def compute_speaking_rate(text: str, duration: float) -> float:
"""Characters per second (for Sinhala, grapheme clusters approximate syllables)."""
# Remove spaces and punctuation for char count
chars = len([c for c in text if c.strip() and c not in "!?.,;:\"'()-"])
if duration > 0:
return chars / duration
return 0.0
def filter_utterances(utterances: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
Apply quality filters based on IndicVoices-R and Emilia-Pipe thresholds.
Returns (kept, rejected) lists.
"""
import librosa
kept = []
rejected = []
for utt in tqdm(utterances, desc="Quality filtering"):
reasons = []
# Skip empty transcriptions
if not utt.get("text", "").strip():
reasons.append("empty_text")
# Skip very low language probability
if utt.get("language_prob", 0) < 0.5:
reasons.append(f"low_lang_prob={utt.get('language_prob', 0):.2f}")
# Duration check (should already be filtered, but double-check)
if utt["duration"] < MIN_SEGMENT_SEC or utt["duration"] > MAX_SEGMENT_SEC:
reasons.append(f"duration={utt['duration']:.1f}s")
# SNR check
try:
snr = compute_snr(utt["path"])
utt["snr_db"] = round(snr, 1)
if snr < SNR_THRESHOLD:
reasons.append(f"low_snr={snr:.1f}dB")
except Exception:
utt["snr_db"] = 0.0
reasons.append("snr_failed")
# Pitch check (detect multi-speaker leakage or non-speech)
try:
pitch_mean, pitch_std = compute_pitch_stats(utt["path"])
utt["pitch_mean_hz"] = round(pitch_mean, 1)
utt["pitch_std_hz"] = round(pitch_std, 1)
if pitch_mean > PITCH_MEAN_MAX:
reasons.append(f"high_pitch={pitch_mean:.0f}Hz")
if pitch_std > PITCH_STD_MAX:
reasons.append(f"high_pitch_var={pitch_std:.0f}Hz")
except Exception:
utt["pitch_mean_hz"] = 0.0
utt["pitch_std_hz"] = 0.0
# Speaking rate check
if utt.get("text"):
rate = compute_speaking_rate(utt["text"], utt["duration"])
utt["speaking_rate"] = round(rate, 1)
if rate > SPEAKING_RATE_MAX:
reasons.append(f"fast_speech={rate:.1f}c/s")
if rate < 1.0 and utt["duration"] > 3.0:
reasons.append(f"slow_speech={rate:.1f}c/s")
# Speech ratio (check for silence-heavy segments)
try:
y, sr_loaded = librosa.load(utt["path"], sr=SAMPLE_RATE, mono=True)
rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]
threshold = np.percentile(rms, 20)
speech_ratio = float(np.sum(rms > threshold) / len(rms))
utt["speech_ratio"] = round(speech_ratio, 3)
if speech_ratio < MIN_SPEECH_RATIO:
reasons.append(f"low_speech_ratio={speech_ratio:.2f}")
except Exception:
utt["speech_ratio"] = 0.0
if reasons:
utt["reject_reasons"] = reasons
rejected.append(utt)
else:
kept.append(utt)
log.info(f" [filter] Kept {len(kept)}/{len(utterances)} "
f"({len(kept)/len(utterances)*100:.1f}%)")
# Log rejection stats
if rejected:
all_reasons = [r for u in rejected for r in u.get("reject_reasons", [])]
reason_counts = {}
for r in all_reasons:
key = r.split("=")[0]
reason_counts[key] = reason_counts.get(key, 0) + 1
log.info(f" [filter] Rejection reasons: {reason_counts}")
return kept, rejected
# ============================================================
# STEP 8: Export as LJSpeech Format
# ============================================================
def export_dataset(utterances: List[Dict], output_dir: Path,
val_split: float = 0.05):
"""
Export as LJSpeech-format dataset for Coqui-TTS FastPitch training.
Creates:
output_dir/
wavs/ - all WAV files (22050 Hz, mono)
metadata.csv - full dataset: filename|text|normalized_text
metadata_train.csv
metadata_val.csv
dataset_stats.json - corpus statistics
"""
output_dir.mkdir(parents=True, exist_ok=True)
wavs_dir = output_dir / "wavs"
wavs_dir.mkdir(exist_ok=True)
# Copy/rename WAV files to sequential names
import shutil
metadata = []
for i, utt in enumerate(tqdm(utterances, desc="Exporting")):
new_name = f"si_{i:06d}"
new_path = wavs_dir / f"{new_name}.wav"
# Copy WAV
if not new_path.exists():
src = Path(utt["path"])
if src.exists():
shutil.copy2(str(src), str(new_path))
else:
continue
text = utt.get("text", "").strip()
if not text:
continue
# Normalize text: basic Sinhala text cleaning
normalized = _normalize_sinhala_text(text)
metadata.append(f"{new_name}|{text}|{normalized}")
# Shuffle and split
import random
random.seed(42)
random.shuffle(metadata)
n_val = max(1, int(len(metadata) * val_split))
val_lines = metadata[:n_val]
train_lines = metadata[n_val:]
# Write metadata files
with open(output_dir / "metadata.csv", "w", encoding="utf-8") as f:
f.write("\n".join(metadata) + "\n")
with open(output_dir / "metadata_train.csv", "w", encoding="utf-8") as f:
f.write("\n".join(train_lines) + "\n")
with open(output_dir / "metadata_val.csv", "w", encoding="utf-8") as f:
f.write("\n".join(val_lines) + "\n")
# Compute corpus statistics
durations = [u["duration"] for u in utterances]
stats = {
"total_utterances": len(metadata),
"train_utterances": len(train_lines),
"val_utterances": len(val_lines),
"total_hours": round(sum(durations) / 3600, 2),
"mean_duration_sec": round(np.mean(durations), 2),
"median_duration_sec": round(np.median(durations), 2),
"min_duration_sec": round(min(durations), 2),
"max_duration_sec": round(max(durations), 2),
"sample_rate": SAMPLE_RATE,
}
# Pitch stats across corpus
pitches = [u.get("pitch_mean_hz", 0) for u in utterances if u.get("pitch_mean_hz", 0) > 0]
if pitches:
stats["corpus_pitch_mean_hz"] = round(float(np.mean(pitches)), 1)
stats["corpus_pitch_std_hz"] = round(float(np.std(pitches)), 1)
with open(output_dir / "dataset_stats.json", "w") as f:
json.dump(stats, f, indent=2)
log.info(f"\n{'='*60}")
log.info(f"DATASET EXPORTED")
log.info(f"{'='*60}")
log.info(f" Location: {output_dir}")
log.info(f" Total: {stats['total_utterances']} utterances")
log.info(f" Train: {stats['train_utterances']}")
log.info(f" Val: {stats['val_utterances']}")
log.info(f" Duration: {stats['total_hours']} hours")
log.info(f" Mean length: {stats['mean_duration_sec']}s")
if 'corpus_pitch_mean_hz' in stats:
log.info(f" Pitch mean: {stats['corpus_pitch_mean_hz']} Hz")
log.info(f" Pitch std: {stats['corpus_pitch_std_hz']} Hz")
log.info(f"{'='*60}")
return stats
def _normalize_sinhala_text(text: str) -> str:
"""
Basic Sinhala text normalization for TTS.
- Unicode NFC normalization (canonical decomposition → composition)
- Remove zero-width characters (except ZWJ which forms conjuncts)
- Normalize punctuation
- Collapse whitespace
"""
import unicodedata
# NFC normalization (critical for Brahmic scripts)
text = unicodedata.normalize('NFC', text)
# Remove zero-width non-joiner (ZWNJ) but keep ZWJ (conjunct former)
text = text.replace('\u200C', '') # ZWNJ
# Keep \u200D (ZWJ) — it's part of Sinhala conjunct consonants like ක්‍ෂ
# Normalize quotation marks
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
# Replace semicolons and colons with commas (IndicTTS convention)
text = text.replace(';', ',').replace(':', ',')
# Remove parentheses (but keep content)
text = text.replace('(', '').replace(')', '')
text = text.replace('[', '').replace(']', '')
# Collapse whitespace
text = ' '.join(text.split())
return text.strip()
# ============================================================
# MAIN
# ============================================================
def parse_args():
parser = argparse.ArgumentParser(
description="Sinhala TTS Data Pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Input sources (pick one)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--video-list", type=str,
help="JSON file with video list (from list_unlimited_history.py)")
input_group.add_argument("--video-id", type=str,
help="Single YouTube video ID")
input_group.add_argument("--audio-dir", type=str,
help="Directory of pre-downloaded audio files")
# Output
parser.add_argument("--output-dir", type=str, default="sinhala_tts_dataset",
help="Output directory (default: sinhala_tts_dataset)")
# Pipeline options
parser.add_argument("--num-speakers", type=int, default=2,
help="Expected number of speakers per video (default: 2)")
parser.add_argument("--target-speaker", type=str, default=None,
help="Speaker label to extract (default: auto-select dominant)")
parser.add_argument("--whisper-model", type=str, default="large-v3",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Whisper model size (default: large-v3)")
parser.add_argument("--hf-token", type=str, default=None,
help="HuggingFace token for pyannote (optional)")
# Skip options
parser.add_argument("--skip-separation", action="store_true",
help="Skip source separation step")
parser.add_argument("--skip-enhancement", action="store_true",
help="Skip audio enhancement step")
parser.add_argument("--skip-diarization", action="store_true",
help="Skip diarization (treat all audio as one speaker)")
parser.add_argument("--skip-transcription", action="store_true",
help="Skip Whisper transcription (need pre-existing transcripts)")
# Control
parser.add_argument("--resume", action="store_true",
help="Resume from checkpoint (skip completed steps)")
parser.add_argument("--max-videos", type=int, default=None,
help="Process only first N videos (for testing)")
parser.add_argument("--batch-size", type=int, default=10,
help="Process videos in batches of N (default: 10)")
parser.add_argument("--use-unlimited-history-only", action="store_true",
help="Only use Unlimited History videos from the list")
return parser.parse_args()
def main():
args = parse_args()
output_dir = Path(args.output_dir)
raw_dir = output_dir / "raw_audio"
separated_dir = output_dir / "separated"
enhanced_dir = output_dir / "enhanced"
segments_dir = output_dir / "segments"
dataset_dir = output_dir / "dataset"
state = PipelineState(str(output_dir / ".state"))
log.info("=" * 60)
log.info("Sinhala TTS Data Pipeline")
log.info("=" * 60)
log.info(f"Output: {output_dir}")
log.info(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
log.info(f"GPU: {torch.cuda.get_device_name(0)}")
log.info("")
# ---- Resolve input ----
if args.video_list:
with open(args.video_list) as f:
data = json.load(f)
if args.use_unlimited_history_only:
videos = data.get("unlimited_history", [])
log.info(f"Using {len(videos)} Unlimited History videos")
else:
videos = data.get("unlimited_history", []) + data.get("other", [])
log.info(f"Using all {len(videos)} videos")
if args.max_videos:
videos = videos[:args.max_videos]
log.info(f"Limited to first {args.max_videos} videos")
elif args.video_id:
videos = [{"id": args.video_id, "title": args.video_id}]
elif args.audio_dir:
audio_dir = Path(args.audio_dir)
videos = []
for f in sorted(audio_dir.glob("*.wav")):
videos.append({"id": f.stem, "title": f.stem, "_local_path": str(f)})
log.info(f"Found {len(videos)} audio files in {audio_dir}")
if not videos:
log.error("No videos to process!")
sys.exit(1)
# ---- Process in batches ----
all_utterances = []
batch_size = args.batch_size
for batch_start in range(0, len(videos), batch_size):
batch = videos[batch_start:batch_start + batch_size]
batch_num = batch_start // batch_size + 1
total_batches = (len(videos) + batch_size - 1) // batch_size
log.info(f"\n{'='*60}")
log.info(f"BATCH {batch_num}/{total_batches} ({len(batch)} videos)")
log.info(f"{'='*60}")
for v in batch:
vid_id = v["id"]
title = v.get("title", vid_id)
log.info(f"\n--- Processing: {title[:60]} ({vid_id}) ---")
# Step 1: Download (or use local file)
if "_local_path" in v:
wav_path = Path(v["_local_path"])
elif args.resume and state.is_done(vid_id, "download"):
wav_path = raw_dir / f"{vid_id}.wav"
else:
wav_path = download_video(vid_id, raw_dir)
if wav_path:
state.mark_done(vid_id, "download")
if not wav_path or not wav_path.exists():
log.warning(f" Skipping {vid_id}: no audio")
continue
current_audio = wav_path
# Step 2: Source separation
if not args.skip_separation:
if args.resume and state.is_done(vid_id, "separation"):
current_audio = separated_dir / f"{vid_id}_vocals.wav"
if not current_audio.exists():
current_audio = wav_path
else:
result = separate_vocals(wav_path, separated_dir)
if result:
current_audio = result
state.mark_done(vid_id, "separation")
# Step 3: Audio enhancement
if not args.skip_enhancement:
if args.resume and state.is_done(vid_id, "enhancement"):
enh_path = enhanced_dir / f"{current_audio.stem}_enhanced.wav"
if enh_path.exists():
current_audio = enh_path
else:
result = enhance_audio(current_audio, enhanced_dir)
if result:
current_audio = result
state.mark_done(vid_id, "enhancement")
# Step 4: Speaker diarization
if not args.skip_diarization:
if args.resume and state.is_done(vid_id, "diarization"):
diar_file = output_dir / ".state" / f"{vid_id}_diarization.json"
if diar_file.exists():
with open(diar_file) as f:
speakers = json.load(f)
else:
speakers = diarize_audio(current_audio, args.num_speakers, args.hf_token)
else:
speakers = diarize_audio(current_audio, args.num_speakers, args.hf_token)
# Save diarization results
diar_file = output_dir / ".state" / f"{vid_id}_diarization.json"
with open(diar_file, "w") as f:
json.dump(speakers, f, indent=2)
state.mark_done(vid_id, "diarization")
target = select_target_speaker(speakers, args.target_speaker)
speaker_segments = speakers[target]
else:
# No diarization: use full audio
import librosa
dur = librosa.get_duration(path=str(current_audio))
speaker_segments = [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}]
# Step 5: VAD segmentation
vid_segments_dir = segments_dir / vid_id
utterances = segment_with_vad(current_audio, speaker_segments, vid_segments_dir)
# Step 6: ASR transcription
if not args.skip_transcription and utterances:
utterances = transcribe_utterances(utterances, args.whisper_model)
all_utterances.extend(utterances)
state.mark_done(vid_id, "complete")
log.info(f" Total utterances so far: {len(all_utterances)} "
f"({sum(u['duration'] for u in all_utterances)/3600:.1f}h)")
# ---- Step 7: Quality filtering ----
log.info(f"\n{'='*60}")
log.info(f"QUALITY FILTERING ({len(all_utterances)} utterances)")
log.info(f"{'='*60}")
if all_utterances:
kept, rejected = filter_utterances(all_utterances)
# Save rejected for inspection
with open(output_dir / "rejected_utterances.json", "w", encoding="utf-8") as f:
json.dump(rejected, f, indent=2, ensure_ascii=False)
# ---- Step 8: Export dataset ----
stats = export_dataset(kept, dataset_dir)
# Save full manifest
with open(output_dir / "full_manifest.json", "w", encoding="utf-8") as f:
json.dump(kept, f, indent=2, ensure_ascii=False)
log.info(f"\n{'='*60}")
log.info(f"PIPELINE COMPLETE")
log.info(f"{'='*60}")
log.info(f" Raw utterances: {len(all_utterances)}")
log.info(f" After filtering: {len(kept)}")
log.info(f" Rejected: {len(rejected)}")
log.info(f" Dataset: {dataset_dir}")
log.info(f" Next step: python scripts/train_fastpitch.py --dataset {dataset_dir}")
else:
log.error("No utterances extracted! Check logs above for errors.")
if __name__ == "__main__":
main()