sinhala-tts / scripts /cloud_pipeline.py
outlawmold's picture
feat(macos): implement Apple Silicon optimizations and switch to wav2vec2 ASR
1a2a2b3
#!/usr/bin/env python3
"""
=============================================================
Sinhala TTS - Phase 2: Cloud GPU Processing Pipeline
=============================================================
Runs on HF Jobs. Reads raw audio from HF dataset repo,
processes through the full pipeline, and pushes results back.
Pipeline:
1. Download raw audio from HF dataset repo
2. Source separation (HTDemucs → vocals only)
3. Audio enhancement (VoiceFixer + DeepFilterNet3)
4. Speaker diarization (pyannote 3.1 / simple-diarizer fallback)
5. VAD segmentation (Silero-VAD, 3-20s chunks)
6. ASR transcription (Whisper large-v3)
7. Quality filtering
8. Export as LJSpeech-format dataset → push to Hub
Usage (on HF Jobs - configured via hf_jobs tool):
python scripts/cloud_pipeline.py \
--source-repo outlawmold/sinhala-tts-raw-audio \
--output-repo outlawmold/sinhala-tts-dataset \
--batch-size 5
=============================================================
"""
import os
import sys
import json
import argparse
import logging
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import numpy as np
import torch
import torchaudio
import soundfile as sf
from tqdm import tqdm
warnings.filterwarnings("ignore")
# ============================================================
# CONFIG
# ============================================================
SAMPLE_RATE = 22050
DIARIZE_SR = 16000
MIN_SEGMENT_SEC = 3.0
MAX_SEGMENT_SEC = 20.0
# Quality thresholds (IndicVoices-R + Emilia-Pipe)
SNR_THRESHOLD = 25.0
PITCH_MEAN_MAX = 350.0
PITCH_STD_MAX = 150.0
SPEAKING_RATE_MAX = 30.0
MIN_SPEECH_RATIO = 0.5
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%H:%M:%S',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('/app/pipeline.log'),
]
)
log = logging.getLogger("sinhala-tts-cloud")
# ============================================================
# HUB I/O
# ============================================================
def get_api():
from huggingface_hub import HfApi
return HfApi()
def download_raw_audio(source_repo: str, work_dir: Path, video_ids: List[str] = None) -> List[Path]:
"""Download raw audio files from HF dataset repo."""
from huggingface_hub import hf_hub_download, list_repo_tree
api = get_api()
audio_dir = work_dir / "raw"
audio_dir.mkdir(parents=True, exist_ok=True)
# List available audio files
files = list(api.list_repo_tree(source_repo, repo_type="dataset", path_in_repo="audio"))
wav_files = [f for f in files if f.rfilename.endswith(".wav")]
log.info(f"Found {len(wav_files)} audio files in {source_repo}")
# Filter to specific video IDs if requested
if video_ids:
vid_set = set(video_ids)
wav_files = [f for f in wav_files if Path(f.rfilename).stem in vid_set]
log.info(f"Filtered to {len(wav_files)} requested videos")
downloaded = []
for wf in wav_files:
local_path = audio_dir / Path(wf.rfilename).name
if local_path.exists():
downloaded.append(local_path)
continue
try:
dl_path = hf_hub_download(
repo_id=source_repo,
filename=wf.rfilename,
repo_type="dataset",
local_dir=str(work_dir / "_hub_cache"),
)
shutil.copy2(dl_path, str(local_path))
downloaded.append(local_path)
except Exception as e:
log.error(f"Failed to download {wf.rfilename}: {e}")
log.info(f"Downloaded {len(downloaded)} audio files")
return downloaded
def load_processing_state(output_repo: str) -> dict:
"""Load processing state from output repo."""
api = get_api()
try:
path = api.hf_hub_download(
repo_id=output_repo,
filename="processing_state.json",
repo_type="dataset",
)
with open(path) as f:
return json.load(f)
except Exception:
return {"completed_videos": [], "total_utterances": 0, "total_hours": 0.0}
def save_processing_state(output_repo: str, state: dict):
"""Save processing state to output repo."""
api = get_api()
state_bytes = json.dumps(state, indent=2).encode("utf-8")
api.upload_file(
path_or_fileobj=state_bytes,
path_in_repo="processing_state.json",
repo_id=output_repo,
repo_type="dataset",
commit_message=f"Update state: {len(state['completed_videos'])} videos, {state['total_hours']:.1f}h",
)
def upload_utterances_batch(
utterances: List[dict],
output_repo: str,
video_id: str,
):
"""Upload processed utterances (WAV + metadata) for one video."""
from huggingface_hub import HfApi, CommitOperationAdd
api = get_api()
operations = []
for utt in utterances:
wav_path = Path(utt["path"])
if not wav_path.exists():
continue
remote_path = f"wavs/{wav_path.name}"
operations.append(
CommitOperationAdd(
path_in_repo=remote_path,
path_or_fileobj=str(wav_path),
)
)
# Also upload per-video metadata
meta_bytes = json.dumps(utterances, indent=2, ensure_ascii=False).encode("utf-8")
operations.append(
CommitOperationAdd(
path_in_repo=f"metadata/{video_id}.json",
path_or_fileobj=meta_bytes,
)
)
if operations:
try:
api.create_commit(
repo_id=output_repo,
repo_type="dataset",
operations=operations,
commit_message=f"Add {len(utterances)} utterances from {video_id}",
)
log.info(f" [upload] Pushed {len(utterances)} utterances for {video_id}")
except Exception as e:
log.error(f" [upload] Failed to push {video_id}: {e}")
def upload_final_dataset(
all_utterances: List[dict],
dataset_dir: Path,
output_repo: str,
stats: dict,
):
"""Upload the final LJSpeech-format dataset."""
from huggingface_hub import HfApi, CommitOperationAdd
api = get_api()
operations = []
# Upload metadata CSVs
for csv_name in ["metadata.csv", "metadata_train.csv", "metadata_val.csv"]:
csv_path = dataset_dir / csv_name
if csv_path.exists():
operations.append(
CommitOperationAdd(
path_in_repo=csv_name,
path_or_fileobj=str(csv_path),
)
)
# Upload stats
stats_bytes = json.dumps(stats, indent=2).encode("utf-8")
operations.append(
CommitOperationAdd(
path_in_repo="dataset_stats.json",
path_or_fileobj=stats_bytes,
)
)
# Upload README
readme = _generate_dataset_readme(stats)
operations.append(
CommitOperationAdd(
path_in_repo="README.md",
path_or_fileobj=readme.encode("utf-8"),
)
)
if operations:
try:
api.create_commit(
repo_id=output_repo,
repo_type="dataset",
operations=operations,
commit_message=f"Final dataset: {stats['total_utterances']} utterances, {stats['total_hours']}h",
)
log.info(f"Final dataset pushed to {output_repo}")
except Exception as e:
log.error(f"Failed to push final dataset: {e}")
def _generate_dataset_readme(stats: dict) -> str:
return f"""---
language:
- si
license: cc-by-4.0
task_categories:
- text-to-speech
- automatic-speech-recognition
pretty_name: Sinhala TTS Dataset (Unlimited History)
size_categories:
- 10K<n<100K
tags:
- sinhala
- tts
- speech
---
# Sinhala TTS Dataset
Clean, segmented Sinhala speech from the "Unlimited History" YouTube series by @sunchare.
## Dataset Statistics
| Metric | Value |
|---|---|
| Total utterances | {stats.get('total_utterances', 'N/A')} |
| Training set | {stats.get('train_utterances', 'N/A')} |
| Validation set | {stats.get('val_utterances', 'N/A')} |
| Total hours | {stats.get('total_hours', 'N/A')} |
| Mean duration | {stats.get('mean_duration_sec', 'N/A')}s |
| Sample rate | {stats.get('sample_rate', 22050)} Hz |
## Processing Pipeline
Raw YouTube audio → HTDemucs (source separation) → VoiceFixer + DeepFilterNet3 (enhancement) →
pyannote 3.1 (diarization) → Silero-VAD (segmentation) → Whisper large-v3 (transcription) →
Quality filtering (SNR≥25dB, pitch, speaking rate)
## Format
LJSpeech-compatible:
- `wavs/` — mono WAV files at 22050 Hz
- `metadata.csv` — `filename|text|normalized_text`
- `metadata_train.csv` / `metadata_val.csv` — train/val splits
## Source
[NU1's VLOG (@sunchare)](https://www.youtube.com/@sunchare) - "Unlimited History" series on Sri Lankan history.
"""
# ============================================================
# PROCESSING STEPS (adapted from data_pipeline.py)
# ============================================================
def separate_vocals(wav_path: Path, output_dir: Path) -> Path:
"""HTDemucs source separation."""
output_path = output_dir / f"{wav_path.stem}_vocals.wav"
if output_path.exists():
return output_path
try:
from demucs.pretrained import get_model
from demucs.apply import apply_model
model = get_model("htdemucs")
model.eval()
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
waveform, sr = torchaudio.load(str(wav_path))
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
if sr != model.samplerate:
waveform = torchaudio.transforms.Resample(sr, model.samplerate)(waveform)
waveform = waveform.unsqueeze(0).to(device)
# Process in chunks to avoid OOM on long audio
with torch.no_grad():
sources = apply_model(model, waveform, device=device, split=True, overlap=0.25)
vocals = sources[0, 3] # drums, bass, other, vocals
vocals_mono = vocals.mean(dim=0, keepdim=True)
if model.samplerate != SAMPLE_RATE:
vocals_mono = torchaudio.transforms.Resample(model.samplerate, SAMPLE_RATE)(vocals_mono)
output_dir.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(output_path), vocals_mono.cpu(), SAMPLE_RATE)
log.info(f" [separation] Done: {output_path.name}")
return output_path
except Exception as e:
log.warning(f" [separation] Failed ({e}), using original audio")
return wav_path
def enhance_audio(wav_path: Path, output_dir: Path) -> Path:
"""VoiceFixer + DeepFilterNet3 enhancement."""
output_path = output_dir / f"{wav_path.stem}_enhanced.wav"
if output_path.exists():
return output_path
output_dir.mkdir(parents=True, exist_ok=True)
current_path = wav_path
# Stage 1: VoiceFixer
try:
from voicefixer import VoiceFixer
vf = VoiceFixer()
vf_output = output_dir / f"{wav_path.stem}_vf.wav"
vf.restore(
input=str(current_path),
output=str(vf_output),
cuda=torch.cuda.is_available(),
mode=0,
)
if vf_output.exists():
current_path = vf_output
log.info(f" [enhance] VoiceFixer done")
except Exception as e:
log.warning(f" [enhance] VoiceFixer failed: {e}")
# Stage 2: DeepFilterNet3
try:
from df.enhance import enhance, init_df, load_audio, save_audio
df_model, df_state, _ = init_df()
audio, _ = load_audio(str(current_path), sr=df_state.sr())
enhanced = enhance(df_model, df_state, audio)
save_audio(str(output_path), enhanced, df_state.sr())
if output_path.exists():
waveform, sr = torchaudio.load(str(output_path))
if sr != SAMPLE_RATE:
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
torchaudio.save(str(output_path), waveform, SAMPLE_RATE)
log.info(f" [enhance] DeepFilterNet3 done")
elif current_path != wav_path:
shutil.copy2(str(current_path), str(output_path))
except Exception as e:
log.warning(f" [enhance] DeepFilterNet3 failed: {e}")
if current_path != wav_path:
shutil.copy2(str(current_path), str(output_path))
else:
return wav_path
# Cleanup VoiceFixer temp
vf_temp = output_dir / f"{wav_path.stem}_vf.wav"
if vf_temp.exists() and output_path.exists() and vf_temp != output_path:
vf_temp.unlink()
return output_path if output_path.exists() else wav_path
def diarize_audio(wav_path: Path, num_speakers: int = 2) -> Dict[str, List[Dict]]:
"""Speaker diarization with pyannote 3.1 (fallback to simple-diarizer)."""
token = os.environ.get("HF_TOKEN")
# Try pyannote first
if token:
try:
from pyannote.audio import Pipeline as PyannotePipeline
pipeline = PyannotePipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=token,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
waveform, sr = torchaudio.load(str(wav_path))
if sr != DIARIZE_SR:
waveform = torchaudio.transforms.Resample(sr, DIARIZE_SR)(waveform)
diarization = pipeline(
{"waveform": waveform, "sample_rate": DIARIZE_SR},
num_speakers=num_speakers,
)
speakers = {}
for turn, _, speaker in diarization.itertracks(yield_label=True):
if speaker not in speakers:
speakers[speaker] = []
speakers[speaker].append({
"start": round(turn.start, 3),
"end": round(turn.end, 3),
"duration": round(turn.end - turn.start, 3),
})
log.info(f" [diarize] {len(speakers)} speakers (pyannote 3.1)")
return speakers
except Exception as e:
log.warning(f" [diarize] pyannote failed: {e}")
# Fallback: simple-diarizer
try:
from simple_diarizer.diarizer import Diarizer
# Monkeypatch torchaudio.load for compatibility
def _fixed_load(uri, frame_offset=0, num_frames=-1, normalize=True, channels_first=True, **kwargs):
stop = None if num_frames == -1 else frame_offset + num_frames
data, samplerate = sf.read(uri, start=frame_offset, stop=stop, dtype='float32')
tensor = torch.from_numpy(data)
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif channels_first:
tensor = tensor.T
return tensor, samplerate
torchaudio.load = _fixed_load
diar = Diarizer(embed_model='ecapa', cluster_method='sc')
segments = diar.diarize(str(wav_path), num_speakers=num_speakers)
speakers = {}
for seg in segments:
label = str(seg['label'])
if label not in speakers:
speakers[label] = []
speakers[label].append({
"start": round(seg['start'], 3),
"end": round(seg['end'], 3),
"duration": round(seg['end'] - seg['start'], 3),
})
log.info(f" [diarize] {len(speakers)} speakers (simple-diarizer)")
return speakers
except Exception as e:
log.error(f" [diarize] All diarization failed: {e}")
import librosa
dur = librosa.get_duration(path=str(wav_path))
return {"SPEAKER_0": [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}]}
def select_target_speaker(speakers: Dict[str, List[Dict]]) -> str:
"""Select speaker with most speaking time."""
durations = {spk: sum(s["duration"] for s in segs) for spk, segs in speakers.items()}
best = max(durations, key=durations.get)
log.info(f" [diarize] Target: {best} ({durations[best]/60:.1f}min / {sum(durations.values())/60:.1f}min)")
return best
def segment_with_vad(wav_path: Path, speaker_segments: List[Dict], output_dir: Path) -> List[Dict]:
"""Silero-VAD segmentation within speaker turns."""
output_dir.mkdir(parents=True, exist_ok=True)
waveform, sr = torchaudio.load(str(wav_path))
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sr != 16000:
waveform_16k = torchaudio.transforms.Resample(sr, 16000)(waveform)
else:
waveform_16k = waveform
vad_model, vad_utils = torch.hub.load(
'snakers4/silero-vad', 'silero_vad', force_reload=False, trust_repo=True,
)
get_speech_timestamps = vad_utils[0]
utterances = []
utt_idx = 0
for seg in speaker_segments:
s16 = int(seg["start"] * 16000)
e16 = int(seg["end"] * 16000)
seg_audio = waveform_16k[0, s16:e16]
if len(seg_audio) < int(MIN_SEGMENT_SEC * 16000):
continue
try:
speech_ts = get_speech_timestamps(
seg_audio, vad_model, sampling_rate=16000,
min_speech_duration_ms=500, min_silence_duration_ms=300,
speech_pad_ms=100, return_seconds=False,
)
except Exception:
speech_ts = [{"start": 0, "end": len(seg_audio)}]
if not speech_ts:
continue
merged = _merge_vad_segments(speech_ts, sr=16000)
for vad_seg in merged:
vad_start_sec = seg["start"] + vad_seg["start"] / 16000
vad_end_sec = seg["start"] + vad_seg["end"] / 16000
duration = vad_end_sec - vad_start_sec
if duration < MIN_SEGMENT_SEC or duration > MAX_SEGMENT_SEC:
continue
start_sample = int(vad_start_sec * sr)
end_sample = int(vad_end_sec * sr)
utt_audio = waveform[:, start_sample:end_sample]
if sr != SAMPLE_RATE:
utt_audio = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(utt_audio)
peak = utt_audio.abs().max()
if peak > 0:
utt_audio = utt_audio * (10 ** (-3 / 20) / peak)
utt_name = f"{wav_path.stem}_utt{utt_idx:05d}.wav"
utt_path = output_dir / utt_name
torchaudio.save(str(utt_path), utt_audio, SAMPLE_RATE)
utterances.append({
"path": str(utt_path),
"filename": utt_name,
"start": round(vad_start_sec, 3),
"end": round(vad_end_sec, 3),
"duration": round(duration, 3),
})
utt_idx += 1
log.info(f" [vad] {len(utterances)} utterances ({sum(u['duration'] for u in utterances)/60:.1f}min)")
return utterances
def _merge_vad_segments(segments, sr=16000, gap_ms=500):
if not segments:
return []
gap_samples = int(gap_ms * sr / 1000)
merged = [{"start": segments[0]["start"], "end": segments[0]["end"]}]
for seg in segments[1:]:
if seg["start"] - merged[-1]["end"] < gap_samples:
merged[-1]["end"] = seg["end"]
else:
merged.append({"start": seg["start"], "end": seg["end"]})
final = []
for seg in merged:
dur = (seg["end"] - seg["start"]) / sr
if dur > MAX_SEGMENT_SEC:
chunk = int(MAX_SEGMENT_SEC * sr)
pos = seg["start"]
while pos < seg["end"]:
end = min(pos + chunk, seg["end"])
if (end - pos) / sr >= MIN_SEGMENT_SEC:
final.append({"start": pos, "end": end})
pos = end
else:
final.append(seg)
return final
def transcribe_utterances(utterances: List[Dict], model_size: str = "large-v3") -> List[Dict]:
"""Whisper transcription."""
try:
from faster_whisper import WhisperModel
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
log.info(f" [asr] Loading faster-whisper {model_size} on {device}...")
model = WhisperModel(model_size, device=device, compute_type=compute_type)
for utt in tqdm(utterances, desc="Transcribing", leave=False):
try:
segments, info = model.transcribe(
utt["path"], language="si", beam_size=5, best_of=5,
temperature=0.0, condition_on_previous_text=False, vad_filter=False,
)
utt["text"] = " ".join(seg.text.strip() for seg in segments).strip()
utt["language_prob"] = info.language_probability
except Exception as e:
utt["text"] = ""
utt["language_prob"] = 0.0
return utterances
except ImportError:
log.error(" [asr] faster-whisper not installed!")
return utterances
def compute_snr(wav_path: str) -> float:
import librosa
y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True)
rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]
threshold = np.percentile(rms, 20)
noise = rms[rms <= threshold]
speech = rms[rms > threshold]
if len(noise) > 0 and np.mean(noise) > 1e-10:
return float(20 * np.log10(np.mean(speech) / np.mean(noise)))
return 40.0
def compute_pitch_stats(wav_path: str) -> Tuple[float, float]:
import librosa
y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True)
f0, _, _ = librosa.pyin(y, fmin=50, fmax=500, sr=sr)
f0v = f0[~np.isnan(f0)]
if len(f0v) > 0:
return float(np.mean(f0v)), float(np.std(f0v))
return 0.0, 0.0
def filter_utterances(utterances: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""Quality filtering."""
import librosa
kept, rejected = [], []
for utt in tqdm(utterances, desc="Filtering", leave=False):
reasons = []
if not utt.get("text", "").strip():
reasons.append("empty_text")
if utt.get("language_prob", 0) < 0.5:
reasons.append(f"low_lang_prob={utt.get('language_prob', 0):.2f}")
if utt["duration"] < MIN_SEGMENT_SEC or utt["duration"] > MAX_SEGMENT_SEC:
reasons.append(f"duration={utt['duration']:.1f}s")
try:
snr = compute_snr(utt["path"])
utt["snr_db"] = round(snr, 1)
if snr < SNR_THRESHOLD:
reasons.append(f"low_snr={snr:.1f}dB")
except Exception:
utt["snr_db"] = 0.0
reasons.append("snr_failed")
try:
pm, ps = compute_pitch_stats(utt["path"])
utt["pitch_mean_hz"] = round(pm, 1)
utt["pitch_std_hz"] = round(ps, 1)
if pm > PITCH_MEAN_MAX:
reasons.append(f"high_pitch={pm:.0f}Hz")
if ps > PITCH_STD_MAX:
reasons.append(f"high_pitch_var={ps:.0f}Hz")
except Exception:
utt["pitch_mean_hz"] = 0.0
utt["pitch_std_hz"] = 0.0
if utt.get("text"):
chars = len([c for c in utt["text"] if c.strip() and c not in "!?.,;:\"'()-"])
rate = chars / utt["duration"] if utt["duration"] > 0 else 0
utt["speaking_rate"] = round(rate, 1)
if rate > SPEAKING_RATE_MAX:
reasons.append(f"fast_speech={rate:.1f}c/s")
if rate < 1.0 and utt["duration"] > 3.0:
reasons.append(f"slow_speech={rate:.1f}c/s")
try:
y, _ = librosa.load(utt["path"], sr=SAMPLE_RATE, mono=True)
rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]
threshold = np.percentile(rms, 20)
speech_ratio = float(np.sum(rms > threshold) / len(rms))
utt["speech_ratio"] = round(speech_ratio, 3)
if speech_ratio < MIN_SPEECH_RATIO:
reasons.append(f"low_speech_ratio={speech_ratio:.2f}")
except Exception:
utt["speech_ratio"] = 0.0
if reasons:
utt["reject_reasons"] = reasons
rejected.append(utt)
else:
kept.append(utt)
log.info(f" [filter] Kept {len(kept)}/{len(utterances)} ({len(kept)/max(1,len(utterances))*100:.1f}%)")
if rejected:
all_reasons = [r for u in rejected for r in u.get("reject_reasons", [])]
reason_counts = {}
for r in all_reasons:
reason_counts[r.split("=")[0]] = reason_counts.get(r.split("=")[0], 0) + 1
log.info(f" [filter] Rejections: {reason_counts}")
return kept, rejected
def normalize_sinhala_text(text: str) -> str:
import unicodedata
text = unicodedata.normalize('NFC', text)
text = text.replace('\u200C', '')
text = text.replace('\u201c', '"').replace('\u201d', '"')
text = text.replace('\u2018', "'").replace('\u2019', "'")
text = text.replace(';', ',').replace(':', ',')
text = text.replace('(', '').replace(')', '')
text = text.replace('[', '').replace(']', '')
return ' '.join(text.split()).strip()
def export_dataset(utterances: List[Dict], output_dir: Path, val_split: float = 0.05) -> dict:
"""Export as LJSpeech format."""
import random
output_dir.mkdir(parents=True, exist_ok=True)
wavs_dir = output_dir / "wavs"
wavs_dir.mkdir(exist_ok=True)
metadata = []
for i, utt in enumerate(tqdm(utterances, desc="Exporting", leave=False)):
name = f"si_{i:06d}"
new_path = wavs_dir / f"{name}.wav"
src = Path(utt["path"])
if src.exists() and not new_path.exists():
shutil.copy2(str(src), str(new_path))
text = utt.get("text", "").strip()
if not text:
continue
metadata.append(f"{name}|{text}|{normalize_sinhala_text(text)}")
random.seed(42)
random.shuffle(metadata)
n_val = max(1, int(len(metadata) * val_split))
(output_dir / "metadata.csv").write_text("\n".join(metadata) + "\n", encoding="utf-8")
(output_dir / "metadata_train.csv").write_text("\n".join(metadata[n_val:]) + "\n", encoding="utf-8")
(output_dir / "metadata_val.csv").write_text("\n".join(metadata[:n_val]) + "\n", encoding="utf-8")
durations = [u["duration"] for u in utterances]
stats = {
"total_utterances": len(metadata),
"train_utterances": len(metadata) - n_val,
"val_utterances": n_val,
"total_hours": round(sum(durations) / 3600, 2),
"mean_duration_sec": round(float(np.mean(durations)), 2),
"median_duration_sec": round(float(np.median(durations)), 2),
"min_duration_sec": round(min(durations), 2),
"max_duration_sec": round(max(durations), 2),
"sample_rate": SAMPLE_RATE,
}
pitches = [u.get("pitch_mean_hz", 0) for u in utterances if u.get("pitch_mean_hz", 0) > 0]
if pitches:
stats["corpus_pitch_mean_hz"] = round(float(np.mean(pitches)), 1)
stats["corpus_pitch_std_hz"] = round(float(np.std(pitches)), 1)
(output_dir / "dataset_stats.json").write_text(json.dumps(stats, indent=2))
log.info(f"\n{'='*60}")
log.info(f"DATASET: {stats['total_utterances']} utts, {stats['total_hours']}h")
log.info(f" Train: {stats['train_utterances']}, Val: {stats['val_utterances']}")
log.info(f" Duration: {stats['mean_duration_sec']}s mean, {stats['median_duration_sec']}s median")
log.info(f"{'='*60}")
return stats
# ============================================================
# MAIN PIPELINE
# ============================================================
def process_one_video(
wav_path: Path,
work_dir: Path,
whisper_model: str,
num_speakers: int,
skip_separation: bool,
skip_enhancement: bool,
) -> Tuple[List[Dict], List[Dict]]:
"""Full pipeline for one video. Returns (kept_utterances, rejected_utterances)."""
vid_id = wav_path.stem
# Step 1: Source separation
current_audio = wav_path
if not skip_separation:
current_audio = separate_vocals(wav_path, work_dir / "separated")
# Step 2: Enhancement
if not skip_enhancement:
current_audio = enhance_audio(current_audio, work_dir / "enhanced")
# Step 3: Diarization
speakers = diarize_audio(current_audio, num_speakers)
target = select_target_speaker(speakers)
speaker_segments = speakers[target]
# Step 4: VAD segmentation
utterances = segment_with_vad(current_audio, speaker_segments, work_dir / "segments" / vid_id)
if not utterances:
return [], []
# Step 5: Transcription
utterances = transcribe_utterances(utterances, whisper_model)
# Step 6: Quality filtering
kept, rejected = filter_utterances(utterances)
return kept, rejected
def main():
parser = argparse.ArgumentParser(description="Sinhala TTS Cloud Pipeline (Phase 2)")
parser.add_argument("--source-repo", required=True, help="HF dataset repo with raw audio")
parser.add_argument("--output-repo", required=True, help="HF dataset repo for processed output")
parser.add_argument("--whisper-model", default="large-v3")
parser.add_argument("--num-speakers", type=int, default=2)
parser.add_argument("--batch-size", type=int, default=5, help="Videos per processing batch")
parser.add_argument("--max-videos", type=int, default=None)
parser.add_argument("--skip-separation", action="store_true")
parser.add_argument("--skip-enhancement", action="store_true")
parser.add_argument("--video-ids", type=str, default=None, help="Comma-separated video IDs to process")
args = parser.parse_args()
work_dir = Path("/app/work")
work_dir.mkdir(parents=True, exist_ok=True)
dataset_dir = work_dir / "dataset"
log.info("=" * 60)
log.info("Sinhala TTS Cloud Pipeline (Phase 2)")
log.info("=" * 60)
log.info(f"Source: {args.source_repo}")
log.info(f"Output: {args.output_repo}")
log.info(f"Device: {'CUDA — ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
# Create output repo
api = get_api()
api.create_repo(repo_id=args.output_repo, repo_type="dataset", exist_ok=True)
# Load state
state = load_processing_state(args.output_repo)
completed = set(state["completed_videos"])
log.info(f"Already completed: {len(completed)} videos")
# Download raw audio
video_ids = args.video_ids.split(",") if args.video_ids else None
raw_files = download_raw_audio(args.source_repo, work_dir, video_ids)
# Filter out already completed
raw_files = [f for f in raw_files if f.stem not in completed]
if args.max_videos:
raw_files = raw_files[:args.max_videos]
log.info(f"To process: {len(raw_files)} videos")
if not raw_files:
log.info("Nothing to process!")
return
# Process in batches
all_kept = []
all_rejected = []
for i, wav_path in enumerate(raw_files):
vid_id = wav_path.stem
log.info(f"\n[{i+1}/{len(raw_files)}] Processing: {vid_id}")
try:
kept, rejected = process_one_video(
wav_path, work_dir, args.whisper_model, args.num_speakers,
args.skip_separation, args.skip_enhancement,
)
all_kept.extend(kept)
all_rejected.extend(rejected)
# Upload utterances for this video
if kept:
upload_utterances_batch(kept, args.output_repo, vid_id)
# Update state
state["completed_videos"].append(vid_id)
state["total_utterances"] = len(all_kept)
state["total_hours"] = round(sum(u["duration"] for u in all_kept) / 3600, 2)
save_processing_state(args.output_repo, state)
log.info(f" TOTAL so far: {len(all_kept)} utterances, {state['total_hours']}h")
# Cleanup this video's intermediate files to save disk
for subdir in ["separated", "enhanced", "segments"]:
d = work_dir / subdir
if d.exists():
for f in d.rglob(f"{vid_id}*"):
f.unlink(missing_ok=True)
wav_path.unlink(missing_ok=True)
except Exception as e:
log.error(f" FAILED: {e}")
import traceback
traceback.print_exc()
continue
# Export final dataset
if all_kept:
log.info(f"\n{'='*60}")
log.info(f"EXPORTING FINAL DATASET")
log.info(f"{'='*60}")
stats = export_dataset(all_kept, dataset_dir)
upload_final_dataset(all_kept, dataset_dir, args.output_repo, stats)
# Also upload rejected for inspection
rej_bytes = json.dumps(all_rejected, indent=2, ensure_ascii=False).encode("utf-8")
api.upload_file(
path_or_fileobj=rej_bytes,
path_in_repo="rejected_utterances.json",
repo_id=args.output_repo,
repo_type="dataset",
commit_message=f"Rejected: {len(all_rejected)} utterances",
)
log.info(f"\n{'='*60}")
log.info(f"PIPELINE COMPLETE")
log.info(f"{'='*60}")
log.info(f" Processed: {len(state['completed_videos'])} videos")
log.info(f" Kept: {len(all_kept)} utterances")
log.info(f" Rejected: {len(all_rejected)} utterances")
log.info(f" Total hours: {state['total_hours']}")
log.info(f" Output: https://huggingface.co/datasets/{args.output_repo}")
if __name__ == "__main__":
main()