File size: 3,789 Bytes
b7720f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import json
import random
from pathlib import Path


CLASSES = [
    ("lead_vocal", 0.95),
    ("kick", 0.25),
    ("bass", 0.35),
    ("pad", 0.10),
    ("synth_lead", 0.70),
    ("fx", 0.05),
]

STYLE_PRESETS = {
    "club": {"layout": "iamf", "room_preset": "club_medium", "lufs": -14.0},
    "cinematic": {"layout": "iamf", "room_preset": "cinema_large", "lufs": -16.0},
    "live_stage": {"layout": "iamf", "room_preset": "stage_live", "lufs": -15.0},
}


def build_example(seed: int) -> dict:
    rng = random.Random(seed)
    style = rng.choice(list(STYLE_PRESETS.keys()))
    bpm = rng.choice([90, 100, 110, 120, 128, 140])
    energy = round(rng.uniform(0.35, 0.95), 2)
    max_objects = rng.choice([8, 10, 12])

    stems = []
    for idx, (klass, leadness) in enumerate(CLASSES, start=1):
        stems.append(
            {
                "id": f"{klass[:1]}{idx}",
                "class": klass,
                "lufs": round(rng.uniform(-25.0, -10.0), 1),
                "transient": round(rng.uniform(0.05, 0.98), 2),
                "band_energy": {
                    "low": round(rng.uniform(0.05, 0.9), 2),
                    "mid": round(rng.uniform(0.05, 0.9), 2),
                    "high": round(rng.uniform(0.05, 0.9), 2),
                },
                "leadness": leadness,
            }
        )

    payload = {
        "target_format": "iamf",
        "max_objects": max_objects,
        "style": style,
        "section": rng.choice(["intro", "verse", "break", "drop"]),
        "global": {"bpm": bpm, "energy": energy},
        "stems": stems,
        "rules": [
            {"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
            {"type": "mono_low_end", "hz_below": 120},
            {"type": "width_pref", "track_class": "pad", "min_width": 0.75},
        ],
    }

    preset = STYLE_PRESETS[style]
    output = {
        "version": "1.0",
        "bed": {
            "layout": preset["layout"],
            "loudness_target_lufs": preset["lufs"],
            "room_preset": preset["room_preset"],
        },
        "objects": [
            {
                "id": stems[0]["id"],
                "class": "lead_vocal",
                "az_deg": 0,
                "el_deg": 10,
                "dist_m": 1.6,
                "width": 0.15,
                "gain_db": 0.0,
                "reverb_send": 0.18,
                "early_reflections": 0.2,
                "motion": [
                    {"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
                    {"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
                ],
            }
        ],
        "constraints_applied": [
            "anchor:lead_vocal@0/10/1.6",
            "mono_low_end<120Hz",
            "pad_width>=0.75",
        ],
    }

    prompt = (
        "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\n"
        "INPUT:\n" + json.dumps(payload, indent=2)
    )

    return {"prompt": prompt, "completion": json.dumps(output, indent=2)}


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate a small synthetic GravityLLM dataset.")
    parser.add_argument("--output", type=Path, default=Path("data/synthetic_train.jsonl"))
    parser.add_argument("--count", type=int, default=25)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    args.output.parent.mkdir(parents=True, exist_ok=True)
    with args.output.open("w", encoding="utf-8") as f:
        for i in range(args.count):
            row = build_example(args.seed + i)
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    print(f"Wrote {args.count} examples to {args.output}")


if __name__ == "__main__":
    main()