"""Pick a random tar file from each of three HF datasets and extract 6 random audio samples per dataset. Datasets: 1. mitermix/audioset-with-grounded-captions (mp3 in audio-dataset-XXXXXX.tar) 2. laion/captioned-ai-music-snippets (mp3 in suno-dataset-XXXXXX.tar) 3. TTS-AGI/majestrino-unified-detailed-captions-temporal (flac in data/XXXXX.tar) Output: ./samples//. (only the audio file is kept, the tar archive is deleted afterwards to save disk). """ from __future__ import annotations import os import random import shutil import sys import tarfile from pathlib import Path from huggingface_hub import HfFileSystem, hf_hub_download ROOT = Path(__file__).resolve().parent.parent SAMPLES_DIR = ROOT / "samples" SAMPLES_DIR.mkdir(parents=True, exist_ok=True) # Reproducibility random.seed(20260411) DATASETS = [ { "repo_id": "mitermix/audioset-with-grounded-captions", "short": "audioset", "tar_glob": "*.tar", # tars are at repo root }, { "repo_id": "laion/captioned-ai-music-snippets", "short": "music", "tar_glob": "*.tar", }, { "repo_id": "TTS-AGI/majestrino-unified-detailed-captions-temporal", "short": "majestrino", "tar_glob": "data/*.tar", # tars under data/ }, ] AUDIO_EXTS = {".mp3", ".flac", ".wav", ".ogg", ".m4a", ".opus"} SAMPLES_PER_DATASET = 6 def list_tars(repo_id: str, glob_pat: str) -> list[str]: """Return relative tar paths inside the dataset repo.""" fs = HfFileSystem() prefix = f"datasets/{repo_id}/" found = fs.glob(prefix + glob_pat) return [p[len(prefix):] for p in found] def pick_audio_members(tar_path: Path, k: int) -> list[tarfile.TarInfo]: with tarfile.open(tar_path, "r") as tf: members = [m for m in tf.getmembers() if m.isfile() and Path(m.name).suffix.lower() in AUDIO_EXTS] if not members: raise RuntimeError(f"No audio members found in {tar_path}") if len(members) <= k: return members return random.sample(members, k) def extract_members(tar_path: Path, members: list[tarfile.TarInfo], out_dir: Path, short: str) -> list[Path]: out_paths: list[Path] = [] out_dir.mkdir(parents=True, exist_ok=True) with tarfile.open(tar_path, "r") as tf: for m in members: with tf.extractfile(m) as src: if src is None: continue # Flatten and prefix with the dataset short name to avoid # collisions across tars. base = Path(m.name).name stem = Path(base).stem ext = Path(base).suffix.lower() out_path = out_dir / f"{short}__{stem}{ext}" with open(out_path, "wb") as dst: shutil.copyfileobj(src, dst) out_paths.append(out_path) return out_paths def main() -> int: manifest_lines = [] for ds in DATASETS: repo_id = ds["repo_id"] short = ds["short"] out_dir = SAMPLES_DIR / short existing = sorted(out_dir.glob("*")) if out_dir.exists() else [] existing_audio = [p for p in existing if p.suffix.lower() in AUDIO_EXTS] if len(existing_audio) >= SAMPLES_PER_DATASET: print(f"[skip] {short}: already have {len(existing_audio)} samples") for p in existing_audio[:SAMPLES_PER_DATASET]: manifest_lines.append(f"{short}\t{p}") continue print(f"[list] {repo_id}: looking up tar files...") tars = list_tars(repo_id, ds["tar_glob"]) if not tars: print(f"[error] No tar files found in {repo_id}") continue print(f" found {len(tars)} tar files") chosen = random.choice(tars) print(f"[pick] {chosen}") print(f"[download] {repo_id}::{chosen}") local_tar = hf_hub_download( repo_id=repo_id, filename=chosen, repo_type="dataset", cache_dir=str(SAMPLES_DIR / "_cache"), ) local_tar = Path(local_tar) print(f" -> {local_tar} ({local_tar.stat().st_size/1e6:.1f} MB)") print(f"[extract] picking {SAMPLES_PER_DATASET} random audio members") members = pick_audio_members(local_tar, SAMPLES_PER_DATASET) out_paths = extract_members(local_tar, members, out_dir, short) for p in out_paths: print(f" {p}") manifest_lines.append(f"{short}\t{p}") # Reclaim disk: drop the tar and the cache try: local_tar.unlink() # also drop empty parent dirs in the cache cache_dir = SAMPLES_DIR / "_cache" if cache_dir.exists(): shutil.rmtree(cache_dir, ignore_errors=True) except OSError: pass manifest_path = SAMPLES_DIR / "manifest.tsv" manifest_path.write_text("\n".join(manifest_lines) + "\n") print(f"\nWrote manifest with {len(manifest_lines)} samples to {manifest_path}") return 0 if __name__ == "__main__": raise SystemExit(main())