lzanardos9 commited on
Commit
b7720f0
·
verified ·
1 Parent(s): 413e01a

Upload 20 files

Browse files
LICENSE ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
10
+
11
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
12
+
13
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
14
+
15
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
16
+
17
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
18
+
19
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
20
+
21
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work.
22
+
23
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
24
+
25
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems.
26
+
27
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
28
+
29
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
30
+
31
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work.
32
+
33
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
34
+ (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
35
+ (b) You must cause any modified files to carry prominent notices stating that You changed the files; and
36
+ (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work; and
37
+ (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file.
38
+
39
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work shall be under the terms and conditions of this License.
40
+
41
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor.
42
+
43
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44
+
45
+ 8. Limitation of Liability. In no event and under no legal theory shall any Contributor be liable to You for damages arising from the use of the Work.
46
+
47
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer support or warranty protection only on Your own behalf.
Makefile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ install:
2
+ python -m pip install -r requirements.txt
3
+
4
+ train:
5
+ python train.py --qlora --bf16 --num_train_epochs 3 --train_batch_size 1 --gradient_accumulation_steps 16
6
+
7
+ infer:
8
+ python infer.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --input_json examples/sample_input.json --validate --output_json outputs/sample_prediction.json
9
+
10
+ eval:
11
+ python evaluate.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --data_file data/valid.jsonl --report_path reports/eval_report.json
12
+
13
+ validate-example:
14
+ python tools/validate_scene.py schemas/scene.schema.json examples/sample_output.json
15
+
16
+ synthetic-data:
17
+ python tools/make_synthetic_dataset.py --output data/synthetic_train.jsonl --count 50
README.md CHANGED
@@ -1,3 +1,267 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ base_model: Qwen/Qwen2.5-1.5B-Instruct
8
+ tags:
9
+ - gravityllm
10
+ - spatial-audio
11
+ - immersive-audio
12
+ - spatial9
13
+ - iamf
14
+ - instruction-tuning
15
+ - json
16
+ - lora
17
+ - qlora
18
+ - peft
19
+ - transformers
20
+ widget:
21
+ - text: |-
22
+ INPUT:
23
+ {
24
+ "target_format": "iamf",
25
+ "max_objects": 10,
26
+ "style": "club",
27
+ "section": "drop",
28
+ "global": {"bpm": 128, "energy": 0.92},
29
+ "stems": [
30
+ {"id": "v1", "class": "lead_vocal", "lufs": -16.8, "transient": 0.25, "band_energy": {"low": 0.1, "mid": 0.6, "high": 0.3}, "leadness": 0.95},
31
+ {"id": "k1", "class": "kick", "lufs": -10.5, "transient": 0.95, "band_energy": {"low": 0.8, "mid": 0.15, "high": 0.05}, "leadness": 0.25}
32
+ ],
33
+ "rules": [
34
+ {"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
35
+ {"type": "mono_low_end", "hz_below": 120}
36
+ ]
37
+ }
38
+ ---
39
+
40
+ ![GravityLLM banner](assets/gravityllm_banner.svg)
41
+
42
+ # GravityLLM
43
+
44
+ GravityLLM is a compact instruction-tuned model for **constraint-conditioned spatial scene generation**.
45
+ It turns **music constraints + stem descriptors** into strict **Spatial9Scene JSON** for immersive audio pipelines such as IAMF, binaural, and bed-plus-object rendering workflows.
46
+
47
+ > **Status**
48
+ > This repository is **training-ready and Hub-ready**.
49
+ > The zip includes code, schema, sample data, evaluation, and upload helpers.
50
+ > It does **not** include fine-tuned weights yet. After training, upload the contents of your `outputs/...` folder as the actual model repo.
51
+
52
+ ## What makes this repo solid
53
+
54
+ - Proper instruction fine-tuning with **prompt masking**, so the loss is applied to the target JSON instead of the instruction prefix.
55
+ - **LoRA** and **QLoRA** training paths for efficient fine-tuning on small-to-medium GPUs.
56
+ - Strict **JSON Schema** validation for production-safe outputs.
57
+ - Built-in **evaluation** for parse rate, schema-valid rate, object-budget pass rate, and anchor-rule pass rate.
58
+ - Clean **Hugging Face upload** helper with `upload_folder`.
59
+ - Ready-made **sample data**, **sample scene**, and **recommended training config**.
60
+
61
+ ## Model contract
62
+
63
+ ### Input
64
+ A structured payload describing:
65
+
66
+ - target format
67
+ - object budget
68
+ - style and section
69
+ - per-stem descriptors
70
+ - hard rules such as anchors, low-end centering, width targets, and masking constraints
71
+
72
+ ### Output
73
+ A single valid JSON object matching `schemas/scene.schema.json`.
74
+
75
+ ### Example input
76
+ ```json
77
+ {
78
+ "target_format": "iamf",
79
+ "max_objects": 10,
80
+ "style": "club",
81
+ "section": "drop",
82
+ "global": {"bpm": 128, "energy": 0.92},
83
+ "stems": [
84
+ {"id": "v1", "class": "lead_vocal", "lufs": -16.8, "transient": 0.25, "band_energy": {"low": 0.1, "mid": 0.6, "high": 0.3}, "leadness": 0.95},
85
+ {"id": "k1", "class": "kick", "lufs": -10.5, "transient": 0.95, "band_energy": {"low": 0.8, "mid": 0.15, "high": 0.05}, "leadness": 0.25}
86
+ ],
87
+ "rules": [
88
+ {"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
89
+ {"type": "mono_low_end", "hz_below": 120}
90
+ ]
91
+ }
92
+ ```
93
+
94
+ ### Example output
95
+ ```json
96
+ {
97
+ "version": "1.0",
98
+ "bed": {"layout": "iamf", "loudness_target_lufs": -14.0, "room_preset": "club_medium"},
99
+ "objects": [
100
+ {
101
+ "id": "v1",
102
+ "class": "lead_vocal",
103
+ "az_deg": 0,
104
+ "el_deg": 10,
105
+ "dist_m": 1.6,
106
+ "width": 0.15,
107
+ "gain_db": 0.0,
108
+ "reverb_send": 0.18,
109
+ "early_reflections": 0.22,
110
+ "motion": [
111
+ {"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
112
+ {"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6}
113
+ ]
114
+ }
115
+ ],
116
+ "constraints_applied": [
117
+ "anchor:lead_vocal@0/10/1.6",
118
+ "mono_low_end<120Hz"
119
+ ]
120
+ }
121
+ ```
122
+
123
+ ## Repository layout
124
+
125
+ ```text
126
+ GravityLLM-HuggingFace-Repo/
127
+ ├── README.md
128
+ ├── LICENSE
129
+ ├── Makefile
130
+ ├── pyproject.toml
131
+ ├── requirements.txt
132
+ ├── train.py
133
+ ├── infer.py
134
+ ├── evaluate.py
135
+ ├── upload_to_hub.py
136
+ ├── assets/
137
+ │ └── gravityllm_banner.svg
138
+ ├── configs/
139
+ │ └── recommended_train_args.json
140
+ ├── data/
141
+ │ ├── train.jsonl
142
+ │ └── valid.jsonl
143
+ ├── examples/
144
+ │ ├── sample_input.json
145
+ │ └── sample_output.json
146
+ ├── schemas/
147
+ │ └── scene.schema.json
148
+ ├── scripts/
149
+ │ ├── push_to_hub.sh
150
+ │ └── train_qlora.sh
151
+ └── tools/
152
+ ├── make_synthetic_dataset.py
153
+ └── validate_scene.py
154
+ ```
155
+
156
+ ## Quick start
157
+
158
+ ### 1) Install
159
+ ```bash
160
+ python -m pip install -r requirements.txt
161
+ ```
162
+
163
+ ### 2) Train with QLoRA
164
+ ```bash
165
+ bash scripts/train_qlora.sh
166
+ ```
167
+
168
+ Or run directly:
169
+
170
+ ```bash
171
+ python train.py --model Qwen/Qwen2.5-1.5B-Instruct --train_file data/train.jsonl --valid_file data/valid.jsonl --output_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --max_length 2048 --num_train_epochs 3 --learning_rate 2e-4 --train_batch_size 1 --eval_batch_size 1 --gradient_accumulation_steps 16 --warmup_ratio 0.03 --save_steps 100 --eval_steps 100 --qlora --bf16
172
+ ```
173
+
174
+ ### 3) Generate a scene
175
+ ```bash
176
+ python infer.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --input_json examples/sample_input.json --validate --output_json outputs/sample_prediction.json
177
+ ```
178
+
179
+ ### 4) Evaluate
180
+ ```bash
181
+ python evaluate.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --data_file data/valid.jsonl --report_path reports/eval_report.json
182
+ ```
183
+
184
+ ### 5) Validate any output
185
+ ```bash
186
+ python tools/validate_scene.py schemas/scene.schema.json outputs/sample_prediction.json
187
+ ```
188
+
189
+ ## Push to the Hugging Face Hub
190
+
191
+ ### From a trained output folder
192
+ ```bash
193
+ python upload_to_hub.py --folder_path outputs/GravityLLM-Qwen2.5-1.5B-S9 --repo_id YOUR_NAMESPACE/GravityLLM-Qwen2.5-1.5B-S9
194
+ ```
195
+
196
+ ### Or with the helper script
197
+ ```bash
198
+ bash scripts/push_to_hub.sh outputs/GravityLLM-Qwen2.5-1.5B-S9 YOUR_NAMESPACE/GravityLLM-Qwen2.5-1.5B-S9
199
+ ```
200
+
201
+ ## Dataset format
202
+
203
+ Training files are JSONL with two fields per row:
204
+
205
+ ```json
206
+ {
207
+ "prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{...}",
208
+ "completion": "{... valid Spatial9Scene JSON ...}"
209
+ }
210
+ ```
211
+
212
+ The provided sample dataset is intentionally small. Replace it with your real production examples as soon as possible.
213
+
214
+ ## Recommended data strategy
215
+
216
+ For a strong first release:
217
+
218
+ 1. Collect a few hundred high-quality gold examples from expert-authored scenes.
219
+ 2. Keep the schema stable and quantized.
220
+ 3. Encode hard rules explicitly instead of relying on vague prose.
221
+ 4. Run evaluation after every fine-tune.
222
+ 5. Add a post-processor to enforce hard constraints if the runtime must be deterministic.
223
+
224
+ ## Suggested training roadmap
225
+
226
+ ### v0
227
+ - Small curated dataset
228
+ - QLoRA adapter
229
+ - Schema-valid JSON only
230
+ - Anchor and budget constraints
231
+
232
+ ### v1
233
+ - More genres and sections
234
+ - Better masking and width rules
235
+ - Object motion patterns
236
+ - Automatic validation and repair loop
237
+
238
+ ### v2
239
+ - Preference tuning on human A/B judgments
240
+ - A dedicated reward signal for clarity, masking avoidance, and translation safety
241
+
242
+ ## Intended use
243
+
244
+ GravityLLM is designed for:
245
+
246
+ - music-tech pipelines
247
+ - Spatial9 scene authoring
248
+ - assisted immersive-audio layout generation
249
+ - IAMF-ready authoring workflows
250
+ - renderer-side JSON generation
251
+
252
+ ## Limitations
253
+
254
+ - This repo does not include trained weights out of the box.
255
+ - The model only knows what you teach it through your dataset.
256
+ - Raw audio is not consumed directly here; the training pipeline expects structured stem features.
257
+ - Production systems should still validate outputs and optionally apply a rule-based correction pass.
258
+
259
+ ## Safety and reliability
260
+
261
+ - Always validate generated scenes against the JSON schema.
262
+ - Keep low-end centering as a hard rule outside the model if that is non-negotiable.
263
+ - Treat the model as a scene proposal engine, not an oracle.
264
+
265
+ ## License
266
+
267
+ This repository is released under Apache-2.0.
assets/gravityllm_banner.svg ADDED
configs/recommended_train_args.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "Qwen/Qwen2.5-1.5B-Instruct",
3
+ "train_file": "data/train.jsonl",
4
+ "valid_file": "data/valid.jsonl",
5
+ "output_dir": "outputs/GravityLLM-Qwen2.5-1.5B-S9",
6
+ "max_length": 2048,
7
+ "num_train_epochs": 3,
8
+ "learning_rate": 0.0002,
9
+ "train_batch_size": 1,
10
+ "eval_batch_size": 1,
11
+ "gradient_accumulation_steps": 16,
12
+ "warmup_ratio": 0.03,
13
+ "weight_decay": 0.0,
14
+ "logging_steps": 10,
15
+ "save_steps": 100,
16
+ "eval_steps": 100,
17
+ "seed": 42,
18
+ "qlora": true,
19
+ "bf16": true,
20
+ "lora_r": 16,
21
+ "lora_alpha": 32,
22
+ "lora_dropout": 0.05
23
+ }
data/train.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 10,\n \"style\": \"club\",\n \"section\": \"drop\",\n \"global\": {\n \"bpm\": 128,\n \"energy\": 0.92\n },\n \"stems\": [\n {\n \"id\": \"v1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -16.8,\n \"transient\": 0.25,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.6,\n \"high\": 0.3\n },\n \"leadness\": 0.95\n },\n {\n \"id\": \"k1\",\n \"class\": \"kick\",\n \"lufs\": -10.5,\n \"transient\": 0.95,\n \"band_energy\": {\n \"low\": 0.8,\n \"mid\": 0.15,\n \"high\": 0.05\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.2,\n \"transient\": 0.55,\n \"band_energy\": {\n \"low\": 0.85,\n \"mid\": 0.12,\n \"high\": 0.03\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -20.3,\n \"transient\": 0.1,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"s1\",\n \"class\": \"synth_lead\",\n \"lufs\": -18.0,\n \"transient\": 0.4,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.55,\n \"high\": 0.35\n },\n \"leadness\": 0.75\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -23.0,\n \"transient\": 0.2,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.3,\n \"high\": 0.6\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.75\n },\n {\n \"type\": \"avoid_band_masking\",\n \"mask_target\": \"lead_vocal\",\n \"band_hz\": [\n 1500,\n 4500\n ]\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -14.0,\n \"room_preset\": \"club_medium\"\n },\n \"objects\": [\n {\n \"id\": \"v1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6,\n \"width\": 0.15,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.18,\n \"early_reflections\": 0.22,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n }\n ]\n },\n {\n \"id\": \"k1\",\n \"class\": \"kick\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2,\n \"width\": 0.0,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.02,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6,\n \"width\": 0.05,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.03,\n \"early_reflections\": 0.06,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6\n }\n ]\n },\n {\n \"id\": \"s1\",\n \"class\": \"synth_lead\",\n \"az_deg\": 25,\n \"el_deg\": 15,\n \"dist_m\": 2.0,\n \"width\": 0.35,\n \"gain_db\": -1.0,\n \"reverb_send\": 0.12,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 20,\n \"el_deg\": 15,\n \"dist_m\": 2.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 15,\n \"dist_m\": 2.0\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0,\n \"width\": 0.85,\n \"gain_db\": -3.0,\n \"reverb_send\": 0.28,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 110,\n \"el_deg\": 30,\n \"dist_m\": 6.0,\n \"width\": 0.65,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.4,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 100,\n \"el_deg\": 30,\n \"dist_m\": 6.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 130,\n \"el_deg\": 30,\n \"dist_m\": 6.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/10/1.6\",\n \"mono_low_end<120Hz\",\n \"pad_width>=0.75\",\n \"avoid_masking_lead_1500-4500Hz\"\n ]\n}"}
2
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 8,\n \"style\": \"cinematic\",\n \"section\": \"break\",\n \"global\": {\n \"bpm\": 90,\n \"energy\": 0.45\n },\n \"stems\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"lufs\": -18.0,\n \"transient\": 0.35,\n \"band_energy\": {\n \"low\": 0.05,\n \"mid\": 0.75,\n \"high\": 0.2\n },\n \"leadness\": 0.98\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"lufs\": -24.0,\n \"transient\": 0.08,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -25.5,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.55,\n \"high\": 0.2\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.0,\n \"transient\": 0.15,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.25,\n \"high\": 0.65\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"type\": \"keep_dialogue_clear\",\n \"band_hz\": [\n 1000,\n 4000\n ]\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.8\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -16.0,\n \"room_preset\": \"cinema_large\"\n },\n \"objects\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8,\n \"width\": 0.1,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.18,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n }\n ]\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2,\n \"width\": 0.4,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.24,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": -20,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5,\n \"width\": 0.85,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n },\n {\n \"t\": 1.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 95,\n \"el_deg\": 28,\n \"dist_m\": 7.0,\n \"width\": 0.7,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 80,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 120,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:dialogue@0/5/1.8\",\n \"keep_dialogue_clear_1000-4000Hz\",\n \"pad_width>=0.80\"\n ]\n}"}
3
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 12,\n \"style\": \"live_stage\",\n \"section\": \"chorus\",\n \"global\": {\n \"bpm\": 112,\n \"energy\": 0.78\n },\n \"stems\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -15.5,\n \"transient\": 0.3,\n \"band_energy\": {\n \"low\": 0.08,\n \"mid\": 0.72,\n \"high\": 0.2\n },\n \"leadness\": 0.96\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"lufs\": -17.4,\n \"transient\": 0.46,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.5\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"lufs\": -18.1,\n \"transient\": 0.44,\n \"band_energy\": {\n \"low\": 0.12,\n \"mid\": 0.62,\n \"high\": 0.26\n },\n \"leadness\": 0.42\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"lufs\": -11.8,\n \"transient\": 0.88,\n \"band_energy\": {\n \"low\": 0.4,\n \"mid\": 0.4,\n \"high\": 0.2\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.4,\n \"transient\": 0.61,\n \"band_energy\": {\n \"low\": 0.82,\n \"mid\": 0.14,\n \"high\": 0.04\n },\n \"leadness\": 0.3\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"lufs\": -27.0,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.02\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"stage_pair\",\n \"left_class\": \"guitar\",\n \"right_class\": \"guitar\"\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -15.0,\n \"room_preset\": \"stage_live\"\n },\n \"objects\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7,\n \"width\": 0.18,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.16,\n \"early_reflections\": 0.24,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n }\n ]\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": -30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4,\n \"width\": 0.2,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8,\n \"width\": 0.05,\n \"gain_db\": -0.8,\n \"reverb_send\": 0.04,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 12,\n \"dist_m\": 10.0,\n \"width\": 1.0,\n \"gain_db\": -12.0,\n \"reverb_send\": 0.6,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/8/1.7\",\n \"mono_low_end<120Hz\",\n \"stage_pair:guitars\"\n ]\n}"}
4
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"binaural\",\n \"max_objects\": 6,\n \"style\": \"ambient\",\n \"section\": \"intro\",\n \"global\": {\n \"bpm\": 72,\n \"energy\": 0.28\n },\n \"stems\": [\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -23.5,\n \"transient\": 0.04,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.52,\n \"high\": 0.23\n },\n \"leadness\": 0.18\n },\n {\n \"id\": \"pi1\",\n \"class\": \"piano\",\n \"lufs\": -20.2,\n \"transient\": 0.22,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.68,\n \"high\": 0.22\n },\n \"leadness\": 0.52\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.4,\n \"transient\": 0.11,\n \"band_energy\": {\n \"low\": 0.07,\n \"mid\": 0.2,\n \"high\": 0.73\n },\n \"leadness\": 0.08\n },\n {\n \"id\": \"a1\",\n \"class\": \"ambience\",\n \"lufs\": -28.0,\n \"transient\": 0.03,\n \"band_energy\": {\n \"low\": 0.18,\n \"mid\": 0.42,\n \"high\": 0.4\n },\n \"leadness\": 0.01\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"piano\",\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.85\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"binaural\",\n \"loudness_target_lufs\": -18.0,\n \"room_preset\": \"ambient_wide\"\n },\n \"objects\": [\n {\n \"id\": \"pi1\",\n \"class\": \"piano\",\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4,\n \"width\": 0.22,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.18,\n \"early_reflections\": 0.17,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": -8,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -95,\n \"el_deg\": 20,\n \"dist_m\": 6.0,\n \"width\": 0.9,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -95,\n \"el_deg\": 20,\n \"dist_m\": 6.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 6.0\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 100,\n \"el_deg\": 35,\n \"dist_m\": 8.0,\n \"width\": 0.75,\n \"gain_db\": -8.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.06,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 90,\n \"el_deg\": 35,\n \"dist_m\": 8.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 130,\n \"el_deg\": 35,\n \"dist_m\": 8.0\n }\n ]\n },\n {\n \"id\": \"a1\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 10,\n \"dist_m\": 12.0,\n \"width\": 1.0,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.65,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 10,\n \"dist_m\": 12.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 10,\n \"dist_m\": 12.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:piano@-12/6/2.4\",\n \"pad_width>=0.85\"\n ]\n}"}
data/valid.jsonl ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 8,\n \"style\": \"cinematic\",\n \"section\": \"break\",\n \"global\": {\n \"bpm\": 90,\n \"energy\": 0.45\n },\n \"stems\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"lufs\": -18.0,\n \"transient\": 0.35,\n \"band_energy\": {\n \"low\": 0.05,\n \"mid\": 0.75,\n \"high\": 0.2\n },\n \"leadness\": 0.98\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"lufs\": -24.0,\n \"transient\": 0.08,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -25.5,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.55,\n \"high\": 0.2\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.0,\n \"transient\": 0.15,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.25,\n \"high\": 0.65\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"type\": \"keep_dialogue_clear\",\n \"band_hz\": [\n 1000,\n 4000\n ]\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.8\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -16.0,\n \"room_preset\": \"cinema_large\"\n },\n \"objects\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8,\n \"width\": 0.1,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.18,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n }\n ]\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2,\n \"width\": 0.4,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.24,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": -20,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5,\n \"width\": 0.85,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n },\n {\n \"t\": 1.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 95,\n \"el_deg\": 28,\n \"dist_m\": 7.0,\n \"width\": 0.7,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 80,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 120,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:dialogue@0/5/1.8\",\n \"keep_dialogue_clear_1000-4000Hz\",\n \"pad_width>=0.80\"\n ]\n}"}
2
+ {"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 12,\n \"style\": \"live_stage\",\n \"section\": \"chorus\",\n \"global\": {\n \"bpm\": 112,\n \"energy\": 0.78\n },\n \"stems\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -15.5,\n \"transient\": 0.3,\n \"band_energy\": {\n \"low\": 0.08,\n \"mid\": 0.72,\n \"high\": 0.2\n },\n \"leadness\": 0.96\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"lufs\": -17.4,\n \"transient\": 0.46,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.5\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"lufs\": -18.1,\n \"transient\": 0.44,\n \"band_energy\": {\n \"low\": 0.12,\n \"mid\": 0.62,\n \"high\": 0.26\n },\n \"leadness\": 0.42\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"lufs\": -11.8,\n \"transient\": 0.88,\n \"band_energy\": {\n \"low\": 0.4,\n \"mid\": 0.4,\n \"high\": 0.2\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.4,\n \"transient\": 0.61,\n \"band_energy\": {\n \"low\": 0.82,\n \"mid\": 0.14,\n \"high\": 0.04\n },\n \"leadness\": 0.3\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"lufs\": -27.0,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.02\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"stage_pair\",\n \"left_class\": \"guitar\",\n \"right_class\": \"guitar\"\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -15.0,\n \"room_preset\": \"stage_live\"\n },\n \"objects\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7,\n \"width\": 0.18,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.16,\n \"early_reflections\": 0.24,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n }\n ]\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": -30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4,\n \"width\": 0.2,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8,\n \"width\": 0.05,\n \"gain_db\": -0.8,\n \"reverb_send\": 0.04,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 12,\n \"dist_m\": 10.0,\n \"width\": 1.0,\n \"gain_db\": -12.0,\n \"reverb_send\": 0.6,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/8/1.7\",\n \"mono_low_end<120Hz\",\n \"stage_pair:guitars\"\n ]\n}"}
evaluate.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Dict, Tuple
6
+
7
+ import torch
8
+ from datasets import load_dataset
9
+ from jsonschema import Draft7Validator
10
+ from peft import AutoPeftModelForCausalLM
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ SYSTEM_PREFIX = (
14
+ "You are GravityLLM, a Spatial9 scene generation model. "
15
+ "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
16
+ "Do not return markdown. Do not explain your answer.\n\n"
17
+ )
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.")
22
+ parser.add_argument("--model_dir", type=str, required=True)
23
+ parser.add_argument("--data_file", type=str, default="data/valid.jsonl")
24
+ parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json"))
25
+ parser.add_argument("--max_new_tokens", type=int, default=900)
26
+ parser.add_argument("--temperature", type=float, default=0.2)
27
+ parser.add_argument("--top_p", type=float, default=0.9)
28
+ parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.")
29
+ parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json"))
30
+ return parser.parse_args()
31
+
32
+
33
+ def load_model_and_tokenizer(model_dir: str):
34
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True)
35
+ if tokenizer.pad_token is None:
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+
38
+ try:
39
+ model = AutoPeftModelForCausalLM.from_pretrained(
40
+ model_dir,
41
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
42
+ device_map="auto" if torch.cuda.is_available() else None,
43
+ trust_remote_code=True,
44
+ )
45
+ except Exception:
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_dir,
48
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
49
+ device_map="auto" if torch.cuda.is_available() else None,
50
+ trust_remote_code=True,
51
+ )
52
+ model.eval()
53
+ return model, tokenizer
54
+
55
+
56
+ def format_prompt(raw_prompt: str) -> str:
57
+ raw_prompt = raw_prompt.strip()
58
+ if raw_prompt.lower().startswith("gravityllm:"):
59
+ raw_prompt = raw_prompt.split(":", 1)[1].strip()
60
+ return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
61
+
62
+
63
+ def extract_first_json(text: str) -> str:
64
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
65
+ return match.group(0).strip() if match else text.strip()
66
+
67
+
68
+ def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]:
69
+ data = json.loads(output_text)
70
+ validator = Draft7Validator(schema)
71
+ errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path))
72
+ return len(errors) == 0, data
73
+
74
+
75
+ def check_budget(input_payload: Dict, scene_payload: Dict) -> bool:
76
+ max_objects = input_payload.get("max_objects")
77
+ if max_objects is None:
78
+ return True
79
+ return len(scene_payload.get("objects", [])) <= max_objects
80
+
81
+
82
+ def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool:
83
+ objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])}
84
+ for rule in input_payload.get("rules", []):
85
+ if rule.get("type") != "anchor":
86
+ continue
87
+ klass = rule.get("track_class")
88
+ obj = objects.get(klass)
89
+ if obj is None:
90
+ return False
91
+ for field in ["az_deg", "el_deg", "dist_m"]:
92
+ if float(obj[field]) != float(rule[field]):
93
+ return False
94
+ return True
95
+
96
+
97
+ def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
98
+ inputs = tokenizer(prompt_text, return_tensors="pt")
99
+ if torch.cuda.is_available():
100
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
101
+
102
+ with torch.no_grad():
103
+ outputs = model.generate(
104
+ **inputs,
105
+ max_new_tokens=max_new_tokens,
106
+ do_sample=True,
107
+ temperature=temperature,
108
+ top_p=top_p,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ pad_token_id=tokenizer.pad_token_id,
111
+ )
112
+
113
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
+ prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
115
+ raw_completion = decoded[len(prompt_prefix):].strip()
116
+ return extract_first_json(raw_completion)
117
+
118
+
119
+ def main() -> None:
120
+ args = parse_args()
121
+ schema = json.loads(args.schema_path.read_text(encoding="utf-8"))
122
+ ds = load_dataset("json", data_files=args.data_file, split="train")
123
+ if args.limit > 0:
124
+ ds = ds.select(range(min(args.limit, len(ds))))
125
+
126
+ model, tokenizer = load_model_and_tokenizer(args.model_dir)
127
+
128
+ total = len(ds)
129
+ parse_ok = 0
130
+ schema_ok = 0
131
+ budget_ok = 0
132
+ anchor_ok = 0
133
+ samples = []
134
+
135
+ for row in ds:
136
+ prompt_text = format_prompt(row["prompt"])
137
+ generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p)
138
+
139
+ sample_report = {"prompt": row["prompt"], "generated": generated}
140
+ try:
141
+ gen_data = json.loads(generated)
142
+ parse_ok += 1
143
+ valid, gen_scene = validate_schema(schema, generated)
144
+ if valid:
145
+ schema_ok += 1
146
+ # Reconstruct input payload from prompt for simple rule checks.
147
+ prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1]
148
+ input_payload = json.loads(prompt_payload_text)
149
+ if check_budget(input_payload, gen_scene):
150
+ budget_ok += 1
151
+ if check_anchor_rules(input_payload, gen_scene):
152
+ anchor_ok += 1
153
+ sample_report["schema_valid"] = True
154
+ sample_report["budget_pass"] = check_budget(input_payload, gen_scene)
155
+ sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene)
156
+ else:
157
+ sample_report["schema_valid"] = False
158
+ except Exception as exc:
159
+ sample_report["error"] = str(exc)
160
+
161
+ samples.append(sample_report)
162
+
163
+ report = {
164
+ "examples": total,
165
+ "json_parse_rate": round(parse_ok / total, 4) if total else 0.0,
166
+ "schema_valid_rate": round(schema_ok / total, 4) if total else 0.0,
167
+ "budget_pass_rate": round(budget_ok / total, 4) if total else 0.0,
168
+ "anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0,
169
+ "samples": samples[:10],
170
+ }
171
+
172
+ args.report_path.parent.mkdir(parents=True, exist_ok=True)
173
+ args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
174
+ print(json.dumps(report, indent=2))
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
examples/sample_input.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "target_format": "iamf",
3
+ "max_objects": 10,
4
+ "style": "club",
5
+ "section": "drop",
6
+ "global": {
7
+ "bpm": 128,
8
+ "energy": 0.92
9
+ },
10
+ "stems": [
11
+ {
12
+ "id": "v1",
13
+ "class": "lead_vocal",
14
+ "lufs": -16.8,
15
+ "transient": 0.25,
16
+ "band_energy": {
17
+ "low": 0.1,
18
+ "mid": 0.6,
19
+ "high": 0.3
20
+ },
21
+ "leadness": 0.95
22
+ },
23
+ {
24
+ "id": "k1",
25
+ "class": "kick",
26
+ "lufs": -10.5,
27
+ "transient": 0.95,
28
+ "band_energy": {
29
+ "low": 0.8,
30
+ "mid": 0.15,
31
+ "high": 0.05
32
+ },
33
+ "leadness": 0.25
34
+ },
35
+ {
36
+ "id": "b1",
37
+ "class": "bass",
38
+ "lufs": -12.2,
39
+ "transient": 0.55,
40
+ "band_energy": {
41
+ "low": 0.85,
42
+ "mid": 0.12,
43
+ "high": 0.03
44
+ },
45
+ "leadness": 0.35
46
+ },
47
+ {
48
+ "id": "p1",
49
+ "class": "pad",
50
+ "lufs": -20.3,
51
+ "transient": 0.1,
52
+ "band_energy": {
53
+ "low": 0.2,
54
+ "mid": 0.5,
55
+ "high": 0.3
56
+ },
57
+ "leadness": 0.1
58
+ },
59
+ {
60
+ "id": "s1",
61
+ "class": "synth_lead",
62
+ "lufs": -18.0,
63
+ "transient": 0.4,
64
+ "band_energy": {
65
+ "low": 0.1,
66
+ "mid": 0.55,
67
+ "high": 0.35
68
+ },
69
+ "leadness": 0.75
70
+ },
71
+ {
72
+ "id": "fx1",
73
+ "class": "fx",
74
+ "lufs": -23.0,
75
+ "transient": 0.2,
76
+ "band_energy": {
77
+ "low": 0.1,
78
+ "mid": 0.3,
79
+ "high": 0.6
80
+ },
81
+ "leadness": 0.05
82
+ }
83
+ ],
84
+ "rules": [
85
+ {
86
+ "type": "anchor",
87
+ "track_class": "lead_vocal",
88
+ "az_deg": 0,
89
+ "el_deg": 10,
90
+ "dist_m": 1.6
91
+ },
92
+ {
93
+ "type": "mono_low_end",
94
+ "hz_below": 120
95
+ },
96
+ {
97
+ "type": "width_pref",
98
+ "track_class": "pad",
99
+ "min_width": 0.75
100
+ },
101
+ {
102
+ "type": "avoid_band_masking",
103
+ "mask_target": "lead_vocal",
104
+ "band_hz": [
105
+ 1500,
106
+ 4500
107
+ ]
108
+ }
109
+ ]
110
+ }
examples/sample_output.json ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "bed": {
4
+ "layout": "iamf",
5
+ "loudness_target_lufs": -14.0,
6
+ "room_preset": "club_medium"
7
+ },
8
+ "objects": [
9
+ {
10
+ "id": "v1",
11
+ "class": "lead_vocal",
12
+ "az_deg": 0,
13
+ "el_deg": 10,
14
+ "dist_m": 1.6,
15
+ "width": 0.15,
16
+ "gain_db": 0.0,
17
+ "reverb_send": 0.18,
18
+ "early_reflections": 0.22,
19
+ "motion": [
20
+ {
21
+ "t": 0.0,
22
+ "az_deg": 0,
23
+ "el_deg": 10,
24
+ "dist_m": 1.6
25
+ },
26
+ {
27
+ "t": 1.0,
28
+ "az_deg": 0,
29
+ "el_deg": 10,
30
+ "dist_m": 1.6
31
+ }
32
+ ]
33
+ },
34
+ {
35
+ "id": "k1",
36
+ "class": "kick",
37
+ "az_deg": 0,
38
+ "el_deg": 0,
39
+ "dist_m": 2.2,
40
+ "width": 0.0,
41
+ "gain_db": 0.0,
42
+ "reverb_send": 0.02,
43
+ "early_reflections": 0.05,
44
+ "motion": [
45
+ {
46
+ "t": 0.0,
47
+ "az_deg": 0,
48
+ "el_deg": 0,
49
+ "dist_m": 2.2
50
+ },
51
+ {
52
+ "t": 1.0,
53
+ "az_deg": 0,
54
+ "el_deg": 0,
55
+ "dist_m": 2.2
56
+ }
57
+ ]
58
+ },
59
+ {
60
+ "id": "b1",
61
+ "class": "bass",
62
+ "az_deg": 0,
63
+ "el_deg": -5,
64
+ "dist_m": 2.6,
65
+ "width": 0.05,
66
+ "gain_db": -0.5,
67
+ "reverb_send": 0.03,
68
+ "early_reflections": 0.06,
69
+ "motion": [
70
+ {
71
+ "t": 0.0,
72
+ "az_deg": 0,
73
+ "el_deg": -5,
74
+ "dist_m": 2.6
75
+ },
76
+ {
77
+ "t": 1.0,
78
+ "az_deg": 0,
79
+ "el_deg": -5,
80
+ "dist_m": 2.6
81
+ }
82
+ ]
83
+ },
84
+ {
85
+ "id": "s1",
86
+ "class": "synth_lead",
87
+ "az_deg": 25,
88
+ "el_deg": 15,
89
+ "dist_m": 2.0,
90
+ "width": 0.35,
91
+ "gain_db": -1.0,
92
+ "reverb_send": 0.12,
93
+ "early_reflections": 0.1,
94
+ "motion": [
95
+ {
96
+ "t": 0.0,
97
+ "az_deg": 20,
98
+ "el_deg": 15,
99
+ "dist_m": 2.0
100
+ },
101
+ {
102
+ "t": 1.0,
103
+ "az_deg": 30,
104
+ "el_deg": 15,
105
+ "dist_m": 2.0
106
+ }
107
+ ]
108
+ },
109
+ {
110
+ "id": "p1",
111
+ "class": "pad",
112
+ "az_deg": -70,
113
+ "el_deg": 20,
114
+ "dist_m": 4.0,
115
+ "width": 0.85,
116
+ "gain_db": -3.0,
117
+ "reverb_send": 0.28,
118
+ "early_reflections": 0.12,
119
+ "motion": [
120
+ {
121
+ "t": 0.0,
122
+ "az_deg": -70,
123
+ "el_deg": 20,
124
+ "dist_m": 4.0
125
+ },
126
+ {
127
+ "t": 1.0,
128
+ "az_deg": -70,
129
+ "el_deg": 20,
130
+ "dist_m": 4.0
131
+ }
132
+ ]
133
+ },
134
+ {
135
+ "id": "fx1",
136
+ "class": "fx",
137
+ "az_deg": 110,
138
+ "el_deg": 30,
139
+ "dist_m": 6.0,
140
+ "width": 0.65,
141
+ "gain_db": -6.0,
142
+ "reverb_send": 0.4,
143
+ "early_reflections": 0.08,
144
+ "motion": [
145
+ {
146
+ "t": 0.0,
147
+ "az_deg": 100,
148
+ "el_deg": 30,
149
+ "dist_m": 6.0
150
+ },
151
+ {
152
+ "t": 1.0,
153
+ "az_deg": 130,
154
+ "el_deg": 30,
155
+ "dist_m": 6.0
156
+ }
157
+ ]
158
+ }
159
+ ],
160
+ "constraints_applied": [
161
+ "anchor:lead_vocal@0/10/1.6",
162
+ "mono_low_end<120Hz",
163
+ "pad_width>=0.75",
164
+ "avoid_masking_lead_1500-4500Hz"
165
+ ]
166
+ }
infer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from jsonschema import Draft7Validator
8
+ from peft import AutoPeftModelForCausalLM
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ SYSTEM_PREFIX = (
12
+ "You are GravityLLM, a Spatial9 scene generation model. "
13
+ "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
14
+ "Do not return markdown. Do not explain your answer.\n\n"
15
+ )
16
+
17
+
18
+ def parse_args() -> argparse.Namespace:
19
+ parser = argparse.ArgumentParser(description="Run GravityLLM inference on a Spatial9 constraint payload.")
20
+ parser.add_argument("--model_dir", type=str, required=True, help="Path or Hub repo id for trained model or adapter.")
21
+ parser.add_argument("--input_json", type=Path, required=True)
22
+ parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json"))
23
+ parser.add_argument("--output_json", type=Path, default=None)
24
+ parser.add_argument("--max_new_tokens", type=int, default=900)
25
+ parser.add_argument("--temperature", type=float, default=0.35)
26
+ parser.add_argument("--top_p", type=float, default=0.9)
27
+ parser.add_argument("--validate", action="store_true")
28
+ return parser.parse_args()
29
+
30
+
31
+ def load_model_and_tokenizer(model_dir: str):
32
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ model = None
37
+ try:
38
+ model = AutoPeftModelForCausalLM.from_pretrained(
39
+ model_dir,
40
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
41
+ device_map="auto" if torch.cuda.is_available() else None,
42
+ trust_remote_code=True,
43
+ )
44
+ except Exception:
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_dir,
47
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
48
+ device_map="auto" if torch.cuda.is_available() else None,
49
+ trust_remote_code=True,
50
+ )
51
+ model.eval()
52
+ return model, tokenizer
53
+
54
+
55
+ def extract_first_json(text: str) -> str:
56
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
57
+ return match.group(0).strip() if match else text.strip()
58
+
59
+
60
+ def validate_output(schema_path: Path, output_text: str):
61
+ schema = json.loads(schema_path.read_text(encoding="utf-8"))
62
+ data = json.loads(output_text)
63
+ validator = Draft7Validator(schema)
64
+ errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path))
65
+ return data, errors
66
+
67
+
68
+ def main() -> None:
69
+ args = parse_args()
70
+ payload = json.loads(args.input_json.read_text(encoding="utf-8"))
71
+
72
+ model, tokenizer = load_model_and_tokenizer(args.model_dir)
73
+ prompt = SYSTEM_PREFIX + "INPUT:\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\n\nOUTPUT:\n"
74
+
75
+ inputs = tokenizer(prompt, return_tensors="pt")
76
+ if torch.cuda.is_available():
77
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
78
+
79
+ with torch.no_grad():
80
+ outputs = model.generate(
81
+ **inputs,
82
+ max_new_tokens=args.max_new_tokens,
83
+ do_sample=True,
84
+ temperature=args.temperature,
85
+ top_p=args.top_p,
86
+ eos_token_id=tokenizer.eos_token_id,
87
+ pad_token_id=tokenizer.pad_token_id,
88
+ )
89
+
90
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
92
+ raw_completion = decoded[len(prompt_prefix):].strip()
93
+ json_text = extract_first_json(raw_completion)
94
+
95
+ if args.validate:
96
+ try:
97
+ _, errors = validate_output(args.schema_path, json_text)
98
+ if errors:
99
+ print("Validation: INVALID")
100
+ for err in errors[:20]:
101
+ path = ".".join(str(p) for p in err.path)
102
+ print(f"- {path}: {err.message}")
103
+ else:
104
+ print("Validation: VALID")
105
+ except Exception as exc:
106
+ print(f"Validation failed: {exc}")
107
+
108
+ if args.output_json:
109
+ args.output_json.parent.mkdir(parents=True, exist_ok=True)
110
+ args.output_json.write_text(json_text + "\n", encoding="utf-8")
111
+ print(f"Wrote output to {args.output_json}")
112
+
113
+ print(json_text)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gravityllm"
3
+ version = "0.1.0"
4
+ description = "Constraint-driven Spatial9 scene generation for immersive audio."
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = {text = "Apache-2.0"}
8
+ dependencies = [
9
+ "accelerate>=0.31.0",
10
+ "bitsandbytes>=0.43.0",
11
+ "datasets>=2.19.0",
12
+ "huggingface_hub>=0.24.0",
13
+ "jsonschema>=4.22.0",
14
+ "peft>=0.11.0",
15
+ "safetensors>=0.4.3",
16
+ "torch",
17
+ "transformers>=4.41.0",
18
+ ]
19
+
20
+ [tool.setuptools]
21
+ py-modules = ["train", "infer", "evaluate", "upload_to_hub"]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.31.0
2
+ bitsandbytes>=0.43.0
3
+ datasets>=2.19.0
4
+ huggingface_hub>=0.24.0
5
+ jsonschema>=4.22.0
6
+ peft>=0.11.0
7
+ safetensors>=0.4.3
8
+ torch
9
+ transformers>=4.41.0
schemas/scene.schema.json ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "http://json-schema.org/draft-07/schema#",
3
+ "title": "Spatial9Scene",
4
+ "type": "object",
5
+ "required": [
6
+ "version",
7
+ "bed",
8
+ "objects",
9
+ "constraints_applied"
10
+ ],
11
+ "properties": {
12
+ "version": {
13
+ "type": "string"
14
+ },
15
+ "bed": {
16
+ "type": "object",
17
+ "required": [
18
+ "layout",
19
+ "loudness_target_lufs",
20
+ "room_preset"
21
+ ],
22
+ "properties": {
23
+ "layout": {
24
+ "type": "string",
25
+ "enum": [
26
+ "binaural",
27
+ "5.1.4",
28
+ "7.1.4",
29
+ "iamf"
30
+ ]
31
+ },
32
+ "loudness_target_lufs": {
33
+ "type": "number"
34
+ },
35
+ "room_preset": {
36
+ "type": "string"
37
+ }
38
+ },
39
+ "additionalProperties": false
40
+ },
41
+ "objects": {
42
+ "type": "array",
43
+ "minItems": 1,
44
+ "maxItems": 32,
45
+ "items": {
46
+ "type": "object",
47
+ "required": [
48
+ "id",
49
+ "class",
50
+ "az_deg",
51
+ "el_deg",
52
+ "dist_m",
53
+ "width",
54
+ "gain_db",
55
+ "reverb_send",
56
+ "early_reflections",
57
+ "motion"
58
+ ],
59
+ "properties": {
60
+ "id": {
61
+ "type": "string"
62
+ },
63
+ "class": {
64
+ "type": "string",
65
+ "enum": [
66
+ "lead_vocal",
67
+ "back_vocal",
68
+ "kick",
69
+ "snare",
70
+ "hihat",
71
+ "bass",
72
+ "drums_bus",
73
+ "pad",
74
+ "synth_lead",
75
+ "guitar",
76
+ "piano",
77
+ "strings",
78
+ "brass",
79
+ "fx",
80
+ "dialogue",
81
+ "ambience",
82
+ "other"
83
+ ]
84
+ },
85
+ "az_deg": {
86
+ "type": "number",
87
+ "minimum": -180,
88
+ "maximum": 180
89
+ },
90
+ "el_deg": {
91
+ "type": "number",
92
+ "minimum": -45,
93
+ "maximum": 90
94
+ },
95
+ "dist_m": {
96
+ "type": "number",
97
+ "minimum": 0.5,
98
+ "maximum": 15.0
99
+ },
100
+ "width": {
101
+ "type": "number",
102
+ "minimum": 0.0,
103
+ "maximum": 1.0
104
+ },
105
+ "gain_db": {
106
+ "type": "number",
107
+ "minimum": -60,
108
+ "maximum": 12
109
+ },
110
+ "reverb_send": {
111
+ "type": "number",
112
+ "minimum": 0.0,
113
+ "maximum": 1.0
114
+ },
115
+ "early_reflections": {
116
+ "type": "number",
117
+ "minimum": 0.0,
118
+ "maximum": 1.0
119
+ },
120
+ "motion": {
121
+ "type": "array",
122
+ "maxItems": 64,
123
+ "items": {
124
+ "type": "object",
125
+ "required": [
126
+ "t",
127
+ "az_deg",
128
+ "el_deg",
129
+ "dist_m"
130
+ ],
131
+ "properties": {
132
+ "t": {
133
+ "type": "number",
134
+ "minimum": 0.0,
135
+ "maximum": 1.0
136
+ },
137
+ "az_deg": {
138
+ "type": "number",
139
+ "minimum": -180,
140
+ "maximum": 180
141
+ },
142
+ "el_deg": {
143
+ "type": "number",
144
+ "minimum": -45,
145
+ "maximum": 90
146
+ },
147
+ "dist_m": {
148
+ "type": "number",
149
+ "minimum": 0.5,
150
+ "maximum": 15.0
151
+ }
152
+ },
153
+ "additionalProperties": false
154
+ }
155
+ }
156
+ },
157
+ "additionalProperties": false
158
+ }
159
+ },
160
+ "constraints_applied": {
161
+ "type": "array",
162
+ "items": {
163
+ "type": "string"
164
+ },
165
+ "maxItems": 64
166
+ }
167
+ },
168
+ "additionalProperties": false
169
+ }
scripts/push_to_hub.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ if [ $# -lt 2 ]; then
5
+ echo "Usage: scripts/push_to_hub.sh <folder_path> <namespace/repo>"
6
+ exit 1
7
+ fi
8
+
9
+ python upload_to_hub.py --folder_path "$1" --repo_id "$2"
scripts/train_qlora.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ python train.py \
5
+ --model Qwen/Qwen2.5-1.5B-Instruct \
6
+ --train_file data/train.jsonl \
7
+ --valid_file data/valid.jsonl \
8
+ --output_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 \
9
+ --max_length 2048 \
10
+ --num_train_epochs 3 \
11
+ --learning_rate 2e-4 \
12
+ --train_batch_size 1 \
13
+ --eval_batch_size 1 \
14
+ --gradient_accumulation_steps 16 \
15
+ --warmup_ratio 0.03 \
16
+ --logging_steps 10 \
17
+ --save_steps 100 \
18
+ --eval_steps 100 \
19
+ --qlora --bf16
tools/make_synthetic_dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import random
4
+ from pathlib import Path
5
+
6
+
7
+ CLASSES = [
8
+ ("lead_vocal", 0.95),
9
+ ("kick", 0.25),
10
+ ("bass", 0.35),
11
+ ("pad", 0.10),
12
+ ("synth_lead", 0.70),
13
+ ("fx", 0.05),
14
+ ]
15
+
16
+ STYLE_PRESETS = {
17
+ "club": {"layout": "iamf", "room_preset": "club_medium", "lufs": -14.0},
18
+ "cinematic": {"layout": "iamf", "room_preset": "cinema_large", "lufs": -16.0},
19
+ "live_stage": {"layout": "iamf", "room_preset": "stage_live", "lufs": -15.0},
20
+ }
21
+
22
+
23
+ def build_example(seed: int) -> dict:
24
+ rng = random.Random(seed)
25
+ style = rng.choice(list(STYLE_PRESETS.keys()))
26
+ bpm = rng.choice([90, 100, 110, 120, 128, 140])
27
+ energy = round(rng.uniform(0.35, 0.95), 2)
28
+ max_objects = rng.choice([8, 10, 12])
29
+
30
+ stems = []
31
+ for idx, (klass, leadness) in enumerate(CLASSES, start=1):
32
+ stems.append(
33
+ {
34
+ "id": f"{klass[:1]}{idx}",
35
+ "class": klass,
36
+ "lufs": round(rng.uniform(-25.0, -10.0), 1),
37
+ "transient": round(rng.uniform(0.05, 0.98), 2),
38
+ "band_energy": {
39
+ "low": round(rng.uniform(0.05, 0.9), 2),
40
+ "mid": round(rng.uniform(0.05, 0.9), 2),
41
+ "high": round(rng.uniform(0.05, 0.9), 2),
42
+ },
43
+ "leadness": leadness,
44
+ }
45
+ )
46
+
47
+ payload = {
48
+ "target_format": "iamf",
49
+ "max_objects": max_objects,
50
+ "style": style,
51
+ "section": rng.choice(["intro", "verse", "break", "drop"]),
52
+ "global": {"bpm": bpm, "energy": energy},
53
+ "stems": stems,
54
+ "rules": [
55
+ {"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
56
+ {"type": "mono_low_end", "hz_below": 120},
57
+ {"type": "width_pref", "track_class": "pad", "min_width": 0.75},
58
+ ],
59
+ }
60
+
61
+ preset = STYLE_PRESETS[style]
62
+ output = {
63
+ "version": "1.0",
64
+ "bed": {
65
+ "layout": preset["layout"],
66
+ "loudness_target_lufs": preset["lufs"],
67
+ "room_preset": preset["room_preset"],
68
+ },
69
+ "objects": [
70
+ {
71
+ "id": stems[0]["id"],
72
+ "class": "lead_vocal",
73
+ "az_deg": 0,
74
+ "el_deg": 10,
75
+ "dist_m": 1.6,
76
+ "width": 0.15,
77
+ "gain_db": 0.0,
78
+ "reverb_send": 0.18,
79
+ "early_reflections": 0.2,
80
+ "motion": [
81
+ {"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
82
+ {"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
83
+ ],
84
+ }
85
+ ],
86
+ "constraints_applied": [
87
+ "anchor:lead_vocal@0/10/1.6",
88
+ "mono_low_end<120Hz",
89
+ "pad_width>=0.75",
90
+ ],
91
+ }
92
+
93
+ prompt = (
94
+ "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\n"
95
+ "INPUT:\n" + json.dumps(payload, indent=2)
96
+ )
97
+
98
+ return {"prompt": prompt, "completion": json.dumps(output, indent=2)}
99
+
100
+
101
+ def main() -> None:
102
+ parser = argparse.ArgumentParser(description="Generate a small synthetic GravityLLM dataset.")
103
+ parser.add_argument("--output", type=Path, default=Path("data/synthetic_train.jsonl"))
104
+ parser.add_argument("--count", type=int, default=25)
105
+ parser.add_argument("--seed", type=int, default=42)
106
+ args = parser.parse_args()
107
+
108
+ args.output.parent.mkdir(parents=True, exist_ok=True)
109
+ with args.output.open("w", encoding="utf-8") as f:
110
+ for i in range(args.count):
111
+ row = build_example(args.seed + i)
112
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
113
+
114
+ print(f"Wrote {args.count} examples to {args.output}")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
tools/validate_scene.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from jsonschema import Draft7Validator
6
+
7
+
8
+ def validate(schema_path: Path, json_path: Path) -> int:
9
+ schema = json.loads(schema_path.read_text(encoding="utf-8"))
10
+ payload = json.loads(json_path.read_text(encoding="utf-8"))
11
+
12
+ validator = Draft7Validator(schema)
13
+ errors = sorted(validator.iter_errors(payload), key=lambda e: list(e.path))
14
+
15
+ if not errors:
16
+ print("VALID")
17
+ return 0
18
+
19
+ print("INVALID")
20
+ for err in errors[:50]:
21
+ path = ".".join(str(p) for p in err.path)
22
+ print(f"- {path}: {err.message}")
23
+ return 1
24
+
25
+
26
+ def main() -> int:
27
+ parser = argparse.ArgumentParser(description="Validate a GravityLLM scene JSON against the Spatial9Scene schema.")
28
+ parser.add_argument("schema_path", type=Path)
29
+ parser.add_argument("json_path", type=Path)
30
+ args = parser.parse_args()
31
+ return validate(args.schema_path, args.json_path)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ raise SystemExit(main())
train.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Dict, List
8
+
9
+ import torch
10
+ from datasets import load_dataset
11
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ BitsAndBytesConfig,
16
+ Trainer,
17
+ TrainingArguments,
18
+ set_seed,
19
+ )
20
+
21
+ SYSTEM_PREFIX = (
22
+ "You are GravityLLM, a Spatial9 scene generation model. "
23
+ "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
24
+ "Do not return markdown. Do not explain your answer. "
25
+ "Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n"
26
+ )
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.")
31
+ parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
32
+ parser.add_argument("--train_file", type=str, default="data/train.jsonl")
33
+ parser.add_argument("--valid_file", type=str, default="data/valid.jsonl")
34
+ parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9")
35
+ parser.add_argument("--max_length", type=int, default=2048)
36
+
37
+ parser.add_argument("--num_train_epochs", type=float, default=1.0)
38
+ parser.add_argument("--learning_rate", type=float, default=2e-4)
39
+ parser.add_argument("--train_batch_size", type=int, default=1)
40
+ parser.add_argument("--eval_batch_size", type=int, default=1)
41
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
42
+ parser.add_argument("--warmup_ratio", type=float, default=0.03)
43
+ parser.add_argument("--weight_decay", type=float, default=0.0)
44
+ parser.add_argument("--logging_steps", type=int, default=10)
45
+ parser.add_argument("--save_steps", type=int, default=200)
46
+ parser.add_argument("--eval_steps", type=int, default=200)
47
+ parser.add_argument("--seed", type=int, default=42)
48
+
49
+ parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.")
50
+ parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.")
51
+ parser.add_argument("--lora_r", type=int, default=16)
52
+ parser.add_argument("--lora_alpha", type=int, default=32)
53
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
54
+
55
+ parser.add_argument("--bf16", action="store_true")
56
+ parser.add_argument("--fp16", action="store_true")
57
+
58
+ parser.add_argument("--push_to_hub", action="store_true")
59
+ parser.add_argument("--hub_model_id", type=str, default=None)
60
+ parser.add_argument("--hub_private_repo", action="store_true")
61
+ return parser.parse_args()
62
+
63
+
64
+ def load_jsonl(file_path: str):
65
+ return load_dataset("json", data_files=file_path, split="train")
66
+
67
+
68
+ def format_prompt(raw_prompt: str) -> str:
69
+ raw_prompt = raw_prompt.strip()
70
+ if raw_prompt.lower().startswith("gravityllm:"):
71
+ raw_prompt = raw_prompt.split(":", 1)[1].strip()
72
+ return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
73
+
74
+
75
+ def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]:
76
+ prompt_text = format_prompt(example["prompt"])
77
+ completion_text = example["completion"].strip()
78
+
79
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
80
+ completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"]
81
+
82
+ input_ids = prompt_ids + completion_ids
83
+ labels = [-100] * len(prompt_ids) + completion_ids
84
+
85
+ if len(input_ids) > max_length:
86
+ input_ids = input_ids[:max_length]
87
+ labels = labels[:max_length]
88
+
89
+ attention_mask = [1] * len(input_ids)
90
+ return {
91
+ "input_ids": input_ids,
92
+ "attention_mask": attention_mask,
93
+ "labels": labels,
94
+ }
95
+
96
+
97
+ @dataclass
98
+ class CausalDataCollator:
99
+ pad_token_id: int
100
+ label_pad_token_id: int = -100
101
+
102
+ def __call__(self, features):
103
+ max_len = max(len(f["input_ids"]) for f in features)
104
+
105
+ input_ids = []
106
+ attention_mask = []
107
+ labels = []
108
+
109
+ for f in features:
110
+ pad_len = max_len - len(f["input_ids"])
111
+ input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len)
112
+ attention_mask.append(f["attention_mask"] + [0] * pad_len)
113
+ labels.append(f["labels"] + [self.label_pad_token_id] * pad_len)
114
+
115
+ batch = {
116
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
117
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
118
+ "labels": torch.tensor(labels, dtype=torch.long),
119
+ }
120
+ return batch
121
+
122
+
123
+ def prepare_model(args: argparse.Namespace):
124
+ model_kwargs = {}
125
+ if args.qlora:
126
+ compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
127
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
128
+ load_in_4bit=True,
129
+ bnb_4bit_quant_type="nf4",
130
+ bnb_4bit_use_double_quant=True,
131
+ bnb_4bit_compute_dtype=compute_dtype,
132
+ )
133
+ model_kwargs["device_map"] = "auto"
134
+
135
+ model = AutoModelForCausalLM.from_pretrained(
136
+ args.model,
137
+ torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None),
138
+ trust_remote_code=True,
139
+ **model_kwargs,
140
+ )
141
+ model.config.use_cache = False
142
+
143
+ if args.qlora:
144
+ model = prepare_model_for_kbit_training(model)
145
+
146
+ if args.lora or args.qlora:
147
+ lora_config = LoraConfig(
148
+ r=args.lora_r,
149
+ lora_alpha=args.lora_alpha,
150
+ lora_dropout=args.lora_dropout,
151
+ bias="none",
152
+ task_type="CAUSAL_LM",
153
+ target_modules="all-linear",
154
+ )
155
+ model = get_peft_model(model, lora_config)
156
+ model.print_trainable_parameters()
157
+
158
+ return model
159
+
160
+
161
+ def main() -> None:
162
+ args = parse_args()
163
+ os.makedirs(args.output_dir, exist_ok=True)
164
+ set_seed(args.seed)
165
+
166
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True)
167
+ tokenizer.padding_side = "right"
168
+ if tokenizer.pad_token is None:
169
+ tokenizer.pad_token = tokenizer.eos_token
170
+
171
+ train_ds = load_jsonl(args.train_file)
172
+ valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None
173
+
174
+ train_ds = train_ds.map(
175
+ lambda row: tokenize_example(row, tokenizer, args.max_length),
176
+ remove_columns=train_ds.column_names,
177
+ desc="Tokenizing train set",
178
+ )
179
+ if valid_ds is not None:
180
+ valid_ds = valid_ds.map(
181
+ lambda row: tokenize_example(row, tokenizer, args.max_length),
182
+ remove_columns=valid_ds.column_names,
183
+ desc="Tokenizing valid set",
184
+ )
185
+
186
+ model = prepare_model(args)
187
+
188
+ training_args = TrainingArguments(
189
+ output_dir=args.output_dir,
190
+ overwrite_output_dir=True,
191
+ num_train_epochs=args.num_train_epochs,
192
+ learning_rate=args.learning_rate,
193
+ per_device_train_batch_size=args.train_batch_size,
194
+ per_device_eval_batch_size=args.eval_batch_size,
195
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
196
+ warmup_ratio=args.warmup_ratio,
197
+ weight_decay=args.weight_decay,
198
+ logging_steps=args.logging_steps,
199
+ save_steps=args.save_steps,
200
+ eval_steps=args.eval_steps,
201
+ evaluation_strategy="steps" if valid_ds is not None else "no",
202
+ save_strategy="steps",
203
+ bf16=args.bf16,
204
+ fp16=args.fp16,
205
+ report_to="none",
206
+ gradient_checkpointing=True,
207
+ lr_scheduler_type="cosine",
208
+ optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch",
209
+ max_grad_norm=1.0,
210
+ push_to_hub=args.push_to_hub,
211
+ hub_model_id=args.hub_model_id,
212
+ hub_private_repo=args.hub_private_repo,
213
+ hub_strategy="end" if args.push_to_hub else "every_save",
214
+ )
215
+
216
+ trainer = Trainer(
217
+ model=model,
218
+ args=training_args,
219
+ train_dataset=train_ds,
220
+ eval_dataset=valid_ds,
221
+ data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id),
222
+ tokenizer=tokenizer,
223
+ )
224
+
225
+ train_result = trainer.train()
226
+ trainer.save_model(args.output_dir)
227
+ tokenizer.save_pretrained(args.output_dir)
228
+
229
+ metrics = train_result.metrics
230
+ with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f:
231
+ json.dump(metrics, f, indent=2)
232
+
233
+ run_meta = vars(args).copy()
234
+ run_meta["train_examples"] = len(train_ds)
235
+ run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0
236
+ with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f:
237
+ json.dump(run_meta, f, indent=2)
238
+
239
+ if args.push_to_hub:
240
+ trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter")
241
+ print(f"Training complete. Artifacts saved to: {args.output_dir}")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
upload_to_hub.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ from huggingface_hub import HfApi
5
+
6
+
7
+ def parse_args() -> argparse.Namespace:
8
+ parser = argparse.ArgumentParser(description="Upload a local GravityLLM folder to the Hugging Face Hub.")
9
+ parser.add_argument("--folder_path", type=Path, required=True, help="Local folder to upload.")
10
+ parser.add_argument("--repo_id", type=str, required=True, help="Namespace/repo-name on Hugging Face.")
11
+ parser.add_argument("--repo_type", type=str, default="model", choices=["model", "dataset", "space"])
12
+ parser.add_argument("--private", action="store_true")
13
+ parser.add_argument("--commit_message", type=str, default="Upload GravityLLM artifacts")
14
+ return parser.parse_args()
15
+
16
+
17
+ def main() -> None:
18
+ args = parse_args()
19
+ api = HfApi()
20
+ api.create_repo(repo_id=args.repo_id, repo_type=args.repo_type, private=args.private, exist_ok=True)
21
+ api.upload_folder(
22
+ folder_path=str(args.folder_path),
23
+ repo_id=args.repo_id,
24
+ repo_type=args.repo_type,
25
+ commit_message=args.commit_message,
26
+ )
27
+ print(f"Uploaded {args.folder_path} to https://huggingface.co/{args.repo_id}")
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()