|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1 |
|
|
language: |
|
|
- en |
|
|
base_model: |
|
|
- Qwen/Qwen3-0.6B |
|
|
--- |
|
|
# Context Merging: from Tokens to Entities and Concepts |
|
|
|
|
|
This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings. |
|
|
|
|
|
- `prepare_dataset.py` builds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping. |
|
|
- `train_custom_qwen3.py` fine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0. |
|
|
- `inference_qwen3_merged.py` runs end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings. |
|
|
|
|
|
--- |
|
|
|
|
|
## How it works |
|
|
|
|
|
1. **Layer-0 grouping at prefill** |
|
|
A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as `grouped_hidden_states`. |
|
|
|
|
|
2. **Dataset building** |
|
|
The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored `grouped_hidden_states`, and serializes them together with target responses. |
|
|
|
|
|
3. **Model training** |
|
|
The training model wraps Qwen3 with a **GroupedInputMLPAdapter** that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses `grouped_inputs` as `inputs_embeds`, then generation proceeds with past-key-values. |
|
|
|
|
|
4. **Inference** |
|
|
The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared. |
|
|
|
|
|
--- |
|
|
|
|
|
## Requirements |
|
|
|
|
|
- Python packages: `torch`, `transformers`, `datasets`, `tqdm`, `psutil`. These are imported directly in the scripts. |
|
|
- GPU is optional. Scripts detect CUDA and set dtype accordingly. |
|
|
|
|
|
Install: |
|
|
|
|
|
```bash |
|
|
pip install torch transformers datasets tqdm psutil |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Repository layout |
|
|
|
|
|
- `prepare_dataset.py` - dataset builder using custom layer 0 grouping. |
|
|
- `train_custom_qwen3.py` - trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0. |
|
|
- `inference_qwen3_merged.py` - two-stage inference runner with metrics. |
|
|
|
|
|
--- |
|
|
|
|
|
## 1 Build the local dataset |
|
|
|
|
|
Run: |
|
|
|
|
|
```bash |
|
|
python prepare_dataset.py |
|
|
``` |
|
|
|
|
|
Key defaults inside `DatasetProcessor`: |
|
|
|
|
|
- `model_name="Qwen/Qwen3-0.6B"` |
|
|
- `dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"` |
|
|
- `output_dir="./processed_dataset"` |
|
|
- `batch_size=1`, `max_samples=None`, `save_frequency=1000` |
|
|
Edit these in the constructor if you need to change them. |
|
|
|
|
|
The builder formats inputs using a simple system prompt template. |
|
|
It tokenizes, runs layer 0 once per example, captures `grouped_hidden_states`, and buffers results. |
|
|
|
|
|
**Outputs** under `output_dir`: |
|
|
|
|
|
- `processed_dataset.pkl` - list of samples with `inputs_embeds` (grouped), `response`, and metadata. |
|
|
- Additional metadata and sample previews are written alongside, for quick inspection. |
|
|
|
|
|
--- |
|
|
|
|
|
## 2 Train the grouped-input model |
|
|
|
|
|
Run: |
|
|
|
|
|
```bash |
|
|
python train_custom_qwen3.py --mode train |
|
|
``` |
|
|
|
|
|
Training config defaults (edit in the script if needed): |
|
|
|
|
|
- `model_name="Qwen/Qwen3-0.6B"` |
|
|
- `dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"` |
|
|
- `output_dir="./grouped_qwen3_checkpoint"` |
|
|
- `batch_size=4`, `learning_rate=5e-4`, `num_epochs=3`, `warmup_steps=100` |
|
|
- Logging, eval, and checkpoint cadence are configurable. |
|
|
|
|
|
What is trained: |
|
|
|
|
|
- A **GroupedInputMLPAdapter** that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm. |
|
|
- Only layer 0 and this adapter are trainable; everything else is frozen. |
|
|
|
|
|
How targets are computed: |
|
|
|
|
|
- Prefill: pass `grouped_inputs` via `inputs_embeds` with `is_prefill=True`. |
|
|
- Then feed target response tokens while reusing `past_key_values`. |
|
|
|
|
|
Checkpoints contain model weights, config, and tokenizer in the epoch folder. |
|
|
|
|
|
--- |
|
|
|
|
|
## 3 Run inference |
|
|
|
|
|
### Option A - standalone runner |
|
|
|
|
|
Quick start: |
|
|
|
|
|
```bash |
|
|
python inference_qwen3_merged.py \ |
|
|
--checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \ |
|
|
--grouping_model Qwen/Qwen3-0.6B \ |
|
|
--instruction "Explain attention like I am in 9th grade" \ |
|
|
--max_length 256 \ |
|
|
--temperature 0.7 \ |
|
|
--device cuda |
|
|
``` |
|
|
|
|
|
CLI options: `--checkpoint`, `--grouping_model`, `--instruction`, `--max_length`, `--temperature`, `--no_sample` for greedy, and `--device` for cuda or cpu. |
|
|
|
|
|
What it does: |
|
|
|
|
|
- Loads a grouping model with the custom layer 0 and a trained inference model. |
|
|
- Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model. |
|
|
- Reports compression ratio, memory reduction, total time, and tokens per second. |
|
|
|
|
|
### Option B - use the training script utilities |
|
|
|
|
|
The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See `load_trained_model` and `generate_with_grouped_input` in the training script if you prefer a programmatic flow. |
|
|
|
|
|
--- |
|
|
|
|
|
## Parameters - quick reference |
|
|
|
|
|
### Dataset builder |
|
|
|
|
|
- `model_name` - base HF model for grouping, default Qwen/Qwen3-0.6B. |
|
|
- `dataset_name` - source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M. |
|
|
- `output_dir` - where pickles and metadata go. |
|
|
- `max_samples` - optional cap for quick tests. |
|
|
|
|
|
### Training |
|
|
|
|
|
- `dataset_path` - path to `processed_dataset.pkl`. |
|
|
- `output_dir` - where checkpoints are written. |
|
|
- `batch_size, learning_rate, num_epochs, warmup_steps` - training hyperparams. |
|
|
- Only layer 0 and the adapter are trainable. Verify with `requires_grad` settings in `_freeze_layers`. |
|
|
|
|
|
### Inference |
|
|
|
|
|
- `--checkpoint` - path to trained checkpoint folder. |
|
|
- `--grouping_model` - HF model name used for grouping. |
|
|
- `--instruction` - user prompt, any language. |
|
|
- `--max_length`, `--temperature`, `--no_sample`, `--device`. |
|
|
|
|
|
--- |
|
|
|
|
|
## Notes |
|
|
|
|
|
- The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states. |
|
|
- Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in `CustomQwen3Attention`. |
|
|
|
|
|
--- |
|
|
|
|
|
## Troubleshooting |
|
|
|
|
|
- **CUDA memory spikes**: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values. |
|
|
- **No grouped states found**: ensure the custom layer 0 is used and `is_initialized` is reset before each prefill. |
|
|
- **Checkpoint not found**: the inference loader expects `pytorch_model.bin` or `model.safetensors` in the checkpoint directory. |
|
|
|
|
|
--- |
|
|
|
|
|
## Why this can save memory |
|
|
|
|
|
If the sequence shrinks from `N` to `G` groups, attention memory scales roughly with `G^2` vs `N^2`. The script prints an estimated savings based on that relation. |
|
|
|
|
|
--- |
|
|
|
|
|
## Citation |
|
|
|
|
|
``` |
|
|
@misc{Kolomeitsev2025ContextMerging, |
|
|
title = {Context Merging: from Tokens to Entities and Concepts}, |
|
|
author = {Konstantin Kolomeitsev}, |
|
|
year = {2025} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Contact |
|
|
|
|
|
If you have any questions, please raise an issue or contact with me [uol92kot@gmail.com](uol92kot@gmail.com). |