Upload 20 files
Browse files- LICENSE +47 -0
- Makefile +17 -0
- README.md +267 -3
- assets/gravityllm_banner.svg +44 -0
- configs/recommended_train_args.json +23 -0
- data/train.jsonl +4 -0
- data/valid.jsonl +2 -0
- evaluate.py +178 -0
- examples/sample_input.json +110 -0
- examples/sample_output.json +166 -0
- infer.py +117 -0
- pyproject.toml +21 -0
- requirements.txt +9 -0
- schemas/scene.schema.json +169 -0
- scripts/push_to_hub.sh +9 -0
- scripts/train_qlora.sh +19 -0
- tools/make_synthetic_dataset.py +118 -0
- tools/validate_scene.py +35 -0
- train.py +245 -0
- upload_to_hub.py +31 -0
LICENSE
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
| 10 |
+
|
| 11 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
| 12 |
+
|
| 13 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
| 14 |
+
|
| 15 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
| 16 |
+
|
| 17 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
| 18 |
+
|
| 19 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
| 20 |
+
|
| 21 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work.
|
| 22 |
+
|
| 23 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
| 24 |
+
|
| 25 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems.
|
| 26 |
+
|
| 27 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
| 28 |
+
|
| 29 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
| 30 |
+
|
| 31 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work.
|
| 32 |
+
|
| 33 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
| 34 |
+
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
| 35 |
+
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
| 36 |
+
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work; and
|
| 37 |
+
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file.
|
| 38 |
+
|
| 39 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work shall be under the terms and conditions of this License.
|
| 40 |
+
|
| 41 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor.
|
| 42 |
+
|
| 43 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 44 |
+
|
| 45 |
+
8. Limitation of Liability. In no event and under no legal theory shall any Contributor be liable to You for damages arising from the use of the Work.
|
| 46 |
+
|
| 47 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer support or warranty protection only on Your own behalf.
|
Makefile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
install:
|
| 2 |
+
python -m pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
train:
|
| 5 |
+
python train.py --qlora --bf16 --num_train_epochs 3 --train_batch_size 1 --gradient_accumulation_steps 16
|
| 6 |
+
|
| 7 |
+
infer:
|
| 8 |
+
python infer.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --input_json examples/sample_input.json --validate --output_json outputs/sample_prediction.json
|
| 9 |
+
|
| 10 |
+
eval:
|
| 11 |
+
python evaluate.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --data_file data/valid.jsonl --report_path reports/eval_report.json
|
| 12 |
+
|
| 13 |
+
validate-example:
|
| 14 |
+
python tools/validate_scene.py schemas/scene.schema.json examples/sample_output.json
|
| 15 |
+
|
| 16 |
+
synthetic-data:
|
| 17 |
+
python tools/make_synthetic_dataset.py --output data/synthetic_train.jsonl --count 50
|
README.md
CHANGED
|
@@ -1,3 +1,267 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
| 8 |
+
tags:
|
| 9 |
+
- gravityllm
|
| 10 |
+
- spatial-audio
|
| 11 |
+
- immersive-audio
|
| 12 |
+
- spatial9
|
| 13 |
+
- iamf
|
| 14 |
+
- instruction-tuning
|
| 15 |
+
- json
|
| 16 |
+
- lora
|
| 17 |
+
- qlora
|
| 18 |
+
- peft
|
| 19 |
+
- transformers
|
| 20 |
+
widget:
|
| 21 |
+
- text: |-
|
| 22 |
+
INPUT:
|
| 23 |
+
{
|
| 24 |
+
"target_format": "iamf",
|
| 25 |
+
"max_objects": 10,
|
| 26 |
+
"style": "club",
|
| 27 |
+
"section": "drop",
|
| 28 |
+
"global": {"bpm": 128, "energy": 0.92},
|
| 29 |
+
"stems": [
|
| 30 |
+
{"id": "v1", "class": "lead_vocal", "lufs": -16.8, "transient": 0.25, "band_energy": {"low": 0.1, "mid": 0.6, "high": 0.3}, "leadness": 0.95},
|
| 31 |
+
{"id": "k1", "class": "kick", "lufs": -10.5, "transient": 0.95, "band_energy": {"low": 0.8, "mid": 0.15, "high": 0.05}, "leadness": 0.25}
|
| 32 |
+
],
|
| 33 |
+
"rules": [
|
| 34 |
+
{"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 35 |
+
{"type": "mono_low_end", "hz_below": 120}
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+
# GravityLLM
|
| 43 |
+
|
| 44 |
+
GravityLLM is a compact instruction-tuned model for **constraint-conditioned spatial scene generation**.
|
| 45 |
+
It turns **music constraints + stem descriptors** into strict **Spatial9Scene JSON** for immersive audio pipelines such as IAMF, binaural, and bed-plus-object rendering workflows.
|
| 46 |
+
|
| 47 |
+
> **Status**
|
| 48 |
+
> This repository is **training-ready and Hub-ready**.
|
| 49 |
+
> The zip includes code, schema, sample data, evaluation, and upload helpers.
|
| 50 |
+
> It does **not** include fine-tuned weights yet. After training, upload the contents of your `outputs/...` folder as the actual model repo.
|
| 51 |
+
|
| 52 |
+
## What makes this repo solid
|
| 53 |
+
|
| 54 |
+
- Proper instruction fine-tuning with **prompt masking**, so the loss is applied to the target JSON instead of the instruction prefix.
|
| 55 |
+
- **LoRA** and **QLoRA** training paths for efficient fine-tuning on small-to-medium GPUs.
|
| 56 |
+
- Strict **JSON Schema** validation for production-safe outputs.
|
| 57 |
+
- Built-in **evaluation** for parse rate, schema-valid rate, object-budget pass rate, and anchor-rule pass rate.
|
| 58 |
+
- Clean **Hugging Face upload** helper with `upload_folder`.
|
| 59 |
+
- Ready-made **sample data**, **sample scene**, and **recommended training config**.
|
| 60 |
+
|
| 61 |
+
## Model contract
|
| 62 |
+
|
| 63 |
+
### Input
|
| 64 |
+
A structured payload describing:
|
| 65 |
+
|
| 66 |
+
- target format
|
| 67 |
+
- object budget
|
| 68 |
+
- style and section
|
| 69 |
+
- per-stem descriptors
|
| 70 |
+
- hard rules such as anchors, low-end centering, width targets, and masking constraints
|
| 71 |
+
|
| 72 |
+
### Output
|
| 73 |
+
A single valid JSON object matching `schemas/scene.schema.json`.
|
| 74 |
+
|
| 75 |
+
### Example input
|
| 76 |
+
```json
|
| 77 |
+
{
|
| 78 |
+
"target_format": "iamf",
|
| 79 |
+
"max_objects": 10,
|
| 80 |
+
"style": "club",
|
| 81 |
+
"section": "drop",
|
| 82 |
+
"global": {"bpm": 128, "energy": 0.92},
|
| 83 |
+
"stems": [
|
| 84 |
+
{"id": "v1", "class": "lead_vocal", "lufs": -16.8, "transient": 0.25, "band_energy": {"low": 0.1, "mid": 0.6, "high": 0.3}, "leadness": 0.95},
|
| 85 |
+
{"id": "k1", "class": "kick", "lufs": -10.5, "transient": 0.95, "band_energy": {"low": 0.8, "mid": 0.15, "high": 0.05}, "leadness": 0.25}
|
| 86 |
+
],
|
| 87 |
+
"rules": [
|
| 88 |
+
{"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 89 |
+
{"type": "mono_low_end", "hz_below": 120}
|
| 90 |
+
]
|
| 91 |
+
}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Example output
|
| 95 |
+
```json
|
| 96 |
+
{
|
| 97 |
+
"version": "1.0",
|
| 98 |
+
"bed": {"layout": "iamf", "loudness_target_lufs": -14.0, "room_preset": "club_medium"},
|
| 99 |
+
"objects": [
|
| 100 |
+
{
|
| 101 |
+
"id": "v1",
|
| 102 |
+
"class": "lead_vocal",
|
| 103 |
+
"az_deg": 0,
|
| 104 |
+
"el_deg": 10,
|
| 105 |
+
"dist_m": 1.6,
|
| 106 |
+
"width": 0.15,
|
| 107 |
+
"gain_db": 0.0,
|
| 108 |
+
"reverb_send": 0.18,
|
| 109 |
+
"early_reflections": 0.22,
|
| 110 |
+
"motion": [
|
| 111 |
+
{"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 112 |
+
{"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6}
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"constraints_applied": [
|
| 117 |
+
"anchor:lead_vocal@0/10/1.6",
|
| 118 |
+
"mono_low_end<120Hz"
|
| 119 |
+
]
|
| 120 |
+
}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Repository layout
|
| 124 |
+
|
| 125 |
+
```text
|
| 126 |
+
GravityLLM-HuggingFace-Repo/
|
| 127 |
+
├── README.md
|
| 128 |
+
├── LICENSE
|
| 129 |
+
├── Makefile
|
| 130 |
+
├── pyproject.toml
|
| 131 |
+
├── requirements.txt
|
| 132 |
+
├── train.py
|
| 133 |
+
├── infer.py
|
| 134 |
+
├── evaluate.py
|
| 135 |
+
├── upload_to_hub.py
|
| 136 |
+
├── assets/
|
| 137 |
+
│ └── gravityllm_banner.svg
|
| 138 |
+
├── configs/
|
| 139 |
+
│ └── recommended_train_args.json
|
| 140 |
+
├── data/
|
| 141 |
+
│ ├── train.jsonl
|
| 142 |
+
│ └── valid.jsonl
|
| 143 |
+
├── examples/
|
| 144 |
+
│ ├── sample_input.json
|
| 145 |
+
│ └── sample_output.json
|
| 146 |
+
├── schemas/
|
| 147 |
+
│ └── scene.schema.json
|
| 148 |
+
├── scripts/
|
| 149 |
+
│ ├── push_to_hub.sh
|
| 150 |
+
│ └── train_qlora.sh
|
| 151 |
+
└── tools/
|
| 152 |
+
├── make_synthetic_dataset.py
|
| 153 |
+
└── validate_scene.py
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## Quick start
|
| 157 |
+
|
| 158 |
+
### 1) Install
|
| 159 |
+
```bash
|
| 160 |
+
python -m pip install -r requirements.txt
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### 2) Train with QLoRA
|
| 164 |
+
```bash
|
| 165 |
+
bash scripts/train_qlora.sh
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Or run directly:
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
python train.py --model Qwen/Qwen2.5-1.5B-Instruct --train_file data/train.jsonl --valid_file data/valid.jsonl --output_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --max_length 2048 --num_train_epochs 3 --learning_rate 2e-4 --train_batch_size 1 --eval_batch_size 1 --gradient_accumulation_steps 16 --warmup_ratio 0.03 --save_steps 100 --eval_steps 100 --qlora --bf16
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 3) Generate a scene
|
| 175 |
+
```bash
|
| 176 |
+
python infer.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --input_json examples/sample_input.json --validate --output_json outputs/sample_prediction.json
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### 4) Evaluate
|
| 180 |
+
```bash
|
| 181 |
+
python evaluate.py --model_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 --data_file data/valid.jsonl --report_path reports/eval_report.json
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
### 5) Validate any output
|
| 185 |
+
```bash
|
| 186 |
+
python tools/validate_scene.py schemas/scene.schema.json outputs/sample_prediction.json
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## Push to the Hugging Face Hub
|
| 190 |
+
|
| 191 |
+
### From a trained output folder
|
| 192 |
+
```bash
|
| 193 |
+
python upload_to_hub.py --folder_path outputs/GravityLLM-Qwen2.5-1.5B-S9 --repo_id YOUR_NAMESPACE/GravityLLM-Qwen2.5-1.5B-S9
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
### Or with the helper script
|
| 197 |
+
```bash
|
| 198 |
+
bash scripts/push_to_hub.sh outputs/GravityLLM-Qwen2.5-1.5B-S9 YOUR_NAMESPACE/GravityLLM-Qwen2.5-1.5B-S9
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
## Dataset format
|
| 202 |
+
|
| 203 |
+
Training files are JSONL with two fields per row:
|
| 204 |
+
|
| 205 |
+
```json
|
| 206 |
+
{
|
| 207 |
+
"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{...}",
|
| 208 |
+
"completion": "{... valid Spatial9Scene JSON ...}"
|
| 209 |
+
}
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
The provided sample dataset is intentionally small. Replace it with your real production examples as soon as possible.
|
| 213 |
+
|
| 214 |
+
## Recommended data strategy
|
| 215 |
+
|
| 216 |
+
For a strong first release:
|
| 217 |
+
|
| 218 |
+
1. Collect a few hundred high-quality gold examples from expert-authored scenes.
|
| 219 |
+
2. Keep the schema stable and quantized.
|
| 220 |
+
3. Encode hard rules explicitly instead of relying on vague prose.
|
| 221 |
+
4. Run evaluation after every fine-tune.
|
| 222 |
+
5. Add a post-processor to enforce hard constraints if the runtime must be deterministic.
|
| 223 |
+
|
| 224 |
+
## Suggested training roadmap
|
| 225 |
+
|
| 226 |
+
### v0
|
| 227 |
+
- Small curated dataset
|
| 228 |
+
- QLoRA adapter
|
| 229 |
+
- Schema-valid JSON only
|
| 230 |
+
- Anchor and budget constraints
|
| 231 |
+
|
| 232 |
+
### v1
|
| 233 |
+
- More genres and sections
|
| 234 |
+
- Better masking and width rules
|
| 235 |
+
- Object motion patterns
|
| 236 |
+
- Automatic validation and repair loop
|
| 237 |
+
|
| 238 |
+
### v2
|
| 239 |
+
- Preference tuning on human A/B judgments
|
| 240 |
+
- A dedicated reward signal for clarity, masking avoidance, and translation safety
|
| 241 |
+
|
| 242 |
+
## Intended use
|
| 243 |
+
|
| 244 |
+
GravityLLM is designed for:
|
| 245 |
+
|
| 246 |
+
- music-tech pipelines
|
| 247 |
+
- Spatial9 scene authoring
|
| 248 |
+
- assisted immersive-audio layout generation
|
| 249 |
+
- IAMF-ready authoring workflows
|
| 250 |
+
- renderer-side JSON generation
|
| 251 |
+
|
| 252 |
+
## Limitations
|
| 253 |
+
|
| 254 |
+
- This repo does not include trained weights out of the box.
|
| 255 |
+
- The model only knows what you teach it through your dataset.
|
| 256 |
+
- Raw audio is not consumed directly here; the training pipeline expects structured stem features.
|
| 257 |
+
- Production systems should still validate outputs and optionally apply a rule-based correction pass.
|
| 258 |
+
|
| 259 |
+
## Safety and reliability
|
| 260 |
+
|
| 261 |
+
- Always validate generated scenes against the JSON schema.
|
| 262 |
+
- Keep low-end centering as a hard rule outside the model if that is non-negotiable.
|
| 263 |
+
- Treat the model as a scene proposal engine, not an oracle.
|
| 264 |
+
|
| 265 |
+
## License
|
| 266 |
+
|
| 267 |
+
This repository is released under Apache-2.0.
|
assets/gravityllm_banner.svg
ADDED
|
|
configs/recommended_train_args.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "Qwen/Qwen2.5-1.5B-Instruct",
|
| 3 |
+
"train_file": "data/train.jsonl",
|
| 4 |
+
"valid_file": "data/valid.jsonl",
|
| 5 |
+
"output_dir": "outputs/GravityLLM-Qwen2.5-1.5B-S9",
|
| 6 |
+
"max_length": 2048,
|
| 7 |
+
"num_train_epochs": 3,
|
| 8 |
+
"learning_rate": 0.0002,
|
| 9 |
+
"train_batch_size": 1,
|
| 10 |
+
"eval_batch_size": 1,
|
| 11 |
+
"gradient_accumulation_steps": 16,
|
| 12 |
+
"warmup_ratio": 0.03,
|
| 13 |
+
"weight_decay": 0.0,
|
| 14 |
+
"logging_steps": 10,
|
| 15 |
+
"save_steps": 100,
|
| 16 |
+
"eval_steps": 100,
|
| 17 |
+
"seed": 42,
|
| 18 |
+
"qlora": true,
|
| 19 |
+
"bf16": true,
|
| 20 |
+
"lora_r": 16,
|
| 21 |
+
"lora_alpha": 32,
|
| 22 |
+
"lora_dropout": 0.05
|
| 23 |
+
}
|
data/train.jsonl
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 10,\n \"style\": \"club\",\n \"section\": \"drop\",\n \"global\": {\n \"bpm\": 128,\n \"energy\": 0.92\n },\n \"stems\": [\n {\n \"id\": \"v1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -16.8,\n \"transient\": 0.25,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.6,\n \"high\": 0.3\n },\n \"leadness\": 0.95\n },\n {\n \"id\": \"k1\",\n \"class\": \"kick\",\n \"lufs\": -10.5,\n \"transient\": 0.95,\n \"band_energy\": {\n \"low\": 0.8,\n \"mid\": 0.15,\n \"high\": 0.05\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.2,\n \"transient\": 0.55,\n \"band_energy\": {\n \"low\": 0.85,\n \"mid\": 0.12,\n \"high\": 0.03\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -20.3,\n \"transient\": 0.1,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"s1\",\n \"class\": \"synth_lead\",\n \"lufs\": -18.0,\n \"transient\": 0.4,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.55,\n \"high\": 0.35\n },\n \"leadness\": 0.75\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -23.0,\n \"transient\": 0.2,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.3,\n \"high\": 0.6\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.75\n },\n {\n \"type\": \"avoid_band_masking\",\n \"mask_target\": \"lead_vocal\",\n \"band_hz\": [\n 1500,\n 4500\n ]\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -14.0,\n \"room_preset\": \"club_medium\"\n },\n \"objects\": [\n {\n \"id\": \"v1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6,\n \"width\": 0.15,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.18,\n \"early_reflections\": 0.22,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 10,\n \"dist_m\": 1.6\n }\n ]\n },\n {\n \"id\": \"k1\",\n \"class\": \"kick\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2,\n \"width\": 0.0,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.02,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.2\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6,\n \"width\": 0.05,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.03,\n \"early_reflections\": 0.06,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.6\n }\n ]\n },\n {\n \"id\": \"s1\",\n \"class\": \"synth_lead\",\n \"az_deg\": 25,\n \"el_deg\": 15,\n \"dist_m\": 2.0,\n \"width\": 0.35,\n \"gain_db\": -1.0,\n \"reverb_send\": 0.12,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 20,\n \"el_deg\": 15,\n \"dist_m\": 2.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 15,\n \"dist_m\": 2.0\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0,\n \"width\": 0.85,\n \"gain_db\": -3.0,\n \"reverb_send\": 0.28,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 4.0\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 110,\n \"el_deg\": 30,\n \"dist_m\": 6.0,\n \"width\": 0.65,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.4,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 100,\n \"el_deg\": 30,\n \"dist_m\": 6.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 130,\n \"el_deg\": 30,\n \"dist_m\": 6.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/10/1.6\",\n \"mono_low_end<120Hz\",\n \"pad_width>=0.75\",\n \"avoid_masking_lead_1500-4500Hz\"\n ]\n}"}
|
| 2 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 8,\n \"style\": \"cinematic\",\n \"section\": \"break\",\n \"global\": {\n \"bpm\": 90,\n \"energy\": 0.45\n },\n \"stems\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"lufs\": -18.0,\n \"transient\": 0.35,\n \"band_energy\": {\n \"low\": 0.05,\n \"mid\": 0.75,\n \"high\": 0.2\n },\n \"leadness\": 0.98\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"lufs\": -24.0,\n \"transient\": 0.08,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -25.5,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.55,\n \"high\": 0.2\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.0,\n \"transient\": 0.15,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.25,\n \"high\": 0.65\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"type\": \"keep_dialogue_clear\",\n \"band_hz\": [\n 1000,\n 4000\n ]\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.8\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -16.0,\n \"room_preset\": \"cinema_large\"\n },\n \"objects\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8,\n \"width\": 0.1,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.18,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n }\n ]\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2,\n \"width\": 0.4,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.24,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": -20,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5,\n \"width\": 0.85,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n },\n {\n \"t\": 1.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 95,\n \"el_deg\": 28,\n \"dist_m\": 7.0,\n \"width\": 0.7,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 80,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 120,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:dialogue@0/5/1.8\",\n \"keep_dialogue_clear_1000-4000Hz\",\n \"pad_width>=0.80\"\n ]\n}"}
|
| 3 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 12,\n \"style\": \"live_stage\",\n \"section\": \"chorus\",\n \"global\": {\n \"bpm\": 112,\n \"energy\": 0.78\n },\n \"stems\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -15.5,\n \"transient\": 0.3,\n \"band_energy\": {\n \"low\": 0.08,\n \"mid\": 0.72,\n \"high\": 0.2\n },\n \"leadness\": 0.96\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"lufs\": -17.4,\n \"transient\": 0.46,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.5\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"lufs\": -18.1,\n \"transient\": 0.44,\n \"band_energy\": {\n \"low\": 0.12,\n \"mid\": 0.62,\n \"high\": 0.26\n },\n \"leadness\": 0.42\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"lufs\": -11.8,\n \"transient\": 0.88,\n \"band_energy\": {\n \"low\": 0.4,\n \"mid\": 0.4,\n \"high\": 0.2\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.4,\n \"transient\": 0.61,\n \"band_energy\": {\n \"low\": 0.82,\n \"mid\": 0.14,\n \"high\": 0.04\n },\n \"leadness\": 0.3\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"lufs\": -27.0,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.02\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"stage_pair\",\n \"left_class\": \"guitar\",\n \"right_class\": \"guitar\"\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -15.0,\n \"room_preset\": \"stage_live\"\n },\n \"objects\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7,\n \"width\": 0.18,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.16,\n \"early_reflections\": 0.24,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n }\n ]\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": -30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4,\n \"width\": 0.2,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8,\n \"width\": 0.05,\n \"gain_db\": -0.8,\n \"reverb_send\": 0.04,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 12,\n \"dist_m\": 10.0,\n \"width\": 1.0,\n \"gain_db\": -12.0,\n \"reverb_send\": 0.6,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/8/1.7\",\n \"mono_low_end<120Hz\",\n \"stage_pair:guitars\"\n ]\n}"}
|
| 4 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"binaural\",\n \"max_objects\": 6,\n \"style\": \"ambient\",\n \"section\": \"intro\",\n \"global\": {\n \"bpm\": 72,\n \"energy\": 0.28\n },\n \"stems\": [\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -23.5,\n \"transient\": 0.04,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.52,\n \"high\": 0.23\n },\n \"leadness\": 0.18\n },\n {\n \"id\": \"pi1\",\n \"class\": \"piano\",\n \"lufs\": -20.2,\n \"transient\": 0.22,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.68,\n \"high\": 0.22\n },\n \"leadness\": 0.52\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.4,\n \"transient\": 0.11,\n \"band_energy\": {\n \"low\": 0.07,\n \"mid\": 0.2,\n \"high\": 0.73\n },\n \"leadness\": 0.08\n },\n {\n \"id\": \"a1\",\n \"class\": \"ambience\",\n \"lufs\": -28.0,\n \"transient\": 0.03,\n \"band_energy\": {\n \"low\": 0.18,\n \"mid\": 0.42,\n \"high\": 0.4\n },\n \"leadness\": 0.01\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"piano\",\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.85\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"binaural\",\n \"loudness_target_lufs\": -18.0,\n \"room_preset\": \"ambient_wide\"\n },\n \"objects\": [\n {\n \"id\": \"pi1\",\n \"class\": \"piano\",\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4,\n \"width\": 0.22,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.18,\n \"early_reflections\": 0.17,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -12,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": -8,\n \"el_deg\": 6,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -95,\n \"el_deg\": 20,\n \"dist_m\": 6.0,\n \"width\": 0.9,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -95,\n \"el_deg\": 20,\n \"dist_m\": 6.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": -70,\n \"el_deg\": 20,\n \"dist_m\": 6.0\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 100,\n \"el_deg\": 35,\n \"dist_m\": 8.0,\n \"width\": 0.75,\n \"gain_db\": -8.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.06,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 90,\n \"el_deg\": 35,\n \"dist_m\": 8.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 130,\n \"el_deg\": 35,\n \"dist_m\": 8.0\n }\n ]\n },\n {\n \"id\": \"a1\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 10,\n \"dist_m\": 12.0,\n \"width\": 1.0,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.65,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 10,\n \"dist_m\": 12.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 10,\n \"dist_m\": 12.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:piano@-12/6/2.4\",\n \"pad_width>=0.85\"\n ]\n}"}
|
data/valid.jsonl
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 8,\n \"style\": \"cinematic\",\n \"section\": \"break\",\n \"global\": {\n \"bpm\": 90,\n \"energy\": 0.45\n },\n \"stems\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"lufs\": -18.0,\n \"transient\": 0.35,\n \"band_energy\": {\n \"low\": 0.05,\n \"mid\": 0.75,\n \"high\": 0.2\n },\n \"leadness\": 0.98\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"lufs\": -24.0,\n \"transient\": 0.08,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.25\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"lufs\": -25.5,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.25,\n \"mid\": 0.55,\n \"high\": 0.2\n },\n \"leadness\": 0.1\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"lufs\": -26.0,\n \"transient\": 0.15,\n \"band_energy\": {\n \"low\": 0.1,\n \"mid\": 0.25,\n \"high\": 0.65\n },\n \"leadness\": 0.05\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"type\": \"keep_dialogue_clear\",\n \"band_hz\": [\n 1000,\n 4000\n ]\n },\n {\n \"type\": \"width_pref\",\n \"track_class\": \"pad\",\n \"min_width\": 0.8\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -16.0,\n \"room_preset\": \"cinema_large\"\n },\n \"objects\": [\n {\n \"id\": \"d1\",\n \"class\": \"dialogue\",\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8,\n \"width\": 0.1,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.18,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 5,\n \"dist_m\": 1.8\n }\n ]\n },\n {\n \"id\": \"st1\",\n \"class\": \"strings\",\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2,\n \"width\": 0.4,\n \"gain_db\": -4.0,\n \"reverb_send\": 0.24,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n },\n {\n \"t\": 1.0,\n \"az_deg\": -20,\n \"el_deg\": 18,\n \"dist_m\": 4.2\n }\n ]\n },\n {\n \"id\": \"p1\",\n \"class\": \"pad\",\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5,\n \"width\": 0.85,\n \"gain_db\": -6.0,\n \"reverb_send\": 0.35,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n },\n {\n \"t\": 1.0,\n \"az_deg\": -80,\n \"el_deg\": 18,\n \"dist_m\": 5.5\n }\n ]\n },\n {\n \"id\": \"fx1\",\n \"class\": \"fx\",\n \"az_deg\": 95,\n \"el_deg\": 28,\n \"dist_m\": 7.0,\n \"width\": 0.7,\n \"gain_db\": -10.0,\n \"reverb_send\": 0.55,\n \"early_reflections\": 0.08,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 80,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 120,\n \"el_deg\": 28,\n \"dist_m\": 7.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:dialogue@0/5/1.8\",\n \"keep_dialogue_clear_1000-4000Hz\",\n \"pad_width>=0.80\"\n ]\n}"}
|
| 2 |
+
{"prompt": "GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\nINPUT:\n{\n \"target_format\": \"iamf\",\n \"max_objects\": 12,\n \"style\": \"live_stage\",\n \"section\": \"chorus\",\n \"global\": {\n \"bpm\": 112,\n \"energy\": 0.78\n },\n \"stems\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"lufs\": -15.5,\n \"transient\": 0.3,\n \"band_energy\": {\n \"low\": 0.08,\n \"mid\": 0.72,\n \"high\": 0.2\n },\n \"leadness\": 0.96\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"lufs\": -17.4,\n \"transient\": 0.46,\n \"band_energy\": {\n \"low\": 0.15,\n \"mid\": 0.6,\n \"high\": 0.25\n },\n \"leadness\": 0.5\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"lufs\": -18.1,\n \"transient\": 0.44,\n \"band_energy\": {\n \"low\": 0.12,\n \"mid\": 0.62,\n \"high\": 0.26\n },\n \"leadness\": 0.42\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"lufs\": -11.8,\n \"transient\": 0.88,\n \"band_energy\": {\n \"low\": 0.4,\n \"mid\": 0.4,\n \"high\": 0.2\n },\n \"leadness\": 0.35\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"lufs\": -12.4,\n \"transient\": 0.61,\n \"band_energy\": {\n \"low\": 0.82,\n \"mid\": 0.14,\n \"high\": 0.04\n },\n \"leadness\": 0.3\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"lufs\": -27.0,\n \"transient\": 0.05,\n \"band_energy\": {\n \"low\": 0.2,\n \"mid\": 0.5,\n \"high\": 0.3\n },\n \"leadness\": 0.02\n }\n ],\n \"rules\": [\n {\n \"type\": \"anchor\",\n \"track_class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"type\": \"mono_low_end\",\n \"hz_below\": 120\n },\n {\n \"type\": \"stage_pair\",\n \"left_class\": \"guitar\",\n \"right_class\": \"guitar\"\n }\n ]\n}", "completion": "{\n \"version\": \"1.0\",\n \"bed\": {\n \"layout\": \"iamf\",\n \"loudness_target_lufs\": -15.0,\n \"room_preset\": \"stage_live\"\n },\n \"objects\": [\n {\n \"id\": \"lv1\",\n \"class\": \"lead_vocal\",\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7,\n \"width\": 0.18,\n \"gain_db\": 0.0,\n \"reverb_send\": 0.16,\n \"early_reflections\": 0.24,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 8,\n \"dist_m\": 1.7\n }\n ]\n },\n {\n \"id\": \"g1\",\n \"class\": \"guitar\",\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": -35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": -30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"g2\",\n \"class\": \"guitar\",\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8,\n \"width\": 0.26,\n \"gain_db\": -1.5,\n \"reverb_send\": 0.11,\n \"early_reflections\": 0.15,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 35,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 30,\n \"el_deg\": 6,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"dr1\",\n \"class\": \"drums_bus\",\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4,\n \"width\": 0.2,\n \"gain_db\": -0.5,\n \"reverb_send\": 0.08,\n \"early_reflections\": 0.12,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": 0,\n \"dist_m\": 2.4\n }\n ]\n },\n {\n \"id\": \"b1\",\n \"class\": \"bass\",\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8,\n \"width\": 0.05,\n \"gain_db\": -0.8,\n \"reverb_send\": 0.04,\n \"early_reflections\": 0.05,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n },\n {\n \"t\": 1.0,\n \"az_deg\": 0,\n \"el_deg\": -5,\n \"dist_m\": 2.8\n }\n ]\n },\n {\n \"id\": \"crowd\",\n \"class\": \"ambience\",\n \"az_deg\": 180,\n \"el_deg\": 12,\n \"dist_m\": 10.0,\n \"width\": 1.0,\n \"gain_db\": -12.0,\n \"reverb_send\": 0.6,\n \"early_reflections\": 0.1,\n \"motion\": [\n {\n \"t\": 0.0,\n \"az_deg\": 160,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n },\n {\n \"t\": 1.0,\n \"az_deg\": 200,\n \"el_deg\": 12,\n \"dist_m\": 10.0\n }\n ]\n }\n ],\n \"constraints_applied\": [\n \"anchor:lead_vocal@0/8/1.7\",\n \"mono_low_end<120Hz\",\n \"stage_pair:guitars\"\n ]\n}"}
|
evaluate.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from jsonschema import Draft7Validator
|
| 10 |
+
from peft import AutoPeftModelForCausalLM
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
|
| 13 |
+
SYSTEM_PREFIX = (
|
| 14 |
+
"You are GravityLLM, a Spatial9 scene generation model. "
|
| 15 |
+
"Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
|
| 16 |
+
"Do not return markdown. Do not explain your answer.\n\n"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_args() -> argparse.Namespace:
|
| 21 |
+
parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.")
|
| 22 |
+
parser.add_argument("--model_dir", type=str, required=True)
|
| 23 |
+
parser.add_argument("--data_file", type=str, default="data/valid.jsonl")
|
| 24 |
+
parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json"))
|
| 25 |
+
parser.add_argument("--max_new_tokens", type=int, default=900)
|
| 26 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 27 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
| 28 |
+
parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.")
|
| 29 |
+
parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json"))
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model_and_tokenizer(model_dir: str):
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True)
|
| 35 |
+
if tokenizer.pad_token is None:
|
| 36 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
model = AutoPeftModelForCausalLM.from_pretrained(
|
| 40 |
+
model_dir,
|
| 41 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
|
| 42 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 43 |
+
trust_remote_code=True,
|
| 44 |
+
)
|
| 45 |
+
except Exception:
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_dir,
|
| 48 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
|
| 49 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
)
|
| 52 |
+
model.eval()
|
| 53 |
+
return model, tokenizer
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def format_prompt(raw_prompt: str) -> str:
|
| 57 |
+
raw_prompt = raw_prompt.strip()
|
| 58 |
+
if raw_prompt.lower().startswith("gravityllm:"):
|
| 59 |
+
raw_prompt = raw_prompt.split(":", 1)[1].strip()
|
| 60 |
+
return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def extract_first_json(text: str) -> str:
|
| 64 |
+
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 65 |
+
return match.group(0).strip() if match else text.strip()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]:
|
| 69 |
+
data = json.loads(output_text)
|
| 70 |
+
validator = Draft7Validator(schema)
|
| 71 |
+
errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path))
|
| 72 |
+
return len(errors) == 0, data
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def check_budget(input_payload: Dict, scene_payload: Dict) -> bool:
|
| 76 |
+
max_objects = input_payload.get("max_objects")
|
| 77 |
+
if max_objects is None:
|
| 78 |
+
return True
|
| 79 |
+
return len(scene_payload.get("objects", [])) <= max_objects
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool:
|
| 83 |
+
objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])}
|
| 84 |
+
for rule in input_payload.get("rules", []):
|
| 85 |
+
if rule.get("type") != "anchor":
|
| 86 |
+
continue
|
| 87 |
+
klass = rule.get("track_class")
|
| 88 |
+
obj = objects.get(klass)
|
| 89 |
+
if obj is None:
|
| 90 |
+
return False
|
| 91 |
+
for field in ["az_deg", "el_deg", "dist_m"]:
|
| 92 |
+
if float(obj[field]) != float(rule[field]):
|
| 93 |
+
return False
|
| 94 |
+
return True
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
|
| 98 |
+
inputs = tokenizer(prompt_text, return_tensors="pt")
|
| 99 |
+
if torch.cuda.is_available():
|
| 100 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
outputs = model.generate(
|
| 104 |
+
**inputs,
|
| 105 |
+
max_new_tokens=max_new_tokens,
|
| 106 |
+
do_sample=True,
|
| 107 |
+
temperature=temperature,
|
| 108 |
+
top_p=top_p,
|
| 109 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 110 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 114 |
+
prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
| 115 |
+
raw_completion = decoded[len(prompt_prefix):].strip()
|
| 116 |
+
return extract_first_json(raw_completion)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main() -> None:
|
| 120 |
+
args = parse_args()
|
| 121 |
+
schema = json.loads(args.schema_path.read_text(encoding="utf-8"))
|
| 122 |
+
ds = load_dataset("json", data_files=args.data_file, split="train")
|
| 123 |
+
if args.limit > 0:
|
| 124 |
+
ds = ds.select(range(min(args.limit, len(ds))))
|
| 125 |
+
|
| 126 |
+
model, tokenizer = load_model_and_tokenizer(args.model_dir)
|
| 127 |
+
|
| 128 |
+
total = len(ds)
|
| 129 |
+
parse_ok = 0
|
| 130 |
+
schema_ok = 0
|
| 131 |
+
budget_ok = 0
|
| 132 |
+
anchor_ok = 0
|
| 133 |
+
samples = []
|
| 134 |
+
|
| 135 |
+
for row in ds:
|
| 136 |
+
prompt_text = format_prompt(row["prompt"])
|
| 137 |
+
generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p)
|
| 138 |
+
|
| 139 |
+
sample_report = {"prompt": row["prompt"], "generated": generated}
|
| 140 |
+
try:
|
| 141 |
+
gen_data = json.loads(generated)
|
| 142 |
+
parse_ok += 1
|
| 143 |
+
valid, gen_scene = validate_schema(schema, generated)
|
| 144 |
+
if valid:
|
| 145 |
+
schema_ok += 1
|
| 146 |
+
# Reconstruct input payload from prompt for simple rule checks.
|
| 147 |
+
prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1]
|
| 148 |
+
input_payload = json.loads(prompt_payload_text)
|
| 149 |
+
if check_budget(input_payload, gen_scene):
|
| 150 |
+
budget_ok += 1
|
| 151 |
+
if check_anchor_rules(input_payload, gen_scene):
|
| 152 |
+
anchor_ok += 1
|
| 153 |
+
sample_report["schema_valid"] = True
|
| 154 |
+
sample_report["budget_pass"] = check_budget(input_payload, gen_scene)
|
| 155 |
+
sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene)
|
| 156 |
+
else:
|
| 157 |
+
sample_report["schema_valid"] = False
|
| 158 |
+
except Exception as exc:
|
| 159 |
+
sample_report["error"] = str(exc)
|
| 160 |
+
|
| 161 |
+
samples.append(sample_report)
|
| 162 |
+
|
| 163 |
+
report = {
|
| 164 |
+
"examples": total,
|
| 165 |
+
"json_parse_rate": round(parse_ok / total, 4) if total else 0.0,
|
| 166 |
+
"schema_valid_rate": round(schema_ok / total, 4) if total else 0.0,
|
| 167 |
+
"budget_pass_rate": round(budget_ok / total, 4) if total else 0.0,
|
| 168 |
+
"anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0,
|
| 169 |
+
"samples": samples[:10],
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
args.report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 173 |
+
args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
|
| 174 |
+
print(json.dumps(report, indent=2))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
examples/sample_input.json
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"target_format": "iamf",
|
| 3 |
+
"max_objects": 10,
|
| 4 |
+
"style": "club",
|
| 5 |
+
"section": "drop",
|
| 6 |
+
"global": {
|
| 7 |
+
"bpm": 128,
|
| 8 |
+
"energy": 0.92
|
| 9 |
+
},
|
| 10 |
+
"stems": [
|
| 11 |
+
{
|
| 12 |
+
"id": "v1",
|
| 13 |
+
"class": "lead_vocal",
|
| 14 |
+
"lufs": -16.8,
|
| 15 |
+
"transient": 0.25,
|
| 16 |
+
"band_energy": {
|
| 17 |
+
"low": 0.1,
|
| 18 |
+
"mid": 0.6,
|
| 19 |
+
"high": 0.3
|
| 20 |
+
},
|
| 21 |
+
"leadness": 0.95
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "k1",
|
| 25 |
+
"class": "kick",
|
| 26 |
+
"lufs": -10.5,
|
| 27 |
+
"transient": 0.95,
|
| 28 |
+
"band_energy": {
|
| 29 |
+
"low": 0.8,
|
| 30 |
+
"mid": 0.15,
|
| 31 |
+
"high": 0.05
|
| 32 |
+
},
|
| 33 |
+
"leadness": 0.25
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"id": "b1",
|
| 37 |
+
"class": "bass",
|
| 38 |
+
"lufs": -12.2,
|
| 39 |
+
"transient": 0.55,
|
| 40 |
+
"band_energy": {
|
| 41 |
+
"low": 0.85,
|
| 42 |
+
"mid": 0.12,
|
| 43 |
+
"high": 0.03
|
| 44 |
+
},
|
| 45 |
+
"leadness": 0.35
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"id": "p1",
|
| 49 |
+
"class": "pad",
|
| 50 |
+
"lufs": -20.3,
|
| 51 |
+
"transient": 0.1,
|
| 52 |
+
"band_energy": {
|
| 53 |
+
"low": 0.2,
|
| 54 |
+
"mid": 0.5,
|
| 55 |
+
"high": 0.3
|
| 56 |
+
},
|
| 57 |
+
"leadness": 0.1
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"id": "s1",
|
| 61 |
+
"class": "synth_lead",
|
| 62 |
+
"lufs": -18.0,
|
| 63 |
+
"transient": 0.4,
|
| 64 |
+
"band_energy": {
|
| 65 |
+
"low": 0.1,
|
| 66 |
+
"mid": 0.55,
|
| 67 |
+
"high": 0.35
|
| 68 |
+
},
|
| 69 |
+
"leadness": 0.75
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"id": "fx1",
|
| 73 |
+
"class": "fx",
|
| 74 |
+
"lufs": -23.0,
|
| 75 |
+
"transient": 0.2,
|
| 76 |
+
"band_energy": {
|
| 77 |
+
"low": 0.1,
|
| 78 |
+
"mid": 0.3,
|
| 79 |
+
"high": 0.6
|
| 80 |
+
},
|
| 81 |
+
"leadness": 0.05
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"rules": [
|
| 85 |
+
{
|
| 86 |
+
"type": "anchor",
|
| 87 |
+
"track_class": "lead_vocal",
|
| 88 |
+
"az_deg": 0,
|
| 89 |
+
"el_deg": 10,
|
| 90 |
+
"dist_m": 1.6
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"type": "mono_low_end",
|
| 94 |
+
"hz_below": 120
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"type": "width_pref",
|
| 98 |
+
"track_class": "pad",
|
| 99 |
+
"min_width": 0.75
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"type": "avoid_band_masking",
|
| 103 |
+
"mask_target": "lead_vocal",
|
| 104 |
+
"band_hz": [
|
| 105 |
+
1500,
|
| 106 |
+
4500
|
| 107 |
+
]
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
}
|
examples/sample_output.json
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0",
|
| 3 |
+
"bed": {
|
| 4 |
+
"layout": "iamf",
|
| 5 |
+
"loudness_target_lufs": -14.0,
|
| 6 |
+
"room_preset": "club_medium"
|
| 7 |
+
},
|
| 8 |
+
"objects": [
|
| 9 |
+
{
|
| 10 |
+
"id": "v1",
|
| 11 |
+
"class": "lead_vocal",
|
| 12 |
+
"az_deg": 0,
|
| 13 |
+
"el_deg": 10,
|
| 14 |
+
"dist_m": 1.6,
|
| 15 |
+
"width": 0.15,
|
| 16 |
+
"gain_db": 0.0,
|
| 17 |
+
"reverb_send": 0.18,
|
| 18 |
+
"early_reflections": 0.22,
|
| 19 |
+
"motion": [
|
| 20 |
+
{
|
| 21 |
+
"t": 0.0,
|
| 22 |
+
"az_deg": 0,
|
| 23 |
+
"el_deg": 10,
|
| 24 |
+
"dist_m": 1.6
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"t": 1.0,
|
| 28 |
+
"az_deg": 0,
|
| 29 |
+
"el_deg": 10,
|
| 30 |
+
"dist_m": 1.6
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"id": "k1",
|
| 36 |
+
"class": "kick",
|
| 37 |
+
"az_deg": 0,
|
| 38 |
+
"el_deg": 0,
|
| 39 |
+
"dist_m": 2.2,
|
| 40 |
+
"width": 0.0,
|
| 41 |
+
"gain_db": 0.0,
|
| 42 |
+
"reverb_send": 0.02,
|
| 43 |
+
"early_reflections": 0.05,
|
| 44 |
+
"motion": [
|
| 45 |
+
{
|
| 46 |
+
"t": 0.0,
|
| 47 |
+
"az_deg": 0,
|
| 48 |
+
"el_deg": 0,
|
| 49 |
+
"dist_m": 2.2
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"t": 1.0,
|
| 53 |
+
"az_deg": 0,
|
| 54 |
+
"el_deg": 0,
|
| 55 |
+
"dist_m": 2.2
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"id": "b1",
|
| 61 |
+
"class": "bass",
|
| 62 |
+
"az_deg": 0,
|
| 63 |
+
"el_deg": -5,
|
| 64 |
+
"dist_m": 2.6,
|
| 65 |
+
"width": 0.05,
|
| 66 |
+
"gain_db": -0.5,
|
| 67 |
+
"reverb_send": 0.03,
|
| 68 |
+
"early_reflections": 0.06,
|
| 69 |
+
"motion": [
|
| 70 |
+
{
|
| 71 |
+
"t": 0.0,
|
| 72 |
+
"az_deg": 0,
|
| 73 |
+
"el_deg": -5,
|
| 74 |
+
"dist_m": 2.6
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"t": 1.0,
|
| 78 |
+
"az_deg": 0,
|
| 79 |
+
"el_deg": -5,
|
| 80 |
+
"dist_m": 2.6
|
| 81 |
+
}
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"id": "s1",
|
| 86 |
+
"class": "synth_lead",
|
| 87 |
+
"az_deg": 25,
|
| 88 |
+
"el_deg": 15,
|
| 89 |
+
"dist_m": 2.0,
|
| 90 |
+
"width": 0.35,
|
| 91 |
+
"gain_db": -1.0,
|
| 92 |
+
"reverb_send": 0.12,
|
| 93 |
+
"early_reflections": 0.1,
|
| 94 |
+
"motion": [
|
| 95 |
+
{
|
| 96 |
+
"t": 0.0,
|
| 97 |
+
"az_deg": 20,
|
| 98 |
+
"el_deg": 15,
|
| 99 |
+
"dist_m": 2.0
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"t": 1.0,
|
| 103 |
+
"az_deg": 30,
|
| 104 |
+
"el_deg": 15,
|
| 105 |
+
"dist_m": 2.0
|
| 106 |
+
}
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"id": "p1",
|
| 111 |
+
"class": "pad",
|
| 112 |
+
"az_deg": -70,
|
| 113 |
+
"el_deg": 20,
|
| 114 |
+
"dist_m": 4.0,
|
| 115 |
+
"width": 0.85,
|
| 116 |
+
"gain_db": -3.0,
|
| 117 |
+
"reverb_send": 0.28,
|
| 118 |
+
"early_reflections": 0.12,
|
| 119 |
+
"motion": [
|
| 120 |
+
{
|
| 121 |
+
"t": 0.0,
|
| 122 |
+
"az_deg": -70,
|
| 123 |
+
"el_deg": 20,
|
| 124 |
+
"dist_m": 4.0
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"t": 1.0,
|
| 128 |
+
"az_deg": -70,
|
| 129 |
+
"el_deg": 20,
|
| 130 |
+
"dist_m": 4.0
|
| 131 |
+
}
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"id": "fx1",
|
| 136 |
+
"class": "fx",
|
| 137 |
+
"az_deg": 110,
|
| 138 |
+
"el_deg": 30,
|
| 139 |
+
"dist_m": 6.0,
|
| 140 |
+
"width": 0.65,
|
| 141 |
+
"gain_db": -6.0,
|
| 142 |
+
"reverb_send": 0.4,
|
| 143 |
+
"early_reflections": 0.08,
|
| 144 |
+
"motion": [
|
| 145 |
+
{
|
| 146 |
+
"t": 0.0,
|
| 147 |
+
"az_deg": 100,
|
| 148 |
+
"el_deg": 30,
|
| 149 |
+
"dist_m": 6.0
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"t": 1.0,
|
| 153 |
+
"az_deg": 130,
|
| 154 |
+
"el_deg": 30,
|
| 155 |
+
"dist_m": 6.0
|
| 156 |
+
}
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
],
|
| 160 |
+
"constraints_applied": [
|
| 161 |
+
"anchor:lead_vocal@0/10/1.6",
|
| 162 |
+
"mono_low_end<120Hz",
|
| 163 |
+
"pad_width>=0.75",
|
| 164 |
+
"avoid_masking_lead_1500-4500Hz"
|
| 165 |
+
]
|
| 166 |
+
}
|
infer.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from jsonschema import Draft7Validator
|
| 8 |
+
from peft import AutoPeftModelForCausalLM
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
SYSTEM_PREFIX = (
|
| 12 |
+
"You are GravityLLM, a Spatial9 scene generation model. "
|
| 13 |
+
"Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
|
| 14 |
+
"Do not return markdown. Do not explain your answer.\n\n"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args() -> argparse.Namespace:
|
| 19 |
+
parser = argparse.ArgumentParser(description="Run GravityLLM inference on a Spatial9 constraint payload.")
|
| 20 |
+
parser.add_argument("--model_dir", type=str, required=True, help="Path or Hub repo id for trained model or adapter.")
|
| 21 |
+
parser.add_argument("--input_json", type=Path, required=True)
|
| 22 |
+
parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json"))
|
| 23 |
+
parser.add_argument("--output_json", type=Path, default=None)
|
| 24 |
+
parser.add_argument("--max_new_tokens", type=int, default=900)
|
| 25 |
+
parser.add_argument("--temperature", type=float, default=0.35)
|
| 26 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
| 27 |
+
parser.add_argument("--validate", action="store_true")
|
| 28 |
+
return parser.parse_args()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_model_and_tokenizer(model_dir: str):
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True)
|
| 33 |
+
if tokenizer.pad_token is None:
|
| 34 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 35 |
+
|
| 36 |
+
model = None
|
| 37 |
+
try:
|
| 38 |
+
model = AutoPeftModelForCausalLM.from_pretrained(
|
| 39 |
+
model_dir,
|
| 40 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
|
| 41 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 42 |
+
trust_remote_code=True,
|
| 43 |
+
)
|
| 44 |
+
except Exception:
|
| 45 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
model_dir,
|
| 47 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
|
| 48 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 49 |
+
trust_remote_code=True,
|
| 50 |
+
)
|
| 51 |
+
model.eval()
|
| 52 |
+
return model, tokenizer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def extract_first_json(text: str) -> str:
|
| 56 |
+
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 57 |
+
return match.group(0).strip() if match else text.strip()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def validate_output(schema_path: Path, output_text: str):
|
| 61 |
+
schema = json.loads(schema_path.read_text(encoding="utf-8"))
|
| 62 |
+
data = json.loads(output_text)
|
| 63 |
+
validator = Draft7Validator(schema)
|
| 64 |
+
errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path))
|
| 65 |
+
return data, errors
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main() -> None:
|
| 69 |
+
args = parse_args()
|
| 70 |
+
payload = json.loads(args.input_json.read_text(encoding="utf-8"))
|
| 71 |
+
|
| 72 |
+
model, tokenizer = load_model_and_tokenizer(args.model_dir)
|
| 73 |
+
prompt = SYSTEM_PREFIX + "INPUT:\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\n\nOUTPUT:\n"
|
| 74 |
+
|
| 75 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 76 |
+
if torch.cuda.is_available():
|
| 77 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
outputs = model.generate(
|
| 81 |
+
**inputs,
|
| 82 |
+
max_new_tokens=args.max_new_tokens,
|
| 83 |
+
do_sample=True,
|
| 84 |
+
temperature=args.temperature,
|
| 85 |
+
top_p=args.top_p,
|
| 86 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 87 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 91 |
+
prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
| 92 |
+
raw_completion = decoded[len(prompt_prefix):].strip()
|
| 93 |
+
json_text = extract_first_json(raw_completion)
|
| 94 |
+
|
| 95 |
+
if args.validate:
|
| 96 |
+
try:
|
| 97 |
+
_, errors = validate_output(args.schema_path, json_text)
|
| 98 |
+
if errors:
|
| 99 |
+
print("Validation: INVALID")
|
| 100 |
+
for err in errors[:20]:
|
| 101 |
+
path = ".".join(str(p) for p in err.path)
|
| 102 |
+
print(f"- {path}: {err.message}")
|
| 103 |
+
else:
|
| 104 |
+
print("Validation: VALID")
|
| 105 |
+
except Exception as exc:
|
| 106 |
+
print(f"Validation failed: {exc}")
|
| 107 |
+
|
| 108 |
+
if args.output_json:
|
| 109 |
+
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
args.output_json.write_text(json_text + "\n", encoding="utf-8")
|
| 111 |
+
print(f"Wrote output to {args.output_json}")
|
| 112 |
+
|
| 113 |
+
print(json_text)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "gravityllm"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Constraint-driven Spatial9 scene generation for immersive audio."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
license = {text = "Apache-2.0"}
|
| 8 |
+
dependencies = [
|
| 9 |
+
"accelerate>=0.31.0",
|
| 10 |
+
"bitsandbytes>=0.43.0",
|
| 11 |
+
"datasets>=2.19.0",
|
| 12 |
+
"huggingface_hub>=0.24.0",
|
| 13 |
+
"jsonschema>=4.22.0",
|
| 14 |
+
"peft>=0.11.0",
|
| 15 |
+
"safetensors>=0.4.3",
|
| 16 |
+
"torch",
|
| 17 |
+
"transformers>=4.41.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.setuptools]
|
| 21 |
+
py-modules = ["train", "infer", "evaluate", "upload_to_hub"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.31.0
|
| 2 |
+
bitsandbytes>=0.43.0
|
| 3 |
+
datasets>=2.19.0
|
| 4 |
+
huggingface_hub>=0.24.0
|
| 5 |
+
jsonschema>=4.22.0
|
| 6 |
+
peft>=0.11.0
|
| 7 |
+
safetensors>=0.4.3
|
| 8 |
+
torch
|
| 9 |
+
transformers>=4.41.0
|
schemas/scene.schema.json
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "http://json-schema.org/draft-07/schema#",
|
| 3 |
+
"title": "Spatial9Scene",
|
| 4 |
+
"type": "object",
|
| 5 |
+
"required": [
|
| 6 |
+
"version",
|
| 7 |
+
"bed",
|
| 8 |
+
"objects",
|
| 9 |
+
"constraints_applied"
|
| 10 |
+
],
|
| 11 |
+
"properties": {
|
| 12 |
+
"version": {
|
| 13 |
+
"type": "string"
|
| 14 |
+
},
|
| 15 |
+
"bed": {
|
| 16 |
+
"type": "object",
|
| 17 |
+
"required": [
|
| 18 |
+
"layout",
|
| 19 |
+
"loudness_target_lufs",
|
| 20 |
+
"room_preset"
|
| 21 |
+
],
|
| 22 |
+
"properties": {
|
| 23 |
+
"layout": {
|
| 24 |
+
"type": "string",
|
| 25 |
+
"enum": [
|
| 26 |
+
"binaural",
|
| 27 |
+
"5.1.4",
|
| 28 |
+
"7.1.4",
|
| 29 |
+
"iamf"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
"loudness_target_lufs": {
|
| 33 |
+
"type": "number"
|
| 34 |
+
},
|
| 35 |
+
"room_preset": {
|
| 36 |
+
"type": "string"
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"additionalProperties": false
|
| 40 |
+
},
|
| 41 |
+
"objects": {
|
| 42 |
+
"type": "array",
|
| 43 |
+
"minItems": 1,
|
| 44 |
+
"maxItems": 32,
|
| 45 |
+
"items": {
|
| 46 |
+
"type": "object",
|
| 47 |
+
"required": [
|
| 48 |
+
"id",
|
| 49 |
+
"class",
|
| 50 |
+
"az_deg",
|
| 51 |
+
"el_deg",
|
| 52 |
+
"dist_m",
|
| 53 |
+
"width",
|
| 54 |
+
"gain_db",
|
| 55 |
+
"reverb_send",
|
| 56 |
+
"early_reflections",
|
| 57 |
+
"motion"
|
| 58 |
+
],
|
| 59 |
+
"properties": {
|
| 60 |
+
"id": {
|
| 61 |
+
"type": "string"
|
| 62 |
+
},
|
| 63 |
+
"class": {
|
| 64 |
+
"type": "string",
|
| 65 |
+
"enum": [
|
| 66 |
+
"lead_vocal",
|
| 67 |
+
"back_vocal",
|
| 68 |
+
"kick",
|
| 69 |
+
"snare",
|
| 70 |
+
"hihat",
|
| 71 |
+
"bass",
|
| 72 |
+
"drums_bus",
|
| 73 |
+
"pad",
|
| 74 |
+
"synth_lead",
|
| 75 |
+
"guitar",
|
| 76 |
+
"piano",
|
| 77 |
+
"strings",
|
| 78 |
+
"brass",
|
| 79 |
+
"fx",
|
| 80 |
+
"dialogue",
|
| 81 |
+
"ambience",
|
| 82 |
+
"other"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
"az_deg": {
|
| 86 |
+
"type": "number",
|
| 87 |
+
"minimum": -180,
|
| 88 |
+
"maximum": 180
|
| 89 |
+
},
|
| 90 |
+
"el_deg": {
|
| 91 |
+
"type": "number",
|
| 92 |
+
"minimum": -45,
|
| 93 |
+
"maximum": 90
|
| 94 |
+
},
|
| 95 |
+
"dist_m": {
|
| 96 |
+
"type": "number",
|
| 97 |
+
"minimum": 0.5,
|
| 98 |
+
"maximum": 15.0
|
| 99 |
+
},
|
| 100 |
+
"width": {
|
| 101 |
+
"type": "number",
|
| 102 |
+
"minimum": 0.0,
|
| 103 |
+
"maximum": 1.0
|
| 104 |
+
},
|
| 105 |
+
"gain_db": {
|
| 106 |
+
"type": "number",
|
| 107 |
+
"minimum": -60,
|
| 108 |
+
"maximum": 12
|
| 109 |
+
},
|
| 110 |
+
"reverb_send": {
|
| 111 |
+
"type": "number",
|
| 112 |
+
"minimum": 0.0,
|
| 113 |
+
"maximum": 1.0
|
| 114 |
+
},
|
| 115 |
+
"early_reflections": {
|
| 116 |
+
"type": "number",
|
| 117 |
+
"minimum": 0.0,
|
| 118 |
+
"maximum": 1.0
|
| 119 |
+
},
|
| 120 |
+
"motion": {
|
| 121 |
+
"type": "array",
|
| 122 |
+
"maxItems": 64,
|
| 123 |
+
"items": {
|
| 124 |
+
"type": "object",
|
| 125 |
+
"required": [
|
| 126 |
+
"t",
|
| 127 |
+
"az_deg",
|
| 128 |
+
"el_deg",
|
| 129 |
+
"dist_m"
|
| 130 |
+
],
|
| 131 |
+
"properties": {
|
| 132 |
+
"t": {
|
| 133 |
+
"type": "number",
|
| 134 |
+
"minimum": 0.0,
|
| 135 |
+
"maximum": 1.0
|
| 136 |
+
},
|
| 137 |
+
"az_deg": {
|
| 138 |
+
"type": "number",
|
| 139 |
+
"minimum": -180,
|
| 140 |
+
"maximum": 180
|
| 141 |
+
},
|
| 142 |
+
"el_deg": {
|
| 143 |
+
"type": "number",
|
| 144 |
+
"minimum": -45,
|
| 145 |
+
"maximum": 90
|
| 146 |
+
},
|
| 147 |
+
"dist_m": {
|
| 148 |
+
"type": "number",
|
| 149 |
+
"minimum": 0.5,
|
| 150 |
+
"maximum": 15.0
|
| 151 |
+
}
|
| 152 |
+
},
|
| 153 |
+
"additionalProperties": false
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
},
|
| 157 |
+
"additionalProperties": false
|
| 158 |
+
}
|
| 159 |
+
},
|
| 160 |
+
"constraints_applied": {
|
| 161 |
+
"type": "array",
|
| 162 |
+
"items": {
|
| 163 |
+
"type": "string"
|
| 164 |
+
},
|
| 165 |
+
"maxItems": 64
|
| 166 |
+
}
|
| 167 |
+
},
|
| 168 |
+
"additionalProperties": false
|
| 169 |
+
}
|
scripts/push_to_hub.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
if [ $# -lt 2 ]; then
|
| 5 |
+
echo "Usage: scripts/push_to_hub.sh <folder_path> <namespace/repo>"
|
| 6 |
+
exit 1
|
| 7 |
+
fi
|
| 8 |
+
|
| 9 |
+
python upload_to_hub.py --folder_path "$1" --repo_id "$2"
|
scripts/train_qlora.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
python train.py \
|
| 5 |
+
--model Qwen/Qwen2.5-1.5B-Instruct \
|
| 6 |
+
--train_file data/train.jsonl \
|
| 7 |
+
--valid_file data/valid.jsonl \
|
| 8 |
+
--output_dir outputs/GravityLLM-Qwen2.5-1.5B-S9 \
|
| 9 |
+
--max_length 2048 \
|
| 10 |
+
--num_train_epochs 3 \
|
| 11 |
+
--learning_rate 2e-4 \
|
| 12 |
+
--train_batch_size 1 \
|
| 13 |
+
--eval_batch_size 1 \
|
| 14 |
+
--gradient_accumulation_steps 16 \
|
| 15 |
+
--warmup_ratio 0.03 \
|
| 16 |
+
--logging_steps 10 \
|
| 17 |
+
--save_steps 100 \
|
| 18 |
+
--eval_steps 100 \
|
| 19 |
+
--qlora --bf16
|
tools/make_synthetic_dataset.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
CLASSES = [
|
| 8 |
+
("lead_vocal", 0.95),
|
| 9 |
+
("kick", 0.25),
|
| 10 |
+
("bass", 0.35),
|
| 11 |
+
("pad", 0.10),
|
| 12 |
+
("synth_lead", 0.70),
|
| 13 |
+
("fx", 0.05),
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
STYLE_PRESETS = {
|
| 17 |
+
"club": {"layout": "iamf", "room_preset": "club_medium", "lufs": -14.0},
|
| 18 |
+
"cinematic": {"layout": "iamf", "room_preset": "cinema_large", "lufs": -16.0},
|
| 19 |
+
"live_stage": {"layout": "iamf", "room_preset": "stage_live", "lufs": -15.0},
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def build_example(seed: int) -> dict:
|
| 24 |
+
rng = random.Random(seed)
|
| 25 |
+
style = rng.choice(list(STYLE_PRESETS.keys()))
|
| 26 |
+
bpm = rng.choice([90, 100, 110, 120, 128, 140])
|
| 27 |
+
energy = round(rng.uniform(0.35, 0.95), 2)
|
| 28 |
+
max_objects = rng.choice([8, 10, 12])
|
| 29 |
+
|
| 30 |
+
stems = []
|
| 31 |
+
for idx, (klass, leadness) in enumerate(CLASSES, start=1):
|
| 32 |
+
stems.append(
|
| 33 |
+
{
|
| 34 |
+
"id": f"{klass[:1]}{idx}",
|
| 35 |
+
"class": klass,
|
| 36 |
+
"lufs": round(rng.uniform(-25.0, -10.0), 1),
|
| 37 |
+
"transient": round(rng.uniform(0.05, 0.98), 2),
|
| 38 |
+
"band_energy": {
|
| 39 |
+
"low": round(rng.uniform(0.05, 0.9), 2),
|
| 40 |
+
"mid": round(rng.uniform(0.05, 0.9), 2),
|
| 41 |
+
"high": round(rng.uniform(0.05, 0.9), 2),
|
| 42 |
+
},
|
| 43 |
+
"leadness": leadness,
|
| 44 |
+
}
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
payload = {
|
| 48 |
+
"target_format": "iamf",
|
| 49 |
+
"max_objects": max_objects,
|
| 50 |
+
"style": style,
|
| 51 |
+
"section": rng.choice(["intro", "verse", "break", "drop"]),
|
| 52 |
+
"global": {"bpm": bpm, "energy": energy},
|
| 53 |
+
"stems": stems,
|
| 54 |
+
"rules": [
|
| 55 |
+
{"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 56 |
+
{"type": "mono_low_end", "hz_below": 120},
|
| 57 |
+
{"type": "width_pref", "track_class": "pad", "min_width": 0.75},
|
| 58 |
+
],
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
preset = STYLE_PRESETS[style]
|
| 62 |
+
output = {
|
| 63 |
+
"version": "1.0",
|
| 64 |
+
"bed": {
|
| 65 |
+
"layout": preset["layout"],
|
| 66 |
+
"loudness_target_lufs": preset["lufs"],
|
| 67 |
+
"room_preset": preset["room_preset"],
|
| 68 |
+
},
|
| 69 |
+
"objects": [
|
| 70 |
+
{
|
| 71 |
+
"id": stems[0]["id"],
|
| 72 |
+
"class": "lead_vocal",
|
| 73 |
+
"az_deg": 0,
|
| 74 |
+
"el_deg": 10,
|
| 75 |
+
"dist_m": 1.6,
|
| 76 |
+
"width": 0.15,
|
| 77 |
+
"gain_db": 0.0,
|
| 78 |
+
"reverb_send": 0.18,
|
| 79 |
+
"early_reflections": 0.2,
|
| 80 |
+
"motion": [
|
| 81 |
+
{"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 82 |
+
{"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
|
| 83 |
+
],
|
| 84 |
+
}
|
| 85 |
+
],
|
| 86 |
+
"constraints_applied": [
|
| 87 |
+
"anchor:lead_vocal@0/10/1.6",
|
| 88 |
+
"mono_low_end<120Hz",
|
| 89 |
+
"pad_width>=0.75",
|
| 90 |
+
],
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
prompt = (
|
| 94 |
+
"GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\n"
|
| 95 |
+
"INPUT:\n" + json.dumps(payload, indent=2)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return {"prompt": prompt, "completion": json.dumps(output, indent=2)}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def main() -> None:
|
| 102 |
+
parser = argparse.ArgumentParser(description="Generate a small synthetic GravityLLM dataset.")
|
| 103 |
+
parser.add_argument("--output", type=Path, default=Path("data/synthetic_train.jsonl"))
|
| 104 |
+
parser.add_argument("--count", type=int, default=25)
|
| 105 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
with args.output.open("w", encoding="utf-8") as f:
|
| 110 |
+
for i in range(args.count):
|
| 111 |
+
row = build_example(args.seed + i)
|
| 112 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 113 |
+
|
| 114 |
+
print(f"Wrote {args.count} examples to {args.output}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
tools/validate_scene.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from jsonschema import Draft7Validator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def validate(schema_path: Path, json_path: Path) -> int:
|
| 9 |
+
schema = json.loads(schema_path.read_text(encoding="utf-8"))
|
| 10 |
+
payload = json.loads(json_path.read_text(encoding="utf-8"))
|
| 11 |
+
|
| 12 |
+
validator = Draft7Validator(schema)
|
| 13 |
+
errors = sorted(validator.iter_errors(payload), key=lambda e: list(e.path))
|
| 14 |
+
|
| 15 |
+
if not errors:
|
| 16 |
+
print("VALID")
|
| 17 |
+
return 0
|
| 18 |
+
|
| 19 |
+
print("INVALID")
|
| 20 |
+
for err in errors[:50]:
|
| 21 |
+
path = ".".join(str(p) for p in err.path)
|
| 22 |
+
print(f"- {path}: {err.message}")
|
| 23 |
+
return 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> int:
|
| 27 |
+
parser = argparse.ArgumentParser(description="Validate a GravityLLM scene JSON against the Spatial9Scene schema.")
|
| 28 |
+
parser.add_argument("schema_path", type=Path)
|
| 29 |
+
parser.add_argument("json_path", type=Path)
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
return validate(args.schema_path, args.json_path)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
raise SystemExit(main())
|
train.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoModelForCausalLM,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
BitsAndBytesConfig,
|
| 16 |
+
Trainer,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
set_seed,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
SYSTEM_PREFIX = (
|
| 22 |
+
"You are GravityLLM, a Spatial9 scene generation model. "
|
| 23 |
+
"Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
|
| 24 |
+
"Do not return markdown. Do not explain your answer. "
|
| 25 |
+
"Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_args() -> argparse.Namespace:
|
| 30 |
+
parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.")
|
| 31 |
+
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 32 |
+
parser.add_argument("--train_file", type=str, default="data/train.jsonl")
|
| 33 |
+
parser.add_argument("--valid_file", type=str, default="data/valid.jsonl")
|
| 34 |
+
parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9")
|
| 35 |
+
parser.add_argument("--max_length", type=int, default=2048)
|
| 36 |
+
|
| 37 |
+
parser.add_argument("--num_train_epochs", type=float, default=1.0)
|
| 38 |
+
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
| 39 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 40 |
+
parser.add_argument("--eval_batch_size", type=int, default=1)
|
| 41 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
| 42 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.03)
|
| 43 |
+
parser.add_argument("--weight_decay", type=float, default=0.0)
|
| 44 |
+
parser.add_argument("--logging_steps", type=int, default=10)
|
| 45 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
| 46 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
| 47 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 48 |
+
|
| 49 |
+
parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.")
|
| 50 |
+
parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.")
|
| 51 |
+
parser.add_argument("--lora_r", type=int, default=16)
|
| 52 |
+
parser.add_argument("--lora_alpha", type=int, default=32)
|
| 53 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05)
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--bf16", action="store_true")
|
| 56 |
+
parser.add_argument("--fp16", action="store_true")
|
| 57 |
+
|
| 58 |
+
parser.add_argument("--push_to_hub", action="store_true")
|
| 59 |
+
parser.add_argument("--hub_model_id", type=str, default=None)
|
| 60 |
+
parser.add_argument("--hub_private_repo", action="store_true")
|
| 61 |
+
return parser.parse_args()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_jsonl(file_path: str):
|
| 65 |
+
return load_dataset("json", data_files=file_path, split="train")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def format_prompt(raw_prompt: str) -> str:
|
| 69 |
+
raw_prompt = raw_prompt.strip()
|
| 70 |
+
if raw_prompt.lower().startswith("gravityllm:"):
|
| 71 |
+
raw_prompt = raw_prompt.split(":", 1)[1].strip()
|
| 72 |
+
return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]:
|
| 76 |
+
prompt_text = format_prompt(example["prompt"])
|
| 77 |
+
completion_text = example["completion"].strip()
|
| 78 |
+
|
| 79 |
+
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
|
| 80 |
+
completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"]
|
| 81 |
+
|
| 82 |
+
input_ids = prompt_ids + completion_ids
|
| 83 |
+
labels = [-100] * len(prompt_ids) + completion_ids
|
| 84 |
+
|
| 85 |
+
if len(input_ids) > max_length:
|
| 86 |
+
input_ids = input_ids[:max_length]
|
| 87 |
+
labels = labels[:max_length]
|
| 88 |
+
|
| 89 |
+
attention_mask = [1] * len(input_ids)
|
| 90 |
+
return {
|
| 91 |
+
"input_ids": input_ids,
|
| 92 |
+
"attention_mask": attention_mask,
|
| 93 |
+
"labels": labels,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class CausalDataCollator:
|
| 99 |
+
pad_token_id: int
|
| 100 |
+
label_pad_token_id: int = -100
|
| 101 |
+
|
| 102 |
+
def __call__(self, features):
|
| 103 |
+
max_len = max(len(f["input_ids"]) for f in features)
|
| 104 |
+
|
| 105 |
+
input_ids = []
|
| 106 |
+
attention_mask = []
|
| 107 |
+
labels = []
|
| 108 |
+
|
| 109 |
+
for f in features:
|
| 110 |
+
pad_len = max_len - len(f["input_ids"])
|
| 111 |
+
input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len)
|
| 112 |
+
attention_mask.append(f["attention_mask"] + [0] * pad_len)
|
| 113 |
+
labels.append(f["labels"] + [self.label_pad_token_id] * pad_len)
|
| 114 |
+
|
| 115 |
+
batch = {
|
| 116 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 117 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 118 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 119 |
+
}
|
| 120 |
+
return batch
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def prepare_model(args: argparse.Namespace):
|
| 124 |
+
model_kwargs = {}
|
| 125 |
+
if args.qlora:
|
| 126 |
+
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
|
| 127 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 128 |
+
load_in_4bit=True,
|
| 129 |
+
bnb_4bit_quant_type="nf4",
|
| 130 |
+
bnb_4bit_use_double_quant=True,
|
| 131 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 132 |
+
)
|
| 133 |
+
model_kwargs["device_map"] = "auto"
|
| 134 |
+
|
| 135 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 136 |
+
args.model,
|
| 137 |
+
torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None),
|
| 138 |
+
trust_remote_code=True,
|
| 139 |
+
**model_kwargs,
|
| 140 |
+
)
|
| 141 |
+
model.config.use_cache = False
|
| 142 |
+
|
| 143 |
+
if args.qlora:
|
| 144 |
+
model = prepare_model_for_kbit_training(model)
|
| 145 |
+
|
| 146 |
+
if args.lora or args.qlora:
|
| 147 |
+
lora_config = LoraConfig(
|
| 148 |
+
r=args.lora_r,
|
| 149 |
+
lora_alpha=args.lora_alpha,
|
| 150 |
+
lora_dropout=args.lora_dropout,
|
| 151 |
+
bias="none",
|
| 152 |
+
task_type="CAUSAL_LM",
|
| 153 |
+
target_modules="all-linear",
|
| 154 |
+
)
|
| 155 |
+
model = get_peft_model(model, lora_config)
|
| 156 |
+
model.print_trainable_parameters()
|
| 157 |
+
|
| 158 |
+
return model
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def main() -> None:
|
| 162 |
+
args = parse_args()
|
| 163 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 164 |
+
set_seed(args.seed)
|
| 165 |
+
|
| 166 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True)
|
| 167 |
+
tokenizer.padding_side = "right"
|
| 168 |
+
if tokenizer.pad_token is None:
|
| 169 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 170 |
+
|
| 171 |
+
train_ds = load_jsonl(args.train_file)
|
| 172 |
+
valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None
|
| 173 |
+
|
| 174 |
+
train_ds = train_ds.map(
|
| 175 |
+
lambda row: tokenize_example(row, tokenizer, args.max_length),
|
| 176 |
+
remove_columns=train_ds.column_names,
|
| 177 |
+
desc="Tokenizing train set",
|
| 178 |
+
)
|
| 179 |
+
if valid_ds is not None:
|
| 180 |
+
valid_ds = valid_ds.map(
|
| 181 |
+
lambda row: tokenize_example(row, tokenizer, args.max_length),
|
| 182 |
+
remove_columns=valid_ds.column_names,
|
| 183 |
+
desc="Tokenizing valid set",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
model = prepare_model(args)
|
| 187 |
+
|
| 188 |
+
training_args = TrainingArguments(
|
| 189 |
+
output_dir=args.output_dir,
|
| 190 |
+
overwrite_output_dir=True,
|
| 191 |
+
num_train_epochs=args.num_train_epochs,
|
| 192 |
+
learning_rate=args.learning_rate,
|
| 193 |
+
per_device_train_batch_size=args.train_batch_size,
|
| 194 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
| 195 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 196 |
+
warmup_ratio=args.warmup_ratio,
|
| 197 |
+
weight_decay=args.weight_decay,
|
| 198 |
+
logging_steps=args.logging_steps,
|
| 199 |
+
save_steps=args.save_steps,
|
| 200 |
+
eval_steps=args.eval_steps,
|
| 201 |
+
evaluation_strategy="steps" if valid_ds is not None else "no",
|
| 202 |
+
save_strategy="steps",
|
| 203 |
+
bf16=args.bf16,
|
| 204 |
+
fp16=args.fp16,
|
| 205 |
+
report_to="none",
|
| 206 |
+
gradient_checkpointing=True,
|
| 207 |
+
lr_scheduler_type="cosine",
|
| 208 |
+
optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch",
|
| 209 |
+
max_grad_norm=1.0,
|
| 210 |
+
push_to_hub=args.push_to_hub,
|
| 211 |
+
hub_model_id=args.hub_model_id,
|
| 212 |
+
hub_private_repo=args.hub_private_repo,
|
| 213 |
+
hub_strategy="end" if args.push_to_hub else "every_save",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
trainer = Trainer(
|
| 217 |
+
model=model,
|
| 218 |
+
args=training_args,
|
| 219 |
+
train_dataset=train_ds,
|
| 220 |
+
eval_dataset=valid_ds,
|
| 221 |
+
data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id),
|
| 222 |
+
tokenizer=tokenizer,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
train_result = trainer.train()
|
| 226 |
+
trainer.save_model(args.output_dir)
|
| 227 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 228 |
+
|
| 229 |
+
metrics = train_result.metrics
|
| 230 |
+
with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f:
|
| 231 |
+
json.dump(metrics, f, indent=2)
|
| 232 |
+
|
| 233 |
+
run_meta = vars(args).copy()
|
| 234 |
+
run_meta["train_examples"] = len(train_ds)
|
| 235 |
+
run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0
|
| 236 |
+
with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f:
|
| 237 |
+
json.dump(run_meta, f, indent=2)
|
| 238 |
+
|
| 239 |
+
if args.push_to_hub:
|
| 240 |
+
trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter")
|
| 241 |
+
print(f"Training complete. Artifacts saved to: {args.output_dir}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
upload_to_hub.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import HfApi
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_args() -> argparse.Namespace:
|
| 8 |
+
parser = argparse.ArgumentParser(description="Upload a local GravityLLM folder to the Hugging Face Hub.")
|
| 9 |
+
parser.add_argument("--folder_path", type=Path, required=True, help="Local folder to upload.")
|
| 10 |
+
parser.add_argument("--repo_id", type=str, required=True, help="Namespace/repo-name on Hugging Face.")
|
| 11 |
+
parser.add_argument("--repo_type", type=str, default="model", choices=["model", "dataset", "space"])
|
| 12 |
+
parser.add_argument("--private", action="store_true")
|
| 13 |
+
parser.add_argument("--commit_message", type=str, default="Upload GravityLLM artifacts")
|
| 14 |
+
return parser.parse_args()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main() -> None:
|
| 18 |
+
args = parse_args()
|
| 19 |
+
api = HfApi()
|
| 20 |
+
api.create_repo(repo_id=args.repo_id, repo_type=args.repo_type, private=args.private, exist_ok=True)
|
| 21 |
+
api.upload_folder(
|
| 22 |
+
folder_path=str(args.folder_path),
|
| 23 |
+
repo_id=args.repo_id,
|
| 24 |
+
repo_type=args.repo_type,
|
| 25 |
+
commit_message=args.commit_message,
|
| 26 |
+
)
|
| 27 |
+
print(f"Uploaded {args.folder_path} to https://huggingface.co/{args.repo_id}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
main()
|