Spaces:
Running on Zero
Running on Zero
| """GDScript Coding Assistant — Gradio app (HF Space, ZeroGPU). | |
| Flow per question: retrieve (CPU) -> generate (ZeroGPU) -> validate (CPU) -> | |
| optional 1x self-correct -> render answer + validation + sources. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| # Route the HF cache to persistent storage (mounted at /data when the Space has | |
| # "persistent storage" enabled) BEFORE importing any HF library, so model and | |
| # embedder weights survive restarts and are not re-downloaded on every cold | |
| # boot. No-op when /data isn't present/writable (falls back to the default | |
| # ephemeral cache). | |
| if not os.environ.get("HF_HOME") and os.path.isdir("/data") and os.access("/data", os.W_OK): | |
| os.environ["HF_HOME"] = "/data/huggingface" | |
| # Import spaces BEFORE any torch-importing library (gradio/rag/generate) so | |
| # ZeroGPU can patch CUDA and keep the model GPU-resident. No-op off-Space. | |
| try: | |
| import spaces # noqa: F401 | |
| except Exception: | |
| pass | |
| import gradio as gr | |
| import rag | |
| import prompt as promptlib | |
| import generate as gen | |
| import validate as gdv | |
| # Cap on auto-correction GPU calls per answer. Each broken block fixed is one | |
| # extra gen.generate() (one @spaces.GPU call), so this bounds the total GPU | |
| # work a single pathological answer (many broken blocks) can trigger. | |
| MAX_FIX_PASSES = 3 | |
| def _sources_md(hits: list[rag.Hit]) -> str: | |
| if not hits: | |
| return "" | |
| lines = ["\n\n<details><summary>📚 Retrieved sources</summary>\n"] | |
| for i, h in enumerate(hits, 1): | |
| loc = h.repo or "corpus" | |
| url = h.origin_url or "" | |
| link = f"[{loc}]({url})" if url.startswith("http") else loc | |
| lines.append(f"{i}. {link} · `{h.file_path or h.kind}` · score {h.score:.2f}") | |
| lines.append("\n</details>") | |
| return "\n".join(lines) | |
| def _autocorrect(answer: str) -> tuple[str, int, int]: | |
| """Repair each broken ```gdscript block IN PLACE, fixing at most | |
| MAX_FIX_PASSES of them (one GPU call each). Returns | |
| (new_answer, num_fixed, num_broken). num_broken counts every block that | |
| failed to parse (including any beyond the cap, so the caller can report | |
| how many were left).""" | |
| spans = gdv.gdscript_block_spans(answer) | |
| pieces: list[str] = [] | |
| cursor = num_fixed = num_broken = passes = 0 | |
| for code, start, end in spans: | |
| pieces.append(answer[cursor:start]) | |
| block_text = answer[start:end] # the whole ```...``` fence | |
| res = gdv.validate_code(code) | |
| if not res.ok: | |
| num_broken += 1 | |
| if passes < MAX_FIX_PASSES: | |
| passes += 1 | |
| fix_out = gen.generate(promptlib.build_fix_messages(code, res.error)) | |
| fix_code = gdv.first_gdscript_block(fix_out) | |
| # Only splice in a fix that actually parses; otherwise keep the | |
| # original (it stays flagged ❌ in the validation report). | |
| if fix_code and gdv.validate_code(fix_code).ok: | |
| block_text = f"```gdscript\n{fix_code}\n```" | |
| num_fixed += 1 | |
| pieces.append(block_text) | |
| cursor = end | |
| pieces.append(answer[cursor:]) | |
| return "".join(pieces), num_fixed, num_broken | |
| def respond(message: str, history, top_k: int, self_correct: bool, | |
| history_turns: int = promptlib.MAX_HISTORY_TURNS): | |
| message = (message or "").strip() | |
| if not message: | |
| return "Ask a GDScript or Godot question." | |
| hits = rag.retrieve(message, k=int(top_k)) | |
| messages = promptlib.build_messages(message, hits, history=history, | |
| max_turns=int(history_turns)) | |
| answer = gen.generate(messages) | |
| # Self-correction: repair EVERY broken GDScript block in place, capped at | |
| # MAX_FIX_PASSES GPU calls so a pathological answer can't blow the budget. | |
| fix_note = "" | |
| if self_correct: | |
| answer, n_fixed, n_broken = _autocorrect(answer) | |
| if n_broken: | |
| head = f"🔧 Auto-corrected {n_fixed}/{n_broken} broken block(s) in place" | |
| if n_broken > MAX_FIX_PASSES: | |
| head += f" — capped at {MAX_FIX_PASSES} fix passes, " \ | |
| f"{n_broken - MAX_FIX_PASSES} not attempted" | |
| fix_note = f"\n\n**{head}.**" | |
| results = gdv.validate_answer(answer) | |
| report = gdv.render_report(results) | |
| note = ("" if rag.index_available() | |
| else "\n\n> ⏳ _Retrieval index not loaded yet — answering without " | |
| "corpus context. Build & push the index (see DEPLOY.md)._") | |
| # The VALIDATION_DELIM prefix lets prompt._clean_assistant strip this | |
| # decoration when the turn is fed back as multi-turn history. | |
| return (f"{answer}{promptlib.VALIDATION_DELIM} \n{report}{fix_note}" | |
| f"{_sources_md(hits)}{note}") | |
| with gr.Blocks(title="GDScript Coding Assistant", fill_height=True) as demo: | |
| gr.Markdown( | |
| "# 🤖 GDScript Coding Assistant\n" | |
| "RAG over a 91,720-chunk Godot/GDScript corpus · Qwen2.5-Coder-7B · " | |
| "answers are **syntax-validated with gdtoolkit**." | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| top_k = gr.Slider(2, 10, value=6, step=1, label="Retrieved snippets (k)") | |
| self_correct = gr.Checkbox( | |
| value=True, label="Auto-correct one syntax error (extra GPU call)") | |
| # Chat memory is hardcoded to 4 turns. Kept as a non-interactive input | |
| # (rather than removed) so the /respond API signature stays | |
| # (message, top_k, self_correct, history_turns). | |
| history_turns = gr.Slider( | |
| 0, 8, value=promptlib.MAX_HISTORY_TURNS, step=1, interactive=False, | |
| label=f"Chat memory: fixed at {promptlib.MAX_HISTORY_TURNS} prior turns") | |
| gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[top_k, self_correct, history_turns], | |
| examples=[ | |
| ["Write a CharacterBody2D top-down movement script", 6, True, 4], | |
| ["How do I define and emit a custom signal?", 6, True, 4], | |
| ["Show a typed @export inventory array with @onready", 6, True, 4], | |
| ["Make an enemy follow the player using a NavigationAgent2D", 6, True, 4], | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| # Preload index/chunks/embedder (and the model unless stubbed) at startup. | |
| try: | |
| rag.warmup() | |
| except Exception as e: | |
| print(f"warmup (rag) skipped: {e}") | |
| demo.queue(max_size=16).launch() | |