3DModelGen / llm_parser.py
tomiconic's picture
Upload 11 files
77e37fc verified
from __future__ import annotations
import json
import os
import re
from functools import lru_cache
from typing import Any
from parser import PromptSpec, merge_prompt_specs, parse_prompt
try:
import spaces # type: ignore
except Exception: # pragma: no cover
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim() # type: ignore
DEFAULT_LOCAL_MODEL = os.getenv("PB3D_LOCAL_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
MODEL_PRESETS = {
"Qwen 2.5 1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
"SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
}
JSON_SCHEMA_HINT = {
"object_type": ["cargo_hauler", "fighter", "shuttle", "freighter", "dropship", "drone"],
"scale": ["small", "medium", "large"],
"hull_style": ["boxy", "rounded", "sleek"],
"engine_count": "integer 1-6",
"wing_span": "float 0.0-0.6",
"cargo_ratio": "float 0.0-0.65",
"cockpit_ratio": "float 0.10-0.30",
"fin_height": "float 0.0-0.3",
"landing_gear": "boolean",
"asymmetry": "float 0.0-0.2",
"notes": "short string",
}
def _clamp(value: float, low: float, high: float) -> float:
return max(low, min(high, value))
@lru_cache(maxsize=2)
def _load_generation_components(model_id: str):
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
has_cuda = torch.cuda.is_available()
torch_dtype = torch.bfloat16 if has_cuda else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
return tokenizer, model
@spaces.GPU(duration=45)
def _generate_structured_json(prompt: str, model_id: str) -> dict[str, Any]:
import torch
tokenizer, model = _load_generation_components(model_id)
system = (
"You are a compact design parser for a procedural 3D generator. "
"Convert the user request into a single JSON object and output JSON only."
)
user = (
"Return a JSON object using this schema: "
f"{json.dumps(JSON_SCHEMA_HINT)}\n"
"Rules: choose the closest allowed enum values, stay conservative, infer hard-surface sci-fi vehicle structure, "
"never explain anything, never use markdown fences, and keep notes brief.\n"
f"Prompt: {prompt}"
)
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
if hasattr(tokenizer, "apply_chat_template"):
rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
rendered = f"System: {system}\nUser: {user}\nAssistant:"
inputs = tokenizer(rendered, return_tensors="pt")
model_device = getattr(model, "device", None)
if model_device is not None:
inputs = {k: v.to(model_device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=220,
do_sample=False,
temperature=None,
top_p=None,
repetition_penalty=1.02,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[1]:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
match = re.search(r"\{.*\}", text, flags=re.S)
if not match:
raise ValueError("Local model did not return JSON.")
return json.loads(match.group(0))
def _normalize_llm_payload(payload: dict[str, Any], original_prompt: str) -> PromptSpec:
def get_str(name: str, default: str) -> str:
value = str(payload.get(name, default)).strip().lower()
return value or default
def get_int(name: str, default: int, low: int, high: int) -> int:
try:
return int(_clamp(int(payload.get(name, default)), low, high))
except Exception:
return default
def get_float(name: str, default: float, low: float, high: float) -> float:
try:
return float(_clamp(float(payload.get(name, default)), low, high))
except Exception:
return default
landing_raw = payload.get("landing_gear", True)
if isinstance(landing_raw, bool):
landing_gear = landing_raw
else:
landing_gear = str(landing_raw).strip().lower() in {"1", "true", "yes", "y"}
return PromptSpec(
object_type=get_str("object_type", "cargo_hauler"),
scale=get_str("scale", "small"),
hull_style=get_str("hull_style", "boxy"),
engine_count=get_int("engine_count", 2, 1, 6),
wing_span=get_float("wing_span", 0.2, 0.0, 0.6),
cargo_ratio=get_float("cargo_ratio", 0.38, 0.0, 0.65),
cockpit_ratio=get_float("cockpit_ratio", 0.18, 0.10, 0.30),
fin_height=get_float("fin_height", 0.0, 0.0, 0.3),
landing_gear=landing_gear,
asymmetry=get_float("asymmetry", 0.0, 0.0, 0.2),
notes=str(payload.get("notes", original_prompt)).strip() or original_prompt,
)
def parse_prompt_with_local_llm(prompt: str, model_id: str | None = None) -> PromptSpec:
model_id = model_id or DEFAULT_LOCAL_MODEL
heuristic = parse_prompt(prompt)
payload = _generate_structured_json(prompt=prompt, model_id=model_id)
llm_spec = _normalize_llm_payload(payload, original_prompt=prompt)
return merge_prompt_specs(heuristic, llm_spec)