| """Pick a random tar file from each of three HF datasets and extract 6 random |
| audio samples per dataset. |
| |
| Datasets: |
| 1. mitermix/audioset-with-grounded-captions (mp3 in audio-dataset-XXXXXX.tar) |
| 2. laion/captioned-ai-music-snippets (mp3 in suno-dataset-XXXXXX.tar) |
| 3. TTS-AGI/majestrino-unified-detailed-captions-temporal |
| (flac in data/XXXXX.tar) |
| |
| Output: ./samples/<dataset_short_name>/<key>.<ext> (only the audio file is kept, |
| the tar archive is deleted afterwards to save disk). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import random |
| import shutil |
| import sys |
| import tarfile |
| from pathlib import Path |
|
|
| from huggingface_hub import HfFileSystem, hf_hub_download |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| SAMPLES_DIR = ROOT / "samples" |
| SAMPLES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| random.seed(20260411) |
|
|
| DATASETS = [ |
| { |
| "repo_id": "mitermix/audioset-with-grounded-captions", |
| "short": "audioset", |
| "tar_glob": "*.tar", |
| }, |
| { |
| "repo_id": "laion/captioned-ai-music-snippets", |
| "short": "music", |
| "tar_glob": "*.tar", |
| }, |
| { |
| "repo_id": "TTS-AGI/majestrino-unified-detailed-captions-temporal", |
| "short": "majestrino", |
| "tar_glob": "data/*.tar", |
| }, |
| ] |
|
|
| AUDIO_EXTS = {".mp3", ".flac", ".wav", ".ogg", ".m4a", ".opus"} |
| SAMPLES_PER_DATASET = 6 |
|
|
|
|
| def list_tars(repo_id: str, glob_pat: str) -> list[str]: |
| """Return relative tar paths inside the dataset repo.""" |
| fs = HfFileSystem() |
| prefix = f"datasets/{repo_id}/" |
| found = fs.glob(prefix + glob_pat) |
| return [p[len(prefix):] for p in found] |
|
|
|
|
| def pick_audio_members(tar_path: Path, k: int) -> list[tarfile.TarInfo]: |
| with tarfile.open(tar_path, "r") as tf: |
| members = [m for m in tf.getmembers() |
| if m.isfile() and Path(m.name).suffix.lower() in AUDIO_EXTS] |
| if not members: |
| raise RuntimeError(f"No audio members found in {tar_path}") |
| if len(members) <= k: |
| return members |
| return random.sample(members, k) |
|
|
|
|
| def extract_members(tar_path: Path, members: list[tarfile.TarInfo], |
| out_dir: Path, short: str) -> list[Path]: |
| out_paths: list[Path] = [] |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with tarfile.open(tar_path, "r") as tf: |
| for m in members: |
| with tf.extractfile(m) as src: |
| if src is None: |
| continue |
| |
| |
| base = Path(m.name).name |
| stem = Path(base).stem |
| ext = Path(base).suffix.lower() |
| out_path = out_dir / f"{short}__{stem}{ext}" |
| with open(out_path, "wb") as dst: |
| shutil.copyfileobj(src, dst) |
| out_paths.append(out_path) |
| return out_paths |
|
|
|
|
| def main() -> int: |
| manifest_lines = [] |
| for ds in DATASETS: |
| repo_id = ds["repo_id"] |
| short = ds["short"] |
| out_dir = SAMPLES_DIR / short |
| existing = sorted(out_dir.glob("*")) if out_dir.exists() else [] |
| existing_audio = [p for p in existing if p.suffix.lower() in AUDIO_EXTS] |
| if len(existing_audio) >= SAMPLES_PER_DATASET: |
| print(f"[skip] {short}: already have {len(existing_audio)} samples") |
| for p in existing_audio[:SAMPLES_PER_DATASET]: |
| manifest_lines.append(f"{short}\t{p}") |
| continue |
|
|
| print(f"[list] {repo_id}: looking up tar files...") |
| tars = list_tars(repo_id, ds["tar_glob"]) |
| if not tars: |
| print(f"[error] No tar files found in {repo_id}") |
| continue |
| print(f" found {len(tars)} tar files") |
| chosen = random.choice(tars) |
| print(f"[pick] {chosen}") |
|
|
| print(f"[download] {repo_id}::{chosen}") |
| local_tar = hf_hub_download( |
| repo_id=repo_id, |
| filename=chosen, |
| repo_type="dataset", |
| cache_dir=str(SAMPLES_DIR / "_cache"), |
| ) |
| local_tar = Path(local_tar) |
| print(f" -> {local_tar} ({local_tar.stat().st_size/1e6:.1f} MB)") |
|
|
| print(f"[extract] picking {SAMPLES_PER_DATASET} random audio members") |
| members = pick_audio_members(local_tar, SAMPLES_PER_DATASET) |
| out_paths = extract_members(local_tar, members, out_dir, short) |
| for p in out_paths: |
| print(f" {p}") |
| manifest_lines.append(f"{short}\t{p}") |
|
|
| |
| try: |
| local_tar.unlink() |
| |
| cache_dir = SAMPLES_DIR / "_cache" |
| if cache_dir.exists(): |
| shutil.rmtree(cache_dir, ignore_errors=True) |
| except OSError: |
| pass |
|
|
| manifest_path = SAMPLES_DIR / "manifest.tsv" |
| manifest_path.write_text("\n".join(manifest_lines) + "\n") |
| print(f"\nWrote manifest with {len(manifest_lines)} samples to {manifest_path}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|