Add files using upload-large-folder tool
Browse files- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/generation_config.json +7 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/image_processing_evabyte.py +204 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/model.safetensors.index.json +450 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/modeling_evabyte.py +912 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/multibyte_decoding_evabyte.py +881 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/preprocessor_config.json +18 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processing_evabyte.py +287 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processor_config.json +6 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/special_tokens_map.json +98 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenization_evabyte.py +246 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenizer_config.json +596 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/README.md +105 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/config.json +48 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/configuration_evabyte.py +99 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva.py +424 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_agg_kernel.py +1766 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_cache.py +761 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_prep_kv_kernel.py +1017 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_pt_ref.py +420 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/generation_config.json +7 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/image_processing_evabyte.py +204 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/model.safetensors.index.json +450 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/modeling_evabyte.py +912 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/multibyte_decoding_evabyte.py +881 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/preprocessor_config.json +18 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processing_evabyte.py +287 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processor_config.json +6 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/special_tokens_map.json +98 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenization_evabyte.py +246 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenizer_config.json +596 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/README.md +105 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/config.json +48 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/configuration_evabyte.py +99 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva.py +424 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_agg_kernel.py +1766 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_cache.py +761 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_prep_kv_kernel.py +1017 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_pt_ref.py +420 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/generation_config.json +7 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/image_processing_evabyte.py +204 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/model.safetensors.index.json +450 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/modeling_evabyte.py +912 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/multibyte_decoding_evabyte.py +881 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/preprocessor_config.json +18 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processing_evabyte.py +287 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processor_config.json +6 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/special_tokens_map.json +98 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenization_evabyte.py +246 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenizer_config.json +596 -0
- ckpts/ocpython_14b_bsz-2m_seq16k_docmask_multipredc2r8_90dynamic-10raw_transsentinel_minsize0ent98line16ow16pack_100B_2m_new_2_step-10000/README.md +105 -0
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.47.1"
|
| 7 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/image_processing_evabyte.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Image processor class for EvaByte."""
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional, Union, Tuple
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
ImageInput,
|
| 10 |
+
PILImageResampling,
|
| 11 |
+
valid_images,
|
| 12 |
+
validate_preprocess_arguments,
|
| 13 |
+
)
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
def _get_qtable_bytes():
|
| 17 |
+
return {
|
| 18 |
+
5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 19 |
+
10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 20 |
+
15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 21 |
+
20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
|
| 22 |
+
25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
|
| 23 |
+
30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
|
| 24 |
+
50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
|
| 25 |
+
75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
|
| 26 |
+
95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
|
| 27 |
+
100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _resize_if_exceeding_max_len(
|
| 32 |
+
width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
|
| 33 |
+
) -> Tuple[int, int]:
|
| 34 |
+
"""
|
| 35 |
+
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
height (`int`):
|
| 39 |
+
Height of the input image.
|
| 40 |
+
width (`int`):
|
| 41 |
+
Width of the input image.
|
| 42 |
+
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
| 43 |
+
Defines the maximum dimensions of the image.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The output size of the image after resizing.
|
| 47 |
+
"""
|
| 48 |
+
max_len = max(height, width) if max_len is None else max_len
|
| 49 |
+
aspect_ratio = width / height
|
| 50 |
+
|
| 51 |
+
if width >= height and width > max_len:
|
| 52 |
+
width = max_len
|
| 53 |
+
height = int(width / aspect_ratio)
|
| 54 |
+
if height % 2 != 0:
|
| 55 |
+
height += 1
|
| 56 |
+
elif height > width and height > max_len:
|
| 57 |
+
height = max_len
|
| 58 |
+
width = int(height * aspect_ratio)
|
| 59 |
+
if width % 2 != 0:
|
| 60 |
+
width += 1
|
| 61 |
+
|
| 62 |
+
# Avoid resizing to a size smaller than 1
|
| 63 |
+
height = max(height, min_len)
|
| 64 |
+
width = max(width, min_len)
|
| 65 |
+
return width, height
|
| 66 |
+
|
| 67 |
+
class EvaByteImageProcessor(BaseImageProcessor):
|
| 68 |
+
|
| 69 |
+
model_input_names = []
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
do_resize: bool = True,
|
| 74 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 75 |
+
size: Dict[str, int] = None,
|
| 76 |
+
do_convert_rgb: bool = True,
|
| 77 |
+
jpeg_quality: int = 25,
|
| 78 |
+
jpeg_subsampling: str = "4:2:0",
|
| 79 |
+
jpeg_streamtype: str = 2,
|
| 80 |
+
jpeg_restart_marker_blocks: int = 1,
|
| 81 |
+
**kwargs,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
self.do_resize = do_resize
|
| 85 |
+
self.resample = resample
|
| 86 |
+
self.size = size if size is not None else {"longest_edge": 384}
|
| 87 |
+
self.do_convert_rgb = do_convert_rgb
|
| 88 |
+
self.jpeg_quality = jpeg_quality
|
| 89 |
+
self.jpeg_subsampling = jpeg_subsampling
|
| 90 |
+
self.jpeg_streamtype = jpeg_streamtype
|
| 91 |
+
self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
|
| 92 |
+
|
| 93 |
+
def jpeg_encode(
|
| 94 |
+
self,
|
| 95 |
+
image,
|
| 96 |
+
jpeg_quality,
|
| 97 |
+
jpeg_subsampling,
|
| 98 |
+
jpeg_streamtype,
|
| 99 |
+
jpeg_restart_marker_blocks,
|
| 100 |
+
):
|
| 101 |
+
with io.BytesIO() as output:
|
| 102 |
+
image.save(
|
| 103 |
+
output,
|
| 104 |
+
format="JPEG",
|
| 105 |
+
quality=jpeg_quality,
|
| 106 |
+
subsampling=jpeg_subsampling,
|
| 107 |
+
streamtype=jpeg_streamtype,
|
| 108 |
+
restart_marker_blocks=jpeg_restart_marker_blocks
|
| 109 |
+
)
|
| 110 |
+
jpeg_bytes = output.getvalue()
|
| 111 |
+
return jpeg_bytes
|
| 112 |
+
|
| 113 |
+
def jpeg_merge_qtables(
|
| 114 |
+
self,
|
| 115 |
+
image_bytes,
|
| 116 |
+
jpeg_quality=None,
|
| 117 |
+
):
|
| 118 |
+
if jpeg_quality is None:
|
| 119 |
+
jpeg_quality = self.jpeg_quality
|
| 120 |
+
qtable_bytes = _get_qtable_bytes()[jpeg_quality]
|
| 121 |
+
return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
|
| 122 |
+
|
| 123 |
+
def resize(
|
| 124 |
+
self,
|
| 125 |
+
image: Image,
|
| 126 |
+
size: Dict[str, int],
|
| 127 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 128 |
+
) -> Image:
|
| 129 |
+
if "longest_edge" in size:
|
| 130 |
+
width, height = image.size
|
| 131 |
+
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
| 132 |
+
width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
|
| 133 |
+
size = (width, height)
|
| 134 |
+
elif "width" in size and "height" in size:
|
| 135 |
+
size = (size["width"], size["height"])
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
| 138 |
+
resized_image = image.resize(size, resample=resample)
|
| 139 |
+
return resized_image
|
| 140 |
+
|
| 141 |
+
def preprocess(
|
| 142 |
+
self,
|
| 143 |
+
images: ImageInput,
|
| 144 |
+
do_resize: bool = None,
|
| 145 |
+
resample = None,
|
| 146 |
+
size: Dict[str, int] = None,
|
| 147 |
+
do_convert_rgb: bool = None,
|
| 148 |
+
jpeg_quality: int = None,
|
| 149 |
+
jpeg_subsampling: str = None,
|
| 150 |
+
jpeg_streamtype: str = None,
|
| 151 |
+
jpeg_restart_marker_blocks: int = None,
|
| 152 |
+
):
|
| 153 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 154 |
+
size = size if size is not None else self.size
|
| 155 |
+
resample = resample if resample is not None else self.resample
|
| 156 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 157 |
+
|
| 158 |
+
jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
|
| 159 |
+
jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
|
| 160 |
+
jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
|
| 161 |
+
jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
|
| 162 |
+
|
| 163 |
+
if images is not None and not valid_images(images):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 166 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
validate_preprocess_arguments(
|
| 170 |
+
do_resize=do_resize,
|
| 171 |
+
size=size,
|
| 172 |
+
resample=resample,
|
| 173 |
+
)
|
| 174 |
+
images_list = images
|
| 175 |
+
if do_convert_rgb:
|
| 176 |
+
images_list = [
|
| 177 |
+
[
|
| 178 |
+
image.convert("RGB") for image in images
|
| 179 |
+
]
|
| 180 |
+
for images in images_list
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
if do_resize:
|
| 184 |
+
images_list = [
|
| 185 |
+
[
|
| 186 |
+
self.resize(image=image, size=size, resample=resample)
|
| 187 |
+
for image in images
|
| 188 |
+
]
|
| 189 |
+
for images in images_list
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
jpeg_bytes = [
|
| 193 |
+
[
|
| 194 |
+
self.jpeg_encode(
|
| 195 |
+
image,
|
| 196 |
+
jpeg_quality,
|
| 197 |
+
jpeg_subsampling,
|
| 198 |
+
jpeg_streamtype,
|
| 199 |
+
jpeg_restart_marker_blocks
|
| 200 |
+
) for image in images
|
| 201 |
+
]
|
| 202 |
+
for images in images_list
|
| 203 |
+
]
|
| 204 |
+
return jpeg_bytes
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 57058938880
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 7 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 8 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 9 |
+
"model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 10 |
+
"model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 11 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 12 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 13 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 14 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 15 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 16 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 17 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 18 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 19 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 20 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 21 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 22 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 23 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 24 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 25 |
+
"model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 26 |
+
"model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 27 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 28 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 29 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 30 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 31 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 32 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 33 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 34 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 35 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 36 |
+
"model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 37 |
+
"model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 38 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 39 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 40 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 41 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 42 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 43 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 44 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 45 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 46 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 47 |
+
"model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 48 |
+
"model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 49 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 50 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 51 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 52 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 53 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 54 |
+
"model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 55 |
+
"model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 56 |
+
"model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 57 |
+
"model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 58 |
+
"model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 59 |
+
"model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 60 |
+
"model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 61 |
+
"model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 62 |
+
"model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 63 |
+
"model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 64 |
+
"model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 65 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 66 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 67 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 68 |
+
"model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 69 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 70 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 71 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 72 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 73 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 74 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 75 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 76 |
+
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 77 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 78 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 79 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 80 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 81 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 82 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 83 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 84 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 85 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 86 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 87 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 88 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 89 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 90 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 91 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 92 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 93 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 94 |
+
"model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 95 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 96 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 97 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 98 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 99 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 100 |
+
"model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 101 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 102 |
+
"model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 103 |
+
"model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 104 |
+
"model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 105 |
+
"model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 106 |
+
"model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 107 |
+
"model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 108 |
+
"model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 109 |
+
"model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 110 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 111 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 112 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 113 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 114 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 115 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 116 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 117 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 118 |
+
"model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 119 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 120 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 121 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 122 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 123 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 124 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 125 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 126 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 127 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 128 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 129 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 130 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 131 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 132 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 133 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 134 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 135 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 136 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 137 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 138 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 139 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 140 |
+
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 141 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 142 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 143 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 144 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 145 |
+
"model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 146 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 147 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 148 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 149 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 150 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 151 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 152 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 153 |
+
"model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 154 |
+
"model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 155 |
+
"model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 156 |
+
"model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 157 |
+
"model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 158 |
+
"model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 159 |
+
"model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 160 |
+
"model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 161 |
+
"model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 162 |
+
"model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 163 |
+
"model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 164 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 165 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 166 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 167 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 168 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 169 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 170 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 171 |
+
"model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 172 |
+
"model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 173 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 174 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 175 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 176 |
+
"model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 177 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 178 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 179 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 180 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 181 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 182 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 183 |
+
"model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 184 |
+
"model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 185 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 186 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 187 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 188 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 189 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 190 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 191 |
+
"model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 192 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 193 |
+
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 194 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 195 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 199 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 201 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 203 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 205 |
+
"model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 206 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 207 |
+
"model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 211 |
+
"model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 213 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 214 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 215 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 216 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 217 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 218 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 219 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 220 |
+
"model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 221 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 222 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 223 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 224 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 225 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 226 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 227 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 228 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 229 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 230 |
+
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 231 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 232 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 233 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 234 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 235 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 236 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 237 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 238 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 239 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 240 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 241 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 242 |
+
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 243 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 244 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 245 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 246 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 247 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 248 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 249 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 250 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 251 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 252 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 253 |
+
"model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 254 |
+
"model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 255 |
+
"model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 256 |
+
"model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 257 |
+
"model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 258 |
+
"model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 259 |
+
"model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 260 |
+
"model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 261 |
+
"model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 262 |
+
"model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 263 |
+
"model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 264 |
+
"model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 265 |
+
"model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 266 |
+
"model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 267 |
+
"model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 268 |
+
"model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 269 |
+
"model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 270 |
+
"model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 271 |
+
"model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 272 |
+
"model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 273 |
+
"model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 274 |
+
"model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 275 |
+
"model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 276 |
+
"model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 277 |
+
"model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 278 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 279 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 280 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 281 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 282 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 283 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 284 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 285 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 286 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 287 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 288 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 289 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 290 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 291 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 292 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 293 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 294 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 295 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 296 |
+
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 297 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 298 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 299 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 300 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 301 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 302 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 303 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 304 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 305 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 306 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 307 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 308 |
+
"model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 309 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 310 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 311 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 312 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 313 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 314 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 315 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 316 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 317 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 318 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 319 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 320 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 321 |
+
"model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 322 |
+
"model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 323 |
+
"model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 324 |
+
"model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 325 |
+
"model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 326 |
+
"model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 327 |
+
"model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 328 |
+
"model.norm.weight": "model-00003-of-00003.safetensors",
|
| 329 |
+
"lm_head.weight": "model-00003-of-00003.safetensors",
|
| 330 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 331 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 332 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 333 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 334 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 335 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 336 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 337 |
+
"model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 338 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 339 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 340 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 341 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 342 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 343 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 344 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 345 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 346 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 347 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 348 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 349 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 350 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 351 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 352 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 353 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 354 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 355 |
+
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 356 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 357 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 358 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 359 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 360 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 361 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 362 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 363 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 364 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 365 |
+
"model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 366 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 367 |
+
"model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 368 |
+
"model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 369 |
+
"model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 370 |
+
"model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 371 |
+
"model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 372 |
+
"model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 373 |
+
"model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 374 |
+
"model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 375 |
+
"model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 376 |
+
"model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 377 |
+
"model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 378 |
+
"model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 379 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 380 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 381 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 382 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 383 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 384 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 385 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 386 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 387 |
+
"model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 388 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 389 |
+
"model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 390 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 391 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 392 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 393 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 394 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 395 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 396 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 397 |
+
"model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 398 |
+
"model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 399 |
+
"model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 400 |
+
"model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 401 |
+
"model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 402 |
+
"model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 403 |
+
"model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 404 |
+
"model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 405 |
+
"model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 406 |
+
"model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 407 |
+
"model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 408 |
+
"model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 409 |
+
"model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 410 |
+
"model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 411 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 412 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 413 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 414 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 415 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 416 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 417 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 418 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 419 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 420 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 421 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 422 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 423 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 424 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 425 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 426 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 427 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 428 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 429 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 430 |
+
"model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 431 |
+
"model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 432 |
+
"model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 433 |
+
"model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 434 |
+
"model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 435 |
+
"model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 436 |
+
"model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 437 |
+
"model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 438 |
+
"model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 439 |
+
"model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 440 |
+
"model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 441 |
+
"model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 442 |
+
"model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 443 |
+
"model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 444 |
+
"model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 445 |
+
"model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 446 |
+
"model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 447 |
+
"model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 448 |
+
"model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
|
| 449 |
+
}
|
| 450 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/modeling_evabyte.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import CrossEntropyLoss
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from transformers.cache_utils import Cache
|
| 10 |
+
from transformers.modeling_outputs import (
|
| 11 |
+
BaseModelOutputWithPast,
|
| 12 |
+
CausalLMOutputWithPast,
|
| 13 |
+
)
|
| 14 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
+
|
| 16 |
+
from .configuration_evabyte import EvaByteConfig
|
| 17 |
+
from .multibyte_decoding_evabyte import MultiByteDecodingMixin
|
| 18 |
+
try:
|
| 19 |
+
import triton
|
| 20 |
+
USE_TRITON_IMPL = True
|
| 21 |
+
from .eva import EvaAttention
|
| 22 |
+
from .eva_agg_kernel import triton_eva_agg_fwd
|
| 23 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 24 |
+
except ImportError:
|
| 25 |
+
USE_TRITON_IMPL = False
|
| 26 |
+
print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
|
| 27 |
+
from .eva_pt_ref import EvaAttention
|
| 28 |
+
from .eva_cache import EvaCache, EvaStaticCacheForTriton
|
| 29 |
+
|
| 30 |
+
MASK_MIN_VALUE = -10e10
|
| 31 |
+
|
| 32 |
+
def prepare_eva_attention_mask(
|
| 33 |
+
seq_len,
|
| 34 |
+
device,
|
| 35 |
+
chunk_size,
|
| 36 |
+
window_size,
|
| 37 |
+
use_cache=False,
|
| 38 |
+
cache=None
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Prepare attention masks for EVA.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
chunk_causal_mask = None
|
| 45 |
+
window_causal_mask = None
|
| 46 |
+
if use_cache:
|
| 47 |
+
cached_seq_len = cache.get_seq_length()
|
| 48 |
+
total_seq_len = seq_len + cached_seq_len
|
| 49 |
+
# cached_seq_len will be 0 during prefilling
|
| 50 |
+
# padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
|
| 51 |
+
padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
|
| 52 |
+
num_chunks = padded_seq_len // chunk_size
|
| 53 |
+
else:
|
| 54 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 55 |
+
assert seq_len % chunk_size == 0
|
| 56 |
+
num_chunks = seq_len // chunk_size
|
| 57 |
+
|
| 58 |
+
assert seq_len % window_size == 0
|
| 59 |
+
|
| 60 |
+
# create causal mask
|
| 61 |
+
################################
|
| 62 |
+
# generate chunked causal masks
|
| 63 |
+
################################
|
| 64 |
+
# [b, h, j, c, c]
|
| 65 |
+
chunks_per_window = window_size // chunk_size
|
| 66 |
+
if num_chunks >= chunks_per_window:
|
| 67 |
+
chunk_causal_mask = torch.ones(
|
| 68 |
+
(chunk_size, num_chunks, num_chunks),
|
| 69 |
+
device=device,
|
| 70 |
+
dtype=torch.bool
|
| 71 |
+
).triu(0)
|
| 72 |
+
|
| 73 |
+
num_blocks = num_chunks // chunks_per_window
|
| 74 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 75 |
+
chunk_size,
|
| 76 |
+
num_blocks,
|
| 77 |
+
chunks_per_window,
|
| 78 |
+
num_blocks,
|
| 79 |
+
chunks_per_window
|
| 80 |
+
).transpose(-2, -3)
|
| 81 |
+
|
| 82 |
+
block_diag_zero = (
|
| 83 |
+
torch.eye(num_blocks, device=device, dtype=torch.bool)
|
| 84 |
+
.unsqueeze(-1)
|
| 85 |
+
.unsqueeze(-1)
|
| 86 |
+
.unsqueeze(0)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Set diagonal blocks to zero
|
| 90 |
+
chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
|
| 91 |
+
|
| 92 |
+
# Reshape back to original size
|
| 93 |
+
chunk_causal_mask = (
|
| 94 |
+
chunk_causal_mask
|
| 95 |
+
.transpose(-2, -3)
|
| 96 |
+
.reshape(chunk_size, num_chunks, num_chunks)
|
| 97 |
+
.transpose(-2, -3)
|
| 98 |
+
.reshape(chunk_size * num_chunks, num_chunks)
|
| 99 |
+
.unsqueeze(0)
|
| 100 |
+
.unsqueeze(0)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
chunk_causal_mask = torch.ones(
|
| 104 |
+
(1, 1, chunk_size, num_chunks, num_chunks),
|
| 105 |
+
device=device,
|
| 106 |
+
dtype=torch.bool,
|
| 107 |
+
).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
|
| 108 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 109 |
+
1, 1, chunk_size * num_chunks, num_chunks
|
| 110 |
+
) # [1, 1, n, c]
|
| 111 |
+
|
| 112 |
+
if use_cache:
|
| 113 |
+
chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
|
| 114 |
+
|
| 115 |
+
window_causal_mask = torch.ones(
|
| 116 |
+
(1, 1, 1, window_size, window_size),
|
| 117 |
+
device=device
|
| 118 |
+
).triu(1).to(torch.bool)
|
| 119 |
+
return (chunk_causal_mask, window_causal_mask)
|
| 120 |
+
|
| 121 |
+
def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
|
| 122 |
+
assert dim < 0 # only accept ``dim'' index in a reverse manner
|
| 123 |
+
seqlen = int(tensor.shape[dim])
|
| 124 |
+
m = seqlen / multiple
|
| 125 |
+
if m.is_integer():
|
| 126 |
+
if create_mask:
|
| 127 |
+
return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
|
| 128 |
+
else:
|
| 129 |
+
return tensor
|
| 130 |
+
remainder = math.ceil(m) * multiple - seqlen
|
| 131 |
+
pad_offset = (0,) * (-1 - dim) * 2
|
| 132 |
+
if left_padding:
|
| 133 |
+
padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
|
| 134 |
+
else:
|
| 135 |
+
padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
|
| 136 |
+
if create_mask:
|
| 137 |
+
# assume dim 0 is the batch size
|
| 138 |
+
padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
|
| 139 |
+
if left_padding:
|
| 140 |
+
padding_mask[:, :remainder] = False
|
| 141 |
+
else:
|
| 142 |
+
padding_mask[:, -remainder:] = False
|
| 143 |
+
return padded_res, padding_mask
|
| 144 |
+
else:
|
| 145 |
+
return padded_res
|
| 146 |
+
|
| 147 |
+
class EvaByteRMSNorm(nn.Module):
|
| 148 |
+
def __init__(self, config):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.config = config
|
| 151 |
+
self.fp32_ln = True
|
| 152 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 153 |
+
self.add_unit_offset = config.norm_add_unit_offset
|
| 154 |
+
if self.add_unit_offset:
|
| 155 |
+
self.weight = nn.Parameter(torch.zeros(config.hidden_size))
|
| 156 |
+
else:
|
| 157 |
+
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states):
|
| 160 |
+
_hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
|
| 161 |
+
|
| 162 |
+
variance = _hidden_states.pow(2).mean(-1, keepdim=True)
|
| 163 |
+
_hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 164 |
+
if self.add_unit_offset:
|
| 165 |
+
return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
|
| 166 |
+
else:
|
| 167 |
+
return (self.weight * _hidden_states).type_as(hidden_states)
|
| 168 |
+
|
| 169 |
+
class EvaByteRotaryEmbedding(torch.nn.Module):
|
| 170 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 171 |
+
super().__init__()
|
| 172 |
+
|
| 173 |
+
self.dim = dim
|
| 174 |
+
self.max_position_embeddings = max_position_embeddings
|
| 175 |
+
self.base = base
|
| 176 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
|
| 179 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
| 180 |
+
device=self.inv_freq.device,
|
| 181 |
+
dtype=torch.get_default_dtype())
|
| 182 |
+
|
| 183 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 184 |
+
self.max_seq_len_cached = seq_len
|
| 185 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 186 |
+
|
| 187 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 188 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 189 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 190 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def forward(self, x, seq_len=None):
|
| 194 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 195 |
+
if seq_len > self.max_seq_len_cached:
|
| 196 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 197 |
+
|
| 198 |
+
# return (
|
| 199 |
+
# self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 200 |
+
# self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 201 |
+
# )
|
| 202 |
+
if seq_len < self.max_seq_len_cached:
|
| 203 |
+
cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
|
| 204 |
+
sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
|
| 205 |
+
else:
|
| 206 |
+
cos_slice = self.cos_cached
|
| 207 |
+
sin_slice = self.sin_cached
|
| 208 |
+
|
| 209 |
+
return (
|
| 210 |
+
cos_slice.to(dtype=x.dtype),
|
| 211 |
+
sin_slice.to(dtype=x.dtype),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 217 |
+
"""EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 218 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 219 |
+
self.scaling_factor = scaling_factor
|
| 220 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 221 |
+
|
| 222 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 223 |
+
self.max_seq_len_cached = seq_len
|
| 224 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 225 |
+
t = t / self.scaling_factor
|
| 226 |
+
|
| 227 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 228 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 229 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 230 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 231 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 235 |
+
"""EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 236 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 237 |
+
self.scaling_factor = scaling_factor
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
|
| 243 |
+
if seq_len > self.max_position_embeddings:
|
| 244 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
| 245 |
+
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
|
| 246 |
+
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 247 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 248 |
+
|
| 249 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 250 |
+
|
| 251 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 252 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 253 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 254 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 255 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class EvaByteMLP(nn.Module):
|
| 259 |
+
def __init__(self, config, layer_idx: int = None):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.hidden_size = config.hidden_size
|
| 262 |
+
self.intermediate_size = config.intermediate_size
|
| 263 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 264 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 265 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 266 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 267 |
+
self.layer_idx = layer_idx
|
| 268 |
+
self.config = config
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 272 |
+
return down_proj
|
| 273 |
+
|
| 274 |
+
class EvaByteDecoderLayer(nn.Module):
|
| 275 |
+
def __init__(self, config: EvaByteConfig, layer_idx: int = None):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.config = config
|
| 278 |
+
self.hidden_size = config.hidden_size
|
| 279 |
+
self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
|
| 280 |
+
self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
|
| 281 |
+
self.input_layernorm = EvaByteRMSNorm(config)
|
| 282 |
+
self.post_attention_layernorm = EvaByteRMSNorm(config)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
hidden_states: torch.Tensor,
|
| 287 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 288 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 289 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 290 |
+
output_attentions: Optional[bool] = False,
|
| 291 |
+
use_cache: Optional[bool] = False,
|
| 292 |
+
cos: Optional[torch.Tensor] = None,
|
| 293 |
+
sin: Optional[torch.Tensor] = None,
|
| 294 |
+
multibyte_decoding: Optional[bool] = False,
|
| 295 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 296 |
+
residual = hidden_states
|
| 297 |
+
if self.config.fp32_skip_add:
|
| 298 |
+
residual = residual.float()
|
| 299 |
+
|
| 300 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 301 |
+
|
| 302 |
+
# Self Attention
|
| 303 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
|
| 304 |
+
attention_mask=attention_mask,
|
| 305 |
+
position_ids=position_ids,
|
| 306 |
+
past_key_value=past_key_value,
|
| 307 |
+
output_attentions=output_attentions,
|
| 308 |
+
use_cache=use_cache,
|
| 309 |
+
cos=cos,
|
| 310 |
+
sin=sin,
|
| 311 |
+
multibyte_decoding=multibyte_decoding)
|
| 312 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 313 |
+
|
| 314 |
+
# Fully Connected
|
| 315 |
+
residual = hidden_states
|
| 316 |
+
if self.config.fp32_skip_add:
|
| 317 |
+
residual = residual.float()
|
| 318 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 319 |
+
hidden_states = self.mlp(hidden_states)
|
| 320 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 321 |
+
|
| 322 |
+
outputs = (hidden_states, )
|
| 323 |
+
|
| 324 |
+
if output_attentions:
|
| 325 |
+
outputs += (self_attn_weights, )
|
| 326 |
+
|
| 327 |
+
if use_cache:
|
| 328 |
+
outputs += (present_key_value, )
|
| 329 |
+
return outputs
|
| 330 |
+
|
| 331 |
+
class EvaBytePreTrainedModel(PreTrainedModel):
|
| 332 |
+
config_class = EvaByteConfig
|
| 333 |
+
base_model_prefix = "model"
|
| 334 |
+
supports_gradient_checkpointing = True
|
| 335 |
+
_no_split_modules = ["EvaByteDecoderLayer"]
|
| 336 |
+
_skip_keys_device_placement = "past_key_values"
|
| 337 |
+
|
| 338 |
+
def _init_weights(self, module):
|
| 339 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 340 |
+
if isinstance(module, nn.Linear):
|
| 341 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 342 |
+
if module.bias is not None:
|
| 343 |
+
module.bias.data.zero_()
|
| 344 |
+
elif isinstance(module, nn.Embedding):
|
| 345 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 346 |
+
if module.padding_idx is not None:
|
| 347 |
+
module.weight.data[module.padding_idx].zero_()
|
| 348 |
+
|
| 349 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 350 |
+
if isinstance(module, EvaByteModel):
|
| 351 |
+
module.gradient_checkpointing = value
|
| 352 |
+
|
| 353 |
+
class EvaByteModel(EvaBytePreTrainedModel):
|
| 354 |
+
"""
|
| 355 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
config: EvaByteConfig
|
| 359 |
+
"""
|
| 360 |
+
def __init__(self, config: EvaByteConfig):
|
| 361 |
+
super().__init__(config)
|
| 362 |
+
self.padding_idx = config.pad_token_id
|
| 363 |
+
self.vocab_size = config.vocab_size
|
| 364 |
+
self.hidden_size = config.hidden_size
|
| 365 |
+
self.num_heads = config.num_attention_heads
|
| 366 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 367 |
+
self.max_position_embeddings = self.config.max_position_embeddings
|
| 368 |
+
|
| 369 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 370 |
+
self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 371 |
+
self.norm = EvaByteRMSNorm(config)
|
| 372 |
+
|
| 373 |
+
self.gradient_checkpointing = False
|
| 374 |
+
self.rope = config.rope_theta
|
| 375 |
+
# Initialize weights and apply final processing
|
| 376 |
+
self.post_init()
|
| 377 |
+
self._init_rope()
|
| 378 |
+
|
| 379 |
+
def _init_rope(self):
|
| 380 |
+
if self.config.rope_scaling is None:
|
| 381 |
+
self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
|
| 382 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 383 |
+
base=self.rope)
|
| 384 |
+
else:
|
| 385 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 386 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 387 |
+
if scaling_type == "linear":
|
| 388 |
+
self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
|
| 389 |
+
self.head_dim,
|
| 390 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 391 |
+
scaling_factor=scaling_factor,
|
| 392 |
+
base=self.rope)
|
| 393 |
+
elif scaling_type == "dynamic":
|
| 394 |
+
self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
|
| 395 |
+
self.head_dim,
|
| 396 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 397 |
+
scaling_factor=scaling_factor,
|
| 398 |
+
base=self.rope)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 401 |
+
|
| 402 |
+
def get_input_embeddings(self):
|
| 403 |
+
return self.embed_tokens
|
| 404 |
+
|
| 405 |
+
def set_input_embeddings(self, value):
|
| 406 |
+
self.embed_tokens = value
|
| 407 |
+
|
| 408 |
+
def _helper_padding_mask(
|
| 409 |
+
self,
|
| 410 |
+
padding_mask,
|
| 411 |
+
causal_mask
|
| 412 |
+
):
|
| 413 |
+
padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
|
| 414 |
+
return torch.logical_or(padding_mask, causal_mask)
|
| 415 |
+
|
| 416 |
+
def _prepare_eva_generation_attn_mask_triton(
|
| 417 |
+
self,
|
| 418 |
+
attention_mask,
|
| 419 |
+
input_ids,
|
| 420 |
+
use_cache,
|
| 421 |
+
past_key_values
|
| 422 |
+
):
|
| 423 |
+
batch_size, seq_len = input_ids.shape
|
| 424 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 425 |
+
# decoding phase
|
| 426 |
+
if past_key_values.rf_mask[0] is not None:
|
| 427 |
+
cur_rf_mask = torch.zeros(
|
| 428 |
+
(batch_size, 1, seq_len, 1),
|
| 429 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 430 |
+
device=past_key_values.rf_mask[0].device
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
cur_rf_mask = None
|
| 434 |
+
|
| 435 |
+
if past_key_values.s_mask[0] is not None:
|
| 436 |
+
cur_s_mask = torch.zeros(
|
| 437 |
+
(batch_size, 1, seq_len, 1),
|
| 438 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 439 |
+
device=past_key_values.s_mask[0].device
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
cur_s_mask = None
|
| 443 |
+
|
| 444 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 445 |
+
if seen_tokens <= self.config.window_size:
|
| 446 |
+
rfa_chunks_dummy_mask = None
|
| 447 |
+
else:
|
| 448 |
+
if cur_s_mask is not None:
|
| 449 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 450 |
+
# the ongoing decoding step would be (seen_seq_len + 1)-th token
|
| 451 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 452 |
+
rfa_chunks_dummy_mask = torch.zeros(
|
| 453 |
+
(batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
|
| 454 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 455 |
+
device=past_key_values.s_mask[0].device
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
rfa_chunks_dummy_mask = None
|
| 459 |
+
# rf_mask and cur_mask are 0s because we do not want to mask them
|
| 460 |
+
return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
|
| 461 |
+
|
| 462 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 463 |
+
# convert 0 -> padding to 1 -> padding
|
| 464 |
+
padded_attention_mask = pad_to_multiple(
|
| 465 |
+
attention_mask,
|
| 466 |
+
self.config.window_size,
|
| 467 |
+
dim=-1,
|
| 468 |
+
value=0,
|
| 469 |
+
create_mask=False,
|
| 470 |
+
left_padding=False
|
| 471 |
+
)
|
| 472 |
+
# convert 0 -> padding to 1 -> padding
|
| 473 |
+
padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 474 |
+
# [b, 1, w, j, 1]
|
| 475 |
+
padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
|
| 476 |
+
# [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
|
| 477 |
+
w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
|
| 478 |
+
w_causal_mask = torch.ones(
|
| 479 |
+
(1, 1, 1, self.config.window_size, self.config.window_size),
|
| 480 |
+
device=input_ids.device
|
| 481 |
+
).triu(1).to(torch.bool)
|
| 482 |
+
s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
|
| 483 |
+
s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
|
| 484 |
+
s_mask = s_mask[..., :seq_len, :]
|
| 485 |
+
# negate the attention mask to get the padding mask
|
| 486 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 487 |
+
return (s_mask, rf_mask)
|
| 488 |
+
else:
|
| 489 |
+
return (None, None)
|
| 490 |
+
|
| 491 |
+
def _prepare_eva_generation_attn_mask(
|
| 492 |
+
self,
|
| 493 |
+
attention_mask,
|
| 494 |
+
input_ids,
|
| 495 |
+
use_cache,
|
| 496 |
+
past_key_values
|
| 497 |
+
):
|
| 498 |
+
batch_size, seq_len = input_ids.shape
|
| 499 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 500 |
+
# decoding phase
|
| 501 |
+
if past_key_values.rf_mask[0] is not None:
|
| 502 |
+
rf_mask = torch.zeros(
|
| 503 |
+
(batch_size, 1, seq_len, 1),
|
| 504 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 505 |
+
device=past_key_values.rf_mask[0].device
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
rf_mask = None
|
| 509 |
+
|
| 510 |
+
cur_causal_mask = torch.zeros(
|
| 511 |
+
(batch_size, 1, seq_len, 1),
|
| 512 |
+
dtype=torch.bool,
|
| 513 |
+
device=input_ids.device
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
chunk_causal_mask = torch.ones(
|
| 517 |
+
(batch_size, 1, seq_len, 1),
|
| 518 |
+
dtype=torch.bool,
|
| 519 |
+
device=input_ids.device
|
| 520 |
+
)
|
| 521 |
+
# chunk_causal_mask are 1s because we will mask them by default and
|
| 522 |
+
# will be unmasked when the current singleton attention is processed over
|
| 523 |
+
return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
|
| 524 |
+
|
| 525 |
+
true_num_chunks = seq_len // self.config.chunk_size
|
| 526 |
+
chunk_causal_mask, _ = prepare_eva_attention_mask(
|
| 527 |
+
seq_len,
|
| 528 |
+
input_ids.device,
|
| 529 |
+
self.config.chunk_size,
|
| 530 |
+
self.config.window_size,
|
| 531 |
+
use_cache=use_cache,
|
| 532 |
+
cache=past_key_values
|
| 533 |
+
)
|
| 534 |
+
chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
|
| 535 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 536 |
+
# convert 0 -> padding to 1 -> padding
|
| 537 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 538 |
+
else:
|
| 539 |
+
rf_mask = None
|
| 540 |
+
|
| 541 |
+
if seq_len < self.config.window_size:
|
| 542 |
+
cur_window_mask = torch.ones(
|
| 543 |
+
(1, 1, seq_len, seq_len),
|
| 544 |
+
device=input_ids.device
|
| 545 |
+
).triu(1).to(torch.bool)
|
| 546 |
+
if rf_mask is not None:
|
| 547 |
+
cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
|
| 548 |
+
prev_window_mask = None
|
| 549 |
+
else:
|
| 550 |
+
if seq_len % self.config.window_size == 0:
|
| 551 |
+
num_windows = seq_len // self.config.window_size
|
| 552 |
+
cur_window_mask = None
|
| 553 |
+
prev_window_mask = torch.ones(
|
| 554 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 555 |
+
device=input_ids.device
|
| 556 |
+
).triu(1).to(torch.bool)
|
| 557 |
+
if rf_mask is not None:
|
| 558 |
+
prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 559 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 560 |
+
else:
|
| 561 |
+
num_windows = seq_len // self.config.window_size
|
| 562 |
+
remainder_tokens = seq_len % self.config.window_size
|
| 563 |
+
cur_window_mask = torch.ones(
|
| 564 |
+
(1, 1, remainder_tokens, remainder_tokens),
|
| 565 |
+
device=input_ids.device
|
| 566 |
+
).triu(1).to(torch.bool)
|
| 567 |
+
prev_window_mask = torch.ones(
|
| 568 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 569 |
+
device=input_ids.device
|
| 570 |
+
).triu(1).to(torch.bool)
|
| 571 |
+
if rf_mask is not None:
|
| 572 |
+
prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 573 |
+
cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
|
| 574 |
+
prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 575 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 576 |
+
|
| 577 |
+
return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
|
| 578 |
+
|
| 579 |
+
def forward(
|
| 580 |
+
self,
|
| 581 |
+
input_ids: torch.LongTensor = None,
|
| 582 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 583 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 584 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 585 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 586 |
+
use_cache: Optional[bool] = None,
|
| 587 |
+
output_attentions: Optional[bool] = None,
|
| 588 |
+
output_hidden_states: Optional[bool] = None,
|
| 589 |
+
return_dict: Optional[bool] = None,
|
| 590 |
+
multibyte_decoding: Optional[bool] = None,
|
| 591 |
+
) -> Tuple:
|
| 592 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 593 |
+
output_hidden_states = (output_hidden_states
|
| 594 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 595 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 596 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 597 |
+
|
| 598 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 599 |
+
raise ValueError(
|
| 600 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 604 |
+
raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 605 |
+
|
| 606 |
+
batch_size, seq_len = input_ids.shape
|
| 607 |
+
#### Step 0. Hack
|
| 608 |
+
if (not self.training) and (not use_cache) and (not multibyte_decoding):
|
| 609 |
+
# forward-only inference mode.
|
| 610 |
+
# We tweak use_cache to be True to reuse code for generation
|
| 611 |
+
use_cache = True
|
| 612 |
+
device = input_ids.device if input_ids is not None else None
|
| 613 |
+
if position_ids is None:
|
| 614 |
+
position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 615 |
+
|
| 616 |
+
#### Step 1. Prepare caches if in inference mode
|
| 617 |
+
if use_cache:
|
| 618 |
+
if past_key_values is not None:
|
| 619 |
+
assert isinstance(past_key_values, Cache)
|
| 620 |
+
else:
|
| 621 |
+
if not USE_TRITON_IMPL:
|
| 622 |
+
past_key_values = EvaCache()
|
| 623 |
+
else:
|
| 624 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 625 |
+
input_ids.shape[0],
|
| 626 |
+
self.config.num_attention_heads,
|
| 627 |
+
self.config.window_size,
|
| 628 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 629 |
+
self.config.num_hidden_layers,
|
| 630 |
+
self.embed_tokens.weight.dtype,
|
| 631 |
+
self.embed_tokens.weight.device,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if not multibyte_decoding:
|
| 635 |
+
if use_cache:
|
| 636 |
+
if USE_TRITON_IMPL:
|
| 637 |
+
causal_mask = self._prepare_eva_generation_attn_mask_triton(
|
| 638 |
+
attention_mask,
|
| 639 |
+
input_ids,
|
| 640 |
+
use_cache,
|
| 641 |
+
past_key_values
|
| 642 |
+
)
|
| 643 |
+
else:
|
| 644 |
+
causal_mask = self._prepare_eva_generation_attn_mask(
|
| 645 |
+
attention_mask,
|
| 646 |
+
input_ids,
|
| 647 |
+
use_cache,
|
| 648 |
+
past_key_values
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
assert self.training
|
| 652 |
+
assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
|
| 653 |
+
# for training, we need to pass in the attention mask
|
| 654 |
+
# usually calculated by _prepare_training_attn_mask()
|
| 655 |
+
causal_mask = attention_mask
|
| 656 |
+
else:
|
| 657 |
+
assert use_cache
|
| 658 |
+
causal_mask = attention_mask
|
| 659 |
+
|
| 660 |
+
if inputs_embeds is None:
|
| 661 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 662 |
+
|
| 663 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 664 |
+
max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
|
| 665 |
+
|
| 666 |
+
hidden_states = inputs_embeds
|
| 667 |
+
|
| 668 |
+
if position_ids is None:
|
| 669 |
+
assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
|
| 670 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 671 |
+
position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 672 |
+
|
| 673 |
+
cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
|
| 674 |
+
assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
|
| 675 |
+
assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
|
| 676 |
+
assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
|
| 677 |
+
cos = cos[position_ids, :]
|
| 678 |
+
sin = sin[position_ids, :]
|
| 679 |
+
cos = cos.unsqueeze(1)
|
| 680 |
+
sin = sin.unsqueeze(1)
|
| 681 |
+
|
| 682 |
+
# decoder layers
|
| 683 |
+
all_hidden_states = () if output_hidden_states else None
|
| 684 |
+
all_self_attns = () if output_attentions else None
|
| 685 |
+
next_decoder_cache = None
|
| 686 |
+
|
| 687 |
+
for decoder_layer in self.layers:
|
| 688 |
+
if output_hidden_states:
|
| 689 |
+
all_hidden_states += (hidden_states, )
|
| 690 |
+
|
| 691 |
+
if self.gradient_checkpointing and self.training:
|
| 692 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 693 |
+
decoder_layer.__call__,
|
| 694 |
+
hidden_states,
|
| 695 |
+
causal_mask,
|
| 696 |
+
position_ids,
|
| 697 |
+
past_key_values,
|
| 698 |
+
output_attentions,
|
| 699 |
+
use_cache,
|
| 700 |
+
cos,
|
| 701 |
+
sin,
|
| 702 |
+
multibyte_decoding,
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
layer_outputs = decoder_layer(
|
| 706 |
+
hidden_states,
|
| 707 |
+
attention_mask=causal_mask,
|
| 708 |
+
position_ids=position_ids,
|
| 709 |
+
past_key_value=past_key_values,
|
| 710 |
+
output_attentions=output_attentions,
|
| 711 |
+
use_cache=use_cache,
|
| 712 |
+
cos=cos,
|
| 713 |
+
sin=sin,
|
| 714 |
+
multibyte_decoding=multibyte_decoding,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
hidden_states = layer_outputs[0]
|
| 718 |
+
|
| 719 |
+
if use_cache:
|
| 720 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 721 |
+
|
| 722 |
+
if output_attentions:
|
| 723 |
+
all_self_attns += (layer_outputs[1], )
|
| 724 |
+
|
| 725 |
+
hidden_states = self.norm(hidden_states)
|
| 726 |
+
|
| 727 |
+
# add hidden states from the last decoder layer
|
| 728 |
+
if output_hidden_states:
|
| 729 |
+
all_hidden_states += (hidden_states, )
|
| 730 |
+
|
| 731 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 732 |
+
if not return_dict:
|
| 733 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 734 |
+
|
| 735 |
+
return BaseModelOutputWithPast(
|
| 736 |
+
last_hidden_state=hidden_states,
|
| 737 |
+
past_key_values=next_cache,
|
| 738 |
+
hidden_states=all_hidden_states,
|
| 739 |
+
attentions=all_self_attns,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
|
| 744 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 745 |
+
|
| 746 |
+
def __init__(self, config):
|
| 747 |
+
EvaBytePreTrainedModel.__init__(self, config)
|
| 748 |
+
|
| 749 |
+
self.model = EvaByteModel(config)
|
| 750 |
+
self.vocab_size = config.vocab_size
|
| 751 |
+
# define multibyte prediction heads
|
| 752 |
+
if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
|
| 753 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
|
| 754 |
+
else:
|
| 755 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 756 |
+
|
| 757 |
+
self.post_init()
|
| 758 |
+
|
| 759 |
+
def get_input_embeddings(self):
|
| 760 |
+
return self.model.embed_tokens
|
| 761 |
+
|
| 762 |
+
def set_input_embeddings(self, value):
|
| 763 |
+
self.model.embed_tokens = value
|
| 764 |
+
|
| 765 |
+
def get_output_embeddings(self):
|
| 766 |
+
return self.lm_head
|
| 767 |
+
|
| 768 |
+
def set_output_embeddings(self, new_embeddings):
|
| 769 |
+
self.lm_head = new_embeddings
|
| 770 |
+
|
| 771 |
+
def set_decoder(self, decoder):
|
| 772 |
+
self.model = decoder
|
| 773 |
+
|
| 774 |
+
def get_decoder(self):
|
| 775 |
+
return self.model
|
| 776 |
+
|
| 777 |
+
def forward(
|
| 778 |
+
self,
|
| 779 |
+
input_ids: torch.LongTensor = None,
|
| 780 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 781 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 782 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 783 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 784 |
+
labels: Optional[torch.LongTensor] = None,
|
| 785 |
+
use_cache: Optional[bool] = None,
|
| 786 |
+
output_attentions: Optional[bool] = None,
|
| 787 |
+
output_hidden_states: Optional[bool] = None,
|
| 788 |
+
return_dict: Optional[bool] = None,
|
| 789 |
+
return_all_pred_logits: Optional[bool] = None,
|
| 790 |
+
multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 791 |
+
|
| 792 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 793 |
+
output_hidden_states = (output_hidden_states
|
| 794 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 795 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 796 |
+
|
| 797 |
+
if input_ids is None:
|
| 798 |
+
assert past_key_values is None
|
| 799 |
+
|
| 800 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 801 |
+
outputs = self.model(
|
| 802 |
+
input_ids=input_ids,
|
| 803 |
+
attention_mask=attention_mask,
|
| 804 |
+
position_ids=position_ids,
|
| 805 |
+
past_key_values=past_key_values,
|
| 806 |
+
inputs_embeds=inputs_embeds,
|
| 807 |
+
use_cache=use_cache,
|
| 808 |
+
output_attentions=output_attentions,
|
| 809 |
+
output_hidden_states=output_hidden_states,
|
| 810 |
+
return_dict=return_dict,
|
| 811 |
+
multibyte_decoding=multibyte_decoding,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
hidden_states = outputs[0]
|
| 815 |
+
|
| 816 |
+
logits = self.lm_head(hidden_states)
|
| 817 |
+
if self.config.fp32_logits:
|
| 818 |
+
logits = logits.float()
|
| 819 |
+
|
| 820 |
+
loss = None
|
| 821 |
+
if labels is not None:
|
| 822 |
+
loss_fct = CrossEntropyLoss(reduction="none")
|
| 823 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 824 |
+
shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 825 |
+
# shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
|
| 826 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 827 |
+
else:
|
| 828 |
+
shift_logits = logits.view(-1, self.config.vocab_size)
|
| 829 |
+
shift_labels = labels.view(-1)
|
| 830 |
+
# Enable model parallelism
|
| 831 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 832 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 833 |
+
|
| 834 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 835 |
+
all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 836 |
+
|
| 837 |
+
if return_all_pred_logits:
|
| 838 |
+
logits = all_pred_logits
|
| 839 |
+
else:
|
| 840 |
+
logits = all_pred_logits[..., 0, :]
|
| 841 |
+
|
| 842 |
+
if not return_dict:
|
| 843 |
+
output = (logits, ) + outputs[1:]
|
| 844 |
+
return (loss, ) + output if loss is not None else output
|
| 845 |
+
|
| 846 |
+
return CausalLMOutputWithPast(
|
| 847 |
+
loss=loss,
|
| 848 |
+
logits=logits,
|
| 849 |
+
past_key_values=outputs.past_key_values,
|
| 850 |
+
hidden_states=outputs.hidden_states,
|
| 851 |
+
attentions=outputs.attentions,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def prepare_inputs_for_generation(self,
|
| 856 |
+
input_ids,
|
| 857 |
+
past_key_values=None,
|
| 858 |
+
attention_mask=None,
|
| 859 |
+
inputs_embeds=None,
|
| 860 |
+
use_cache=True,
|
| 861 |
+
**kwargs):
|
| 862 |
+
# prefill phase:
|
| 863 |
+
# input_ids: b x s
|
| 864 |
+
# attention_mask: None if no padding or b x s
|
| 865 |
+
# position_ids : b x s
|
| 866 |
+
|
| 867 |
+
# token gen phase:
|
| 868 |
+
# input_ids : b x 1
|
| 869 |
+
# attention_mask: b x 1 x s
|
| 870 |
+
# position_ids: b x 1
|
| 871 |
+
past_length = 0
|
| 872 |
+
if past_key_values is not None:
|
| 873 |
+
assert isinstance(past_key_values, Cache)
|
| 874 |
+
past_length = past_key_values.get_seq_length()
|
| 875 |
+
|
| 876 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 877 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
| 878 |
+
elif past_length < input_ids.shape[1]:
|
| 879 |
+
input_ids = input_ids[:, past_length:]
|
| 880 |
+
|
| 881 |
+
position_ids = kwargs.get("position_ids", None)
|
| 882 |
+
if attention_mask is not None and position_ids is None:
|
| 883 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 884 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 885 |
+
if past_key_values:
|
| 886 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 887 |
+
|
| 888 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 889 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 890 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 891 |
+
else:
|
| 892 |
+
model_inputs = {"input_ids": input_ids}
|
| 893 |
+
|
| 894 |
+
# must initialize position_ids at each step during GPU inference
|
| 895 |
+
assert position_ids is not None
|
| 896 |
+
model_inputs.update(
|
| 897 |
+
{
|
| 898 |
+
"position_ids": position_ids,
|
| 899 |
+
"past_key_values": past_key_values,
|
| 900 |
+
"use_cache": use_cache,
|
| 901 |
+
"attention_mask": attention_mask,
|
| 902 |
+
}
|
| 903 |
+
)
|
| 904 |
+
return model_inputs
|
| 905 |
+
|
| 906 |
+
@staticmethod
|
| 907 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 908 |
+
reordered_past = ()
|
| 909 |
+
for layer_past in past_key_values:
|
| 910 |
+
reordered_past += (tuple(
|
| 911 |
+
past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
|
| 912 |
+
return reordered_past
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/multibyte_decoding_evabyte.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# The implementation of multibyte deocidng is largely adapted from
|
| 3 |
+
# Medusa decoding: https://github.com/FasterDecoding/Medusa
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.generation.stopping_criteria import (
|
| 7 |
+
MaxLengthCriteria,
|
| 8 |
+
StoppingCriteriaList,
|
| 9 |
+
)
|
| 10 |
+
from typing import Union, List
|
| 11 |
+
from .eva_cache import EvaStaticCacheForTriton
|
| 12 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 13 |
+
|
| 14 |
+
class MultibyteEosTokenCriteria:
|
| 15 |
+
"""
|
| 16 |
+
This class implements a simple stopping criteria to stop generation whenever
|
| 17 |
+
the "end-of-sequence" token is generated in the last `new_tokens` tokens.
|
| 18 |
+
|
| 19 |
+
Adapted from
|
| 20 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
|
| 21 |
+
By default, it uses the `model.generation_config.eos_token_id`.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
eos_token_id (`Union[int, List[int]]`):
|
| 25 |
+
The id(s) of the *end-of-sequence* token.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, eos_token_ids: Union[int, List[int]]):
|
| 29 |
+
if isinstance(eos_token_ids, int):
|
| 30 |
+
eos_token_ids = [eos_token_ids]
|
| 31 |
+
self.eos_token_ids = eos_token_ids
|
| 32 |
+
|
| 33 |
+
def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
|
| 34 |
+
current_input_len = input_ids.shape[-1]
|
| 35 |
+
new_token_ids = input_ids[:, current_input_len - new_tokens:]
|
| 36 |
+
for eos_token_id in self.eos_token_ids:
|
| 37 |
+
if torch.any(new_token_ids == eos_token_id):
|
| 38 |
+
return True
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def build_tree(spec):
|
| 42 |
+
nodes_at_depth = []
|
| 43 |
+
nodes_at_depth.append([()]) # Root at depth 1
|
| 44 |
+
|
| 45 |
+
for d in range(1, len(spec) + 1):
|
| 46 |
+
prev_nodes = nodes_at_depth[d - 1]
|
| 47 |
+
spec_list = spec[d - 1]
|
| 48 |
+
current_nodes = []
|
| 49 |
+
for node_idx, node in enumerate(prev_nodes):
|
| 50 |
+
if node_idx < len(spec_list):
|
| 51 |
+
num_children = spec_list[node_idx]
|
| 52 |
+
else:
|
| 53 |
+
num_children = 0
|
| 54 |
+
for child_idx in range(num_children):
|
| 55 |
+
new_node = node + (child_idx,)
|
| 56 |
+
current_nodes.append(new_node)
|
| 57 |
+
nodes_at_depth.append(current_nodes)
|
| 58 |
+
|
| 59 |
+
# Flatten the list of nodes, excluding the root node if desired
|
| 60 |
+
all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
|
| 61 |
+
return all_nodes
|
| 62 |
+
|
| 63 |
+
evabyte_7b_95 = build_tree(
|
| 64 |
+
[
|
| 65 |
+
[10],
|
| 66 |
+
[10, 8, 2, 2, 1, 1],
|
| 67 |
+
[10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
|
| 68 |
+
[8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
|
| 69 |
+
[6, 2, 1, 1],
|
| 70 |
+
[4, 2, 1, 1],
|
| 71 |
+
[4, 2, 1],
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
evabyte_7b_31 = build_tree(
|
| 75 |
+
[
|
| 76 |
+
[4],
|
| 77 |
+
[3, 2, 1, 1],
|
| 78 |
+
[3, 2, 1, 1],
|
| 79 |
+
[2, 1, 1],
|
| 80 |
+
[2, 1],
|
| 81 |
+
[2, 1],
|
| 82 |
+
[2, 1],
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
|
| 86 |
+
|
| 87 |
+
def pad_path(path, length, pad_value=-2):
|
| 88 |
+
"""
|
| 89 |
+
Pad the given path list with a specific value up to a specified length.
|
| 90 |
+
|
| 91 |
+
Parameters:
|
| 92 |
+
- path (list): The original list that needs padding.
|
| 93 |
+
- length (int): The desired length of the padded list.
|
| 94 |
+
- pad_value (optional, default=-2): The value to use for padding.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
- list: A new list based on the original path but padded to the desired length.
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> pad_path([1,2,3], 5)
|
| 101 |
+
[1, 2, 3, -2, -2]
|
| 102 |
+
|
| 103 |
+
Note:
|
| 104 |
+
If the given path is already longer than the specified length,
|
| 105 |
+
then no padding occurs, and the original path is returned.
|
| 106 |
+
"""
|
| 107 |
+
return path + [pad_value] * (length - len(path))
|
| 108 |
+
|
| 109 |
+
def reset_past_key_values(passed_key_values):
|
| 110 |
+
"""
|
| 111 |
+
Resets the current lengths in the passed key-values to zero.
|
| 112 |
+
|
| 113 |
+
This function is designed to be used during the evaluation of a baseline model.
|
| 114 |
+
It iterates through each layer's key-values and sets their current lengths to zero,
|
| 115 |
+
effectively resetting their state.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
|
| 122 |
+
"""
|
| 123 |
+
for i in range(len(passed_key_values)):
|
| 124 |
+
for j in range(2):
|
| 125 |
+
passed_key_values[i][j].current_length.fill_(0)
|
| 126 |
+
return passed_key_values
|
| 127 |
+
|
| 128 |
+
def get_nucleus_one_token(logit, temperature, top_p):
|
| 129 |
+
"""
|
| 130 |
+
Performs token sampling based on the nucleus (top-p) sampling method.
|
| 131 |
+
|
| 132 |
+
This function selects a token from a given logit distribution using the nucleus sampling strategy.
|
| 133 |
+
It allows for more controlled and diverse generation compared to traditional top-k sampling.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
|
| 137 |
+
temperature (float): A temperature parameter to control the randomness in sampling.
|
| 138 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 139 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 140 |
+
It controls the size of the set of high-probability tokens to consider for sampling.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 144 |
+
"""
|
| 145 |
+
if top_p >= 1:
|
| 146 |
+
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
|
| 147 |
+
logit = logit / temperature
|
| 148 |
+
probs = torch.softmax(logit, dim=-1)
|
| 149 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 150 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 151 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 152 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 153 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 154 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 155 |
+
logit[indices_to_remove] = float('-inf')
|
| 156 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 157 |
+
return sampled_tokens
|
| 158 |
+
|
| 159 |
+
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
|
| 160 |
+
"""
|
| 161 |
+
Implements token sampling based on the typical sampling method.
|
| 162 |
+
|
| 163 |
+
This function selects a token from a given logit distribution using the typical sampling strategy,
|
| 164 |
+
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
|
| 168 |
+
temperature (float): A parameter to control the randomness in sampling.
|
| 169 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 170 |
+
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
|
| 171 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 175 |
+
"""
|
| 176 |
+
logit = logit / temperature
|
| 177 |
+
probs = torch.softmax(logit, dim=-1)
|
| 178 |
+
entropy = -torch.sum(
|
| 179 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 180 |
+
)
|
| 181 |
+
threshold = torch.minimum(
|
| 182 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 183 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 184 |
+
)
|
| 185 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 186 |
+
logit[indices_to_remove] = float('-inf')
|
| 187 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 188 |
+
return sampled_tokens
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def generate_medusa_buffers(medusa_choices, device="cuda"):
|
| 193 |
+
"""
|
| 194 |
+
Generate buffers for the Medusa structure based on the provided choices.
|
| 195 |
+
|
| 196 |
+
Parameters:
|
| 197 |
+
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
| 198 |
+
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
- dict: A dictionary containing buffers related to the Medusa structure.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
# Sort the medusa_choices based on their lengths and then their values
|
| 205 |
+
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
| 206 |
+
medusa_len = len(sorted_medusa_choices) + 1
|
| 207 |
+
|
| 208 |
+
# Initialize depth_counts to keep track of how many choices have a particular depth
|
| 209 |
+
depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
|
| 210 |
+
for path in sorted_medusa_choices:
|
| 211 |
+
depth_counts[len(path) - 1] += 1
|
| 212 |
+
|
| 213 |
+
# Create the attention mask for Medusa
|
| 214 |
+
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
| 215 |
+
medusa_attn_mask[:, 0] = 1
|
| 216 |
+
start = 0
|
| 217 |
+
for i in range(len(depth_counts)):
|
| 218 |
+
for j in range(depth_counts[i]):
|
| 219 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 220 |
+
# retrieve ancestor position
|
| 221 |
+
if len(cur_medusa_choice) == 1:
|
| 222 |
+
continue
|
| 223 |
+
ancestor_idx = []
|
| 224 |
+
for c in range(len(cur_medusa_choice) - 1):
|
| 225 |
+
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
|
| 226 |
+
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
| 227 |
+
start += depth_counts[i]
|
| 228 |
+
|
| 229 |
+
# Generate tree indices for the Medusa structure
|
| 230 |
+
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
| 231 |
+
medusa_tree_indices[0] = 0
|
| 232 |
+
start = 0
|
| 233 |
+
for i in range(len(depth_counts)):
|
| 234 |
+
for j in range(depth_counts[i]):
|
| 235 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 236 |
+
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
|
| 237 |
+
start += depth_counts[i]
|
| 238 |
+
|
| 239 |
+
# Generate position IDs for the Medusa structure
|
| 240 |
+
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
| 241 |
+
start = 0
|
| 242 |
+
for i in range(len(depth_counts)):
|
| 243 |
+
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
|
| 244 |
+
start += depth_counts[i]
|
| 245 |
+
|
| 246 |
+
# Generate retrieval indices for Medusa structure verification
|
| 247 |
+
retrieve_indices_nest = []
|
| 248 |
+
retrieve_paths = []
|
| 249 |
+
for i in range(len(sorted_medusa_choices)):
|
| 250 |
+
cur_medusa_choice = sorted_medusa_choices[-i-1]
|
| 251 |
+
retrieve_indice = []
|
| 252 |
+
if cur_medusa_choice in retrieve_paths:
|
| 253 |
+
continue
|
| 254 |
+
else:
|
| 255 |
+
for c in range(len(cur_medusa_choice)):
|
| 256 |
+
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
|
| 257 |
+
retrieve_paths.append(cur_medusa_choice[:c+1])
|
| 258 |
+
retrieve_indices_nest.append(retrieve_indice)
|
| 259 |
+
max_length = max([len(x) for x in retrieve_indices_nest])
|
| 260 |
+
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
|
| 261 |
+
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
| 262 |
+
retrieve_indices = retrieve_indices + 1
|
| 263 |
+
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
|
| 264 |
+
|
| 265 |
+
# Aggregate the generated buffers into a dictionary
|
| 266 |
+
medusa_buffers = {
|
| 267 |
+
"medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
|
| 268 |
+
"tree_indices": medusa_tree_indices,
|
| 269 |
+
"medusa_position_ids": medusa_position_ids.unsqueeze(0),
|
| 270 |
+
"retrieve_indices": retrieve_indices,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# Move the tensors in the dictionary to the specified device
|
| 274 |
+
medusa_buffers = {
|
| 275 |
+
k: v.clone().to(device)
|
| 276 |
+
if isinstance(v, torch.Tensor)
|
| 277 |
+
else torch.tensor(v, device=device)
|
| 278 |
+
for k, v in medusa_buffers.items()
|
| 279 |
+
}
|
| 280 |
+
return medusa_buffers
|
| 281 |
+
|
| 282 |
+
def generate_candidates(
|
| 283 |
+
medusa_logits,
|
| 284 |
+
logits,
|
| 285 |
+
tree_indices,
|
| 286 |
+
retrieve_indices,
|
| 287 |
+
temperature = 0,
|
| 288 |
+
posterior_threshold=0.3,
|
| 289 |
+
posterior_alpha = 0.09,
|
| 290 |
+
top_p=0.8,
|
| 291 |
+
sampling = 'typical',
|
| 292 |
+
fast = False
|
| 293 |
+
):
|
| 294 |
+
# Say we have 3 heads, and the top-4 for each head are:
|
| 295 |
+
# [10, 3, 8, 4]
|
| 296 |
+
# [9, 5, 1, 6]
|
| 297 |
+
# [7, 16, 3, 2]
|
| 298 |
+
|
| 299 |
+
# candidates_id = 10
|
| 300 |
+
if temperature == 0 or fast:
|
| 301 |
+
candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
|
| 302 |
+
else:
|
| 303 |
+
if sampling == 'typical':
|
| 304 |
+
candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
|
| 305 |
+
elif sampling == 'nucleus':
|
| 306 |
+
candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
|
| 307 |
+
else:
|
| 308 |
+
raise NotImplementedError
|
| 309 |
+
|
| 310 |
+
# this calculates the top-k medusa logits
|
| 311 |
+
# candidates_medusa_id = [
|
| 312 |
+
# [9, 5, 1, 6]
|
| 313 |
+
# [7, 16, 3, 2]
|
| 314 |
+
# ]
|
| 315 |
+
candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
|
| 316 |
+
|
| 317 |
+
# [10, 9, 5, 1, 6, 7, 16, 3, 2]
|
| 318 |
+
candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
|
| 319 |
+
|
| 320 |
+
# based on the pre-defined tree_indices, select the corresponding candidates
|
| 321 |
+
# if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
|
| 322 |
+
# tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
|
| 323 |
+
tree_candidate_ids = candidate_ids[tree_indices]
|
| 324 |
+
|
| 325 |
+
# tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
|
| 326 |
+
# Sometimes the tree_indices are padded, so we append a zero here
|
| 327 |
+
# so that all padded indices select the appended zero.
|
| 328 |
+
tree_candidate_ids_ext = torch.cat(
|
| 329 |
+
[
|
| 330 |
+
tree_candidate_ids,
|
| 331 |
+
torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
|
| 332 |
+
],
|
| 333 |
+
dim=0
|
| 334 |
+
)
|
| 335 |
+
# [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
|
| 336 |
+
unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
|
| 337 |
+
|
| 338 |
+
tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
|
| 339 |
+
|
| 340 |
+
return tree_candidate_ids, unflattened_candidate_ids
|
| 341 |
+
|
| 342 |
+
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
|
| 343 |
+
"""
|
| 344 |
+
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
|
| 345 |
+
|
| 346 |
+
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
|
| 347 |
+
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
|
| 348 |
+
temperature scaling and cumulative probability thresholding.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 352 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 353 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 354 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 358 |
+
"""
|
| 359 |
+
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
|
| 360 |
+
|
| 361 |
+
# Apply temperature
|
| 362 |
+
logits = logits[:, :-1] / temperature
|
| 363 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 364 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 365 |
+
if top_p >= 1:
|
| 366 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 367 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 368 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 369 |
+
return posterior_mask
|
| 370 |
+
# Convert to probabilities (softmax)
|
| 371 |
+
probs = F.softmax(logits, dim=-1)
|
| 372 |
+
# Sort the probabilities
|
| 373 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 374 |
+
|
| 375 |
+
# Compute cumulative probabilities
|
| 376 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 377 |
+
|
| 378 |
+
# Create mask for the top-p nucleus
|
| 379 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 380 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 381 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 382 |
+
|
| 383 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Remove low-probability tokens
|
| 387 |
+
logits[indices_to_remove] = float('-inf')
|
| 388 |
+
# Sample from the remaining tokens
|
| 389 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 390 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 391 |
+
# Create a mask for selected tokens
|
| 392 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 393 |
+
|
| 394 |
+
return posterior_mask
|
| 395 |
+
|
| 396 |
+
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
|
| 397 |
+
"""
|
| 398 |
+
Args:
|
| 399 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 400 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 401 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 402 |
+
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
|
| 403 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 407 |
+
"""
|
| 408 |
+
logits = logits[:, :-1] / temperature
|
| 409 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 410 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 411 |
+
probs = F.softmax(logits, dim=-1)
|
| 412 |
+
entropy = -torch.sum(
|
| 413 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 414 |
+
)
|
| 415 |
+
threshold = torch.minimum(
|
| 416 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 417 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 418 |
+
)
|
| 419 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 420 |
+
logits[indices_to_remove] = float('-inf')
|
| 421 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 422 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 423 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 424 |
+
return posterior_mask
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def evaluate_posterior(
|
| 429 |
+
logits,
|
| 430 |
+
candidates,
|
| 431 |
+
temperature,
|
| 432 |
+
posterior_threshold=0.3,
|
| 433 |
+
posterior_alpha = 0.09,
|
| 434 |
+
top_p=0.8,
|
| 435 |
+
sampling = 'typical',
|
| 436 |
+
fast = True
|
| 437 |
+
):
|
| 438 |
+
if logits.shape[1] <= 1:
|
| 439 |
+
return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
|
| 440 |
+
# Greedy decoding based on temperature value
|
| 441 |
+
if temperature == 0:
|
| 442 |
+
# Find the tokens that match the maximum logits for each position in the sequence
|
| 443 |
+
posterior_mask = (
|
| 444 |
+
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
|
| 445 |
+
).int()
|
| 446 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 447 |
+
accept_length = candidates_accept_length.max().item()
|
| 448 |
+
# Choose the best candidate
|
| 449 |
+
if accept_length == 0:
|
| 450 |
+
# Default to the first candidate if none are accepted
|
| 451 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 452 |
+
else:
|
| 453 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 454 |
+
return best_candidate, accept_length
|
| 455 |
+
elif sampling == 'typical':
|
| 456 |
+
if fast:
|
| 457 |
+
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
|
| 458 |
+
candidates_prob = torch.gather(
|
| 459 |
+
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
|
| 460 |
+
).squeeze(-1)
|
| 461 |
+
posterior_entropy = -torch.sum(
|
| 462 |
+
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
|
| 463 |
+
) # torch.sum(torch.log(*)) is faster than torch.prod
|
| 464 |
+
threshold = torch.minimum(
|
| 465 |
+
torch.ones_like(posterior_entropy) * posterior_threshold,
|
| 466 |
+
torch.exp(-posterior_entropy) * posterior_alpha,
|
| 467 |
+
)
|
| 468 |
+
posterior_mask = candidates_prob > threshold
|
| 469 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 470 |
+
|
| 471 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 472 |
+
accept_length = candidates_accept_length.max().item()
|
| 473 |
+
if accept_length == 0:
|
| 474 |
+
# If no candidates are accepted, just choose the first one
|
| 475 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 476 |
+
else:
|
| 477 |
+
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
|
| 478 |
+
# Accept the best one according to likelihood
|
| 479 |
+
likelihood = torch.sum(
|
| 480 |
+
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
|
| 481 |
+
)
|
| 482 |
+
best_candidate = best_candidates[torch.argmax(likelihood)]
|
| 483 |
+
return best_candidate, accept_length
|
| 484 |
+
# Calculate posterior probabilities and thresholds for candidate selection
|
| 485 |
+
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
|
| 486 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 487 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 488 |
+
accept_length = candidates_accept_length.max().item()
|
| 489 |
+
|
| 490 |
+
if accept_length == 0:
|
| 491 |
+
# If no candidates are accepted, just choose the first one
|
| 492 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 493 |
+
else:
|
| 494 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 495 |
+
# Accept the best one according to likelihood
|
| 496 |
+
return best_candidate, accept_length
|
| 497 |
+
elif sampling == 'nucleus':
|
| 498 |
+
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
|
| 499 |
+
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
|
| 500 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 501 |
+
accept_length = candidates_accept_length.max().item()
|
| 502 |
+
# Choose the best candidate
|
| 503 |
+
if accept_length == 0:
|
| 504 |
+
# Default to the first candidate if none are accepted
|
| 505 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 506 |
+
else:
|
| 507 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 508 |
+
return best_candidate, accept_length
|
| 509 |
+
else:
|
| 510 |
+
raise NotImplementedError
|
| 511 |
+
|
| 512 |
+
def update_inference_inputs(
|
| 513 |
+
input_ids,
|
| 514 |
+
medusa_logits,
|
| 515 |
+
logits,
|
| 516 |
+
candidate_ids,
|
| 517 |
+
best_candidate,
|
| 518 |
+
accept_length,
|
| 519 |
+
):
|
| 520 |
+
input_ids = torch.cat(
|
| 521 |
+
[
|
| 522 |
+
input_ids,
|
| 523 |
+
candidate_ids[None, best_candidate, : accept_length + 1]
|
| 524 |
+
],
|
| 525 |
+
dim=-1
|
| 526 |
+
)
|
| 527 |
+
logits = logits[
|
| 528 |
+
None, best_candidate, accept_length : accept_length + 1
|
| 529 |
+
]
|
| 530 |
+
medusa_logits = medusa_logits[
|
| 531 |
+
:, None, best_candidate, accept_length : accept_length + 1
|
| 532 |
+
]
|
| 533 |
+
# Update the new token counter
|
| 534 |
+
new_token = accept_length + 1
|
| 535 |
+
return input_ids, medusa_logits, logits, new_token
|
| 536 |
+
|
| 537 |
+
def split_logits(full_logits):
|
| 538 |
+
# logits has shape [b, n, heads, vocab_size]
|
| 539 |
+
logits = full_logits[..., 0, :]
|
| 540 |
+
medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
|
| 541 |
+
return medusa_logits, logits
|
| 542 |
+
|
| 543 |
+
class MultiByteDecodingMixin:
|
| 544 |
+
def multi_byte_pred_update_cache(
|
| 545 |
+
self,
|
| 546 |
+
past_key_values,
|
| 547 |
+
retrieve_indices,
|
| 548 |
+
best_candidate,
|
| 549 |
+
new_tokens,
|
| 550 |
+
):
|
| 551 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 552 |
+
select_indices = (
|
| 553 |
+
retrieve_indices[best_candidate, : new_tokens] + prev_window_len
|
| 554 |
+
)
|
| 555 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 556 |
+
|
| 557 |
+
past_key_values.update_past_len(new_tokens, layer_idx)
|
| 558 |
+
|
| 559 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 560 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 561 |
+
|
| 562 |
+
tgt_window_k = past_window_k[..., select_indices, :]
|
| 563 |
+
tgt_window_v = past_window_v[..., select_indices, :]
|
| 564 |
+
|
| 565 |
+
dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 566 |
+
dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 567 |
+
|
| 568 |
+
dst_window_k.copy_(tgt_window_k, non_blocking=True)
|
| 569 |
+
dst_window_v.copy_(tgt_window_v, non_blocking=True)
|
| 570 |
+
|
| 571 |
+
new_window_len = prev_window_len + new_tokens
|
| 572 |
+
if new_window_len >= self.config.window_size:
|
| 573 |
+
assert new_window_len < 2 * self.config.window_size
|
| 574 |
+
|
| 575 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 576 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 577 |
+
|
| 578 |
+
_window_len = new_window_len - self.config.window_size
|
| 579 |
+
|
| 580 |
+
if _window_len > 0:
|
| 581 |
+
new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
|
| 582 |
+
new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
|
| 583 |
+
|
| 584 |
+
_dst_window_k = past_window_k[..., : _window_len, :]
|
| 585 |
+
_dst_window_v = past_window_v[..., : _window_len, :]
|
| 586 |
+
|
| 587 |
+
_dst_window_k.copy_(new_window_k, non_blocking=True)
|
| 588 |
+
_dst_window_v.copy_(new_window_v, non_blocking=True)
|
| 589 |
+
|
| 590 |
+
past_key_values.past_window_pos[layer_idx] = _window_len
|
| 591 |
+
else:
|
| 592 |
+
dump_k = None
|
| 593 |
+
dump_v = None
|
| 594 |
+
past_key_values.past_window_pos[layer_idx] = new_window_len
|
| 595 |
+
|
| 596 |
+
if dump_k is not None and dump_v is not None:
|
| 597 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 598 |
+
dump_k, dump_v,
|
| 599 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 600 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 601 |
+
None,
|
| 602 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 603 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 604 |
+
)
|
| 605 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 606 |
+
rfa_k, rfa_v, layer_idx
|
| 607 |
+
)
|
| 608 |
+
return past_key_values
|
| 609 |
+
|
| 610 |
+
def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 611 |
+
self,
|
| 612 |
+
past_key_values,
|
| 613 |
+
):
|
| 614 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 615 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 616 |
+
|
| 617 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 618 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 619 |
+
|
| 620 |
+
new_window_len = prev_window_len
|
| 621 |
+
if new_window_len == self.config.window_size:
|
| 622 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 623 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 624 |
+
past_key_values.past_window_pos[layer_idx] = 0
|
| 625 |
+
|
| 626 |
+
if dump_k is not None and dump_v is not None:
|
| 627 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 628 |
+
dump_k, dump_v,
|
| 629 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 630 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 631 |
+
None,
|
| 632 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 633 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 634 |
+
)
|
| 635 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 636 |
+
rfa_k, rfa_v, layer_idx
|
| 637 |
+
)
|
| 638 |
+
return past_key_values
|
| 639 |
+
|
| 640 |
+
def multi_byte_pred_update_attn_mask(
|
| 641 |
+
self,
|
| 642 |
+
last_iter_new_tokens,
|
| 643 |
+
tree_candidate_ids,
|
| 644 |
+
past_attn_mask,
|
| 645 |
+
medusa_attn_mask,
|
| 646 |
+
past_key_values,
|
| 647 |
+
):
|
| 648 |
+
batch_size, tree_candidate_len = tree_candidate_ids.shape
|
| 649 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 650 |
+
# NOTE: past_key_values has been updated so now
|
| 651 |
+
# seen_tokens incldues new tokens from the last tree iteration
|
| 652 |
+
assert seen_tokens > 0
|
| 653 |
+
# so one iteration would not cross two windows
|
| 654 |
+
assert last_iter_new_tokens < self.config.window_size
|
| 655 |
+
|
| 656 |
+
if past_attn_mask is not None and seen_tokens < self.config.window_size:
|
| 657 |
+
past_attn_mask = torch.cat(
|
| 658 |
+
[
|
| 659 |
+
past_attn_mask,
|
| 660 |
+
torch.ones(
|
| 661 |
+
[batch_size, 1, tree_candidate_len, last_iter_new_tokens],
|
| 662 |
+
dtype=torch.bool,
|
| 663 |
+
device=self.device
|
| 664 |
+
)
|
| 665 |
+
],
|
| 666 |
+
dim=-1
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
# we initialize attn mask each time when
|
| 670 |
+
# 1. the model crosses the window bounary, or
|
| 671 |
+
# 2. after prefilling
|
| 672 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 673 |
+
|
| 674 |
+
window_tokens = seen_tokens % self.config.window_size
|
| 675 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 676 |
+
attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
|
| 677 |
+
past_attn_mask = torch.ones(
|
| 678 |
+
(batch_size, 1, tree_candidate_len, attn_mask_len),
|
| 679 |
+
dtype=torch.bool,
|
| 680 |
+
device=self.device
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# note that 1 indicates the position is not masked
|
| 684 |
+
tree_attn_mask = torch.cat(
|
| 685 |
+
[
|
| 686 |
+
past_attn_mask,
|
| 687 |
+
medusa_attn_mask.to(torch.bool)
|
| 688 |
+
],
|
| 689 |
+
dim=-1
|
| 690 |
+
)
|
| 691 |
+
return tree_attn_mask, past_attn_mask
|
| 692 |
+
|
| 693 |
+
@torch.no_grad()
|
| 694 |
+
def multi_byte_generate(
|
| 695 |
+
self,
|
| 696 |
+
input_ids,
|
| 697 |
+
attention_mask=None,
|
| 698 |
+
temperature=0.0,
|
| 699 |
+
max_length=None,
|
| 700 |
+
max_new_tokens=None,
|
| 701 |
+
stopping_criteria=None,
|
| 702 |
+
posterior_threshold=0.09,
|
| 703 |
+
posterior_alpha=0.3,
|
| 704 |
+
top_p=0.8,
|
| 705 |
+
sampling='typical',
|
| 706 |
+
fast=True,
|
| 707 |
+
do_sample=False,
|
| 708 |
+
medusa_choices=None,
|
| 709 |
+
return_acc_lengths=False
|
| 710 |
+
):
|
| 711 |
+
if do_sample or temperature > 0.0:
|
| 712 |
+
fast = False
|
| 713 |
+
|
| 714 |
+
### Prepare `max_length` depending on other stopping criteria.
|
| 715 |
+
if max_new_tokens is not None:
|
| 716 |
+
max_length = max_new_tokens + input_ids.shape[-1]
|
| 717 |
+
elif max_new_tokens is None and max_length is None:
|
| 718 |
+
max_length = getattr(self.config, "max_position_embeddings", 32768)
|
| 719 |
+
|
| 720 |
+
### Set up stopping criteria
|
| 721 |
+
eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
|
| 722 |
+
stop_criteria = StoppingCriteriaList()
|
| 723 |
+
if max_length is not None:
|
| 724 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 725 |
+
stop_criteria.append(
|
| 726 |
+
MaxLengthCriteria(
|
| 727 |
+
max_length=max_length,
|
| 728 |
+
max_position_embeddings=max_position_embeddings,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
if stopping_criteria is not None and len(stopping_criteria) > 0:
|
| 732 |
+
stop_criteria.extend(stopping_criteria)
|
| 733 |
+
|
| 734 |
+
assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
|
| 735 |
+
assert attention_mask is None, "Only support attention mask None for now"
|
| 736 |
+
# Avoid modifying the input_ids in-place
|
| 737 |
+
input_ids = input_ids.clone()
|
| 738 |
+
position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
|
| 739 |
+
|
| 740 |
+
####################################################
|
| 741 |
+
# 0. initialize the medusa buffers
|
| 742 |
+
####################################################
|
| 743 |
+
if medusa_choices is None:
|
| 744 |
+
medusa_choices = evabyte_7b_95
|
| 745 |
+
medusa_buffers = generate_medusa_buffers(
|
| 746 |
+
medusa_choices, device=self.device
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 750 |
+
input_ids.shape[0],
|
| 751 |
+
self.config.num_attention_heads,
|
| 752 |
+
# we add 256 to allow tree ids
|
| 753 |
+
self.config.window_size + 256,
|
| 754 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 755 |
+
self.config.num_hidden_layers,
|
| 756 |
+
self.lm_head.weight.dtype,
|
| 757 |
+
self.lm_head.weight.device,
|
| 758 |
+
)
|
| 759 |
+
# prefill to get medusa logits and logits
|
| 760 |
+
full_logits, past_key_values = self.forward(
|
| 761 |
+
input_ids,
|
| 762 |
+
attention_mask=attention_mask,
|
| 763 |
+
position_ids=position_ids,
|
| 764 |
+
use_cache=True,
|
| 765 |
+
past_key_values=past_key_values,
|
| 766 |
+
return_all_pred_logits=True,
|
| 767 |
+
multibyte_decoding=False,
|
| 768 |
+
)
|
| 769 |
+
# handles an edge case where the prefill length == window_size
|
| 770 |
+
# we force the previous window to be dumped into RFA chunks
|
| 771 |
+
past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 772 |
+
past_key_values
|
| 773 |
+
)
|
| 774 |
+
medusa_logits, logits = split_logits(full_logits)
|
| 775 |
+
|
| 776 |
+
past_attn_mask = None
|
| 777 |
+
last_iter_new_tokens = 0
|
| 778 |
+
max_iters = 32768
|
| 779 |
+
if return_acc_lengths:
|
| 780 |
+
acc_lengths = []
|
| 781 |
+
for _ in range(max_iters):
|
| 782 |
+
####################################################
|
| 783 |
+
# 1. generate candidate_ids with topk predictions from Medusa heads
|
| 784 |
+
####################################################
|
| 785 |
+
tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
|
| 786 |
+
medusa_logits,
|
| 787 |
+
logits,
|
| 788 |
+
medusa_buffers["tree_indices"],
|
| 789 |
+
medusa_buffers["retrieve_indices"],
|
| 790 |
+
temperature=temperature,
|
| 791 |
+
posterior_alpha=posterior_alpha,
|
| 792 |
+
posterior_threshold=posterior_threshold,
|
| 793 |
+
top_p=top_p,
|
| 794 |
+
sampling=sampling,
|
| 795 |
+
fast=fast,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
####################################################
|
| 799 |
+
# 2. Build the medusa attention mask and position ids
|
| 800 |
+
####################################################
|
| 801 |
+
# NOTE: 1 indicates the position is not masked
|
| 802 |
+
medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
|
| 803 |
+
last_iter_new_tokens,
|
| 804 |
+
tree_candidate_ids,
|
| 805 |
+
past_attn_mask,
|
| 806 |
+
medusa_buffers["medusa_attn_mask"],
|
| 807 |
+
past_key_values,
|
| 808 |
+
)
|
| 809 |
+
medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
|
| 810 |
+
|
| 811 |
+
####################################################
|
| 812 |
+
# 3. tree decoding
|
| 813 |
+
####################################################
|
| 814 |
+
tree_full_logits, past_key_values = self.forward(
|
| 815 |
+
tree_candidate_ids,
|
| 816 |
+
past_key_values=past_key_values,
|
| 817 |
+
attention_mask=medusa_attn_mask,
|
| 818 |
+
position_ids=medusa_position_ids,
|
| 819 |
+
return_all_pred_logits=True,
|
| 820 |
+
multibyte_decoding=True,
|
| 821 |
+
)
|
| 822 |
+
_medusa_logits, _logits = split_logits(tree_full_logits)
|
| 823 |
+
medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 824 |
+
logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 825 |
+
|
| 826 |
+
####################################################
|
| 827 |
+
# 4. candidate selection
|
| 828 |
+
####################################################
|
| 829 |
+
# if the current iteration, with tree tokens, crosses window
|
| 830 |
+
# boundaries, trim the condidate_ids to be within the window
|
| 831 |
+
# so that those exceeded tokens (which will be inaccurate)
|
| 832 |
+
# will not be considered
|
| 833 |
+
tree_depth = unflattened_candidate_ids.shape[-1]
|
| 834 |
+
if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
|
| 835 |
+
max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
|
| 836 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
|
| 837 |
+
_trimmed_logits = logits[:, :max_acc_len]
|
| 838 |
+
else:
|
| 839 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids
|
| 840 |
+
_trimmed_logits = logits
|
| 841 |
+
best_candidate, accept_length = evaluate_posterior(
|
| 842 |
+
_trimmed_logits,
|
| 843 |
+
_trimmed_unflattened_candidate_ids,
|
| 844 |
+
temperature,
|
| 845 |
+
posterior_threshold,
|
| 846 |
+
posterior_alpha,
|
| 847 |
+
top_p=top_p,
|
| 848 |
+
sampling=sampling,
|
| 849 |
+
fast=fast
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
####################################################
|
| 853 |
+
# 5. update model inputs and caches
|
| 854 |
+
####################################################
|
| 855 |
+
input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
|
| 856 |
+
input_ids,
|
| 857 |
+
medusa_logits,
|
| 858 |
+
logits,
|
| 859 |
+
unflattened_candidate_ids,
|
| 860 |
+
best_candidate,
|
| 861 |
+
accept_length,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
past_key_values = self.multi_byte_pred_update_cache(
|
| 865 |
+
past_key_values,
|
| 866 |
+
medusa_buffers["retrieve_indices"],
|
| 867 |
+
best_candidate,
|
| 868 |
+
last_iter_new_tokens,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
if return_acc_lengths:
|
| 872 |
+
acc_lengths.append(last_iter_new_tokens)
|
| 873 |
+
if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
|
| 874 |
+
if return_acc_lengths:
|
| 875 |
+
return input_ids, acc_lengths
|
| 876 |
+
else:
|
| 877 |
+
return input_ids
|
| 878 |
+
if return_acc_lengths:
|
| 879 |
+
return input_ids, acc_lengths
|
| 880 |
+
else:
|
| 881 |
+
return input_ids
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/preprocessor_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
|
| 4 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 5 |
+
},
|
| 6 |
+
"do_convert_rgb": true,
|
| 7 |
+
"do_resize": true,
|
| 8 |
+
"image_processor_type": "EvaByteImageProcessor",
|
| 9 |
+
"jpeg_quality": 25,
|
| 10 |
+
"jpeg_restart_marker_blocks": 1,
|
| 11 |
+
"jpeg_streamtype": 2,
|
| 12 |
+
"jpeg_subsampling": "4:2:0",
|
| 13 |
+
"processor_class": "EvaByteProcessor",
|
| 14 |
+
"resample": 1,
|
| 15 |
+
"size": {
|
| 16 |
+
"longest_edge": 384
|
| 17 |
+
}
|
| 18 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processing_evabyte.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""
|
| 3 |
+
Processor class for EvaByte.
|
| 4 |
+
"""
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
import os
|
| 10 |
+
import PIL
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from typing import List, Optional, Union
|
| 14 |
+
|
| 15 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
+
from transformers.image_utils import ImageInput, is_valid_image
|
| 17 |
+
from transformers.processing_utils import ProcessorMixin
|
| 18 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 19 |
+
from transformers.utils import TensorType, to_py_obj
|
| 20 |
+
|
| 21 |
+
def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
|
| 22 |
+
image_obj = None
|
| 23 |
+
if isinstance(image, Image.Image):
|
| 24 |
+
image_obj = image
|
| 25 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 26 |
+
image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
|
| 27 |
+
elif os.path.isfile(image):
|
| 28 |
+
image_obj = Image.open(image)
|
| 29 |
+
elif image.startswith("data:image/"):
|
| 30 |
+
image = image.split(",")[1]
|
| 31 |
+
# Try to load as base64
|
| 32 |
+
try:
|
| 33 |
+
b64 = base64.decodebytes(image.encode())
|
| 34 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
image_obj = Image.open(image)
|
| 41 |
+
if image_obj is None:
|
| 42 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 43 |
+
|
| 44 |
+
return image_obj
|
| 45 |
+
|
| 46 |
+
def is_url(val) -> bool:
|
| 47 |
+
return isinstance(val, str) and val.startswith("http")
|
| 48 |
+
|
| 49 |
+
def is_file(val) -> bool:
|
| 50 |
+
return isinstance(val, str) and os.path.isfile(val)
|
| 51 |
+
|
| 52 |
+
def is_image_or_image_url(elem):
|
| 53 |
+
return is_url(elem) or is_valid_image(elem) or is_file(elem)
|
| 54 |
+
|
| 55 |
+
vl_chat_template = """
|
| 56 |
+
{{- bos_token }}
|
| 57 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 58 |
+
{%- set system_message = messages[0]['content'] %}
|
| 59 |
+
{%- set messages = messages[1:] %}
|
| 60 |
+
{%- else %}
|
| 61 |
+
{%- set system_message = "" %}
|
| 62 |
+
{%- endif %}
|
| 63 |
+
|
| 64 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 65 |
+
|
| 66 |
+
{%- for message in messages %}
|
| 67 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 68 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
|
| 71 |
+
{%- if message['content'] is string %}
|
| 72 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 73 |
+
{%- else %}
|
| 74 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
|
| 75 |
+
{%- for content in message['content'] %}
|
| 76 |
+
{%- if content['type'] == 'image' %}
|
| 77 |
+
{{- '<image_placeholder>\n' }}
|
| 78 |
+
{%- elif content['type'] == 'text' %}
|
| 79 |
+
{{- content['text'] }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- endfor %}
|
| 82 |
+
{{- '<|eot_id|>' }}
|
| 83 |
+
{%- endif %}
|
| 84 |
+
{%- endfor %}
|
| 85 |
+
|
| 86 |
+
{%- if add_generation_prompt %}
|
| 87 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
class EvaByteProcessor(ProcessorMixin):
|
| 92 |
+
r"""
|
| 93 |
+
Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
|
| 94 |
+
|
| 95 |
+
[`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
|
| 96 |
+
[`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image_processor ([`EvaByteImageProcessor`], *optional*):
|
| 100 |
+
The image processor is a required input.
|
| 101 |
+
tokenizer ([`EvaByteTokenizer`], *optional*):
|
| 102 |
+
The tokenizer is a required input.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
attributes = ["image_processor", "tokenizer"]
|
| 106 |
+
image_processor_class = "AutoImageProcessor"
|
| 107 |
+
tokenizer_class = "AutoTokenizer"
|
| 108 |
+
|
| 109 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 110 |
+
if image_processor is None:
|
| 111 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 112 |
+
if tokenizer is None:
|
| 113 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 114 |
+
|
| 115 |
+
super().__init__(image_processor, tokenizer)
|
| 116 |
+
self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
|
| 117 |
+
self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
|
| 118 |
+
self.image_placeholder = "<image_placeholder>"
|
| 119 |
+
self.vl_chat_template = vl_chat_template
|
| 120 |
+
|
| 121 |
+
def __call__(
|
| 122 |
+
self,
|
| 123 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 124 |
+
images: ImageInput = None,
|
| 125 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 126 |
+
strip_ending_sentinel: bool = False,
|
| 127 |
+
encode_only: bool = False,
|
| 128 |
+
**kwargs
|
| 129 |
+
) -> Union[BatchFeature, List[List[int]]]:
|
| 130 |
+
# processing pipeline:
|
| 131 |
+
# 1. read images or videos from paths
|
| 132 |
+
# 2. use image_processor to convert images / videos to byte streams
|
| 133 |
+
if images is not None:
|
| 134 |
+
if isinstance(images, bytes):
|
| 135 |
+
image_bytes_list = [[images]]
|
| 136 |
+
elif isinstance(images, list) and isinstance(images[0], bytes):
|
| 137 |
+
image_bytes_list = [images]
|
| 138 |
+
elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
|
| 139 |
+
image_bytes_list = images
|
| 140 |
+
else:
|
| 141 |
+
if is_image_or_image_url(images):
|
| 142 |
+
images = [[images]]
|
| 143 |
+
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
| 144 |
+
images = [images]
|
| 145 |
+
elif (
|
| 146 |
+
not isinstance(images, list)
|
| 147 |
+
and not isinstance(images[0], list)
|
| 148 |
+
and not is_image_or_image_url(images[0][0])
|
| 149 |
+
):
|
| 150 |
+
raise ValueError(
|
| 151 |
+
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
| 152 |
+
)
|
| 153 |
+
# Load images if they are URLs
|
| 154 |
+
images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
|
| 155 |
+
image_bytes_list = self.image_processor(images=images, **kwargs)
|
| 156 |
+
|
| 157 |
+
if not isinstance(text, list):
|
| 158 |
+
text = [text]
|
| 159 |
+
assert len(text) == 1, "Only support batch size 1 for now"
|
| 160 |
+
assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
|
| 161 |
+
# TODO: invoke SequenceFeatureExtractor to get batched inputs
|
| 162 |
+
|
| 163 |
+
# 3. tokenize the text and put images / videos byte streams into the placeholders
|
| 164 |
+
# surrounded by special tokens like "<image>" and "</image>"
|
| 165 |
+
batch_input_ids = []
|
| 166 |
+
if not encode_only:
|
| 167 |
+
batch_attention_mask = []
|
| 168 |
+
else:
|
| 169 |
+
batch_attention_mask = None
|
| 170 |
+
|
| 171 |
+
for t, image_bytes in zip(text, image_bytes_list):
|
| 172 |
+
text_splits = t.split(self.image_placeholder)
|
| 173 |
+
if len(text_splits) != len(image_bytes) + 1:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"The number of image tokens should be equal to the number of images, "
|
| 176 |
+
f"but got {len(text_splits)} and {len(image_bytes) + 1}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
input_ids = [self.tokenizer.bos_token_id]
|
| 180 |
+
for i, text_part in enumerate(text_splits):
|
| 181 |
+
# each text part must be non-empty because we added markers around placeholders
|
| 182 |
+
split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
|
| 183 |
+
input_ids.extend(split_tokens)
|
| 184 |
+
# Add image bytes after each text part except the last one
|
| 185 |
+
if i < len(image_bytes):
|
| 186 |
+
input_ids.append(self.t2v_token_id)
|
| 187 |
+
input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
|
| 188 |
+
input_ids.append(self.v2t_token_id)
|
| 189 |
+
|
| 190 |
+
if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
|
| 191 |
+
input_ids = input_ids[:-1]
|
| 192 |
+
|
| 193 |
+
batch_input_ids.append(input_ids)
|
| 194 |
+
if not encode_only:
|
| 195 |
+
batch_attention_mask.append([1] * len(input_ids))
|
| 196 |
+
|
| 197 |
+
if not encode_only:
|
| 198 |
+
# 4. return batch of features
|
| 199 |
+
inputs = BatchFeature({
|
| 200 |
+
"input_ids": batch_input_ids,
|
| 201 |
+
"attention_mask": batch_attention_mask
|
| 202 |
+
}, tensor_type=return_tensors)
|
| 203 |
+
return inputs
|
| 204 |
+
# # Pad sequences
|
| 205 |
+
# padded_inputs = self.tokenizer.pad(
|
| 206 |
+
# {"input_ids": batch_input_ids},
|
| 207 |
+
# padding=True,
|
| 208 |
+
# return_attention_mask=True,
|
| 209 |
+
# return_tensors=return_tensors,
|
| 210 |
+
# )
|
| 211 |
+
# return BatchFeature(data=padded_inputs)
|
| 212 |
+
else:
|
| 213 |
+
return batch_input_ids
|
| 214 |
+
|
| 215 |
+
def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
|
| 216 |
+
image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
|
| 217 |
+
image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
|
| 218 |
+
return image_bytes
|
| 219 |
+
|
| 220 |
+
def batch_decode(self, sequences, **kwargs):
|
| 221 |
+
"""
|
| 222 |
+
This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 223 |
+
refer to the docstring of this method for more information.
|
| 224 |
+
"""
|
| 225 |
+
rets = [self.decode(seq, **kwargs) for seq in sequences]
|
| 226 |
+
return tuple(map(list, zip(*rets)))
|
| 227 |
+
|
| 228 |
+
def decode(self, token_ids, **kwargs):
|
| 229 |
+
"""
|
| 230 |
+
Decodes a sequence of input_ids, handling image tokens separately.
|
| 231 |
+
Returns a tuple of (decoded_text, images), where images is a list of bytes.
|
| 232 |
+
"""
|
| 233 |
+
if kwargs and "jpeg_quality" in kwargs:
|
| 234 |
+
kwargs = kwargs.copy()
|
| 235 |
+
jpeg_quality = kwargs.pop("jpeg_quality")
|
| 236 |
+
else:
|
| 237 |
+
jpeg_quality = None
|
| 238 |
+
|
| 239 |
+
token_ids = to_py_obj(token_ids)
|
| 240 |
+
# Find indices of t2v_token_id and v2t_token_id
|
| 241 |
+
t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
|
| 242 |
+
v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
|
| 243 |
+
|
| 244 |
+
# Check for correct pairing of t2v and v2t tokens
|
| 245 |
+
if len(t2v_indices) != len(v2t_indices):
|
| 246 |
+
raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
|
| 247 |
+
|
| 248 |
+
# Ensure t2v and v2t tokens are in the correct order
|
| 249 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 250 |
+
if t2v_idx >= v2t_idx:
|
| 251 |
+
raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
|
| 252 |
+
|
| 253 |
+
# Initialize the start index
|
| 254 |
+
images = []
|
| 255 |
+
decoded_text = ""
|
| 256 |
+
|
| 257 |
+
start = 0
|
| 258 |
+
# Iterate over pairs of t2v and v2t indices
|
| 259 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 260 |
+
# Decode text tokens before the image
|
| 261 |
+
text_token_ids = token_ids[start:t2v_idx]
|
| 262 |
+
if len(text_token_ids) > 0:
|
| 263 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 264 |
+
|
| 265 |
+
# Insert image placeholder
|
| 266 |
+
decoded_text += self.image_placeholder
|
| 267 |
+
|
| 268 |
+
# Extract image tokens and convert them to bytes
|
| 269 |
+
image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
|
| 270 |
+
image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
|
| 271 |
+
images.append(image_bytes)
|
| 272 |
+
|
| 273 |
+
# Update the start index to the token after v2t_token_id
|
| 274 |
+
start = v2t_idx + 1
|
| 275 |
+
|
| 276 |
+
# Decode any remaining text tokens after the last image
|
| 277 |
+
if start < len(token_ids):
|
| 278 |
+
text_token_ids = token_ids[start:]
|
| 279 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 280 |
+
|
| 281 |
+
return decoded_text, images
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def model_input_names(self):
|
| 285 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 286 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 287 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "EvaByteProcessor"
|
| 6 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/special_tokens_map.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<repo_name>",
|
| 4 |
+
"<file_sep>",
|
| 5 |
+
"<t2v_token>",
|
| 6 |
+
"<v2t_token>",
|
| 7 |
+
"<|start_header_id|>",
|
| 8 |
+
"<|end_header_id|>",
|
| 9 |
+
"<|eot_id|>",
|
| 10 |
+
"<extra_id_12>",
|
| 11 |
+
"<extra_id_13>",
|
| 12 |
+
"<extra_id_14>",
|
| 13 |
+
"<extra_id_15>",
|
| 14 |
+
"<extra_id_16>",
|
| 15 |
+
"<extra_id_17>",
|
| 16 |
+
"<extra_id_18>",
|
| 17 |
+
"<extra_id_19>",
|
| 18 |
+
"<extra_id_20>",
|
| 19 |
+
"<extra_id_21>",
|
| 20 |
+
"<extra_id_22>",
|
| 21 |
+
"<extra_id_23>",
|
| 22 |
+
"<extra_id_24>",
|
| 23 |
+
"<extra_id_25>",
|
| 24 |
+
"<extra_id_26>",
|
| 25 |
+
"<extra_id_27>",
|
| 26 |
+
"<extra_id_28>",
|
| 27 |
+
"<extra_id_29>",
|
| 28 |
+
"<extra_id_30>",
|
| 29 |
+
"<extra_id_31>",
|
| 30 |
+
"<extra_id_32>",
|
| 31 |
+
"<extra_id_33>",
|
| 32 |
+
"<extra_id_34>",
|
| 33 |
+
"<extra_id_35>",
|
| 34 |
+
"<extra_id_36>",
|
| 35 |
+
"<extra_id_37>",
|
| 36 |
+
"<extra_id_38>",
|
| 37 |
+
"<extra_id_39>",
|
| 38 |
+
"<extra_id_40>",
|
| 39 |
+
"<extra_id_41>",
|
| 40 |
+
"<extra_id_42>",
|
| 41 |
+
"<extra_id_43>",
|
| 42 |
+
"<extra_id_44>",
|
| 43 |
+
"<extra_id_45>",
|
| 44 |
+
"<extra_id_46>",
|
| 45 |
+
"<extra_id_47>",
|
| 46 |
+
"<extra_id_48>",
|
| 47 |
+
"<extra_id_49>",
|
| 48 |
+
"<extra_id_50>",
|
| 49 |
+
"<extra_id_51>",
|
| 50 |
+
"<extra_id_52>",
|
| 51 |
+
"<extra_id_53>",
|
| 52 |
+
"<extra_id_54>",
|
| 53 |
+
"<extra_id_55>",
|
| 54 |
+
"<extra_id_56>",
|
| 55 |
+
"<extra_id_57>",
|
| 56 |
+
"<extra_id_58>",
|
| 57 |
+
"<extra_id_59>",
|
| 58 |
+
"<extra_id_60>",
|
| 59 |
+
"<extra_id_61>",
|
| 60 |
+
"<extra_id_62>",
|
| 61 |
+
"<extra_id_63>"
|
| 62 |
+
],
|
| 63 |
+
"bos_token": {
|
| 64 |
+
"content": "<bos>",
|
| 65 |
+
"lstrip": false,
|
| 66 |
+
"normalized": true,
|
| 67 |
+
"rstrip": false,
|
| 68 |
+
"single_word": false
|
| 69 |
+
},
|
| 70 |
+
"eos_token": {
|
| 71 |
+
"content": "<eos>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": true,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false
|
| 76 |
+
},
|
| 77 |
+
"pad_token": {
|
| 78 |
+
"content": "<pad>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": true,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false
|
| 83 |
+
},
|
| 84 |
+
"sep_token": {
|
| 85 |
+
"content": "<sep>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": true,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false
|
| 90 |
+
},
|
| 91 |
+
"unk_token": {
|
| 92 |
+
"content": "<unk>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false
|
| 97 |
+
}
|
| 98 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenization_evabyte.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
""" Tokenization class for model EvaByte."""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
chat_template = """
|
| 16 |
+
{{- bos_token }}
|
| 17 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 18 |
+
{%- set system_message = messages[0]['content'] %}
|
| 19 |
+
{%- set messages = messages[1:] %}
|
| 20 |
+
{%- else %}
|
| 21 |
+
{%- set system_message = "" %}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
|
| 24 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 25 |
+
|
| 26 |
+
{%- for message in messages %}
|
| 27 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 28 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
|
| 31 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 32 |
+
{%- endfor %}
|
| 33 |
+
|
| 34 |
+
{%- if add_generation_prompt %}
|
| 35 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 36 |
+
{%- endif %}
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
class EvaByteTokenizer(PreTrainedTokenizer):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
bos_token="<bos>",
|
| 43 |
+
eos_token="<eos>",
|
| 44 |
+
unk_token="<unk>",
|
| 45 |
+
sep_token="<sep>",
|
| 46 |
+
pad_token="<pad>",
|
| 47 |
+
extra_ids=59,
|
| 48 |
+
additional_special_tokens=None,
|
| 49 |
+
clean_up_tokenization_spaces=False,
|
| 50 |
+
**kwargs,
|
| 51 |
+
) -> None:
|
| 52 |
+
num_base_special_tokens = 5
|
| 53 |
+
# Add extra_ids to the special token list
|
| 54 |
+
if extra_ids > 0 and additional_special_tokens is None:
|
| 55 |
+
additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
|
| 56 |
+
elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
|
| 57 |
+
# Check that we have the right number of extra_id special tokens
|
| 58 |
+
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
| 59 |
+
if extra_tokens != extra_ids:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
|
| 62 |
+
" provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
|
| 63 |
+
" extra_ids tokens"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
#### override some reserved tokens to support chat template
|
| 67 |
+
for i, token in enumerate(additional_special_tokens):
|
| 68 |
+
if token == "<extra_id_5>":
|
| 69 |
+
token = "<repo_name>"
|
| 70 |
+
elif token == "<extra_id_6>":
|
| 71 |
+
token = "<file_sep>"
|
| 72 |
+
elif token == "<extra_id_7>":
|
| 73 |
+
token = "<t2v_token>"
|
| 74 |
+
elif token == "<extra_id_8>":
|
| 75 |
+
token = "<v2t_token>"
|
| 76 |
+
elif token == "<extra_id_9>":
|
| 77 |
+
token = "<|start_header_id|>"
|
| 78 |
+
elif token == "<extra_id_10>":
|
| 79 |
+
token = "<|end_header_id|>"
|
| 80 |
+
elif token == "<extra_id_11>":
|
| 81 |
+
token = "<|eot_id|>"
|
| 82 |
+
additional_special_tokens[i] = token
|
| 83 |
+
|
| 84 |
+
# lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
|
| 85 |
+
# this would be important for the byte tokenizer
|
| 86 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
| 87 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
| 88 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
| 89 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 90 |
+
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
| 91 |
+
|
| 92 |
+
self._added_tokens_decoder = {
|
| 93 |
+
0: pad_token,
|
| 94 |
+
1: bos_token,
|
| 95 |
+
2: eos_token,
|
| 96 |
+
3: unk_token, # unk_token is a placeholder
|
| 97 |
+
4: sep_token,
|
| 98 |
+
**{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
|
| 99 |
+
}
|
| 100 |
+
self.offset = len(self._added_tokens_decoder)
|
| 101 |
+
self._utf_vocab_size = 2**8 # utf is 8 bits
|
| 102 |
+
self.add_bos_token = True
|
| 103 |
+
self.add_eos_token = False
|
| 104 |
+
super().__init__(
|
| 105 |
+
pad_token=pad_token,
|
| 106 |
+
bos_token=bos_token,
|
| 107 |
+
eos_token=eos_token,
|
| 108 |
+
unk_token=unk_token,
|
| 109 |
+
sep_token=sep_token,
|
| 110 |
+
extra_ids=0,
|
| 111 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 112 |
+
additional_special_tokens=additional_special_tokens,
|
| 113 |
+
**kwargs,
|
| 114 |
+
)
|
| 115 |
+
self.chat_template = chat_template
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def vocab_size(self):
|
| 120 |
+
return self._utf_vocab_size
|
| 121 |
+
|
| 122 |
+
def get_vocab(self):
|
| 123 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
| 124 |
+
vocab.update(self.added_tokens_encoder)
|
| 125 |
+
return vocab
|
| 126 |
+
|
| 127 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 128 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 129 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 130 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 131 |
+
|
| 132 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 133 |
+
|
| 134 |
+
if token_ids_1 is not None:
|
| 135 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 136 |
+
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 140 |
+
def get_special_tokens_mask(
|
| 141 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 142 |
+
) -> List[int]:
|
| 143 |
+
"""
|
| 144 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 145 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (`List[int]`):
|
| 149 |
+
List of IDs.
|
| 150 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 153 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 157 |
+
"""
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
return super().get_special_tokens_mask(
|
| 160 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 164 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is None:
|
| 167 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 168 |
+
return (
|
| 169 |
+
bos_token_id
|
| 170 |
+
+ ([0] * len(token_ids_0))
|
| 171 |
+
+ eos_token_id
|
| 172 |
+
+ bos_token_id
|
| 173 |
+
+ ([0] * len(token_ids_1))
|
| 174 |
+
+ eos_token_id
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 178 |
+
def create_token_type_ids_from_sequences(
|
| 179 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 180 |
+
) -> List[int]:
|
| 181 |
+
"""
|
| 182 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 183 |
+
sequence pair mask has the following format:
|
| 184 |
+
|
| 185 |
+
```
|
| 186 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 187 |
+
| first sequence | second sequence |
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
token_ids_0 (`List[int]`):
|
| 194 |
+
List of ids.
|
| 195 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 196 |
+
Optional second list of IDs for sequence pairs.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 200 |
+
"""
|
| 201 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 202 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 203 |
+
|
| 204 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 205 |
+
|
| 206 |
+
if token_ids_1 is not None:
|
| 207 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 208 |
+
|
| 209 |
+
return output
|
| 210 |
+
|
| 211 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 212 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
| 213 |
+
tokens = [chr(i) for i in text.encode("utf-8")]
|
| 214 |
+
return tokens
|
| 215 |
+
|
| 216 |
+
def _convert_token_to_id(self, token):
|
| 217 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 218 |
+
|
| 219 |
+
if len(token) != 1:
|
| 220 |
+
token_id = None
|
| 221 |
+
else:
|
| 222 |
+
token_id = ord(token) + self.offset
|
| 223 |
+
|
| 224 |
+
return token_id
|
| 225 |
+
|
| 226 |
+
def _convert_id_to_token(self, index):
|
| 227 |
+
"""Converts an index (integer) to a byte (str) using the vocab."""
|
| 228 |
+
token = chr(index - self.offset)
|
| 229 |
+
return token
|
| 230 |
+
|
| 231 |
+
def convert_tokens_to_string(self, tokens):
|
| 232 |
+
"""Converts a sequence of bytes (string) to a single string."""
|
| 233 |
+
bstring = b""
|
| 234 |
+
for token in tokens:
|
| 235 |
+
if token in self.added_tokens_decoder:
|
| 236 |
+
tok_string = self.added_tokens_decoder[token].encode("utf-8")
|
| 237 |
+
elif token in self.added_tokens_encoder:
|
| 238 |
+
tok_string = token.encode("utf-8")
|
| 239 |
+
else:
|
| 240 |
+
tok_string = bytes([ord(token)])
|
| 241 |
+
bstring += tok_string
|
| 242 |
+
string = bstring.decode("utf-8", errors="ignore")
|
| 243 |
+
return string
|
| 244 |
+
|
| 245 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 246 |
+
return ()
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenizer_config.json
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": true,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<bos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": true,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "<eos>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": true,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "<sep>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": true,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"5": {
|
| 44 |
+
"content": "<repo_name>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": true,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": false
|
| 50 |
+
},
|
| 51 |
+
"6": {
|
| 52 |
+
"content": "<file_sep>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": true,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": false
|
| 58 |
+
},
|
| 59 |
+
"7": {
|
| 60 |
+
"content": "<t2v_token>",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": true,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": false
|
| 66 |
+
},
|
| 67 |
+
"8": {
|
| 68 |
+
"content": "<v2t_token>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": true,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": false
|
| 74 |
+
},
|
| 75 |
+
"9": {
|
| 76 |
+
"content": "<|start_header_id|>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": true,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": false
|
| 82 |
+
},
|
| 83 |
+
"10": {
|
| 84 |
+
"content": "<|end_header_id|>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": true,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": false
|
| 90 |
+
},
|
| 91 |
+
"11": {
|
| 92 |
+
"content": "<|eot_id|>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": false
|
| 98 |
+
},
|
| 99 |
+
"12": {
|
| 100 |
+
"content": "<extra_id_12>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": true,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": false
|
| 106 |
+
},
|
| 107 |
+
"13": {
|
| 108 |
+
"content": "<extra_id_13>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": true,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": false
|
| 114 |
+
},
|
| 115 |
+
"14": {
|
| 116 |
+
"content": "<extra_id_14>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": true,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": false
|
| 122 |
+
},
|
| 123 |
+
"15": {
|
| 124 |
+
"content": "<extra_id_15>",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": true,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": false
|
| 130 |
+
},
|
| 131 |
+
"16": {
|
| 132 |
+
"content": "<extra_id_16>",
|
| 133 |
+
"lstrip": false,
|
| 134 |
+
"normalized": true,
|
| 135 |
+
"rstrip": false,
|
| 136 |
+
"single_word": false,
|
| 137 |
+
"special": false
|
| 138 |
+
},
|
| 139 |
+
"17": {
|
| 140 |
+
"content": "<extra_id_17>",
|
| 141 |
+
"lstrip": false,
|
| 142 |
+
"normalized": true,
|
| 143 |
+
"rstrip": false,
|
| 144 |
+
"single_word": false,
|
| 145 |
+
"special": false
|
| 146 |
+
},
|
| 147 |
+
"18": {
|
| 148 |
+
"content": "<extra_id_18>",
|
| 149 |
+
"lstrip": false,
|
| 150 |
+
"normalized": true,
|
| 151 |
+
"rstrip": false,
|
| 152 |
+
"single_word": false,
|
| 153 |
+
"special": false
|
| 154 |
+
},
|
| 155 |
+
"19": {
|
| 156 |
+
"content": "<extra_id_19>",
|
| 157 |
+
"lstrip": false,
|
| 158 |
+
"normalized": true,
|
| 159 |
+
"rstrip": false,
|
| 160 |
+
"single_word": false,
|
| 161 |
+
"special": false
|
| 162 |
+
},
|
| 163 |
+
"20": {
|
| 164 |
+
"content": "<extra_id_20>",
|
| 165 |
+
"lstrip": false,
|
| 166 |
+
"normalized": true,
|
| 167 |
+
"rstrip": false,
|
| 168 |
+
"single_word": false,
|
| 169 |
+
"special": false
|
| 170 |
+
},
|
| 171 |
+
"21": {
|
| 172 |
+
"content": "<extra_id_21>",
|
| 173 |
+
"lstrip": false,
|
| 174 |
+
"normalized": true,
|
| 175 |
+
"rstrip": false,
|
| 176 |
+
"single_word": false,
|
| 177 |
+
"special": false
|
| 178 |
+
},
|
| 179 |
+
"22": {
|
| 180 |
+
"content": "<extra_id_22>",
|
| 181 |
+
"lstrip": false,
|
| 182 |
+
"normalized": true,
|
| 183 |
+
"rstrip": false,
|
| 184 |
+
"single_word": false,
|
| 185 |
+
"special": false
|
| 186 |
+
},
|
| 187 |
+
"23": {
|
| 188 |
+
"content": "<extra_id_23>",
|
| 189 |
+
"lstrip": false,
|
| 190 |
+
"normalized": true,
|
| 191 |
+
"rstrip": false,
|
| 192 |
+
"single_word": false,
|
| 193 |
+
"special": false
|
| 194 |
+
},
|
| 195 |
+
"24": {
|
| 196 |
+
"content": "<extra_id_24>",
|
| 197 |
+
"lstrip": false,
|
| 198 |
+
"normalized": true,
|
| 199 |
+
"rstrip": false,
|
| 200 |
+
"single_word": false,
|
| 201 |
+
"special": false
|
| 202 |
+
},
|
| 203 |
+
"25": {
|
| 204 |
+
"content": "<extra_id_25>",
|
| 205 |
+
"lstrip": false,
|
| 206 |
+
"normalized": true,
|
| 207 |
+
"rstrip": false,
|
| 208 |
+
"single_word": false,
|
| 209 |
+
"special": false
|
| 210 |
+
},
|
| 211 |
+
"26": {
|
| 212 |
+
"content": "<extra_id_26>",
|
| 213 |
+
"lstrip": false,
|
| 214 |
+
"normalized": true,
|
| 215 |
+
"rstrip": false,
|
| 216 |
+
"single_word": false,
|
| 217 |
+
"special": false
|
| 218 |
+
},
|
| 219 |
+
"27": {
|
| 220 |
+
"content": "<extra_id_27>",
|
| 221 |
+
"lstrip": false,
|
| 222 |
+
"normalized": true,
|
| 223 |
+
"rstrip": false,
|
| 224 |
+
"single_word": false,
|
| 225 |
+
"special": false
|
| 226 |
+
},
|
| 227 |
+
"28": {
|
| 228 |
+
"content": "<extra_id_28>",
|
| 229 |
+
"lstrip": false,
|
| 230 |
+
"normalized": true,
|
| 231 |
+
"rstrip": false,
|
| 232 |
+
"single_word": false,
|
| 233 |
+
"special": false
|
| 234 |
+
},
|
| 235 |
+
"29": {
|
| 236 |
+
"content": "<extra_id_29>",
|
| 237 |
+
"lstrip": false,
|
| 238 |
+
"normalized": true,
|
| 239 |
+
"rstrip": false,
|
| 240 |
+
"single_word": false,
|
| 241 |
+
"special": false
|
| 242 |
+
},
|
| 243 |
+
"30": {
|
| 244 |
+
"content": "<extra_id_30>",
|
| 245 |
+
"lstrip": false,
|
| 246 |
+
"normalized": true,
|
| 247 |
+
"rstrip": false,
|
| 248 |
+
"single_word": false,
|
| 249 |
+
"special": false
|
| 250 |
+
},
|
| 251 |
+
"31": {
|
| 252 |
+
"content": "<extra_id_31>",
|
| 253 |
+
"lstrip": false,
|
| 254 |
+
"normalized": true,
|
| 255 |
+
"rstrip": false,
|
| 256 |
+
"single_word": false,
|
| 257 |
+
"special": false
|
| 258 |
+
},
|
| 259 |
+
"32": {
|
| 260 |
+
"content": "<extra_id_32>",
|
| 261 |
+
"lstrip": false,
|
| 262 |
+
"normalized": true,
|
| 263 |
+
"rstrip": false,
|
| 264 |
+
"single_word": false,
|
| 265 |
+
"special": false
|
| 266 |
+
},
|
| 267 |
+
"33": {
|
| 268 |
+
"content": "<extra_id_33>",
|
| 269 |
+
"lstrip": false,
|
| 270 |
+
"normalized": true,
|
| 271 |
+
"rstrip": false,
|
| 272 |
+
"single_word": false,
|
| 273 |
+
"special": false
|
| 274 |
+
},
|
| 275 |
+
"34": {
|
| 276 |
+
"content": "<extra_id_34>",
|
| 277 |
+
"lstrip": false,
|
| 278 |
+
"normalized": true,
|
| 279 |
+
"rstrip": false,
|
| 280 |
+
"single_word": false,
|
| 281 |
+
"special": false
|
| 282 |
+
},
|
| 283 |
+
"35": {
|
| 284 |
+
"content": "<extra_id_35>",
|
| 285 |
+
"lstrip": false,
|
| 286 |
+
"normalized": true,
|
| 287 |
+
"rstrip": false,
|
| 288 |
+
"single_word": false,
|
| 289 |
+
"special": false
|
| 290 |
+
},
|
| 291 |
+
"36": {
|
| 292 |
+
"content": "<extra_id_36>",
|
| 293 |
+
"lstrip": false,
|
| 294 |
+
"normalized": true,
|
| 295 |
+
"rstrip": false,
|
| 296 |
+
"single_word": false,
|
| 297 |
+
"special": false
|
| 298 |
+
},
|
| 299 |
+
"37": {
|
| 300 |
+
"content": "<extra_id_37>",
|
| 301 |
+
"lstrip": false,
|
| 302 |
+
"normalized": true,
|
| 303 |
+
"rstrip": false,
|
| 304 |
+
"single_word": false,
|
| 305 |
+
"special": false
|
| 306 |
+
},
|
| 307 |
+
"38": {
|
| 308 |
+
"content": "<extra_id_38>",
|
| 309 |
+
"lstrip": false,
|
| 310 |
+
"normalized": true,
|
| 311 |
+
"rstrip": false,
|
| 312 |
+
"single_word": false,
|
| 313 |
+
"special": false
|
| 314 |
+
},
|
| 315 |
+
"39": {
|
| 316 |
+
"content": "<extra_id_39>",
|
| 317 |
+
"lstrip": false,
|
| 318 |
+
"normalized": true,
|
| 319 |
+
"rstrip": false,
|
| 320 |
+
"single_word": false,
|
| 321 |
+
"special": false
|
| 322 |
+
},
|
| 323 |
+
"40": {
|
| 324 |
+
"content": "<extra_id_40>",
|
| 325 |
+
"lstrip": false,
|
| 326 |
+
"normalized": true,
|
| 327 |
+
"rstrip": false,
|
| 328 |
+
"single_word": false,
|
| 329 |
+
"special": false
|
| 330 |
+
},
|
| 331 |
+
"41": {
|
| 332 |
+
"content": "<extra_id_41>",
|
| 333 |
+
"lstrip": false,
|
| 334 |
+
"normalized": true,
|
| 335 |
+
"rstrip": false,
|
| 336 |
+
"single_word": false,
|
| 337 |
+
"special": false
|
| 338 |
+
},
|
| 339 |
+
"42": {
|
| 340 |
+
"content": "<extra_id_42>",
|
| 341 |
+
"lstrip": false,
|
| 342 |
+
"normalized": true,
|
| 343 |
+
"rstrip": false,
|
| 344 |
+
"single_word": false,
|
| 345 |
+
"special": false
|
| 346 |
+
},
|
| 347 |
+
"43": {
|
| 348 |
+
"content": "<extra_id_43>",
|
| 349 |
+
"lstrip": false,
|
| 350 |
+
"normalized": true,
|
| 351 |
+
"rstrip": false,
|
| 352 |
+
"single_word": false,
|
| 353 |
+
"special": false
|
| 354 |
+
},
|
| 355 |
+
"44": {
|
| 356 |
+
"content": "<extra_id_44>",
|
| 357 |
+
"lstrip": false,
|
| 358 |
+
"normalized": true,
|
| 359 |
+
"rstrip": false,
|
| 360 |
+
"single_word": false,
|
| 361 |
+
"special": false
|
| 362 |
+
},
|
| 363 |
+
"45": {
|
| 364 |
+
"content": "<extra_id_45>",
|
| 365 |
+
"lstrip": false,
|
| 366 |
+
"normalized": true,
|
| 367 |
+
"rstrip": false,
|
| 368 |
+
"single_word": false,
|
| 369 |
+
"special": false
|
| 370 |
+
},
|
| 371 |
+
"46": {
|
| 372 |
+
"content": "<extra_id_46>",
|
| 373 |
+
"lstrip": false,
|
| 374 |
+
"normalized": true,
|
| 375 |
+
"rstrip": false,
|
| 376 |
+
"single_word": false,
|
| 377 |
+
"special": false
|
| 378 |
+
},
|
| 379 |
+
"47": {
|
| 380 |
+
"content": "<extra_id_47>",
|
| 381 |
+
"lstrip": false,
|
| 382 |
+
"normalized": true,
|
| 383 |
+
"rstrip": false,
|
| 384 |
+
"single_word": false,
|
| 385 |
+
"special": false
|
| 386 |
+
},
|
| 387 |
+
"48": {
|
| 388 |
+
"content": "<extra_id_48>",
|
| 389 |
+
"lstrip": false,
|
| 390 |
+
"normalized": true,
|
| 391 |
+
"rstrip": false,
|
| 392 |
+
"single_word": false,
|
| 393 |
+
"special": false
|
| 394 |
+
},
|
| 395 |
+
"49": {
|
| 396 |
+
"content": "<extra_id_49>",
|
| 397 |
+
"lstrip": false,
|
| 398 |
+
"normalized": true,
|
| 399 |
+
"rstrip": false,
|
| 400 |
+
"single_word": false,
|
| 401 |
+
"special": false
|
| 402 |
+
},
|
| 403 |
+
"50": {
|
| 404 |
+
"content": "<extra_id_50>",
|
| 405 |
+
"lstrip": false,
|
| 406 |
+
"normalized": true,
|
| 407 |
+
"rstrip": false,
|
| 408 |
+
"single_word": false,
|
| 409 |
+
"special": false
|
| 410 |
+
},
|
| 411 |
+
"51": {
|
| 412 |
+
"content": "<extra_id_51>",
|
| 413 |
+
"lstrip": false,
|
| 414 |
+
"normalized": true,
|
| 415 |
+
"rstrip": false,
|
| 416 |
+
"single_word": false,
|
| 417 |
+
"special": false
|
| 418 |
+
},
|
| 419 |
+
"52": {
|
| 420 |
+
"content": "<extra_id_52>",
|
| 421 |
+
"lstrip": false,
|
| 422 |
+
"normalized": true,
|
| 423 |
+
"rstrip": false,
|
| 424 |
+
"single_word": false,
|
| 425 |
+
"special": false
|
| 426 |
+
},
|
| 427 |
+
"53": {
|
| 428 |
+
"content": "<extra_id_53>",
|
| 429 |
+
"lstrip": false,
|
| 430 |
+
"normalized": true,
|
| 431 |
+
"rstrip": false,
|
| 432 |
+
"single_word": false,
|
| 433 |
+
"special": false
|
| 434 |
+
},
|
| 435 |
+
"54": {
|
| 436 |
+
"content": "<extra_id_54>",
|
| 437 |
+
"lstrip": false,
|
| 438 |
+
"normalized": true,
|
| 439 |
+
"rstrip": false,
|
| 440 |
+
"single_word": false,
|
| 441 |
+
"special": false
|
| 442 |
+
},
|
| 443 |
+
"55": {
|
| 444 |
+
"content": "<extra_id_55>",
|
| 445 |
+
"lstrip": false,
|
| 446 |
+
"normalized": true,
|
| 447 |
+
"rstrip": false,
|
| 448 |
+
"single_word": false,
|
| 449 |
+
"special": false
|
| 450 |
+
},
|
| 451 |
+
"56": {
|
| 452 |
+
"content": "<extra_id_56>",
|
| 453 |
+
"lstrip": false,
|
| 454 |
+
"normalized": true,
|
| 455 |
+
"rstrip": false,
|
| 456 |
+
"single_word": false,
|
| 457 |
+
"special": false
|
| 458 |
+
},
|
| 459 |
+
"57": {
|
| 460 |
+
"content": "<extra_id_57>",
|
| 461 |
+
"lstrip": false,
|
| 462 |
+
"normalized": true,
|
| 463 |
+
"rstrip": false,
|
| 464 |
+
"single_word": false,
|
| 465 |
+
"special": false
|
| 466 |
+
},
|
| 467 |
+
"58": {
|
| 468 |
+
"content": "<extra_id_58>",
|
| 469 |
+
"lstrip": false,
|
| 470 |
+
"normalized": true,
|
| 471 |
+
"rstrip": false,
|
| 472 |
+
"single_word": false,
|
| 473 |
+
"special": false
|
| 474 |
+
},
|
| 475 |
+
"59": {
|
| 476 |
+
"content": "<extra_id_59>",
|
| 477 |
+
"lstrip": false,
|
| 478 |
+
"normalized": true,
|
| 479 |
+
"rstrip": false,
|
| 480 |
+
"single_word": false,
|
| 481 |
+
"special": false
|
| 482 |
+
},
|
| 483 |
+
"60": {
|
| 484 |
+
"content": "<extra_id_60>",
|
| 485 |
+
"lstrip": false,
|
| 486 |
+
"normalized": true,
|
| 487 |
+
"rstrip": false,
|
| 488 |
+
"single_word": false,
|
| 489 |
+
"special": false
|
| 490 |
+
},
|
| 491 |
+
"61": {
|
| 492 |
+
"content": "<extra_id_61>",
|
| 493 |
+
"lstrip": false,
|
| 494 |
+
"normalized": true,
|
| 495 |
+
"rstrip": false,
|
| 496 |
+
"single_word": false,
|
| 497 |
+
"special": false
|
| 498 |
+
},
|
| 499 |
+
"62": {
|
| 500 |
+
"content": "<extra_id_62>",
|
| 501 |
+
"lstrip": false,
|
| 502 |
+
"normalized": true,
|
| 503 |
+
"rstrip": false,
|
| 504 |
+
"single_word": false,
|
| 505 |
+
"special": false
|
| 506 |
+
},
|
| 507 |
+
"63": {
|
| 508 |
+
"content": "<extra_id_63>",
|
| 509 |
+
"lstrip": false,
|
| 510 |
+
"normalized": true,
|
| 511 |
+
"rstrip": false,
|
| 512 |
+
"single_word": false,
|
| 513 |
+
"special": false
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"additional_special_tokens": [
|
| 517 |
+
"<repo_name>",
|
| 518 |
+
"<file_sep>",
|
| 519 |
+
"<t2v_token>",
|
| 520 |
+
"<v2t_token>",
|
| 521 |
+
"<|start_header_id|>",
|
| 522 |
+
"<|end_header_id|>",
|
| 523 |
+
"<|eot_id|>",
|
| 524 |
+
"<extra_id_12>",
|
| 525 |
+
"<extra_id_13>",
|
| 526 |
+
"<extra_id_14>",
|
| 527 |
+
"<extra_id_15>",
|
| 528 |
+
"<extra_id_16>",
|
| 529 |
+
"<extra_id_17>",
|
| 530 |
+
"<extra_id_18>",
|
| 531 |
+
"<extra_id_19>",
|
| 532 |
+
"<extra_id_20>",
|
| 533 |
+
"<extra_id_21>",
|
| 534 |
+
"<extra_id_22>",
|
| 535 |
+
"<extra_id_23>",
|
| 536 |
+
"<extra_id_24>",
|
| 537 |
+
"<extra_id_25>",
|
| 538 |
+
"<extra_id_26>",
|
| 539 |
+
"<extra_id_27>",
|
| 540 |
+
"<extra_id_28>",
|
| 541 |
+
"<extra_id_29>",
|
| 542 |
+
"<extra_id_30>",
|
| 543 |
+
"<extra_id_31>",
|
| 544 |
+
"<extra_id_32>",
|
| 545 |
+
"<extra_id_33>",
|
| 546 |
+
"<extra_id_34>",
|
| 547 |
+
"<extra_id_35>",
|
| 548 |
+
"<extra_id_36>",
|
| 549 |
+
"<extra_id_37>",
|
| 550 |
+
"<extra_id_38>",
|
| 551 |
+
"<extra_id_39>",
|
| 552 |
+
"<extra_id_40>",
|
| 553 |
+
"<extra_id_41>",
|
| 554 |
+
"<extra_id_42>",
|
| 555 |
+
"<extra_id_43>",
|
| 556 |
+
"<extra_id_44>",
|
| 557 |
+
"<extra_id_45>",
|
| 558 |
+
"<extra_id_46>",
|
| 559 |
+
"<extra_id_47>",
|
| 560 |
+
"<extra_id_48>",
|
| 561 |
+
"<extra_id_49>",
|
| 562 |
+
"<extra_id_50>",
|
| 563 |
+
"<extra_id_51>",
|
| 564 |
+
"<extra_id_52>",
|
| 565 |
+
"<extra_id_53>",
|
| 566 |
+
"<extra_id_54>",
|
| 567 |
+
"<extra_id_55>",
|
| 568 |
+
"<extra_id_56>",
|
| 569 |
+
"<extra_id_57>",
|
| 570 |
+
"<extra_id_58>",
|
| 571 |
+
"<extra_id_59>",
|
| 572 |
+
"<extra_id_60>",
|
| 573 |
+
"<extra_id_61>",
|
| 574 |
+
"<extra_id_62>",
|
| 575 |
+
"<extra_id_63>"
|
| 576 |
+
],
|
| 577 |
+
"auto_map": {
|
| 578 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor",
|
| 579 |
+
"AutoTokenizer": [
|
| 580 |
+
"tokenization_evabyte.EvaByteTokenizer",
|
| 581 |
+
null
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
"bos_token": "<bos>",
|
| 585 |
+
"chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
|
| 586 |
+
"clean_up_tokenization_spaces": false,
|
| 587 |
+
"eos_token": "<eos>",
|
| 588 |
+
"extra_ids": 0,
|
| 589 |
+
"extra_special_tokens": {},
|
| 590 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 591 |
+
"pad_token": "<pad>",
|
| 592 |
+
"processor_class": "EvaByteProcessor",
|
| 593 |
+
"sep_token": "<sep>",
|
| 594 |
+
"tokenizer_class": "EvaByteTokenizer",
|
| 595 |
+
"unk_token": "<unk>"
|
| 596 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
# EvaByte Model Card
|
| 5 |
+
|
| 6 |
+
**EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
|
| 7 |
+
|
| 8 |
+
## Model Resources
|
| 9 |
+
|
| 10 |
+
- **Repository:** https://github.com/openevabyte/evabyte
|
| 11 |
+
- **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
|
| 12 |
+
- **Paper:** Coming soon
|
| 13 |
+
|
| 14 |
+
## Model Details
|
| 15 |
+
|
| 16 |
+
EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
|
| 17 |
+
|
| 18 |
+
| Stage | Model |
|
| 19 |
+
|:----- |:-----|
|
| 20 |
+
| Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
|
| 21 |
+
| Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
|
| 22 |
+
| SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
**Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
|
| 28 |
+
|
| 29 |
+
The code snippet below demonstrates EvaByte-6.5B for completion:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
# Load model and tokenizer
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
|
| 38 |
+
|
| 39 |
+
prompt = "The quick brown fox jumps "
|
| 40 |
+
|
| 41 |
+
# Tokenize input
|
| 42 |
+
# Option 1: standard HF tokenizer interface
|
| 43 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
| 44 |
+
|
| 45 |
+
# Option 2: Direct UTF-8 byte encoding with offset
|
| 46 |
+
# Note: Each byte is offset by 64 with <bos> prepended.
|
| 47 |
+
input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
|
| 48 |
+
|
| 49 |
+
# byte-by-byte generation (default)
|
| 50 |
+
generation_output = model.generate(
|
| 51 |
+
input_ids=input_ids,
|
| 52 |
+
max_new_tokens=32
|
| 53 |
+
)
|
| 54 |
+
# alternatively, use faster multibyte generation
|
| 55 |
+
generation_output = model.multi_byte_generate(
|
| 56 |
+
input_ids=input_ids,
|
| 57 |
+
max_new_tokens=32
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Decode and print the output
|
| 61 |
+
response = tokenizer.decode(
|
| 62 |
+
generation_output[0][input_ids.shape[1]:],
|
| 63 |
+
skip_special_tokens=False,
|
| 64 |
+
clean_up_tokenization_spaces=False
|
| 65 |
+
)
|
| 66 |
+
print(response)
|
| 67 |
+
# Sample output:
|
| 68 |
+
# over the lazy dog.\n\nThe quick
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### ⚙️ Generation Modes
|
| 72 |
+
|
| 73 |
+
EvaByte supports two generation interfaces:
|
| 74 |
+
- `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
|
| 75 |
+
- `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
|
| 76 |
+
- `input_ids`: the input byte ids.
|
| 77 |
+
- `temperature`: the temperature for sampling.
|
| 78 |
+
- `max_length`: the maximum length of the generated sequence.
|
| 79 |
+
- `max_new_tokens`: the maximum number of new bytes to generate.
|
| 80 |
+
- `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
|
| 81 |
+
- `top_p`: the top-p parameter for sampling.
|
| 82 |
+
- `do_sample`: greedy decoding or sampling.
|
| 83 |
+
|
| 84 |
+
**Notes and Limitations:**
|
| 85 |
+
- `device_map="auto"` is not supported for >2 GPUs.
|
| 86 |
+
- Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
|
| 87 |
+
- `torch_dtype=torch.bfloat16` is required.
|
| 88 |
+
- The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
|
| 89 |
+
|
| 90 |
+
## Bias, Risks, and Limitations
|
| 91 |
+
As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
|
| 92 |
+
|
| 93 |
+
## Evaluation
|
| 94 |
+
|
| 95 |
+
For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
|
| 96 |
+
|
| 97 |
+
## Citation
|
| 98 |
+
```bibtex
|
| 99 |
+
@misc{evabyte,
|
| 100 |
+
title = {EvaByte: Efficient Byte-level Language Models at Scale},
|
| 101 |
+
url = {https://hkunlp.github.io/blog/2025/evabyte},
|
| 102 |
+
author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
|
| 103 |
+
year = {2025}
|
| 104 |
+
}
|
| 105 |
+
```
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"EvaByteForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_class": "eva",
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"auto_map": {
|
| 10 |
+
"AutoConfig": "configuration_evabyte.EvaByteConfig",
|
| 11 |
+
"AutoModelForCausalLM": "modeling_evabyte.EvaByteForCausalLM"
|
| 12 |
+
},
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"chunk_size": 16,
|
| 15 |
+
"eos_token_id": 2,
|
| 16 |
+
"fp32_ln": true,
|
| 17 |
+
"fp32_logits": true,
|
| 18 |
+
"fp32_skip_add": false,
|
| 19 |
+
"hidden_act": "silu",
|
| 20 |
+
"hidden_size": 5120,
|
| 21 |
+
"init_cutoff_factor": null,
|
| 22 |
+
"init_fn": "v2",
|
| 23 |
+
"init_std": 0.01275,
|
| 24 |
+
"initializer_range": 0.01275,
|
| 25 |
+
"intermediate_size": 16384,
|
| 26 |
+
"lazy_init": true,
|
| 27 |
+
"max_position_embeddings": 16384,
|
| 28 |
+
"max_seq_length": 16384,
|
| 29 |
+
"mixedp_attn": true,
|
| 30 |
+
"model_type": "evabyte",
|
| 31 |
+
"norm_add_unit_offset": true,
|
| 32 |
+
"num_attention_heads": 40,
|
| 33 |
+
"num_chunks": null,
|
| 34 |
+
"num_hidden_layers": 40,
|
| 35 |
+
"num_key_value_heads": 40,
|
| 36 |
+
"num_pred_heads": 1,
|
| 37 |
+
"pad_token_id": 0,
|
| 38 |
+
"return_dict": false,
|
| 39 |
+
"rms_norm_eps": 1e-06,
|
| 40 |
+
"rope_scaling": null,
|
| 41 |
+
"rope_theta": 100000.0,
|
| 42 |
+
"tie_word_embeddings": false,
|
| 43 |
+
"torch_dtype": "bfloat16",
|
| 44 |
+
"transformers_version": "4.47.1",
|
| 45 |
+
"use_cache": true,
|
| 46 |
+
"vocab_size": 320,
|
| 47 |
+
"window_size": 2048
|
| 48 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/configuration_evabyte.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" EvaByte configuration"""
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
class EvaByteConfig(PretrainedConfig):
|
| 6 |
+
model_type = "evabyte"
|
| 7 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
vocab_size=320,
|
| 12 |
+
hidden_size=4096,
|
| 13 |
+
intermediate_size=11008,
|
| 14 |
+
num_hidden_layers=32,
|
| 15 |
+
num_attention_heads=32,
|
| 16 |
+
num_key_value_heads=None,
|
| 17 |
+
hidden_act="silu",
|
| 18 |
+
max_position_embeddings=2048,
|
| 19 |
+
initializer_range=0.02,
|
| 20 |
+
rms_norm_eps=1e-6,
|
| 21 |
+
use_cache=True,
|
| 22 |
+
pad_token_id=None,
|
| 23 |
+
bos_token_id=1,
|
| 24 |
+
eos_token_id=2,
|
| 25 |
+
tie_word_embeddings=False,
|
| 26 |
+
rope_theta=10000.0,
|
| 27 |
+
rope_scaling=None,
|
| 28 |
+
attention_bias=False,
|
| 29 |
+
attention_dropout=0.0,
|
| 30 |
+
norm_add_unit_offset=False,
|
| 31 |
+
init_fn="mitchell",
|
| 32 |
+
init_std=0.006,
|
| 33 |
+
init_cutoff_factor=None,
|
| 34 |
+
attention_class="mha",
|
| 35 |
+
window_size=512,
|
| 36 |
+
num_chunks=None,
|
| 37 |
+
chunk_size=256,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
self.vocab_size = vocab_size
|
| 41 |
+
self.max_position_embeddings = max_position_embeddings
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.intermediate_size = intermediate_size
|
| 44 |
+
self.num_hidden_layers = num_hidden_layers
|
| 45 |
+
self.num_attention_heads = num_attention_heads
|
| 46 |
+
|
| 47 |
+
# for backward compatibility
|
| 48 |
+
if num_key_value_heads is None:
|
| 49 |
+
num_key_value_heads = num_attention_heads
|
| 50 |
+
|
| 51 |
+
self.num_key_value_heads = num_key_value_heads
|
| 52 |
+
self.hidden_act = hidden_act
|
| 53 |
+
self.initializer_range = initializer_range
|
| 54 |
+
self.rms_norm_eps = rms_norm_eps
|
| 55 |
+
self.use_cache = use_cache
|
| 56 |
+
self.rope_theta = rope_theta
|
| 57 |
+
self.rope_scaling = rope_scaling
|
| 58 |
+
self._rope_scaling_validation()
|
| 59 |
+
self.attention_bias = attention_bias
|
| 60 |
+
self.attention_dropout = attention_dropout
|
| 61 |
+
|
| 62 |
+
self.norm_add_unit_offset = norm_add_unit_offset
|
| 63 |
+
self.init_fn = init_fn
|
| 64 |
+
self.init_std = init_std
|
| 65 |
+
self.init_cutoff_factor = init_cutoff_factor
|
| 66 |
+
|
| 67 |
+
# Attention-specific paramters
|
| 68 |
+
self.attention_class = attention_class
|
| 69 |
+
self.window_size = window_size
|
| 70 |
+
self.num_chunks = num_chunks
|
| 71 |
+
self.chunk_size = chunk_size
|
| 72 |
+
|
| 73 |
+
super().__init__(
|
| 74 |
+
pad_token_id=pad_token_id,
|
| 75 |
+
bos_token_id=bos_token_id,
|
| 76 |
+
eos_token_id=eos_token_id,
|
| 77 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 78 |
+
**kwargs,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _rope_scaling_validation(self):
|
| 82 |
+
"""
|
| 83 |
+
Validate the `rope_scaling` configuration.
|
| 84 |
+
"""
|
| 85 |
+
if self.rope_scaling is None:
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
| 91 |
+
)
|
| 92 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 93 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 94 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 97 |
+
)
|
| 98 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 99 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple, List, Any, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .eva_agg_kernel import eva_agg_func_triton
|
| 6 |
+
from .eva_prep_kv_kernel import eva_prep_kv_func_triton
|
| 7 |
+
try:
|
| 8 |
+
import triton
|
| 9 |
+
USE_TRITON_IMPL = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
USE_TRITON_IMPL = False
|
| 12 |
+
raise ImportError("Triton is not installed. Please install it by running `pip install triton`.")
|
| 13 |
+
|
| 14 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Rotates half the hidden dims (last dim) of the input.
|
| 17 |
+
Args:
|
| 18 |
+
x: Rotary embedded tensor
|
| 19 |
+
Return:
|
| 20 |
+
Tensor with half of last dim negated and rotated to the front.
|
| 21 |
+
"""
|
| 22 |
+
x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
|
| 23 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 24 |
+
|
| 25 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
| 26 |
+
position_ids: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
|
| 29 |
+
|
| 30 |
+
The legends for dimensions are defined as:
|
| 31 |
+
num_heads: number of attention heads
|
| 32 |
+
current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
|
| 33 |
+
max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
|
| 34 |
+
maximum lenghth, e.g. the length of static sequence length of KV cache
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
|
| 39 |
+
k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
|
| 40 |
+
cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 41 |
+
sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 42 |
+
position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
|
| 43 |
+
(batch_size, current_seq_len).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Embedded query and key tensor of same size as input.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
bs, nheads, cur_seq_len, head_dim = q.shape
|
| 50 |
+
assert len(
|
| 51 |
+
k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
|
| 52 |
+
assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
|
| 53 |
+
assert list(k.shape[2:]) == [cur_seq_len,
|
| 54 |
+
head_dim], f"k has different current_seq_len and/or head_dim compared to q"
|
| 55 |
+
assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
|
| 56 |
+
assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
|
| 57 |
+
f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
|
| 58 |
+
|
| 59 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 60 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 61 |
+
return q_embed, k_embed
|
| 62 |
+
|
| 63 |
+
class EvaAttention(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Causal EVA for language modeling.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.config = config
|
| 71 |
+
self.layer_idx = layer_idx
|
| 72 |
+
self.hidden_size = config.hidden_size
|
| 73 |
+
self.num_heads = config.num_attention_heads
|
| 74 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 75 |
+
self.head_dim_scaling = self.head_dim ** -0.5
|
| 76 |
+
|
| 77 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 78 |
+
|
| 79 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 82 |
+
f" and `num_heads`: {self.num_heads})."
|
| 83 |
+
)
|
| 84 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 85 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 86 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 87 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 88 |
+
|
| 89 |
+
self.window_size = config.window_size
|
| 90 |
+
|
| 91 |
+
self.num_chunks = config.num_chunks
|
| 92 |
+
self.chunk_size = config.chunk_size
|
| 93 |
+
if self.chunk_size is not None:
|
| 94 |
+
assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
|
| 95 |
+
# chunk_size overrides the number of landmarks
|
| 96 |
+
self.num_chunks = None
|
| 97 |
+
|
| 98 |
+
self.chunks_per_window = int(self.window_size // self.chunk_size)
|
| 99 |
+
self.adaptive_phi = nn.Parameter(
|
| 100 |
+
torch.randn(
|
| 101 |
+
1,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
1,
|
| 104 |
+
1,
|
| 105 |
+
self.head_dim
|
| 106 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 107 |
+
)
|
| 108 |
+
self.adaptive_mu_k = nn.Parameter(
|
| 109 |
+
torch.randn(
|
| 110 |
+
1,
|
| 111 |
+
self.num_heads,
|
| 112 |
+
1,
|
| 113 |
+
1,
|
| 114 |
+
self.head_dim
|
| 115 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _triton_forward(
|
| 119 |
+
self,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 122 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 123 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 124 |
+
output_attentions: bool = False,
|
| 125 |
+
use_cache: bool = False,
|
| 126 |
+
cos: Optional[torch.Tensor] = None,
|
| 127 |
+
sin: Optional[torch.Tensor] = None,
|
| 128 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 129 |
+
assert not output_attentions
|
| 130 |
+
bsz, q_len, _ = hidden_states.size()
|
| 131 |
+
|
| 132 |
+
if use_cache:
|
| 133 |
+
if past_key_value is None:
|
| 134 |
+
raise ValueError
|
| 135 |
+
assert isinstance(attention_mask, tuple)
|
| 136 |
+
|
| 137 |
+
# infer the model's running mode
|
| 138 |
+
is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
|
| 139 |
+
is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
|
| 140 |
+
|
| 141 |
+
if is_prefilling:
|
| 142 |
+
assert len(attention_mask) == 2
|
| 143 |
+
window_mask, intra_chunk_mask = attention_mask
|
| 144 |
+
chunk_mask = None
|
| 145 |
+
elif is_decoding:
|
| 146 |
+
assert len(attention_mask) == 3
|
| 147 |
+
window_mask, intra_chunk_mask, chunk_mask = attention_mask
|
| 148 |
+
else:
|
| 149 |
+
if attention_mask is not None:
|
| 150 |
+
assert isinstance(attention_mask, tuple) and len(attention_mask) == 3
|
| 151 |
+
window_mask, chunk_mask, intra_chunk_mask = attention_mask
|
| 152 |
+
else:
|
| 153 |
+
window_mask, chunk_mask, intra_chunk_mask = None, None, None
|
| 154 |
+
|
| 155 |
+
############################################
|
| 156 |
+
# compute q, k, v from hidden states
|
| 157 |
+
############################################
|
| 158 |
+
# [b, h, q_len, d]
|
| 159 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 160 |
+
# [b, h, kv_len, d]
|
| 161 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 162 |
+
# [b, h, kv_len, d]
|
| 163 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
if use_cache:
|
| 166 |
+
past_key_value.update_past_len(q.shape[-2], self.layer_idx)
|
| 167 |
+
|
| 168 |
+
############################################
|
| 169 |
+
# apply rotary positional embeddings to q, k
|
| 170 |
+
############################################
|
| 171 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 172 |
+
|
| 173 |
+
############################################
|
| 174 |
+
# update and get cached singleton tokens
|
| 175 |
+
# update and cache k and v for calculating chunk-level RFAs
|
| 176 |
+
############################################
|
| 177 |
+
if use_cache:
|
| 178 |
+
s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks(
|
| 179 |
+
k,
|
| 180 |
+
v,
|
| 181 |
+
self.layer_idx,
|
| 182 |
+
self.window_size,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
s_k, s_v = k, v
|
| 186 |
+
dump_k, dump_v = k, v
|
| 187 |
+
|
| 188 |
+
if use_cache:
|
| 189 |
+
singleton_mask, dump_rf_mask = past_key_value.update_mask(
|
| 190 |
+
s_mask=window_mask,
|
| 191 |
+
rf_mask=intra_chunk_mask,
|
| 192 |
+
layer_idx=self.layer_idx,
|
| 193 |
+
window_size=self.window_size,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
singleton_mask = window_mask
|
| 197 |
+
dump_rf_mask = intra_chunk_mask
|
| 198 |
+
|
| 199 |
+
if dump_k is not None and dump_v is not None:
|
| 200 |
+
# 1. in prefilling, the input shape is
|
| 201 |
+
# dump_k/dump_v: [b, h, n, d]
|
| 202 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 203 |
+
# 2. in decoding, the input shape is
|
| 204 |
+
# k/v: [b, h, w, d]
|
| 205 |
+
# rfa_k/rfa_v: [b, h, w//c, d]
|
| 206 |
+
# 3. in forward inference; the seq_len is already divisible
|
| 207 |
+
rfa_k, rfa_v = eva_prep_kv_func_triton(
|
| 208 |
+
dump_k, dump_v,
|
| 209 |
+
self.adaptive_mu_k, self.adaptive_phi,
|
| 210 |
+
dump_rf_mask, self.head_dim_scaling, self.chunk_size
|
| 211 |
+
)
|
| 212 |
+
# rfa_mask = get_rfa_chunk_mask(dump_rf_mask)
|
| 213 |
+
if use_cache:
|
| 214 |
+
rfa_k, rfa_v = past_key_value.update_chunk_rfas(
|
| 215 |
+
rfa_k, rfa_v, self.layer_idx
|
| 216 |
+
)
|
| 217 |
+
elif use_cache:
|
| 218 |
+
# if there are not enough elements within the last chunk,
|
| 219 |
+
# we will only use the cached chunk-level RFAs
|
| 220 |
+
rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 221 |
+
else:
|
| 222 |
+
rfa_k, rfa_v = None, None
|
| 223 |
+
|
| 224 |
+
############################################
|
| 225 |
+
# compute the full attention output
|
| 226 |
+
############################################
|
| 227 |
+
if is_prefilling:
|
| 228 |
+
# prefilling
|
| 229 |
+
# 1. in prefilling, the input shape is
|
| 230 |
+
# q: [b, h, n, d]
|
| 231 |
+
# k/v: [b, h, n, d]
|
| 232 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 233 |
+
attn_output = eva_agg_func_triton(
|
| 234 |
+
q, s_k, s_v,
|
| 235 |
+
rfa_k, rfa_v,
|
| 236 |
+
singleton_mask, chunk_mask,
|
| 237 |
+
self.head_dim_scaling, self.window_size, self.chunks_per_window
|
| 238 |
+
)
|
| 239 |
+
elif is_decoding:
|
| 240 |
+
# 2. in decoding, the input shape is
|
| 241 |
+
# q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
|
| 242 |
+
# k/v: [b, h, 1 + s, d]
|
| 243 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 244 |
+
if rfa_k is not None and rfa_v is not None:
|
| 245 |
+
# we only take the chunk-level RFAs not in the current window
|
| 246 |
+
seen_seq_len = past_key_value.get_seq_length(self.layer_idx)
|
| 247 |
+
if seen_seq_len <= self.window_size:
|
| 248 |
+
agg_k = s_k
|
| 249 |
+
agg_v = s_v
|
| 250 |
+
attn_mask = singleton_mask
|
| 251 |
+
else:
|
| 252 |
+
# NOTE: we already updated the cache so the length now
|
| 253 |
+
# includes the current token
|
| 254 |
+
# we subtract 1 from seen_seq_len because we want
|
| 255 |
+
# if seen_seq_len = 2048 -> num_windows_seen_so_far = 0
|
| 256 |
+
# if seen_seq_len = 4096 -> num_windows_seen_so_far = 1
|
| 257 |
+
# if seen_seq_len = 4097 -> num_windows_seen_so_far = 2
|
| 258 |
+
# NOTE the cat order should be taken care of;
|
| 259 |
+
# should align with the order based on which
|
| 260 |
+
# the attention mask is constructed
|
| 261 |
+
num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size
|
| 262 |
+
agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
|
| 263 |
+
agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
|
| 264 |
+
if singleton_mask is not None:
|
| 265 |
+
assert chunk_mask is not None
|
| 266 |
+
attn_mask = torch.cat([singleton_mask, chunk_mask], dim=-1)
|
| 267 |
+
else:
|
| 268 |
+
attn_mask = singleton_mask
|
| 269 |
+
else:
|
| 270 |
+
agg_k = s_k
|
| 271 |
+
agg_v = s_v
|
| 272 |
+
attn_mask = singleton_mask
|
| 273 |
+
attn_output = F.scaled_dot_product_attention(
|
| 274 |
+
q, agg_k, agg_v,
|
| 275 |
+
attn_mask=attn_mask,
|
| 276 |
+
is_causal=False,
|
| 277 |
+
dropout_p=0.0,
|
| 278 |
+
scale=self.head_dim_scaling
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
# 3. in single-forward inference
|
| 282 |
+
attn_output = eva_agg_func_triton(
|
| 283 |
+
q, s_k, s_v,
|
| 284 |
+
rfa_k, rfa_v,
|
| 285 |
+
singleton_mask, chunk_mask,
|
| 286 |
+
self.head_dim_scaling, self.window_size, self.chunks_per_window
|
| 287 |
+
)
|
| 288 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 291 |
+
f" {attn_output.size()}"
|
| 292 |
+
)
|
| 293 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 294 |
+
attn_output = self.o_proj(attn_output)
|
| 295 |
+
attn_weights = None
|
| 296 |
+
return attn_output, attn_weights, past_key_value
|
| 297 |
+
|
| 298 |
+
def _multibyte_decoding_forward(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 302 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 303 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 304 |
+
output_attentions: bool = False,
|
| 305 |
+
use_cache: bool = False,
|
| 306 |
+
cos: Optional[torch.Tensor] = None,
|
| 307 |
+
sin: Optional[torch.Tensor] = None,
|
| 308 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 309 |
+
# during multi-byte forwarding, we only read caches and do not update them
|
| 310 |
+
assert not output_attentions
|
| 311 |
+
bsz, q_len, _ = hidden_states.size()
|
| 312 |
+
|
| 313 |
+
if use_cache and past_key_value is None:
|
| 314 |
+
raise ValueError
|
| 315 |
+
|
| 316 |
+
assert USE_TRITON_IMPL
|
| 317 |
+
assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool
|
| 318 |
+
|
| 319 |
+
assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
|
| 320 |
+
|
| 321 |
+
############################################
|
| 322 |
+
# compute q, k, v from hidden states
|
| 323 |
+
############################################
|
| 324 |
+
# [b, h, q_len, d]
|
| 325 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 326 |
+
# [b, h, kv_len, d]
|
| 327 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 328 |
+
# [b, h, kv_len, d]
|
| 329 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 330 |
+
|
| 331 |
+
############################################
|
| 332 |
+
# apply rotary positional embeddings to q, k
|
| 333 |
+
############################################
|
| 334 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 335 |
+
|
| 336 |
+
############################################
|
| 337 |
+
# update and get cached singleton tokens
|
| 338 |
+
############################################
|
| 339 |
+
input_len = k.shape[-2]
|
| 340 |
+
window_pos = past_key_value.past_window_pos[self.layer_idx]
|
| 341 |
+
new_window_pos = window_pos + input_len
|
| 342 |
+
|
| 343 |
+
past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k
|
| 344 |
+
past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v
|
| 345 |
+
s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :]
|
| 346 |
+
s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :]
|
| 347 |
+
|
| 348 |
+
rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 349 |
+
|
| 350 |
+
############################################
|
| 351 |
+
# compute the full attention output
|
| 352 |
+
############################################
|
| 353 |
+
# 2. in decoding, the input shape is
|
| 354 |
+
# q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
|
| 355 |
+
# k/v: [b, h, 1 + s, d]
|
| 356 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 357 |
+
if rfa_k is not None and rfa_v is not None:
|
| 358 |
+
# NOTE the cat order should be taken care of;
|
| 359 |
+
# should align with the order based on which
|
| 360 |
+
# the attention mask is constructed
|
| 361 |
+
# agg_k = torch.cat([s_k, rfa_k], dim=-2)
|
| 362 |
+
# agg_v = torch.cat([s_v, rfa_v], dim=-2)
|
| 363 |
+
agg_k = torch.cat([rfa_k, s_k], dim=-2)
|
| 364 |
+
agg_v = torch.cat([rfa_v, s_v], dim=-2)
|
| 365 |
+
else:
|
| 366 |
+
agg_k = s_k
|
| 367 |
+
agg_v = s_v
|
| 368 |
+
attn_output = F.scaled_dot_product_attention(
|
| 369 |
+
q, agg_k, agg_v,
|
| 370 |
+
attn_mask=attention_mask,
|
| 371 |
+
is_causal=False,
|
| 372 |
+
dropout_p=0.0,
|
| 373 |
+
scale=self.head_dim_scaling
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 377 |
+
raise ValueError(
|
| 378 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 379 |
+
f" {attn_output.size()}"
|
| 380 |
+
)
|
| 381 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 382 |
+
attn_output = self.o_proj(attn_output)
|
| 383 |
+
attn_weights = None
|
| 384 |
+
return attn_output, attn_weights, past_key_value
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
hidden_states: torch.Tensor,
|
| 389 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 390 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 391 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 392 |
+
output_attentions: bool = False,
|
| 393 |
+
use_cache: bool = False,
|
| 394 |
+
cos: Optional[torch.Tensor] = None,
|
| 395 |
+
sin: Optional[torch.Tensor] = None,
|
| 396 |
+
multibyte_decoding: Optional[bool] = False,
|
| 397 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 398 |
+
assert not output_attentions
|
| 399 |
+
if use_cache and past_key_value is None:
|
| 400 |
+
raise ValueError
|
| 401 |
+
|
| 402 |
+
assert USE_TRITON_IMPL
|
| 403 |
+
if use_cache and multibyte_decoding:
|
| 404 |
+
return self._multibyte_decoding_forward(
|
| 405 |
+
hidden_states,
|
| 406 |
+
attention_mask=attention_mask,
|
| 407 |
+
position_ids=position_ids,
|
| 408 |
+
past_key_value=past_key_value,
|
| 409 |
+
output_attentions=output_attentions,
|
| 410 |
+
use_cache=use_cache,
|
| 411 |
+
cos=cos,
|
| 412 |
+
sin=sin,
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
return self._triton_forward(
|
| 416 |
+
hidden_states,
|
| 417 |
+
attention_mask=attention_mask,
|
| 418 |
+
position_ids=position_ids,
|
| 419 |
+
past_key_value=past_key_value,
|
| 420 |
+
output_attentions=output_attentions,
|
| 421 |
+
use_cache=use_cache,
|
| 422 |
+
cos=cos,
|
| 423 |
+
sin=sin,
|
| 424 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_agg_kernel.py
ADDED
|
@@ -0,0 +1,1766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
@triton.heuristics(
|
| 8 |
+
{
|
| 9 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 10 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 11 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 12 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 13 |
+
}
|
| 14 |
+
)
|
| 15 |
+
@triton.jit
|
| 16 |
+
def _bwd_eva_agg_kernel_dkdv(
|
| 17 |
+
Q,
|
| 18 |
+
K,
|
| 19 |
+
V,
|
| 20 |
+
WindowMask,
|
| 21 |
+
DO,
|
| 22 |
+
LSE,
|
| 23 |
+
DO_T_O,
|
| 24 |
+
DK,
|
| 25 |
+
DV,
|
| 26 |
+
softmax_scale,
|
| 27 |
+
stride_qb, stride_qh, stride_qm,
|
| 28 |
+
stride_kb, stride_kh, stride_kn,
|
| 29 |
+
stride_vb, stride_vh, stride_vn,
|
| 30 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 31 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 32 |
+
stride_lse_b, stride_lse_h,
|
| 33 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 34 |
+
stride_dk_b, stride_dk_h, stride_dk_n,
|
| 35 |
+
stride_dv_b, stride_dv_h, stride_dv_n,
|
| 36 |
+
nheads,
|
| 37 |
+
seqlen_q,
|
| 38 |
+
seqlen_k,
|
| 39 |
+
headdim,
|
| 40 |
+
WINDOW_SIZE: tl.constexpr,
|
| 41 |
+
MASK_TYPE: tl.constexpr,
|
| 42 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 43 |
+
EVEN_M: tl.constexpr,
|
| 44 |
+
EVEN_N: tl.constexpr,
|
| 45 |
+
EVEN_W: tl.constexpr,
|
| 46 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 47 |
+
BLOCK_M: tl.constexpr,
|
| 48 |
+
BLOCK_N: tl.constexpr,
|
| 49 |
+
):
|
| 50 |
+
off_bh = tl.program_id(1)
|
| 51 |
+
off_h = off_bh % nheads
|
| 52 |
+
off_b = off_bh // nheads
|
| 53 |
+
|
| 54 |
+
start_n = tl.program_id(0)
|
| 55 |
+
# determine which window the current KV block belongs to
|
| 56 |
+
offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
|
| 57 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 58 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 59 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 60 |
+
|
| 61 |
+
# initialize pointers
|
| 62 |
+
q_ptrs = (
|
| 63 |
+
Q +
|
| 64 |
+
off_b * stride_qb +
|
| 65 |
+
off_h * stride_qh +
|
| 66 |
+
offs_m[:, None] * stride_qm + offs_d[None, :]
|
| 67 |
+
)
|
| 68 |
+
k_ptrs = (
|
| 69 |
+
K +
|
| 70 |
+
off_b * stride_kb +
|
| 71 |
+
off_h * stride_kh +
|
| 72 |
+
offs_n[:, None] * stride_kn + offs_d[None, :]
|
| 73 |
+
)
|
| 74 |
+
v_ptrs = (
|
| 75 |
+
V +
|
| 76 |
+
off_b * stride_vb +
|
| 77 |
+
off_h * stride_vh +
|
| 78 |
+
offs_n[:, None] * stride_vn + offs_d[None, :]
|
| 79 |
+
)
|
| 80 |
+
do_ptrs = (
|
| 81 |
+
DO +
|
| 82 |
+
off_b * stride_do_b +
|
| 83 |
+
off_h * stride_do_h +
|
| 84 |
+
offs_m[:, None] * stride_do_m + offs_d[None, :]
|
| 85 |
+
)
|
| 86 |
+
do_t_o_ptrs = (
|
| 87 |
+
DO_T_O +
|
| 88 |
+
off_b * stride_do_t_o_b +
|
| 89 |
+
off_h * stride_do_t_o_h +
|
| 90 |
+
offs_m[:, None]
|
| 91 |
+
)
|
| 92 |
+
lse_ptrs = (
|
| 93 |
+
LSE +
|
| 94 |
+
off_b * stride_lse_b +
|
| 95 |
+
off_h * stride_lse_h +
|
| 96 |
+
offs_m[:, None]
|
| 97 |
+
)
|
| 98 |
+
if MASK_TYPE == 1:
|
| 99 |
+
m_ptrs = (
|
| 100 |
+
WindowMask +
|
| 101 |
+
off_b * stride_window_mask_b +
|
| 102 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 103 |
+
)
|
| 104 |
+
dk_ptrs = (
|
| 105 |
+
DK +
|
| 106 |
+
off_b * stride_dk_b +
|
| 107 |
+
off_h * stride_dk_h +
|
| 108 |
+
offs_n[:, None] * stride_dk_n + offs_d[None, :]
|
| 109 |
+
)
|
| 110 |
+
dv_ptrs = (
|
| 111 |
+
DV +
|
| 112 |
+
off_b * stride_dv_b +
|
| 113 |
+
off_h * stride_dv_h +
|
| 114 |
+
offs_n[:, None] * stride_dv_n + offs_d[None, :]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 1. for singletons
|
| 118 |
+
# determine start and end of query block
|
| 119 |
+
begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
|
| 120 |
+
end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
|
| 121 |
+
|
| 122 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 123 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 124 |
+
if EVEN_N & EVEN_M:
|
| 125 |
+
if EVEN_HEADDIM:
|
| 126 |
+
k = tl.load(k_ptrs)
|
| 127 |
+
v = tl.load(v_ptrs)
|
| 128 |
+
else:
|
| 129 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 130 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 131 |
+
else:
|
| 132 |
+
if EVEN_HEADDIM:
|
| 133 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 134 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 135 |
+
else:
|
| 136 |
+
k = tl.load(
|
| 137 |
+
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
|
| 138 |
+
)
|
| 139 |
+
v = tl.load(
|
| 140 |
+
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
|
| 141 |
+
)
|
| 142 |
+
for start_m in range(begin_m, end_m, BLOCK_M):
|
| 143 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 144 |
+
# load q, do, and lse
|
| 145 |
+
if EVEN_M & EVEN_N:
|
| 146 |
+
if EVEN_HEADDIM:
|
| 147 |
+
q = tl.load(
|
| 148 |
+
q_ptrs + start_m * stride_qm
|
| 149 |
+
)
|
| 150 |
+
do = tl.load(
|
| 151 |
+
do_ptrs + start_m * stride_do_m
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
q = tl.load(
|
| 155 |
+
q_ptrs + start_m * stride_qm,
|
| 156 |
+
mask=offs_d[None, :] < headdim,
|
| 157 |
+
other=0.0
|
| 158 |
+
)
|
| 159 |
+
do = tl.load(
|
| 160 |
+
do_ptrs + start_m * stride_do_m,
|
| 161 |
+
mask=offs_d[None, :] < headdim,
|
| 162 |
+
other=0.0
|
| 163 |
+
)
|
| 164 |
+
do_t_o = tl.load(
|
| 165 |
+
do_t_o_ptrs + start_m
|
| 166 |
+
)
|
| 167 |
+
lse = tl.load(
|
| 168 |
+
lse_ptrs + start_m
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
if EVEN_HEADDIM:
|
| 172 |
+
q = tl.load(
|
| 173 |
+
q_ptrs + start_m * stride_qm,
|
| 174 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 175 |
+
other=0.0
|
| 176 |
+
)
|
| 177 |
+
do = tl.load(
|
| 178 |
+
do_ptrs + start_m * stride_do_m,
|
| 179 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 180 |
+
other=0.0
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
q = tl.load(
|
| 184 |
+
q_ptrs + start_m * stride_qm,
|
| 185 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 186 |
+
other=0.0
|
| 187 |
+
)
|
| 188 |
+
do = tl.load(
|
| 189 |
+
do_ptrs + start_m * stride_do_m,
|
| 190 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 191 |
+
other=0.0
|
| 192 |
+
)
|
| 193 |
+
do_t_o = tl.load(
|
| 194 |
+
do_t_o_ptrs + start_m,
|
| 195 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 196 |
+
other=0.0
|
| 197 |
+
)
|
| 198 |
+
lse = tl.load(
|
| 199 |
+
lse_ptrs + start_m,
|
| 200 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 201 |
+
other=0.0
|
| 202 |
+
)
|
| 203 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 204 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 205 |
+
qk += tl.dot(q, tl.trans(k))
|
| 206 |
+
if not EVEN_M:
|
| 207 |
+
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
|
| 208 |
+
|
| 209 |
+
if MASK_TYPE == 1:
|
| 210 |
+
if EVEN_M & EVEN_W:
|
| 211 |
+
mask = tl.load(
|
| 212 |
+
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
mask = tl.load(
|
| 216 |
+
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
|
| 217 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q)
|
| 218 |
+
& (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
|
| 219 |
+
other=1,
|
| 220 |
+
)
|
| 221 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 222 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 223 |
+
# to multiply with softmax_scale here.
|
| 224 |
+
# we assume mask already implies the causal masking
|
| 225 |
+
qk = qk * softmax_scale
|
| 226 |
+
qk = tl.where(mask, float("-inf"), qk)
|
| 227 |
+
p = tl.exp(qk - lse)
|
| 228 |
+
else:
|
| 229 |
+
qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
|
| 230 |
+
p = tl.exp(qk * softmax_scale - lse)
|
| 231 |
+
|
| 232 |
+
# dp [M, N]
|
| 233 |
+
dp = tl.dot(do, tl.trans(v))
|
| 234 |
+
# p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
|
| 235 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 236 |
+
# p is fp32 and [M, N], convert to q.dtype
|
| 237 |
+
# do [M, D] -> dv [N, D]
|
| 238 |
+
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
| 239 |
+
# dk [N, D]
|
| 240 |
+
dk += tl.dot(tl.trans(ds), q)
|
| 241 |
+
if EVEN_N & EVEN_M:
|
| 242 |
+
if EVEN_HEADDIM:
|
| 243 |
+
tl.store(dv_ptrs, dv)
|
| 244 |
+
tl.store(dk_ptrs, dk)
|
| 245 |
+
else:
|
| 246 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
| 247 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
| 248 |
+
else:
|
| 249 |
+
if EVEN_HEADDIM:
|
| 250 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
| 251 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
| 252 |
+
else:
|
| 253 |
+
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 254 |
+
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 255 |
+
|
| 256 |
+
@triton.heuristics(
|
| 257 |
+
{
|
| 258 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 259 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 260 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
@triton.jit
|
| 264 |
+
def _bwd_eva_agg_kernel_drfa_kv(
|
| 265 |
+
Q,
|
| 266 |
+
RFA_K,
|
| 267 |
+
RFA_V,
|
| 268 |
+
ChunkMask,
|
| 269 |
+
DO,
|
| 270 |
+
LSE,
|
| 271 |
+
DO_T_O,
|
| 272 |
+
D_RFA_K,
|
| 273 |
+
D_RFA_V,
|
| 274 |
+
softmax_scale,
|
| 275 |
+
stride_qb, stride_qh, stride_qm,
|
| 276 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 277 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 278 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 279 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 280 |
+
stride_lse_b, stride_lse_h,
|
| 281 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 282 |
+
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
|
| 283 |
+
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
|
| 284 |
+
nheads,
|
| 285 |
+
seqlen_q,
|
| 286 |
+
nchunks,
|
| 287 |
+
headdim,
|
| 288 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 289 |
+
WINDOW_SIZE: tl.constexpr,
|
| 290 |
+
MASK_TYPE: tl.constexpr,
|
| 291 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 292 |
+
EVEN_M: tl.constexpr,
|
| 293 |
+
EVEN_C: tl.constexpr,
|
| 294 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 295 |
+
BLOCK_M: tl.constexpr,
|
| 296 |
+
BLOCK_N: tl.constexpr,
|
| 297 |
+
):
|
| 298 |
+
off_bh = tl.program_id(1)
|
| 299 |
+
off_h = off_bh % nheads
|
| 300 |
+
off_b = off_bh // nheads
|
| 301 |
+
start_c = tl.program_id(0)
|
| 302 |
+
# there are 128 chunks per window
|
| 303 |
+
offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 304 |
+
# determine which window the current KV block belongs to
|
| 305 |
+
offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
|
| 306 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 307 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 308 |
+
|
| 309 |
+
# initialize pointers
|
| 310 |
+
q_ptrs = (
|
| 311 |
+
Q +
|
| 312 |
+
off_b * stride_qb +
|
| 313 |
+
off_h * stride_qh +
|
| 314 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 315 |
+
)
|
| 316 |
+
do_ptrs = (
|
| 317 |
+
DO +
|
| 318 |
+
off_b * stride_do_b +
|
| 319 |
+
off_h * stride_do_h +
|
| 320 |
+
(offs_m[:, None] * stride_do_m + offs_d[None, :])
|
| 321 |
+
)
|
| 322 |
+
do_t_o_ptrs = (
|
| 323 |
+
DO_T_O +
|
| 324 |
+
off_b * stride_do_t_o_b +
|
| 325 |
+
off_h * stride_do_t_o_h +
|
| 326 |
+
(offs_m[:, None])
|
| 327 |
+
)
|
| 328 |
+
lse_ptrs = (
|
| 329 |
+
LSE +
|
| 330 |
+
off_b * stride_lse_b +
|
| 331 |
+
off_h * stride_lse_h +
|
| 332 |
+
(offs_m[:, None])
|
| 333 |
+
)
|
| 334 |
+
rfa_k_ptrs = (
|
| 335 |
+
RFA_K +
|
| 336 |
+
off_b * stride_rfa_kb +
|
| 337 |
+
off_h * stride_rfa_kh +
|
| 338 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 339 |
+
)
|
| 340 |
+
rfa_v_ptrs = (
|
| 341 |
+
RFA_V +
|
| 342 |
+
off_b * stride_rfa_vb +
|
| 343 |
+
off_h * stride_rfa_vh +
|
| 344 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 345 |
+
)
|
| 346 |
+
if MASK_TYPE == 1:
|
| 347 |
+
rfa_m_ptrs = (
|
| 348 |
+
ChunkMask +
|
| 349 |
+
off_b * stride_chunk_mask_b +
|
| 350 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 351 |
+
)
|
| 352 |
+
d_rfa_k_ptrs = (
|
| 353 |
+
D_RFA_K +
|
| 354 |
+
off_b * stride_d_rfa_k_b +
|
| 355 |
+
off_h * stride_d_rfa_k_h +
|
| 356 |
+
(offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
|
| 357 |
+
)
|
| 358 |
+
d_rfa_v_ptrs = (
|
| 359 |
+
D_RFA_V +
|
| 360 |
+
off_b * stride_d_rfa_v_b +
|
| 361 |
+
off_h * stride_d_rfa_v_h +
|
| 362 |
+
(offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 366 |
+
d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 367 |
+
if EVEN_C & EVEN_M:
|
| 368 |
+
if EVEN_HEADDIM:
|
| 369 |
+
rfa_k = tl.load(rfa_k_ptrs)
|
| 370 |
+
rfa_v = tl.load(rfa_v_ptrs)
|
| 371 |
+
else:
|
| 372 |
+
rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 373 |
+
rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 374 |
+
else:
|
| 375 |
+
if EVEN_HEADDIM:
|
| 376 |
+
rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
|
| 377 |
+
rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
|
| 378 |
+
else:
|
| 379 |
+
rfa_k = tl.load(
|
| 380 |
+
rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
|
| 381 |
+
)
|
| 382 |
+
rfa_v = tl.load(
|
| 383 |
+
rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
|
| 384 |
+
)
|
| 385 |
+
begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
|
| 386 |
+
end_m = seqlen_q
|
| 387 |
+
for start_m in range(begin_m, end_m, BLOCK_M):
|
| 388 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 389 |
+
# load q, do, and lse
|
| 390 |
+
if EVEN_M:
|
| 391 |
+
if EVEN_HEADDIM:
|
| 392 |
+
q = tl.load(
|
| 393 |
+
q_ptrs + start_m * stride_qm
|
| 394 |
+
)
|
| 395 |
+
do = tl.load(
|
| 396 |
+
do_ptrs + start_m * stride_do_m
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
q = tl.load(
|
| 400 |
+
q_ptrs + start_m * stride_qm,
|
| 401 |
+
mask=offs_d[None, :] < headdim,
|
| 402 |
+
other=0.0
|
| 403 |
+
)
|
| 404 |
+
do = tl.load(
|
| 405 |
+
do_ptrs + start_m * stride_do_m,
|
| 406 |
+
mask=offs_d[None, :] < headdim,
|
| 407 |
+
other=0.0
|
| 408 |
+
)
|
| 409 |
+
do_t_o = tl.load(
|
| 410 |
+
do_t_o_ptrs + start_m
|
| 411 |
+
)
|
| 412 |
+
lse = tl.load(
|
| 413 |
+
lse_ptrs + start_m
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
if EVEN_HEADDIM:
|
| 417 |
+
q = tl.load(
|
| 418 |
+
q_ptrs + start_m * stride_qm,
|
| 419 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 420 |
+
other=0.0
|
| 421 |
+
)
|
| 422 |
+
do = tl.load(
|
| 423 |
+
do_ptrs + start_m * stride_do_m,
|
| 424 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 425 |
+
other=0.0
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
q = tl.load(
|
| 429 |
+
q_ptrs + start_m * stride_qm,
|
| 430 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 431 |
+
other=0.0
|
| 432 |
+
)
|
| 433 |
+
do = tl.load(
|
| 434 |
+
do_ptrs + start_m * stride_do_m,
|
| 435 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 436 |
+
other=0.0
|
| 437 |
+
)
|
| 438 |
+
do_t_o = tl.load(
|
| 439 |
+
do_t_o_ptrs + start_m,
|
| 440 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 441 |
+
other=0.0
|
| 442 |
+
)
|
| 443 |
+
lse = tl.load(
|
| 444 |
+
lse_ptrs + start_m,
|
| 445 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 446 |
+
other=0.0
|
| 447 |
+
)
|
| 448 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 449 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 450 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 451 |
+
if not EVEN_M:
|
| 452 |
+
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
|
| 453 |
+
|
| 454 |
+
if MASK_TYPE == 1:
|
| 455 |
+
if EVEN_M & EVEN_C:
|
| 456 |
+
mask = tl.load(
|
| 457 |
+
rfa_m_ptrs + (start_m * stride_chunk_mask_m)
|
| 458 |
+
)
|
| 459 |
+
else:
|
| 460 |
+
mask = tl.load(
|
| 461 |
+
rfa_m_ptrs + (start_m * stride_chunk_mask_m),
|
| 462 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q)
|
| 463 |
+
& (offs_c[None, :] < nchunks),
|
| 464 |
+
other=1,
|
| 465 |
+
)
|
| 466 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 467 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 468 |
+
# to multiply with softmax_scale here.
|
| 469 |
+
# we assume mask already implies the causal masking
|
| 470 |
+
qk = qk * softmax_scale
|
| 471 |
+
qk = tl.where(mask, float("-inf"), qk)
|
| 472 |
+
p = tl.exp(qk - lse)
|
| 473 |
+
else:
|
| 474 |
+
p = tl.exp(qk * softmax_scale - lse)
|
| 475 |
+
|
| 476 |
+
dp = tl.dot(do, tl.trans(rfa_v))
|
| 477 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 478 |
+
# p is fp32, convert to q.dtype
|
| 479 |
+
d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
|
| 480 |
+
# move softmax_scale to ds to save computation
|
| 481 |
+
d_rfa_k += tl.dot(tl.trans(ds), q)
|
| 482 |
+
if EVEN_C & EVEN_M:
|
| 483 |
+
if EVEN_HEADDIM:
|
| 484 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v)
|
| 485 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k)
|
| 486 |
+
else:
|
| 487 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
|
| 488 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
|
| 489 |
+
else:
|
| 490 |
+
if EVEN_HEADDIM:
|
| 491 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
|
| 492 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
|
| 493 |
+
else:
|
| 494 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
|
| 495 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
|
| 496 |
+
|
| 497 |
+
@triton.heuristics(
|
| 498 |
+
{
|
| 499 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 500 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 501 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 502 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 503 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 504 |
+
}
|
| 505 |
+
)
|
| 506 |
+
@triton.jit
|
| 507 |
+
def _bwd_eva_agg_kernel_dq(
|
| 508 |
+
Q,
|
| 509 |
+
K,
|
| 510 |
+
V,
|
| 511 |
+
RFA_K,
|
| 512 |
+
RFA_V,
|
| 513 |
+
WindowMask,
|
| 514 |
+
ChunkMask,
|
| 515 |
+
DO,
|
| 516 |
+
LSE,
|
| 517 |
+
DO_T_O,
|
| 518 |
+
DQ,
|
| 519 |
+
softmax_scale,
|
| 520 |
+
stride_qb, stride_qh, stride_qm,
|
| 521 |
+
stride_kb, stride_kh, stride_kn,
|
| 522 |
+
stride_vb, stride_vh, stride_vn,
|
| 523 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 524 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 525 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 526 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 527 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 528 |
+
stride_lse_b, stride_lse_h,
|
| 529 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 530 |
+
stride_dq_b, stride_dq_h, stride_dq_m,
|
| 531 |
+
nheads,
|
| 532 |
+
seqlen_q,
|
| 533 |
+
seqlen_k,
|
| 534 |
+
nchunks,
|
| 535 |
+
headdim,
|
| 536 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 537 |
+
WINDOW_SIZE: tl.constexpr,
|
| 538 |
+
MASK_TYPE: tl.constexpr,
|
| 539 |
+
EMPTY_RFA_KV: tl.constexpr,
|
| 540 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 541 |
+
EVEN_M: tl.constexpr,
|
| 542 |
+
EVEN_N: tl.constexpr,
|
| 543 |
+
EVEN_W: tl.constexpr,
|
| 544 |
+
EVEN_C: tl.constexpr,
|
| 545 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 546 |
+
BLOCK_M: tl.constexpr,
|
| 547 |
+
BLOCK_N: tl.constexpr,
|
| 548 |
+
):
|
| 549 |
+
start_m = tl.program_id(0)
|
| 550 |
+
off_bh = tl.program_id(1)
|
| 551 |
+
off_h = off_bh % nheads
|
| 552 |
+
off_b = off_bh // nheads
|
| 553 |
+
# initialize offsets
|
| 554 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 555 |
+
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
|
| 556 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 557 |
+
offs_c = tl.arange(0, BLOCK_N)
|
| 558 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 559 |
+
# TODO: add paratheses or not
|
| 560 |
+
q_ptrs = (
|
| 561 |
+
Q +
|
| 562 |
+
off_b * stride_qb +
|
| 563 |
+
off_h * stride_qh +
|
| 564 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 565 |
+
)
|
| 566 |
+
k_ptrs = (
|
| 567 |
+
K +
|
| 568 |
+
off_b * stride_kb +
|
| 569 |
+
off_h * stride_kh +
|
| 570 |
+
(offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 571 |
+
)
|
| 572 |
+
v_ptrs = (
|
| 573 |
+
V +
|
| 574 |
+
off_b * stride_vb +
|
| 575 |
+
off_h * stride_vh +
|
| 576 |
+
(offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 577 |
+
)
|
| 578 |
+
if EMPTY_RFA_KV == 0:
|
| 579 |
+
rfa_k_ptrs = (
|
| 580 |
+
RFA_K +
|
| 581 |
+
off_b * stride_rfa_kb +
|
| 582 |
+
off_h * stride_rfa_kh +
|
| 583 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 584 |
+
)
|
| 585 |
+
rfa_v_ptrs = (
|
| 586 |
+
RFA_V +
|
| 587 |
+
off_b * stride_rfa_vb +
|
| 588 |
+
off_h * stride_rfa_vh +
|
| 589 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 590 |
+
)
|
| 591 |
+
dq_ptrs = (
|
| 592 |
+
DQ +
|
| 593 |
+
off_b * stride_dq_b +
|
| 594 |
+
off_h * stride_dq_h +
|
| 595 |
+
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
|
| 596 |
+
)
|
| 597 |
+
do_ptrs = (
|
| 598 |
+
DO +
|
| 599 |
+
off_b * stride_do_b +
|
| 600 |
+
off_h * stride_do_h +
|
| 601 |
+
(offs_m[:, None] * stride_do_m + offs_d[None, :])
|
| 602 |
+
)
|
| 603 |
+
do_t_o_ptrs = (
|
| 604 |
+
DO_T_O +
|
| 605 |
+
off_b * stride_do_t_o_b +
|
| 606 |
+
off_h * stride_do_t_o_h +
|
| 607 |
+
offs_m[:, None]
|
| 608 |
+
)
|
| 609 |
+
lse_ptrs = (
|
| 610 |
+
LSE +
|
| 611 |
+
off_b * stride_lse_b +
|
| 612 |
+
off_h * stride_lse_h +
|
| 613 |
+
offs_m[:, None]
|
| 614 |
+
)
|
| 615 |
+
### load q, do, do_t_o, lse ####
|
| 616 |
+
if EVEN_M:
|
| 617 |
+
if EVEN_HEADDIM:
|
| 618 |
+
q = tl.load(
|
| 619 |
+
q_ptrs
|
| 620 |
+
)
|
| 621 |
+
do = tl.load(
|
| 622 |
+
do_ptrs
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
q = tl.load(
|
| 626 |
+
q_ptrs,
|
| 627 |
+
mask=offs_d[None, :] < headdim,
|
| 628 |
+
other=0.0
|
| 629 |
+
)
|
| 630 |
+
do = tl.load(
|
| 631 |
+
do_ptrs,
|
| 632 |
+
mask=offs_d[None, :] < headdim,
|
| 633 |
+
other=0.0
|
| 634 |
+
)
|
| 635 |
+
do_t_o = tl.load(
|
| 636 |
+
do_t_o_ptrs
|
| 637 |
+
)
|
| 638 |
+
lse = tl.load(
|
| 639 |
+
lse_ptrs
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
if EVEN_HEADDIM:
|
| 643 |
+
q = tl.load(
|
| 644 |
+
q_ptrs,
|
| 645 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 646 |
+
other=0.0
|
| 647 |
+
)
|
| 648 |
+
do = tl.load(
|
| 649 |
+
do_ptrs,
|
| 650 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 651 |
+
other=0.0
|
| 652 |
+
)
|
| 653 |
+
else:
|
| 654 |
+
q = tl.load(
|
| 655 |
+
q_ptrs,
|
| 656 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 657 |
+
other=0.0
|
| 658 |
+
)
|
| 659 |
+
do = tl.load(
|
| 660 |
+
do_ptrs,
|
| 661 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 662 |
+
other=0.0
|
| 663 |
+
)
|
| 664 |
+
do_t_o = tl.load(
|
| 665 |
+
do_t_o_ptrs,
|
| 666 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 667 |
+
other=0.0
|
| 668 |
+
)
|
| 669 |
+
lse = tl.load(
|
| 670 |
+
lse_ptrs,
|
| 671 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 672 |
+
other=0.0
|
| 673 |
+
)
|
| 674 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 675 |
+
lse *= 1.4426950408889634 # log2(e)
|
| 676 |
+
qk_scale = softmax_scale
|
| 677 |
+
qk_scale *= 1.4426950408889634 # log2(e)
|
| 678 |
+
if MASK_TYPE == 1:
|
| 679 |
+
window_mask_ptrs = (
|
| 680 |
+
WindowMask +
|
| 681 |
+
off_b * stride_window_mask_b +
|
| 682 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 683 |
+
)
|
| 684 |
+
if EMPTY_RFA_KV == 0:
|
| 685 |
+
chunk_mask_ptrs = (
|
| 686 |
+
ChunkMask +
|
| 687 |
+
off_b * stride_chunk_mask_b +
|
| 688 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 692 |
+
# loop over k, v and update accumulator
|
| 693 |
+
# Iterate over local singletons;
|
| 694 |
+
# so we only iterate over blocks within the current window
|
| 695 |
+
start_idx_n = offs_w * WINDOW_SIZE
|
| 696 |
+
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 697 |
+
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
|
| 698 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 699 |
+
if EVEN_N & EVEN_M:
|
| 700 |
+
if EVEN_HEADDIM:
|
| 701 |
+
k = tl.load(
|
| 702 |
+
k_ptrs + start_n * stride_kn
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
k = tl.load(
|
| 706 |
+
k_ptrs + start_n * stride_kn,
|
| 707 |
+
mask=offs_d[None, :] < headdim,
|
| 708 |
+
other=0.0
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
if EVEN_HEADDIM:
|
| 712 |
+
k = tl.load(
|
| 713 |
+
k_ptrs + start_n * stride_kn,
|
| 714 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 715 |
+
other=0.0,
|
| 716 |
+
)
|
| 717 |
+
else:
|
| 718 |
+
k = tl.load(
|
| 719 |
+
k_ptrs + start_n * stride_kn,
|
| 720 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 721 |
+
other=0.0,
|
| 722 |
+
)
|
| 723 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 724 |
+
qk += tl.dot(q, tl.trans(k))
|
| 725 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 726 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 727 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 728 |
+
|
| 729 |
+
if MASK_TYPE == 1:
|
| 730 |
+
if EVEN_M & EVEN_W:
|
| 731 |
+
window_mask = tl.load(
|
| 732 |
+
window_mask_ptrs + start_n - start_idx_n
|
| 733 |
+
)
|
| 734 |
+
else:
|
| 735 |
+
window_mask = tl.load(
|
| 736 |
+
window_mask_ptrs + start_n - start_idx_n,
|
| 737 |
+
mask=(offs_m[:, None] < seqlen_q)
|
| 738 |
+
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
|
| 739 |
+
other=1,
|
| 740 |
+
)
|
| 741 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 742 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 743 |
+
# to multiply with softmax_scale here.
|
| 744 |
+
# we assume mask already implies the causal masking
|
| 745 |
+
qk = qk * qk_scale
|
| 746 |
+
qk = tl.where(window_mask, float("-inf"), qk)
|
| 747 |
+
p = tl.exp2(qk - lse)
|
| 748 |
+
else:
|
| 749 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 750 |
+
p = tl.exp2(qk * qk_scale - lse)
|
| 751 |
+
|
| 752 |
+
if EVEN_N & EVEN_M:
|
| 753 |
+
if EVEN_HEADDIM:
|
| 754 |
+
v = tl.load(
|
| 755 |
+
v_ptrs + start_n * stride_vn
|
| 756 |
+
)
|
| 757 |
+
else:
|
| 758 |
+
v = tl.load(
|
| 759 |
+
v_ptrs + start_n * stride_vn,
|
| 760 |
+
mask=offs_d[None, :] < headdim,
|
| 761 |
+
other=0.0
|
| 762 |
+
)
|
| 763 |
+
else:
|
| 764 |
+
if EVEN_HEADDIM:
|
| 765 |
+
v = tl.load(
|
| 766 |
+
v_ptrs + start_n * stride_vn,
|
| 767 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 768 |
+
other=0.0,
|
| 769 |
+
)
|
| 770 |
+
else:
|
| 771 |
+
v = tl.load(
|
| 772 |
+
v_ptrs + start_n * stride_vn,
|
| 773 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 774 |
+
other=0.0,
|
| 775 |
+
)
|
| 776 |
+
dp = tl.dot(do, tl.trans(v))
|
| 777 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 778 |
+
dq += tl.dot(ds, k)
|
| 779 |
+
|
| 780 |
+
if EMPTY_RFA_KV == 0:
|
| 781 |
+
# Iterate over RFA chunks
|
| 782 |
+
# we only iterate over chunks before the current local singleton window
|
| 783 |
+
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
|
| 784 |
+
for start_c in range(0, end_idx_c, BLOCK_N):
|
| 785 |
+
start_c = tl.multiple_of(start_c, BLOCK_N)
|
| 786 |
+
# -- compute qk ----
|
| 787 |
+
if EVEN_C & EVEN_M:
|
| 788 |
+
if EVEN_HEADDIM:
|
| 789 |
+
rfa_k = tl.load(
|
| 790 |
+
rfa_k_ptrs + start_c * stride_rfa_kc
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
rfa_k = tl.load(
|
| 794 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 795 |
+
mask=offs_d[None, :] < headdim,
|
| 796 |
+
other=0.0
|
| 797 |
+
)
|
| 798 |
+
else:
|
| 799 |
+
if EVEN_HEADDIM:
|
| 800 |
+
rfa_k = tl.load(
|
| 801 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 802 |
+
mask=(start_c + offs_c)[:, None] < nchunks,
|
| 803 |
+
other=0.0,
|
| 804 |
+
)
|
| 805 |
+
else:
|
| 806 |
+
rfa_k = tl.load(
|
| 807 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 808 |
+
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 809 |
+
other=0.0,
|
| 810 |
+
)
|
| 811 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 812 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 813 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 814 |
+
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
|
| 815 |
+
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
|
| 816 |
+
|
| 817 |
+
if MASK_TYPE == 1:
|
| 818 |
+
if EVEN_C & EVEN_M:
|
| 819 |
+
chunk_mask = tl.load(
|
| 820 |
+
chunk_mask_ptrs + start_c
|
| 821 |
+
)
|
| 822 |
+
else:
|
| 823 |
+
chunk_mask = tl.load(
|
| 824 |
+
chunk_mask_ptrs + start_c,
|
| 825 |
+
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
|
| 826 |
+
other=1,
|
| 827 |
+
)
|
| 828 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 829 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 830 |
+
# to multiply with softmax_scale here.
|
| 831 |
+
# we assume mask already implies the causal masking
|
| 832 |
+
qk = qk * qk_scale
|
| 833 |
+
qk = tl.where(chunk_mask, float("-inf"), qk)
|
| 834 |
+
p = tl.exp2(qk - lse)
|
| 835 |
+
else:
|
| 836 |
+
p = tl.exp2(qk * qk_scale - lse)
|
| 837 |
+
|
| 838 |
+
if EVEN_C & EVEN_M:
|
| 839 |
+
if EVEN_HEADDIM:
|
| 840 |
+
rfa_v = tl.load(
|
| 841 |
+
rfa_v_ptrs + start_c * stride_rfa_vc
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
rfa_v = tl.load(
|
| 845 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 846 |
+
mask=offs_d[None, :] < headdim,
|
| 847 |
+
other=0.0
|
| 848 |
+
)
|
| 849 |
+
else:
|
| 850 |
+
if EVEN_HEADDIM:
|
| 851 |
+
rfa_v = tl.load(
|
| 852 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 853 |
+
mask=(start_c + offs_n)[:, None] < nchunks,
|
| 854 |
+
other=0.0,
|
| 855 |
+
)
|
| 856 |
+
else:
|
| 857 |
+
rfa_v = tl.load(
|
| 858 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 859 |
+
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 860 |
+
other=0.0,
|
| 861 |
+
)
|
| 862 |
+
dp = tl.dot(do, tl.trans(rfa_v))
|
| 863 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 864 |
+
dq += tl.dot(ds, rfa_k)
|
| 865 |
+
|
| 866 |
+
start_m = tl.program_id(0)
|
| 867 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 868 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 869 |
+
dq_ptrs = (
|
| 870 |
+
DQ +
|
| 871 |
+
off_b * stride_dq_b +
|
| 872 |
+
off_h * stride_dq_h +
|
| 873 |
+
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
|
| 874 |
+
)
|
| 875 |
+
if EVEN_M:
|
| 876 |
+
if EVEN_HEADDIM:
|
| 877 |
+
tl.store(
|
| 878 |
+
dq_ptrs, dq
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
tl.store(
|
| 882 |
+
dq_ptrs, dq,
|
| 883 |
+
mask=offs_d[None, :] < headdim
|
| 884 |
+
)
|
| 885 |
+
else:
|
| 886 |
+
if EVEN_HEADDIM:
|
| 887 |
+
tl.store(
|
| 888 |
+
dq_ptrs, dq,
|
| 889 |
+
mask=offs_m[:, None] < seqlen_q
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
tl.store(
|
| 893 |
+
dq_ptrs, dq,
|
| 894 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
_capability_90_config = {
|
| 898 |
+
"fwd": {
|
| 899 |
+
(torch.bfloat16, 64): (128, 128, 4, 3),
|
| 900 |
+
(torch.bfloat16, 128): (128, 128, 8, 3),
|
| 901 |
+
(torch.float32, 64): (128, 64, 8, 3),
|
| 902 |
+
(torch.float32, 128): (64, 32, 4, 3),
|
| 903 |
+
},
|
| 904 |
+
"bwd_dq": {
|
| 905 |
+
(torch.bfloat16, 64): (128, 64, 4, 3),
|
| 906 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 907 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 908 |
+
(torch.float32, 128): (32, 32, 4, 2),
|
| 909 |
+
},
|
| 910 |
+
"bwd_dkdv": {
|
| 911 |
+
(torch.bfloat16, 64): (128, 64, 4, 2),
|
| 912 |
+
(torch.bfloat16, 128): (128, 64, 8, 2),
|
| 913 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 914 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 915 |
+
},
|
| 916 |
+
"bwd_drfa_kv": {
|
| 917 |
+
(torch.bfloat16, 64): (128, 64, 4, 2),
|
| 918 |
+
(torch.bfloat16, 128): (128, 64, 8, 2),
|
| 919 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 920 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 921 |
+
}
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
_capability_80_config = {
|
| 925 |
+
"fwd": {
|
| 926 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 927 |
+
(torch.bfloat16, 128): (64, 64, 8, 3),
|
| 928 |
+
(torch.float32, 64): (64, 32, 4, 2),
|
| 929 |
+
(torch.float32, 128): (64, 32, 8, 1),
|
| 930 |
+
},
|
| 931 |
+
"bwd_dq": {
|
| 932 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 933 |
+
(torch.bfloat16, 128): (64, 32, 4, 2),
|
| 934 |
+
(torch.float32, 64): (32, 32, 4, 2),
|
| 935 |
+
(torch.float32, 128): (32, 32, 4, 2),
|
| 936 |
+
},
|
| 937 |
+
"bwd_dkdv": {
|
| 938 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 939 |
+
(torch.bfloat16, 128): (32, 32, 4, 2),
|
| 940 |
+
(torch.float32, 64): (32, 32, 4, 1),
|
| 941 |
+
(torch.float32, 128): (16, 64, 8, 1),
|
| 942 |
+
},
|
| 943 |
+
"bwd_drfa_kv": {
|
| 944 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 945 |
+
(torch.bfloat16, 128): (64, 32, 4, 3),
|
| 946 |
+
(torch.float32, 64): (32, 32, 4, 1),
|
| 947 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
|
| 952 |
+
capability = torch.cuda.get_device_capability()
|
| 953 |
+
if capability >= (9, 0):
|
| 954 |
+
kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
|
| 955 |
+
elif capability >= (8, 0):
|
| 956 |
+
kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
|
| 957 |
+
else:
|
| 958 |
+
if mode == "fwd":
|
| 959 |
+
if dtype == torch.float32:
|
| 960 |
+
kernel_config = (32, 16, 4, 2)
|
| 961 |
+
else:
|
| 962 |
+
kernel_config = (64, 32, 4, 2)
|
| 963 |
+
else:
|
| 964 |
+
if dtype == torch.float32:
|
| 965 |
+
kernel_config = (16, 16, 4, 1)
|
| 966 |
+
else:
|
| 967 |
+
kernel_config = (32, 32, 4, 1)
|
| 968 |
+
return kernel_config
|
| 969 |
+
|
| 970 |
+
@triton.heuristics(
|
| 971 |
+
{
|
| 972 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 973 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 974 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 975 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 976 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 977 |
+
}
|
| 978 |
+
)
|
| 979 |
+
@triton.jit
|
| 980 |
+
def _fwd_eva_agg_kernel(
|
| 981 |
+
Q,
|
| 982 |
+
K,
|
| 983 |
+
V,
|
| 984 |
+
RFA_K,
|
| 985 |
+
RFA_V,
|
| 986 |
+
WindowMask,
|
| 987 |
+
ChunkMask,
|
| 988 |
+
Out,
|
| 989 |
+
LSE,
|
| 990 |
+
softmax_scale,
|
| 991 |
+
stride_qb, stride_qh, stride_qm,
|
| 992 |
+
stride_kb, stride_kh, stride_kn,
|
| 993 |
+
stride_vb, stride_vh, stride_vn,
|
| 994 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 995 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 996 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 997 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 998 |
+
stride_ob, stride_oh, stride_om,
|
| 999 |
+
stride_lse_b, stride_lse_h,
|
| 1000 |
+
nheads,
|
| 1001 |
+
seqlen_q,
|
| 1002 |
+
seqlen_k,
|
| 1003 |
+
nchunks,
|
| 1004 |
+
headdim,
|
| 1005 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 1006 |
+
WINDOW_SIZE: tl.constexpr,
|
| 1007 |
+
MASK_TYPE: tl.constexpr,
|
| 1008 |
+
EMPTY_RFA_KV: tl.constexpr,
|
| 1009 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 1010 |
+
EVEN_M: tl.constexpr,
|
| 1011 |
+
EVEN_N: tl.constexpr,
|
| 1012 |
+
EVEN_W: tl.constexpr,
|
| 1013 |
+
EVEN_C: tl.constexpr,
|
| 1014 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 1015 |
+
BLOCK_M: tl.constexpr,
|
| 1016 |
+
BLOCK_N: tl.constexpr,
|
| 1017 |
+
):
|
| 1018 |
+
start_m = tl.program_id(0)
|
| 1019 |
+
off_bh = tl.program_id(1)
|
| 1020 |
+
off_h = off_bh % nheads
|
| 1021 |
+
off_b = off_bh // nheads
|
| 1022 |
+
# initialize offsets
|
| 1023 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 1024 |
+
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
|
| 1025 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 1026 |
+
offs_c = tl.arange(0, BLOCK_N)
|
| 1027 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 1028 |
+
# TODO: add paratheses or not
|
| 1029 |
+
q_ptrs = (
|
| 1030 |
+
Q +
|
| 1031 |
+
off_b * stride_qb +
|
| 1032 |
+
off_h * stride_qh +
|
| 1033 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 1034 |
+
)
|
| 1035 |
+
k_ptrs = (
|
| 1036 |
+
K +
|
| 1037 |
+
off_b * stride_kb +
|
| 1038 |
+
off_h * stride_kh +
|
| 1039 |
+
(offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 1040 |
+
)
|
| 1041 |
+
v_ptrs = (
|
| 1042 |
+
V +
|
| 1043 |
+
off_b * stride_vb +
|
| 1044 |
+
off_h * stride_vh +
|
| 1045 |
+
(offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 1046 |
+
)
|
| 1047 |
+
if EMPTY_RFA_KV == 0:
|
| 1048 |
+
rfa_k_ptrs = (
|
| 1049 |
+
RFA_K +
|
| 1050 |
+
off_b * stride_rfa_kb +
|
| 1051 |
+
off_h * stride_rfa_kh +
|
| 1052 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 1053 |
+
)
|
| 1054 |
+
rfa_v_ptrs = (
|
| 1055 |
+
RFA_V +
|
| 1056 |
+
off_b * stride_rfa_vb +
|
| 1057 |
+
off_h * stride_rfa_vh +
|
| 1058 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
qk_scale = softmax_scale
|
| 1062 |
+
qk_scale *= 1.4426950408889634 # log2(e)
|
| 1063 |
+
if MASK_TYPE == 1:
|
| 1064 |
+
window_mask_ptrs = (
|
| 1065 |
+
WindowMask +
|
| 1066 |
+
off_b * stride_window_mask_b +
|
| 1067 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 1068 |
+
)
|
| 1069 |
+
if EMPTY_RFA_KV == 0:
|
| 1070 |
+
chunk_mask_ptrs = (
|
| 1071 |
+
ChunkMask +
|
| 1072 |
+
off_b * stride_chunk_mask_b +
|
| 1073 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 1077 |
+
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 1078 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 1079 |
+
# load q: it will stay in SRAM throughout
|
| 1080 |
+
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
|
| 1081 |
+
# tl.load(q_ptrs), we get the wrong output!
|
| 1082 |
+
if EVEN_M & EVEN_N:
|
| 1083 |
+
if EVEN_HEADDIM:
|
| 1084 |
+
q = tl.load(
|
| 1085 |
+
q_ptrs
|
| 1086 |
+
)
|
| 1087 |
+
else:
|
| 1088 |
+
q = tl.load(
|
| 1089 |
+
q_ptrs,
|
| 1090 |
+
mask=offs_d[None, :] < headdim,
|
| 1091 |
+
other=0.0
|
| 1092 |
+
)
|
| 1093 |
+
else:
|
| 1094 |
+
if EVEN_HEADDIM:
|
| 1095 |
+
q = tl.load(
|
| 1096 |
+
q_ptrs,
|
| 1097 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 1098 |
+
other=0.0
|
| 1099 |
+
)
|
| 1100 |
+
else:
|
| 1101 |
+
q = tl.load(
|
| 1102 |
+
q_ptrs,
|
| 1103 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 1104 |
+
other=0.0
|
| 1105 |
+
)
|
| 1106 |
+
# loop over k, v and update accumulator
|
| 1107 |
+
# Iterate over local singletons;
|
| 1108 |
+
# so we only iterate over blocks within the current window
|
| 1109 |
+
start_idx_n = offs_w * WINDOW_SIZE
|
| 1110 |
+
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 1111 |
+
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
|
| 1112 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 1113 |
+
# -- compute qk ----
|
| 1114 |
+
if EVEN_N & EVEN_M:
|
| 1115 |
+
if EVEN_HEADDIM:
|
| 1116 |
+
k = tl.load(
|
| 1117 |
+
k_ptrs + start_n * stride_kn
|
| 1118 |
+
)
|
| 1119 |
+
else:
|
| 1120 |
+
k = tl.load(
|
| 1121 |
+
k_ptrs + start_n * stride_kn,
|
| 1122 |
+
mask=offs_d[None, :] < headdim,
|
| 1123 |
+
other=0.0
|
| 1124 |
+
)
|
| 1125 |
+
else:
|
| 1126 |
+
if EVEN_HEADDIM:
|
| 1127 |
+
k = tl.load(
|
| 1128 |
+
k_ptrs + start_n * stride_kn,
|
| 1129 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 1130 |
+
other=0.0,
|
| 1131 |
+
)
|
| 1132 |
+
else:
|
| 1133 |
+
k = tl.load(
|
| 1134 |
+
k_ptrs + start_n * stride_kn,
|
| 1135 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 1136 |
+
other=0.0,
|
| 1137 |
+
)
|
| 1138 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 1139 |
+
qk += tl.dot(q, tl.trans(k))
|
| 1140 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 1141 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 1142 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 1143 |
+
|
| 1144 |
+
if MASK_TYPE == 1:
|
| 1145 |
+
if EVEN_M & EVEN_W:
|
| 1146 |
+
window_mask = tl.load(
|
| 1147 |
+
window_mask_ptrs + start_n - start_idx_n
|
| 1148 |
+
)
|
| 1149 |
+
else:
|
| 1150 |
+
window_mask = tl.load(
|
| 1151 |
+
window_mask_ptrs + start_n - start_idx_n,
|
| 1152 |
+
mask=(offs_m[:, None] < seqlen_q)
|
| 1153 |
+
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
|
| 1154 |
+
other=1,
|
| 1155 |
+
)
|
| 1156 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 1157 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 1158 |
+
# to multiply with softmax_scale here.
|
| 1159 |
+
# we assume mask already implies the causal masking
|
| 1160 |
+
qk = qk * qk_scale
|
| 1161 |
+
qk = tl.where(window_mask, float("-inf"), qk)
|
| 1162 |
+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
|
| 1163 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1164 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1165 |
+
p = tl.exp2(qk - m_ij_masked[:, None])
|
| 1166 |
+
else:
|
| 1167 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 1168 |
+
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
|
| 1169 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1170 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1171 |
+
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
|
| 1172 |
+
|
| 1173 |
+
d_ij = tl.sum(p, 1)
|
| 1174 |
+
|
| 1175 |
+
# scale acc_o
|
| 1176 |
+
prev_scale = tl.exp2(m_i - m_ij_masked)
|
| 1177 |
+
# # -- update output accumulator --
|
| 1178 |
+
acc_o = acc_o * prev_scale[:, None]
|
| 1179 |
+
# update acc_o
|
| 1180 |
+
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
| 1181 |
+
if EVEN_HEADDIM:
|
| 1182 |
+
v = tl.load(
|
| 1183 |
+
v_ptrs + start_n * stride_vn
|
| 1184 |
+
)
|
| 1185 |
+
else:
|
| 1186 |
+
v = tl.load(
|
| 1187 |
+
v_ptrs + start_n * stride_vn,
|
| 1188 |
+
mask=offs_d[None, :] < headdim,
|
| 1189 |
+
other=0.0
|
| 1190 |
+
)
|
| 1191 |
+
else:
|
| 1192 |
+
if EVEN_HEADDIM:
|
| 1193 |
+
v = tl.load(
|
| 1194 |
+
v_ptrs + start_n * stride_vn,
|
| 1195 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 1196 |
+
other=0.0,
|
| 1197 |
+
)
|
| 1198 |
+
else:
|
| 1199 |
+
v = tl.load(
|
| 1200 |
+
v_ptrs + start_n * stride_vn,
|
| 1201 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 1202 |
+
other=0.0,
|
| 1203 |
+
)
|
| 1204 |
+
p = p.to(v.dtype)
|
| 1205 |
+
acc_o = tl.dot(p, v, acc_o)
|
| 1206 |
+
|
| 1207 |
+
# -- update statistics
|
| 1208 |
+
d_i = d_i * prev_scale + d_ij
|
| 1209 |
+
m_i = m_ij
|
| 1210 |
+
|
| 1211 |
+
if EMPTY_RFA_KV == 0:
|
| 1212 |
+
# Iterate over RFA chunks
|
| 1213 |
+
# we only iterate over chunks before the current local singleton window
|
| 1214 |
+
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
|
| 1215 |
+
for start_c in range(0, end_idx_c, BLOCK_N):
|
| 1216 |
+
start_c = tl.multiple_of(start_c, BLOCK_N)
|
| 1217 |
+
# -- compute qk ----
|
| 1218 |
+
if EVEN_C & EVEN_M:
|
| 1219 |
+
if EVEN_HEADDIM:
|
| 1220 |
+
rfa_k = tl.load(
|
| 1221 |
+
rfa_k_ptrs + start_c * stride_rfa_kc
|
| 1222 |
+
)
|
| 1223 |
+
else:
|
| 1224 |
+
rfa_k = tl.load(
|
| 1225 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1226 |
+
mask=offs_d[None, :] < headdim,
|
| 1227 |
+
other=0.0
|
| 1228 |
+
)
|
| 1229 |
+
else:
|
| 1230 |
+
if EVEN_HEADDIM:
|
| 1231 |
+
rfa_k = tl.load(
|
| 1232 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1233 |
+
mask=(start_c + offs_c)[:, None] < nchunks,
|
| 1234 |
+
other=0.0,
|
| 1235 |
+
)
|
| 1236 |
+
else:
|
| 1237 |
+
rfa_k = tl.load(
|
| 1238 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1239 |
+
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 1240 |
+
other=0.0,
|
| 1241 |
+
)
|
| 1242 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 1243 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 1244 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 1245 |
+
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
|
| 1246 |
+
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
|
| 1247 |
+
|
| 1248 |
+
if MASK_TYPE == 1:
|
| 1249 |
+
if EVEN_C & EVEN_M:
|
| 1250 |
+
chunk_mask = tl.load(
|
| 1251 |
+
chunk_mask_ptrs + start_c
|
| 1252 |
+
)
|
| 1253 |
+
else:
|
| 1254 |
+
chunk_mask = tl.load(
|
| 1255 |
+
chunk_mask_ptrs + start_c,
|
| 1256 |
+
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
|
| 1257 |
+
other=1,
|
| 1258 |
+
)
|
| 1259 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 1260 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 1261 |
+
# to multiply with softmax_scale here.
|
| 1262 |
+
# we assume mask already implies the causal masking
|
| 1263 |
+
qk = qk * qk_scale
|
| 1264 |
+
qk = tl.where(chunk_mask, float("-inf"), qk)
|
| 1265 |
+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
|
| 1266 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1267 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1268 |
+
p = tl.exp2(qk - m_ij_masked[:, None])
|
| 1269 |
+
else:
|
| 1270 |
+
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
|
| 1271 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1272 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1273 |
+
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
|
| 1274 |
+
|
| 1275 |
+
d_ij = tl.sum(p, 1)
|
| 1276 |
+
|
| 1277 |
+
# scale acc_o
|
| 1278 |
+
prev_scale = tl.exp2(m_i - m_ij_masked)
|
| 1279 |
+
# # -- update output accumulator --
|
| 1280 |
+
acc_o = acc_o * prev_scale[:, None]
|
| 1281 |
+
# update acc_o
|
| 1282 |
+
# TODO: If we just do "if EVEN_N", there seems to be some race condition ?
|
| 1283 |
+
if EVEN_C & EVEN_M:
|
| 1284 |
+
if EVEN_HEADDIM:
|
| 1285 |
+
rfa_v = tl.load(
|
| 1286 |
+
rfa_v_ptrs + start_c * stride_rfa_vc
|
| 1287 |
+
)
|
| 1288 |
+
else:
|
| 1289 |
+
rfa_v = tl.load(
|
| 1290 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1291 |
+
mask=offs_d[None, :] < headdim,
|
| 1292 |
+
other=0.0
|
| 1293 |
+
)
|
| 1294 |
+
else:
|
| 1295 |
+
if EVEN_HEADDIM:
|
| 1296 |
+
rfa_v = tl.load(
|
| 1297 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1298 |
+
mask=(start_c + offs_n)[:, None] < nchunks,
|
| 1299 |
+
other=0.0,
|
| 1300 |
+
)
|
| 1301 |
+
else:
|
| 1302 |
+
rfa_v = tl.load(
|
| 1303 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1304 |
+
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 1305 |
+
other=0.0,
|
| 1306 |
+
)
|
| 1307 |
+
p = p.to(rfa_v.dtype)
|
| 1308 |
+
acc_o = tl.dot(p, rfa_v, acc_o)
|
| 1309 |
+
|
| 1310 |
+
# -- update statistics
|
| 1311 |
+
d_i = d_i * prev_scale + d_ij
|
| 1312 |
+
m_i = m_ij
|
| 1313 |
+
|
| 1314 |
+
# for rows that are all -inf, set d_i to 1.0
|
| 1315 |
+
d_i = tl.where(d_i == 0.0, 1.0, d_i)
|
| 1316 |
+
# multiply by log(2)
|
| 1317 |
+
lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
|
| 1318 |
+
acc_o = acc_o / d_i[:, None]
|
| 1319 |
+
# TODO: understand why rematerialize offsets to save registers?
|
| 1320 |
+
start_m = tl.program_id(0)
|
| 1321 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 1322 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 1323 |
+
out_ptrs = (
|
| 1324 |
+
Out +
|
| 1325 |
+
off_b * stride_ob +
|
| 1326 |
+
off_h * stride_oh +
|
| 1327 |
+
(offs_m[:, None] * stride_om + offs_d[None, :])
|
| 1328 |
+
)
|
| 1329 |
+
if EVEN_M:
|
| 1330 |
+
if EVEN_HEADDIM:
|
| 1331 |
+
tl.store(
|
| 1332 |
+
out_ptrs, acc_o
|
| 1333 |
+
)
|
| 1334 |
+
else:
|
| 1335 |
+
tl.store(
|
| 1336 |
+
out_ptrs, acc_o,
|
| 1337 |
+
mask=offs_d[None, :] < headdim
|
| 1338 |
+
)
|
| 1339 |
+
else:
|
| 1340 |
+
if EVEN_HEADDIM:
|
| 1341 |
+
tl.store(
|
| 1342 |
+
out_ptrs, acc_o,
|
| 1343 |
+
mask=offs_m[:, None] < seqlen_q
|
| 1344 |
+
)
|
| 1345 |
+
else:
|
| 1346 |
+
tl.store(
|
| 1347 |
+
out_ptrs, acc_o,
|
| 1348 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
| 1349 |
+
)
|
| 1350 |
+
lse_ptrs = (
|
| 1351 |
+
LSE +
|
| 1352 |
+
off_b * stride_lse_b +
|
| 1353 |
+
off_h * stride_lse_h +
|
| 1354 |
+
offs_m
|
| 1355 |
+
)
|
| 1356 |
+
if EVEN_M:
|
| 1357 |
+
tl.store(
|
| 1358 |
+
lse_ptrs, lse_m,
|
| 1359 |
+
)
|
| 1360 |
+
else:
|
| 1361 |
+
tl.store(
|
| 1362 |
+
lse_ptrs, lse_m,
|
| 1363 |
+
mask=offs_m < seqlen_q
|
| 1364 |
+
)
|
| 1365 |
+
|
| 1366 |
+
def triton_eva_agg_fwd(
|
| 1367 |
+
q, k, v, rfa_k, rfa_v,
|
| 1368 |
+
window_mask,
|
| 1369 |
+
chunk_mask,
|
| 1370 |
+
softmax_scale,
|
| 1371 |
+
window_size,
|
| 1372 |
+
chunks_per_window
|
| 1373 |
+
):
|
| 1374 |
+
if rfa_k is None and rfa_v is None:
|
| 1375 |
+
empty_rfa_kv = 1
|
| 1376 |
+
|
| 1377 |
+
q, k, v = [
|
| 1378 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 1379 |
+
for x in [q, k, v]
|
| 1380 |
+
]
|
| 1381 |
+
else:
|
| 1382 |
+
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
|
| 1383 |
+
empty_rfa_kv = 0
|
| 1384 |
+
|
| 1385 |
+
q, k, v, rfa_k, rfa_v = [
|
| 1386 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 1387 |
+
for x in [q, k, v, rfa_k, rfa_v]
|
| 1388 |
+
]
|
| 1389 |
+
|
| 1390 |
+
# shape constraints
|
| 1391 |
+
batch, nheads, seqlen_q, head_dim = q.shape
|
| 1392 |
+
_, _, seqlen_k, _ = k.shape
|
| 1393 |
+
if empty_rfa_kv == 0:
|
| 1394 |
+
nchunks = rfa_k.shape[-2]
|
| 1395 |
+
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
|
| 1396 |
+
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
|
| 1397 |
+
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
|
| 1398 |
+
else:
|
| 1399 |
+
nchunks = 0
|
| 1400 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 1401 |
+
assert k.shape == (batch, nheads, seqlen_k, head_dim)
|
| 1402 |
+
assert v.shape == (batch, nheads, seqlen_k, head_dim)
|
| 1403 |
+
|
| 1404 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 1405 |
+
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
| 1406 |
+
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
|
| 1407 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
| 1408 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 1409 |
+
|
| 1410 |
+
mask_type = 0
|
| 1411 |
+
if window_mask is not None:
|
| 1412 |
+
mask_type = 1
|
| 1413 |
+
assert window_mask.dtype == torch.bool
|
| 1414 |
+
assert window_mask.is_cuda
|
| 1415 |
+
assert window_mask.dim() == 4
|
| 1416 |
+
assert window_mask.shape == (batch, 1, seqlen_q, window_size)
|
| 1417 |
+
if window_mask.stride(-1) != 1:
|
| 1418 |
+
window_mask = window_mask.contiguous()
|
| 1419 |
+
|
| 1420 |
+
assert chunk_mask is not None
|
| 1421 |
+
assert chunk_mask.dtype == torch.bool
|
| 1422 |
+
assert chunk_mask.is_cuda
|
| 1423 |
+
assert chunk_mask.dim() == 4
|
| 1424 |
+
assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
|
| 1425 |
+
if chunk_mask.stride(-1) != 1:
|
| 1426 |
+
chunk_mask = chunk_mask.contiguous()
|
| 1427 |
+
|
| 1428 |
+
chunk_mask_strides = (
|
| 1429 |
+
(chunk_mask.stride(0), chunk_mask.stride(2))
|
| 1430 |
+
if mask_type == 1 else
|
| 1431 |
+
(0, 0)
|
| 1432 |
+
)
|
| 1433 |
+
window_mask_strides = (
|
| 1434 |
+
(window_mask.stride(0), window_mask.stride(2))
|
| 1435 |
+
if mask_type == 1 else
|
| 1436 |
+
(0, 0)
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
rfa_k_strides = (
|
| 1440 |
+
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
|
| 1441 |
+
if empty_rfa_kv == 0 else
|
| 1442 |
+
(0, 0, 0)
|
| 1443 |
+
)
|
| 1444 |
+
rfa_v_strides = (
|
| 1445 |
+
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
|
| 1446 |
+
if empty_rfa_kv == 0 else
|
| 1447 |
+
(0, 0, 0)
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
o = torch.empty_like(q)
|
| 1451 |
+
lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
| 1452 |
+
|
| 1453 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 1454 |
+
|
| 1455 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
|
| 1456 |
+
|
| 1457 |
+
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
|
| 1458 |
+
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
|
| 1459 |
+
|
| 1460 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 1461 |
+
_fwd_eva_agg_kernel[grid](
|
| 1462 |
+
q,
|
| 1463 |
+
k,
|
| 1464 |
+
v,
|
| 1465 |
+
rfa_k,
|
| 1466 |
+
rfa_v,
|
| 1467 |
+
window_mask,
|
| 1468 |
+
chunk_mask,
|
| 1469 |
+
o,
|
| 1470 |
+
lse,
|
| 1471 |
+
softmax_scale,
|
| 1472 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1473 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1474 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1475 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1476 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1477 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1478 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1479 |
+
o.stride(0), o.stride(1), o.stride(2),
|
| 1480 |
+
lse.stride(0), lse.stride(1),
|
| 1481 |
+
nheads,
|
| 1482 |
+
seqlen_q,
|
| 1483 |
+
seqlen_k,
|
| 1484 |
+
nchunks,
|
| 1485 |
+
head_dim,
|
| 1486 |
+
chunks_per_window,
|
| 1487 |
+
window_size,
|
| 1488 |
+
mask_type,
|
| 1489 |
+
empty_rfa_kv,
|
| 1490 |
+
BLOCK_HEADDIM,
|
| 1491 |
+
BLOCK_M=BLOCK_M,
|
| 1492 |
+
BLOCK_N=BLOCK_N,
|
| 1493 |
+
num_warps=num_warps,
|
| 1494 |
+
num_stages=num_stages,
|
| 1495 |
+
)
|
| 1496 |
+
return o, lse
|
| 1497 |
+
|
| 1498 |
+
def triton_eva_agg_bwd(
|
| 1499 |
+
do,
|
| 1500 |
+
q, k, v, rfa_k, rfa_v,
|
| 1501 |
+
window_mask, chunk_mask,
|
| 1502 |
+
o, lse,
|
| 1503 |
+
dq, dk, dv, d_rfa_k, d_rfa_v,
|
| 1504 |
+
softmax_scale,
|
| 1505 |
+
window_size,
|
| 1506 |
+
chunks_per_window,
|
| 1507 |
+
empty_rfa_kv,
|
| 1508 |
+
mask_type,
|
| 1509 |
+
):
|
| 1510 |
+
if do.stride(-1) != 1:
|
| 1511 |
+
do = do.contiguous()
|
| 1512 |
+
|
| 1513 |
+
# shape constraints
|
| 1514 |
+
batch, nheads, seqlen_q, head_dim = q.shape
|
| 1515 |
+
_, _, seqlen_k, _ = k.shape
|
| 1516 |
+
if empty_rfa_kv == 0:
|
| 1517 |
+
nchunks = rfa_k.shape[-2]
|
| 1518 |
+
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
|
| 1519 |
+
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
|
| 1520 |
+
assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
|
| 1521 |
+
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
|
| 1522 |
+
else:
|
| 1523 |
+
nchunks = 0
|
| 1524 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 1525 |
+
|
| 1526 |
+
assert lse.shape == (batch, nheads, seqlen_q)
|
| 1527 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
|
| 1528 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
| 1529 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 1530 |
+
|
| 1531 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 1532 |
+
|
| 1533 |
+
window_mask_strides = (
|
| 1534 |
+
(window_mask.stride(0), window_mask.stride(2))
|
| 1535 |
+
if mask_type == 1 else
|
| 1536 |
+
(0, 0)
|
| 1537 |
+
)
|
| 1538 |
+
chunk_mask_strides = (
|
| 1539 |
+
(chunk_mask.stride(0), chunk_mask.stride(2))
|
| 1540 |
+
if mask_type == 1 else
|
| 1541 |
+
(0, 0)
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
rfa_k_strides = (
|
| 1545 |
+
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
|
| 1546 |
+
if empty_rfa_kv == 0 else
|
| 1547 |
+
(0, 0, 0)
|
| 1548 |
+
)
|
| 1549 |
+
rfa_v_strides = (
|
| 1550 |
+
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
|
| 1551 |
+
if empty_rfa_kv == 0 else
|
| 1552 |
+
(0, 0, 0)
|
| 1553 |
+
)
|
| 1554 |
+
|
| 1555 |
+
d_rfa_k_strides = (
|
| 1556 |
+
(d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
|
| 1557 |
+
if empty_rfa_kv == 0 else
|
| 1558 |
+
(0, 0, 0)
|
| 1559 |
+
)
|
| 1560 |
+
d_rfa_v_strides = (
|
| 1561 |
+
(d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
|
| 1562 |
+
if empty_rfa_kv == 0 else
|
| 1563 |
+
(0, 0, 0)
|
| 1564 |
+
)
|
| 1565 |
+
|
| 1566 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 1567 |
+
|
| 1568 |
+
do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
|
| 1569 |
+
|
| 1570 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
|
| 1571 |
+
|
| 1572 |
+
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
|
| 1573 |
+
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
|
| 1574 |
+
grid = lambda META: (
|
| 1575 |
+
triton.cdiv(seqlen_q, META["BLOCK_M"]),
|
| 1576 |
+
batch * nheads,
|
| 1577 |
+
)
|
| 1578 |
+
_bwd_eva_agg_kernel_dq[grid](
|
| 1579 |
+
q,
|
| 1580 |
+
k,
|
| 1581 |
+
v,
|
| 1582 |
+
rfa_k,
|
| 1583 |
+
rfa_v,
|
| 1584 |
+
window_mask,
|
| 1585 |
+
chunk_mask,
|
| 1586 |
+
do,
|
| 1587 |
+
lse,
|
| 1588 |
+
do_t_o,
|
| 1589 |
+
dq,
|
| 1590 |
+
softmax_scale,
|
| 1591 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1592 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1593 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1594 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1595 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1596 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1597 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1598 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1599 |
+
lse.stride(0), lse.stride(1),
|
| 1600 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1601 |
+
dq.stride(0), dq.stride(1), dq.stride(2),
|
| 1602 |
+
nheads,
|
| 1603 |
+
seqlen_q,
|
| 1604 |
+
seqlen_k,
|
| 1605 |
+
nchunks,
|
| 1606 |
+
head_dim,
|
| 1607 |
+
chunks_per_window,
|
| 1608 |
+
window_size,
|
| 1609 |
+
mask_type,
|
| 1610 |
+
empty_rfa_kv,
|
| 1611 |
+
BLOCK_HEADDIM,
|
| 1612 |
+
BLOCK_M=BLOCK_M,
|
| 1613 |
+
BLOCK_N=BLOCK_N,
|
| 1614 |
+
num_warps=num_warps,
|
| 1615 |
+
num_stages=num_stages,
|
| 1616 |
+
)
|
| 1617 |
+
|
| 1618 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
|
| 1619 |
+
grid = lambda META: (
|
| 1620 |
+
triton.cdiv(seqlen_k, META["BLOCK_N"]),
|
| 1621 |
+
batch * nheads,
|
| 1622 |
+
)
|
| 1623 |
+
_bwd_eva_agg_kernel_dkdv[grid](
|
| 1624 |
+
q,
|
| 1625 |
+
k,
|
| 1626 |
+
v,
|
| 1627 |
+
window_mask,
|
| 1628 |
+
do,
|
| 1629 |
+
lse,
|
| 1630 |
+
do_t_o,
|
| 1631 |
+
dk,
|
| 1632 |
+
dv,
|
| 1633 |
+
softmax_scale,
|
| 1634 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1635 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1636 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1637 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1638 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1639 |
+
lse.stride(0), lse.stride(1),
|
| 1640 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1641 |
+
dk.stride(0), dk.stride(1), dk.stride(2),
|
| 1642 |
+
dv.stride(0), dv.stride(1), dv.stride(2),
|
| 1643 |
+
nheads,
|
| 1644 |
+
seqlen_q,
|
| 1645 |
+
seqlen_k,
|
| 1646 |
+
head_dim,
|
| 1647 |
+
window_size,
|
| 1648 |
+
mask_type,
|
| 1649 |
+
BLOCK_HEADDIM,
|
| 1650 |
+
BLOCK_M=BLOCK_M,
|
| 1651 |
+
BLOCK_N=BLOCK_N,
|
| 1652 |
+
num_warps=num_warps,
|
| 1653 |
+
num_stages=num_stages,
|
| 1654 |
+
)
|
| 1655 |
+
if empty_rfa_kv == 0:
|
| 1656 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
|
| 1657 |
+
grid = lambda META: (
|
| 1658 |
+
triton.cdiv(nchunks, META["BLOCK_N"]),
|
| 1659 |
+
batch * nheads,
|
| 1660 |
+
)
|
| 1661 |
+
_bwd_eva_agg_kernel_drfa_kv[grid](
|
| 1662 |
+
q,
|
| 1663 |
+
rfa_k,
|
| 1664 |
+
rfa_v,
|
| 1665 |
+
chunk_mask,
|
| 1666 |
+
do,
|
| 1667 |
+
lse,
|
| 1668 |
+
do_t_o,
|
| 1669 |
+
d_rfa_k,
|
| 1670 |
+
d_rfa_v,
|
| 1671 |
+
softmax_scale,
|
| 1672 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1673 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1674 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1675 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1676 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1677 |
+
lse.stride(0), lse.stride(1),
|
| 1678 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1679 |
+
d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
|
| 1680 |
+
d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
|
| 1681 |
+
nheads,
|
| 1682 |
+
seqlen_q,
|
| 1683 |
+
nchunks,
|
| 1684 |
+
head_dim,
|
| 1685 |
+
chunks_per_window,
|
| 1686 |
+
window_size,
|
| 1687 |
+
mask_type,
|
| 1688 |
+
BLOCK_HEADDIM,
|
| 1689 |
+
BLOCK_M=BLOCK_M,
|
| 1690 |
+
BLOCK_N=BLOCK_N,
|
| 1691 |
+
num_warps=num_warps,
|
| 1692 |
+
num_stages=num_stages,
|
| 1693 |
+
)
|
| 1694 |
+
|
| 1695 |
+
|
| 1696 |
+
class EvaAggFunc(torch.autograd.Function):
|
| 1697 |
+
@staticmethod
|
| 1698 |
+
def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
|
| 1699 |
+
if rfa_k is None and rfa_v is None:
|
| 1700 |
+
empty_rfa_kv = 1
|
| 1701 |
+
else:
|
| 1702 |
+
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
|
| 1703 |
+
empty_rfa_kv = 0
|
| 1704 |
+
|
| 1705 |
+
if window_mask is not None:
|
| 1706 |
+
mask_type = 1
|
| 1707 |
+
else:
|
| 1708 |
+
mask_type = 0
|
| 1709 |
+
o, lse = triton_eva_agg_fwd(
|
| 1710 |
+
q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
|
| 1711 |
+
)
|
| 1712 |
+
ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
|
| 1713 |
+
ctx.softmax_scale = softmax_scale
|
| 1714 |
+
ctx.window_size = window_size
|
| 1715 |
+
ctx.chunks_per_window = chunks_per_window
|
| 1716 |
+
ctx.empty_rfa_kv = empty_rfa_kv
|
| 1717 |
+
ctx.mask_type = mask_type
|
| 1718 |
+
return o
|
| 1719 |
+
|
| 1720 |
+
@staticmethod
|
| 1721 |
+
def backward(ctx, do):
|
| 1722 |
+
q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
|
| 1723 |
+
dq = torch.empty_like(q)
|
| 1724 |
+
dk = torch.empty_like(k)
|
| 1725 |
+
dv = torch.empty_like(v)
|
| 1726 |
+
if ctx.empty_rfa_kv == 0:
|
| 1727 |
+
d_rfa_k = torch.empty_like(rfa_k)
|
| 1728 |
+
d_rfa_v = torch.empty_like(rfa_v)
|
| 1729 |
+
else:
|
| 1730 |
+
d_rfa_k = None
|
| 1731 |
+
d_rfa_v = None
|
| 1732 |
+
triton_eva_agg_bwd(
|
| 1733 |
+
do,
|
| 1734 |
+
q,
|
| 1735 |
+
k,
|
| 1736 |
+
v,
|
| 1737 |
+
rfa_k,
|
| 1738 |
+
rfa_v,
|
| 1739 |
+
window_mask,
|
| 1740 |
+
chunk_mask,
|
| 1741 |
+
o,
|
| 1742 |
+
lse,
|
| 1743 |
+
dq,
|
| 1744 |
+
dk,
|
| 1745 |
+
dv,
|
| 1746 |
+
d_rfa_k,
|
| 1747 |
+
d_rfa_v,
|
| 1748 |
+
softmax_scale=ctx.softmax_scale,
|
| 1749 |
+
window_size=ctx.window_size,
|
| 1750 |
+
chunks_per_window=ctx.chunks_per_window,
|
| 1751 |
+
empty_rfa_kv=ctx.empty_rfa_kv,
|
| 1752 |
+
mask_type=ctx.mask_type,
|
| 1753 |
+
)
|
| 1754 |
+
return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
|
| 1755 |
+
|
| 1756 |
+
|
| 1757 |
+
def eva_agg_func_triton(
|
| 1758 |
+
q, k, v, rfa_k, rfa_v,
|
| 1759 |
+
window_mask, chunk_mask,
|
| 1760 |
+
softmax_scale=None, window_size=None, chunks_per_window=None,
|
| 1761 |
+
):
|
| 1762 |
+
return EvaAggFunc.apply(
|
| 1763 |
+
q, k, v, rfa_k, rfa_v,
|
| 1764 |
+
window_mask, chunk_mask,
|
| 1765 |
+
softmax_scale, window_size, chunks_per_window,
|
| 1766 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_cache.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple, List, Any, Union
|
| 2 |
+
import torch
|
| 3 |
+
from transformers.cache_utils import Cache
|
| 4 |
+
|
| 5 |
+
class EvaCache(Cache):
|
| 6 |
+
"""
|
| 7 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 8 |
+
|
| 9 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
| 10 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
self.w_k: List[torch.Tensor] = []
|
| 15 |
+
self.w_v: List[torch.Tensor] = []
|
| 16 |
+
|
| 17 |
+
self.rf_q: List[torch.Tensor] = []
|
| 18 |
+
self.rf_k: List[torch.Tensor] = []
|
| 19 |
+
self.rf_v: List[torch.Tensor] = []
|
| 20 |
+
|
| 21 |
+
self.softmax_phi_k_v: List[torch.Tensor] = []
|
| 22 |
+
self.log_sum_phi_k: List[torch.Tensor] = []
|
| 23 |
+
self.rf_k_bar: List[torch.Tensor] = []
|
| 24 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 25 |
+
|
| 26 |
+
# attention masks temporary buffer
|
| 27 |
+
self.rf_mask: List[Optional[torch.Tensor]] = []
|
| 28 |
+
self.s_mask: List[torch.Tensor] = []
|
| 29 |
+
self.chunk_mask: List[torch.Tensor] = []
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
"""
|
| 33 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
| 34 |
+
to the number of layers in the model.
|
| 35 |
+
"""
|
| 36 |
+
return len(self.w_k)
|
| 37 |
+
|
| 38 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
| 39 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
| 40 |
+
# Cache without size limit -> all cache is usable
|
| 41 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
| 42 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
| 43 |
+
max_length = self.get_max_length()
|
| 44 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 45 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 46 |
+
return max_length - new_seq_length
|
| 47 |
+
return previous_seq_length
|
| 48 |
+
|
| 49 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 50 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 51 |
+
for layer_idx in range(len(self.w_k)):
|
| 52 |
+
device = self.w_k[layer_idx].device
|
| 53 |
+
self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 54 |
+
|
| 55 |
+
device = self.w_v[layer_idx].device
|
| 56 |
+
self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 57 |
+
|
| 58 |
+
device = self.rf_q[layer_idx].device
|
| 59 |
+
self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device))
|
| 60 |
+
|
| 61 |
+
device = self.rf_k[layer_idx].device
|
| 62 |
+
self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 63 |
+
|
| 64 |
+
device = self.rf_v[layer_idx].device
|
| 65 |
+
self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 66 |
+
|
| 67 |
+
device = self.softmax_phi_k_v[layer_idx].device
|
| 68 |
+
self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 69 |
+
|
| 70 |
+
device = self.log_sum_phi_k[layer_idx].device
|
| 71 |
+
self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 72 |
+
|
| 73 |
+
device = self.rf_k_bar[layer_idx].device
|
| 74 |
+
self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device))
|
| 75 |
+
|
| 76 |
+
device = self.rf_mask[layer_idx].device
|
| 77 |
+
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 78 |
+
|
| 79 |
+
device = self.s_mask[layer_idx].device
|
| 80 |
+
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 81 |
+
|
| 82 |
+
device = self.chunk_mask[layer_idx].device
|
| 83 |
+
self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 84 |
+
@property
|
| 85 |
+
def seen_tokens(self):
|
| 86 |
+
if hasattr(self, "_seen_tokens"):
|
| 87 |
+
return self._seen_tokens
|
| 88 |
+
else:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
def update_past_len(
|
| 92 |
+
self,
|
| 93 |
+
cur_q_len: int,
|
| 94 |
+
layer_idx: int
|
| 95 |
+
):
|
| 96 |
+
# Update the number of seen tokens
|
| 97 |
+
if layer_idx == 0:
|
| 98 |
+
self._seen_tokens += cur_q_len
|
| 99 |
+
return self._seen_tokens
|
| 100 |
+
|
| 101 |
+
def update_mask(
|
| 102 |
+
self,
|
| 103 |
+
prev_s_mask,
|
| 104 |
+
cur_s_mask,
|
| 105 |
+
chunk_mask,
|
| 106 |
+
rf_mask,
|
| 107 |
+
layer_idx,
|
| 108 |
+
window_size,
|
| 109 |
+
chunk_size,
|
| 110 |
+
):
|
| 111 |
+
############################################
|
| 112 |
+
# compute masks for singletons
|
| 113 |
+
############################################
|
| 114 |
+
q_len = None
|
| 115 |
+
if len(self.s_mask) <= layer_idx:
|
| 116 |
+
q_len = chunk_mask.shape[-2]
|
| 117 |
+
# prefill stage
|
| 118 |
+
# q is of shape [b, h, n, d]
|
| 119 |
+
if q_len < window_size:
|
| 120 |
+
assert prev_s_mask is None
|
| 121 |
+
|
| 122 |
+
# w_v = # [b, h, 1, j, d]
|
| 123 |
+
# store the past window-wise key-value pairs
|
| 124 |
+
self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :])
|
| 125 |
+
else:
|
| 126 |
+
# decoding stage
|
| 127 |
+
prev_s_mask = None
|
| 128 |
+
|
| 129 |
+
cached_s_mask = self.s_mask[layer_idx]
|
| 130 |
+
assert cached_s_mask is not None
|
| 131 |
+
if cached_s_mask.shape[-1] == window_size:
|
| 132 |
+
cur_s_mask = cur_s_mask
|
| 133 |
+
else:
|
| 134 |
+
cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1)
|
| 135 |
+
|
| 136 |
+
# store the past window-wise key-value pairs
|
| 137 |
+
self.s_mask[layer_idx] = cur_s_mask
|
| 138 |
+
|
| 139 |
+
############################################
|
| 140 |
+
# compute masks for intra-chunks
|
| 141 |
+
############################################
|
| 142 |
+
dump_rf_mask = None
|
| 143 |
+
if len(self.rf_mask) <= layer_idx:
|
| 144 |
+
# initialize chunk stats
|
| 145 |
+
# prefill stage
|
| 146 |
+
if q_len < chunk_size:
|
| 147 |
+
cur_rf_mask = rf_mask
|
| 148 |
+
else:
|
| 149 |
+
if q_len % chunk_size == 0:
|
| 150 |
+
dump_rf_mask = rf_mask
|
| 151 |
+
cur_rf_mask = None
|
| 152 |
+
else:
|
| 153 |
+
remainder_tokens = q_len % chunk_size
|
| 154 |
+
if rf_mask is not None:
|
| 155 |
+
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 156 |
+
else:
|
| 157 |
+
dump_rf_mask = None
|
| 158 |
+
cur_rf_mask = None
|
| 159 |
+
self.rf_mask.append(cur_rf_mask)
|
| 160 |
+
else:
|
| 161 |
+
past_rf_mask = self.rf_mask[layer_idx]
|
| 162 |
+
if past_rf_mask is not None:
|
| 163 |
+
# when decoding tokens, we always assume the
|
| 164 |
+
# incoming token mask is 0 (not masked)
|
| 165 |
+
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
|
| 166 |
+
else:
|
| 167 |
+
# we do not need to use rf_mask anymore after we receive generated tokens
|
| 168 |
+
cur_rf_mask = None
|
| 169 |
+
# We need to store rf_k_bar and RFA-results that
|
| 170 |
+
# compute the per-chunk RFA.
|
| 171 |
+
|
| 172 |
+
# Dump the chunk if the len of current chunk reaches <chunk_size>.
|
| 173 |
+
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size:
|
| 174 |
+
dump_rf_mask = cur_rf_mask
|
| 175 |
+
cur_rf_mask = None
|
| 176 |
+
|
| 177 |
+
self.rf_mask[layer_idx] = cur_rf_mask
|
| 178 |
+
|
| 179 |
+
############################################
|
| 180 |
+
# compute masks for inter chunks
|
| 181 |
+
############################################
|
| 182 |
+
if len(self.chunk_mask) <= layer_idx:
|
| 183 |
+
# prefill stage
|
| 184 |
+
# q is of shape [b, h, n, d]
|
| 185 |
+
if q_len < window_size:
|
| 186 |
+
cur_chunk_mask = chunk_mask
|
| 187 |
+
prev_chunk_mask = None
|
| 188 |
+
else:
|
| 189 |
+
if q_len % window_size == 0:
|
| 190 |
+
cur_chunk_mask = None
|
| 191 |
+
prev_chunk_mask = chunk_mask
|
| 192 |
+
else:
|
| 193 |
+
remainder_tokens = q_len % window_size
|
| 194 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 195 |
+
prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 196 |
+
bsz, num_heads, _, head_dim = prev_chunk_mask.shape
|
| 197 |
+
prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 198 |
+
|
| 199 |
+
assert prev_s_mask is not None
|
| 200 |
+
if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1:
|
| 201 |
+
# need to expand
|
| 202 |
+
prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
|
| 203 |
+
# w_v = # [b, h, 1, j, d]
|
| 204 |
+
# store the past window-wise key-value pairs
|
| 205 |
+
self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :])
|
| 206 |
+
else:
|
| 207 |
+
# decoding stage
|
| 208 |
+
prev_chunk_mask = None
|
| 209 |
+
cur_chunk_mask = self.chunk_mask[layer_idx]
|
| 210 |
+
|
| 211 |
+
# if the current sequence length reaches <chunk_size>,
|
| 212 |
+
# we append a new 1 to the end of chunk_mask
|
| 213 |
+
seen_seq_len = self.get_seq_length(layer_idx)
|
| 214 |
+
if seen_seq_len > 0 and seen_seq_len % chunk_size == 0:
|
| 215 |
+
past_chunk_mask = self.chunk_mask[layer_idx]
|
| 216 |
+
if past_chunk_mask is not None:
|
| 217 |
+
# when decoding tokens, we always assume the
|
| 218 |
+
# incoming token mask is 0 (not masked)
|
| 219 |
+
cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1)
|
| 220 |
+
else:
|
| 221 |
+
cur_chunk_mask = chunk_mask
|
| 222 |
+
self.chunk_mask[layer_idx] = cur_chunk_mask
|
| 223 |
+
|
| 224 |
+
# if the len of current sequence reaches <window_size> + 1,
|
| 225 |
+
# we turn on the mask for most recent chunks
|
| 226 |
+
if seen_seq_len > 0 and seen_seq_len % window_size == 1:
|
| 227 |
+
cur_chunk_mask = self.chunk_mask[layer_idx]
|
| 228 |
+
# we do not need to use rf_mask anymore after we receive generated tokens
|
| 229 |
+
num_chunks_per_window = window_size // chunk_size
|
| 230 |
+
cur_chunk_mask[..., -num_chunks_per_window:] = False
|
| 231 |
+
self.chunk_mask[layer_idx] = cur_chunk_mask
|
| 232 |
+
|
| 233 |
+
return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask)
|
| 234 |
+
|
| 235 |
+
def update_singletons(
|
| 236 |
+
self,
|
| 237 |
+
q,
|
| 238 |
+
k,
|
| 239 |
+
v,
|
| 240 |
+
layer_idx,
|
| 241 |
+
window_size,
|
| 242 |
+
):
|
| 243 |
+
if len(self.w_k) <= layer_idx:
|
| 244 |
+
# prefill stage
|
| 245 |
+
# q is of shape [b, h, n, d]
|
| 246 |
+
q_len = q.shape[-2]
|
| 247 |
+
if q_len < window_size:
|
| 248 |
+
w_q = q
|
| 249 |
+
w_k = k
|
| 250 |
+
w_v = v
|
| 251 |
+
past_w_q = past_w_k = past_w_v = None
|
| 252 |
+
else:
|
| 253 |
+
if q_len % window_size == 0:
|
| 254 |
+
w_q = None
|
| 255 |
+
w_k = None
|
| 256 |
+
w_v = None
|
| 257 |
+
past_w_q = q
|
| 258 |
+
past_w_k = k
|
| 259 |
+
past_w_v = v
|
| 260 |
+
else:
|
| 261 |
+
remainder_tokens = q_len % window_size
|
| 262 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 263 |
+
past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 264 |
+
past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 265 |
+
past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 266 |
+
bsz, num_heads, _, head_dim = past_w_q.shape
|
| 267 |
+
past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 268 |
+
past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 269 |
+
past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 270 |
+
# w_q = q[..., None, -window_size:, :] # [b, h, 1, j, d]
|
| 271 |
+
# w_k = # [b, h, 1, j, d]
|
| 272 |
+
# w_v = # [b, h, 1, j, d]
|
| 273 |
+
# store the past window-wise key-value pairs
|
| 274 |
+
# if w_k is None, it means we happen to pass in a sqeuence that is divisible by window_size
|
| 275 |
+
# we leave the cache with window_size-sized kv pairs to be cleared next iteration
|
| 276 |
+
self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :])
|
| 277 |
+
self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :])
|
| 278 |
+
else:
|
| 279 |
+
# decoding stage
|
| 280 |
+
past_w_q = past_w_k = past_w_v = None
|
| 281 |
+
# this is implemented as either a sliding window or fixed window
|
| 282 |
+
w_q = q # [b, h, 1, d]
|
| 283 |
+
w_k = k # [b, h, 1, d]
|
| 284 |
+
w_v = v # [b, h, 1, d]
|
| 285 |
+
|
| 286 |
+
cached_w_k = self.w_k[layer_idx]
|
| 287 |
+
assert cached_w_k is not None # [b, h, j, d]
|
| 288 |
+
if cached_w_k.shape[-2] == window_size:
|
| 289 |
+
w_k = w_k
|
| 290 |
+
else:
|
| 291 |
+
w_k = torch.cat([cached_w_k, w_k], dim=-2)
|
| 292 |
+
|
| 293 |
+
cached_w_v = self.w_v[layer_idx]
|
| 294 |
+
assert cached_w_v is not None
|
| 295 |
+
if cached_w_v.shape[-2] == window_size:
|
| 296 |
+
w_v = w_v
|
| 297 |
+
else:
|
| 298 |
+
w_v = torch.cat([cached_w_v, w_v], dim=-2)
|
| 299 |
+
|
| 300 |
+
# store the past window-wise key-value pairs
|
| 301 |
+
self.w_k[layer_idx] = w_k
|
| 302 |
+
self.w_v[layer_idx] = w_v
|
| 303 |
+
return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v)
|
| 304 |
+
|
| 305 |
+
def update_chunks(
|
| 306 |
+
self,
|
| 307 |
+
q,
|
| 308 |
+
k,
|
| 309 |
+
v,
|
| 310 |
+
layer_idx,
|
| 311 |
+
chunk_size
|
| 312 |
+
):
|
| 313 |
+
q_len = q.shape[-2]
|
| 314 |
+
dump_q = None
|
| 315 |
+
dump_k = None
|
| 316 |
+
dump_v = None
|
| 317 |
+
if len(self.rf_q) <= layer_idx:
|
| 318 |
+
# initialize chunk stats
|
| 319 |
+
# prefill stage
|
| 320 |
+
if q_len < chunk_size:
|
| 321 |
+
rf_q = q
|
| 322 |
+
rf_k = k
|
| 323 |
+
rf_v = v
|
| 324 |
+
else:
|
| 325 |
+
if q_len % chunk_size == 0:
|
| 326 |
+
rf_q = None
|
| 327 |
+
rf_k = None
|
| 328 |
+
rf_v = None
|
| 329 |
+
dump_q = q
|
| 330 |
+
dump_k = k
|
| 331 |
+
dump_v = v
|
| 332 |
+
else:
|
| 333 |
+
remainder_tokens = q_len % chunk_size
|
| 334 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 335 |
+
dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 336 |
+
dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 337 |
+
dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 338 |
+
self.rf_q.append(rf_q)
|
| 339 |
+
self.rf_k.append(rf_k)
|
| 340 |
+
self.rf_v.append(rf_v)
|
| 341 |
+
else:
|
| 342 |
+
# decode tokens
|
| 343 |
+
# add query, key & value to the current chunk.
|
| 344 |
+
past_rf_q = self.rf_q[layer_idx]
|
| 345 |
+
if past_rf_q is not None:
|
| 346 |
+
rf_q = torch.cat([past_rf_q, q], dim=-2)
|
| 347 |
+
else:
|
| 348 |
+
rf_q = q
|
| 349 |
+
|
| 350 |
+
past_rf_k = self.rf_k[layer_idx]
|
| 351 |
+
if past_rf_k is not None:
|
| 352 |
+
rf_k = torch.cat([past_rf_k, k], dim=-2)
|
| 353 |
+
else:
|
| 354 |
+
rf_k = k
|
| 355 |
+
|
| 356 |
+
past_rf_v = self.rf_v[layer_idx]
|
| 357 |
+
if past_rf_v is not None:
|
| 358 |
+
rf_v = torch.cat([past_rf_v, v], dim=-2)
|
| 359 |
+
else:
|
| 360 |
+
rf_v = v
|
| 361 |
+
|
| 362 |
+
# We need to store rf_k_bar and RFA-results that
|
| 363 |
+
# compute the per-chunk RFA.
|
| 364 |
+
|
| 365 |
+
# Dump the chunk if the len of current chunk reaches <chunk_size>.
|
| 366 |
+
if rf_q.shape[-2] == chunk_size:
|
| 367 |
+
dump_q = rf_q
|
| 368 |
+
dump_k = rf_k
|
| 369 |
+
dump_v = rf_v
|
| 370 |
+
# clear the chunk
|
| 371 |
+
rf_q = None
|
| 372 |
+
rf_k = None
|
| 373 |
+
rf_v = None
|
| 374 |
+
|
| 375 |
+
self.rf_q[layer_idx] = rf_q
|
| 376 |
+
self.rf_k[layer_idx] = rf_k
|
| 377 |
+
self.rf_v[layer_idx] = rf_v
|
| 378 |
+
|
| 379 |
+
return dump_q, dump_k, dump_v
|
| 380 |
+
|
| 381 |
+
def update_chunk_rfas(
|
| 382 |
+
self,
|
| 383 |
+
softmax_phi_k_v,
|
| 384 |
+
log_sum_phi_k,
|
| 385 |
+
rf_k_bar,
|
| 386 |
+
layer_idx,
|
| 387 |
+
random_feature_dim
|
| 388 |
+
):
|
| 389 |
+
if len(self.softmax_phi_k_v) <= layer_idx:
|
| 390 |
+
# prefill stage
|
| 391 |
+
self.softmax_phi_k_v.append(softmax_phi_k_v)
|
| 392 |
+
self.log_sum_phi_k.append(log_sum_phi_k)
|
| 393 |
+
self.rf_k_bar.append(rf_k_bar)
|
| 394 |
+
else:
|
| 395 |
+
# token decoding
|
| 396 |
+
past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx]
|
| 397 |
+
past_log_sum_phi_k = self.log_sum_phi_k[layer_idx]
|
| 398 |
+
past_rf_k_bar = self.rf_k_bar[layer_idx]
|
| 399 |
+
|
| 400 |
+
if past_softmax_phi_k_v is not None:
|
| 401 |
+
if random_feature_dim == 1:
|
| 402 |
+
dim = -2
|
| 403 |
+
else:
|
| 404 |
+
dim = -3
|
| 405 |
+
softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim)
|
| 406 |
+
|
| 407 |
+
if past_log_sum_phi_k is not None:
|
| 408 |
+
if random_feature_dim == 1:
|
| 409 |
+
dim = -2
|
| 410 |
+
else:
|
| 411 |
+
dim = -3
|
| 412 |
+
log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim)
|
| 413 |
+
|
| 414 |
+
if past_rf_k_bar is not None:
|
| 415 |
+
rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2)
|
| 416 |
+
|
| 417 |
+
self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v
|
| 418 |
+
self.log_sum_phi_k[layer_idx] = log_sum_phi_k
|
| 419 |
+
self.rf_k_bar[layer_idx] = rf_k_bar
|
| 420 |
+
|
| 421 |
+
return softmax_phi_k_v, log_sum_phi_k, rf_k_bar
|
| 422 |
+
|
| 423 |
+
def get_chunk_rfas(self, layer_idx):
|
| 424 |
+
if len(self.softmax_phi_k_v) <= layer_idx:
|
| 425 |
+
return (
|
| 426 |
+
None,
|
| 427 |
+
None,
|
| 428 |
+
None
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
return (
|
| 432 |
+
self.softmax_phi_k_v[layer_idx],
|
| 433 |
+
self.log_sum_phi_k[layer_idx],
|
| 434 |
+
self.rf_k_bar[layer_idx]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 438 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 439 |
+
if len(self.w_k) <= layer_idx:
|
| 440 |
+
return 0
|
| 441 |
+
return self._seen_tokens
|
| 442 |
+
|
| 443 |
+
def get_max_length(self) -> Optional[int]:
|
| 444 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
def update(
|
| 448 |
+
self,
|
| 449 |
+
layer_idx: int,
|
| 450 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 451 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 452 |
+
raise NotImplementedError("`update` is not used in Eva Cache.")
|
| 453 |
+
|
| 454 |
+
class EvaStaticCacheForTriton(Cache):
|
| 455 |
+
"""
|
| 456 |
+
A variant of EvaCache for eva's triton kernels
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def __init__(
|
| 460 |
+
self,
|
| 461 |
+
batch_size,
|
| 462 |
+
num_key_value_heads,
|
| 463 |
+
window_size,
|
| 464 |
+
head_dim,
|
| 465 |
+
num_layers,
|
| 466 |
+
dtype,
|
| 467 |
+
device
|
| 468 |
+
) -> None:
|
| 469 |
+
self.past_window_k: List[torch.Tensor] = []
|
| 470 |
+
self.past_window_v: List[torch.Tensor] = []
|
| 471 |
+
|
| 472 |
+
cache_shape = (batch_size, num_key_value_heads, window_size, head_dim)
|
| 473 |
+
for idx in range(num_layers):
|
| 474 |
+
new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device)
|
| 475 |
+
new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device)
|
| 476 |
+
self.past_window_k.append(new_window_k)
|
| 477 |
+
self.past_window_v.append(new_window_v)
|
| 478 |
+
|
| 479 |
+
self.past_window_pos: List[int] = []
|
| 480 |
+
|
| 481 |
+
self.rfa_k: List[torch.Tensor] = []
|
| 482 |
+
self.rfa_v: List[torch.Tensor] = []
|
| 483 |
+
# self.rfa_mask: List[torch.Tensor] = []
|
| 484 |
+
|
| 485 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 486 |
+
|
| 487 |
+
# attention masks temporary buffer
|
| 488 |
+
self.rf_mask: List[Optional[torch.Tensor]] = []
|
| 489 |
+
self.s_mask: List[torch.Tensor] = []
|
| 490 |
+
|
| 491 |
+
def __len__(self):
|
| 492 |
+
"""
|
| 493 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
| 494 |
+
to the number of layers in the model.
|
| 495 |
+
"""
|
| 496 |
+
return len(self.past_window_pos)
|
| 497 |
+
|
| 498 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
| 499 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
| 500 |
+
# Cache without size limit -> all cache is usable
|
| 501 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
| 502 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
| 503 |
+
max_length = self.get_max_length()
|
| 504 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 505 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 506 |
+
return max_length - new_seq_length
|
| 507 |
+
return previous_seq_length
|
| 508 |
+
|
| 509 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 510 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 511 |
+
for layer_idx in range(len(self.past_window_k)):
|
| 512 |
+
device = self.past_window_k[layer_idx].device
|
| 513 |
+
self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 514 |
+
|
| 515 |
+
device = self.past_window_v[layer_idx].device
|
| 516 |
+
self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 517 |
+
|
| 518 |
+
device = self.rfa_k[layer_idx].device
|
| 519 |
+
self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 520 |
+
|
| 521 |
+
device = self.rfa_v[layer_idx].device
|
| 522 |
+
self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 523 |
+
|
| 524 |
+
# device = self.rfa_mask[layer_idx].device
|
| 525 |
+
# self.rfa_mask[layer_idx] = self.rfa_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 526 |
+
|
| 527 |
+
device = self.rf_mask[layer_idx].device
|
| 528 |
+
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 529 |
+
|
| 530 |
+
device = self.s_mask[layer_idx].device
|
| 531 |
+
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def seen_tokens(self):
|
| 535 |
+
if hasattr(self, "_seen_tokens"):
|
| 536 |
+
return self._seen_tokens
|
| 537 |
+
else:
|
| 538 |
+
return None
|
| 539 |
+
|
| 540 |
+
def update_past_len(
|
| 541 |
+
self,
|
| 542 |
+
cur_q_len: int,
|
| 543 |
+
layer_idx: int
|
| 544 |
+
):
|
| 545 |
+
# Update the number of seen tokens
|
| 546 |
+
if layer_idx == 0:
|
| 547 |
+
self._seen_tokens += cur_q_len
|
| 548 |
+
return self._seen_tokens
|
| 549 |
+
|
| 550 |
+
def update_mask(
|
| 551 |
+
self,
|
| 552 |
+
s_mask,
|
| 553 |
+
rf_mask,
|
| 554 |
+
layer_idx,
|
| 555 |
+
window_size,
|
| 556 |
+
):
|
| 557 |
+
############################################
|
| 558 |
+
# compute masks for singletons
|
| 559 |
+
############################################
|
| 560 |
+
if len(self.s_mask) <= layer_idx:
|
| 561 |
+
# prefill stage
|
| 562 |
+
# q is of shape [b, h, n, d]
|
| 563 |
+
# s_v = # [b, h, 1, j, d]
|
| 564 |
+
# store the past window-wise key-value pairs
|
| 565 |
+
if s_mask is None:
|
| 566 |
+
cur_s_mask = None
|
| 567 |
+
else:
|
| 568 |
+
q_len = s_mask.shape[-2]
|
| 569 |
+
# s_mask is of shape [b, h, n, w]
|
| 570 |
+
# let r = q_len % window_size
|
| 571 |
+
# if r == 0, the mask to be appended is of shape [b, h, 1, w]
|
| 572 |
+
# otherwise, r < w, the mask to be appended is of shape [b, h, 1, r]
|
| 573 |
+
remainder_tokens = q_len % window_size
|
| 574 |
+
if remainder_tokens == 0:
|
| 575 |
+
cur_s_mask = None
|
| 576 |
+
else:
|
| 577 |
+
cur_s_mask = s_mask[..., -1:, :remainder_tokens]
|
| 578 |
+
self.s_mask.append(cur_s_mask)
|
| 579 |
+
# we use the passed s_mask for subsequent computations
|
| 580 |
+
dump_s_mask = s_mask
|
| 581 |
+
else:
|
| 582 |
+
# decoding stage
|
| 583 |
+
past_s_mask = self.s_mask[layer_idx]
|
| 584 |
+
if past_s_mask is None:
|
| 585 |
+
assert s_mask is None
|
| 586 |
+
cur_s_mask = None
|
| 587 |
+
else:
|
| 588 |
+
assert s_mask is not None
|
| 589 |
+
cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1)
|
| 590 |
+
|
| 591 |
+
dump_s_mask = cur_s_mask
|
| 592 |
+
if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size:
|
| 593 |
+
cur_s_mask = None
|
| 594 |
+
# store the past window-wise key-value pairs
|
| 595 |
+
self.s_mask[layer_idx] = cur_s_mask
|
| 596 |
+
|
| 597 |
+
############################################
|
| 598 |
+
# compute masks for intra-chunks
|
| 599 |
+
############################################
|
| 600 |
+
dump_rf_mask = None
|
| 601 |
+
if len(self.rf_mask) <= layer_idx:
|
| 602 |
+
# initialize chunk stats
|
| 603 |
+
# prefill stage
|
| 604 |
+
if rf_mask is None:
|
| 605 |
+
cur_rf_mask = None
|
| 606 |
+
else:
|
| 607 |
+
q_len = rf_mask.shape[-2]
|
| 608 |
+
if q_len < window_size:
|
| 609 |
+
dump_rf_mask = None
|
| 610 |
+
cur_rf_mask = rf_mask
|
| 611 |
+
else:
|
| 612 |
+
if q_len % window_size == 0:
|
| 613 |
+
dump_rf_mask = rf_mask
|
| 614 |
+
cur_rf_mask = None
|
| 615 |
+
else:
|
| 616 |
+
remainder_tokens = q_len % window_size
|
| 617 |
+
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 618 |
+
self.rf_mask.append(cur_rf_mask)
|
| 619 |
+
else:
|
| 620 |
+
past_rf_mask = self.rf_mask[layer_idx]
|
| 621 |
+
if past_rf_mask is not None:
|
| 622 |
+
# when decoding tokens, we always assume the
|
| 623 |
+
# incoming token mask is 0 (not masked)
|
| 624 |
+
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
|
| 625 |
+
else:
|
| 626 |
+
cur_rf_mask = None
|
| 627 |
+
|
| 628 |
+
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size:
|
| 629 |
+
dump_rf_mask = cur_rf_mask
|
| 630 |
+
cur_rf_mask = None
|
| 631 |
+
|
| 632 |
+
self.rf_mask[layer_idx] = cur_rf_mask
|
| 633 |
+
|
| 634 |
+
return dump_s_mask, dump_rf_mask
|
| 635 |
+
|
| 636 |
+
def update_singletons_and_chunks(
|
| 637 |
+
self,
|
| 638 |
+
k,
|
| 639 |
+
v,
|
| 640 |
+
layer_idx,
|
| 641 |
+
window_size,
|
| 642 |
+
):
|
| 643 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 644 |
+
# prefill stage
|
| 645 |
+
s_k = k
|
| 646 |
+
s_v = v
|
| 647 |
+
input_len = k.shape[-2]
|
| 648 |
+
window_pos = 0
|
| 649 |
+
if input_len <= window_size:
|
| 650 |
+
new_window_pos = window_pos + input_len
|
| 651 |
+
|
| 652 |
+
cached_window_k = k
|
| 653 |
+
cached_window_v = v
|
| 654 |
+
dump_k = None
|
| 655 |
+
dump_v = None
|
| 656 |
+
else:
|
| 657 |
+
remainder_tokens = input_len % window_size
|
| 658 |
+
if remainder_tokens == 0:
|
| 659 |
+
remainder_tokens = window_size
|
| 660 |
+
new_window_pos = window_pos + remainder_tokens
|
| 661 |
+
|
| 662 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 663 |
+
cached_window_k = k[..., -remainder_tokens:, :]
|
| 664 |
+
cached_window_v = v[..., -remainder_tokens:, :]
|
| 665 |
+
dump_k = k[..., :-remainder_tokens, :]
|
| 666 |
+
dump_v = v[..., :-remainder_tokens, :]
|
| 667 |
+
# store the past window-wise key-value pairs
|
| 668 |
+
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k
|
| 669 |
+
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v
|
| 670 |
+
self.past_window_pos.append(new_window_pos)
|
| 671 |
+
else:
|
| 672 |
+
# decoding stage
|
| 673 |
+
# if the previous cache has full tokens,
|
| 674 |
+
# roll back to the first elements
|
| 675 |
+
if self.past_window_pos[layer_idx] == window_size:
|
| 676 |
+
self.past_window_pos[layer_idx] = 0
|
| 677 |
+
dump_k = self.past_window_k[layer_idx].clone()
|
| 678 |
+
dump_v = self.past_window_v[layer_idx].clone()
|
| 679 |
+
else:
|
| 680 |
+
dump_k = None
|
| 681 |
+
dump_v = None
|
| 682 |
+
|
| 683 |
+
input_len = k.shape[-2]
|
| 684 |
+
window_pos = self.past_window_pos[layer_idx]
|
| 685 |
+
new_window_pos = window_pos + input_len
|
| 686 |
+
|
| 687 |
+
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k
|
| 688 |
+
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v
|
| 689 |
+
|
| 690 |
+
s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :]
|
| 691 |
+
s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :]
|
| 692 |
+
|
| 693 |
+
self.past_window_pos[layer_idx] = new_window_pos
|
| 694 |
+
|
| 695 |
+
return s_k, s_v, dump_k, dump_v
|
| 696 |
+
|
| 697 |
+
def update_chunk_rfas(
|
| 698 |
+
self,
|
| 699 |
+
rfa_k,
|
| 700 |
+
rfa_v,
|
| 701 |
+
layer_idx,
|
| 702 |
+
):
|
| 703 |
+
if len(self.rfa_k) <= layer_idx:
|
| 704 |
+
# prefill stage
|
| 705 |
+
self.rfa_k.append(rfa_k)
|
| 706 |
+
self.rfa_v.append(rfa_v)
|
| 707 |
+
else:
|
| 708 |
+
# token decoding
|
| 709 |
+
past_rfa_k = self.rfa_k[layer_idx]
|
| 710 |
+
past_rfa_v = self.rfa_v[layer_idx]
|
| 711 |
+
|
| 712 |
+
if past_rfa_k is not None:
|
| 713 |
+
rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2)
|
| 714 |
+
|
| 715 |
+
if past_rfa_v is not None:
|
| 716 |
+
rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2)
|
| 717 |
+
|
| 718 |
+
self.rfa_k[layer_idx] = rfa_k
|
| 719 |
+
self.rfa_v[layer_idx] = rfa_v
|
| 720 |
+
|
| 721 |
+
return rfa_k, rfa_v
|
| 722 |
+
|
| 723 |
+
def get_past_window_pos(self, layer_idx):
|
| 724 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 725 |
+
return None
|
| 726 |
+
else:
|
| 727 |
+
return self.past_window_pos[layer_idx]
|
| 728 |
+
|
| 729 |
+
def get_past_window_kv(self, layer_idx):
|
| 730 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 731 |
+
return None, None
|
| 732 |
+
else:
|
| 733 |
+
return (
|
| 734 |
+
self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :],
|
| 735 |
+
self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :]
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
def get_chunk_rfas(self, layer_idx):
|
| 739 |
+
if len(self.rfa_k) <= layer_idx:
|
| 740 |
+
return None, None
|
| 741 |
+
else:
|
| 742 |
+
return self.rfa_k[layer_idx], self.rfa_v[layer_idx]
|
| 743 |
+
|
| 744 |
+
def get_seq_length(self, layer_idx = 0) -> int:
|
| 745 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 746 |
+
# layer_idx must be provided since otherwise
|
| 747 |
+
# any layer > 0 can only get the updated _seen_tokens
|
| 748 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 749 |
+
return 0
|
| 750 |
+
return self._seen_tokens
|
| 751 |
+
|
| 752 |
+
def get_max_length(self) -> Optional[int]:
|
| 753 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
| 754 |
+
return None
|
| 755 |
+
|
| 756 |
+
def update(
|
| 757 |
+
self,
|
| 758 |
+
layer_idx: int,
|
| 759 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 760 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 761 |
+
raise NotImplementedError("`update` is not used in Eva Cache.")
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_prep_kv_kernel.py
ADDED
|
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
@triton.heuristics(
|
| 8 |
+
{
|
| 9 |
+
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
|
| 10 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 11 |
+
}
|
| 12 |
+
)
|
| 13 |
+
@triton.jit
|
| 14 |
+
def _fwd_eva_prep_kv_kernel(
|
| 15 |
+
K, # [b, h, n, d]
|
| 16 |
+
V, # [b, h, n, d]
|
| 17 |
+
PARAM_MU, # [1, h, 1, 1, d]
|
| 18 |
+
PARAM_PHI, # [1, h, 1, 1, d]
|
| 19 |
+
Mask, # [b, h, n, 1]
|
| 20 |
+
Out_RFA_K, # [b, h, c, d]
|
| 21 |
+
Out_RFA_V, # [b, h, c, d]
|
| 22 |
+
softmax_scale,
|
| 23 |
+
stride_kb, stride_kh, stride_kn,
|
| 24 |
+
stride_vb, stride_vh, stride_vn,
|
| 25 |
+
stride_mu_h,
|
| 26 |
+
stride_phi_h,
|
| 27 |
+
stride_mb, stride_mn,
|
| 28 |
+
stride_ok_b, stride_ok_h, stride_ok_c,
|
| 29 |
+
stride_ov_b, stride_ov_h, stride_ov_c,
|
| 30 |
+
nheads,
|
| 31 |
+
seqlen,
|
| 32 |
+
nchunks,
|
| 33 |
+
headdim,
|
| 34 |
+
CHUNKS_PER_BLOCK: tl.constexpr,
|
| 35 |
+
CHUNK_SIZE: tl.constexpr,
|
| 36 |
+
MASK_TYPE: tl.constexpr,
|
| 37 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 38 |
+
EVEN_N: tl.constexpr,
|
| 39 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 40 |
+
BLOCK_N: tl.constexpr,
|
| 41 |
+
):
|
| 42 |
+
start_n = tl.program_id(0)
|
| 43 |
+
offs_bh = tl.program_id(1)
|
| 44 |
+
offs_h = offs_bh % nheads
|
| 45 |
+
offs_b = offs_bh // nheads
|
| 46 |
+
# initialize offsets
|
| 47 |
+
# we load BLOCK_N keys and values each time, and
|
| 48 |
+
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
|
| 49 |
+
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
|
| 50 |
+
offs_m = tl.arange(0, CHUNK_SIZE)
|
| 51 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 52 |
+
|
| 53 |
+
k_ptrs = (
|
| 54 |
+
K +
|
| 55 |
+
offs_b * stride_kb +
|
| 56 |
+
offs_h * stride_kh +
|
| 57 |
+
(
|
| 58 |
+
(
|
| 59 |
+
start_n * BLOCK_N +
|
| 60 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 61 |
+
offs_m[None, :, None]
|
| 62 |
+
) * stride_kn +
|
| 63 |
+
offs_d[None, None, :]
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
v_ptrs = (
|
| 67 |
+
V +
|
| 68 |
+
offs_b * stride_vb +
|
| 69 |
+
offs_h * stride_vh +
|
| 70 |
+
(
|
| 71 |
+
(
|
| 72 |
+
start_n * BLOCK_N +
|
| 73 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 74 |
+
offs_m[None, :, None]
|
| 75 |
+
) * stride_vn +
|
| 76 |
+
offs_d[None, None, :]
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
param_mu_ptrs = (
|
| 80 |
+
PARAM_MU +
|
| 81 |
+
offs_h * stride_mu_h +
|
| 82 |
+
offs_d[None, None, :]
|
| 83 |
+
)
|
| 84 |
+
param_phi_ptrs = (
|
| 85 |
+
PARAM_PHI +
|
| 86 |
+
offs_h * stride_phi_h +
|
| 87 |
+
offs_d[None, None, :]
|
| 88 |
+
)
|
| 89 |
+
log2e = 1.4426950408889634
|
| 90 |
+
if MASK_TYPE == 1:
|
| 91 |
+
m_ptrs = (
|
| 92 |
+
Mask +
|
| 93 |
+
offs_b * stride_mb +
|
| 94 |
+
(
|
| 95 |
+
(
|
| 96 |
+
start_n * BLOCK_N +
|
| 97 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 98 |
+
offs_m[None, :]
|
| 99 |
+
) * stride_mn
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
if EVEN_N:
|
| 103 |
+
if EVEN_HEADDIM:
|
| 104 |
+
k = tl.load(
|
| 105 |
+
k_ptrs
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
k = tl.load(
|
| 109 |
+
k_ptrs,
|
| 110 |
+
mask=offs_d[None, None, :] < headdim,
|
| 111 |
+
other=0.0
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
if EVEN_HEADDIM:
|
| 115 |
+
k = tl.load(
|
| 116 |
+
k_ptrs,
|
| 117 |
+
mask=(
|
| 118 |
+
start_n * BLOCK_N +
|
| 119 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 120 |
+
offs_m[None, :, None]
|
| 121 |
+
) < seqlen,
|
| 122 |
+
other=0.0
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
k = tl.load(
|
| 126 |
+
k_ptrs,
|
| 127 |
+
mask=(
|
| 128 |
+
(
|
| 129 |
+
start_n * BLOCK_N +
|
| 130 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 131 |
+
offs_m[None, :, None]
|
| 132 |
+
) < seqlen
|
| 133 |
+
) & (offs_d[None, None, :] < headdim),
|
| 134 |
+
other=0.0
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
|
| 138 |
+
rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 139 |
+
rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
|
| 140 |
+
rfa_k_c_w *= log2e
|
| 141 |
+
if MASK_TYPE == 1:
|
| 142 |
+
if EVEN_N:
|
| 143 |
+
mask = tl.load(
|
| 144 |
+
m_ptrs
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
mask = tl.load(
|
| 148 |
+
m_ptrs,
|
| 149 |
+
mask=(
|
| 150 |
+
start_n * BLOCK_N +
|
| 151 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 152 |
+
offs_m[None, :]
|
| 153 |
+
) < seqlen,
|
| 154 |
+
other=1,
|
| 155 |
+
)
|
| 156 |
+
rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
|
| 157 |
+
|
| 158 |
+
m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
|
| 159 |
+
masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
|
| 160 |
+
m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
|
| 161 |
+
rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
|
| 162 |
+
denom_k = tl.sum(rfa_k_c_w, axis=-1)
|
| 163 |
+
denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
|
| 164 |
+
rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
|
| 165 |
+
rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
|
| 166 |
+
# TODO: understand why rematerialize offsets to save registers?
|
| 167 |
+
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
|
| 168 |
+
out_rfa_k_ptrs = (
|
| 169 |
+
Out_RFA_K +
|
| 170 |
+
offs_b * stride_ok_b +
|
| 171 |
+
offs_h * stride_ok_h +
|
| 172 |
+
(offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if EVEN_N:
|
| 176 |
+
if EVEN_HEADDIM:
|
| 177 |
+
tl.store(
|
| 178 |
+
out_rfa_k_ptrs, rfa_k_c
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
tl.store(
|
| 182 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 183 |
+
mask=offs_d[None, :] < headdim
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
if EVEN_HEADDIM:
|
| 187 |
+
tl.store(
|
| 188 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 189 |
+
mask=offs_out_c[:, None] < nchunks
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
tl.store(
|
| 193 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 194 |
+
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
|
| 199 |
+
rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 200 |
+
rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
|
| 201 |
+
rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
|
| 202 |
+
rfa_v_c_w *= log2e * softmax_scale
|
| 203 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 204 |
+
rfa_v_c_w += tl.where(
|
| 205 |
+
(
|
| 206 |
+
start_n * BLOCK_N +
|
| 207 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 208 |
+
offs_m[None, :]
|
| 209 |
+
) < seqlen,
|
| 210 |
+
0,
|
| 211 |
+
float("-inf")
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if MASK_TYPE == 1:
|
| 215 |
+
rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
|
| 216 |
+
|
| 217 |
+
if EVEN_N:
|
| 218 |
+
if EVEN_HEADDIM:
|
| 219 |
+
v = tl.load(
|
| 220 |
+
v_ptrs
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
v = tl.load(
|
| 224 |
+
v_ptrs,
|
| 225 |
+
mask=offs_d[None, None, :] < headdim,
|
| 226 |
+
other=0.0
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
if EVEN_HEADDIM:
|
| 230 |
+
v = tl.load(
|
| 231 |
+
v_ptrs,
|
| 232 |
+
mask=(
|
| 233 |
+
start_n * BLOCK_N +
|
| 234 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 235 |
+
offs_m[None, :, None]
|
| 236 |
+
) < seqlen,
|
| 237 |
+
other=0.0
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
v = tl.load(
|
| 241 |
+
v_ptrs,
|
| 242 |
+
mask=(
|
| 243 |
+
(
|
| 244 |
+
start_n * BLOCK_N +
|
| 245 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 246 |
+
offs_m[None, :, None]
|
| 247 |
+
) < seqlen
|
| 248 |
+
) & (offs_d[None, None, :] < headdim),
|
| 249 |
+
other=0.0
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
|
| 254 |
+
masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
|
| 255 |
+
m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
|
| 256 |
+
rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
|
| 257 |
+
denom_v = tl.sum(rfa_v_c_w, axis=-1)
|
| 258 |
+
denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
|
| 259 |
+
rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
|
| 260 |
+
rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
|
| 261 |
+
|
| 262 |
+
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
|
| 263 |
+
out_rfa_v_ptrs = (
|
| 264 |
+
Out_RFA_V +
|
| 265 |
+
offs_b * stride_ov_b +
|
| 266 |
+
offs_h * stride_ov_h +
|
| 267 |
+
(offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
|
| 268 |
+
)
|
| 269 |
+
if EVEN_N:
|
| 270 |
+
if EVEN_HEADDIM:
|
| 271 |
+
tl.store(
|
| 272 |
+
out_rfa_v_ptrs, rfa_v_c
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
tl.store(
|
| 276 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 277 |
+
mask=offs_d[None, :] < headdim
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
if EVEN_HEADDIM:
|
| 281 |
+
tl.store(
|
| 282 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 283 |
+
mask=offs_out_c[:, None] < nchunks
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
tl.store(
|
| 287 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 288 |
+
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@triton.heuristics(
|
| 294 |
+
{
|
| 295 |
+
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
|
| 296 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
@triton.jit
|
| 300 |
+
def _bwd_eva_prep_kv_kernel(
|
| 301 |
+
RFA_K, # [b, h, c, d]
|
| 302 |
+
RFA_V, # [b, h, c, d]
|
| 303 |
+
K, # [b, h, n, d]
|
| 304 |
+
V, # [b, h, n, d]
|
| 305 |
+
PARAM_MU, # [1, h, 1, 1, d]
|
| 306 |
+
PARAM_PHI, # [1, h, 1, 1, d]
|
| 307 |
+
Mask, # [b, h, n, 1]
|
| 308 |
+
D_RFA_K, # [b, h, c, d]
|
| 309 |
+
D_RFA_V, # [b, h, c, d]
|
| 310 |
+
D_K, # [b, h, n, d]
|
| 311 |
+
D_V, # [b, h, n, d]
|
| 312 |
+
D_PARAM_MU_PARTIAL, # [b, h, g, d]
|
| 313 |
+
D_PARAM_PHI_PARTIAL, # [b, h, g, d]
|
| 314 |
+
softmax_scale,
|
| 315 |
+
stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
|
| 316 |
+
stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
|
| 317 |
+
stride_kb, stride_kh, stride_kn,
|
| 318 |
+
stride_vb, stride_vh, stride_vn,
|
| 319 |
+
stride_mu_h,
|
| 320 |
+
stride_phi_h,
|
| 321 |
+
stride_mb, stride_mn,
|
| 322 |
+
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
|
| 323 |
+
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
|
| 324 |
+
stride_d_k_b, stride_d_k_h, stride_d_k_n,
|
| 325 |
+
stride_d_v_b, stride_d_v_h, stride_d_v_n,
|
| 326 |
+
stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
|
| 327 |
+
stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
|
| 328 |
+
nheads,
|
| 329 |
+
seqlen,
|
| 330 |
+
nchunks,
|
| 331 |
+
headdim,
|
| 332 |
+
CHUNKS_PER_BLOCK: tl.constexpr,
|
| 333 |
+
CHUNK_SIZE: tl.constexpr,
|
| 334 |
+
MASK_TYPE: tl.constexpr,
|
| 335 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 336 |
+
EVEN_N: tl.constexpr,
|
| 337 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 338 |
+
BLOCK_N: tl.constexpr,
|
| 339 |
+
):
|
| 340 |
+
start_n = tl.program_id(0)
|
| 341 |
+
offs_bh = tl.program_id(1)
|
| 342 |
+
offs_h = offs_bh % nheads
|
| 343 |
+
offs_b = offs_bh // nheads
|
| 344 |
+
# initialize offsets
|
| 345 |
+
# we load BLOCK_N keys and values each time, and
|
| 346 |
+
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
|
| 347 |
+
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
|
| 348 |
+
offs_m = tl.arange(0, CHUNK_SIZE)
|
| 349 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 350 |
+
|
| 351 |
+
offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
|
| 352 |
+
|
| 353 |
+
k_ptrs = (
|
| 354 |
+
K +
|
| 355 |
+
offs_b * stride_kb +
|
| 356 |
+
offs_h * stride_kh +
|
| 357 |
+
(
|
| 358 |
+
(
|
| 359 |
+
start_n * BLOCK_N +
|
| 360 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 361 |
+
offs_m[None, :, None]
|
| 362 |
+
) * stride_kn +
|
| 363 |
+
offs_d[None, None, :]
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
rfa_k_ptrs = (
|
| 367 |
+
RFA_K +
|
| 368 |
+
offs_b * stride_rfa_k_b +
|
| 369 |
+
offs_h * stride_rfa_k_h +
|
| 370 |
+
(offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
|
| 371 |
+
)
|
| 372 |
+
rfa_v_ptrs = (
|
| 373 |
+
RFA_V +
|
| 374 |
+
offs_b * stride_rfa_v_b +
|
| 375 |
+
offs_h * stride_rfa_v_h +
|
| 376 |
+
(offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
d_rfa_k_ptrs = (
|
| 380 |
+
D_RFA_K +
|
| 381 |
+
offs_b * stride_d_rfa_k_b +
|
| 382 |
+
offs_h * stride_d_rfa_k_h +
|
| 383 |
+
(offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
|
| 384 |
+
)
|
| 385 |
+
d_rfa_v_ptrs = (
|
| 386 |
+
D_RFA_V +
|
| 387 |
+
offs_b * stride_d_rfa_v_b +
|
| 388 |
+
offs_h * stride_d_rfa_v_h +
|
| 389 |
+
(offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
param_mu_ptrs = (
|
| 393 |
+
PARAM_MU +
|
| 394 |
+
offs_h * stride_mu_h +
|
| 395 |
+
offs_d[None, None, :]
|
| 396 |
+
)
|
| 397 |
+
param_phi_ptrs = (
|
| 398 |
+
PARAM_PHI +
|
| 399 |
+
offs_h * stride_phi_h +
|
| 400 |
+
offs_d[None, None, :]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
log2e = 1.4426950408889634
|
| 404 |
+
if MASK_TYPE == 1:
|
| 405 |
+
m_ptrs = (
|
| 406 |
+
Mask +
|
| 407 |
+
offs_b * stride_mb +
|
| 408 |
+
(
|
| 409 |
+
(
|
| 410 |
+
start_n * BLOCK_N +
|
| 411 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 412 |
+
offs_m[None, :]
|
| 413 |
+
) * stride_mn
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
if EVEN_N:
|
| 417 |
+
if EVEN_HEADDIM:
|
| 418 |
+
k = tl.load(
|
| 419 |
+
k_ptrs
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
k = tl.load(
|
| 423 |
+
k_ptrs,
|
| 424 |
+
mask=offs_d[None, None, :] < headdim,
|
| 425 |
+
other=0.0
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
if EVEN_HEADDIM:
|
| 429 |
+
k = tl.load(
|
| 430 |
+
k_ptrs,
|
| 431 |
+
mask=(
|
| 432 |
+
start_n * BLOCK_N +
|
| 433 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 434 |
+
offs_m[None, :, None]
|
| 435 |
+
) < seqlen,
|
| 436 |
+
other=0.0
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
k = tl.load(
|
| 440 |
+
k_ptrs,
|
| 441 |
+
mask=(
|
| 442 |
+
(
|
| 443 |
+
start_n * BLOCK_N +
|
| 444 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 445 |
+
offs_m[None, :, None]
|
| 446 |
+
) < seqlen
|
| 447 |
+
) & (offs_d[None, None, :] < headdim),
|
| 448 |
+
other=0.0
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if EVEN_N:
|
| 452 |
+
if EVEN_HEADDIM:
|
| 453 |
+
rfa_k = tl.load(
|
| 454 |
+
rfa_k_ptrs
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
rfa_k = tl.load(
|
| 458 |
+
rfa_k_ptrs,
|
| 459 |
+
mask=offs_d[None, :] < headdim,
|
| 460 |
+
other=0.0
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
if EVEN_HEADDIM:
|
| 464 |
+
rfa_k = tl.load(
|
| 465 |
+
rfa_k_ptrs,
|
| 466 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 467 |
+
other=0.0
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
rfa_k = tl.load(
|
| 471 |
+
rfa_k_ptrs,
|
| 472 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 473 |
+
other=0.0
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
if EVEN_N:
|
| 477 |
+
if EVEN_HEADDIM:
|
| 478 |
+
d_rfa_k = tl.load(
|
| 479 |
+
d_rfa_k_ptrs
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
d_rfa_k = tl.load(
|
| 483 |
+
d_rfa_k_ptrs,
|
| 484 |
+
mask=offs_d[None, :] < headdim,
|
| 485 |
+
other=0.0
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
if EVEN_HEADDIM:
|
| 489 |
+
d_rfa_k = tl.load(
|
| 490 |
+
d_rfa_k_ptrs,
|
| 491 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 492 |
+
other=0.0
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
d_rfa_k = tl.load(
|
| 496 |
+
d_rfa_k_ptrs,
|
| 497 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 498 |
+
other=0.0
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
|
| 502 |
+
mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 503 |
+
mu_c_w += tl.sum(k * param_mu, axis=-1)
|
| 504 |
+
mu_c_w *= log2e
|
| 505 |
+
|
| 506 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 507 |
+
mu_c_w += tl.where(
|
| 508 |
+
(
|
| 509 |
+
start_n * BLOCK_N +
|
| 510 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 511 |
+
offs_m[None, :]
|
| 512 |
+
) < seqlen,
|
| 513 |
+
0,
|
| 514 |
+
float("-inf")
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if MASK_TYPE == 1:
|
| 518 |
+
if EVEN_N:
|
| 519 |
+
mask = tl.load(
|
| 520 |
+
m_ptrs
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
mask = tl.load(
|
| 524 |
+
m_ptrs,
|
| 525 |
+
mask=(
|
| 526 |
+
start_n * BLOCK_N +
|
| 527 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 528 |
+
offs_m[None, :]
|
| 529 |
+
) < seqlen,
|
| 530 |
+
other=1,
|
| 531 |
+
)
|
| 532 |
+
mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
|
| 533 |
+
|
| 534 |
+
# [c, w]
|
| 535 |
+
m_mu_c_w = tl.max(mu_c_w, axis=-1)
|
| 536 |
+
masked_out_rows_mu = (m_mu_c_w == float("-inf"))
|
| 537 |
+
m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
|
| 538 |
+
mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
|
| 539 |
+
denom_mu = tl.sum(mu_c_w, axis=-1)
|
| 540 |
+
denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
|
| 541 |
+
mu_tilde_c_w = mu_c_w / denom_mu[:, None]
|
| 542 |
+
mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
|
| 543 |
+
# [c, d] [c, w, d] -> [c, w]
|
| 544 |
+
d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
|
| 545 |
+
# [c, d] [c, d] -> [c]
|
| 546 |
+
d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
|
| 547 |
+
d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
|
| 548 |
+
|
| 549 |
+
# [c, w] [c, w, d] -> [d]
|
| 550 |
+
d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
|
| 551 |
+
# [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
|
| 552 |
+
d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
|
| 553 |
+
|
| 554 |
+
d_param_mu_partial_ptrs = (
|
| 555 |
+
D_PARAM_MU_PARTIAL +
|
| 556 |
+
offs_b * stride_d_mu_b +
|
| 557 |
+
offs_h * stride_d_mu_h +
|
| 558 |
+
start_n * stride_d_mu_g +
|
| 559 |
+
offs_d
|
| 560 |
+
)
|
| 561 |
+
if EVEN_HEADDIM:
|
| 562 |
+
tl.store(
|
| 563 |
+
d_param_mu_partial_ptrs, d_param_mu
|
| 564 |
+
)
|
| 565 |
+
else:
|
| 566 |
+
tl.store(
|
| 567 |
+
d_param_mu_partial_ptrs, d_param_mu,
|
| 568 |
+
mask=offs_d < headdim
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
v_ptrs = (
|
| 573 |
+
V +
|
| 574 |
+
offs_b * stride_vb +
|
| 575 |
+
offs_h * stride_vh +
|
| 576 |
+
(
|
| 577 |
+
(
|
| 578 |
+
start_n * BLOCK_N +
|
| 579 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 580 |
+
offs_m[None, :, None]
|
| 581 |
+
) * stride_vn +
|
| 582 |
+
offs_d[None, None, :]
|
| 583 |
+
)
|
| 584 |
+
)
|
| 585 |
+
if EVEN_N:
|
| 586 |
+
if EVEN_HEADDIM:
|
| 587 |
+
v = tl.load(
|
| 588 |
+
v_ptrs
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
v = tl.load(
|
| 592 |
+
v_ptrs,
|
| 593 |
+
mask=offs_d[None, None, :] < headdim,
|
| 594 |
+
other=0.0
|
| 595 |
+
)
|
| 596 |
+
else:
|
| 597 |
+
if EVEN_HEADDIM:
|
| 598 |
+
v = tl.load(
|
| 599 |
+
v_ptrs,
|
| 600 |
+
mask=(
|
| 601 |
+
start_n * BLOCK_N +
|
| 602 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 603 |
+
offs_m[None, :, None]
|
| 604 |
+
) < seqlen,
|
| 605 |
+
other=0.0
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
v = tl.load(
|
| 609 |
+
v_ptrs,
|
| 610 |
+
mask=(
|
| 611 |
+
(
|
| 612 |
+
start_n * BLOCK_N +
|
| 613 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 614 |
+
offs_m[None, :, None]
|
| 615 |
+
) < seqlen
|
| 616 |
+
) & (offs_d[None, None, :] < headdim),
|
| 617 |
+
other=0.0
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
if EVEN_N:
|
| 622 |
+
if EVEN_HEADDIM:
|
| 623 |
+
rfa_v = tl.load(
|
| 624 |
+
rfa_v_ptrs
|
| 625 |
+
)
|
| 626 |
+
else:
|
| 627 |
+
rfa_v = tl.load(
|
| 628 |
+
rfa_v_ptrs,
|
| 629 |
+
mask=offs_d[None, :] < headdim,
|
| 630 |
+
other=0.0
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
if EVEN_HEADDIM:
|
| 634 |
+
rfa_v = tl.load(
|
| 635 |
+
rfa_v_ptrs,
|
| 636 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 637 |
+
other=0.0
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
rfa_v = tl.load(
|
| 641 |
+
rfa_v_ptrs,
|
| 642 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 643 |
+
other=0.0
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if EVEN_N:
|
| 647 |
+
if EVEN_HEADDIM:
|
| 648 |
+
d_rfa_v = tl.load(
|
| 649 |
+
d_rfa_v_ptrs
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
d_rfa_v = tl.load(
|
| 653 |
+
d_rfa_v_ptrs,
|
| 654 |
+
mask=offs_d[None, :] < headdim,
|
| 655 |
+
other=0.0
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
if EVEN_HEADDIM:
|
| 659 |
+
d_rfa_v = tl.load(
|
| 660 |
+
d_rfa_v_ptrs,
|
| 661 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 662 |
+
other=0.0
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
d_rfa_v = tl.load(
|
| 666 |
+
d_rfa_v_ptrs,
|
| 667 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 668 |
+
other=0.0
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
|
| 672 |
+
phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 673 |
+
phi_c_w += tl.sum(k * param_phi, axis=-1)
|
| 674 |
+
phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
|
| 675 |
+
phi_c_w *= log2e * softmax_scale
|
| 676 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 677 |
+
phi_c_w += tl.where(
|
| 678 |
+
(
|
| 679 |
+
start_n * BLOCK_N +
|
| 680 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 681 |
+
offs_m[None, :]
|
| 682 |
+
) < seqlen,
|
| 683 |
+
0,
|
| 684 |
+
float("-inf")
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
if MASK_TYPE == 1:
|
| 688 |
+
phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
m_phi_c_w = tl.max(phi_c_w, axis=-1)
|
| 692 |
+
masked_out_rows_phi = (m_phi_c_w == float("-inf"))
|
| 693 |
+
m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
|
| 694 |
+
phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
|
| 695 |
+
denom_phi = tl.sum(phi_c_w, axis=-1)
|
| 696 |
+
denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
|
| 697 |
+
phi_tilde_c_w = phi_c_w / denom_phi[:, None]
|
| 698 |
+
# phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
|
| 699 |
+
# phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
|
| 700 |
+
phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
|
| 701 |
+
d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
|
| 702 |
+
d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
|
| 703 |
+
d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
|
| 704 |
+
|
| 705 |
+
d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
|
| 706 |
+
d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
|
| 707 |
+
# [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
|
| 708 |
+
d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
|
| 709 |
+
|
| 710 |
+
d_k_ptrs = (
|
| 711 |
+
D_K +
|
| 712 |
+
offs_b * stride_d_k_b +
|
| 713 |
+
offs_h * stride_d_k_h +
|
| 714 |
+
(
|
| 715 |
+
(
|
| 716 |
+
start_n * BLOCK_N +
|
| 717 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 718 |
+
offs_m[None, :, None]
|
| 719 |
+
) * stride_d_k_n +
|
| 720 |
+
offs_d[None, None, :]
|
| 721 |
+
)
|
| 722 |
+
)
|
| 723 |
+
d_v_ptrs = (
|
| 724 |
+
D_V +
|
| 725 |
+
offs_b * stride_d_v_b +
|
| 726 |
+
offs_h * stride_d_v_h +
|
| 727 |
+
(
|
| 728 |
+
(
|
| 729 |
+
start_n * BLOCK_N +
|
| 730 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 731 |
+
offs_m[None, :, None]
|
| 732 |
+
) * stride_d_v_n +
|
| 733 |
+
offs_d[None, None, :]
|
| 734 |
+
)
|
| 735 |
+
)
|
| 736 |
+
if EVEN_N:
|
| 737 |
+
if EVEN_HEADDIM:
|
| 738 |
+
tl.store(
|
| 739 |
+
d_k_ptrs, d_k
|
| 740 |
+
)
|
| 741 |
+
tl.store(
|
| 742 |
+
d_v_ptrs, d_v
|
| 743 |
+
)
|
| 744 |
+
else:
|
| 745 |
+
tl.store(
|
| 746 |
+
d_k_ptrs, d_k,
|
| 747 |
+
mask=offs_d[None, None, :] < headdim
|
| 748 |
+
)
|
| 749 |
+
tl.store(
|
| 750 |
+
d_v_ptrs, d_v,
|
| 751 |
+
mask=offs_d[None, None, :] < headdim
|
| 752 |
+
)
|
| 753 |
+
else:
|
| 754 |
+
if EVEN_HEADDIM:
|
| 755 |
+
tl.store(
|
| 756 |
+
d_k_ptrs, d_k,
|
| 757 |
+
mask=(
|
| 758 |
+
(
|
| 759 |
+
start_n * BLOCK_N +
|
| 760 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 761 |
+
offs_m[None, :, None]
|
| 762 |
+
) < seqlen
|
| 763 |
+
),
|
| 764 |
+
)
|
| 765 |
+
tl.store(
|
| 766 |
+
d_v_ptrs, d_v,
|
| 767 |
+
mask=(
|
| 768 |
+
(
|
| 769 |
+
start_n * BLOCK_N +
|
| 770 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 771 |
+
offs_m[None, :, None]
|
| 772 |
+
) < seqlen
|
| 773 |
+
),
|
| 774 |
+
)
|
| 775 |
+
else:
|
| 776 |
+
tl.store(
|
| 777 |
+
d_k_ptrs, d_k,
|
| 778 |
+
mask=(
|
| 779 |
+
(
|
| 780 |
+
start_n * BLOCK_N +
|
| 781 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 782 |
+
offs_m[None, :, None]
|
| 783 |
+
) < seqlen
|
| 784 |
+
) & (offs_d[None, None, :] < headdim),
|
| 785 |
+
)
|
| 786 |
+
tl.store(
|
| 787 |
+
d_v_ptrs, d_v,
|
| 788 |
+
mask=(
|
| 789 |
+
(
|
| 790 |
+
start_n * BLOCK_N +
|
| 791 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 792 |
+
offs_m[None, :, None]
|
| 793 |
+
) < seqlen
|
| 794 |
+
) & (offs_d[None, None, :] < headdim),
|
| 795 |
+
)
|
| 796 |
+
d_param_phi_partial_ptrs = (
|
| 797 |
+
D_PARAM_PHI_PARTIAL +
|
| 798 |
+
offs_b * stride_d_phi_b +
|
| 799 |
+
offs_h * stride_d_phi_h +
|
| 800 |
+
start_n * stride_d_phi_g +
|
| 801 |
+
offs_d
|
| 802 |
+
)
|
| 803 |
+
if EVEN_HEADDIM:
|
| 804 |
+
tl.store(
|
| 805 |
+
d_param_phi_partial_ptrs, d_param_phi
|
| 806 |
+
)
|
| 807 |
+
else:
|
| 808 |
+
tl.store(
|
| 809 |
+
d_param_phi_partial_ptrs, d_param_phi,
|
| 810 |
+
mask=offs_d < headdim
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
|
| 814 |
+
k, v, param_mu, param_phi = [
|
| 815 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 816 |
+
for x in [k, v, param_mu, param_phi]
|
| 817 |
+
]
|
| 818 |
+
|
| 819 |
+
# shape constraints
|
| 820 |
+
batch, nheads, seqlen, head_dim = k.shape
|
| 821 |
+
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
|
| 822 |
+
nchunks = seqlen // chunksize
|
| 823 |
+
assert k.shape == (batch, nheads, seqlen, head_dim)
|
| 824 |
+
assert v.shape == (batch, nheads, seqlen, head_dim)
|
| 825 |
+
assert param_mu.shape == (1, nheads, 1, 1, head_dim)
|
| 826 |
+
assert param_phi.shape == (1, nheads, 1, 1, head_dim)
|
| 827 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 828 |
+
assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
|
| 829 |
+
assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
|
| 830 |
+
assert k.is_cuda and v.is_cuda
|
| 831 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 832 |
+
|
| 833 |
+
mask_type = 0
|
| 834 |
+
if mask is not None:
|
| 835 |
+
mask_type = 1
|
| 836 |
+
assert mask.dtype == torch.bool
|
| 837 |
+
assert mask.is_cuda
|
| 838 |
+
assert mask.dim() == 4
|
| 839 |
+
assert mask.shape == (batch, 1, seqlen, 1)
|
| 840 |
+
if mask.stride(-1) != 1:
|
| 841 |
+
mask = mask.contiguous()
|
| 842 |
+
mask_strides = (
|
| 843 |
+
(mask.stride(0), mask.stride(2))
|
| 844 |
+
if mask_type == 1 else
|
| 845 |
+
(0, 0)
|
| 846 |
+
)
|
| 847 |
+
out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
|
| 848 |
+
out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
|
| 849 |
+
|
| 850 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 851 |
+
BLOCK = 128
|
| 852 |
+
num_warps = 4 if head_dim <= 64 else 8
|
| 853 |
+
|
| 854 |
+
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
|
| 855 |
+
chunks_per_block = BLOCK // chunksize
|
| 856 |
+
|
| 857 |
+
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
|
| 858 |
+
_fwd_eva_prep_kv_kernel[grid](
|
| 859 |
+
k,
|
| 860 |
+
v,
|
| 861 |
+
param_mu,
|
| 862 |
+
param_phi,
|
| 863 |
+
mask,
|
| 864 |
+
out_rfa_k,
|
| 865 |
+
out_rfa_v,
|
| 866 |
+
softmax_scale,
|
| 867 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 868 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 869 |
+
param_mu.stride(1),
|
| 870 |
+
param_phi.stride(1),
|
| 871 |
+
mask_strides[0], mask_strides[1],
|
| 872 |
+
out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
|
| 873 |
+
out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
|
| 874 |
+
nheads,
|
| 875 |
+
seqlen,
|
| 876 |
+
nchunks,
|
| 877 |
+
head_dim,
|
| 878 |
+
chunks_per_block,
|
| 879 |
+
chunksize,
|
| 880 |
+
mask_type,
|
| 881 |
+
BLOCK_HEADDIM,
|
| 882 |
+
BLOCK_N=BLOCK,
|
| 883 |
+
num_warps=num_warps,
|
| 884 |
+
num_stages=1,
|
| 885 |
+
)
|
| 886 |
+
return out_rfa_k, out_rfa_v
|
| 887 |
+
|
| 888 |
+
def triton_eva_prep_kv_bwd(
|
| 889 |
+
d_rfa_k, d_rfa_v,
|
| 890 |
+
k, v, param_mu, param_phi,
|
| 891 |
+
mask,
|
| 892 |
+
rfa_k, rfa_v,
|
| 893 |
+
d_k, d_v, d_param_mu, d_param_phi,
|
| 894 |
+
softmax_scale,
|
| 895 |
+
mask_type,
|
| 896 |
+
chunksize
|
| 897 |
+
):
|
| 898 |
+
d_rfa_k, d_rfa_v = [
|
| 899 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 900 |
+
for x in [d_rfa_k, d_rfa_v]
|
| 901 |
+
]
|
| 902 |
+
|
| 903 |
+
# shape constraints
|
| 904 |
+
batch, nheads, seqlen, head_dim = k.shape
|
| 905 |
+
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
|
| 906 |
+
nchunks = seqlen // chunksize
|
| 907 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 908 |
+
|
| 909 |
+
mask_strides = (
|
| 910 |
+
(mask.stride(0), mask.stride(2))
|
| 911 |
+
if mask_type == 1 else
|
| 912 |
+
(0, 0)
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 916 |
+
BLOCK = 128
|
| 917 |
+
num_warps = 4 if head_dim <= 64 else 8
|
| 918 |
+
|
| 919 |
+
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
|
| 920 |
+
chunks_per_block = BLOCK // chunksize
|
| 921 |
+
|
| 922 |
+
partial_groups = triton.cdiv(seqlen, BLOCK)
|
| 923 |
+
d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
|
| 924 |
+
d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
|
| 925 |
+
grid = lambda META: (partial_groups, batch * nheads)
|
| 926 |
+
_bwd_eva_prep_kv_kernel[grid](
|
| 927 |
+
rfa_k, # [b, h, c, d]
|
| 928 |
+
rfa_v, # [b, h, c, d]
|
| 929 |
+
k, # [b, h, n, d]
|
| 930 |
+
v, # [b, h, n, d]
|
| 931 |
+
param_mu, # [1, h, 1, 1, d]
|
| 932 |
+
param_phi, # [1, h, 1, 1, d]
|
| 933 |
+
mask, # [b, h, n, 1]
|
| 934 |
+
d_rfa_k, # [b, h, c, d]
|
| 935 |
+
d_rfa_v, # [b, h, c, d]
|
| 936 |
+
d_k, # [b, h, n, d]
|
| 937 |
+
d_v, # [b, h, n, d]
|
| 938 |
+
d_param_mu_partial, # [b, h, g, d]
|
| 939 |
+
d_param_phi_partial, # [b, h, g, d]
|
| 940 |
+
softmax_scale,
|
| 941 |
+
rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
|
| 942 |
+
rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
|
| 943 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 944 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 945 |
+
param_mu.stride(1),
|
| 946 |
+
param_phi.stride(1),
|
| 947 |
+
mask_strides[0], mask_strides[1],
|
| 948 |
+
d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
|
| 949 |
+
d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
|
| 950 |
+
d_k.stride(0), d_k.stride(1), d_k.stride(2),
|
| 951 |
+
d_v.stride(0), d_v.stride(1), d_v.stride(2),
|
| 952 |
+
d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
|
| 953 |
+
d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
|
| 954 |
+
nheads,
|
| 955 |
+
seqlen,
|
| 956 |
+
nchunks,
|
| 957 |
+
head_dim,
|
| 958 |
+
chunks_per_block,
|
| 959 |
+
chunksize,
|
| 960 |
+
mask_type,
|
| 961 |
+
BLOCK_HEADDIM,
|
| 962 |
+
BLOCK_N=BLOCK,
|
| 963 |
+
num_warps=num_warps,
|
| 964 |
+
num_stages=1,
|
| 965 |
+
)
|
| 966 |
+
d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
|
| 967 |
+
d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class EvaPrepKVFunc(torch.autograd.Function):
|
| 972 |
+
@staticmethod
|
| 973 |
+
def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
|
| 974 |
+
if mask is not None:
|
| 975 |
+
mask_type = 1
|
| 976 |
+
else:
|
| 977 |
+
mask_type = 0
|
| 978 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 979 |
+
k, v, param_mu, param_phi, mask, softmax_scale, chunksize
|
| 980 |
+
)
|
| 981 |
+
ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
|
| 982 |
+
ctx.softmax_scale = softmax_scale
|
| 983 |
+
ctx.chunksize = chunksize
|
| 984 |
+
ctx.mask_type = mask_type
|
| 985 |
+
return rfa_k, rfa_v
|
| 986 |
+
|
| 987 |
+
@staticmethod
|
| 988 |
+
def backward(ctx, d_rfa_k, d_rfa_v):
|
| 989 |
+
k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
|
| 990 |
+
d_k = torch.empty_like(k)
|
| 991 |
+
d_v = torch.empty_like(v)
|
| 992 |
+
d_param_mu = torch.empty_like(param_mu)
|
| 993 |
+
d_param_phi = torch.empty_like(param_phi)
|
| 994 |
+
triton_eva_prep_kv_bwd(
|
| 995 |
+
d_rfa_k, d_rfa_v,
|
| 996 |
+
k, v, param_mu, param_phi,
|
| 997 |
+
mask,
|
| 998 |
+
rfa_k, rfa_v,
|
| 999 |
+
d_k, d_v, d_param_mu, d_param_phi,
|
| 1000 |
+
ctx.softmax_scale,
|
| 1001 |
+
ctx.mask_type,
|
| 1002 |
+
ctx.chunksize
|
| 1003 |
+
)
|
| 1004 |
+
return d_k, d_v, d_param_mu, d_param_phi, None, None, None
|
| 1005 |
+
|
| 1006 |
+
def eva_prep_kv_func_triton(
|
| 1007 |
+
k, v,
|
| 1008 |
+
param_mu, param_phi,
|
| 1009 |
+
mask,
|
| 1010 |
+
softmax_scale=None, chunksize=None
|
| 1011 |
+
):
|
| 1012 |
+
return EvaPrepKVFunc.apply(
|
| 1013 |
+
k, v,
|
| 1014 |
+
param_mu, param_phi,
|
| 1015 |
+
mask,
|
| 1016 |
+
softmax_scale, chunksize
|
| 1017 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_pt_ref.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
MASK_MIN_VALUE = -10e10
|
| 6 |
+
|
| 7 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
"""
|
| 9 |
+
Rotates half the hidden dims (last dim) of the input.
|
| 10 |
+
Args:
|
| 11 |
+
x: Rotary embedded tensor
|
| 12 |
+
Return:
|
| 13 |
+
Tensor with half of last dim negated and rotated to the front.
|
| 14 |
+
"""
|
| 15 |
+
x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
|
| 16 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 17 |
+
|
| 18 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
| 19 |
+
position_ids: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
|
| 22 |
+
|
| 23 |
+
The legends for dimensions are defined as:
|
| 24 |
+
num_heads: number of attention heads
|
| 25 |
+
current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
|
| 26 |
+
max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
|
| 27 |
+
maximum lenghth, e.g. the length of static sequence length of KV cache
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
|
| 32 |
+
k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
|
| 33 |
+
cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 34 |
+
sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 35 |
+
position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
|
| 36 |
+
(batch_size, current_seq_len).
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Embedded query and key tensor of same size as input.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
bs, nheads, cur_seq_len, head_dim = q.shape
|
| 43 |
+
assert len(
|
| 44 |
+
k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
|
| 45 |
+
assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
|
| 46 |
+
assert list(k.shape[2:]) == [cur_seq_len,
|
| 47 |
+
head_dim], f"k has different current_seq_len and/or head_dim compared to q"
|
| 48 |
+
assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
|
| 49 |
+
assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
|
| 50 |
+
f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
|
| 51 |
+
|
| 52 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 53 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 54 |
+
return q_embed, k_embed
|
| 55 |
+
|
| 56 |
+
def attention_op(
|
| 57 |
+
q,
|
| 58 |
+
k,
|
| 59 |
+
v,
|
| 60 |
+
attn_mask,
|
| 61 |
+
mixedp_attn,
|
| 62 |
+
head_dim_scaling
|
| 63 |
+
):
|
| 64 |
+
attn = torch.matmul(q, k.transpose(-2, -1))
|
| 65 |
+
if mixedp_attn:
|
| 66 |
+
attn = attn.to(torch.float)
|
| 67 |
+
attn = attn * head_dim_scaling
|
| 68 |
+
if attn_mask is not None:
|
| 69 |
+
attn = attn.masked_fill(attn_mask, MASK_MIN_VALUE)
|
| 70 |
+
|
| 71 |
+
attn_weights = torch.softmax(attn, dim=-1).to(q.dtype)
|
| 72 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 73 |
+
return attn_output
|
| 74 |
+
|
| 75 |
+
def prm_projection(
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
projection_matrix: torch.Tensor,
|
| 78 |
+
mixedp_attn: bool = False
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Constructs nonnegative kernel features for fast softmax attention.
|
| 82 |
+
Args:
|
| 83 |
+
x: input for which features are computed
|
| 84 |
+
projection_matrix: random matrix used to compute features
|
| 85 |
+
Returns:
|
| 86 |
+
Random features for fast attention.
|
| 87 |
+
"""
|
| 88 |
+
# x : [..., m, d]
|
| 89 |
+
# proj : [..., r, d]
|
| 90 |
+
scaling_factor = (x.shape[-1] ** -0.5)
|
| 91 |
+
proj_x = torch.matmul(projection_matrix, x.transpose(-1, -2)) # [..., r, m]
|
| 92 |
+
norm = torch.sum(x ** 2, dim=-1).unsqueeze(-2) * 0.5 # [..., 1]
|
| 93 |
+
if mixedp_attn:
|
| 94 |
+
proj_x = proj_x.to(torch.float)
|
| 95 |
+
norm = norm.to(torch.float)
|
| 96 |
+
phi_x = scaling_factor * (proj_x - norm)
|
| 97 |
+
return phi_x
|
| 98 |
+
|
| 99 |
+
class EvaAttention(nn.Module):
|
| 100 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.config = config
|
| 103 |
+
self.layer_idx = layer_idx
|
| 104 |
+
self.hidden_size = config.hidden_size
|
| 105 |
+
self.num_heads = config.num_attention_heads
|
| 106 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 107 |
+
self.head_dim_scaling = self.head_dim ** -0.5
|
| 108 |
+
|
| 109 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 110 |
+
|
| 111 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 114 |
+
f" and `num_heads`: {self.num_heads})."
|
| 115 |
+
)
|
| 116 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 117 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 118 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 119 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 120 |
+
|
| 121 |
+
self.window_size = config.window_size
|
| 122 |
+
|
| 123 |
+
self.num_chunks = config.num_chunks
|
| 124 |
+
self.chunk_size = config.chunk_size
|
| 125 |
+
if self.chunk_size is not None:
|
| 126 |
+
assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
|
| 127 |
+
# chunk_size overrides the number of landmarks
|
| 128 |
+
self.num_chunks = None
|
| 129 |
+
|
| 130 |
+
self.chunks_per_window = int(self.window_size // self.chunk_size)
|
| 131 |
+
self.random_feature_dim = 1
|
| 132 |
+
self.adaptive_phi = nn.Parameter(
|
| 133 |
+
torch.randn(
|
| 134 |
+
1,
|
| 135 |
+
self.num_heads,
|
| 136 |
+
1,
|
| 137 |
+
1,
|
| 138 |
+
self.head_dim
|
| 139 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 140 |
+
)
|
| 141 |
+
self.adaptive_mu_k = nn.Parameter(
|
| 142 |
+
torch.randn(
|
| 143 |
+
1,
|
| 144 |
+
self.num_heads,
|
| 145 |
+
1,
|
| 146 |
+
1,
|
| 147 |
+
self.head_dim
|
| 148 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def _generate_feature_map(self, rf_q, rf_k, rf_v):
|
| 152 |
+
rf_k_logits = torch.sum(self.adaptive_mu_k.to(rf_k.dtype) * rf_k, dim=-1, keepdim=True) # b h c m 1
|
| 153 |
+
if self.config.mixedp_attn:
|
| 154 |
+
rf_k_logits = rf_k_logits.to(torch.float)
|
| 155 |
+
rf_k_weights = torch.softmax(rf_k_logits, dim=-2).to(rf_k.dtype)
|
| 156 |
+
rf_k_bar = torch.sum(rf_k_weights * rf_k, dim=-2)
|
| 157 |
+
weights = self.adaptive_phi.to(rf_k.dtype)
|
| 158 |
+
return weights, rf_k_bar
|
| 159 |
+
|
| 160 |
+
def _calculate_chunk_rfa_cache(self, rf_q, rf_k, rf_v, weights, rf_mask=None):
|
| 161 |
+
proj_x = torch.sum(weights * rf_k, dim=-1, keepdim=True)
|
| 162 |
+
norm = torch.sum(rf_k ** 2, dim=-1, keepdim=True) * 0.5 # [..., 1]
|
| 163 |
+
if self.config.mixedp_attn:
|
| 164 |
+
proj_x = proj_x.to(torch.float)
|
| 165 |
+
norm = norm.to(torch.float)
|
| 166 |
+
log_phi_k = self.head_dim_scaling * (proj_x - norm)
|
| 167 |
+
|
| 168 |
+
if rf_mask is not None:
|
| 169 |
+
log_phi_k = log_phi_k.masked_fill(rf_mask, MASK_MIN_VALUE)
|
| 170 |
+
|
| 171 |
+
# [b, h, c, m, r]
|
| 172 |
+
softmax_phi_k = torch.softmax(log_phi_k, dim=-2).to(rf_k.dtype)
|
| 173 |
+
softmax_phi_k_v = torch.sum(softmax_phi_k * rf_v, dim=-2)
|
| 174 |
+
# [b, h, c, r, m] [b, h, c, m, d] -> [b, h, c, r, d]
|
| 175 |
+
# softmax_phi_k_v = torch.matmul(softmax_phi_k.transpose(-1, -2), rf_v).squeeze(-2)
|
| 176 |
+
log_sum_phi_k = None
|
| 177 |
+
return softmax_phi_k_v, log_sum_phi_k
|
| 178 |
+
|
| 179 |
+
def _calculate_chunk_rfa(self, q, softmax_phi_k_v, log_sum_phi_k, weights):
|
| 180 |
+
if self.random_feature_dim == 1:
|
| 181 |
+
# when r = 1, the snis weights becomes 1, so this takes no effect
|
| 182 |
+
# [b, h, c, r, d] -> [b, h, c, d]
|
| 183 |
+
return softmax_phi_k_v
|
| 184 |
+
else:
|
| 185 |
+
# [b, h, c, r, d] [b, h, 1, s, d] -> [b, h, c, r, s]
|
| 186 |
+
log_phi_q = prm_projection(q.unsqueeze(-3), weights, self.config.mixedp_attn)
|
| 187 |
+
# [b, h, c, r, s] [b, h, c, r, 1] -> [b, h, c, r, s]
|
| 188 |
+
sniw = torch.softmax(log_phi_q + log_sum_phi_k, dim=-1).to(q.dtype)
|
| 189 |
+
# [b, h, c, r, s] [b, h, c, r, d] -> [b, h, c, s, d] -> [b, h, s, c, d]
|
| 190 |
+
rfa_per_chunk = torch.matmul(sniw.transpose(-1, -2), softmax_phi_k_v).transpose(-3, -2)
|
| 191 |
+
return rfa_per_chunk
|
| 192 |
+
|
| 193 |
+
def window_partition(self, x, window_size=None):
|
| 194 |
+
window_size = window_size if window_size is not None else self.window_size
|
| 195 |
+
|
| 196 |
+
gw, d = x.shape[-2:]
|
| 197 |
+
leading_dims = x.shape[:-2]
|
| 198 |
+
n_groups = gw // window_size
|
| 199 |
+
return x.reshape(*leading_dims, n_groups, window_size, d)
|
| 200 |
+
|
| 201 |
+
def window_merge(self, x, window_size=None):
|
| 202 |
+
g, w, d = x.shape[-3:]
|
| 203 |
+
leading_dims = x.shape[:-3]
|
| 204 |
+
return x.reshape(*leading_dims, g * w, d)
|
| 205 |
+
|
| 206 |
+
def forward(
|
| 207 |
+
self,
|
| 208 |
+
hidden_states: torch.Tensor,
|
| 209 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 210 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 211 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 212 |
+
output_attentions: bool = False,
|
| 213 |
+
use_cache: bool = False,
|
| 214 |
+
cos: Optional[torch.Tensor] = None,
|
| 215 |
+
sin: Optional[torch.Tensor] = None,
|
| 216 |
+
multibyte_decoding: Optional[bool] = False,
|
| 217 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 218 |
+
assert not output_attentions
|
| 219 |
+
bsz, q_len, _ = hidden_states.size()
|
| 220 |
+
|
| 221 |
+
############################################
|
| 222 |
+
# initialize past states if not provided
|
| 223 |
+
############################################
|
| 224 |
+
if use_cache and past_key_value is None:
|
| 225 |
+
raise ValueError
|
| 226 |
+
if use_cache and multibyte_decoding:
|
| 227 |
+
raise NotImplementedError("Multibyte decoding is not supported for PyTorch native implementation")
|
| 228 |
+
# assert isinstance(attention_mask, tuple)
|
| 229 |
+
if len(attention_mask) == 4:
|
| 230 |
+
assert use_cache
|
| 231 |
+
prev_causal_mask, cur_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
|
| 232 |
+
elif len(attention_mask) == 3:
|
| 233 |
+
assert not use_cache
|
| 234 |
+
window_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
|
| 235 |
+
else:
|
| 236 |
+
raise NotImplementedError("Only attention-mask tuple with length 2 or 3 is supported")
|
| 237 |
+
|
| 238 |
+
############################################
|
| 239 |
+
# compute q, k, v from hidden states
|
| 240 |
+
############################################
|
| 241 |
+
# [b, h, q_len, d]
|
| 242 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 243 |
+
# [b, h, kv_len, d]
|
| 244 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 245 |
+
# [b, h, kv_len, d]
|
| 246 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 247 |
+
|
| 248 |
+
if use_cache:
|
| 249 |
+
past_key_value.update_past_len(q.shape[-2], self.layer_idx)
|
| 250 |
+
|
| 251 |
+
############################################
|
| 252 |
+
# apply rotary positional embeddings to q, k
|
| 253 |
+
############################################
|
| 254 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 255 |
+
|
| 256 |
+
############################################
|
| 257 |
+
# compute q, k, v stats for the local window
|
| 258 |
+
############################################
|
| 259 |
+
if use_cache:
|
| 260 |
+
(prev_w_q, prev_w_k, prev_w_v), (cur_w_q, cur_w_k, cur_w_v) = past_key_value.update_singletons(
|
| 261 |
+
q,
|
| 262 |
+
k,
|
| 263 |
+
v,
|
| 264 |
+
self.layer_idx,
|
| 265 |
+
self.window_size,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
prev_w_q = self.window_partition(q) # [b, h, w, i, d]
|
| 269 |
+
prev_w_k = self.window_partition(k) # [b, h, w, j, d]
|
| 270 |
+
prev_w_v = self.window_partition(v) # [b, h, w, j, d]
|
| 271 |
+
# during training, we assume window_size divides seq_len so no remainders
|
| 272 |
+
cur_w_q = cur_w_k = cur_w_v = None
|
| 273 |
+
|
| 274 |
+
############################################
|
| 275 |
+
# compute q, k, v stats for chunk-level RFAs
|
| 276 |
+
############################################
|
| 277 |
+
if use_cache:
|
| 278 |
+
dump_q, dump_k, dump_v = past_key_value.update_chunks(q, k, v, self.layer_idx, self.chunk_size)
|
| 279 |
+
else:
|
| 280 |
+
dump_q, dump_k, dump_v = q, k, v
|
| 281 |
+
|
| 282 |
+
if use_cache:
|
| 283 |
+
prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask = past_key_value.update_mask(
|
| 284 |
+
prev_s_mask=prev_causal_mask,
|
| 285 |
+
cur_s_mask=cur_causal_mask,
|
| 286 |
+
chunk_mask=chunk_causal_mask,
|
| 287 |
+
rf_mask=intra_chunk_mask,
|
| 288 |
+
layer_idx=self.layer_idx,
|
| 289 |
+
window_size=self.window_size,
|
| 290 |
+
chunk_size=self.chunk_size,
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
prev_s_mask = self.window_partition(prev_causal_mask) # [1, 1, w, i, j]
|
| 294 |
+
cur_s_mask = None
|
| 295 |
+
prev_chunk_mask = self.window_partition(chunk_causal_mask)
|
| 296 |
+
cur_chunk_mask = None
|
| 297 |
+
dump_rf_mask = intra_chunk_mask
|
| 298 |
+
if prev_s_mask.shape[-3] == 1:
|
| 299 |
+
# need to expand
|
| 300 |
+
prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
|
| 301 |
+
|
| 302 |
+
if (
|
| 303 |
+
dump_q is not None and
|
| 304 |
+
dump_k is not None and
|
| 305 |
+
dump_v is not None
|
| 306 |
+
):
|
| 307 |
+
# [b, h, c, j, d]
|
| 308 |
+
rf_q = self.window_partition(dump_q, window_size=self.chunk_size)
|
| 309 |
+
# [b, h, c, j, d]
|
| 310 |
+
rf_k = self.window_partition(dump_k, window_size=self.chunk_size)
|
| 311 |
+
# [b, h, c, j, d]
|
| 312 |
+
rf_v = self.window_partition(dump_v, window_size=self.chunk_size)
|
| 313 |
+
|
| 314 |
+
if dump_rf_mask is not None:
|
| 315 |
+
rf_mask = self.window_partition(dump_rf_mask, window_size=self.chunk_size)
|
| 316 |
+
rf_q = rf_q.masked_fill(rf_mask, 0.)
|
| 317 |
+
rf_k = rf_k.masked_fill(rf_mask, 0.)
|
| 318 |
+
rf_v = rf_v.masked_fill(rf_mask, 0.)
|
| 319 |
+
else:
|
| 320 |
+
rf_mask = None
|
| 321 |
+
else:
|
| 322 |
+
rf_q = None
|
| 323 |
+
rf_k = None
|
| 324 |
+
rf_v = None
|
| 325 |
+
rf_mask = None
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if rf_q is not None:
|
| 329 |
+
# import pdb; pdb.set_trace()
|
| 330 |
+
weights, rf_k_bar = self._generate_feature_map(rf_q, rf_k, rf_v)
|
| 331 |
+
softmax_phi_k_v, log_sum_phi_k = self._calculate_chunk_rfa_cache(rf_q, rf_k, rf_v, weights, rf_mask=rf_mask)
|
| 332 |
+
if use_cache:
|
| 333 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.update_chunk_rfas(
|
| 334 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar, self.layer_idx, 1
|
| 335 |
+
)
|
| 336 |
+
elif use_cache:
|
| 337 |
+
weights = None
|
| 338 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 339 |
+
else:
|
| 340 |
+
weights = None
|
| 341 |
+
softmax_phi_k_v = None
|
| 342 |
+
log_sum_phi_k = None
|
| 343 |
+
rf_k_bar = None
|
| 344 |
+
|
| 345 |
+
if rf_k_bar is not None:
|
| 346 |
+
rfa_per_chunk = self._calculate_chunk_rfa(q, softmax_phi_k_v, log_sum_phi_k, weights)
|
| 347 |
+
############################################
|
| 348 |
+
# compute meta-attention weights for
|
| 349 |
+
# - group-wise RFAs and
|
| 350 |
+
# - singletons (equivalent to exact local attention)
|
| 351 |
+
############################################
|
| 352 |
+
if prev_w_k is not None:
|
| 353 |
+
if rf_k_bar is not None:
|
| 354 |
+
num_windows = prev_w_k.shape[-3]
|
| 355 |
+
# rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
|
| 356 |
+
# -> [b, h, 1, c, d] -> [b, h, w, c, d]
|
| 357 |
+
prev_rf_k_bar = rf_k_bar.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
|
| 358 |
+
prev_rfa_per_chunk = rfa_per_chunk.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
|
| 359 |
+
prev_agg_k = torch.cat([prev_w_k, prev_rf_k_bar], dim=-2)
|
| 360 |
+
prev_agg_v = torch.cat([prev_w_v, prev_rfa_per_chunk], dim=-2)
|
| 361 |
+
|
| 362 |
+
prev_attn_mask = torch.cat([prev_s_mask, prev_chunk_mask], dim=-1)
|
| 363 |
+
else:
|
| 364 |
+
prev_agg_k = prev_w_k
|
| 365 |
+
prev_agg_v = prev_w_v
|
| 366 |
+
prev_attn_mask = prev_s_mask
|
| 367 |
+
|
| 368 |
+
prev_attn_output = attention_op(
|
| 369 |
+
q=prev_w_q,
|
| 370 |
+
k=prev_agg_k,
|
| 371 |
+
v=prev_agg_v,
|
| 372 |
+
attn_mask=prev_attn_mask,
|
| 373 |
+
mixedp_attn=self.config.mixedp_attn,
|
| 374 |
+
head_dim_scaling=self.head_dim_scaling
|
| 375 |
+
)
|
| 376 |
+
prev_attn_output = self.window_merge(prev_attn_output)
|
| 377 |
+
|
| 378 |
+
if cur_w_k is not None:
|
| 379 |
+
if rf_k_bar is not None:
|
| 380 |
+
# rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
|
| 381 |
+
# cur_w_k and cur_w_v also has shape [b, h, m, d]
|
| 382 |
+
cur_agg_k = torch.cat([cur_w_k, rf_k_bar], dim=-2)
|
| 383 |
+
cur_agg_v = torch.cat([cur_w_v, rfa_per_chunk], dim=-2)
|
| 384 |
+
|
| 385 |
+
cur_attn_mask = torch.cat([cur_s_mask, cur_chunk_mask], dim=-1)
|
| 386 |
+
else:
|
| 387 |
+
cur_agg_k = cur_w_k
|
| 388 |
+
cur_agg_v = cur_w_v
|
| 389 |
+
cur_attn_mask = cur_s_mask
|
| 390 |
+
|
| 391 |
+
cur_attn_output = attention_op(
|
| 392 |
+
q=cur_w_q,
|
| 393 |
+
k=cur_agg_k,
|
| 394 |
+
v=cur_agg_v,
|
| 395 |
+
attn_mask=cur_attn_mask,
|
| 396 |
+
mixedp_attn=self.config.mixedp_attn,
|
| 397 |
+
head_dim_scaling=self.head_dim_scaling
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if prev_w_k is not None and cur_w_k is not None:
|
| 401 |
+
attn_output = torch.cat([prev_attn_output, cur_attn_output], dim=-2)
|
| 402 |
+
elif prev_w_k is not None:
|
| 403 |
+
attn_output = prev_attn_output
|
| 404 |
+
elif cur_w_k is not None:
|
| 405 |
+
attn_output = cur_attn_output
|
| 406 |
+
else:
|
| 407 |
+
raise ValueError("There must be some bug")
|
| 408 |
+
|
| 409 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 412 |
+
f" {attn_output.size()}"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 416 |
+
attn_output = self.o_proj(attn_output)
|
| 417 |
+
|
| 418 |
+
attn_weights = None
|
| 419 |
+
|
| 420 |
+
return attn_output, attn_weights, past_key_value
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.47.1"
|
| 7 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/image_processing_evabyte.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Image processor class for EvaByte."""
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional, Union, Tuple
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
ImageInput,
|
| 10 |
+
PILImageResampling,
|
| 11 |
+
valid_images,
|
| 12 |
+
validate_preprocess_arguments,
|
| 13 |
+
)
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
def _get_qtable_bytes():
|
| 17 |
+
return {
|
| 18 |
+
5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 19 |
+
10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 20 |
+
15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 21 |
+
20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
|
| 22 |
+
25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
|
| 23 |
+
30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
|
| 24 |
+
50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
|
| 25 |
+
75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
|
| 26 |
+
95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
|
| 27 |
+
100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _resize_if_exceeding_max_len(
|
| 32 |
+
width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
|
| 33 |
+
) -> Tuple[int, int]:
|
| 34 |
+
"""
|
| 35 |
+
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
height (`int`):
|
| 39 |
+
Height of the input image.
|
| 40 |
+
width (`int`):
|
| 41 |
+
Width of the input image.
|
| 42 |
+
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
| 43 |
+
Defines the maximum dimensions of the image.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The output size of the image after resizing.
|
| 47 |
+
"""
|
| 48 |
+
max_len = max(height, width) if max_len is None else max_len
|
| 49 |
+
aspect_ratio = width / height
|
| 50 |
+
|
| 51 |
+
if width >= height and width > max_len:
|
| 52 |
+
width = max_len
|
| 53 |
+
height = int(width / aspect_ratio)
|
| 54 |
+
if height % 2 != 0:
|
| 55 |
+
height += 1
|
| 56 |
+
elif height > width and height > max_len:
|
| 57 |
+
height = max_len
|
| 58 |
+
width = int(height * aspect_ratio)
|
| 59 |
+
if width % 2 != 0:
|
| 60 |
+
width += 1
|
| 61 |
+
|
| 62 |
+
# Avoid resizing to a size smaller than 1
|
| 63 |
+
height = max(height, min_len)
|
| 64 |
+
width = max(width, min_len)
|
| 65 |
+
return width, height
|
| 66 |
+
|
| 67 |
+
class EvaByteImageProcessor(BaseImageProcessor):
|
| 68 |
+
|
| 69 |
+
model_input_names = []
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
do_resize: bool = True,
|
| 74 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 75 |
+
size: Dict[str, int] = None,
|
| 76 |
+
do_convert_rgb: bool = True,
|
| 77 |
+
jpeg_quality: int = 25,
|
| 78 |
+
jpeg_subsampling: str = "4:2:0",
|
| 79 |
+
jpeg_streamtype: str = 2,
|
| 80 |
+
jpeg_restart_marker_blocks: int = 1,
|
| 81 |
+
**kwargs,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
self.do_resize = do_resize
|
| 85 |
+
self.resample = resample
|
| 86 |
+
self.size = size if size is not None else {"longest_edge": 384}
|
| 87 |
+
self.do_convert_rgb = do_convert_rgb
|
| 88 |
+
self.jpeg_quality = jpeg_quality
|
| 89 |
+
self.jpeg_subsampling = jpeg_subsampling
|
| 90 |
+
self.jpeg_streamtype = jpeg_streamtype
|
| 91 |
+
self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
|
| 92 |
+
|
| 93 |
+
def jpeg_encode(
|
| 94 |
+
self,
|
| 95 |
+
image,
|
| 96 |
+
jpeg_quality,
|
| 97 |
+
jpeg_subsampling,
|
| 98 |
+
jpeg_streamtype,
|
| 99 |
+
jpeg_restart_marker_blocks,
|
| 100 |
+
):
|
| 101 |
+
with io.BytesIO() as output:
|
| 102 |
+
image.save(
|
| 103 |
+
output,
|
| 104 |
+
format="JPEG",
|
| 105 |
+
quality=jpeg_quality,
|
| 106 |
+
subsampling=jpeg_subsampling,
|
| 107 |
+
streamtype=jpeg_streamtype,
|
| 108 |
+
restart_marker_blocks=jpeg_restart_marker_blocks
|
| 109 |
+
)
|
| 110 |
+
jpeg_bytes = output.getvalue()
|
| 111 |
+
return jpeg_bytes
|
| 112 |
+
|
| 113 |
+
def jpeg_merge_qtables(
|
| 114 |
+
self,
|
| 115 |
+
image_bytes,
|
| 116 |
+
jpeg_quality=None,
|
| 117 |
+
):
|
| 118 |
+
if jpeg_quality is None:
|
| 119 |
+
jpeg_quality = self.jpeg_quality
|
| 120 |
+
qtable_bytes = _get_qtable_bytes()[jpeg_quality]
|
| 121 |
+
return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
|
| 122 |
+
|
| 123 |
+
def resize(
|
| 124 |
+
self,
|
| 125 |
+
image: Image,
|
| 126 |
+
size: Dict[str, int],
|
| 127 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 128 |
+
) -> Image:
|
| 129 |
+
if "longest_edge" in size:
|
| 130 |
+
width, height = image.size
|
| 131 |
+
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
| 132 |
+
width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
|
| 133 |
+
size = (width, height)
|
| 134 |
+
elif "width" in size and "height" in size:
|
| 135 |
+
size = (size["width"], size["height"])
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
| 138 |
+
resized_image = image.resize(size, resample=resample)
|
| 139 |
+
return resized_image
|
| 140 |
+
|
| 141 |
+
def preprocess(
|
| 142 |
+
self,
|
| 143 |
+
images: ImageInput,
|
| 144 |
+
do_resize: bool = None,
|
| 145 |
+
resample = None,
|
| 146 |
+
size: Dict[str, int] = None,
|
| 147 |
+
do_convert_rgb: bool = None,
|
| 148 |
+
jpeg_quality: int = None,
|
| 149 |
+
jpeg_subsampling: str = None,
|
| 150 |
+
jpeg_streamtype: str = None,
|
| 151 |
+
jpeg_restart_marker_blocks: int = None,
|
| 152 |
+
):
|
| 153 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 154 |
+
size = size if size is not None else self.size
|
| 155 |
+
resample = resample if resample is not None else self.resample
|
| 156 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 157 |
+
|
| 158 |
+
jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
|
| 159 |
+
jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
|
| 160 |
+
jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
|
| 161 |
+
jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
|
| 162 |
+
|
| 163 |
+
if images is not None and not valid_images(images):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 166 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
validate_preprocess_arguments(
|
| 170 |
+
do_resize=do_resize,
|
| 171 |
+
size=size,
|
| 172 |
+
resample=resample,
|
| 173 |
+
)
|
| 174 |
+
images_list = images
|
| 175 |
+
if do_convert_rgb:
|
| 176 |
+
images_list = [
|
| 177 |
+
[
|
| 178 |
+
image.convert("RGB") for image in images
|
| 179 |
+
]
|
| 180 |
+
for images in images_list
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
if do_resize:
|
| 184 |
+
images_list = [
|
| 185 |
+
[
|
| 186 |
+
self.resize(image=image, size=size, resample=resample)
|
| 187 |
+
for image in images
|
| 188 |
+
]
|
| 189 |
+
for images in images_list
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
jpeg_bytes = [
|
| 193 |
+
[
|
| 194 |
+
self.jpeg_encode(
|
| 195 |
+
image,
|
| 196 |
+
jpeg_quality,
|
| 197 |
+
jpeg_subsampling,
|
| 198 |
+
jpeg_streamtype,
|
| 199 |
+
jpeg_restart_marker_blocks
|
| 200 |
+
) for image in images
|
| 201 |
+
]
|
| 202 |
+
for images in images_list
|
| 203 |
+
]
|
| 204 |
+
return jpeg_bytes
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 57058938880
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 7 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 8 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 9 |
+
"model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 10 |
+
"model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 11 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 12 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 13 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 14 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 15 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 16 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 17 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 18 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 19 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 20 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 21 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 22 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 23 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 24 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 25 |
+
"model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 26 |
+
"model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 27 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 28 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 29 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 30 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 31 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 32 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 33 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 34 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 35 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 36 |
+
"model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 37 |
+
"model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 38 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 39 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 40 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 41 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 42 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 43 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 44 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 45 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 46 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 47 |
+
"model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 48 |
+
"model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 49 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 50 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 51 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 52 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 53 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 54 |
+
"model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 55 |
+
"model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 56 |
+
"model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 57 |
+
"model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 58 |
+
"model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 59 |
+
"model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 60 |
+
"model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 61 |
+
"model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 62 |
+
"model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 63 |
+
"model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 64 |
+
"model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 65 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 66 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 67 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 68 |
+
"model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 69 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 70 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 71 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 72 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 73 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 74 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 75 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 76 |
+
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 77 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 78 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 79 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 80 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 81 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 82 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 83 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 84 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 85 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 86 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 87 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 88 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 89 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 90 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 91 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 92 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 93 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 94 |
+
"model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 95 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 96 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 97 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 98 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 99 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 100 |
+
"model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 101 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 102 |
+
"model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 103 |
+
"model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 104 |
+
"model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 105 |
+
"model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 106 |
+
"model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 107 |
+
"model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 108 |
+
"model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 109 |
+
"model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 110 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 111 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 112 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 113 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 114 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 115 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 116 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 117 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 118 |
+
"model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 119 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 120 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 121 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 122 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 123 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 124 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 125 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 126 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 127 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 128 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 129 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 130 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 131 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 132 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 133 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 134 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 135 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 136 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 137 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 138 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 139 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 140 |
+
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 141 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 142 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 143 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 144 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 145 |
+
"model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 146 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 147 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 148 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 149 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 150 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 151 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 152 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 153 |
+
"model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 154 |
+
"model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 155 |
+
"model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 156 |
+
"model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 157 |
+
"model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 158 |
+
"model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 159 |
+
"model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 160 |
+
"model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 161 |
+
"model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 162 |
+
"model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 163 |
+
"model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 164 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 165 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 166 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 167 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 168 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 169 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 170 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 171 |
+
"model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 172 |
+
"model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 173 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 174 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 175 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 176 |
+
"model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 177 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 178 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 179 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 180 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 181 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 182 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 183 |
+
"model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 184 |
+
"model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 185 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 186 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 187 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 188 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 189 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 190 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 191 |
+
"model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 192 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 193 |
+
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 194 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 195 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 199 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 201 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 203 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 205 |
+
"model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 206 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 207 |
+
"model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 211 |
+
"model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 213 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 214 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 215 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 216 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 217 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 218 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 219 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 220 |
+
"model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 221 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 222 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 223 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 224 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 225 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 226 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 227 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 228 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 229 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 230 |
+
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 231 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 232 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 233 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 234 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 235 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 236 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 237 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 238 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 239 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 240 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 241 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 242 |
+
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 243 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 244 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 245 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 246 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 247 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 248 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 249 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 250 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 251 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 252 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 253 |
+
"model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 254 |
+
"model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 255 |
+
"model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 256 |
+
"model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 257 |
+
"model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 258 |
+
"model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 259 |
+
"model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 260 |
+
"model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 261 |
+
"model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 262 |
+
"model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 263 |
+
"model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 264 |
+
"model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 265 |
+
"model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 266 |
+
"model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 267 |
+
"model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 268 |
+
"model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 269 |
+
"model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 270 |
+
"model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 271 |
+
"model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 272 |
+
"model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 273 |
+
"model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 274 |
+
"model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 275 |
+
"model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 276 |
+
"model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 277 |
+
"model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 278 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 279 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 280 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 281 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 282 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 283 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 284 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 285 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 286 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 287 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 288 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 289 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 290 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 291 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 292 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 293 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 294 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 295 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 296 |
+
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 297 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 298 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 299 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 300 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 301 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 302 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 303 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 304 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 305 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 306 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 307 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 308 |
+
"model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 309 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 310 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 311 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 312 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 313 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 314 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 315 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 316 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 317 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 318 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 319 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 320 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 321 |
+
"model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 322 |
+
"model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 323 |
+
"model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 324 |
+
"model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 325 |
+
"model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 326 |
+
"model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 327 |
+
"model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 328 |
+
"model.norm.weight": "model-00003-of-00003.safetensors",
|
| 329 |
+
"lm_head.weight": "model-00003-of-00003.safetensors",
|
| 330 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 331 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 332 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 333 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 334 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 335 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 336 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 337 |
+
"model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 338 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 339 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 340 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 341 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 342 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 343 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 344 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 345 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 346 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 347 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 348 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 349 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 350 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 351 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 352 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 353 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 354 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 355 |
+
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 356 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 357 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 358 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 359 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 360 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 361 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 362 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 363 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 364 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 365 |
+
"model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 366 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 367 |
+
"model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 368 |
+
"model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 369 |
+
"model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 370 |
+
"model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 371 |
+
"model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 372 |
+
"model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 373 |
+
"model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 374 |
+
"model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 375 |
+
"model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 376 |
+
"model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 377 |
+
"model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 378 |
+
"model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 379 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 380 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 381 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 382 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 383 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 384 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 385 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 386 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 387 |
+
"model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 388 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 389 |
+
"model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 390 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 391 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 392 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 393 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 394 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 395 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 396 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 397 |
+
"model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 398 |
+
"model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 399 |
+
"model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 400 |
+
"model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 401 |
+
"model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 402 |
+
"model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 403 |
+
"model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 404 |
+
"model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 405 |
+
"model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 406 |
+
"model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 407 |
+
"model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 408 |
+
"model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 409 |
+
"model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 410 |
+
"model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 411 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 412 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 413 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 414 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 415 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 416 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 417 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 418 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 419 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 420 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 421 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 422 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 423 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 424 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 425 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 426 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 427 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 428 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 429 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 430 |
+
"model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 431 |
+
"model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 432 |
+
"model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 433 |
+
"model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 434 |
+
"model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 435 |
+
"model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 436 |
+
"model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 437 |
+
"model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 438 |
+
"model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 439 |
+
"model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 440 |
+
"model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 441 |
+
"model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 442 |
+
"model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 443 |
+
"model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 444 |
+
"model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 445 |
+
"model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 446 |
+
"model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 447 |
+
"model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 448 |
+
"model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
|
| 449 |
+
}
|
| 450 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/modeling_evabyte.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import CrossEntropyLoss
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from transformers.cache_utils import Cache
|
| 10 |
+
from transformers.modeling_outputs import (
|
| 11 |
+
BaseModelOutputWithPast,
|
| 12 |
+
CausalLMOutputWithPast,
|
| 13 |
+
)
|
| 14 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
+
|
| 16 |
+
from .configuration_evabyte import EvaByteConfig
|
| 17 |
+
from .multibyte_decoding_evabyte import MultiByteDecodingMixin
|
| 18 |
+
try:
|
| 19 |
+
import triton
|
| 20 |
+
USE_TRITON_IMPL = True
|
| 21 |
+
from .eva import EvaAttention
|
| 22 |
+
from .eva_agg_kernel import triton_eva_agg_fwd
|
| 23 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 24 |
+
except ImportError:
|
| 25 |
+
USE_TRITON_IMPL = False
|
| 26 |
+
print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
|
| 27 |
+
from .eva_pt_ref import EvaAttention
|
| 28 |
+
from .eva_cache import EvaCache, EvaStaticCacheForTriton
|
| 29 |
+
|
| 30 |
+
MASK_MIN_VALUE = -10e10
|
| 31 |
+
|
| 32 |
+
def prepare_eva_attention_mask(
|
| 33 |
+
seq_len,
|
| 34 |
+
device,
|
| 35 |
+
chunk_size,
|
| 36 |
+
window_size,
|
| 37 |
+
use_cache=False,
|
| 38 |
+
cache=None
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Prepare attention masks for EVA.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
chunk_causal_mask = None
|
| 45 |
+
window_causal_mask = None
|
| 46 |
+
if use_cache:
|
| 47 |
+
cached_seq_len = cache.get_seq_length()
|
| 48 |
+
total_seq_len = seq_len + cached_seq_len
|
| 49 |
+
# cached_seq_len will be 0 during prefilling
|
| 50 |
+
# padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
|
| 51 |
+
padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
|
| 52 |
+
num_chunks = padded_seq_len // chunk_size
|
| 53 |
+
else:
|
| 54 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 55 |
+
assert seq_len % chunk_size == 0
|
| 56 |
+
num_chunks = seq_len // chunk_size
|
| 57 |
+
|
| 58 |
+
assert seq_len % window_size == 0
|
| 59 |
+
|
| 60 |
+
# create causal mask
|
| 61 |
+
################################
|
| 62 |
+
# generate chunked causal masks
|
| 63 |
+
################################
|
| 64 |
+
# [b, h, j, c, c]
|
| 65 |
+
chunks_per_window = window_size // chunk_size
|
| 66 |
+
if num_chunks >= chunks_per_window:
|
| 67 |
+
chunk_causal_mask = torch.ones(
|
| 68 |
+
(chunk_size, num_chunks, num_chunks),
|
| 69 |
+
device=device,
|
| 70 |
+
dtype=torch.bool
|
| 71 |
+
).triu(0)
|
| 72 |
+
|
| 73 |
+
num_blocks = num_chunks // chunks_per_window
|
| 74 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 75 |
+
chunk_size,
|
| 76 |
+
num_blocks,
|
| 77 |
+
chunks_per_window,
|
| 78 |
+
num_blocks,
|
| 79 |
+
chunks_per_window
|
| 80 |
+
).transpose(-2, -3)
|
| 81 |
+
|
| 82 |
+
block_diag_zero = (
|
| 83 |
+
torch.eye(num_blocks, device=device, dtype=torch.bool)
|
| 84 |
+
.unsqueeze(-1)
|
| 85 |
+
.unsqueeze(-1)
|
| 86 |
+
.unsqueeze(0)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Set diagonal blocks to zero
|
| 90 |
+
chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
|
| 91 |
+
|
| 92 |
+
# Reshape back to original size
|
| 93 |
+
chunk_causal_mask = (
|
| 94 |
+
chunk_causal_mask
|
| 95 |
+
.transpose(-2, -3)
|
| 96 |
+
.reshape(chunk_size, num_chunks, num_chunks)
|
| 97 |
+
.transpose(-2, -3)
|
| 98 |
+
.reshape(chunk_size * num_chunks, num_chunks)
|
| 99 |
+
.unsqueeze(0)
|
| 100 |
+
.unsqueeze(0)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
chunk_causal_mask = torch.ones(
|
| 104 |
+
(1, 1, chunk_size, num_chunks, num_chunks),
|
| 105 |
+
device=device,
|
| 106 |
+
dtype=torch.bool,
|
| 107 |
+
).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
|
| 108 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 109 |
+
1, 1, chunk_size * num_chunks, num_chunks
|
| 110 |
+
) # [1, 1, n, c]
|
| 111 |
+
|
| 112 |
+
if use_cache:
|
| 113 |
+
chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
|
| 114 |
+
|
| 115 |
+
window_causal_mask = torch.ones(
|
| 116 |
+
(1, 1, 1, window_size, window_size),
|
| 117 |
+
device=device
|
| 118 |
+
).triu(1).to(torch.bool)
|
| 119 |
+
return (chunk_causal_mask, window_causal_mask)
|
| 120 |
+
|
| 121 |
+
def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
|
| 122 |
+
assert dim < 0 # only accept ``dim'' index in a reverse manner
|
| 123 |
+
seqlen = int(tensor.shape[dim])
|
| 124 |
+
m = seqlen / multiple
|
| 125 |
+
if m.is_integer():
|
| 126 |
+
if create_mask:
|
| 127 |
+
return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
|
| 128 |
+
else:
|
| 129 |
+
return tensor
|
| 130 |
+
remainder = math.ceil(m) * multiple - seqlen
|
| 131 |
+
pad_offset = (0,) * (-1 - dim) * 2
|
| 132 |
+
if left_padding:
|
| 133 |
+
padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
|
| 134 |
+
else:
|
| 135 |
+
padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
|
| 136 |
+
if create_mask:
|
| 137 |
+
# assume dim 0 is the batch size
|
| 138 |
+
padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
|
| 139 |
+
if left_padding:
|
| 140 |
+
padding_mask[:, :remainder] = False
|
| 141 |
+
else:
|
| 142 |
+
padding_mask[:, -remainder:] = False
|
| 143 |
+
return padded_res, padding_mask
|
| 144 |
+
else:
|
| 145 |
+
return padded_res
|
| 146 |
+
|
| 147 |
+
class EvaByteRMSNorm(nn.Module):
|
| 148 |
+
def __init__(self, config):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.config = config
|
| 151 |
+
self.fp32_ln = True
|
| 152 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 153 |
+
self.add_unit_offset = config.norm_add_unit_offset
|
| 154 |
+
if self.add_unit_offset:
|
| 155 |
+
self.weight = nn.Parameter(torch.zeros(config.hidden_size))
|
| 156 |
+
else:
|
| 157 |
+
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states):
|
| 160 |
+
_hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
|
| 161 |
+
|
| 162 |
+
variance = _hidden_states.pow(2).mean(-1, keepdim=True)
|
| 163 |
+
_hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 164 |
+
if self.add_unit_offset:
|
| 165 |
+
return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
|
| 166 |
+
else:
|
| 167 |
+
return (self.weight * _hidden_states).type_as(hidden_states)
|
| 168 |
+
|
| 169 |
+
class EvaByteRotaryEmbedding(torch.nn.Module):
|
| 170 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 171 |
+
super().__init__()
|
| 172 |
+
|
| 173 |
+
self.dim = dim
|
| 174 |
+
self.max_position_embeddings = max_position_embeddings
|
| 175 |
+
self.base = base
|
| 176 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
|
| 179 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
| 180 |
+
device=self.inv_freq.device,
|
| 181 |
+
dtype=torch.get_default_dtype())
|
| 182 |
+
|
| 183 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 184 |
+
self.max_seq_len_cached = seq_len
|
| 185 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 186 |
+
|
| 187 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 188 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 189 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 190 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def forward(self, x, seq_len=None):
|
| 194 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 195 |
+
if seq_len > self.max_seq_len_cached:
|
| 196 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 197 |
+
|
| 198 |
+
# return (
|
| 199 |
+
# self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 200 |
+
# self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 201 |
+
# )
|
| 202 |
+
if seq_len < self.max_seq_len_cached:
|
| 203 |
+
cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
|
| 204 |
+
sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
|
| 205 |
+
else:
|
| 206 |
+
cos_slice = self.cos_cached
|
| 207 |
+
sin_slice = self.sin_cached
|
| 208 |
+
|
| 209 |
+
return (
|
| 210 |
+
cos_slice.to(dtype=x.dtype),
|
| 211 |
+
sin_slice.to(dtype=x.dtype),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 217 |
+
"""EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 218 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 219 |
+
self.scaling_factor = scaling_factor
|
| 220 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 221 |
+
|
| 222 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 223 |
+
self.max_seq_len_cached = seq_len
|
| 224 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 225 |
+
t = t / self.scaling_factor
|
| 226 |
+
|
| 227 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 228 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 229 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 230 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 231 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 235 |
+
"""EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 236 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 237 |
+
self.scaling_factor = scaling_factor
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
|
| 243 |
+
if seq_len > self.max_position_embeddings:
|
| 244 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
| 245 |
+
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
|
| 246 |
+
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 247 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 248 |
+
|
| 249 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 250 |
+
|
| 251 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 252 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 253 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 254 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 255 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class EvaByteMLP(nn.Module):
|
| 259 |
+
def __init__(self, config, layer_idx: int = None):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.hidden_size = config.hidden_size
|
| 262 |
+
self.intermediate_size = config.intermediate_size
|
| 263 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 264 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 265 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 266 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 267 |
+
self.layer_idx = layer_idx
|
| 268 |
+
self.config = config
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 272 |
+
return down_proj
|
| 273 |
+
|
| 274 |
+
class EvaByteDecoderLayer(nn.Module):
|
| 275 |
+
def __init__(self, config: EvaByteConfig, layer_idx: int = None):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.config = config
|
| 278 |
+
self.hidden_size = config.hidden_size
|
| 279 |
+
self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
|
| 280 |
+
self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
|
| 281 |
+
self.input_layernorm = EvaByteRMSNorm(config)
|
| 282 |
+
self.post_attention_layernorm = EvaByteRMSNorm(config)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
hidden_states: torch.Tensor,
|
| 287 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 288 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 289 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 290 |
+
output_attentions: Optional[bool] = False,
|
| 291 |
+
use_cache: Optional[bool] = False,
|
| 292 |
+
cos: Optional[torch.Tensor] = None,
|
| 293 |
+
sin: Optional[torch.Tensor] = None,
|
| 294 |
+
multibyte_decoding: Optional[bool] = False,
|
| 295 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 296 |
+
residual = hidden_states
|
| 297 |
+
if self.config.fp32_skip_add:
|
| 298 |
+
residual = residual.float()
|
| 299 |
+
|
| 300 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 301 |
+
|
| 302 |
+
# Self Attention
|
| 303 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
|
| 304 |
+
attention_mask=attention_mask,
|
| 305 |
+
position_ids=position_ids,
|
| 306 |
+
past_key_value=past_key_value,
|
| 307 |
+
output_attentions=output_attentions,
|
| 308 |
+
use_cache=use_cache,
|
| 309 |
+
cos=cos,
|
| 310 |
+
sin=sin,
|
| 311 |
+
multibyte_decoding=multibyte_decoding)
|
| 312 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 313 |
+
|
| 314 |
+
# Fully Connected
|
| 315 |
+
residual = hidden_states
|
| 316 |
+
if self.config.fp32_skip_add:
|
| 317 |
+
residual = residual.float()
|
| 318 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 319 |
+
hidden_states = self.mlp(hidden_states)
|
| 320 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 321 |
+
|
| 322 |
+
outputs = (hidden_states, )
|
| 323 |
+
|
| 324 |
+
if output_attentions:
|
| 325 |
+
outputs += (self_attn_weights, )
|
| 326 |
+
|
| 327 |
+
if use_cache:
|
| 328 |
+
outputs += (present_key_value, )
|
| 329 |
+
return outputs
|
| 330 |
+
|
| 331 |
+
class EvaBytePreTrainedModel(PreTrainedModel):
|
| 332 |
+
config_class = EvaByteConfig
|
| 333 |
+
base_model_prefix = "model"
|
| 334 |
+
supports_gradient_checkpointing = True
|
| 335 |
+
_no_split_modules = ["EvaByteDecoderLayer"]
|
| 336 |
+
_skip_keys_device_placement = "past_key_values"
|
| 337 |
+
|
| 338 |
+
def _init_weights(self, module):
|
| 339 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 340 |
+
if isinstance(module, nn.Linear):
|
| 341 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 342 |
+
if module.bias is not None:
|
| 343 |
+
module.bias.data.zero_()
|
| 344 |
+
elif isinstance(module, nn.Embedding):
|
| 345 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 346 |
+
if module.padding_idx is not None:
|
| 347 |
+
module.weight.data[module.padding_idx].zero_()
|
| 348 |
+
|
| 349 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 350 |
+
if isinstance(module, EvaByteModel):
|
| 351 |
+
module.gradient_checkpointing = value
|
| 352 |
+
|
| 353 |
+
class EvaByteModel(EvaBytePreTrainedModel):
|
| 354 |
+
"""
|
| 355 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
config: EvaByteConfig
|
| 359 |
+
"""
|
| 360 |
+
def __init__(self, config: EvaByteConfig):
|
| 361 |
+
super().__init__(config)
|
| 362 |
+
self.padding_idx = config.pad_token_id
|
| 363 |
+
self.vocab_size = config.vocab_size
|
| 364 |
+
self.hidden_size = config.hidden_size
|
| 365 |
+
self.num_heads = config.num_attention_heads
|
| 366 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 367 |
+
self.max_position_embeddings = self.config.max_position_embeddings
|
| 368 |
+
|
| 369 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 370 |
+
self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 371 |
+
self.norm = EvaByteRMSNorm(config)
|
| 372 |
+
|
| 373 |
+
self.gradient_checkpointing = False
|
| 374 |
+
self.rope = config.rope_theta
|
| 375 |
+
# Initialize weights and apply final processing
|
| 376 |
+
self.post_init()
|
| 377 |
+
self._init_rope()
|
| 378 |
+
|
| 379 |
+
def _init_rope(self):
|
| 380 |
+
if self.config.rope_scaling is None:
|
| 381 |
+
self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
|
| 382 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 383 |
+
base=self.rope)
|
| 384 |
+
else:
|
| 385 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 386 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 387 |
+
if scaling_type == "linear":
|
| 388 |
+
self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
|
| 389 |
+
self.head_dim,
|
| 390 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 391 |
+
scaling_factor=scaling_factor,
|
| 392 |
+
base=self.rope)
|
| 393 |
+
elif scaling_type == "dynamic":
|
| 394 |
+
self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
|
| 395 |
+
self.head_dim,
|
| 396 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 397 |
+
scaling_factor=scaling_factor,
|
| 398 |
+
base=self.rope)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 401 |
+
|
| 402 |
+
def get_input_embeddings(self):
|
| 403 |
+
return self.embed_tokens
|
| 404 |
+
|
| 405 |
+
def set_input_embeddings(self, value):
|
| 406 |
+
self.embed_tokens = value
|
| 407 |
+
|
| 408 |
+
def _helper_padding_mask(
|
| 409 |
+
self,
|
| 410 |
+
padding_mask,
|
| 411 |
+
causal_mask
|
| 412 |
+
):
|
| 413 |
+
padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
|
| 414 |
+
return torch.logical_or(padding_mask, causal_mask)
|
| 415 |
+
|
| 416 |
+
def _prepare_eva_generation_attn_mask_triton(
|
| 417 |
+
self,
|
| 418 |
+
attention_mask,
|
| 419 |
+
input_ids,
|
| 420 |
+
use_cache,
|
| 421 |
+
past_key_values
|
| 422 |
+
):
|
| 423 |
+
batch_size, seq_len = input_ids.shape
|
| 424 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 425 |
+
# decoding phase
|
| 426 |
+
if past_key_values.rf_mask[0] is not None:
|
| 427 |
+
cur_rf_mask = torch.zeros(
|
| 428 |
+
(batch_size, 1, seq_len, 1),
|
| 429 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 430 |
+
device=past_key_values.rf_mask[0].device
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
cur_rf_mask = None
|
| 434 |
+
|
| 435 |
+
if past_key_values.s_mask[0] is not None:
|
| 436 |
+
cur_s_mask = torch.zeros(
|
| 437 |
+
(batch_size, 1, seq_len, 1),
|
| 438 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 439 |
+
device=past_key_values.s_mask[0].device
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
cur_s_mask = None
|
| 443 |
+
|
| 444 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 445 |
+
if seen_tokens <= self.config.window_size:
|
| 446 |
+
rfa_chunks_dummy_mask = None
|
| 447 |
+
else:
|
| 448 |
+
if cur_s_mask is not None:
|
| 449 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 450 |
+
# the ongoing decoding step would be (seen_seq_len + 1)-th token
|
| 451 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 452 |
+
rfa_chunks_dummy_mask = torch.zeros(
|
| 453 |
+
(batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
|
| 454 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 455 |
+
device=past_key_values.s_mask[0].device
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
rfa_chunks_dummy_mask = None
|
| 459 |
+
# rf_mask and cur_mask are 0s because we do not want to mask them
|
| 460 |
+
return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
|
| 461 |
+
|
| 462 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 463 |
+
# convert 0 -> padding to 1 -> padding
|
| 464 |
+
padded_attention_mask = pad_to_multiple(
|
| 465 |
+
attention_mask,
|
| 466 |
+
self.config.window_size,
|
| 467 |
+
dim=-1,
|
| 468 |
+
value=0,
|
| 469 |
+
create_mask=False,
|
| 470 |
+
left_padding=False
|
| 471 |
+
)
|
| 472 |
+
# convert 0 -> padding to 1 -> padding
|
| 473 |
+
padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 474 |
+
# [b, 1, w, j, 1]
|
| 475 |
+
padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
|
| 476 |
+
# [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
|
| 477 |
+
w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
|
| 478 |
+
w_causal_mask = torch.ones(
|
| 479 |
+
(1, 1, 1, self.config.window_size, self.config.window_size),
|
| 480 |
+
device=input_ids.device
|
| 481 |
+
).triu(1).to(torch.bool)
|
| 482 |
+
s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
|
| 483 |
+
s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
|
| 484 |
+
s_mask = s_mask[..., :seq_len, :]
|
| 485 |
+
# negate the attention mask to get the padding mask
|
| 486 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 487 |
+
return (s_mask, rf_mask)
|
| 488 |
+
else:
|
| 489 |
+
return (None, None)
|
| 490 |
+
|
| 491 |
+
def _prepare_eva_generation_attn_mask(
|
| 492 |
+
self,
|
| 493 |
+
attention_mask,
|
| 494 |
+
input_ids,
|
| 495 |
+
use_cache,
|
| 496 |
+
past_key_values
|
| 497 |
+
):
|
| 498 |
+
batch_size, seq_len = input_ids.shape
|
| 499 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 500 |
+
# decoding phase
|
| 501 |
+
if past_key_values.rf_mask[0] is not None:
|
| 502 |
+
rf_mask = torch.zeros(
|
| 503 |
+
(batch_size, 1, seq_len, 1),
|
| 504 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 505 |
+
device=past_key_values.rf_mask[0].device
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
rf_mask = None
|
| 509 |
+
|
| 510 |
+
cur_causal_mask = torch.zeros(
|
| 511 |
+
(batch_size, 1, seq_len, 1),
|
| 512 |
+
dtype=torch.bool,
|
| 513 |
+
device=input_ids.device
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
chunk_causal_mask = torch.ones(
|
| 517 |
+
(batch_size, 1, seq_len, 1),
|
| 518 |
+
dtype=torch.bool,
|
| 519 |
+
device=input_ids.device
|
| 520 |
+
)
|
| 521 |
+
# chunk_causal_mask are 1s because we will mask them by default and
|
| 522 |
+
# will be unmasked when the current singleton attention is processed over
|
| 523 |
+
return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
|
| 524 |
+
|
| 525 |
+
true_num_chunks = seq_len // self.config.chunk_size
|
| 526 |
+
chunk_causal_mask, _ = prepare_eva_attention_mask(
|
| 527 |
+
seq_len,
|
| 528 |
+
input_ids.device,
|
| 529 |
+
self.config.chunk_size,
|
| 530 |
+
self.config.window_size,
|
| 531 |
+
use_cache=use_cache,
|
| 532 |
+
cache=past_key_values
|
| 533 |
+
)
|
| 534 |
+
chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
|
| 535 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 536 |
+
# convert 0 -> padding to 1 -> padding
|
| 537 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 538 |
+
else:
|
| 539 |
+
rf_mask = None
|
| 540 |
+
|
| 541 |
+
if seq_len < self.config.window_size:
|
| 542 |
+
cur_window_mask = torch.ones(
|
| 543 |
+
(1, 1, seq_len, seq_len),
|
| 544 |
+
device=input_ids.device
|
| 545 |
+
).triu(1).to(torch.bool)
|
| 546 |
+
if rf_mask is not None:
|
| 547 |
+
cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
|
| 548 |
+
prev_window_mask = None
|
| 549 |
+
else:
|
| 550 |
+
if seq_len % self.config.window_size == 0:
|
| 551 |
+
num_windows = seq_len // self.config.window_size
|
| 552 |
+
cur_window_mask = None
|
| 553 |
+
prev_window_mask = torch.ones(
|
| 554 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 555 |
+
device=input_ids.device
|
| 556 |
+
).triu(1).to(torch.bool)
|
| 557 |
+
if rf_mask is not None:
|
| 558 |
+
prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 559 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 560 |
+
else:
|
| 561 |
+
num_windows = seq_len // self.config.window_size
|
| 562 |
+
remainder_tokens = seq_len % self.config.window_size
|
| 563 |
+
cur_window_mask = torch.ones(
|
| 564 |
+
(1, 1, remainder_tokens, remainder_tokens),
|
| 565 |
+
device=input_ids.device
|
| 566 |
+
).triu(1).to(torch.bool)
|
| 567 |
+
prev_window_mask = torch.ones(
|
| 568 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 569 |
+
device=input_ids.device
|
| 570 |
+
).triu(1).to(torch.bool)
|
| 571 |
+
if rf_mask is not None:
|
| 572 |
+
prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 573 |
+
cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
|
| 574 |
+
prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 575 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 576 |
+
|
| 577 |
+
return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
|
| 578 |
+
|
| 579 |
+
def forward(
|
| 580 |
+
self,
|
| 581 |
+
input_ids: torch.LongTensor = None,
|
| 582 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 583 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 584 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 585 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 586 |
+
use_cache: Optional[bool] = None,
|
| 587 |
+
output_attentions: Optional[bool] = None,
|
| 588 |
+
output_hidden_states: Optional[bool] = None,
|
| 589 |
+
return_dict: Optional[bool] = None,
|
| 590 |
+
multibyte_decoding: Optional[bool] = None,
|
| 591 |
+
) -> Tuple:
|
| 592 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 593 |
+
output_hidden_states = (output_hidden_states
|
| 594 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 595 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 596 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 597 |
+
|
| 598 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 599 |
+
raise ValueError(
|
| 600 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 604 |
+
raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 605 |
+
|
| 606 |
+
batch_size, seq_len = input_ids.shape
|
| 607 |
+
#### Step 0. Hack
|
| 608 |
+
if (not self.training) and (not use_cache) and (not multibyte_decoding):
|
| 609 |
+
# forward-only inference mode.
|
| 610 |
+
# We tweak use_cache to be True to reuse code for generation
|
| 611 |
+
use_cache = True
|
| 612 |
+
device = input_ids.device if input_ids is not None else None
|
| 613 |
+
if position_ids is None:
|
| 614 |
+
position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 615 |
+
|
| 616 |
+
#### Step 1. Prepare caches if in inference mode
|
| 617 |
+
if use_cache:
|
| 618 |
+
if past_key_values is not None:
|
| 619 |
+
assert isinstance(past_key_values, Cache)
|
| 620 |
+
else:
|
| 621 |
+
if not USE_TRITON_IMPL:
|
| 622 |
+
past_key_values = EvaCache()
|
| 623 |
+
else:
|
| 624 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 625 |
+
input_ids.shape[0],
|
| 626 |
+
self.config.num_attention_heads,
|
| 627 |
+
self.config.window_size,
|
| 628 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 629 |
+
self.config.num_hidden_layers,
|
| 630 |
+
self.embed_tokens.weight.dtype,
|
| 631 |
+
self.embed_tokens.weight.device,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if not multibyte_decoding:
|
| 635 |
+
if use_cache:
|
| 636 |
+
if USE_TRITON_IMPL:
|
| 637 |
+
causal_mask = self._prepare_eva_generation_attn_mask_triton(
|
| 638 |
+
attention_mask,
|
| 639 |
+
input_ids,
|
| 640 |
+
use_cache,
|
| 641 |
+
past_key_values
|
| 642 |
+
)
|
| 643 |
+
else:
|
| 644 |
+
causal_mask = self._prepare_eva_generation_attn_mask(
|
| 645 |
+
attention_mask,
|
| 646 |
+
input_ids,
|
| 647 |
+
use_cache,
|
| 648 |
+
past_key_values
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
assert self.training
|
| 652 |
+
assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
|
| 653 |
+
# for training, we need to pass in the attention mask
|
| 654 |
+
# usually calculated by _prepare_training_attn_mask()
|
| 655 |
+
causal_mask = attention_mask
|
| 656 |
+
else:
|
| 657 |
+
assert use_cache
|
| 658 |
+
causal_mask = attention_mask
|
| 659 |
+
|
| 660 |
+
if inputs_embeds is None:
|
| 661 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 662 |
+
|
| 663 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 664 |
+
max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
|
| 665 |
+
|
| 666 |
+
hidden_states = inputs_embeds
|
| 667 |
+
|
| 668 |
+
if position_ids is None:
|
| 669 |
+
assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
|
| 670 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 671 |
+
position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 672 |
+
|
| 673 |
+
cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
|
| 674 |
+
assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
|
| 675 |
+
assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
|
| 676 |
+
assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
|
| 677 |
+
cos = cos[position_ids, :]
|
| 678 |
+
sin = sin[position_ids, :]
|
| 679 |
+
cos = cos.unsqueeze(1)
|
| 680 |
+
sin = sin.unsqueeze(1)
|
| 681 |
+
|
| 682 |
+
# decoder layers
|
| 683 |
+
all_hidden_states = () if output_hidden_states else None
|
| 684 |
+
all_self_attns = () if output_attentions else None
|
| 685 |
+
next_decoder_cache = None
|
| 686 |
+
|
| 687 |
+
for decoder_layer in self.layers:
|
| 688 |
+
if output_hidden_states:
|
| 689 |
+
all_hidden_states += (hidden_states, )
|
| 690 |
+
|
| 691 |
+
if self.gradient_checkpointing and self.training:
|
| 692 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 693 |
+
decoder_layer.__call__,
|
| 694 |
+
hidden_states,
|
| 695 |
+
causal_mask,
|
| 696 |
+
position_ids,
|
| 697 |
+
past_key_values,
|
| 698 |
+
output_attentions,
|
| 699 |
+
use_cache,
|
| 700 |
+
cos,
|
| 701 |
+
sin,
|
| 702 |
+
multibyte_decoding,
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
layer_outputs = decoder_layer(
|
| 706 |
+
hidden_states,
|
| 707 |
+
attention_mask=causal_mask,
|
| 708 |
+
position_ids=position_ids,
|
| 709 |
+
past_key_value=past_key_values,
|
| 710 |
+
output_attentions=output_attentions,
|
| 711 |
+
use_cache=use_cache,
|
| 712 |
+
cos=cos,
|
| 713 |
+
sin=sin,
|
| 714 |
+
multibyte_decoding=multibyte_decoding,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
hidden_states = layer_outputs[0]
|
| 718 |
+
|
| 719 |
+
if use_cache:
|
| 720 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 721 |
+
|
| 722 |
+
if output_attentions:
|
| 723 |
+
all_self_attns += (layer_outputs[1], )
|
| 724 |
+
|
| 725 |
+
hidden_states = self.norm(hidden_states)
|
| 726 |
+
|
| 727 |
+
# add hidden states from the last decoder layer
|
| 728 |
+
if output_hidden_states:
|
| 729 |
+
all_hidden_states += (hidden_states, )
|
| 730 |
+
|
| 731 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 732 |
+
if not return_dict:
|
| 733 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 734 |
+
|
| 735 |
+
return BaseModelOutputWithPast(
|
| 736 |
+
last_hidden_state=hidden_states,
|
| 737 |
+
past_key_values=next_cache,
|
| 738 |
+
hidden_states=all_hidden_states,
|
| 739 |
+
attentions=all_self_attns,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
|
| 744 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 745 |
+
|
| 746 |
+
def __init__(self, config):
|
| 747 |
+
EvaBytePreTrainedModel.__init__(self, config)
|
| 748 |
+
|
| 749 |
+
self.model = EvaByteModel(config)
|
| 750 |
+
self.vocab_size = config.vocab_size
|
| 751 |
+
# define multibyte prediction heads
|
| 752 |
+
if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
|
| 753 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
|
| 754 |
+
else:
|
| 755 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 756 |
+
|
| 757 |
+
self.post_init()
|
| 758 |
+
|
| 759 |
+
def get_input_embeddings(self):
|
| 760 |
+
return self.model.embed_tokens
|
| 761 |
+
|
| 762 |
+
def set_input_embeddings(self, value):
|
| 763 |
+
self.model.embed_tokens = value
|
| 764 |
+
|
| 765 |
+
def get_output_embeddings(self):
|
| 766 |
+
return self.lm_head
|
| 767 |
+
|
| 768 |
+
def set_output_embeddings(self, new_embeddings):
|
| 769 |
+
self.lm_head = new_embeddings
|
| 770 |
+
|
| 771 |
+
def set_decoder(self, decoder):
|
| 772 |
+
self.model = decoder
|
| 773 |
+
|
| 774 |
+
def get_decoder(self):
|
| 775 |
+
return self.model
|
| 776 |
+
|
| 777 |
+
def forward(
|
| 778 |
+
self,
|
| 779 |
+
input_ids: torch.LongTensor = None,
|
| 780 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 781 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 782 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 783 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 784 |
+
labels: Optional[torch.LongTensor] = None,
|
| 785 |
+
use_cache: Optional[bool] = None,
|
| 786 |
+
output_attentions: Optional[bool] = None,
|
| 787 |
+
output_hidden_states: Optional[bool] = None,
|
| 788 |
+
return_dict: Optional[bool] = None,
|
| 789 |
+
return_all_pred_logits: Optional[bool] = None,
|
| 790 |
+
multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 791 |
+
|
| 792 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 793 |
+
output_hidden_states = (output_hidden_states
|
| 794 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 795 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 796 |
+
|
| 797 |
+
if input_ids is None:
|
| 798 |
+
assert past_key_values is None
|
| 799 |
+
|
| 800 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 801 |
+
outputs = self.model(
|
| 802 |
+
input_ids=input_ids,
|
| 803 |
+
attention_mask=attention_mask,
|
| 804 |
+
position_ids=position_ids,
|
| 805 |
+
past_key_values=past_key_values,
|
| 806 |
+
inputs_embeds=inputs_embeds,
|
| 807 |
+
use_cache=use_cache,
|
| 808 |
+
output_attentions=output_attentions,
|
| 809 |
+
output_hidden_states=output_hidden_states,
|
| 810 |
+
return_dict=return_dict,
|
| 811 |
+
multibyte_decoding=multibyte_decoding,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
hidden_states = outputs[0]
|
| 815 |
+
|
| 816 |
+
logits = self.lm_head(hidden_states)
|
| 817 |
+
if self.config.fp32_logits:
|
| 818 |
+
logits = logits.float()
|
| 819 |
+
|
| 820 |
+
loss = None
|
| 821 |
+
if labels is not None:
|
| 822 |
+
loss_fct = CrossEntropyLoss(reduction="none")
|
| 823 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 824 |
+
shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 825 |
+
# shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
|
| 826 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 827 |
+
else:
|
| 828 |
+
shift_logits = logits.view(-1, self.config.vocab_size)
|
| 829 |
+
shift_labels = labels.view(-1)
|
| 830 |
+
# Enable model parallelism
|
| 831 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 832 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 833 |
+
|
| 834 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 835 |
+
all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 836 |
+
|
| 837 |
+
if return_all_pred_logits:
|
| 838 |
+
logits = all_pred_logits
|
| 839 |
+
else:
|
| 840 |
+
logits = all_pred_logits[..., 0, :]
|
| 841 |
+
|
| 842 |
+
if not return_dict:
|
| 843 |
+
output = (logits, ) + outputs[1:]
|
| 844 |
+
return (loss, ) + output if loss is not None else output
|
| 845 |
+
|
| 846 |
+
return CausalLMOutputWithPast(
|
| 847 |
+
loss=loss,
|
| 848 |
+
logits=logits,
|
| 849 |
+
past_key_values=outputs.past_key_values,
|
| 850 |
+
hidden_states=outputs.hidden_states,
|
| 851 |
+
attentions=outputs.attentions,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def prepare_inputs_for_generation(self,
|
| 856 |
+
input_ids,
|
| 857 |
+
past_key_values=None,
|
| 858 |
+
attention_mask=None,
|
| 859 |
+
inputs_embeds=None,
|
| 860 |
+
use_cache=True,
|
| 861 |
+
**kwargs):
|
| 862 |
+
# prefill phase:
|
| 863 |
+
# input_ids: b x s
|
| 864 |
+
# attention_mask: None if no padding or b x s
|
| 865 |
+
# position_ids : b x s
|
| 866 |
+
|
| 867 |
+
# token gen phase:
|
| 868 |
+
# input_ids : b x 1
|
| 869 |
+
# attention_mask: b x 1 x s
|
| 870 |
+
# position_ids: b x 1
|
| 871 |
+
past_length = 0
|
| 872 |
+
if past_key_values is not None:
|
| 873 |
+
assert isinstance(past_key_values, Cache)
|
| 874 |
+
past_length = past_key_values.get_seq_length()
|
| 875 |
+
|
| 876 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 877 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
| 878 |
+
elif past_length < input_ids.shape[1]:
|
| 879 |
+
input_ids = input_ids[:, past_length:]
|
| 880 |
+
|
| 881 |
+
position_ids = kwargs.get("position_ids", None)
|
| 882 |
+
if attention_mask is not None and position_ids is None:
|
| 883 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 884 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 885 |
+
if past_key_values:
|
| 886 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 887 |
+
|
| 888 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 889 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 890 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 891 |
+
else:
|
| 892 |
+
model_inputs = {"input_ids": input_ids}
|
| 893 |
+
|
| 894 |
+
# must initialize position_ids at each step during GPU inference
|
| 895 |
+
assert position_ids is not None
|
| 896 |
+
model_inputs.update(
|
| 897 |
+
{
|
| 898 |
+
"position_ids": position_ids,
|
| 899 |
+
"past_key_values": past_key_values,
|
| 900 |
+
"use_cache": use_cache,
|
| 901 |
+
"attention_mask": attention_mask,
|
| 902 |
+
}
|
| 903 |
+
)
|
| 904 |
+
return model_inputs
|
| 905 |
+
|
| 906 |
+
@staticmethod
|
| 907 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 908 |
+
reordered_past = ()
|
| 909 |
+
for layer_past in past_key_values:
|
| 910 |
+
reordered_past += (tuple(
|
| 911 |
+
past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
|
| 912 |
+
return reordered_past
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/multibyte_decoding_evabyte.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# The implementation of multibyte deocidng is largely adapted from
|
| 3 |
+
# Medusa decoding: https://github.com/FasterDecoding/Medusa
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.generation.stopping_criteria import (
|
| 7 |
+
MaxLengthCriteria,
|
| 8 |
+
StoppingCriteriaList,
|
| 9 |
+
)
|
| 10 |
+
from typing import Union, List
|
| 11 |
+
from .eva_cache import EvaStaticCacheForTriton
|
| 12 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 13 |
+
|
| 14 |
+
class MultibyteEosTokenCriteria:
|
| 15 |
+
"""
|
| 16 |
+
This class implements a simple stopping criteria to stop generation whenever
|
| 17 |
+
the "end-of-sequence" token is generated in the last `new_tokens` tokens.
|
| 18 |
+
|
| 19 |
+
Adapted from
|
| 20 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
|
| 21 |
+
By default, it uses the `model.generation_config.eos_token_id`.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
eos_token_id (`Union[int, List[int]]`):
|
| 25 |
+
The id(s) of the *end-of-sequence* token.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, eos_token_ids: Union[int, List[int]]):
|
| 29 |
+
if isinstance(eos_token_ids, int):
|
| 30 |
+
eos_token_ids = [eos_token_ids]
|
| 31 |
+
self.eos_token_ids = eos_token_ids
|
| 32 |
+
|
| 33 |
+
def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
|
| 34 |
+
current_input_len = input_ids.shape[-1]
|
| 35 |
+
new_token_ids = input_ids[:, current_input_len - new_tokens:]
|
| 36 |
+
for eos_token_id in self.eos_token_ids:
|
| 37 |
+
if torch.any(new_token_ids == eos_token_id):
|
| 38 |
+
return True
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def build_tree(spec):
|
| 42 |
+
nodes_at_depth = []
|
| 43 |
+
nodes_at_depth.append([()]) # Root at depth 1
|
| 44 |
+
|
| 45 |
+
for d in range(1, len(spec) + 1):
|
| 46 |
+
prev_nodes = nodes_at_depth[d - 1]
|
| 47 |
+
spec_list = spec[d - 1]
|
| 48 |
+
current_nodes = []
|
| 49 |
+
for node_idx, node in enumerate(prev_nodes):
|
| 50 |
+
if node_idx < len(spec_list):
|
| 51 |
+
num_children = spec_list[node_idx]
|
| 52 |
+
else:
|
| 53 |
+
num_children = 0
|
| 54 |
+
for child_idx in range(num_children):
|
| 55 |
+
new_node = node + (child_idx,)
|
| 56 |
+
current_nodes.append(new_node)
|
| 57 |
+
nodes_at_depth.append(current_nodes)
|
| 58 |
+
|
| 59 |
+
# Flatten the list of nodes, excluding the root node if desired
|
| 60 |
+
all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
|
| 61 |
+
return all_nodes
|
| 62 |
+
|
| 63 |
+
evabyte_7b_95 = build_tree(
|
| 64 |
+
[
|
| 65 |
+
[10],
|
| 66 |
+
[10, 8, 2, 2, 1, 1],
|
| 67 |
+
[10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
|
| 68 |
+
[8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
|
| 69 |
+
[6, 2, 1, 1],
|
| 70 |
+
[4, 2, 1, 1],
|
| 71 |
+
[4, 2, 1],
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
evabyte_7b_31 = build_tree(
|
| 75 |
+
[
|
| 76 |
+
[4],
|
| 77 |
+
[3, 2, 1, 1],
|
| 78 |
+
[3, 2, 1, 1],
|
| 79 |
+
[2, 1, 1],
|
| 80 |
+
[2, 1],
|
| 81 |
+
[2, 1],
|
| 82 |
+
[2, 1],
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
|
| 86 |
+
|
| 87 |
+
def pad_path(path, length, pad_value=-2):
|
| 88 |
+
"""
|
| 89 |
+
Pad the given path list with a specific value up to a specified length.
|
| 90 |
+
|
| 91 |
+
Parameters:
|
| 92 |
+
- path (list): The original list that needs padding.
|
| 93 |
+
- length (int): The desired length of the padded list.
|
| 94 |
+
- pad_value (optional, default=-2): The value to use for padding.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
- list: A new list based on the original path but padded to the desired length.
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> pad_path([1,2,3], 5)
|
| 101 |
+
[1, 2, 3, -2, -2]
|
| 102 |
+
|
| 103 |
+
Note:
|
| 104 |
+
If the given path is already longer than the specified length,
|
| 105 |
+
then no padding occurs, and the original path is returned.
|
| 106 |
+
"""
|
| 107 |
+
return path + [pad_value] * (length - len(path))
|
| 108 |
+
|
| 109 |
+
def reset_past_key_values(passed_key_values):
|
| 110 |
+
"""
|
| 111 |
+
Resets the current lengths in the passed key-values to zero.
|
| 112 |
+
|
| 113 |
+
This function is designed to be used during the evaluation of a baseline model.
|
| 114 |
+
It iterates through each layer's key-values and sets their current lengths to zero,
|
| 115 |
+
effectively resetting their state.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
|
| 122 |
+
"""
|
| 123 |
+
for i in range(len(passed_key_values)):
|
| 124 |
+
for j in range(2):
|
| 125 |
+
passed_key_values[i][j].current_length.fill_(0)
|
| 126 |
+
return passed_key_values
|
| 127 |
+
|
| 128 |
+
def get_nucleus_one_token(logit, temperature, top_p):
|
| 129 |
+
"""
|
| 130 |
+
Performs token sampling based on the nucleus (top-p) sampling method.
|
| 131 |
+
|
| 132 |
+
This function selects a token from a given logit distribution using the nucleus sampling strategy.
|
| 133 |
+
It allows for more controlled and diverse generation compared to traditional top-k sampling.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
|
| 137 |
+
temperature (float): A temperature parameter to control the randomness in sampling.
|
| 138 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 139 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 140 |
+
It controls the size of the set of high-probability tokens to consider for sampling.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 144 |
+
"""
|
| 145 |
+
if top_p >= 1:
|
| 146 |
+
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
|
| 147 |
+
logit = logit / temperature
|
| 148 |
+
probs = torch.softmax(logit, dim=-1)
|
| 149 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 150 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 151 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 152 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 153 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 154 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 155 |
+
logit[indices_to_remove] = float('-inf')
|
| 156 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 157 |
+
return sampled_tokens
|
| 158 |
+
|
| 159 |
+
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
|
| 160 |
+
"""
|
| 161 |
+
Implements token sampling based on the typical sampling method.
|
| 162 |
+
|
| 163 |
+
This function selects a token from a given logit distribution using the typical sampling strategy,
|
| 164 |
+
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
|
| 168 |
+
temperature (float): A parameter to control the randomness in sampling.
|
| 169 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 170 |
+
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
|
| 171 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 175 |
+
"""
|
| 176 |
+
logit = logit / temperature
|
| 177 |
+
probs = torch.softmax(logit, dim=-1)
|
| 178 |
+
entropy = -torch.sum(
|
| 179 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 180 |
+
)
|
| 181 |
+
threshold = torch.minimum(
|
| 182 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 183 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 184 |
+
)
|
| 185 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 186 |
+
logit[indices_to_remove] = float('-inf')
|
| 187 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 188 |
+
return sampled_tokens
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def generate_medusa_buffers(medusa_choices, device="cuda"):
|
| 193 |
+
"""
|
| 194 |
+
Generate buffers for the Medusa structure based on the provided choices.
|
| 195 |
+
|
| 196 |
+
Parameters:
|
| 197 |
+
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
| 198 |
+
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
- dict: A dictionary containing buffers related to the Medusa structure.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
# Sort the medusa_choices based on their lengths and then their values
|
| 205 |
+
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
| 206 |
+
medusa_len = len(sorted_medusa_choices) + 1
|
| 207 |
+
|
| 208 |
+
# Initialize depth_counts to keep track of how many choices have a particular depth
|
| 209 |
+
depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
|
| 210 |
+
for path in sorted_medusa_choices:
|
| 211 |
+
depth_counts[len(path) - 1] += 1
|
| 212 |
+
|
| 213 |
+
# Create the attention mask for Medusa
|
| 214 |
+
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
| 215 |
+
medusa_attn_mask[:, 0] = 1
|
| 216 |
+
start = 0
|
| 217 |
+
for i in range(len(depth_counts)):
|
| 218 |
+
for j in range(depth_counts[i]):
|
| 219 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 220 |
+
# retrieve ancestor position
|
| 221 |
+
if len(cur_medusa_choice) == 1:
|
| 222 |
+
continue
|
| 223 |
+
ancestor_idx = []
|
| 224 |
+
for c in range(len(cur_medusa_choice) - 1):
|
| 225 |
+
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
|
| 226 |
+
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
| 227 |
+
start += depth_counts[i]
|
| 228 |
+
|
| 229 |
+
# Generate tree indices for the Medusa structure
|
| 230 |
+
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
| 231 |
+
medusa_tree_indices[0] = 0
|
| 232 |
+
start = 0
|
| 233 |
+
for i in range(len(depth_counts)):
|
| 234 |
+
for j in range(depth_counts[i]):
|
| 235 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 236 |
+
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
|
| 237 |
+
start += depth_counts[i]
|
| 238 |
+
|
| 239 |
+
# Generate position IDs for the Medusa structure
|
| 240 |
+
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
| 241 |
+
start = 0
|
| 242 |
+
for i in range(len(depth_counts)):
|
| 243 |
+
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
|
| 244 |
+
start += depth_counts[i]
|
| 245 |
+
|
| 246 |
+
# Generate retrieval indices for Medusa structure verification
|
| 247 |
+
retrieve_indices_nest = []
|
| 248 |
+
retrieve_paths = []
|
| 249 |
+
for i in range(len(sorted_medusa_choices)):
|
| 250 |
+
cur_medusa_choice = sorted_medusa_choices[-i-1]
|
| 251 |
+
retrieve_indice = []
|
| 252 |
+
if cur_medusa_choice in retrieve_paths:
|
| 253 |
+
continue
|
| 254 |
+
else:
|
| 255 |
+
for c in range(len(cur_medusa_choice)):
|
| 256 |
+
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
|
| 257 |
+
retrieve_paths.append(cur_medusa_choice[:c+1])
|
| 258 |
+
retrieve_indices_nest.append(retrieve_indice)
|
| 259 |
+
max_length = max([len(x) for x in retrieve_indices_nest])
|
| 260 |
+
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
|
| 261 |
+
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
| 262 |
+
retrieve_indices = retrieve_indices + 1
|
| 263 |
+
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
|
| 264 |
+
|
| 265 |
+
# Aggregate the generated buffers into a dictionary
|
| 266 |
+
medusa_buffers = {
|
| 267 |
+
"medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
|
| 268 |
+
"tree_indices": medusa_tree_indices,
|
| 269 |
+
"medusa_position_ids": medusa_position_ids.unsqueeze(0),
|
| 270 |
+
"retrieve_indices": retrieve_indices,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# Move the tensors in the dictionary to the specified device
|
| 274 |
+
medusa_buffers = {
|
| 275 |
+
k: v.clone().to(device)
|
| 276 |
+
if isinstance(v, torch.Tensor)
|
| 277 |
+
else torch.tensor(v, device=device)
|
| 278 |
+
for k, v in medusa_buffers.items()
|
| 279 |
+
}
|
| 280 |
+
return medusa_buffers
|
| 281 |
+
|
| 282 |
+
def generate_candidates(
|
| 283 |
+
medusa_logits,
|
| 284 |
+
logits,
|
| 285 |
+
tree_indices,
|
| 286 |
+
retrieve_indices,
|
| 287 |
+
temperature = 0,
|
| 288 |
+
posterior_threshold=0.3,
|
| 289 |
+
posterior_alpha = 0.09,
|
| 290 |
+
top_p=0.8,
|
| 291 |
+
sampling = 'typical',
|
| 292 |
+
fast = False
|
| 293 |
+
):
|
| 294 |
+
# Say we have 3 heads, and the top-4 for each head are:
|
| 295 |
+
# [10, 3, 8, 4]
|
| 296 |
+
# [9, 5, 1, 6]
|
| 297 |
+
# [7, 16, 3, 2]
|
| 298 |
+
|
| 299 |
+
# candidates_id = 10
|
| 300 |
+
if temperature == 0 or fast:
|
| 301 |
+
candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
|
| 302 |
+
else:
|
| 303 |
+
if sampling == 'typical':
|
| 304 |
+
candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
|
| 305 |
+
elif sampling == 'nucleus':
|
| 306 |
+
candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
|
| 307 |
+
else:
|
| 308 |
+
raise NotImplementedError
|
| 309 |
+
|
| 310 |
+
# this calculates the top-k medusa logits
|
| 311 |
+
# candidates_medusa_id = [
|
| 312 |
+
# [9, 5, 1, 6]
|
| 313 |
+
# [7, 16, 3, 2]
|
| 314 |
+
# ]
|
| 315 |
+
candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
|
| 316 |
+
|
| 317 |
+
# [10, 9, 5, 1, 6, 7, 16, 3, 2]
|
| 318 |
+
candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
|
| 319 |
+
|
| 320 |
+
# based on the pre-defined tree_indices, select the corresponding candidates
|
| 321 |
+
# if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
|
| 322 |
+
# tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
|
| 323 |
+
tree_candidate_ids = candidate_ids[tree_indices]
|
| 324 |
+
|
| 325 |
+
# tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
|
| 326 |
+
# Sometimes the tree_indices are padded, so we append a zero here
|
| 327 |
+
# so that all padded indices select the appended zero.
|
| 328 |
+
tree_candidate_ids_ext = torch.cat(
|
| 329 |
+
[
|
| 330 |
+
tree_candidate_ids,
|
| 331 |
+
torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
|
| 332 |
+
],
|
| 333 |
+
dim=0
|
| 334 |
+
)
|
| 335 |
+
# [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
|
| 336 |
+
unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
|
| 337 |
+
|
| 338 |
+
tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
|
| 339 |
+
|
| 340 |
+
return tree_candidate_ids, unflattened_candidate_ids
|
| 341 |
+
|
| 342 |
+
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
|
| 343 |
+
"""
|
| 344 |
+
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
|
| 345 |
+
|
| 346 |
+
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
|
| 347 |
+
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
|
| 348 |
+
temperature scaling and cumulative probability thresholding.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 352 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 353 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 354 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 358 |
+
"""
|
| 359 |
+
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
|
| 360 |
+
|
| 361 |
+
# Apply temperature
|
| 362 |
+
logits = logits[:, :-1] / temperature
|
| 363 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 364 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 365 |
+
if top_p >= 1:
|
| 366 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 367 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 368 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 369 |
+
return posterior_mask
|
| 370 |
+
# Convert to probabilities (softmax)
|
| 371 |
+
probs = F.softmax(logits, dim=-1)
|
| 372 |
+
# Sort the probabilities
|
| 373 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 374 |
+
|
| 375 |
+
# Compute cumulative probabilities
|
| 376 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 377 |
+
|
| 378 |
+
# Create mask for the top-p nucleus
|
| 379 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 380 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 381 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 382 |
+
|
| 383 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Remove low-probability tokens
|
| 387 |
+
logits[indices_to_remove] = float('-inf')
|
| 388 |
+
# Sample from the remaining tokens
|
| 389 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 390 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 391 |
+
# Create a mask for selected tokens
|
| 392 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 393 |
+
|
| 394 |
+
return posterior_mask
|
| 395 |
+
|
| 396 |
+
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
|
| 397 |
+
"""
|
| 398 |
+
Args:
|
| 399 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 400 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 401 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 402 |
+
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
|
| 403 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 407 |
+
"""
|
| 408 |
+
logits = logits[:, :-1] / temperature
|
| 409 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 410 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 411 |
+
probs = F.softmax(logits, dim=-1)
|
| 412 |
+
entropy = -torch.sum(
|
| 413 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 414 |
+
)
|
| 415 |
+
threshold = torch.minimum(
|
| 416 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 417 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 418 |
+
)
|
| 419 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 420 |
+
logits[indices_to_remove] = float('-inf')
|
| 421 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 422 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 423 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 424 |
+
return posterior_mask
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def evaluate_posterior(
|
| 429 |
+
logits,
|
| 430 |
+
candidates,
|
| 431 |
+
temperature,
|
| 432 |
+
posterior_threshold=0.3,
|
| 433 |
+
posterior_alpha = 0.09,
|
| 434 |
+
top_p=0.8,
|
| 435 |
+
sampling = 'typical',
|
| 436 |
+
fast = True
|
| 437 |
+
):
|
| 438 |
+
if logits.shape[1] <= 1:
|
| 439 |
+
return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
|
| 440 |
+
# Greedy decoding based on temperature value
|
| 441 |
+
if temperature == 0:
|
| 442 |
+
# Find the tokens that match the maximum logits for each position in the sequence
|
| 443 |
+
posterior_mask = (
|
| 444 |
+
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
|
| 445 |
+
).int()
|
| 446 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 447 |
+
accept_length = candidates_accept_length.max().item()
|
| 448 |
+
# Choose the best candidate
|
| 449 |
+
if accept_length == 0:
|
| 450 |
+
# Default to the first candidate if none are accepted
|
| 451 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 452 |
+
else:
|
| 453 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 454 |
+
return best_candidate, accept_length
|
| 455 |
+
elif sampling == 'typical':
|
| 456 |
+
if fast:
|
| 457 |
+
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
|
| 458 |
+
candidates_prob = torch.gather(
|
| 459 |
+
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
|
| 460 |
+
).squeeze(-1)
|
| 461 |
+
posterior_entropy = -torch.sum(
|
| 462 |
+
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
|
| 463 |
+
) # torch.sum(torch.log(*)) is faster than torch.prod
|
| 464 |
+
threshold = torch.minimum(
|
| 465 |
+
torch.ones_like(posterior_entropy) * posterior_threshold,
|
| 466 |
+
torch.exp(-posterior_entropy) * posterior_alpha,
|
| 467 |
+
)
|
| 468 |
+
posterior_mask = candidates_prob > threshold
|
| 469 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 470 |
+
|
| 471 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 472 |
+
accept_length = candidates_accept_length.max().item()
|
| 473 |
+
if accept_length == 0:
|
| 474 |
+
# If no candidates are accepted, just choose the first one
|
| 475 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 476 |
+
else:
|
| 477 |
+
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
|
| 478 |
+
# Accept the best one according to likelihood
|
| 479 |
+
likelihood = torch.sum(
|
| 480 |
+
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
|
| 481 |
+
)
|
| 482 |
+
best_candidate = best_candidates[torch.argmax(likelihood)]
|
| 483 |
+
return best_candidate, accept_length
|
| 484 |
+
# Calculate posterior probabilities and thresholds for candidate selection
|
| 485 |
+
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
|
| 486 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 487 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 488 |
+
accept_length = candidates_accept_length.max().item()
|
| 489 |
+
|
| 490 |
+
if accept_length == 0:
|
| 491 |
+
# If no candidates are accepted, just choose the first one
|
| 492 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 493 |
+
else:
|
| 494 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 495 |
+
# Accept the best one according to likelihood
|
| 496 |
+
return best_candidate, accept_length
|
| 497 |
+
elif sampling == 'nucleus':
|
| 498 |
+
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
|
| 499 |
+
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
|
| 500 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 501 |
+
accept_length = candidates_accept_length.max().item()
|
| 502 |
+
# Choose the best candidate
|
| 503 |
+
if accept_length == 0:
|
| 504 |
+
# Default to the first candidate if none are accepted
|
| 505 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 506 |
+
else:
|
| 507 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 508 |
+
return best_candidate, accept_length
|
| 509 |
+
else:
|
| 510 |
+
raise NotImplementedError
|
| 511 |
+
|
| 512 |
+
def update_inference_inputs(
|
| 513 |
+
input_ids,
|
| 514 |
+
medusa_logits,
|
| 515 |
+
logits,
|
| 516 |
+
candidate_ids,
|
| 517 |
+
best_candidate,
|
| 518 |
+
accept_length,
|
| 519 |
+
):
|
| 520 |
+
input_ids = torch.cat(
|
| 521 |
+
[
|
| 522 |
+
input_ids,
|
| 523 |
+
candidate_ids[None, best_candidate, : accept_length + 1]
|
| 524 |
+
],
|
| 525 |
+
dim=-1
|
| 526 |
+
)
|
| 527 |
+
logits = logits[
|
| 528 |
+
None, best_candidate, accept_length : accept_length + 1
|
| 529 |
+
]
|
| 530 |
+
medusa_logits = medusa_logits[
|
| 531 |
+
:, None, best_candidate, accept_length : accept_length + 1
|
| 532 |
+
]
|
| 533 |
+
# Update the new token counter
|
| 534 |
+
new_token = accept_length + 1
|
| 535 |
+
return input_ids, medusa_logits, logits, new_token
|
| 536 |
+
|
| 537 |
+
def split_logits(full_logits):
|
| 538 |
+
# logits has shape [b, n, heads, vocab_size]
|
| 539 |
+
logits = full_logits[..., 0, :]
|
| 540 |
+
medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
|
| 541 |
+
return medusa_logits, logits
|
| 542 |
+
|
| 543 |
+
class MultiByteDecodingMixin:
|
| 544 |
+
def multi_byte_pred_update_cache(
|
| 545 |
+
self,
|
| 546 |
+
past_key_values,
|
| 547 |
+
retrieve_indices,
|
| 548 |
+
best_candidate,
|
| 549 |
+
new_tokens,
|
| 550 |
+
):
|
| 551 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 552 |
+
select_indices = (
|
| 553 |
+
retrieve_indices[best_candidate, : new_tokens] + prev_window_len
|
| 554 |
+
)
|
| 555 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 556 |
+
|
| 557 |
+
past_key_values.update_past_len(new_tokens, layer_idx)
|
| 558 |
+
|
| 559 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 560 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 561 |
+
|
| 562 |
+
tgt_window_k = past_window_k[..., select_indices, :]
|
| 563 |
+
tgt_window_v = past_window_v[..., select_indices, :]
|
| 564 |
+
|
| 565 |
+
dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 566 |
+
dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 567 |
+
|
| 568 |
+
dst_window_k.copy_(tgt_window_k, non_blocking=True)
|
| 569 |
+
dst_window_v.copy_(tgt_window_v, non_blocking=True)
|
| 570 |
+
|
| 571 |
+
new_window_len = prev_window_len + new_tokens
|
| 572 |
+
if new_window_len >= self.config.window_size:
|
| 573 |
+
assert new_window_len < 2 * self.config.window_size
|
| 574 |
+
|
| 575 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 576 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 577 |
+
|
| 578 |
+
_window_len = new_window_len - self.config.window_size
|
| 579 |
+
|
| 580 |
+
if _window_len > 0:
|
| 581 |
+
new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
|
| 582 |
+
new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
|
| 583 |
+
|
| 584 |
+
_dst_window_k = past_window_k[..., : _window_len, :]
|
| 585 |
+
_dst_window_v = past_window_v[..., : _window_len, :]
|
| 586 |
+
|
| 587 |
+
_dst_window_k.copy_(new_window_k, non_blocking=True)
|
| 588 |
+
_dst_window_v.copy_(new_window_v, non_blocking=True)
|
| 589 |
+
|
| 590 |
+
past_key_values.past_window_pos[layer_idx] = _window_len
|
| 591 |
+
else:
|
| 592 |
+
dump_k = None
|
| 593 |
+
dump_v = None
|
| 594 |
+
past_key_values.past_window_pos[layer_idx] = new_window_len
|
| 595 |
+
|
| 596 |
+
if dump_k is not None and dump_v is not None:
|
| 597 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 598 |
+
dump_k, dump_v,
|
| 599 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 600 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 601 |
+
None,
|
| 602 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 603 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 604 |
+
)
|
| 605 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 606 |
+
rfa_k, rfa_v, layer_idx
|
| 607 |
+
)
|
| 608 |
+
return past_key_values
|
| 609 |
+
|
| 610 |
+
def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 611 |
+
self,
|
| 612 |
+
past_key_values,
|
| 613 |
+
):
|
| 614 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 615 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 616 |
+
|
| 617 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 618 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 619 |
+
|
| 620 |
+
new_window_len = prev_window_len
|
| 621 |
+
if new_window_len == self.config.window_size:
|
| 622 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 623 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 624 |
+
past_key_values.past_window_pos[layer_idx] = 0
|
| 625 |
+
|
| 626 |
+
if dump_k is not None and dump_v is not None:
|
| 627 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 628 |
+
dump_k, dump_v,
|
| 629 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 630 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 631 |
+
None,
|
| 632 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 633 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 634 |
+
)
|
| 635 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 636 |
+
rfa_k, rfa_v, layer_idx
|
| 637 |
+
)
|
| 638 |
+
return past_key_values
|
| 639 |
+
|
| 640 |
+
def multi_byte_pred_update_attn_mask(
|
| 641 |
+
self,
|
| 642 |
+
last_iter_new_tokens,
|
| 643 |
+
tree_candidate_ids,
|
| 644 |
+
past_attn_mask,
|
| 645 |
+
medusa_attn_mask,
|
| 646 |
+
past_key_values,
|
| 647 |
+
):
|
| 648 |
+
batch_size, tree_candidate_len = tree_candidate_ids.shape
|
| 649 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 650 |
+
# NOTE: past_key_values has been updated so now
|
| 651 |
+
# seen_tokens incldues new tokens from the last tree iteration
|
| 652 |
+
assert seen_tokens > 0
|
| 653 |
+
# so one iteration would not cross two windows
|
| 654 |
+
assert last_iter_new_tokens < self.config.window_size
|
| 655 |
+
|
| 656 |
+
if past_attn_mask is not None and seen_tokens < self.config.window_size:
|
| 657 |
+
past_attn_mask = torch.cat(
|
| 658 |
+
[
|
| 659 |
+
past_attn_mask,
|
| 660 |
+
torch.ones(
|
| 661 |
+
[batch_size, 1, tree_candidate_len, last_iter_new_tokens],
|
| 662 |
+
dtype=torch.bool,
|
| 663 |
+
device=self.device
|
| 664 |
+
)
|
| 665 |
+
],
|
| 666 |
+
dim=-1
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
# we initialize attn mask each time when
|
| 670 |
+
# 1. the model crosses the window bounary, or
|
| 671 |
+
# 2. after prefilling
|
| 672 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 673 |
+
|
| 674 |
+
window_tokens = seen_tokens % self.config.window_size
|
| 675 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 676 |
+
attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
|
| 677 |
+
past_attn_mask = torch.ones(
|
| 678 |
+
(batch_size, 1, tree_candidate_len, attn_mask_len),
|
| 679 |
+
dtype=torch.bool,
|
| 680 |
+
device=self.device
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# note that 1 indicates the position is not masked
|
| 684 |
+
tree_attn_mask = torch.cat(
|
| 685 |
+
[
|
| 686 |
+
past_attn_mask,
|
| 687 |
+
medusa_attn_mask.to(torch.bool)
|
| 688 |
+
],
|
| 689 |
+
dim=-1
|
| 690 |
+
)
|
| 691 |
+
return tree_attn_mask, past_attn_mask
|
| 692 |
+
|
| 693 |
+
@torch.no_grad()
|
| 694 |
+
def multi_byte_generate(
|
| 695 |
+
self,
|
| 696 |
+
input_ids,
|
| 697 |
+
attention_mask=None,
|
| 698 |
+
temperature=0.0,
|
| 699 |
+
max_length=None,
|
| 700 |
+
max_new_tokens=None,
|
| 701 |
+
stopping_criteria=None,
|
| 702 |
+
posterior_threshold=0.09,
|
| 703 |
+
posterior_alpha=0.3,
|
| 704 |
+
top_p=0.8,
|
| 705 |
+
sampling='typical',
|
| 706 |
+
fast=True,
|
| 707 |
+
do_sample=False,
|
| 708 |
+
medusa_choices=None,
|
| 709 |
+
return_acc_lengths=False
|
| 710 |
+
):
|
| 711 |
+
if do_sample or temperature > 0.0:
|
| 712 |
+
fast = False
|
| 713 |
+
|
| 714 |
+
### Prepare `max_length` depending on other stopping criteria.
|
| 715 |
+
if max_new_tokens is not None:
|
| 716 |
+
max_length = max_new_tokens + input_ids.shape[-1]
|
| 717 |
+
elif max_new_tokens is None and max_length is None:
|
| 718 |
+
max_length = getattr(self.config, "max_position_embeddings", 32768)
|
| 719 |
+
|
| 720 |
+
### Set up stopping criteria
|
| 721 |
+
eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
|
| 722 |
+
stop_criteria = StoppingCriteriaList()
|
| 723 |
+
if max_length is not None:
|
| 724 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 725 |
+
stop_criteria.append(
|
| 726 |
+
MaxLengthCriteria(
|
| 727 |
+
max_length=max_length,
|
| 728 |
+
max_position_embeddings=max_position_embeddings,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
if stopping_criteria is not None and len(stopping_criteria) > 0:
|
| 732 |
+
stop_criteria.extend(stopping_criteria)
|
| 733 |
+
|
| 734 |
+
assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
|
| 735 |
+
assert attention_mask is None, "Only support attention mask None for now"
|
| 736 |
+
# Avoid modifying the input_ids in-place
|
| 737 |
+
input_ids = input_ids.clone()
|
| 738 |
+
position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
|
| 739 |
+
|
| 740 |
+
####################################################
|
| 741 |
+
# 0. initialize the medusa buffers
|
| 742 |
+
####################################################
|
| 743 |
+
if medusa_choices is None:
|
| 744 |
+
medusa_choices = evabyte_7b_95
|
| 745 |
+
medusa_buffers = generate_medusa_buffers(
|
| 746 |
+
medusa_choices, device=self.device
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 750 |
+
input_ids.shape[0],
|
| 751 |
+
self.config.num_attention_heads,
|
| 752 |
+
# we add 256 to allow tree ids
|
| 753 |
+
self.config.window_size + 256,
|
| 754 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 755 |
+
self.config.num_hidden_layers,
|
| 756 |
+
self.lm_head.weight.dtype,
|
| 757 |
+
self.lm_head.weight.device,
|
| 758 |
+
)
|
| 759 |
+
# prefill to get medusa logits and logits
|
| 760 |
+
full_logits, past_key_values = self.forward(
|
| 761 |
+
input_ids,
|
| 762 |
+
attention_mask=attention_mask,
|
| 763 |
+
position_ids=position_ids,
|
| 764 |
+
use_cache=True,
|
| 765 |
+
past_key_values=past_key_values,
|
| 766 |
+
return_all_pred_logits=True,
|
| 767 |
+
multibyte_decoding=False,
|
| 768 |
+
)
|
| 769 |
+
# handles an edge case where the prefill length == window_size
|
| 770 |
+
# we force the previous window to be dumped into RFA chunks
|
| 771 |
+
past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 772 |
+
past_key_values
|
| 773 |
+
)
|
| 774 |
+
medusa_logits, logits = split_logits(full_logits)
|
| 775 |
+
|
| 776 |
+
past_attn_mask = None
|
| 777 |
+
last_iter_new_tokens = 0
|
| 778 |
+
max_iters = 32768
|
| 779 |
+
if return_acc_lengths:
|
| 780 |
+
acc_lengths = []
|
| 781 |
+
for _ in range(max_iters):
|
| 782 |
+
####################################################
|
| 783 |
+
# 1. generate candidate_ids with topk predictions from Medusa heads
|
| 784 |
+
####################################################
|
| 785 |
+
tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
|
| 786 |
+
medusa_logits,
|
| 787 |
+
logits,
|
| 788 |
+
medusa_buffers["tree_indices"],
|
| 789 |
+
medusa_buffers["retrieve_indices"],
|
| 790 |
+
temperature=temperature,
|
| 791 |
+
posterior_alpha=posterior_alpha,
|
| 792 |
+
posterior_threshold=posterior_threshold,
|
| 793 |
+
top_p=top_p,
|
| 794 |
+
sampling=sampling,
|
| 795 |
+
fast=fast,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
####################################################
|
| 799 |
+
# 2. Build the medusa attention mask and position ids
|
| 800 |
+
####################################################
|
| 801 |
+
# NOTE: 1 indicates the position is not masked
|
| 802 |
+
medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
|
| 803 |
+
last_iter_new_tokens,
|
| 804 |
+
tree_candidate_ids,
|
| 805 |
+
past_attn_mask,
|
| 806 |
+
medusa_buffers["medusa_attn_mask"],
|
| 807 |
+
past_key_values,
|
| 808 |
+
)
|
| 809 |
+
medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
|
| 810 |
+
|
| 811 |
+
####################################################
|
| 812 |
+
# 3. tree decoding
|
| 813 |
+
####################################################
|
| 814 |
+
tree_full_logits, past_key_values = self.forward(
|
| 815 |
+
tree_candidate_ids,
|
| 816 |
+
past_key_values=past_key_values,
|
| 817 |
+
attention_mask=medusa_attn_mask,
|
| 818 |
+
position_ids=medusa_position_ids,
|
| 819 |
+
return_all_pred_logits=True,
|
| 820 |
+
multibyte_decoding=True,
|
| 821 |
+
)
|
| 822 |
+
_medusa_logits, _logits = split_logits(tree_full_logits)
|
| 823 |
+
medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 824 |
+
logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 825 |
+
|
| 826 |
+
####################################################
|
| 827 |
+
# 4. candidate selection
|
| 828 |
+
####################################################
|
| 829 |
+
# if the current iteration, with tree tokens, crosses window
|
| 830 |
+
# boundaries, trim the condidate_ids to be within the window
|
| 831 |
+
# so that those exceeded tokens (which will be inaccurate)
|
| 832 |
+
# will not be considered
|
| 833 |
+
tree_depth = unflattened_candidate_ids.shape[-1]
|
| 834 |
+
if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
|
| 835 |
+
max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
|
| 836 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
|
| 837 |
+
_trimmed_logits = logits[:, :max_acc_len]
|
| 838 |
+
else:
|
| 839 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids
|
| 840 |
+
_trimmed_logits = logits
|
| 841 |
+
best_candidate, accept_length = evaluate_posterior(
|
| 842 |
+
_trimmed_logits,
|
| 843 |
+
_trimmed_unflattened_candidate_ids,
|
| 844 |
+
temperature,
|
| 845 |
+
posterior_threshold,
|
| 846 |
+
posterior_alpha,
|
| 847 |
+
top_p=top_p,
|
| 848 |
+
sampling=sampling,
|
| 849 |
+
fast=fast
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
####################################################
|
| 853 |
+
# 5. update model inputs and caches
|
| 854 |
+
####################################################
|
| 855 |
+
input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
|
| 856 |
+
input_ids,
|
| 857 |
+
medusa_logits,
|
| 858 |
+
logits,
|
| 859 |
+
unflattened_candidate_ids,
|
| 860 |
+
best_candidate,
|
| 861 |
+
accept_length,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
past_key_values = self.multi_byte_pred_update_cache(
|
| 865 |
+
past_key_values,
|
| 866 |
+
medusa_buffers["retrieve_indices"],
|
| 867 |
+
best_candidate,
|
| 868 |
+
last_iter_new_tokens,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
if return_acc_lengths:
|
| 872 |
+
acc_lengths.append(last_iter_new_tokens)
|
| 873 |
+
if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
|
| 874 |
+
if return_acc_lengths:
|
| 875 |
+
return input_ids, acc_lengths
|
| 876 |
+
else:
|
| 877 |
+
return input_ids
|
| 878 |
+
if return_acc_lengths:
|
| 879 |
+
return input_ids, acc_lengths
|
| 880 |
+
else:
|
| 881 |
+
return input_ids
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/preprocessor_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
|
| 4 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 5 |
+
},
|
| 6 |
+
"do_convert_rgb": true,
|
| 7 |
+
"do_resize": true,
|
| 8 |
+
"image_processor_type": "EvaByteImageProcessor",
|
| 9 |
+
"jpeg_quality": 25,
|
| 10 |
+
"jpeg_restart_marker_blocks": 1,
|
| 11 |
+
"jpeg_streamtype": 2,
|
| 12 |
+
"jpeg_subsampling": "4:2:0",
|
| 13 |
+
"processor_class": "EvaByteProcessor",
|
| 14 |
+
"resample": 1,
|
| 15 |
+
"size": {
|
| 16 |
+
"longest_edge": 384
|
| 17 |
+
}
|
| 18 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processing_evabyte.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""
|
| 3 |
+
Processor class for EvaByte.
|
| 4 |
+
"""
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
import os
|
| 10 |
+
import PIL
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from typing import List, Optional, Union
|
| 14 |
+
|
| 15 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
+
from transformers.image_utils import ImageInput, is_valid_image
|
| 17 |
+
from transformers.processing_utils import ProcessorMixin
|
| 18 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 19 |
+
from transformers.utils import TensorType, to_py_obj
|
| 20 |
+
|
| 21 |
+
def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
|
| 22 |
+
image_obj = None
|
| 23 |
+
if isinstance(image, Image.Image):
|
| 24 |
+
image_obj = image
|
| 25 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 26 |
+
image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
|
| 27 |
+
elif os.path.isfile(image):
|
| 28 |
+
image_obj = Image.open(image)
|
| 29 |
+
elif image.startswith("data:image/"):
|
| 30 |
+
image = image.split(",")[1]
|
| 31 |
+
# Try to load as base64
|
| 32 |
+
try:
|
| 33 |
+
b64 = base64.decodebytes(image.encode())
|
| 34 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
image_obj = Image.open(image)
|
| 41 |
+
if image_obj is None:
|
| 42 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 43 |
+
|
| 44 |
+
return image_obj
|
| 45 |
+
|
| 46 |
+
def is_url(val) -> bool:
|
| 47 |
+
return isinstance(val, str) and val.startswith("http")
|
| 48 |
+
|
| 49 |
+
def is_file(val) -> bool:
|
| 50 |
+
return isinstance(val, str) and os.path.isfile(val)
|
| 51 |
+
|
| 52 |
+
def is_image_or_image_url(elem):
|
| 53 |
+
return is_url(elem) or is_valid_image(elem) or is_file(elem)
|
| 54 |
+
|
| 55 |
+
vl_chat_template = """
|
| 56 |
+
{{- bos_token }}
|
| 57 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 58 |
+
{%- set system_message = messages[0]['content'] %}
|
| 59 |
+
{%- set messages = messages[1:] %}
|
| 60 |
+
{%- else %}
|
| 61 |
+
{%- set system_message = "" %}
|
| 62 |
+
{%- endif %}
|
| 63 |
+
|
| 64 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 65 |
+
|
| 66 |
+
{%- for message in messages %}
|
| 67 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 68 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
|
| 71 |
+
{%- if message['content'] is string %}
|
| 72 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 73 |
+
{%- else %}
|
| 74 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
|
| 75 |
+
{%- for content in message['content'] %}
|
| 76 |
+
{%- if content['type'] == 'image' %}
|
| 77 |
+
{{- '<image_placeholder>\n' }}
|
| 78 |
+
{%- elif content['type'] == 'text' %}
|
| 79 |
+
{{- content['text'] }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- endfor %}
|
| 82 |
+
{{- '<|eot_id|>' }}
|
| 83 |
+
{%- endif %}
|
| 84 |
+
{%- endfor %}
|
| 85 |
+
|
| 86 |
+
{%- if add_generation_prompt %}
|
| 87 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
class EvaByteProcessor(ProcessorMixin):
|
| 92 |
+
r"""
|
| 93 |
+
Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
|
| 94 |
+
|
| 95 |
+
[`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
|
| 96 |
+
[`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image_processor ([`EvaByteImageProcessor`], *optional*):
|
| 100 |
+
The image processor is a required input.
|
| 101 |
+
tokenizer ([`EvaByteTokenizer`], *optional*):
|
| 102 |
+
The tokenizer is a required input.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
attributes = ["image_processor", "tokenizer"]
|
| 106 |
+
image_processor_class = "AutoImageProcessor"
|
| 107 |
+
tokenizer_class = "AutoTokenizer"
|
| 108 |
+
|
| 109 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 110 |
+
if image_processor is None:
|
| 111 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 112 |
+
if tokenizer is None:
|
| 113 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 114 |
+
|
| 115 |
+
super().__init__(image_processor, tokenizer)
|
| 116 |
+
self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
|
| 117 |
+
self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
|
| 118 |
+
self.image_placeholder = "<image_placeholder>"
|
| 119 |
+
self.vl_chat_template = vl_chat_template
|
| 120 |
+
|
| 121 |
+
def __call__(
|
| 122 |
+
self,
|
| 123 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 124 |
+
images: ImageInput = None,
|
| 125 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 126 |
+
strip_ending_sentinel: bool = False,
|
| 127 |
+
encode_only: bool = False,
|
| 128 |
+
**kwargs
|
| 129 |
+
) -> Union[BatchFeature, List[List[int]]]:
|
| 130 |
+
# processing pipeline:
|
| 131 |
+
# 1. read images or videos from paths
|
| 132 |
+
# 2. use image_processor to convert images / videos to byte streams
|
| 133 |
+
if images is not None:
|
| 134 |
+
if isinstance(images, bytes):
|
| 135 |
+
image_bytes_list = [[images]]
|
| 136 |
+
elif isinstance(images, list) and isinstance(images[0], bytes):
|
| 137 |
+
image_bytes_list = [images]
|
| 138 |
+
elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
|
| 139 |
+
image_bytes_list = images
|
| 140 |
+
else:
|
| 141 |
+
if is_image_or_image_url(images):
|
| 142 |
+
images = [[images]]
|
| 143 |
+
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
| 144 |
+
images = [images]
|
| 145 |
+
elif (
|
| 146 |
+
not isinstance(images, list)
|
| 147 |
+
and not isinstance(images[0], list)
|
| 148 |
+
and not is_image_or_image_url(images[0][0])
|
| 149 |
+
):
|
| 150 |
+
raise ValueError(
|
| 151 |
+
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
| 152 |
+
)
|
| 153 |
+
# Load images if they are URLs
|
| 154 |
+
images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
|
| 155 |
+
image_bytes_list = self.image_processor(images=images, **kwargs)
|
| 156 |
+
|
| 157 |
+
if not isinstance(text, list):
|
| 158 |
+
text = [text]
|
| 159 |
+
assert len(text) == 1, "Only support batch size 1 for now"
|
| 160 |
+
assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
|
| 161 |
+
# TODO: invoke SequenceFeatureExtractor to get batched inputs
|
| 162 |
+
|
| 163 |
+
# 3. tokenize the text and put images / videos byte streams into the placeholders
|
| 164 |
+
# surrounded by special tokens like "<image>" and "</image>"
|
| 165 |
+
batch_input_ids = []
|
| 166 |
+
if not encode_only:
|
| 167 |
+
batch_attention_mask = []
|
| 168 |
+
else:
|
| 169 |
+
batch_attention_mask = None
|
| 170 |
+
|
| 171 |
+
for t, image_bytes in zip(text, image_bytes_list):
|
| 172 |
+
text_splits = t.split(self.image_placeholder)
|
| 173 |
+
if len(text_splits) != len(image_bytes) + 1:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"The number of image tokens should be equal to the number of images, "
|
| 176 |
+
f"but got {len(text_splits)} and {len(image_bytes) + 1}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
input_ids = [self.tokenizer.bos_token_id]
|
| 180 |
+
for i, text_part in enumerate(text_splits):
|
| 181 |
+
# each text part must be non-empty because we added markers around placeholders
|
| 182 |
+
split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
|
| 183 |
+
input_ids.extend(split_tokens)
|
| 184 |
+
# Add image bytes after each text part except the last one
|
| 185 |
+
if i < len(image_bytes):
|
| 186 |
+
input_ids.append(self.t2v_token_id)
|
| 187 |
+
input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
|
| 188 |
+
input_ids.append(self.v2t_token_id)
|
| 189 |
+
|
| 190 |
+
if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
|
| 191 |
+
input_ids = input_ids[:-1]
|
| 192 |
+
|
| 193 |
+
batch_input_ids.append(input_ids)
|
| 194 |
+
if not encode_only:
|
| 195 |
+
batch_attention_mask.append([1] * len(input_ids))
|
| 196 |
+
|
| 197 |
+
if not encode_only:
|
| 198 |
+
# 4. return batch of features
|
| 199 |
+
inputs = BatchFeature({
|
| 200 |
+
"input_ids": batch_input_ids,
|
| 201 |
+
"attention_mask": batch_attention_mask
|
| 202 |
+
}, tensor_type=return_tensors)
|
| 203 |
+
return inputs
|
| 204 |
+
# # Pad sequences
|
| 205 |
+
# padded_inputs = self.tokenizer.pad(
|
| 206 |
+
# {"input_ids": batch_input_ids},
|
| 207 |
+
# padding=True,
|
| 208 |
+
# return_attention_mask=True,
|
| 209 |
+
# return_tensors=return_tensors,
|
| 210 |
+
# )
|
| 211 |
+
# return BatchFeature(data=padded_inputs)
|
| 212 |
+
else:
|
| 213 |
+
return batch_input_ids
|
| 214 |
+
|
| 215 |
+
def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
|
| 216 |
+
image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
|
| 217 |
+
image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
|
| 218 |
+
return image_bytes
|
| 219 |
+
|
| 220 |
+
def batch_decode(self, sequences, **kwargs):
|
| 221 |
+
"""
|
| 222 |
+
This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 223 |
+
refer to the docstring of this method for more information.
|
| 224 |
+
"""
|
| 225 |
+
rets = [self.decode(seq, **kwargs) for seq in sequences]
|
| 226 |
+
return tuple(map(list, zip(*rets)))
|
| 227 |
+
|
| 228 |
+
def decode(self, token_ids, **kwargs):
|
| 229 |
+
"""
|
| 230 |
+
Decodes a sequence of input_ids, handling image tokens separately.
|
| 231 |
+
Returns a tuple of (decoded_text, images), where images is a list of bytes.
|
| 232 |
+
"""
|
| 233 |
+
if kwargs and "jpeg_quality" in kwargs:
|
| 234 |
+
kwargs = kwargs.copy()
|
| 235 |
+
jpeg_quality = kwargs.pop("jpeg_quality")
|
| 236 |
+
else:
|
| 237 |
+
jpeg_quality = None
|
| 238 |
+
|
| 239 |
+
token_ids = to_py_obj(token_ids)
|
| 240 |
+
# Find indices of t2v_token_id and v2t_token_id
|
| 241 |
+
t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
|
| 242 |
+
v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
|
| 243 |
+
|
| 244 |
+
# Check for correct pairing of t2v and v2t tokens
|
| 245 |
+
if len(t2v_indices) != len(v2t_indices):
|
| 246 |
+
raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
|
| 247 |
+
|
| 248 |
+
# Ensure t2v and v2t tokens are in the correct order
|
| 249 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 250 |
+
if t2v_idx >= v2t_idx:
|
| 251 |
+
raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
|
| 252 |
+
|
| 253 |
+
# Initialize the start index
|
| 254 |
+
images = []
|
| 255 |
+
decoded_text = ""
|
| 256 |
+
|
| 257 |
+
start = 0
|
| 258 |
+
# Iterate over pairs of t2v and v2t indices
|
| 259 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 260 |
+
# Decode text tokens before the image
|
| 261 |
+
text_token_ids = token_ids[start:t2v_idx]
|
| 262 |
+
if len(text_token_ids) > 0:
|
| 263 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 264 |
+
|
| 265 |
+
# Insert image placeholder
|
| 266 |
+
decoded_text += self.image_placeholder
|
| 267 |
+
|
| 268 |
+
# Extract image tokens and convert them to bytes
|
| 269 |
+
image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
|
| 270 |
+
image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
|
| 271 |
+
images.append(image_bytes)
|
| 272 |
+
|
| 273 |
+
# Update the start index to the token after v2t_token_id
|
| 274 |
+
start = v2t_idx + 1
|
| 275 |
+
|
| 276 |
+
# Decode any remaining text tokens after the last image
|
| 277 |
+
if start < len(token_ids):
|
| 278 |
+
text_token_ids = token_ids[start:]
|
| 279 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 280 |
+
|
| 281 |
+
return decoded_text, images
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def model_input_names(self):
|
| 285 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 286 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 287 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "EvaByteProcessor"
|
| 6 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/special_tokens_map.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<repo_name>",
|
| 4 |
+
"<file_sep>",
|
| 5 |
+
"<t2v_token>",
|
| 6 |
+
"<v2t_token>",
|
| 7 |
+
"<|start_header_id|>",
|
| 8 |
+
"<|end_header_id|>",
|
| 9 |
+
"<|eot_id|>",
|
| 10 |
+
"<extra_id_12>",
|
| 11 |
+
"<extra_id_13>",
|
| 12 |
+
"<extra_id_14>",
|
| 13 |
+
"<extra_id_15>",
|
| 14 |
+
"<extra_id_16>",
|
| 15 |
+
"<extra_id_17>",
|
| 16 |
+
"<extra_id_18>",
|
| 17 |
+
"<extra_id_19>",
|
| 18 |
+
"<extra_id_20>",
|
| 19 |
+
"<extra_id_21>",
|
| 20 |
+
"<extra_id_22>",
|
| 21 |
+
"<extra_id_23>",
|
| 22 |
+
"<extra_id_24>",
|
| 23 |
+
"<extra_id_25>",
|
| 24 |
+
"<extra_id_26>",
|
| 25 |
+
"<extra_id_27>",
|
| 26 |
+
"<extra_id_28>",
|
| 27 |
+
"<extra_id_29>",
|
| 28 |
+
"<extra_id_30>",
|
| 29 |
+
"<extra_id_31>",
|
| 30 |
+
"<extra_id_32>",
|
| 31 |
+
"<extra_id_33>",
|
| 32 |
+
"<extra_id_34>",
|
| 33 |
+
"<extra_id_35>",
|
| 34 |
+
"<extra_id_36>",
|
| 35 |
+
"<extra_id_37>",
|
| 36 |
+
"<extra_id_38>",
|
| 37 |
+
"<extra_id_39>",
|
| 38 |
+
"<extra_id_40>",
|
| 39 |
+
"<extra_id_41>",
|
| 40 |
+
"<extra_id_42>",
|
| 41 |
+
"<extra_id_43>",
|
| 42 |
+
"<extra_id_44>",
|
| 43 |
+
"<extra_id_45>",
|
| 44 |
+
"<extra_id_46>",
|
| 45 |
+
"<extra_id_47>",
|
| 46 |
+
"<extra_id_48>",
|
| 47 |
+
"<extra_id_49>",
|
| 48 |
+
"<extra_id_50>",
|
| 49 |
+
"<extra_id_51>",
|
| 50 |
+
"<extra_id_52>",
|
| 51 |
+
"<extra_id_53>",
|
| 52 |
+
"<extra_id_54>",
|
| 53 |
+
"<extra_id_55>",
|
| 54 |
+
"<extra_id_56>",
|
| 55 |
+
"<extra_id_57>",
|
| 56 |
+
"<extra_id_58>",
|
| 57 |
+
"<extra_id_59>",
|
| 58 |
+
"<extra_id_60>",
|
| 59 |
+
"<extra_id_61>",
|
| 60 |
+
"<extra_id_62>",
|
| 61 |
+
"<extra_id_63>"
|
| 62 |
+
],
|
| 63 |
+
"bos_token": {
|
| 64 |
+
"content": "<bos>",
|
| 65 |
+
"lstrip": false,
|
| 66 |
+
"normalized": true,
|
| 67 |
+
"rstrip": false,
|
| 68 |
+
"single_word": false
|
| 69 |
+
},
|
| 70 |
+
"eos_token": {
|
| 71 |
+
"content": "<eos>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": true,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false
|
| 76 |
+
},
|
| 77 |
+
"pad_token": {
|
| 78 |
+
"content": "<pad>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": true,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false
|
| 83 |
+
},
|
| 84 |
+
"sep_token": {
|
| 85 |
+
"content": "<sep>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": true,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false
|
| 90 |
+
},
|
| 91 |
+
"unk_token": {
|
| 92 |
+
"content": "<unk>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false
|
| 97 |
+
}
|
| 98 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenization_evabyte.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
""" Tokenization class for model EvaByte."""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
chat_template = """
|
| 16 |
+
{{- bos_token }}
|
| 17 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 18 |
+
{%- set system_message = messages[0]['content'] %}
|
| 19 |
+
{%- set messages = messages[1:] %}
|
| 20 |
+
{%- else %}
|
| 21 |
+
{%- set system_message = "" %}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
|
| 24 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 25 |
+
|
| 26 |
+
{%- for message in messages %}
|
| 27 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 28 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
|
| 31 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 32 |
+
{%- endfor %}
|
| 33 |
+
|
| 34 |
+
{%- if add_generation_prompt %}
|
| 35 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 36 |
+
{%- endif %}
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
class EvaByteTokenizer(PreTrainedTokenizer):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
bos_token="<bos>",
|
| 43 |
+
eos_token="<eos>",
|
| 44 |
+
unk_token="<unk>",
|
| 45 |
+
sep_token="<sep>",
|
| 46 |
+
pad_token="<pad>",
|
| 47 |
+
extra_ids=59,
|
| 48 |
+
additional_special_tokens=None,
|
| 49 |
+
clean_up_tokenization_spaces=False,
|
| 50 |
+
**kwargs,
|
| 51 |
+
) -> None:
|
| 52 |
+
num_base_special_tokens = 5
|
| 53 |
+
# Add extra_ids to the special token list
|
| 54 |
+
if extra_ids > 0 and additional_special_tokens is None:
|
| 55 |
+
additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
|
| 56 |
+
elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
|
| 57 |
+
# Check that we have the right number of extra_id special tokens
|
| 58 |
+
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
| 59 |
+
if extra_tokens != extra_ids:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
|
| 62 |
+
" provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
|
| 63 |
+
" extra_ids tokens"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
#### override some reserved tokens to support chat template
|
| 67 |
+
for i, token in enumerate(additional_special_tokens):
|
| 68 |
+
if token == "<extra_id_5>":
|
| 69 |
+
token = "<repo_name>"
|
| 70 |
+
elif token == "<extra_id_6>":
|
| 71 |
+
token = "<file_sep>"
|
| 72 |
+
elif token == "<extra_id_7>":
|
| 73 |
+
token = "<t2v_token>"
|
| 74 |
+
elif token == "<extra_id_8>":
|
| 75 |
+
token = "<v2t_token>"
|
| 76 |
+
elif token == "<extra_id_9>":
|
| 77 |
+
token = "<|start_header_id|>"
|
| 78 |
+
elif token == "<extra_id_10>":
|
| 79 |
+
token = "<|end_header_id|>"
|
| 80 |
+
elif token == "<extra_id_11>":
|
| 81 |
+
token = "<|eot_id|>"
|
| 82 |
+
additional_special_tokens[i] = token
|
| 83 |
+
|
| 84 |
+
# lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
|
| 85 |
+
# this would be important for the byte tokenizer
|
| 86 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
| 87 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
| 88 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
| 89 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 90 |
+
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
| 91 |
+
|
| 92 |
+
self._added_tokens_decoder = {
|
| 93 |
+
0: pad_token,
|
| 94 |
+
1: bos_token,
|
| 95 |
+
2: eos_token,
|
| 96 |
+
3: unk_token, # unk_token is a placeholder
|
| 97 |
+
4: sep_token,
|
| 98 |
+
**{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
|
| 99 |
+
}
|
| 100 |
+
self.offset = len(self._added_tokens_decoder)
|
| 101 |
+
self._utf_vocab_size = 2**8 # utf is 8 bits
|
| 102 |
+
self.add_bos_token = True
|
| 103 |
+
self.add_eos_token = False
|
| 104 |
+
super().__init__(
|
| 105 |
+
pad_token=pad_token,
|
| 106 |
+
bos_token=bos_token,
|
| 107 |
+
eos_token=eos_token,
|
| 108 |
+
unk_token=unk_token,
|
| 109 |
+
sep_token=sep_token,
|
| 110 |
+
extra_ids=0,
|
| 111 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 112 |
+
additional_special_tokens=additional_special_tokens,
|
| 113 |
+
**kwargs,
|
| 114 |
+
)
|
| 115 |
+
self.chat_template = chat_template
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def vocab_size(self):
|
| 120 |
+
return self._utf_vocab_size
|
| 121 |
+
|
| 122 |
+
def get_vocab(self):
|
| 123 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
| 124 |
+
vocab.update(self.added_tokens_encoder)
|
| 125 |
+
return vocab
|
| 126 |
+
|
| 127 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 128 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 129 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 130 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 131 |
+
|
| 132 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 133 |
+
|
| 134 |
+
if token_ids_1 is not None:
|
| 135 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 136 |
+
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 140 |
+
def get_special_tokens_mask(
|
| 141 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 142 |
+
) -> List[int]:
|
| 143 |
+
"""
|
| 144 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 145 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (`List[int]`):
|
| 149 |
+
List of IDs.
|
| 150 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 153 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 157 |
+
"""
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
return super().get_special_tokens_mask(
|
| 160 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 164 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is None:
|
| 167 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 168 |
+
return (
|
| 169 |
+
bos_token_id
|
| 170 |
+
+ ([0] * len(token_ids_0))
|
| 171 |
+
+ eos_token_id
|
| 172 |
+
+ bos_token_id
|
| 173 |
+
+ ([0] * len(token_ids_1))
|
| 174 |
+
+ eos_token_id
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 178 |
+
def create_token_type_ids_from_sequences(
|
| 179 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 180 |
+
) -> List[int]:
|
| 181 |
+
"""
|
| 182 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 183 |
+
sequence pair mask has the following format:
|
| 184 |
+
|
| 185 |
+
```
|
| 186 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 187 |
+
| first sequence | second sequence |
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
token_ids_0 (`List[int]`):
|
| 194 |
+
List of ids.
|
| 195 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 196 |
+
Optional second list of IDs for sequence pairs.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 200 |
+
"""
|
| 201 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 202 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 203 |
+
|
| 204 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 205 |
+
|
| 206 |
+
if token_ids_1 is not None:
|
| 207 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 208 |
+
|
| 209 |
+
return output
|
| 210 |
+
|
| 211 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 212 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
| 213 |
+
tokens = [chr(i) for i in text.encode("utf-8")]
|
| 214 |
+
return tokens
|
| 215 |
+
|
| 216 |
+
def _convert_token_to_id(self, token):
|
| 217 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 218 |
+
|
| 219 |
+
if len(token) != 1:
|
| 220 |
+
token_id = None
|
| 221 |
+
else:
|
| 222 |
+
token_id = ord(token) + self.offset
|
| 223 |
+
|
| 224 |
+
return token_id
|
| 225 |
+
|
| 226 |
+
def _convert_id_to_token(self, index):
|
| 227 |
+
"""Converts an index (integer) to a byte (str) using the vocab."""
|
| 228 |
+
token = chr(index - self.offset)
|
| 229 |
+
return token
|
| 230 |
+
|
| 231 |
+
def convert_tokens_to_string(self, tokens):
|
| 232 |
+
"""Converts a sequence of bytes (string) to a single string."""
|
| 233 |
+
bstring = b""
|
| 234 |
+
for token in tokens:
|
| 235 |
+
if token in self.added_tokens_decoder:
|
| 236 |
+
tok_string = self.added_tokens_decoder[token].encode("utf-8")
|
| 237 |
+
elif token in self.added_tokens_encoder:
|
| 238 |
+
tok_string = token.encode("utf-8")
|
| 239 |
+
else:
|
| 240 |
+
tok_string = bytes([ord(token)])
|
| 241 |
+
bstring += tok_string
|
| 242 |
+
string = bstring.decode("utf-8", errors="ignore")
|
| 243 |
+
return string
|
| 244 |
+
|
| 245 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 246 |
+
return ()
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenizer_config.json
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": true,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<bos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": true,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "<eos>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": true,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "<sep>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": true,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"5": {
|
| 44 |
+
"content": "<repo_name>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": true,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": false
|
| 50 |
+
},
|
| 51 |
+
"6": {
|
| 52 |
+
"content": "<file_sep>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": true,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": false
|
| 58 |
+
},
|
| 59 |
+
"7": {
|
| 60 |
+
"content": "<t2v_token>",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": true,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": false
|
| 66 |
+
},
|
| 67 |
+
"8": {
|
| 68 |
+
"content": "<v2t_token>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": true,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": false
|
| 74 |
+
},
|
| 75 |
+
"9": {
|
| 76 |
+
"content": "<|start_header_id|>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": true,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": false
|
| 82 |
+
},
|
| 83 |
+
"10": {
|
| 84 |
+
"content": "<|end_header_id|>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": true,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": false
|
| 90 |
+
},
|
| 91 |
+
"11": {
|
| 92 |
+
"content": "<|eot_id|>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": false
|
| 98 |
+
},
|
| 99 |
+
"12": {
|
| 100 |
+
"content": "<extra_id_12>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": true,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": false
|
| 106 |
+
},
|
| 107 |
+
"13": {
|
| 108 |
+
"content": "<extra_id_13>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": true,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": false
|
| 114 |
+
},
|
| 115 |
+
"14": {
|
| 116 |
+
"content": "<extra_id_14>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": true,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": false
|
| 122 |
+
},
|
| 123 |
+
"15": {
|
| 124 |
+
"content": "<extra_id_15>",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": true,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": false
|
| 130 |
+
},
|
| 131 |
+
"16": {
|
| 132 |
+
"content": "<extra_id_16>",
|
| 133 |
+
"lstrip": false,
|
| 134 |
+
"normalized": true,
|
| 135 |
+
"rstrip": false,
|
| 136 |
+
"single_word": false,
|
| 137 |
+
"special": false
|
| 138 |
+
},
|
| 139 |
+
"17": {
|
| 140 |
+
"content": "<extra_id_17>",
|
| 141 |
+
"lstrip": false,
|
| 142 |
+
"normalized": true,
|
| 143 |
+
"rstrip": false,
|
| 144 |
+
"single_word": false,
|
| 145 |
+
"special": false
|
| 146 |
+
},
|
| 147 |
+
"18": {
|
| 148 |
+
"content": "<extra_id_18>",
|
| 149 |
+
"lstrip": false,
|
| 150 |
+
"normalized": true,
|
| 151 |
+
"rstrip": false,
|
| 152 |
+
"single_word": false,
|
| 153 |
+
"special": false
|
| 154 |
+
},
|
| 155 |
+
"19": {
|
| 156 |
+
"content": "<extra_id_19>",
|
| 157 |
+
"lstrip": false,
|
| 158 |
+
"normalized": true,
|
| 159 |
+
"rstrip": false,
|
| 160 |
+
"single_word": false,
|
| 161 |
+
"special": false
|
| 162 |
+
},
|
| 163 |
+
"20": {
|
| 164 |
+
"content": "<extra_id_20>",
|
| 165 |
+
"lstrip": false,
|
| 166 |
+
"normalized": true,
|
| 167 |
+
"rstrip": false,
|
| 168 |
+
"single_word": false,
|
| 169 |
+
"special": false
|
| 170 |
+
},
|
| 171 |
+
"21": {
|
| 172 |
+
"content": "<extra_id_21>",
|
| 173 |
+
"lstrip": false,
|
| 174 |
+
"normalized": true,
|
| 175 |
+
"rstrip": false,
|
| 176 |
+
"single_word": false,
|
| 177 |
+
"special": false
|
| 178 |
+
},
|
| 179 |
+
"22": {
|
| 180 |
+
"content": "<extra_id_22>",
|
| 181 |
+
"lstrip": false,
|
| 182 |
+
"normalized": true,
|
| 183 |
+
"rstrip": false,
|
| 184 |
+
"single_word": false,
|
| 185 |
+
"special": false
|
| 186 |
+
},
|
| 187 |
+
"23": {
|
| 188 |
+
"content": "<extra_id_23>",
|
| 189 |
+
"lstrip": false,
|
| 190 |
+
"normalized": true,
|
| 191 |
+
"rstrip": false,
|
| 192 |
+
"single_word": false,
|
| 193 |
+
"special": false
|
| 194 |
+
},
|
| 195 |
+
"24": {
|
| 196 |
+
"content": "<extra_id_24>",
|
| 197 |
+
"lstrip": false,
|
| 198 |
+
"normalized": true,
|
| 199 |
+
"rstrip": false,
|
| 200 |
+
"single_word": false,
|
| 201 |
+
"special": false
|
| 202 |
+
},
|
| 203 |
+
"25": {
|
| 204 |
+
"content": "<extra_id_25>",
|
| 205 |
+
"lstrip": false,
|
| 206 |
+
"normalized": true,
|
| 207 |
+
"rstrip": false,
|
| 208 |
+
"single_word": false,
|
| 209 |
+
"special": false
|
| 210 |
+
},
|
| 211 |
+
"26": {
|
| 212 |
+
"content": "<extra_id_26>",
|
| 213 |
+
"lstrip": false,
|
| 214 |
+
"normalized": true,
|
| 215 |
+
"rstrip": false,
|
| 216 |
+
"single_word": false,
|
| 217 |
+
"special": false
|
| 218 |
+
},
|
| 219 |
+
"27": {
|
| 220 |
+
"content": "<extra_id_27>",
|
| 221 |
+
"lstrip": false,
|
| 222 |
+
"normalized": true,
|
| 223 |
+
"rstrip": false,
|
| 224 |
+
"single_word": false,
|
| 225 |
+
"special": false
|
| 226 |
+
},
|
| 227 |
+
"28": {
|
| 228 |
+
"content": "<extra_id_28>",
|
| 229 |
+
"lstrip": false,
|
| 230 |
+
"normalized": true,
|
| 231 |
+
"rstrip": false,
|
| 232 |
+
"single_word": false,
|
| 233 |
+
"special": false
|
| 234 |
+
},
|
| 235 |
+
"29": {
|
| 236 |
+
"content": "<extra_id_29>",
|
| 237 |
+
"lstrip": false,
|
| 238 |
+
"normalized": true,
|
| 239 |
+
"rstrip": false,
|
| 240 |
+
"single_word": false,
|
| 241 |
+
"special": false
|
| 242 |
+
},
|
| 243 |
+
"30": {
|
| 244 |
+
"content": "<extra_id_30>",
|
| 245 |
+
"lstrip": false,
|
| 246 |
+
"normalized": true,
|
| 247 |
+
"rstrip": false,
|
| 248 |
+
"single_word": false,
|
| 249 |
+
"special": false
|
| 250 |
+
},
|
| 251 |
+
"31": {
|
| 252 |
+
"content": "<extra_id_31>",
|
| 253 |
+
"lstrip": false,
|
| 254 |
+
"normalized": true,
|
| 255 |
+
"rstrip": false,
|
| 256 |
+
"single_word": false,
|
| 257 |
+
"special": false
|
| 258 |
+
},
|
| 259 |
+
"32": {
|
| 260 |
+
"content": "<extra_id_32>",
|
| 261 |
+
"lstrip": false,
|
| 262 |
+
"normalized": true,
|
| 263 |
+
"rstrip": false,
|
| 264 |
+
"single_word": false,
|
| 265 |
+
"special": false
|
| 266 |
+
},
|
| 267 |
+
"33": {
|
| 268 |
+
"content": "<extra_id_33>",
|
| 269 |
+
"lstrip": false,
|
| 270 |
+
"normalized": true,
|
| 271 |
+
"rstrip": false,
|
| 272 |
+
"single_word": false,
|
| 273 |
+
"special": false
|
| 274 |
+
},
|
| 275 |
+
"34": {
|
| 276 |
+
"content": "<extra_id_34>",
|
| 277 |
+
"lstrip": false,
|
| 278 |
+
"normalized": true,
|
| 279 |
+
"rstrip": false,
|
| 280 |
+
"single_word": false,
|
| 281 |
+
"special": false
|
| 282 |
+
},
|
| 283 |
+
"35": {
|
| 284 |
+
"content": "<extra_id_35>",
|
| 285 |
+
"lstrip": false,
|
| 286 |
+
"normalized": true,
|
| 287 |
+
"rstrip": false,
|
| 288 |
+
"single_word": false,
|
| 289 |
+
"special": false
|
| 290 |
+
},
|
| 291 |
+
"36": {
|
| 292 |
+
"content": "<extra_id_36>",
|
| 293 |
+
"lstrip": false,
|
| 294 |
+
"normalized": true,
|
| 295 |
+
"rstrip": false,
|
| 296 |
+
"single_word": false,
|
| 297 |
+
"special": false
|
| 298 |
+
},
|
| 299 |
+
"37": {
|
| 300 |
+
"content": "<extra_id_37>",
|
| 301 |
+
"lstrip": false,
|
| 302 |
+
"normalized": true,
|
| 303 |
+
"rstrip": false,
|
| 304 |
+
"single_word": false,
|
| 305 |
+
"special": false
|
| 306 |
+
},
|
| 307 |
+
"38": {
|
| 308 |
+
"content": "<extra_id_38>",
|
| 309 |
+
"lstrip": false,
|
| 310 |
+
"normalized": true,
|
| 311 |
+
"rstrip": false,
|
| 312 |
+
"single_word": false,
|
| 313 |
+
"special": false
|
| 314 |
+
},
|
| 315 |
+
"39": {
|
| 316 |
+
"content": "<extra_id_39>",
|
| 317 |
+
"lstrip": false,
|
| 318 |
+
"normalized": true,
|
| 319 |
+
"rstrip": false,
|
| 320 |
+
"single_word": false,
|
| 321 |
+
"special": false
|
| 322 |
+
},
|
| 323 |
+
"40": {
|
| 324 |
+
"content": "<extra_id_40>",
|
| 325 |
+
"lstrip": false,
|
| 326 |
+
"normalized": true,
|
| 327 |
+
"rstrip": false,
|
| 328 |
+
"single_word": false,
|
| 329 |
+
"special": false
|
| 330 |
+
},
|
| 331 |
+
"41": {
|
| 332 |
+
"content": "<extra_id_41>",
|
| 333 |
+
"lstrip": false,
|
| 334 |
+
"normalized": true,
|
| 335 |
+
"rstrip": false,
|
| 336 |
+
"single_word": false,
|
| 337 |
+
"special": false
|
| 338 |
+
},
|
| 339 |
+
"42": {
|
| 340 |
+
"content": "<extra_id_42>",
|
| 341 |
+
"lstrip": false,
|
| 342 |
+
"normalized": true,
|
| 343 |
+
"rstrip": false,
|
| 344 |
+
"single_word": false,
|
| 345 |
+
"special": false
|
| 346 |
+
},
|
| 347 |
+
"43": {
|
| 348 |
+
"content": "<extra_id_43>",
|
| 349 |
+
"lstrip": false,
|
| 350 |
+
"normalized": true,
|
| 351 |
+
"rstrip": false,
|
| 352 |
+
"single_word": false,
|
| 353 |
+
"special": false
|
| 354 |
+
},
|
| 355 |
+
"44": {
|
| 356 |
+
"content": "<extra_id_44>",
|
| 357 |
+
"lstrip": false,
|
| 358 |
+
"normalized": true,
|
| 359 |
+
"rstrip": false,
|
| 360 |
+
"single_word": false,
|
| 361 |
+
"special": false
|
| 362 |
+
},
|
| 363 |
+
"45": {
|
| 364 |
+
"content": "<extra_id_45>",
|
| 365 |
+
"lstrip": false,
|
| 366 |
+
"normalized": true,
|
| 367 |
+
"rstrip": false,
|
| 368 |
+
"single_word": false,
|
| 369 |
+
"special": false
|
| 370 |
+
},
|
| 371 |
+
"46": {
|
| 372 |
+
"content": "<extra_id_46>",
|
| 373 |
+
"lstrip": false,
|
| 374 |
+
"normalized": true,
|
| 375 |
+
"rstrip": false,
|
| 376 |
+
"single_word": false,
|
| 377 |
+
"special": false
|
| 378 |
+
},
|
| 379 |
+
"47": {
|
| 380 |
+
"content": "<extra_id_47>",
|
| 381 |
+
"lstrip": false,
|
| 382 |
+
"normalized": true,
|
| 383 |
+
"rstrip": false,
|
| 384 |
+
"single_word": false,
|
| 385 |
+
"special": false
|
| 386 |
+
},
|
| 387 |
+
"48": {
|
| 388 |
+
"content": "<extra_id_48>",
|
| 389 |
+
"lstrip": false,
|
| 390 |
+
"normalized": true,
|
| 391 |
+
"rstrip": false,
|
| 392 |
+
"single_word": false,
|
| 393 |
+
"special": false
|
| 394 |
+
},
|
| 395 |
+
"49": {
|
| 396 |
+
"content": "<extra_id_49>",
|
| 397 |
+
"lstrip": false,
|
| 398 |
+
"normalized": true,
|
| 399 |
+
"rstrip": false,
|
| 400 |
+
"single_word": false,
|
| 401 |
+
"special": false
|
| 402 |
+
},
|
| 403 |
+
"50": {
|
| 404 |
+
"content": "<extra_id_50>",
|
| 405 |
+
"lstrip": false,
|
| 406 |
+
"normalized": true,
|
| 407 |
+
"rstrip": false,
|
| 408 |
+
"single_word": false,
|
| 409 |
+
"special": false
|
| 410 |
+
},
|
| 411 |
+
"51": {
|
| 412 |
+
"content": "<extra_id_51>",
|
| 413 |
+
"lstrip": false,
|
| 414 |
+
"normalized": true,
|
| 415 |
+
"rstrip": false,
|
| 416 |
+
"single_word": false,
|
| 417 |
+
"special": false
|
| 418 |
+
},
|
| 419 |
+
"52": {
|
| 420 |
+
"content": "<extra_id_52>",
|
| 421 |
+
"lstrip": false,
|
| 422 |
+
"normalized": true,
|
| 423 |
+
"rstrip": false,
|
| 424 |
+
"single_word": false,
|
| 425 |
+
"special": false
|
| 426 |
+
},
|
| 427 |
+
"53": {
|
| 428 |
+
"content": "<extra_id_53>",
|
| 429 |
+
"lstrip": false,
|
| 430 |
+
"normalized": true,
|
| 431 |
+
"rstrip": false,
|
| 432 |
+
"single_word": false,
|
| 433 |
+
"special": false
|
| 434 |
+
},
|
| 435 |
+
"54": {
|
| 436 |
+
"content": "<extra_id_54>",
|
| 437 |
+
"lstrip": false,
|
| 438 |
+
"normalized": true,
|
| 439 |
+
"rstrip": false,
|
| 440 |
+
"single_word": false,
|
| 441 |
+
"special": false
|
| 442 |
+
},
|
| 443 |
+
"55": {
|
| 444 |
+
"content": "<extra_id_55>",
|
| 445 |
+
"lstrip": false,
|
| 446 |
+
"normalized": true,
|
| 447 |
+
"rstrip": false,
|
| 448 |
+
"single_word": false,
|
| 449 |
+
"special": false
|
| 450 |
+
},
|
| 451 |
+
"56": {
|
| 452 |
+
"content": "<extra_id_56>",
|
| 453 |
+
"lstrip": false,
|
| 454 |
+
"normalized": true,
|
| 455 |
+
"rstrip": false,
|
| 456 |
+
"single_word": false,
|
| 457 |
+
"special": false
|
| 458 |
+
},
|
| 459 |
+
"57": {
|
| 460 |
+
"content": "<extra_id_57>",
|
| 461 |
+
"lstrip": false,
|
| 462 |
+
"normalized": true,
|
| 463 |
+
"rstrip": false,
|
| 464 |
+
"single_word": false,
|
| 465 |
+
"special": false
|
| 466 |
+
},
|
| 467 |
+
"58": {
|
| 468 |
+
"content": "<extra_id_58>",
|
| 469 |
+
"lstrip": false,
|
| 470 |
+
"normalized": true,
|
| 471 |
+
"rstrip": false,
|
| 472 |
+
"single_word": false,
|
| 473 |
+
"special": false
|
| 474 |
+
},
|
| 475 |
+
"59": {
|
| 476 |
+
"content": "<extra_id_59>",
|
| 477 |
+
"lstrip": false,
|
| 478 |
+
"normalized": true,
|
| 479 |
+
"rstrip": false,
|
| 480 |
+
"single_word": false,
|
| 481 |
+
"special": false
|
| 482 |
+
},
|
| 483 |
+
"60": {
|
| 484 |
+
"content": "<extra_id_60>",
|
| 485 |
+
"lstrip": false,
|
| 486 |
+
"normalized": true,
|
| 487 |
+
"rstrip": false,
|
| 488 |
+
"single_word": false,
|
| 489 |
+
"special": false
|
| 490 |
+
},
|
| 491 |
+
"61": {
|
| 492 |
+
"content": "<extra_id_61>",
|
| 493 |
+
"lstrip": false,
|
| 494 |
+
"normalized": true,
|
| 495 |
+
"rstrip": false,
|
| 496 |
+
"single_word": false,
|
| 497 |
+
"special": false
|
| 498 |
+
},
|
| 499 |
+
"62": {
|
| 500 |
+
"content": "<extra_id_62>",
|
| 501 |
+
"lstrip": false,
|
| 502 |
+
"normalized": true,
|
| 503 |
+
"rstrip": false,
|
| 504 |
+
"single_word": false,
|
| 505 |
+
"special": false
|
| 506 |
+
},
|
| 507 |
+
"63": {
|
| 508 |
+
"content": "<extra_id_63>",
|
| 509 |
+
"lstrip": false,
|
| 510 |
+
"normalized": true,
|
| 511 |
+
"rstrip": false,
|
| 512 |
+
"single_word": false,
|
| 513 |
+
"special": false
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"additional_special_tokens": [
|
| 517 |
+
"<repo_name>",
|
| 518 |
+
"<file_sep>",
|
| 519 |
+
"<t2v_token>",
|
| 520 |
+
"<v2t_token>",
|
| 521 |
+
"<|start_header_id|>",
|
| 522 |
+
"<|end_header_id|>",
|
| 523 |
+
"<|eot_id|>",
|
| 524 |
+
"<extra_id_12>",
|
| 525 |
+
"<extra_id_13>",
|
| 526 |
+
"<extra_id_14>",
|
| 527 |
+
"<extra_id_15>",
|
| 528 |
+
"<extra_id_16>",
|
| 529 |
+
"<extra_id_17>",
|
| 530 |
+
"<extra_id_18>",
|
| 531 |
+
"<extra_id_19>",
|
| 532 |
+
"<extra_id_20>",
|
| 533 |
+
"<extra_id_21>",
|
| 534 |
+
"<extra_id_22>",
|
| 535 |
+
"<extra_id_23>",
|
| 536 |
+
"<extra_id_24>",
|
| 537 |
+
"<extra_id_25>",
|
| 538 |
+
"<extra_id_26>",
|
| 539 |
+
"<extra_id_27>",
|
| 540 |
+
"<extra_id_28>",
|
| 541 |
+
"<extra_id_29>",
|
| 542 |
+
"<extra_id_30>",
|
| 543 |
+
"<extra_id_31>",
|
| 544 |
+
"<extra_id_32>",
|
| 545 |
+
"<extra_id_33>",
|
| 546 |
+
"<extra_id_34>",
|
| 547 |
+
"<extra_id_35>",
|
| 548 |
+
"<extra_id_36>",
|
| 549 |
+
"<extra_id_37>",
|
| 550 |
+
"<extra_id_38>",
|
| 551 |
+
"<extra_id_39>",
|
| 552 |
+
"<extra_id_40>",
|
| 553 |
+
"<extra_id_41>",
|
| 554 |
+
"<extra_id_42>",
|
| 555 |
+
"<extra_id_43>",
|
| 556 |
+
"<extra_id_44>",
|
| 557 |
+
"<extra_id_45>",
|
| 558 |
+
"<extra_id_46>",
|
| 559 |
+
"<extra_id_47>",
|
| 560 |
+
"<extra_id_48>",
|
| 561 |
+
"<extra_id_49>",
|
| 562 |
+
"<extra_id_50>",
|
| 563 |
+
"<extra_id_51>",
|
| 564 |
+
"<extra_id_52>",
|
| 565 |
+
"<extra_id_53>",
|
| 566 |
+
"<extra_id_54>",
|
| 567 |
+
"<extra_id_55>",
|
| 568 |
+
"<extra_id_56>",
|
| 569 |
+
"<extra_id_57>",
|
| 570 |
+
"<extra_id_58>",
|
| 571 |
+
"<extra_id_59>",
|
| 572 |
+
"<extra_id_60>",
|
| 573 |
+
"<extra_id_61>",
|
| 574 |
+
"<extra_id_62>",
|
| 575 |
+
"<extra_id_63>"
|
| 576 |
+
],
|
| 577 |
+
"auto_map": {
|
| 578 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor",
|
| 579 |
+
"AutoTokenizer": [
|
| 580 |
+
"tokenization_evabyte.EvaByteTokenizer",
|
| 581 |
+
null
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
"bos_token": "<bos>",
|
| 585 |
+
"chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
|
| 586 |
+
"clean_up_tokenization_spaces": false,
|
| 587 |
+
"eos_token": "<eos>",
|
| 588 |
+
"extra_ids": 0,
|
| 589 |
+
"extra_special_tokens": {},
|
| 590 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 591 |
+
"pad_token": "<pad>",
|
| 592 |
+
"processor_class": "EvaByteProcessor",
|
| 593 |
+
"sep_token": "<sep>",
|
| 594 |
+
"tokenizer_class": "EvaByteTokenizer",
|
| 595 |
+
"unk_token": "<unk>"
|
| 596 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
# EvaByte Model Card
|
| 5 |
+
|
| 6 |
+
**EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
|
| 7 |
+
|
| 8 |
+
## Model Resources
|
| 9 |
+
|
| 10 |
+
- **Repository:** https://github.com/openevabyte/evabyte
|
| 11 |
+
- **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
|
| 12 |
+
- **Paper:** Coming soon
|
| 13 |
+
|
| 14 |
+
## Model Details
|
| 15 |
+
|
| 16 |
+
EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
|
| 17 |
+
|
| 18 |
+
| Stage | Model |
|
| 19 |
+
|:----- |:-----|
|
| 20 |
+
| Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
|
| 21 |
+
| Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
|
| 22 |
+
| SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
**Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
|
| 28 |
+
|
| 29 |
+
The code snippet below demonstrates EvaByte-6.5B for completion:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
# Load model and tokenizer
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
|
| 38 |
+
|
| 39 |
+
prompt = "The quick brown fox jumps "
|
| 40 |
+
|
| 41 |
+
# Tokenize input
|
| 42 |
+
# Option 1: standard HF tokenizer interface
|
| 43 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
| 44 |
+
|
| 45 |
+
# Option 2: Direct UTF-8 byte encoding with offset
|
| 46 |
+
# Note: Each byte is offset by 64 with <bos> prepended.
|
| 47 |
+
input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
|
| 48 |
+
|
| 49 |
+
# byte-by-byte generation (default)
|
| 50 |
+
generation_output = model.generate(
|
| 51 |
+
input_ids=input_ids,
|
| 52 |
+
max_new_tokens=32
|
| 53 |
+
)
|
| 54 |
+
# alternatively, use faster multibyte generation
|
| 55 |
+
generation_output = model.multi_byte_generate(
|
| 56 |
+
input_ids=input_ids,
|
| 57 |
+
max_new_tokens=32
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Decode and print the output
|
| 61 |
+
response = tokenizer.decode(
|
| 62 |
+
generation_output[0][input_ids.shape[1]:],
|
| 63 |
+
skip_special_tokens=False,
|
| 64 |
+
clean_up_tokenization_spaces=False
|
| 65 |
+
)
|
| 66 |
+
print(response)
|
| 67 |
+
# Sample output:
|
| 68 |
+
# over the lazy dog.\n\nThe quick
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### ⚙️ Generation Modes
|
| 72 |
+
|
| 73 |
+
EvaByte supports two generation interfaces:
|
| 74 |
+
- `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
|
| 75 |
+
- `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
|
| 76 |
+
- `input_ids`: the input byte ids.
|
| 77 |
+
- `temperature`: the temperature for sampling.
|
| 78 |
+
- `max_length`: the maximum length of the generated sequence.
|
| 79 |
+
- `max_new_tokens`: the maximum number of new bytes to generate.
|
| 80 |
+
- `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
|
| 81 |
+
- `top_p`: the top-p parameter for sampling.
|
| 82 |
+
- `do_sample`: greedy decoding or sampling.
|
| 83 |
+
|
| 84 |
+
**Notes and Limitations:**
|
| 85 |
+
- `device_map="auto"` is not supported for >2 GPUs.
|
| 86 |
+
- Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
|
| 87 |
+
- `torch_dtype=torch.bfloat16` is required.
|
| 88 |
+
- The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
|
| 89 |
+
|
| 90 |
+
## Bias, Risks, and Limitations
|
| 91 |
+
As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
|
| 92 |
+
|
| 93 |
+
## Evaluation
|
| 94 |
+
|
| 95 |
+
For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
|
| 96 |
+
|
| 97 |
+
## Citation
|
| 98 |
+
```bibtex
|
| 99 |
+
@misc{evabyte,
|
| 100 |
+
title = {EvaByte: Efficient Byte-level Language Models at Scale},
|
| 101 |
+
url = {https://hkunlp.github.io/blog/2025/evabyte},
|
| 102 |
+
author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
|
| 103 |
+
year = {2025}
|
| 104 |
+
}
|
| 105 |
+
```
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"EvaByteForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_class": "eva",
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"auto_map": {
|
| 10 |
+
"AutoConfig": "configuration_evabyte.EvaByteConfig",
|
| 11 |
+
"AutoModelForCausalLM": "modeling_evabyte.EvaByteForCausalLM"
|
| 12 |
+
},
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"chunk_size": 16,
|
| 15 |
+
"eos_token_id": 2,
|
| 16 |
+
"fp32_ln": true,
|
| 17 |
+
"fp32_logits": true,
|
| 18 |
+
"fp32_skip_add": false,
|
| 19 |
+
"hidden_act": "silu",
|
| 20 |
+
"hidden_size": 5120,
|
| 21 |
+
"init_cutoff_factor": null,
|
| 22 |
+
"init_fn": "v2",
|
| 23 |
+
"init_std": 0.01275,
|
| 24 |
+
"initializer_range": 0.01275,
|
| 25 |
+
"intermediate_size": 16384,
|
| 26 |
+
"lazy_init": true,
|
| 27 |
+
"max_position_embeddings": 16384,
|
| 28 |
+
"max_seq_length": 16384,
|
| 29 |
+
"mixedp_attn": true,
|
| 30 |
+
"model_type": "evabyte",
|
| 31 |
+
"norm_add_unit_offset": true,
|
| 32 |
+
"num_attention_heads": 40,
|
| 33 |
+
"num_chunks": null,
|
| 34 |
+
"num_hidden_layers": 40,
|
| 35 |
+
"num_key_value_heads": 40,
|
| 36 |
+
"num_pred_heads": 1,
|
| 37 |
+
"pad_token_id": 0,
|
| 38 |
+
"return_dict": false,
|
| 39 |
+
"rms_norm_eps": 1e-06,
|
| 40 |
+
"rope_scaling": null,
|
| 41 |
+
"rope_theta": 100000.0,
|
| 42 |
+
"tie_word_embeddings": false,
|
| 43 |
+
"torch_dtype": "bfloat16",
|
| 44 |
+
"transformers_version": "4.47.1",
|
| 45 |
+
"use_cache": true,
|
| 46 |
+
"vocab_size": 320,
|
| 47 |
+
"window_size": 2048
|
| 48 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/configuration_evabyte.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" EvaByte configuration"""
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
class EvaByteConfig(PretrainedConfig):
|
| 6 |
+
model_type = "evabyte"
|
| 7 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
vocab_size=320,
|
| 12 |
+
hidden_size=4096,
|
| 13 |
+
intermediate_size=11008,
|
| 14 |
+
num_hidden_layers=32,
|
| 15 |
+
num_attention_heads=32,
|
| 16 |
+
num_key_value_heads=None,
|
| 17 |
+
hidden_act="silu",
|
| 18 |
+
max_position_embeddings=2048,
|
| 19 |
+
initializer_range=0.02,
|
| 20 |
+
rms_norm_eps=1e-6,
|
| 21 |
+
use_cache=True,
|
| 22 |
+
pad_token_id=None,
|
| 23 |
+
bos_token_id=1,
|
| 24 |
+
eos_token_id=2,
|
| 25 |
+
tie_word_embeddings=False,
|
| 26 |
+
rope_theta=10000.0,
|
| 27 |
+
rope_scaling=None,
|
| 28 |
+
attention_bias=False,
|
| 29 |
+
attention_dropout=0.0,
|
| 30 |
+
norm_add_unit_offset=False,
|
| 31 |
+
init_fn="mitchell",
|
| 32 |
+
init_std=0.006,
|
| 33 |
+
init_cutoff_factor=None,
|
| 34 |
+
attention_class="mha",
|
| 35 |
+
window_size=512,
|
| 36 |
+
num_chunks=None,
|
| 37 |
+
chunk_size=256,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
self.vocab_size = vocab_size
|
| 41 |
+
self.max_position_embeddings = max_position_embeddings
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.intermediate_size = intermediate_size
|
| 44 |
+
self.num_hidden_layers = num_hidden_layers
|
| 45 |
+
self.num_attention_heads = num_attention_heads
|
| 46 |
+
|
| 47 |
+
# for backward compatibility
|
| 48 |
+
if num_key_value_heads is None:
|
| 49 |
+
num_key_value_heads = num_attention_heads
|
| 50 |
+
|
| 51 |
+
self.num_key_value_heads = num_key_value_heads
|
| 52 |
+
self.hidden_act = hidden_act
|
| 53 |
+
self.initializer_range = initializer_range
|
| 54 |
+
self.rms_norm_eps = rms_norm_eps
|
| 55 |
+
self.use_cache = use_cache
|
| 56 |
+
self.rope_theta = rope_theta
|
| 57 |
+
self.rope_scaling = rope_scaling
|
| 58 |
+
self._rope_scaling_validation()
|
| 59 |
+
self.attention_bias = attention_bias
|
| 60 |
+
self.attention_dropout = attention_dropout
|
| 61 |
+
|
| 62 |
+
self.norm_add_unit_offset = norm_add_unit_offset
|
| 63 |
+
self.init_fn = init_fn
|
| 64 |
+
self.init_std = init_std
|
| 65 |
+
self.init_cutoff_factor = init_cutoff_factor
|
| 66 |
+
|
| 67 |
+
# Attention-specific paramters
|
| 68 |
+
self.attention_class = attention_class
|
| 69 |
+
self.window_size = window_size
|
| 70 |
+
self.num_chunks = num_chunks
|
| 71 |
+
self.chunk_size = chunk_size
|
| 72 |
+
|
| 73 |
+
super().__init__(
|
| 74 |
+
pad_token_id=pad_token_id,
|
| 75 |
+
bos_token_id=bos_token_id,
|
| 76 |
+
eos_token_id=eos_token_id,
|
| 77 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 78 |
+
**kwargs,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _rope_scaling_validation(self):
|
| 82 |
+
"""
|
| 83 |
+
Validate the `rope_scaling` configuration.
|
| 84 |
+
"""
|
| 85 |
+
if self.rope_scaling is None:
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
| 91 |
+
)
|
| 92 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 93 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 94 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 97 |
+
)
|
| 98 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 99 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple, List, Any, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .eva_agg_kernel import eva_agg_func_triton
|
| 6 |
+
from .eva_prep_kv_kernel import eva_prep_kv_func_triton
|
| 7 |
+
try:
|
| 8 |
+
import triton
|
| 9 |
+
USE_TRITON_IMPL = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
USE_TRITON_IMPL = False
|
| 12 |
+
raise ImportError("Triton is not installed. Please install it by running `pip install triton`.")
|
| 13 |
+
|
| 14 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Rotates half the hidden dims (last dim) of the input.
|
| 17 |
+
Args:
|
| 18 |
+
x: Rotary embedded tensor
|
| 19 |
+
Return:
|
| 20 |
+
Tensor with half of last dim negated and rotated to the front.
|
| 21 |
+
"""
|
| 22 |
+
x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
|
| 23 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 24 |
+
|
| 25 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
| 26 |
+
position_ids: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
|
| 29 |
+
|
| 30 |
+
The legends for dimensions are defined as:
|
| 31 |
+
num_heads: number of attention heads
|
| 32 |
+
current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
|
| 33 |
+
max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
|
| 34 |
+
maximum lenghth, e.g. the length of static sequence length of KV cache
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
|
| 39 |
+
k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
|
| 40 |
+
cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 41 |
+
sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 42 |
+
position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
|
| 43 |
+
(batch_size, current_seq_len).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Embedded query and key tensor of same size as input.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
bs, nheads, cur_seq_len, head_dim = q.shape
|
| 50 |
+
assert len(
|
| 51 |
+
k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
|
| 52 |
+
assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
|
| 53 |
+
assert list(k.shape[2:]) == [cur_seq_len,
|
| 54 |
+
head_dim], f"k has different current_seq_len and/or head_dim compared to q"
|
| 55 |
+
assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
|
| 56 |
+
assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
|
| 57 |
+
f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
|
| 58 |
+
|
| 59 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 60 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 61 |
+
return q_embed, k_embed
|
| 62 |
+
|
| 63 |
+
class EvaAttention(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Causal EVA for language modeling.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.config = config
|
| 71 |
+
self.layer_idx = layer_idx
|
| 72 |
+
self.hidden_size = config.hidden_size
|
| 73 |
+
self.num_heads = config.num_attention_heads
|
| 74 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 75 |
+
self.head_dim_scaling = self.head_dim ** -0.5
|
| 76 |
+
|
| 77 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 78 |
+
|
| 79 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 82 |
+
f" and `num_heads`: {self.num_heads})."
|
| 83 |
+
)
|
| 84 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 85 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 86 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 87 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 88 |
+
|
| 89 |
+
self.window_size = config.window_size
|
| 90 |
+
|
| 91 |
+
self.num_chunks = config.num_chunks
|
| 92 |
+
self.chunk_size = config.chunk_size
|
| 93 |
+
if self.chunk_size is not None:
|
| 94 |
+
assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
|
| 95 |
+
# chunk_size overrides the number of landmarks
|
| 96 |
+
self.num_chunks = None
|
| 97 |
+
|
| 98 |
+
self.chunks_per_window = int(self.window_size // self.chunk_size)
|
| 99 |
+
self.adaptive_phi = nn.Parameter(
|
| 100 |
+
torch.randn(
|
| 101 |
+
1,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
1,
|
| 104 |
+
1,
|
| 105 |
+
self.head_dim
|
| 106 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 107 |
+
)
|
| 108 |
+
self.adaptive_mu_k = nn.Parameter(
|
| 109 |
+
torch.randn(
|
| 110 |
+
1,
|
| 111 |
+
self.num_heads,
|
| 112 |
+
1,
|
| 113 |
+
1,
|
| 114 |
+
self.head_dim
|
| 115 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _triton_forward(
|
| 119 |
+
self,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 122 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 123 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 124 |
+
output_attentions: bool = False,
|
| 125 |
+
use_cache: bool = False,
|
| 126 |
+
cos: Optional[torch.Tensor] = None,
|
| 127 |
+
sin: Optional[torch.Tensor] = None,
|
| 128 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 129 |
+
assert not output_attentions
|
| 130 |
+
bsz, q_len, _ = hidden_states.size()
|
| 131 |
+
|
| 132 |
+
if use_cache:
|
| 133 |
+
if past_key_value is None:
|
| 134 |
+
raise ValueError
|
| 135 |
+
assert isinstance(attention_mask, tuple)
|
| 136 |
+
|
| 137 |
+
# infer the model's running mode
|
| 138 |
+
is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
|
| 139 |
+
is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
|
| 140 |
+
|
| 141 |
+
if is_prefilling:
|
| 142 |
+
assert len(attention_mask) == 2
|
| 143 |
+
window_mask, intra_chunk_mask = attention_mask
|
| 144 |
+
chunk_mask = None
|
| 145 |
+
elif is_decoding:
|
| 146 |
+
assert len(attention_mask) == 3
|
| 147 |
+
window_mask, intra_chunk_mask, chunk_mask = attention_mask
|
| 148 |
+
else:
|
| 149 |
+
if attention_mask is not None:
|
| 150 |
+
assert isinstance(attention_mask, tuple) and len(attention_mask) == 3
|
| 151 |
+
window_mask, chunk_mask, intra_chunk_mask = attention_mask
|
| 152 |
+
else:
|
| 153 |
+
window_mask, chunk_mask, intra_chunk_mask = None, None, None
|
| 154 |
+
|
| 155 |
+
############################################
|
| 156 |
+
# compute q, k, v from hidden states
|
| 157 |
+
############################################
|
| 158 |
+
# [b, h, q_len, d]
|
| 159 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 160 |
+
# [b, h, kv_len, d]
|
| 161 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 162 |
+
# [b, h, kv_len, d]
|
| 163 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
if use_cache:
|
| 166 |
+
past_key_value.update_past_len(q.shape[-2], self.layer_idx)
|
| 167 |
+
|
| 168 |
+
############################################
|
| 169 |
+
# apply rotary positional embeddings to q, k
|
| 170 |
+
############################################
|
| 171 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 172 |
+
|
| 173 |
+
############################################
|
| 174 |
+
# update and get cached singleton tokens
|
| 175 |
+
# update and cache k and v for calculating chunk-level RFAs
|
| 176 |
+
############################################
|
| 177 |
+
if use_cache:
|
| 178 |
+
s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks(
|
| 179 |
+
k,
|
| 180 |
+
v,
|
| 181 |
+
self.layer_idx,
|
| 182 |
+
self.window_size,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
s_k, s_v = k, v
|
| 186 |
+
dump_k, dump_v = k, v
|
| 187 |
+
|
| 188 |
+
if use_cache:
|
| 189 |
+
singleton_mask, dump_rf_mask = past_key_value.update_mask(
|
| 190 |
+
s_mask=window_mask,
|
| 191 |
+
rf_mask=intra_chunk_mask,
|
| 192 |
+
layer_idx=self.layer_idx,
|
| 193 |
+
window_size=self.window_size,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
singleton_mask = window_mask
|
| 197 |
+
dump_rf_mask = intra_chunk_mask
|
| 198 |
+
|
| 199 |
+
if dump_k is not None and dump_v is not None:
|
| 200 |
+
# 1. in prefilling, the input shape is
|
| 201 |
+
# dump_k/dump_v: [b, h, n, d]
|
| 202 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 203 |
+
# 2. in decoding, the input shape is
|
| 204 |
+
# k/v: [b, h, w, d]
|
| 205 |
+
# rfa_k/rfa_v: [b, h, w//c, d]
|
| 206 |
+
# 3. in forward inference; the seq_len is already divisible
|
| 207 |
+
rfa_k, rfa_v = eva_prep_kv_func_triton(
|
| 208 |
+
dump_k, dump_v,
|
| 209 |
+
self.adaptive_mu_k, self.adaptive_phi,
|
| 210 |
+
dump_rf_mask, self.head_dim_scaling, self.chunk_size
|
| 211 |
+
)
|
| 212 |
+
# rfa_mask = get_rfa_chunk_mask(dump_rf_mask)
|
| 213 |
+
if use_cache:
|
| 214 |
+
rfa_k, rfa_v = past_key_value.update_chunk_rfas(
|
| 215 |
+
rfa_k, rfa_v, self.layer_idx
|
| 216 |
+
)
|
| 217 |
+
elif use_cache:
|
| 218 |
+
# if there are not enough elements within the last chunk,
|
| 219 |
+
# we will only use the cached chunk-level RFAs
|
| 220 |
+
rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 221 |
+
else:
|
| 222 |
+
rfa_k, rfa_v = None, None
|
| 223 |
+
|
| 224 |
+
############################################
|
| 225 |
+
# compute the full attention output
|
| 226 |
+
############################################
|
| 227 |
+
if is_prefilling:
|
| 228 |
+
# prefilling
|
| 229 |
+
# 1. in prefilling, the input shape is
|
| 230 |
+
# q: [b, h, n, d]
|
| 231 |
+
# k/v: [b, h, n, d]
|
| 232 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 233 |
+
attn_output = eva_agg_func_triton(
|
| 234 |
+
q, s_k, s_v,
|
| 235 |
+
rfa_k, rfa_v,
|
| 236 |
+
singleton_mask, chunk_mask,
|
| 237 |
+
self.head_dim_scaling, self.window_size, self.chunks_per_window
|
| 238 |
+
)
|
| 239 |
+
elif is_decoding:
|
| 240 |
+
# 2. in decoding, the input shape is
|
| 241 |
+
# q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
|
| 242 |
+
# k/v: [b, h, 1 + s, d]
|
| 243 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 244 |
+
if rfa_k is not None and rfa_v is not None:
|
| 245 |
+
# we only take the chunk-level RFAs not in the current window
|
| 246 |
+
seen_seq_len = past_key_value.get_seq_length(self.layer_idx)
|
| 247 |
+
if seen_seq_len <= self.window_size:
|
| 248 |
+
agg_k = s_k
|
| 249 |
+
agg_v = s_v
|
| 250 |
+
attn_mask = singleton_mask
|
| 251 |
+
else:
|
| 252 |
+
# NOTE: we already updated the cache so the length now
|
| 253 |
+
# includes the current token
|
| 254 |
+
# we subtract 1 from seen_seq_len because we want
|
| 255 |
+
# if seen_seq_len = 2048 -> num_windows_seen_so_far = 0
|
| 256 |
+
# if seen_seq_len = 4096 -> num_windows_seen_so_far = 1
|
| 257 |
+
# if seen_seq_len = 4097 -> num_windows_seen_so_far = 2
|
| 258 |
+
# NOTE the cat order should be taken care of;
|
| 259 |
+
# should align with the order based on which
|
| 260 |
+
# the attention mask is constructed
|
| 261 |
+
num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size
|
| 262 |
+
agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
|
| 263 |
+
agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
|
| 264 |
+
if singleton_mask is not None:
|
| 265 |
+
assert chunk_mask is not None
|
| 266 |
+
attn_mask = torch.cat([singleton_mask, chunk_mask], dim=-1)
|
| 267 |
+
else:
|
| 268 |
+
attn_mask = singleton_mask
|
| 269 |
+
else:
|
| 270 |
+
agg_k = s_k
|
| 271 |
+
agg_v = s_v
|
| 272 |
+
attn_mask = singleton_mask
|
| 273 |
+
attn_output = F.scaled_dot_product_attention(
|
| 274 |
+
q, agg_k, agg_v,
|
| 275 |
+
attn_mask=attn_mask,
|
| 276 |
+
is_causal=False,
|
| 277 |
+
dropout_p=0.0,
|
| 278 |
+
scale=self.head_dim_scaling
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
# 3. in single-forward inference
|
| 282 |
+
attn_output = eva_agg_func_triton(
|
| 283 |
+
q, s_k, s_v,
|
| 284 |
+
rfa_k, rfa_v,
|
| 285 |
+
singleton_mask, chunk_mask,
|
| 286 |
+
self.head_dim_scaling, self.window_size, self.chunks_per_window
|
| 287 |
+
)
|
| 288 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 291 |
+
f" {attn_output.size()}"
|
| 292 |
+
)
|
| 293 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 294 |
+
attn_output = self.o_proj(attn_output)
|
| 295 |
+
attn_weights = None
|
| 296 |
+
return attn_output, attn_weights, past_key_value
|
| 297 |
+
|
| 298 |
+
def _multibyte_decoding_forward(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 302 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 303 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 304 |
+
output_attentions: bool = False,
|
| 305 |
+
use_cache: bool = False,
|
| 306 |
+
cos: Optional[torch.Tensor] = None,
|
| 307 |
+
sin: Optional[torch.Tensor] = None,
|
| 308 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 309 |
+
# during multi-byte forwarding, we only read caches and do not update them
|
| 310 |
+
assert not output_attentions
|
| 311 |
+
bsz, q_len, _ = hidden_states.size()
|
| 312 |
+
|
| 313 |
+
if use_cache and past_key_value is None:
|
| 314 |
+
raise ValueError
|
| 315 |
+
|
| 316 |
+
assert USE_TRITON_IMPL
|
| 317 |
+
assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool
|
| 318 |
+
|
| 319 |
+
assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
|
| 320 |
+
|
| 321 |
+
############################################
|
| 322 |
+
# compute q, k, v from hidden states
|
| 323 |
+
############################################
|
| 324 |
+
# [b, h, q_len, d]
|
| 325 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 326 |
+
# [b, h, kv_len, d]
|
| 327 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 328 |
+
# [b, h, kv_len, d]
|
| 329 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 330 |
+
|
| 331 |
+
############################################
|
| 332 |
+
# apply rotary positional embeddings to q, k
|
| 333 |
+
############################################
|
| 334 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 335 |
+
|
| 336 |
+
############################################
|
| 337 |
+
# update and get cached singleton tokens
|
| 338 |
+
############################################
|
| 339 |
+
input_len = k.shape[-2]
|
| 340 |
+
window_pos = past_key_value.past_window_pos[self.layer_idx]
|
| 341 |
+
new_window_pos = window_pos + input_len
|
| 342 |
+
|
| 343 |
+
past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k
|
| 344 |
+
past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v
|
| 345 |
+
s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :]
|
| 346 |
+
s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :]
|
| 347 |
+
|
| 348 |
+
rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 349 |
+
|
| 350 |
+
############################################
|
| 351 |
+
# compute the full attention output
|
| 352 |
+
############################################
|
| 353 |
+
# 2. in decoding, the input shape is
|
| 354 |
+
# q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
|
| 355 |
+
# k/v: [b, h, 1 + s, d]
|
| 356 |
+
# rfa_k/rfa_v: [b, h, n // c, d]
|
| 357 |
+
if rfa_k is not None and rfa_v is not None:
|
| 358 |
+
# NOTE the cat order should be taken care of;
|
| 359 |
+
# should align with the order based on which
|
| 360 |
+
# the attention mask is constructed
|
| 361 |
+
# agg_k = torch.cat([s_k, rfa_k], dim=-2)
|
| 362 |
+
# agg_v = torch.cat([s_v, rfa_v], dim=-2)
|
| 363 |
+
agg_k = torch.cat([rfa_k, s_k], dim=-2)
|
| 364 |
+
agg_v = torch.cat([rfa_v, s_v], dim=-2)
|
| 365 |
+
else:
|
| 366 |
+
agg_k = s_k
|
| 367 |
+
agg_v = s_v
|
| 368 |
+
attn_output = F.scaled_dot_product_attention(
|
| 369 |
+
q, agg_k, agg_v,
|
| 370 |
+
attn_mask=attention_mask,
|
| 371 |
+
is_causal=False,
|
| 372 |
+
dropout_p=0.0,
|
| 373 |
+
scale=self.head_dim_scaling
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 377 |
+
raise ValueError(
|
| 378 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 379 |
+
f" {attn_output.size()}"
|
| 380 |
+
)
|
| 381 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 382 |
+
attn_output = self.o_proj(attn_output)
|
| 383 |
+
attn_weights = None
|
| 384 |
+
return attn_output, attn_weights, past_key_value
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
hidden_states: torch.Tensor,
|
| 389 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 390 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 391 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 392 |
+
output_attentions: bool = False,
|
| 393 |
+
use_cache: bool = False,
|
| 394 |
+
cos: Optional[torch.Tensor] = None,
|
| 395 |
+
sin: Optional[torch.Tensor] = None,
|
| 396 |
+
multibyte_decoding: Optional[bool] = False,
|
| 397 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 398 |
+
assert not output_attentions
|
| 399 |
+
if use_cache and past_key_value is None:
|
| 400 |
+
raise ValueError
|
| 401 |
+
|
| 402 |
+
assert USE_TRITON_IMPL
|
| 403 |
+
if use_cache and multibyte_decoding:
|
| 404 |
+
return self._multibyte_decoding_forward(
|
| 405 |
+
hidden_states,
|
| 406 |
+
attention_mask=attention_mask,
|
| 407 |
+
position_ids=position_ids,
|
| 408 |
+
past_key_value=past_key_value,
|
| 409 |
+
output_attentions=output_attentions,
|
| 410 |
+
use_cache=use_cache,
|
| 411 |
+
cos=cos,
|
| 412 |
+
sin=sin,
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
return self._triton_forward(
|
| 416 |
+
hidden_states,
|
| 417 |
+
attention_mask=attention_mask,
|
| 418 |
+
position_ids=position_ids,
|
| 419 |
+
past_key_value=past_key_value,
|
| 420 |
+
output_attentions=output_attentions,
|
| 421 |
+
use_cache=use_cache,
|
| 422 |
+
cos=cos,
|
| 423 |
+
sin=sin,
|
| 424 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_agg_kernel.py
ADDED
|
@@ -0,0 +1,1766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
@triton.heuristics(
|
| 8 |
+
{
|
| 9 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 10 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 11 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 12 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 13 |
+
}
|
| 14 |
+
)
|
| 15 |
+
@triton.jit
|
| 16 |
+
def _bwd_eva_agg_kernel_dkdv(
|
| 17 |
+
Q,
|
| 18 |
+
K,
|
| 19 |
+
V,
|
| 20 |
+
WindowMask,
|
| 21 |
+
DO,
|
| 22 |
+
LSE,
|
| 23 |
+
DO_T_O,
|
| 24 |
+
DK,
|
| 25 |
+
DV,
|
| 26 |
+
softmax_scale,
|
| 27 |
+
stride_qb, stride_qh, stride_qm,
|
| 28 |
+
stride_kb, stride_kh, stride_kn,
|
| 29 |
+
stride_vb, stride_vh, stride_vn,
|
| 30 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 31 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 32 |
+
stride_lse_b, stride_lse_h,
|
| 33 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 34 |
+
stride_dk_b, stride_dk_h, stride_dk_n,
|
| 35 |
+
stride_dv_b, stride_dv_h, stride_dv_n,
|
| 36 |
+
nheads,
|
| 37 |
+
seqlen_q,
|
| 38 |
+
seqlen_k,
|
| 39 |
+
headdim,
|
| 40 |
+
WINDOW_SIZE: tl.constexpr,
|
| 41 |
+
MASK_TYPE: tl.constexpr,
|
| 42 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 43 |
+
EVEN_M: tl.constexpr,
|
| 44 |
+
EVEN_N: tl.constexpr,
|
| 45 |
+
EVEN_W: tl.constexpr,
|
| 46 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 47 |
+
BLOCK_M: tl.constexpr,
|
| 48 |
+
BLOCK_N: tl.constexpr,
|
| 49 |
+
):
|
| 50 |
+
off_bh = tl.program_id(1)
|
| 51 |
+
off_h = off_bh % nheads
|
| 52 |
+
off_b = off_bh // nheads
|
| 53 |
+
|
| 54 |
+
start_n = tl.program_id(0)
|
| 55 |
+
# determine which window the current KV block belongs to
|
| 56 |
+
offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
|
| 57 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 58 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 59 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 60 |
+
|
| 61 |
+
# initialize pointers
|
| 62 |
+
q_ptrs = (
|
| 63 |
+
Q +
|
| 64 |
+
off_b * stride_qb +
|
| 65 |
+
off_h * stride_qh +
|
| 66 |
+
offs_m[:, None] * stride_qm + offs_d[None, :]
|
| 67 |
+
)
|
| 68 |
+
k_ptrs = (
|
| 69 |
+
K +
|
| 70 |
+
off_b * stride_kb +
|
| 71 |
+
off_h * stride_kh +
|
| 72 |
+
offs_n[:, None] * stride_kn + offs_d[None, :]
|
| 73 |
+
)
|
| 74 |
+
v_ptrs = (
|
| 75 |
+
V +
|
| 76 |
+
off_b * stride_vb +
|
| 77 |
+
off_h * stride_vh +
|
| 78 |
+
offs_n[:, None] * stride_vn + offs_d[None, :]
|
| 79 |
+
)
|
| 80 |
+
do_ptrs = (
|
| 81 |
+
DO +
|
| 82 |
+
off_b * stride_do_b +
|
| 83 |
+
off_h * stride_do_h +
|
| 84 |
+
offs_m[:, None] * stride_do_m + offs_d[None, :]
|
| 85 |
+
)
|
| 86 |
+
do_t_o_ptrs = (
|
| 87 |
+
DO_T_O +
|
| 88 |
+
off_b * stride_do_t_o_b +
|
| 89 |
+
off_h * stride_do_t_o_h +
|
| 90 |
+
offs_m[:, None]
|
| 91 |
+
)
|
| 92 |
+
lse_ptrs = (
|
| 93 |
+
LSE +
|
| 94 |
+
off_b * stride_lse_b +
|
| 95 |
+
off_h * stride_lse_h +
|
| 96 |
+
offs_m[:, None]
|
| 97 |
+
)
|
| 98 |
+
if MASK_TYPE == 1:
|
| 99 |
+
m_ptrs = (
|
| 100 |
+
WindowMask +
|
| 101 |
+
off_b * stride_window_mask_b +
|
| 102 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 103 |
+
)
|
| 104 |
+
dk_ptrs = (
|
| 105 |
+
DK +
|
| 106 |
+
off_b * stride_dk_b +
|
| 107 |
+
off_h * stride_dk_h +
|
| 108 |
+
offs_n[:, None] * stride_dk_n + offs_d[None, :]
|
| 109 |
+
)
|
| 110 |
+
dv_ptrs = (
|
| 111 |
+
DV +
|
| 112 |
+
off_b * stride_dv_b +
|
| 113 |
+
off_h * stride_dv_h +
|
| 114 |
+
offs_n[:, None] * stride_dv_n + offs_d[None, :]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 1. for singletons
|
| 118 |
+
# determine start and end of query block
|
| 119 |
+
begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
|
| 120 |
+
end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
|
| 121 |
+
|
| 122 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 123 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 124 |
+
if EVEN_N & EVEN_M:
|
| 125 |
+
if EVEN_HEADDIM:
|
| 126 |
+
k = tl.load(k_ptrs)
|
| 127 |
+
v = tl.load(v_ptrs)
|
| 128 |
+
else:
|
| 129 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 130 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 131 |
+
else:
|
| 132 |
+
if EVEN_HEADDIM:
|
| 133 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 134 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 135 |
+
else:
|
| 136 |
+
k = tl.load(
|
| 137 |
+
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
|
| 138 |
+
)
|
| 139 |
+
v = tl.load(
|
| 140 |
+
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
|
| 141 |
+
)
|
| 142 |
+
for start_m in range(begin_m, end_m, BLOCK_M):
|
| 143 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 144 |
+
# load q, do, and lse
|
| 145 |
+
if EVEN_M & EVEN_N:
|
| 146 |
+
if EVEN_HEADDIM:
|
| 147 |
+
q = tl.load(
|
| 148 |
+
q_ptrs + start_m * stride_qm
|
| 149 |
+
)
|
| 150 |
+
do = tl.load(
|
| 151 |
+
do_ptrs + start_m * stride_do_m
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
q = tl.load(
|
| 155 |
+
q_ptrs + start_m * stride_qm,
|
| 156 |
+
mask=offs_d[None, :] < headdim,
|
| 157 |
+
other=0.0
|
| 158 |
+
)
|
| 159 |
+
do = tl.load(
|
| 160 |
+
do_ptrs + start_m * stride_do_m,
|
| 161 |
+
mask=offs_d[None, :] < headdim,
|
| 162 |
+
other=0.0
|
| 163 |
+
)
|
| 164 |
+
do_t_o = tl.load(
|
| 165 |
+
do_t_o_ptrs + start_m
|
| 166 |
+
)
|
| 167 |
+
lse = tl.load(
|
| 168 |
+
lse_ptrs + start_m
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
if EVEN_HEADDIM:
|
| 172 |
+
q = tl.load(
|
| 173 |
+
q_ptrs + start_m * stride_qm,
|
| 174 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 175 |
+
other=0.0
|
| 176 |
+
)
|
| 177 |
+
do = tl.load(
|
| 178 |
+
do_ptrs + start_m * stride_do_m,
|
| 179 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 180 |
+
other=0.0
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
q = tl.load(
|
| 184 |
+
q_ptrs + start_m * stride_qm,
|
| 185 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 186 |
+
other=0.0
|
| 187 |
+
)
|
| 188 |
+
do = tl.load(
|
| 189 |
+
do_ptrs + start_m * stride_do_m,
|
| 190 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 191 |
+
other=0.0
|
| 192 |
+
)
|
| 193 |
+
do_t_o = tl.load(
|
| 194 |
+
do_t_o_ptrs + start_m,
|
| 195 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 196 |
+
other=0.0
|
| 197 |
+
)
|
| 198 |
+
lse = tl.load(
|
| 199 |
+
lse_ptrs + start_m,
|
| 200 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 201 |
+
other=0.0
|
| 202 |
+
)
|
| 203 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 204 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 205 |
+
qk += tl.dot(q, tl.trans(k))
|
| 206 |
+
if not EVEN_M:
|
| 207 |
+
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
|
| 208 |
+
|
| 209 |
+
if MASK_TYPE == 1:
|
| 210 |
+
if EVEN_M & EVEN_W:
|
| 211 |
+
mask = tl.load(
|
| 212 |
+
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
mask = tl.load(
|
| 216 |
+
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
|
| 217 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q)
|
| 218 |
+
& (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
|
| 219 |
+
other=1,
|
| 220 |
+
)
|
| 221 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 222 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 223 |
+
# to multiply with softmax_scale here.
|
| 224 |
+
# we assume mask already implies the causal masking
|
| 225 |
+
qk = qk * softmax_scale
|
| 226 |
+
qk = tl.where(mask, float("-inf"), qk)
|
| 227 |
+
p = tl.exp(qk - lse)
|
| 228 |
+
else:
|
| 229 |
+
qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
|
| 230 |
+
p = tl.exp(qk * softmax_scale - lse)
|
| 231 |
+
|
| 232 |
+
# dp [M, N]
|
| 233 |
+
dp = tl.dot(do, tl.trans(v))
|
| 234 |
+
# p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
|
| 235 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 236 |
+
# p is fp32 and [M, N], convert to q.dtype
|
| 237 |
+
# do [M, D] -> dv [N, D]
|
| 238 |
+
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
| 239 |
+
# dk [N, D]
|
| 240 |
+
dk += tl.dot(tl.trans(ds), q)
|
| 241 |
+
if EVEN_N & EVEN_M:
|
| 242 |
+
if EVEN_HEADDIM:
|
| 243 |
+
tl.store(dv_ptrs, dv)
|
| 244 |
+
tl.store(dk_ptrs, dk)
|
| 245 |
+
else:
|
| 246 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
| 247 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
| 248 |
+
else:
|
| 249 |
+
if EVEN_HEADDIM:
|
| 250 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
| 251 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
| 252 |
+
else:
|
| 253 |
+
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 254 |
+
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 255 |
+
|
| 256 |
+
@triton.heuristics(
|
| 257 |
+
{
|
| 258 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 259 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 260 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
@triton.jit
|
| 264 |
+
def _bwd_eva_agg_kernel_drfa_kv(
|
| 265 |
+
Q,
|
| 266 |
+
RFA_K,
|
| 267 |
+
RFA_V,
|
| 268 |
+
ChunkMask,
|
| 269 |
+
DO,
|
| 270 |
+
LSE,
|
| 271 |
+
DO_T_O,
|
| 272 |
+
D_RFA_K,
|
| 273 |
+
D_RFA_V,
|
| 274 |
+
softmax_scale,
|
| 275 |
+
stride_qb, stride_qh, stride_qm,
|
| 276 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 277 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 278 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 279 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 280 |
+
stride_lse_b, stride_lse_h,
|
| 281 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 282 |
+
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
|
| 283 |
+
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
|
| 284 |
+
nheads,
|
| 285 |
+
seqlen_q,
|
| 286 |
+
nchunks,
|
| 287 |
+
headdim,
|
| 288 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 289 |
+
WINDOW_SIZE: tl.constexpr,
|
| 290 |
+
MASK_TYPE: tl.constexpr,
|
| 291 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 292 |
+
EVEN_M: tl.constexpr,
|
| 293 |
+
EVEN_C: tl.constexpr,
|
| 294 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 295 |
+
BLOCK_M: tl.constexpr,
|
| 296 |
+
BLOCK_N: tl.constexpr,
|
| 297 |
+
):
|
| 298 |
+
off_bh = tl.program_id(1)
|
| 299 |
+
off_h = off_bh % nheads
|
| 300 |
+
off_b = off_bh // nheads
|
| 301 |
+
start_c = tl.program_id(0)
|
| 302 |
+
# there are 128 chunks per window
|
| 303 |
+
offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 304 |
+
# determine which window the current KV block belongs to
|
| 305 |
+
offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
|
| 306 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 307 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 308 |
+
|
| 309 |
+
# initialize pointers
|
| 310 |
+
q_ptrs = (
|
| 311 |
+
Q +
|
| 312 |
+
off_b * stride_qb +
|
| 313 |
+
off_h * stride_qh +
|
| 314 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 315 |
+
)
|
| 316 |
+
do_ptrs = (
|
| 317 |
+
DO +
|
| 318 |
+
off_b * stride_do_b +
|
| 319 |
+
off_h * stride_do_h +
|
| 320 |
+
(offs_m[:, None] * stride_do_m + offs_d[None, :])
|
| 321 |
+
)
|
| 322 |
+
do_t_o_ptrs = (
|
| 323 |
+
DO_T_O +
|
| 324 |
+
off_b * stride_do_t_o_b +
|
| 325 |
+
off_h * stride_do_t_o_h +
|
| 326 |
+
(offs_m[:, None])
|
| 327 |
+
)
|
| 328 |
+
lse_ptrs = (
|
| 329 |
+
LSE +
|
| 330 |
+
off_b * stride_lse_b +
|
| 331 |
+
off_h * stride_lse_h +
|
| 332 |
+
(offs_m[:, None])
|
| 333 |
+
)
|
| 334 |
+
rfa_k_ptrs = (
|
| 335 |
+
RFA_K +
|
| 336 |
+
off_b * stride_rfa_kb +
|
| 337 |
+
off_h * stride_rfa_kh +
|
| 338 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 339 |
+
)
|
| 340 |
+
rfa_v_ptrs = (
|
| 341 |
+
RFA_V +
|
| 342 |
+
off_b * stride_rfa_vb +
|
| 343 |
+
off_h * stride_rfa_vh +
|
| 344 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 345 |
+
)
|
| 346 |
+
if MASK_TYPE == 1:
|
| 347 |
+
rfa_m_ptrs = (
|
| 348 |
+
ChunkMask +
|
| 349 |
+
off_b * stride_chunk_mask_b +
|
| 350 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 351 |
+
)
|
| 352 |
+
d_rfa_k_ptrs = (
|
| 353 |
+
D_RFA_K +
|
| 354 |
+
off_b * stride_d_rfa_k_b +
|
| 355 |
+
off_h * stride_d_rfa_k_h +
|
| 356 |
+
(offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
|
| 357 |
+
)
|
| 358 |
+
d_rfa_v_ptrs = (
|
| 359 |
+
D_RFA_V +
|
| 360 |
+
off_b * stride_d_rfa_v_b +
|
| 361 |
+
off_h * stride_d_rfa_v_h +
|
| 362 |
+
(offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 366 |
+
d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 367 |
+
if EVEN_C & EVEN_M:
|
| 368 |
+
if EVEN_HEADDIM:
|
| 369 |
+
rfa_k = tl.load(rfa_k_ptrs)
|
| 370 |
+
rfa_v = tl.load(rfa_v_ptrs)
|
| 371 |
+
else:
|
| 372 |
+
rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 373 |
+
rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 374 |
+
else:
|
| 375 |
+
if EVEN_HEADDIM:
|
| 376 |
+
rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
|
| 377 |
+
rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
|
| 378 |
+
else:
|
| 379 |
+
rfa_k = tl.load(
|
| 380 |
+
rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
|
| 381 |
+
)
|
| 382 |
+
rfa_v = tl.load(
|
| 383 |
+
rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
|
| 384 |
+
)
|
| 385 |
+
begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
|
| 386 |
+
end_m = seqlen_q
|
| 387 |
+
for start_m in range(begin_m, end_m, BLOCK_M):
|
| 388 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 389 |
+
# load q, do, and lse
|
| 390 |
+
if EVEN_M:
|
| 391 |
+
if EVEN_HEADDIM:
|
| 392 |
+
q = tl.load(
|
| 393 |
+
q_ptrs + start_m * stride_qm
|
| 394 |
+
)
|
| 395 |
+
do = tl.load(
|
| 396 |
+
do_ptrs + start_m * stride_do_m
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
q = tl.load(
|
| 400 |
+
q_ptrs + start_m * stride_qm,
|
| 401 |
+
mask=offs_d[None, :] < headdim,
|
| 402 |
+
other=0.0
|
| 403 |
+
)
|
| 404 |
+
do = tl.load(
|
| 405 |
+
do_ptrs + start_m * stride_do_m,
|
| 406 |
+
mask=offs_d[None, :] < headdim,
|
| 407 |
+
other=0.0
|
| 408 |
+
)
|
| 409 |
+
do_t_o = tl.load(
|
| 410 |
+
do_t_o_ptrs + start_m
|
| 411 |
+
)
|
| 412 |
+
lse = tl.load(
|
| 413 |
+
lse_ptrs + start_m
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
if EVEN_HEADDIM:
|
| 417 |
+
q = tl.load(
|
| 418 |
+
q_ptrs + start_m * stride_qm,
|
| 419 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 420 |
+
other=0.0
|
| 421 |
+
)
|
| 422 |
+
do = tl.load(
|
| 423 |
+
do_ptrs + start_m * stride_do_m,
|
| 424 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 425 |
+
other=0.0
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
q = tl.load(
|
| 429 |
+
q_ptrs + start_m * stride_qm,
|
| 430 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 431 |
+
other=0.0
|
| 432 |
+
)
|
| 433 |
+
do = tl.load(
|
| 434 |
+
do_ptrs + start_m * stride_do_m,
|
| 435 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 436 |
+
other=0.0
|
| 437 |
+
)
|
| 438 |
+
do_t_o = tl.load(
|
| 439 |
+
do_t_o_ptrs + start_m,
|
| 440 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 441 |
+
other=0.0
|
| 442 |
+
)
|
| 443 |
+
lse = tl.load(
|
| 444 |
+
lse_ptrs + start_m,
|
| 445 |
+
mask=(start_m + offs_m)[:, None] < seqlen_q,
|
| 446 |
+
other=0.0
|
| 447 |
+
)
|
| 448 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 449 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 450 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 451 |
+
if not EVEN_M:
|
| 452 |
+
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
|
| 453 |
+
|
| 454 |
+
if MASK_TYPE == 1:
|
| 455 |
+
if EVEN_M & EVEN_C:
|
| 456 |
+
mask = tl.load(
|
| 457 |
+
rfa_m_ptrs + (start_m * stride_chunk_mask_m)
|
| 458 |
+
)
|
| 459 |
+
else:
|
| 460 |
+
mask = tl.load(
|
| 461 |
+
rfa_m_ptrs + (start_m * stride_chunk_mask_m),
|
| 462 |
+
mask=((start_m + offs_m)[:, None] < seqlen_q)
|
| 463 |
+
& (offs_c[None, :] < nchunks),
|
| 464 |
+
other=1,
|
| 465 |
+
)
|
| 466 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 467 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 468 |
+
# to multiply with softmax_scale here.
|
| 469 |
+
# we assume mask already implies the causal masking
|
| 470 |
+
qk = qk * softmax_scale
|
| 471 |
+
qk = tl.where(mask, float("-inf"), qk)
|
| 472 |
+
p = tl.exp(qk - lse)
|
| 473 |
+
else:
|
| 474 |
+
p = tl.exp(qk * softmax_scale - lse)
|
| 475 |
+
|
| 476 |
+
dp = tl.dot(do, tl.trans(rfa_v))
|
| 477 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 478 |
+
# p is fp32, convert to q.dtype
|
| 479 |
+
d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
|
| 480 |
+
# move softmax_scale to ds to save computation
|
| 481 |
+
d_rfa_k += tl.dot(tl.trans(ds), q)
|
| 482 |
+
if EVEN_C & EVEN_M:
|
| 483 |
+
if EVEN_HEADDIM:
|
| 484 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v)
|
| 485 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k)
|
| 486 |
+
else:
|
| 487 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
|
| 488 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
|
| 489 |
+
else:
|
| 490 |
+
if EVEN_HEADDIM:
|
| 491 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
|
| 492 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
|
| 493 |
+
else:
|
| 494 |
+
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
|
| 495 |
+
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
|
| 496 |
+
|
| 497 |
+
@triton.heuristics(
|
| 498 |
+
{
|
| 499 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 500 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 501 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 502 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 503 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 504 |
+
}
|
| 505 |
+
)
|
| 506 |
+
@triton.jit
|
| 507 |
+
def _bwd_eva_agg_kernel_dq(
|
| 508 |
+
Q,
|
| 509 |
+
K,
|
| 510 |
+
V,
|
| 511 |
+
RFA_K,
|
| 512 |
+
RFA_V,
|
| 513 |
+
WindowMask,
|
| 514 |
+
ChunkMask,
|
| 515 |
+
DO,
|
| 516 |
+
LSE,
|
| 517 |
+
DO_T_O,
|
| 518 |
+
DQ,
|
| 519 |
+
softmax_scale,
|
| 520 |
+
stride_qb, stride_qh, stride_qm,
|
| 521 |
+
stride_kb, stride_kh, stride_kn,
|
| 522 |
+
stride_vb, stride_vh, stride_vn,
|
| 523 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 524 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 525 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 526 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 527 |
+
stride_do_b, stride_do_h, stride_do_m,
|
| 528 |
+
stride_lse_b, stride_lse_h,
|
| 529 |
+
stride_do_t_o_b, stride_do_t_o_h,
|
| 530 |
+
stride_dq_b, stride_dq_h, stride_dq_m,
|
| 531 |
+
nheads,
|
| 532 |
+
seqlen_q,
|
| 533 |
+
seqlen_k,
|
| 534 |
+
nchunks,
|
| 535 |
+
headdim,
|
| 536 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 537 |
+
WINDOW_SIZE: tl.constexpr,
|
| 538 |
+
MASK_TYPE: tl.constexpr,
|
| 539 |
+
EMPTY_RFA_KV: tl.constexpr,
|
| 540 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 541 |
+
EVEN_M: tl.constexpr,
|
| 542 |
+
EVEN_N: tl.constexpr,
|
| 543 |
+
EVEN_W: tl.constexpr,
|
| 544 |
+
EVEN_C: tl.constexpr,
|
| 545 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 546 |
+
BLOCK_M: tl.constexpr,
|
| 547 |
+
BLOCK_N: tl.constexpr,
|
| 548 |
+
):
|
| 549 |
+
start_m = tl.program_id(0)
|
| 550 |
+
off_bh = tl.program_id(1)
|
| 551 |
+
off_h = off_bh % nheads
|
| 552 |
+
off_b = off_bh // nheads
|
| 553 |
+
# initialize offsets
|
| 554 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 555 |
+
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
|
| 556 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 557 |
+
offs_c = tl.arange(0, BLOCK_N)
|
| 558 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 559 |
+
# TODO: add paratheses or not
|
| 560 |
+
q_ptrs = (
|
| 561 |
+
Q +
|
| 562 |
+
off_b * stride_qb +
|
| 563 |
+
off_h * stride_qh +
|
| 564 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 565 |
+
)
|
| 566 |
+
k_ptrs = (
|
| 567 |
+
K +
|
| 568 |
+
off_b * stride_kb +
|
| 569 |
+
off_h * stride_kh +
|
| 570 |
+
(offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 571 |
+
)
|
| 572 |
+
v_ptrs = (
|
| 573 |
+
V +
|
| 574 |
+
off_b * stride_vb +
|
| 575 |
+
off_h * stride_vh +
|
| 576 |
+
(offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 577 |
+
)
|
| 578 |
+
if EMPTY_RFA_KV == 0:
|
| 579 |
+
rfa_k_ptrs = (
|
| 580 |
+
RFA_K +
|
| 581 |
+
off_b * stride_rfa_kb +
|
| 582 |
+
off_h * stride_rfa_kh +
|
| 583 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 584 |
+
)
|
| 585 |
+
rfa_v_ptrs = (
|
| 586 |
+
RFA_V +
|
| 587 |
+
off_b * stride_rfa_vb +
|
| 588 |
+
off_h * stride_rfa_vh +
|
| 589 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 590 |
+
)
|
| 591 |
+
dq_ptrs = (
|
| 592 |
+
DQ +
|
| 593 |
+
off_b * stride_dq_b +
|
| 594 |
+
off_h * stride_dq_h +
|
| 595 |
+
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
|
| 596 |
+
)
|
| 597 |
+
do_ptrs = (
|
| 598 |
+
DO +
|
| 599 |
+
off_b * stride_do_b +
|
| 600 |
+
off_h * stride_do_h +
|
| 601 |
+
(offs_m[:, None] * stride_do_m + offs_d[None, :])
|
| 602 |
+
)
|
| 603 |
+
do_t_o_ptrs = (
|
| 604 |
+
DO_T_O +
|
| 605 |
+
off_b * stride_do_t_o_b +
|
| 606 |
+
off_h * stride_do_t_o_h +
|
| 607 |
+
offs_m[:, None]
|
| 608 |
+
)
|
| 609 |
+
lse_ptrs = (
|
| 610 |
+
LSE +
|
| 611 |
+
off_b * stride_lse_b +
|
| 612 |
+
off_h * stride_lse_h +
|
| 613 |
+
offs_m[:, None]
|
| 614 |
+
)
|
| 615 |
+
### load q, do, do_t_o, lse ####
|
| 616 |
+
if EVEN_M:
|
| 617 |
+
if EVEN_HEADDIM:
|
| 618 |
+
q = tl.load(
|
| 619 |
+
q_ptrs
|
| 620 |
+
)
|
| 621 |
+
do = tl.load(
|
| 622 |
+
do_ptrs
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
q = tl.load(
|
| 626 |
+
q_ptrs,
|
| 627 |
+
mask=offs_d[None, :] < headdim,
|
| 628 |
+
other=0.0
|
| 629 |
+
)
|
| 630 |
+
do = tl.load(
|
| 631 |
+
do_ptrs,
|
| 632 |
+
mask=offs_d[None, :] < headdim,
|
| 633 |
+
other=0.0
|
| 634 |
+
)
|
| 635 |
+
do_t_o = tl.load(
|
| 636 |
+
do_t_o_ptrs
|
| 637 |
+
)
|
| 638 |
+
lse = tl.load(
|
| 639 |
+
lse_ptrs
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
if EVEN_HEADDIM:
|
| 643 |
+
q = tl.load(
|
| 644 |
+
q_ptrs,
|
| 645 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 646 |
+
other=0.0
|
| 647 |
+
)
|
| 648 |
+
do = tl.load(
|
| 649 |
+
do_ptrs,
|
| 650 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 651 |
+
other=0.0
|
| 652 |
+
)
|
| 653 |
+
else:
|
| 654 |
+
q = tl.load(
|
| 655 |
+
q_ptrs,
|
| 656 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 657 |
+
other=0.0
|
| 658 |
+
)
|
| 659 |
+
do = tl.load(
|
| 660 |
+
do_ptrs,
|
| 661 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 662 |
+
other=0.0
|
| 663 |
+
)
|
| 664 |
+
do_t_o = tl.load(
|
| 665 |
+
do_t_o_ptrs,
|
| 666 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 667 |
+
other=0.0
|
| 668 |
+
)
|
| 669 |
+
lse = tl.load(
|
| 670 |
+
lse_ptrs,
|
| 671 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 672 |
+
other=0.0
|
| 673 |
+
)
|
| 674 |
+
lse = tl.where(lse == float("-inf"), 0.0, lse)
|
| 675 |
+
lse *= 1.4426950408889634 # log2(e)
|
| 676 |
+
qk_scale = softmax_scale
|
| 677 |
+
qk_scale *= 1.4426950408889634 # log2(e)
|
| 678 |
+
if MASK_TYPE == 1:
|
| 679 |
+
window_mask_ptrs = (
|
| 680 |
+
WindowMask +
|
| 681 |
+
off_b * stride_window_mask_b +
|
| 682 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 683 |
+
)
|
| 684 |
+
if EMPTY_RFA_KV == 0:
|
| 685 |
+
chunk_mask_ptrs = (
|
| 686 |
+
ChunkMask +
|
| 687 |
+
off_b * stride_chunk_mask_b +
|
| 688 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 692 |
+
# loop over k, v and update accumulator
|
| 693 |
+
# Iterate over local singletons;
|
| 694 |
+
# so we only iterate over blocks within the current window
|
| 695 |
+
start_idx_n = offs_w * WINDOW_SIZE
|
| 696 |
+
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 697 |
+
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
|
| 698 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 699 |
+
if EVEN_N & EVEN_M:
|
| 700 |
+
if EVEN_HEADDIM:
|
| 701 |
+
k = tl.load(
|
| 702 |
+
k_ptrs + start_n * stride_kn
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
k = tl.load(
|
| 706 |
+
k_ptrs + start_n * stride_kn,
|
| 707 |
+
mask=offs_d[None, :] < headdim,
|
| 708 |
+
other=0.0
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
if EVEN_HEADDIM:
|
| 712 |
+
k = tl.load(
|
| 713 |
+
k_ptrs + start_n * stride_kn,
|
| 714 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 715 |
+
other=0.0,
|
| 716 |
+
)
|
| 717 |
+
else:
|
| 718 |
+
k = tl.load(
|
| 719 |
+
k_ptrs + start_n * stride_kn,
|
| 720 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 721 |
+
other=0.0,
|
| 722 |
+
)
|
| 723 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 724 |
+
qk += tl.dot(q, tl.trans(k))
|
| 725 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 726 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 727 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 728 |
+
|
| 729 |
+
if MASK_TYPE == 1:
|
| 730 |
+
if EVEN_M & EVEN_W:
|
| 731 |
+
window_mask = tl.load(
|
| 732 |
+
window_mask_ptrs + start_n - start_idx_n
|
| 733 |
+
)
|
| 734 |
+
else:
|
| 735 |
+
window_mask = tl.load(
|
| 736 |
+
window_mask_ptrs + start_n - start_idx_n,
|
| 737 |
+
mask=(offs_m[:, None] < seqlen_q)
|
| 738 |
+
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
|
| 739 |
+
other=1,
|
| 740 |
+
)
|
| 741 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 742 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 743 |
+
# to multiply with softmax_scale here.
|
| 744 |
+
# we assume mask already implies the causal masking
|
| 745 |
+
qk = qk * qk_scale
|
| 746 |
+
qk = tl.where(window_mask, float("-inf"), qk)
|
| 747 |
+
p = tl.exp2(qk - lse)
|
| 748 |
+
else:
|
| 749 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 750 |
+
p = tl.exp2(qk * qk_scale - lse)
|
| 751 |
+
|
| 752 |
+
if EVEN_N & EVEN_M:
|
| 753 |
+
if EVEN_HEADDIM:
|
| 754 |
+
v = tl.load(
|
| 755 |
+
v_ptrs + start_n * stride_vn
|
| 756 |
+
)
|
| 757 |
+
else:
|
| 758 |
+
v = tl.load(
|
| 759 |
+
v_ptrs + start_n * stride_vn,
|
| 760 |
+
mask=offs_d[None, :] < headdim,
|
| 761 |
+
other=0.0
|
| 762 |
+
)
|
| 763 |
+
else:
|
| 764 |
+
if EVEN_HEADDIM:
|
| 765 |
+
v = tl.load(
|
| 766 |
+
v_ptrs + start_n * stride_vn,
|
| 767 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 768 |
+
other=0.0,
|
| 769 |
+
)
|
| 770 |
+
else:
|
| 771 |
+
v = tl.load(
|
| 772 |
+
v_ptrs + start_n * stride_vn,
|
| 773 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 774 |
+
other=0.0,
|
| 775 |
+
)
|
| 776 |
+
dp = tl.dot(do, tl.trans(v))
|
| 777 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 778 |
+
dq += tl.dot(ds, k)
|
| 779 |
+
|
| 780 |
+
if EMPTY_RFA_KV == 0:
|
| 781 |
+
# Iterate over RFA chunks
|
| 782 |
+
# we only iterate over chunks before the current local singleton window
|
| 783 |
+
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
|
| 784 |
+
for start_c in range(0, end_idx_c, BLOCK_N):
|
| 785 |
+
start_c = tl.multiple_of(start_c, BLOCK_N)
|
| 786 |
+
# -- compute qk ----
|
| 787 |
+
if EVEN_C & EVEN_M:
|
| 788 |
+
if EVEN_HEADDIM:
|
| 789 |
+
rfa_k = tl.load(
|
| 790 |
+
rfa_k_ptrs + start_c * stride_rfa_kc
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
rfa_k = tl.load(
|
| 794 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 795 |
+
mask=offs_d[None, :] < headdim,
|
| 796 |
+
other=0.0
|
| 797 |
+
)
|
| 798 |
+
else:
|
| 799 |
+
if EVEN_HEADDIM:
|
| 800 |
+
rfa_k = tl.load(
|
| 801 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 802 |
+
mask=(start_c + offs_c)[:, None] < nchunks,
|
| 803 |
+
other=0.0,
|
| 804 |
+
)
|
| 805 |
+
else:
|
| 806 |
+
rfa_k = tl.load(
|
| 807 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 808 |
+
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 809 |
+
other=0.0,
|
| 810 |
+
)
|
| 811 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 812 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 813 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 814 |
+
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
|
| 815 |
+
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
|
| 816 |
+
|
| 817 |
+
if MASK_TYPE == 1:
|
| 818 |
+
if EVEN_C & EVEN_M:
|
| 819 |
+
chunk_mask = tl.load(
|
| 820 |
+
chunk_mask_ptrs + start_c
|
| 821 |
+
)
|
| 822 |
+
else:
|
| 823 |
+
chunk_mask = tl.load(
|
| 824 |
+
chunk_mask_ptrs + start_c,
|
| 825 |
+
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
|
| 826 |
+
other=1,
|
| 827 |
+
)
|
| 828 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 829 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 830 |
+
# to multiply with softmax_scale here.
|
| 831 |
+
# we assume mask already implies the causal masking
|
| 832 |
+
qk = qk * qk_scale
|
| 833 |
+
qk = tl.where(chunk_mask, float("-inf"), qk)
|
| 834 |
+
p = tl.exp2(qk - lse)
|
| 835 |
+
else:
|
| 836 |
+
p = tl.exp2(qk * qk_scale - lse)
|
| 837 |
+
|
| 838 |
+
if EVEN_C & EVEN_M:
|
| 839 |
+
if EVEN_HEADDIM:
|
| 840 |
+
rfa_v = tl.load(
|
| 841 |
+
rfa_v_ptrs + start_c * stride_rfa_vc
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
rfa_v = tl.load(
|
| 845 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 846 |
+
mask=offs_d[None, :] < headdim,
|
| 847 |
+
other=0.0
|
| 848 |
+
)
|
| 849 |
+
else:
|
| 850 |
+
if EVEN_HEADDIM:
|
| 851 |
+
rfa_v = tl.load(
|
| 852 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 853 |
+
mask=(start_c + offs_n)[:, None] < nchunks,
|
| 854 |
+
other=0.0,
|
| 855 |
+
)
|
| 856 |
+
else:
|
| 857 |
+
rfa_v = tl.load(
|
| 858 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 859 |
+
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 860 |
+
other=0.0,
|
| 861 |
+
)
|
| 862 |
+
dp = tl.dot(do, tl.trans(rfa_v))
|
| 863 |
+
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
|
| 864 |
+
dq += tl.dot(ds, rfa_k)
|
| 865 |
+
|
| 866 |
+
start_m = tl.program_id(0)
|
| 867 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 868 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 869 |
+
dq_ptrs = (
|
| 870 |
+
DQ +
|
| 871 |
+
off_b * stride_dq_b +
|
| 872 |
+
off_h * stride_dq_h +
|
| 873 |
+
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
|
| 874 |
+
)
|
| 875 |
+
if EVEN_M:
|
| 876 |
+
if EVEN_HEADDIM:
|
| 877 |
+
tl.store(
|
| 878 |
+
dq_ptrs, dq
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
tl.store(
|
| 882 |
+
dq_ptrs, dq,
|
| 883 |
+
mask=offs_d[None, :] < headdim
|
| 884 |
+
)
|
| 885 |
+
else:
|
| 886 |
+
if EVEN_HEADDIM:
|
| 887 |
+
tl.store(
|
| 888 |
+
dq_ptrs, dq,
|
| 889 |
+
mask=offs_m[:, None] < seqlen_q
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
tl.store(
|
| 893 |
+
dq_ptrs, dq,
|
| 894 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
_capability_90_config = {
|
| 898 |
+
"fwd": {
|
| 899 |
+
(torch.bfloat16, 64): (128, 128, 4, 3),
|
| 900 |
+
(torch.bfloat16, 128): (128, 128, 8, 3),
|
| 901 |
+
(torch.float32, 64): (128, 64, 8, 3),
|
| 902 |
+
(torch.float32, 128): (64, 32, 4, 3),
|
| 903 |
+
},
|
| 904 |
+
"bwd_dq": {
|
| 905 |
+
(torch.bfloat16, 64): (128, 64, 4, 3),
|
| 906 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 907 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 908 |
+
(torch.float32, 128): (32, 32, 4, 2),
|
| 909 |
+
},
|
| 910 |
+
"bwd_dkdv": {
|
| 911 |
+
(torch.bfloat16, 64): (128, 64, 4, 2),
|
| 912 |
+
(torch.bfloat16, 128): (128, 64, 8, 2),
|
| 913 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 914 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 915 |
+
},
|
| 916 |
+
"bwd_drfa_kv": {
|
| 917 |
+
(torch.bfloat16, 64): (128, 64, 4, 2),
|
| 918 |
+
(torch.bfloat16, 128): (128, 64, 8, 2),
|
| 919 |
+
(torch.float32, 64): (128, 64, 8, 2),
|
| 920 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 921 |
+
}
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
_capability_80_config = {
|
| 925 |
+
"fwd": {
|
| 926 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 927 |
+
(torch.bfloat16, 128): (64, 64, 8, 3),
|
| 928 |
+
(torch.float32, 64): (64, 32, 4, 2),
|
| 929 |
+
(torch.float32, 128): (64, 32, 8, 1),
|
| 930 |
+
},
|
| 931 |
+
"bwd_dq": {
|
| 932 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 933 |
+
(torch.bfloat16, 128): (64, 32, 4, 2),
|
| 934 |
+
(torch.float32, 64): (32, 32, 4, 2),
|
| 935 |
+
(torch.float32, 128): (32, 32, 4, 2),
|
| 936 |
+
},
|
| 937 |
+
"bwd_dkdv": {
|
| 938 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 939 |
+
(torch.bfloat16, 128): (32, 32, 4, 2),
|
| 940 |
+
(torch.float32, 64): (32, 32, 4, 1),
|
| 941 |
+
(torch.float32, 128): (16, 64, 8, 1),
|
| 942 |
+
},
|
| 943 |
+
"bwd_drfa_kv": {
|
| 944 |
+
(torch.bfloat16, 64): (64, 64, 4, 3),
|
| 945 |
+
(torch.bfloat16, 128): (64, 32, 4, 3),
|
| 946 |
+
(torch.float32, 64): (32, 32, 4, 1),
|
| 947 |
+
(torch.float32, 128): (32, 32, 4, 1),
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
|
| 952 |
+
capability = torch.cuda.get_device_capability()
|
| 953 |
+
if capability >= (9, 0):
|
| 954 |
+
kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
|
| 955 |
+
elif capability >= (8, 0):
|
| 956 |
+
kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
|
| 957 |
+
else:
|
| 958 |
+
if mode == "fwd":
|
| 959 |
+
if dtype == torch.float32:
|
| 960 |
+
kernel_config = (32, 16, 4, 2)
|
| 961 |
+
else:
|
| 962 |
+
kernel_config = (64, 32, 4, 2)
|
| 963 |
+
else:
|
| 964 |
+
if dtype == torch.float32:
|
| 965 |
+
kernel_config = (16, 16, 4, 1)
|
| 966 |
+
else:
|
| 967 |
+
kernel_config = (32, 32, 4, 1)
|
| 968 |
+
return kernel_config
|
| 969 |
+
|
| 970 |
+
@triton.heuristics(
|
| 971 |
+
{
|
| 972 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 973 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 974 |
+
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
|
| 975 |
+
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
|
| 976 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 977 |
+
}
|
| 978 |
+
)
|
| 979 |
+
@triton.jit
|
| 980 |
+
def _fwd_eva_agg_kernel(
|
| 981 |
+
Q,
|
| 982 |
+
K,
|
| 983 |
+
V,
|
| 984 |
+
RFA_K,
|
| 985 |
+
RFA_V,
|
| 986 |
+
WindowMask,
|
| 987 |
+
ChunkMask,
|
| 988 |
+
Out,
|
| 989 |
+
LSE,
|
| 990 |
+
softmax_scale,
|
| 991 |
+
stride_qb, stride_qh, stride_qm,
|
| 992 |
+
stride_kb, stride_kh, stride_kn,
|
| 993 |
+
stride_vb, stride_vh, stride_vn,
|
| 994 |
+
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
|
| 995 |
+
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
|
| 996 |
+
stride_window_mask_b, stride_window_mask_m,
|
| 997 |
+
stride_chunk_mask_b, stride_chunk_mask_m,
|
| 998 |
+
stride_ob, stride_oh, stride_om,
|
| 999 |
+
stride_lse_b, stride_lse_h,
|
| 1000 |
+
nheads,
|
| 1001 |
+
seqlen_q,
|
| 1002 |
+
seqlen_k,
|
| 1003 |
+
nchunks,
|
| 1004 |
+
headdim,
|
| 1005 |
+
CHUNKS_PER_WINDOW: tl.constexpr,
|
| 1006 |
+
WINDOW_SIZE: tl.constexpr,
|
| 1007 |
+
MASK_TYPE: tl.constexpr,
|
| 1008 |
+
EMPTY_RFA_KV: tl.constexpr,
|
| 1009 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 1010 |
+
EVEN_M: tl.constexpr,
|
| 1011 |
+
EVEN_N: tl.constexpr,
|
| 1012 |
+
EVEN_W: tl.constexpr,
|
| 1013 |
+
EVEN_C: tl.constexpr,
|
| 1014 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 1015 |
+
BLOCK_M: tl.constexpr,
|
| 1016 |
+
BLOCK_N: tl.constexpr,
|
| 1017 |
+
):
|
| 1018 |
+
start_m = tl.program_id(0)
|
| 1019 |
+
off_bh = tl.program_id(1)
|
| 1020 |
+
off_h = off_bh % nheads
|
| 1021 |
+
off_b = off_bh // nheads
|
| 1022 |
+
# initialize offsets
|
| 1023 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 1024 |
+
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
|
| 1025 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 1026 |
+
offs_c = tl.arange(0, BLOCK_N)
|
| 1027 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 1028 |
+
# TODO: add paratheses or not
|
| 1029 |
+
q_ptrs = (
|
| 1030 |
+
Q +
|
| 1031 |
+
off_b * stride_qb +
|
| 1032 |
+
off_h * stride_qh +
|
| 1033 |
+
(offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 1034 |
+
)
|
| 1035 |
+
k_ptrs = (
|
| 1036 |
+
K +
|
| 1037 |
+
off_b * stride_kb +
|
| 1038 |
+
off_h * stride_kh +
|
| 1039 |
+
(offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 1040 |
+
)
|
| 1041 |
+
v_ptrs = (
|
| 1042 |
+
V +
|
| 1043 |
+
off_b * stride_vb +
|
| 1044 |
+
off_h * stride_vh +
|
| 1045 |
+
(offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 1046 |
+
)
|
| 1047 |
+
if EMPTY_RFA_KV == 0:
|
| 1048 |
+
rfa_k_ptrs = (
|
| 1049 |
+
RFA_K +
|
| 1050 |
+
off_b * stride_rfa_kb +
|
| 1051 |
+
off_h * stride_rfa_kh +
|
| 1052 |
+
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
|
| 1053 |
+
)
|
| 1054 |
+
rfa_v_ptrs = (
|
| 1055 |
+
RFA_V +
|
| 1056 |
+
off_b * stride_rfa_vb +
|
| 1057 |
+
off_h * stride_rfa_vh +
|
| 1058 |
+
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
qk_scale = softmax_scale
|
| 1062 |
+
qk_scale *= 1.4426950408889634 # log2(e)
|
| 1063 |
+
if MASK_TYPE == 1:
|
| 1064 |
+
window_mask_ptrs = (
|
| 1065 |
+
WindowMask +
|
| 1066 |
+
off_b * stride_window_mask_b +
|
| 1067 |
+
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
|
| 1068 |
+
)
|
| 1069 |
+
if EMPTY_RFA_KV == 0:
|
| 1070 |
+
chunk_mask_ptrs = (
|
| 1071 |
+
ChunkMask +
|
| 1072 |
+
off_b * stride_chunk_mask_b +
|
| 1073 |
+
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 1077 |
+
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 1078 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 1079 |
+
# load q: it will stay in SRAM throughout
|
| 1080 |
+
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
|
| 1081 |
+
# tl.load(q_ptrs), we get the wrong output!
|
| 1082 |
+
if EVEN_M & EVEN_N:
|
| 1083 |
+
if EVEN_HEADDIM:
|
| 1084 |
+
q = tl.load(
|
| 1085 |
+
q_ptrs
|
| 1086 |
+
)
|
| 1087 |
+
else:
|
| 1088 |
+
q = tl.load(
|
| 1089 |
+
q_ptrs,
|
| 1090 |
+
mask=offs_d[None, :] < headdim,
|
| 1091 |
+
other=0.0
|
| 1092 |
+
)
|
| 1093 |
+
else:
|
| 1094 |
+
if EVEN_HEADDIM:
|
| 1095 |
+
q = tl.load(
|
| 1096 |
+
q_ptrs,
|
| 1097 |
+
mask=offs_m[:, None] < seqlen_q,
|
| 1098 |
+
other=0.0
|
| 1099 |
+
)
|
| 1100 |
+
else:
|
| 1101 |
+
q = tl.load(
|
| 1102 |
+
q_ptrs,
|
| 1103 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 1104 |
+
other=0.0
|
| 1105 |
+
)
|
| 1106 |
+
# loop over k, v and update accumulator
|
| 1107 |
+
# Iterate over local singletons;
|
| 1108 |
+
# so we only iterate over blocks within the current window
|
| 1109 |
+
start_idx_n = offs_w * WINDOW_SIZE
|
| 1110 |
+
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 1111 |
+
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
|
| 1112 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 1113 |
+
# -- compute qk ----
|
| 1114 |
+
if EVEN_N & EVEN_M:
|
| 1115 |
+
if EVEN_HEADDIM:
|
| 1116 |
+
k = tl.load(
|
| 1117 |
+
k_ptrs + start_n * stride_kn
|
| 1118 |
+
)
|
| 1119 |
+
else:
|
| 1120 |
+
k = tl.load(
|
| 1121 |
+
k_ptrs + start_n * stride_kn,
|
| 1122 |
+
mask=offs_d[None, :] < headdim,
|
| 1123 |
+
other=0.0
|
| 1124 |
+
)
|
| 1125 |
+
else:
|
| 1126 |
+
if EVEN_HEADDIM:
|
| 1127 |
+
k = tl.load(
|
| 1128 |
+
k_ptrs + start_n * stride_kn,
|
| 1129 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 1130 |
+
other=0.0,
|
| 1131 |
+
)
|
| 1132 |
+
else:
|
| 1133 |
+
k = tl.load(
|
| 1134 |
+
k_ptrs + start_n * stride_kn,
|
| 1135 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 1136 |
+
other=0.0,
|
| 1137 |
+
)
|
| 1138 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 1139 |
+
qk += tl.dot(q, tl.trans(k))
|
| 1140 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 1141 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 1142 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 1143 |
+
|
| 1144 |
+
if MASK_TYPE == 1:
|
| 1145 |
+
if EVEN_M & EVEN_W:
|
| 1146 |
+
window_mask = tl.load(
|
| 1147 |
+
window_mask_ptrs + start_n - start_idx_n
|
| 1148 |
+
)
|
| 1149 |
+
else:
|
| 1150 |
+
window_mask = tl.load(
|
| 1151 |
+
window_mask_ptrs + start_n - start_idx_n,
|
| 1152 |
+
mask=(offs_m[:, None] < seqlen_q)
|
| 1153 |
+
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
|
| 1154 |
+
other=1,
|
| 1155 |
+
)
|
| 1156 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 1157 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 1158 |
+
# to multiply with softmax_scale here.
|
| 1159 |
+
# we assume mask already implies the causal masking
|
| 1160 |
+
qk = qk * qk_scale
|
| 1161 |
+
qk = tl.where(window_mask, float("-inf"), qk)
|
| 1162 |
+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
|
| 1163 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1164 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1165 |
+
p = tl.exp2(qk - m_ij_masked[:, None])
|
| 1166 |
+
else:
|
| 1167 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 1168 |
+
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
|
| 1169 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1170 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1171 |
+
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
|
| 1172 |
+
|
| 1173 |
+
d_ij = tl.sum(p, 1)
|
| 1174 |
+
|
| 1175 |
+
# scale acc_o
|
| 1176 |
+
prev_scale = tl.exp2(m_i - m_ij_masked)
|
| 1177 |
+
# # -- update output accumulator --
|
| 1178 |
+
acc_o = acc_o * prev_scale[:, None]
|
| 1179 |
+
# update acc_o
|
| 1180 |
+
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
| 1181 |
+
if EVEN_HEADDIM:
|
| 1182 |
+
v = tl.load(
|
| 1183 |
+
v_ptrs + start_n * stride_vn
|
| 1184 |
+
)
|
| 1185 |
+
else:
|
| 1186 |
+
v = tl.load(
|
| 1187 |
+
v_ptrs + start_n * stride_vn,
|
| 1188 |
+
mask=offs_d[None, :] < headdim,
|
| 1189 |
+
other=0.0
|
| 1190 |
+
)
|
| 1191 |
+
else:
|
| 1192 |
+
if EVEN_HEADDIM:
|
| 1193 |
+
v = tl.load(
|
| 1194 |
+
v_ptrs + start_n * stride_vn,
|
| 1195 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
| 1196 |
+
other=0.0,
|
| 1197 |
+
)
|
| 1198 |
+
else:
|
| 1199 |
+
v = tl.load(
|
| 1200 |
+
v_ptrs + start_n * stride_vn,
|
| 1201 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 1202 |
+
other=0.0,
|
| 1203 |
+
)
|
| 1204 |
+
p = p.to(v.dtype)
|
| 1205 |
+
acc_o = tl.dot(p, v, acc_o)
|
| 1206 |
+
|
| 1207 |
+
# -- update statistics
|
| 1208 |
+
d_i = d_i * prev_scale + d_ij
|
| 1209 |
+
m_i = m_ij
|
| 1210 |
+
|
| 1211 |
+
if EMPTY_RFA_KV == 0:
|
| 1212 |
+
# Iterate over RFA chunks
|
| 1213 |
+
# we only iterate over chunks before the current local singleton window
|
| 1214 |
+
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
|
| 1215 |
+
for start_c in range(0, end_idx_c, BLOCK_N):
|
| 1216 |
+
start_c = tl.multiple_of(start_c, BLOCK_N)
|
| 1217 |
+
# -- compute qk ----
|
| 1218 |
+
if EVEN_C & EVEN_M:
|
| 1219 |
+
if EVEN_HEADDIM:
|
| 1220 |
+
rfa_k = tl.load(
|
| 1221 |
+
rfa_k_ptrs + start_c * stride_rfa_kc
|
| 1222 |
+
)
|
| 1223 |
+
else:
|
| 1224 |
+
rfa_k = tl.load(
|
| 1225 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1226 |
+
mask=offs_d[None, :] < headdim,
|
| 1227 |
+
other=0.0
|
| 1228 |
+
)
|
| 1229 |
+
else:
|
| 1230 |
+
if EVEN_HEADDIM:
|
| 1231 |
+
rfa_k = tl.load(
|
| 1232 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1233 |
+
mask=(start_c + offs_c)[:, None] < nchunks,
|
| 1234 |
+
other=0.0,
|
| 1235 |
+
)
|
| 1236 |
+
else:
|
| 1237 |
+
rfa_k = tl.load(
|
| 1238 |
+
rfa_k_ptrs + start_c * stride_rfa_kc,
|
| 1239 |
+
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 1240 |
+
other=0.0,
|
| 1241 |
+
)
|
| 1242 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 1243 |
+
qk += tl.dot(q, tl.trans(rfa_k))
|
| 1244 |
+
# Trying to combine the two masks seem to make the result wrong
|
| 1245 |
+
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
|
| 1246 |
+
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
|
| 1247 |
+
|
| 1248 |
+
if MASK_TYPE == 1:
|
| 1249 |
+
if EVEN_C & EVEN_M:
|
| 1250 |
+
chunk_mask = tl.load(
|
| 1251 |
+
chunk_mask_ptrs + start_c
|
| 1252 |
+
)
|
| 1253 |
+
else:
|
| 1254 |
+
chunk_mask = tl.load(
|
| 1255 |
+
chunk_mask_ptrs + start_c,
|
| 1256 |
+
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
|
| 1257 |
+
other=1,
|
| 1258 |
+
)
|
| 1259 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
| 1260 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
| 1261 |
+
# to multiply with softmax_scale here.
|
| 1262 |
+
# we assume mask already implies the causal masking
|
| 1263 |
+
qk = qk * qk_scale
|
| 1264 |
+
qk = tl.where(chunk_mask, float("-inf"), qk)
|
| 1265 |
+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
|
| 1266 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1267 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1268 |
+
p = tl.exp2(qk - m_ij_masked[:, None])
|
| 1269 |
+
else:
|
| 1270 |
+
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
|
| 1271 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 1272 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 1273 |
+
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
|
| 1274 |
+
|
| 1275 |
+
d_ij = tl.sum(p, 1)
|
| 1276 |
+
|
| 1277 |
+
# scale acc_o
|
| 1278 |
+
prev_scale = tl.exp2(m_i - m_ij_masked)
|
| 1279 |
+
# # -- update output accumulator --
|
| 1280 |
+
acc_o = acc_o * prev_scale[:, None]
|
| 1281 |
+
# update acc_o
|
| 1282 |
+
# TODO: If we just do "if EVEN_N", there seems to be some race condition ?
|
| 1283 |
+
if EVEN_C & EVEN_M:
|
| 1284 |
+
if EVEN_HEADDIM:
|
| 1285 |
+
rfa_v = tl.load(
|
| 1286 |
+
rfa_v_ptrs + start_c * stride_rfa_vc
|
| 1287 |
+
)
|
| 1288 |
+
else:
|
| 1289 |
+
rfa_v = tl.load(
|
| 1290 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1291 |
+
mask=offs_d[None, :] < headdim,
|
| 1292 |
+
other=0.0
|
| 1293 |
+
)
|
| 1294 |
+
else:
|
| 1295 |
+
if EVEN_HEADDIM:
|
| 1296 |
+
rfa_v = tl.load(
|
| 1297 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1298 |
+
mask=(start_c + offs_n)[:, None] < nchunks,
|
| 1299 |
+
other=0.0,
|
| 1300 |
+
)
|
| 1301 |
+
else:
|
| 1302 |
+
rfa_v = tl.load(
|
| 1303 |
+
rfa_v_ptrs + start_c * stride_rfa_vc,
|
| 1304 |
+
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 1305 |
+
other=0.0,
|
| 1306 |
+
)
|
| 1307 |
+
p = p.to(rfa_v.dtype)
|
| 1308 |
+
acc_o = tl.dot(p, rfa_v, acc_o)
|
| 1309 |
+
|
| 1310 |
+
# -- update statistics
|
| 1311 |
+
d_i = d_i * prev_scale + d_ij
|
| 1312 |
+
m_i = m_ij
|
| 1313 |
+
|
| 1314 |
+
# for rows that are all -inf, set d_i to 1.0
|
| 1315 |
+
d_i = tl.where(d_i == 0.0, 1.0, d_i)
|
| 1316 |
+
# multiply by log(2)
|
| 1317 |
+
lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
|
| 1318 |
+
acc_o = acc_o / d_i[:, None]
|
| 1319 |
+
# TODO: understand why rematerialize offsets to save registers?
|
| 1320 |
+
start_m = tl.program_id(0)
|
| 1321 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 1322 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 1323 |
+
out_ptrs = (
|
| 1324 |
+
Out +
|
| 1325 |
+
off_b * stride_ob +
|
| 1326 |
+
off_h * stride_oh +
|
| 1327 |
+
(offs_m[:, None] * stride_om + offs_d[None, :])
|
| 1328 |
+
)
|
| 1329 |
+
if EVEN_M:
|
| 1330 |
+
if EVEN_HEADDIM:
|
| 1331 |
+
tl.store(
|
| 1332 |
+
out_ptrs, acc_o
|
| 1333 |
+
)
|
| 1334 |
+
else:
|
| 1335 |
+
tl.store(
|
| 1336 |
+
out_ptrs, acc_o,
|
| 1337 |
+
mask=offs_d[None, :] < headdim
|
| 1338 |
+
)
|
| 1339 |
+
else:
|
| 1340 |
+
if EVEN_HEADDIM:
|
| 1341 |
+
tl.store(
|
| 1342 |
+
out_ptrs, acc_o,
|
| 1343 |
+
mask=offs_m[:, None] < seqlen_q
|
| 1344 |
+
)
|
| 1345 |
+
else:
|
| 1346 |
+
tl.store(
|
| 1347 |
+
out_ptrs, acc_o,
|
| 1348 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
| 1349 |
+
)
|
| 1350 |
+
lse_ptrs = (
|
| 1351 |
+
LSE +
|
| 1352 |
+
off_b * stride_lse_b +
|
| 1353 |
+
off_h * stride_lse_h +
|
| 1354 |
+
offs_m
|
| 1355 |
+
)
|
| 1356 |
+
if EVEN_M:
|
| 1357 |
+
tl.store(
|
| 1358 |
+
lse_ptrs, lse_m,
|
| 1359 |
+
)
|
| 1360 |
+
else:
|
| 1361 |
+
tl.store(
|
| 1362 |
+
lse_ptrs, lse_m,
|
| 1363 |
+
mask=offs_m < seqlen_q
|
| 1364 |
+
)
|
| 1365 |
+
|
| 1366 |
+
def triton_eva_agg_fwd(
|
| 1367 |
+
q, k, v, rfa_k, rfa_v,
|
| 1368 |
+
window_mask,
|
| 1369 |
+
chunk_mask,
|
| 1370 |
+
softmax_scale,
|
| 1371 |
+
window_size,
|
| 1372 |
+
chunks_per_window
|
| 1373 |
+
):
|
| 1374 |
+
if rfa_k is None and rfa_v is None:
|
| 1375 |
+
empty_rfa_kv = 1
|
| 1376 |
+
|
| 1377 |
+
q, k, v = [
|
| 1378 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 1379 |
+
for x in [q, k, v]
|
| 1380 |
+
]
|
| 1381 |
+
else:
|
| 1382 |
+
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
|
| 1383 |
+
empty_rfa_kv = 0
|
| 1384 |
+
|
| 1385 |
+
q, k, v, rfa_k, rfa_v = [
|
| 1386 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 1387 |
+
for x in [q, k, v, rfa_k, rfa_v]
|
| 1388 |
+
]
|
| 1389 |
+
|
| 1390 |
+
# shape constraints
|
| 1391 |
+
batch, nheads, seqlen_q, head_dim = q.shape
|
| 1392 |
+
_, _, seqlen_k, _ = k.shape
|
| 1393 |
+
if empty_rfa_kv == 0:
|
| 1394 |
+
nchunks = rfa_k.shape[-2]
|
| 1395 |
+
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
|
| 1396 |
+
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
|
| 1397 |
+
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
|
| 1398 |
+
else:
|
| 1399 |
+
nchunks = 0
|
| 1400 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 1401 |
+
assert k.shape == (batch, nheads, seqlen_k, head_dim)
|
| 1402 |
+
assert v.shape == (batch, nheads, seqlen_k, head_dim)
|
| 1403 |
+
|
| 1404 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 1405 |
+
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
| 1406 |
+
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
|
| 1407 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
| 1408 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 1409 |
+
|
| 1410 |
+
mask_type = 0
|
| 1411 |
+
if window_mask is not None:
|
| 1412 |
+
mask_type = 1
|
| 1413 |
+
assert window_mask.dtype == torch.bool
|
| 1414 |
+
assert window_mask.is_cuda
|
| 1415 |
+
assert window_mask.dim() == 4
|
| 1416 |
+
assert window_mask.shape == (batch, 1, seqlen_q, window_size)
|
| 1417 |
+
if window_mask.stride(-1) != 1:
|
| 1418 |
+
window_mask = window_mask.contiguous()
|
| 1419 |
+
|
| 1420 |
+
assert chunk_mask is not None
|
| 1421 |
+
assert chunk_mask.dtype == torch.bool
|
| 1422 |
+
assert chunk_mask.is_cuda
|
| 1423 |
+
assert chunk_mask.dim() == 4
|
| 1424 |
+
assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
|
| 1425 |
+
if chunk_mask.stride(-1) != 1:
|
| 1426 |
+
chunk_mask = chunk_mask.contiguous()
|
| 1427 |
+
|
| 1428 |
+
chunk_mask_strides = (
|
| 1429 |
+
(chunk_mask.stride(0), chunk_mask.stride(2))
|
| 1430 |
+
if mask_type == 1 else
|
| 1431 |
+
(0, 0)
|
| 1432 |
+
)
|
| 1433 |
+
window_mask_strides = (
|
| 1434 |
+
(window_mask.stride(0), window_mask.stride(2))
|
| 1435 |
+
if mask_type == 1 else
|
| 1436 |
+
(0, 0)
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
rfa_k_strides = (
|
| 1440 |
+
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
|
| 1441 |
+
if empty_rfa_kv == 0 else
|
| 1442 |
+
(0, 0, 0)
|
| 1443 |
+
)
|
| 1444 |
+
rfa_v_strides = (
|
| 1445 |
+
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
|
| 1446 |
+
if empty_rfa_kv == 0 else
|
| 1447 |
+
(0, 0, 0)
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
o = torch.empty_like(q)
|
| 1451 |
+
lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
| 1452 |
+
|
| 1453 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 1454 |
+
|
| 1455 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
|
| 1456 |
+
|
| 1457 |
+
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
|
| 1458 |
+
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
|
| 1459 |
+
|
| 1460 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 1461 |
+
_fwd_eva_agg_kernel[grid](
|
| 1462 |
+
q,
|
| 1463 |
+
k,
|
| 1464 |
+
v,
|
| 1465 |
+
rfa_k,
|
| 1466 |
+
rfa_v,
|
| 1467 |
+
window_mask,
|
| 1468 |
+
chunk_mask,
|
| 1469 |
+
o,
|
| 1470 |
+
lse,
|
| 1471 |
+
softmax_scale,
|
| 1472 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1473 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1474 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1475 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1476 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1477 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1478 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1479 |
+
o.stride(0), o.stride(1), o.stride(2),
|
| 1480 |
+
lse.stride(0), lse.stride(1),
|
| 1481 |
+
nheads,
|
| 1482 |
+
seqlen_q,
|
| 1483 |
+
seqlen_k,
|
| 1484 |
+
nchunks,
|
| 1485 |
+
head_dim,
|
| 1486 |
+
chunks_per_window,
|
| 1487 |
+
window_size,
|
| 1488 |
+
mask_type,
|
| 1489 |
+
empty_rfa_kv,
|
| 1490 |
+
BLOCK_HEADDIM,
|
| 1491 |
+
BLOCK_M=BLOCK_M,
|
| 1492 |
+
BLOCK_N=BLOCK_N,
|
| 1493 |
+
num_warps=num_warps,
|
| 1494 |
+
num_stages=num_stages,
|
| 1495 |
+
)
|
| 1496 |
+
return o, lse
|
| 1497 |
+
|
| 1498 |
+
def triton_eva_agg_bwd(
|
| 1499 |
+
do,
|
| 1500 |
+
q, k, v, rfa_k, rfa_v,
|
| 1501 |
+
window_mask, chunk_mask,
|
| 1502 |
+
o, lse,
|
| 1503 |
+
dq, dk, dv, d_rfa_k, d_rfa_v,
|
| 1504 |
+
softmax_scale,
|
| 1505 |
+
window_size,
|
| 1506 |
+
chunks_per_window,
|
| 1507 |
+
empty_rfa_kv,
|
| 1508 |
+
mask_type,
|
| 1509 |
+
):
|
| 1510 |
+
if do.stride(-1) != 1:
|
| 1511 |
+
do = do.contiguous()
|
| 1512 |
+
|
| 1513 |
+
# shape constraints
|
| 1514 |
+
batch, nheads, seqlen_q, head_dim = q.shape
|
| 1515 |
+
_, _, seqlen_k, _ = k.shape
|
| 1516 |
+
if empty_rfa_kv == 0:
|
| 1517 |
+
nchunks = rfa_k.shape[-2]
|
| 1518 |
+
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
|
| 1519 |
+
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
|
| 1520 |
+
assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
|
| 1521 |
+
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
|
| 1522 |
+
else:
|
| 1523 |
+
nchunks = 0
|
| 1524 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 1525 |
+
|
| 1526 |
+
assert lse.shape == (batch, nheads, seqlen_q)
|
| 1527 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
|
| 1528 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
| 1529 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 1530 |
+
|
| 1531 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 1532 |
+
|
| 1533 |
+
window_mask_strides = (
|
| 1534 |
+
(window_mask.stride(0), window_mask.stride(2))
|
| 1535 |
+
if mask_type == 1 else
|
| 1536 |
+
(0, 0)
|
| 1537 |
+
)
|
| 1538 |
+
chunk_mask_strides = (
|
| 1539 |
+
(chunk_mask.stride(0), chunk_mask.stride(2))
|
| 1540 |
+
if mask_type == 1 else
|
| 1541 |
+
(0, 0)
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
rfa_k_strides = (
|
| 1545 |
+
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
|
| 1546 |
+
if empty_rfa_kv == 0 else
|
| 1547 |
+
(0, 0, 0)
|
| 1548 |
+
)
|
| 1549 |
+
rfa_v_strides = (
|
| 1550 |
+
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
|
| 1551 |
+
if empty_rfa_kv == 0 else
|
| 1552 |
+
(0, 0, 0)
|
| 1553 |
+
)
|
| 1554 |
+
|
| 1555 |
+
d_rfa_k_strides = (
|
| 1556 |
+
(d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
|
| 1557 |
+
if empty_rfa_kv == 0 else
|
| 1558 |
+
(0, 0, 0)
|
| 1559 |
+
)
|
| 1560 |
+
d_rfa_v_strides = (
|
| 1561 |
+
(d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
|
| 1562 |
+
if empty_rfa_kv == 0 else
|
| 1563 |
+
(0, 0, 0)
|
| 1564 |
+
)
|
| 1565 |
+
|
| 1566 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 1567 |
+
|
| 1568 |
+
do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
|
| 1569 |
+
|
| 1570 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
|
| 1571 |
+
|
| 1572 |
+
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
|
| 1573 |
+
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
|
| 1574 |
+
grid = lambda META: (
|
| 1575 |
+
triton.cdiv(seqlen_q, META["BLOCK_M"]),
|
| 1576 |
+
batch * nheads,
|
| 1577 |
+
)
|
| 1578 |
+
_bwd_eva_agg_kernel_dq[grid](
|
| 1579 |
+
q,
|
| 1580 |
+
k,
|
| 1581 |
+
v,
|
| 1582 |
+
rfa_k,
|
| 1583 |
+
rfa_v,
|
| 1584 |
+
window_mask,
|
| 1585 |
+
chunk_mask,
|
| 1586 |
+
do,
|
| 1587 |
+
lse,
|
| 1588 |
+
do_t_o,
|
| 1589 |
+
dq,
|
| 1590 |
+
softmax_scale,
|
| 1591 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1592 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1593 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1594 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1595 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1596 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1597 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1598 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1599 |
+
lse.stride(0), lse.stride(1),
|
| 1600 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1601 |
+
dq.stride(0), dq.stride(1), dq.stride(2),
|
| 1602 |
+
nheads,
|
| 1603 |
+
seqlen_q,
|
| 1604 |
+
seqlen_k,
|
| 1605 |
+
nchunks,
|
| 1606 |
+
head_dim,
|
| 1607 |
+
chunks_per_window,
|
| 1608 |
+
window_size,
|
| 1609 |
+
mask_type,
|
| 1610 |
+
empty_rfa_kv,
|
| 1611 |
+
BLOCK_HEADDIM,
|
| 1612 |
+
BLOCK_M=BLOCK_M,
|
| 1613 |
+
BLOCK_N=BLOCK_N,
|
| 1614 |
+
num_warps=num_warps,
|
| 1615 |
+
num_stages=num_stages,
|
| 1616 |
+
)
|
| 1617 |
+
|
| 1618 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
|
| 1619 |
+
grid = lambda META: (
|
| 1620 |
+
triton.cdiv(seqlen_k, META["BLOCK_N"]),
|
| 1621 |
+
batch * nheads,
|
| 1622 |
+
)
|
| 1623 |
+
_bwd_eva_agg_kernel_dkdv[grid](
|
| 1624 |
+
q,
|
| 1625 |
+
k,
|
| 1626 |
+
v,
|
| 1627 |
+
window_mask,
|
| 1628 |
+
do,
|
| 1629 |
+
lse,
|
| 1630 |
+
do_t_o,
|
| 1631 |
+
dk,
|
| 1632 |
+
dv,
|
| 1633 |
+
softmax_scale,
|
| 1634 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1635 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 1636 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 1637 |
+
window_mask_strides[0], window_mask_strides[1],
|
| 1638 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1639 |
+
lse.stride(0), lse.stride(1),
|
| 1640 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1641 |
+
dk.stride(0), dk.stride(1), dk.stride(2),
|
| 1642 |
+
dv.stride(0), dv.stride(1), dv.stride(2),
|
| 1643 |
+
nheads,
|
| 1644 |
+
seqlen_q,
|
| 1645 |
+
seqlen_k,
|
| 1646 |
+
head_dim,
|
| 1647 |
+
window_size,
|
| 1648 |
+
mask_type,
|
| 1649 |
+
BLOCK_HEADDIM,
|
| 1650 |
+
BLOCK_M=BLOCK_M,
|
| 1651 |
+
BLOCK_N=BLOCK_N,
|
| 1652 |
+
num_warps=num_warps,
|
| 1653 |
+
num_stages=num_stages,
|
| 1654 |
+
)
|
| 1655 |
+
if empty_rfa_kv == 0:
|
| 1656 |
+
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
|
| 1657 |
+
grid = lambda META: (
|
| 1658 |
+
triton.cdiv(nchunks, META["BLOCK_N"]),
|
| 1659 |
+
batch * nheads,
|
| 1660 |
+
)
|
| 1661 |
+
_bwd_eva_agg_kernel_drfa_kv[grid](
|
| 1662 |
+
q,
|
| 1663 |
+
rfa_k,
|
| 1664 |
+
rfa_v,
|
| 1665 |
+
chunk_mask,
|
| 1666 |
+
do,
|
| 1667 |
+
lse,
|
| 1668 |
+
do_t_o,
|
| 1669 |
+
d_rfa_k,
|
| 1670 |
+
d_rfa_v,
|
| 1671 |
+
softmax_scale,
|
| 1672 |
+
q.stride(0), q.stride(1), q.stride(2),
|
| 1673 |
+
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
|
| 1674 |
+
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
|
| 1675 |
+
chunk_mask_strides[0], chunk_mask_strides[1],
|
| 1676 |
+
do.stride(0), do.stride(1), do.stride(2),
|
| 1677 |
+
lse.stride(0), lse.stride(1),
|
| 1678 |
+
do_t_o.stride(0), do_t_o.stride(1),
|
| 1679 |
+
d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
|
| 1680 |
+
d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
|
| 1681 |
+
nheads,
|
| 1682 |
+
seqlen_q,
|
| 1683 |
+
nchunks,
|
| 1684 |
+
head_dim,
|
| 1685 |
+
chunks_per_window,
|
| 1686 |
+
window_size,
|
| 1687 |
+
mask_type,
|
| 1688 |
+
BLOCK_HEADDIM,
|
| 1689 |
+
BLOCK_M=BLOCK_M,
|
| 1690 |
+
BLOCK_N=BLOCK_N,
|
| 1691 |
+
num_warps=num_warps,
|
| 1692 |
+
num_stages=num_stages,
|
| 1693 |
+
)
|
| 1694 |
+
|
| 1695 |
+
|
| 1696 |
+
class EvaAggFunc(torch.autograd.Function):
|
| 1697 |
+
@staticmethod
|
| 1698 |
+
def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
|
| 1699 |
+
if rfa_k is None and rfa_v is None:
|
| 1700 |
+
empty_rfa_kv = 1
|
| 1701 |
+
else:
|
| 1702 |
+
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
|
| 1703 |
+
empty_rfa_kv = 0
|
| 1704 |
+
|
| 1705 |
+
if window_mask is not None:
|
| 1706 |
+
mask_type = 1
|
| 1707 |
+
else:
|
| 1708 |
+
mask_type = 0
|
| 1709 |
+
o, lse = triton_eva_agg_fwd(
|
| 1710 |
+
q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
|
| 1711 |
+
)
|
| 1712 |
+
ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
|
| 1713 |
+
ctx.softmax_scale = softmax_scale
|
| 1714 |
+
ctx.window_size = window_size
|
| 1715 |
+
ctx.chunks_per_window = chunks_per_window
|
| 1716 |
+
ctx.empty_rfa_kv = empty_rfa_kv
|
| 1717 |
+
ctx.mask_type = mask_type
|
| 1718 |
+
return o
|
| 1719 |
+
|
| 1720 |
+
@staticmethod
|
| 1721 |
+
def backward(ctx, do):
|
| 1722 |
+
q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
|
| 1723 |
+
dq = torch.empty_like(q)
|
| 1724 |
+
dk = torch.empty_like(k)
|
| 1725 |
+
dv = torch.empty_like(v)
|
| 1726 |
+
if ctx.empty_rfa_kv == 0:
|
| 1727 |
+
d_rfa_k = torch.empty_like(rfa_k)
|
| 1728 |
+
d_rfa_v = torch.empty_like(rfa_v)
|
| 1729 |
+
else:
|
| 1730 |
+
d_rfa_k = None
|
| 1731 |
+
d_rfa_v = None
|
| 1732 |
+
triton_eva_agg_bwd(
|
| 1733 |
+
do,
|
| 1734 |
+
q,
|
| 1735 |
+
k,
|
| 1736 |
+
v,
|
| 1737 |
+
rfa_k,
|
| 1738 |
+
rfa_v,
|
| 1739 |
+
window_mask,
|
| 1740 |
+
chunk_mask,
|
| 1741 |
+
o,
|
| 1742 |
+
lse,
|
| 1743 |
+
dq,
|
| 1744 |
+
dk,
|
| 1745 |
+
dv,
|
| 1746 |
+
d_rfa_k,
|
| 1747 |
+
d_rfa_v,
|
| 1748 |
+
softmax_scale=ctx.softmax_scale,
|
| 1749 |
+
window_size=ctx.window_size,
|
| 1750 |
+
chunks_per_window=ctx.chunks_per_window,
|
| 1751 |
+
empty_rfa_kv=ctx.empty_rfa_kv,
|
| 1752 |
+
mask_type=ctx.mask_type,
|
| 1753 |
+
)
|
| 1754 |
+
return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
|
| 1755 |
+
|
| 1756 |
+
|
| 1757 |
+
def eva_agg_func_triton(
|
| 1758 |
+
q, k, v, rfa_k, rfa_v,
|
| 1759 |
+
window_mask, chunk_mask,
|
| 1760 |
+
softmax_scale=None, window_size=None, chunks_per_window=None,
|
| 1761 |
+
):
|
| 1762 |
+
return EvaAggFunc.apply(
|
| 1763 |
+
q, k, v, rfa_k, rfa_v,
|
| 1764 |
+
window_mask, chunk_mask,
|
| 1765 |
+
softmax_scale, window_size, chunks_per_window,
|
| 1766 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_cache.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple, List, Any, Union
|
| 2 |
+
import torch
|
| 3 |
+
from transformers.cache_utils import Cache
|
| 4 |
+
|
| 5 |
+
class EvaCache(Cache):
|
| 6 |
+
"""
|
| 7 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 8 |
+
|
| 9 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
| 10 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
self.w_k: List[torch.Tensor] = []
|
| 15 |
+
self.w_v: List[torch.Tensor] = []
|
| 16 |
+
|
| 17 |
+
self.rf_q: List[torch.Tensor] = []
|
| 18 |
+
self.rf_k: List[torch.Tensor] = []
|
| 19 |
+
self.rf_v: List[torch.Tensor] = []
|
| 20 |
+
|
| 21 |
+
self.softmax_phi_k_v: List[torch.Tensor] = []
|
| 22 |
+
self.log_sum_phi_k: List[torch.Tensor] = []
|
| 23 |
+
self.rf_k_bar: List[torch.Tensor] = []
|
| 24 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 25 |
+
|
| 26 |
+
# attention masks temporary buffer
|
| 27 |
+
self.rf_mask: List[Optional[torch.Tensor]] = []
|
| 28 |
+
self.s_mask: List[torch.Tensor] = []
|
| 29 |
+
self.chunk_mask: List[torch.Tensor] = []
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
"""
|
| 33 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
| 34 |
+
to the number of layers in the model.
|
| 35 |
+
"""
|
| 36 |
+
return len(self.w_k)
|
| 37 |
+
|
| 38 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
| 39 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
| 40 |
+
# Cache without size limit -> all cache is usable
|
| 41 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
| 42 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
| 43 |
+
max_length = self.get_max_length()
|
| 44 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 45 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 46 |
+
return max_length - new_seq_length
|
| 47 |
+
return previous_seq_length
|
| 48 |
+
|
| 49 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 50 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 51 |
+
for layer_idx in range(len(self.w_k)):
|
| 52 |
+
device = self.w_k[layer_idx].device
|
| 53 |
+
self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 54 |
+
|
| 55 |
+
device = self.w_v[layer_idx].device
|
| 56 |
+
self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 57 |
+
|
| 58 |
+
device = self.rf_q[layer_idx].device
|
| 59 |
+
self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device))
|
| 60 |
+
|
| 61 |
+
device = self.rf_k[layer_idx].device
|
| 62 |
+
self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 63 |
+
|
| 64 |
+
device = self.rf_v[layer_idx].device
|
| 65 |
+
self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 66 |
+
|
| 67 |
+
device = self.softmax_phi_k_v[layer_idx].device
|
| 68 |
+
self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 69 |
+
|
| 70 |
+
device = self.log_sum_phi_k[layer_idx].device
|
| 71 |
+
self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 72 |
+
|
| 73 |
+
device = self.rf_k_bar[layer_idx].device
|
| 74 |
+
self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device))
|
| 75 |
+
|
| 76 |
+
device = self.rf_mask[layer_idx].device
|
| 77 |
+
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 78 |
+
|
| 79 |
+
device = self.s_mask[layer_idx].device
|
| 80 |
+
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 81 |
+
|
| 82 |
+
device = self.chunk_mask[layer_idx].device
|
| 83 |
+
self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 84 |
+
@property
|
| 85 |
+
def seen_tokens(self):
|
| 86 |
+
if hasattr(self, "_seen_tokens"):
|
| 87 |
+
return self._seen_tokens
|
| 88 |
+
else:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
def update_past_len(
|
| 92 |
+
self,
|
| 93 |
+
cur_q_len: int,
|
| 94 |
+
layer_idx: int
|
| 95 |
+
):
|
| 96 |
+
# Update the number of seen tokens
|
| 97 |
+
if layer_idx == 0:
|
| 98 |
+
self._seen_tokens += cur_q_len
|
| 99 |
+
return self._seen_tokens
|
| 100 |
+
|
| 101 |
+
def update_mask(
|
| 102 |
+
self,
|
| 103 |
+
prev_s_mask,
|
| 104 |
+
cur_s_mask,
|
| 105 |
+
chunk_mask,
|
| 106 |
+
rf_mask,
|
| 107 |
+
layer_idx,
|
| 108 |
+
window_size,
|
| 109 |
+
chunk_size,
|
| 110 |
+
):
|
| 111 |
+
############################################
|
| 112 |
+
# compute masks for singletons
|
| 113 |
+
############################################
|
| 114 |
+
q_len = None
|
| 115 |
+
if len(self.s_mask) <= layer_idx:
|
| 116 |
+
q_len = chunk_mask.shape[-2]
|
| 117 |
+
# prefill stage
|
| 118 |
+
# q is of shape [b, h, n, d]
|
| 119 |
+
if q_len < window_size:
|
| 120 |
+
assert prev_s_mask is None
|
| 121 |
+
|
| 122 |
+
# w_v = # [b, h, 1, j, d]
|
| 123 |
+
# store the past window-wise key-value pairs
|
| 124 |
+
self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :])
|
| 125 |
+
else:
|
| 126 |
+
# decoding stage
|
| 127 |
+
prev_s_mask = None
|
| 128 |
+
|
| 129 |
+
cached_s_mask = self.s_mask[layer_idx]
|
| 130 |
+
assert cached_s_mask is not None
|
| 131 |
+
if cached_s_mask.shape[-1] == window_size:
|
| 132 |
+
cur_s_mask = cur_s_mask
|
| 133 |
+
else:
|
| 134 |
+
cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1)
|
| 135 |
+
|
| 136 |
+
# store the past window-wise key-value pairs
|
| 137 |
+
self.s_mask[layer_idx] = cur_s_mask
|
| 138 |
+
|
| 139 |
+
############################################
|
| 140 |
+
# compute masks for intra-chunks
|
| 141 |
+
############################################
|
| 142 |
+
dump_rf_mask = None
|
| 143 |
+
if len(self.rf_mask) <= layer_idx:
|
| 144 |
+
# initialize chunk stats
|
| 145 |
+
# prefill stage
|
| 146 |
+
if q_len < chunk_size:
|
| 147 |
+
cur_rf_mask = rf_mask
|
| 148 |
+
else:
|
| 149 |
+
if q_len % chunk_size == 0:
|
| 150 |
+
dump_rf_mask = rf_mask
|
| 151 |
+
cur_rf_mask = None
|
| 152 |
+
else:
|
| 153 |
+
remainder_tokens = q_len % chunk_size
|
| 154 |
+
if rf_mask is not None:
|
| 155 |
+
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 156 |
+
else:
|
| 157 |
+
dump_rf_mask = None
|
| 158 |
+
cur_rf_mask = None
|
| 159 |
+
self.rf_mask.append(cur_rf_mask)
|
| 160 |
+
else:
|
| 161 |
+
past_rf_mask = self.rf_mask[layer_idx]
|
| 162 |
+
if past_rf_mask is not None:
|
| 163 |
+
# when decoding tokens, we always assume the
|
| 164 |
+
# incoming token mask is 0 (not masked)
|
| 165 |
+
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
|
| 166 |
+
else:
|
| 167 |
+
# we do not need to use rf_mask anymore after we receive generated tokens
|
| 168 |
+
cur_rf_mask = None
|
| 169 |
+
# We need to store rf_k_bar and RFA-results that
|
| 170 |
+
# compute the per-chunk RFA.
|
| 171 |
+
|
| 172 |
+
# Dump the chunk if the len of current chunk reaches <chunk_size>.
|
| 173 |
+
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size:
|
| 174 |
+
dump_rf_mask = cur_rf_mask
|
| 175 |
+
cur_rf_mask = None
|
| 176 |
+
|
| 177 |
+
self.rf_mask[layer_idx] = cur_rf_mask
|
| 178 |
+
|
| 179 |
+
############################################
|
| 180 |
+
# compute masks for inter chunks
|
| 181 |
+
############################################
|
| 182 |
+
if len(self.chunk_mask) <= layer_idx:
|
| 183 |
+
# prefill stage
|
| 184 |
+
# q is of shape [b, h, n, d]
|
| 185 |
+
if q_len < window_size:
|
| 186 |
+
cur_chunk_mask = chunk_mask
|
| 187 |
+
prev_chunk_mask = None
|
| 188 |
+
else:
|
| 189 |
+
if q_len % window_size == 0:
|
| 190 |
+
cur_chunk_mask = None
|
| 191 |
+
prev_chunk_mask = chunk_mask
|
| 192 |
+
else:
|
| 193 |
+
remainder_tokens = q_len % window_size
|
| 194 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 195 |
+
prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 196 |
+
bsz, num_heads, _, head_dim = prev_chunk_mask.shape
|
| 197 |
+
prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 198 |
+
|
| 199 |
+
assert prev_s_mask is not None
|
| 200 |
+
if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1:
|
| 201 |
+
# need to expand
|
| 202 |
+
prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
|
| 203 |
+
# w_v = # [b, h, 1, j, d]
|
| 204 |
+
# store the past window-wise key-value pairs
|
| 205 |
+
self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :])
|
| 206 |
+
else:
|
| 207 |
+
# decoding stage
|
| 208 |
+
prev_chunk_mask = None
|
| 209 |
+
cur_chunk_mask = self.chunk_mask[layer_idx]
|
| 210 |
+
|
| 211 |
+
# if the current sequence length reaches <chunk_size>,
|
| 212 |
+
# we append a new 1 to the end of chunk_mask
|
| 213 |
+
seen_seq_len = self.get_seq_length(layer_idx)
|
| 214 |
+
if seen_seq_len > 0 and seen_seq_len % chunk_size == 0:
|
| 215 |
+
past_chunk_mask = self.chunk_mask[layer_idx]
|
| 216 |
+
if past_chunk_mask is not None:
|
| 217 |
+
# when decoding tokens, we always assume the
|
| 218 |
+
# incoming token mask is 0 (not masked)
|
| 219 |
+
cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1)
|
| 220 |
+
else:
|
| 221 |
+
cur_chunk_mask = chunk_mask
|
| 222 |
+
self.chunk_mask[layer_idx] = cur_chunk_mask
|
| 223 |
+
|
| 224 |
+
# if the len of current sequence reaches <window_size> + 1,
|
| 225 |
+
# we turn on the mask for most recent chunks
|
| 226 |
+
if seen_seq_len > 0 and seen_seq_len % window_size == 1:
|
| 227 |
+
cur_chunk_mask = self.chunk_mask[layer_idx]
|
| 228 |
+
# we do not need to use rf_mask anymore after we receive generated tokens
|
| 229 |
+
num_chunks_per_window = window_size // chunk_size
|
| 230 |
+
cur_chunk_mask[..., -num_chunks_per_window:] = False
|
| 231 |
+
self.chunk_mask[layer_idx] = cur_chunk_mask
|
| 232 |
+
|
| 233 |
+
return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask)
|
| 234 |
+
|
| 235 |
+
def update_singletons(
|
| 236 |
+
self,
|
| 237 |
+
q,
|
| 238 |
+
k,
|
| 239 |
+
v,
|
| 240 |
+
layer_idx,
|
| 241 |
+
window_size,
|
| 242 |
+
):
|
| 243 |
+
if len(self.w_k) <= layer_idx:
|
| 244 |
+
# prefill stage
|
| 245 |
+
# q is of shape [b, h, n, d]
|
| 246 |
+
q_len = q.shape[-2]
|
| 247 |
+
if q_len < window_size:
|
| 248 |
+
w_q = q
|
| 249 |
+
w_k = k
|
| 250 |
+
w_v = v
|
| 251 |
+
past_w_q = past_w_k = past_w_v = None
|
| 252 |
+
else:
|
| 253 |
+
if q_len % window_size == 0:
|
| 254 |
+
w_q = None
|
| 255 |
+
w_k = None
|
| 256 |
+
w_v = None
|
| 257 |
+
past_w_q = q
|
| 258 |
+
past_w_k = k
|
| 259 |
+
past_w_v = v
|
| 260 |
+
else:
|
| 261 |
+
remainder_tokens = q_len % window_size
|
| 262 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 263 |
+
past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 264 |
+
past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 265 |
+
past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 266 |
+
bsz, num_heads, _, head_dim = past_w_q.shape
|
| 267 |
+
past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 268 |
+
past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 269 |
+
past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim)
|
| 270 |
+
# w_q = q[..., None, -window_size:, :] # [b, h, 1, j, d]
|
| 271 |
+
# w_k = # [b, h, 1, j, d]
|
| 272 |
+
# w_v = # [b, h, 1, j, d]
|
| 273 |
+
# store the past window-wise key-value pairs
|
| 274 |
+
# if w_k is None, it means we happen to pass in a sqeuence that is divisible by window_size
|
| 275 |
+
# we leave the cache with window_size-sized kv pairs to be cleared next iteration
|
| 276 |
+
self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :])
|
| 277 |
+
self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :])
|
| 278 |
+
else:
|
| 279 |
+
# decoding stage
|
| 280 |
+
past_w_q = past_w_k = past_w_v = None
|
| 281 |
+
# this is implemented as either a sliding window or fixed window
|
| 282 |
+
w_q = q # [b, h, 1, d]
|
| 283 |
+
w_k = k # [b, h, 1, d]
|
| 284 |
+
w_v = v # [b, h, 1, d]
|
| 285 |
+
|
| 286 |
+
cached_w_k = self.w_k[layer_idx]
|
| 287 |
+
assert cached_w_k is not None # [b, h, j, d]
|
| 288 |
+
if cached_w_k.shape[-2] == window_size:
|
| 289 |
+
w_k = w_k
|
| 290 |
+
else:
|
| 291 |
+
w_k = torch.cat([cached_w_k, w_k], dim=-2)
|
| 292 |
+
|
| 293 |
+
cached_w_v = self.w_v[layer_idx]
|
| 294 |
+
assert cached_w_v is not None
|
| 295 |
+
if cached_w_v.shape[-2] == window_size:
|
| 296 |
+
w_v = w_v
|
| 297 |
+
else:
|
| 298 |
+
w_v = torch.cat([cached_w_v, w_v], dim=-2)
|
| 299 |
+
|
| 300 |
+
# store the past window-wise key-value pairs
|
| 301 |
+
self.w_k[layer_idx] = w_k
|
| 302 |
+
self.w_v[layer_idx] = w_v
|
| 303 |
+
return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v)
|
| 304 |
+
|
| 305 |
+
def update_chunks(
|
| 306 |
+
self,
|
| 307 |
+
q,
|
| 308 |
+
k,
|
| 309 |
+
v,
|
| 310 |
+
layer_idx,
|
| 311 |
+
chunk_size
|
| 312 |
+
):
|
| 313 |
+
q_len = q.shape[-2]
|
| 314 |
+
dump_q = None
|
| 315 |
+
dump_k = None
|
| 316 |
+
dump_v = None
|
| 317 |
+
if len(self.rf_q) <= layer_idx:
|
| 318 |
+
# initialize chunk stats
|
| 319 |
+
# prefill stage
|
| 320 |
+
if q_len < chunk_size:
|
| 321 |
+
rf_q = q
|
| 322 |
+
rf_k = k
|
| 323 |
+
rf_v = v
|
| 324 |
+
else:
|
| 325 |
+
if q_len % chunk_size == 0:
|
| 326 |
+
rf_q = None
|
| 327 |
+
rf_k = None
|
| 328 |
+
rf_v = None
|
| 329 |
+
dump_q = q
|
| 330 |
+
dump_k = k
|
| 331 |
+
dump_v = v
|
| 332 |
+
else:
|
| 333 |
+
remainder_tokens = q_len % chunk_size
|
| 334 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 335 |
+
dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 336 |
+
dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 337 |
+
dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 338 |
+
self.rf_q.append(rf_q)
|
| 339 |
+
self.rf_k.append(rf_k)
|
| 340 |
+
self.rf_v.append(rf_v)
|
| 341 |
+
else:
|
| 342 |
+
# decode tokens
|
| 343 |
+
# add query, key & value to the current chunk.
|
| 344 |
+
past_rf_q = self.rf_q[layer_idx]
|
| 345 |
+
if past_rf_q is not None:
|
| 346 |
+
rf_q = torch.cat([past_rf_q, q], dim=-2)
|
| 347 |
+
else:
|
| 348 |
+
rf_q = q
|
| 349 |
+
|
| 350 |
+
past_rf_k = self.rf_k[layer_idx]
|
| 351 |
+
if past_rf_k is not None:
|
| 352 |
+
rf_k = torch.cat([past_rf_k, k], dim=-2)
|
| 353 |
+
else:
|
| 354 |
+
rf_k = k
|
| 355 |
+
|
| 356 |
+
past_rf_v = self.rf_v[layer_idx]
|
| 357 |
+
if past_rf_v is not None:
|
| 358 |
+
rf_v = torch.cat([past_rf_v, v], dim=-2)
|
| 359 |
+
else:
|
| 360 |
+
rf_v = v
|
| 361 |
+
|
| 362 |
+
# We need to store rf_k_bar and RFA-results that
|
| 363 |
+
# compute the per-chunk RFA.
|
| 364 |
+
|
| 365 |
+
# Dump the chunk if the len of current chunk reaches <chunk_size>.
|
| 366 |
+
if rf_q.shape[-2] == chunk_size:
|
| 367 |
+
dump_q = rf_q
|
| 368 |
+
dump_k = rf_k
|
| 369 |
+
dump_v = rf_v
|
| 370 |
+
# clear the chunk
|
| 371 |
+
rf_q = None
|
| 372 |
+
rf_k = None
|
| 373 |
+
rf_v = None
|
| 374 |
+
|
| 375 |
+
self.rf_q[layer_idx] = rf_q
|
| 376 |
+
self.rf_k[layer_idx] = rf_k
|
| 377 |
+
self.rf_v[layer_idx] = rf_v
|
| 378 |
+
|
| 379 |
+
return dump_q, dump_k, dump_v
|
| 380 |
+
|
| 381 |
+
def update_chunk_rfas(
|
| 382 |
+
self,
|
| 383 |
+
softmax_phi_k_v,
|
| 384 |
+
log_sum_phi_k,
|
| 385 |
+
rf_k_bar,
|
| 386 |
+
layer_idx,
|
| 387 |
+
random_feature_dim
|
| 388 |
+
):
|
| 389 |
+
if len(self.softmax_phi_k_v) <= layer_idx:
|
| 390 |
+
# prefill stage
|
| 391 |
+
self.softmax_phi_k_v.append(softmax_phi_k_v)
|
| 392 |
+
self.log_sum_phi_k.append(log_sum_phi_k)
|
| 393 |
+
self.rf_k_bar.append(rf_k_bar)
|
| 394 |
+
else:
|
| 395 |
+
# token decoding
|
| 396 |
+
past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx]
|
| 397 |
+
past_log_sum_phi_k = self.log_sum_phi_k[layer_idx]
|
| 398 |
+
past_rf_k_bar = self.rf_k_bar[layer_idx]
|
| 399 |
+
|
| 400 |
+
if past_softmax_phi_k_v is not None:
|
| 401 |
+
if random_feature_dim == 1:
|
| 402 |
+
dim = -2
|
| 403 |
+
else:
|
| 404 |
+
dim = -3
|
| 405 |
+
softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim)
|
| 406 |
+
|
| 407 |
+
if past_log_sum_phi_k is not None:
|
| 408 |
+
if random_feature_dim == 1:
|
| 409 |
+
dim = -2
|
| 410 |
+
else:
|
| 411 |
+
dim = -3
|
| 412 |
+
log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim)
|
| 413 |
+
|
| 414 |
+
if past_rf_k_bar is not None:
|
| 415 |
+
rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2)
|
| 416 |
+
|
| 417 |
+
self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v
|
| 418 |
+
self.log_sum_phi_k[layer_idx] = log_sum_phi_k
|
| 419 |
+
self.rf_k_bar[layer_idx] = rf_k_bar
|
| 420 |
+
|
| 421 |
+
return softmax_phi_k_v, log_sum_phi_k, rf_k_bar
|
| 422 |
+
|
| 423 |
+
def get_chunk_rfas(self, layer_idx):
|
| 424 |
+
if len(self.softmax_phi_k_v) <= layer_idx:
|
| 425 |
+
return (
|
| 426 |
+
None,
|
| 427 |
+
None,
|
| 428 |
+
None
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
return (
|
| 432 |
+
self.softmax_phi_k_v[layer_idx],
|
| 433 |
+
self.log_sum_phi_k[layer_idx],
|
| 434 |
+
self.rf_k_bar[layer_idx]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 438 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 439 |
+
if len(self.w_k) <= layer_idx:
|
| 440 |
+
return 0
|
| 441 |
+
return self._seen_tokens
|
| 442 |
+
|
| 443 |
+
def get_max_length(self) -> Optional[int]:
|
| 444 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
def update(
|
| 448 |
+
self,
|
| 449 |
+
layer_idx: int,
|
| 450 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 451 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 452 |
+
raise NotImplementedError("`update` is not used in Eva Cache.")
|
| 453 |
+
|
| 454 |
+
class EvaStaticCacheForTriton(Cache):
|
| 455 |
+
"""
|
| 456 |
+
A variant of EvaCache for eva's triton kernels
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def __init__(
|
| 460 |
+
self,
|
| 461 |
+
batch_size,
|
| 462 |
+
num_key_value_heads,
|
| 463 |
+
window_size,
|
| 464 |
+
head_dim,
|
| 465 |
+
num_layers,
|
| 466 |
+
dtype,
|
| 467 |
+
device
|
| 468 |
+
) -> None:
|
| 469 |
+
self.past_window_k: List[torch.Tensor] = []
|
| 470 |
+
self.past_window_v: List[torch.Tensor] = []
|
| 471 |
+
|
| 472 |
+
cache_shape = (batch_size, num_key_value_heads, window_size, head_dim)
|
| 473 |
+
for idx in range(num_layers):
|
| 474 |
+
new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device)
|
| 475 |
+
new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device)
|
| 476 |
+
self.past_window_k.append(new_window_k)
|
| 477 |
+
self.past_window_v.append(new_window_v)
|
| 478 |
+
|
| 479 |
+
self.past_window_pos: List[int] = []
|
| 480 |
+
|
| 481 |
+
self.rfa_k: List[torch.Tensor] = []
|
| 482 |
+
self.rfa_v: List[torch.Tensor] = []
|
| 483 |
+
# self.rfa_mask: List[torch.Tensor] = []
|
| 484 |
+
|
| 485 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 486 |
+
|
| 487 |
+
# attention masks temporary buffer
|
| 488 |
+
self.rf_mask: List[Optional[torch.Tensor]] = []
|
| 489 |
+
self.s_mask: List[torch.Tensor] = []
|
| 490 |
+
|
| 491 |
+
def __len__(self):
|
| 492 |
+
"""
|
| 493 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
| 494 |
+
to the number of layers in the model.
|
| 495 |
+
"""
|
| 496 |
+
return len(self.past_window_pos)
|
| 497 |
+
|
| 498 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
| 499 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
| 500 |
+
# Cache without size limit -> all cache is usable
|
| 501 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
| 502 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
| 503 |
+
max_length = self.get_max_length()
|
| 504 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 505 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 506 |
+
return max_length - new_seq_length
|
| 507 |
+
return previous_seq_length
|
| 508 |
+
|
| 509 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 510 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 511 |
+
for layer_idx in range(len(self.past_window_k)):
|
| 512 |
+
device = self.past_window_k[layer_idx].device
|
| 513 |
+
self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 514 |
+
|
| 515 |
+
device = self.past_window_v[layer_idx].device
|
| 516 |
+
self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 517 |
+
|
| 518 |
+
device = self.rfa_k[layer_idx].device
|
| 519 |
+
self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device))
|
| 520 |
+
|
| 521 |
+
device = self.rfa_v[layer_idx].device
|
| 522 |
+
self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device))
|
| 523 |
+
|
| 524 |
+
# device = self.rfa_mask[layer_idx].device
|
| 525 |
+
# self.rfa_mask[layer_idx] = self.rfa_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 526 |
+
|
| 527 |
+
device = self.rf_mask[layer_idx].device
|
| 528 |
+
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 529 |
+
|
| 530 |
+
device = self.s_mask[layer_idx].device
|
| 531 |
+
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def seen_tokens(self):
|
| 535 |
+
if hasattr(self, "_seen_tokens"):
|
| 536 |
+
return self._seen_tokens
|
| 537 |
+
else:
|
| 538 |
+
return None
|
| 539 |
+
|
| 540 |
+
def update_past_len(
|
| 541 |
+
self,
|
| 542 |
+
cur_q_len: int,
|
| 543 |
+
layer_idx: int
|
| 544 |
+
):
|
| 545 |
+
# Update the number of seen tokens
|
| 546 |
+
if layer_idx == 0:
|
| 547 |
+
self._seen_tokens += cur_q_len
|
| 548 |
+
return self._seen_tokens
|
| 549 |
+
|
| 550 |
+
def update_mask(
|
| 551 |
+
self,
|
| 552 |
+
s_mask,
|
| 553 |
+
rf_mask,
|
| 554 |
+
layer_idx,
|
| 555 |
+
window_size,
|
| 556 |
+
):
|
| 557 |
+
############################################
|
| 558 |
+
# compute masks for singletons
|
| 559 |
+
############################################
|
| 560 |
+
if len(self.s_mask) <= layer_idx:
|
| 561 |
+
# prefill stage
|
| 562 |
+
# q is of shape [b, h, n, d]
|
| 563 |
+
# s_v = # [b, h, 1, j, d]
|
| 564 |
+
# store the past window-wise key-value pairs
|
| 565 |
+
if s_mask is None:
|
| 566 |
+
cur_s_mask = None
|
| 567 |
+
else:
|
| 568 |
+
q_len = s_mask.shape[-2]
|
| 569 |
+
# s_mask is of shape [b, h, n, w]
|
| 570 |
+
# let r = q_len % window_size
|
| 571 |
+
# if r == 0, the mask to be appended is of shape [b, h, 1, w]
|
| 572 |
+
# otherwise, r < w, the mask to be appended is of shape [b, h, 1, r]
|
| 573 |
+
remainder_tokens = q_len % window_size
|
| 574 |
+
if remainder_tokens == 0:
|
| 575 |
+
cur_s_mask = None
|
| 576 |
+
else:
|
| 577 |
+
cur_s_mask = s_mask[..., -1:, :remainder_tokens]
|
| 578 |
+
self.s_mask.append(cur_s_mask)
|
| 579 |
+
# we use the passed s_mask for subsequent computations
|
| 580 |
+
dump_s_mask = s_mask
|
| 581 |
+
else:
|
| 582 |
+
# decoding stage
|
| 583 |
+
past_s_mask = self.s_mask[layer_idx]
|
| 584 |
+
if past_s_mask is None:
|
| 585 |
+
assert s_mask is None
|
| 586 |
+
cur_s_mask = None
|
| 587 |
+
else:
|
| 588 |
+
assert s_mask is not None
|
| 589 |
+
cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1)
|
| 590 |
+
|
| 591 |
+
dump_s_mask = cur_s_mask
|
| 592 |
+
if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size:
|
| 593 |
+
cur_s_mask = None
|
| 594 |
+
# store the past window-wise key-value pairs
|
| 595 |
+
self.s_mask[layer_idx] = cur_s_mask
|
| 596 |
+
|
| 597 |
+
############################################
|
| 598 |
+
# compute masks for intra-chunks
|
| 599 |
+
############################################
|
| 600 |
+
dump_rf_mask = None
|
| 601 |
+
if len(self.rf_mask) <= layer_idx:
|
| 602 |
+
# initialize chunk stats
|
| 603 |
+
# prefill stage
|
| 604 |
+
if rf_mask is None:
|
| 605 |
+
cur_rf_mask = None
|
| 606 |
+
else:
|
| 607 |
+
q_len = rf_mask.shape[-2]
|
| 608 |
+
if q_len < window_size:
|
| 609 |
+
dump_rf_mask = None
|
| 610 |
+
cur_rf_mask = rf_mask
|
| 611 |
+
else:
|
| 612 |
+
if q_len % window_size == 0:
|
| 613 |
+
dump_rf_mask = rf_mask
|
| 614 |
+
cur_rf_mask = None
|
| 615 |
+
else:
|
| 616 |
+
remainder_tokens = q_len % window_size
|
| 617 |
+
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 618 |
+
self.rf_mask.append(cur_rf_mask)
|
| 619 |
+
else:
|
| 620 |
+
past_rf_mask = self.rf_mask[layer_idx]
|
| 621 |
+
if past_rf_mask is not None:
|
| 622 |
+
# when decoding tokens, we always assume the
|
| 623 |
+
# incoming token mask is 0 (not masked)
|
| 624 |
+
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
|
| 625 |
+
else:
|
| 626 |
+
cur_rf_mask = None
|
| 627 |
+
|
| 628 |
+
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size:
|
| 629 |
+
dump_rf_mask = cur_rf_mask
|
| 630 |
+
cur_rf_mask = None
|
| 631 |
+
|
| 632 |
+
self.rf_mask[layer_idx] = cur_rf_mask
|
| 633 |
+
|
| 634 |
+
return dump_s_mask, dump_rf_mask
|
| 635 |
+
|
| 636 |
+
def update_singletons_and_chunks(
|
| 637 |
+
self,
|
| 638 |
+
k,
|
| 639 |
+
v,
|
| 640 |
+
layer_idx,
|
| 641 |
+
window_size,
|
| 642 |
+
):
|
| 643 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 644 |
+
# prefill stage
|
| 645 |
+
s_k = k
|
| 646 |
+
s_v = v
|
| 647 |
+
input_len = k.shape[-2]
|
| 648 |
+
window_pos = 0
|
| 649 |
+
if input_len <= window_size:
|
| 650 |
+
new_window_pos = window_pos + input_len
|
| 651 |
+
|
| 652 |
+
cached_window_k = k
|
| 653 |
+
cached_window_v = v
|
| 654 |
+
dump_k = None
|
| 655 |
+
dump_v = None
|
| 656 |
+
else:
|
| 657 |
+
remainder_tokens = input_len % window_size
|
| 658 |
+
if remainder_tokens == 0:
|
| 659 |
+
remainder_tokens = window_size
|
| 660 |
+
new_window_pos = window_pos + remainder_tokens
|
| 661 |
+
|
| 662 |
+
# [b, h, n-r, d] [b, h, r, d]
|
| 663 |
+
cached_window_k = k[..., -remainder_tokens:, :]
|
| 664 |
+
cached_window_v = v[..., -remainder_tokens:, :]
|
| 665 |
+
dump_k = k[..., :-remainder_tokens, :]
|
| 666 |
+
dump_v = v[..., :-remainder_tokens, :]
|
| 667 |
+
# store the past window-wise key-value pairs
|
| 668 |
+
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k
|
| 669 |
+
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v
|
| 670 |
+
self.past_window_pos.append(new_window_pos)
|
| 671 |
+
else:
|
| 672 |
+
# decoding stage
|
| 673 |
+
# if the previous cache has full tokens,
|
| 674 |
+
# roll back to the first elements
|
| 675 |
+
if self.past_window_pos[layer_idx] == window_size:
|
| 676 |
+
self.past_window_pos[layer_idx] = 0
|
| 677 |
+
dump_k = self.past_window_k[layer_idx].clone()
|
| 678 |
+
dump_v = self.past_window_v[layer_idx].clone()
|
| 679 |
+
else:
|
| 680 |
+
dump_k = None
|
| 681 |
+
dump_v = None
|
| 682 |
+
|
| 683 |
+
input_len = k.shape[-2]
|
| 684 |
+
window_pos = self.past_window_pos[layer_idx]
|
| 685 |
+
new_window_pos = window_pos + input_len
|
| 686 |
+
|
| 687 |
+
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k
|
| 688 |
+
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v
|
| 689 |
+
|
| 690 |
+
s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :]
|
| 691 |
+
s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :]
|
| 692 |
+
|
| 693 |
+
self.past_window_pos[layer_idx] = new_window_pos
|
| 694 |
+
|
| 695 |
+
return s_k, s_v, dump_k, dump_v
|
| 696 |
+
|
| 697 |
+
def update_chunk_rfas(
|
| 698 |
+
self,
|
| 699 |
+
rfa_k,
|
| 700 |
+
rfa_v,
|
| 701 |
+
layer_idx,
|
| 702 |
+
):
|
| 703 |
+
if len(self.rfa_k) <= layer_idx:
|
| 704 |
+
# prefill stage
|
| 705 |
+
self.rfa_k.append(rfa_k)
|
| 706 |
+
self.rfa_v.append(rfa_v)
|
| 707 |
+
else:
|
| 708 |
+
# token decoding
|
| 709 |
+
past_rfa_k = self.rfa_k[layer_idx]
|
| 710 |
+
past_rfa_v = self.rfa_v[layer_idx]
|
| 711 |
+
|
| 712 |
+
if past_rfa_k is not None:
|
| 713 |
+
rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2)
|
| 714 |
+
|
| 715 |
+
if past_rfa_v is not None:
|
| 716 |
+
rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2)
|
| 717 |
+
|
| 718 |
+
self.rfa_k[layer_idx] = rfa_k
|
| 719 |
+
self.rfa_v[layer_idx] = rfa_v
|
| 720 |
+
|
| 721 |
+
return rfa_k, rfa_v
|
| 722 |
+
|
| 723 |
+
def get_past_window_pos(self, layer_idx):
|
| 724 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 725 |
+
return None
|
| 726 |
+
else:
|
| 727 |
+
return self.past_window_pos[layer_idx]
|
| 728 |
+
|
| 729 |
+
def get_past_window_kv(self, layer_idx):
|
| 730 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 731 |
+
return None, None
|
| 732 |
+
else:
|
| 733 |
+
return (
|
| 734 |
+
self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :],
|
| 735 |
+
self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :]
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
def get_chunk_rfas(self, layer_idx):
|
| 739 |
+
if len(self.rfa_k) <= layer_idx:
|
| 740 |
+
return None, None
|
| 741 |
+
else:
|
| 742 |
+
return self.rfa_k[layer_idx], self.rfa_v[layer_idx]
|
| 743 |
+
|
| 744 |
+
def get_seq_length(self, layer_idx = 0) -> int:
|
| 745 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 746 |
+
# layer_idx must be provided since otherwise
|
| 747 |
+
# any layer > 0 can only get the updated _seen_tokens
|
| 748 |
+
if len(self.past_window_pos) <= layer_idx:
|
| 749 |
+
return 0
|
| 750 |
+
return self._seen_tokens
|
| 751 |
+
|
| 752 |
+
def get_max_length(self) -> Optional[int]:
|
| 753 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
| 754 |
+
return None
|
| 755 |
+
|
| 756 |
+
def update(
|
| 757 |
+
self,
|
| 758 |
+
layer_idx: int,
|
| 759 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 760 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 761 |
+
raise NotImplementedError("`update` is not used in Eva Cache.")
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_prep_kv_kernel.py
ADDED
|
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
@triton.heuristics(
|
| 8 |
+
{
|
| 9 |
+
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
|
| 10 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 11 |
+
}
|
| 12 |
+
)
|
| 13 |
+
@triton.jit
|
| 14 |
+
def _fwd_eva_prep_kv_kernel(
|
| 15 |
+
K, # [b, h, n, d]
|
| 16 |
+
V, # [b, h, n, d]
|
| 17 |
+
PARAM_MU, # [1, h, 1, 1, d]
|
| 18 |
+
PARAM_PHI, # [1, h, 1, 1, d]
|
| 19 |
+
Mask, # [b, h, n, 1]
|
| 20 |
+
Out_RFA_K, # [b, h, c, d]
|
| 21 |
+
Out_RFA_V, # [b, h, c, d]
|
| 22 |
+
softmax_scale,
|
| 23 |
+
stride_kb, stride_kh, stride_kn,
|
| 24 |
+
stride_vb, stride_vh, stride_vn,
|
| 25 |
+
stride_mu_h,
|
| 26 |
+
stride_phi_h,
|
| 27 |
+
stride_mb, stride_mn,
|
| 28 |
+
stride_ok_b, stride_ok_h, stride_ok_c,
|
| 29 |
+
stride_ov_b, stride_ov_h, stride_ov_c,
|
| 30 |
+
nheads,
|
| 31 |
+
seqlen,
|
| 32 |
+
nchunks,
|
| 33 |
+
headdim,
|
| 34 |
+
CHUNKS_PER_BLOCK: tl.constexpr,
|
| 35 |
+
CHUNK_SIZE: tl.constexpr,
|
| 36 |
+
MASK_TYPE: tl.constexpr,
|
| 37 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 38 |
+
EVEN_N: tl.constexpr,
|
| 39 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 40 |
+
BLOCK_N: tl.constexpr,
|
| 41 |
+
):
|
| 42 |
+
start_n = tl.program_id(0)
|
| 43 |
+
offs_bh = tl.program_id(1)
|
| 44 |
+
offs_h = offs_bh % nheads
|
| 45 |
+
offs_b = offs_bh // nheads
|
| 46 |
+
# initialize offsets
|
| 47 |
+
# we load BLOCK_N keys and values each time, and
|
| 48 |
+
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
|
| 49 |
+
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
|
| 50 |
+
offs_m = tl.arange(0, CHUNK_SIZE)
|
| 51 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 52 |
+
|
| 53 |
+
k_ptrs = (
|
| 54 |
+
K +
|
| 55 |
+
offs_b * stride_kb +
|
| 56 |
+
offs_h * stride_kh +
|
| 57 |
+
(
|
| 58 |
+
(
|
| 59 |
+
start_n * BLOCK_N +
|
| 60 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 61 |
+
offs_m[None, :, None]
|
| 62 |
+
) * stride_kn +
|
| 63 |
+
offs_d[None, None, :]
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
v_ptrs = (
|
| 67 |
+
V +
|
| 68 |
+
offs_b * stride_vb +
|
| 69 |
+
offs_h * stride_vh +
|
| 70 |
+
(
|
| 71 |
+
(
|
| 72 |
+
start_n * BLOCK_N +
|
| 73 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 74 |
+
offs_m[None, :, None]
|
| 75 |
+
) * stride_vn +
|
| 76 |
+
offs_d[None, None, :]
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
param_mu_ptrs = (
|
| 80 |
+
PARAM_MU +
|
| 81 |
+
offs_h * stride_mu_h +
|
| 82 |
+
offs_d[None, None, :]
|
| 83 |
+
)
|
| 84 |
+
param_phi_ptrs = (
|
| 85 |
+
PARAM_PHI +
|
| 86 |
+
offs_h * stride_phi_h +
|
| 87 |
+
offs_d[None, None, :]
|
| 88 |
+
)
|
| 89 |
+
log2e = 1.4426950408889634
|
| 90 |
+
if MASK_TYPE == 1:
|
| 91 |
+
m_ptrs = (
|
| 92 |
+
Mask +
|
| 93 |
+
offs_b * stride_mb +
|
| 94 |
+
(
|
| 95 |
+
(
|
| 96 |
+
start_n * BLOCK_N +
|
| 97 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 98 |
+
offs_m[None, :]
|
| 99 |
+
) * stride_mn
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
if EVEN_N:
|
| 103 |
+
if EVEN_HEADDIM:
|
| 104 |
+
k = tl.load(
|
| 105 |
+
k_ptrs
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
k = tl.load(
|
| 109 |
+
k_ptrs,
|
| 110 |
+
mask=offs_d[None, None, :] < headdim,
|
| 111 |
+
other=0.0
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
if EVEN_HEADDIM:
|
| 115 |
+
k = tl.load(
|
| 116 |
+
k_ptrs,
|
| 117 |
+
mask=(
|
| 118 |
+
start_n * BLOCK_N +
|
| 119 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 120 |
+
offs_m[None, :, None]
|
| 121 |
+
) < seqlen,
|
| 122 |
+
other=0.0
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
k = tl.load(
|
| 126 |
+
k_ptrs,
|
| 127 |
+
mask=(
|
| 128 |
+
(
|
| 129 |
+
start_n * BLOCK_N +
|
| 130 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 131 |
+
offs_m[None, :, None]
|
| 132 |
+
) < seqlen
|
| 133 |
+
) & (offs_d[None, None, :] < headdim),
|
| 134 |
+
other=0.0
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
|
| 138 |
+
rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 139 |
+
rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
|
| 140 |
+
rfa_k_c_w *= log2e
|
| 141 |
+
if MASK_TYPE == 1:
|
| 142 |
+
if EVEN_N:
|
| 143 |
+
mask = tl.load(
|
| 144 |
+
m_ptrs
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
mask = tl.load(
|
| 148 |
+
m_ptrs,
|
| 149 |
+
mask=(
|
| 150 |
+
start_n * BLOCK_N +
|
| 151 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 152 |
+
offs_m[None, :]
|
| 153 |
+
) < seqlen,
|
| 154 |
+
other=1,
|
| 155 |
+
)
|
| 156 |
+
rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
|
| 157 |
+
|
| 158 |
+
m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
|
| 159 |
+
masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
|
| 160 |
+
m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
|
| 161 |
+
rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
|
| 162 |
+
denom_k = tl.sum(rfa_k_c_w, axis=-1)
|
| 163 |
+
denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
|
| 164 |
+
rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
|
| 165 |
+
rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
|
| 166 |
+
# TODO: understand why rematerialize offsets to save registers?
|
| 167 |
+
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
|
| 168 |
+
out_rfa_k_ptrs = (
|
| 169 |
+
Out_RFA_K +
|
| 170 |
+
offs_b * stride_ok_b +
|
| 171 |
+
offs_h * stride_ok_h +
|
| 172 |
+
(offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if EVEN_N:
|
| 176 |
+
if EVEN_HEADDIM:
|
| 177 |
+
tl.store(
|
| 178 |
+
out_rfa_k_ptrs, rfa_k_c
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
tl.store(
|
| 182 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 183 |
+
mask=offs_d[None, :] < headdim
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
if EVEN_HEADDIM:
|
| 187 |
+
tl.store(
|
| 188 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 189 |
+
mask=offs_out_c[:, None] < nchunks
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
tl.store(
|
| 193 |
+
out_rfa_k_ptrs, rfa_k_c,
|
| 194 |
+
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
|
| 199 |
+
rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 200 |
+
rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
|
| 201 |
+
rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
|
| 202 |
+
rfa_v_c_w *= log2e * softmax_scale
|
| 203 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 204 |
+
rfa_v_c_w += tl.where(
|
| 205 |
+
(
|
| 206 |
+
start_n * BLOCK_N +
|
| 207 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 208 |
+
offs_m[None, :]
|
| 209 |
+
) < seqlen,
|
| 210 |
+
0,
|
| 211 |
+
float("-inf")
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if MASK_TYPE == 1:
|
| 215 |
+
rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
|
| 216 |
+
|
| 217 |
+
if EVEN_N:
|
| 218 |
+
if EVEN_HEADDIM:
|
| 219 |
+
v = tl.load(
|
| 220 |
+
v_ptrs
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
v = tl.load(
|
| 224 |
+
v_ptrs,
|
| 225 |
+
mask=offs_d[None, None, :] < headdim,
|
| 226 |
+
other=0.0
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
if EVEN_HEADDIM:
|
| 230 |
+
v = tl.load(
|
| 231 |
+
v_ptrs,
|
| 232 |
+
mask=(
|
| 233 |
+
start_n * BLOCK_N +
|
| 234 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 235 |
+
offs_m[None, :, None]
|
| 236 |
+
) < seqlen,
|
| 237 |
+
other=0.0
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
v = tl.load(
|
| 241 |
+
v_ptrs,
|
| 242 |
+
mask=(
|
| 243 |
+
(
|
| 244 |
+
start_n * BLOCK_N +
|
| 245 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 246 |
+
offs_m[None, :, None]
|
| 247 |
+
) < seqlen
|
| 248 |
+
) & (offs_d[None, None, :] < headdim),
|
| 249 |
+
other=0.0
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
|
| 254 |
+
masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
|
| 255 |
+
m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
|
| 256 |
+
rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
|
| 257 |
+
denom_v = tl.sum(rfa_v_c_w, axis=-1)
|
| 258 |
+
denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
|
| 259 |
+
rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
|
| 260 |
+
rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
|
| 261 |
+
|
| 262 |
+
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
|
| 263 |
+
out_rfa_v_ptrs = (
|
| 264 |
+
Out_RFA_V +
|
| 265 |
+
offs_b * stride_ov_b +
|
| 266 |
+
offs_h * stride_ov_h +
|
| 267 |
+
(offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
|
| 268 |
+
)
|
| 269 |
+
if EVEN_N:
|
| 270 |
+
if EVEN_HEADDIM:
|
| 271 |
+
tl.store(
|
| 272 |
+
out_rfa_v_ptrs, rfa_v_c
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
tl.store(
|
| 276 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 277 |
+
mask=offs_d[None, :] < headdim
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
if EVEN_HEADDIM:
|
| 281 |
+
tl.store(
|
| 282 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 283 |
+
mask=offs_out_c[:, None] < nchunks
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
tl.store(
|
| 287 |
+
out_rfa_v_ptrs, rfa_v_c,
|
| 288 |
+
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@triton.heuristics(
|
| 294 |
+
{
|
| 295 |
+
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
|
| 296 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
@triton.jit
|
| 300 |
+
def _bwd_eva_prep_kv_kernel(
|
| 301 |
+
RFA_K, # [b, h, c, d]
|
| 302 |
+
RFA_V, # [b, h, c, d]
|
| 303 |
+
K, # [b, h, n, d]
|
| 304 |
+
V, # [b, h, n, d]
|
| 305 |
+
PARAM_MU, # [1, h, 1, 1, d]
|
| 306 |
+
PARAM_PHI, # [1, h, 1, 1, d]
|
| 307 |
+
Mask, # [b, h, n, 1]
|
| 308 |
+
D_RFA_K, # [b, h, c, d]
|
| 309 |
+
D_RFA_V, # [b, h, c, d]
|
| 310 |
+
D_K, # [b, h, n, d]
|
| 311 |
+
D_V, # [b, h, n, d]
|
| 312 |
+
D_PARAM_MU_PARTIAL, # [b, h, g, d]
|
| 313 |
+
D_PARAM_PHI_PARTIAL, # [b, h, g, d]
|
| 314 |
+
softmax_scale,
|
| 315 |
+
stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
|
| 316 |
+
stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
|
| 317 |
+
stride_kb, stride_kh, stride_kn,
|
| 318 |
+
stride_vb, stride_vh, stride_vn,
|
| 319 |
+
stride_mu_h,
|
| 320 |
+
stride_phi_h,
|
| 321 |
+
stride_mb, stride_mn,
|
| 322 |
+
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
|
| 323 |
+
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
|
| 324 |
+
stride_d_k_b, stride_d_k_h, stride_d_k_n,
|
| 325 |
+
stride_d_v_b, stride_d_v_h, stride_d_v_n,
|
| 326 |
+
stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
|
| 327 |
+
stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
|
| 328 |
+
nheads,
|
| 329 |
+
seqlen,
|
| 330 |
+
nchunks,
|
| 331 |
+
headdim,
|
| 332 |
+
CHUNKS_PER_BLOCK: tl.constexpr,
|
| 333 |
+
CHUNK_SIZE: tl.constexpr,
|
| 334 |
+
MASK_TYPE: tl.constexpr,
|
| 335 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 336 |
+
EVEN_N: tl.constexpr,
|
| 337 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 338 |
+
BLOCK_N: tl.constexpr,
|
| 339 |
+
):
|
| 340 |
+
start_n = tl.program_id(0)
|
| 341 |
+
offs_bh = tl.program_id(1)
|
| 342 |
+
offs_h = offs_bh % nheads
|
| 343 |
+
offs_b = offs_bh // nheads
|
| 344 |
+
# initialize offsets
|
| 345 |
+
# we load BLOCK_N keys and values each time, and
|
| 346 |
+
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
|
| 347 |
+
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
|
| 348 |
+
offs_m = tl.arange(0, CHUNK_SIZE)
|
| 349 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 350 |
+
|
| 351 |
+
offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
|
| 352 |
+
|
| 353 |
+
k_ptrs = (
|
| 354 |
+
K +
|
| 355 |
+
offs_b * stride_kb +
|
| 356 |
+
offs_h * stride_kh +
|
| 357 |
+
(
|
| 358 |
+
(
|
| 359 |
+
start_n * BLOCK_N +
|
| 360 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 361 |
+
offs_m[None, :, None]
|
| 362 |
+
) * stride_kn +
|
| 363 |
+
offs_d[None, None, :]
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
rfa_k_ptrs = (
|
| 367 |
+
RFA_K +
|
| 368 |
+
offs_b * stride_rfa_k_b +
|
| 369 |
+
offs_h * stride_rfa_k_h +
|
| 370 |
+
(offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
|
| 371 |
+
)
|
| 372 |
+
rfa_v_ptrs = (
|
| 373 |
+
RFA_V +
|
| 374 |
+
offs_b * stride_rfa_v_b +
|
| 375 |
+
offs_h * stride_rfa_v_h +
|
| 376 |
+
(offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
d_rfa_k_ptrs = (
|
| 380 |
+
D_RFA_K +
|
| 381 |
+
offs_b * stride_d_rfa_k_b +
|
| 382 |
+
offs_h * stride_d_rfa_k_h +
|
| 383 |
+
(offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
|
| 384 |
+
)
|
| 385 |
+
d_rfa_v_ptrs = (
|
| 386 |
+
D_RFA_V +
|
| 387 |
+
offs_b * stride_d_rfa_v_b +
|
| 388 |
+
offs_h * stride_d_rfa_v_h +
|
| 389 |
+
(offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
param_mu_ptrs = (
|
| 393 |
+
PARAM_MU +
|
| 394 |
+
offs_h * stride_mu_h +
|
| 395 |
+
offs_d[None, None, :]
|
| 396 |
+
)
|
| 397 |
+
param_phi_ptrs = (
|
| 398 |
+
PARAM_PHI +
|
| 399 |
+
offs_h * stride_phi_h +
|
| 400 |
+
offs_d[None, None, :]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
log2e = 1.4426950408889634
|
| 404 |
+
if MASK_TYPE == 1:
|
| 405 |
+
m_ptrs = (
|
| 406 |
+
Mask +
|
| 407 |
+
offs_b * stride_mb +
|
| 408 |
+
(
|
| 409 |
+
(
|
| 410 |
+
start_n * BLOCK_N +
|
| 411 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 412 |
+
offs_m[None, :]
|
| 413 |
+
) * stride_mn
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
if EVEN_N:
|
| 417 |
+
if EVEN_HEADDIM:
|
| 418 |
+
k = tl.load(
|
| 419 |
+
k_ptrs
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
k = tl.load(
|
| 423 |
+
k_ptrs,
|
| 424 |
+
mask=offs_d[None, None, :] < headdim,
|
| 425 |
+
other=0.0
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
if EVEN_HEADDIM:
|
| 429 |
+
k = tl.load(
|
| 430 |
+
k_ptrs,
|
| 431 |
+
mask=(
|
| 432 |
+
start_n * BLOCK_N +
|
| 433 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 434 |
+
offs_m[None, :, None]
|
| 435 |
+
) < seqlen,
|
| 436 |
+
other=0.0
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
k = tl.load(
|
| 440 |
+
k_ptrs,
|
| 441 |
+
mask=(
|
| 442 |
+
(
|
| 443 |
+
start_n * BLOCK_N +
|
| 444 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 445 |
+
offs_m[None, :, None]
|
| 446 |
+
) < seqlen
|
| 447 |
+
) & (offs_d[None, None, :] < headdim),
|
| 448 |
+
other=0.0
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if EVEN_N:
|
| 452 |
+
if EVEN_HEADDIM:
|
| 453 |
+
rfa_k = tl.load(
|
| 454 |
+
rfa_k_ptrs
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
rfa_k = tl.load(
|
| 458 |
+
rfa_k_ptrs,
|
| 459 |
+
mask=offs_d[None, :] < headdim,
|
| 460 |
+
other=0.0
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
if EVEN_HEADDIM:
|
| 464 |
+
rfa_k = tl.load(
|
| 465 |
+
rfa_k_ptrs,
|
| 466 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 467 |
+
other=0.0
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
rfa_k = tl.load(
|
| 471 |
+
rfa_k_ptrs,
|
| 472 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 473 |
+
other=0.0
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
if EVEN_N:
|
| 477 |
+
if EVEN_HEADDIM:
|
| 478 |
+
d_rfa_k = tl.load(
|
| 479 |
+
d_rfa_k_ptrs
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
d_rfa_k = tl.load(
|
| 483 |
+
d_rfa_k_ptrs,
|
| 484 |
+
mask=offs_d[None, :] < headdim,
|
| 485 |
+
other=0.0
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
if EVEN_HEADDIM:
|
| 489 |
+
d_rfa_k = tl.load(
|
| 490 |
+
d_rfa_k_ptrs,
|
| 491 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 492 |
+
other=0.0
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
d_rfa_k = tl.load(
|
| 496 |
+
d_rfa_k_ptrs,
|
| 497 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 498 |
+
other=0.0
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
|
| 502 |
+
mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 503 |
+
mu_c_w += tl.sum(k * param_mu, axis=-1)
|
| 504 |
+
mu_c_w *= log2e
|
| 505 |
+
|
| 506 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 507 |
+
mu_c_w += tl.where(
|
| 508 |
+
(
|
| 509 |
+
start_n * BLOCK_N +
|
| 510 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 511 |
+
offs_m[None, :]
|
| 512 |
+
) < seqlen,
|
| 513 |
+
0,
|
| 514 |
+
float("-inf")
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if MASK_TYPE == 1:
|
| 518 |
+
if EVEN_N:
|
| 519 |
+
mask = tl.load(
|
| 520 |
+
m_ptrs
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
mask = tl.load(
|
| 524 |
+
m_ptrs,
|
| 525 |
+
mask=(
|
| 526 |
+
start_n * BLOCK_N +
|
| 527 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 528 |
+
offs_m[None, :]
|
| 529 |
+
) < seqlen,
|
| 530 |
+
other=1,
|
| 531 |
+
)
|
| 532 |
+
mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
|
| 533 |
+
|
| 534 |
+
# [c, w]
|
| 535 |
+
m_mu_c_w = tl.max(mu_c_w, axis=-1)
|
| 536 |
+
masked_out_rows_mu = (m_mu_c_w == float("-inf"))
|
| 537 |
+
m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
|
| 538 |
+
mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
|
| 539 |
+
denom_mu = tl.sum(mu_c_w, axis=-1)
|
| 540 |
+
denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
|
| 541 |
+
mu_tilde_c_w = mu_c_w / denom_mu[:, None]
|
| 542 |
+
mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
|
| 543 |
+
# [c, d] [c, w, d] -> [c, w]
|
| 544 |
+
d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
|
| 545 |
+
# [c, d] [c, d] -> [c]
|
| 546 |
+
d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
|
| 547 |
+
d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
|
| 548 |
+
|
| 549 |
+
# [c, w] [c, w, d] -> [d]
|
| 550 |
+
d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
|
| 551 |
+
# [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
|
| 552 |
+
d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
|
| 553 |
+
|
| 554 |
+
d_param_mu_partial_ptrs = (
|
| 555 |
+
D_PARAM_MU_PARTIAL +
|
| 556 |
+
offs_b * stride_d_mu_b +
|
| 557 |
+
offs_h * stride_d_mu_h +
|
| 558 |
+
start_n * stride_d_mu_g +
|
| 559 |
+
offs_d
|
| 560 |
+
)
|
| 561 |
+
if EVEN_HEADDIM:
|
| 562 |
+
tl.store(
|
| 563 |
+
d_param_mu_partial_ptrs, d_param_mu
|
| 564 |
+
)
|
| 565 |
+
else:
|
| 566 |
+
tl.store(
|
| 567 |
+
d_param_mu_partial_ptrs, d_param_mu,
|
| 568 |
+
mask=offs_d < headdim
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
v_ptrs = (
|
| 573 |
+
V +
|
| 574 |
+
offs_b * stride_vb +
|
| 575 |
+
offs_h * stride_vh +
|
| 576 |
+
(
|
| 577 |
+
(
|
| 578 |
+
start_n * BLOCK_N +
|
| 579 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 580 |
+
offs_m[None, :, None]
|
| 581 |
+
) * stride_vn +
|
| 582 |
+
offs_d[None, None, :]
|
| 583 |
+
)
|
| 584 |
+
)
|
| 585 |
+
if EVEN_N:
|
| 586 |
+
if EVEN_HEADDIM:
|
| 587 |
+
v = tl.load(
|
| 588 |
+
v_ptrs
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
v = tl.load(
|
| 592 |
+
v_ptrs,
|
| 593 |
+
mask=offs_d[None, None, :] < headdim,
|
| 594 |
+
other=0.0
|
| 595 |
+
)
|
| 596 |
+
else:
|
| 597 |
+
if EVEN_HEADDIM:
|
| 598 |
+
v = tl.load(
|
| 599 |
+
v_ptrs,
|
| 600 |
+
mask=(
|
| 601 |
+
start_n * BLOCK_N +
|
| 602 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 603 |
+
offs_m[None, :, None]
|
| 604 |
+
) < seqlen,
|
| 605 |
+
other=0.0
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
v = tl.load(
|
| 609 |
+
v_ptrs,
|
| 610 |
+
mask=(
|
| 611 |
+
(
|
| 612 |
+
start_n * BLOCK_N +
|
| 613 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 614 |
+
offs_m[None, :, None]
|
| 615 |
+
) < seqlen
|
| 616 |
+
) & (offs_d[None, None, :] < headdim),
|
| 617 |
+
other=0.0
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
if EVEN_N:
|
| 622 |
+
if EVEN_HEADDIM:
|
| 623 |
+
rfa_v = tl.load(
|
| 624 |
+
rfa_v_ptrs
|
| 625 |
+
)
|
| 626 |
+
else:
|
| 627 |
+
rfa_v = tl.load(
|
| 628 |
+
rfa_v_ptrs,
|
| 629 |
+
mask=offs_d[None, :] < headdim,
|
| 630 |
+
other=0.0
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
if EVEN_HEADDIM:
|
| 634 |
+
rfa_v = tl.load(
|
| 635 |
+
rfa_v_ptrs,
|
| 636 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 637 |
+
other=0.0
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
rfa_v = tl.load(
|
| 641 |
+
rfa_v_ptrs,
|
| 642 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 643 |
+
other=0.0
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if EVEN_N:
|
| 647 |
+
if EVEN_HEADDIM:
|
| 648 |
+
d_rfa_v = tl.load(
|
| 649 |
+
d_rfa_v_ptrs
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
d_rfa_v = tl.load(
|
| 653 |
+
d_rfa_v_ptrs,
|
| 654 |
+
mask=offs_d[None, :] < headdim,
|
| 655 |
+
other=0.0
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
if EVEN_HEADDIM:
|
| 659 |
+
d_rfa_v = tl.load(
|
| 660 |
+
d_rfa_v_ptrs,
|
| 661 |
+
mask=offs_rfa_c[:, None] < nchunks,
|
| 662 |
+
other=0.0
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
d_rfa_v = tl.load(
|
| 666 |
+
d_rfa_v_ptrs,
|
| 667 |
+
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
|
| 668 |
+
other=0.0
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
|
| 672 |
+
phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
|
| 673 |
+
phi_c_w += tl.sum(k * param_phi, axis=-1)
|
| 674 |
+
phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
|
| 675 |
+
phi_c_w *= log2e * softmax_scale
|
| 676 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 677 |
+
phi_c_w += tl.where(
|
| 678 |
+
(
|
| 679 |
+
start_n * BLOCK_N +
|
| 680 |
+
offs_c[:, None] * CHUNK_SIZE +
|
| 681 |
+
offs_m[None, :]
|
| 682 |
+
) < seqlen,
|
| 683 |
+
0,
|
| 684 |
+
float("-inf")
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
if MASK_TYPE == 1:
|
| 688 |
+
phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
m_phi_c_w = tl.max(phi_c_w, axis=-1)
|
| 692 |
+
masked_out_rows_phi = (m_phi_c_w == float("-inf"))
|
| 693 |
+
m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
|
| 694 |
+
phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
|
| 695 |
+
denom_phi = tl.sum(phi_c_w, axis=-1)
|
| 696 |
+
denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
|
| 697 |
+
phi_tilde_c_w = phi_c_w / denom_phi[:, None]
|
| 698 |
+
# phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
|
| 699 |
+
# phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
|
| 700 |
+
phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
|
| 701 |
+
d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
|
| 702 |
+
d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
|
| 703 |
+
d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
|
| 704 |
+
|
| 705 |
+
d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
|
| 706 |
+
d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
|
| 707 |
+
# [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
|
| 708 |
+
d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
|
| 709 |
+
|
| 710 |
+
d_k_ptrs = (
|
| 711 |
+
D_K +
|
| 712 |
+
offs_b * stride_d_k_b +
|
| 713 |
+
offs_h * stride_d_k_h +
|
| 714 |
+
(
|
| 715 |
+
(
|
| 716 |
+
start_n * BLOCK_N +
|
| 717 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 718 |
+
offs_m[None, :, None]
|
| 719 |
+
) * stride_d_k_n +
|
| 720 |
+
offs_d[None, None, :]
|
| 721 |
+
)
|
| 722 |
+
)
|
| 723 |
+
d_v_ptrs = (
|
| 724 |
+
D_V +
|
| 725 |
+
offs_b * stride_d_v_b +
|
| 726 |
+
offs_h * stride_d_v_h +
|
| 727 |
+
(
|
| 728 |
+
(
|
| 729 |
+
start_n * BLOCK_N +
|
| 730 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 731 |
+
offs_m[None, :, None]
|
| 732 |
+
) * stride_d_v_n +
|
| 733 |
+
offs_d[None, None, :]
|
| 734 |
+
)
|
| 735 |
+
)
|
| 736 |
+
if EVEN_N:
|
| 737 |
+
if EVEN_HEADDIM:
|
| 738 |
+
tl.store(
|
| 739 |
+
d_k_ptrs, d_k
|
| 740 |
+
)
|
| 741 |
+
tl.store(
|
| 742 |
+
d_v_ptrs, d_v
|
| 743 |
+
)
|
| 744 |
+
else:
|
| 745 |
+
tl.store(
|
| 746 |
+
d_k_ptrs, d_k,
|
| 747 |
+
mask=offs_d[None, None, :] < headdim
|
| 748 |
+
)
|
| 749 |
+
tl.store(
|
| 750 |
+
d_v_ptrs, d_v,
|
| 751 |
+
mask=offs_d[None, None, :] < headdim
|
| 752 |
+
)
|
| 753 |
+
else:
|
| 754 |
+
if EVEN_HEADDIM:
|
| 755 |
+
tl.store(
|
| 756 |
+
d_k_ptrs, d_k,
|
| 757 |
+
mask=(
|
| 758 |
+
(
|
| 759 |
+
start_n * BLOCK_N +
|
| 760 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 761 |
+
offs_m[None, :, None]
|
| 762 |
+
) < seqlen
|
| 763 |
+
),
|
| 764 |
+
)
|
| 765 |
+
tl.store(
|
| 766 |
+
d_v_ptrs, d_v,
|
| 767 |
+
mask=(
|
| 768 |
+
(
|
| 769 |
+
start_n * BLOCK_N +
|
| 770 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 771 |
+
offs_m[None, :, None]
|
| 772 |
+
) < seqlen
|
| 773 |
+
),
|
| 774 |
+
)
|
| 775 |
+
else:
|
| 776 |
+
tl.store(
|
| 777 |
+
d_k_ptrs, d_k,
|
| 778 |
+
mask=(
|
| 779 |
+
(
|
| 780 |
+
start_n * BLOCK_N +
|
| 781 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 782 |
+
offs_m[None, :, None]
|
| 783 |
+
) < seqlen
|
| 784 |
+
) & (offs_d[None, None, :] < headdim),
|
| 785 |
+
)
|
| 786 |
+
tl.store(
|
| 787 |
+
d_v_ptrs, d_v,
|
| 788 |
+
mask=(
|
| 789 |
+
(
|
| 790 |
+
start_n * BLOCK_N +
|
| 791 |
+
offs_c[:, None, None] * CHUNK_SIZE +
|
| 792 |
+
offs_m[None, :, None]
|
| 793 |
+
) < seqlen
|
| 794 |
+
) & (offs_d[None, None, :] < headdim),
|
| 795 |
+
)
|
| 796 |
+
d_param_phi_partial_ptrs = (
|
| 797 |
+
D_PARAM_PHI_PARTIAL +
|
| 798 |
+
offs_b * stride_d_phi_b +
|
| 799 |
+
offs_h * stride_d_phi_h +
|
| 800 |
+
start_n * stride_d_phi_g +
|
| 801 |
+
offs_d
|
| 802 |
+
)
|
| 803 |
+
if EVEN_HEADDIM:
|
| 804 |
+
tl.store(
|
| 805 |
+
d_param_phi_partial_ptrs, d_param_phi
|
| 806 |
+
)
|
| 807 |
+
else:
|
| 808 |
+
tl.store(
|
| 809 |
+
d_param_phi_partial_ptrs, d_param_phi,
|
| 810 |
+
mask=offs_d < headdim
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
|
| 814 |
+
k, v, param_mu, param_phi = [
|
| 815 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 816 |
+
for x in [k, v, param_mu, param_phi]
|
| 817 |
+
]
|
| 818 |
+
|
| 819 |
+
# shape constraints
|
| 820 |
+
batch, nheads, seqlen, head_dim = k.shape
|
| 821 |
+
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
|
| 822 |
+
nchunks = seqlen // chunksize
|
| 823 |
+
assert k.shape == (batch, nheads, seqlen, head_dim)
|
| 824 |
+
assert v.shape == (batch, nheads, seqlen, head_dim)
|
| 825 |
+
assert param_mu.shape == (1, nheads, 1, 1, head_dim)
|
| 826 |
+
assert param_phi.shape == (1, nheads, 1, 1, head_dim)
|
| 827 |
+
assert head_dim <= 128, "We only test head dimensions up to 128"
|
| 828 |
+
assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
|
| 829 |
+
assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
|
| 830 |
+
assert k.is_cuda and v.is_cuda
|
| 831 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 832 |
+
|
| 833 |
+
mask_type = 0
|
| 834 |
+
if mask is not None:
|
| 835 |
+
mask_type = 1
|
| 836 |
+
assert mask.dtype == torch.bool
|
| 837 |
+
assert mask.is_cuda
|
| 838 |
+
assert mask.dim() == 4
|
| 839 |
+
assert mask.shape == (batch, 1, seqlen, 1)
|
| 840 |
+
if mask.stride(-1) != 1:
|
| 841 |
+
mask = mask.contiguous()
|
| 842 |
+
mask_strides = (
|
| 843 |
+
(mask.stride(0), mask.stride(2))
|
| 844 |
+
if mask_type == 1 else
|
| 845 |
+
(0, 0)
|
| 846 |
+
)
|
| 847 |
+
out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
|
| 848 |
+
out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
|
| 849 |
+
|
| 850 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 851 |
+
BLOCK = 128
|
| 852 |
+
num_warps = 4 if head_dim <= 64 else 8
|
| 853 |
+
|
| 854 |
+
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
|
| 855 |
+
chunks_per_block = BLOCK // chunksize
|
| 856 |
+
|
| 857 |
+
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
|
| 858 |
+
_fwd_eva_prep_kv_kernel[grid](
|
| 859 |
+
k,
|
| 860 |
+
v,
|
| 861 |
+
param_mu,
|
| 862 |
+
param_phi,
|
| 863 |
+
mask,
|
| 864 |
+
out_rfa_k,
|
| 865 |
+
out_rfa_v,
|
| 866 |
+
softmax_scale,
|
| 867 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 868 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 869 |
+
param_mu.stride(1),
|
| 870 |
+
param_phi.stride(1),
|
| 871 |
+
mask_strides[0], mask_strides[1],
|
| 872 |
+
out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
|
| 873 |
+
out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
|
| 874 |
+
nheads,
|
| 875 |
+
seqlen,
|
| 876 |
+
nchunks,
|
| 877 |
+
head_dim,
|
| 878 |
+
chunks_per_block,
|
| 879 |
+
chunksize,
|
| 880 |
+
mask_type,
|
| 881 |
+
BLOCK_HEADDIM,
|
| 882 |
+
BLOCK_N=BLOCK,
|
| 883 |
+
num_warps=num_warps,
|
| 884 |
+
num_stages=1,
|
| 885 |
+
)
|
| 886 |
+
return out_rfa_k, out_rfa_v
|
| 887 |
+
|
| 888 |
+
def triton_eva_prep_kv_bwd(
|
| 889 |
+
d_rfa_k, d_rfa_v,
|
| 890 |
+
k, v, param_mu, param_phi,
|
| 891 |
+
mask,
|
| 892 |
+
rfa_k, rfa_v,
|
| 893 |
+
d_k, d_v, d_param_mu, d_param_phi,
|
| 894 |
+
softmax_scale,
|
| 895 |
+
mask_type,
|
| 896 |
+
chunksize
|
| 897 |
+
):
|
| 898 |
+
d_rfa_k, d_rfa_v = [
|
| 899 |
+
x if x.stride(-1) == 1 else x.contiguous()
|
| 900 |
+
for x in [d_rfa_k, d_rfa_v]
|
| 901 |
+
]
|
| 902 |
+
|
| 903 |
+
# shape constraints
|
| 904 |
+
batch, nheads, seqlen, head_dim = k.shape
|
| 905 |
+
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
|
| 906 |
+
nchunks = seqlen // chunksize
|
| 907 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
|
| 908 |
+
|
| 909 |
+
mask_strides = (
|
| 910 |
+
(mask.stride(0), mask.stride(2))
|
| 911 |
+
if mask_type == 1 else
|
| 912 |
+
(0, 0)
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
|
| 916 |
+
BLOCK = 128
|
| 917 |
+
num_warps = 4 if head_dim <= 64 else 8
|
| 918 |
+
|
| 919 |
+
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
|
| 920 |
+
chunks_per_block = BLOCK // chunksize
|
| 921 |
+
|
| 922 |
+
partial_groups = triton.cdiv(seqlen, BLOCK)
|
| 923 |
+
d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
|
| 924 |
+
d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
|
| 925 |
+
grid = lambda META: (partial_groups, batch * nheads)
|
| 926 |
+
_bwd_eva_prep_kv_kernel[grid](
|
| 927 |
+
rfa_k, # [b, h, c, d]
|
| 928 |
+
rfa_v, # [b, h, c, d]
|
| 929 |
+
k, # [b, h, n, d]
|
| 930 |
+
v, # [b, h, n, d]
|
| 931 |
+
param_mu, # [1, h, 1, 1, d]
|
| 932 |
+
param_phi, # [1, h, 1, 1, d]
|
| 933 |
+
mask, # [b, h, n, 1]
|
| 934 |
+
d_rfa_k, # [b, h, c, d]
|
| 935 |
+
d_rfa_v, # [b, h, c, d]
|
| 936 |
+
d_k, # [b, h, n, d]
|
| 937 |
+
d_v, # [b, h, n, d]
|
| 938 |
+
d_param_mu_partial, # [b, h, g, d]
|
| 939 |
+
d_param_phi_partial, # [b, h, g, d]
|
| 940 |
+
softmax_scale,
|
| 941 |
+
rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
|
| 942 |
+
rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
|
| 943 |
+
k.stride(0), k.stride(1), k.stride(2),
|
| 944 |
+
v.stride(0), v.stride(1), v.stride(2),
|
| 945 |
+
param_mu.stride(1),
|
| 946 |
+
param_phi.stride(1),
|
| 947 |
+
mask_strides[0], mask_strides[1],
|
| 948 |
+
d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
|
| 949 |
+
d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
|
| 950 |
+
d_k.stride(0), d_k.stride(1), d_k.stride(2),
|
| 951 |
+
d_v.stride(0), d_v.stride(1), d_v.stride(2),
|
| 952 |
+
d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
|
| 953 |
+
d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
|
| 954 |
+
nheads,
|
| 955 |
+
seqlen,
|
| 956 |
+
nchunks,
|
| 957 |
+
head_dim,
|
| 958 |
+
chunks_per_block,
|
| 959 |
+
chunksize,
|
| 960 |
+
mask_type,
|
| 961 |
+
BLOCK_HEADDIM,
|
| 962 |
+
BLOCK_N=BLOCK,
|
| 963 |
+
num_warps=num_warps,
|
| 964 |
+
num_stages=1,
|
| 965 |
+
)
|
| 966 |
+
d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
|
| 967 |
+
d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class EvaPrepKVFunc(torch.autograd.Function):
|
| 972 |
+
@staticmethod
|
| 973 |
+
def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
|
| 974 |
+
if mask is not None:
|
| 975 |
+
mask_type = 1
|
| 976 |
+
else:
|
| 977 |
+
mask_type = 0
|
| 978 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 979 |
+
k, v, param_mu, param_phi, mask, softmax_scale, chunksize
|
| 980 |
+
)
|
| 981 |
+
ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
|
| 982 |
+
ctx.softmax_scale = softmax_scale
|
| 983 |
+
ctx.chunksize = chunksize
|
| 984 |
+
ctx.mask_type = mask_type
|
| 985 |
+
return rfa_k, rfa_v
|
| 986 |
+
|
| 987 |
+
@staticmethod
|
| 988 |
+
def backward(ctx, d_rfa_k, d_rfa_v):
|
| 989 |
+
k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
|
| 990 |
+
d_k = torch.empty_like(k)
|
| 991 |
+
d_v = torch.empty_like(v)
|
| 992 |
+
d_param_mu = torch.empty_like(param_mu)
|
| 993 |
+
d_param_phi = torch.empty_like(param_phi)
|
| 994 |
+
triton_eva_prep_kv_bwd(
|
| 995 |
+
d_rfa_k, d_rfa_v,
|
| 996 |
+
k, v, param_mu, param_phi,
|
| 997 |
+
mask,
|
| 998 |
+
rfa_k, rfa_v,
|
| 999 |
+
d_k, d_v, d_param_mu, d_param_phi,
|
| 1000 |
+
ctx.softmax_scale,
|
| 1001 |
+
ctx.mask_type,
|
| 1002 |
+
ctx.chunksize
|
| 1003 |
+
)
|
| 1004 |
+
return d_k, d_v, d_param_mu, d_param_phi, None, None, None
|
| 1005 |
+
|
| 1006 |
+
def eva_prep_kv_func_triton(
|
| 1007 |
+
k, v,
|
| 1008 |
+
param_mu, param_phi,
|
| 1009 |
+
mask,
|
| 1010 |
+
softmax_scale=None, chunksize=None
|
| 1011 |
+
):
|
| 1012 |
+
return EvaPrepKVFunc.apply(
|
| 1013 |
+
k, v,
|
| 1014 |
+
param_mu, param_phi,
|
| 1015 |
+
mask,
|
| 1016 |
+
softmax_scale, chunksize
|
| 1017 |
+
)
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_pt_ref.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
MASK_MIN_VALUE = -10e10
|
| 6 |
+
|
| 7 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
"""
|
| 9 |
+
Rotates half the hidden dims (last dim) of the input.
|
| 10 |
+
Args:
|
| 11 |
+
x: Rotary embedded tensor
|
| 12 |
+
Return:
|
| 13 |
+
Tensor with half of last dim negated and rotated to the front.
|
| 14 |
+
"""
|
| 15 |
+
x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
|
| 16 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 17 |
+
|
| 18 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
| 19 |
+
position_ids: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
|
| 22 |
+
|
| 23 |
+
The legends for dimensions are defined as:
|
| 24 |
+
num_heads: number of attention heads
|
| 25 |
+
current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
|
| 26 |
+
max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
|
| 27 |
+
maximum lenghth, e.g. the length of static sequence length of KV cache
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
|
| 32 |
+
k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
|
| 33 |
+
cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 34 |
+
sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
|
| 35 |
+
position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
|
| 36 |
+
(batch_size, current_seq_len).
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Embedded query and key tensor of same size as input.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
bs, nheads, cur_seq_len, head_dim = q.shape
|
| 43 |
+
assert len(
|
| 44 |
+
k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
|
| 45 |
+
assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
|
| 46 |
+
assert list(k.shape[2:]) == [cur_seq_len,
|
| 47 |
+
head_dim], f"k has different current_seq_len and/or head_dim compared to q"
|
| 48 |
+
assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
|
| 49 |
+
assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
|
| 50 |
+
f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
|
| 51 |
+
|
| 52 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 53 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 54 |
+
return q_embed, k_embed
|
| 55 |
+
|
| 56 |
+
def attention_op(
|
| 57 |
+
q,
|
| 58 |
+
k,
|
| 59 |
+
v,
|
| 60 |
+
attn_mask,
|
| 61 |
+
mixedp_attn,
|
| 62 |
+
head_dim_scaling
|
| 63 |
+
):
|
| 64 |
+
attn = torch.matmul(q, k.transpose(-2, -1))
|
| 65 |
+
if mixedp_attn:
|
| 66 |
+
attn = attn.to(torch.float)
|
| 67 |
+
attn = attn * head_dim_scaling
|
| 68 |
+
if attn_mask is not None:
|
| 69 |
+
attn = attn.masked_fill(attn_mask, MASK_MIN_VALUE)
|
| 70 |
+
|
| 71 |
+
attn_weights = torch.softmax(attn, dim=-1).to(q.dtype)
|
| 72 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 73 |
+
return attn_output
|
| 74 |
+
|
| 75 |
+
def prm_projection(
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
projection_matrix: torch.Tensor,
|
| 78 |
+
mixedp_attn: bool = False
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Constructs nonnegative kernel features for fast softmax attention.
|
| 82 |
+
Args:
|
| 83 |
+
x: input for which features are computed
|
| 84 |
+
projection_matrix: random matrix used to compute features
|
| 85 |
+
Returns:
|
| 86 |
+
Random features for fast attention.
|
| 87 |
+
"""
|
| 88 |
+
# x : [..., m, d]
|
| 89 |
+
# proj : [..., r, d]
|
| 90 |
+
scaling_factor = (x.shape[-1] ** -0.5)
|
| 91 |
+
proj_x = torch.matmul(projection_matrix, x.transpose(-1, -2)) # [..., r, m]
|
| 92 |
+
norm = torch.sum(x ** 2, dim=-1).unsqueeze(-2) * 0.5 # [..., 1]
|
| 93 |
+
if mixedp_attn:
|
| 94 |
+
proj_x = proj_x.to(torch.float)
|
| 95 |
+
norm = norm.to(torch.float)
|
| 96 |
+
phi_x = scaling_factor * (proj_x - norm)
|
| 97 |
+
return phi_x
|
| 98 |
+
|
| 99 |
+
class EvaAttention(nn.Module):
|
| 100 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.config = config
|
| 103 |
+
self.layer_idx = layer_idx
|
| 104 |
+
self.hidden_size = config.hidden_size
|
| 105 |
+
self.num_heads = config.num_attention_heads
|
| 106 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 107 |
+
self.head_dim_scaling = self.head_dim ** -0.5
|
| 108 |
+
|
| 109 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 110 |
+
|
| 111 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 114 |
+
f" and `num_heads`: {self.num_heads})."
|
| 115 |
+
)
|
| 116 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 117 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 118 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 119 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 120 |
+
|
| 121 |
+
self.window_size = config.window_size
|
| 122 |
+
|
| 123 |
+
self.num_chunks = config.num_chunks
|
| 124 |
+
self.chunk_size = config.chunk_size
|
| 125 |
+
if self.chunk_size is not None:
|
| 126 |
+
assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
|
| 127 |
+
# chunk_size overrides the number of landmarks
|
| 128 |
+
self.num_chunks = None
|
| 129 |
+
|
| 130 |
+
self.chunks_per_window = int(self.window_size // self.chunk_size)
|
| 131 |
+
self.random_feature_dim = 1
|
| 132 |
+
self.adaptive_phi = nn.Parameter(
|
| 133 |
+
torch.randn(
|
| 134 |
+
1,
|
| 135 |
+
self.num_heads,
|
| 136 |
+
1,
|
| 137 |
+
1,
|
| 138 |
+
self.head_dim
|
| 139 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 140 |
+
)
|
| 141 |
+
self.adaptive_mu_k = nn.Parameter(
|
| 142 |
+
torch.randn(
|
| 143 |
+
1,
|
| 144 |
+
self.num_heads,
|
| 145 |
+
1,
|
| 146 |
+
1,
|
| 147 |
+
self.head_dim
|
| 148 |
+
).clamp(-1., 1.) * self.head_dim_scaling
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def _generate_feature_map(self, rf_q, rf_k, rf_v):
|
| 152 |
+
rf_k_logits = torch.sum(self.adaptive_mu_k.to(rf_k.dtype) * rf_k, dim=-1, keepdim=True) # b h c m 1
|
| 153 |
+
if self.config.mixedp_attn:
|
| 154 |
+
rf_k_logits = rf_k_logits.to(torch.float)
|
| 155 |
+
rf_k_weights = torch.softmax(rf_k_logits, dim=-2).to(rf_k.dtype)
|
| 156 |
+
rf_k_bar = torch.sum(rf_k_weights * rf_k, dim=-2)
|
| 157 |
+
weights = self.adaptive_phi.to(rf_k.dtype)
|
| 158 |
+
return weights, rf_k_bar
|
| 159 |
+
|
| 160 |
+
def _calculate_chunk_rfa_cache(self, rf_q, rf_k, rf_v, weights, rf_mask=None):
|
| 161 |
+
proj_x = torch.sum(weights * rf_k, dim=-1, keepdim=True)
|
| 162 |
+
norm = torch.sum(rf_k ** 2, dim=-1, keepdim=True) * 0.5 # [..., 1]
|
| 163 |
+
if self.config.mixedp_attn:
|
| 164 |
+
proj_x = proj_x.to(torch.float)
|
| 165 |
+
norm = norm.to(torch.float)
|
| 166 |
+
log_phi_k = self.head_dim_scaling * (proj_x - norm)
|
| 167 |
+
|
| 168 |
+
if rf_mask is not None:
|
| 169 |
+
log_phi_k = log_phi_k.masked_fill(rf_mask, MASK_MIN_VALUE)
|
| 170 |
+
|
| 171 |
+
# [b, h, c, m, r]
|
| 172 |
+
softmax_phi_k = torch.softmax(log_phi_k, dim=-2).to(rf_k.dtype)
|
| 173 |
+
softmax_phi_k_v = torch.sum(softmax_phi_k * rf_v, dim=-2)
|
| 174 |
+
# [b, h, c, r, m] [b, h, c, m, d] -> [b, h, c, r, d]
|
| 175 |
+
# softmax_phi_k_v = torch.matmul(softmax_phi_k.transpose(-1, -2), rf_v).squeeze(-2)
|
| 176 |
+
log_sum_phi_k = None
|
| 177 |
+
return softmax_phi_k_v, log_sum_phi_k
|
| 178 |
+
|
| 179 |
+
def _calculate_chunk_rfa(self, q, softmax_phi_k_v, log_sum_phi_k, weights):
|
| 180 |
+
if self.random_feature_dim == 1:
|
| 181 |
+
# when r = 1, the snis weights becomes 1, so this takes no effect
|
| 182 |
+
# [b, h, c, r, d] -> [b, h, c, d]
|
| 183 |
+
return softmax_phi_k_v
|
| 184 |
+
else:
|
| 185 |
+
# [b, h, c, r, d] [b, h, 1, s, d] -> [b, h, c, r, s]
|
| 186 |
+
log_phi_q = prm_projection(q.unsqueeze(-3), weights, self.config.mixedp_attn)
|
| 187 |
+
# [b, h, c, r, s] [b, h, c, r, 1] -> [b, h, c, r, s]
|
| 188 |
+
sniw = torch.softmax(log_phi_q + log_sum_phi_k, dim=-1).to(q.dtype)
|
| 189 |
+
# [b, h, c, r, s] [b, h, c, r, d] -> [b, h, c, s, d] -> [b, h, s, c, d]
|
| 190 |
+
rfa_per_chunk = torch.matmul(sniw.transpose(-1, -2), softmax_phi_k_v).transpose(-3, -2)
|
| 191 |
+
return rfa_per_chunk
|
| 192 |
+
|
| 193 |
+
def window_partition(self, x, window_size=None):
|
| 194 |
+
window_size = window_size if window_size is not None else self.window_size
|
| 195 |
+
|
| 196 |
+
gw, d = x.shape[-2:]
|
| 197 |
+
leading_dims = x.shape[:-2]
|
| 198 |
+
n_groups = gw // window_size
|
| 199 |
+
return x.reshape(*leading_dims, n_groups, window_size, d)
|
| 200 |
+
|
| 201 |
+
def window_merge(self, x, window_size=None):
|
| 202 |
+
g, w, d = x.shape[-3:]
|
| 203 |
+
leading_dims = x.shape[:-3]
|
| 204 |
+
return x.reshape(*leading_dims, g * w, d)
|
| 205 |
+
|
| 206 |
+
def forward(
|
| 207 |
+
self,
|
| 208 |
+
hidden_states: torch.Tensor,
|
| 209 |
+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 210 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 211 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 212 |
+
output_attentions: bool = False,
|
| 213 |
+
use_cache: bool = False,
|
| 214 |
+
cos: Optional[torch.Tensor] = None,
|
| 215 |
+
sin: Optional[torch.Tensor] = None,
|
| 216 |
+
multibyte_decoding: Optional[bool] = False,
|
| 217 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 218 |
+
assert not output_attentions
|
| 219 |
+
bsz, q_len, _ = hidden_states.size()
|
| 220 |
+
|
| 221 |
+
############################################
|
| 222 |
+
# initialize past states if not provided
|
| 223 |
+
############################################
|
| 224 |
+
if use_cache and past_key_value is None:
|
| 225 |
+
raise ValueError
|
| 226 |
+
if use_cache and multibyte_decoding:
|
| 227 |
+
raise NotImplementedError("Multibyte decoding is not supported for PyTorch native implementation")
|
| 228 |
+
# assert isinstance(attention_mask, tuple)
|
| 229 |
+
if len(attention_mask) == 4:
|
| 230 |
+
assert use_cache
|
| 231 |
+
prev_causal_mask, cur_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
|
| 232 |
+
elif len(attention_mask) == 3:
|
| 233 |
+
assert not use_cache
|
| 234 |
+
window_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
|
| 235 |
+
else:
|
| 236 |
+
raise NotImplementedError("Only attention-mask tuple with length 2 or 3 is supported")
|
| 237 |
+
|
| 238 |
+
############################################
|
| 239 |
+
# compute q, k, v from hidden states
|
| 240 |
+
############################################
|
| 241 |
+
# [b, h, q_len, d]
|
| 242 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 243 |
+
# [b, h, kv_len, d]
|
| 244 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 245 |
+
# [b, h, kv_len, d]
|
| 246 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 247 |
+
|
| 248 |
+
if use_cache:
|
| 249 |
+
past_key_value.update_past_len(q.shape[-2], self.layer_idx)
|
| 250 |
+
|
| 251 |
+
############################################
|
| 252 |
+
# apply rotary positional embeddings to q, k
|
| 253 |
+
############################################
|
| 254 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 255 |
+
|
| 256 |
+
############################################
|
| 257 |
+
# compute q, k, v stats for the local window
|
| 258 |
+
############################################
|
| 259 |
+
if use_cache:
|
| 260 |
+
(prev_w_q, prev_w_k, prev_w_v), (cur_w_q, cur_w_k, cur_w_v) = past_key_value.update_singletons(
|
| 261 |
+
q,
|
| 262 |
+
k,
|
| 263 |
+
v,
|
| 264 |
+
self.layer_idx,
|
| 265 |
+
self.window_size,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
prev_w_q = self.window_partition(q) # [b, h, w, i, d]
|
| 269 |
+
prev_w_k = self.window_partition(k) # [b, h, w, j, d]
|
| 270 |
+
prev_w_v = self.window_partition(v) # [b, h, w, j, d]
|
| 271 |
+
# during training, we assume window_size divides seq_len so no remainders
|
| 272 |
+
cur_w_q = cur_w_k = cur_w_v = None
|
| 273 |
+
|
| 274 |
+
############################################
|
| 275 |
+
# compute q, k, v stats for chunk-level RFAs
|
| 276 |
+
############################################
|
| 277 |
+
if use_cache:
|
| 278 |
+
dump_q, dump_k, dump_v = past_key_value.update_chunks(q, k, v, self.layer_idx, self.chunk_size)
|
| 279 |
+
else:
|
| 280 |
+
dump_q, dump_k, dump_v = q, k, v
|
| 281 |
+
|
| 282 |
+
if use_cache:
|
| 283 |
+
prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask = past_key_value.update_mask(
|
| 284 |
+
prev_s_mask=prev_causal_mask,
|
| 285 |
+
cur_s_mask=cur_causal_mask,
|
| 286 |
+
chunk_mask=chunk_causal_mask,
|
| 287 |
+
rf_mask=intra_chunk_mask,
|
| 288 |
+
layer_idx=self.layer_idx,
|
| 289 |
+
window_size=self.window_size,
|
| 290 |
+
chunk_size=self.chunk_size,
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
prev_s_mask = self.window_partition(prev_causal_mask) # [1, 1, w, i, j]
|
| 294 |
+
cur_s_mask = None
|
| 295 |
+
prev_chunk_mask = self.window_partition(chunk_causal_mask)
|
| 296 |
+
cur_chunk_mask = None
|
| 297 |
+
dump_rf_mask = intra_chunk_mask
|
| 298 |
+
if prev_s_mask.shape[-3] == 1:
|
| 299 |
+
# need to expand
|
| 300 |
+
prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
|
| 301 |
+
|
| 302 |
+
if (
|
| 303 |
+
dump_q is not None and
|
| 304 |
+
dump_k is not None and
|
| 305 |
+
dump_v is not None
|
| 306 |
+
):
|
| 307 |
+
# [b, h, c, j, d]
|
| 308 |
+
rf_q = self.window_partition(dump_q, window_size=self.chunk_size)
|
| 309 |
+
# [b, h, c, j, d]
|
| 310 |
+
rf_k = self.window_partition(dump_k, window_size=self.chunk_size)
|
| 311 |
+
# [b, h, c, j, d]
|
| 312 |
+
rf_v = self.window_partition(dump_v, window_size=self.chunk_size)
|
| 313 |
+
|
| 314 |
+
if dump_rf_mask is not None:
|
| 315 |
+
rf_mask = self.window_partition(dump_rf_mask, window_size=self.chunk_size)
|
| 316 |
+
rf_q = rf_q.masked_fill(rf_mask, 0.)
|
| 317 |
+
rf_k = rf_k.masked_fill(rf_mask, 0.)
|
| 318 |
+
rf_v = rf_v.masked_fill(rf_mask, 0.)
|
| 319 |
+
else:
|
| 320 |
+
rf_mask = None
|
| 321 |
+
else:
|
| 322 |
+
rf_q = None
|
| 323 |
+
rf_k = None
|
| 324 |
+
rf_v = None
|
| 325 |
+
rf_mask = None
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if rf_q is not None:
|
| 329 |
+
# import pdb; pdb.set_trace()
|
| 330 |
+
weights, rf_k_bar = self._generate_feature_map(rf_q, rf_k, rf_v)
|
| 331 |
+
softmax_phi_k_v, log_sum_phi_k = self._calculate_chunk_rfa_cache(rf_q, rf_k, rf_v, weights, rf_mask=rf_mask)
|
| 332 |
+
if use_cache:
|
| 333 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.update_chunk_rfas(
|
| 334 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar, self.layer_idx, 1
|
| 335 |
+
)
|
| 336 |
+
elif use_cache:
|
| 337 |
+
weights = None
|
| 338 |
+
softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.get_chunk_rfas(self.layer_idx)
|
| 339 |
+
else:
|
| 340 |
+
weights = None
|
| 341 |
+
softmax_phi_k_v = None
|
| 342 |
+
log_sum_phi_k = None
|
| 343 |
+
rf_k_bar = None
|
| 344 |
+
|
| 345 |
+
if rf_k_bar is not None:
|
| 346 |
+
rfa_per_chunk = self._calculate_chunk_rfa(q, softmax_phi_k_v, log_sum_phi_k, weights)
|
| 347 |
+
############################################
|
| 348 |
+
# compute meta-attention weights for
|
| 349 |
+
# - group-wise RFAs and
|
| 350 |
+
# - singletons (equivalent to exact local attention)
|
| 351 |
+
############################################
|
| 352 |
+
if prev_w_k is not None:
|
| 353 |
+
if rf_k_bar is not None:
|
| 354 |
+
num_windows = prev_w_k.shape[-3]
|
| 355 |
+
# rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
|
| 356 |
+
# -> [b, h, 1, c, d] -> [b, h, w, c, d]
|
| 357 |
+
prev_rf_k_bar = rf_k_bar.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
|
| 358 |
+
prev_rfa_per_chunk = rfa_per_chunk.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
|
| 359 |
+
prev_agg_k = torch.cat([prev_w_k, prev_rf_k_bar], dim=-2)
|
| 360 |
+
prev_agg_v = torch.cat([prev_w_v, prev_rfa_per_chunk], dim=-2)
|
| 361 |
+
|
| 362 |
+
prev_attn_mask = torch.cat([prev_s_mask, prev_chunk_mask], dim=-1)
|
| 363 |
+
else:
|
| 364 |
+
prev_agg_k = prev_w_k
|
| 365 |
+
prev_agg_v = prev_w_v
|
| 366 |
+
prev_attn_mask = prev_s_mask
|
| 367 |
+
|
| 368 |
+
prev_attn_output = attention_op(
|
| 369 |
+
q=prev_w_q,
|
| 370 |
+
k=prev_agg_k,
|
| 371 |
+
v=prev_agg_v,
|
| 372 |
+
attn_mask=prev_attn_mask,
|
| 373 |
+
mixedp_attn=self.config.mixedp_attn,
|
| 374 |
+
head_dim_scaling=self.head_dim_scaling
|
| 375 |
+
)
|
| 376 |
+
prev_attn_output = self.window_merge(prev_attn_output)
|
| 377 |
+
|
| 378 |
+
if cur_w_k is not None:
|
| 379 |
+
if rf_k_bar is not None:
|
| 380 |
+
# rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
|
| 381 |
+
# cur_w_k and cur_w_v also has shape [b, h, m, d]
|
| 382 |
+
cur_agg_k = torch.cat([cur_w_k, rf_k_bar], dim=-2)
|
| 383 |
+
cur_agg_v = torch.cat([cur_w_v, rfa_per_chunk], dim=-2)
|
| 384 |
+
|
| 385 |
+
cur_attn_mask = torch.cat([cur_s_mask, cur_chunk_mask], dim=-1)
|
| 386 |
+
else:
|
| 387 |
+
cur_agg_k = cur_w_k
|
| 388 |
+
cur_agg_v = cur_w_v
|
| 389 |
+
cur_attn_mask = cur_s_mask
|
| 390 |
+
|
| 391 |
+
cur_attn_output = attention_op(
|
| 392 |
+
q=cur_w_q,
|
| 393 |
+
k=cur_agg_k,
|
| 394 |
+
v=cur_agg_v,
|
| 395 |
+
attn_mask=cur_attn_mask,
|
| 396 |
+
mixedp_attn=self.config.mixedp_attn,
|
| 397 |
+
head_dim_scaling=self.head_dim_scaling
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if prev_w_k is not None and cur_w_k is not None:
|
| 401 |
+
attn_output = torch.cat([prev_attn_output, cur_attn_output], dim=-2)
|
| 402 |
+
elif prev_w_k is not None:
|
| 403 |
+
attn_output = prev_attn_output
|
| 404 |
+
elif cur_w_k is not None:
|
| 405 |
+
attn_output = cur_attn_output
|
| 406 |
+
else:
|
| 407 |
+
raise ValueError("There must be some bug")
|
| 408 |
+
|
| 409 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 412 |
+
f" {attn_output.size()}"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
| 416 |
+
attn_output = self.o_proj(attn_output)
|
| 417 |
+
|
| 418 |
+
attn_weights = None
|
| 419 |
+
|
| 420 |
+
return attn_output, attn_weights, past_key_value
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.47.1"
|
| 7 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/image_processing_evabyte.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Image processor class for EvaByte."""
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional, Union, Tuple
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
ImageInput,
|
| 10 |
+
PILImageResampling,
|
| 11 |
+
valid_images,
|
| 12 |
+
validate_preprocess_arguments,
|
| 13 |
+
)
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
def _get_qtable_bytes():
|
| 17 |
+
return {
|
| 18 |
+
5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 19 |
+
10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 20 |
+
15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
|
| 21 |
+
20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
|
| 22 |
+
25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
|
| 23 |
+
30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
|
| 24 |
+
50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
|
| 25 |
+
75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
|
| 26 |
+
95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
|
| 27 |
+
100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _resize_if_exceeding_max_len(
|
| 32 |
+
width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
|
| 33 |
+
) -> Tuple[int, int]:
|
| 34 |
+
"""
|
| 35 |
+
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
height (`int`):
|
| 39 |
+
Height of the input image.
|
| 40 |
+
width (`int`):
|
| 41 |
+
Width of the input image.
|
| 42 |
+
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
| 43 |
+
Defines the maximum dimensions of the image.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The output size of the image after resizing.
|
| 47 |
+
"""
|
| 48 |
+
max_len = max(height, width) if max_len is None else max_len
|
| 49 |
+
aspect_ratio = width / height
|
| 50 |
+
|
| 51 |
+
if width >= height and width > max_len:
|
| 52 |
+
width = max_len
|
| 53 |
+
height = int(width / aspect_ratio)
|
| 54 |
+
if height % 2 != 0:
|
| 55 |
+
height += 1
|
| 56 |
+
elif height > width and height > max_len:
|
| 57 |
+
height = max_len
|
| 58 |
+
width = int(height * aspect_ratio)
|
| 59 |
+
if width % 2 != 0:
|
| 60 |
+
width += 1
|
| 61 |
+
|
| 62 |
+
# Avoid resizing to a size smaller than 1
|
| 63 |
+
height = max(height, min_len)
|
| 64 |
+
width = max(width, min_len)
|
| 65 |
+
return width, height
|
| 66 |
+
|
| 67 |
+
class EvaByteImageProcessor(BaseImageProcessor):
|
| 68 |
+
|
| 69 |
+
model_input_names = []
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
do_resize: bool = True,
|
| 74 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 75 |
+
size: Dict[str, int] = None,
|
| 76 |
+
do_convert_rgb: bool = True,
|
| 77 |
+
jpeg_quality: int = 25,
|
| 78 |
+
jpeg_subsampling: str = "4:2:0",
|
| 79 |
+
jpeg_streamtype: str = 2,
|
| 80 |
+
jpeg_restart_marker_blocks: int = 1,
|
| 81 |
+
**kwargs,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
self.do_resize = do_resize
|
| 85 |
+
self.resample = resample
|
| 86 |
+
self.size = size if size is not None else {"longest_edge": 384}
|
| 87 |
+
self.do_convert_rgb = do_convert_rgb
|
| 88 |
+
self.jpeg_quality = jpeg_quality
|
| 89 |
+
self.jpeg_subsampling = jpeg_subsampling
|
| 90 |
+
self.jpeg_streamtype = jpeg_streamtype
|
| 91 |
+
self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
|
| 92 |
+
|
| 93 |
+
def jpeg_encode(
|
| 94 |
+
self,
|
| 95 |
+
image,
|
| 96 |
+
jpeg_quality,
|
| 97 |
+
jpeg_subsampling,
|
| 98 |
+
jpeg_streamtype,
|
| 99 |
+
jpeg_restart_marker_blocks,
|
| 100 |
+
):
|
| 101 |
+
with io.BytesIO() as output:
|
| 102 |
+
image.save(
|
| 103 |
+
output,
|
| 104 |
+
format="JPEG",
|
| 105 |
+
quality=jpeg_quality,
|
| 106 |
+
subsampling=jpeg_subsampling,
|
| 107 |
+
streamtype=jpeg_streamtype,
|
| 108 |
+
restart_marker_blocks=jpeg_restart_marker_blocks
|
| 109 |
+
)
|
| 110 |
+
jpeg_bytes = output.getvalue()
|
| 111 |
+
return jpeg_bytes
|
| 112 |
+
|
| 113 |
+
def jpeg_merge_qtables(
|
| 114 |
+
self,
|
| 115 |
+
image_bytes,
|
| 116 |
+
jpeg_quality=None,
|
| 117 |
+
):
|
| 118 |
+
if jpeg_quality is None:
|
| 119 |
+
jpeg_quality = self.jpeg_quality
|
| 120 |
+
qtable_bytes = _get_qtable_bytes()[jpeg_quality]
|
| 121 |
+
return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
|
| 122 |
+
|
| 123 |
+
def resize(
|
| 124 |
+
self,
|
| 125 |
+
image: Image,
|
| 126 |
+
size: Dict[str, int],
|
| 127 |
+
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
| 128 |
+
) -> Image:
|
| 129 |
+
if "longest_edge" in size:
|
| 130 |
+
width, height = image.size
|
| 131 |
+
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
| 132 |
+
width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
|
| 133 |
+
size = (width, height)
|
| 134 |
+
elif "width" in size and "height" in size:
|
| 135 |
+
size = (size["width"], size["height"])
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
| 138 |
+
resized_image = image.resize(size, resample=resample)
|
| 139 |
+
return resized_image
|
| 140 |
+
|
| 141 |
+
def preprocess(
|
| 142 |
+
self,
|
| 143 |
+
images: ImageInput,
|
| 144 |
+
do_resize: bool = None,
|
| 145 |
+
resample = None,
|
| 146 |
+
size: Dict[str, int] = None,
|
| 147 |
+
do_convert_rgb: bool = None,
|
| 148 |
+
jpeg_quality: int = None,
|
| 149 |
+
jpeg_subsampling: str = None,
|
| 150 |
+
jpeg_streamtype: str = None,
|
| 151 |
+
jpeg_restart_marker_blocks: int = None,
|
| 152 |
+
):
|
| 153 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 154 |
+
size = size if size is not None else self.size
|
| 155 |
+
resample = resample if resample is not None else self.resample
|
| 156 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 157 |
+
|
| 158 |
+
jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
|
| 159 |
+
jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
|
| 160 |
+
jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
|
| 161 |
+
jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
|
| 162 |
+
|
| 163 |
+
if images is not None and not valid_images(images):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 166 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
validate_preprocess_arguments(
|
| 170 |
+
do_resize=do_resize,
|
| 171 |
+
size=size,
|
| 172 |
+
resample=resample,
|
| 173 |
+
)
|
| 174 |
+
images_list = images
|
| 175 |
+
if do_convert_rgb:
|
| 176 |
+
images_list = [
|
| 177 |
+
[
|
| 178 |
+
image.convert("RGB") for image in images
|
| 179 |
+
]
|
| 180 |
+
for images in images_list
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
if do_resize:
|
| 184 |
+
images_list = [
|
| 185 |
+
[
|
| 186 |
+
self.resize(image=image, size=size, resample=resample)
|
| 187 |
+
for image in images
|
| 188 |
+
]
|
| 189 |
+
for images in images_list
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
jpeg_bytes = [
|
| 193 |
+
[
|
| 194 |
+
self.jpeg_encode(
|
| 195 |
+
image,
|
| 196 |
+
jpeg_quality,
|
| 197 |
+
jpeg_subsampling,
|
| 198 |
+
jpeg_streamtype,
|
| 199 |
+
jpeg_restart_marker_blocks
|
| 200 |
+
) for image in images
|
| 201 |
+
]
|
| 202 |
+
for images in images_list
|
| 203 |
+
]
|
| 204 |
+
return jpeg_bytes
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 57058938880
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 7 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 8 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 9 |
+
"model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 10 |
+
"model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 11 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 12 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 13 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 14 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 15 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 16 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 17 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 18 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 19 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 20 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 21 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 22 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 23 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 24 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 25 |
+
"model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 26 |
+
"model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 27 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 28 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 29 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 30 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 31 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 32 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 33 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 34 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 35 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 36 |
+
"model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 37 |
+
"model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 38 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 39 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 40 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 41 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 42 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 43 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 44 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 45 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 46 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 47 |
+
"model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 48 |
+
"model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 49 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 50 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 51 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 52 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 53 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 54 |
+
"model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 55 |
+
"model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 56 |
+
"model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 57 |
+
"model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 58 |
+
"model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 59 |
+
"model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 60 |
+
"model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 61 |
+
"model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 62 |
+
"model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 63 |
+
"model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 64 |
+
"model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 65 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 66 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 67 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 68 |
+
"model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 69 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 70 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 71 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 72 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 73 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 74 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 75 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 76 |
+
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 77 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 78 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 79 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 80 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 81 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 82 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 83 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 84 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 85 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 86 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 87 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 88 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 89 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 90 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 91 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 92 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 93 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 94 |
+
"model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 95 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 96 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 97 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 98 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 99 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 100 |
+
"model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 101 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 102 |
+
"model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 103 |
+
"model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 104 |
+
"model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 105 |
+
"model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 106 |
+
"model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 107 |
+
"model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 108 |
+
"model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 109 |
+
"model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 110 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 111 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 112 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 113 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 114 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 115 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 116 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 117 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 118 |
+
"model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 119 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 120 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 121 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 122 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 123 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 124 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 125 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 126 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 127 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 128 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 129 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 130 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 131 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 132 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 133 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 134 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 135 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 136 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 137 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 138 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 139 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 140 |
+
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 141 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 142 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 143 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 144 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 145 |
+
"model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 146 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 147 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 148 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 149 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 150 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 151 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 152 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 153 |
+
"model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 154 |
+
"model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 155 |
+
"model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 156 |
+
"model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 157 |
+
"model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 158 |
+
"model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 159 |
+
"model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 160 |
+
"model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 161 |
+
"model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 162 |
+
"model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 163 |
+
"model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 164 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 165 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 166 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 167 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 168 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 169 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 170 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 171 |
+
"model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 172 |
+
"model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 173 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 174 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 175 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 176 |
+
"model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 177 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 178 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 179 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 180 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 181 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 182 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 183 |
+
"model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 184 |
+
"model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 185 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 186 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 187 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 188 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 189 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 190 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 191 |
+
"model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 192 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 193 |
+
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 194 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 195 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 199 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 201 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 203 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 205 |
+
"model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 206 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 207 |
+
"model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 211 |
+
"model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 213 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 214 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 215 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 216 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 217 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 218 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 219 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 220 |
+
"model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 221 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 222 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 223 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 224 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 225 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 226 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 227 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 228 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 229 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 230 |
+
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 231 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 232 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 233 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 234 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 235 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 236 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 237 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 238 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 239 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 240 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 241 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 242 |
+
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 243 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 244 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 245 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 246 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 247 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 248 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 249 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 250 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 251 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 252 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 253 |
+
"model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 254 |
+
"model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 255 |
+
"model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 256 |
+
"model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 257 |
+
"model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 258 |
+
"model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 259 |
+
"model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 260 |
+
"model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 261 |
+
"model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 262 |
+
"model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 263 |
+
"model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 264 |
+
"model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 265 |
+
"model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 266 |
+
"model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 267 |
+
"model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 268 |
+
"model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 269 |
+
"model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 270 |
+
"model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 271 |
+
"model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 272 |
+
"model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 273 |
+
"model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 274 |
+
"model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 275 |
+
"model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 276 |
+
"model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 277 |
+
"model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 278 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 279 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 280 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 281 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 282 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 283 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 284 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 285 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 286 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 287 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 288 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 289 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 290 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 291 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 292 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 293 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 294 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 295 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 296 |
+
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 297 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 298 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 299 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 300 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 301 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 302 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 303 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 304 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 305 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 306 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 307 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 308 |
+
"model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 309 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 310 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 311 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 312 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 313 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 314 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 315 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 316 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 317 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 318 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 319 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 320 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 321 |
+
"model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 322 |
+
"model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 323 |
+
"model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 324 |
+
"model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 325 |
+
"model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 326 |
+
"model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 327 |
+
"model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 328 |
+
"model.norm.weight": "model-00003-of-00003.safetensors",
|
| 329 |
+
"lm_head.weight": "model-00003-of-00003.safetensors",
|
| 330 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 331 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 332 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 333 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 334 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 335 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 336 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 337 |
+
"model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 338 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 339 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 340 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 341 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 342 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 343 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 344 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 345 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 346 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 347 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 348 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 349 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 350 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 351 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 352 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 353 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 354 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 355 |
+
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 356 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 357 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 358 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 359 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 360 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 361 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 362 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 363 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 364 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 365 |
+
"model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 366 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 367 |
+
"model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 368 |
+
"model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 369 |
+
"model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 370 |
+
"model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 371 |
+
"model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 372 |
+
"model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 373 |
+
"model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 374 |
+
"model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 375 |
+
"model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 376 |
+
"model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 377 |
+
"model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 378 |
+
"model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 379 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 380 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 381 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 382 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 383 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 384 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 385 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 386 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 387 |
+
"model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 388 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 389 |
+
"model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 390 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 391 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 392 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 393 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 394 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 395 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 396 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 397 |
+
"model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 398 |
+
"model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 399 |
+
"model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 400 |
+
"model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 401 |
+
"model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 402 |
+
"model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 403 |
+
"model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 404 |
+
"model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 405 |
+
"model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 406 |
+
"model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 407 |
+
"model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 408 |
+
"model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 409 |
+
"model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 410 |
+
"model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
|
| 411 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 412 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 413 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 414 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 415 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 416 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 417 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 418 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 419 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 420 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 421 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 422 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 423 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 424 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 425 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 426 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 427 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 428 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 429 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 430 |
+
"model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 431 |
+
"model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 432 |
+
"model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 433 |
+
"model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 434 |
+
"model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
|
| 435 |
+
"model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 436 |
+
"model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 437 |
+
"model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 438 |
+
"model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 439 |
+
"model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 440 |
+
"model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 441 |
+
"model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
|
| 442 |
+
"model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
|
| 443 |
+
"model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 444 |
+
"model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 445 |
+
"model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 446 |
+
"model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
|
| 447 |
+
"model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
|
| 448 |
+
"model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
|
| 449 |
+
}
|
| 450 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/modeling_evabyte.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import CrossEntropyLoss
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from transformers.cache_utils import Cache
|
| 10 |
+
from transformers.modeling_outputs import (
|
| 11 |
+
BaseModelOutputWithPast,
|
| 12 |
+
CausalLMOutputWithPast,
|
| 13 |
+
)
|
| 14 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
+
|
| 16 |
+
from .configuration_evabyte import EvaByteConfig
|
| 17 |
+
from .multibyte_decoding_evabyte import MultiByteDecodingMixin
|
| 18 |
+
try:
|
| 19 |
+
import triton
|
| 20 |
+
USE_TRITON_IMPL = True
|
| 21 |
+
from .eva import EvaAttention
|
| 22 |
+
from .eva_agg_kernel import triton_eva_agg_fwd
|
| 23 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 24 |
+
except ImportError:
|
| 25 |
+
USE_TRITON_IMPL = False
|
| 26 |
+
print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
|
| 27 |
+
from .eva_pt_ref import EvaAttention
|
| 28 |
+
from .eva_cache import EvaCache, EvaStaticCacheForTriton
|
| 29 |
+
|
| 30 |
+
MASK_MIN_VALUE = -10e10
|
| 31 |
+
|
| 32 |
+
def prepare_eva_attention_mask(
|
| 33 |
+
seq_len,
|
| 34 |
+
device,
|
| 35 |
+
chunk_size,
|
| 36 |
+
window_size,
|
| 37 |
+
use_cache=False,
|
| 38 |
+
cache=None
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Prepare attention masks for EVA.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
chunk_causal_mask = None
|
| 45 |
+
window_causal_mask = None
|
| 46 |
+
if use_cache:
|
| 47 |
+
cached_seq_len = cache.get_seq_length()
|
| 48 |
+
total_seq_len = seq_len + cached_seq_len
|
| 49 |
+
# cached_seq_len will be 0 during prefilling
|
| 50 |
+
# padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
|
| 51 |
+
padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
|
| 52 |
+
num_chunks = padded_seq_len // chunk_size
|
| 53 |
+
else:
|
| 54 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 55 |
+
assert seq_len % chunk_size == 0
|
| 56 |
+
num_chunks = seq_len // chunk_size
|
| 57 |
+
|
| 58 |
+
assert seq_len % window_size == 0
|
| 59 |
+
|
| 60 |
+
# create causal mask
|
| 61 |
+
################################
|
| 62 |
+
# generate chunked causal masks
|
| 63 |
+
################################
|
| 64 |
+
# [b, h, j, c, c]
|
| 65 |
+
chunks_per_window = window_size // chunk_size
|
| 66 |
+
if num_chunks >= chunks_per_window:
|
| 67 |
+
chunk_causal_mask = torch.ones(
|
| 68 |
+
(chunk_size, num_chunks, num_chunks),
|
| 69 |
+
device=device,
|
| 70 |
+
dtype=torch.bool
|
| 71 |
+
).triu(0)
|
| 72 |
+
|
| 73 |
+
num_blocks = num_chunks // chunks_per_window
|
| 74 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 75 |
+
chunk_size,
|
| 76 |
+
num_blocks,
|
| 77 |
+
chunks_per_window,
|
| 78 |
+
num_blocks,
|
| 79 |
+
chunks_per_window
|
| 80 |
+
).transpose(-2, -3)
|
| 81 |
+
|
| 82 |
+
block_diag_zero = (
|
| 83 |
+
torch.eye(num_blocks, device=device, dtype=torch.bool)
|
| 84 |
+
.unsqueeze(-1)
|
| 85 |
+
.unsqueeze(-1)
|
| 86 |
+
.unsqueeze(0)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Set diagonal blocks to zero
|
| 90 |
+
chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
|
| 91 |
+
|
| 92 |
+
# Reshape back to original size
|
| 93 |
+
chunk_causal_mask = (
|
| 94 |
+
chunk_causal_mask
|
| 95 |
+
.transpose(-2, -3)
|
| 96 |
+
.reshape(chunk_size, num_chunks, num_chunks)
|
| 97 |
+
.transpose(-2, -3)
|
| 98 |
+
.reshape(chunk_size * num_chunks, num_chunks)
|
| 99 |
+
.unsqueeze(0)
|
| 100 |
+
.unsqueeze(0)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
chunk_causal_mask = torch.ones(
|
| 104 |
+
(1, 1, chunk_size, num_chunks, num_chunks),
|
| 105 |
+
device=device,
|
| 106 |
+
dtype=torch.bool,
|
| 107 |
+
).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
|
| 108 |
+
chunk_causal_mask = chunk_causal_mask.reshape(
|
| 109 |
+
1, 1, chunk_size * num_chunks, num_chunks
|
| 110 |
+
) # [1, 1, n, c]
|
| 111 |
+
|
| 112 |
+
if use_cache:
|
| 113 |
+
chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
|
| 114 |
+
|
| 115 |
+
window_causal_mask = torch.ones(
|
| 116 |
+
(1, 1, 1, window_size, window_size),
|
| 117 |
+
device=device
|
| 118 |
+
).triu(1).to(torch.bool)
|
| 119 |
+
return (chunk_causal_mask, window_causal_mask)
|
| 120 |
+
|
| 121 |
+
def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
|
| 122 |
+
assert dim < 0 # only accept ``dim'' index in a reverse manner
|
| 123 |
+
seqlen = int(tensor.shape[dim])
|
| 124 |
+
m = seqlen / multiple
|
| 125 |
+
if m.is_integer():
|
| 126 |
+
if create_mask:
|
| 127 |
+
return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
|
| 128 |
+
else:
|
| 129 |
+
return tensor
|
| 130 |
+
remainder = math.ceil(m) * multiple - seqlen
|
| 131 |
+
pad_offset = (0,) * (-1 - dim) * 2
|
| 132 |
+
if left_padding:
|
| 133 |
+
padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
|
| 134 |
+
else:
|
| 135 |
+
padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
|
| 136 |
+
if create_mask:
|
| 137 |
+
# assume dim 0 is the batch size
|
| 138 |
+
padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
|
| 139 |
+
if left_padding:
|
| 140 |
+
padding_mask[:, :remainder] = False
|
| 141 |
+
else:
|
| 142 |
+
padding_mask[:, -remainder:] = False
|
| 143 |
+
return padded_res, padding_mask
|
| 144 |
+
else:
|
| 145 |
+
return padded_res
|
| 146 |
+
|
| 147 |
+
class EvaByteRMSNorm(nn.Module):
|
| 148 |
+
def __init__(self, config):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.config = config
|
| 151 |
+
self.fp32_ln = True
|
| 152 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 153 |
+
self.add_unit_offset = config.norm_add_unit_offset
|
| 154 |
+
if self.add_unit_offset:
|
| 155 |
+
self.weight = nn.Parameter(torch.zeros(config.hidden_size))
|
| 156 |
+
else:
|
| 157 |
+
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states):
|
| 160 |
+
_hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
|
| 161 |
+
|
| 162 |
+
variance = _hidden_states.pow(2).mean(-1, keepdim=True)
|
| 163 |
+
_hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 164 |
+
if self.add_unit_offset:
|
| 165 |
+
return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
|
| 166 |
+
else:
|
| 167 |
+
return (self.weight * _hidden_states).type_as(hidden_states)
|
| 168 |
+
|
| 169 |
+
class EvaByteRotaryEmbedding(torch.nn.Module):
|
| 170 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 171 |
+
super().__init__()
|
| 172 |
+
|
| 173 |
+
self.dim = dim
|
| 174 |
+
self.max_position_embeddings = max_position_embeddings
|
| 175 |
+
self.base = base
|
| 176 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
|
| 179 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
| 180 |
+
device=self.inv_freq.device,
|
| 181 |
+
dtype=torch.get_default_dtype())
|
| 182 |
+
|
| 183 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 184 |
+
self.max_seq_len_cached = seq_len
|
| 185 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 186 |
+
|
| 187 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 188 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 189 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 190 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def forward(self, x, seq_len=None):
|
| 194 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 195 |
+
if seq_len > self.max_seq_len_cached:
|
| 196 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 197 |
+
|
| 198 |
+
# return (
|
| 199 |
+
# self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 200 |
+
# self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 201 |
+
# )
|
| 202 |
+
if seq_len < self.max_seq_len_cached:
|
| 203 |
+
cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
|
| 204 |
+
sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
|
| 205 |
+
else:
|
| 206 |
+
cos_slice = self.cos_cached
|
| 207 |
+
sin_slice = self.sin_cached
|
| 208 |
+
|
| 209 |
+
return (
|
| 210 |
+
cos_slice.to(dtype=x.dtype),
|
| 211 |
+
sin_slice.to(dtype=x.dtype),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 217 |
+
"""EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 218 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 219 |
+
self.scaling_factor = scaling_factor
|
| 220 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 221 |
+
|
| 222 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 223 |
+
self.max_seq_len_cached = seq_len
|
| 224 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 225 |
+
t = t / self.scaling_factor
|
| 226 |
+
|
| 227 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 228 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 229 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 230 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 231 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
|
| 235 |
+
"""EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 236 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 237 |
+
self.scaling_factor = scaling_factor
|
| 238 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 239 |
+
|
| 240 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 241 |
+
self.max_seq_len_cached = seq_len
|
| 242 |
+
|
| 243 |
+
if seq_len > self.max_position_embeddings:
|
| 244 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
| 245 |
+
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
|
| 246 |
+
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 247 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 248 |
+
|
| 249 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 250 |
+
|
| 251 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 252 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 253 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 254 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 255 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class EvaByteMLP(nn.Module):
|
| 259 |
+
def __init__(self, config, layer_idx: int = None):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.hidden_size = config.hidden_size
|
| 262 |
+
self.intermediate_size = config.intermediate_size
|
| 263 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 264 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 265 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 266 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 267 |
+
self.layer_idx = layer_idx
|
| 268 |
+
self.config = config
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 272 |
+
return down_proj
|
| 273 |
+
|
| 274 |
+
class EvaByteDecoderLayer(nn.Module):
|
| 275 |
+
def __init__(self, config: EvaByteConfig, layer_idx: int = None):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.config = config
|
| 278 |
+
self.hidden_size = config.hidden_size
|
| 279 |
+
self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
|
| 280 |
+
self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
|
| 281 |
+
self.input_layernorm = EvaByteRMSNorm(config)
|
| 282 |
+
self.post_attention_layernorm = EvaByteRMSNorm(config)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
hidden_states: torch.Tensor,
|
| 287 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 288 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 289 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 290 |
+
output_attentions: Optional[bool] = False,
|
| 291 |
+
use_cache: Optional[bool] = False,
|
| 292 |
+
cos: Optional[torch.Tensor] = None,
|
| 293 |
+
sin: Optional[torch.Tensor] = None,
|
| 294 |
+
multibyte_decoding: Optional[bool] = False,
|
| 295 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 296 |
+
residual = hidden_states
|
| 297 |
+
if self.config.fp32_skip_add:
|
| 298 |
+
residual = residual.float()
|
| 299 |
+
|
| 300 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 301 |
+
|
| 302 |
+
# Self Attention
|
| 303 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
|
| 304 |
+
attention_mask=attention_mask,
|
| 305 |
+
position_ids=position_ids,
|
| 306 |
+
past_key_value=past_key_value,
|
| 307 |
+
output_attentions=output_attentions,
|
| 308 |
+
use_cache=use_cache,
|
| 309 |
+
cos=cos,
|
| 310 |
+
sin=sin,
|
| 311 |
+
multibyte_decoding=multibyte_decoding)
|
| 312 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 313 |
+
|
| 314 |
+
# Fully Connected
|
| 315 |
+
residual = hidden_states
|
| 316 |
+
if self.config.fp32_skip_add:
|
| 317 |
+
residual = residual.float()
|
| 318 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 319 |
+
hidden_states = self.mlp(hidden_states)
|
| 320 |
+
hidden_states = (residual + hidden_states).to(hidden_states.dtype)
|
| 321 |
+
|
| 322 |
+
outputs = (hidden_states, )
|
| 323 |
+
|
| 324 |
+
if output_attentions:
|
| 325 |
+
outputs += (self_attn_weights, )
|
| 326 |
+
|
| 327 |
+
if use_cache:
|
| 328 |
+
outputs += (present_key_value, )
|
| 329 |
+
return outputs
|
| 330 |
+
|
| 331 |
+
class EvaBytePreTrainedModel(PreTrainedModel):
|
| 332 |
+
config_class = EvaByteConfig
|
| 333 |
+
base_model_prefix = "model"
|
| 334 |
+
supports_gradient_checkpointing = True
|
| 335 |
+
_no_split_modules = ["EvaByteDecoderLayer"]
|
| 336 |
+
_skip_keys_device_placement = "past_key_values"
|
| 337 |
+
|
| 338 |
+
def _init_weights(self, module):
|
| 339 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 340 |
+
if isinstance(module, nn.Linear):
|
| 341 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 342 |
+
if module.bias is not None:
|
| 343 |
+
module.bias.data.zero_()
|
| 344 |
+
elif isinstance(module, nn.Embedding):
|
| 345 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 346 |
+
if module.padding_idx is not None:
|
| 347 |
+
module.weight.data[module.padding_idx].zero_()
|
| 348 |
+
|
| 349 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 350 |
+
if isinstance(module, EvaByteModel):
|
| 351 |
+
module.gradient_checkpointing = value
|
| 352 |
+
|
| 353 |
+
class EvaByteModel(EvaBytePreTrainedModel):
|
| 354 |
+
"""
|
| 355 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
config: EvaByteConfig
|
| 359 |
+
"""
|
| 360 |
+
def __init__(self, config: EvaByteConfig):
|
| 361 |
+
super().__init__(config)
|
| 362 |
+
self.padding_idx = config.pad_token_id
|
| 363 |
+
self.vocab_size = config.vocab_size
|
| 364 |
+
self.hidden_size = config.hidden_size
|
| 365 |
+
self.num_heads = config.num_attention_heads
|
| 366 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 367 |
+
self.max_position_embeddings = self.config.max_position_embeddings
|
| 368 |
+
|
| 369 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 370 |
+
self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 371 |
+
self.norm = EvaByteRMSNorm(config)
|
| 372 |
+
|
| 373 |
+
self.gradient_checkpointing = False
|
| 374 |
+
self.rope = config.rope_theta
|
| 375 |
+
# Initialize weights and apply final processing
|
| 376 |
+
self.post_init()
|
| 377 |
+
self._init_rope()
|
| 378 |
+
|
| 379 |
+
def _init_rope(self):
|
| 380 |
+
if self.config.rope_scaling is None:
|
| 381 |
+
self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
|
| 382 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 383 |
+
base=self.rope)
|
| 384 |
+
else:
|
| 385 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 386 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 387 |
+
if scaling_type == "linear":
|
| 388 |
+
self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
|
| 389 |
+
self.head_dim,
|
| 390 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 391 |
+
scaling_factor=scaling_factor,
|
| 392 |
+
base=self.rope)
|
| 393 |
+
elif scaling_type == "dynamic":
|
| 394 |
+
self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
|
| 395 |
+
self.head_dim,
|
| 396 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 397 |
+
scaling_factor=scaling_factor,
|
| 398 |
+
base=self.rope)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 401 |
+
|
| 402 |
+
def get_input_embeddings(self):
|
| 403 |
+
return self.embed_tokens
|
| 404 |
+
|
| 405 |
+
def set_input_embeddings(self, value):
|
| 406 |
+
self.embed_tokens = value
|
| 407 |
+
|
| 408 |
+
def _helper_padding_mask(
|
| 409 |
+
self,
|
| 410 |
+
padding_mask,
|
| 411 |
+
causal_mask
|
| 412 |
+
):
|
| 413 |
+
padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
|
| 414 |
+
return torch.logical_or(padding_mask, causal_mask)
|
| 415 |
+
|
| 416 |
+
def _prepare_eva_generation_attn_mask_triton(
|
| 417 |
+
self,
|
| 418 |
+
attention_mask,
|
| 419 |
+
input_ids,
|
| 420 |
+
use_cache,
|
| 421 |
+
past_key_values
|
| 422 |
+
):
|
| 423 |
+
batch_size, seq_len = input_ids.shape
|
| 424 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 425 |
+
# decoding phase
|
| 426 |
+
if past_key_values.rf_mask[0] is not None:
|
| 427 |
+
cur_rf_mask = torch.zeros(
|
| 428 |
+
(batch_size, 1, seq_len, 1),
|
| 429 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 430 |
+
device=past_key_values.rf_mask[0].device
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
cur_rf_mask = None
|
| 434 |
+
|
| 435 |
+
if past_key_values.s_mask[0] is not None:
|
| 436 |
+
cur_s_mask = torch.zeros(
|
| 437 |
+
(batch_size, 1, seq_len, 1),
|
| 438 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 439 |
+
device=past_key_values.s_mask[0].device
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
cur_s_mask = None
|
| 443 |
+
|
| 444 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 445 |
+
if seen_tokens <= self.config.window_size:
|
| 446 |
+
rfa_chunks_dummy_mask = None
|
| 447 |
+
else:
|
| 448 |
+
if cur_s_mask is not None:
|
| 449 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 450 |
+
# the ongoing decoding step would be (seen_seq_len + 1)-th token
|
| 451 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 452 |
+
rfa_chunks_dummy_mask = torch.zeros(
|
| 453 |
+
(batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
|
| 454 |
+
dtype=past_key_values.s_mask[0].dtype,
|
| 455 |
+
device=past_key_values.s_mask[0].device
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
rfa_chunks_dummy_mask = None
|
| 459 |
+
# rf_mask and cur_mask are 0s because we do not want to mask them
|
| 460 |
+
return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
|
| 461 |
+
|
| 462 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 463 |
+
# convert 0 -> padding to 1 -> padding
|
| 464 |
+
padded_attention_mask = pad_to_multiple(
|
| 465 |
+
attention_mask,
|
| 466 |
+
self.config.window_size,
|
| 467 |
+
dim=-1,
|
| 468 |
+
value=0,
|
| 469 |
+
create_mask=False,
|
| 470 |
+
left_padding=False
|
| 471 |
+
)
|
| 472 |
+
# convert 0 -> padding to 1 -> padding
|
| 473 |
+
padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 474 |
+
# [b, 1, w, j, 1]
|
| 475 |
+
padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
|
| 476 |
+
# [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
|
| 477 |
+
w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
|
| 478 |
+
w_causal_mask = torch.ones(
|
| 479 |
+
(1, 1, 1, self.config.window_size, self.config.window_size),
|
| 480 |
+
device=input_ids.device
|
| 481 |
+
).triu(1).to(torch.bool)
|
| 482 |
+
s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
|
| 483 |
+
s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
|
| 484 |
+
s_mask = s_mask[..., :seq_len, :]
|
| 485 |
+
# negate the attention mask to get the padding mask
|
| 486 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 487 |
+
return (s_mask, rf_mask)
|
| 488 |
+
else:
|
| 489 |
+
return (None, None)
|
| 490 |
+
|
| 491 |
+
def _prepare_eva_generation_attn_mask(
|
| 492 |
+
self,
|
| 493 |
+
attention_mask,
|
| 494 |
+
input_ids,
|
| 495 |
+
use_cache,
|
| 496 |
+
past_key_values
|
| 497 |
+
):
|
| 498 |
+
batch_size, seq_len = input_ids.shape
|
| 499 |
+
if use_cache and past_key_values.get_seq_length() > 0:
|
| 500 |
+
# decoding phase
|
| 501 |
+
if past_key_values.rf_mask[0] is not None:
|
| 502 |
+
rf_mask = torch.zeros(
|
| 503 |
+
(batch_size, 1, seq_len, 1),
|
| 504 |
+
dtype=past_key_values.rf_mask[0].dtype,
|
| 505 |
+
device=past_key_values.rf_mask[0].device
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
rf_mask = None
|
| 509 |
+
|
| 510 |
+
cur_causal_mask = torch.zeros(
|
| 511 |
+
(batch_size, 1, seq_len, 1),
|
| 512 |
+
dtype=torch.bool,
|
| 513 |
+
device=input_ids.device
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
chunk_causal_mask = torch.ones(
|
| 517 |
+
(batch_size, 1, seq_len, 1),
|
| 518 |
+
dtype=torch.bool,
|
| 519 |
+
device=input_ids.device
|
| 520 |
+
)
|
| 521 |
+
# chunk_causal_mask are 1s because we will mask them by default and
|
| 522 |
+
# will be unmasked when the current singleton attention is processed over
|
| 523 |
+
return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
|
| 524 |
+
|
| 525 |
+
true_num_chunks = seq_len // self.config.chunk_size
|
| 526 |
+
chunk_causal_mask, _ = prepare_eva_attention_mask(
|
| 527 |
+
seq_len,
|
| 528 |
+
input_ids.device,
|
| 529 |
+
self.config.chunk_size,
|
| 530 |
+
self.config.window_size,
|
| 531 |
+
use_cache=use_cache,
|
| 532 |
+
cache=past_key_values
|
| 533 |
+
)
|
| 534 |
+
chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
|
| 535 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 536 |
+
# convert 0 -> padding to 1 -> padding
|
| 537 |
+
rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
|
| 538 |
+
else:
|
| 539 |
+
rf_mask = None
|
| 540 |
+
|
| 541 |
+
if seq_len < self.config.window_size:
|
| 542 |
+
cur_window_mask = torch.ones(
|
| 543 |
+
(1, 1, seq_len, seq_len),
|
| 544 |
+
device=input_ids.device
|
| 545 |
+
).triu(1).to(torch.bool)
|
| 546 |
+
if rf_mask is not None:
|
| 547 |
+
cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
|
| 548 |
+
prev_window_mask = None
|
| 549 |
+
else:
|
| 550 |
+
if seq_len % self.config.window_size == 0:
|
| 551 |
+
num_windows = seq_len // self.config.window_size
|
| 552 |
+
cur_window_mask = None
|
| 553 |
+
prev_window_mask = torch.ones(
|
| 554 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 555 |
+
device=input_ids.device
|
| 556 |
+
).triu(1).to(torch.bool)
|
| 557 |
+
if rf_mask is not None:
|
| 558 |
+
prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 559 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 560 |
+
else:
|
| 561 |
+
num_windows = seq_len // self.config.window_size
|
| 562 |
+
remainder_tokens = seq_len % self.config.window_size
|
| 563 |
+
cur_window_mask = torch.ones(
|
| 564 |
+
(1, 1, remainder_tokens, remainder_tokens),
|
| 565 |
+
device=input_ids.device
|
| 566 |
+
).triu(1).to(torch.bool)
|
| 567 |
+
prev_window_mask = torch.ones(
|
| 568 |
+
(1, 1, num_windows, self.config.window_size, self.config.window_size),
|
| 569 |
+
device=input_ids.device
|
| 570 |
+
).triu(1).to(torch.bool)
|
| 571 |
+
if rf_mask is not None:
|
| 572 |
+
prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
|
| 573 |
+
cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
|
| 574 |
+
prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
|
| 575 |
+
prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
|
| 576 |
+
|
| 577 |
+
return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
|
| 578 |
+
|
| 579 |
+
def forward(
|
| 580 |
+
self,
|
| 581 |
+
input_ids: torch.LongTensor = None,
|
| 582 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 583 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 584 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 585 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 586 |
+
use_cache: Optional[bool] = None,
|
| 587 |
+
output_attentions: Optional[bool] = None,
|
| 588 |
+
output_hidden_states: Optional[bool] = None,
|
| 589 |
+
return_dict: Optional[bool] = None,
|
| 590 |
+
multibyte_decoding: Optional[bool] = None,
|
| 591 |
+
) -> Tuple:
|
| 592 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 593 |
+
output_hidden_states = (output_hidden_states
|
| 594 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 595 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 596 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 597 |
+
|
| 598 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 599 |
+
raise ValueError(
|
| 600 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 604 |
+
raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 605 |
+
|
| 606 |
+
batch_size, seq_len = input_ids.shape
|
| 607 |
+
#### Step 0. Hack
|
| 608 |
+
if (not self.training) and (not use_cache) and (not multibyte_decoding):
|
| 609 |
+
# forward-only inference mode.
|
| 610 |
+
# We tweak use_cache to be True to reuse code for generation
|
| 611 |
+
use_cache = True
|
| 612 |
+
device = input_ids.device if input_ids is not None else None
|
| 613 |
+
if position_ids is None:
|
| 614 |
+
position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 615 |
+
|
| 616 |
+
#### Step 1. Prepare caches if in inference mode
|
| 617 |
+
if use_cache:
|
| 618 |
+
if past_key_values is not None:
|
| 619 |
+
assert isinstance(past_key_values, Cache)
|
| 620 |
+
else:
|
| 621 |
+
if not USE_TRITON_IMPL:
|
| 622 |
+
past_key_values = EvaCache()
|
| 623 |
+
else:
|
| 624 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 625 |
+
input_ids.shape[0],
|
| 626 |
+
self.config.num_attention_heads,
|
| 627 |
+
self.config.window_size,
|
| 628 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 629 |
+
self.config.num_hidden_layers,
|
| 630 |
+
self.embed_tokens.weight.dtype,
|
| 631 |
+
self.embed_tokens.weight.device,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if not multibyte_decoding:
|
| 635 |
+
if use_cache:
|
| 636 |
+
if USE_TRITON_IMPL:
|
| 637 |
+
causal_mask = self._prepare_eva_generation_attn_mask_triton(
|
| 638 |
+
attention_mask,
|
| 639 |
+
input_ids,
|
| 640 |
+
use_cache,
|
| 641 |
+
past_key_values
|
| 642 |
+
)
|
| 643 |
+
else:
|
| 644 |
+
causal_mask = self._prepare_eva_generation_attn_mask(
|
| 645 |
+
attention_mask,
|
| 646 |
+
input_ids,
|
| 647 |
+
use_cache,
|
| 648 |
+
past_key_values
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
assert self.training
|
| 652 |
+
assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
|
| 653 |
+
# for training, we need to pass in the attention mask
|
| 654 |
+
# usually calculated by _prepare_training_attn_mask()
|
| 655 |
+
causal_mask = attention_mask
|
| 656 |
+
else:
|
| 657 |
+
assert use_cache
|
| 658 |
+
causal_mask = attention_mask
|
| 659 |
+
|
| 660 |
+
if inputs_embeds is None:
|
| 661 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 662 |
+
|
| 663 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 664 |
+
max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
|
| 665 |
+
|
| 666 |
+
hidden_states = inputs_embeds
|
| 667 |
+
|
| 668 |
+
if position_ids is None:
|
| 669 |
+
assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
|
| 670 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 671 |
+
position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
|
| 672 |
+
|
| 673 |
+
cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
|
| 674 |
+
assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
|
| 675 |
+
assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
|
| 676 |
+
assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
|
| 677 |
+
cos = cos[position_ids, :]
|
| 678 |
+
sin = sin[position_ids, :]
|
| 679 |
+
cos = cos.unsqueeze(1)
|
| 680 |
+
sin = sin.unsqueeze(1)
|
| 681 |
+
|
| 682 |
+
# decoder layers
|
| 683 |
+
all_hidden_states = () if output_hidden_states else None
|
| 684 |
+
all_self_attns = () if output_attentions else None
|
| 685 |
+
next_decoder_cache = None
|
| 686 |
+
|
| 687 |
+
for decoder_layer in self.layers:
|
| 688 |
+
if output_hidden_states:
|
| 689 |
+
all_hidden_states += (hidden_states, )
|
| 690 |
+
|
| 691 |
+
if self.gradient_checkpointing and self.training:
|
| 692 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 693 |
+
decoder_layer.__call__,
|
| 694 |
+
hidden_states,
|
| 695 |
+
causal_mask,
|
| 696 |
+
position_ids,
|
| 697 |
+
past_key_values,
|
| 698 |
+
output_attentions,
|
| 699 |
+
use_cache,
|
| 700 |
+
cos,
|
| 701 |
+
sin,
|
| 702 |
+
multibyte_decoding,
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
layer_outputs = decoder_layer(
|
| 706 |
+
hidden_states,
|
| 707 |
+
attention_mask=causal_mask,
|
| 708 |
+
position_ids=position_ids,
|
| 709 |
+
past_key_value=past_key_values,
|
| 710 |
+
output_attentions=output_attentions,
|
| 711 |
+
use_cache=use_cache,
|
| 712 |
+
cos=cos,
|
| 713 |
+
sin=sin,
|
| 714 |
+
multibyte_decoding=multibyte_decoding,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
hidden_states = layer_outputs[0]
|
| 718 |
+
|
| 719 |
+
if use_cache:
|
| 720 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 721 |
+
|
| 722 |
+
if output_attentions:
|
| 723 |
+
all_self_attns += (layer_outputs[1], )
|
| 724 |
+
|
| 725 |
+
hidden_states = self.norm(hidden_states)
|
| 726 |
+
|
| 727 |
+
# add hidden states from the last decoder layer
|
| 728 |
+
if output_hidden_states:
|
| 729 |
+
all_hidden_states += (hidden_states, )
|
| 730 |
+
|
| 731 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 732 |
+
if not return_dict:
|
| 733 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 734 |
+
|
| 735 |
+
return BaseModelOutputWithPast(
|
| 736 |
+
last_hidden_state=hidden_states,
|
| 737 |
+
past_key_values=next_cache,
|
| 738 |
+
hidden_states=all_hidden_states,
|
| 739 |
+
attentions=all_self_attns,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
|
| 744 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 745 |
+
|
| 746 |
+
def __init__(self, config):
|
| 747 |
+
EvaBytePreTrainedModel.__init__(self, config)
|
| 748 |
+
|
| 749 |
+
self.model = EvaByteModel(config)
|
| 750 |
+
self.vocab_size = config.vocab_size
|
| 751 |
+
# define multibyte prediction heads
|
| 752 |
+
if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
|
| 753 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
|
| 754 |
+
else:
|
| 755 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 756 |
+
|
| 757 |
+
self.post_init()
|
| 758 |
+
|
| 759 |
+
def get_input_embeddings(self):
|
| 760 |
+
return self.model.embed_tokens
|
| 761 |
+
|
| 762 |
+
def set_input_embeddings(self, value):
|
| 763 |
+
self.model.embed_tokens = value
|
| 764 |
+
|
| 765 |
+
def get_output_embeddings(self):
|
| 766 |
+
return self.lm_head
|
| 767 |
+
|
| 768 |
+
def set_output_embeddings(self, new_embeddings):
|
| 769 |
+
self.lm_head = new_embeddings
|
| 770 |
+
|
| 771 |
+
def set_decoder(self, decoder):
|
| 772 |
+
self.model = decoder
|
| 773 |
+
|
| 774 |
+
def get_decoder(self):
|
| 775 |
+
return self.model
|
| 776 |
+
|
| 777 |
+
def forward(
|
| 778 |
+
self,
|
| 779 |
+
input_ids: torch.LongTensor = None,
|
| 780 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 781 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 782 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 783 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 784 |
+
labels: Optional[torch.LongTensor] = None,
|
| 785 |
+
use_cache: Optional[bool] = None,
|
| 786 |
+
output_attentions: Optional[bool] = None,
|
| 787 |
+
output_hidden_states: Optional[bool] = None,
|
| 788 |
+
return_dict: Optional[bool] = None,
|
| 789 |
+
return_all_pred_logits: Optional[bool] = None,
|
| 790 |
+
multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 791 |
+
|
| 792 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 793 |
+
output_hidden_states = (output_hidden_states
|
| 794 |
+
if output_hidden_states is not None else self.config.output_hidden_states)
|
| 795 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 796 |
+
|
| 797 |
+
if input_ids is None:
|
| 798 |
+
assert past_key_values is None
|
| 799 |
+
|
| 800 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 801 |
+
outputs = self.model(
|
| 802 |
+
input_ids=input_ids,
|
| 803 |
+
attention_mask=attention_mask,
|
| 804 |
+
position_ids=position_ids,
|
| 805 |
+
past_key_values=past_key_values,
|
| 806 |
+
inputs_embeds=inputs_embeds,
|
| 807 |
+
use_cache=use_cache,
|
| 808 |
+
output_attentions=output_attentions,
|
| 809 |
+
output_hidden_states=output_hidden_states,
|
| 810 |
+
return_dict=return_dict,
|
| 811 |
+
multibyte_decoding=multibyte_decoding,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
hidden_states = outputs[0]
|
| 815 |
+
|
| 816 |
+
logits = self.lm_head(hidden_states)
|
| 817 |
+
if self.config.fp32_logits:
|
| 818 |
+
logits = logits.float()
|
| 819 |
+
|
| 820 |
+
loss = None
|
| 821 |
+
if labels is not None:
|
| 822 |
+
loss_fct = CrossEntropyLoss(reduction="none")
|
| 823 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 824 |
+
shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 825 |
+
# shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
|
| 826 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 827 |
+
else:
|
| 828 |
+
shift_logits = logits.view(-1, self.config.vocab_size)
|
| 829 |
+
shift_labels = labels.view(-1)
|
| 830 |
+
# Enable model parallelism
|
| 831 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 832 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 833 |
+
|
| 834 |
+
if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
|
| 835 |
+
all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
|
| 836 |
+
|
| 837 |
+
if return_all_pred_logits:
|
| 838 |
+
logits = all_pred_logits
|
| 839 |
+
else:
|
| 840 |
+
logits = all_pred_logits[..., 0, :]
|
| 841 |
+
|
| 842 |
+
if not return_dict:
|
| 843 |
+
output = (logits, ) + outputs[1:]
|
| 844 |
+
return (loss, ) + output if loss is not None else output
|
| 845 |
+
|
| 846 |
+
return CausalLMOutputWithPast(
|
| 847 |
+
loss=loss,
|
| 848 |
+
logits=logits,
|
| 849 |
+
past_key_values=outputs.past_key_values,
|
| 850 |
+
hidden_states=outputs.hidden_states,
|
| 851 |
+
attentions=outputs.attentions,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def prepare_inputs_for_generation(self,
|
| 856 |
+
input_ids,
|
| 857 |
+
past_key_values=None,
|
| 858 |
+
attention_mask=None,
|
| 859 |
+
inputs_embeds=None,
|
| 860 |
+
use_cache=True,
|
| 861 |
+
**kwargs):
|
| 862 |
+
# prefill phase:
|
| 863 |
+
# input_ids: b x s
|
| 864 |
+
# attention_mask: None if no padding or b x s
|
| 865 |
+
# position_ids : b x s
|
| 866 |
+
|
| 867 |
+
# token gen phase:
|
| 868 |
+
# input_ids : b x 1
|
| 869 |
+
# attention_mask: b x 1 x s
|
| 870 |
+
# position_ids: b x 1
|
| 871 |
+
past_length = 0
|
| 872 |
+
if past_key_values is not None:
|
| 873 |
+
assert isinstance(past_key_values, Cache)
|
| 874 |
+
past_length = past_key_values.get_seq_length()
|
| 875 |
+
|
| 876 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 877 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
| 878 |
+
elif past_length < input_ids.shape[1]:
|
| 879 |
+
input_ids = input_ids[:, past_length:]
|
| 880 |
+
|
| 881 |
+
position_ids = kwargs.get("position_ids", None)
|
| 882 |
+
if attention_mask is not None and position_ids is None:
|
| 883 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 884 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 885 |
+
if past_key_values:
|
| 886 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 887 |
+
|
| 888 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 889 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 890 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 891 |
+
else:
|
| 892 |
+
model_inputs = {"input_ids": input_ids}
|
| 893 |
+
|
| 894 |
+
# must initialize position_ids at each step during GPU inference
|
| 895 |
+
assert position_ids is not None
|
| 896 |
+
model_inputs.update(
|
| 897 |
+
{
|
| 898 |
+
"position_ids": position_ids,
|
| 899 |
+
"past_key_values": past_key_values,
|
| 900 |
+
"use_cache": use_cache,
|
| 901 |
+
"attention_mask": attention_mask,
|
| 902 |
+
}
|
| 903 |
+
)
|
| 904 |
+
return model_inputs
|
| 905 |
+
|
| 906 |
+
@staticmethod
|
| 907 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 908 |
+
reordered_past = ()
|
| 909 |
+
for layer_past in past_key_values:
|
| 910 |
+
reordered_past += (tuple(
|
| 911 |
+
past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
|
| 912 |
+
return reordered_past
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/multibyte_decoding_evabyte.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# The implementation of multibyte deocidng is largely adapted from
|
| 3 |
+
# Medusa decoding: https://github.com/FasterDecoding/Medusa
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.generation.stopping_criteria import (
|
| 7 |
+
MaxLengthCriteria,
|
| 8 |
+
StoppingCriteriaList,
|
| 9 |
+
)
|
| 10 |
+
from typing import Union, List
|
| 11 |
+
from .eva_cache import EvaStaticCacheForTriton
|
| 12 |
+
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
|
| 13 |
+
|
| 14 |
+
class MultibyteEosTokenCriteria:
|
| 15 |
+
"""
|
| 16 |
+
This class implements a simple stopping criteria to stop generation whenever
|
| 17 |
+
the "end-of-sequence" token is generated in the last `new_tokens` tokens.
|
| 18 |
+
|
| 19 |
+
Adapted from
|
| 20 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
|
| 21 |
+
By default, it uses the `model.generation_config.eos_token_id`.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
eos_token_id (`Union[int, List[int]]`):
|
| 25 |
+
The id(s) of the *end-of-sequence* token.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, eos_token_ids: Union[int, List[int]]):
|
| 29 |
+
if isinstance(eos_token_ids, int):
|
| 30 |
+
eos_token_ids = [eos_token_ids]
|
| 31 |
+
self.eos_token_ids = eos_token_ids
|
| 32 |
+
|
| 33 |
+
def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
|
| 34 |
+
current_input_len = input_ids.shape[-1]
|
| 35 |
+
new_token_ids = input_ids[:, current_input_len - new_tokens:]
|
| 36 |
+
for eos_token_id in self.eos_token_ids:
|
| 37 |
+
if torch.any(new_token_ids == eos_token_id):
|
| 38 |
+
return True
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def build_tree(spec):
|
| 42 |
+
nodes_at_depth = []
|
| 43 |
+
nodes_at_depth.append([()]) # Root at depth 1
|
| 44 |
+
|
| 45 |
+
for d in range(1, len(spec) + 1):
|
| 46 |
+
prev_nodes = nodes_at_depth[d - 1]
|
| 47 |
+
spec_list = spec[d - 1]
|
| 48 |
+
current_nodes = []
|
| 49 |
+
for node_idx, node in enumerate(prev_nodes):
|
| 50 |
+
if node_idx < len(spec_list):
|
| 51 |
+
num_children = spec_list[node_idx]
|
| 52 |
+
else:
|
| 53 |
+
num_children = 0
|
| 54 |
+
for child_idx in range(num_children):
|
| 55 |
+
new_node = node + (child_idx,)
|
| 56 |
+
current_nodes.append(new_node)
|
| 57 |
+
nodes_at_depth.append(current_nodes)
|
| 58 |
+
|
| 59 |
+
# Flatten the list of nodes, excluding the root node if desired
|
| 60 |
+
all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
|
| 61 |
+
return all_nodes
|
| 62 |
+
|
| 63 |
+
evabyte_7b_95 = build_tree(
|
| 64 |
+
[
|
| 65 |
+
[10],
|
| 66 |
+
[10, 8, 2, 2, 1, 1],
|
| 67 |
+
[10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
|
| 68 |
+
[8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
|
| 69 |
+
[6, 2, 1, 1],
|
| 70 |
+
[4, 2, 1, 1],
|
| 71 |
+
[4, 2, 1],
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
evabyte_7b_31 = build_tree(
|
| 75 |
+
[
|
| 76 |
+
[4],
|
| 77 |
+
[3, 2, 1, 1],
|
| 78 |
+
[3, 2, 1, 1],
|
| 79 |
+
[2, 1, 1],
|
| 80 |
+
[2, 1],
|
| 81 |
+
[2, 1],
|
| 82 |
+
[2, 1],
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
|
| 86 |
+
|
| 87 |
+
def pad_path(path, length, pad_value=-2):
|
| 88 |
+
"""
|
| 89 |
+
Pad the given path list with a specific value up to a specified length.
|
| 90 |
+
|
| 91 |
+
Parameters:
|
| 92 |
+
- path (list): The original list that needs padding.
|
| 93 |
+
- length (int): The desired length of the padded list.
|
| 94 |
+
- pad_value (optional, default=-2): The value to use for padding.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
- list: A new list based on the original path but padded to the desired length.
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> pad_path([1,2,3], 5)
|
| 101 |
+
[1, 2, 3, -2, -2]
|
| 102 |
+
|
| 103 |
+
Note:
|
| 104 |
+
If the given path is already longer than the specified length,
|
| 105 |
+
then no padding occurs, and the original path is returned.
|
| 106 |
+
"""
|
| 107 |
+
return path + [pad_value] * (length - len(path))
|
| 108 |
+
|
| 109 |
+
def reset_past_key_values(passed_key_values):
|
| 110 |
+
"""
|
| 111 |
+
Resets the current lengths in the passed key-values to zero.
|
| 112 |
+
|
| 113 |
+
This function is designed to be used during the evaluation of a baseline model.
|
| 114 |
+
It iterates through each layer's key-values and sets their current lengths to zero,
|
| 115 |
+
effectively resetting their state.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
|
| 122 |
+
"""
|
| 123 |
+
for i in range(len(passed_key_values)):
|
| 124 |
+
for j in range(2):
|
| 125 |
+
passed_key_values[i][j].current_length.fill_(0)
|
| 126 |
+
return passed_key_values
|
| 127 |
+
|
| 128 |
+
def get_nucleus_one_token(logit, temperature, top_p):
|
| 129 |
+
"""
|
| 130 |
+
Performs token sampling based on the nucleus (top-p) sampling method.
|
| 131 |
+
|
| 132 |
+
This function selects a token from a given logit distribution using the nucleus sampling strategy.
|
| 133 |
+
It allows for more controlled and diverse generation compared to traditional top-k sampling.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
|
| 137 |
+
temperature (float): A temperature parameter to control the randomness in sampling.
|
| 138 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 139 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 140 |
+
It controls the size of the set of high-probability tokens to consider for sampling.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 144 |
+
"""
|
| 145 |
+
if top_p >= 1:
|
| 146 |
+
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
|
| 147 |
+
logit = logit / temperature
|
| 148 |
+
probs = torch.softmax(logit, dim=-1)
|
| 149 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 150 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 151 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 152 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 153 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 154 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 155 |
+
logit[indices_to_remove] = float('-inf')
|
| 156 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 157 |
+
return sampled_tokens
|
| 158 |
+
|
| 159 |
+
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
|
| 160 |
+
"""
|
| 161 |
+
Implements token sampling based on the typical sampling method.
|
| 162 |
+
|
| 163 |
+
This function selects a token from a given logit distribution using the typical sampling strategy,
|
| 164 |
+
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
|
| 168 |
+
temperature (float): A parameter to control the randomness in sampling.
|
| 169 |
+
Higher values increase diversity, lower values make selections more deterministic.
|
| 170 |
+
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
|
| 171 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor: A tensor containing the indices of the sampled tokens.
|
| 175 |
+
"""
|
| 176 |
+
logit = logit / temperature
|
| 177 |
+
probs = torch.softmax(logit, dim=-1)
|
| 178 |
+
entropy = -torch.sum(
|
| 179 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 180 |
+
)
|
| 181 |
+
threshold = torch.minimum(
|
| 182 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 183 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 184 |
+
)
|
| 185 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 186 |
+
logit[indices_to_remove] = float('-inf')
|
| 187 |
+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
|
| 188 |
+
return sampled_tokens
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def generate_medusa_buffers(medusa_choices, device="cuda"):
|
| 193 |
+
"""
|
| 194 |
+
Generate buffers for the Medusa structure based on the provided choices.
|
| 195 |
+
|
| 196 |
+
Parameters:
|
| 197 |
+
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
| 198 |
+
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
- dict: A dictionary containing buffers related to the Medusa structure.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
# Sort the medusa_choices based on their lengths and then their values
|
| 205 |
+
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
| 206 |
+
medusa_len = len(sorted_medusa_choices) + 1
|
| 207 |
+
|
| 208 |
+
# Initialize depth_counts to keep track of how many choices have a particular depth
|
| 209 |
+
depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
|
| 210 |
+
for path in sorted_medusa_choices:
|
| 211 |
+
depth_counts[len(path) - 1] += 1
|
| 212 |
+
|
| 213 |
+
# Create the attention mask for Medusa
|
| 214 |
+
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
| 215 |
+
medusa_attn_mask[:, 0] = 1
|
| 216 |
+
start = 0
|
| 217 |
+
for i in range(len(depth_counts)):
|
| 218 |
+
for j in range(depth_counts[i]):
|
| 219 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 220 |
+
# retrieve ancestor position
|
| 221 |
+
if len(cur_medusa_choice) == 1:
|
| 222 |
+
continue
|
| 223 |
+
ancestor_idx = []
|
| 224 |
+
for c in range(len(cur_medusa_choice) - 1):
|
| 225 |
+
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
|
| 226 |
+
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
| 227 |
+
start += depth_counts[i]
|
| 228 |
+
|
| 229 |
+
# Generate tree indices for the Medusa structure
|
| 230 |
+
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
| 231 |
+
medusa_tree_indices[0] = 0
|
| 232 |
+
start = 0
|
| 233 |
+
for i in range(len(depth_counts)):
|
| 234 |
+
for j in range(depth_counts[i]):
|
| 235 |
+
cur_medusa_choice = sorted_medusa_choices[start + j]
|
| 236 |
+
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
|
| 237 |
+
start += depth_counts[i]
|
| 238 |
+
|
| 239 |
+
# Generate position IDs for the Medusa structure
|
| 240 |
+
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
| 241 |
+
start = 0
|
| 242 |
+
for i in range(len(depth_counts)):
|
| 243 |
+
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
|
| 244 |
+
start += depth_counts[i]
|
| 245 |
+
|
| 246 |
+
# Generate retrieval indices for Medusa structure verification
|
| 247 |
+
retrieve_indices_nest = []
|
| 248 |
+
retrieve_paths = []
|
| 249 |
+
for i in range(len(sorted_medusa_choices)):
|
| 250 |
+
cur_medusa_choice = sorted_medusa_choices[-i-1]
|
| 251 |
+
retrieve_indice = []
|
| 252 |
+
if cur_medusa_choice in retrieve_paths:
|
| 253 |
+
continue
|
| 254 |
+
else:
|
| 255 |
+
for c in range(len(cur_medusa_choice)):
|
| 256 |
+
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
|
| 257 |
+
retrieve_paths.append(cur_medusa_choice[:c+1])
|
| 258 |
+
retrieve_indices_nest.append(retrieve_indice)
|
| 259 |
+
max_length = max([len(x) for x in retrieve_indices_nest])
|
| 260 |
+
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
|
| 261 |
+
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
| 262 |
+
retrieve_indices = retrieve_indices + 1
|
| 263 |
+
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
|
| 264 |
+
|
| 265 |
+
# Aggregate the generated buffers into a dictionary
|
| 266 |
+
medusa_buffers = {
|
| 267 |
+
"medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
|
| 268 |
+
"tree_indices": medusa_tree_indices,
|
| 269 |
+
"medusa_position_ids": medusa_position_ids.unsqueeze(0),
|
| 270 |
+
"retrieve_indices": retrieve_indices,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# Move the tensors in the dictionary to the specified device
|
| 274 |
+
medusa_buffers = {
|
| 275 |
+
k: v.clone().to(device)
|
| 276 |
+
if isinstance(v, torch.Tensor)
|
| 277 |
+
else torch.tensor(v, device=device)
|
| 278 |
+
for k, v in medusa_buffers.items()
|
| 279 |
+
}
|
| 280 |
+
return medusa_buffers
|
| 281 |
+
|
| 282 |
+
def generate_candidates(
|
| 283 |
+
medusa_logits,
|
| 284 |
+
logits,
|
| 285 |
+
tree_indices,
|
| 286 |
+
retrieve_indices,
|
| 287 |
+
temperature = 0,
|
| 288 |
+
posterior_threshold=0.3,
|
| 289 |
+
posterior_alpha = 0.09,
|
| 290 |
+
top_p=0.8,
|
| 291 |
+
sampling = 'typical',
|
| 292 |
+
fast = False
|
| 293 |
+
):
|
| 294 |
+
# Say we have 3 heads, and the top-4 for each head are:
|
| 295 |
+
# [10, 3, 8, 4]
|
| 296 |
+
# [9, 5, 1, 6]
|
| 297 |
+
# [7, 16, 3, 2]
|
| 298 |
+
|
| 299 |
+
# candidates_id = 10
|
| 300 |
+
if temperature == 0 or fast:
|
| 301 |
+
candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
|
| 302 |
+
else:
|
| 303 |
+
if sampling == 'typical':
|
| 304 |
+
candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
|
| 305 |
+
elif sampling == 'nucleus':
|
| 306 |
+
candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
|
| 307 |
+
else:
|
| 308 |
+
raise NotImplementedError
|
| 309 |
+
|
| 310 |
+
# this calculates the top-k medusa logits
|
| 311 |
+
# candidates_medusa_id = [
|
| 312 |
+
# [9, 5, 1, 6]
|
| 313 |
+
# [7, 16, 3, 2]
|
| 314 |
+
# ]
|
| 315 |
+
candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
|
| 316 |
+
|
| 317 |
+
# [10, 9, 5, 1, 6, 7, 16, 3, 2]
|
| 318 |
+
candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
|
| 319 |
+
|
| 320 |
+
# based on the pre-defined tree_indices, select the corresponding candidates
|
| 321 |
+
# if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
|
| 322 |
+
# tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
|
| 323 |
+
tree_candidate_ids = candidate_ids[tree_indices]
|
| 324 |
+
|
| 325 |
+
# tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
|
| 326 |
+
# Sometimes the tree_indices are padded, so we append a zero here
|
| 327 |
+
# so that all padded indices select the appended zero.
|
| 328 |
+
tree_candidate_ids_ext = torch.cat(
|
| 329 |
+
[
|
| 330 |
+
tree_candidate_ids,
|
| 331 |
+
torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
|
| 332 |
+
],
|
| 333 |
+
dim=0
|
| 334 |
+
)
|
| 335 |
+
# [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
|
| 336 |
+
unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
|
| 337 |
+
|
| 338 |
+
tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
|
| 339 |
+
|
| 340 |
+
return tree_candidate_ids, unflattened_candidate_ids
|
| 341 |
+
|
| 342 |
+
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
|
| 343 |
+
"""
|
| 344 |
+
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
|
| 345 |
+
|
| 346 |
+
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
|
| 347 |
+
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
|
| 348 |
+
temperature scaling and cumulative probability thresholding.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 352 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 353 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 354 |
+
top_p (float): The cumulative probability threshold for nucleus sampling.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 358 |
+
"""
|
| 359 |
+
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
|
| 360 |
+
|
| 361 |
+
# Apply temperature
|
| 362 |
+
logits = logits[:, :-1] / temperature
|
| 363 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 364 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 365 |
+
if top_p >= 1:
|
| 366 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 367 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 368 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 369 |
+
return posterior_mask
|
| 370 |
+
# Convert to probabilities (softmax)
|
| 371 |
+
probs = F.softmax(logits, dim=-1)
|
| 372 |
+
# Sort the probabilities
|
| 373 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
|
| 374 |
+
|
| 375 |
+
# Compute cumulative probabilities
|
| 376 |
+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
|
| 377 |
+
|
| 378 |
+
# Create mask for the top-p nucleus
|
| 379 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 380 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 381 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 382 |
+
|
| 383 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Remove low-probability tokens
|
| 387 |
+
logits[indices_to_remove] = float('-inf')
|
| 388 |
+
# Sample from the remaining tokens
|
| 389 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 390 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 391 |
+
# Create a mask for selected tokens
|
| 392 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 393 |
+
|
| 394 |
+
return posterior_mask
|
| 395 |
+
|
| 396 |
+
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
|
| 397 |
+
"""
|
| 398 |
+
Args:
|
| 399 |
+
logits (torch.Tensor): A tensor of logits from a language model output.
|
| 400 |
+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
|
| 401 |
+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
|
| 402 |
+
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
|
| 403 |
+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
|
| 407 |
+
"""
|
| 408 |
+
logits = logits[:, :-1] / temperature
|
| 409 |
+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
|
| 410 |
+
logits = logits.view(n_samples*n_tokens, -1)
|
| 411 |
+
probs = F.softmax(logits, dim=-1)
|
| 412 |
+
entropy = -torch.sum(
|
| 413 |
+
probs * torch.log(probs + 1e-5), dim=-1
|
| 414 |
+
)
|
| 415 |
+
threshold = torch.minimum(
|
| 416 |
+
torch.ones_like(entropy) * posterior_threshold,
|
| 417 |
+
torch.exp(-entropy) * posterior_alpha,
|
| 418 |
+
)
|
| 419 |
+
indices_to_remove = probs < threshold.unsqueeze(-1)
|
| 420 |
+
logits[indices_to_remove] = float('-inf')
|
| 421 |
+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 422 |
+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
|
| 423 |
+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
|
| 424 |
+
return posterior_mask
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def evaluate_posterior(
|
| 429 |
+
logits,
|
| 430 |
+
candidates,
|
| 431 |
+
temperature,
|
| 432 |
+
posterior_threshold=0.3,
|
| 433 |
+
posterior_alpha = 0.09,
|
| 434 |
+
top_p=0.8,
|
| 435 |
+
sampling = 'typical',
|
| 436 |
+
fast = True
|
| 437 |
+
):
|
| 438 |
+
if logits.shape[1] <= 1:
|
| 439 |
+
return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
|
| 440 |
+
# Greedy decoding based on temperature value
|
| 441 |
+
if temperature == 0:
|
| 442 |
+
# Find the tokens that match the maximum logits for each position in the sequence
|
| 443 |
+
posterior_mask = (
|
| 444 |
+
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
|
| 445 |
+
).int()
|
| 446 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 447 |
+
accept_length = candidates_accept_length.max().item()
|
| 448 |
+
# Choose the best candidate
|
| 449 |
+
if accept_length == 0:
|
| 450 |
+
# Default to the first candidate if none are accepted
|
| 451 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 452 |
+
else:
|
| 453 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 454 |
+
return best_candidate, accept_length
|
| 455 |
+
elif sampling == 'typical':
|
| 456 |
+
if fast:
|
| 457 |
+
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
|
| 458 |
+
candidates_prob = torch.gather(
|
| 459 |
+
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
|
| 460 |
+
).squeeze(-1)
|
| 461 |
+
posterior_entropy = -torch.sum(
|
| 462 |
+
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
|
| 463 |
+
) # torch.sum(torch.log(*)) is faster than torch.prod
|
| 464 |
+
threshold = torch.minimum(
|
| 465 |
+
torch.ones_like(posterior_entropy) * posterior_threshold,
|
| 466 |
+
torch.exp(-posterior_entropy) * posterior_alpha,
|
| 467 |
+
)
|
| 468 |
+
posterior_mask = candidates_prob > threshold
|
| 469 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 470 |
+
|
| 471 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 472 |
+
accept_length = candidates_accept_length.max().item()
|
| 473 |
+
if accept_length == 0:
|
| 474 |
+
# If no candidates are accepted, just choose the first one
|
| 475 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 476 |
+
else:
|
| 477 |
+
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
|
| 478 |
+
# Accept the best one according to likelihood
|
| 479 |
+
likelihood = torch.sum(
|
| 480 |
+
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
|
| 481 |
+
)
|
| 482 |
+
best_candidate = best_candidates[torch.argmax(likelihood)]
|
| 483 |
+
return best_candidate, accept_length
|
| 484 |
+
# Calculate posterior probabilities and thresholds for candidate selection
|
| 485 |
+
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
|
| 486 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 487 |
+
# Choose the best candidate based on the evaluated posterior probabilities
|
| 488 |
+
accept_length = candidates_accept_length.max().item()
|
| 489 |
+
|
| 490 |
+
if accept_length == 0:
|
| 491 |
+
# If no candidates are accepted, just choose the first one
|
| 492 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 493 |
+
else:
|
| 494 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 495 |
+
# Accept the best one according to likelihood
|
| 496 |
+
return best_candidate, accept_length
|
| 497 |
+
elif sampling == 'nucleus':
|
| 498 |
+
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
|
| 499 |
+
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
|
| 500 |
+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
|
| 501 |
+
accept_length = candidates_accept_length.max().item()
|
| 502 |
+
# Choose the best candidate
|
| 503 |
+
if accept_length == 0:
|
| 504 |
+
# Default to the first candidate if none are accepted
|
| 505 |
+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
|
| 506 |
+
else:
|
| 507 |
+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
|
| 508 |
+
return best_candidate, accept_length
|
| 509 |
+
else:
|
| 510 |
+
raise NotImplementedError
|
| 511 |
+
|
| 512 |
+
def update_inference_inputs(
|
| 513 |
+
input_ids,
|
| 514 |
+
medusa_logits,
|
| 515 |
+
logits,
|
| 516 |
+
candidate_ids,
|
| 517 |
+
best_candidate,
|
| 518 |
+
accept_length,
|
| 519 |
+
):
|
| 520 |
+
input_ids = torch.cat(
|
| 521 |
+
[
|
| 522 |
+
input_ids,
|
| 523 |
+
candidate_ids[None, best_candidate, : accept_length + 1]
|
| 524 |
+
],
|
| 525 |
+
dim=-1
|
| 526 |
+
)
|
| 527 |
+
logits = logits[
|
| 528 |
+
None, best_candidate, accept_length : accept_length + 1
|
| 529 |
+
]
|
| 530 |
+
medusa_logits = medusa_logits[
|
| 531 |
+
:, None, best_candidate, accept_length : accept_length + 1
|
| 532 |
+
]
|
| 533 |
+
# Update the new token counter
|
| 534 |
+
new_token = accept_length + 1
|
| 535 |
+
return input_ids, medusa_logits, logits, new_token
|
| 536 |
+
|
| 537 |
+
def split_logits(full_logits):
|
| 538 |
+
# logits has shape [b, n, heads, vocab_size]
|
| 539 |
+
logits = full_logits[..., 0, :]
|
| 540 |
+
medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
|
| 541 |
+
return medusa_logits, logits
|
| 542 |
+
|
| 543 |
+
class MultiByteDecodingMixin:
|
| 544 |
+
def multi_byte_pred_update_cache(
|
| 545 |
+
self,
|
| 546 |
+
past_key_values,
|
| 547 |
+
retrieve_indices,
|
| 548 |
+
best_candidate,
|
| 549 |
+
new_tokens,
|
| 550 |
+
):
|
| 551 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 552 |
+
select_indices = (
|
| 553 |
+
retrieve_indices[best_candidate, : new_tokens] + prev_window_len
|
| 554 |
+
)
|
| 555 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 556 |
+
|
| 557 |
+
past_key_values.update_past_len(new_tokens, layer_idx)
|
| 558 |
+
|
| 559 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 560 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 561 |
+
|
| 562 |
+
tgt_window_k = past_window_k[..., select_indices, :]
|
| 563 |
+
tgt_window_v = past_window_v[..., select_indices, :]
|
| 564 |
+
|
| 565 |
+
dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 566 |
+
dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
|
| 567 |
+
|
| 568 |
+
dst_window_k.copy_(tgt_window_k, non_blocking=True)
|
| 569 |
+
dst_window_v.copy_(tgt_window_v, non_blocking=True)
|
| 570 |
+
|
| 571 |
+
new_window_len = prev_window_len + new_tokens
|
| 572 |
+
if new_window_len >= self.config.window_size:
|
| 573 |
+
assert new_window_len < 2 * self.config.window_size
|
| 574 |
+
|
| 575 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 576 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 577 |
+
|
| 578 |
+
_window_len = new_window_len - self.config.window_size
|
| 579 |
+
|
| 580 |
+
if _window_len > 0:
|
| 581 |
+
new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
|
| 582 |
+
new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
|
| 583 |
+
|
| 584 |
+
_dst_window_k = past_window_k[..., : _window_len, :]
|
| 585 |
+
_dst_window_v = past_window_v[..., : _window_len, :]
|
| 586 |
+
|
| 587 |
+
_dst_window_k.copy_(new_window_k, non_blocking=True)
|
| 588 |
+
_dst_window_v.copy_(new_window_v, non_blocking=True)
|
| 589 |
+
|
| 590 |
+
past_key_values.past_window_pos[layer_idx] = _window_len
|
| 591 |
+
else:
|
| 592 |
+
dump_k = None
|
| 593 |
+
dump_v = None
|
| 594 |
+
past_key_values.past_window_pos[layer_idx] = new_window_len
|
| 595 |
+
|
| 596 |
+
if dump_k is not None and dump_v is not None:
|
| 597 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 598 |
+
dump_k, dump_v,
|
| 599 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 600 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 601 |
+
None,
|
| 602 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 603 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 604 |
+
)
|
| 605 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 606 |
+
rfa_k, rfa_v, layer_idx
|
| 607 |
+
)
|
| 608 |
+
return past_key_values
|
| 609 |
+
|
| 610 |
+
def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 611 |
+
self,
|
| 612 |
+
past_key_values,
|
| 613 |
+
):
|
| 614 |
+
prev_window_len = past_key_values.get_past_window_pos(0)
|
| 615 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 616 |
+
|
| 617 |
+
past_window_k = past_key_values.past_window_k[layer_idx]
|
| 618 |
+
past_window_v = past_key_values.past_window_v[layer_idx]
|
| 619 |
+
|
| 620 |
+
new_window_len = prev_window_len
|
| 621 |
+
if new_window_len == self.config.window_size:
|
| 622 |
+
dump_k = past_window_k[..., :self.config.window_size, :].clone()
|
| 623 |
+
dump_v = past_window_v[..., :self.config.window_size, :].clone()
|
| 624 |
+
past_key_values.past_window_pos[layer_idx] = 0
|
| 625 |
+
|
| 626 |
+
if dump_k is not None and dump_v is not None:
|
| 627 |
+
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
|
| 628 |
+
dump_k, dump_v,
|
| 629 |
+
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
|
| 630 |
+
self.model.layers[layer_idx].self_attn.adaptive_phi,
|
| 631 |
+
None,
|
| 632 |
+
self.model.layers[layer_idx].self_attn.head_dim_scaling,
|
| 633 |
+
self.model.layers[layer_idx].self_attn.chunk_size
|
| 634 |
+
)
|
| 635 |
+
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
|
| 636 |
+
rfa_k, rfa_v, layer_idx
|
| 637 |
+
)
|
| 638 |
+
return past_key_values
|
| 639 |
+
|
| 640 |
+
def multi_byte_pred_update_attn_mask(
|
| 641 |
+
self,
|
| 642 |
+
last_iter_new_tokens,
|
| 643 |
+
tree_candidate_ids,
|
| 644 |
+
past_attn_mask,
|
| 645 |
+
medusa_attn_mask,
|
| 646 |
+
past_key_values,
|
| 647 |
+
):
|
| 648 |
+
batch_size, tree_candidate_len = tree_candidate_ids.shape
|
| 649 |
+
seen_tokens = past_key_values.get_seq_length()
|
| 650 |
+
# NOTE: past_key_values has been updated so now
|
| 651 |
+
# seen_tokens incldues new tokens from the last tree iteration
|
| 652 |
+
assert seen_tokens > 0
|
| 653 |
+
# so one iteration would not cross two windows
|
| 654 |
+
assert last_iter_new_tokens < self.config.window_size
|
| 655 |
+
|
| 656 |
+
if past_attn_mask is not None and seen_tokens < self.config.window_size:
|
| 657 |
+
past_attn_mask = torch.cat(
|
| 658 |
+
[
|
| 659 |
+
past_attn_mask,
|
| 660 |
+
torch.ones(
|
| 661 |
+
[batch_size, 1, tree_candidate_len, last_iter_new_tokens],
|
| 662 |
+
dtype=torch.bool,
|
| 663 |
+
device=self.device
|
| 664 |
+
)
|
| 665 |
+
],
|
| 666 |
+
dim=-1
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
# we initialize attn mask each time when
|
| 670 |
+
# 1. the model crosses the window bounary, or
|
| 671 |
+
# 2. after prefilling
|
| 672 |
+
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
|
| 673 |
+
|
| 674 |
+
window_tokens = seen_tokens % self.config.window_size
|
| 675 |
+
num_windows_seen_so_far = seen_tokens // self.config.window_size
|
| 676 |
+
attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
|
| 677 |
+
past_attn_mask = torch.ones(
|
| 678 |
+
(batch_size, 1, tree_candidate_len, attn_mask_len),
|
| 679 |
+
dtype=torch.bool,
|
| 680 |
+
device=self.device
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# note that 1 indicates the position is not masked
|
| 684 |
+
tree_attn_mask = torch.cat(
|
| 685 |
+
[
|
| 686 |
+
past_attn_mask,
|
| 687 |
+
medusa_attn_mask.to(torch.bool)
|
| 688 |
+
],
|
| 689 |
+
dim=-1
|
| 690 |
+
)
|
| 691 |
+
return tree_attn_mask, past_attn_mask
|
| 692 |
+
|
| 693 |
+
@torch.no_grad()
|
| 694 |
+
def multi_byte_generate(
|
| 695 |
+
self,
|
| 696 |
+
input_ids,
|
| 697 |
+
attention_mask=None,
|
| 698 |
+
temperature=0.0,
|
| 699 |
+
max_length=None,
|
| 700 |
+
max_new_tokens=None,
|
| 701 |
+
stopping_criteria=None,
|
| 702 |
+
posterior_threshold=0.09,
|
| 703 |
+
posterior_alpha=0.3,
|
| 704 |
+
top_p=0.8,
|
| 705 |
+
sampling='typical',
|
| 706 |
+
fast=True,
|
| 707 |
+
do_sample=False,
|
| 708 |
+
medusa_choices=None,
|
| 709 |
+
return_acc_lengths=False
|
| 710 |
+
):
|
| 711 |
+
if do_sample or temperature > 0.0:
|
| 712 |
+
fast = False
|
| 713 |
+
|
| 714 |
+
### Prepare `max_length` depending on other stopping criteria.
|
| 715 |
+
if max_new_tokens is not None:
|
| 716 |
+
max_length = max_new_tokens + input_ids.shape[-1]
|
| 717 |
+
elif max_new_tokens is None and max_length is None:
|
| 718 |
+
max_length = getattr(self.config, "max_position_embeddings", 32768)
|
| 719 |
+
|
| 720 |
+
### Set up stopping criteria
|
| 721 |
+
eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
|
| 722 |
+
stop_criteria = StoppingCriteriaList()
|
| 723 |
+
if max_length is not None:
|
| 724 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 725 |
+
stop_criteria.append(
|
| 726 |
+
MaxLengthCriteria(
|
| 727 |
+
max_length=max_length,
|
| 728 |
+
max_position_embeddings=max_position_embeddings,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
if stopping_criteria is not None and len(stopping_criteria) > 0:
|
| 732 |
+
stop_criteria.extend(stopping_criteria)
|
| 733 |
+
|
| 734 |
+
assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
|
| 735 |
+
assert attention_mask is None, "Only support attention mask None for now"
|
| 736 |
+
# Avoid modifying the input_ids in-place
|
| 737 |
+
input_ids = input_ids.clone()
|
| 738 |
+
position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
|
| 739 |
+
|
| 740 |
+
####################################################
|
| 741 |
+
# 0. initialize the medusa buffers
|
| 742 |
+
####################################################
|
| 743 |
+
if medusa_choices is None:
|
| 744 |
+
medusa_choices = evabyte_7b_95
|
| 745 |
+
medusa_buffers = generate_medusa_buffers(
|
| 746 |
+
medusa_choices, device=self.device
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
past_key_values = EvaStaticCacheForTriton(
|
| 750 |
+
input_ids.shape[0],
|
| 751 |
+
self.config.num_attention_heads,
|
| 752 |
+
# we add 256 to allow tree ids
|
| 753 |
+
self.config.window_size + 256,
|
| 754 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
| 755 |
+
self.config.num_hidden_layers,
|
| 756 |
+
self.lm_head.weight.dtype,
|
| 757 |
+
self.lm_head.weight.device,
|
| 758 |
+
)
|
| 759 |
+
# prefill to get medusa logits and logits
|
| 760 |
+
full_logits, past_key_values = self.forward(
|
| 761 |
+
input_ids,
|
| 762 |
+
attention_mask=attention_mask,
|
| 763 |
+
position_ids=position_ids,
|
| 764 |
+
use_cache=True,
|
| 765 |
+
past_key_values=past_key_values,
|
| 766 |
+
return_all_pred_logits=True,
|
| 767 |
+
multibyte_decoding=False,
|
| 768 |
+
)
|
| 769 |
+
# handles an edge case where the prefill length == window_size
|
| 770 |
+
# we force the previous window to be dumped into RFA chunks
|
| 771 |
+
past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
|
| 772 |
+
past_key_values
|
| 773 |
+
)
|
| 774 |
+
medusa_logits, logits = split_logits(full_logits)
|
| 775 |
+
|
| 776 |
+
past_attn_mask = None
|
| 777 |
+
last_iter_new_tokens = 0
|
| 778 |
+
max_iters = 32768
|
| 779 |
+
if return_acc_lengths:
|
| 780 |
+
acc_lengths = []
|
| 781 |
+
for _ in range(max_iters):
|
| 782 |
+
####################################################
|
| 783 |
+
# 1. generate candidate_ids with topk predictions from Medusa heads
|
| 784 |
+
####################################################
|
| 785 |
+
tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
|
| 786 |
+
medusa_logits,
|
| 787 |
+
logits,
|
| 788 |
+
medusa_buffers["tree_indices"],
|
| 789 |
+
medusa_buffers["retrieve_indices"],
|
| 790 |
+
temperature=temperature,
|
| 791 |
+
posterior_alpha=posterior_alpha,
|
| 792 |
+
posterior_threshold=posterior_threshold,
|
| 793 |
+
top_p=top_p,
|
| 794 |
+
sampling=sampling,
|
| 795 |
+
fast=fast,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
####################################################
|
| 799 |
+
# 2. Build the medusa attention mask and position ids
|
| 800 |
+
####################################################
|
| 801 |
+
# NOTE: 1 indicates the position is not masked
|
| 802 |
+
medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
|
| 803 |
+
last_iter_new_tokens,
|
| 804 |
+
tree_candidate_ids,
|
| 805 |
+
past_attn_mask,
|
| 806 |
+
medusa_buffers["medusa_attn_mask"],
|
| 807 |
+
past_key_values,
|
| 808 |
+
)
|
| 809 |
+
medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
|
| 810 |
+
|
| 811 |
+
####################################################
|
| 812 |
+
# 3. tree decoding
|
| 813 |
+
####################################################
|
| 814 |
+
tree_full_logits, past_key_values = self.forward(
|
| 815 |
+
tree_candidate_ids,
|
| 816 |
+
past_key_values=past_key_values,
|
| 817 |
+
attention_mask=medusa_attn_mask,
|
| 818 |
+
position_ids=medusa_position_ids,
|
| 819 |
+
return_all_pred_logits=True,
|
| 820 |
+
multibyte_decoding=True,
|
| 821 |
+
)
|
| 822 |
+
_medusa_logits, _logits = split_logits(tree_full_logits)
|
| 823 |
+
medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 824 |
+
logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
|
| 825 |
+
|
| 826 |
+
####################################################
|
| 827 |
+
# 4. candidate selection
|
| 828 |
+
####################################################
|
| 829 |
+
# if the current iteration, with tree tokens, crosses window
|
| 830 |
+
# boundaries, trim the condidate_ids to be within the window
|
| 831 |
+
# so that those exceeded tokens (which will be inaccurate)
|
| 832 |
+
# will not be considered
|
| 833 |
+
tree_depth = unflattened_candidate_ids.shape[-1]
|
| 834 |
+
if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
|
| 835 |
+
max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
|
| 836 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
|
| 837 |
+
_trimmed_logits = logits[:, :max_acc_len]
|
| 838 |
+
else:
|
| 839 |
+
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids
|
| 840 |
+
_trimmed_logits = logits
|
| 841 |
+
best_candidate, accept_length = evaluate_posterior(
|
| 842 |
+
_trimmed_logits,
|
| 843 |
+
_trimmed_unflattened_candidate_ids,
|
| 844 |
+
temperature,
|
| 845 |
+
posterior_threshold,
|
| 846 |
+
posterior_alpha,
|
| 847 |
+
top_p=top_p,
|
| 848 |
+
sampling=sampling,
|
| 849 |
+
fast=fast
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
####################################################
|
| 853 |
+
# 5. update model inputs and caches
|
| 854 |
+
####################################################
|
| 855 |
+
input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
|
| 856 |
+
input_ids,
|
| 857 |
+
medusa_logits,
|
| 858 |
+
logits,
|
| 859 |
+
unflattened_candidate_ids,
|
| 860 |
+
best_candidate,
|
| 861 |
+
accept_length,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
past_key_values = self.multi_byte_pred_update_cache(
|
| 865 |
+
past_key_values,
|
| 866 |
+
medusa_buffers["retrieve_indices"],
|
| 867 |
+
best_candidate,
|
| 868 |
+
last_iter_new_tokens,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
if return_acc_lengths:
|
| 872 |
+
acc_lengths.append(last_iter_new_tokens)
|
| 873 |
+
if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
|
| 874 |
+
if return_acc_lengths:
|
| 875 |
+
return input_ids, acc_lengths
|
| 876 |
+
else:
|
| 877 |
+
return input_ids
|
| 878 |
+
if return_acc_lengths:
|
| 879 |
+
return input_ids, acc_lengths
|
| 880 |
+
else:
|
| 881 |
+
return input_ids
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/preprocessor_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
|
| 4 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 5 |
+
},
|
| 6 |
+
"do_convert_rgb": true,
|
| 7 |
+
"do_resize": true,
|
| 8 |
+
"image_processor_type": "EvaByteImageProcessor",
|
| 9 |
+
"jpeg_quality": 25,
|
| 10 |
+
"jpeg_restart_marker_blocks": 1,
|
| 11 |
+
"jpeg_streamtype": 2,
|
| 12 |
+
"jpeg_subsampling": "4:2:0",
|
| 13 |
+
"processor_class": "EvaByteProcessor",
|
| 14 |
+
"resample": 1,
|
| 15 |
+
"size": {
|
| 16 |
+
"longest_edge": 384
|
| 17 |
+
}
|
| 18 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processing_evabyte.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""
|
| 3 |
+
Processor class for EvaByte.
|
| 4 |
+
"""
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
import os
|
| 10 |
+
import PIL
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from typing import List, Optional, Union
|
| 14 |
+
|
| 15 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
+
from transformers.image_utils import ImageInput, is_valid_image
|
| 17 |
+
from transformers.processing_utils import ProcessorMixin
|
| 18 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 19 |
+
from transformers.utils import TensorType, to_py_obj
|
| 20 |
+
|
| 21 |
+
def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
|
| 22 |
+
image_obj = None
|
| 23 |
+
if isinstance(image, Image.Image):
|
| 24 |
+
image_obj = image
|
| 25 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 26 |
+
image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
|
| 27 |
+
elif os.path.isfile(image):
|
| 28 |
+
image_obj = Image.open(image)
|
| 29 |
+
elif image.startswith("data:image/"):
|
| 30 |
+
image = image.split(",")[1]
|
| 31 |
+
# Try to load as base64
|
| 32 |
+
try:
|
| 33 |
+
b64 = base64.decodebytes(image.encode())
|
| 34 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
image_obj = Image.open(image)
|
| 41 |
+
if image_obj is None:
|
| 42 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 43 |
+
|
| 44 |
+
return image_obj
|
| 45 |
+
|
| 46 |
+
def is_url(val) -> bool:
|
| 47 |
+
return isinstance(val, str) and val.startswith("http")
|
| 48 |
+
|
| 49 |
+
def is_file(val) -> bool:
|
| 50 |
+
return isinstance(val, str) and os.path.isfile(val)
|
| 51 |
+
|
| 52 |
+
def is_image_or_image_url(elem):
|
| 53 |
+
return is_url(elem) or is_valid_image(elem) or is_file(elem)
|
| 54 |
+
|
| 55 |
+
vl_chat_template = """
|
| 56 |
+
{{- bos_token }}
|
| 57 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 58 |
+
{%- set system_message = messages[0]['content'] %}
|
| 59 |
+
{%- set messages = messages[1:] %}
|
| 60 |
+
{%- else %}
|
| 61 |
+
{%- set system_message = "" %}
|
| 62 |
+
{%- endif %}
|
| 63 |
+
|
| 64 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 65 |
+
|
| 66 |
+
{%- for message in messages %}
|
| 67 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 68 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
|
| 71 |
+
{%- if message['content'] is string %}
|
| 72 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 73 |
+
{%- else %}
|
| 74 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
|
| 75 |
+
{%- for content in message['content'] %}
|
| 76 |
+
{%- if content['type'] == 'image' %}
|
| 77 |
+
{{- '<image_placeholder>\n' }}
|
| 78 |
+
{%- elif content['type'] == 'text' %}
|
| 79 |
+
{{- content['text'] }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- endfor %}
|
| 82 |
+
{{- '<|eot_id|>' }}
|
| 83 |
+
{%- endif %}
|
| 84 |
+
{%- endfor %}
|
| 85 |
+
|
| 86 |
+
{%- if add_generation_prompt %}
|
| 87 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
class EvaByteProcessor(ProcessorMixin):
|
| 92 |
+
r"""
|
| 93 |
+
Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
|
| 94 |
+
|
| 95 |
+
[`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
|
| 96 |
+
[`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image_processor ([`EvaByteImageProcessor`], *optional*):
|
| 100 |
+
The image processor is a required input.
|
| 101 |
+
tokenizer ([`EvaByteTokenizer`], *optional*):
|
| 102 |
+
The tokenizer is a required input.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
attributes = ["image_processor", "tokenizer"]
|
| 106 |
+
image_processor_class = "AutoImageProcessor"
|
| 107 |
+
tokenizer_class = "AutoTokenizer"
|
| 108 |
+
|
| 109 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 110 |
+
if image_processor is None:
|
| 111 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 112 |
+
if tokenizer is None:
|
| 113 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 114 |
+
|
| 115 |
+
super().__init__(image_processor, tokenizer)
|
| 116 |
+
self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
|
| 117 |
+
self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
|
| 118 |
+
self.image_placeholder = "<image_placeholder>"
|
| 119 |
+
self.vl_chat_template = vl_chat_template
|
| 120 |
+
|
| 121 |
+
def __call__(
|
| 122 |
+
self,
|
| 123 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 124 |
+
images: ImageInput = None,
|
| 125 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 126 |
+
strip_ending_sentinel: bool = False,
|
| 127 |
+
encode_only: bool = False,
|
| 128 |
+
**kwargs
|
| 129 |
+
) -> Union[BatchFeature, List[List[int]]]:
|
| 130 |
+
# processing pipeline:
|
| 131 |
+
# 1. read images or videos from paths
|
| 132 |
+
# 2. use image_processor to convert images / videos to byte streams
|
| 133 |
+
if images is not None:
|
| 134 |
+
if isinstance(images, bytes):
|
| 135 |
+
image_bytes_list = [[images]]
|
| 136 |
+
elif isinstance(images, list) and isinstance(images[0], bytes):
|
| 137 |
+
image_bytes_list = [images]
|
| 138 |
+
elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
|
| 139 |
+
image_bytes_list = images
|
| 140 |
+
else:
|
| 141 |
+
if is_image_or_image_url(images):
|
| 142 |
+
images = [[images]]
|
| 143 |
+
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
| 144 |
+
images = [images]
|
| 145 |
+
elif (
|
| 146 |
+
not isinstance(images, list)
|
| 147 |
+
and not isinstance(images[0], list)
|
| 148 |
+
and not is_image_or_image_url(images[0][0])
|
| 149 |
+
):
|
| 150 |
+
raise ValueError(
|
| 151 |
+
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
| 152 |
+
)
|
| 153 |
+
# Load images if they are URLs
|
| 154 |
+
images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
|
| 155 |
+
image_bytes_list = self.image_processor(images=images, **kwargs)
|
| 156 |
+
|
| 157 |
+
if not isinstance(text, list):
|
| 158 |
+
text = [text]
|
| 159 |
+
assert len(text) == 1, "Only support batch size 1 for now"
|
| 160 |
+
assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
|
| 161 |
+
# TODO: invoke SequenceFeatureExtractor to get batched inputs
|
| 162 |
+
|
| 163 |
+
# 3. tokenize the text and put images / videos byte streams into the placeholders
|
| 164 |
+
# surrounded by special tokens like "<image>" and "</image>"
|
| 165 |
+
batch_input_ids = []
|
| 166 |
+
if not encode_only:
|
| 167 |
+
batch_attention_mask = []
|
| 168 |
+
else:
|
| 169 |
+
batch_attention_mask = None
|
| 170 |
+
|
| 171 |
+
for t, image_bytes in zip(text, image_bytes_list):
|
| 172 |
+
text_splits = t.split(self.image_placeholder)
|
| 173 |
+
if len(text_splits) != len(image_bytes) + 1:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"The number of image tokens should be equal to the number of images, "
|
| 176 |
+
f"but got {len(text_splits)} and {len(image_bytes) + 1}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
input_ids = [self.tokenizer.bos_token_id]
|
| 180 |
+
for i, text_part in enumerate(text_splits):
|
| 181 |
+
# each text part must be non-empty because we added markers around placeholders
|
| 182 |
+
split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
|
| 183 |
+
input_ids.extend(split_tokens)
|
| 184 |
+
# Add image bytes after each text part except the last one
|
| 185 |
+
if i < len(image_bytes):
|
| 186 |
+
input_ids.append(self.t2v_token_id)
|
| 187 |
+
input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
|
| 188 |
+
input_ids.append(self.v2t_token_id)
|
| 189 |
+
|
| 190 |
+
if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
|
| 191 |
+
input_ids = input_ids[:-1]
|
| 192 |
+
|
| 193 |
+
batch_input_ids.append(input_ids)
|
| 194 |
+
if not encode_only:
|
| 195 |
+
batch_attention_mask.append([1] * len(input_ids))
|
| 196 |
+
|
| 197 |
+
if not encode_only:
|
| 198 |
+
# 4. return batch of features
|
| 199 |
+
inputs = BatchFeature({
|
| 200 |
+
"input_ids": batch_input_ids,
|
| 201 |
+
"attention_mask": batch_attention_mask
|
| 202 |
+
}, tensor_type=return_tensors)
|
| 203 |
+
return inputs
|
| 204 |
+
# # Pad sequences
|
| 205 |
+
# padded_inputs = self.tokenizer.pad(
|
| 206 |
+
# {"input_ids": batch_input_ids},
|
| 207 |
+
# padding=True,
|
| 208 |
+
# return_attention_mask=True,
|
| 209 |
+
# return_tensors=return_tensors,
|
| 210 |
+
# )
|
| 211 |
+
# return BatchFeature(data=padded_inputs)
|
| 212 |
+
else:
|
| 213 |
+
return batch_input_ids
|
| 214 |
+
|
| 215 |
+
def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
|
| 216 |
+
image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
|
| 217 |
+
image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
|
| 218 |
+
return image_bytes
|
| 219 |
+
|
| 220 |
+
def batch_decode(self, sequences, **kwargs):
|
| 221 |
+
"""
|
| 222 |
+
This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 223 |
+
refer to the docstring of this method for more information.
|
| 224 |
+
"""
|
| 225 |
+
rets = [self.decode(seq, **kwargs) for seq in sequences]
|
| 226 |
+
return tuple(map(list, zip(*rets)))
|
| 227 |
+
|
| 228 |
+
def decode(self, token_ids, **kwargs):
|
| 229 |
+
"""
|
| 230 |
+
Decodes a sequence of input_ids, handling image tokens separately.
|
| 231 |
+
Returns a tuple of (decoded_text, images), where images is a list of bytes.
|
| 232 |
+
"""
|
| 233 |
+
if kwargs and "jpeg_quality" in kwargs:
|
| 234 |
+
kwargs = kwargs.copy()
|
| 235 |
+
jpeg_quality = kwargs.pop("jpeg_quality")
|
| 236 |
+
else:
|
| 237 |
+
jpeg_quality = None
|
| 238 |
+
|
| 239 |
+
token_ids = to_py_obj(token_ids)
|
| 240 |
+
# Find indices of t2v_token_id and v2t_token_id
|
| 241 |
+
t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
|
| 242 |
+
v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
|
| 243 |
+
|
| 244 |
+
# Check for correct pairing of t2v and v2t tokens
|
| 245 |
+
if len(t2v_indices) != len(v2t_indices):
|
| 246 |
+
raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
|
| 247 |
+
|
| 248 |
+
# Ensure t2v and v2t tokens are in the correct order
|
| 249 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 250 |
+
if t2v_idx >= v2t_idx:
|
| 251 |
+
raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
|
| 252 |
+
|
| 253 |
+
# Initialize the start index
|
| 254 |
+
images = []
|
| 255 |
+
decoded_text = ""
|
| 256 |
+
|
| 257 |
+
start = 0
|
| 258 |
+
# Iterate over pairs of t2v and v2t indices
|
| 259 |
+
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
|
| 260 |
+
# Decode text tokens before the image
|
| 261 |
+
text_token_ids = token_ids[start:t2v_idx]
|
| 262 |
+
if len(text_token_ids) > 0:
|
| 263 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 264 |
+
|
| 265 |
+
# Insert image placeholder
|
| 266 |
+
decoded_text += self.image_placeholder
|
| 267 |
+
|
| 268 |
+
# Extract image tokens and convert them to bytes
|
| 269 |
+
image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
|
| 270 |
+
image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
|
| 271 |
+
images.append(image_bytes)
|
| 272 |
+
|
| 273 |
+
# Update the start index to the token after v2t_token_id
|
| 274 |
+
start = v2t_idx + 1
|
| 275 |
+
|
| 276 |
+
# Decode any remaining text tokens after the last image
|
| 277 |
+
if start < len(token_ids):
|
| 278 |
+
text_token_ids = token_ids[start:]
|
| 279 |
+
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
|
| 280 |
+
|
| 281 |
+
return decoded_text, images
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def model_input_names(self):
|
| 285 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 286 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 287 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "EvaByteProcessor"
|
| 6 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/special_tokens_map.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<repo_name>",
|
| 4 |
+
"<file_sep>",
|
| 5 |
+
"<t2v_token>",
|
| 6 |
+
"<v2t_token>",
|
| 7 |
+
"<|start_header_id|>",
|
| 8 |
+
"<|end_header_id|>",
|
| 9 |
+
"<|eot_id|>",
|
| 10 |
+
"<extra_id_12>",
|
| 11 |
+
"<extra_id_13>",
|
| 12 |
+
"<extra_id_14>",
|
| 13 |
+
"<extra_id_15>",
|
| 14 |
+
"<extra_id_16>",
|
| 15 |
+
"<extra_id_17>",
|
| 16 |
+
"<extra_id_18>",
|
| 17 |
+
"<extra_id_19>",
|
| 18 |
+
"<extra_id_20>",
|
| 19 |
+
"<extra_id_21>",
|
| 20 |
+
"<extra_id_22>",
|
| 21 |
+
"<extra_id_23>",
|
| 22 |
+
"<extra_id_24>",
|
| 23 |
+
"<extra_id_25>",
|
| 24 |
+
"<extra_id_26>",
|
| 25 |
+
"<extra_id_27>",
|
| 26 |
+
"<extra_id_28>",
|
| 27 |
+
"<extra_id_29>",
|
| 28 |
+
"<extra_id_30>",
|
| 29 |
+
"<extra_id_31>",
|
| 30 |
+
"<extra_id_32>",
|
| 31 |
+
"<extra_id_33>",
|
| 32 |
+
"<extra_id_34>",
|
| 33 |
+
"<extra_id_35>",
|
| 34 |
+
"<extra_id_36>",
|
| 35 |
+
"<extra_id_37>",
|
| 36 |
+
"<extra_id_38>",
|
| 37 |
+
"<extra_id_39>",
|
| 38 |
+
"<extra_id_40>",
|
| 39 |
+
"<extra_id_41>",
|
| 40 |
+
"<extra_id_42>",
|
| 41 |
+
"<extra_id_43>",
|
| 42 |
+
"<extra_id_44>",
|
| 43 |
+
"<extra_id_45>",
|
| 44 |
+
"<extra_id_46>",
|
| 45 |
+
"<extra_id_47>",
|
| 46 |
+
"<extra_id_48>",
|
| 47 |
+
"<extra_id_49>",
|
| 48 |
+
"<extra_id_50>",
|
| 49 |
+
"<extra_id_51>",
|
| 50 |
+
"<extra_id_52>",
|
| 51 |
+
"<extra_id_53>",
|
| 52 |
+
"<extra_id_54>",
|
| 53 |
+
"<extra_id_55>",
|
| 54 |
+
"<extra_id_56>",
|
| 55 |
+
"<extra_id_57>",
|
| 56 |
+
"<extra_id_58>",
|
| 57 |
+
"<extra_id_59>",
|
| 58 |
+
"<extra_id_60>",
|
| 59 |
+
"<extra_id_61>",
|
| 60 |
+
"<extra_id_62>",
|
| 61 |
+
"<extra_id_63>"
|
| 62 |
+
],
|
| 63 |
+
"bos_token": {
|
| 64 |
+
"content": "<bos>",
|
| 65 |
+
"lstrip": false,
|
| 66 |
+
"normalized": true,
|
| 67 |
+
"rstrip": false,
|
| 68 |
+
"single_word": false
|
| 69 |
+
},
|
| 70 |
+
"eos_token": {
|
| 71 |
+
"content": "<eos>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": true,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false
|
| 76 |
+
},
|
| 77 |
+
"pad_token": {
|
| 78 |
+
"content": "<pad>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": true,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false
|
| 83 |
+
},
|
| 84 |
+
"sep_token": {
|
| 85 |
+
"content": "<sep>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": true,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false
|
| 90 |
+
},
|
| 91 |
+
"unk_token": {
|
| 92 |
+
"content": "<unk>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false
|
| 97 |
+
}
|
| 98 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenization_evabyte.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
""" Tokenization class for model EvaByte."""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
chat_template = """
|
| 16 |
+
{{- bos_token }}
|
| 17 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 18 |
+
{%- set system_message = messages[0]['content'] %}
|
| 19 |
+
{%- set messages = messages[1:] %}
|
| 20 |
+
{%- else %}
|
| 21 |
+
{%- set system_message = "" %}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
|
| 24 |
+
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
|
| 25 |
+
|
| 26 |
+
{%- for message in messages %}
|
| 27 |
+
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
|
| 28 |
+
{{- raise_exception('Conversation roles must be user or assistant') }}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
|
| 31 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
| 32 |
+
{%- endfor %}
|
| 33 |
+
|
| 34 |
+
{%- if add_generation_prompt %}
|
| 35 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 36 |
+
{%- endif %}
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
class EvaByteTokenizer(PreTrainedTokenizer):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
bos_token="<bos>",
|
| 43 |
+
eos_token="<eos>",
|
| 44 |
+
unk_token="<unk>",
|
| 45 |
+
sep_token="<sep>",
|
| 46 |
+
pad_token="<pad>",
|
| 47 |
+
extra_ids=59,
|
| 48 |
+
additional_special_tokens=None,
|
| 49 |
+
clean_up_tokenization_spaces=False,
|
| 50 |
+
**kwargs,
|
| 51 |
+
) -> None:
|
| 52 |
+
num_base_special_tokens = 5
|
| 53 |
+
# Add extra_ids to the special token list
|
| 54 |
+
if extra_ids > 0 and additional_special_tokens is None:
|
| 55 |
+
additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
|
| 56 |
+
elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
|
| 57 |
+
# Check that we have the right number of extra_id special tokens
|
| 58 |
+
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
| 59 |
+
if extra_tokens != extra_ids:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
|
| 62 |
+
" provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
|
| 63 |
+
" extra_ids tokens"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
#### override some reserved tokens to support chat template
|
| 67 |
+
for i, token in enumerate(additional_special_tokens):
|
| 68 |
+
if token == "<extra_id_5>":
|
| 69 |
+
token = "<repo_name>"
|
| 70 |
+
elif token == "<extra_id_6>":
|
| 71 |
+
token = "<file_sep>"
|
| 72 |
+
elif token == "<extra_id_7>":
|
| 73 |
+
token = "<t2v_token>"
|
| 74 |
+
elif token == "<extra_id_8>":
|
| 75 |
+
token = "<v2t_token>"
|
| 76 |
+
elif token == "<extra_id_9>":
|
| 77 |
+
token = "<|start_header_id|>"
|
| 78 |
+
elif token == "<extra_id_10>":
|
| 79 |
+
token = "<|end_header_id|>"
|
| 80 |
+
elif token == "<extra_id_11>":
|
| 81 |
+
token = "<|eot_id|>"
|
| 82 |
+
additional_special_tokens[i] = token
|
| 83 |
+
|
| 84 |
+
# lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
|
| 85 |
+
# this would be important for the byte tokenizer
|
| 86 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
| 87 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
| 88 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
| 89 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 90 |
+
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
| 91 |
+
|
| 92 |
+
self._added_tokens_decoder = {
|
| 93 |
+
0: pad_token,
|
| 94 |
+
1: bos_token,
|
| 95 |
+
2: eos_token,
|
| 96 |
+
3: unk_token, # unk_token is a placeholder
|
| 97 |
+
4: sep_token,
|
| 98 |
+
**{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
|
| 99 |
+
}
|
| 100 |
+
self.offset = len(self._added_tokens_decoder)
|
| 101 |
+
self._utf_vocab_size = 2**8 # utf is 8 bits
|
| 102 |
+
self.add_bos_token = True
|
| 103 |
+
self.add_eos_token = False
|
| 104 |
+
super().__init__(
|
| 105 |
+
pad_token=pad_token,
|
| 106 |
+
bos_token=bos_token,
|
| 107 |
+
eos_token=eos_token,
|
| 108 |
+
unk_token=unk_token,
|
| 109 |
+
sep_token=sep_token,
|
| 110 |
+
extra_ids=0,
|
| 111 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 112 |
+
additional_special_tokens=additional_special_tokens,
|
| 113 |
+
**kwargs,
|
| 114 |
+
)
|
| 115 |
+
self.chat_template = chat_template
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def vocab_size(self):
|
| 120 |
+
return self._utf_vocab_size
|
| 121 |
+
|
| 122 |
+
def get_vocab(self):
|
| 123 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
| 124 |
+
vocab.update(self.added_tokens_encoder)
|
| 125 |
+
return vocab
|
| 126 |
+
|
| 127 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 128 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 129 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 130 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 131 |
+
|
| 132 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 133 |
+
|
| 134 |
+
if token_ids_1 is not None:
|
| 135 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 136 |
+
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 140 |
+
def get_special_tokens_mask(
|
| 141 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 142 |
+
) -> List[int]:
|
| 143 |
+
"""
|
| 144 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 145 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (`List[int]`):
|
| 149 |
+
List of IDs.
|
| 150 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 153 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 157 |
+
"""
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
return super().get_special_tokens_mask(
|
| 160 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 164 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is None:
|
| 167 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 168 |
+
return (
|
| 169 |
+
bos_token_id
|
| 170 |
+
+ ([0] * len(token_ids_0))
|
| 171 |
+
+ eos_token_id
|
| 172 |
+
+ bos_token_id
|
| 173 |
+
+ ([0] * len(token_ids_1))
|
| 174 |
+
+ eos_token_id
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 178 |
+
def create_token_type_ids_from_sequences(
|
| 179 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 180 |
+
) -> List[int]:
|
| 181 |
+
"""
|
| 182 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 183 |
+
sequence pair mask has the following format:
|
| 184 |
+
|
| 185 |
+
```
|
| 186 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 187 |
+
| first sequence | second sequence |
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
token_ids_0 (`List[int]`):
|
| 194 |
+
List of ids.
|
| 195 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 196 |
+
Optional second list of IDs for sequence pairs.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 200 |
+
"""
|
| 201 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 202 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 203 |
+
|
| 204 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 205 |
+
|
| 206 |
+
if token_ids_1 is not None:
|
| 207 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 208 |
+
|
| 209 |
+
return output
|
| 210 |
+
|
| 211 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 212 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
| 213 |
+
tokens = [chr(i) for i in text.encode("utf-8")]
|
| 214 |
+
return tokens
|
| 215 |
+
|
| 216 |
+
def _convert_token_to_id(self, token):
|
| 217 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 218 |
+
|
| 219 |
+
if len(token) != 1:
|
| 220 |
+
token_id = None
|
| 221 |
+
else:
|
| 222 |
+
token_id = ord(token) + self.offset
|
| 223 |
+
|
| 224 |
+
return token_id
|
| 225 |
+
|
| 226 |
+
def _convert_id_to_token(self, index):
|
| 227 |
+
"""Converts an index (integer) to a byte (str) using the vocab."""
|
| 228 |
+
token = chr(index - self.offset)
|
| 229 |
+
return token
|
| 230 |
+
|
| 231 |
+
def convert_tokens_to_string(self, tokens):
|
| 232 |
+
"""Converts a sequence of bytes (string) to a single string."""
|
| 233 |
+
bstring = b""
|
| 234 |
+
for token in tokens:
|
| 235 |
+
if token in self.added_tokens_decoder:
|
| 236 |
+
tok_string = self.added_tokens_decoder[token].encode("utf-8")
|
| 237 |
+
elif token in self.added_tokens_encoder:
|
| 238 |
+
tok_string = token.encode("utf-8")
|
| 239 |
+
else:
|
| 240 |
+
tok_string = bytes([ord(token)])
|
| 241 |
+
bstring += tok_string
|
| 242 |
+
string = bstring.decode("utf-8", errors="ignore")
|
| 243 |
+
return string
|
| 244 |
+
|
| 245 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 246 |
+
return ()
|
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenizer_config.json
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": true,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<bos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": true,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "<eos>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": true,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "<sep>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": true,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"5": {
|
| 44 |
+
"content": "<repo_name>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": true,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": false
|
| 50 |
+
},
|
| 51 |
+
"6": {
|
| 52 |
+
"content": "<file_sep>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": true,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": false
|
| 58 |
+
},
|
| 59 |
+
"7": {
|
| 60 |
+
"content": "<t2v_token>",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": true,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": false
|
| 66 |
+
},
|
| 67 |
+
"8": {
|
| 68 |
+
"content": "<v2t_token>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": true,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": false
|
| 74 |
+
},
|
| 75 |
+
"9": {
|
| 76 |
+
"content": "<|start_header_id|>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": true,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": false
|
| 82 |
+
},
|
| 83 |
+
"10": {
|
| 84 |
+
"content": "<|end_header_id|>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": true,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": false
|
| 90 |
+
},
|
| 91 |
+
"11": {
|
| 92 |
+
"content": "<|eot_id|>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": true,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": false
|
| 98 |
+
},
|
| 99 |
+
"12": {
|
| 100 |
+
"content": "<extra_id_12>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": true,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": false
|
| 106 |
+
},
|
| 107 |
+
"13": {
|
| 108 |
+
"content": "<extra_id_13>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": true,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": false
|
| 114 |
+
},
|
| 115 |
+
"14": {
|
| 116 |
+
"content": "<extra_id_14>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": true,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": false
|
| 122 |
+
},
|
| 123 |
+
"15": {
|
| 124 |
+
"content": "<extra_id_15>",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": true,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": false
|
| 130 |
+
},
|
| 131 |
+
"16": {
|
| 132 |
+
"content": "<extra_id_16>",
|
| 133 |
+
"lstrip": false,
|
| 134 |
+
"normalized": true,
|
| 135 |
+
"rstrip": false,
|
| 136 |
+
"single_word": false,
|
| 137 |
+
"special": false
|
| 138 |
+
},
|
| 139 |
+
"17": {
|
| 140 |
+
"content": "<extra_id_17>",
|
| 141 |
+
"lstrip": false,
|
| 142 |
+
"normalized": true,
|
| 143 |
+
"rstrip": false,
|
| 144 |
+
"single_word": false,
|
| 145 |
+
"special": false
|
| 146 |
+
},
|
| 147 |
+
"18": {
|
| 148 |
+
"content": "<extra_id_18>",
|
| 149 |
+
"lstrip": false,
|
| 150 |
+
"normalized": true,
|
| 151 |
+
"rstrip": false,
|
| 152 |
+
"single_word": false,
|
| 153 |
+
"special": false
|
| 154 |
+
},
|
| 155 |
+
"19": {
|
| 156 |
+
"content": "<extra_id_19>",
|
| 157 |
+
"lstrip": false,
|
| 158 |
+
"normalized": true,
|
| 159 |
+
"rstrip": false,
|
| 160 |
+
"single_word": false,
|
| 161 |
+
"special": false
|
| 162 |
+
},
|
| 163 |
+
"20": {
|
| 164 |
+
"content": "<extra_id_20>",
|
| 165 |
+
"lstrip": false,
|
| 166 |
+
"normalized": true,
|
| 167 |
+
"rstrip": false,
|
| 168 |
+
"single_word": false,
|
| 169 |
+
"special": false
|
| 170 |
+
},
|
| 171 |
+
"21": {
|
| 172 |
+
"content": "<extra_id_21>",
|
| 173 |
+
"lstrip": false,
|
| 174 |
+
"normalized": true,
|
| 175 |
+
"rstrip": false,
|
| 176 |
+
"single_word": false,
|
| 177 |
+
"special": false
|
| 178 |
+
},
|
| 179 |
+
"22": {
|
| 180 |
+
"content": "<extra_id_22>",
|
| 181 |
+
"lstrip": false,
|
| 182 |
+
"normalized": true,
|
| 183 |
+
"rstrip": false,
|
| 184 |
+
"single_word": false,
|
| 185 |
+
"special": false
|
| 186 |
+
},
|
| 187 |
+
"23": {
|
| 188 |
+
"content": "<extra_id_23>",
|
| 189 |
+
"lstrip": false,
|
| 190 |
+
"normalized": true,
|
| 191 |
+
"rstrip": false,
|
| 192 |
+
"single_word": false,
|
| 193 |
+
"special": false
|
| 194 |
+
},
|
| 195 |
+
"24": {
|
| 196 |
+
"content": "<extra_id_24>",
|
| 197 |
+
"lstrip": false,
|
| 198 |
+
"normalized": true,
|
| 199 |
+
"rstrip": false,
|
| 200 |
+
"single_word": false,
|
| 201 |
+
"special": false
|
| 202 |
+
},
|
| 203 |
+
"25": {
|
| 204 |
+
"content": "<extra_id_25>",
|
| 205 |
+
"lstrip": false,
|
| 206 |
+
"normalized": true,
|
| 207 |
+
"rstrip": false,
|
| 208 |
+
"single_word": false,
|
| 209 |
+
"special": false
|
| 210 |
+
},
|
| 211 |
+
"26": {
|
| 212 |
+
"content": "<extra_id_26>",
|
| 213 |
+
"lstrip": false,
|
| 214 |
+
"normalized": true,
|
| 215 |
+
"rstrip": false,
|
| 216 |
+
"single_word": false,
|
| 217 |
+
"special": false
|
| 218 |
+
},
|
| 219 |
+
"27": {
|
| 220 |
+
"content": "<extra_id_27>",
|
| 221 |
+
"lstrip": false,
|
| 222 |
+
"normalized": true,
|
| 223 |
+
"rstrip": false,
|
| 224 |
+
"single_word": false,
|
| 225 |
+
"special": false
|
| 226 |
+
},
|
| 227 |
+
"28": {
|
| 228 |
+
"content": "<extra_id_28>",
|
| 229 |
+
"lstrip": false,
|
| 230 |
+
"normalized": true,
|
| 231 |
+
"rstrip": false,
|
| 232 |
+
"single_word": false,
|
| 233 |
+
"special": false
|
| 234 |
+
},
|
| 235 |
+
"29": {
|
| 236 |
+
"content": "<extra_id_29>",
|
| 237 |
+
"lstrip": false,
|
| 238 |
+
"normalized": true,
|
| 239 |
+
"rstrip": false,
|
| 240 |
+
"single_word": false,
|
| 241 |
+
"special": false
|
| 242 |
+
},
|
| 243 |
+
"30": {
|
| 244 |
+
"content": "<extra_id_30>",
|
| 245 |
+
"lstrip": false,
|
| 246 |
+
"normalized": true,
|
| 247 |
+
"rstrip": false,
|
| 248 |
+
"single_word": false,
|
| 249 |
+
"special": false
|
| 250 |
+
},
|
| 251 |
+
"31": {
|
| 252 |
+
"content": "<extra_id_31>",
|
| 253 |
+
"lstrip": false,
|
| 254 |
+
"normalized": true,
|
| 255 |
+
"rstrip": false,
|
| 256 |
+
"single_word": false,
|
| 257 |
+
"special": false
|
| 258 |
+
},
|
| 259 |
+
"32": {
|
| 260 |
+
"content": "<extra_id_32>",
|
| 261 |
+
"lstrip": false,
|
| 262 |
+
"normalized": true,
|
| 263 |
+
"rstrip": false,
|
| 264 |
+
"single_word": false,
|
| 265 |
+
"special": false
|
| 266 |
+
},
|
| 267 |
+
"33": {
|
| 268 |
+
"content": "<extra_id_33>",
|
| 269 |
+
"lstrip": false,
|
| 270 |
+
"normalized": true,
|
| 271 |
+
"rstrip": false,
|
| 272 |
+
"single_word": false,
|
| 273 |
+
"special": false
|
| 274 |
+
},
|
| 275 |
+
"34": {
|
| 276 |
+
"content": "<extra_id_34>",
|
| 277 |
+
"lstrip": false,
|
| 278 |
+
"normalized": true,
|
| 279 |
+
"rstrip": false,
|
| 280 |
+
"single_word": false,
|
| 281 |
+
"special": false
|
| 282 |
+
},
|
| 283 |
+
"35": {
|
| 284 |
+
"content": "<extra_id_35>",
|
| 285 |
+
"lstrip": false,
|
| 286 |
+
"normalized": true,
|
| 287 |
+
"rstrip": false,
|
| 288 |
+
"single_word": false,
|
| 289 |
+
"special": false
|
| 290 |
+
},
|
| 291 |
+
"36": {
|
| 292 |
+
"content": "<extra_id_36>",
|
| 293 |
+
"lstrip": false,
|
| 294 |
+
"normalized": true,
|
| 295 |
+
"rstrip": false,
|
| 296 |
+
"single_word": false,
|
| 297 |
+
"special": false
|
| 298 |
+
},
|
| 299 |
+
"37": {
|
| 300 |
+
"content": "<extra_id_37>",
|
| 301 |
+
"lstrip": false,
|
| 302 |
+
"normalized": true,
|
| 303 |
+
"rstrip": false,
|
| 304 |
+
"single_word": false,
|
| 305 |
+
"special": false
|
| 306 |
+
},
|
| 307 |
+
"38": {
|
| 308 |
+
"content": "<extra_id_38>",
|
| 309 |
+
"lstrip": false,
|
| 310 |
+
"normalized": true,
|
| 311 |
+
"rstrip": false,
|
| 312 |
+
"single_word": false,
|
| 313 |
+
"special": false
|
| 314 |
+
},
|
| 315 |
+
"39": {
|
| 316 |
+
"content": "<extra_id_39>",
|
| 317 |
+
"lstrip": false,
|
| 318 |
+
"normalized": true,
|
| 319 |
+
"rstrip": false,
|
| 320 |
+
"single_word": false,
|
| 321 |
+
"special": false
|
| 322 |
+
},
|
| 323 |
+
"40": {
|
| 324 |
+
"content": "<extra_id_40>",
|
| 325 |
+
"lstrip": false,
|
| 326 |
+
"normalized": true,
|
| 327 |
+
"rstrip": false,
|
| 328 |
+
"single_word": false,
|
| 329 |
+
"special": false
|
| 330 |
+
},
|
| 331 |
+
"41": {
|
| 332 |
+
"content": "<extra_id_41>",
|
| 333 |
+
"lstrip": false,
|
| 334 |
+
"normalized": true,
|
| 335 |
+
"rstrip": false,
|
| 336 |
+
"single_word": false,
|
| 337 |
+
"special": false
|
| 338 |
+
},
|
| 339 |
+
"42": {
|
| 340 |
+
"content": "<extra_id_42>",
|
| 341 |
+
"lstrip": false,
|
| 342 |
+
"normalized": true,
|
| 343 |
+
"rstrip": false,
|
| 344 |
+
"single_word": false,
|
| 345 |
+
"special": false
|
| 346 |
+
},
|
| 347 |
+
"43": {
|
| 348 |
+
"content": "<extra_id_43>",
|
| 349 |
+
"lstrip": false,
|
| 350 |
+
"normalized": true,
|
| 351 |
+
"rstrip": false,
|
| 352 |
+
"single_word": false,
|
| 353 |
+
"special": false
|
| 354 |
+
},
|
| 355 |
+
"44": {
|
| 356 |
+
"content": "<extra_id_44>",
|
| 357 |
+
"lstrip": false,
|
| 358 |
+
"normalized": true,
|
| 359 |
+
"rstrip": false,
|
| 360 |
+
"single_word": false,
|
| 361 |
+
"special": false
|
| 362 |
+
},
|
| 363 |
+
"45": {
|
| 364 |
+
"content": "<extra_id_45>",
|
| 365 |
+
"lstrip": false,
|
| 366 |
+
"normalized": true,
|
| 367 |
+
"rstrip": false,
|
| 368 |
+
"single_word": false,
|
| 369 |
+
"special": false
|
| 370 |
+
},
|
| 371 |
+
"46": {
|
| 372 |
+
"content": "<extra_id_46>",
|
| 373 |
+
"lstrip": false,
|
| 374 |
+
"normalized": true,
|
| 375 |
+
"rstrip": false,
|
| 376 |
+
"single_word": false,
|
| 377 |
+
"special": false
|
| 378 |
+
},
|
| 379 |
+
"47": {
|
| 380 |
+
"content": "<extra_id_47>",
|
| 381 |
+
"lstrip": false,
|
| 382 |
+
"normalized": true,
|
| 383 |
+
"rstrip": false,
|
| 384 |
+
"single_word": false,
|
| 385 |
+
"special": false
|
| 386 |
+
},
|
| 387 |
+
"48": {
|
| 388 |
+
"content": "<extra_id_48>",
|
| 389 |
+
"lstrip": false,
|
| 390 |
+
"normalized": true,
|
| 391 |
+
"rstrip": false,
|
| 392 |
+
"single_word": false,
|
| 393 |
+
"special": false
|
| 394 |
+
},
|
| 395 |
+
"49": {
|
| 396 |
+
"content": "<extra_id_49>",
|
| 397 |
+
"lstrip": false,
|
| 398 |
+
"normalized": true,
|
| 399 |
+
"rstrip": false,
|
| 400 |
+
"single_word": false,
|
| 401 |
+
"special": false
|
| 402 |
+
},
|
| 403 |
+
"50": {
|
| 404 |
+
"content": "<extra_id_50>",
|
| 405 |
+
"lstrip": false,
|
| 406 |
+
"normalized": true,
|
| 407 |
+
"rstrip": false,
|
| 408 |
+
"single_word": false,
|
| 409 |
+
"special": false
|
| 410 |
+
},
|
| 411 |
+
"51": {
|
| 412 |
+
"content": "<extra_id_51>",
|
| 413 |
+
"lstrip": false,
|
| 414 |
+
"normalized": true,
|
| 415 |
+
"rstrip": false,
|
| 416 |
+
"single_word": false,
|
| 417 |
+
"special": false
|
| 418 |
+
},
|
| 419 |
+
"52": {
|
| 420 |
+
"content": "<extra_id_52>",
|
| 421 |
+
"lstrip": false,
|
| 422 |
+
"normalized": true,
|
| 423 |
+
"rstrip": false,
|
| 424 |
+
"single_word": false,
|
| 425 |
+
"special": false
|
| 426 |
+
},
|
| 427 |
+
"53": {
|
| 428 |
+
"content": "<extra_id_53>",
|
| 429 |
+
"lstrip": false,
|
| 430 |
+
"normalized": true,
|
| 431 |
+
"rstrip": false,
|
| 432 |
+
"single_word": false,
|
| 433 |
+
"special": false
|
| 434 |
+
},
|
| 435 |
+
"54": {
|
| 436 |
+
"content": "<extra_id_54>",
|
| 437 |
+
"lstrip": false,
|
| 438 |
+
"normalized": true,
|
| 439 |
+
"rstrip": false,
|
| 440 |
+
"single_word": false,
|
| 441 |
+
"special": false
|
| 442 |
+
},
|
| 443 |
+
"55": {
|
| 444 |
+
"content": "<extra_id_55>",
|
| 445 |
+
"lstrip": false,
|
| 446 |
+
"normalized": true,
|
| 447 |
+
"rstrip": false,
|
| 448 |
+
"single_word": false,
|
| 449 |
+
"special": false
|
| 450 |
+
},
|
| 451 |
+
"56": {
|
| 452 |
+
"content": "<extra_id_56>",
|
| 453 |
+
"lstrip": false,
|
| 454 |
+
"normalized": true,
|
| 455 |
+
"rstrip": false,
|
| 456 |
+
"single_word": false,
|
| 457 |
+
"special": false
|
| 458 |
+
},
|
| 459 |
+
"57": {
|
| 460 |
+
"content": "<extra_id_57>",
|
| 461 |
+
"lstrip": false,
|
| 462 |
+
"normalized": true,
|
| 463 |
+
"rstrip": false,
|
| 464 |
+
"single_word": false,
|
| 465 |
+
"special": false
|
| 466 |
+
},
|
| 467 |
+
"58": {
|
| 468 |
+
"content": "<extra_id_58>",
|
| 469 |
+
"lstrip": false,
|
| 470 |
+
"normalized": true,
|
| 471 |
+
"rstrip": false,
|
| 472 |
+
"single_word": false,
|
| 473 |
+
"special": false
|
| 474 |
+
},
|
| 475 |
+
"59": {
|
| 476 |
+
"content": "<extra_id_59>",
|
| 477 |
+
"lstrip": false,
|
| 478 |
+
"normalized": true,
|
| 479 |
+
"rstrip": false,
|
| 480 |
+
"single_word": false,
|
| 481 |
+
"special": false
|
| 482 |
+
},
|
| 483 |
+
"60": {
|
| 484 |
+
"content": "<extra_id_60>",
|
| 485 |
+
"lstrip": false,
|
| 486 |
+
"normalized": true,
|
| 487 |
+
"rstrip": false,
|
| 488 |
+
"single_word": false,
|
| 489 |
+
"special": false
|
| 490 |
+
},
|
| 491 |
+
"61": {
|
| 492 |
+
"content": "<extra_id_61>",
|
| 493 |
+
"lstrip": false,
|
| 494 |
+
"normalized": true,
|
| 495 |
+
"rstrip": false,
|
| 496 |
+
"single_word": false,
|
| 497 |
+
"special": false
|
| 498 |
+
},
|
| 499 |
+
"62": {
|
| 500 |
+
"content": "<extra_id_62>",
|
| 501 |
+
"lstrip": false,
|
| 502 |
+
"normalized": true,
|
| 503 |
+
"rstrip": false,
|
| 504 |
+
"single_word": false,
|
| 505 |
+
"special": false
|
| 506 |
+
},
|
| 507 |
+
"63": {
|
| 508 |
+
"content": "<extra_id_63>",
|
| 509 |
+
"lstrip": false,
|
| 510 |
+
"normalized": true,
|
| 511 |
+
"rstrip": false,
|
| 512 |
+
"single_word": false,
|
| 513 |
+
"special": false
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"additional_special_tokens": [
|
| 517 |
+
"<repo_name>",
|
| 518 |
+
"<file_sep>",
|
| 519 |
+
"<t2v_token>",
|
| 520 |
+
"<v2t_token>",
|
| 521 |
+
"<|start_header_id|>",
|
| 522 |
+
"<|end_header_id|>",
|
| 523 |
+
"<|eot_id|>",
|
| 524 |
+
"<extra_id_12>",
|
| 525 |
+
"<extra_id_13>",
|
| 526 |
+
"<extra_id_14>",
|
| 527 |
+
"<extra_id_15>",
|
| 528 |
+
"<extra_id_16>",
|
| 529 |
+
"<extra_id_17>",
|
| 530 |
+
"<extra_id_18>",
|
| 531 |
+
"<extra_id_19>",
|
| 532 |
+
"<extra_id_20>",
|
| 533 |
+
"<extra_id_21>",
|
| 534 |
+
"<extra_id_22>",
|
| 535 |
+
"<extra_id_23>",
|
| 536 |
+
"<extra_id_24>",
|
| 537 |
+
"<extra_id_25>",
|
| 538 |
+
"<extra_id_26>",
|
| 539 |
+
"<extra_id_27>",
|
| 540 |
+
"<extra_id_28>",
|
| 541 |
+
"<extra_id_29>",
|
| 542 |
+
"<extra_id_30>",
|
| 543 |
+
"<extra_id_31>",
|
| 544 |
+
"<extra_id_32>",
|
| 545 |
+
"<extra_id_33>",
|
| 546 |
+
"<extra_id_34>",
|
| 547 |
+
"<extra_id_35>",
|
| 548 |
+
"<extra_id_36>",
|
| 549 |
+
"<extra_id_37>",
|
| 550 |
+
"<extra_id_38>",
|
| 551 |
+
"<extra_id_39>",
|
| 552 |
+
"<extra_id_40>",
|
| 553 |
+
"<extra_id_41>",
|
| 554 |
+
"<extra_id_42>",
|
| 555 |
+
"<extra_id_43>",
|
| 556 |
+
"<extra_id_44>",
|
| 557 |
+
"<extra_id_45>",
|
| 558 |
+
"<extra_id_46>",
|
| 559 |
+
"<extra_id_47>",
|
| 560 |
+
"<extra_id_48>",
|
| 561 |
+
"<extra_id_49>",
|
| 562 |
+
"<extra_id_50>",
|
| 563 |
+
"<extra_id_51>",
|
| 564 |
+
"<extra_id_52>",
|
| 565 |
+
"<extra_id_53>",
|
| 566 |
+
"<extra_id_54>",
|
| 567 |
+
"<extra_id_55>",
|
| 568 |
+
"<extra_id_56>",
|
| 569 |
+
"<extra_id_57>",
|
| 570 |
+
"<extra_id_58>",
|
| 571 |
+
"<extra_id_59>",
|
| 572 |
+
"<extra_id_60>",
|
| 573 |
+
"<extra_id_61>",
|
| 574 |
+
"<extra_id_62>",
|
| 575 |
+
"<extra_id_63>"
|
| 576 |
+
],
|
| 577 |
+
"auto_map": {
|
| 578 |
+
"AutoProcessor": "processing_evabyte.EvaByteProcessor",
|
| 579 |
+
"AutoTokenizer": [
|
| 580 |
+
"tokenization_evabyte.EvaByteTokenizer",
|
| 581 |
+
null
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
"bos_token": "<bos>",
|
| 585 |
+
"chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
|
| 586 |
+
"clean_up_tokenization_spaces": false,
|
| 587 |
+
"eos_token": "<eos>",
|
| 588 |
+
"extra_ids": 0,
|
| 589 |
+
"extra_special_tokens": {},
|
| 590 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 591 |
+
"pad_token": "<pad>",
|
| 592 |
+
"processor_class": "EvaByteProcessor",
|
| 593 |
+
"sep_token": "<sep>",
|
| 594 |
+
"tokenizer_class": "EvaByteTokenizer",
|
| 595 |
+
"unk_token": "<unk>"
|
| 596 |
+
}
|
ckpts/ocpython_14b_bsz-2m_seq16k_docmask_multipredc2r8_90dynamic-10raw_transsentinel_minsize0ent98line16ow16pack_100B_2m_new_2_step-10000/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
# EvaByte Model Card
|
| 5 |
+
|
| 6 |
+
**EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
|
| 7 |
+
|
| 8 |
+
## Model Resources
|
| 9 |
+
|
| 10 |
+
- **Repository:** https://github.com/openevabyte/evabyte
|
| 11 |
+
- **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
|
| 12 |
+
- **Paper:** Coming soon
|
| 13 |
+
|
| 14 |
+
## Model Details
|
| 15 |
+
|
| 16 |
+
EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
|
| 17 |
+
|
| 18 |
+
| Stage | Model |
|
| 19 |
+
|:----- |:-----|
|
| 20 |
+
| Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
|
| 21 |
+
| Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
|
| 22 |
+
| SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
**Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
|
| 28 |
+
|
| 29 |
+
The code snippet below demonstrates EvaByte-6.5B for completion:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
# Load model and tokenizer
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
|
| 38 |
+
|
| 39 |
+
prompt = "The quick brown fox jumps "
|
| 40 |
+
|
| 41 |
+
# Tokenize input
|
| 42 |
+
# Option 1: standard HF tokenizer interface
|
| 43 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
| 44 |
+
|
| 45 |
+
# Option 2: Direct UTF-8 byte encoding with offset
|
| 46 |
+
# Note: Each byte is offset by 64 with <bos> prepended.
|
| 47 |
+
input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
|
| 48 |
+
|
| 49 |
+
# byte-by-byte generation (default)
|
| 50 |
+
generation_output = model.generate(
|
| 51 |
+
input_ids=input_ids,
|
| 52 |
+
max_new_tokens=32
|
| 53 |
+
)
|
| 54 |
+
# alternatively, use faster multibyte generation
|
| 55 |
+
generation_output = model.multi_byte_generate(
|
| 56 |
+
input_ids=input_ids,
|
| 57 |
+
max_new_tokens=32
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Decode and print the output
|
| 61 |
+
response = tokenizer.decode(
|
| 62 |
+
generation_output[0][input_ids.shape[1]:],
|
| 63 |
+
skip_special_tokens=False,
|
| 64 |
+
clean_up_tokenization_spaces=False
|
| 65 |
+
)
|
| 66 |
+
print(response)
|
| 67 |
+
# Sample output:
|
| 68 |
+
# over the lazy dog.\n\nThe quick
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### ⚙️ Generation Modes
|
| 72 |
+
|
| 73 |
+
EvaByte supports two generation interfaces:
|
| 74 |
+
- `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
|
| 75 |
+
- `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
|
| 76 |
+
- `input_ids`: the input byte ids.
|
| 77 |
+
- `temperature`: the temperature for sampling.
|
| 78 |
+
- `max_length`: the maximum length of the generated sequence.
|
| 79 |
+
- `max_new_tokens`: the maximum number of new bytes to generate.
|
| 80 |
+
- `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
|
| 81 |
+
- `top_p`: the top-p parameter for sampling.
|
| 82 |
+
- `do_sample`: greedy decoding or sampling.
|
| 83 |
+
|
| 84 |
+
**Notes and Limitations:**
|
| 85 |
+
- `device_map="auto"` is not supported for >2 GPUs.
|
| 86 |
+
- Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
|
| 87 |
+
- `torch_dtype=torch.bfloat16` is required.
|
| 88 |
+
- The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
|
| 89 |
+
|
| 90 |
+
## Bias, Risks, and Limitations
|
| 91 |
+
As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
|
| 92 |
+
|
| 93 |
+
## Evaluation
|
| 94 |
+
|
| 95 |
+
For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
|
| 96 |
+
|
| 97 |
+
## Citation
|
| 98 |
+
```bibtex
|
| 99 |
+
@misc{evabyte,
|
| 100 |
+
title = {EvaByte: Efficient Byte-level Language Models at Scale},
|
| 101 |
+
url = {https://hkunlp.github.io/blog/2025/evabyte},
|
| 102 |
+
author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
|
| 103 |
+
year = {2025}
|
| 104 |
+
}
|
| 105 |
+
```
|