File size: 5,258 Bytes
4aed451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Pick a random tar file from each of three HF datasets and extract 6 random
audio samples per dataset.

Datasets:
  1. mitermix/audioset-with-grounded-captions   (mp3 in audio-dataset-XXXXXX.tar)
  2. laion/captioned-ai-music-snippets          (mp3 in suno-dataset-XXXXXX.tar)
  3. TTS-AGI/majestrino-unified-detailed-captions-temporal
                                                (flac in data/XXXXX.tar)

Output: ./samples/<dataset_short_name>/<key>.<ext>  (only the audio file is kept,
the tar archive is deleted afterwards to save disk).
"""

from __future__ import annotations

import os
import random
import shutil
import sys
import tarfile
from pathlib import Path

from huggingface_hub import HfFileSystem, hf_hub_download

ROOT = Path(__file__).resolve().parent.parent
SAMPLES_DIR = ROOT / "samples"
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)

# Reproducibility
random.seed(20260411)

DATASETS = [
    {
        "repo_id": "mitermix/audioset-with-grounded-captions",
        "short":   "audioset",
        "tar_glob": "*.tar",                # tars are at repo root
    },
    {
        "repo_id": "laion/captioned-ai-music-snippets",
        "short":   "music",
        "tar_glob": "*.tar",
    },
    {
        "repo_id": "TTS-AGI/majestrino-unified-detailed-captions-temporal",
        "short":   "majestrino",
        "tar_glob": "data/*.tar",           # tars under data/
    },
]

AUDIO_EXTS = {".mp3", ".flac", ".wav", ".ogg", ".m4a", ".opus"}
SAMPLES_PER_DATASET = 6


def list_tars(repo_id: str, glob_pat: str) -> list[str]:
    """Return relative tar paths inside the dataset repo."""
    fs = HfFileSystem()
    prefix = f"datasets/{repo_id}/"
    found = fs.glob(prefix + glob_pat)
    return [p[len(prefix):] for p in found]


def pick_audio_members(tar_path: Path, k: int) -> list[tarfile.TarInfo]:
    with tarfile.open(tar_path, "r") as tf:
        members = [m for m in tf.getmembers()
                   if m.isfile() and Path(m.name).suffix.lower() in AUDIO_EXTS]
    if not members:
        raise RuntimeError(f"No audio members found in {tar_path}")
    if len(members) <= k:
        return members
    return random.sample(members, k)


def extract_members(tar_path: Path, members: list[tarfile.TarInfo],
                    out_dir: Path, short: str) -> list[Path]:
    out_paths: list[Path] = []
    out_dir.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r") as tf:
        for m in members:
            with tf.extractfile(m) as src:
                if src is None:
                    continue
                # Flatten and prefix with the dataset short name to avoid
                # collisions across tars.
                base = Path(m.name).name
                stem = Path(base).stem
                ext = Path(base).suffix.lower()
                out_path = out_dir / f"{short}__{stem}{ext}"
                with open(out_path, "wb") as dst:
                    shutil.copyfileobj(src, dst)
                out_paths.append(out_path)
    return out_paths


def main() -> int:
    manifest_lines = []
    for ds in DATASETS:
        repo_id = ds["repo_id"]
        short = ds["short"]
        out_dir = SAMPLES_DIR / short
        existing = sorted(out_dir.glob("*")) if out_dir.exists() else []
        existing_audio = [p for p in existing if p.suffix.lower() in AUDIO_EXTS]
        if len(existing_audio) >= SAMPLES_PER_DATASET:
            print(f"[skip] {short}: already have {len(existing_audio)} samples")
            for p in existing_audio[:SAMPLES_PER_DATASET]:
                manifest_lines.append(f"{short}\t{p}")
            continue

        print(f"[list] {repo_id}: looking up tar files...")
        tars = list_tars(repo_id, ds["tar_glob"])
        if not tars:
            print(f"[error] No tar files found in {repo_id}")
            continue
        print(f"        found {len(tars)} tar files")
        chosen = random.choice(tars)
        print(f"[pick] {chosen}")

        print(f"[download] {repo_id}::{chosen}")
        local_tar = hf_hub_download(
            repo_id=repo_id,
            filename=chosen,
            repo_type="dataset",
            cache_dir=str(SAMPLES_DIR / "_cache"),
        )
        local_tar = Path(local_tar)
        print(f"           -> {local_tar} ({local_tar.stat().st_size/1e6:.1f} MB)")

        print(f"[extract] picking {SAMPLES_PER_DATASET} random audio members")
        members = pick_audio_members(local_tar, SAMPLES_PER_DATASET)
        out_paths = extract_members(local_tar, members, out_dir, short)
        for p in out_paths:
            print(f"          {p}")
            manifest_lines.append(f"{short}\t{p}")

        # Reclaim disk: drop the tar and the cache
        try:
            local_tar.unlink()
            # also drop empty parent dirs in the cache
            cache_dir = SAMPLES_DIR / "_cache"
            if cache_dir.exists():
                shutil.rmtree(cache_dir, ignore_errors=True)
        except OSError:
            pass

    manifest_path = SAMPLES_DIR / "manifest.tsv"
    manifest_path.write_text("\n".join(manifest_lines) + "\n")
    print(f"\nWrote manifest with {len(manifest_lines)} samples to {manifest_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())