whisper-captioning-ensemble / scripts /sample_datasets.py
ChristophSchuhmann's picture
Upload scripts/sample_datasets.py with huggingface_hub
4aed451 verified
"""Pick a random tar file from each of three HF datasets and extract 6 random
audio samples per dataset.
Datasets:
1. mitermix/audioset-with-grounded-captions (mp3 in audio-dataset-XXXXXX.tar)
2. laion/captioned-ai-music-snippets (mp3 in suno-dataset-XXXXXX.tar)
3. TTS-AGI/majestrino-unified-detailed-captions-temporal
(flac in data/XXXXX.tar)
Output: ./samples/<dataset_short_name>/<key>.<ext> (only the audio file is kept,
the tar archive is deleted afterwards to save disk).
"""
from __future__ import annotations
import os
import random
import shutil
import sys
import tarfile
from pathlib import Path
from huggingface_hub import HfFileSystem, hf_hub_download
ROOT = Path(__file__).resolve().parent.parent
SAMPLES_DIR = ROOT / "samples"
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
# Reproducibility
random.seed(20260411)
DATASETS = [
{
"repo_id": "mitermix/audioset-with-grounded-captions",
"short": "audioset",
"tar_glob": "*.tar", # tars are at repo root
},
{
"repo_id": "laion/captioned-ai-music-snippets",
"short": "music",
"tar_glob": "*.tar",
},
{
"repo_id": "TTS-AGI/majestrino-unified-detailed-captions-temporal",
"short": "majestrino",
"tar_glob": "data/*.tar", # tars under data/
},
]
AUDIO_EXTS = {".mp3", ".flac", ".wav", ".ogg", ".m4a", ".opus"}
SAMPLES_PER_DATASET = 6
def list_tars(repo_id: str, glob_pat: str) -> list[str]:
"""Return relative tar paths inside the dataset repo."""
fs = HfFileSystem()
prefix = f"datasets/{repo_id}/"
found = fs.glob(prefix + glob_pat)
return [p[len(prefix):] for p in found]
def pick_audio_members(tar_path: Path, k: int) -> list[tarfile.TarInfo]:
with tarfile.open(tar_path, "r") as tf:
members = [m for m in tf.getmembers()
if m.isfile() and Path(m.name).suffix.lower() in AUDIO_EXTS]
if not members:
raise RuntimeError(f"No audio members found in {tar_path}")
if len(members) <= k:
return members
return random.sample(members, k)
def extract_members(tar_path: Path, members: list[tarfile.TarInfo],
out_dir: Path, short: str) -> list[Path]:
out_paths: list[Path] = []
out_dir.mkdir(parents=True, exist_ok=True)
with tarfile.open(tar_path, "r") as tf:
for m in members:
with tf.extractfile(m) as src:
if src is None:
continue
# Flatten and prefix with the dataset short name to avoid
# collisions across tars.
base = Path(m.name).name
stem = Path(base).stem
ext = Path(base).suffix.lower()
out_path = out_dir / f"{short}__{stem}{ext}"
with open(out_path, "wb") as dst:
shutil.copyfileobj(src, dst)
out_paths.append(out_path)
return out_paths
def main() -> int:
manifest_lines = []
for ds in DATASETS:
repo_id = ds["repo_id"]
short = ds["short"]
out_dir = SAMPLES_DIR / short
existing = sorted(out_dir.glob("*")) if out_dir.exists() else []
existing_audio = [p for p in existing if p.suffix.lower() in AUDIO_EXTS]
if len(existing_audio) >= SAMPLES_PER_DATASET:
print(f"[skip] {short}: already have {len(existing_audio)} samples")
for p in existing_audio[:SAMPLES_PER_DATASET]:
manifest_lines.append(f"{short}\t{p}")
continue
print(f"[list] {repo_id}: looking up tar files...")
tars = list_tars(repo_id, ds["tar_glob"])
if not tars:
print(f"[error] No tar files found in {repo_id}")
continue
print(f" found {len(tars)} tar files")
chosen = random.choice(tars)
print(f"[pick] {chosen}")
print(f"[download] {repo_id}::{chosen}")
local_tar = hf_hub_download(
repo_id=repo_id,
filename=chosen,
repo_type="dataset",
cache_dir=str(SAMPLES_DIR / "_cache"),
)
local_tar = Path(local_tar)
print(f" -> {local_tar} ({local_tar.stat().st_size/1e6:.1f} MB)")
print(f"[extract] picking {SAMPLES_PER_DATASET} random audio members")
members = pick_audio_members(local_tar, SAMPLES_PER_DATASET)
out_paths = extract_members(local_tar, members, out_dir, short)
for p in out_paths:
print(f" {p}")
manifest_lines.append(f"{short}\t{p}")
# Reclaim disk: drop the tar and the cache
try:
local_tar.unlink()
# also drop empty parent dirs in the cache
cache_dir = SAMPLES_DIR / "_cache"
if cache_dir.exists():
shutil.rmtree(cache_dir, ignore_errors=True)
except OSError:
pass
manifest_path = SAMPLES_DIR / "manifest.tsv"
manifest_path.write_text("\n".join(manifest_lines) + "\n")
print(f"\nWrote manifest with {len(manifest_lines)} samples to {manifest_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())