| import json |
| import os |
| import re |
| import ast |
| from pathlib import Path |
|
|
| import gradio as gr |
| import httpx |
|
|
| LANGUAGES = ["Python", "JavaScript", "TypeScript", "Rust", "Go", "C++"] |
|
|
|
|
| def load_local_env() -> None: |
| env_path = Path(".env") |
| if not env_path.exists(): |
| return |
|
|
| for line in env_path.read_text(encoding="utf-8").splitlines(): |
| if not line or line.startswith("#") or "=" not in line: |
| continue |
| key, value = line.split("=", 1) |
| os.environ.setdefault(key.strip(), value.strip()) |
|
|
|
|
| load_local_env() |
| MODAL_VERIFIER_URL = os.environ.get("MODAL_VERIFIER_URL") |
| MODAL_SANDBOX_URL = os.environ.get("MODAL_SANDBOX_URL") |
| os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") |
|
|
|
|
| def load_static(filename: str) -> str: |
| return Path("static", filename).read_text(encoding="utf-8") |
|
|
|
|
| def endpoint_url(url: str | None, path: str) -> str | None: |
| if not url: |
| return None |
| clean = url.rstrip("/") |
| if clean.endswith(path): |
| return clean |
| return f"{clean}{path}" |
|
|
|
|
| custom_html = f""" |
| <div id="split-brain-root"> |
| <div class="brain-rail" aria-label="Split-brain architecture"> |
| <div class="brain-node local-node"> |
| <span class="brain-label">Local Draft</span> |
| <strong>WebGPU 1.5B</strong> |
| <small>fast browser stream</small> |
| </div> |
| <div class="brain-pulse" aria-hidden="true"> |
| <span></span> |
| <span></span> |
| <span></span> |
| </div> |
| <div class="brain-node cloud-node"> |
| <span class="brain-label">Cloud Check</span> |
| <strong>Modal A10G 14B</strong> |
| <small>llama.cpp verifier</small> |
| </div> |
| </div> |
| <div class="webgpu-notice" id="webgpu-warning" hidden> |
| WebGPU not detected. Use Chrome 113+ on desktop for local inference. |
| </div> |
| <div id="load-section" class="load-section"> |
| <button id="load-btn" class="local-button" onclick="window.initEngine()">Load 1.5B Model</button> |
| <div class="loading-bar"><div class="loading-bar-fill" id="load-progress"></div></div> |
| <span id="load-status" class="load-status">Model not loaded</span> |
| </div> |
| <div class="stream-shell"> |
| <div class="stream-toolbar"> |
| <span>Speculative draft</span> |
| <span id="stream-phase">Idle</span> |
| </div> |
| <pre id="stream-display" class="code-stream">Waiting for model load...</pre> |
| </div> |
| <div class="status-bar"> |
| <span id="status-text">Idle</span> |
| <span id="token-count">0 tok/s</span> |
| <span id="verifier-status">Verifier idle</span> |
| </div> |
| </div> |
| <script type="module"> |
| {load_static("engine.js")} |
| {load_static("ui.js")} |
| |
| const warning = document.getElementById("webgpu-warning"); |
| const loadButton = document.getElementById("load-btn"); |
| |
| function formatBrowserError(error) {{ |
| if (!error) return "unknown error"; |
| if (error.message) return error.message; |
| if (typeof error === "string") return error; |
| try {{ |
| return JSON.stringify(error); |
| }} catch (_err) {{ |
| return String(error); |
| }} |
| }} |
| |
| function findGradioInput(id) {{ |
| const root = document.getElementById(id); |
| if (!root) return null; |
| if (root.matches("input, textarea")) return root; |
| return root.querySelector("input, textarea"); |
| }} |
| |
| function findGradioButton(id) {{ |
| const root = document.getElementById(id); |
| if (!root) return null; |
| if (root.matches("button")) return root; |
| return root.querySelector("button"); |
| }} |
| |
| function cleanGeneratedCode(code) {{ |
| if (!code) return ""; |
| return stripMarkdownCodeFence(code); |
| }} |
| |
| if (!isWebGPUSupported()) {{ |
| warning.hidden = false; |
| loadButton.disabled = true; |
| setStatus("Chrome 113+ with WebGPU required", "warning"); |
| }} |
| |
| window.initEngine = async function() {{ |
| loadButton.disabled = true; |
| document.getElementById("load-status").textContent = "Loading model weights..."; |
| try {{ |
| await loadModel((progress) => {{ |
| if (progress.status === "attempt") {{ |
| document.getElementById("load-status").textContent = `Trying ${{progress.dtype}} WebGPU weights...`; |
| return; |
| }} |
| const value = progress.progress ? Math.round(progress.progress) : 0; |
| document.getElementById("load-progress").style.width = `${{value}}%`; |
| if (progress.file) {{ |
| document.getElementById("load-status").textContent = `${{progress.file}} (${{progress.dtype || "auto"}}) - ${{value}}%`; |
| }} |
| }}); |
| document.getElementById("load-progress").style.width = "100%"; |
| document.getElementById("load-status").textContent = `Model ready - WebGPU active (${{getActiveDtype() || "auto"}})`; |
| document.getElementById("load-section").classList.add("loaded"); |
| setStatus("Ready", "success"); |
| }} catch (error) {{ |
| console.error("Model load failed", error); |
| loadButton.disabled = false; |
| setStatus(`Model load failed: ${{formatBrowserError(error)}}`, "warning"); |
| document.getElementById("load-status").textContent = "Load failed"; |
| }} |
| }}; |
| |
| window.runLocalGeneration = async function(prompt, language) {{ |
| if (!prompt || !prompt.trim()) {{ |
| setStatus("Enter a prompt first", "warning"); |
| return []; |
| }} |
| |
| reset(); |
| setVerifierStatus("IDLE"); |
| setStatus("Generating locally (WebGPU)...", "neutral"); |
| |
| let tokenCount = 0; |
| const startTime = Date.now(); |
| |
| try {{ |
| const fullCode = await generateCode( |
| prompt, |
| language, |
| (token) => {{ |
| appendToken(token); |
| tokenCount += 1; |
| const elapsed = Math.max((Date.now() - startTime) / 1000, 0.1); |
| document.getElementById("token-count").textContent = `${{Math.round(tokenCount / elapsed)}} tok/s`; |
| }}, |
| () => {{ |
| setStatus("Local generation complete. Verifier warming up...", "neutral"); |
| setVerifierStatus("CHECKING"); |
| }} |
| ); |
| const cleanCode = cleanGeneratedCode(fullCode); |
| if (cleanCode !== getCurrentCode()) {{ |
| setCode(cleanCode); |
| }} |
| |
| const hidden = findGradioInput("draft-output-hidden"); |
| const trigger = findGradioButton("trigger-verify-btn"); |
| if (!hidden || !trigger) {{ |
| setStatus("Gradio verification bridge not ready", "warning"); |
| return []; |
| }} |
| |
| hidden.value = cleanCode; |
| hidden.dispatchEvent(new Event("input", {{ bubbles: true }})); |
| trigger.click(); |
| }} catch (error) {{ |
| setStatus(`Generation failed: ${{error.message}}`, "warning"); |
| }} |
| return []; |
| }}; |
| |
| window.applyVerification = function(verdictJson) {{ |
| if (!verdictJson) return []; |
| let verdict; |
| try {{ |
| verdict = JSON.parse(verdictJson); |
| }} catch (error) {{ |
| setStatus("Verifier returned invalid JSON", "warning"); |
| return []; |
| }} |
| |
| if (verdict.verdict === "PASS") {{ |
| setVerifierStatus("PASS"); |
| setStatus("Verified clean", "success"); |
| }} else if (verdict.verdict === "ERROR") {{ |
| setVerifierStatus("ERROR"); |
| setStatus(`Verifier failed: ${{verdict.reason || "unknown error"}}`, "warning"); |
| }} else {{ |
| verdict.corrected_code = cleanGeneratedCode(verdict.corrected_code || ""); |
| rollbackAndReplace(verdict.corrected_code, verdict.reason || "Verifier supplied a correction", verdict.verdict); |
| }} |
| return []; |
| }}; |
| </script> |
| """ |
|
|
|
|
| async def verify_with_modal(prompt: str, draft_code: str, language: str) -> str: |
| draft_code = strip_markdown_code_fence(draft_code) |
| verifier_url = endpoint_url(MODAL_VERIFIER_URL, "/verify") |
| if not verifier_url: |
| return json.dumps( |
| { |
| "verdict": "PASS", |
| "reason": "MODAL_VERIFIER_URL is not configured; local demo fallback used.", |
| } |
| ) |
|
|
| try: |
| async with httpx.AsyncClient(timeout=180.0) as client: |
| response = await client.post( |
| verifier_url, |
| json={"prompt": prompt, "draft_code": draft_code, "language": language.lower()}, |
| ) |
| response.raise_for_status() |
| return response.text |
| except Exception as exc: |
| return json.dumps({"verdict": "ERROR", "reason": str(exc)}) |
|
|
|
|
| async def execute_in_sandbox(code: str) -> dict: |
| code = strip_markdown_code_fence(code) |
| sandbox_url = endpoint_url(MODAL_SANDBOX_URL, "/execute") |
| if not sandbox_url: |
| return {"stdout": "", "stderr": "Sandbox not configured", "returncode": -1} |
|
|
| async with httpx.AsyncClient(timeout=30.0) as client: |
| response = await client.post(sandbox_url, json={"code": code}) |
| response.raise_for_status() |
| return response.json() |
|
|
|
|
| def code_from_verdict(draft_code: str, verdict_json: str) -> str: |
| draft_code = strip_markdown_code_fence(draft_code) |
| if not verdict_json: |
| return draft_code |
| try: |
| verdict = json.loads(verdict_json) |
| except json.JSONDecodeError: |
| return draft_code |
| return strip_markdown_code_fence(verdict.get("corrected_code") or draft_code) |
|
|
|
|
| def strip_markdown_code_fence(code: str) -> str: |
| text = (code or "").strip() |
| if not text: |
| return "" |
|
|
| opening_fence = re.match(r"^```(?:[a-zA-Z0-9_+#.-]+)?\s*\n?", text) |
| if opening_fence: |
| text = text[opening_fence.end() :] |
| closing_index = text.find("```") |
| if closing_index >= 0: |
| text = text[:closing_index] |
| else: |
| first_fence = text.find("```") |
| if first_fence >= 0: |
| text = text[:first_fence] |
|
|
| return trim_markdown_explanation(text) |
|
|
|
|
| def trim_markdown_explanation(text: str) -> str: |
| explanation = re.compile( |
| r"^\s*(?:[-*]\s+|\d+\.\s+|#{1,6}\s+|Explanation\s*:|Steps\s*:|Notes?\s*:|The code\b|This code\b)", |
| re.IGNORECASE, |
| ) |
| kept = [] |
| for line in text.splitlines(): |
| if explanation.match(line): |
| break |
| kept.append(line) |
| return "\n".join(kept).strip() |
|
|
|
|
| async def run_sandbox(language: str, draft_code: str, verdict_json: str) -> str: |
| if language.lower() != "python": |
| return "Sandbox execution is currently wired for Python only." |
|
|
| code = prepare_python_for_sandbox(code_from_verdict(draft_code, verdict_json)) |
| if not code.strip(): |
| return "No generated code is available yet." |
|
|
| result = await execute_in_sandbox(code) |
| stdout = result.get("stdout", "") |
| stderr = result.get("stderr", "") |
| returncode = result.get("returncode", "") |
| return "\n".join( |
| [ |
| f"returncode: {returncode}", |
| "", |
| "stdout:", |
| stdout or "<empty>", |
| "", |
| "stderr:", |
| stderr or "<empty>", |
| ] |
| ) |
|
|
|
|
| def prepare_python_for_sandbox(code: str) -> str: |
| code = strip_markdown_code_fence(code) |
| try: |
| tree = ast.parse(code) |
| except SyntaxError: |
| return code |
|
|
| executable_nodes = ( |
| ast.Assign, |
| ast.AugAssign, |
| ast.AnnAssign, |
| ast.Assert, |
| ast.Delete, |
| ast.Expr, |
| ast.For, |
| ast.AsyncFor, |
| ast.While, |
| ast.If, |
| ast.Match, |
| ast.Raise, |
| ast.Return, |
| ast.Try, |
| ast.With, |
| ast.AsyncWith, |
| ) |
| has_top_level_execution = any(isinstance(node, executable_nodes) for node in tree.body) |
| if has_top_level_execution: |
| return code |
|
|
| for node in tree.body: |
| if isinstance(node, ast.FunctionDef) and function_has_no_required_args(node): |
| return f'{code}\n\nif __name__ == "__main__":\n {node.name}()\n' |
|
|
| return code |
|
|
|
|
| def function_has_no_required_args(node: ast.FunctionDef) -> bool: |
| args = node.args |
| positional = [*args.posonlyargs, *args.args] |
| required_positional = len(positional) - len(args.defaults) |
| required_kwonly = sum( |
| 1 for arg, default in zip(args.kwonlyargs, args.kw_defaults) if default is None |
| ) |
| return required_positional == 0 and required_kwonly == 0 |
|
|
|
|
| with gr.Blocks( |
| title="Split-Brain Co-Pilot", |
| css=load_static("style.css"), |
| theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"), |
| ) as demo: |
| gr.HTML( |
| """ |
| <section class="app-header"> |
| <p class="eyebrow">Build Small Hackathon · 15.5B parameters total</p> |
| <h1>Split-Brain Co-Pilot</h1> |
| <p>One small model drafts in your browser. Another small model checks it on Modal. The UI shows the handoff, verdict, rollback, and executable proof.</p> |
| <div class="badge-row" aria-label="Project badges"> |
| <span>WebGPU local-first</span> |
| <span>llama.cpp verifier</span> |
| <span>Modal sandbox</span> |
| <span>Custom Gradio UI</span> |
| </div> |
| </section> |
| <div class="space-init" id="space-init">Space initializing...</div> |
| <script> |
| requestAnimationFrame(() => { |
| const el = document.getElementById("space-init"); |
| if (el) el.hidden = true; |
| }); |
| </script> |
| """ |
| ) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=2, min_width=320): |
| prompt_input = gr.Textbox( |
| label="Prompt", |
| placeholder="Write a Python function that finds all prime numbers up to n using a segmented sieve, handling edge cases.", |
| lines=6, |
| ) |
| language_select = gr.Dropdown(choices=LANGUAGES, value="Python", label="Language") |
| generate_btn = gr.Button("Generate -> Verify", variant="primary") |
| with gr.Column(scale=3, min_width=420): |
| gr.HTML(custom_html) |
| draft_hidden = gr.Textbox( |
| label="draft bridge", |
| elem_id="draft-output-hidden", |
| elem_classes=["bridge-hidden"], |
| ) |
| verify_trigger = gr.Button( |
| "verify", |
| elem_id="trigger-verify-btn", |
| elem_classes=["bridge-hidden"], |
| ) |
| verdict_output = gr.Textbox( |
| label="verdict", |
| elem_classes=["bridge-hidden"], |
| ) |
|
|
| with gr.Row(): |
| sandbox_btn = gr.Button("Run Python Sandbox", variant="secondary") |
| sandbox_output = gr.Code(label="Sandbox Execution Output", language="shell") |
|
|
| generate_btn.click( |
| fn=None, |
| inputs=[prompt_input, language_select], |
| outputs=[], |
| js="(prompt, lang) => window.runLocalGeneration(prompt, lang)", |
| ) |
|
|
| async def run_verification(prompt: str, draft_code: str, language: str) -> str: |
| return await verify_with_modal(prompt, draft_code, language) |
|
|
| verify_trigger.click( |
| fn=run_verification, |
| inputs=[prompt_input, draft_hidden, language_select], |
| outputs=[verdict_output], |
| ) |
|
|
| verdict_output.change( |
| fn=None, |
| inputs=[verdict_output], |
| outputs=[], |
| js="(verdict) => window.applyVerification(verdict)", |
| ) |
|
|
| sandbox_btn.click( |
| fn=run_sandbox, |
| inputs=[language_select, draft_hidden, verdict_output], |
| outputs=[sandbox_output], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name=os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1"), |
| server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")), |
| show_api=False, |
| ) |
|
|