File size: 5,755 Bytes
77e37fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from __future__ import annotations

import json
import os
import re
from functools import lru_cache
from typing import Any

from parser import PromptSpec, merge_prompt_specs, parse_prompt

try:
    import spaces  # type: ignore
except Exception:  # pragma: no cover
    class _SpacesShim:
        @staticmethod
        def GPU(*args, **kwargs):
            def decorator(fn):
                return fn
            return decorator

    spaces = _SpacesShim()  # type: ignore


DEFAULT_LOCAL_MODEL = os.getenv("PB3D_LOCAL_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
MODEL_PRESETS = {
    "Qwen 2.5 1.5B": "Qwen/Qwen2.5-1.5B-Instruct",
    "SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
}

JSON_SCHEMA_HINT = {
    "object_type": ["cargo_hauler", "fighter", "shuttle", "freighter", "dropship", "drone"],
    "scale": ["small", "medium", "large"],
    "hull_style": ["boxy", "rounded", "sleek"],
    "engine_count": "integer 1-6",
    "wing_span": "float 0.0-0.6",
    "cargo_ratio": "float 0.0-0.65",
    "cockpit_ratio": "float 0.10-0.30",
    "fin_height": "float 0.0-0.3",
    "landing_gear": "boolean",
    "asymmetry": "float 0.0-0.2",
    "notes": "short string",
}


def _clamp(value: float, low: float, high: float) -> float:
    return max(low, min(high, value))


@lru_cache(maxsize=2)
def _load_generation_components(model_id: str):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    has_cuda = torch.cuda.is_available()
    torch_dtype = torch.bfloat16 if has_cuda else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    return tokenizer, model


@spaces.GPU(duration=45)
def _generate_structured_json(prompt: str, model_id: str) -> dict[str, Any]:
    import torch

    tokenizer, model = _load_generation_components(model_id)

    system = (
        "You are a compact design parser for a procedural 3D generator. "
        "Convert the user request into a single JSON object and output JSON only."
    )
    user = (
        "Return a JSON object using this schema: "
        f"{json.dumps(JSON_SCHEMA_HINT)}\n"
        "Rules: choose the closest allowed enum values, stay conservative, infer hard-surface sci-fi vehicle structure, "
        "never explain anything, never use markdown fences, and keep notes brief.\n"
        f"Prompt: {prompt}"
    )

    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": user},
    ]

    if hasattr(tokenizer, "apply_chat_template"):
        rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        rendered = f"System: {system}\nUser: {user}\nAssistant:"

    inputs = tokenizer(rendered, return_tensors="pt")
    model_device = getattr(model, "device", None)
    if model_device is not None:
        inputs = {k: v.to(model_device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=220,
            do_sample=False,
            temperature=None,
            top_p=None,
            repetition_penalty=1.02,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    match = re.search(r"\{.*\}", text, flags=re.S)
    if not match:
        raise ValueError("Local model did not return JSON.")
    return json.loads(match.group(0))


def _normalize_llm_payload(payload: dict[str, Any], original_prompt: str) -> PromptSpec:
    def get_str(name: str, default: str) -> str:
        value = str(payload.get(name, default)).strip().lower()
        return value or default

    def get_int(name: str, default: int, low: int, high: int) -> int:
        try:
            return int(_clamp(int(payload.get(name, default)), low, high))
        except Exception:
            return default

    def get_float(name: str, default: float, low: float, high: float) -> float:
        try:
            return float(_clamp(float(payload.get(name, default)), low, high))
        except Exception:
            return default

    landing_raw = payload.get("landing_gear", True)
    if isinstance(landing_raw, bool):
        landing_gear = landing_raw
    else:
        landing_gear = str(landing_raw).strip().lower() in {"1", "true", "yes", "y"}

    return PromptSpec(
        object_type=get_str("object_type", "cargo_hauler"),
        scale=get_str("scale", "small"),
        hull_style=get_str("hull_style", "boxy"),
        engine_count=get_int("engine_count", 2, 1, 6),
        wing_span=get_float("wing_span", 0.2, 0.0, 0.6),
        cargo_ratio=get_float("cargo_ratio", 0.38, 0.0, 0.65),
        cockpit_ratio=get_float("cockpit_ratio", 0.18, 0.10, 0.30),
        fin_height=get_float("fin_height", 0.0, 0.0, 0.3),
        landing_gear=landing_gear,
        asymmetry=get_float("asymmetry", 0.0, 0.0, 0.2),
        notes=str(payload.get("notes", original_prompt)).strip() or original_prompt,
    )


def parse_prompt_with_local_llm(prompt: str, model_id: str | None = None) -> PromptSpec:
    model_id = model_id or DEFAULT_LOCAL_MODEL
    heuristic = parse_prompt(prompt)
    payload = _generate_structured_json(prompt=prompt, model_id=model_id)
    llm_spec = _normalize_llm_payload(payload, original_prompt=prompt)
    return merge_prompt_specs(heuristic, llm_spec)