sinhala-tts / scripts /finalize_distributed_dataset.py
outlawmold's picture
Fix critical issues, migrate to IndicF5 fine-tuning, update pipeline
dd75f48
#!/usr/bin/env python3
"""
Build one canonical LJSpeech dataset from distributed per-video outputs.
Distributed workers write:
videos/<video_id>/dataset/metadata.csv
videos/<video_id>/dataset/wavs/*.wav
done/<video_id>.json
This script downloads only videos that have a done marker, renumbers wavs
deterministically, and writes root metadata/train/val/stats files locally.
Optionally, it can upload the finalized dataset back to the same HF dataset repo.
"""
import argparse
import json
import os
import random
import shutil
from pathlib import Path
import numpy as np
SAMPLE_RATE = 22050
OUTPUT_REPO = "outlawmold/sinhala-tts-dataset"
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
def list_paths(api, repo: str, prefix: str):
try:
return [
item.rfilename
for item in api.list_repo_tree(
repo,
repo_type="dataset",
path_in_repo=prefix,
recursive=True,
)
if hasattr(item, "rfilename")
]
except Exception:
return []
def file_exists(api, repo: str, path_in_repo: str) -> bool:
try:
return api.file_exists(
repo_id=repo,
repo_type="dataset",
filename=path_in_repo,
)
except Exception:
return False
def done_video_ids(api, repo: str):
return sorted(Path(path).stem for path in list_paths(api, repo, "done") if path.endswith(".json"))
def read_text_file(api, repo: str, path_in_repo: str) -> str:
path = api.hf_hub_download(
repo_id=repo,
repo_type="dataset",
filename=path_in_repo,
force_download=True,
)
return Path(path).read_text(encoding="utf-8")
def download_file(api, repo: str, path_in_repo: str, local_path: Path):
downloaded = api.hf_hub_download(
repo_id=repo,
repo_type="dataset",
filename=path_in_repo,
force_download=True,
)
local_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(downloaded, local_path)
def parse_metadata(text: str):
rows = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
parts = line.split("|", 2)
if len(parts) != 3:
continue
rows.append(tuple(parts))
return rows
def build_dataset(repo: str, output_dir: Path, val_split: float):
from huggingface_hub import HfApi
api = HfApi()
output_dir.mkdir(parents=True, exist_ok=True)
wavs_dir = output_dir / "wavs"
wavs_dir.mkdir(exist_ok=True)
videos = done_video_ids(api, repo)
metadata = []
manifest = []
durations = []
next_index = 0
for video_id in videos:
metadata_path = f"videos/{video_id}/dataset/metadata.csv"
try:
rows = parse_metadata(read_text_file(api, repo, metadata_path))
except Exception as e:
print(f"[skip] {video_id}: no metadata ({e})")
continue
for old_name, text, normalized in rows:
old_wav = f"videos/{video_id}/dataset/wavs/{old_name}.wav"
new_name = f"si_{next_index:06d}"
new_wav = wavs_dir / f"{new_name}.wav"
try:
download_file(api, repo, old_wav, new_wav)
except Exception as e:
print(f"[skip] {video_id}/{old_name}: missing wav ({e})")
continue
metadata.append(f"{new_name}|{text}|{normalized}")
manifest.append({
"name": new_name,
"source_video": video_id,
"source_name": old_name,
})
try:
import soundfile as sf
info = sf.info(str(new_wav))
durations.append(info.frames / info.samplerate)
except Exception:
pass
next_index += 1
# Video-level train/val split to prevent data leakage
video_to_lines = {}
for line, m in zip(metadata, manifest):
vid = m["source_video"]
video_to_lines.setdefault(vid, []).append(line)
video_ids = sorted(video_to_lines.keys())
random.seed(42)
random.shuffle(video_ids)
n_val_videos = max(1, int(len(video_ids) * val_split)) if video_ids else 0
val_videos = set(video_ids[:n_val_videos])
train_videos = set(video_ids[n_val_videos:])
train_lines = []
val_lines = []
for vid in video_ids:
if vid in val_videos:
val_lines.extend(video_to_lines[vid])
else:
train_lines.extend(video_to_lines[vid])
(output_dir / "metadata.csv").write_text("\n".join(metadata) + ("\n" if metadata else ""), encoding="utf-8")
(output_dir / "metadata_train.csv").write_text(
"\n".join(train_lines) + ("\n" if train_lines else ""),
encoding="utf-8",
)
(output_dir / "metadata_val.csv").write_text(
"\n".join(val_lines) + ("\n" if val_lines else ""),
encoding="utf-8",
)
(output_dir / "distributed_manifest.json").write_text(
json.dumps(manifest, indent=2, ensure_ascii=False),
encoding="utf-8",
)
stats = {
"total_utterances": len(metadata),
"train_utterances": len(train_lines),
"val_utterances": len(val_lines),
"total_videos": len(video_ids),
"train_videos": len(train_videos),
"val_videos": len(val_videos),
"source_videos_done": len(videos),
"source_videos_with_metadata": len({m["source_video"] for m in manifest}),
"total_hours": round(sum(durations) / 3600, 2) if durations else 0.0,
"mean_duration_sec": round(float(np.mean(durations)), 2) if durations else 0.0,
"median_duration_sec": round(float(np.median(durations)), 2) if durations else 0.0,
"min_duration_sec": round(min(durations), 2) if durations else 0.0,
"max_duration_sec": round(max(durations), 2) if durations else 0.0,
"sample_rate": SAMPLE_RATE,
}
(output_dir / "dataset_stats.json").write_text(json.dumps(stats, indent=2), encoding="utf-8")
return stats
def upload_dataset(repo: str, output_dir: Path):
from huggingface_hub import CommitOperationDelete, HfApi
api = HfApi()
root_files = [
"metadata.csv",
"metadata_train.csv",
"metadata_val.csv",
"dataset_stats.json",
"distributed_manifest.json",
]
delete_paths = set(path for path in list_paths(api, repo, "wavs") if path.endswith(".wav"))
delete_paths.update(path for path in root_files if file_exists(api, repo, path))
if delete_paths:
api.create_commit(
repo_id=repo,
repo_type="dataset",
operations=[CommitOperationDelete(path_in_repo=path) for path in sorted(delete_paths)],
commit_message="Clear finalized dataset before replacement",
)
api.upload_folder(
folder_path=str(output_dir),
repo_id=repo,
repo_type="dataset",
commit_message="Finalize distributed dataset",
)
def parse_args():
parser = argparse.ArgumentParser(description="Finalize distributed Sinhala TTS outputs")
parser.add_argument("--repo", default=OUTPUT_REPO)
parser.add_argument("--output-dir", default="pipeline_output/distributed_final_dataset")
parser.add_argument("--val-split", type=float, default=0.05)
parser.add_argument("--upload", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
output_dir = Path(args.output_dir)
stats = build_dataset(args.repo, output_dir, args.val_split)
print(f"Wrote {stats['total_utterances']} utterances to {output_dir}")
print(f"Hours: {stats['total_hours']}")
if args.upload:
upload_dataset(args.repo, output_dir)
print(f"Uploaded finalized dataset to https://huggingface.co/datasets/{args.repo}")
if __name__ == "__main__":
main()