Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -1,3 +1,226 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ModelGate-Router
|
| 2 |
+
|
| 3 |
+
GRPO (Group Relative Policy Optimization) fine-tuned routing model for ModelGate's contract-aware query routing. Based on Arch-Router-1.5B. Classifies incoming queries as **simple**, **medium**, or **complex** to route them to the right model tier.
|
| 4 |
+
|
| 5 |
+
## Results
|
| 6 |
+
|
| 7 |
+
### Accuracy β Held-Out Eval (54 unseen prompts, zero training overlap)
|
| 8 |
+
|
| 9 |
+
| Tier | Stock Arch-Router | ModelGate-Router | Improvement |
|
| 10 |
+
|------|-------------------|---------------|-------------|
|
| 11 |
+
| Simple (33) | 87.9% | 81.8% | -6.1% |
|
| 12 |
+
| **Medium (14)** | **14.3%** | **85.7%** | **+71.4%** |
|
| 13 |
+
| Complex (7) | 100.0% | 85.7% | -14.3% |
|
| 14 |
+
| **Overall (54)** | **70.4%** | **83.3%** | **+13.0%** |
|
| 15 |
+
|
| 16 |
+
The stock model misclassifies **86% of medium queries** as complex β routing them to expensive premium models when a mid-tier model would suffice. ModelGate-Router fixes this.
|
| 17 |
+
|
| 18 |
+
### Latency β GGUF Q8_0 + CUDA (RTX 3080 Laptop)
|
| 19 |
+
|
| 20 |
+
| Metric | Stock | ModelGate-Router | Delta |
|
| 21 |
+
|--------|-------|---------------|-------|
|
| 22 |
+
| Avg | 62.6ms | 61.7ms | -0.9ms |
|
| 23 |
+
| P50 | 61.3ms | 60.5ms | -0.8ms |
|
| 24 |
+
| P95 | 67.8ms | 67.3ms | -0.5ms |
|
| 25 |
+
|
| 26 |
+
**Zero latency overhead.** ModelGate-Router is actually marginally faster.
|
| 27 |
+
|
| 28 |
+
### Latency by Inference Backend
|
| 29 |
+
|
| 30 |
+
| Backend | Model Size | Avg Latency | vs Transformers FP16 |
|
| 31 |
+
|---------|-----------|-------------|---------------------|
|
| 32 |
+
| Transformers FP16 | ~3.0 GB | 196ms | baseline |
|
| 33 |
+
| GGUF Q8_0 (CPU) | 1.6 GB | 768ms | 3.9x slower |
|
| 34 |
+
| **GGUF Q8_0 (CUDA)** | **1.6 GB** | **62ms** | **3.2x faster** |
|
| 35 |
+
|
| 36 |
+
### Quantization Impact on Accuracy
|
| 37 |
+
|
| 38 |
+
| Backend | Stock Accuracy | ModelGate-Router Accuracy |
|
| 39 |
+
|---------|---------------|-------------------|
|
| 40 |
+
| Transformers FP16 | 72.2% | 81.5% |
|
| 41 |
+
| GGUF Q8_0 | 70.4% | 83.3% |
|
| 42 |
+
|
| 43 |
+
Q8_0 quantization causes **no meaningful accuracy degradation**.
|
| 44 |
+
|
| 45 |
+
### Chain-of-Thought vs Direct Output
|
| 46 |
+
|
| 47 |
+
We trained two variants β one with chain-of-thought reasoning (`<reasoning>` tags before the answer) and one with direct JSON output. Results on held-out data:
|
| 48 |
+
|
| 49 |
+
| Variant | Accuracy | Avg Latency | Verdict |
|
| 50 |
+
|---------|----------|-------------|---------|
|
| 51 |
+
| Stock | 72.2% | 196ms | Baseline |
|
| 52 |
+
| **ModelGate-Router (No-CoT)** | **81.5%** | **198ms** | **Best tradeoff** |
|
| 53 |
+
| ModelGate-Router (CoT) | 61.1% | 1,787ms | Overfit, 9x slower |
|
| 54 |
+
|
| 55 |
+
The CoT variant actually hurt accuracy on unseen data (overfit to training format) and added ~1.6s of latency per classification. The No-CoT variant is the clear winner.
|
| 56 |
+
|
| 57 |
+
## Why Fine-Tune?
|
| 58 |
+
|
| 59 |
+
The stock `katanemo/Arch-Router-1.5B` was trained for general-purpose intent routing. Our use case is specific: classify queries across customer support, insurance claims, and device protection into three complexity tiers. The stock model has a critical blind spot β it routes nearly all medium-complexity queries to the complex tier, wasting money on premium models.
|
| 60 |
+
|
| 61 |
+
## Architecture
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
Qwen/Qwen2.5-1.5B-Instruct (base LLM, 1.5B params)
|
| 65 |
+
|
|
| 66 |
+
v Katanemo fine-tune
|
| 67 |
+
katanemo/Arch-Router-1.5B (general intent routing)
|
| 68 |
+
|
|
| 69 |
+
v GRPO fine-tune (2.3% of params, LoRA rank 32)
|
| 70 |
+
ModelGate-Router (domain-specific complexity routing)
|
| 71 |
+
|
|
| 72 |
+
v GGUF Q8_0 quantization
|
| 73 |
+
ModelGate-Router.Q8_0.gguf (1.6 GB, production-ready)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Files
|
| 77 |
+
|
| 78 |
+
### Model Weights
|
| 79 |
+
|
| 80 |
+
| File | Size | Description |
|
| 81 |
+
|------|------|-------------|
|
| 82 |
+
| `ModelGate-Router.Q8_0.gguf` | 1.6 GB | ModelGate-Router β GGUF Q8_0, deploy with llama.cpp |
|
| 83 |
+
| `stock_arch_router.Q8_0.gguf` | 1.6 GB | Stock Arch-Router in GGUF Q8_0, for comparison |
|
| 84 |
+
| `ModelGate-Router-LoRA/` | 157 MB | ModelGate-Router LoRA adapter (best model) |
|
| 85 |
+
| `modelgate_arch_router_lora/` | 157 MB | CoT LoRA adapter (for reference only) |
|
| 86 |
+
|
| 87 |
+
### Data
|
| 88 |
+
|
| 89 |
+
| File | Description |
|
| 90 |
+
|------|-------------|
|
| 91 |
+
| `grpo_training_data.json` | 172 labeled training prompts across 4 domains |
|
| 92 |
+
| `grpo_eval_data.json` | 54 held-out eval prompts (zero overlap with training) |
|
| 93 |
+
|
| 94 |
+
### Scripts
|
| 95 |
+
|
| 96 |
+
| File | Description |
|
| 97 |
+
|------|-------------|
|
| 98 |
+
| `grpo_finetune_arch_router.ipynb` | CoT training notebook (Colab/local) |
|
| 99 |
+
| `grpo_run_nocot.py` | No-CoT training script (the one that produced the best model) |
|
| 100 |
+
| `export_gguf.py` | Merges LoRA + converts to GGUF Q8_0 |
|
| 101 |
+
| `bench_gguf.py` | Benchmarks GGUF models via llama.cpp (accuracy + latency) |
|
| 102 |
+
| `bench_stock_vs_finetune.py` | Benchmarks via Transformers (FP16, 3-way comparison) |
|
| 103 |
+
|
| 104 |
+
## Training Data
|
| 105 |
+
|
| 106 |
+
**172 training examples** across 4 domains and 3 tiers:
|
| 107 |
+
|
| 108 |
+
| Domain | Count |
|
| 109 |
+
|--------|-------|
|
| 110 |
+
| customer_support | 51 |
|
| 111 |
+
| insurance_claims | 46 |
|
| 112 |
+
| device_protection | 37 |
|
| 113 |
+
| general | 38 |
|
| 114 |
+
|
| 115 |
+
| Tier | Count | Examples |
|
| 116 |
+
|------|-------|----------|
|
| 117 |
+
| simple | 95 | "What is your return policy?", "Is my claim approved?" |
|
| 118 |
+
| medium | 51 | "Compare the protection plans available for my new laptop..." |
|
| 119 |
+
| complex | 26 | "Analyze the multi-party liability exposure across claims..." |
|
| 120 |
+
|
| 121 |
+
**54 eval examples** β completely separate prompts, same domain/tier distribution, zero overlap with training data.
|
| 122 |
+
|
| 123 |
+
## How GRPO Training Works
|
| 124 |
+
|
| 125 |
+
Unlike supervised fine-tuning where you provide input-output pairs, GRPO:
|
| 126 |
+
|
| 127 |
+
1. **Generates** multiple candidate completions per prompt
|
| 128 |
+
2. **Scores** each with reward functions
|
| 129 |
+
3. **Reinforces** the best completions relative to the group
|
| 130 |
+
|
| 131 |
+
### No-CoT Reward Functions
|
| 132 |
+
|
| 133 |
+
| Function | Max Score | Purpose |
|
| 134 |
+
|----------|-----------|---------|
|
| 135 |
+
| `correctness_reward_func` | 2.0 | Route matches ground truth |
|
| 136 |
+
| `valid_route_reward_func` | 0.5 | Output is a valid tier name |
|
| 137 |
+
| `json_format_reward_func` | 1.0 | Output is clean JSON with "route" key |
|
| 138 |
+
| `brevity_reward_func` | 0.5 | Rewards short outputs (just the JSON) |
|
| 139 |
+
|
| 140 |
+
## Training Details
|
| 141 |
+
|
| 142 |
+
| Parameter | Value |
|
| 143 |
+
|-----------|-------|
|
| 144 |
+
| Base model | `katanemo/Arch-Router-1.5B` |
|
| 145 |
+
| Method | GRPO via Unsloth + TRL |
|
| 146 |
+
| LoRA rank | 32 |
|
| 147 |
+
| Trainable params | 36.9M / 1.58B (2.3%) |
|
| 148 |
+
| Training steps | 150 |
|
| 149 |
+
| Training time | **2.5 minutes** |
|
| 150 |
+
| Hardware | RTX 3080 Laptop 8GB |
|
| 151 |
+
| VRAM usage | ~6 GB (4-bit quantized during training) |
|
| 152 |
+
| Generations per prompt | 4 |
|
| 153 |
+
| Learning rate | 5e-6 |
|
| 154 |
+
| Max completion length | 64 tokens |
|
| 155 |
+
|
| 156 |
+
## How to Reproduce
|
| 157 |
+
|
| 158 |
+
### Train the No-CoT Model
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
# Requires: pip install unsloth vllm trl
|
| 162 |
+
python finetuning/grpo_run_nocot.py
|
| 163 |
+
# Output: ModelGate-Router-LoRA/
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Export to GGUF
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
python finetuning/export_gguf.py nocot
|
| 170 |
+
# Output: finetuning/ModelGate-Router.Q8_0.gguf
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
### Benchmark
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
# GGUF benchmark (requires llama-cpp-python with CUDA)
|
| 177 |
+
python finetuning/bench_gguf.py
|
| 178 |
+
|
| 179 |
+
# Transformers FP16 benchmark (3-way: stock vs no-CoT vs CoT)
|
| 180 |
+
python finetuning/bench_stock_vs_finetune.py
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## Production Deployment
|
| 184 |
+
|
| 185 |
+
The recommended deployment uses `ModelGate-Router.Q8_0.gguf` with llama.cpp:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
from llama_cpp import Llama
|
| 189 |
+
|
| 190 |
+
model = Llama(
|
| 191 |
+
model_path="finetuning/ModelGate-Router.Q8_0.gguf",
|
| 192 |
+
n_ctx=512,
|
| 193 |
+
n_gpu_layers=-1, # All layers on GPU
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Classify a query
|
| 197 |
+
response = model.create_chat_completion(
|
| 198 |
+
messages=[{"role": "user", "content": routing_prompt}],
|
| 199 |
+
max_tokens=30,
|
| 200 |
+
temperature=0,
|
| 201 |
+
)
|
| 202 |
+
route = json.loads(response["choices"][0]["message"]["content"])["route"]
|
| 203 |
+
# route is "simple", "medium", or "complex"
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**Expected performance**: ~62ms per classification, 83%+ accuracy, 1.6 GB VRAM.
|
| 207 |
+
|
| 208 |
+
## Route Policies
|
| 209 |
+
|
| 210 |
+
The three tiers the model classifies into (defined in `backend/services/classifier.py`):
|
| 211 |
+
|
| 212 |
+
| Tier | Description | Model Tier | Cost |
|
| 213 |
+
|------|-------------|-----------|------|
|
| 214 |
+
| simple | FAQs, status checks, basic lookups | gpt-4o-mini, gemini-flash | $0.10-0.60/M tokens |
|
| 215 |
+
| medium | Multi-step reasoning, comparisons, troubleshooting | gpt-4o, claude-sonnet | $2.50-15.00/M tokens |
|
| 216 |
+
| complex | Multi-document analysis, legal/financial reasoning | gemini-2.5-pro, claude-sonnet | $2.50-15.00/M tokens |
|
| 217 |
+
|
| 218 |
+
Correctly routing simple queries to cheap models instead of premium ones is the core value proposition. The stock model's 14% medium accuracy means it wastes money routing mid-tier queries to expensive models. ModelGate-Router's 86% medium accuracy captures those savings.
|
| 219 |
+
|
| 220 |
+
## References
|
| 221 |
+
|
| 222 |
+
- [Arch-Router-1.5B](https://huggingface.co/katanemo/Arch-Router-1.5B) β base model
|
| 223 |
+
- [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) β foundation model
|
| 224 |
+
- [Unsloth](https://github.com/unslothai/unsloth) β training framework
|
| 225 |
+
- [TRL GRPOTrainer](https://huggingface.co/docs/trl/main/en/grpo_trainer) β GRPO implementation
|
| 226 |
+
- [llama.cpp](https://github.com/ggerganov/llama.cpp) β GGUF inference engine
|