Text Generation
Transformers
PEFT
English
gravityllm
spatial-audio
immersive-audio
spatial9
iamf
instruction-tuning
json
lora
qlora
Instructions to use Spatial9/GravityLLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Spatial9/GravityLLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="Spatial9/GravityLLM")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Spatial9/GravityLLM", dtype="auto") - PEFT
How to use Spatial9/GravityLLM with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use Spatial9/GravityLLM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "Spatial9/GravityLLM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Spatial9/GravityLLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/Spatial9/GravityLLM
- SGLang
How to use Spatial9/GravityLLM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "Spatial9/GravityLLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Spatial9/GravityLLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "Spatial9/GravityLLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Spatial9/GravityLLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use Spatial9/GravityLLM with Docker Model Runner:
docker model run hf.co/Spatial9/GravityLLM
| import argparse | |
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import torch | |
| from datasets import load_dataset | |
| from jsonschema import Draft7Validator | |
| from peft import AutoPeftModelForCausalLM | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| SYSTEM_PREFIX = ( | |
| "You are GravityLLM, a Spatial9 scene generation model. " | |
| "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " | |
| "Do not return markdown. Do not explain your answer.\n\n" | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.") | |
| parser.add_argument("--model_dir", type=str, required=True) | |
| parser.add_argument("--data_file", type=str, default="data/valid.jsonl") | |
| parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json")) | |
| parser.add_argument("--max_new_tokens", type=int, default=900) | |
| parser.add_argument("--temperature", type=float, default=0.2) | |
| parser.add_argument("--top_p", type=float, default=0.9) | |
| parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.") | |
| parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json")) | |
| return parser.parse_args() | |
| def load_model_and_tokenizer(model_dir: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| try: | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| model_dir, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| ) | |
| except Exception: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_dir, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| return model, tokenizer | |
| def format_prompt(raw_prompt: str) -> str: | |
| raw_prompt = raw_prompt.strip() | |
| if raw_prompt.lower().startswith("gravityllm:"): | |
| raw_prompt = raw_prompt.split(":", 1)[1].strip() | |
| return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" | |
| def extract_first_json(text: str) -> str: | |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) | |
| return match.group(0).strip() if match else text.strip() | |
| def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]: | |
| data = json.loads(output_text) | |
| validator = Draft7Validator(schema) | |
| errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path)) | |
| return len(errors) == 0, data | |
| def check_budget(input_payload: Dict, scene_payload: Dict) -> bool: | |
| max_objects = input_payload.get("max_objects") | |
| if max_objects is None: | |
| return True | |
| return len(scene_payload.get("objects", [])) <= max_objects | |
| def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool: | |
| objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])} | |
| for rule in input_payload.get("rules", []): | |
| if rule.get("type") != "anchor": | |
| continue | |
| klass = rule.get("track_class") | |
| obj = objects.get(klass) | |
| if obj is None: | |
| return False | |
| for field in ["az_deg", "el_deg", "dist_m"]: | |
| if float(obj[field]) != float(rule[field]): | |
| return False | |
| return True | |
| def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
| inputs = tokenizer(prompt_text, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) | |
| raw_completion = decoded[len(prompt_prefix):].strip() | |
| return extract_first_json(raw_completion) | |
| def main() -> None: | |
| args = parse_args() | |
| schema = json.loads(args.schema_path.read_text(encoding="utf-8")) | |
| ds = load_dataset("json", data_files=args.data_file, split="train") | |
| if args.limit > 0: | |
| ds = ds.select(range(min(args.limit, len(ds)))) | |
| model, tokenizer = load_model_and_tokenizer(args.model_dir) | |
| total = len(ds) | |
| parse_ok = 0 | |
| schema_ok = 0 | |
| budget_ok = 0 | |
| anchor_ok = 0 | |
| samples = [] | |
| for row in ds: | |
| prompt_text = format_prompt(row["prompt"]) | |
| generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p) | |
| sample_report = {"prompt": row["prompt"], "generated": generated} | |
| try: | |
| gen_data = json.loads(generated) | |
| parse_ok += 1 | |
| valid, gen_scene = validate_schema(schema, generated) | |
| if valid: | |
| schema_ok += 1 | |
| # Reconstruct input payload from prompt for simple rule checks. | |
| prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1] | |
| input_payload = json.loads(prompt_payload_text) | |
| if check_budget(input_payload, gen_scene): | |
| budget_ok += 1 | |
| if check_anchor_rules(input_payload, gen_scene): | |
| anchor_ok += 1 | |
| sample_report["schema_valid"] = True | |
| sample_report["budget_pass"] = check_budget(input_payload, gen_scene) | |
| sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene) | |
| else: | |
| sample_report["schema_valid"] = False | |
| except Exception as exc: | |
| sample_report["error"] = str(exc) | |
| samples.append(sample_report) | |
| report = { | |
| "examples": total, | |
| "json_parse_rate": round(parse_ok / total, 4) if total else 0.0, | |
| "schema_valid_rate": round(schema_ok / total, 4) if total else 0.0, | |
| "budget_pass_rate": round(budget_ok / total, 4) if total else 0.0, | |
| "anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0, | |
| "samples": samples[:10], | |
| } | |
| args.report_path.parent.mkdir(parents=True, exist_ok=True) | |
| args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8") | |
| print(json.dumps(report, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |