File size: 18,032 Bytes
bf2fdae
 
 
f91e906
1b16b00
bf2fdae
 
f91e906
bf2fdae
f91e906
1b16b00
4f65341
f91e906
bf2fdae
 
 
f91e906
f5a609d
1b16b00
f5a609d
 
f91e906
 
4c3d05b
 
bf2fdae
4c3d05b
f91e906
4c3d05b
 
bf2fdae
4c3d05b
f91e906
 
 
bf2fdae
 
 
 
 
 
 
 
 
 
 
 
f91e906
bf2fdae
1b16b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2fdae
 
 
f91e906
 
4c3d05b
f91e906
1b16b00
4c3d05b
 
1b16b00
4c3d05b
 
 
 
 
 
1b16b00
4c3d05b
bf2fdae
4c3d05b
1b16b00
4c3d05b
1b16b00
4c3d05b
 
 
 
f91e906
 
 
 
 
 
 
 
bf2fdae
4c3d05b
bf2fdae
4c3d05b
1b16b00
f91e906
1b16b00
f91e906
 
 
4c3d05b
f91e906
 
 
 
 
 
4c3d05b
bf2fdae
f91e906
1b16b00
f91e906
1b16b00
f91e906
 
 
1b16b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b16b00
 
 
 
bf2fdae
 
 
 
 
1b16b00
bf2fdae
1b16b00
 
 
bf2fdae
 
1b16b00
 
 
 
 
 
 
 
 
 
bf2fdae
 
 
 
 
 
1b16b00
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2fdae
 
 
 
 
 
 
 
 
 
 
4f65341
bf2fdae
 
 
 
 
 
 
 
 
 
4f65341
 
bf2fdae
4f65341
 
f91e906
bf2fdae
4f65341
 
f91e906
 
bf2fdae
 
 
 
 
 
 
f91e906
bf2fdae
 
4f65341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b16b00
 
4f65341
1b16b00
 
 
 
 
 
4f65341
 
 
 
f5a609d
 
 
4f65341
 
f5a609d
 
1b16b00
 
 
f5a609d
 
 
 
 
 
1b16b00
 
f5a609d
 
 
 
1b16b00
f5a609d
 
 
 
 
4f65341
 
f5a609d
1b16b00
f5a609d
 
 
 
 
 
 
 
 
 
 
4f65341
bf2fdae
 
4f65341
bf2fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f91e906
bf2fdae
 
 
 
f91e906
bf2fdae
 
 
f91e906
bf2fdae
f91e906
bf2fdae
 
 
 
f91e906
bf2fdae
 
 
 
 
 
f91e906
bf2fdae
 
 
 
 
f91e906
bf2fdae
 
 
 
f91e906
f5a609d
bf2fdae
 
 
 
 
 
 
 
 
 
 
 
 
4f65341
bf2fdae
 
 
 
 
 
 
 
 
 
 
 
 
4f65341
f5a609d
bf2fdae
f91e906
f5a609d
 
 
 
 
f91e906
bf2fdae
f91e906
bf2fdae
1b16b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2fdae
 
 
1b16b00
 
 
9773e4b
1b16b00
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
from __future__ import annotations

import json
import os
import re
from typing import Any, Dict, List, Tuple

import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline
from threading import Thread

HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")

PLAN_END_TOKEN = "<|end_of_plan|>"
STOP_SEQUENCES = [PLAN_END_TOKEN, "</json>", "</JSON>"]

ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n  "route_plan": ["/general-search(...)"],\n  "route_rationale": "...",\n  ...\n}<|end_of_plan|>\nReturn nothing else."""

MODELS = {
    "Router-Qwen3-32B-8bit": {
        "repo_id": "Alovestocode/router-qwen3-32b-merged",
        "description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.",
        "params_b": 32.0,
    },
    "Router-Gemma3-27B-8bit": {
        "repo_id": "Alovestocode/router-gemma3-merged",
        "description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.",
        "params_b": 27.0,
    },
}

REQUIRED_KEYS = [
    "route_plan",
    "route_rationale",
    "expected_artifacts",
    "thinking_outline",
    "handoff_plan",
    "todo_list",
    "difficulty",
    "tags",
    "acceptance_criteria",
    "metrics",
]

PIPELINES: Dict[str, Any] = {}
TOKENIZER_CACHE: Dict[str, Any] = {}
WARMED_REMAINING = False
TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE)


def get_tokenizer(repo: str):
    tok = TOKENIZER_CACHE.get(repo)
    if tok is not None:
        return tok
    tok = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
    tok.padding_side = "left"
    tok.truncation_side = "left"
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token_id = tok.eos_token_id
    TOKENIZER_CACHE[repo] = tok
    return tok


def load_pipeline(model_name: str):
    if model_name in PIPELINES:
        return PIPELINES[model_name]

    repo = MODELS[model_name]["repo_id"]
    tokenizer = get_tokenizer(repo)

    try:
        quant_config = BitsAndBytesConfig(load_in_8bit=True)
        pipe = pipeline(
            task="text-generation",
            model=repo,
            tokenizer=tokenizer,
            trust_remote_code=True,
            device_map="auto",
            model_kwargs={"quantization_config": quant_config},
            use_cache=True,
            token=HF_TOKEN,
        )
        pipe.model.eval()
        PIPELINES[model_name] = pipe
        _schedule_background_warm(model_name)
        return pipe
    except Exception as exc:
        print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")

    for dtype in (torch.bfloat16, torch.float16, torch.float32):
        try:
            pipe = pipeline(
                task="text-generation",
                model=repo,
                tokenizer=tokenizer,
                trust_remote_code=True,
                device_map="auto",
                dtype=dtype,
                use_cache=True,
                token=HF_TOKEN,
            )
            pipe.model.eval()
            PIPELINES[model_name] = pipe
            _schedule_background_warm(model_name)
            return pipe
        except Exception:
            continue

    pipe = pipeline(
        task="text-generation",
        model=repo,
        tokenizer=tokenizer,
        trust_remote_code=True,
        device_map="auto",
        use_cache=True,
        token=HF_TOKEN,
    )
    pipe.model.eval()
    PIPELINES[model_name] = pipe
    _schedule_background_warm(model_name)
    return pipe


def _schedule_background_warm(loaded_model: str) -> None:
    global WARMED_REMAINING
    if WARMED_REMAINING:
        return
    warm_remaining = os.environ.get("ROUTER_WARM_REMAINING", "1")
    if warm_remaining not in {"1", "true", "True"}:
        return

    remaining = [name for name in MODELS if name not in PIPELINES]
    if not remaining:
        WARMED_REMAINING = True
        return

    def _warm_all():
        for name in remaining:
            try:
                print(f"Background warm start for {name}")
                load_pipeline(name)
            except Exception as exc:  # pragma: no cover
                print(f"Warm start failed for {name}: {exc}")
        WARMED_REMAINING = True

    Thread(target=_warm_all, daemon=True).start()


def build_router_prompt(
    user_task: str,
    context: str,
    acceptance: str,
    extra_guidance: str,
    difficulty: str,
    tags: str,
) -> str:
    prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"]
    prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}")
    prompt_parts.append(f"Tags: {tags or 'general'}")
    if acceptance.strip():
        prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}")
    if extra_guidance.strip():
        prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}")
    if context.strip():
        prompt_parts.append("\n### Supporting context\n" + context.strip())
    prompt_parts.append("\n### User task\n" + user_task.strip())
    prompt_parts.append("\nReturn only JSON.")
    return "\n".join(prompt_parts)


def extract_json_from_text(text: str) -> str:
    start = text.find("{")
    if start == -1:
        raise ValueError("Router output did not contain a JSON object.")
    depth = 0
    in_string = False
    escape = False
    for idx in range(start, len(text)):
        ch = text[idx]
        if in_string:
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == '"':
                in_string = False
            continue
        if ch == '"':
            in_string = True
            continue
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
            if depth == 0:
                return text[start : idx + 1]
    raise ValueError("Router output JSON appears truncated.")


def is_function_call(text: str) -> bool:
    return bool(TOOL_PATTERN.match(text.strip()))


def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
    issues: List[str] = []
    for key in REQUIRED_KEYS:
        if key not in plan:
            issues.append(f"Missing key: {key}")

    route_plan = plan.get("route_plan")
    if isinstance(route_plan, str) and is_function_call(route_plan):
        plan["route_plan"] = [route_plan]
        route_plan = plan["route_plan"]
    if not isinstance(route_plan, list) or not route_plan:
        issues.append("route_plan must be a non-empty list of tool calls")
    else:
        cleaned: List[str] = []
        for entry in route_plan:
            if isinstance(entry, str) and is_function_call(entry.strip().strip("'\"")):
                cleaned.append(entry.strip().strip("'\""))
            else:
                issues.append(f"route_plan entry is not a tool call: {entry}")
        if cleaned:
            plan["route_plan"] = cleaned

    metrics = plan.get("metrics")
    if not isinstance(metrics, dict):
        issues.append("metrics must be an object containing primary/secondary entries")
    todo = plan.get("todo_list")
    if not isinstance(todo, list) or not todo:
        issues.append("todo_list must contain at least one checklist item")
    else:
        cleaned_todo: List[str] = []
        for entry in todo:
            if isinstance(entry, str):
                text = entry.strip()
                if not text.startswith("- ["):
                    text = text.lstrip("- ")
                    text = f"- [ ] {text}"
                cleaned_todo.append(text)
            else:
                issues.append("todo_list entry must be a string")
        if cleaned_todo:
            plan["todo_list"] = cleaned_todo

    return len(issues) == 0, issues


def format_validation_message(ok: bool, issues: List[str]) -> str:
    if ok:
        return "βœ… Router plan includes all required fields."
    bullets = "\n".join(f"- {issue}" for issue in issues)
    return f"❌ Issues detected:\n{bullets}"


@spaces.GPU(duration=600)
def generate_router_plan_streaming(
    user_task: str,
    context: str,
    acceptance: str,
    extra_guidance: str,
    difficulty: str,
    tags: str,
    model_choice: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
):
    """Generator function for streaming token output."""
    if not user_task.strip():
        yield "", {}, "❌ User task is required.", ""
        return
    
    if model_choice not in MODELS:
        yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
        return

    try:
        prompt = build_router_prompt(
            user_task=user_task,
            context=context,
            acceptance=acceptance,
            extra_guidance=extra_guidance,
            difficulty=difficulty,
            tags=tags,
        )

        generator = load_pipeline(model_choice)
        
        # Get the underlying model and tokenizer
        model = generator.model
        tokenizer = generator.tokenizer
        
        # Set up streaming
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        # Prepare inputs
        inputs = tokenizer(prompt, return_tensors="pt")
        if hasattr(model, 'device'):
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        elif torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # Start generation in a separate thread
        generation_kwargs = {
            **inputs,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "streamer": streamer,
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
        }

        def _generate():
            with torch.inference_mode():
                model.generate(**generation_kwargs)

        thread = Thread(target=_generate)
        thread.start()
        
        # Stream tokens
        completion = ""
        parsed_plan: Dict[str, Any] | None = None
        validation_msg = "πŸ”„ Generating..."

        for new_text in streamer:
            completion += new_text
            chunk = completion
            finished = False
            display_plan = parsed_plan or {}

            chunk, finished = trim_at_stop_sequences(chunk)

            try:
                json_block = extract_json_from_text(chunk)
                candidate_plan = json.loads(json_block)
                ok, issues = validate_router_plan(candidate_plan)
                validation_msg = format_validation_message(ok, issues)
                parsed_plan = candidate_plan if ok else parsed_plan
                display_plan = candidate_plan
            except Exception:
                # Ignore until JSON is complete
                pass

            yield chunk, display_plan, validation_msg, prompt

            if finished:
                completion = chunk
                break

        # Final processing after streaming completes
        thread.join()

        completion = trim_at_stop_sequences(completion.strip())[0]
        if parsed_plan is None:
            try:
                json_block = extract_json_from_text(completion)
                parsed_plan = json.loads(json_block)
                ok, issues = validate_router_plan(parsed_plan)
                validation_msg = format_validation_message(ok, issues)
            except Exception as exc:
                parsed_plan = {}
                validation_msg = f"❌ JSON parsing failed: {exc}"

        yield completion, parsed_plan, validation_msg, prompt
        
    except Exception as exc:
        error_msg = f"❌ Generation failed: {str(exc)}"
        yield "", {}, error_msg, ""


def clear_outputs():
    return "", {}, "Awaiting generation.", ""


def build_ui():
    description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
    with gr.Blocks(theme=gr.themes.Soft(), css="""
        textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
        .status-ok { color: #0d9488; font-weight: 600; }
        .status-bad { color: #dc2626; font-weight: 600; }
    """) as demo:
        gr.Markdown("# πŸ›°οΈ Router Control Room β€” ZeroGPU" )
        gr.Markdown(description)

        with gr.Row():
            with gr.Column(scale=3):
                user_task = gr.Textbox(
                    label="User Task / Problem Statement",
                    placeholder="Describe the homework-style query that needs routing...",
                    lines=8,
                    value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
                )
                context = gr.Textbox(
                    label="Supporting Context (optional)",
                    placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
                    lines=4,
                )
                acceptance = gr.Textbox(
                    label="Acceptance Criteria",
                    placeholder="Bullet list of 'definition of done' checks.",
                    lines=3,
                    value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
                )
                extra_guidance = gr.Textbox(
                    label="Additional Guidance",
                    placeholder="Special constraints, tools to avoid, etc.",
                    lines=3,
                )
            with gr.Column(scale=2):
                model_choice = gr.Dropdown(
                    label="Router Checkpoint",
                    choices=list(MODELS.keys()),
                    value=list(MODELS.keys())[0] if MODELS else None,
                    allow_custom_value=False,
                )
                difficulty = gr.Radio(
                    label="Difficulty Tier",
                    choices=["introductory", "intermediate", "advanced"],
                    value="advanced",
                    interactive=True,
                )
                tags = gr.Textbox(
                    label="Tags",
                    placeholder="Comma-separated e.g. calculus, optimization, python",
                    value="calculus, optimization, python",
                )
                max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
                temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
                top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")

        generate_btn = gr.Button("Generate Router Plan", variant="primary")
        clear_btn = gr.Button("Clear", variant="secondary")

        with gr.Row():
            raw_output = gr.Textbox(label="Raw Model Output", lines=12)
            plan_json = gr.JSON(label="Parsed Router Plan")
        validation_msg = gr.Markdown("Awaiting generation.")
        prompt_view = gr.Textbox(label="Full Prompt", lines=10)

        generate_btn.click(
            generate_router_plan_streaming,
            inputs=[
                user_task,
                context,
                acceptance,
                extra_guidance,
                difficulty,
                tags,
                model_choice,
                max_new_tokens,
                temperature,
                top_p,
            ],
            outputs=[raw_output, plan_json, validation_msg, prompt_view],
            show_progress="full",
            api_name="/generate_router_plan_streaming",
        )

        clear_btn.click(
            fn=clear_outputs,
            outputs=[raw_output, plan_json, validation_msg, prompt_view],
            api_name="/clear_outputs",
        )

    return demo



def _prefetch_from_env() -> None:
    entries = os.environ.get("ROUTER_PREFETCH_MODELS")
    if entries:
        names = [item.strip() for item in entries.split(",") if item.strip()]
    else:
        single = os.environ.get("ROUTER_PREFETCH_MODEL")
        names = [single] if single else []

    if names == ["ALL"] or names == ["all"]:
        names = list(MODELS.keys())

    for name in names:
        if name not in MODELS:
            print(f"Prefetch skipped, unknown model: {name}")
            continue
        try:
            load_pipeline(name)
            print(f"Prefetched router model: {name}")
        except Exception as exc:  # pragma: no cover
            print(f"Prefetch failed for {name}: {exc}")


_prefetch_from_env()

demo = build_ui()

if __name__ == "__main__":  # pragma: no cover
    demo.launch(
        server_name="0.0.0.0",
        server_port=int(os.environ.get("PORT", 7860)),
        show_api=True
    )
def trim_at_stop_sequences(text: str) -> Tuple[str, bool]:
    earliest = None
    for stop in STOP_SEQUENCES:
        idx = text.find(stop)
        if idx != -1 and (earliest is None or idx < earliest):
            earliest = idx
    if earliest is not None:
        return text[:earliest], True
    return text, False