OpenCircuit
Fix push_to_hub: use upload_folder instead (create_model_card kwarg compat)
6f9f012
#!/usr/bin/env python3
"""BF-Router Trainer Space - QLoRA fine-tuning with live Gradio monitoring."""
import os, json, time, threading, traceback
import gradio as gr
status = {"state": "initializing", "epoch": 0, "total_epochs": 3, "loss": 0,
"eval_loss": 0, "progress": 0, "step": 0, "max_steps": 0,
"log": [], "agent_acc": 0, "tool_acc": 0}
def log(msg):
status["log"].append("[%s] %s" % (time.strftime("%H:%M:%S"), msg))
print(msg, flush=True)
def run_training():
try:
import torch
from datasets import load_dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainerCallback)
from peft import LoraConfig, TaskType, PeftModel
from trl import SFTConfig, SFTTrainer
status["state"] = "loading_data"
log("Loading training data from OpenCircuit/bf-router-training-data...")
dataset = load_dataset("OpenCircuit/bf-router-training-data", data_files={"train": "data/bf_router_merged_train.jsonl", "validation": "data/bf_router_merged_val.jsonl", "test": "data/bf_router_merged_test.jsonl"})
log("Train: %d, Val: %d, Test: %d" % (
len(dataset["train"]), len(dataset["validation"]), len(dataset["test"])))
status["state"] = "loading_model"
log("Loading Qwen3-4B-Instruct-2507 with 4-bit QLoRA...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
base_model = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.eos_token = "<|im_end|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
base_model, quantization_config=bnb_config,
device_map="auto", trust_remote_code=True)
model.config.use_cache = False
log("Model loaded: %dM params" % (model.num_parameters() / 1e6))
def fmt(s):
text = tokenizer.apply_chat_template(
s["messages"], tokenize=False, add_generation_prompt=False)
return {"text": text}
ftrain = dataset["train"].map(fmt, remove_columns=dataset["train"].column_names)
fval = dataset["validation"].map(fmt, remove_columns=dataset["validation"].column_names)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32,
lora_dropout=0.05, bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"])
out_dir = "/app/output/bf-router-v0.5"
args = SFTConfig(
output_dir=out_dir, num_train_epochs=3,
per_device_train_batch_size=4, per_device_eval_batch_size=4,
gradient_accumulation_steps=4, gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="adamw_torch_fused", learning_rate=2e-4,
lr_scheduler_type="cosine", warmup_ratio=0.03,
max_grad_norm=0.3, weight_decay=0.01, bf16=True,
max_length=2048, dataset_text_field="text",
logging_steps=10, logging_first_step=True,
save_strategy="epoch", eval_strategy="epoch", save_total_limit=3,
load_best_model_at_end=True, metric_for_best_model="eval_loss",
greater_is_better=False, report_to="none", seed=42)
class StatusCallback(TrainerCallback):
def on_log(self, a, state, control, logs=None, **kw):
if logs:
status["epoch"] = logs.get("epoch", 0)
status["loss"] = logs.get("loss", logs.get("eval_loss", 0))
if "eval_loss" in logs:
status["eval_loss"] = logs["eval_loss"]
status["step"] = state.global_step
status["max_steps"] = state.max_steps
if state.max_steps:
status["progress"] = state.global_step / state.max_steps * 100
status["state"] = "training"
log("Starting QLoRA fine-tuning (3 epochs, effective batch=16)...")
trainer = SFTTrainer(
model=model, processing_class=tokenizer, args=args,
peft_config=lora_config, train_dataset=ftrain,
eval_dataset=fval, callbacks=[StatusCallback()])
trainer.train()
trainer.save_model(out_dir)
tokenizer.save_pretrained(out_dir)
log("Final eval loss: %.4f (from best checkpoint)" % status["eval_loss"])
# Quick accuracy eval
status["state"] = "evaluating"
log("Evaluating routing accuracy on test set...")
correct_agent = 0
total = 0
test_subset = dataset["test"].select(range(min(100, len(dataset["test"]))))
model.eval()
device = next(model.parameters()).device
for sample in test_subset:
msgs = sample["messages"]
expected = json.loads(msgs[-1]["content"])
inp = tokenizer.apply_chat_template(
msgs[:-1], tokenize=False, add_generation_prompt=True)
inputs = tokenizer(inp, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs, max_new_tokens=256, temperature=0.3,
top_p=0.7, do_sample=True,
pad_token_id=tokenizer.pad_token_id)
gen = tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True).strip()
try:
pred = json.loads(gen)
if pred.get("agent") == expected.get("agent"):
correct_agent += 1
except Exception:
pass
total += 1
acc = correct_agent / total * 100 if total else 0
status["agent_acc"] = acc
log("Agent routing accuracy: %.1f%% (%d/%d)" % (acc, correct_agent, total))
# Push to Hub
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
log("Pushing model to OpenCircuit/bf-router...")
from huggingface_hub import HfApi
api = HfApi(token=hf_token)
try:
api.create_repo("OpenCircuit/bf-router", exist_ok=True)
except Exception:
pass
api.upload_folder(folder_path=out_dir, repo_id="OpenCircuit/bf-router", repo_type="model", token=hf_token, commit_message="Upload BF-Router v0.5 QLoRA adapter (Qwen3-4B)")
log("Model pushed to Hub!")
status["state"] = "complete"
log("Training complete!")
with open(os.path.join(out_dir, "results.json"), "w") as f:
json.dump({"eval_loss": status["eval_loss"],
"agent_accuracy": acc, "total_test": total}, f, indent=2)
except Exception as e:
status["state"] = "error"
status["error"] = str(e)
log("ERROR: %s" % str(e))
log(traceback.format_exc())
# Start training in background
t = threading.Thread(target=run_training, daemon=True)
t.start()
# Gradio UI
SYSTEM_PROMPT = (
'You are BF-Router, the intent classifier for BlueprintForge. '
'Analyze the user\'s message and respond with JSON: '
'{"agent":"<id>","confidence":<0-1>,"reason":"<why>",'
'"tools":["<tool1>",...],"chain":[]}. '
'Agents: manny (builder), ping (investigator), fuse (debugger), '
'bit (planner), mainframe (knowledge), sc (tester), '
'willow (human-translator).'
)
def get_status():
icons = {
"initializing": "hourglass", "loading_data": "chart",
"loading_model": "robot", "training": "fire",
"evaluating": "magnifier", "complete": "check", "error": "cross"
}
state = status["state"]
md = "## BF-Router Training\n\n"
md += "| Metric | Value |\n|--------|-------|\n"
md += "| **State** | %s |\n" % state
md += "| **Progress** | %.1f%% (%d/%d) |\n" % (
status["progress"], status["step"], status["max_steps"])
md += "| **Epoch** | %.2f / %d |\n" % (status["epoch"], status["total_epochs"])
md += "| **Train Loss** | %.4f |\n" % status["loss"]
md += "| **Eval Loss** | %.4f |\n" % status["eval_loss"]
md += "| **Agent Accuracy** | %.1f%% |\n" % status["agent_acc"]
if status.get("error"):
md += "\n**Error:** `%s`" % status["error"]
return md
def get_logs():
return "\n".join(status["log"][-50:])
def test_model(query):
if status["state"] != "complete":
return "Training is %s. Please wait for completion." % status["state"]
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
out_dir = "/app/output/bf-router-v0.5"
tok = AutoTokenizer.from_pretrained(out_dir, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
mdl = PeftModel.from_pretrained(mdl, out_dir)
msgs = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query}
]
txt = tok.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True)
inp = tok(txt, return_tensors="pt").to(mdl.device)
with torch.no_grad():
out = mdl.generate(
**inp, max_new_tokens=256, temperature=0.3,
top_p=0.7, do_sample=True)
return tok.decode(
out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True)
except Exception as ex:
return "Error: %s" % str(ex)
with gr.Blocks(title="BF-Router Trainer") as demo:
gr.Markdown(
"# BF-Router Fine-Tuning\n"
"QLoRA training of Qwen3-4B for BlueprintForge 7-agent routing"
)
with gr.Row():
with gr.Column(scale=1):
status_md = gr.Markdown(get_status, every=5)
with gr.Column(scale=2):
log_box = gr.Textbox(get_logs, label="Training Log", lines=20, every=5)
gr.Markdown("---\n## Test Model")
with gr.Row():
q = gr.Textbox(label="Query", placeholder="Build a health bar for the player")
btn = gr.Button("Route", variant="primary")
out = gr.JSON(label="BF-Router Response")
btn.click(test_model, inputs=q, outputs=out)
demo.launch(server_name="0.0.0.0", server_port=7860)