--- license: mit datasets: - Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1 language: - en base_model: - Qwen/Qwen3-0.6B --- # Context Merging: from Tokens to Entities and Concepts This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings. - `prepare_dataset.py` builds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping. - `train_custom_qwen3.py` fine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0. - `inference_qwen3_merged.py` runs end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings. --- ## How it works 1. **Layer-0 grouping at prefill** A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as `grouped_hidden_states`. 2. **Dataset building** The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored `grouped_hidden_states`, and serializes them together with target responses. 3. **Model training** The training model wraps Qwen3 with a **GroupedInputMLPAdapter** that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses `grouped_inputs` as `inputs_embeds`, then generation proceeds with past-key-values. 4. **Inference** The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared. --- ## Requirements - Python packages: `torch`, `transformers`, `datasets`, `tqdm`, `psutil`. These are imported directly in the scripts. - GPU is optional. Scripts detect CUDA and set dtype accordingly. Install: ```bash pip install torch transformers datasets tqdm psutil ``` --- ## Repository layout - `prepare_dataset.py` - dataset builder using custom layer 0 grouping. - `train_custom_qwen3.py` - trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0. - `inference_qwen3_merged.py` - two-stage inference runner with metrics. --- ## 1 Build the local dataset Run: ```bash python prepare_dataset.py ``` Key defaults inside `DatasetProcessor`: - `model_name="Qwen/Qwen3-0.6B"` - `dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"` - `output_dir="./processed_dataset"` - `batch_size=1`, `max_samples=None`, `save_frequency=1000` Edit these in the constructor if you need to change them. The builder formats inputs using a simple system prompt template. It tokenizes, runs layer 0 once per example, captures `grouped_hidden_states`, and buffers results. **Outputs** under `output_dir`: - `processed_dataset.pkl` - list of samples with `inputs_embeds` (grouped), `response`, and metadata. - Additional metadata and sample previews are written alongside, for quick inspection. --- ## 2 Train the grouped-input model Run: ```bash python train_custom_qwen3.py --mode train ``` Training config defaults (edit in the script if needed): - `model_name="Qwen/Qwen3-0.6B"` - `dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"` - `output_dir="./grouped_qwen3_checkpoint"` - `batch_size=4`, `learning_rate=5e-4`, `num_epochs=3`, `warmup_steps=100` - Logging, eval, and checkpoint cadence are configurable. What is trained: - A **GroupedInputMLPAdapter** that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm. - Only layer 0 and this adapter are trainable; everything else is frozen. How targets are computed: - Prefill: pass `grouped_inputs` via `inputs_embeds` with `is_prefill=True`. - Then feed target response tokens while reusing `past_key_values`. Checkpoints contain model weights, config, and tokenizer in the epoch folder. --- ## 3 Run inference ### Option A - standalone runner Quick start: ```bash python inference_qwen3_merged.py \ --checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \ --grouping_model Qwen/Qwen3-0.6B \ --instruction "Explain attention like I am in 9th grade" \ --max_length 256 \ --temperature 0.7 \ --device cuda ``` CLI options: `--checkpoint`, `--grouping_model`, `--instruction`, `--max_length`, `--temperature`, `--no_sample` for greedy, and `--device` for cuda or cpu. What it does: - Loads a grouping model with the custom layer 0 and a trained inference model. - Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model. - Reports compression ratio, memory reduction, total time, and tokens per second. ### Option B - use the training script utilities The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See `load_trained_model` and `generate_with_grouped_input` in the training script if you prefer a programmatic flow. --- ## Parameters - quick reference ### Dataset builder - `model_name` - base HF model for grouping, default Qwen/Qwen3-0.6B. - `dataset_name` - source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M. - `output_dir` - where pickles and metadata go. - `max_samples` - optional cap for quick tests. ### Training - `dataset_path` - path to `processed_dataset.pkl`. - `output_dir` - where checkpoints are written. - `batch_size, learning_rate, num_epochs, warmup_steps` - training hyperparams. - Only layer 0 and the adapter are trainable. Verify with `requires_grad` settings in `_freeze_layers`. ### Inference - `--checkpoint` - path to trained checkpoint folder. - `--grouping_model` - HF model name used for grouping. - `--instruction` - user prompt, any language. - `--max_length`, `--temperature`, `--no_sample`, `--device`. --- ## Notes - The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states. - Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in `CustomQwen3Attention`. --- ## Troubleshooting - **CUDA memory spikes**: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values. - **No grouped states found**: ensure the custom layer 0 is used and `is_initialized` is reset before each prefill. - **Checkpoint not found**: the inference loader expects `pytorch_model.bin` or `model.safetensors` in the checkpoint directory. --- ## Why this can save memory If the sequence shrinks from `N` to `G` groups, attention memory scales roughly with `G^2` vs `N^2`. The script prints an estimated savings based on that relation. --- ## Citation ``` @misc{Kolomeitsev2025ContextMerging, title = {Context Merging: from Tokens to Entities and Concepts}, author = {Konstantin Kolomeitsev}, year = {2025} } ``` ## Contact If you have any questions, please raise an issue or contact with me [uol92kot@gmail.com](uol92kot@gmail.com).