ComplexityRouter / README.md
RowRed's picture
Update README.md
7225bc2 verified
|
Raw
History Blame Contribute Delete
7.3 kB
metadata
license: apache-2.0
language:
  - en
tags:
  - prompt-routing
  - complexity-classifier
  - deberta-v3
  - llm-router
  - cost-optimization
datasets:
  - RowRed/ComplexityRouter
  - OpenAssistant/oasst2
base_model:
  - microsoft/deberta-v3-base

ComplexityRouter: A Complexity based LLM Router

Introducing ComplexityRouter, a lightweight prompt complexity classifier finetuned from microsoft/deberta-v3-base. Using prompts from Open Assistant Conversations Dataset Release 2 (OASST2), some from myself, and some made more complex by Qwen3.5-4B (non-thinking mode), I created a synthetic dataset classifying 4,400 of the prompts using Qwen3.5-4B Non Thinking Mode.

It assigns prompts to one of 4 complexity levels, making it useful for routing queries to the appropriate LLM tier.

Model Details

Model Description

  • Model type: Text Classification (multi‑class)
  • Language: English
  • Backbone: microsoft/deberta-v3-base
  • License: Apache‑2.0
  • Finetuned from model: microsoft/deberta-v3-base
  • Training data: OASST2 + synthetic augmentations + manually created prompts
    Labels generated by Qwen3.5‑4B (non‑thinking mode).

Model Sources

Uses

Direct Use

Route prompts to appropriate LLM tiers based on predicted complexity:

Level Meaning Suggested LLM Tier
0 (Trivial) Simple lookups, basic Q&A (e.g., “What is 2+2?”) Fast/cheap local model
1 (Simple) Moderate reasoning, basic domain knowledge Mid‑tier model
2 (Moderate) Complex reasoning, deep knowledge required Strong model
3 (Complex) Very complex reasoning, niche expertise Frontier API model

Recommended routing strategy: Group levels 0 and 1 together (fast/cheap tier), level 2 as standard, level 3 as premium. The model achieves 93.0% adjacent accuracy on my test, meaning it rarely misroutes by more than one tier.

Out‑of‑Scope Use

  • Multi‑turn conversation routing (single prompts only).
  • Non‑English prompts (training data was English‑only).
  • Prompts requiring image or multimodal understanding.

Bias, Risks, and Limitations

  • Training data is synthetic and may not represent all real‑world prompt distributions.
  • Level 1 (Simple) and Level 2 (Moderate) have lower per‑class F1 scores – boundary cases are inherently ambiguous.
  • The model may struggle with very domain‑specific technical jargon.
  • Performance may degrade on prompts that are very different from the training distribution.

Notice

This is my first attempt making a widespread finetune. There are probably lots of issues, but thought the idea was sound. I might make a second (hopefully better) version eventually, but am not sure where to get lots of high-quality open source data.

Training Details

Training Data

Split Samples Source File Notes
Training 2,800 TRAINING.jsonl Used for model training
Validation 600 TRAINING.jsonl Used for early stopping / hyperparameter tuning
Test (internal) 600 TRAINING.jsonl Used for in‑distribution evaluation
Test (held‑out) 400 TEST.jsonl Fully independent test set (reported results)

Total unique prompts: 4,400

Class distribution (training): Level 0: 762 (27.2%) • Level 1: 674 (24.1%) • Level 2: 795 (28.4%) • Level 3: 569 (20.3%)

Training Procedure

  • Hardware: NVIDIA T4 (16 GB VRAM, Google Colab)
  • Framework: PyTorch 2.11 + Hugging Face Transformers
  • Optimizer: AdamW (lr=2e-5, weight_decay=0.01)
  • Scheduler: Linear warmup (10% of steps) → linear decay
  • Loss: Weighted Cross‑Entropy (classification) + MSE (regression)
  • Batch size: 16 (effective 32 with gradient accumulation)
  • Epochs: 7 (early stopping patience = 3)
  • Training time: ~18 minutes
  • Class balancing: sqrt‑scaled class weights + weighted random sampler

Evaluation Results

Reported on 600 held‑out samples from TRAINING.jsonl (internal test).

Metric Value
Exact Match Accuracy 64.5%
Adjacent (±1) Accuracy 93.0%
Macro F1 0.663
Weighted F1 0.653

Per‑Class Performance (internal test, 600 samples)

Level Precision Recall F1 Support
L0 (Trivial) 0.658 0.626 0.642 163
L1 (Simple) 0.457 0.628 0.529 145
L2 (Moderate) 0.683 0.571 0.622 170
L3 (Complex) 0.933 0.795 0.858 122

Confusion Matrix (internal test, 600 samples)

Pred L0 Pred L1 Pred L2 Pred L3
True L0 102 46 13 2
True L1 35 91 18 1
True L2 15 54 97 4
True L3 3 8 14 97

How to Get Started with the Model

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn

class PromptComplexityRouter(nn.Module):
    def __init__(self, backbone="microsoft/deberta-v3-base", num_labels=4):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(backbone)
        hidden_size = self.backbone.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_labels),
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output)

# Load
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("RowRed/ComplexityRouter")
model = PromptComplexityRouter()
model.load_state_dict(
    torch.load("pytorch_model.bin", map_location=device),
    strict=False
)
model.to(device)
model.eval()

# Predict
prompts = ["What is 2+2?", "Explain quantum entanglement in detail."]
encoded = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
    logits = model(encoded["input_ids"], encoded["attention_mask"])
    probs = torch.softmax(logits, dim=-1)
    predictions = torch.argmax(probs, dim=-1)

for prompt, level in zip(prompts, predictions):
    print(f"Level {level.item()}: {prompt}")

Citation

If you use this model, please cite:

@software{ComplexityRouter,
  author = {RowRed},
  title = {ComplexityRouter},
  year = {2026},
  url = {https://huggingface.co/RowRed/ComplexityRouter}
}

Additionally, acknowledge the base dataset and labeling model:

@dataset{oasst2,
  author = {OpenAssistant Contributors},
  title = {Open Assistant Conversations Dataset Release 2},
  year = {2023},
  url = {https://huggingface.co/datasets/OpenAssistant/oasst2}
}

@software{qwen3.5-4b,
  author = {Qwen Team},
  title = {Qwen3.5-4B},
  year = {2026},
  url = {https://huggingface.co/Qwen/Qwen3.5-4B}
}

License

This model is released under Apache‑2.0. The backbone (microsoft/deberta-v3-base) is MIT‑licensed. The training dataset is derived from OASST2 (Apache‑2.0) and Qwen3.5‑4B outputs (Apache‑2.0).