File size: 6,443 Bytes
777ea0e
 
 
 
 
 
 
043484b
 
 
 
 
 
 
 
 
 
8df32ec
 
 
 
 
 
 
777ea0e
 
 
 
 
 
 
635e6fb
 
 
 
 
777ea0e
 
 
 
 
 
 
 
 
 
 
 
 
 
635e6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217a06b
 
777ea0e
 
 
 
 
217a06b
 
777ea0e
 
635e6fb
 
 
777ea0e
635e6fb
 
 
 
 
 
 
777ea0e
635e6fb
777ea0e
 
 
 
217a06b
 
635e6fb
217a06b
777ea0e
 
 
 
 
 
 
 
 
 
 
 
69036da
 
 
217a06b
69036da
 
777ea0e
 
 
217a06b
777ea0e
217a06b
 
 
 
777ea0e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""GDScript Coding Assistant — Gradio app (HF Space, ZeroGPU).

Flow per question:  retrieve (CPU) -> generate (ZeroGPU) -> validate (CPU) ->
optional 1x self-correct -> render answer + validation + sources.
"""
from __future__ import annotations

import os

# Route the HF cache to persistent storage (mounted at /data when the Space has
# "persistent storage" enabled) BEFORE importing any HF library, so model and
# embedder weights survive restarts and are not re-downloaded on every cold
# boot. No-op when /data isn't present/writable (falls back to the default
# ephemeral cache).
if not os.environ.get("HF_HOME") and os.path.isdir("/data") and os.access("/data", os.W_OK):
    os.environ["HF_HOME"] = "/data/huggingface"

# Import spaces BEFORE any torch-importing library (gradio/rag/generate) so
# ZeroGPU can patch CUDA and keep the model GPU-resident. No-op off-Space.
try:
    import spaces  # noqa: F401
except Exception:
    pass

import gradio as gr

import rag
import prompt as promptlib
import generate as gen
import validate as gdv

# Cap on auto-correction GPU calls per answer. Each broken block fixed is one
# extra gen.generate() (one @spaces.GPU call), so this bounds the total GPU
# work a single pathological answer (many broken blocks) can trigger.
MAX_FIX_PASSES = 3


def _sources_md(hits: list[rag.Hit]) -> str:
    if not hits:
        return ""
    lines = ["\n\n<details><summary>📚 Retrieved sources</summary>\n"]
    for i, h in enumerate(hits, 1):
        loc = h.repo or "corpus"
        url = h.origin_url or ""
        link = f"[{loc}]({url})" if url.startswith("http") else loc
        lines.append(f"{i}. {link} · `{h.file_path or h.kind}` · score {h.score:.2f}")
    lines.append("\n</details>")
    return "\n".join(lines)


def _autocorrect(answer: str) -> tuple[str, int, int]:
    """Repair each broken ```gdscript block IN PLACE, fixing at most
    MAX_FIX_PASSES of them (one GPU call each). Returns
    (new_answer, num_fixed, num_broken). num_broken counts every block that
    failed to parse (including any beyond the cap, so the caller can report
    how many were left)."""
    spans = gdv.gdscript_block_spans(answer)
    pieces: list[str] = []
    cursor = num_fixed = num_broken = passes = 0
    for code, start, end in spans:
        pieces.append(answer[cursor:start])
        block_text = answer[start:end]            # the whole ```...``` fence
        res = gdv.validate_code(code)
        if not res.ok:
            num_broken += 1
            if passes < MAX_FIX_PASSES:
                passes += 1
                fix_out = gen.generate(promptlib.build_fix_messages(code, res.error))
                fix_code = gdv.first_gdscript_block(fix_out)
                # Only splice in a fix that actually parses; otherwise keep the
                # original (it stays flagged ❌ in the validation report).
                if fix_code and gdv.validate_code(fix_code).ok:
                    block_text = f"```gdscript\n{fix_code}\n```"
                    num_fixed += 1
        pieces.append(block_text)
        cursor = end
    pieces.append(answer[cursor:])
    return "".join(pieces), num_fixed, num_broken


def respond(message: str, history, top_k: int, self_correct: bool,
            history_turns: int = promptlib.MAX_HISTORY_TURNS):
    message = (message or "").strip()
    if not message:
        return "Ask a GDScript or Godot question."

    hits = rag.retrieve(message, k=int(top_k))
    messages = promptlib.build_messages(message, hits, history=history,
                                        max_turns=int(history_turns))
    answer = gen.generate(messages)

    # Self-correction: repair EVERY broken GDScript block in place, capped at
    # MAX_FIX_PASSES GPU calls so a pathological answer can't blow the budget.
    fix_note = ""
    if self_correct:
        answer, n_fixed, n_broken = _autocorrect(answer)
        if n_broken:
            head = f"🔧 Auto-corrected {n_fixed}/{n_broken} broken block(s) in place"
            if n_broken > MAX_FIX_PASSES:
                head += f" — capped at {MAX_FIX_PASSES} fix passes, " \
                        f"{n_broken - MAX_FIX_PASSES} not attempted"
            fix_note = f"\n\n**{head}.**"

    results = gdv.validate_answer(answer)
    report = gdv.render_report(results)
    note = ("" if rag.index_available()
            else "\n\n> ⏳ _Retrieval index not loaded yet — answering without "
                 "corpus context. Build & push the index (see DEPLOY.md)._")
    # The VALIDATION_DELIM prefix lets prompt._clean_assistant strip this
    # decoration when the turn is fed back as multi-turn history.
    return (f"{answer}{promptlib.VALIDATION_DELIM} \n{report}{fix_note}"
            f"{_sources_md(hits)}{note}")


with gr.Blocks(title="GDScript Coding Assistant", fill_height=True) as demo:
    gr.Markdown(
        "# 🤖 GDScript Coding Assistant\n"
        "RAG over a 91,720-chunk Godot/GDScript corpus · Qwen2.5-Coder-7B · "
        "answers are **syntax-validated with gdtoolkit**."
    )
    with gr.Accordion("Settings", open=False):
        top_k = gr.Slider(2, 10, value=6, step=1, label="Retrieved snippets (k)")
        self_correct = gr.Checkbox(
            value=True, label="Auto-correct one syntax error (extra GPU call)")
        # Chat memory is hardcoded to 4 turns. Kept as a non-interactive input
        # (rather than removed) so the /respond API signature stays
        # (message, top_k, self_correct, history_turns).
        history_turns = gr.Slider(
            0, 8, value=promptlib.MAX_HISTORY_TURNS, step=1, interactive=False,
            label=f"Chat memory: fixed at {promptlib.MAX_HISTORY_TURNS} prior turns")

    gr.ChatInterface(
        fn=respond,
        additional_inputs=[top_k, self_correct, history_turns],
        examples=[
            ["Write a CharacterBody2D top-down movement script", 6, True, 4],
            ["How do I define and emit a custom signal?", 6, True, 4],
            ["Show a typed @export inventory array with @onready", 6, True, 4],
            ["Make an enemy follow the player using a NavigationAgent2D", 6, True, 4],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    # Preload index/chunks/embedder (and the model unless stubbed) at startup.
    try:
        rag.warmup()
    except Exception as e:
        print(f"warmup (rag) skipped: {e}")
    demo.queue(max_size=16).launch()