Alikestocode's picture
Fix deprecation warnings and improve error handling
bf2fdae
raw
history blame
11.4 kB
from __future__ import annotations
import json
import os
from typing import Any, Dict, List, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit ONLY strict JSON with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nEach route_plan entry must be a tool call (e.g., /math(...), /code(...), /general-search(...)).\nBe concise but precise. Do not include prose outside of the JSON object."""
MODELS = {
"Router-Qwen3-32B-8bit": {
"repo_id": "Alovestocode/router-qwen3-32b-merged",
"description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.",
"params_b": 32.0,
},
"Router-Gemma3-27B-8bit": {
"repo_id": "Alovestocode/router-gemma3-merged",
"description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.",
"params_b": 27.0,
},
}
REQUIRED_KEYS = [
"route_plan",
"route_rationale",
"expected_artifacts",
"thinking_outline",
"handoff_plan",
"todo_list",
"difficulty",
"tags",
"acceptance_criteria",
"metrics",
]
PIPELINES: Dict[str, Any] = {}
def load_pipeline(model_name: str):
if model_name in PIPELINES:
return PIPELINES[model_name]
repo = MODELS[model_name]["repo_id"]
tokenizer = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
try:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto",
model_kwargs={"quantization_config": quantization_config},
use_cache=True,
token=HF_TOKEN,
)
PIPELINES[model_name] = pipe
return pipe
except Exception as exc:
print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
for dtype in (torch.bfloat16, torch.float16, torch.float32):
try:
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto",
dtype=dtype,
use_cache=True,
token=HF_TOKEN,
)
PIPELINES[model_name] = pipe
return pipe
except Exception:
continue
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto",
use_cache=True,
token=HF_TOKEN,
)
PIPELINES[model_name] = pipe
return pipe
def build_router_prompt(
user_task: str,
context: str,
acceptance: str,
extra_guidance: str,
difficulty: str,
tags: str,
) -> str:
prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"]
prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}")
prompt_parts.append(f"Tags: {tags or 'general'}")
if acceptance.strip():
prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}")
if extra_guidance.strip():
prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}")
if context.strip():
prompt_parts.append("\n### Supporting context\n" + context.strip())
prompt_parts.append("\n### User task\n" + user_task.strip())
prompt_parts.append("\nReturn only JSON.")
return "\n".join(prompt_parts)
def extract_json_from_text(text: str) -> str:
start = text.find("{")
if start == -1:
raise ValueError("Router output did not contain a JSON object.")
depth = 0
in_string = False
escape = False
for idx in range(start, len(text)):
ch = text[idx]
if in_string:
if escape:
escape = False
elif ch == "\\":
escape = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
continue
if ch == '{':
depth += 1
elif ch == '}':
depth -= 1
if depth == 0:
return text[start : idx + 1]
raise ValueError("Router output JSON appears truncated.")
def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
issues: List[str] = []
for key in REQUIRED_KEYS:
if key not in plan:
issues.append(f"Missing key: {key}")
route_plan = plan.get("route_plan")
if not isinstance(route_plan, list) or not route_plan:
issues.append("route_plan must be a non-empty list of tool calls")
metrics = plan.get("metrics")
if not isinstance(metrics, dict):
issues.append("metrics must be an object containing primary/secondary entries")
todo = plan.get("todo_list")
if not isinstance(todo, list) or not todo:
issues.append("todo_list must contain at least one checklist item")
return len(issues) == 0, issues
def format_validation_message(ok: bool, issues: List[str]) -> str:
if ok:
return "✅ Router plan includes all required fields."
bullets = "\n".join(f"- {issue}" for issue in issues)
return f"❌ Issues detected:\n{bullets}"
@spaces.GPU(duration=600)
def generate_router_plan(
user_task: str,
context: str,
acceptance: str,
extra_guidance: str,
difficulty: str,
tags: str,
model_choice: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> Tuple[str, Dict[str, Any], str, str]:
if not user_task.strip():
raise gr.Error("User task is required.")
if model_choice not in MODELS:
raise gr.Error(f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}")
try:
prompt = build_router_prompt(
user_task=user_task,
context=context,
acceptance=acceptance,
extra_guidance=extra_guidance,
difficulty=difficulty,
tags=tags,
)
generator = load_pipeline(model_choice)
result = generator(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)[0]["generated_text"]
completion = result[len(prompt) :].strip() if result.startswith(prompt) else result.strip()
try:
json_block = extract_json_from_text(completion)
plan = json.loads(json_block)
ok, issues = validate_router_plan(plan)
validation_msg = format_validation_message(ok, issues)
except Exception as exc:
plan = {}
validation_msg = f"❌ JSON parsing failed: {exc}"
return completion, plan, validation_msg, prompt
except Exception as exc:
error_msg = f"❌ Generation failed: {str(exc)}"
return "", {}, error_msg, ""
def clear_outputs():
return "", {}, "Awaiting generation.", ""
def build_ui():
description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
with gr.Blocks(theme=gr.themes.Soft(), css="""
textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
.status-ok { color: #0d9488; font-weight: 600; }
.status-bad { color: #dc2626; font-weight: 600; }
""") as demo:
gr.Markdown("# 🛰️ Router Control Room — ZeroGPU" )
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=3):
user_task = gr.Textbox(
label="User Task / Problem Statement",
placeholder="Describe the homework-style query that needs routing...",
lines=8,
value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
)
context = gr.Textbox(
label="Supporting Context (optional)",
placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
lines=4,
)
acceptance = gr.Textbox(
label="Acceptance Criteria",
placeholder="Bullet list of 'definition of done' checks.",
lines=3,
value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
)
extra_guidance = gr.Textbox(
label="Additional Guidance",
placeholder="Special constraints, tools to avoid, etc.",
lines=3,
)
with gr.Column(scale=2):
model_choice = gr.Dropdown(
label="Router Checkpoint",
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0] if MODELS else None,
allow_custom_value=False,
)
difficulty = gr.Radio(
label="Difficulty Tier",
choices=["introductory", "intermediate", "advanced"],
value="advanced",
interactive=True,
)
tags = gr.Textbox(
label="Tags",
placeholder="Comma-separated e.g. calculus, optimization, python",
value="calculus, optimization, python",
)
max_new_tokens = gr.Slider(256, 1024, value=640, step=32, label="Max New Tokens")
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
generate_btn = gr.Button("Generate Router Plan", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Row():
raw_output = gr.Textbox(label="Raw Model Output", lines=12)
plan_json = gr.JSON(label="Parsed Router Plan")
validation_msg = gr.Markdown("Awaiting generation.")
prompt_view = gr.Textbox(label="Full Prompt", lines=10)
generate_btn.click(
generate_router_plan,
inputs=[
user_task,
context,
acceptance,
extra_guidance,
difficulty,
tags,
model_choice,
max_new_tokens,
temperature,
top_p,
],
outputs=[raw_output, plan_json, validation_msg, prompt_view],
)
clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view])
return demo
demo = build_ui()
if __name__ == "__main__": # pragma: no cover
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))