#!/usr/bin/env python3 """ ============================================================= Sinhala TTS - Phase 2: Cloud GPU Processing Pipeline ============================================================= Runs on HF Jobs. Reads raw audio from HF dataset repo, processes through the full pipeline, and pushes results back. Pipeline: 1. Download raw audio from HF dataset repo 2. Source separation (HTDemucs → vocals only) 3. Audio enhancement (VoiceFixer + DeepFilterNet3) 4. Speaker diarization (pyannote 3.1 / simple-diarizer fallback) 5. VAD segmentation (Silero-VAD, 3-20s chunks) 6. ASR transcription (Whisper large-v3) 7. Quality filtering 8. Export as LJSpeech-format dataset → push to Hub Usage (on HF Jobs - configured via hf_jobs tool): python scripts/cloud_pipeline.py \ --source-repo outlawmold/sinhala-tts-raw-audio \ --output-repo outlawmold/sinhala-tts-dataset \ --batch-size 5 ============================================================= """ import os import sys import json import argparse import logging import shutil import tempfile import warnings from pathlib import Path from typing import Optional, Dict, List, Tuple import numpy as np import torch import torchaudio import soundfile as sf from tqdm import tqdm warnings.filterwarnings("ignore") # ============================================================ # CONFIG # ============================================================ SAMPLE_RATE = 22050 DIARIZE_SR = 16000 MIN_SEGMENT_SEC = 3.0 MAX_SEGMENT_SEC = 20.0 # Quality thresholds (IndicVoices-R + Emilia-Pipe) SNR_THRESHOLD = 25.0 PITCH_MEAN_MAX = 350.0 PITCH_STD_MAX = 150.0 SPEAKING_RATE_MAX = 30.0 MIN_SPEECH_RATIO = 0.5 logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%H:%M:%S', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('/app/pipeline.log'), ] ) log = logging.getLogger("sinhala-tts-cloud") # ============================================================ # HUB I/O # ============================================================ def get_api(): from huggingface_hub import HfApi return HfApi() def download_raw_audio(source_repo: str, work_dir: Path, video_ids: List[str] = None) -> List[Path]: """Download raw audio files from HF dataset repo.""" from huggingface_hub import hf_hub_download, list_repo_tree api = get_api() audio_dir = work_dir / "raw" audio_dir.mkdir(parents=True, exist_ok=True) # List available audio files files = list(api.list_repo_tree(source_repo, repo_type="dataset", path_in_repo="audio")) wav_files = [f for f in files if f.rfilename.endswith(".wav")] log.info(f"Found {len(wav_files)} audio files in {source_repo}") # Filter to specific video IDs if requested if video_ids: vid_set = set(video_ids) wav_files = [f for f in wav_files if Path(f.rfilename).stem in vid_set] log.info(f"Filtered to {len(wav_files)} requested videos") downloaded = [] for wf in wav_files: local_path = audio_dir / Path(wf.rfilename).name if local_path.exists(): downloaded.append(local_path) continue try: dl_path = hf_hub_download( repo_id=source_repo, filename=wf.rfilename, repo_type="dataset", local_dir=str(work_dir / "_hub_cache"), ) shutil.copy2(dl_path, str(local_path)) downloaded.append(local_path) except Exception as e: log.error(f"Failed to download {wf.rfilename}: {e}") log.info(f"Downloaded {len(downloaded)} audio files") return downloaded def load_processing_state(output_repo: str) -> dict: """Load processing state from output repo.""" api = get_api() try: path = api.hf_hub_download( repo_id=output_repo, filename="processing_state.json", repo_type="dataset", ) with open(path) as f: return json.load(f) except Exception: return {"completed_videos": [], "total_utterances": 0, "total_hours": 0.0} def save_processing_state(output_repo: str, state: dict): """Save processing state to output repo.""" api = get_api() state_bytes = json.dumps(state, indent=2).encode("utf-8") api.upload_file( path_or_fileobj=state_bytes, path_in_repo="processing_state.json", repo_id=output_repo, repo_type="dataset", commit_message=f"Update state: {len(state['completed_videos'])} videos, {state['total_hours']:.1f}h", ) def upload_utterances_batch( utterances: List[dict], output_repo: str, video_id: str, ): """Upload processed utterances (WAV + metadata) for one video.""" from huggingface_hub import HfApi, CommitOperationAdd api = get_api() operations = [] for utt in utterances: wav_path = Path(utt["path"]) if not wav_path.exists(): continue remote_path = f"wavs/{wav_path.name}" operations.append( CommitOperationAdd( path_in_repo=remote_path, path_or_fileobj=str(wav_path), ) ) # Also upload per-video metadata meta_bytes = json.dumps(utterances, indent=2, ensure_ascii=False).encode("utf-8") operations.append( CommitOperationAdd( path_in_repo=f"metadata/{video_id}.json", path_or_fileobj=meta_bytes, ) ) if operations: try: api.create_commit( repo_id=output_repo, repo_type="dataset", operations=operations, commit_message=f"Add {len(utterances)} utterances from {video_id}", ) log.info(f" [upload] Pushed {len(utterances)} utterances for {video_id}") except Exception as e: log.error(f" [upload] Failed to push {video_id}: {e}") def upload_final_dataset( all_utterances: List[dict], dataset_dir: Path, output_repo: str, stats: dict, ): """Upload the final LJSpeech-format dataset.""" from huggingface_hub import HfApi, CommitOperationAdd api = get_api() operations = [] # Upload metadata CSVs for csv_name in ["metadata.csv", "metadata_train.csv", "metadata_val.csv"]: csv_path = dataset_dir / csv_name if csv_path.exists(): operations.append( CommitOperationAdd( path_in_repo=csv_name, path_or_fileobj=str(csv_path), ) ) # Upload stats stats_bytes = json.dumps(stats, indent=2).encode("utf-8") operations.append( CommitOperationAdd( path_in_repo="dataset_stats.json", path_or_fileobj=stats_bytes, ) ) # Upload README readme = _generate_dataset_readme(stats) operations.append( CommitOperationAdd( path_in_repo="README.md", path_or_fileobj=readme.encode("utf-8"), ) ) if operations: try: api.create_commit( repo_id=output_repo, repo_type="dataset", operations=operations, commit_message=f"Final dataset: {stats['total_utterances']} utterances, {stats['total_hours']}h", ) log.info(f"Final dataset pushed to {output_repo}") except Exception as e: log.error(f"Failed to push final dataset: {e}") def _generate_dataset_readme(stats: dict) -> str: return f"""--- language: - si license: cc-by-4.0 task_categories: - text-to-speech - automatic-speech-recognition pretty_name: Sinhala TTS Dataset (Unlimited History) size_categories: - 10K Path: """HTDemucs source separation.""" output_path = output_dir / f"{wav_path.stem}_vocals.wav" if output_path.exists(): return output_path try: from demucs.pretrained import get_model from demucs.apply import apply_model model = get_model("htdemucs") model.eval() if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") model.to(device) waveform, sr = torchaudio.load(str(wav_path)) if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) if sr != model.samplerate: waveform = torchaudio.transforms.Resample(sr, model.samplerate)(waveform) waveform = waveform.unsqueeze(0).to(device) # Process in chunks to avoid OOM on long audio with torch.no_grad(): sources = apply_model(model, waveform, device=device, split=True, overlap=0.25) vocals = sources[0, 3] # drums, bass, other, vocals vocals_mono = vocals.mean(dim=0, keepdim=True) if model.samplerate != SAMPLE_RATE: vocals_mono = torchaudio.transforms.Resample(model.samplerate, SAMPLE_RATE)(vocals_mono) output_dir.mkdir(parents=True, exist_ok=True) torchaudio.save(str(output_path), vocals_mono.cpu(), SAMPLE_RATE) log.info(f" [separation] Done: {output_path.name}") return output_path except Exception as e: log.warning(f" [separation] Failed ({e}), using original audio") return wav_path def enhance_audio(wav_path: Path, output_dir: Path) -> Path: """VoiceFixer + DeepFilterNet3 enhancement.""" output_path = output_dir / f"{wav_path.stem}_enhanced.wav" if output_path.exists(): return output_path output_dir.mkdir(parents=True, exist_ok=True) current_path = wav_path # Stage 1: VoiceFixer try: from voicefixer import VoiceFixer vf = VoiceFixer() vf_output = output_dir / f"{wav_path.stem}_vf.wav" vf.restore( input=str(current_path), output=str(vf_output), cuda=torch.cuda.is_available(), mode=0, ) if vf_output.exists(): current_path = vf_output log.info(f" [enhance] VoiceFixer done") except Exception as e: log.warning(f" [enhance] VoiceFixer failed: {e}") # Stage 2: DeepFilterNet3 try: from df.enhance import enhance, init_df, load_audio, save_audio df_model, df_state, _ = init_df() audio, _ = load_audio(str(current_path), sr=df_state.sr()) enhanced = enhance(df_model, df_state, audio) save_audio(str(output_path), enhanced, df_state.sr()) if output_path.exists(): waveform, sr = torchaudio.load(str(output_path)) if sr != SAMPLE_RATE: waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform) torchaudio.save(str(output_path), waveform, SAMPLE_RATE) log.info(f" [enhance] DeepFilterNet3 done") elif current_path != wav_path: shutil.copy2(str(current_path), str(output_path)) except Exception as e: log.warning(f" [enhance] DeepFilterNet3 failed: {e}") if current_path != wav_path: shutil.copy2(str(current_path), str(output_path)) else: return wav_path # Cleanup VoiceFixer temp vf_temp = output_dir / f"{wav_path.stem}_vf.wav" if vf_temp.exists() and output_path.exists() and vf_temp != output_path: vf_temp.unlink() return output_path if output_path.exists() else wav_path def diarize_audio(wav_path: Path, num_speakers: int = 2) -> Dict[str, List[Dict]]: """Speaker diarization with pyannote 3.1 (fallback to simple-diarizer).""" token = os.environ.get("HF_TOKEN") # Try pyannote first if token: try: from pyannote.audio import Pipeline as PyannotePipeline pipeline = PyannotePipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=token, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device) waveform, sr = torchaudio.load(str(wav_path)) if sr != DIARIZE_SR: waveform = torchaudio.transforms.Resample(sr, DIARIZE_SR)(waveform) diarization = pipeline( {"waveform": waveform, "sample_rate": DIARIZE_SR}, num_speakers=num_speakers, ) speakers = {} for turn, _, speaker in diarization.itertracks(yield_label=True): if speaker not in speakers: speakers[speaker] = [] speakers[speaker].append({ "start": round(turn.start, 3), "end": round(turn.end, 3), "duration": round(turn.end - turn.start, 3), }) log.info(f" [diarize] {len(speakers)} speakers (pyannote 3.1)") return speakers except Exception as e: log.warning(f" [diarize] pyannote failed: {e}") # Fallback: simple-diarizer try: from simple_diarizer.diarizer import Diarizer # Monkeypatch torchaudio.load for compatibility def _fixed_load(uri, frame_offset=0, num_frames=-1, normalize=True, channels_first=True, **kwargs): stop = None if num_frames == -1 else frame_offset + num_frames data, samplerate = sf.read(uri, start=frame_offset, stop=stop, dtype='float32') tensor = torch.from_numpy(data) if tensor.ndim == 1: tensor = tensor.unsqueeze(0) elif channels_first: tensor = tensor.T return tensor, samplerate torchaudio.load = _fixed_load diar = Diarizer(embed_model='ecapa', cluster_method='sc') segments = diar.diarize(str(wav_path), num_speakers=num_speakers) speakers = {} for seg in segments: label = str(seg['label']) if label not in speakers: speakers[label] = [] speakers[label].append({ "start": round(seg['start'], 3), "end": round(seg['end'], 3), "duration": round(seg['end'] - seg['start'], 3), }) log.info(f" [diarize] {len(speakers)} speakers (simple-diarizer)") return speakers except Exception as e: log.error(f" [diarize] All diarization failed: {e}") import librosa dur = librosa.get_duration(path=str(wav_path)) return {"SPEAKER_0": [{"start": 0.0, "end": round(dur, 3), "duration": round(dur, 3)}]} def select_target_speaker(speakers: Dict[str, List[Dict]]) -> str: """Select speaker with most speaking time.""" durations = {spk: sum(s["duration"] for s in segs) for spk, segs in speakers.items()} best = max(durations, key=durations.get) log.info(f" [diarize] Target: {best} ({durations[best]/60:.1f}min / {sum(durations.values())/60:.1f}min)") return best def segment_with_vad(wav_path: Path, speaker_segments: List[Dict], output_dir: Path) -> List[Dict]: """Silero-VAD segmentation within speaker turns.""" output_dir.mkdir(parents=True, exist_ok=True) waveform, sr = torchaudio.load(str(wav_path)) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != 16000: waveform_16k = torchaudio.transforms.Resample(sr, 16000)(waveform) else: waveform_16k = waveform vad_model, vad_utils = torch.hub.load( 'snakers4/silero-vad', 'silero_vad', force_reload=False, trust_repo=True, ) get_speech_timestamps = vad_utils[0] utterances = [] utt_idx = 0 for seg in speaker_segments: s16 = int(seg["start"] * 16000) e16 = int(seg["end"] * 16000) seg_audio = waveform_16k[0, s16:e16] if len(seg_audio) < int(MIN_SEGMENT_SEC * 16000): continue try: speech_ts = get_speech_timestamps( seg_audio, vad_model, sampling_rate=16000, min_speech_duration_ms=500, min_silence_duration_ms=300, speech_pad_ms=100, return_seconds=False, ) except Exception: speech_ts = [{"start": 0, "end": len(seg_audio)}] if not speech_ts: continue merged = _merge_vad_segments(speech_ts, sr=16000) for vad_seg in merged: vad_start_sec = seg["start"] + vad_seg["start"] / 16000 vad_end_sec = seg["start"] + vad_seg["end"] / 16000 duration = vad_end_sec - vad_start_sec if duration < MIN_SEGMENT_SEC or duration > MAX_SEGMENT_SEC: continue start_sample = int(vad_start_sec * sr) end_sample = int(vad_end_sec * sr) utt_audio = waveform[:, start_sample:end_sample] if sr != SAMPLE_RATE: utt_audio = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(utt_audio) peak = utt_audio.abs().max() if peak > 0: utt_audio = utt_audio * (10 ** (-3 / 20) / peak) utt_name = f"{wav_path.stem}_utt{utt_idx:05d}.wav" utt_path = output_dir / utt_name torchaudio.save(str(utt_path), utt_audio, SAMPLE_RATE) utterances.append({ "path": str(utt_path), "filename": utt_name, "start": round(vad_start_sec, 3), "end": round(vad_end_sec, 3), "duration": round(duration, 3), }) utt_idx += 1 log.info(f" [vad] {len(utterances)} utterances ({sum(u['duration'] for u in utterances)/60:.1f}min)") return utterances def _merge_vad_segments(segments, sr=16000, gap_ms=500): if not segments: return [] gap_samples = int(gap_ms * sr / 1000) merged = [{"start": segments[0]["start"], "end": segments[0]["end"]}] for seg in segments[1:]: if seg["start"] - merged[-1]["end"] < gap_samples: merged[-1]["end"] = seg["end"] else: merged.append({"start": seg["start"], "end": seg["end"]}) final = [] for seg in merged: dur = (seg["end"] - seg["start"]) / sr if dur > MAX_SEGMENT_SEC: chunk = int(MAX_SEGMENT_SEC * sr) pos = seg["start"] while pos < seg["end"]: end = min(pos + chunk, seg["end"]) if (end - pos) / sr >= MIN_SEGMENT_SEC: final.append({"start": pos, "end": end}) pos = end else: final.append(seg) return final def transcribe_utterances(utterances: List[Dict], model_size: str = "large-v3") -> List[Dict]: """Whisper transcription.""" try: from faster_whisper import WhisperModel device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" log.info(f" [asr] Loading faster-whisper {model_size} on {device}...") model = WhisperModel(model_size, device=device, compute_type=compute_type) for utt in tqdm(utterances, desc="Transcribing", leave=False): try: segments, info = model.transcribe( utt["path"], language="si", beam_size=5, best_of=5, temperature=0.0, condition_on_previous_text=False, vad_filter=False, ) utt["text"] = " ".join(seg.text.strip() for seg in segments).strip() utt["language_prob"] = info.language_probability except Exception as e: utt["text"] = "" utt["language_prob"] = 0.0 return utterances except ImportError: log.error(" [asr] faster-whisper not installed!") return utterances def compute_snr(wav_path: str) -> float: import librosa y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True) rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0] threshold = np.percentile(rms, 20) noise = rms[rms <= threshold] speech = rms[rms > threshold] if len(noise) > 0 and np.mean(noise) > 1e-10: return float(20 * np.log10(np.mean(speech) / np.mean(noise))) return 40.0 def compute_pitch_stats(wav_path: str) -> Tuple[float, float]: import librosa y, sr = librosa.load(wav_path, sr=SAMPLE_RATE, mono=True) f0, _, _ = librosa.pyin(y, fmin=50, fmax=500, sr=sr) f0v = f0[~np.isnan(f0)] if len(f0v) > 0: return float(np.mean(f0v)), float(np.std(f0v)) return 0.0, 0.0 def filter_utterances(utterances: List[Dict]) -> Tuple[List[Dict], List[Dict]]: """Quality filtering.""" import librosa kept, rejected = [], [] for utt in tqdm(utterances, desc="Filtering", leave=False): reasons = [] if not utt.get("text", "").strip(): reasons.append("empty_text") if utt.get("language_prob", 0) < 0.5: reasons.append(f"low_lang_prob={utt.get('language_prob', 0):.2f}") if utt["duration"] < MIN_SEGMENT_SEC or utt["duration"] > MAX_SEGMENT_SEC: reasons.append(f"duration={utt['duration']:.1f}s") try: snr = compute_snr(utt["path"]) utt["snr_db"] = round(snr, 1) if snr < SNR_THRESHOLD: reasons.append(f"low_snr={snr:.1f}dB") except Exception: utt["snr_db"] = 0.0 reasons.append("snr_failed") try: pm, ps = compute_pitch_stats(utt["path"]) utt["pitch_mean_hz"] = round(pm, 1) utt["pitch_std_hz"] = round(ps, 1) if pm > PITCH_MEAN_MAX: reasons.append(f"high_pitch={pm:.0f}Hz") if ps > PITCH_STD_MAX: reasons.append(f"high_pitch_var={ps:.0f}Hz") except Exception: utt["pitch_mean_hz"] = 0.0 utt["pitch_std_hz"] = 0.0 if utt.get("text"): chars = len([c for c in utt["text"] if c.strip() and c not in "!?.,;:\"'()-"]) rate = chars / utt["duration"] if utt["duration"] > 0 else 0 utt["speaking_rate"] = round(rate, 1) if rate > SPEAKING_RATE_MAX: reasons.append(f"fast_speech={rate:.1f}c/s") if rate < 1.0 and utt["duration"] > 3.0: reasons.append(f"slow_speech={rate:.1f}c/s") try: y, _ = librosa.load(utt["path"], sr=SAMPLE_RATE, mono=True) rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0] threshold = np.percentile(rms, 20) speech_ratio = float(np.sum(rms > threshold) / len(rms)) utt["speech_ratio"] = round(speech_ratio, 3) if speech_ratio < MIN_SPEECH_RATIO: reasons.append(f"low_speech_ratio={speech_ratio:.2f}") except Exception: utt["speech_ratio"] = 0.0 if reasons: utt["reject_reasons"] = reasons rejected.append(utt) else: kept.append(utt) log.info(f" [filter] Kept {len(kept)}/{len(utterances)} ({len(kept)/max(1,len(utterances))*100:.1f}%)") if rejected: all_reasons = [r for u in rejected for r in u.get("reject_reasons", [])] reason_counts = {} for r in all_reasons: reason_counts[r.split("=")[0]] = reason_counts.get(r.split("=")[0], 0) + 1 log.info(f" [filter] Rejections: {reason_counts}") return kept, rejected def normalize_sinhala_text(text: str) -> str: import unicodedata text = unicodedata.normalize('NFC', text) text = text.replace('\u200C', '') text = text.replace('\u201c', '"').replace('\u201d', '"') text = text.replace('\u2018', "'").replace('\u2019', "'") text = text.replace(';', ',').replace(':', ',') text = text.replace('(', '').replace(')', '') text = text.replace('[', '').replace(']', '') return ' '.join(text.split()).strip() def export_dataset(utterances: List[Dict], output_dir: Path, val_split: float = 0.05) -> dict: """Export as LJSpeech format.""" import random output_dir.mkdir(parents=True, exist_ok=True) wavs_dir = output_dir / "wavs" wavs_dir.mkdir(exist_ok=True) metadata = [] for i, utt in enumerate(tqdm(utterances, desc="Exporting", leave=False)): name = f"si_{i:06d}" new_path = wavs_dir / f"{name}.wav" src = Path(utt["path"]) if src.exists() and not new_path.exists(): shutil.copy2(str(src), str(new_path)) text = utt.get("text", "").strip() if not text: continue metadata.append(f"{name}|{text}|{normalize_sinhala_text(text)}") random.seed(42) random.shuffle(metadata) n_val = max(1, int(len(metadata) * val_split)) (output_dir / "metadata.csv").write_text("\n".join(metadata) + "\n", encoding="utf-8") (output_dir / "metadata_train.csv").write_text("\n".join(metadata[n_val:]) + "\n", encoding="utf-8") (output_dir / "metadata_val.csv").write_text("\n".join(metadata[:n_val]) + "\n", encoding="utf-8") durations = [u["duration"] for u in utterances] stats = { "total_utterances": len(metadata), "train_utterances": len(metadata) - n_val, "val_utterances": n_val, "total_hours": round(sum(durations) / 3600, 2), "mean_duration_sec": round(float(np.mean(durations)), 2), "median_duration_sec": round(float(np.median(durations)), 2), "min_duration_sec": round(min(durations), 2), "max_duration_sec": round(max(durations), 2), "sample_rate": SAMPLE_RATE, } pitches = [u.get("pitch_mean_hz", 0) for u in utterances if u.get("pitch_mean_hz", 0) > 0] if pitches: stats["corpus_pitch_mean_hz"] = round(float(np.mean(pitches)), 1) stats["corpus_pitch_std_hz"] = round(float(np.std(pitches)), 1) (output_dir / "dataset_stats.json").write_text(json.dumps(stats, indent=2)) log.info(f"\n{'='*60}") log.info(f"DATASET: {stats['total_utterances']} utts, {stats['total_hours']}h") log.info(f" Train: {stats['train_utterances']}, Val: {stats['val_utterances']}") log.info(f" Duration: {stats['mean_duration_sec']}s mean, {stats['median_duration_sec']}s median") log.info(f"{'='*60}") return stats # ============================================================ # MAIN PIPELINE # ============================================================ def process_one_video( wav_path: Path, work_dir: Path, whisper_model: str, num_speakers: int, skip_separation: bool, skip_enhancement: bool, ) -> Tuple[List[Dict], List[Dict]]: """Full pipeline for one video. Returns (kept_utterances, rejected_utterances).""" vid_id = wav_path.stem # Step 1: Source separation current_audio = wav_path if not skip_separation: current_audio = separate_vocals(wav_path, work_dir / "separated") # Step 2: Enhancement if not skip_enhancement: current_audio = enhance_audio(current_audio, work_dir / "enhanced") # Step 3: Diarization speakers = diarize_audio(current_audio, num_speakers) target = select_target_speaker(speakers) speaker_segments = speakers[target] # Step 4: VAD segmentation utterances = segment_with_vad(current_audio, speaker_segments, work_dir / "segments" / vid_id) if not utterances: return [], [] # Step 5: Transcription utterances = transcribe_utterances(utterances, whisper_model) # Step 6: Quality filtering kept, rejected = filter_utterances(utterances) return kept, rejected def main(): parser = argparse.ArgumentParser(description="Sinhala TTS Cloud Pipeline (Phase 2)") parser.add_argument("--source-repo", required=True, help="HF dataset repo with raw audio") parser.add_argument("--output-repo", required=True, help="HF dataset repo for processed output") parser.add_argument("--whisper-model", default="large-v3") parser.add_argument("--num-speakers", type=int, default=2) parser.add_argument("--batch-size", type=int, default=5, help="Videos per processing batch") parser.add_argument("--max-videos", type=int, default=None) parser.add_argument("--skip-separation", action="store_true") parser.add_argument("--skip-enhancement", action="store_true") parser.add_argument("--video-ids", type=str, default=None, help="Comma-separated video IDs to process") args = parser.parse_args() work_dir = Path("/app/work") work_dir.mkdir(parents=True, exist_ok=True) dataset_dir = work_dir / "dataset" log.info("=" * 60) log.info("Sinhala TTS Cloud Pipeline (Phase 2)") log.info("=" * 60) log.info(f"Source: {args.source_repo}") log.info(f"Output: {args.output_repo}") log.info(f"Device: {'CUDA — ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") # Create output repo api = get_api() api.create_repo(repo_id=args.output_repo, repo_type="dataset", exist_ok=True) # Load state state = load_processing_state(args.output_repo) completed = set(state["completed_videos"]) log.info(f"Already completed: {len(completed)} videos") # Download raw audio video_ids = args.video_ids.split(",") if args.video_ids else None raw_files = download_raw_audio(args.source_repo, work_dir, video_ids) # Filter out already completed raw_files = [f for f in raw_files if f.stem not in completed] if args.max_videos: raw_files = raw_files[:args.max_videos] log.info(f"To process: {len(raw_files)} videos") if not raw_files: log.info("Nothing to process!") return # Process in batches all_kept = [] all_rejected = [] for i, wav_path in enumerate(raw_files): vid_id = wav_path.stem log.info(f"\n[{i+1}/{len(raw_files)}] Processing: {vid_id}") try: kept, rejected = process_one_video( wav_path, work_dir, args.whisper_model, args.num_speakers, args.skip_separation, args.skip_enhancement, ) all_kept.extend(kept) all_rejected.extend(rejected) # Upload utterances for this video if kept: upload_utterances_batch(kept, args.output_repo, vid_id) # Update state state["completed_videos"].append(vid_id) state["total_utterances"] = len(all_kept) state["total_hours"] = round(sum(u["duration"] for u in all_kept) / 3600, 2) save_processing_state(args.output_repo, state) log.info(f" TOTAL so far: {len(all_kept)} utterances, {state['total_hours']}h") # Cleanup this video's intermediate files to save disk for subdir in ["separated", "enhanced", "segments"]: d = work_dir / subdir if d.exists(): for f in d.rglob(f"{vid_id}*"): f.unlink(missing_ok=True) wav_path.unlink(missing_ok=True) except Exception as e: log.error(f" FAILED: {e}") import traceback traceback.print_exc() continue # Export final dataset if all_kept: log.info(f"\n{'='*60}") log.info(f"EXPORTING FINAL DATASET") log.info(f"{'='*60}") stats = export_dataset(all_kept, dataset_dir) upload_final_dataset(all_kept, dataset_dir, args.output_repo, stats) # Also upload rejected for inspection rej_bytes = json.dumps(all_rejected, indent=2, ensure_ascii=False).encode("utf-8") api.upload_file( path_or_fileobj=rej_bytes, path_in_repo="rejected_utterances.json", repo_id=args.output_repo, repo_type="dataset", commit_message=f"Rejected: {len(all_rejected)} utterances", ) log.info(f"\n{'='*60}") log.info(f"PIPELINE COMPLETE") log.info(f"{'='*60}") log.info(f" Processed: {len(state['completed_videos'])} videos") log.info(f" Kept: {len(all_kept)} utterances") log.info(f" Rejected: {len(all_rejected)} utterances") log.info(f" Total hours: {state['total_hours']}") log.info(f" Output: https://huggingface.co/datasets/{args.output_repo}") if __name__ == "__main__": main()