MuratcanKoylan commited on
Commit
192d8d2
Β·
verified Β·
1 Parent(s): e4c940d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +139 -58
README.md CHANGED
@@ -19,69 +19,128 @@ pipeline_tag: text-classification
19
 
20
  # Memory Routing Agent
21
 
22
- A specialized 8B parameter model that **outperforms 104B models** on marketing conversation classification.
23
 
24
- ## Key Results
 
 
25
 
26
- ![Model Comparison](assets/model_comparison.png)
27
 
28
- | Metric | Our Model (8B) | Cohere (104B) |
29
- |--------|----------------|---------------|
30
- | **Avg F1** | **0.68** | 0.61 |
31
- | Exact Match | **60%** | 26% |
32
- | Model Size | 8B | 104B |
33
- | **Improvement** | **+11.1% F1** | baseline |
34
 
35
- Our 8B model achieves **11.1% higher F1 score** than the 104B teacher model that generated its training data, while being **13x smaller**.
 
 
36
 
37
- ## Training Results
38
 
39
- ### Phase 1: Supervised Fine-Tuning (SFT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ![SFT Loss](assets/sft_loss.png)
42
 
43
- - **100 training steps** on 2,001 synthetic conversations
44
- - Loss dropped from **5.47 β†’ 0.26** (95% reduction)
45
- - Best test loss: **0.105** at step 90
46
 
47
- ### Phase 2: Reinforcement Learning (RL)
48
 
49
  ![RL Reward](assets/rl_reward.png)
50
 
51
- - **30 RL iterations** with importance sampling policy gradient
52
- - Mean reward improved from **0.73 β†’ 0.93** (+27%)
53
- - Accuracy maintained at **99.9%+** throughout
54
 
55
- ### Reward Components
56
 
57
- ![RL Components](assets/rl_components.png)
58
 
59
- | Component | Start | End | Description |
60
- |-----------|-------|-----|-------------|
61
- | R_F1 | 0.64 | 0.90 | F1 score vs gold labels |
62
- | R_temp | 0.81 | 0.95 | Temporal alignment |
63
- | R_parity | 0.86 | 1.00 | Company/user scope |
64
- | R_eff | 1.00 | 1.00 | Storage efficiency |
65
 
66
- ## Performance by Difficulty
67
 
68
  ![Difficulty Comparison](assets/difficulty_comparison.png)
69
 
70
- | Difficulty | Our Model | Cohere (104B) | Winner |
71
- |------------|-----------|---------------|--------|
72
- | Easy | **0.86** | 0.48 | Ours (+79%) |
73
- | Medium | **0.65** | 0.64 | Ours (+2%) |
74
- | Hard | 0.50 | **0.72** | Cohere |
75
 
76
- Our model excels at clear-cut cases but the larger model handles ambiguous multi-label scenarios better.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  ## What It Does
79
 
80
- The Memory Routing Agent classifies marketing conversations into 13 categories to determine what information should be stored in an AI assistant's long-term memory:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- - **Company categories**: brand_core, strategic_signatures, knowledge_artifacts, business_priorities, tools_config, performance_context
83
- - **User categories**: communication_style, strategic_approach, role_context, workflow_patterns, session_history, interaction_preferences
84
- - **None**: Transactional or irrelevant content
85
 
86
  ## Training Pipeline
87
 
@@ -109,6 +168,21 @@ The Memory Routing Agent classifies marketing conversations into 13 categories t
109
  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
110
  ```
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  ## Quick Start
113
 
114
  ### Installation
@@ -123,14 +197,13 @@ python -m venv venv
123
  source venv/bin/activate
124
 
125
  # Install dependencies
126
- pip install tinker-toolkit python-dotenv cohere
127
- pip install -e ".[envs]"
128
  ```
129
 
130
  ### Environment Setup
131
 
132
  ```bash
133
- # Create .env file
134
  echo "TINKER_API_KEY=your_tinker_key" >> .env
135
  echo "COHERE_API_KEY=your_cohere_key" >> .env
136
  echo "HF_TOKEN=your_huggingface_token" >> .env
@@ -144,7 +217,7 @@ from tinker import types
144
  from tinker_cookbook import renderers
145
  from tinker_cookbook.tokenizer_utils import get_tokenizer
146
 
147
- # Load model
148
  service_client = tinker.ServiceClient()
149
  checkpoint = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012"
150
  sampling_client = service_client.create_sampling_client(model_path=checkpoint)
@@ -174,6 +247,8 @@ print(f"Categories: {response['content']}")
174
  # Output: company.brand_core
175
  ```
176
 
 
 
177
  ## Project Structure
178
 
179
  ```
@@ -203,6 +278,8 @@ memory-routing-agent/
203
  └── README.md # This file
204
  ```
205
 
 
 
206
  ## Benchmark
207
 
208
  The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domains:
@@ -223,6 +300,8 @@ The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domai
223
  python training/final_benchmark.py
224
  ```
225
 
 
 
226
  ## Training Your Own Model
227
 
228
  ### 1. Generate Synthetic Data
@@ -250,27 +329,24 @@ python training/train_v2.py
250
  python training/final_benchmark.py
251
  ```
252
 
253
- ## Reward Function
254
-
255
- The RL phase uses a composite reward:
256
-
257
- ```
258
- R_total = 0.6 Γ— R_F1 + 0.2 Γ— R_temp + 0.1 Γ— R_parity + 0.1 Γ— R_eff
259
- ```
260
-
261
- | Component | Weight | Description |
262
- |-----------|--------|-------------|
263
- | R_F1 | 60% | F1 score vs gold labels |
264
- | R_temp | 20% | Persistence horizon alignment |
265
- | R_parity | 10% | Company/user scope correctness |
266
- | R_eff | 10% | Storage efficiency (≀3 categories) |
267
 
268
  ## Limitations
269
 
270
  - **Multi-label**: Under-predicts when multiple categories apply
271
- - **Overlap**: Struggles with company/user category overlap
272
  - **Domain**: Marketing-specific; not tested on other domains
273
 
 
 
 
 
 
 
 
 
 
 
274
  ## Citation
275
 
276
  ```bibtex
@@ -282,12 +358,17 @@ R_total = 0.6 Γ— R_F1 + 0.2 Γ— R_temp + 0.1 Γ— R_parity + 0.1 Γ— R_eff
282
  }
283
  ```
284
 
 
 
285
  ## License
286
 
287
  Apache 2.0
288
 
 
 
289
  ## Acknowledgments
290
 
291
- - [Thinking Machines](https://thinkingmachines.ai/) for Tinker training platform
292
  - [Cohere](https://cohere.com/) for Command-R-Plus teacher model
293
  - [Meta](https://ai.meta.com/) for Llama 3.1 base model
 
 
19
 
20
  # Memory Routing Agent
21
 
22
+ **A specialized 8B model that outperforms its 104B teacher on marketing conversation classification.**
23
 
24
+ [![HuggingFace](https://img.shields.io/badge/πŸ€—%20Model-Marketing--Memory--Routing--8B-blue)](https://huggingface.co/MuratcanKoylan/Marketing-Memory-Routing-8B)
25
+ [![GitHub](https://img.shields.io/badge/GitHub-memory--routing--agent-black)](https://github.com/muratcankoylan/memory-routing-agent)
26
+ [![License](https://img.shields.io/badge/License-Apache%202.0-green)](LICENSE)
27
 
28
+ ---
29
 
30
+ ## The Experiment
31
+
32
+ This project demonstrates **prompt distillation**: training a small, specialized model to outperform the large model that generated its training data.
33
+
34
+ ### The Challenge
 
35
 
36
+ Marketing AI assistants need to remember the right information from conversations. Not everything is worth storing - you need to distinguish between:
37
+ - **Valuable**: "Our brand voice is professional but approachable" β†’ Store in long-term memory
38
+ - **Transactional**: "What time is the meeting tomorrow?" β†’ Don't store
39
 
40
+ This is a **13-category classification problem** with nuanced distinctions between company-level and user-level information, different persistence horizons, and the critical ability to say "none" for irrelevant content.
41
 
42
+ ### The Approach
43
+
44
+ 1. **Generate synthetic data** using Cohere Command-R-Plus (104B) as the teacher
45
+ 2. **Fine-tune Llama-3.1-8B** with LoRA using Tinker's training platform
46
+ 3. **Apply reinforcement learning** with a custom reward function
47
+ 4. **Benchmark against the teacher** on challenging, held-out scenarios
48
+
49
+ ### The Result
50
+
51
+ | Model | Parameters | Avg F1 | Exact Match |
52
+ |-------|------------|--------|-------------|
53
+ | **Ours** | **8B** | **0.68** | **60%** |
54
+ | Cohere Command-R-Plus | 104B | 0.61 | 26% |
55
+
56
+ **Our 8B model achieves 11.1% higher F1 and 2.3x better exact match accuracy than the 104B teacher, while being 13x smaller.**
57
+
58
+ The student surpassed the teacher through:
59
+ - **Focused training**: The model only learns this one task, not general capabilities
60
+ - **RL refinement**: The reward function optimizes for exact category matching, not just plausible outputs
61
+ - **Clean data**: Synthetic data with consistent labeling, no noise from human annotation disagreements
62
+
63
+ ---
64
+
65
+ ## Training Visualizations
66
+
67
+ ### Phase 1: Supervised Fine-Tuning
68
 
69
  ![SFT Loss](assets/sft_loss.png)
70
 
71
+ 100 training steps reduced loss from 5.47 to 0.26 (95% reduction). The model learned the basic classification task in the first epoch.
 
 
72
 
73
+ ### Phase 2: Reinforcement Learning
74
 
75
  ![RL Reward](assets/rl_reward.png)
76
 
77
+ 30 RL iterations improved mean reward from 0.73 to 0.93. The reward function combines F1 score, temporal alignment, scope correctness, and storage efficiency.
 
 
78
 
79
+ ### Model Comparison
80
 
81
+ ![Model Comparison](assets/model_comparison.png)
82
 
83
+ Our model excels at exact matching (60% vs 26%) because RL optimizes for getting all categories right, not just some.
 
 
 
 
 
84
 
85
+ ### Performance by Difficulty
86
 
87
  ![Difficulty Comparison](assets/difficulty_comparison.png)
88
 
89
+ The 8B model dominates on easy cases (+79% F1) and matches on medium cases. The 104B model still wins on hard multi-label scenarios.
 
 
 
 
90
 
91
+ ---
92
+
93
+ ## Key Results
94
+
95
+ | Metric | Our Model (8B) | Cohere (104B) |
96
+ |--------|----------------|---------------|
97
+ | **Avg F1** | **0.68** | 0.61 |
98
+ | **Exact Match** | **60%** | 26% |
99
+ | Any Match | 72% | 82% |
100
+ | Model Size | 8B | 104B |
101
+ | **Improvement** | **+11.1% F1** | baseline |
102
+
103
+ ### Reward Components (Final RL Iteration)
104
+
105
+ | Component | Score | Description |
106
+ |-----------|-------|-------------|
107
+ | R_F1 | 0.90 | F1 score vs gold labels |
108
+ | R_temp | 0.95 | Temporal alignment |
109
+ | R_parity | 1.00 | Company/user scope |
110
+ | R_eff | 1.00 | Storage efficiency |
111
+
112
+ ---
113
 
114
  ## What It Does
115
 
116
+ The Memory Routing Agent classifies marketing conversations into 13 memory categories:
117
+
118
+ ### Company Categories (Long-term business context)
119
+ | Category | Description | Persistence |
120
+ |----------|-------------|-------------|
121
+ | `company.brand_core` | Voice, values, positioning | Long (>1y) |
122
+ | `company.strategic_signatures` | Decision frameworks | Long (>1y) |
123
+ | `company.knowledge_artifacts` | Docs, style guides | Long (>1y) |
124
+ | `company.business_priorities` | Quarterly goals | Short (<3m) |
125
+ | `company.tools_config` | Integrations, APIs | Medium (~6m) |
126
+ | `company.performance_context` | Campaign metrics | Rolling (~6m) |
127
+
128
+ ### User Categories (Personal preferences)
129
+ | Category | Description | Persistence |
130
+ |----------|-------------|-------------|
131
+ | `user.communication_style` | Tone, format preferences | Long (>1y) |
132
+ | `user.strategic_approach` | Personal priorities | Long (>1y) |
133
+ | `user.role_context` | Title, scope | Medium (~1y) |
134
+ | `user.workflow_patterns` | Review cadence | Medium (~1y) |
135
+ | `user.session_history` | Immediate context | Short (<2w) |
136
+ | `user.interaction_preferences` | Coaching style | Evolving |
137
+
138
+ ### Special
139
+ | Category | Description |
140
+ |----------|-------------|
141
+ | `none` | Transactional or irrelevant content |
142
 
143
+ ---
 
 
144
 
145
  ## Training Pipeline
146
 
 
168
  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
169
  ```
170
 
171
+ ### Reward Function
172
+
173
+ ```
174
+ R_total = 0.6 Γ— R_F1 + 0.2 Γ— R_temp + 0.1 Γ— R_parity + 0.1 Γ— R_eff
175
+ ```
176
+
177
+ | Component | Weight | Description |
178
+ |-----------|--------|-------------|
179
+ | R_F1 | 60% | F1 score vs gold labels |
180
+ | R_temp | 20% | Persistence horizon alignment |
181
+ | R_parity | 10% | Company/user scope correctness |
182
+ | R_eff | 10% | Storage efficiency (≀3 categories) |
183
+
184
+ ---
185
+
186
  ## Quick Start
187
 
188
  ### Installation
 
197
  source venv/bin/activate
198
 
199
  # Install dependencies
200
+ pip install -r requirements.txt
 
201
  ```
202
 
203
  ### Environment Setup
204
 
205
  ```bash
206
+ # Create .env file with your API keys
207
  echo "TINKER_API_KEY=your_tinker_key" >> .env
208
  echo "COHERE_API_KEY=your_cohere_key" >> .env
209
  echo "HF_TOKEN=your_huggingface_token" >> .env
 
217
  from tinker_cookbook import renderers
218
  from tinker_cookbook.tokenizer_utils import get_tokenizer
219
 
220
+ # Load model from Tinker checkpoint
221
  service_client = tinker.ServiceClient()
222
  checkpoint = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012"
223
  sampling_client = service_client.create_sampling_client(model_path=checkpoint)
 
247
  # Output: company.brand_core
248
  ```
249
 
250
+ ---
251
+
252
  ## Project Structure
253
 
254
  ```
 
278
  └── README.md # This file
279
  ```
280
 
281
+ ---
282
+
283
  ## Benchmark
284
 
285
  The Marketing Routing Benchmark contains 50 challenging scenarios across 7 domains:
 
300
  python training/final_benchmark.py
301
  ```
302
 
303
+ ---
304
+
305
  ## Training Your Own Model
306
 
307
  ### 1. Generate Synthetic Data
 
329
  python training/final_benchmark.py
330
  ```
331
 
332
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  ## Limitations
335
 
336
  - **Multi-label**: Under-predicts when multiple categories apply
337
+ - **Overlap**: Struggles with company/user category overlap on edge cases
338
  - **Domain**: Marketing-specific; not tested on other domains
339
 
340
+ ---
341
+
342
+ ## Links
343
+
344
+ - **HuggingFace Model**: [MuratcanKoylan/Marketing-Memory-Routing-8B](https://huggingface.co/MuratcanKoylan/Marketing-Memory-Routing-8B)
345
+ - **GitHub Repository**: [muratcankoylan/memory-routing-agent](https://github.com/muratcankoylan/memory-routing-agent)
346
+ - **Training Platform**: [Tinker by Thinking Machines](https://thinkingmachines.ai/)
347
+
348
+ ---
349
+
350
  ## Citation
351
 
352
  ```bibtex
 
358
  }
359
  ```
360
 
361
+ ---
362
+
363
  ## License
364
 
365
  Apache 2.0
366
 
367
+ ---
368
+
369
  ## Acknowledgments
370
 
371
+ - [Thinking Machines](https://thinkingmachines.ai/) for the Tinker training platform
372
  - [Cohere](https://cohere.com/) for Command-R-Plus teacher model
373
  - [Meta](https://ai.meta.com/) for Llama 3.1 base model
374
+ - [Anthropic](https://anthropic.com/) for Claude, which assisted in developing this project