lzanardos9's picture
Upload 16 files
45132f0 verified
Raw
History Blame Contribute Delete
15.9 kB
from __future__ import annotations
import base64
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Dict
import gradio as gr
from huggingface_hub import InferenceClient
from huggingface_hub.errors import HfHubHTTPError, InferenceTimeoutError
from utils.scene_tools import (
SCHEMA,
extract_first_json_block,
heuristic_scene,
parse_json_text,
plot_scene,
scene_markdown,
scene_table,
validate_scene,
)
ROOT = Path(__file__).resolve().parent
ASSETS = ROOT / "assets"
EXAMPLES = ROOT / "examples"
APP_TITLE = "GravityLLM"
DEFAULT_MODEL_ID = os.getenv("GRAVITYLLM_MODEL_ID", "your-namespace/GravityLLM-AutoPosition")
DEFAULT_BACKEND = os.getenv("GRAVITYLLM_BACKEND", "hybrid")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
SYSTEM_PREFIX = (
"You are GravityLLM (Spatial9 AutoPosition SLM). "
"Generate ONLY valid JSON matching the Spatial9Scene schema. "
"No markdown. No explanation. No code fences.\n\n"
)
EXAMPLE_FILES = {
"Club Drop": EXAMPLES / "club_drop.json",
"Cinematic Break": EXAMPLES / "cinematic_break.json",
"Podcast Voice": EXAMPLES / "podcast_voice.json",
}
def logo_data_uri() -> str:
logo_bytes = (ASSETS / "spatial9_logo.png").read_bytes()
return "data:image/png;base64," + base64.b64encode(logo_bytes).decode("utf-8")
def load_example(name: str) -> str:
path = EXAMPLE_FILES.get(name, next(iter(EXAMPLE_FILES.values())))
return path.read_text(encoding="utf-8")
def build_prompt(payload: Dict[str, Any]) -> str:
return SYSTEM_PREFIX + "INPUT:\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\nOUTPUT:\n"
def remote_generate(
payload: Dict[str, Any],
model_id: str,
temperature: float,
top_p: float,
max_new_tokens: int,
use_grammar: bool,
) -> tuple[Dict[str, Any], str]:
prompt = build_prompt(payload)
client = InferenceClient(model=model_id, token=HF_TOKEN)
call_kwargs = dict(
prompt=prompt,
model=model_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.05,
return_full_text=False,
)
if use_grammar:
call_kwargs["grammar"] = {"type": "json", "value": SCHEMA}
try:
response = client.text_generation(**call_kwargs)
scene = parse_json_text(response)
return scene, f"remote-model ({model_id})"
except Exception as first_error:
if use_grammar:
try:
call_kwargs.pop("grammar", None)
response = client.text_generation(**call_kwargs)
scene = parse_json_text(response)
return scene, f"remote-model ({model_id}, grammar-fallback)"
except Exception as second_error:
raise RuntimeError(f"{type(first_error).__name__}: {first_error}\n\nFallback: {type(second_error).__name__}: {second_error}") from second_error
raise
def write_download_file(scene: Dict[str, Any]) -> str:
fd, path = tempfile.mkstemp(prefix="gravityllm_scene_", suffix=".json")
os.close(fd)
Path(path).write_text(json.dumps(scene, ensure_ascii=False, indent=2), encoding="utf-8")
return path
def generate_scene(
payload_text: str,
model_id: str,
backend: str,
temperature: float,
top_p: float,
max_new_tokens: int,
use_grammar: bool,
):
try:
payload = parse_json_text(payload_text)
except Exception as exc:
msg = f"### Invalid input JSON\n\n- {type(exc).__name__}: {exc}"
return "", msg, None, [], None, "Fix the input JSON and try again."
backend = backend or DEFAULT_BACKEND
model_id = model_id.strip() or DEFAULT_MODEL_ID
scene = None
backend_used = "rules-engine demo"
status = "Scene generated."
if backend in {"remote-model", "hybrid"}:
try:
scene, backend_used = remote_generate(payload, model_id, temperature, top_p, max_new_tokens, use_grammar)
status = f"Generated from remote model: {model_id}"
except (InferenceTimeoutError, HfHubHTTPError, RuntimeError, ValueError, json.JSONDecodeError) as exc:
if backend == "remote-model":
msg = f"### Remote generation failed\n\n- {type(exc).__name__}: {exc}"
return "", msg, None, [], None, "Remote inference failed."
scene = heuristic_scene(payload)
backend_used = "rules-engine demo (remote fallback)"
status = f"Remote generation failed; heuristic fallback used. Details: {type(exc).__name__}: {exc}"
if scene is None:
scene = heuristic_scene(payload)
valid, errors = validate_scene(scene)
download_path = write_download_file(scene)
figure = plot_scene(scene)
table = scene_table(scene)
summary = scene_markdown(scene, valid, errors, backend_used)
return json.dumps(scene, ensure_ascii=False, indent=2), summary, figure, table, download_path, status
def validate_only(scene_text: str):
try:
scene = parse_json_text(scene_text)
except Exception as exc:
return f"### Invalid scene JSON\n\n- {type(exc).__name__}: {exc}", None, []
valid, errors = validate_scene(scene)
summary = scene_markdown(scene, valid, errors, "manual validation")
return summary, plot_scene(scene), scene_table(scene)
def build_payload(target_format, style, section, bpm, energy, max_objects):
payload = {
"target_format": target_format,
"max_objects": int(max_objects),
"style": style,
"section": section,
"global": {"bpm": int(bpm), "energy": float(energy)},
"stems": [
{"id": "lead", "class": "lead_vocal", "lufs": -17.0, "transient": 0.25, "band_energy": {"low": 0.08, "mid": 0.67, "high": 0.25}, "leadness": 0.96},
{"id": "kick", "class": "kick", "lufs": -10.6, "transient": 0.96, "band_energy": {"low": 0.82, "mid": 0.12, "high": 0.06}, "leadness": 0.22},
{"id": "bass", "class": "bass", "lufs": -12.5, "transient": 0.58, "band_energy": {"low": 0.86, "mid": 0.10, "high": 0.04}, "leadness": 0.30},
{"id": "pad", "class": "pad", "lufs": -21.5, "transient": 0.05, "band_energy": {"low": 0.20, "mid": 0.50, "high": 0.30}, "leadness": 0.08},
{"id": "fx", "class": "fx", "lufs": -24.0, "transient": 0.22, "band_energy": {"low": 0.10, "mid": 0.24, "high": 0.66}, "leadness": 0.04},
],
"rules": [
{"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
{"type": "mono_low_end", "hz_below": 120},
{"type": "width_pref", "track_class": "pad", "min_width": 0.75},
],
}
return json.dumps(payload, ensure_ascii=False, indent=2)
hero_html = f"""
<div class="hero-wrap">
<div class="hero-left">
<div class="hero-logo-card">
<img class="hero-logo" src="{logo_data_uri()}" alt="Spatial9 logo"/>
</div>
</div>
<div class="hero-right">
<div class="eyebrow">SPATIAL9 • HUGGING FACE SPACE</div>
<h1>GravityLLM Studio</h1>
<p class="hero-copy">
Constraint-conditioned immersive scene generation with schema-guided JSON output,
remote Hugging Face inference, heuristic fallback, and a live spatial preview.
</p>
<div class="hero-chips">
<span>IAMF Ready</span>
<span>Schema Validated</span>
<span>Spatial Preview</span>
<span>Branded Demo</span>
</div>
</div>
</div>
"""
css = """
:root {
--g-bg: #f6f9ff;
--g-panel: rgba(255,255,255,0.86);
--g-panel-strong: rgba(255,255,255,0.96);
--g-line: #dbe7f6;
--g-ink: #15233d;
--g-sub: #5f728f;
--g-accent: #1f6fe5;
--g-accent-2: #0f9bb9;
}
.gradio-container {
background:
radial-gradient(circle at top left, rgba(55,120,246,0.10), transparent 32%),
radial-gradient(circle at bottom right, rgba(15,155,185,0.10), transparent 28%),
var(--g-bg);
}
.hero-wrap {
display: grid;
grid-template-columns: 280px 1fr;
gap: 28px;
padding: 22px 8px 12px 8px;
align-items: center;
}
.hero-logo-card {
background: linear-gradient(180deg, rgba(255,255,255,0.96), rgba(247,250,255,0.90));
border: 1px solid var(--g-line);
box-shadow: 0 18px 42px rgba(31,58,114,0.08);
border-radius: 28px;
padding: 24px;
display: flex;
justify-content: center;
align-items: center;
min-height: 170px;
}
.hero-logo {
width: 100%;
max-width: 220px;
object-fit: contain;
}
.hero-right h1 {
font-size: 2.5rem;
margin: 0;
color: var(--g-ink);
}
.hero-copy {
color: var(--g-sub);
font-size: 1.06rem;
line-height: 1.6;
max-width: 780px;
}
.eyebrow {
color: var(--g-accent);
font-size: 0.92rem;
letter-spacing: 0.14em;
font-weight: 700;
margin-bottom: 8px;
}
.hero-chips {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 14px;
}
.hero-chips span {
background: rgba(239,246,255,0.96);
border: 1px solid #cfe0fb;
color: #28558f;
border-radius: 999px;
padding: 8px 12px;
font-size: 0.9rem;
font-weight: 600;
}
.card-note {
color: var(--g-sub);
}
.block-panel {
background: var(--g-panel);
border: 1px solid var(--g-line);
border-radius: 22px;
padding: 10px;
}
footer {visibility: hidden;}
@media (max-width: 900px) {
.hero-wrap {grid-template-columns: 1fr;}
}
"""
with gr.Blocks(
title=f"{APP_TITLE} Studio",
fill_width=True,
css=css,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="cyan",
neutral_hue="slate",
radius_size="lg",
),
) as demo:
gr.HTML(hero_html)
with gr.Tabs():
with gr.Tab("GravityLLM Studio"):
with gr.Row():
with gr.Column(scale=11):
example_name = gr.Dropdown(
choices=list(EXAMPLE_FILES.keys()),
value="Club Drop",
label="Example payload",
)
load_btn = gr.Button("Load Example", variant="secondary")
payload_box = gr.Code(
value=load_example("Club Drop"),
language="json",
label="Constraint + stem feature payload",
lines=26,
)
with gr.Column(scale=6):
model_id = gr.Textbox(
value=DEFAULT_MODEL_ID,
label="Model repo or endpoint",
info="Set your Hugging Face model repo id or inference endpoint URL.",
)
backend = gr.Dropdown(
choices=["hybrid", "remote-model", "rules-engine demo"],
value=DEFAULT_BACKEND if DEFAULT_BACKEND in {"hybrid", "remote-model", "rules-engine demo"} else "hybrid",
label="Backend",
)
temperature = gr.Slider(0.0, 1.2, value=0.2, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
max_new_tokens = gr.Slider(128, 1400, value=900, step=16, label="Max new tokens")
use_grammar = gr.Checkbox(
value=True,
label="Use JSON schema grammar when remote backend supports it",
)
run_btn = gr.Button("Generate Spatial Scene", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
with gr.Column(scale=9):
output_box = gr.Code(language="json", label="Generated Spatial9Scene JSON", lines=26)
download = gr.File(label="Download scene JSON")
with gr.Column(scale=7):
summary = gr.Markdown("### Ready\n\nLoad an example or paste your own payload.")
plot = gr.Plot(label="Spatial scene preview")
object_table = gr.Dataframe(
headers=["id", "class", "az_deg", "el_deg", "dist_m", "width", "gain_db"],
datatype=["str", "str", "number", "number", "number", "number", "number"],
row_count=(0, "dynamic"),
col_count=(7, "fixed"),
label="Object inspector",
)
with gr.Tab("Prompt Builder"):
gr.Markdown("Build a starter payload, then send it to GravityLLM Studio.")
with gr.Row():
target_format = gr.Dropdown(["iamf", "binaural", "5.1.4", "7.1.4"], value="iamf", label="Target format")
style = gr.Dropdown(["club", "cinematic", "podcast", "live", "intimate"], value="club", label="Style")
section = gr.Dropdown(["intro", "verse", "break", "drop", "full"], value="drop", label="Section")
with gr.Row():
bpm = gr.Slider(0, 200, value=128, step=1, label="BPM")
energy = gr.Slider(0.0, 1.0, value=0.92, step=0.01, label="Energy")
max_objects_builder = gr.Slider(1, 32, value=10, step=1, label="Max objects")
build_btn = gr.Button("Build Payload", variant="primary")
builder_output = gr.Code(language="json", label="Starter payload", lines=24)
send_to_studio_btn = gr.Button("Send to Studio", variant="secondary")
with gr.Tab("Validate Existing Scene"):
scene_input = gr.Code(language="json", label="Paste a Spatial9Scene JSON", lines=24)
validate_btn = gr.Button("Validate Scene", variant="primary")
validate_summary = gr.Markdown()
validate_plot = gr.Plot()
validate_table = gr.Dataframe(
headers=["id", "class", "az_deg", "el_deg", "dist_m", "width", "gain_db"],
datatype=["str", "str", "number", "number", "number", "number", "number"],
row_count=(0, "dynamic"),
col_count=(7, "fixed"),
label="Validated object inspector",
)
with gr.Tab("About"):
gr.Image(value=str(ASSETS / "gravityllm_space_banner.png"), label="GravityLLM banner", show_download_button=False, show_fullscreen_button=False)
gr.Markdown(
"""
### What this Space does
- Turns **constraints + stem descriptors** into **Spatial9Scene JSON**
- Can call a remote Hugging Face model repo through `InferenceClient`
- Falls back to a deterministic **rules engine** so the demo stays usable
- Validates outputs against the included JSON schema
- Renders a spatial top-down preview of object positions
### Environment variables
- `GRAVITYLLM_MODEL_ID` — model repo id or endpoint URL
- `HF_TOKEN` — required if the model is gated or private
- `GRAVITYLLM_BACKEND` — optional default: `hybrid`, `remote-model`, or `rules-engine demo`
### Recommended setup
1. Upload your GravityLLM model repo.
2. Train and push weights.
3. Upload this Space repo.
4. Set `GRAVITYLLM_MODEL_ID` in the Space settings.
"""
)
load_btn.click(fn=load_example, inputs=example_name, outputs=payload_box)
build_btn.click(
fn=build_payload,
inputs=[target_format, style, section, bpm, energy, max_objects_builder],
outputs=builder_output,
)
send_to_studio_btn.click(fn=lambda x: x, inputs=builder_output, outputs=payload_box)
run_btn.click(
fn=generate_scene,
inputs=[payload_box, model_id, backend, temperature, top_p, max_new_tokens, use_grammar],
outputs=[output_box, summary, plot, object_table, download, status],
)
validate_btn.click(
fn=validate_only,
inputs=scene_input,
outputs=[validate_summary, validate_plot, validate_table],
)
if __name__ == "__main__":
demo.launch()