File size: 4,154 Bytes
edc9020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
---
license: mit
language:
- en
pipeline_tag: text-generation
tags:
- bitnet
- quantization
- early-exit
- layer-skipping
- efficient-transformers
datasets:
- roneneldan/TinyStories
---
# bitskip-v3-earlyexit
BitSkip v3 with 8-bit activation quantization, ternary weights, and Hadamard transform
## Model Description
This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
## Architecture Details
- **Layers**: 24
- **Hidden dimension**: 2048
- **Attention heads**: 32 (64-dimensional each)
- **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
- **FFN intermediate size**: 4096
- **Position embeddings**: Rotary Position Embeddings (RoPE)
- **Normalization**: RMSNorm
- **Activation**: SwiGLU (for MLP)
- **Parameters**: ~1.06B
### Quantization Scheme
- **Weights**: Ternary {-1, 0, 1}
- **Activations**: 8-bit quantization (post-Hadamard)
- **Hadamard**: Yes (FWHT)
## Training Details
### Dataset
- **Source**: TinyStories (2.1M stories)
- **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
- **Sequence length**: 512 tokens
### Training Techniques
**Quadratic Layer Dropout:**
- Progressive dropout: p_l = 0.5 × (l/L)²
- Normalized so Σp_l = 1.0
- Never drops final layer
- Makes earlier layers more accurate
**Early Exit Loss:**
- All layers share the same LM head
- Loss = main_loss + 0.3 × early_exit_loss
- Layer-proportional weighting: w_i = (i+1)/L
- Enables flexible early exit at inference
### Hyperparameters
- **Optimizer**: AdamW
- **Learning rate**: 6e-4
- **Warmup steps**: 1000
- **Batch size**: 16 (effective: 64)
- **Training steps**: 50000
- **Gradient clipping**: 1.0
## Performance
### Perplexity (TinyStories validation)
| Exit Layer | Perplexity | Speed (tok/s) |
|------------|------------|---------------|
| All layers | TBD | TBD |
| Layer 18 | TBD | TBD |
| Layer 12 | TBD | TBD |
| Layer 6 | TBD | TBD |
### Training Stability
- **Gradient norms**: TBD
- **Final loss**: TBD
## Usage
### Installation
```bash
pip install transformers torch
```
### Basic Inference
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model
model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v3-earlyexit")
tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v3-earlyexit")
# Generate text
inputs = tokenizer("Once upon a time", return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Early Exit Inference
```python
# Exit at layer 12 for faster inference
model.set_exit_layer(12)
outputs = model.generate(**inputs, max_length=100)
# 1.5-2x faster with minimal quality loss
```
### Benchmark Different Exit Layers
```python
for exit_layer in [6, 12, 18, 24]:
model.set_exit_layer(exit_layer)
outputs = model.generate(**inputs, max_length=100)
print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
```
## Limitations
- **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
- **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
- **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
## Citation
If you use this model, please cite:
```bibtex
@article{bitnet,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
@article{layerskip,
title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
journal={arXiv preprint arXiv:2404.16710},
year={2024}
}
```
## License
MIT License
## Contact
For questions or issues, please open an issue on the model repository.
|