Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from functools import lru_cache | |
| from typing import Any | |
| from parser import PromptSpec, merge_prompt_specs, parse_prompt | |
| try: | |
| import spaces # type: ignore | |
| except Exception: # pragma: no cover | |
| class _SpacesShim: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesShim() # type: ignore | |
| DEFAULT_LOCAL_MODEL = os.getenv("PB3D_LOCAL_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") | |
| MODEL_PRESETS = { | |
| "Qwen 2.5 1.5B": "Qwen/Qwen2.5-1.5B-Instruct", | |
| "SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| } | |
| JSON_SCHEMA_HINT = { | |
| "object_type": ["cargo_hauler", "fighter", "shuttle", "freighter", "dropship", "drone"], | |
| "scale": ["small", "medium", "large"], | |
| "hull_style": ["boxy", "rounded", "sleek"], | |
| "engine_count": "integer 1-6", | |
| "wing_span": "float 0.0-0.6", | |
| "cargo_ratio": "float 0.0-0.65", | |
| "cockpit_ratio": "float 0.10-0.30", | |
| "fin_height": "float 0.0-0.3", | |
| "landing_gear": "boolean", | |
| "asymmetry": "float 0.0-0.2", | |
| "notes": "short string", | |
| } | |
| def _clamp(value: float, low: float, high: float) -> float: | |
| return max(low, min(high, value)) | |
| def _load_generation_components(model_id: str): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| has_cuda = torch.cuda.is_available() | |
| torch_dtype = torch.bfloat16 if has_cuda else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| return tokenizer, model | |
| def _generate_structured_json(prompt: str, model_id: str) -> dict[str, Any]: | |
| import torch | |
| tokenizer, model = _load_generation_components(model_id) | |
| system = ( | |
| "You are a compact design parser for a procedural 3D generator. " | |
| "Convert the user request into a single JSON object and output JSON only." | |
| ) | |
| user = ( | |
| "Return a JSON object using this schema: " | |
| f"{json.dumps(JSON_SCHEMA_HINT)}\n" | |
| "Rules: choose the closest allowed enum values, stay conservative, infer hard-surface sci-fi vehicle structure, " | |
| "never explain anything, never use markdown fences, and keep notes brief.\n" | |
| f"Prompt: {prompt}" | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ] | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| else: | |
| rendered = f"System: {system}\nUser: {user}\nAssistant:" | |
| inputs = tokenizer(rendered, return_tensors="pt") | |
| model_device = getattr(model, "device", None) | |
| if model_device is not None: | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=220, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| repetition_penalty=1.02, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| new_tokens = output[0][inputs["input_ids"].shape[1]:] | |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| match = re.search(r"\{.*\}", text, flags=re.S) | |
| if not match: | |
| raise ValueError("Local model did not return JSON.") | |
| return json.loads(match.group(0)) | |
| def _normalize_llm_payload(payload: dict[str, Any], original_prompt: str) -> PromptSpec: | |
| def get_str(name: str, default: str) -> str: | |
| value = str(payload.get(name, default)).strip().lower() | |
| return value or default | |
| def get_int(name: str, default: int, low: int, high: int) -> int: | |
| try: | |
| return int(_clamp(int(payload.get(name, default)), low, high)) | |
| except Exception: | |
| return default | |
| def get_float(name: str, default: float, low: float, high: float) -> float: | |
| try: | |
| return float(_clamp(float(payload.get(name, default)), low, high)) | |
| except Exception: | |
| return default | |
| landing_raw = payload.get("landing_gear", True) | |
| if isinstance(landing_raw, bool): | |
| landing_gear = landing_raw | |
| else: | |
| landing_gear = str(landing_raw).strip().lower() in {"1", "true", "yes", "y"} | |
| return PromptSpec( | |
| object_type=get_str("object_type", "cargo_hauler"), | |
| scale=get_str("scale", "small"), | |
| hull_style=get_str("hull_style", "boxy"), | |
| engine_count=get_int("engine_count", 2, 1, 6), | |
| wing_span=get_float("wing_span", 0.2, 0.0, 0.6), | |
| cargo_ratio=get_float("cargo_ratio", 0.38, 0.0, 0.65), | |
| cockpit_ratio=get_float("cockpit_ratio", 0.18, 0.10, 0.30), | |
| fin_height=get_float("fin_height", 0.0, 0.0, 0.3), | |
| landing_gear=landing_gear, | |
| asymmetry=get_float("asymmetry", 0.0, 0.0, 0.2), | |
| notes=str(payload.get("notes", original_prompt)).strip() or original_prompt, | |
| ) | |
| def parse_prompt_with_local_llm(prompt: str, model_id: str | None = None) -> PromptSpec: | |
| model_id = model_id or DEFAULT_LOCAL_MODEL | |
| heuristic = parse_prompt(prompt) | |
| payload = _generate_structured_json(prompt=prompt, model_id=model_id) | |
| llm_spec = _normalize_llm_payload(payload, original_prompt=prompt) | |
| return merge_prompt_specs(heuristic, llm_spec) | |