#!/usr/bin/env python3 """BF-Router Trainer Space - QLoRA fine-tuning with live Gradio monitoring.""" import os, json, time, threading, traceback import gradio as gr status = {"state": "initializing", "epoch": 0, "total_epochs": 3, "loss": 0, "eval_loss": 0, "progress": 0, "step": 0, "max_steps": 0, "log": [], "agent_acc": 0, "tool_acc": 0} def log(msg): status["log"].append("[%s] %s" % (time.strftime("%H:%M:%S"), msg)) print(msg, flush=True) def run_training(): try: import torch from datasets import load_dataset from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback) from peft import LoraConfig, TaskType, PeftModel from trl import SFTConfig, SFTTrainer status["state"] = "loading_data" log("Loading training data from OpenCircuit/bf-router-training-data...") dataset = load_dataset("OpenCircuit/bf-router-training-data", data_files={"train": "data/bf_router_merged_train.jsonl", "validation": "data/bf_router_merged_val.jsonl", "test": "data/bf_router_merged_test.jsonl"}) log("Train: %d, Val: %d, Test: %d" % ( len(dataset["train"]), len(dataset["validation"]), len(dataset["test"]))) status["state"] = "loading_model" log("Loading Qwen3-4B-Instruct-2507 with 4-bit QLoRA...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) base_model = "Qwen/Qwen3-4B-Instruct-2507" tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) tokenizer.eos_token = "<|im_end|>" tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( base_model, quantization_config=bnb_config, device_map="auto", trust_remote_code=True) model.config.use_cache = False log("Model loaded: %dM params" % (model.num_parameters() / 1e6)) def fmt(s): text = tokenizer.apply_chat_template( s["messages"], tokenize=False, add_generation_prompt=False) return {"text": text} ftrain = dataset["train"].map(fmt, remove_columns=dataset["train"].column_names) fval = dataset["validation"].map(fmt, remove_columns=dataset["validation"].column_names) lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) out_dir = "/app/output/bf-router-v0.5" args = SFTConfig( output_dir=out_dir, num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch_fused", learning_rate=2e-4, lr_scheduler_type="cosine", warmup_ratio=0.03, max_grad_norm=0.3, weight_decay=0.01, bf16=True, max_length=2048, dataset_text_field="text", logging_steps=10, logging_first_step=True, save_strategy="epoch", eval_strategy="epoch", save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, report_to="none", seed=42) class StatusCallback(TrainerCallback): def on_log(self, a, state, control, logs=None, **kw): if logs: status["epoch"] = logs.get("epoch", 0) status["loss"] = logs.get("loss", logs.get("eval_loss", 0)) if "eval_loss" in logs: status["eval_loss"] = logs["eval_loss"] status["step"] = state.global_step status["max_steps"] = state.max_steps if state.max_steps: status["progress"] = state.global_step / state.max_steps * 100 status["state"] = "training" log("Starting QLoRA fine-tuning (3 epochs, effective batch=16)...") trainer = SFTTrainer( model=model, processing_class=tokenizer, args=args, peft_config=lora_config, train_dataset=ftrain, eval_dataset=fval, callbacks=[StatusCallback()]) trainer.train() trainer.save_model(out_dir) tokenizer.save_pretrained(out_dir) log("Final eval loss: %.4f (from best checkpoint)" % status["eval_loss"]) # Quick accuracy eval status["state"] = "evaluating" log("Evaluating routing accuracy on test set...") correct_agent = 0 total = 0 test_subset = dataset["test"].select(range(min(100, len(dataset["test"])))) model.eval() device = next(model.parameters()).device for sample in test_subset: msgs = sample["messages"] expected = json.loads(msgs[-1]["content"]) inp = tokenizer.apply_chat_template( msgs[:-1], tokenize=False, add_generation_prompt=True) inputs = tokenizer(inp, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=256, temperature=0.3, top_p=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id) gen = tokenizer.decode( out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() try: pred = json.loads(gen) if pred.get("agent") == expected.get("agent"): correct_agent += 1 except Exception: pass total += 1 acc = correct_agent / total * 100 if total else 0 status["agent_acc"] = acc log("Agent routing accuracy: %.1f%% (%d/%d)" % (acc, correct_agent, total)) # Push to Hub hf_token = os.environ.get("HF_TOKEN") if hf_token: log("Pushing model to OpenCircuit/bf-router...") from huggingface_hub import HfApi api = HfApi(token=hf_token) try: api.create_repo("OpenCircuit/bf-router", exist_ok=True) except Exception: pass api.upload_folder(folder_path=out_dir, repo_id="OpenCircuit/bf-router", repo_type="model", token=hf_token, commit_message="Upload BF-Router v0.5 QLoRA adapter (Qwen3-4B)") log("Model pushed to Hub!") status["state"] = "complete" log("Training complete!") with open(os.path.join(out_dir, "results.json"), "w") as f: json.dump({"eval_loss": status["eval_loss"], "agent_accuracy": acc, "total_test": total}, f, indent=2) except Exception as e: status["state"] = "error" status["error"] = str(e) log("ERROR: %s" % str(e)) log(traceback.format_exc()) # Start training in background t = threading.Thread(target=run_training, daemon=True) t.start() # Gradio UI SYSTEM_PROMPT = ( 'You are BF-Router, the intent classifier for BlueprintForge. ' 'Analyze the user\'s message and respond with JSON: ' '{"agent":"","confidence":<0-1>,"reason":"",' '"tools":["",...],"chain":[]}. ' 'Agents: manny (builder), ping (investigator), fuse (debugger), ' 'bit (planner), mainframe (knowledge), sc (tester), ' 'willow (human-translator).' ) def get_status(): icons = { "initializing": "hourglass", "loading_data": "chart", "loading_model": "robot", "training": "fire", "evaluating": "magnifier", "complete": "check", "error": "cross" } state = status["state"] md = "## BF-Router Training\n\n" md += "| Metric | Value |\n|--------|-------|\n" md += "| **State** | %s |\n" % state md += "| **Progress** | %.1f%% (%d/%d) |\n" % ( status["progress"], status["step"], status["max_steps"]) md += "| **Epoch** | %.2f / %d |\n" % (status["epoch"], status["total_epochs"]) md += "| **Train Loss** | %.4f |\n" % status["loss"] md += "| **Eval Loss** | %.4f |\n" % status["eval_loss"] md += "| **Agent Accuracy** | %.1f%% |\n" % status["agent_acc"] if status.get("error"): md += "\n**Error:** `%s`" % status["error"] return md def get_logs(): return "\n".join(status["log"][-50:]) def test_model(query): if status["state"] != "complete": return "Training is %s. Please wait for completion." % status["state"] try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel out_dir = "/app/output/bf-router-v0.5" tok = AutoTokenizer.from_pretrained(out_dir, trust_remote_code=True) mdl = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-4B-Instruct-2507", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True) mdl = PeftModel.from_pretrained(mdl, out_dir) msgs = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": query} ] txt = tok.apply_chat_template( msgs, tokenize=False, add_generation_prompt=True) inp = tok(txt, return_tensors="pt").to(mdl.device) with torch.no_grad(): out = mdl.generate( **inp, max_new_tokens=256, temperature=0.3, top_p=0.7, do_sample=True) return tok.decode( out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True) except Exception as ex: return "Error: %s" % str(ex) with gr.Blocks(title="BF-Router Trainer") as demo: gr.Markdown( "# BF-Router Fine-Tuning\n" "QLoRA training of Qwen3-4B for BlueprintForge 7-agent routing" ) with gr.Row(): with gr.Column(scale=1): status_md = gr.Markdown(get_status, every=5) with gr.Column(scale=2): log_box = gr.Textbox(get_logs, label="Training Log", lines=20, every=5) gr.Markdown("---\n## Test Model") with gr.Row(): q = gr.Textbox(label="Query", placeholder="Build a health bar for the player") btn = gr.Button("Route", variant="primary") out = gr.JSON(label="BF-Router Response") btn.click(test_model, inputs=q, outputs=out) demo.launch(server_name="0.0.0.0", server_port=7860)