|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-generation |
|
|
tags: |
|
|
- bitnet |
|
|
- quantization |
|
|
- early-exit |
|
|
- layer-skipping |
|
|
- efficient-transformers |
|
|
datasets: |
|
|
- roneneldan/TinyStories |
|
|
--- |
|
|
|
|
|
# bitskip-v2-earlyexit |
|
|
|
|
|
BitSkip v2 with 4-bit activation quantization, ternary weights, and Hadamard transform |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference. |
|
|
|
|
|
## Architecture Details |
|
|
|
|
|
- **Layers**: 24 |
|
|
- **Hidden dimension**: 2048 |
|
|
- **Attention heads**: 32 (64-dimensional each) |
|
|
- **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio) |
|
|
- **FFN intermediate size**: 4096 |
|
|
- **Position embeddings**: Rotary Position Embeddings (RoPE) |
|
|
- **Normalization**: RMSNorm |
|
|
- **Activation**: SwiGLU (for MLP) |
|
|
- **Parameters**: ~1.06B |
|
|
|
|
|
### Quantization Scheme |
|
|
|
|
|
- **Weights**: Ternary {-1, 0, 1} |
|
|
- **Activations**: 4-bit quantization (post-Hadamard) |
|
|
- **Hadamard**: Yes (FWHT) |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Dataset |
|
|
- **Source**: TinyStories (2.1M stories) |
|
|
- **Tokenizer**: GPT-2 BPE (vocab size: 50,257) |
|
|
- **Sequence length**: 512 tokens |
|
|
|
|
|
### Training Techniques |
|
|
|
|
|
**Quadratic Layer Dropout:** |
|
|
- Progressive dropout: p_l = 0.5 × (l/L)² |
|
|
- Normalized so Σp_l = 1.0 |
|
|
- Never drops final layer |
|
|
- Makes earlier layers more accurate |
|
|
|
|
|
**Early Exit Loss:** |
|
|
- All layers share the same LM head |
|
|
- Loss = main_loss + 0.3 × early_exit_loss |
|
|
- Layer-proportional weighting: w_i = (i+1)/L |
|
|
- Enables flexible early exit at inference |
|
|
|
|
|
### Hyperparameters |
|
|
|
|
|
- **Optimizer**: AdamW |
|
|
- **Learning rate**: 3e-4 |
|
|
- **Warmup steps**: 4000 |
|
|
- **Batch size**: 16 (effective: 64) |
|
|
- **Training steps**: 50000 |
|
|
- **Gradient clipping**: 0.5 |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Perplexity (TinyStories validation) |
|
|
|
|
|
| Exit Layer | Perplexity | Speed (tok/s) | |
|
|
|------------|------------|---------------| |
|
|
| All layers | TBD | TBD | |
|
|
| Layer 18 | TBD | TBD | |
|
|
| Layer 12 | TBD | TBD | |
|
|
| Layer 6 | TBD | TBD | |
|
|
|
|
|
### Training Stability |
|
|
|
|
|
- **Gradient norms**: 50-110 |
|
|
- **Final loss**: TBD |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch |
|
|
``` |
|
|
|
|
|
### Basic Inference |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
# Load model |
|
|
model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v2-earlyexit") |
|
|
tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v2-earlyexit") |
|
|
|
|
|
# Generate text |
|
|
inputs = tokenizer("Once upon a time", return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_length=100) |
|
|
print(tokenizer.decode(outputs[0])) |
|
|
``` |
|
|
|
|
|
### Early Exit Inference |
|
|
|
|
|
```python |
|
|
# Exit at layer 12 for faster inference |
|
|
model.set_exit_layer(12) |
|
|
outputs = model.generate(**inputs, max_length=100) |
|
|
# 1.5-2x faster with minimal quality loss |
|
|
``` |
|
|
|
|
|
### Benchmark Different Exit Layers |
|
|
|
|
|
```python |
|
|
for exit_layer in [6, 12, 18, 24]: |
|
|
model.set_exit_layer(exit_layer) |
|
|
outputs = model.generate(**inputs, max_length=100) |
|
|
print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}") |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width |
|
|
- **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning |
|
|
- **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{bitnet, |
|
|
title={BitNet: Scaling 1-bit Transformers for Large Language Models}, |
|
|
author={Wang, Hongyu and Ma, Shuming and Dong, Li and others}, |
|
|
journal={arXiv preprint arXiv:2310.11453}, |
|
|
year={2023} |
|
|
} |
|
|
|
|
|
@article{layerskip, |
|
|
title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding}, |
|
|
author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others}, |
|
|
journal={arXiv preprint arXiv:2404.16710}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Contact |
|
|
|
|
|
For questions or issues, please open an issue on the model repository. |
|
|
|