vivekchakraverty's picture
Auto-correct EVERY broken GDScript block in place (capped at MAX_FIX_PASSES)
635e6fb
"""GDScript Coding Assistant — Gradio app (HF Space, ZeroGPU).
Flow per question: retrieve (CPU) -> generate (ZeroGPU) -> validate (CPU) ->
optional 1x self-correct -> render answer + validation + sources.
"""
from __future__ import annotations
import os
# Route the HF cache to persistent storage (mounted at /data when the Space has
# "persistent storage" enabled) BEFORE importing any HF library, so model and
# embedder weights survive restarts and are not re-downloaded on every cold
# boot. No-op when /data isn't present/writable (falls back to the default
# ephemeral cache).
if not os.environ.get("HF_HOME") and os.path.isdir("/data") and os.access("/data", os.W_OK):
os.environ["HF_HOME"] = "/data/huggingface"
# Import spaces BEFORE any torch-importing library (gradio/rag/generate) so
# ZeroGPU can patch CUDA and keep the model GPU-resident. No-op off-Space.
try:
import spaces # noqa: F401
except Exception:
pass
import gradio as gr
import rag
import prompt as promptlib
import generate as gen
import validate as gdv
# Cap on auto-correction GPU calls per answer. Each broken block fixed is one
# extra gen.generate() (one @spaces.GPU call), so this bounds the total GPU
# work a single pathological answer (many broken blocks) can trigger.
MAX_FIX_PASSES = 3
def _sources_md(hits: list[rag.Hit]) -> str:
if not hits:
return ""
lines = ["\n\n<details><summary>📚 Retrieved sources</summary>\n"]
for i, h in enumerate(hits, 1):
loc = h.repo or "corpus"
url = h.origin_url or ""
link = f"[{loc}]({url})" if url.startswith("http") else loc
lines.append(f"{i}. {link} · `{h.file_path or h.kind}` · score {h.score:.2f}")
lines.append("\n</details>")
return "\n".join(lines)
def _autocorrect(answer: str) -> tuple[str, int, int]:
"""Repair each broken ```gdscript block IN PLACE, fixing at most
MAX_FIX_PASSES of them (one GPU call each). Returns
(new_answer, num_fixed, num_broken). num_broken counts every block that
failed to parse (including any beyond the cap, so the caller can report
how many were left)."""
spans = gdv.gdscript_block_spans(answer)
pieces: list[str] = []
cursor = num_fixed = num_broken = passes = 0
for code, start, end in spans:
pieces.append(answer[cursor:start])
block_text = answer[start:end] # the whole ```...``` fence
res = gdv.validate_code(code)
if not res.ok:
num_broken += 1
if passes < MAX_FIX_PASSES:
passes += 1
fix_out = gen.generate(promptlib.build_fix_messages(code, res.error))
fix_code = gdv.first_gdscript_block(fix_out)
# Only splice in a fix that actually parses; otherwise keep the
# original (it stays flagged ❌ in the validation report).
if fix_code and gdv.validate_code(fix_code).ok:
block_text = f"```gdscript\n{fix_code}\n```"
num_fixed += 1
pieces.append(block_text)
cursor = end
pieces.append(answer[cursor:])
return "".join(pieces), num_fixed, num_broken
def respond(message: str, history, top_k: int, self_correct: bool,
history_turns: int = promptlib.MAX_HISTORY_TURNS):
message = (message or "").strip()
if not message:
return "Ask a GDScript or Godot question."
hits = rag.retrieve(message, k=int(top_k))
messages = promptlib.build_messages(message, hits, history=history,
max_turns=int(history_turns))
answer = gen.generate(messages)
# Self-correction: repair EVERY broken GDScript block in place, capped at
# MAX_FIX_PASSES GPU calls so a pathological answer can't blow the budget.
fix_note = ""
if self_correct:
answer, n_fixed, n_broken = _autocorrect(answer)
if n_broken:
head = f"🔧 Auto-corrected {n_fixed}/{n_broken} broken block(s) in place"
if n_broken > MAX_FIX_PASSES:
head += f" — capped at {MAX_FIX_PASSES} fix passes, " \
f"{n_broken - MAX_FIX_PASSES} not attempted"
fix_note = f"\n\n**{head}.**"
results = gdv.validate_answer(answer)
report = gdv.render_report(results)
note = ("" if rag.index_available()
else "\n\n> ⏳ _Retrieval index not loaded yet — answering without "
"corpus context. Build & push the index (see DEPLOY.md)._")
# The VALIDATION_DELIM prefix lets prompt._clean_assistant strip this
# decoration when the turn is fed back as multi-turn history.
return (f"{answer}{promptlib.VALIDATION_DELIM} \n{report}{fix_note}"
f"{_sources_md(hits)}{note}")
with gr.Blocks(title="GDScript Coding Assistant", fill_height=True) as demo:
gr.Markdown(
"# 🤖 GDScript Coding Assistant\n"
"RAG over a 91,720-chunk Godot/GDScript corpus · Qwen2.5-Coder-7B · "
"answers are **syntax-validated with gdtoolkit**."
)
with gr.Accordion("Settings", open=False):
top_k = gr.Slider(2, 10, value=6, step=1, label="Retrieved snippets (k)")
self_correct = gr.Checkbox(
value=True, label="Auto-correct one syntax error (extra GPU call)")
# Chat memory is hardcoded to 4 turns. Kept as a non-interactive input
# (rather than removed) so the /respond API signature stays
# (message, top_k, self_correct, history_turns).
history_turns = gr.Slider(
0, 8, value=promptlib.MAX_HISTORY_TURNS, step=1, interactive=False,
label=f"Chat memory: fixed at {promptlib.MAX_HISTORY_TURNS} prior turns")
gr.ChatInterface(
fn=respond,
additional_inputs=[top_k, self_correct, history_turns],
examples=[
["Write a CharacterBody2D top-down movement script", 6, True, 4],
["How do I define and emit a custom signal?", 6, True, 4],
["Show a typed @export inventory array with @onready", 6, True, 4],
["Make an enemy follow the player using a NavigationAgent2D", 6, True, 4],
],
cache_examples=False,
)
if __name__ == "__main__":
# Preload index/chunks/embedder (and the model unless stubbed) at startup.
try:
rag.warmup()
except Exception as e:
print(f"warmup (rag) skipped: {e}")
demo.queue(max_size=16).launch()