Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -19,69 +19,128 @@ pipeline_tag: text-classification
|
|
| 19 |
|
| 20 |
# Memory Routing Agent
|
| 21 |
|
| 22 |
-
A specialized 8B
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
| **Improvement** | **+11.1% F1** | baseline |
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |

|
| 42 |
|
| 43 |
-
|
| 44 |
-
- Loss dropped from **5.47 β 0.26** (95% reduction)
|
| 45 |
-
- Best test loss: **0.105** at step 90
|
| 46 |
|
| 47 |
-
### Phase 2: Reinforcement Learning
|
| 48 |
|
| 49 |

|
| 50 |
|
| 51 |
-
|
| 52 |
-
- Mean reward improved from **0.73 β 0.93** (+27%)
|
| 53 |
-
- Accuracy maintained at **99.9%+** throughout
|
| 54 |
|
| 55 |
-
###
|
| 56 |
|
| 57 |
-

|
| 69 |
|
| 70 |
-
|
| 71 |
-
|------------|-----------|---------------|--------|
|
| 72 |
-
| Easy | **0.86** | 0.48 | Ours (+79%) |
|
| 73 |
-
| Medium | **0.65** | 0.64 | Ours (+2%) |
|
| 74 |
-
| Hard | 0.50 | **0.72** | Cohere |
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
## What It Does
|
| 79 |
|
| 80 |
-
The Memory Routing Agent classifies marketing conversations into 13 categories
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
- **User categories**: communication_style, strategic_approach, role_context, workflow_patterns, session_history, interaction_preferences
|
| 84 |
-
- **None**: Transactional or irrelevant content
|
| 85 |
|
| 86 |
## Training Pipeline
|
| 87 |
|
|
@@ -109,6 +168,21 @@ The Memory Routing Agent classifies marketing conversations into 13 categories t
|
|
| 109 |
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
```
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
## Quick Start
|
| 113 |
|
| 114 |
### Installation
|
|
@@ -123,14 +197,13 @@ python -m venv venv
|
|
| 123 |
source venv/bin/activate
|
| 124 |
|
| 125 |
# Install dependencies
|
| 126 |
-
pip install
|
| 127 |
-
pip install -e ".[envs]"
|
| 128 |
```
|
| 129 |
|
| 130 |
### Environment Setup
|
| 131 |
|
| 132 |
```bash
|
| 133 |
-
# Create .env file
|
| 134 |
echo "TINKER_API_KEY=your_tinker_key" >> .env
|
| 135 |
echo "COHERE_API_KEY=your_cohere_key" >> .env
|
| 136 |
echo "HF_TOKEN=your_huggingface_token" >> .env
|
|
@@ -144,7 +217,7 @@ from tinker import types
|
|
| 144 |
from tinker_cookbook import renderers
|
| 145 |
from tinker_cookbook.tokenizer_utils import get_tokenizer
|
| 146 |
|
| 147 |
-
# Load model
|
| 148 |
service_client = tinker.ServiceClient()
|
| 149 |
checkpoint = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012"
|
| 150 |
sampling_client = service_client.create_sampling_client(model_path=checkpoint)
|
|
@@ -174,6 +247,8 @@ print(f"Categories: {response['content']}")
|
|
| 174 |
# Output: company.brand_core
|
| 175 |
```
|
| 176 |
|
|
|
|
|
|
|
| 177 |
## Project Structure
|
| 178 |
|
| 179 |
```
|
|
@@ -203,6 +278,8 @@ memory-routing-agent/
|
|
| 203 |
βββ README.md # This file
|
| 204 |
```
|
| 205 |
|
|
|
|
|
|
|
| 206 |
## Benchmark
|
| 207 |
|
| 208 |
The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domains:
|
|
@@ -223,6 +300,8 @@ The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domai
|
|
| 223 |
python training/final_benchmark.py
|
| 224 |
```
|
| 225 |
|
|
|
|
|
|
|
| 226 |
## Training Your Own Model
|
| 227 |
|
| 228 |
### 1. Generate Synthetic Data
|
|
@@ -250,27 +329,24 @@ python training/train_v2.py
|
|
| 250 |
python training/final_benchmark.py
|
| 251 |
```
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
The RL phase uses a composite reward:
|
| 256 |
-
|
| 257 |
-
```
|
| 258 |
-
R_total = 0.6 Γ R_F1 + 0.2 Γ R_temp + 0.1 Γ R_parity + 0.1 Γ R_eff
|
| 259 |
-
```
|
| 260 |
-
|
| 261 |
-
| Component | Weight | Description |
|
| 262 |
-
|-----------|--------|-------------|
|
| 263 |
-
| R_F1 | 60% | F1 score vs gold labels |
|
| 264 |
-
| R_temp | 20% | Persistence horizon alignment |
|
| 265 |
-
| R_parity | 10% | Company/user scope correctness |
|
| 266 |
-
| R_eff | 10% | Storage efficiency (β€3 categories) |
|
| 267 |
|
| 268 |
## Limitations
|
| 269 |
|
| 270 |
- **Multi-label**: Under-predicts when multiple categories apply
|
| 271 |
-
- **Overlap**: Struggles with company/user category overlap
|
| 272 |
- **Domain**: Marketing-specific; not tested on other domains
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
## Citation
|
| 275 |
|
| 276 |
```bibtex
|
|
@@ -282,12 +358,17 @@ R_total = 0.6 Γ R_F1 + 0.2 Γ R_temp + 0.1 Γ R_parity + 0.1 Γ R_eff
|
|
| 282 |
}
|
| 283 |
```
|
| 284 |
|
|
|
|
|
|
|
| 285 |
## License
|
| 286 |
|
| 287 |
Apache 2.0
|
| 288 |
|
|
|
|
|
|
|
| 289 |
## Acknowledgments
|
| 290 |
|
| 291 |
-
- [Thinking Machines](https://thinkingmachines.ai/) for Tinker training platform
|
| 292 |
- [Cohere](https://cohere.com/) for Command-R-Plus teacher model
|
| 293 |
- [Meta](https://ai.meta.com/) for Llama 3.1 base model
|
|
|
|
|
|
| 19 |
|
| 20 |
# Memory Routing Agent
|
| 21 |
|
| 22 |
+
**A specialized 8B model that outperforms its 104B teacher on marketing conversation classification.**
|
| 23 |
|
| 24 |
+
[](https://huggingface.co/MuratcanKoylan/Marketing-Memory-Routing-8B)
|
| 25 |
+
[](https://github.com/muratcankoylan/memory-routing-agent)
|
| 26 |
+
[](LICENSE)
|
| 27 |
|
| 28 |
+
---
|
| 29 |
|
| 30 |
+
## The Experiment
|
| 31 |
+
|
| 32 |
+
This project demonstrates **prompt distillation**: training a small, specialized model to outperform the large model that generated its training data.
|
| 33 |
+
|
| 34 |
+
### The Challenge
|
|
|
|
| 35 |
|
| 36 |
+
Marketing AI assistants need to remember the right information from conversations. Not everything is worth storing - you need to distinguish between:
|
| 37 |
+
- **Valuable**: "Our brand voice is professional but approachable" β Store in long-term memory
|
| 38 |
+
- **Transactional**: "What time is the meeting tomorrow?" β Don't store
|
| 39 |
|
| 40 |
+
This is a **13-category classification problem** with nuanced distinctions between company-level and user-level information, different persistence horizons, and the critical ability to say "none" for irrelevant content.
|
| 41 |
|
| 42 |
+
### The Approach
|
| 43 |
+
|
| 44 |
+
1. **Generate synthetic data** using Cohere Command-R-Plus (104B) as the teacher
|
| 45 |
+
2. **Fine-tune Llama-3.1-8B** with LoRA using Tinker's training platform
|
| 46 |
+
3. **Apply reinforcement learning** with a custom reward function
|
| 47 |
+
4. **Benchmark against the teacher** on challenging, held-out scenarios
|
| 48 |
+
|
| 49 |
+
### The Result
|
| 50 |
+
|
| 51 |
+
| Model | Parameters | Avg F1 | Exact Match |
|
| 52 |
+
|-------|------------|--------|-------------|
|
| 53 |
+
| **Ours** | **8B** | **0.68** | **60%** |
|
| 54 |
+
| Cohere Command-R-Plus | 104B | 0.61 | 26% |
|
| 55 |
+
|
| 56 |
+
**Our 8B model achieves 11.1% higher F1 and 2.3x better exact match accuracy than the 104B teacher, while being 13x smaller.**
|
| 57 |
+
|
| 58 |
+
The student surpassed the teacher through:
|
| 59 |
+
- **Focused training**: The model only learns this one task, not general capabilities
|
| 60 |
+
- **RL refinement**: The reward function optimizes for exact category matching, not just plausible outputs
|
| 61 |
+
- **Clean data**: Synthetic data with consistent labeling, no noise from human annotation disagreements
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## Training Visualizations
|
| 66 |
+
|
| 67 |
+
### Phase 1: Supervised Fine-Tuning
|
| 68 |
|
| 69 |

|
| 70 |
|
| 71 |
+
100 training steps reduced loss from 5.47 to 0.26 (95% reduction). The model learned the basic classification task in the first epoch.
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
### Phase 2: Reinforcement Learning
|
| 74 |
|
| 75 |

|
| 76 |
|
| 77 |
+
30 RL iterations improved mean reward from 0.73 to 0.93. The reward function combines F1 score, temporal alignment, scope correctness, and storage efficiency.
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
### Model Comparison
|
| 80 |
|
| 81 |
+

|
| 82 |
|
| 83 |
+
Our model excels at exact matching (60% vs 26%) because RL optimizes for getting all categories right, not just some.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
### Performance by Difficulty
|
| 86 |
|
| 87 |

|
| 88 |
|
| 89 |
+
The 8B model dominates on easy cases (+79% F1) and matches on medium cases. The 104B model still wins on hard multi-label scenarios.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## Key Results
|
| 94 |
+
|
| 95 |
+
| Metric | Our Model (8B) | Cohere (104B) |
|
| 96 |
+
|--------|----------------|---------------|
|
| 97 |
+
| **Avg F1** | **0.68** | 0.61 |
|
| 98 |
+
| **Exact Match** | **60%** | 26% |
|
| 99 |
+
| Any Match | 72% | 82% |
|
| 100 |
+
| Model Size | 8B | 104B |
|
| 101 |
+
| **Improvement** | **+11.1% F1** | baseline |
|
| 102 |
+
|
| 103 |
+
### Reward Components (Final RL Iteration)
|
| 104 |
+
|
| 105 |
+
| Component | Score | Description |
|
| 106 |
+
|-----------|-------|-------------|
|
| 107 |
+
| R_F1 | 0.90 | F1 score vs gold labels |
|
| 108 |
+
| R_temp | 0.95 | Temporal alignment |
|
| 109 |
+
| R_parity | 1.00 | Company/user scope |
|
| 110 |
+
| R_eff | 1.00 | Storage efficiency |
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
|
| 114 |
## What It Does
|
| 115 |
|
| 116 |
+
The Memory Routing Agent classifies marketing conversations into 13 memory categories:
|
| 117 |
+
|
| 118 |
+
### Company Categories (Long-term business context)
|
| 119 |
+
| Category | Description | Persistence |
|
| 120 |
+
|----------|-------------|-------------|
|
| 121 |
+
| `company.brand_core` | Voice, values, positioning | Long (>1y) |
|
| 122 |
+
| `company.strategic_signatures` | Decision frameworks | Long (>1y) |
|
| 123 |
+
| `company.knowledge_artifacts` | Docs, style guides | Long (>1y) |
|
| 124 |
+
| `company.business_priorities` | Quarterly goals | Short (<3m) |
|
| 125 |
+
| `company.tools_config` | Integrations, APIs | Medium (~6m) |
|
| 126 |
+
| `company.performance_context` | Campaign metrics | Rolling (~6m) |
|
| 127 |
+
|
| 128 |
+
### User Categories (Personal preferences)
|
| 129 |
+
| Category | Description | Persistence |
|
| 130 |
+
|----------|-------------|-------------|
|
| 131 |
+
| `user.communication_style` | Tone, format preferences | Long (>1y) |
|
| 132 |
+
| `user.strategic_approach` | Personal priorities | Long (>1y) |
|
| 133 |
+
| `user.role_context` | Title, scope | Medium (~1y) |
|
| 134 |
+
| `user.workflow_patterns` | Review cadence | Medium (~1y) |
|
| 135 |
+
| `user.session_history` | Immediate context | Short (<2w) |
|
| 136 |
+
| `user.interaction_preferences` | Coaching style | Evolving |
|
| 137 |
+
|
| 138 |
+
### Special
|
| 139 |
+
| Category | Description |
|
| 140 |
+
|----------|-------------|
|
| 141 |
+
| `none` | Transactional or irrelevant content |
|
| 142 |
|
| 143 |
+
---
|
|
|
|
|
|
|
| 144 |
|
| 145 |
## Training Pipeline
|
| 146 |
|
|
|
|
| 168 |
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
```
|
| 170 |
|
| 171 |
+
### Reward Function
|
| 172 |
+
|
| 173 |
+
```
|
| 174 |
+
R_total = 0.6 Γ R_F1 + 0.2 Γ R_temp + 0.1 Γ R_parity + 0.1 Γ R_eff
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
| Component | Weight | Description |
|
| 178 |
+
|-----------|--------|-------------|
|
| 179 |
+
| R_F1 | 60% | F1 score vs gold labels |
|
| 180 |
+
| R_temp | 20% | Persistence horizon alignment |
|
| 181 |
+
| R_parity | 10% | Company/user scope correctness |
|
| 182 |
+
| R_eff | 10% | Storage efficiency (β€3 categories) |
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
## Quick Start
|
| 187 |
|
| 188 |
### Installation
|
|
|
|
| 197 |
source venv/bin/activate
|
| 198 |
|
| 199 |
# Install dependencies
|
| 200 |
+
pip install -r requirements.txt
|
|
|
|
| 201 |
```
|
| 202 |
|
| 203 |
### Environment Setup
|
| 204 |
|
| 205 |
```bash
|
| 206 |
+
# Create .env file with your API keys
|
| 207 |
echo "TINKER_API_KEY=your_tinker_key" >> .env
|
| 208 |
echo "COHERE_API_KEY=your_cohere_key" >> .env
|
| 209 |
echo "HF_TOKEN=your_huggingface_token" >> .env
|
|
|
|
| 217 |
from tinker_cookbook import renderers
|
| 218 |
from tinker_cookbook.tokenizer_utils import get_tokenizer
|
| 219 |
|
| 220 |
+
# Load model from Tinker checkpoint
|
| 221 |
service_client = tinker.ServiceClient()
|
| 222 |
checkpoint = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012"
|
| 223 |
sampling_client = service_client.create_sampling_client(model_path=checkpoint)
|
|
|
|
| 247 |
# Output: company.brand_core
|
| 248 |
```
|
| 249 |
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
## Project Structure
|
| 253 |
|
| 254 |
```
|
|
|
|
| 278 |
βββ README.md # This file
|
| 279 |
```
|
| 280 |
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
## Benchmark
|
| 284 |
|
| 285 |
The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domains:
|
|
|
|
| 300 |
python training/final_benchmark.py
|
| 301 |
```
|
| 302 |
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
## Training Your Own Model
|
| 306 |
|
| 307 |
### 1. Generate Synthetic Data
|
|
|
|
| 329 |
python training/final_benchmark.py
|
| 330 |
```
|
| 331 |
|
| 332 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
## Limitations
|
| 335 |
|
| 336 |
- **Multi-label**: Under-predicts when multiple categories apply
|
| 337 |
+
- **Overlap**: Struggles with company/user category overlap on edge cases
|
| 338 |
- **Domain**: Marketing-specific; not tested on other domains
|
| 339 |
|
| 340 |
+
---
|
| 341 |
+
|
| 342 |
+
## Links
|
| 343 |
+
|
| 344 |
+
- **HuggingFace Model**: [MuratcanKoylan/Marketing-Memory-Routing-8B](https://huggingface.co/MuratcanKoylan/Marketing-Memory-Routing-8B)
|
| 345 |
+
- **GitHub Repository**: [muratcankoylan/memory-routing-agent](https://github.com/muratcankoylan/memory-routing-agent)
|
| 346 |
+
- **Training Platform**: [Tinker by Thinking Machines](https://thinkingmachines.ai/)
|
| 347 |
+
|
| 348 |
+
---
|
| 349 |
+
|
| 350 |
## Citation
|
| 351 |
|
| 352 |
```bibtex
|
|
|
|
| 358 |
}
|
| 359 |
```
|
| 360 |
|
| 361 |
+
---
|
| 362 |
+
|
| 363 |
## License
|
| 364 |
|
| 365 |
Apache 2.0
|
| 366 |
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
## Acknowledgments
|
| 370 |
|
| 371 |
+
- [Thinking Machines](https://thinkingmachines.ai/) for the Tinker training platform
|
| 372 |
- [Cohere](https://cohere.com/) for Command-R-Plus teacher model
|
| 373 |
- [Meta](https://ai.meta.com/) for Llama 3.1 base model
|
| 374 |
+
- [Anthropic](https://anthropic.com/) for Claude, which assisted in developing this project
|