| | --- |
| | language: |
| | - en |
| | tags: |
| | - pytorch |
| | - causal-lm |
| | license: apache-2.0 |
| | --- |
| | |
| | # Sparse GPT-J 6B |
| |
|
| | ## Model Description |
| | The sparse version of GPT-J 6B is a pruned variant derived from the original [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6b) model and the vast majority of linear layers maintain a 40% unstructured sparsity (except for the 'lm_head'). |
| | |
| | |
| | <figure> |
| | |
| | | Hyperparameter | Value | |
| | |----------------------|------------| |
| | | \\(n_{parameters}\\) | 6053381344 | |
| | | \\(n_{layers}\\) | 28* | |
| | | \\(d_{model}\\) | 4096 | |
| | | \\(d_{ff}\\) | 16384 | |
| | | \\(n_{heads}\\) | 16 | |
| | | \\(d_{head}\\) | 256 | |
| | | \\(n_{ctx}\\) | 2048 | |
| | | \\(n_{vocab}\\) | 50257/50400† (same tokenizer as GPT-2/3) | |
| | | Positional Encoding | Rotary Position Embedding RoPE | |
| | | RoPE Dimensions | [64](https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L223) | |
| | <figcaption><p><strong>*</strong> Each layer consists of one feedforward block and one self attention block.</p> |
| | <p><strong>†</strong> Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer.</p></figcaption></figure> |
| | |
| | The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model |
| | dimension is split into 16 heads, each with a dimension of 256. Rotary Position Embedding (RoPE) is applied to 64 |
| | dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as |
| | GPT-2/GPT-3. |
| | |
| | |
| | |
| | ## Evaluation results |
| | Evaluating the accuracy of the sparse model of gpt-j-6b using the lambada_openai dataset in lm_eval, providing the accuracy fluctuation under two precisions: FP32 and BF16. |
| | <figure> |
| | |
| | | Sparsity | Dataset | Precision | Dense Acc ↑ | Sparse Acc ↑ | Acc fluctuations | |
| | |------ |---------------- |------- |------- |-------- |------------------ | |
| | | 40% |Lambada_openai | FP32 | 0.6831 | 0.6922 | +1.33% | |
| | | 40% |Lambada_openai | BF16 | 0.6771 | 0.6874 | +0.63% | |