English
context-merging / README.md
kkolomeitsev's picture
Update README.md
0c5e860 verified
---
license: mit
datasets:
- Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1
language:
- en
base_model:
- Qwen/Qwen3-0.6B
---
# Context Merging: from Tokens to Entities and Concepts
This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings.
- `prepare_dataset.py` builds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping.
- `train_custom_qwen3.py` fine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0.
- `inference_qwen3_merged.py` runs end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings.
---
## How it works
1. **Layer-0 grouping at prefill**
A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as `grouped_hidden_states`.
2. **Dataset building**
The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored `grouped_hidden_states`, and serializes them together with target responses.
3. **Model training**
The training model wraps Qwen3 with a **GroupedInputMLPAdapter** that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses `grouped_inputs` as `inputs_embeds`, then generation proceeds with past-key-values.
4. **Inference**
The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared.
---
## Requirements
- Python packages: `torch`, `transformers`, `datasets`, `tqdm`, `psutil`. These are imported directly in the scripts.
- GPU is optional. Scripts detect CUDA and set dtype accordingly.
Install:
```bash
pip install torch transformers datasets tqdm psutil
```
---
## Repository layout
- `prepare_dataset.py` - dataset builder using custom layer 0 grouping.
- `train_custom_qwen3.py` - trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0.
- `inference_qwen3_merged.py` - two-stage inference runner with metrics.
---
## 1 Build the local dataset
Run:
```bash
python prepare_dataset.py
```
Key defaults inside `DatasetProcessor`:
- `model_name="Qwen/Qwen3-0.6B"`
- `dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"`
- `output_dir="./processed_dataset"`
- `batch_size=1`, `max_samples=None`, `save_frequency=1000`
Edit these in the constructor if you need to change them.
The builder formats inputs using a simple system prompt template.
It tokenizes, runs layer 0 once per example, captures `grouped_hidden_states`, and buffers results.
**Outputs** under `output_dir`:
- `processed_dataset.pkl` - list of samples with `inputs_embeds` (grouped), `response`, and metadata.
- Additional metadata and sample previews are written alongside, for quick inspection.
---
## 2 Train the grouped-input model
Run:
```bash
python train_custom_qwen3.py --mode train
```
Training config defaults (edit in the script if needed):
- `model_name="Qwen/Qwen3-0.6B"`
- `dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"`
- `output_dir="./grouped_qwen3_checkpoint"`
- `batch_size=4`, `learning_rate=5e-4`, `num_epochs=3`, `warmup_steps=100`
- Logging, eval, and checkpoint cadence are configurable.
What is trained:
- A **GroupedInputMLPAdapter** that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm.
- Only layer 0 and this adapter are trainable; everything else is frozen.
How targets are computed:
- Prefill: pass `grouped_inputs` via `inputs_embeds` with `is_prefill=True`.
- Then feed target response tokens while reusing `past_key_values`.
Checkpoints contain model weights, config, and tokenizer in the epoch folder.
---
## 3 Run inference
### Option A - standalone runner
Quick start:
```bash
python inference_qwen3_merged.py \
--checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \
--grouping_model Qwen/Qwen3-0.6B \
--instruction "Explain attention like I am in 9th grade" \
--max_length 256 \
--temperature 0.7 \
--device cuda
```
CLI options: `--checkpoint`, `--grouping_model`, `--instruction`, `--max_length`, `--temperature`, `--no_sample` for greedy, and `--device` for cuda or cpu.
What it does:
- Loads a grouping model with the custom layer 0 and a trained inference model.
- Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model.
- Reports compression ratio, memory reduction, total time, and tokens per second.
### Option B - use the training script utilities
The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See `load_trained_model` and `generate_with_grouped_input` in the training script if you prefer a programmatic flow.
---
## Parameters - quick reference
### Dataset builder
- `model_name` - base HF model for grouping, default Qwen/Qwen3-0.6B.
- `dataset_name` - source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M.
- `output_dir` - where pickles and metadata go.
- `max_samples` - optional cap for quick tests.
### Training
- `dataset_path` - path to `processed_dataset.pkl`.
- `output_dir` - where checkpoints are written.
- `batch_size, learning_rate, num_epochs, warmup_steps` - training hyperparams.
- Only layer 0 and the adapter are trainable. Verify with `requires_grad` settings in `_freeze_layers`.
### Inference
- `--checkpoint` - path to trained checkpoint folder.
- `--grouping_model` - HF model name used for grouping.
- `--instruction` - user prompt, any language.
- `--max_length`, `--temperature`, `--no_sample`, `--device`.
---
## Notes
- The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states.
- Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in `CustomQwen3Attention`.
---
## Troubleshooting
- **CUDA memory spikes**: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values.
- **No grouped states found**: ensure the custom layer 0 is used and `is_initialized` is reset before each prefill.
- **Checkpoint not found**: the inference loader expects `pytorch_model.bin` or `model.safetensors` in the checkpoint directory.
---
## Why this can save memory
If the sequence shrinks from `N` to `G` groups, attention memory scales roughly with `G^2` vs `N^2`. The script prints an estimated savings based on that relation.
---
## Citation
```
@misc{Kolomeitsev2025ContextMerging,
title = {Context Merging: from Tokens to Entities and Concepts},
author = {Konstantin Kolomeitsev},
year = {2025}
}
```
## Contact
If you have any questions, please raise an issue or contact with me [uol92kot@gmail.com](uol92kot@gmail.com).