Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """BF-Router Trainer Space - QLoRA fine-tuning with live Gradio monitoring.""" | |
| import os, json, time, threading, traceback | |
| import gradio as gr | |
| status = {"state": "initializing", "epoch": 0, "total_epochs": 3, "loss": 0, | |
| "eval_loss": 0, "progress": 0, "step": 0, "max_steps": 0, | |
| "log": [], "agent_acc": 0, "tool_acc": 0} | |
| def log(msg): | |
| status["log"].append("[%s] %s" % (time.strftime("%H:%M:%S"), msg)) | |
| print(msg, flush=True) | |
| def run_training(): | |
| try: | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
| BitsAndBytesConfig, TrainerCallback) | |
| from peft import LoraConfig, TaskType, PeftModel | |
| from trl import SFTConfig, SFTTrainer | |
| status["state"] = "loading_data" | |
| log("Loading training data from OpenCircuit/bf-router-training-data...") | |
| dataset = load_dataset("OpenCircuit/bf-router-training-data", data_files={"train": "data/bf_router_merged_train.jsonl", "validation": "data/bf_router_merged_val.jsonl", "test": "data/bf_router_merged_test.jsonl"}) | |
| log("Train: %d, Val: %d, Test: %d" % ( | |
| len(dataset["train"]), len(dataset["validation"]), len(dataset["test"]))) | |
| status["state"] = "loading_model" | |
| log("Loading Qwen3-4B-Instruct-2507 with 4-bit QLoRA...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
| base_model = "Qwen/Qwen3-4B-Instruct-2507" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
| tokenizer.eos_token = "<|im_end|>" | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, quantization_config=bnb_config, | |
| device_map="auto", trust_remote_code=True) | |
| model.config.use_cache = False | |
| log("Model loaded: %dM params" % (model.num_parameters() / 1e6)) | |
| def fmt(s): | |
| text = tokenizer.apply_chat_template( | |
| s["messages"], tokenize=False, add_generation_prompt=False) | |
| return {"text": text} | |
| ftrain = dataset["train"].map(fmt, remove_columns=dataset["train"].column_names) | |
| fval = dataset["validation"].map(fmt, remove_columns=dataset["validation"].column_names) | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, | |
| lora_dropout=0.05, bias="none", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"]) | |
| out_dir = "/app/output/bf-router-v0.5" | |
| args = SFTConfig( | |
| output_dir=out_dir, num_train_epochs=3, | |
| per_device_train_batch_size=4, per_device_eval_batch_size=4, | |
| gradient_accumulation_steps=4, gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| optim="adamw_torch_fused", learning_rate=2e-4, | |
| lr_scheduler_type="cosine", warmup_ratio=0.03, | |
| max_grad_norm=0.3, weight_decay=0.01, bf16=True, | |
| max_length=2048, dataset_text_field="text", | |
| logging_steps=10, logging_first_step=True, | |
| save_strategy="epoch", eval_strategy="epoch", save_total_limit=3, | |
| load_best_model_at_end=True, metric_for_best_model="eval_loss", | |
| greater_is_better=False, report_to="none", seed=42) | |
| class StatusCallback(TrainerCallback): | |
| def on_log(self, a, state, control, logs=None, **kw): | |
| if logs: | |
| status["epoch"] = logs.get("epoch", 0) | |
| status["loss"] = logs.get("loss", logs.get("eval_loss", 0)) | |
| if "eval_loss" in logs: | |
| status["eval_loss"] = logs["eval_loss"] | |
| status["step"] = state.global_step | |
| status["max_steps"] = state.max_steps | |
| if state.max_steps: | |
| status["progress"] = state.global_step / state.max_steps * 100 | |
| status["state"] = "training" | |
| log("Starting QLoRA fine-tuning (3 epochs, effective batch=16)...") | |
| trainer = SFTTrainer( | |
| model=model, processing_class=tokenizer, args=args, | |
| peft_config=lora_config, train_dataset=ftrain, | |
| eval_dataset=fval, callbacks=[StatusCallback()]) | |
| trainer.train() | |
| trainer.save_model(out_dir) | |
| tokenizer.save_pretrained(out_dir) | |
| log("Final eval loss: %.4f (from best checkpoint)" % status["eval_loss"]) | |
| # Quick accuracy eval | |
| status["state"] = "evaluating" | |
| log("Evaluating routing accuracy on test set...") | |
| correct_agent = 0 | |
| total = 0 | |
| test_subset = dataset["test"].select(range(min(100, len(dataset["test"])))) | |
| model.eval() | |
| device = next(model.parameters()).device | |
| for sample in test_subset: | |
| msgs = sample["messages"] | |
| expected = json.loads(msgs[-1]["content"]) | |
| inp = tokenizer.apply_chat_template( | |
| msgs[:-1], tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(inp, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, max_new_tokens=256, temperature=0.3, | |
| top_p=0.7, do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id) | |
| gen = tokenizer.decode( | |
| out[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True).strip() | |
| try: | |
| pred = json.loads(gen) | |
| if pred.get("agent") == expected.get("agent"): | |
| correct_agent += 1 | |
| except Exception: | |
| pass | |
| total += 1 | |
| acc = correct_agent / total * 100 if total else 0 | |
| status["agent_acc"] = acc | |
| log("Agent routing accuracy: %.1f%% (%d/%d)" % (acc, correct_agent, total)) | |
| # Push to Hub | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| log("Pushing model to OpenCircuit/bf-router...") | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| try: | |
| api.create_repo("OpenCircuit/bf-router", exist_ok=True) | |
| except Exception: | |
| pass | |
| api.upload_folder(folder_path=out_dir, repo_id="OpenCircuit/bf-router", repo_type="model", token=hf_token, commit_message="Upload BF-Router v0.5 QLoRA adapter (Qwen3-4B)") | |
| log("Model pushed to Hub!") | |
| status["state"] = "complete" | |
| log("Training complete!") | |
| with open(os.path.join(out_dir, "results.json"), "w") as f: | |
| json.dump({"eval_loss": status["eval_loss"], | |
| "agent_accuracy": acc, "total_test": total}, f, indent=2) | |
| except Exception as e: | |
| status["state"] = "error" | |
| status["error"] = str(e) | |
| log("ERROR: %s" % str(e)) | |
| log(traceback.format_exc()) | |
| # Start training in background | |
| t = threading.Thread(target=run_training, daemon=True) | |
| t.start() | |
| # Gradio UI | |
| SYSTEM_PROMPT = ( | |
| 'You are BF-Router, the intent classifier for BlueprintForge. ' | |
| 'Analyze the user\'s message and respond with JSON: ' | |
| '{"agent":"<id>","confidence":<0-1>,"reason":"<why>",' | |
| '"tools":["<tool1>",...],"chain":[]}. ' | |
| 'Agents: manny (builder), ping (investigator), fuse (debugger), ' | |
| 'bit (planner), mainframe (knowledge), sc (tester), ' | |
| 'willow (human-translator).' | |
| ) | |
| def get_status(): | |
| icons = { | |
| "initializing": "hourglass", "loading_data": "chart", | |
| "loading_model": "robot", "training": "fire", | |
| "evaluating": "magnifier", "complete": "check", "error": "cross" | |
| } | |
| state = status["state"] | |
| md = "## BF-Router Training\n\n" | |
| md += "| Metric | Value |\n|--------|-------|\n" | |
| md += "| **State** | %s |\n" % state | |
| md += "| **Progress** | %.1f%% (%d/%d) |\n" % ( | |
| status["progress"], status["step"], status["max_steps"]) | |
| md += "| **Epoch** | %.2f / %d |\n" % (status["epoch"], status["total_epochs"]) | |
| md += "| **Train Loss** | %.4f |\n" % status["loss"] | |
| md += "| **Eval Loss** | %.4f |\n" % status["eval_loss"] | |
| md += "| **Agent Accuracy** | %.1f%% |\n" % status["agent_acc"] | |
| if status.get("error"): | |
| md += "\n**Error:** `%s`" % status["error"] | |
| return md | |
| def get_logs(): | |
| return "\n".join(status["log"][-50:]) | |
| def test_model(query): | |
| if status["state"] != "complete": | |
| return "Training is %s. Please wait for completion." % status["state"] | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| out_dir = "/app/output/bf-router-v0.5" | |
| tok = AutoTokenizer.from_pretrained(out_dir, trust_remote_code=True) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-4B-Instruct-2507", | |
| device_map="auto", torch_dtype=torch.float16, trust_remote_code=True) | |
| mdl = PeftModel.from_pretrained(mdl, out_dir) | |
| msgs = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": query} | |
| ] | |
| txt = tok.apply_chat_template( | |
| msgs, tokenize=False, add_generation_prompt=True) | |
| inp = tok(txt, return_tensors="pt").to(mdl.device) | |
| with torch.no_grad(): | |
| out = mdl.generate( | |
| **inp, max_new_tokens=256, temperature=0.3, | |
| top_p=0.7, do_sample=True) | |
| return tok.decode( | |
| out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True) | |
| except Exception as ex: | |
| return "Error: %s" % str(ex) | |
| with gr.Blocks(title="BF-Router Trainer") as demo: | |
| gr.Markdown( | |
| "# BF-Router Fine-Tuning\n" | |
| "QLoRA training of Qwen3-4B for BlueprintForge 7-agent routing" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| status_md = gr.Markdown(get_status, every=5) | |
| with gr.Column(scale=2): | |
| log_box = gr.Textbox(get_logs, label="Training Log", lines=20, every=5) | |
| gr.Markdown("---\n## Test Model") | |
| with gr.Row(): | |
| q = gr.Textbox(label="Query", placeholder="Build a health bar for the player") | |
| btn = gr.Button("Route", variant="primary") | |
| out = gr.JSON(label="BF-Router Response") | |
| btn.click(test_model, inputs=q, outputs=out) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |