Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any, Dict, List, Tuple | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.") | |
| ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit ONLY strict JSON with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nEach route_plan entry must be a tool call (e.g., /math(...), /code(...), /general-search(...)).\nBe concise but precise. Do not include prose outside of the JSON object.""" | |
| MODELS = { | |
| "Router-Qwen3-32B-8bit": { | |
| "repo_id": "Alovestocode/router-qwen3-32b-merged", | |
| "description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.", | |
| "params_b": 32.0, | |
| }, | |
| "Router-Gemma3-27B-8bit": { | |
| "repo_id": "Alovestocode/router-gemma3-merged", | |
| "description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.", | |
| "params_b": 27.0, | |
| }, | |
| } | |
| REQUIRED_KEYS = [ | |
| "route_plan", | |
| "route_rationale", | |
| "expected_artifacts", | |
| "thinking_outline", | |
| "handoff_plan", | |
| "todo_list", | |
| "difficulty", | |
| "tags", | |
| "acceptance_criteria", | |
| "metrics", | |
| ] | |
| PIPELINES: Dict[str, Any] = {} | |
| def load_pipeline(model_name: str): | |
| if model_name in PIPELINES: | |
| return PIPELINES[model_name] | |
| repo = MODELS[model_name]["repo_id"] | |
| tokenizer = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN) | |
| try: | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| model_kwargs={"quantization_config": quantization_config}, | |
| use_cache=True, | |
| token=HF_TOKEN, | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception as exc: | |
| print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.") | |
| for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
| try: | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| dtype=dtype, | |
| use_cache=True, | |
| token=HF_TOKEN, | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception: | |
| continue | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| use_cache=True, | |
| token=HF_TOKEN, | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| def build_router_prompt( | |
| user_task: str, | |
| context: str, | |
| acceptance: str, | |
| extra_guidance: str, | |
| difficulty: str, | |
| tags: str, | |
| ) -> str: | |
| prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"] | |
| prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}") | |
| prompt_parts.append(f"Tags: {tags or 'general'}") | |
| if acceptance.strip(): | |
| prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}") | |
| if extra_guidance.strip(): | |
| prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}") | |
| if context.strip(): | |
| prompt_parts.append("\n### Supporting context\n" + context.strip()) | |
| prompt_parts.append("\n### User task\n" + user_task.strip()) | |
| prompt_parts.append("\nReturn only JSON.") | |
| return "\n".join(prompt_parts) | |
| def extract_json_from_text(text: str) -> str: | |
| start = text.find("{") | |
| if start == -1: | |
| raise ValueError("Router output did not contain a JSON object.") | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| for idx in range(start, len(text)): | |
| ch = text[idx] | |
| if in_string: | |
| if escape: | |
| escape = False | |
| elif ch == "\\": | |
| escape = True | |
| elif ch == '"': | |
| in_string = False | |
| continue | |
| if ch == '"': | |
| in_string = True | |
| continue | |
| if ch == '{': | |
| depth += 1 | |
| elif ch == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| return text[start : idx + 1] | |
| raise ValueError("Router output JSON appears truncated.") | |
| def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]: | |
| issues: List[str] = [] | |
| for key in REQUIRED_KEYS: | |
| if key not in plan: | |
| issues.append(f"Missing key: {key}") | |
| route_plan = plan.get("route_plan") | |
| if not isinstance(route_plan, list) or not route_plan: | |
| issues.append("route_plan must be a non-empty list of tool calls") | |
| metrics = plan.get("metrics") | |
| if not isinstance(metrics, dict): | |
| issues.append("metrics must be an object containing primary/secondary entries") | |
| todo = plan.get("todo_list") | |
| if not isinstance(todo, list) or not todo: | |
| issues.append("todo_list must contain at least one checklist item") | |
| return len(issues) == 0, issues | |
| def format_validation_message(ok: bool, issues: List[str]) -> str: | |
| if ok: | |
| return "✅ Router plan includes all required fields." | |
| bullets = "\n".join(f"- {issue}" for issue in issues) | |
| return f"❌ Issues detected:\n{bullets}" | |
| def generate_router_plan( | |
| user_task: str, | |
| context: str, | |
| acceptance: str, | |
| extra_guidance: str, | |
| difficulty: str, | |
| tags: str, | |
| model_choice: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> Tuple[str, Dict[str, Any], str, str]: | |
| if not user_task.strip(): | |
| raise gr.Error("User task is required.") | |
| if model_choice not in MODELS: | |
| raise gr.Error(f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}") | |
| try: | |
| prompt = build_router_prompt( | |
| user_task=user_task, | |
| context=context, | |
| acceptance=acceptance, | |
| extra_guidance=extra_guidance, | |
| difficulty=difficulty, | |
| tags=tags, | |
| ) | |
| generator = load_pipeline(model_choice) | |
| result = generator( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| )[0]["generated_text"] | |
| completion = result[len(prompt) :].strip() if result.startswith(prompt) else result.strip() | |
| try: | |
| json_block = extract_json_from_text(completion) | |
| plan = json.loads(json_block) | |
| ok, issues = validate_router_plan(plan) | |
| validation_msg = format_validation_message(ok, issues) | |
| except Exception as exc: | |
| plan = {} | |
| validation_msg = f"❌ JSON parsing failed: {exc}" | |
| return completion, plan, validation_msg, prompt | |
| except Exception as exc: | |
| error_msg = f"❌ Generation failed: {str(exc)}" | |
| return "", {}, error_msg, "" | |
| def clear_outputs(): | |
| return "", {}, "Awaiting generation.", "" | |
| def build_ui(): | |
| description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans." | |
| with gr.Blocks(theme=gr.themes.Soft(), css=""" | |
| textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; } | |
| .status-ok { color: #0d9488; font-weight: 600; } | |
| .status-bad { color: #dc2626; font-weight: 600; } | |
| """) as demo: | |
| gr.Markdown("# 🛰️ Router Control Room — ZeroGPU" ) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| user_task = gr.Textbox( | |
| label="User Task / Problem Statement", | |
| placeholder="Describe the homework-style query that needs routing...", | |
| lines=8, | |
| value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.", | |
| ) | |
| context = gr.Textbox( | |
| label="Supporting Context (optional)", | |
| placeholder="Paste any retrieved evidence, PDFs, or rubric notes.", | |
| lines=4, | |
| ) | |
| acceptance = gr.Textbox( | |
| label="Acceptance Criteria", | |
| placeholder="Bullet list of 'definition of done' checks.", | |
| lines=3, | |
| value="- Provide citations for every claim.\n- Ensure /math verifies /code output.", | |
| ) | |
| extra_guidance = gr.Textbox( | |
| label="Additional Guidance", | |
| placeholder="Special constraints, tools to avoid, etc.", | |
| lines=3, | |
| ) | |
| with gr.Column(scale=2): | |
| model_choice = gr.Dropdown( | |
| label="Router Checkpoint", | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0] if MODELS else None, | |
| allow_custom_value=False, | |
| ) | |
| difficulty = gr.Radio( | |
| label="Difficulty Tier", | |
| choices=["introductory", "intermediate", "advanced"], | |
| value="advanced", | |
| interactive=True, | |
| ) | |
| tags = gr.Textbox( | |
| label="Tags", | |
| placeholder="Comma-separated e.g. calculus, optimization, python", | |
| value="calculus, optimization, python", | |
| ) | |
| max_new_tokens = gr.Slider(256, 1024, value=640, step=32, label="Max New Tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| generate_btn = gr.Button("Generate Router Plan", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Row(): | |
| raw_output = gr.Textbox(label="Raw Model Output", lines=12) | |
| plan_json = gr.JSON(label="Parsed Router Plan") | |
| validation_msg = gr.Markdown("Awaiting generation.") | |
| prompt_view = gr.Textbox(label="Full Prompt", lines=10) | |
| generate_btn.click( | |
| generate_router_plan, | |
| inputs=[ | |
| user_task, | |
| context, | |
| acceptance, | |
| extra_guidance, | |
| difficulty, | |
| tags, | |
| model_choice, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| ], | |
| outputs=[raw_output, plan_json, validation_msg, prompt_view], | |
| ) | |
| clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view]) | |
| return demo | |
| demo = build_ui() | |
| if __name__ == "__main__": # pragma: no cover | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |