Ram07's picture
Upload folder using huggingface_hub
d0f40b0 verified
---
license: mit
language:
- en
pipeline_tag: text-generation
tags:
- bitnet
- quantization
- early-exit
- layer-skipping
- efficient-transformers
datasets:
- roneneldan/TinyStories
---
# bitskip-v2-earlyexit
BitSkip v2 with 4-bit activation quantization, ternary weights, and Hadamard transform
## Model Description
This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
## Architecture Details
- **Layers**: 24
- **Hidden dimension**: 2048
- **Attention heads**: 32 (64-dimensional each)
- **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
- **FFN intermediate size**: 4096
- **Position embeddings**: Rotary Position Embeddings (RoPE)
- **Normalization**: RMSNorm
- **Activation**: SwiGLU (for MLP)
- **Parameters**: ~1.06B
### Quantization Scheme
- **Weights**: Ternary {-1, 0, 1}
- **Activations**: 4-bit quantization (post-Hadamard)
- **Hadamard**: Yes (FWHT)
## Training Details
### Dataset
- **Source**: TinyStories (2.1M stories)
- **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
- **Sequence length**: 512 tokens
### Training Techniques
**Quadratic Layer Dropout:**
- Progressive dropout: p_l = 0.5 × (l/L)²
- Normalized so Σp_l = 1.0
- Never drops final layer
- Makes earlier layers more accurate
**Early Exit Loss:**
- All layers share the same LM head
- Loss = main_loss + 0.3 × early_exit_loss
- Layer-proportional weighting: w_i = (i+1)/L
- Enables flexible early exit at inference
### Hyperparameters
- **Optimizer**: AdamW
- **Learning rate**: 3e-4
- **Warmup steps**: 4000
- **Batch size**: 16 (effective: 64)
- **Training steps**: 50000
- **Gradient clipping**: 0.5
## Performance
### Perplexity (TinyStories validation)
| Exit Layer | Perplexity | Speed (tok/s) |
|------------|------------|---------------|
| All layers | TBD | TBD |
| Layer 18 | TBD | TBD |
| Layer 12 | TBD | TBD |
| Layer 6 | TBD | TBD |
### Training Stability
- **Gradient norms**: 50-110
- **Final loss**: TBD
## Usage
### Installation
```bash
pip install transformers torch
```
### Basic Inference
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model
model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v2-earlyexit")
tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v2-earlyexit")
# Generate text
inputs = tokenizer("Once upon a time", return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Early Exit Inference
```python
# Exit at layer 12 for faster inference
model.set_exit_layer(12)
outputs = model.generate(**inputs, max_length=100)
# 1.5-2x faster with minimal quality loss
```
### Benchmark Different Exit Layers
```python
for exit_layer in [6, 12, 18, 24]:
model.set_exit_layer(exit_layer)
outputs = model.generate(**inputs, max_length=100)
print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
```
## Limitations
- **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
- **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
- **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
## Citation
If you use this model, please cite:
```bibtex
@article{bitnet,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
@article{layerskip,
title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
journal={arXiv preprint arXiv:2404.16710},
year={2024}
}
```
## License
MIT License
## Contact
For questions or issues, please open an issue on the model repository.