2ira commited on
Commit
6880c8a
·
verified ·
1 Parent(s): a2e2e6c

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/generation_config.json +7 -0
  2. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/image_processing_evabyte.py +204 -0
  3. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/model.safetensors.index.json +450 -0
  4. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/modeling_evabyte.py +912 -0
  5. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/multibyte_decoding_evabyte.py +881 -0
  6. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/preprocessor_config.json +18 -0
  7. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processing_evabyte.py +287 -0
  8. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processor_config.json +6 -0
  9. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/special_tokens_map.json +98 -0
  10. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenization_evabyte.py +246 -0
  11. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenizer_config.json +596 -0
  12. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/README.md +105 -0
  13. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/config.json +48 -0
  14. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/configuration_evabyte.py +99 -0
  15. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva.py +424 -0
  16. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_agg_kernel.py +1766 -0
  17. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_cache.py +761 -0
  18. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_prep_kv_kernel.py +1017 -0
  19. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_pt_ref.py +420 -0
  20. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/generation_config.json +7 -0
  21. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/image_processing_evabyte.py +204 -0
  22. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/model.safetensors.index.json +450 -0
  23. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/modeling_evabyte.py +912 -0
  24. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/multibyte_decoding_evabyte.py +881 -0
  25. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/preprocessor_config.json +18 -0
  26. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processing_evabyte.py +287 -0
  27. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processor_config.json +6 -0
  28. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/special_tokens_map.json +98 -0
  29. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenization_evabyte.py +246 -0
  30. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenizer_config.json +596 -0
  31. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/README.md +105 -0
  32. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/config.json +48 -0
  33. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/configuration_evabyte.py +99 -0
  34. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva.py +424 -0
  35. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_agg_kernel.py +1766 -0
  36. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_cache.py +761 -0
  37. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_prep_kv_kernel.py +1017 -0
  38. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_pt_ref.py +420 -0
  39. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/generation_config.json +7 -0
  40. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/image_processing_evabyte.py +204 -0
  41. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/model.safetensors.index.json +450 -0
  42. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/modeling_evabyte.py +912 -0
  43. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/multibyte_decoding_evabyte.py +881 -0
  44. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/preprocessor_config.json +18 -0
  45. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processing_evabyte.py +287 -0
  46. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processor_config.json +6 -0
  47. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/special_tokens_map.json +98 -0
  48. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenization_evabyte.py +246 -0
  49. ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenizer_config.json +596 -0
  50. ckpts/ocpython_14b_bsz-2m_seq16k_docmask_multipredc2r8_90dynamic-10raw_transsentinel_minsize0ent98line16ow16pack_100B_2m_new_2_step-10000/README.md +105 -0
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.47.1"
7
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/image_processing_evabyte.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Image processor class for EvaByte."""
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ import io
7
+ from transformers.image_processing_utils import BaseImageProcessor
8
+ from transformers.image_utils import (
9
+ ImageInput,
10
+ PILImageResampling,
11
+ valid_images,
12
+ validate_preprocess_arguments,
13
+ )
14
+ from PIL import Image
15
+
16
+ def _get_qtable_bytes():
17
+ return {
18
+ 5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
19
+ 10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
20
+ 15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
21
+ 20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
22
+ 25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
23
+ 30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
24
+ 50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
25
+ 75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
26
+ 95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
27
+ 100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
28
+ }
29
+
30
+
31
+ def _resize_if_exceeding_max_len(
32
+ width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
33
+ ) -> Tuple[int, int]:
34
+ """
35
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
36
+
37
+ Args:
38
+ height (`int`):
39
+ Height of the input image.
40
+ width (`int`):
41
+ Width of the input image.
42
+ max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
43
+ Defines the maximum dimensions of the image.
44
+
45
+ Returns:
46
+ The output size of the image after resizing.
47
+ """
48
+ max_len = max(height, width) if max_len is None else max_len
49
+ aspect_ratio = width / height
50
+
51
+ if width >= height and width > max_len:
52
+ width = max_len
53
+ height = int(width / aspect_ratio)
54
+ if height % 2 != 0:
55
+ height += 1
56
+ elif height > width and height > max_len:
57
+ height = max_len
58
+ width = int(height * aspect_ratio)
59
+ if width % 2 != 0:
60
+ width += 1
61
+
62
+ # Avoid resizing to a size smaller than 1
63
+ height = max(height, min_len)
64
+ width = max(width, min_len)
65
+ return width, height
66
+
67
+ class EvaByteImageProcessor(BaseImageProcessor):
68
+
69
+ model_input_names = []
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
75
+ size: Dict[str, int] = None,
76
+ do_convert_rgb: bool = True,
77
+ jpeg_quality: int = 25,
78
+ jpeg_subsampling: str = "4:2:0",
79
+ jpeg_streamtype: str = 2,
80
+ jpeg_restart_marker_blocks: int = 1,
81
+ **kwargs,
82
+ ) -> None:
83
+ super().__init__(**kwargs)
84
+ self.do_resize = do_resize
85
+ self.resample = resample
86
+ self.size = size if size is not None else {"longest_edge": 384}
87
+ self.do_convert_rgb = do_convert_rgb
88
+ self.jpeg_quality = jpeg_quality
89
+ self.jpeg_subsampling = jpeg_subsampling
90
+ self.jpeg_streamtype = jpeg_streamtype
91
+ self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
92
+
93
+ def jpeg_encode(
94
+ self,
95
+ image,
96
+ jpeg_quality,
97
+ jpeg_subsampling,
98
+ jpeg_streamtype,
99
+ jpeg_restart_marker_blocks,
100
+ ):
101
+ with io.BytesIO() as output:
102
+ image.save(
103
+ output,
104
+ format="JPEG",
105
+ quality=jpeg_quality,
106
+ subsampling=jpeg_subsampling,
107
+ streamtype=jpeg_streamtype,
108
+ restart_marker_blocks=jpeg_restart_marker_blocks
109
+ )
110
+ jpeg_bytes = output.getvalue()
111
+ return jpeg_bytes
112
+
113
+ def jpeg_merge_qtables(
114
+ self,
115
+ image_bytes,
116
+ jpeg_quality=None,
117
+ ):
118
+ if jpeg_quality is None:
119
+ jpeg_quality = self.jpeg_quality
120
+ qtable_bytes = _get_qtable_bytes()[jpeg_quality]
121
+ return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
122
+
123
+ def resize(
124
+ self,
125
+ image: Image,
126
+ size: Dict[str, int],
127
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
128
+ ) -> Image:
129
+ if "longest_edge" in size:
130
+ width, height = image.size
131
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
132
+ width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
133
+ size = (width, height)
134
+ elif "width" in size and "height" in size:
135
+ size = (size["width"], size["height"])
136
+ else:
137
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
138
+ resized_image = image.resize(size, resample=resample)
139
+ return resized_image
140
+
141
+ def preprocess(
142
+ self,
143
+ images: ImageInput,
144
+ do_resize: bool = None,
145
+ resample = None,
146
+ size: Dict[str, int] = None,
147
+ do_convert_rgb: bool = None,
148
+ jpeg_quality: int = None,
149
+ jpeg_subsampling: str = None,
150
+ jpeg_streamtype: str = None,
151
+ jpeg_restart_marker_blocks: int = None,
152
+ ):
153
+ do_resize = do_resize if do_resize is not None else self.do_resize
154
+ size = size if size is not None else self.size
155
+ resample = resample if resample is not None else self.resample
156
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
157
+
158
+ jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
159
+ jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
160
+ jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
161
+ jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
162
+
163
+ if images is not None and not valid_images(images):
164
+ raise ValueError(
165
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
166
+ "torch.Tensor, tf.Tensor or jax.ndarray."
167
+ )
168
+
169
+ validate_preprocess_arguments(
170
+ do_resize=do_resize,
171
+ size=size,
172
+ resample=resample,
173
+ )
174
+ images_list = images
175
+ if do_convert_rgb:
176
+ images_list = [
177
+ [
178
+ image.convert("RGB") for image in images
179
+ ]
180
+ for images in images_list
181
+ ]
182
+
183
+ if do_resize:
184
+ images_list = [
185
+ [
186
+ self.resize(image=image, size=size, resample=resample)
187
+ for image in images
188
+ ]
189
+ for images in images_list
190
+ ]
191
+
192
+ jpeg_bytes = [
193
+ [
194
+ self.jpeg_encode(
195
+ image,
196
+ jpeg_quality,
197
+ jpeg_subsampling,
198
+ jpeg_streamtype,
199
+ jpeg_restart_marker_blocks
200
+ ) for image in images
201
+ ]
202
+ for images in images_list
203
+ ]
204
+ return jpeg_bytes
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/model.safetensors.index.json ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 57058938880
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
10
+ "model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
11
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
24
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
25
+ "model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
26
+ "model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
27
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
28
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
29
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
30
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
31
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
32
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
33
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
34
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
35
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
36
+ "model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
37
+ "model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
38
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
39
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
40
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
41
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
42
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
43
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
44
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
45
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
46
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
47
+ "model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
48
+ "model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
49
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
50
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
51
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
52
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
53
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
54
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
55
+ "model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
56
+ "model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
57
+ "model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
58
+ "model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
59
+ "model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
60
+ "model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
61
+ "model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
62
+ "model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
63
+ "model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
64
+ "model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
65
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
66
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
67
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
68
+ "model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
69
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
70
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
71
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
72
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
73
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
74
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
75
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
76
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
77
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
90
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
91
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
92
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
93
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
94
+ "model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
95
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
96
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
97
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
98
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
99
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
100
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
101
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
102
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
103
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
104
+ "model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
105
+ "model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
106
+ "model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
107
+ "model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
108
+ "model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
109
+ "model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
110
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
111
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
112
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
113
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
114
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
115
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
116
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
119
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
126
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
127
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
143
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
144
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
145
+ "model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
146
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
147
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
148
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
149
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
150
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
151
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
152
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
153
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
154
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
155
+ "model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
156
+ "model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
157
+ "model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
158
+ "model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
159
+ "model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
160
+ "model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
161
+ "model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
162
+ "model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
163
+ "model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
164
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
165
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
166
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
167
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
168
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
169
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
170
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
171
+ "model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
172
+ "model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
173
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
174
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
175
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
176
+ "model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
177
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
178
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
184
+ "model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
185
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
190
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
192
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
194
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
195
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
196
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
197
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
198
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
201
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
202
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
203
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
204
+ "model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
205
+ "model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
206
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
213
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
214
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
215
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
221
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
225
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
226
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
227
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
228
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
229
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
230
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
231
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
232
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
233
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
238
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
239
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
240
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
241
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
242
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
243
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
244
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
245
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
246
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
247
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
248
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
249
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
250
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
251
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
256
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
257
+ "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
258
+ "model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
259
+ "model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
260
+ "model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
261
+ "model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
262
+ "model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
263
+ "model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
264
+ "model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
265
+ "model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
266
+ "model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
267
+ "model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
268
+ "model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
269
+ "model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
270
+ "model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
271
+ "model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
272
+ "model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
273
+ "model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
274
+ "model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
275
+ "model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
276
+ "model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
277
+ "model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
278
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
279
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
280
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
281
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
282
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
285
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
286
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
287
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
288
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
289
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
290
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
291
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
293
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
294
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
297
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
298
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
299
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
300
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
301
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
302
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
303
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
304
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
305
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
306
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
307
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
308
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
309
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
310
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
311
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
312
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
313
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
314
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
315
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
316
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
317
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
318
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
319
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
320
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
321
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
322
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
323
+ "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
324
+ "model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
325
+ "model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
326
+ "model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
327
+ "model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
328
+ "model.norm.weight": "model-00003-of-00003.safetensors",
329
+ "lm_head.weight": "model-00003-of-00003.safetensors",
330
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
331
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
332
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
333
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
334
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
335
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
336
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
337
+ "model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
338
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
339
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
340
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
341
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
342
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
343
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
344
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
345
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
346
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
347
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
348
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
349
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
350
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
351
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
352
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
353
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
354
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
355
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
356
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
357
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
358
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
359
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
360
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
361
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
362
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
363
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
364
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
365
+ "model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
366
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
367
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
368
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
369
+ "model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
370
+ "model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
371
+ "model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
372
+ "model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
373
+ "model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
374
+ "model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
375
+ "model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
376
+ "model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
377
+ "model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
378
+ "model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
379
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
380
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
381
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
382
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
383
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
384
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
385
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
386
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
387
+ "model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
388
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
389
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
390
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
391
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
392
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
393
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
394
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
395
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
396
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
398
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
399
+ "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
400
+ "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
401
+ "model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
402
+ "model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
403
+ "model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
404
+ "model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
405
+ "model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
406
+ "model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
407
+ "model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
408
+ "model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
409
+ "model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
410
+ "model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
411
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
412
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
413
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
414
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
415
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
416
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
417
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
418
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
419
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
420
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
421
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
422
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
423
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
424
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
425
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
426
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
427
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
428
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
429
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
430
+ "model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
431
+ "model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
432
+ "model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
433
+ "model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
434
+ "model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
435
+ "model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
436
+ "model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
437
+ "model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
438
+ "model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
439
+ "model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
440
+ "model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
441
+ "model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
442
+ "model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
443
+ "model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
444
+ "model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
445
+ "model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
446
+ "model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
447
+ "model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
448
+ "model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
449
+ }
450
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/modeling_evabyte.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from .configuration_evabyte import EvaByteConfig
17
+ from .multibyte_decoding_evabyte import MultiByteDecodingMixin
18
+ try:
19
+ import triton
20
+ USE_TRITON_IMPL = True
21
+ from .eva import EvaAttention
22
+ from .eva_agg_kernel import triton_eva_agg_fwd
23
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
24
+ except ImportError:
25
+ USE_TRITON_IMPL = False
26
+ print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
27
+ from .eva_pt_ref import EvaAttention
28
+ from .eva_cache import EvaCache, EvaStaticCacheForTriton
29
+
30
+ MASK_MIN_VALUE = -10e10
31
+
32
+ def prepare_eva_attention_mask(
33
+ seq_len,
34
+ device,
35
+ chunk_size,
36
+ window_size,
37
+ use_cache=False,
38
+ cache=None
39
+ ):
40
+ """
41
+ Prepare attention masks for EVA.
42
+
43
+ """
44
+ chunk_causal_mask = None
45
+ window_causal_mask = None
46
+ if use_cache:
47
+ cached_seq_len = cache.get_seq_length()
48
+ total_seq_len = seq_len + cached_seq_len
49
+ # cached_seq_len will be 0 during prefilling
50
+ # padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
51
+ padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
52
+ num_chunks = padded_seq_len // chunk_size
53
+ else:
54
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
55
+ assert seq_len % chunk_size == 0
56
+ num_chunks = seq_len // chunk_size
57
+
58
+ assert seq_len % window_size == 0
59
+
60
+ # create causal mask
61
+ ################################
62
+ # generate chunked causal masks
63
+ ################################
64
+ # [b, h, j, c, c]
65
+ chunks_per_window = window_size // chunk_size
66
+ if num_chunks >= chunks_per_window:
67
+ chunk_causal_mask = torch.ones(
68
+ (chunk_size, num_chunks, num_chunks),
69
+ device=device,
70
+ dtype=torch.bool
71
+ ).triu(0)
72
+
73
+ num_blocks = num_chunks // chunks_per_window
74
+ chunk_causal_mask = chunk_causal_mask.reshape(
75
+ chunk_size,
76
+ num_blocks,
77
+ chunks_per_window,
78
+ num_blocks,
79
+ chunks_per_window
80
+ ).transpose(-2, -3)
81
+
82
+ block_diag_zero = (
83
+ torch.eye(num_blocks, device=device, dtype=torch.bool)
84
+ .unsqueeze(-1)
85
+ .unsqueeze(-1)
86
+ .unsqueeze(0)
87
+ )
88
+
89
+ # Set diagonal blocks to zero
90
+ chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
91
+
92
+ # Reshape back to original size
93
+ chunk_causal_mask = (
94
+ chunk_causal_mask
95
+ .transpose(-2, -3)
96
+ .reshape(chunk_size, num_chunks, num_chunks)
97
+ .transpose(-2, -3)
98
+ .reshape(chunk_size * num_chunks, num_chunks)
99
+ .unsqueeze(0)
100
+ .unsqueeze(0)
101
+ )
102
+ else:
103
+ chunk_causal_mask = torch.ones(
104
+ (1, 1, chunk_size, num_chunks, num_chunks),
105
+ device=device,
106
+ dtype=torch.bool,
107
+ ).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
108
+ chunk_causal_mask = chunk_causal_mask.reshape(
109
+ 1, 1, chunk_size * num_chunks, num_chunks
110
+ ) # [1, 1, n, c]
111
+
112
+ if use_cache:
113
+ chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
114
+
115
+ window_causal_mask = torch.ones(
116
+ (1, 1, 1, window_size, window_size),
117
+ device=device
118
+ ).triu(1).to(torch.bool)
119
+ return (chunk_causal_mask, window_causal_mask)
120
+
121
+ def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
122
+ assert dim < 0 # only accept ``dim'' index in a reverse manner
123
+ seqlen = int(tensor.shape[dim])
124
+ m = seqlen / multiple
125
+ if m.is_integer():
126
+ if create_mask:
127
+ return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
128
+ else:
129
+ return tensor
130
+ remainder = math.ceil(m) * multiple - seqlen
131
+ pad_offset = (0,) * (-1 - dim) * 2
132
+ if left_padding:
133
+ padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
134
+ else:
135
+ padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
136
+ if create_mask:
137
+ # assume dim 0 is the batch size
138
+ padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
139
+ if left_padding:
140
+ padding_mask[:, :remainder] = False
141
+ else:
142
+ padding_mask[:, -remainder:] = False
143
+ return padded_res, padding_mask
144
+ else:
145
+ return padded_res
146
+
147
+ class EvaByteRMSNorm(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.fp32_ln = True
152
+ self.variance_epsilon = config.rms_norm_eps
153
+ self.add_unit_offset = config.norm_add_unit_offset
154
+ if self.add_unit_offset:
155
+ self.weight = nn.Parameter(torch.zeros(config.hidden_size))
156
+ else:
157
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
+
159
+ def forward(self, hidden_states):
160
+ _hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
161
+
162
+ variance = _hidden_states.pow(2).mean(-1, keepdim=True)
163
+ _hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
+ if self.add_unit_offset:
165
+ return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
166
+ else:
167
+ return (self.weight * _hidden_states).type_as(hidden_states)
168
+
169
+ class EvaByteRotaryEmbedding(torch.nn.Module):
170
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
171
+ super().__init__()
172
+
173
+ self.dim = dim
174
+ self.max_position_embeddings = max_position_embeddings
175
+ self.base = base
176
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
+
179
+ self._set_cos_sin_cache(seq_len=max_position_embeddings,
180
+ device=self.inv_freq.device,
181
+ dtype=torch.get_default_dtype())
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
186
+
187
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
188
+ emb = torch.cat((freqs, freqs), dim=-1)
189
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
190
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
191
+
192
+
193
+ def forward(self, x, seq_len=None):
194
+ # x: [bs, num_attention_heads, seq_len, head_size]
195
+ if seq_len > self.max_seq_len_cached:
196
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
197
+
198
+ # return (
199
+ # self.cos_cached[:seq_len].to(dtype=x.dtype),
200
+ # self.sin_cached[:seq_len].to(dtype=x.dtype),
201
+ # )
202
+ if seq_len < self.max_seq_len_cached:
203
+ cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
204
+ sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
205
+ else:
206
+ cos_slice = self.cos_cached
207
+ sin_slice = self.sin_cached
208
+
209
+ return (
210
+ cos_slice.to(dtype=x.dtype),
211
+ sin_slice.to(dtype=x.dtype),
212
+ )
213
+
214
+
215
+
216
+ class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
217
+ """EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
218
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
219
+ self.scaling_factor = scaling_factor
220
+ super().__init__(dim, max_position_embeddings, base, device)
221
+
222
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
223
+ self.max_seq_len_cached = seq_len
224
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
225
+ t = t / self.scaling_factor
226
+
227
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
228
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
229
+ emb = torch.cat((freqs, freqs), dim=-1)
230
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
231
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
232
+
233
+
234
+ class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
235
+ """EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
236
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
237
+ self.scaling_factor = scaling_factor
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+
243
+ if seq_len > self.max_position_embeddings:
244
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
245
+ (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
246
+ inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+
249
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
250
+
251
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
252
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
253
+ emb = torch.cat((freqs, freqs), dim=-1)
254
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
255
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
256
+
257
+
258
+ class EvaByteMLP(nn.Module):
259
+ def __init__(self, config, layer_idx: int = None):
260
+ super().__init__()
261
+ self.hidden_size = config.hidden_size
262
+ self.intermediate_size = config.intermediate_size
263
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
264
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
265
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
266
+ self.act_fn = ACT2FN[config.hidden_act]
267
+ self.layer_idx = layer_idx
268
+ self.config = config
269
+
270
+ def forward(self, x):
271
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
272
+ return down_proj
273
+
274
+ class EvaByteDecoderLayer(nn.Module):
275
+ def __init__(self, config: EvaByteConfig, layer_idx: int = None):
276
+ super().__init__()
277
+ self.config = config
278
+ self.hidden_size = config.hidden_size
279
+ self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
280
+ self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
281
+ self.input_layernorm = EvaByteRMSNorm(config)
282
+ self.post_attention_layernorm = EvaByteRMSNorm(config)
283
+
284
+ def forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ attention_mask: Optional[torch.Tensor] = None,
288
+ position_ids: Optional[torch.LongTensor] = None,
289
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
290
+ output_attentions: Optional[bool] = False,
291
+ use_cache: Optional[bool] = False,
292
+ cos: Optional[torch.Tensor] = None,
293
+ sin: Optional[torch.Tensor] = None,
294
+ multibyte_decoding: Optional[bool] = False,
295
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
296
+ residual = hidden_states
297
+ if self.config.fp32_skip_add:
298
+ residual = residual.float()
299
+
300
+ hidden_states = self.input_layernorm(hidden_states)
301
+
302
+ # Self Attention
303
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
304
+ attention_mask=attention_mask,
305
+ position_ids=position_ids,
306
+ past_key_value=past_key_value,
307
+ output_attentions=output_attentions,
308
+ use_cache=use_cache,
309
+ cos=cos,
310
+ sin=sin,
311
+ multibyte_decoding=multibyte_decoding)
312
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
313
+
314
+ # Fully Connected
315
+ residual = hidden_states
316
+ if self.config.fp32_skip_add:
317
+ residual = residual.float()
318
+ hidden_states = self.post_attention_layernorm(hidden_states)
319
+ hidden_states = self.mlp(hidden_states)
320
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
321
+
322
+ outputs = (hidden_states, )
323
+
324
+ if output_attentions:
325
+ outputs += (self_attn_weights, )
326
+
327
+ if use_cache:
328
+ outputs += (present_key_value, )
329
+ return outputs
330
+
331
+ class EvaBytePreTrainedModel(PreTrainedModel):
332
+ config_class = EvaByteConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["EvaByteDecoderLayer"]
336
+ _skip_keys_device_placement = "past_key_values"
337
+
338
+ def _init_weights(self, module):
339
+ std = getattr(self.config, "initializer_range", 0.02)
340
+ if isinstance(module, nn.Linear):
341
+ module.weight.data.normal_(mean=0.0, std=std)
342
+ if module.bias is not None:
343
+ module.bias.data.zero_()
344
+ elif isinstance(module, nn.Embedding):
345
+ module.weight.data.normal_(mean=0.0, std=std)
346
+ if module.padding_idx is not None:
347
+ module.weight.data[module.padding_idx].zero_()
348
+
349
+ def _set_gradient_checkpointing(self, module, value=False):
350
+ if isinstance(module, EvaByteModel):
351
+ module.gradient_checkpointing = value
352
+
353
+ class EvaByteModel(EvaBytePreTrainedModel):
354
+ """
355
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
356
+
357
+ Args:
358
+ config: EvaByteConfig
359
+ """
360
+ def __init__(self, config: EvaByteConfig):
361
+ super().__init__(config)
362
+ self.padding_idx = config.pad_token_id
363
+ self.vocab_size = config.vocab_size
364
+ self.hidden_size = config.hidden_size
365
+ self.num_heads = config.num_attention_heads
366
+ self.head_dim = self.hidden_size // self.num_heads
367
+ self.max_position_embeddings = self.config.max_position_embeddings
368
+
369
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
370
+ self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
371
+ self.norm = EvaByteRMSNorm(config)
372
+
373
+ self.gradient_checkpointing = False
374
+ self.rope = config.rope_theta
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+ self._init_rope()
378
+
379
+ def _init_rope(self):
380
+ if self.config.rope_scaling is None:
381
+ self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
382
+ max_position_embeddings=self.max_position_embeddings,
383
+ base=self.rope)
384
+ else:
385
+ scaling_type = self.config.rope_scaling["type"]
386
+ scaling_factor = self.config.rope_scaling["factor"]
387
+ if scaling_type == "linear":
388
+ self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
389
+ self.head_dim,
390
+ max_position_embeddings=self.max_position_embeddings,
391
+ scaling_factor=scaling_factor,
392
+ base=self.rope)
393
+ elif scaling_type == "dynamic":
394
+ self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
395
+ self.head_dim,
396
+ max_position_embeddings=self.max_position_embeddings,
397
+ scaling_factor=scaling_factor,
398
+ base=self.rope)
399
+ else:
400
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
401
+
402
+ def get_input_embeddings(self):
403
+ return self.embed_tokens
404
+
405
+ def set_input_embeddings(self, value):
406
+ self.embed_tokens = value
407
+
408
+ def _helper_padding_mask(
409
+ self,
410
+ padding_mask,
411
+ causal_mask
412
+ ):
413
+ padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
414
+ return torch.logical_or(padding_mask, causal_mask)
415
+
416
+ def _prepare_eva_generation_attn_mask_triton(
417
+ self,
418
+ attention_mask,
419
+ input_ids,
420
+ use_cache,
421
+ past_key_values
422
+ ):
423
+ batch_size, seq_len = input_ids.shape
424
+ if use_cache and past_key_values.get_seq_length() > 0:
425
+ # decoding phase
426
+ if past_key_values.rf_mask[0] is not None:
427
+ cur_rf_mask = torch.zeros(
428
+ (batch_size, 1, seq_len, 1),
429
+ dtype=past_key_values.rf_mask[0].dtype,
430
+ device=past_key_values.rf_mask[0].device
431
+ )
432
+ else:
433
+ cur_rf_mask = None
434
+
435
+ if past_key_values.s_mask[0] is not None:
436
+ cur_s_mask = torch.zeros(
437
+ (batch_size, 1, seq_len, 1),
438
+ dtype=past_key_values.s_mask[0].dtype,
439
+ device=past_key_values.s_mask[0].device
440
+ )
441
+ else:
442
+ cur_s_mask = None
443
+
444
+ seen_tokens = past_key_values.get_seq_length()
445
+ if seen_tokens <= self.config.window_size:
446
+ rfa_chunks_dummy_mask = None
447
+ else:
448
+ if cur_s_mask is not None:
449
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
450
+ # the ongoing decoding step would be (seen_seq_len + 1)-th token
451
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
452
+ rfa_chunks_dummy_mask = torch.zeros(
453
+ (batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
454
+ dtype=past_key_values.s_mask[0].dtype,
455
+ device=past_key_values.s_mask[0].device
456
+ )
457
+ else:
458
+ rfa_chunks_dummy_mask = None
459
+ # rf_mask and cur_mask are 0s because we do not want to mask them
460
+ return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
461
+
462
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
463
+ # convert 0 -> padding to 1 -> padding
464
+ padded_attention_mask = pad_to_multiple(
465
+ attention_mask,
466
+ self.config.window_size,
467
+ dim=-1,
468
+ value=0,
469
+ create_mask=False,
470
+ left_padding=False
471
+ )
472
+ # convert 0 -> padding to 1 -> padding
473
+ padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
474
+ # [b, 1, w, j, 1]
475
+ padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
476
+ # [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
477
+ w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
478
+ w_causal_mask = torch.ones(
479
+ (1, 1, 1, self.config.window_size, self.config.window_size),
480
+ device=input_ids.device
481
+ ).triu(1).to(torch.bool)
482
+ s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
483
+ s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
484
+ s_mask = s_mask[..., :seq_len, :]
485
+ # negate the attention mask to get the padding mask
486
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
487
+ return (s_mask, rf_mask)
488
+ else:
489
+ return (None, None)
490
+
491
+ def _prepare_eva_generation_attn_mask(
492
+ self,
493
+ attention_mask,
494
+ input_ids,
495
+ use_cache,
496
+ past_key_values
497
+ ):
498
+ batch_size, seq_len = input_ids.shape
499
+ if use_cache and past_key_values.get_seq_length() > 0:
500
+ # decoding phase
501
+ if past_key_values.rf_mask[0] is not None:
502
+ rf_mask = torch.zeros(
503
+ (batch_size, 1, seq_len, 1),
504
+ dtype=past_key_values.rf_mask[0].dtype,
505
+ device=past_key_values.rf_mask[0].device
506
+ )
507
+ else:
508
+ rf_mask = None
509
+
510
+ cur_causal_mask = torch.zeros(
511
+ (batch_size, 1, seq_len, 1),
512
+ dtype=torch.bool,
513
+ device=input_ids.device
514
+ )
515
+
516
+ chunk_causal_mask = torch.ones(
517
+ (batch_size, 1, seq_len, 1),
518
+ dtype=torch.bool,
519
+ device=input_ids.device
520
+ )
521
+ # chunk_causal_mask are 1s because we will mask them by default and
522
+ # will be unmasked when the current singleton attention is processed over
523
+ return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
524
+
525
+ true_num_chunks = seq_len // self.config.chunk_size
526
+ chunk_causal_mask, _ = prepare_eva_attention_mask(
527
+ seq_len,
528
+ input_ids.device,
529
+ self.config.chunk_size,
530
+ self.config.window_size,
531
+ use_cache=use_cache,
532
+ cache=past_key_values
533
+ )
534
+ chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
535
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
536
+ # convert 0 -> padding to 1 -> padding
537
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
538
+ else:
539
+ rf_mask = None
540
+
541
+ if seq_len < self.config.window_size:
542
+ cur_window_mask = torch.ones(
543
+ (1, 1, seq_len, seq_len),
544
+ device=input_ids.device
545
+ ).triu(1).to(torch.bool)
546
+ if rf_mask is not None:
547
+ cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
548
+ prev_window_mask = None
549
+ else:
550
+ if seq_len % self.config.window_size == 0:
551
+ num_windows = seq_len // self.config.window_size
552
+ cur_window_mask = None
553
+ prev_window_mask = torch.ones(
554
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
555
+ device=input_ids.device
556
+ ).triu(1).to(torch.bool)
557
+ if rf_mask is not None:
558
+ prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
559
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
560
+ else:
561
+ num_windows = seq_len // self.config.window_size
562
+ remainder_tokens = seq_len % self.config.window_size
563
+ cur_window_mask = torch.ones(
564
+ (1, 1, remainder_tokens, remainder_tokens),
565
+ device=input_ids.device
566
+ ).triu(1).to(torch.bool)
567
+ prev_window_mask = torch.ones(
568
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
569
+ device=input_ids.device
570
+ ).triu(1).to(torch.bool)
571
+ if rf_mask is not None:
572
+ prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
573
+ cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
574
+ prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
575
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
576
+
577
+ return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
578
+
579
+ def forward(
580
+ self,
581
+ input_ids: torch.LongTensor = None,
582
+ attention_mask: Optional[torch.Tensor] = None,
583
+ position_ids: Optional[torch.LongTensor] = None,
584
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
586
+ use_cache: Optional[bool] = None,
587
+ output_attentions: Optional[bool] = None,
588
+ output_hidden_states: Optional[bool] = None,
589
+ return_dict: Optional[bool] = None,
590
+ multibyte_decoding: Optional[bool] = None,
591
+ ) -> Tuple:
592
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
593
+ output_hidden_states = (output_hidden_states
594
+ if output_hidden_states is not None else self.config.output_hidden_states)
595
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
596
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
597
+
598
+ if (input_ids is None) ^ (inputs_embeds is not None):
599
+ raise ValueError(
600
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
601
+ )
602
+
603
+ if self.gradient_checkpointing and self.training and use_cache:
604
+ raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
605
+
606
+ batch_size, seq_len = input_ids.shape
607
+ #### Step 0. Hack
608
+ if (not self.training) and (not use_cache) and (not multibyte_decoding):
609
+ # forward-only inference mode.
610
+ # We tweak use_cache to be True to reuse code for generation
611
+ use_cache = True
612
+ device = input_ids.device if input_ids is not None else None
613
+ if position_ids is None:
614
+ position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
615
+
616
+ #### Step 1. Prepare caches if in inference mode
617
+ if use_cache:
618
+ if past_key_values is not None:
619
+ assert isinstance(past_key_values, Cache)
620
+ else:
621
+ if not USE_TRITON_IMPL:
622
+ past_key_values = EvaCache()
623
+ else:
624
+ past_key_values = EvaStaticCacheForTriton(
625
+ input_ids.shape[0],
626
+ self.config.num_attention_heads,
627
+ self.config.window_size,
628
+ self.config.hidden_size // self.config.num_attention_heads,
629
+ self.config.num_hidden_layers,
630
+ self.embed_tokens.weight.dtype,
631
+ self.embed_tokens.weight.device,
632
+ )
633
+
634
+ if not multibyte_decoding:
635
+ if use_cache:
636
+ if USE_TRITON_IMPL:
637
+ causal_mask = self._prepare_eva_generation_attn_mask_triton(
638
+ attention_mask,
639
+ input_ids,
640
+ use_cache,
641
+ past_key_values
642
+ )
643
+ else:
644
+ causal_mask = self._prepare_eva_generation_attn_mask(
645
+ attention_mask,
646
+ input_ids,
647
+ use_cache,
648
+ past_key_values
649
+ )
650
+ else:
651
+ assert self.training
652
+ assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
653
+ # for training, we need to pass in the attention mask
654
+ # usually calculated by _prepare_training_attn_mask()
655
+ causal_mask = attention_mask
656
+ else:
657
+ assert use_cache
658
+ causal_mask = attention_mask
659
+
660
+ if inputs_embeds is None:
661
+ inputs_embeds = self.embed_tokens(input_ids)
662
+
663
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
664
+ max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
665
+
666
+ hidden_states = inputs_embeds
667
+
668
+ if position_ids is None:
669
+ assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
670
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
671
+ position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
672
+
673
+ cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
674
+ assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
675
+ assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
676
+ assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
677
+ cos = cos[position_ids, :]
678
+ sin = sin[position_ids, :]
679
+ cos = cos.unsqueeze(1)
680
+ sin = sin.unsqueeze(1)
681
+
682
+ # decoder layers
683
+ all_hidden_states = () if output_hidden_states else None
684
+ all_self_attns = () if output_attentions else None
685
+ next_decoder_cache = None
686
+
687
+ for decoder_layer in self.layers:
688
+ if output_hidden_states:
689
+ all_hidden_states += (hidden_states, )
690
+
691
+ if self.gradient_checkpointing and self.training:
692
+ layer_outputs = torch.utils.checkpoint.checkpoint(
693
+ decoder_layer.__call__,
694
+ hidden_states,
695
+ causal_mask,
696
+ position_ids,
697
+ past_key_values,
698
+ output_attentions,
699
+ use_cache,
700
+ cos,
701
+ sin,
702
+ multibyte_decoding,
703
+ )
704
+ else:
705
+ layer_outputs = decoder_layer(
706
+ hidden_states,
707
+ attention_mask=causal_mask,
708
+ position_ids=position_ids,
709
+ past_key_value=past_key_values,
710
+ output_attentions=output_attentions,
711
+ use_cache=use_cache,
712
+ cos=cos,
713
+ sin=sin,
714
+ multibyte_decoding=multibyte_decoding,
715
+ )
716
+
717
+ hidden_states = layer_outputs[0]
718
+
719
+ if use_cache:
720
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
721
+
722
+ if output_attentions:
723
+ all_self_attns += (layer_outputs[1], )
724
+
725
+ hidden_states = self.norm(hidden_states)
726
+
727
+ # add hidden states from the last decoder layer
728
+ if output_hidden_states:
729
+ all_hidden_states += (hidden_states, )
730
+
731
+ next_cache = next_decoder_cache if use_cache else None
732
+ if not return_dict:
733
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
734
+
735
+ return BaseModelOutputWithPast(
736
+ last_hidden_state=hidden_states,
737
+ past_key_values=next_cache,
738
+ hidden_states=all_hidden_states,
739
+ attentions=all_self_attns,
740
+ )
741
+
742
+
743
+ class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
744
+ _tied_weights_keys = ["lm_head.weight"]
745
+
746
+ def __init__(self, config):
747
+ EvaBytePreTrainedModel.__init__(self, config)
748
+
749
+ self.model = EvaByteModel(config)
750
+ self.vocab_size = config.vocab_size
751
+ # define multibyte prediction heads
752
+ if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
753
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
754
+ else:
755
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
756
+
757
+ self.post_init()
758
+
759
+ def get_input_embeddings(self):
760
+ return self.model.embed_tokens
761
+
762
+ def set_input_embeddings(self, value):
763
+ self.model.embed_tokens = value
764
+
765
+ def get_output_embeddings(self):
766
+ return self.lm_head
767
+
768
+ def set_output_embeddings(self, new_embeddings):
769
+ self.lm_head = new_embeddings
770
+
771
+ def set_decoder(self, decoder):
772
+ self.model = decoder
773
+
774
+ def get_decoder(self):
775
+ return self.model
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: torch.LongTensor = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ position_ids: Optional[torch.LongTensor] = None,
782
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
783
+ inputs_embeds: Optional[torch.FloatTensor] = None,
784
+ labels: Optional[torch.LongTensor] = None,
785
+ use_cache: Optional[bool] = None,
786
+ output_attentions: Optional[bool] = None,
787
+ output_hidden_states: Optional[bool] = None,
788
+ return_dict: Optional[bool] = None,
789
+ return_all_pred_logits: Optional[bool] = None,
790
+ multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
791
+
792
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
793
+ output_hidden_states = (output_hidden_states
794
+ if output_hidden_states is not None else self.config.output_hidden_states)
795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796
+
797
+ if input_ids is None:
798
+ assert past_key_values is None
799
+
800
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
801
+ outputs = self.model(
802
+ input_ids=input_ids,
803
+ attention_mask=attention_mask,
804
+ position_ids=position_ids,
805
+ past_key_values=past_key_values,
806
+ inputs_embeds=inputs_embeds,
807
+ use_cache=use_cache,
808
+ output_attentions=output_attentions,
809
+ output_hidden_states=output_hidden_states,
810
+ return_dict=return_dict,
811
+ multibyte_decoding=multibyte_decoding,
812
+ )
813
+
814
+ hidden_states = outputs[0]
815
+
816
+ logits = self.lm_head(hidden_states)
817
+ if self.config.fp32_logits:
818
+ logits = logits.float()
819
+
820
+ loss = None
821
+ if labels is not None:
822
+ loss_fct = CrossEntropyLoss(reduction="none")
823
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
824
+ shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
825
+ # shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
826
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
827
+ else:
828
+ shift_logits = logits.view(-1, self.config.vocab_size)
829
+ shift_labels = labels.view(-1)
830
+ # Enable model parallelism
831
+ shift_labels = shift_labels.to(shift_logits.device)
832
+ loss = loss_fct(shift_logits, shift_labels)
833
+
834
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
835
+ all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
836
+
837
+ if return_all_pred_logits:
838
+ logits = all_pred_logits
839
+ else:
840
+ logits = all_pred_logits[..., 0, :]
841
+
842
+ if not return_dict:
843
+ output = (logits, ) + outputs[1:]
844
+ return (loss, ) + output if loss is not None else output
845
+
846
+ return CausalLMOutputWithPast(
847
+ loss=loss,
848
+ logits=logits,
849
+ past_key_values=outputs.past_key_values,
850
+ hidden_states=outputs.hidden_states,
851
+ attentions=outputs.attentions,
852
+ )
853
+
854
+
855
+ def prepare_inputs_for_generation(self,
856
+ input_ids,
857
+ past_key_values=None,
858
+ attention_mask=None,
859
+ inputs_embeds=None,
860
+ use_cache=True,
861
+ **kwargs):
862
+ # prefill phase:
863
+ # input_ids: b x s
864
+ # attention_mask: None if no padding or b x s
865
+ # position_ids : b x s
866
+
867
+ # token gen phase:
868
+ # input_ids : b x 1
869
+ # attention_mask: b x 1 x s
870
+ # position_ids: b x 1
871
+ past_length = 0
872
+ if past_key_values is not None:
873
+ assert isinstance(past_key_values, Cache)
874
+ past_length = past_key_values.get_seq_length()
875
+
876
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
877
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
878
+ elif past_length < input_ids.shape[1]:
879
+ input_ids = input_ids[:, past_length:]
880
+
881
+ position_ids = kwargs.get("position_ids", None)
882
+ if attention_mask is not None and position_ids is None:
883
+ position_ids = attention_mask.long().cumsum(-1) - 1
884
+ position_ids.masked_fill_(attention_mask == 0, 1)
885
+ if past_key_values:
886
+ position_ids = position_ids[:, -input_ids.shape[1]:]
887
+
888
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
889
+ if inputs_embeds is not None and past_key_values is None:
890
+ model_inputs = {"inputs_embeds": inputs_embeds}
891
+ else:
892
+ model_inputs = {"input_ids": input_ids}
893
+
894
+ # must initialize position_ids at each step during GPU inference
895
+ assert position_ids is not None
896
+ model_inputs.update(
897
+ {
898
+ "position_ids": position_ids,
899
+ "past_key_values": past_key_values,
900
+ "use_cache": use_cache,
901
+ "attention_mask": attention_mask,
902
+ }
903
+ )
904
+ return model_inputs
905
+
906
+ @staticmethod
907
+ def _reorder_cache(past_key_values, beam_idx):
908
+ reordered_past = ()
909
+ for layer_past in past_key_values:
910
+ reordered_past += (tuple(
911
+ past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
912
+ return reordered_past
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/multibyte_decoding_evabyte.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # The implementation of multibyte deocidng is largely adapted from
3
+ # Medusa decoding: https://github.com/FasterDecoding/Medusa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers.generation.stopping_criteria import (
7
+ MaxLengthCriteria,
8
+ StoppingCriteriaList,
9
+ )
10
+ from typing import Union, List
11
+ from .eva_cache import EvaStaticCacheForTriton
12
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
13
+
14
+ class MultibyteEosTokenCriteria:
15
+ """
16
+ This class implements a simple stopping criteria to stop generation whenever
17
+ the "end-of-sequence" token is generated in the last `new_tokens` tokens.
18
+
19
+ Adapted from
20
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
21
+ By default, it uses the `model.generation_config.eos_token_id`.
22
+
23
+ Args:
24
+ eos_token_id (`Union[int, List[int]]`):
25
+ The id(s) of the *end-of-sequence* token.
26
+ """
27
+
28
+ def __init__(self, eos_token_ids: Union[int, List[int]]):
29
+ if isinstance(eos_token_ids, int):
30
+ eos_token_ids = [eos_token_ids]
31
+ self.eos_token_ids = eos_token_ids
32
+
33
+ def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
34
+ current_input_len = input_ids.shape[-1]
35
+ new_token_ids = input_ids[:, current_input_len - new_tokens:]
36
+ for eos_token_id in self.eos_token_ids:
37
+ if torch.any(new_token_ids == eos_token_id):
38
+ return True
39
+ return False
40
+
41
+ def build_tree(spec):
42
+ nodes_at_depth = []
43
+ nodes_at_depth.append([()]) # Root at depth 1
44
+
45
+ for d in range(1, len(spec) + 1):
46
+ prev_nodes = nodes_at_depth[d - 1]
47
+ spec_list = spec[d - 1]
48
+ current_nodes = []
49
+ for node_idx, node in enumerate(prev_nodes):
50
+ if node_idx < len(spec_list):
51
+ num_children = spec_list[node_idx]
52
+ else:
53
+ num_children = 0
54
+ for child_idx in range(num_children):
55
+ new_node = node + (child_idx,)
56
+ current_nodes.append(new_node)
57
+ nodes_at_depth.append(current_nodes)
58
+
59
+ # Flatten the list of nodes, excluding the root node if desired
60
+ all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
61
+ return all_nodes
62
+
63
+ evabyte_7b_95 = build_tree(
64
+ [
65
+ [10],
66
+ [10, 8, 2, 2, 1, 1],
67
+ [10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
68
+ [8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
69
+ [6, 2, 1, 1],
70
+ [4, 2, 1, 1],
71
+ [4, 2, 1],
72
+ ]
73
+ )
74
+ evabyte_7b_31 = build_tree(
75
+ [
76
+ [4],
77
+ [3, 2, 1, 1],
78
+ [3, 2, 1, 1],
79
+ [2, 1, 1],
80
+ [2, 1],
81
+ [2, 1],
82
+ [2, 1],
83
+ ]
84
+ )
85
+ TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
86
+
87
+ def pad_path(path, length, pad_value=-2):
88
+ """
89
+ Pad the given path list with a specific value up to a specified length.
90
+
91
+ Parameters:
92
+ - path (list): The original list that needs padding.
93
+ - length (int): The desired length of the padded list.
94
+ - pad_value (optional, default=-2): The value to use for padding.
95
+
96
+ Returns:
97
+ - list: A new list based on the original path but padded to the desired length.
98
+
99
+ Example:
100
+ >>> pad_path([1,2,3], 5)
101
+ [1, 2, 3, -2, -2]
102
+
103
+ Note:
104
+ If the given path is already longer than the specified length,
105
+ then no padding occurs, and the original path is returned.
106
+ """
107
+ return path + [pad_value] * (length - len(path))
108
+
109
+ def reset_past_key_values(passed_key_values):
110
+ """
111
+ Resets the current lengths in the passed key-values to zero.
112
+
113
+ This function is designed to be used during the evaluation of a baseline model.
114
+ It iterates through each layer's key-values and sets their current lengths to zero,
115
+ effectively resetting their state.
116
+
117
+ Args:
118
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
119
+
120
+ Returns:
121
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
122
+ """
123
+ for i in range(len(passed_key_values)):
124
+ for j in range(2):
125
+ passed_key_values[i][j].current_length.fill_(0)
126
+ return passed_key_values
127
+
128
+ def get_nucleus_one_token(logit, temperature, top_p):
129
+ """
130
+ Performs token sampling based on the nucleus (top-p) sampling method.
131
+
132
+ This function selects a token from a given logit distribution using the nucleus sampling strategy.
133
+ It allows for more controlled and diverse generation compared to traditional top-k sampling.
134
+
135
+ Args:
136
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
137
+ temperature (float): A temperature parameter to control the randomness in sampling.
138
+ Higher values increase diversity, lower values make selections more deterministic.
139
+ top_p (float): The cumulative probability threshold for nucleus sampling.
140
+ It controls the size of the set of high-probability tokens to consider for sampling.
141
+
142
+ Returns:
143
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
144
+ """
145
+ if top_p >= 1:
146
+ return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
147
+ logit = logit / temperature
148
+ probs = torch.softmax(logit, dim=-1)
149
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
150
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
151
+ sorted_indices_to_remove = cum_probs > top_p
152
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
153
+ sorted_indices_to_remove[..., 0] = 0
154
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
155
+ logit[indices_to_remove] = float('-inf')
156
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
157
+ return sampled_tokens
158
+
159
+ def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
160
+ """
161
+ Implements token sampling based on the typical sampling method.
162
+
163
+ This function selects a token from a given logit distribution using the typical sampling strategy,
164
+ aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
165
+
166
+ Args:
167
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
168
+ temperature (float): A parameter to control the randomness in sampling.
169
+ Higher values increase diversity, lower values make selections more deterministic.
170
+ posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
171
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
172
+
173
+ Returns:
174
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
175
+ """
176
+ logit = logit / temperature
177
+ probs = torch.softmax(logit, dim=-1)
178
+ entropy = -torch.sum(
179
+ probs * torch.log(probs + 1e-5), dim=-1
180
+ )
181
+ threshold = torch.minimum(
182
+ torch.ones_like(entropy) * posterior_threshold,
183
+ torch.exp(-entropy) * posterior_alpha,
184
+ )
185
+ indices_to_remove = probs < threshold.unsqueeze(-1)
186
+ logit[indices_to_remove] = float('-inf')
187
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
188
+ return sampled_tokens
189
+
190
+
191
+
192
+ def generate_medusa_buffers(medusa_choices, device="cuda"):
193
+ """
194
+ Generate buffers for the Medusa structure based on the provided choices.
195
+
196
+ Parameters:
197
+ - medusa_choices (list): A nested list representing tree in the Medusa structure.
198
+ - device (str): Device to which the tensors should be moved. Default is "cuda".
199
+
200
+ Returns:
201
+ - dict: A dictionary containing buffers related to the Medusa structure.
202
+ """
203
+
204
+ # Sort the medusa_choices based on their lengths and then their values
205
+ sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
206
+ medusa_len = len(sorted_medusa_choices) + 1
207
+
208
+ # Initialize depth_counts to keep track of how many choices have a particular depth
209
+ depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
210
+ for path in sorted_medusa_choices:
211
+ depth_counts[len(path) - 1] += 1
212
+
213
+ # Create the attention mask for Medusa
214
+ medusa_attn_mask = torch.eye(medusa_len, medusa_len)
215
+ medusa_attn_mask[:, 0] = 1
216
+ start = 0
217
+ for i in range(len(depth_counts)):
218
+ for j in range(depth_counts[i]):
219
+ cur_medusa_choice = sorted_medusa_choices[start + j]
220
+ # retrieve ancestor position
221
+ if len(cur_medusa_choice) == 1:
222
+ continue
223
+ ancestor_idx = []
224
+ for c in range(len(cur_medusa_choice) - 1):
225
+ ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
226
+ medusa_attn_mask[j + start + 1, ancestor_idx] = 1
227
+ start += depth_counts[i]
228
+
229
+ # Generate tree indices for the Medusa structure
230
+ medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
231
+ medusa_tree_indices[0] = 0
232
+ start = 0
233
+ for i in range(len(depth_counts)):
234
+ for j in range(depth_counts[i]):
235
+ cur_medusa_choice = sorted_medusa_choices[start + j]
236
+ medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
237
+ start += depth_counts[i]
238
+
239
+ # Generate position IDs for the Medusa structure
240
+ medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
241
+ start = 0
242
+ for i in range(len(depth_counts)):
243
+ medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
244
+ start += depth_counts[i]
245
+
246
+ # Generate retrieval indices for Medusa structure verification
247
+ retrieve_indices_nest = []
248
+ retrieve_paths = []
249
+ for i in range(len(sorted_medusa_choices)):
250
+ cur_medusa_choice = sorted_medusa_choices[-i-1]
251
+ retrieve_indice = []
252
+ if cur_medusa_choice in retrieve_paths:
253
+ continue
254
+ else:
255
+ for c in range(len(cur_medusa_choice)):
256
+ retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
257
+ retrieve_paths.append(cur_medusa_choice[:c+1])
258
+ retrieve_indices_nest.append(retrieve_indice)
259
+ max_length = max([len(x) for x in retrieve_indices_nest])
260
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
261
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
262
+ retrieve_indices = retrieve_indices + 1
263
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
264
+
265
+ # Aggregate the generated buffers into a dictionary
266
+ medusa_buffers = {
267
+ "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
268
+ "tree_indices": medusa_tree_indices,
269
+ "medusa_position_ids": medusa_position_ids.unsqueeze(0),
270
+ "retrieve_indices": retrieve_indices,
271
+ }
272
+
273
+ # Move the tensors in the dictionary to the specified device
274
+ medusa_buffers = {
275
+ k: v.clone().to(device)
276
+ if isinstance(v, torch.Tensor)
277
+ else torch.tensor(v, device=device)
278
+ for k, v in medusa_buffers.items()
279
+ }
280
+ return medusa_buffers
281
+
282
+ def generate_candidates(
283
+ medusa_logits,
284
+ logits,
285
+ tree_indices,
286
+ retrieve_indices,
287
+ temperature = 0,
288
+ posterior_threshold=0.3,
289
+ posterior_alpha = 0.09,
290
+ top_p=0.8,
291
+ sampling = 'typical',
292
+ fast = False
293
+ ):
294
+ # Say we have 3 heads, and the top-4 for each head are:
295
+ # [10, 3, 8, 4]
296
+ # [9, 5, 1, 6]
297
+ # [7, 16, 3, 2]
298
+
299
+ # candidates_id = 10
300
+ if temperature == 0 or fast:
301
+ candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
302
+ else:
303
+ if sampling == 'typical':
304
+ candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
305
+ elif sampling == 'nucleus':
306
+ candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
307
+ else:
308
+ raise NotImplementedError
309
+
310
+ # this calculates the top-k medusa logits
311
+ # candidates_medusa_id = [
312
+ # [9, 5, 1, 6]
313
+ # [7, 16, 3, 2]
314
+ # ]
315
+ candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
316
+
317
+ # [10, 9, 5, 1, 6, 7, 16, 3, 2]
318
+ candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
319
+
320
+ # based on the pre-defined tree_indices, select the corresponding candidates
321
+ # if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
322
+ # tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
323
+ tree_candidate_ids = candidate_ids[tree_indices]
324
+
325
+ # tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
326
+ # Sometimes the tree_indices are padded, so we append a zero here
327
+ # so that all padded indices select the appended zero.
328
+ tree_candidate_ids_ext = torch.cat(
329
+ [
330
+ tree_candidate_ids,
331
+ torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
332
+ ],
333
+ dim=0
334
+ )
335
+ # [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
336
+ unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
337
+
338
+ tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
339
+
340
+ return tree_candidate_ids, unflattened_candidate_ids
341
+
342
+ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
343
+ """
344
+ Generates a posterior mask for token candidates using nucleus (top-p) sampling.
345
+
346
+ This function applies nucleus sampling to a set of logits, and then generates a mask indicating
347
+ which candidate tokens are selected. It adapts the sampling strategy to accommodate for
348
+ temperature scaling and cumulative probability thresholding.
349
+
350
+ Args:
351
+ logits (torch.Tensor): A tensor of logits from a language model output.
352
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
353
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
354
+ top_p (float): The cumulative probability threshold for nucleus sampling.
355
+
356
+ Returns:
357
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
358
+ """
359
+ # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
360
+
361
+ # Apply temperature
362
+ logits = logits[:, :-1] / temperature
363
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
364
+ logits = logits.view(n_samples*n_tokens, -1)
365
+ if top_p >= 1:
366
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
367
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
368
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
369
+ return posterior_mask
370
+ # Convert to probabilities (softmax)
371
+ probs = F.softmax(logits, dim=-1)
372
+ # Sort the probabilities
373
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
374
+
375
+ # Compute cumulative probabilities
376
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
377
+
378
+ # Create mask for the top-p nucleus
379
+ sorted_indices_to_remove = cum_probs > top_p
380
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
381
+ sorted_indices_to_remove[..., 0] = 0
382
+
383
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
384
+
385
+
386
+ # Remove low-probability tokens
387
+ logits[indices_to_remove] = float('-inf')
388
+ # Sample from the remaining tokens
389
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
390
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
391
+ # Create a mask for selected tokens
392
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
393
+
394
+ return posterior_mask
395
+
396
+ def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
397
+ """
398
+ Args:
399
+ logits (torch.Tensor): A tensor of logits from a language model output.
400
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
401
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
402
+ posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
403
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
404
+
405
+ Returns:
406
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
407
+ """
408
+ logits = logits[:, :-1] / temperature
409
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
410
+ logits = logits.view(n_samples*n_tokens, -1)
411
+ probs = F.softmax(logits, dim=-1)
412
+ entropy = -torch.sum(
413
+ probs * torch.log(probs + 1e-5), dim=-1
414
+ )
415
+ threshold = torch.minimum(
416
+ torch.ones_like(entropy) * posterior_threshold,
417
+ torch.exp(-entropy) * posterior_alpha,
418
+ )
419
+ indices_to_remove = probs < threshold.unsqueeze(-1)
420
+ logits[indices_to_remove] = float('-inf')
421
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
422
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
423
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
424
+ return posterior_mask
425
+
426
+
427
+
428
+ def evaluate_posterior(
429
+ logits,
430
+ candidates,
431
+ temperature,
432
+ posterior_threshold=0.3,
433
+ posterior_alpha = 0.09,
434
+ top_p=0.8,
435
+ sampling = 'typical',
436
+ fast = True
437
+ ):
438
+ if logits.shape[1] <= 1:
439
+ return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
440
+ # Greedy decoding based on temperature value
441
+ if temperature == 0:
442
+ # Find the tokens that match the maximum logits for each position in the sequence
443
+ posterior_mask = (
444
+ candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
445
+ ).int()
446
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
447
+ accept_length = candidates_accept_length.max().item()
448
+ # Choose the best candidate
449
+ if accept_length == 0:
450
+ # Default to the first candidate if none are accepted
451
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
452
+ else:
453
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
454
+ return best_candidate, accept_length
455
+ elif sampling == 'typical':
456
+ if fast:
457
+ posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
458
+ candidates_prob = torch.gather(
459
+ posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
460
+ ).squeeze(-1)
461
+ posterior_entropy = -torch.sum(
462
+ posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
463
+ ) # torch.sum(torch.log(*)) is faster than torch.prod
464
+ threshold = torch.minimum(
465
+ torch.ones_like(posterior_entropy) * posterior_threshold,
466
+ torch.exp(-posterior_entropy) * posterior_alpha,
467
+ )
468
+ posterior_mask = candidates_prob > threshold
469
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
470
+
471
+ # Choose the best candidate based on the evaluated posterior probabilities
472
+ accept_length = candidates_accept_length.max().item()
473
+ if accept_length == 0:
474
+ # If no candidates are accepted, just choose the first one
475
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
476
+ else:
477
+ best_candidates = torch.where(candidates_accept_length == accept_length)[0]
478
+ # Accept the best one according to likelihood
479
+ likelihood = torch.sum(
480
+ torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
481
+ )
482
+ best_candidate = best_candidates[torch.argmax(likelihood)]
483
+ return best_candidate, accept_length
484
+ # Calculate posterior probabilities and thresholds for candidate selection
485
+ posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
486
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
487
+ # Choose the best candidate based on the evaluated posterior probabilities
488
+ accept_length = candidates_accept_length.max().item()
489
+
490
+ if accept_length == 0:
491
+ # If no candidates are accepted, just choose the first one
492
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
493
+ else:
494
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
495
+ # Accept the best one according to likelihood
496
+ return best_candidate, accept_length
497
+ elif sampling == 'nucleus':
498
+ assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
499
+ posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
500
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
501
+ accept_length = candidates_accept_length.max().item()
502
+ # Choose the best candidate
503
+ if accept_length == 0:
504
+ # Default to the first candidate if none are accepted
505
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
506
+ else:
507
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
508
+ return best_candidate, accept_length
509
+ else:
510
+ raise NotImplementedError
511
+
512
+ def update_inference_inputs(
513
+ input_ids,
514
+ medusa_logits,
515
+ logits,
516
+ candidate_ids,
517
+ best_candidate,
518
+ accept_length,
519
+ ):
520
+ input_ids = torch.cat(
521
+ [
522
+ input_ids,
523
+ candidate_ids[None, best_candidate, : accept_length + 1]
524
+ ],
525
+ dim=-1
526
+ )
527
+ logits = logits[
528
+ None, best_candidate, accept_length : accept_length + 1
529
+ ]
530
+ medusa_logits = medusa_logits[
531
+ :, None, best_candidate, accept_length : accept_length + 1
532
+ ]
533
+ # Update the new token counter
534
+ new_token = accept_length + 1
535
+ return input_ids, medusa_logits, logits, new_token
536
+
537
+ def split_logits(full_logits):
538
+ # logits has shape [b, n, heads, vocab_size]
539
+ logits = full_logits[..., 0, :]
540
+ medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
541
+ return medusa_logits, logits
542
+
543
+ class MultiByteDecodingMixin:
544
+ def multi_byte_pred_update_cache(
545
+ self,
546
+ past_key_values,
547
+ retrieve_indices,
548
+ best_candidate,
549
+ new_tokens,
550
+ ):
551
+ prev_window_len = past_key_values.get_past_window_pos(0)
552
+ select_indices = (
553
+ retrieve_indices[best_candidate, : new_tokens] + prev_window_len
554
+ )
555
+ for layer_idx in range(self.config.num_hidden_layers):
556
+
557
+ past_key_values.update_past_len(new_tokens, layer_idx)
558
+
559
+ past_window_k = past_key_values.past_window_k[layer_idx]
560
+ past_window_v = past_key_values.past_window_v[layer_idx]
561
+
562
+ tgt_window_k = past_window_k[..., select_indices, :]
563
+ tgt_window_v = past_window_v[..., select_indices, :]
564
+
565
+ dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
566
+ dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
567
+
568
+ dst_window_k.copy_(tgt_window_k, non_blocking=True)
569
+ dst_window_v.copy_(tgt_window_v, non_blocking=True)
570
+
571
+ new_window_len = prev_window_len + new_tokens
572
+ if new_window_len >= self.config.window_size:
573
+ assert new_window_len < 2 * self.config.window_size
574
+
575
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
576
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
577
+
578
+ _window_len = new_window_len - self.config.window_size
579
+
580
+ if _window_len > 0:
581
+ new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
582
+ new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
583
+
584
+ _dst_window_k = past_window_k[..., : _window_len, :]
585
+ _dst_window_v = past_window_v[..., : _window_len, :]
586
+
587
+ _dst_window_k.copy_(new_window_k, non_blocking=True)
588
+ _dst_window_v.copy_(new_window_v, non_blocking=True)
589
+
590
+ past_key_values.past_window_pos[layer_idx] = _window_len
591
+ else:
592
+ dump_k = None
593
+ dump_v = None
594
+ past_key_values.past_window_pos[layer_idx] = new_window_len
595
+
596
+ if dump_k is not None and dump_v is not None:
597
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
598
+ dump_k, dump_v,
599
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
600
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
601
+ None,
602
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
603
+ self.model.layers[layer_idx].self_attn.chunk_size
604
+ )
605
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
606
+ rfa_k, rfa_v, layer_idx
607
+ )
608
+ return past_key_values
609
+
610
+ def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
611
+ self,
612
+ past_key_values,
613
+ ):
614
+ prev_window_len = past_key_values.get_past_window_pos(0)
615
+ for layer_idx in range(self.config.num_hidden_layers):
616
+
617
+ past_window_k = past_key_values.past_window_k[layer_idx]
618
+ past_window_v = past_key_values.past_window_v[layer_idx]
619
+
620
+ new_window_len = prev_window_len
621
+ if new_window_len == self.config.window_size:
622
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
623
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
624
+ past_key_values.past_window_pos[layer_idx] = 0
625
+
626
+ if dump_k is not None and dump_v is not None:
627
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
628
+ dump_k, dump_v,
629
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
630
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
631
+ None,
632
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
633
+ self.model.layers[layer_idx].self_attn.chunk_size
634
+ )
635
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
636
+ rfa_k, rfa_v, layer_idx
637
+ )
638
+ return past_key_values
639
+
640
+ def multi_byte_pred_update_attn_mask(
641
+ self,
642
+ last_iter_new_tokens,
643
+ tree_candidate_ids,
644
+ past_attn_mask,
645
+ medusa_attn_mask,
646
+ past_key_values,
647
+ ):
648
+ batch_size, tree_candidate_len = tree_candidate_ids.shape
649
+ seen_tokens = past_key_values.get_seq_length()
650
+ # NOTE: past_key_values has been updated so now
651
+ # seen_tokens incldues new tokens from the last tree iteration
652
+ assert seen_tokens > 0
653
+ # so one iteration would not cross two windows
654
+ assert last_iter_new_tokens < self.config.window_size
655
+
656
+ if past_attn_mask is not None and seen_tokens < self.config.window_size:
657
+ past_attn_mask = torch.cat(
658
+ [
659
+ past_attn_mask,
660
+ torch.ones(
661
+ [batch_size, 1, tree_candidate_len, last_iter_new_tokens],
662
+ dtype=torch.bool,
663
+ device=self.device
664
+ )
665
+ ],
666
+ dim=-1
667
+ )
668
+ else:
669
+ # we initialize attn mask each time when
670
+ # 1. the model crosses the window bounary, or
671
+ # 2. after prefilling
672
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
673
+
674
+ window_tokens = seen_tokens % self.config.window_size
675
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
676
+ attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
677
+ past_attn_mask = torch.ones(
678
+ (batch_size, 1, tree_candidate_len, attn_mask_len),
679
+ dtype=torch.bool,
680
+ device=self.device
681
+ )
682
+
683
+ # note that 1 indicates the position is not masked
684
+ tree_attn_mask = torch.cat(
685
+ [
686
+ past_attn_mask,
687
+ medusa_attn_mask.to(torch.bool)
688
+ ],
689
+ dim=-1
690
+ )
691
+ return tree_attn_mask, past_attn_mask
692
+
693
+ @torch.no_grad()
694
+ def multi_byte_generate(
695
+ self,
696
+ input_ids,
697
+ attention_mask=None,
698
+ temperature=0.0,
699
+ max_length=None,
700
+ max_new_tokens=None,
701
+ stopping_criteria=None,
702
+ posterior_threshold=0.09,
703
+ posterior_alpha=0.3,
704
+ top_p=0.8,
705
+ sampling='typical',
706
+ fast=True,
707
+ do_sample=False,
708
+ medusa_choices=None,
709
+ return_acc_lengths=False
710
+ ):
711
+ if do_sample or temperature > 0.0:
712
+ fast = False
713
+
714
+ ### Prepare `max_length` depending on other stopping criteria.
715
+ if max_new_tokens is not None:
716
+ max_length = max_new_tokens + input_ids.shape[-1]
717
+ elif max_new_tokens is None and max_length is None:
718
+ max_length = getattr(self.config, "max_position_embeddings", 32768)
719
+
720
+ ### Set up stopping criteria
721
+ eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
722
+ stop_criteria = StoppingCriteriaList()
723
+ if max_length is not None:
724
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
725
+ stop_criteria.append(
726
+ MaxLengthCriteria(
727
+ max_length=max_length,
728
+ max_position_embeddings=max_position_embeddings,
729
+ )
730
+ )
731
+ if stopping_criteria is not None and len(stopping_criteria) > 0:
732
+ stop_criteria.extend(stopping_criteria)
733
+
734
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
735
+ assert attention_mask is None, "Only support attention mask None for now"
736
+ # Avoid modifying the input_ids in-place
737
+ input_ids = input_ids.clone()
738
+ position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
739
+
740
+ ####################################################
741
+ # 0. initialize the medusa buffers
742
+ ####################################################
743
+ if medusa_choices is None:
744
+ medusa_choices = evabyte_7b_95
745
+ medusa_buffers = generate_medusa_buffers(
746
+ medusa_choices, device=self.device
747
+ )
748
+
749
+ past_key_values = EvaStaticCacheForTriton(
750
+ input_ids.shape[0],
751
+ self.config.num_attention_heads,
752
+ # we add 256 to allow tree ids
753
+ self.config.window_size + 256,
754
+ self.config.hidden_size // self.config.num_attention_heads,
755
+ self.config.num_hidden_layers,
756
+ self.lm_head.weight.dtype,
757
+ self.lm_head.weight.device,
758
+ )
759
+ # prefill to get medusa logits and logits
760
+ full_logits, past_key_values = self.forward(
761
+ input_ids,
762
+ attention_mask=attention_mask,
763
+ position_ids=position_ids,
764
+ use_cache=True,
765
+ past_key_values=past_key_values,
766
+ return_all_pred_logits=True,
767
+ multibyte_decoding=False,
768
+ )
769
+ # handles an edge case where the prefill length == window_size
770
+ # we force the previous window to be dumped into RFA chunks
771
+ past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
772
+ past_key_values
773
+ )
774
+ medusa_logits, logits = split_logits(full_logits)
775
+
776
+ past_attn_mask = None
777
+ last_iter_new_tokens = 0
778
+ max_iters = 32768
779
+ if return_acc_lengths:
780
+ acc_lengths = []
781
+ for _ in range(max_iters):
782
+ ####################################################
783
+ # 1. generate candidate_ids with topk predictions from Medusa heads
784
+ ####################################################
785
+ tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
786
+ medusa_logits,
787
+ logits,
788
+ medusa_buffers["tree_indices"],
789
+ medusa_buffers["retrieve_indices"],
790
+ temperature=temperature,
791
+ posterior_alpha=posterior_alpha,
792
+ posterior_threshold=posterior_threshold,
793
+ top_p=top_p,
794
+ sampling=sampling,
795
+ fast=fast,
796
+ )
797
+
798
+ ####################################################
799
+ # 2. Build the medusa attention mask and position ids
800
+ ####################################################
801
+ # NOTE: 1 indicates the position is not masked
802
+ medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
803
+ last_iter_new_tokens,
804
+ tree_candidate_ids,
805
+ past_attn_mask,
806
+ medusa_buffers["medusa_attn_mask"],
807
+ past_key_values,
808
+ )
809
+ medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
810
+
811
+ ####################################################
812
+ # 3. tree decoding
813
+ ####################################################
814
+ tree_full_logits, past_key_values = self.forward(
815
+ tree_candidate_ids,
816
+ past_key_values=past_key_values,
817
+ attention_mask=medusa_attn_mask,
818
+ position_ids=medusa_position_ids,
819
+ return_all_pred_logits=True,
820
+ multibyte_decoding=True,
821
+ )
822
+ _medusa_logits, _logits = split_logits(tree_full_logits)
823
+ medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
824
+ logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
825
+
826
+ ####################################################
827
+ # 4. candidate selection
828
+ ####################################################
829
+ # if the current iteration, with tree tokens, crosses window
830
+ # boundaries, trim the condidate_ids to be within the window
831
+ # so that those exceeded tokens (which will be inaccurate)
832
+ # will not be considered
833
+ tree_depth = unflattened_candidate_ids.shape[-1]
834
+ if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
835
+ max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
836
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
837
+ _trimmed_logits = logits[:, :max_acc_len]
838
+ else:
839
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids
840
+ _trimmed_logits = logits
841
+ best_candidate, accept_length = evaluate_posterior(
842
+ _trimmed_logits,
843
+ _trimmed_unflattened_candidate_ids,
844
+ temperature,
845
+ posterior_threshold,
846
+ posterior_alpha,
847
+ top_p=top_p,
848
+ sampling=sampling,
849
+ fast=fast
850
+ )
851
+
852
+ ####################################################
853
+ # 5. update model inputs and caches
854
+ ####################################################
855
+ input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
856
+ input_ids,
857
+ medusa_logits,
858
+ logits,
859
+ unflattened_candidate_ids,
860
+ best_candidate,
861
+ accept_length,
862
+ )
863
+
864
+ past_key_values = self.multi_byte_pred_update_cache(
865
+ past_key_values,
866
+ medusa_buffers["retrieve_indices"],
867
+ best_candidate,
868
+ last_iter_new_tokens,
869
+ )
870
+
871
+ if return_acc_lengths:
872
+ acc_lengths.append(last_iter_new_tokens)
873
+ if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
874
+ if return_acc_lengths:
875
+ return input_ids, acc_lengths
876
+ else:
877
+ return input_ids
878
+ if return_acc_lengths:
879
+ return input_ids, acc_lengths
880
+ else:
881
+ return input_ids
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
4
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
5
+ },
6
+ "do_convert_rgb": true,
7
+ "do_resize": true,
8
+ "image_processor_type": "EvaByteImageProcessor",
9
+ "jpeg_quality": 25,
10
+ "jpeg_restart_marker_blocks": 1,
11
+ "jpeg_streamtype": 2,
12
+ "jpeg_subsampling": "4:2:0",
13
+ "processor_class": "EvaByteProcessor",
14
+ "resample": 1,
15
+ "size": {
16
+ "longest_edge": 384
17
+ }
18
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processing_evabyte.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Processor class for EvaByte.
4
+ """
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ import requests
9
+ import os
10
+ import PIL
11
+ from PIL import Image
12
+
13
+ from typing import List, Optional, Union
14
+
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.image_utils import ImageInput, is_valid_image
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
19
+ from transformers.utils import TensorType, to_py_obj
20
+
21
+ def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
22
+ image_obj = None
23
+ if isinstance(image, Image.Image):
24
+ image_obj = image
25
+ elif image.startswith("http://") or image.startswith("https://"):
26
+ image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
27
+ elif os.path.isfile(image):
28
+ image_obj = Image.open(image)
29
+ elif image.startswith("data:image/"):
30
+ image = image.split(",")[1]
31
+ # Try to load as base64
32
+ try:
33
+ b64 = base64.decodebytes(image.encode())
34
+ image = PIL.Image.open(BytesIO(b64))
35
+ except Exception as e:
36
+ raise ValueError(
37
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
38
+ )
39
+ else:
40
+ image_obj = Image.open(image)
41
+ if image_obj is None:
42
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
43
+
44
+ return image_obj
45
+
46
+ def is_url(val) -> bool:
47
+ return isinstance(val, str) and val.startswith("http")
48
+
49
+ def is_file(val) -> bool:
50
+ return isinstance(val, str) and os.path.isfile(val)
51
+
52
+ def is_image_or_image_url(elem):
53
+ return is_url(elem) or is_valid_image(elem) or is_file(elem)
54
+
55
+ vl_chat_template = """
56
+ {{- bos_token }}
57
+ {%- if messages[0]['role'] == 'system' %}
58
+ {%- set system_message = messages[0]['content'] %}
59
+ {%- set messages = messages[1:] %}
60
+ {%- else %}
61
+ {%- set system_message = "" %}
62
+ {%- endif %}
63
+
64
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
65
+
66
+ {%- for message in messages %}
67
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
68
+ {{- raise_exception('Conversation roles must be user or assistant') }}
69
+ {%- endif %}
70
+
71
+ {%- if message['content'] is string %}
72
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
73
+ {%- else %}
74
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
75
+ {%- for content in message['content'] %}
76
+ {%- if content['type'] == 'image' %}
77
+ {{- '<image_placeholder>\n' }}
78
+ {%- elif content['type'] == 'text' %}
79
+ {{- content['text'] }}
80
+ {%- endif %}
81
+ {%- endfor %}
82
+ {{- '<|eot_id|>' }}
83
+ {%- endif %}
84
+ {%- endfor %}
85
+
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
88
+ {%- endif %}
89
+ """
90
+
91
+ class EvaByteProcessor(ProcessorMixin):
92
+ r"""
93
+ Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
94
+
95
+ [`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
96
+ [`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
97
+
98
+ Args:
99
+ image_processor ([`EvaByteImageProcessor`], *optional*):
100
+ The image processor is a required input.
101
+ tokenizer ([`EvaByteTokenizer`], *optional*):
102
+ The tokenizer is a required input.
103
+ """
104
+
105
+ attributes = ["image_processor", "tokenizer"]
106
+ image_processor_class = "AutoImageProcessor"
107
+ tokenizer_class = "AutoTokenizer"
108
+
109
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
110
+ if image_processor is None:
111
+ raise ValueError("You need to specify an `image_processor`.")
112
+ if tokenizer is None:
113
+ raise ValueError("You need to specify a `tokenizer`.")
114
+
115
+ super().__init__(image_processor, tokenizer)
116
+ self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
117
+ self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
118
+ self.image_placeholder = "<image_placeholder>"
119
+ self.vl_chat_template = vl_chat_template
120
+
121
+ def __call__(
122
+ self,
123
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
124
+ images: ImageInput = None,
125
+ return_tensors: Optional[Union[str, TensorType]] = None,
126
+ strip_ending_sentinel: bool = False,
127
+ encode_only: bool = False,
128
+ **kwargs
129
+ ) -> Union[BatchFeature, List[List[int]]]:
130
+ # processing pipeline:
131
+ # 1. read images or videos from paths
132
+ # 2. use image_processor to convert images / videos to byte streams
133
+ if images is not None:
134
+ if isinstance(images, bytes):
135
+ image_bytes_list = [[images]]
136
+ elif isinstance(images, list) and isinstance(images[0], bytes):
137
+ image_bytes_list = [images]
138
+ elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
139
+ image_bytes_list = images
140
+ else:
141
+ if is_image_or_image_url(images):
142
+ images = [[images]]
143
+ elif isinstance(images, list) and is_image_or_image_url(images[0]):
144
+ images = [images]
145
+ elif (
146
+ not isinstance(images, list)
147
+ and not isinstance(images[0], list)
148
+ and not is_image_or_image_url(images[0][0])
149
+ ):
150
+ raise ValueError(
151
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
152
+ )
153
+ # Load images if they are URLs
154
+ images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
155
+ image_bytes_list = self.image_processor(images=images, **kwargs)
156
+
157
+ if not isinstance(text, list):
158
+ text = [text]
159
+ assert len(text) == 1, "Only support batch size 1 for now"
160
+ assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
161
+ # TODO: invoke SequenceFeatureExtractor to get batched inputs
162
+
163
+ # 3. tokenize the text and put images / videos byte streams into the placeholders
164
+ # surrounded by special tokens like "<image>" and "</image>"
165
+ batch_input_ids = []
166
+ if not encode_only:
167
+ batch_attention_mask = []
168
+ else:
169
+ batch_attention_mask = None
170
+
171
+ for t, image_bytes in zip(text, image_bytes_list):
172
+ text_splits = t.split(self.image_placeholder)
173
+ if len(text_splits) != len(image_bytes) + 1:
174
+ raise ValueError(
175
+ f"The number of image tokens should be equal to the number of images, "
176
+ f"but got {len(text_splits)} and {len(image_bytes) + 1}"
177
+ )
178
+
179
+ input_ids = [self.tokenizer.bos_token_id]
180
+ for i, text_part in enumerate(text_splits):
181
+ # each text part must be non-empty because we added markers around placeholders
182
+ split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
183
+ input_ids.extend(split_tokens)
184
+ # Add image bytes after each text part except the last one
185
+ if i < len(image_bytes):
186
+ input_ids.append(self.t2v_token_id)
187
+ input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
188
+ input_ids.append(self.v2t_token_id)
189
+
190
+ if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
191
+ input_ids = input_ids[:-1]
192
+
193
+ batch_input_ids.append(input_ids)
194
+ if not encode_only:
195
+ batch_attention_mask.append([1] * len(input_ids))
196
+
197
+ if not encode_only:
198
+ # 4. return batch of features
199
+ inputs = BatchFeature({
200
+ "input_ids": batch_input_ids,
201
+ "attention_mask": batch_attention_mask
202
+ }, tensor_type=return_tensors)
203
+ return inputs
204
+ # # Pad sequences
205
+ # padded_inputs = self.tokenizer.pad(
206
+ # {"input_ids": batch_input_ids},
207
+ # padding=True,
208
+ # return_attention_mask=True,
209
+ # return_tensors=return_tensors,
210
+ # )
211
+ # return BatchFeature(data=padded_inputs)
212
+ else:
213
+ return batch_input_ids
214
+
215
+ def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
216
+ image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
217
+ image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
218
+ return image_bytes
219
+
220
+ def batch_decode(self, sequences, **kwargs):
221
+ """
222
+ This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
223
+ refer to the docstring of this method for more information.
224
+ """
225
+ rets = [self.decode(seq, **kwargs) for seq in sequences]
226
+ return tuple(map(list, zip(*rets)))
227
+
228
+ def decode(self, token_ids, **kwargs):
229
+ """
230
+ Decodes a sequence of input_ids, handling image tokens separately.
231
+ Returns a tuple of (decoded_text, images), where images is a list of bytes.
232
+ """
233
+ if kwargs and "jpeg_quality" in kwargs:
234
+ kwargs = kwargs.copy()
235
+ jpeg_quality = kwargs.pop("jpeg_quality")
236
+ else:
237
+ jpeg_quality = None
238
+
239
+ token_ids = to_py_obj(token_ids)
240
+ # Find indices of t2v_token_id and v2t_token_id
241
+ t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
242
+ v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
243
+
244
+ # Check for correct pairing of t2v and v2t tokens
245
+ if len(t2v_indices) != len(v2t_indices):
246
+ raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
247
+
248
+ # Ensure t2v and v2t tokens are in the correct order
249
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
250
+ if t2v_idx >= v2t_idx:
251
+ raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
252
+
253
+ # Initialize the start index
254
+ images = []
255
+ decoded_text = ""
256
+
257
+ start = 0
258
+ # Iterate over pairs of t2v and v2t indices
259
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
260
+ # Decode text tokens before the image
261
+ text_token_ids = token_ids[start:t2v_idx]
262
+ if len(text_token_ids) > 0:
263
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
264
+
265
+ # Insert image placeholder
266
+ decoded_text += self.image_placeholder
267
+
268
+ # Extract image tokens and convert them to bytes
269
+ image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
270
+ image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
271
+ images.append(image_bytes)
272
+
273
+ # Update the start index to the token after v2t_token_id
274
+ start = v2t_idx + 1
275
+
276
+ # Decode any remaining text tokens after the last image
277
+ if start < len(token_ids):
278
+ text_token_ids = token_ids[start:]
279
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
280
+
281
+ return decoded_text, images
282
+
283
+ @property
284
+ def model_input_names(self):
285
+ tokenizer_input_names = self.tokenizer.model_input_names
286
+ image_processor_input_names = self.image_processor.model_input_names
287
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
4
+ },
5
+ "processor_class": "EvaByteProcessor"
6
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/special_tokens_map.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<repo_name>",
4
+ "<file_sep>",
5
+ "<t2v_token>",
6
+ "<v2t_token>",
7
+ "<|start_header_id|>",
8
+ "<|end_header_id|>",
9
+ "<|eot_id|>",
10
+ "<extra_id_12>",
11
+ "<extra_id_13>",
12
+ "<extra_id_14>",
13
+ "<extra_id_15>",
14
+ "<extra_id_16>",
15
+ "<extra_id_17>",
16
+ "<extra_id_18>",
17
+ "<extra_id_19>",
18
+ "<extra_id_20>",
19
+ "<extra_id_21>",
20
+ "<extra_id_22>",
21
+ "<extra_id_23>",
22
+ "<extra_id_24>",
23
+ "<extra_id_25>",
24
+ "<extra_id_26>",
25
+ "<extra_id_27>",
26
+ "<extra_id_28>",
27
+ "<extra_id_29>",
28
+ "<extra_id_30>",
29
+ "<extra_id_31>",
30
+ "<extra_id_32>",
31
+ "<extra_id_33>",
32
+ "<extra_id_34>",
33
+ "<extra_id_35>",
34
+ "<extra_id_36>",
35
+ "<extra_id_37>",
36
+ "<extra_id_38>",
37
+ "<extra_id_39>",
38
+ "<extra_id_40>",
39
+ "<extra_id_41>",
40
+ "<extra_id_42>",
41
+ "<extra_id_43>",
42
+ "<extra_id_44>",
43
+ "<extra_id_45>",
44
+ "<extra_id_46>",
45
+ "<extra_id_47>",
46
+ "<extra_id_48>",
47
+ "<extra_id_49>",
48
+ "<extra_id_50>",
49
+ "<extra_id_51>",
50
+ "<extra_id_52>",
51
+ "<extra_id_53>",
52
+ "<extra_id_54>",
53
+ "<extra_id_55>",
54
+ "<extra_id_56>",
55
+ "<extra_id_57>",
56
+ "<extra_id_58>",
57
+ "<extra_id_59>",
58
+ "<extra_id_60>",
59
+ "<extra_id_61>",
60
+ "<extra_id_62>",
61
+ "<extra_id_63>"
62
+ ],
63
+ "bos_token": {
64
+ "content": "<bos>",
65
+ "lstrip": false,
66
+ "normalized": true,
67
+ "rstrip": false,
68
+ "single_word": false
69
+ },
70
+ "eos_token": {
71
+ "content": "<eos>",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false
76
+ },
77
+ "pad_token": {
78
+ "content": "<pad>",
79
+ "lstrip": false,
80
+ "normalized": true,
81
+ "rstrip": false,
82
+ "single_word": false
83
+ },
84
+ "sep_token": {
85
+ "content": "<sep>",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false
90
+ },
91
+ "unk_token": {
92
+ "content": "<unk>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false
97
+ }
98
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenization_evabyte.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """ Tokenization class for model EvaByte."""
4
+
5
+
6
+ from typing import List, Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ chat_template = """
16
+ {{- bos_token }}
17
+ {%- if messages[0]['role'] == 'system' %}
18
+ {%- set system_message = messages[0]['content'] %}
19
+ {%- set messages = messages[1:] %}
20
+ {%- else %}
21
+ {%- set system_message = "" %}
22
+ {%- endif %}
23
+
24
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
25
+
26
+ {%- for message in messages %}
27
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
28
+ {{- raise_exception('Conversation roles must be user or assistant') }}
29
+ {%- endif %}
30
+
31
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
32
+ {%- endfor %}
33
+
34
+ {%- if add_generation_prompt %}
35
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
36
+ {%- endif %}
37
+ """
38
+
39
+ class EvaByteTokenizer(PreTrainedTokenizer):
40
+ def __init__(
41
+ self,
42
+ bos_token="<bos>",
43
+ eos_token="<eos>",
44
+ unk_token="<unk>",
45
+ sep_token="<sep>",
46
+ pad_token="<pad>",
47
+ extra_ids=59,
48
+ additional_special_tokens=None,
49
+ clean_up_tokenization_spaces=False,
50
+ **kwargs,
51
+ ) -> None:
52
+ num_base_special_tokens = 5
53
+ # Add extra_ids to the special token list
54
+ if extra_ids > 0 and additional_special_tokens is None:
55
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
56
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
57
+ # Check that we have the right number of extra_id special tokens
58
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
59
+ if extra_tokens != extra_ids:
60
+ raise ValueError(
61
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
62
+ " provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
63
+ " extra_ids tokens"
64
+ )
65
+
66
+ #### override some reserved tokens to support chat template
67
+ for i, token in enumerate(additional_special_tokens):
68
+ if token == "<extra_id_5>":
69
+ token = "<repo_name>"
70
+ elif token == "<extra_id_6>":
71
+ token = "<file_sep>"
72
+ elif token == "<extra_id_7>":
73
+ token = "<t2v_token>"
74
+ elif token == "<extra_id_8>":
75
+ token = "<v2t_token>"
76
+ elif token == "<extra_id_9>":
77
+ token = "<|start_header_id|>"
78
+ elif token == "<extra_id_10>":
79
+ token = "<|end_header_id|>"
80
+ elif token == "<extra_id_11>":
81
+ token = "<|eot_id|>"
82
+ additional_special_tokens[i] = token
83
+
84
+ # lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
85
+ # this would be important for the byte tokenizer
86
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
87
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
88
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
89
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
90
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
91
+
92
+ self._added_tokens_decoder = {
93
+ 0: pad_token,
94
+ 1: bos_token,
95
+ 2: eos_token,
96
+ 3: unk_token, # unk_token is a placeholder
97
+ 4: sep_token,
98
+ **{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
99
+ }
100
+ self.offset = len(self._added_tokens_decoder)
101
+ self._utf_vocab_size = 2**8 # utf is 8 bits
102
+ self.add_bos_token = True
103
+ self.add_eos_token = False
104
+ super().__init__(
105
+ pad_token=pad_token,
106
+ bos_token=bos_token,
107
+ eos_token=eos_token,
108
+ unk_token=unk_token,
109
+ sep_token=sep_token,
110
+ extra_ids=0,
111
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
112
+ additional_special_tokens=additional_special_tokens,
113
+ **kwargs,
114
+ )
115
+ self.chat_template = chat_template
116
+
117
+
118
+ @property
119
+ def vocab_size(self):
120
+ return self._utf_vocab_size
121
+
122
+ def get_vocab(self):
123
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
124
+ vocab.update(self.added_tokens_encoder)
125
+ return vocab
126
+
127
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
128
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
129
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
130
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
131
+
132
+ output = bos_token_id + token_ids_0 + eos_token_id
133
+
134
+ if token_ids_1 is not None:
135
+ output = output + bos_token_id + token_ids_1 + eos_token_id
136
+
137
+ return output
138
+
139
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
140
+ def get_special_tokens_mask(
141
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
142
+ ) -> List[int]:
143
+ """
144
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
145
+ special tokens using the tokenizer `prepare_for_model` method.
146
+
147
+ Args:
148
+ token_ids_0 (`List[int]`):
149
+ List of IDs.
150
+ token_ids_1 (`List[int]`, *optional*):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
153
+ Whether or not the token list is already formatted with special tokens for the model.
154
+
155
+ Returns:
156
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
157
+ """
158
+ if already_has_special_tokens:
159
+ return super().get_special_tokens_mask(
160
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
161
+ )
162
+
163
+ bos_token_id = [1] if self.add_bos_token else []
164
+ eos_token_id = [1] if self.add_eos_token else []
165
+
166
+ if token_ids_1 is None:
167
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
168
+ return (
169
+ bos_token_id
170
+ + ([0] * len(token_ids_0))
171
+ + eos_token_id
172
+ + bos_token_id
173
+ + ([0] * len(token_ids_1))
174
+ + eos_token_id
175
+ )
176
+
177
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
178
+ def create_token_type_ids_from_sequences(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
183
+ sequence pair mask has the following format:
184
+
185
+ ```
186
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
187
+ | first sequence | second sequence |
188
+ ```
189
+
190
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
191
+
192
+ Args:
193
+ token_ids_0 (`List[int]`):
194
+ List of ids.
195
+ token_ids_1 (`List[int]`, *optional*):
196
+ Optional second list of IDs for sequence pairs.
197
+
198
+ Returns:
199
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
200
+ """
201
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
202
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
203
+
204
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
205
+
206
+ if token_ids_1 is not None:
207
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
208
+
209
+ return output
210
+
211
+ def _tokenize(self, text: str) -> List[str]:
212
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
213
+ tokens = [chr(i) for i in text.encode("utf-8")]
214
+ return tokens
215
+
216
+ def _convert_token_to_id(self, token):
217
+ """Converts a token (str) in an id using the vocab."""
218
+
219
+ if len(token) != 1:
220
+ token_id = None
221
+ else:
222
+ token_id = ord(token) + self.offset
223
+
224
+ return token_id
225
+
226
+ def _convert_id_to_token(self, index):
227
+ """Converts an index (integer) to a byte (str) using the vocab."""
228
+ token = chr(index - self.offset)
229
+ return token
230
+
231
+ def convert_tokens_to_string(self, tokens):
232
+ """Converts a sequence of bytes (string) to a single string."""
233
+ bstring = b""
234
+ for token in tokens:
235
+ if token in self.added_tokens_decoder:
236
+ tok_string = self.added_tokens_decoder[token].encode("utf-8")
237
+ elif token in self.added_tokens_encoder:
238
+ tok_string = token.encode("utf-8")
239
+ else:
240
+ tok_string = bytes([ord(token)])
241
+ bstring += tok_string
242
+ string = bstring.decode("utf-8", errors="ignore")
243
+ return string
244
+
245
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
246
+ return ()
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-30000/tokenizer_config.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<sep>",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<repo_name>",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": false
50
+ },
51
+ "6": {
52
+ "content": "<file_sep>",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ },
59
+ "7": {
60
+ "content": "<t2v_token>",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": false
66
+ },
67
+ "8": {
68
+ "content": "<v2t_token>",
69
+ "lstrip": false,
70
+ "normalized": true,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": false
74
+ },
75
+ "9": {
76
+ "content": "<|start_header_id|>",
77
+ "lstrip": false,
78
+ "normalized": true,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "10": {
84
+ "content": "<|end_header_id|>",
85
+ "lstrip": false,
86
+ "normalized": true,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "11": {
92
+ "content": "<|eot_id|>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "12": {
100
+ "content": "<extra_id_12>",
101
+ "lstrip": false,
102
+ "normalized": true,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "13": {
108
+ "content": "<extra_id_13>",
109
+ "lstrip": false,
110
+ "normalized": true,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "14": {
116
+ "content": "<extra_id_14>",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": false
122
+ },
123
+ "15": {
124
+ "content": "<extra_id_15>",
125
+ "lstrip": false,
126
+ "normalized": true,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": false
130
+ },
131
+ "16": {
132
+ "content": "<extra_id_16>",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": false
138
+ },
139
+ "17": {
140
+ "content": "<extra_id_17>",
141
+ "lstrip": false,
142
+ "normalized": true,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": false
146
+ },
147
+ "18": {
148
+ "content": "<extra_id_18>",
149
+ "lstrip": false,
150
+ "normalized": true,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": false
154
+ },
155
+ "19": {
156
+ "content": "<extra_id_19>",
157
+ "lstrip": false,
158
+ "normalized": true,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "20": {
164
+ "content": "<extra_id_20>",
165
+ "lstrip": false,
166
+ "normalized": true,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "21": {
172
+ "content": "<extra_id_21>",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": false
178
+ },
179
+ "22": {
180
+ "content": "<extra_id_22>",
181
+ "lstrip": false,
182
+ "normalized": true,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": false
186
+ },
187
+ "23": {
188
+ "content": "<extra_id_23>",
189
+ "lstrip": false,
190
+ "normalized": true,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": false
194
+ },
195
+ "24": {
196
+ "content": "<extra_id_24>",
197
+ "lstrip": false,
198
+ "normalized": true,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": false
202
+ },
203
+ "25": {
204
+ "content": "<extra_id_25>",
205
+ "lstrip": false,
206
+ "normalized": true,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": false
210
+ },
211
+ "26": {
212
+ "content": "<extra_id_26>",
213
+ "lstrip": false,
214
+ "normalized": true,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": false
218
+ },
219
+ "27": {
220
+ "content": "<extra_id_27>",
221
+ "lstrip": false,
222
+ "normalized": true,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": false
226
+ },
227
+ "28": {
228
+ "content": "<extra_id_28>",
229
+ "lstrip": false,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": false
234
+ },
235
+ "29": {
236
+ "content": "<extra_id_29>",
237
+ "lstrip": false,
238
+ "normalized": true,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": false
242
+ },
243
+ "30": {
244
+ "content": "<extra_id_30>",
245
+ "lstrip": false,
246
+ "normalized": true,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": false
250
+ },
251
+ "31": {
252
+ "content": "<extra_id_31>",
253
+ "lstrip": false,
254
+ "normalized": true,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": false
258
+ },
259
+ "32": {
260
+ "content": "<extra_id_32>",
261
+ "lstrip": false,
262
+ "normalized": true,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": false
266
+ },
267
+ "33": {
268
+ "content": "<extra_id_33>",
269
+ "lstrip": false,
270
+ "normalized": true,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": false
274
+ },
275
+ "34": {
276
+ "content": "<extra_id_34>",
277
+ "lstrip": false,
278
+ "normalized": true,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": false
282
+ },
283
+ "35": {
284
+ "content": "<extra_id_35>",
285
+ "lstrip": false,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": false
290
+ },
291
+ "36": {
292
+ "content": "<extra_id_36>",
293
+ "lstrip": false,
294
+ "normalized": true,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": false
298
+ },
299
+ "37": {
300
+ "content": "<extra_id_37>",
301
+ "lstrip": false,
302
+ "normalized": true,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": false
306
+ },
307
+ "38": {
308
+ "content": "<extra_id_38>",
309
+ "lstrip": false,
310
+ "normalized": true,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": false
314
+ },
315
+ "39": {
316
+ "content": "<extra_id_39>",
317
+ "lstrip": false,
318
+ "normalized": true,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": false
322
+ },
323
+ "40": {
324
+ "content": "<extra_id_40>",
325
+ "lstrip": false,
326
+ "normalized": true,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": false
330
+ },
331
+ "41": {
332
+ "content": "<extra_id_41>",
333
+ "lstrip": false,
334
+ "normalized": true,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": false
338
+ },
339
+ "42": {
340
+ "content": "<extra_id_42>",
341
+ "lstrip": false,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": false
346
+ },
347
+ "43": {
348
+ "content": "<extra_id_43>",
349
+ "lstrip": false,
350
+ "normalized": true,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": false
354
+ },
355
+ "44": {
356
+ "content": "<extra_id_44>",
357
+ "lstrip": false,
358
+ "normalized": true,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": false
362
+ },
363
+ "45": {
364
+ "content": "<extra_id_45>",
365
+ "lstrip": false,
366
+ "normalized": true,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": false
370
+ },
371
+ "46": {
372
+ "content": "<extra_id_46>",
373
+ "lstrip": false,
374
+ "normalized": true,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": false
378
+ },
379
+ "47": {
380
+ "content": "<extra_id_47>",
381
+ "lstrip": false,
382
+ "normalized": true,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": false
386
+ },
387
+ "48": {
388
+ "content": "<extra_id_48>",
389
+ "lstrip": false,
390
+ "normalized": true,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": false
394
+ },
395
+ "49": {
396
+ "content": "<extra_id_49>",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": false
402
+ },
403
+ "50": {
404
+ "content": "<extra_id_50>",
405
+ "lstrip": false,
406
+ "normalized": true,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": false
410
+ },
411
+ "51": {
412
+ "content": "<extra_id_51>",
413
+ "lstrip": false,
414
+ "normalized": true,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": false
418
+ },
419
+ "52": {
420
+ "content": "<extra_id_52>",
421
+ "lstrip": false,
422
+ "normalized": true,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": false
426
+ },
427
+ "53": {
428
+ "content": "<extra_id_53>",
429
+ "lstrip": false,
430
+ "normalized": true,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": false
434
+ },
435
+ "54": {
436
+ "content": "<extra_id_54>",
437
+ "lstrip": false,
438
+ "normalized": true,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": false
442
+ },
443
+ "55": {
444
+ "content": "<extra_id_55>",
445
+ "lstrip": false,
446
+ "normalized": true,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": false
450
+ },
451
+ "56": {
452
+ "content": "<extra_id_56>",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": false
458
+ },
459
+ "57": {
460
+ "content": "<extra_id_57>",
461
+ "lstrip": false,
462
+ "normalized": true,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": false
466
+ },
467
+ "58": {
468
+ "content": "<extra_id_58>",
469
+ "lstrip": false,
470
+ "normalized": true,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": false
474
+ },
475
+ "59": {
476
+ "content": "<extra_id_59>",
477
+ "lstrip": false,
478
+ "normalized": true,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": false
482
+ },
483
+ "60": {
484
+ "content": "<extra_id_60>",
485
+ "lstrip": false,
486
+ "normalized": true,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": false
490
+ },
491
+ "61": {
492
+ "content": "<extra_id_61>",
493
+ "lstrip": false,
494
+ "normalized": true,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": false
498
+ },
499
+ "62": {
500
+ "content": "<extra_id_62>",
501
+ "lstrip": false,
502
+ "normalized": true,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": false
506
+ },
507
+ "63": {
508
+ "content": "<extra_id_63>",
509
+ "lstrip": false,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": false
514
+ }
515
+ },
516
+ "additional_special_tokens": [
517
+ "<repo_name>",
518
+ "<file_sep>",
519
+ "<t2v_token>",
520
+ "<v2t_token>",
521
+ "<|start_header_id|>",
522
+ "<|end_header_id|>",
523
+ "<|eot_id|>",
524
+ "<extra_id_12>",
525
+ "<extra_id_13>",
526
+ "<extra_id_14>",
527
+ "<extra_id_15>",
528
+ "<extra_id_16>",
529
+ "<extra_id_17>",
530
+ "<extra_id_18>",
531
+ "<extra_id_19>",
532
+ "<extra_id_20>",
533
+ "<extra_id_21>",
534
+ "<extra_id_22>",
535
+ "<extra_id_23>",
536
+ "<extra_id_24>",
537
+ "<extra_id_25>",
538
+ "<extra_id_26>",
539
+ "<extra_id_27>",
540
+ "<extra_id_28>",
541
+ "<extra_id_29>",
542
+ "<extra_id_30>",
543
+ "<extra_id_31>",
544
+ "<extra_id_32>",
545
+ "<extra_id_33>",
546
+ "<extra_id_34>",
547
+ "<extra_id_35>",
548
+ "<extra_id_36>",
549
+ "<extra_id_37>",
550
+ "<extra_id_38>",
551
+ "<extra_id_39>",
552
+ "<extra_id_40>",
553
+ "<extra_id_41>",
554
+ "<extra_id_42>",
555
+ "<extra_id_43>",
556
+ "<extra_id_44>",
557
+ "<extra_id_45>",
558
+ "<extra_id_46>",
559
+ "<extra_id_47>",
560
+ "<extra_id_48>",
561
+ "<extra_id_49>",
562
+ "<extra_id_50>",
563
+ "<extra_id_51>",
564
+ "<extra_id_52>",
565
+ "<extra_id_53>",
566
+ "<extra_id_54>",
567
+ "<extra_id_55>",
568
+ "<extra_id_56>",
569
+ "<extra_id_57>",
570
+ "<extra_id_58>",
571
+ "<extra_id_59>",
572
+ "<extra_id_60>",
573
+ "<extra_id_61>",
574
+ "<extra_id_62>",
575
+ "<extra_id_63>"
576
+ ],
577
+ "auto_map": {
578
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor",
579
+ "AutoTokenizer": [
580
+ "tokenization_evabyte.EvaByteTokenizer",
581
+ null
582
+ ]
583
+ },
584
+ "bos_token": "<bos>",
585
+ "chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
586
+ "clean_up_tokenization_spaces": false,
587
+ "eos_token": "<eos>",
588
+ "extra_ids": 0,
589
+ "extra_special_tokens": {},
590
+ "model_max_length": 1000000000000000019884624838656,
591
+ "pad_token": "<pad>",
592
+ "processor_class": "EvaByteProcessor",
593
+ "sep_token": "<sep>",
594
+ "tokenizer_class": "EvaByteTokenizer",
595
+ "unk_token": "<unk>"
596
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # EvaByte Model Card
5
+
6
+ **EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
7
+
8
+ ## Model Resources
9
+
10
+ - **Repository:** https://github.com/openevabyte/evabyte
11
+ - **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
12
+ - **Paper:** Coming soon
13
+
14
+ ## Model Details
15
+
16
+ EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
17
+
18
+ | Stage | Model |
19
+ |:----- |:-----|
20
+ | Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
21
+ | Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
22
+ | SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
23
+
24
+
25
+ ## Usage
26
+
27
+ **Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
28
+
29
+ The code snippet below demonstrates EvaByte-6.5B for completion:
30
+
31
+ ```python
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ import torch
34
+
35
+ # Load model and tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
38
+
39
+ prompt = "The quick brown fox jumps "
40
+
41
+ # Tokenize input
42
+ # Option 1: standard HF tokenizer interface
43
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
44
+
45
+ # Option 2: Direct UTF-8 byte encoding with offset
46
+ # Note: Each byte is offset by 64 with <bos> prepended.
47
+ input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
48
+
49
+ # byte-by-byte generation (default)
50
+ generation_output = model.generate(
51
+ input_ids=input_ids,
52
+ max_new_tokens=32
53
+ )
54
+ # alternatively, use faster multibyte generation
55
+ generation_output = model.multi_byte_generate(
56
+ input_ids=input_ids,
57
+ max_new_tokens=32
58
+ )
59
+
60
+ # Decode and print the output
61
+ response = tokenizer.decode(
62
+ generation_output[0][input_ids.shape[1]:],
63
+ skip_special_tokens=False,
64
+ clean_up_tokenization_spaces=False
65
+ )
66
+ print(response)
67
+ # Sample output:
68
+ # over the lazy dog.\n\nThe quick
69
+ ```
70
+
71
+ ### ⚙️ Generation Modes
72
+
73
+ EvaByte supports two generation interfaces:
74
+ - `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
75
+ - `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
76
+ - `input_ids`: the input byte ids.
77
+ - `temperature`: the temperature for sampling.
78
+ - `max_length`: the maximum length of the generated sequence.
79
+ - `max_new_tokens`: the maximum number of new bytes to generate.
80
+ - `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
81
+ - `top_p`: the top-p parameter for sampling.
82
+ - `do_sample`: greedy decoding or sampling.
83
+
84
+ **Notes and Limitations:**
85
+ - `device_map="auto"` is not supported for >2 GPUs.
86
+ - Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
87
+ - `torch_dtype=torch.bfloat16` is required.
88
+ - The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
89
+
90
+ ## Bias, Risks, and Limitations
91
+ As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
92
+
93
+ ## Evaluation
94
+
95
+ For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
96
+
97
+ ## Citation
98
+ ```bibtex
99
+ @misc{evabyte,
100
+ title = {EvaByte: Efficient Byte-level Language Models at Scale},
101
+ url = {https://hkunlp.github.io/blog/2025/evabyte},
102
+ author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
103
+ year = {2025}
104
+ }
105
+ ```
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": null,
3
+ "architectures": [
4
+ "EvaByteForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_class": "eva",
8
+ "attention_dropout": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_evabyte.EvaByteConfig",
11
+ "AutoModelForCausalLM": "modeling_evabyte.EvaByteForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "chunk_size": 16,
15
+ "eos_token_id": 2,
16
+ "fp32_ln": true,
17
+ "fp32_logits": true,
18
+ "fp32_skip_add": false,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 5120,
21
+ "init_cutoff_factor": null,
22
+ "init_fn": "v2",
23
+ "init_std": 0.01275,
24
+ "initializer_range": 0.01275,
25
+ "intermediate_size": 16384,
26
+ "lazy_init": true,
27
+ "max_position_embeddings": 16384,
28
+ "max_seq_length": 16384,
29
+ "mixedp_attn": true,
30
+ "model_type": "evabyte",
31
+ "norm_add_unit_offset": true,
32
+ "num_attention_heads": 40,
33
+ "num_chunks": null,
34
+ "num_hidden_layers": 40,
35
+ "num_key_value_heads": 40,
36
+ "num_pred_heads": 1,
37
+ "pad_token_id": 0,
38
+ "return_dict": false,
39
+ "rms_norm_eps": 1e-06,
40
+ "rope_scaling": null,
41
+ "rope_theta": 100000.0,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "bfloat16",
44
+ "transformers_version": "4.47.1",
45
+ "use_cache": true,
46
+ "vocab_size": 320,
47
+ "window_size": 2048
48
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/configuration_evabyte.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EvaByte configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class EvaByteConfig(PretrainedConfig):
6
+ model_type = "evabyte"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=320,
12
+ hidden_size=4096,
13
+ intermediate_size=11008,
14
+ num_hidden_layers=32,
15
+ num_attention_heads=32,
16
+ num_key_value_heads=None,
17
+ hidden_act="silu",
18
+ max_position_embeddings=2048,
19
+ initializer_range=0.02,
20
+ rms_norm_eps=1e-6,
21
+ use_cache=True,
22
+ pad_token_id=None,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ norm_add_unit_offset=False,
31
+ init_fn="mitchell",
32
+ init_std=0.006,
33
+ init_cutoff_factor=None,
34
+ attention_class="mha",
35
+ window_size=512,
36
+ num_chunks=None,
37
+ chunk_size=256,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+
47
+ # for backward compatibility
48
+ if num_key_value_heads is None:
49
+ num_key_value_heads = num_attention_heads
50
+
51
+ self.num_key_value_heads = num_key_value_heads
52
+ self.hidden_act = hidden_act
53
+ self.initializer_range = initializer_range
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.use_cache = use_cache
56
+ self.rope_theta = rope_theta
57
+ self.rope_scaling = rope_scaling
58
+ self._rope_scaling_validation()
59
+ self.attention_bias = attention_bias
60
+ self.attention_dropout = attention_dropout
61
+
62
+ self.norm_add_unit_offset = norm_add_unit_offset
63
+ self.init_fn = init_fn
64
+ self.init_std = init_std
65
+ self.init_cutoff_factor = init_cutoff_factor
66
+
67
+ # Attention-specific paramters
68
+ self.attention_class = attention_class
69
+ self.window_size = window_size
70
+ self.num_chunks = num_chunks
71
+ self.chunk_size = chunk_size
72
+
73
+ super().__init__(
74
+ pad_token_id=pad_token_id,
75
+ bos_token_id=bos_token_id,
76
+ eos_token_id=eos_token_id,
77
+ tie_word_embeddings=tie_word_embeddings,
78
+ **kwargs,
79
+ )
80
+
81
+ def _rope_scaling_validation(self):
82
+ """
83
+ Validate the `rope_scaling` configuration.
84
+ """
85
+ if self.rope_scaling is None:
86
+ return
87
+
88
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
89
+ raise ValueError(
90
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
91
+ )
92
+ rope_scaling_type = self.rope_scaling.get("type", None)
93
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
94
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
95
+ raise ValueError(
96
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
97
+ )
98
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
99
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from .eva_agg_kernel import eva_agg_func_triton
6
+ from .eva_prep_kv_kernel import eva_prep_kv_func_triton
7
+ try:
8
+ import triton
9
+ USE_TRITON_IMPL = True
10
+ except ImportError:
11
+ USE_TRITON_IMPL = False
12
+ raise ImportError("Triton is not installed. Please install it by running `pip install triton`.")
13
+
14
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Rotates half the hidden dims (last dim) of the input.
17
+ Args:
18
+ x: Rotary embedded tensor
19
+ Return:
20
+ Tensor with half of last dim negated and rotated to the front.
21
+ """
22
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
23
+ return torch.cat((-x2, x1), dim=-1)
24
+
25
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
26
+ position_ids: torch.Tensor) -> torch.Tensor:
27
+ """
28
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
29
+
30
+ The legends for dimensions are defined as:
31
+ num_heads: number of attention heads
32
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
33
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
34
+ maximum lenghth, e.g. the length of static sequence length of KV cache
35
+
36
+
37
+ Args:
38
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
39
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
40
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
41
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
42
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
43
+ (batch_size, current_seq_len).
44
+
45
+ Returns:
46
+ Embedded query and key tensor of same size as input.
47
+
48
+ """
49
+ bs, nheads, cur_seq_len, head_dim = q.shape
50
+ assert len(
51
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
52
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
53
+ assert list(k.shape[2:]) == [cur_seq_len,
54
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
55
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
56
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
57
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
58
+
59
+ q_embed = (q * cos) + (rotate_half(q) * sin)
60
+ k_embed = (k * cos) + (rotate_half(k) * sin)
61
+ return q_embed, k_embed
62
+
63
+ class EvaAttention(nn.Module):
64
+ """
65
+ Causal EVA for language modeling.
66
+ """
67
+
68
+ def __init__(self, config, layer_idx: Optional[int] = None):
69
+ super().__init__()
70
+ self.config = config
71
+ self.layer_idx = layer_idx
72
+ self.hidden_size = config.hidden_size
73
+ self.num_heads = config.num_attention_heads
74
+ self.head_dim = self.hidden_size // self.num_heads
75
+ self.head_dim_scaling = self.head_dim ** -0.5
76
+
77
+ self.max_position_embeddings = config.max_position_embeddings
78
+
79
+ if (self.head_dim * self.num_heads) != self.hidden_size:
80
+ raise ValueError(
81
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
82
+ f" and `num_heads`: {self.num_heads})."
83
+ )
84
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
85
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
86
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
87
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
88
+
89
+ self.window_size = config.window_size
90
+
91
+ self.num_chunks = config.num_chunks
92
+ self.chunk_size = config.chunk_size
93
+ if self.chunk_size is not None:
94
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
95
+ # chunk_size overrides the number of landmarks
96
+ self.num_chunks = None
97
+
98
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
99
+ self.adaptive_phi = nn.Parameter(
100
+ torch.randn(
101
+ 1,
102
+ self.num_heads,
103
+ 1,
104
+ 1,
105
+ self.head_dim
106
+ ).clamp(-1., 1.) * self.head_dim_scaling
107
+ )
108
+ self.adaptive_mu_k = nn.Parameter(
109
+ torch.randn(
110
+ 1,
111
+ self.num_heads,
112
+ 1,
113
+ 1,
114
+ self.head_dim
115
+ ).clamp(-1., 1.) * self.head_dim_scaling
116
+ )
117
+
118
+ def _triton_forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
124
+ output_attentions: bool = False,
125
+ use_cache: bool = False,
126
+ cos: Optional[torch.Tensor] = None,
127
+ sin: Optional[torch.Tensor] = None,
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
129
+ assert not output_attentions
130
+ bsz, q_len, _ = hidden_states.size()
131
+
132
+ if use_cache:
133
+ if past_key_value is None:
134
+ raise ValueError
135
+ assert isinstance(attention_mask, tuple)
136
+
137
+ # infer the model's running mode
138
+ is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
139
+ is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
140
+
141
+ if is_prefilling:
142
+ assert len(attention_mask) == 2
143
+ window_mask, intra_chunk_mask = attention_mask
144
+ chunk_mask = None
145
+ elif is_decoding:
146
+ assert len(attention_mask) == 3
147
+ window_mask, intra_chunk_mask, chunk_mask = attention_mask
148
+ else:
149
+ if attention_mask is not None:
150
+ assert isinstance(attention_mask, tuple) and len(attention_mask) == 3
151
+ window_mask, chunk_mask, intra_chunk_mask = attention_mask
152
+ else:
153
+ window_mask, chunk_mask, intra_chunk_mask = None, None, None
154
+
155
+ ############################################
156
+ # compute q, k, v from hidden states
157
+ ############################################
158
+ # [b, h, q_len, d]
159
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
160
+ # [b, h, kv_len, d]
161
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
162
+ # [b, h, kv_len, d]
163
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
164
+
165
+ if use_cache:
166
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
167
+
168
+ ############################################
169
+ # apply rotary positional embeddings to q, k
170
+ ############################################
171
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
172
+
173
+ ############################################
174
+ # update and get cached singleton tokens
175
+ # update and cache k and v for calculating chunk-level RFAs
176
+ ############################################
177
+ if use_cache:
178
+ s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks(
179
+ k,
180
+ v,
181
+ self.layer_idx,
182
+ self.window_size,
183
+ )
184
+ else:
185
+ s_k, s_v = k, v
186
+ dump_k, dump_v = k, v
187
+
188
+ if use_cache:
189
+ singleton_mask, dump_rf_mask = past_key_value.update_mask(
190
+ s_mask=window_mask,
191
+ rf_mask=intra_chunk_mask,
192
+ layer_idx=self.layer_idx,
193
+ window_size=self.window_size,
194
+ )
195
+ else:
196
+ singleton_mask = window_mask
197
+ dump_rf_mask = intra_chunk_mask
198
+
199
+ if dump_k is not None and dump_v is not None:
200
+ # 1. in prefilling, the input shape is
201
+ # dump_k/dump_v: [b, h, n, d]
202
+ # rfa_k/rfa_v: [b, h, n // c, d]
203
+ # 2. in decoding, the input shape is
204
+ # k/v: [b, h, w, d]
205
+ # rfa_k/rfa_v: [b, h, w//c, d]
206
+ # 3. in forward inference; the seq_len is already divisible
207
+ rfa_k, rfa_v = eva_prep_kv_func_triton(
208
+ dump_k, dump_v,
209
+ self.adaptive_mu_k, self.adaptive_phi,
210
+ dump_rf_mask, self.head_dim_scaling, self.chunk_size
211
+ )
212
+ # rfa_mask = get_rfa_chunk_mask(dump_rf_mask)
213
+ if use_cache:
214
+ rfa_k, rfa_v = past_key_value.update_chunk_rfas(
215
+ rfa_k, rfa_v, self.layer_idx
216
+ )
217
+ elif use_cache:
218
+ # if there are not enough elements within the last chunk,
219
+ # we will only use the cached chunk-level RFAs
220
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
221
+ else:
222
+ rfa_k, rfa_v = None, None
223
+
224
+ ############################################
225
+ # compute the full attention output
226
+ ############################################
227
+ if is_prefilling:
228
+ # prefilling
229
+ # 1. in prefilling, the input shape is
230
+ # q: [b, h, n, d]
231
+ # k/v: [b, h, n, d]
232
+ # rfa_k/rfa_v: [b, h, n // c, d]
233
+ attn_output = eva_agg_func_triton(
234
+ q, s_k, s_v,
235
+ rfa_k, rfa_v,
236
+ singleton_mask, chunk_mask,
237
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
238
+ )
239
+ elif is_decoding:
240
+ # 2. in decoding, the input shape is
241
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
242
+ # k/v: [b, h, 1 + s, d]
243
+ # rfa_k/rfa_v: [b, h, n // c, d]
244
+ if rfa_k is not None and rfa_v is not None:
245
+ # we only take the chunk-level RFAs not in the current window
246
+ seen_seq_len = past_key_value.get_seq_length(self.layer_idx)
247
+ if seen_seq_len <= self.window_size:
248
+ agg_k = s_k
249
+ agg_v = s_v
250
+ attn_mask = singleton_mask
251
+ else:
252
+ # NOTE: we already updated the cache so the length now
253
+ # includes the current token
254
+ # we subtract 1 from seen_seq_len because we want
255
+ # if seen_seq_len = 2048 -> num_windows_seen_so_far = 0
256
+ # if seen_seq_len = 4096 -> num_windows_seen_so_far = 1
257
+ # if seen_seq_len = 4097 -> num_windows_seen_so_far = 2
258
+ # NOTE the cat order should be taken care of;
259
+ # should align with the order based on which
260
+ # the attention mask is constructed
261
+ num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size
262
+ agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
263
+ agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
264
+ if singleton_mask is not None:
265
+ assert chunk_mask is not None
266
+ attn_mask = torch.cat([singleton_mask, chunk_mask], dim=-1)
267
+ else:
268
+ attn_mask = singleton_mask
269
+ else:
270
+ agg_k = s_k
271
+ agg_v = s_v
272
+ attn_mask = singleton_mask
273
+ attn_output = F.scaled_dot_product_attention(
274
+ q, agg_k, agg_v,
275
+ attn_mask=attn_mask,
276
+ is_causal=False,
277
+ dropout_p=0.0,
278
+ scale=self.head_dim_scaling
279
+ )
280
+ else:
281
+ # 3. in single-forward inference
282
+ attn_output = eva_agg_func_triton(
283
+ q, s_k, s_v,
284
+ rfa_k, rfa_v,
285
+ singleton_mask, chunk_mask,
286
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
287
+ )
288
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
289
+ raise ValueError(
290
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
291
+ f" {attn_output.size()}"
292
+ )
293
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
294
+ attn_output = self.o_proj(attn_output)
295
+ attn_weights = None
296
+ return attn_output, attn_weights, past_key_value
297
+
298
+ def _multibyte_decoding_forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
304
+ output_attentions: bool = False,
305
+ use_cache: bool = False,
306
+ cos: Optional[torch.Tensor] = None,
307
+ sin: Optional[torch.Tensor] = None,
308
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
309
+ # during multi-byte forwarding, we only read caches and do not update them
310
+ assert not output_attentions
311
+ bsz, q_len, _ = hidden_states.size()
312
+
313
+ if use_cache and past_key_value is None:
314
+ raise ValueError
315
+
316
+ assert USE_TRITON_IMPL
317
+ assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool
318
+
319
+ assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
320
+
321
+ ############################################
322
+ # compute q, k, v from hidden states
323
+ ############################################
324
+ # [b, h, q_len, d]
325
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
326
+ # [b, h, kv_len, d]
327
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
328
+ # [b, h, kv_len, d]
329
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
330
+
331
+ ############################################
332
+ # apply rotary positional embeddings to q, k
333
+ ############################################
334
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
335
+
336
+ ############################################
337
+ # update and get cached singleton tokens
338
+ ############################################
339
+ input_len = k.shape[-2]
340
+ window_pos = past_key_value.past_window_pos[self.layer_idx]
341
+ new_window_pos = window_pos + input_len
342
+
343
+ past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k
344
+ past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v
345
+ s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :]
346
+ s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :]
347
+
348
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
349
+
350
+ ############################################
351
+ # compute the full attention output
352
+ ############################################
353
+ # 2. in decoding, the input shape is
354
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
355
+ # k/v: [b, h, 1 + s, d]
356
+ # rfa_k/rfa_v: [b, h, n // c, d]
357
+ if rfa_k is not None and rfa_v is not None:
358
+ # NOTE the cat order should be taken care of;
359
+ # should align with the order based on which
360
+ # the attention mask is constructed
361
+ # agg_k = torch.cat([s_k, rfa_k], dim=-2)
362
+ # agg_v = torch.cat([s_v, rfa_v], dim=-2)
363
+ agg_k = torch.cat([rfa_k, s_k], dim=-2)
364
+ agg_v = torch.cat([rfa_v, s_v], dim=-2)
365
+ else:
366
+ agg_k = s_k
367
+ agg_v = s_v
368
+ attn_output = F.scaled_dot_product_attention(
369
+ q, agg_k, agg_v,
370
+ attn_mask=attention_mask,
371
+ is_causal=False,
372
+ dropout_p=0.0,
373
+ scale=self.head_dim_scaling
374
+ )
375
+
376
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
377
+ raise ValueError(
378
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
379
+ f" {attn_output.size()}"
380
+ )
381
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
382
+ attn_output = self.o_proj(attn_output)
383
+ attn_weights = None
384
+ return attn_output, attn_weights, past_key_value
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
392
+ output_attentions: bool = False,
393
+ use_cache: bool = False,
394
+ cos: Optional[torch.Tensor] = None,
395
+ sin: Optional[torch.Tensor] = None,
396
+ multibyte_decoding: Optional[bool] = False,
397
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
398
+ assert not output_attentions
399
+ if use_cache and past_key_value is None:
400
+ raise ValueError
401
+
402
+ assert USE_TRITON_IMPL
403
+ if use_cache and multibyte_decoding:
404
+ return self._multibyte_decoding_forward(
405
+ hidden_states,
406
+ attention_mask=attention_mask,
407
+ position_ids=position_ids,
408
+ past_key_value=past_key_value,
409
+ output_attentions=output_attentions,
410
+ use_cache=use_cache,
411
+ cos=cos,
412
+ sin=sin,
413
+ )
414
+ else:
415
+ return self._triton_forward(
416
+ hidden_states,
417
+ attention_mask=attention_mask,
418
+ position_ids=position_ids,
419
+ past_key_value=past_key_value,
420
+ output_attentions=output_attentions,
421
+ use_cache=use_cache,
422
+ cos=cos,
423
+ sin=sin,
424
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_agg_kernel.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
10
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
11
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
12
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
13
+ }
14
+ )
15
+ @triton.jit
16
+ def _bwd_eva_agg_kernel_dkdv(
17
+ Q,
18
+ K,
19
+ V,
20
+ WindowMask,
21
+ DO,
22
+ LSE,
23
+ DO_T_O,
24
+ DK,
25
+ DV,
26
+ softmax_scale,
27
+ stride_qb, stride_qh, stride_qm,
28
+ stride_kb, stride_kh, stride_kn,
29
+ stride_vb, stride_vh, stride_vn,
30
+ stride_window_mask_b, stride_window_mask_m,
31
+ stride_do_b, stride_do_h, stride_do_m,
32
+ stride_lse_b, stride_lse_h,
33
+ stride_do_t_o_b, stride_do_t_o_h,
34
+ stride_dk_b, stride_dk_h, stride_dk_n,
35
+ stride_dv_b, stride_dv_h, stride_dv_n,
36
+ nheads,
37
+ seqlen_q,
38
+ seqlen_k,
39
+ headdim,
40
+ WINDOW_SIZE: tl.constexpr,
41
+ MASK_TYPE: tl.constexpr,
42
+ BLOCK_HEADDIM: tl.constexpr,
43
+ EVEN_M: tl.constexpr,
44
+ EVEN_N: tl.constexpr,
45
+ EVEN_W: tl.constexpr,
46
+ EVEN_HEADDIM: tl.constexpr,
47
+ BLOCK_M: tl.constexpr,
48
+ BLOCK_N: tl.constexpr,
49
+ ):
50
+ off_bh = tl.program_id(1)
51
+ off_h = off_bh % nheads
52
+ off_b = off_bh // nheads
53
+
54
+ start_n = tl.program_id(0)
55
+ # determine which window the current KV block belongs to
56
+ offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
57
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
58
+ offs_m = tl.arange(0, BLOCK_M)
59
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
60
+
61
+ # initialize pointers
62
+ q_ptrs = (
63
+ Q +
64
+ off_b * stride_qb +
65
+ off_h * stride_qh +
66
+ offs_m[:, None] * stride_qm + offs_d[None, :]
67
+ )
68
+ k_ptrs = (
69
+ K +
70
+ off_b * stride_kb +
71
+ off_h * stride_kh +
72
+ offs_n[:, None] * stride_kn + offs_d[None, :]
73
+ )
74
+ v_ptrs = (
75
+ V +
76
+ off_b * stride_vb +
77
+ off_h * stride_vh +
78
+ offs_n[:, None] * stride_vn + offs_d[None, :]
79
+ )
80
+ do_ptrs = (
81
+ DO +
82
+ off_b * stride_do_b +
83
+ off_h * stride_do_h +
84
+ offs_m[:, None] * stride_do_m + offs_d[None, :]
85
+ )
86
+ do_t_o_ptrs = (
87
+ DO_T_O +
88
+ off_b * stride_do_t_o_b +
89
+ off_h * stride_do_t_o_h +
90
+ offs_m[:, None]
91
+ )
92
+ lse_ptrs = (
93
+ LSE +
94
+ off_b * stride_lse_b +
95
+ off_h * stride_lse_h +
96
+ offs_m[:, None]
97
+ )
98
+ if MASK_TYPE == 1:
99
+ m_ptrs = (
100
+ WindowMask +
101
+ off_b * stride_window_mask_b +
102
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
103
+ )
104
+ dk_ptrs = (
105
+ DK +
106
+ off_b * stride_dk_b +
107
+ off_h * stride_dk_h +
108
+ offs_n[:, None] * stride_dk_n + offs_d[None, :]
109
+ )
110
+ dv_ptrs = (
111
+ DV +
112
+ off_b * stride_dv_b +
113
+ off_h * stride_dv_h +
114
+ offs_n[:, None] * stride_dv_n + offs_d[None, :]
115
+ )
116
+
117
+ # 1. for singletons
118
+ # determine start and end of query block
119
+ begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
120
+ end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
121
+
122
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
123
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
124
+ if EVEN_N & EVEN_M:
125
+ if EVEN_HEADDIM:
126
+ k = tl.load(k_ptrs)
127
+ v = tl.load(v_ptrs)
128
+ else:
129
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
130
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
131
+ else:
132
+ if EVEN_HEADDIM:
133
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
134
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
135
+ else:
136
+ k = tl.load(
137
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
138
+ )
139
+ v = tl.load(
140
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
141
+ )
142
+ for start_m in range(begin_m, end_m, BLOCK_M):
143
+ start_m = tl.multiple_of(start_m, BLOCK_M)
144
+ # load q, do, and lse
145
+ if EVEN_M & EVEN_N:
146
+ if EVEN_HEADDIM:
147
+ q = tl.load(
148
+ q_ptrs + start_m * stride_qm
149
+ )
150
+ do = tl.load(
151
+ do_ptrs + start_m * stride_do_m
152
+ )
153
+ else:
154
+ q = tl.load(
155
+ q_ptrs + start_m * stride_qm,
156
+ mask=offs_d[None, :] < headdim,
157
+ other=0.0
158
+ )
159
+ do = tl.load(
160
+ do_ptrs + start_m * stride_do_m,
161
+ mask=offs_d[None, :] < headdim,
162
+ other=0.0
163
+ )
164
+ do_t_o = tl.load(
165
+ do_t_o_ptrs + start_m
166
+ )
167
+ lse = tl.load(
168
+ lse_ptrs + start_m
169
+ )
170
+ else:
171
+ if EVEN_HEADDIM:
172
+ q = tl.load(
173
+ q_ptrs + start_m * stride_qm,
174
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
175
+ other=0.0
176
+ )
177
+ do = tl.load(
178
+ do_ptrs + start_m * stride_do_m,
179
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
180
+ other=0.0
181
+ )
182
+ else:
183
+ q = tl.load(
184
+ q_ptrs + start_m * stride_qm,
185
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
186
+ other=0.0
187
+ )
188
+ do = tl.load(
189
+ do_ptrs + start_m * stride_do_m,
190
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
191
+ other=0.0
192
+ )
193
+ do_t_o = tl.load(
194
+ do_t_o_ptrs + start_m,
195
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
196
+ other=0.0
197
+ )
198
+ lse = tl.load(
199
+ lse_ptrs + start_m,
200
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
201
+ other=0.0
202
+ )
203
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
204
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
205
+ qk += tl.dot(q, tl.trans(k))
206
+ if not EVEN_M:
207
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
208
+
209
+ if MASK_TYPE == 1:
210
+ if EVEN_M & EVEN_W:
211
+ mask = tl.load(
212
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
213
+ )
214
+ else:
215
+ mask = tl.load(
216
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
217
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
218
+ & (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
219
+ other=1,
220
+ )
221
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
222
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
223
+ # to multiply with softmax_scale here.
224
+ # we assume mask already implies the causal masking
225
+ qk = qk * softmax_scale
226
+ qk = tl.where(mask, float("-inf"), qk)
227
+ p = tl.exp(qk - lse)
228
+ else:
229
+ qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
230
+ p = tl.exp(qk * softmax_scale - lse)
231
+
232
+ # dp [M, N]
233
+ dp = tl.dot(do, tl.trans(v))
234
+ # p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
235
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
236
+ # p is fp32 and [M, N], convert to q.dtype
237
+ # do [M, D] -> dv [N, D]
238
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do)
239
+ # dk [N, D]
240
+ dk += tl.dot(tl.trans(ds), q)
241
+ if EVEN_N & EVEN_M:
242
+ if EVEN_HEADDIM:
243
+ tl.store(dv_ptrs, dv)
244
+ tl.store(dk_ptrs, dk)
245
+ else:
246
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
247
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
248
+ else:
249
+ if EVEN_HEADDIM:
250
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
251
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
252
+ else:
253
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
254
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
255
+
256
+ @triton.heuristics(
257
+ {
258
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
259
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
260
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
261
+ }
262
+ )
263
+ @triton.jit
264
+ def _bwd_eva_agg_kernel_drfa_kv(
265
+ Q,
266
+ RFA_K,
267
+ RFA_V,
268
+ ChunkMask,
269
+ DO,
270
+ LSE,
271
+ DO_T_O,
272
+ D_RFA_K,
273
+ D_RFA_V,
274
+ softmax_scale,
275
+ stride_qb, stride_qh, stride_qm,
276
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
277
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
278
+ stride_chunk_mask_b, stride_chunk_mask_m,
279
+ stride_do_b, stride_do_h, stride_do_m,
280
+ stride_lse_b, stride_lse_h,
281
+ stride_do_t_o_b, stride_do_t_o_h,
282
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
283
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
284
+ nheads,
285
+ seqlen_q,
286
+ nchunks,
287
+ headdim,
288
+ CHUNKS_PER_WINDOW: tl.constexpr,
289
+ WINDOW_SIZE: tl.constexpr,
290
+ MASK_TYPE: tl.constexpr,
291
+ BLOCK_HEADDIM: tl.constexpr,
292
+ EVEN_M: tl.constexpr,
293
+ EVEN_C: tl.constexpr,
294
+ EVEN_HEADDIM: tl.constexpr,
295
+ BLOCK_M: tl.constexpr,
296
+ BLOCK_N: tl.constexpr,
297
+ ):
298
+ off_bh = tl.program_id(1)
299
+ off_h = off_bh % nheads
300
+ off_b = off_bh // nheads
301
+ start_c = tl.program_id(0)
302
+ # there are 128 chunks per window
303
+ offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
304
+ # determine which window the current KV block belongs to
305
+ offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
306
+ offs_m = tl.arange(0, BLOCK_M)
307
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
308
+
309
+ # initialize pointers
310
+ q_ptrs = (
311
+ Q +
312
+ off_b * stride_qb +
313
+ off_h * stride_qh +
314
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
315
+ )
316
+ do_ptrs = (
317
+ DO +
318
+ off_b * stride_do_b +
319
+ off_h * stride_do_h +
320
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
321
+ )
322
+ do_t_o_ptrs = (
323
+ DO_T_O +
324
+ off_b * stride_do_t_o_b +
325
+ off_h * stride_do_t_o_h +
326
+ (offs_m[:, None])
327
+ )
328
+ lse_ptrs = (
329
+ LSE +
330
+ off_b * stride_lse_b +
331
+ off_h * stride_lse_h +
332
+ (offs_m[:, None])
333
+ )
334
+ rfa_k_ptrs = (
335
+ RFA_K +
336
+ off_b * stride_rfa_kb +
337
+ off_h * stride_rfa_kh +
338
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
339
+ )
340
+ rfa_v_ptrs = (
341
+ RFA_V +
342
+ off_b * stride_rfa_vb +
343
+ off_h * stride_rfa_vh +
344
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
345
+ )
346
+ if MASK_TYPE == 1:
347
+ rfa_m_ptrs = (
348
+ ChunkMask +
349
+ off_b * stride_chunk_mask_b +
350
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
351
+ )
352
+ d_rfa_k_ptrs = (
353
+ D_RFA_K +
354
+ off_b * stride_d_rfa_k_b +
355
+ off_h * stride_d_rfa_k_h +
356
+ (offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
357
+ )
358
+ d_rfa_v_ptrs = (
359
+ D_RFA_V +
360
+ off_b * stride_d_rfa_v_b +
361
+ off_h * stride_d_rfa_v_h +
362
+ (offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
363
+ )
364
+
365
+ d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
366
+ d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
367
+ if EVEN_C & EVEN_M:
368
+ if EVEN_HEADDIM:
369
+ rfa_k = tl.load(rfa_k_ptrs)
370
+ rfa_v = tl.load(rfa_v_ptrs)
371
+ else:
372
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
373
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
374
+ else:
375
+ if EVEN_HEADDIM:
376
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
377
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
378
+ else:
379
+ rfa_k = tl.load(
380
+ rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
381
+ )
382
+ rfa_v = tl.load(
383
+ rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
384
+ )
385
+ begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
386
+ end_m = seqlen_q
387
+ for start_m in range(begin_m, end_m, BLOCK_M):
388
+ start_m = tl.multiple_of(start_m, BLOCK_M)
389
+ # load q, do, and lse
390
+ if EVEN_M:
391
+ if EVEN_HEADDIM:
392
+ q = tl.load(
393
+ q_ptrs + start_m * stride_qm
394
+ )
395
+ do = tl.load(
396
+ do_ptrs + start_m * stride_do_m
397
+ )
398
+ else:
399
+ q = tl.load(
400
+ q_ptrs + start_m * stride_qm,
401
+ mask=offs_d[None, :] < headdim,
402
+ other=0.0
403
+ )
404
+ do = tl.load(
405
+ do_ptrs + start_m * stride_do_m,
406
+ mask=offs_d[None, :] < headdim,
407
+ other=0.0
408
+ )
409
+ do_t_o = tl.load(
410
+ do_t_o_ptrs + start_m
411
+ )
412
+ lse = tl.load(
413
+ lse_ptrs + start_m
414
+ )
415
+ else:
416
+ if EVEN_HEADDIM:
417
+ q = tl.load(
418
+ q_ptrs + start_m * stride_qm,
419
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
420
+ other=0.0
421
+ )
422
+ do = tl.load(
423
+ do_ptrs + start_m * stride_do_m,
424
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
425
+ other=0.0
426
+ )
427
+ else:
428
+ q = tl.load(
429
+ q_ptrs + start_m * stride_qm,
430
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
431
+ other=0.0
432
+ )
433
+ do = tl.load(
434
+ do_ptrs + start_m * stride_do_m,
435
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
436
+ other=0.0
437
+ )
438
+ do_t_o = tl.load(
439
+ do_t_o_ptrs + start_m,
440
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
441
+ other=0.0
442
+ )
443
+ lse = tl.load(
444
+ lse_ptrs + start_m,
445
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
446
+ other=0.0
447
+ )
448
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
449
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
450
+ qk += tl.dot(q, tl.trans(rfa_k))
451
+ if not EVEN_M:
452
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
453
+
454
+ if MASK_TYPE == 1:
455
+ if EVEN_M & EVEN_C:
456
+ mask = tl.load(
457
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m)
458
+ )
459
+ else:
460
+ mask = tl.load(
461
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m),
462
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
463
+ & (offs_c[None, :] < nchunks),
464
+ other=1,
465
+ )
466
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
467
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
468
+ # to multiply with softmax_scale here.
469
+ # we assume mask already implies the causal masking
470
+ qk = qk * softmax_scale
471
+ qk = tl.where(mask, float("-inf"), qk)
472
+ p = tl.exp(qk - lse)
473
+ else:
474
+ p = tl.exp(qk * softmax_scale - lse)
475
+
476
+ dp = tl.dot(do, tl.trans(rfa_v))
477
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
478
+ # p is fp32, convert to q.dtype
479
+ d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
480
+ # move softmax_scale to ds to save computation
481
+ d_rfa_k += tl.dot(tl.trans(ds), q)
482
+ if EVEN_C & EVEN_M:
483
+ if EVEN_HEADDIM:
484
+ tl.store(d_rfa_v_ptrs, d_rfa_v)
485
+ tl.store(d_rfa_k_ptrs, d_rfa_k)
486
+ else:
487
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
488
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
489
+ else:
490
+ if EVEN_HEADDIM:
491
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
492
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
493
+ else:
494
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
495
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
496
+
497
+ @triton.heuristics(
498
+ {
499
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
500
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
501
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
502
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
503
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
504
+ }
505
+ )
506
+ @triton.jit
507
+ def _bwd_eva_agg_kernel_dq(
508
+ Q,
509
+ K,
510
+ V,
511
+ RFA_K,
512
+ RFA_V,
513
+ WindowMask,
514
+ ChunkMask,
515
+ DO,
516
+ LSE,
517
+ DO_T_O,
518
+ DQ,
519
+ softmax_scale,
520
+ stride_qb, stride_qh, stride_qm,
521
+ stride_kb, stride_kh, stride_kn,
522
+ stride_vb, stride_vh, stride_vn,
523
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
524
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
525
+ stride_window_mask_b, stride_window_mask_m,
526
+ stride_chunk_mask_b, stride_chunk_mask_m,
527
+ stride_do_b, stride_do_h, stride_do_m,
528
+ stride_lse_b, stride_lse_h,
529
+ stride_do_t_o_b, stride_do_t_o_h,
530
+ stride_dq_b, stride_dq_h, stride_dq_m,
531
+ nheads,
532
+ seqlen_q,
533
+ seqlen_k,
534
+ nchunks,
535
+ headdim,
536
+ CHUNKS_PER_WINDOW: tl.constexpr,
537
+ WINDOW_SIZE: tl.constexpr,
538
+ MASK_TYPE: tl.constexpr,
539
+ EMPTY_RFA_KV: tl.constexpr,
540
+ BLOCK_HEADDIM: tl.constexpr,
541
+ EVEN_M: tl.constexpr,
542
+ EVEN_N: tl.constexpr,
543
+ EVEN_W: tl.constexpr,
544
+ EVEN_C: tl.constexpr,
545
+ EVEN_HEADDIM: tl.constexpr,
546
+ BLOCK_M: tl.constexpr,
547
+ BLOCK_N: tl.constexpr,
548
+ ):
549
+ start_m = tl.program_id(0)
550
+ off_bh = tl.program_id(1)
551
+ off_h = off_bh % nheads
552
+ off_b = off_bh // nheads
553
+ # initialize offsets
554
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
555
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
556
+ offs_n = tl.arange(0, BLOCK_N)
557
+ offs_c = tl.arange(0, BLOCK_N)
558
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
559
+ # TODO: add paratheses or not
560
+ q_ptrs = (
561
+ Q +
562
+ off_b * stride_qb +
563
+ off_h * stride_qh +
564
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
565
+ )
566
+ k_ptrs = (
567
+ K +
568
+ off_b * stride_kb +
569
+ off_h * stride_kh +
570
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
571
+ )
572
+ v_ptrs = (
573
+ V +
574
+ off_b * stride_vb +
575
+ off_h * stride_vh +
576
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
577
+ )
578
+ if EMPTY_RFA_KV == 0:
579
+ rfa_k_ptrs = (
580
+ RFA_K +
581
+ off_b * stride_rfa_kb +
582
+ off_h * stride_rfa_kh +
583
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
584
+ )
585
+ rfa_v_ptrs = (
586
+ RFA_V +
587
+ off_b * stride_rfa_vb +
588
+ off_h * stride_rfa_vh +
589
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
590
+ )
591
+ dq_ptrs = (
592
+ DQ +
593
+ off_b * stride_dq_b +
594
+ off_h * stride_dq_h +
595
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
596
+ )
597
+ do_ptrs = (
598
+ DO +
599
+ off_b * stride_do_b +
600
+ off_h * stride_do_h +
601
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
602
+ )
603
+ do_t_o_ptrs = (
604
+ DO_T_O +
605
+ off_b * stride_do_t_o_b +
606
+ off_h * stride_do_t_o_h +
607
+ offs_m[:, None]
608
+ )
609
+ lse_ptrs = (
610
+ LSE +
611
+ off_b * stride_lse_b +
612
+ off_h * stride_lse_h +
613
+ offs_m[:, None]
614
+ )
615
+ ### load q, do, do_t_o, lse ####
616
+ if EVEN_M:
617
+ if EVEN_HEADDIM:
618
+ q = tl.load(
619
+ q_ptrs
620
+ )
621
+ do = tl.load(
622
+ do_ptrs
623
+ )
624
+ else:
625
+ q = tl.load(
626
+ q_ptrs,
627
+ mask=offs_d[None, :] < headdim,
628
+ other=0.0
629
+ )
630
+ do = tl.load(
631
+ do_ptrs,
632
+ mask=offs_d[None, :] < headdim,
633
+ other=0.0
634
+ )
635
+ do_t_o = tl.load(
636
+ do_t_o_ptrs
637
+ )
638
+ lse = tl.load(
639
+ lse_ptrs
640
+ )
641
+ else:
642
+ if EVEN_HEADDIM:
643
+ q = tl.load(
644
+ q_ptrs,
645
+ mask=offs_m[:, None] < seqlen_q,
646
+ other=0.0
647
+ )
648
+ do = tl.load(
649
+ do_ptrs,
650
+ mask=offs_m[:, None] < seqlen_q,
651
+ other=0.0
652
+ )
653
+ else:
654
+ q = tl.load(
655
+ q_ptrs,
656
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
657
+ other=0.0
658
+ )
659
+ do = tl.load(
660
+ do_ptrs,
661
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
662
+ other=0.0
663
+ )
664
+ do_t_o = tl.load(
665
+ do_t_o_ptrs,
666
+ mask=offs_m[:, None] < seqlen_q,
667
+ other=0.0
668
+ )
669
+ lse = tl.load(
670
+ lse_ptrs,
671
+ mask=offs_m[:, None] < seqlen_q,
672
+ other=0.0
673
+ )
674
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
675
+ lse *= 1.4426950408889634 # log2(e)
676
+ qk_scale = softmax_scale
677
+ qk_scale *= 1.4426950408889634 # log2(e)
678
+ if MASK_TYPE == 1:
679
+ window_mask_ptrs = (
680
+ WindowMask +
681
+ off_b * stride_window_mask_b +
682
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
683
+ )
684
+ if EMPTY_RFA_KV == 0:
685
+ chunk_mask_ptrs = (
686
+ ChunkMask +
687
+ off_b * stride_chunk_mask_b +
688
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
689
+ )
690
+
691
+ dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
692
+ # loop over k, v and update accumulator
693
+ # Iterate over local singletons;
694
+ # so we only iterate over blocks within the current window
695
+ start_idx_n = offs_w * WINDOW_SIZE
696
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
697
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
698
+ start_n = tl.multiple_of(start_n, BLOCK_N)
699
+ if EVEN_N & EVEN_M:
700
+ if EVEN_HEADDIM:
701
+ k = tl.load(
702
+ k_ptrs + start_n * stride_kn
703
+ )
704
+ else:
705
+ k = tl.load(
706
+ k_ptrs + start_n * stride_kn,
707
+ mask=offs_d[None, :] < headdim,
708
+ other=0.0
709
+ )
710
+ else:
711
+ if EVEN_HEADDIM:
712
+ k = tl.load(
713
+ k_ptrs + start_n * stride_kn,
714
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
715
+ other=0.0,
716
+ )
717
+ else:
718
+ k = tl.load(
719
+ k_ptrs + start_n * stride_kn,
720
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
721
+ other=0.0,
722
+ )
723
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
724
+ qk += tl.dot(q, tl.trans(k))
725
+ # Trying to combine the two masks seem to make the result wrong
726
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
727
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
728
+
729
+ if MASK_TYPE == 1:
730
+ if EVEN_M & EVEN_W:
731
+ window_mask = tl.load(
732
+ window_mask_ptrs + start_n - start_idx_n
733
+ )
734
+ else:
735
+ window_mask = tl.load(
736
+ window_mask_ptrs + start_n - start_idx_n,
737
+ mask=(offs_m[:, None] < seqlen_q)
738
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
739
+ other=1,
740
+ )
741
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
742
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
743
+ # to multiply with softmax_scale here.
744
+ # we assume mask already implies the causal masking
745
+ qk = qk * qk_scale
746
+ qk = tl.where(window_mask, float("-inf"), qk)
747
+ p = tl.exp2(qk - lse)
748
+ else:
749
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
750
+ p = tl.exp2(qk * qk_scale - lse)
751
+
752
+ if EVEN_N & EVEN_M:
753
+ if EVEN_HEADDIM:
754
+ v = tl.load(
755
+ v_ptrs + start_n * stride_vn
756
+ )
757
+ else:
758
+ v = tl.load(
759
+ v_ptrs + start_n * stride_vn,
760
+ mask=offs_d[None, :] < headdim,
761
+ other=0.0
762
+ )
763
+ else:
764
+ if EVEN_HEADDIM:
765
+ v = tl.load(
766
+ v_ptrs + start_n * stride_vn,
767
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
768
+ other=0.0,
769
+ )
770
+ else:
771
+ v = tl.load(
772
+ v_ptrs + start_n * stride_vn,
773
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
774
+ other=0.0,
775
+ )
776
+ dp = tl.dot(do, tl.trans(v))
777
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
778
+ dq += tl.dot(ds, k)
779
+
780
+ if EMPTY_RFA_KV == 0:
781
+ # Iterate over RFA chunks
782
+ # we only iterate over chunks before the current local singleton window
783
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
784
+ for start_c in range(0, end_idx_c, BLOCK_N):
785
+ start_c = tl.multiple_of(start_c, BLOCK_N)
786
+ # -- compute qk ----
787
+ if EVEN_C & EVEN_M:
788
+ if EVEN_HEADDIM:
789
+ rfa_k = tl.load(
790
+ rfa_k_ptrs + start_c * stride_rfa_kc
791
+ )
792
+ else:
793
+ rfa_k = tl.load(
794
+ rfa_k_ptrs + start_c * stride_rfa_kc,
795
+ mask=offs_d[None, :] < headdim,
796
+ other=0.0
797
+ )
798
+ else:
799
+ if EVEN_HEADDIM:
800
+ rfa_k = tl.load(
801
+ rfa_k_ptrs + start_c * stride_rfa_kc,
802
+ mask=(start_c + offs_c)[:, None] < nchunks,
803
+ other=0.0,
804
+ )
805
+ else:
806
+ rfa_k = tl.load(
807
+ rfa_k_ptrs + start_c * stride_rfa_kc,
808
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
809
+ other=0.0,
810
+ )
811
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
812
+ qk += tl.dot(q, tl.trans(rfa_k))
813
+ # Trying to combine the two masks seem to make the result wrong
814
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
815
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
816
+
817
+ if MASK_TYPE == 1:
818
+ if EVEN_C & EVEN_M:
819
+ chunk_mask = tl.load(
820
+ chunk_mask_ptrs + start_c
821
+ )
822
+ else:
823
+ chunk_mask = tl.load(
824
+ chunk_mask_ptrs + start_c,
825
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
826
+ other=1,
827
+ )
828
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
829
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
830
+ # to multiply with softmax_scale here.
831
+ # we assume mask already implies the causal masking
832
+ qk = qk * qk_scale
833
+ qk = tl.where(chunk_mask, float("-inf"), qk)
834
+ p = tl.exp2(qk - lse)
835
+ else:
836
+ p = tl.exp2(qk * qk_scale - lse)
837
+
838
+ if EVEN_C & EVEN_M:
839
+ if EVEN_HEADDIM:
840
+ rfa_v = tl.load(
841
+ rfa_v_ptrs + start_c * stride_rfa_vc
842
+ )
843
+ else:
844
+ rfa_v = tl.load(
845
+ rfa_v_ptrs + start_c * stride_rfa_vc,
846
+ mask=offs_d[None, :] < headdim,
847
+ other=0.0
848
+ )
849
+ else:
850
+ if EVEN_HEADDIM:
851
+ rfa_v = tl.load(
852
+ rfa_v_ptrs + start_c * stride_rfa_vc,
853
+ mask=(start_c + offs_n)[:, None] < nchunks,
854
+ other=0.0,
855
+ )
856
+ else:
857
+ rfa_v = tl.load(
858
+ rfa_v_ptrs + start_c * stride_rfa_vc,
859
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
860
+ other=0.0,
861
+ )
862
+ dp = tl.dot(do, tl.trans(rfa_v))
863
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
864
+ dq += tl.dot(ds, rfa_k)
865
+
866
+ start_m = tl.program_id(0)
867
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
868
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
869
+ dq_ptrs = (
870
+ DQ +
871
+ off_b * stride_dq_b +
872
+ off_h * stride_dq_h +
873
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
874
+ )
875
+ if EVEN_M:
876
+ if EVEN_HEADDIM:
877
+ tl.store(
878
+ dq_ptrs, dq
879
+ )
880
+ else:
881
+ tl.store(
882
+ dq_ptrs, dq,
883
+ mask=offs_d[None, :] < headdim
884
+ )
885
+ else:
886
+ if EVEN_HEADDIM:
887
+ tl.store(
888
+ dq_ptrs, dq,
889
+ mask=offs_m[:, None] < seqlen_q
890
+ )
891
+ else:
892
+ tl.store(
893
+ dq_ptrs, dq,
894
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
895
+ )
896
+
897
+ _capability_90_config = {
898
+ "fwd": {
899
+ (torch.bfloat16, 64): (128, 128, 4, 3),
900
+ (torch.bfloat16, 128): (128, 128, 8, 3),
901
+ (torch.float32, 64): (128, 64, 8, 3),
902
+ (torch.float32, 128): (64, 32, 4, 3),
903
+ },
904
+ "bwd_dq": {
905
+ (torch.bfloat16, 64): (128, 64, 4, 3),
906
+ (torch.bfloat16, 128): (128, 64, 8, 3),
907
+ (torch.float32, 64): (128, 64, 8, 2),
908
+ (torch.float32, 128): (32, 32, 4, 2),
909
+ },
910
+ "bwd_dkdv": {
911
+ (torch.bfloat16, 64): (128, 64, 4, 2),
912
+ (torch.bfloat16, 128): (128, 64, 8, 2),
913
+ (torch.float32, 64): (128, 64, 8, 2),
914
+ (torch.float32, 128): (32, 32, 4, 1),
915
+ },
916
+ "bwd_drfa_kv": {
917
+ (torch.bfloat16, 64): (128, 64, 4, 2),
918
+ (torch.bfloat16, 128): (128, 64, 8, 2),
919
+ (torch.float32, 64): (128, 64, 8, 2),
920
+ (torch.float32, 128): (32, 32, 4, 1),
921
+ }
922
+ }
923
+
924
+ _capability_80_config = {
925
+ "fwd": {
926
+ (torch.bfloat16, 64): (64, 64, 4, 3),
927
+ (torch.bfloat16, 128): (64, 64, 8, 3),
928
+ (torch.float32, 64): (64, 32, 4, 2),
929
+ (torch.float32, 128): (64, 32, 8, 1),
930
+ },
931
+ "bwd_dq": {
932
+ (torch.bfloat16, 64): (64, 64, 4, 3),
933
+ (torch.bfloat16, 128): (64, 32, 4, 2),
934
+ (torch.float32, 64): (32, 32, 4, 2),
935
+ (torch.float32, 128): (32, 32, 4, 2),
936
+ },
937
+ "bwd_dkdv": {
938
+ (torch.bfloat16, 64): (64, 64, 4, 3),
939
+ (torch.bfloat16, 128): (32, 32, 4, 2),
940
+ (torch.float32, 64): (32, 32, 4, 1),
941
+ (torch.float32, 128): (16, 64, 8, 1),
942
+ },
943
+ "bwd_drfa_kv": {
944
+ (torch.bfloat16, 64): (64, 64, 4, 3),
945
+ (torch.bfloat16, 128): (64, 32, 4, 3),
946
+ (torch.float32, 64): (32, 32, 4, 1),
947
+ (torch.float32, 128): (32, 32, 4, 1),
948
+ }
949
+ }
950
+
951
+ def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
952
+ capability = torch.cuda.get_device_capability()
953
+ if capability >= (9, 0):
954
+ kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
955
+ elif capability >= (8, 0):
956
+ kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
957
+ else:
958
+ if mode == "fwd":
959
+ if dtype == torch.float32:
960
+ kernel_config = (32, 16, 4, 2)
961
+ else:
962
+ kernel_config = (64, 32, 4, 2)
963
+ else:
964
+ if dtype == torch.float32:
965
+ kernel_config = (16, 16, 4, 1)
966
+ else:
967
+ kernel_config = (32, 32, 4, 1)
968
+ return kernel_config
969
+
970
+ @triton.heuristics(
971
+ {
972
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
973
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
974
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
975
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
976
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
977
+ }
978
+ )
979
+ @triton.jit
980
+ def _fwd_eva_agg_kernel(
981
+ Q,
982
+ K,
983
+ V,
984
+ RFA_K,
985
+ RFA_V,
986
+ WindowMask,
987
+ ChunkMask,
988
+ Out,
989
+ LSE,
990
+ softmax_scale,
991
+ stride_qb, stride_qh, stride_qm,
992
+ stride_kb, stride_kh, stride_kn,
993
+ stride_vb, stride_vh, stride_vn,
994
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
995
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
996
+ stride_window_mask_b, stride_window_mask_m,
997
+ stride_chunk_mask_b, stride_chunk_mask_m,
998
+ stride_ob, stride_oh, stride_om,
999
+ stride_lse_b, stride_lse_h,
1000
+ nheads,
1001
+ seqlen_q,
1002
+ seqlen_k,
1003
+ nchunks,
1004
+ headdim,
1005
+ CHUNKS_PER_WINDOW: tl.constexpr,
1006
+ WINDOW_SIZE: tl.constexpr,
1007
+ MASK_TYPE: tl.constexpr,
1008
+ EMPTY_RFA_KV: tl.constexpr,
1009
+ BLOCK_HEADDIM: tl.constexpr,
1010
+ EVEN_M: tl.constexpr,
1011
+ EVEN_N: tl.constexpr,
1012
+ EVEN_W: tl.constexpr,
1013
+ EVEN_C: tl.constexpr,
1014
+ EVEN_HEADDIM: tl.constexpr,
1015
+ BLOCK_M: tl.constexpr,
1016
+ BLOCK_N: tl.constexpr,
1017
+ ):
1018
+ start_m = tl.program_id(0)
1019
+ off_bh = tl.program_id(1)
1020
+ off_h = off_bh % nheads
1021
+ off_b = off_bh // nheads
1022
+ # initialize offsets
1023
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
1024
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
1025
+ offs_n = tl.arange(0, BLOCK_N)
1026
+ offs_c = tl.arange(0, BLOCK_N)
1027
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
1028
+ # TODO: add paratheses or not
1029
+ q_ptrs = (
1030
+ Q +
1031
+ off_b * stride_qb +
1032
+ off_h * stride_qh +
1033
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
1034
+ )
1035
+ k_ptrs = (
1036
+ K +
1037
+ off_b * stride_kb +
1038
+ off_h * stride_kh +
1039
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
1040
+ )
1041
+ v_ptrs = (
1042
+ V +
1043
+ off_b * stride_vb +
1044
+ off_h * stride_vh +
1045
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
1046
+ )
1047
+ if EMPTY_RFA_KV == 0:
1048
+ rfa_k_ptrs = (
1049
+ RFA_K +
1050
+ off_b * stride_rfa_kb +
1051
+ off_h * stride_rfa_kh +
1052
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
1053
+ )
1054
+ rfa_v_ptrs = (
1055
+ RFA_V +
1056
+ off_b * stride_rfa_vb +
1057
+ off_h * stride_rfa_vh +
1058
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
1059
+ )
1060
+
1061
+ qk_scale = softmax_scale
1062
+ qk_scale *= 1.4426950408889634 # log2(e)
1063
+ if MASK_TYPE == 1:
1064
+ window_mask_ptrs = (
1065
+ WindowMask +
1066
+ off_b * stride_window_mask_b +
1067
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
1068
+ )
1069
+ if EMPTY_RFA_KV == 0:
1070
+ chunk_mask_ptrs = (
1071
+ ChunkMask +
1072
+ off_b * stride_chunk_mask_b +
1073
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
1074
+ )
1075
+
1076
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
1077
+ d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
1078
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
1079
+ # load q: it will stay in SRAM throughout
1080
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
1081
+ # tl.load(q_ptrs), we get the wrong output!
1082
+ if EVEN_M & EVEN_N:
1083
+ if EVEN_HEADDIM:
1084
+ q = tl.load(
1085
+ q_ptrs
1086
+ )
1087
+ else:
1088
+ q = tl.load(
1089
+ q_ptrs,
1090
+ mask=offs_d[None, :] < headdim,
1091
+ other=0.0
1092
+ )
1093
+ else:
1094
+ if EVEN_HEADDIM:
1095
+ q = tl.load(
1096
+ q_ptrs,
1097
+ mask=offs_m[:, None] < seqlen_q,
1098
+ other=0.0
1099
+ )
1100
+ else:
1101
+ q = tl.load(
1102
+ q_ptrs,
1103
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
1104
+ other=0.0
1105
+ )
1106
+ # loop over k, v and update accumulator
1107
+ # Iterate over local singletons;
1108
+ # so we only iterate over blocks within the current window
1109
+ start_idx_n = offs_w * WINDOW_SIZE
1110
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
1111
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
1112
+ start_n = tl.multiple_of(start_n, BLOCK_N)
1113
+ # -- compute qk ----
1114
+ if EVEN_N & EVEN_M:
1115
+ if EVEN_HEADDIM:
1116
+ k = tl.load(
1117
+ k_ptrs + start_n * stride_kn
1118
+ )
1119
+ else:
1120
+ k = tl.load(
1121
+ k_ptrs + start_n * stride_kn,
1122
+ mask=offs_d[None, :] < headdim,
1123
+ other=0.0
1124
+ )
1125
+ else:
1126
+ if EVEN_HEADDIM:
1127
+ k = tl.load(
1128
+ k_ptrs + start_n * stride_kn,
1129
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
1130
+ other=0.0,
1131
+ )
1132
+ else:
1133
+ k = tl.load(
1134
+ k_ptrs + start_n * stride_kn,
1135
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
1136
+ other=0.0,
1137
+ )
1138
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
1139
+ qk += tl.dot(q, tl.trans(k))
1140
+ # Trying to combine the two masks seem to make the result wrong
1141
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
1142
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
1143
+
1144
+ if MASK_TYPE == 1:
1145
+ if EVEN_M & EVEN_W:
1146
+ window_mask = tl.load(
1147
+ window_mask_ptrs + start_n - start_idx_n
1148
+ )
1149
+ else:
1150
+ window_mask = tl.load(
1151
+ window_mask_ptrs + start_n - start_idx_n,
1152
+ mask=(offs_m[:, None] < seqlen_q)
1153
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
1154
+ other=1,
1155
+ )
1156
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1157
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1158
+ # to multiply with softmax_scale here.
1159
+ # we assume mask already implies the causal masking
1160
+ qk = qk * qk_scale
1161
+ qk = tl.where(window_mask, float("-inf"), qk)
1162
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
1163
+ masked_out_rows = (m_ij == float("-inf"))
1164
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1165
+ p = tl.exp2(qk - m_ij_masked[:, None])
1166
+ else:
1167
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
1168
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1169
+ masked_out_rows = (m_ij == float("-inf"))
1170
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1171
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1172
+
1173
+ d_ij = tl.sum(p, 1)
1174
+
1175
+ # scale acc_o
1176
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1177
+ # # -- update output accumulator --
1178
+ acc_o = acc_o * prev_scale[:, None]
1179
+ # update acc_o
1180
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
1181
+ if EVEN_HEADDIM:
1182
+ v = tl.load(
1183
+ v_ptrs + start_n * stride_vn
1184
+ )
1185
+ else:
1186
+ v = tl.load(
1187
+ v_ptrs + start_n * stride_vn,
1188
+ mask=offs_d[None, :] < headdim,
1189
+ other=0.0
1190
+ )
1191
+ else:
1192
+ if EVEN_HEADDIM:
1193
+ v = tl.load(
1194
+ v_ptrs + start_n * stride_vn,
1195
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
1196
+ other=0.0,
1197
+ )
1198
+ else:
1199
+ v = tl.load(
1200
+ v_ptrs + start_n * stride_vn,
1201
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
1202
+ other=0.0,
1203
+ )
1204
+ p = p.to(v.dtype)
1205
+ acc_o = tl.dot(p, v, acc_o)
1206
+
1207
+ # -- update statistics
1208
+ d_i = d_i * prev_scale + d_ij
1209
+ m_i = m_ij
1210
+
1211
+ if EMPTY_RFA_KV == 0:
1212
+ # Iterate over RFA chunks
1213
+ # we only iterate over chunks before the current local singleton window
1214
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
1215
+ for start_c in range(0, end_idx_c, BLOCK_N):
1216
+ start_c = tl.multiple_of(start_c, BLOCK_N)
1217
+ # -- compute qk ----
1218
+ if EVEN_C & EVEN_M:
1219
+ if EVEN_HEADDIM:
1220
+ rfa_k = tl.load(
1221
+ rfa_k_ptrs + start_c * stride_rfa_kc
1222
+ )
1223
+ else:
1224
+ rfa_k = tl.load(
1225
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1226
+ mask=offs_d[None, :] < headdim,
1227
+ other=0.0
1228
+ )
1229
+ else:
1230
+ if EVEN_HEADDIM:
1231
+ rfa_k = tl.load(
1232
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1233
+ mask=(start_c + offs_c)[:, None] < nchunks,
1234
+ other=0.0,
1235
+ )
1236
+ else:
1237
+ rfa_k = tl.load(
1238
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1239
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
1240
+ other=0.0,
1241
+ )
1242
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
1243
+ qk += tl.dot(q, tl.trans(rfa_k))
1244
+ # Trying to combine the two masks seem to make the result wrong
1245
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
1246
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
1247
+
1248
+ if MASK_TYPE == 1:
1249
+ if EVEN_C & EVEN_M:
1250
+ chunk_mask = tl.load(
1251
+ chunk_mask_ptrs + start_c
1252
+ )
1253
+ else:
1254
+ chunk_mask = tl.load(
1255
+ chunk_mask_ptrs + start_c,
1256
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
1257
+ other=1,
1258
+ )
1259
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1260
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1261
+ # to multiply with softmax_scale here.
1262
+ # we assume mask already implies the causal masking
1263
+ qk = qk * qk_scale
1264
+ qk = tl.where(chunk_mask, float("-inf"), qk)
1265
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
1266
+ masked_out_rows = (m_ij == float("-inf"))
1267
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1268
+ p = tl.exp2(qk - m_ij_masked[:, None])
1269
+ else:
1270
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1271
+ masked_out_rows = (m_ij == float("-inf"))
1272
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1273
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1274
+
1275
+ d_ij = tl.sum(p, 1)
1276
+
1277
+ # scale acc_o
1278
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1279
+ # # -- update output accumulator --
1280
+ acc_o = acc_o * prev_scale[:, None]
1281
+ # update acc_o
1282
+ # TODO: If we just do "if EVEN_N", there seems to be some race condition ?
1283
+ if EVEN_C & EVEN_M:
1284
+ if EVEN_HEADDIM:
1285
+ rfa_v = tl.load(
1286
+ rfa_v_ptrs + start_c * stride_rfa_vc
1287
+ )
1288
+ else:
1289
+ rfa_v = tl.load(
1290
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1291
+ mask=offs_d[None, :] < headdim,
1292
+ other=0.0
1293
+ )
1294
+ else:
1295
+ if EVEN_HEADDIM:
1296
+ rfa_v = tl.load(
1297
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1298
+ mask=(start_c + offs_n)[:, None] < nchunks,
1299
+ other=0.0,
1300
+ )
1301
+ else:
1302
+ rfa_v = tl.load(
1303
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1304
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
1305
+ other=0.0,
1306
+ )
1307
+ p = p.to(rfa_v.dtype)
1308
+ acc_o = tl.dot(p, rfa_v, acc_o)
1309
+
1310
+ # -- update statistics
1311
+ d_i = d_i * prev_scale + d_ij
1312
+ m_i = m_ij
1313
+
1314
+ # for rows that are all -inf, set d_i to 1.0
1315
+ d_i = tl.where(d_i == 0.0, 1.0, d_i)
1316
+ # multiply by log(2)
1317
+ lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
1318
+ acc_o = acc_o / d_i[:, None]
1319
+ # TODO: understand why rematerialize offsets to save registers?
1320
+ start_m = tl.program_id(0)
1321
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
1322
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
1323
+ out_ptrs = (
1324
+ Out +
1325
+ off_b * stride_ob +
1326
+ off_h * stride_oh +
1327
+ (offs_m[:, None] * stride_om + offs_d[None, :])
1328
+ )
1329
+ if EVEN_M:
1330
+ if EVEN_HEADDIM:
1331
+ tl.store(
1332
+ out_ptrs, acc_o
1333
+ )
1334
+ else:
1335
+ tl.store(
1336
+ out_ptrs, acc_o,
1337
+ mask=offs_d[None, :] < headdim
1338
+ )
1339
+ else:
1340
+ if EVEN_HEADDIM:
1341
+ tl.store(
1342
+ out_ptrs, acc_o,
1343
+ mask=offs_m[:, None] < seqlen_q
1344
+ )
1345
+ else:
1346
+ tl.store(
1347
+ out_ptrs, acc_o,
1348
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
1349
+ )
1350
+ lse_ptrs = (
1351
+ LSE +
1352
+ off_b * stride_lse_b +
1353
+ off_h * stride_lse_h +
1354
+ offs_m
1355
+ )
1356
+ if EVEN_M:
1357
+ tl.store(
1358
+ lse_ptrs, lse_m,
1359
+ )
1360
+ else:
1361
+ tl.store(
1362
+ lse_ptrs, lse_m,
1363
+ mask=offs_m < seqlen_q
1364
+ )
1365
+
1366
+ def triton_eva_agg_fwd(
1367
+ q, k, v, rfa_k, rfa_v,
1368
+ window_mask,
1369
+ chunk_mask,
1370
+ softmax_scale,
1371
+ window_size,
1372
+ chunks_per_window
1373
+ ):
1374
+ if rfa_k is None and rfa_v is None:
1375
+ empty_rfa_kv = 1
1376
+
1377
+ q, k, v = [
1378
+ x if x.stride(-1) == 1 else x.contiguous()
1379
+ for x in [q, k, v]
1380
+ ]
1381
+ else:
1382
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
1383
+ empty_rfa_kv = 0
1384
+
1385
+ q, k, v, rfa_k, rfa_v = [
1386
+ x if x.stride(-1) == 1 else x.contiguous()
1387
+ for x in [q, k, v, rfa_k, rfa_v]
1388
+ ]
1389
+
1390
+ # shape constraints
1391
+ batch, nheads, seqlen_q, head_dim = q.shape
1392
+ _, _, seqlen_k, _ = k.shape
1393
+ if empty_rfa_kv == 0:
1394
+ nchunks = rfa_k.shape[-2]
1395
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
1396
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
1397
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
1398
+ else:
1399
+ nchunks = 0
1400
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
1401
+ assert k.shape == (batch, nheads, seqlen_k, head_dim)
1402
+ assert v.shape == (batch, nheads, seqlen_k, head_dim)
1403
+
1404
+ assert head_dim <= 128, "We only test head dimensions up to 128"
1405
+ # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
1406
+ assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
1407
+ assert q.is_cuda and k.is_cuda and v.is_cuda
1408
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
1409
+
1410
+ mask_type = 0
1411
+ if window_mask is not None:
1412
+ mask_type = 1
1413
+ assert window_mask.dtype == torch.bool
1414
+ assert window_mask.is_cuda
1415
+ assert window_mask.dim() == 4
1416
+ assert window_mask.shape == (batch, 1, seqlen_q, window_size)
1417
+ if window_mask.stride(-1) != 1:
1418
+ window_mask = window_mask.contiguous()
1419
+
1420
+ assert chunk_mask is not None
1421
+ assert chunk_mask.dtype == torch.bool
1422
+ assert chunk_mask.is_cuda
1423
+ assert chunk_mask.dim() == 4
1424
+ assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
1425
+ if chunk_mask.stride(-1) != 1:
1426
+ chunk_mask = chunk_mask.contiguous()
1427
+
1428
+ chunk_mask_strides = (
1429
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1430
+ if mask_type == 1 else
1431
+ (0, 0)
1432
+ )
1433
+ window_mask_strides = (
1434
+ (window_mask.stride(0), window_mask.stride(2))
1435
+ if mask_type == 1 else
1436
+ (0, 0)
1437
+ )
1438
+
1439
+ rfa_k_strides = (
1440
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
1441
+ if empty_rfa_kv == 0 else
1442
+ (0, 0, 0)
1443
+ )
1444
+ rfa_v_strides = (
1445
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
1446
+ if empty_rfa_kv == 0 else
1447
+ (0, 0, 0)
1448
+ )
1449
+
1450
+ o = torch.empty_like(q)
1451
+ lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
1452
+
1453
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1454
+
1455
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
1456
+
1457
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1458
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
1459
+
1460
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
1461
+ _fwd_eva_agg_kernel[grid](
1462
+ q,
1463
+ k,
1464
+ v,
1465
+ rfa_k,
1466
+ rfa_v,
1467
+ window_mask,
1468
+ chunk_mask,
1469
+ o,
1470
+ lse,
1471
+ softmax_scale,
1472
+ q.stride(0), q.stride(1), q.stride(2),
1473
+ k.stride(0), k.stride(1), k.stride(2),
1474
+ v.stride(0), v.stride(1), v.stride(2),
1475
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1476
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1477
+ window_mask_strides[0], window_mask_strides[1],
1478
+ chunk_mask_strides[0], chunk_mask_strides[1],
1479
+ o.stride(0), o.stride(1), o.stride(2),
1480
+ lse.stride(0), lse.stride(1),
1481
+ nheads,
1482
+ seqlen_q,
1483
+ seqlen_k,
1484
+ nchunks,
1485
+ head_dim,
1486
+ chunks_per_window,
1487
+ window_size,
1488
+ mask_type,
1489
+ empty_rfa_kv,
1490
+ BLOCK_HEADDIM,
1491
+ BLOCK_M=BLOCK_M,
1492
+ BLOCK_N=BLOCK_N,
1493
+ num_warps=num_warps,
1494
+ num_stages=num_stages,
1495
+ )
1496
+ return o, lse
1497
+
1498
+ def triton_eva_agg_bwd(
1499
+ do,
1500
+ q, k, v, rfa_k, rfa_v,
1501
+ window_mask, chunk_mask,
1502
+ o, lse,
1503
+ dq, dk, dv, d_rfa_k, d_rfa_v,
1504
+ softmax_scale,
1505
+ window_size,
1506
+ chunks_per_window,
1507
+ empty_rfa_kv,
1508
+ mask_type,
1509
+ ):
1510
+ if do.stride(-1) != 1:
1511
+ do = do.contiguous()
1512
+
1513
+ # shape constraints
1514
+ batch, nheads, seqlen_q, head_dim = q.shape
1515
+ _, _, seqlen_k, _ = k.shape
1516
+ if empty_rfa_kv == 0:
1517
+ nchunks = rfa_k.shape[-2]
1518
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
1519
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
1520
+ assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
1521
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
1522
+ else:
1523
+ nchunks = 0
1524
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
1525
+
1526
+ assert lse.shape == (batch, nheads, seqlen_q)
1527
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
1528
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
1529
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
1530
+
1531
+ assert head_dim <= 128, "We only test head dimensions up to 128"
1532
+
1533
+ window_mask_strides = (
1534
+ (window_mask.stride(0), window_mask.stride(2))
1535
+ if mask_type == 1 else
1536
+ (0, 0)
1537
+ )
1538
+ chunk_mask_strides = (
1539
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1540
+ if mask_type == 1 else
1541
+ (0, 0)
1542
+ )
1543
+
1544
+ rfa_k_strides = (
1545
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
1546
+ if empty_rfa_kv == 0 else
1547
+ (0, 0, 0)
1548
+ )
1549
+ rfa_v_strides = (
1550
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
1551
+ if empty_rfa_kv == 0 else
1552
+ (0, 0, 0)
1553
+ )
1554
+
1555
+ d_rfa_k_strides = (
1556
+ (d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
1557
+ if empty_rfa_kv == 0 else
1558
+ (0, 0, 0)
1559
+ )
1560
+ d_rfa_v_strides = (
1561
+ (d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
1562
+ if empty_rfa_kv == 0 else
1563
+ (0, 0, 0)
1564
+ )
1565
+
1566
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1567
+
1568
+ do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
1569
+
1570
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
1571
+
1572
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1573
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
1574
+ grid = lambda META: (
1575
+ triton.cdiv(seqlen_q, META["BLOCK_M"]),
1576
+ batch * nheads,
1577
+ )
1578
+ _bwd_eva_agg_kernel_dq[grid](
1579
+ q,
1580
+ k,
1581
+ v,
1582
+ rfa_k,
1583
+ rfa_v,
1584
+ window_mask,
1585
+ chunk_mask,
1586
+ do,
1587
+ lse,
1588
+ do_t_o,
1589
+ dq,
1590
+ softmax_scale,
1591
+ q.stride(0), q.stride(1), q.stride(2),
1592
+ k.stride(0), k.stride(1), k.stride(2),
1593
+ v.stride(0), v.stride(1), v.stride(2),
1594
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1595
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1596
+ window_mask_strides[0], window_mask_strides[1],
1597
+ chunk_mask_strides[0], chunk_mask_strides[1],
1598
+ do.stride(0), do.stride(1), do.stride(2),
1599
+ lse.stride(0), lse.stride(1),
1600
+ do_t_o.stride(0), do_t_o.stride(1),
1601
+ dq.stride(0), dq.stride(1), dq.stride(2),
1602
+ nheads,
1603
+ seqlen_q,
1604
+ seqlen_k,
1605
+ nchunks,
1606
+ head_dim,
1607
+ chunks_per_window,
1608
+ window_size,
1609
+ mask_type,
1610
+ empty_rfa_kv,
1611
+ BLOCK_HEADDIM,
1612
+ BLOCK_M=BLOCK_M,
1613
+ BLOCK_N=BLOCK_N,
1614
+ num_warps=num_warps,
1615
+ num_stages=num_stages,
1616
+ )
1617
+
1618
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
1619
+ grid = lambda META: (
1620
+ triton.cdiv(seqlen_k, META["BLOCK_N"]),
1621
+ batch * nheads,
1622
+ )
1623
+ _bwd_eva_agg_kernel_dkdv[grid](
1624
+ q,
1625
+ k,
1626
+ v,
1627
+ window_mask,
1628
+ do,
1629
+ lse,
1630
+ do_t_o,
1631
+ dk,
1632
+ dv,
1633
+ softmax_scale,
1634
+ q.stride(0), q.stride(1), q.stride(2),
1635
+ k.stride(0), k.stride(1), k.stride(2),
1636
+ v.stride(0), v.stride(1), v.stride(2),
1637
+ window_mask_strides[0], window_mask_strides[1],
1638
+ do.stride(0), do.stride(1), do.stride(2),
1639
+ lse.stride(0), lse.stride(1),
1640
+ do_t_o.stride(0), do_t_o.stride(1),
1641
+ dk.stride(0), dk.stride(1), dk.stride(2),
1642
+ dv.stride(0), dv.stride(1), dv.stride(2),
1643
+ nheads,
1644
+ seqlen_q,
1645
+ seqlen_k,
1646
+ head_dim,
1647
+ window_size,
1648
+ mask_type,
1649
+ BLOCK_HEADDIM,
1650
+ BLOCK_M=BLOCK_M,
1651
+ BLOCK_N=BLOCK_N,
1652
+ num_warps=num_warps,
1653
+ num_stages=num_stages,
1654
+ )
1655
+ if empty_rfa_kv == 0:
1656
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
1657
+ grid = lambda META: (
1658
+ triton.cdiv(nchunks, META["BLOCK_N"]),
1659
+ batch * nheads,
1660
+ )
1661
+ _bwd_eva_agg_kernel_drfa_kv[grid](
1662
+ q,
1663
+ rfa_k,
1664
+ rfa_v,
1665
+ chunk_mask,
1666
+ do,
1667
+ lse,
1668
+ do_t_o,
1669
+ d_rfa_k,
1670
+ d_rfa_v,
1671
+ softmax_scale,
1672
+ q.stride(0), q.stride(1), q.stride(2),
1673
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1674
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1675
+ chunk_mask_strides[0], chunk_mask_strides[1],
1676
+ do.stride(0), do.stride(1), do.stride(2),
1677
+ lse.stride(0), lse.stride(1),
1678
+ do_t_o.stride(0), do_t_o.stride(1),
1679
+ d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
1680
+ d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
1681
+ nheads,
1682
+ seqlen_q,
1683
+ nchunks,
1684
+ head_dim,
1685
+ chunks_per_window,
1686
+ window_size,
1687
+ mask_type,
1688
+ BLOCK_HEADDIM,
1689
+ BLOCK_M=BLOCK_M,
1690
+ BLOCK_N=BLOCK_N,
1691
+ num_warps=num_warps,
1692
+ num_stages=num_stages,
1693
+ )
1694
+
1695
+
1696
+ class EvaAggFunc(torch.autograd.Function):
1697
+ @staticmethod
1698
+ def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
1699
+ if rfa_k is None and rfa_v is None:
1700
+ empty_rfa_kv = 1
1701
+ else:
1702
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
1703
+ empty_rfa_kv = 0
1704
+
1705
+ if window_mask is not None:
1706
+ mask_type = 1
1707
+ else:
1708
+ mask_type = 0
1709
+ o, lse = triton_eva_agg_fwd(
1710
+ q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
1711
+ )
1712
+ ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
1713
+ ctx.softmax_scale = softmax_scale
1714
+ ctx.window_size = window_size
1715
+ ctx.chunks_per_window = chunks_per_window
1716
+ ctx.empty_rfa_kv = empty_rfa_kv
1717
+ ctx.mask_type = mask_type
1718
+ return o
1719
+
1720
+ @staticmethod
1721
+ def backward(ctx, do):
1722
+ q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
1723
+ dq = torch.empty_like(q)
1724
+ dk = torch.empty_like(k)
1725
+ dv = torch.empty_like(v)
1726
+ if ctx.empty_rfa_kv == 0:
1727
+ d_rfa_k = torch.empty_like(rfa_k)
1728
+ d_rfa_v = torch.empty_like(rfa_v)
1729
+ else:
1730
+ d_rfa_k = None
1731
+ d_rfa_v = None
1732
+ triton_eva_agg_bwd(
1733
+ do,
1734
+ q,
1735
+ k,
1736
+ v,
1737
+ rfa_k,
1738
+ rfa_v,
1739
+ window_mask,
1740
+ chunk_mask,
1741
+ o,
1742
+ lse,
1743
+ dq,
1744
+ dk,
1745
+ dv,
1746
+ d_rfa_k,
1747
+ d_rfa_v,
1748
+ softmax_scale=ctx.softmax_scale,
1749
+ window_size=ctx.window_size,
1750
+ chunks_per_window=ctx.chunks_per_window,
1751
+ empty_rfa_kv=ctx.empty_rfa_kv,
1752
+ mask_type=ctx.mask_type,
1753
+ )
1754
+ return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
1755
+
1756
+
1757
+ def eva_agg_func_triton(
1758
+ q, k, v, rfa_k, rfa_v,
1759
+ window_mask, chunk_mask,
1760
+ softmax_scale=None, window_size=None, chunks_per_window=None,
1761
+ ):
1762
+ return EvaAggFunc.apply(
1763
+ q, k, v, rfa_k, rfa_v,
1764
+ window_mask, chunk_mask,
1765
+ softmax_scale, window_size, chunks_per_window,
1766
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_cache.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from transformers.cache_utils import Cache
4
+
5
+ class EvaCache(Cache):
6
+ """
7
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
8
+
9
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
10
+ `[batch_size, num_heads, seq_len, head_dim]`.
11
+ """
12
+
13
+ def __init__(self) -> None:
14
+ self.w_k: List[torch.Tensor] = []
15
+ self.w_v: List[torch.Tensor] = []
16
+
17
+ self.rf_q: List[torch.Tensor] = []
18
+ self.rf_k: List[torch.Tensor] = []
19
+ self.rf_v: List[torch.Tensor] = []
20
+
21
+ self.softmax_phi_k_v: List[torch.Tensor] = []
22
+ self.log_sum_phi_k: List[torch.Tensor] = []
23
+ self.rf_k_bar: List[torch.Tensor] = []
24
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
25
+
26
+ # attention masks temporary buffer
27
+ self.rf_mask: List[Optional[torch.Tensor]] = []
28
+ self.s_mask: List[torch.Tensor] = []
29
+ self.chunk_mask: List[torch.Tensor] = []
30
+
31
+ def __len__(self):
32
+ """
33
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
34
+ to the number of layers in the model.
35
+ """
36
+ return len(self.w_k)
37
+
38
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
39
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
40
+ # Cache without size limit -> all cache is usable
41
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
42
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
43
+ max_length = self.get_max_length()
44
+ previous_seq_length = self.get_seq_length(layer_idx)
45
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
46
+ return max_length - new_seq_length
47
+ return previous_seq_length
48
+
49
+ def reorder_cache(self, beam_idx: torch.LongTensor):
50
+ """Reorders the cache for beam search, given the selected beam indices."""
51
+ for layer_idx in range(len(self.w_k)):
52
+ device = self.w_k[layer_idx].device
53
+ self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device))
54
+
55
+ device = self.w_v[layer_idx].device
56
+ self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device))
57
+
58
+ device = self.rf_q[layer_idx].device
59
+ self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device))
60
+
61
+ device = self.rf_k[layer_idx].device
62
+ self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device))
63
+
64
+ device = self.rf_v[layer_idx].device
65
+ self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device))
66
+
67
+ device = self.softmax_phi_k_v[layer_idx].device
68
+ self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device))
69
+
70
+ device = self.log_sum_phi_k[layer_idx].device
71
+ self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device))
72
+
73
+ device = self.rf_k_bar[layer_idx].device
74
+ self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device))
75
+
76
+ device = self.rf_mask[layer_idx].device
77
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
78
+
79
+ device = self.s_mask[layer_idx].device
80
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
81
+
82
+ device = self.chunk_mask[layer_idx].device
83
+ self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device))
84
+ @property
85
+ def seen_tokens(self):
86
+ if hasattr(self, "_seen_tokens"):
87
+ return self._seen_tokens
88
+ else:
89
+ return None
90
+
91
+ def update_past_len(
92
+ self,
93
+ cur_q_len: int,
94
+ layer_idx: int
95
+ ):
96
+ # Update the number of seen tokens
97
+ if layer_idx == 0:
98
+ self._seen_tokens += cur_q_len
99
+ return self._seen_tokens
100
+
101
+ def update_mask(
102
+ self,
103
+ prev_s_mask,
104
+ cur_s_mask,
105
+ chunk_mask,
106
+ rf_mask,
107
+ layer_idx,
108
+ window_size,
109
+ chunk_size,
110
+ ):
111
+ ############################################
112
+ # compute masks for singletons
113
+ ############################################
114
+ q_len = None
115
+ if len(self.s_mask) <= layer_idx:
116
+ q_len = chunk_mask.shape[-2]
117
+ # prefill stage
118
+ # q is of shape [b, h, n, d]
119
+ if q_len < window_size:
120
+ assert prev_s_mask is None
121
+
122
+ # w_v = # [b, h, 1, j, d]
123
+ # store the past window-wise key-value pairs
124
+ self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :])
125
+ else:
126
+ # decoding stage
127
+ prev_s_mask = None
128
+
129
+ cached_s_mask = self.s_mask[layer_idx]
130
+ assert cached_s_mask is not None
131
+ if cached_s_mask.shape[-1] == window_size:
132
+ cur_s_mask = cur_s_mask
133
+ else:
134
+ cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1)
135
+
136
+ # store the past window-wise key-value pairs
137
+ self.s_mask[layer_idx] = cur_s_mask
138
+
139
+ ############################################
140
+ # compute masks for intra-chunks
141
+ ############################################
142
+ dump_rf_mask = None
143
+ if len(self.rf_mask) <= layer_idx:
144
+ # initialize chunk stats
145
+ # prefill stage
146
+ if q_len < chunk_size:
147
+ cur_rf_mask = rf_mask
148
+ else:
149
+ if q_len % chunk_size == 0:
150
+ dump_rf_mask = rf_mask
151
+ cur_rf_mask = None
152
+ else:
153
+ remainder_tokens = q_len % chunk_size
154
+ if rf_mask is not None:
155
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
156
+ else:
157
+ dump_rf_mask = None
158
+ cur_rf_mask = None
159
+ self.rf_mask.append(cur_rf_mask)
160
+ else:
161
+ past_rf_mask = self.rf_mask[layer_idx]
162
+ if past_rf_mask is not None:
163
+ # when decoding tokens, we always assume the
164
+ # incoming token mask is 0 (not masked)
165
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
166
+ else:
167
+ # we do not need to use rf_mask anymore after we receive generated tokens
168
+ cur_rf_mask = None
169
+ # We need to store rf_k_bar and RFA-results that
170
+ # compute the per-chunk RFA.
171
+
172
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
173
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size:
174
+ dump_rf_mask = cur_rf_mask
175
+ cur_rf_mask = None
176
+
177
+ self.rf_mask[layer_idx] = cur_rf_mask
178
+
179
+ ############################################
180
+ # compute masks for inter chunks
181
+ ############################################
182
+ if len(self.chunk_mask) <= layer_idx:
183
+ # prefill stage
184
+ # q is of shape [b, h, n, d]
185
+ if q_len < window_size:
186
+ cur_chunk_mask = chunk_mask
187
+ prev_chunk_mask = None
188
+ else:
189
+ if q_len % window_size == 0:
190
+ cur_chunk_mask = None
191
+ prev_chunk_mask = chunk_mask
192
+ else:
193
+ remainder_tokens = q_len % window_size
194
+ # [b, h, n-r, d] [b, h, r, d]
195
+ prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
196
+ bsz, num_heads, _, head_dim = prev_chunk_mask.shape
197
+ prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim)
198
+
199
+ assert prev_s_mask is not None
200
+ if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1:
201
+ # need to expand
202
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
203
+ # w_v = # [b, h, 1, j, d]
204
+ # store the past window-wise key-value pairs
205
+ self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :])
206
+ else:
207
+ # decoding stage
208
+ prev_chunk_mask = None
209
+ cur_chunk_mask = self.chunk_mask[layer_idx]
210
+
211
+ # if the current sequence length reaches <chunk_size>,
212
+ # we append a new 1 to the end of chunk_mask
213
+ seen_seq_len = self.get_seq_length(layer_idx)
214
+ if seen_seq_len > 0 and seen_seq_len % chunk_size == 0:
215
+ past_chunk_mask = self.chunk_mask[layer_idx]
216
+ if past_chunk_mask is not None:
217
+ # when decoding tokens, we always assume the
218
+ # incoming token mask is 0 (not masked)
219
+ cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1)
220
+ else:
221
+ cur_chunk_mask = chunk_mask
222
+ self.chunk_mask[layer_idx] = cur_chunk_mask
223
+
224
+ # if the len of current sequence reaches <window_size> + 1,
225
+ # we turn on the mask for most recent chunks
226
+ if seen_seq_len > 0 and seen_seq_len % window_size == 1:
227
+ cur_chunk_mask = self.chunk_mask[layer_idx]
228
+ # we do not need to use rf_mask anymore after we receive generated tokens
229
+ num_chunks_per_window = window_size // chunk_size
230
+ cur_chunk_mask[..., -num_chunks_per_window:] = False
231
+ self.chunk_mask[layer_idx] = cur_chunk_mask
232
+
233
+ return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask)
234
+
235
+ def update_singletons(
236
+ self,
237
+ q,
238
+ k,
239
+ v,
240
+ layer_idx,
241
+ window_size,
242
+ ):
243
+ if len(self.w_k) <= layer_idx:
244
+ # prefill stage
245
+ # q is of shape [b, h, n, d]
246
+ q_len = q.shape[-2]
247
+ if q_len < window_size:
248
+ w_q = q
249
+ w_k = k
250
+ w_v = v
251
+ past_w_q = past_w_k = past_w_v = None
252
+ else:
253
+ if q_len % window_size == 0:
254
+ w_q = None
255
+ w_k = None
256
+ w_v = None
257
+ past_w_q = q
258
+ past_w_k = k
259
+ past_w_v = v
260
+ else:
261
+ remainder_tokens = q_len % window_size
262
+ # [b, h, n-r, d] [b, h, r, d]
263
+ past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
264
+ past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
265
+ past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
266
+ bsz, num_heads, _, head_dim = past_w_q.shape
267
+ past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim)
268
+ past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim)
269
+ past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim)
270
+ # w_q = q[..., None, -window_size:, :] # [b, h, 1, j, d]
271
+ # w_k = # [b, h, 1, j, d]
272
+ # w_v = # [b, h, 1, j, d]
273
+ # store the past window-wise key-value pairs
274
+ # if w_k is None, it means we happen to pass in a sqeuence that is divisible by window_size
275
+ # we leave the cache with window_size-sized kv pairs to be cleared next iteration
276
+ self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :])
277
+ self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :])
278
+ else:
279
+ # decoding stage
280
+ past_w_q = past_w_k = past_w_v = None
281
+ # this is implemented as either a sliding window or fixed window
282
+ w_q = q # [b, h, 1, d]
283
+ w_k = k # [b, h, 1, d]
284
+ w_v = v # [b, h, 1, d]
285
+
286
+ cached_w_k = self.w_k[layer_idx]
287
+ assert cached_w_k is not None # [b, h, j, d]
288
+ if cached_w_k.shape[-2] == window_size:
289
+ w_k = w_k
290
+ else:
291
+ w_k = torch.cat([cached_w_k, w_k], dim=-2)
292
+
293
+ cached_w_v = self.w_v[layer_idx]
294
+ assert cached_w_v is not None
295
+ if cached_w_v.shape[-2] == window_size:
296
+ w_v = w_v
297
+ else:
298
+ w_v = torch.cat([cached_w_v, w_v], dim=-2)
299
+
300
+ # store the past window-wise key-value pairs
301
+ self.w_k[layer_idx] = w_k
302
+ self.w_v[layer_idx] = w_v
303
+ return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v)
304
+
305
+ def update_chunks(
306
+ self,
307
+ q,
308
+ k,
309
+ v,
310
+ layer_idx,
311
+ chunk_size
312
+ ):
313
+ q_len = q.shape[-2]
314
+ dump_q = None
315
+ dump_k = None
316
+ dump_v = None
317
+ if len(self.rf_q) <= layer_idx:
318
+ # initialize chunk stats
319
+ # prefill stage
320
+ if q_len < chunk_size:
321
+ rf_q = q
322
+ rf_k = k
323
+ rf_v = v
324
+ else:
325
+ if q_len % chunk_size == 0:
326
+ rf_q = None
327
+ rf_k = None
328
+ rf_v = None
329
+ dump_q = q
330
+ dump_k = k
331
+ dump_v = v
332
+ else:
333
+ remainder_tokens = q_len % chunk_size
334
+ # [b, h, n-r, d] [b, h, r, d]
335
+ dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
336
+ dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
337
+ dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
338
+ self.rf_q.append(rf_q)
339
+ self.rf_k.append(rf_k)
340
+ self.rf_v.append(rf_v)
341
+ else:
342
+ # decode tokens
343
+ # add query, key & value to the current chunk.
344
+ past_rf_q = self.rf_q[layer_idx]
345
+ if past_rf_q is not None:
346
+ rf_q = torch.cat([past_rf_q, q], dim=-2)
347
+ else:
348
+ rf_q = q
349
+
350
+ past_rf_k = self.rf_k[layer_idx]
351
+ if past_rf_k is not None:
352
+ rf_k = torch.cat([past_rf_k, k], dim=-2)
353
+ else:
354
+ rf_k = k
355
+
356
+ past_rf_v = self.rf_v[layer_idx]
357
+ if past_rf_v is not None:
358
+ rf_v = torch.cat([past_rf_v, v], dim=-2)
359
+ else:
360
+ rf_v = v
361
+
362
+ # We need to store rf_k_bar and RFA-results that
363
+ # compute the per-chunk RFA.
364
+
365
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
366
+ if rf_q.shape[-2] == chunk_size:
367
+ dump_q = rf_q
368
+ dump_k = rf_k
369
+ dump_v = rf_v
370
+ # clear the chunk
371
+ rf_q = None
372
+ rf_k = None
373
+ rf_v = None
374
+
375
+ self.rf_q[layer_idx] = rf_q
376
+ self.rf_k[layer_idx] = rf_k
377
+ self.rf_v[layer_idx] = rf_v
378
+
379
+ return dump_q, dump_k, dump_v
380
+
381
+ def update_chunk_rfas(
382
+ self,
383
+ softmax_phi_k_v,
384
+ log_sum_phi_k,
385
+ rf_k_bar,
386
+ layer_idx,
387
+ random_feature_dim
388
+ ):
389
+ if len(self.softmax_phi_k_v) <= layer_idx:
390
+ # prefill stage
391
+ self.softmax_phi_k_v.append(softmax_phi_k_v)
392
+ self.log_sum_phi_k.append(log_sum_phi_k)
393
+ self.rf_k_bar.append(rf_k_bar)
394
+ else:
395
+ # token decoding
396
+ past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx]
397
+ past_log_sum_phi_k = self.log_sum_phi_k[layer_idx]
398
+ past_rf_k_bar = self.rf_k_bar[layer_idx]
399
+
400
+ if past_softmax_phi_k_v is not None:
401
+ if random_feature_dim == 1:
402
+ dim = -2
403
+ else:
404
+ dim = -3
405
+ softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim)
406
+
407
+ if past_log_sum_phi_k is not None:
408
+ if random_feature_dim == 1:
409
+ dim = -2
410
+ else:
411
+ dim = -3
412
+ log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim)
413
+
414
+ if past_rf_k_bar is not None:
415
+ rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2)
416
+
417
+ self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v
418
+ self.log_sum_phi_k[layer_idx] = log_sum_phi_k
419
+ self.rf_k_bar[layer_idx] = rf_k_bar
420
+
421
+ return softmax_phi_k_v, log_sum_phi_k, rf_k_bar
422
+
423
+ def get_chunk_rfas(self, layer_idx):
424
+ if len(self.softmax_phi_k_v) <= layer_idx:
425
+ return (
426
+ None,
427
+ None,
428
+ None
429
+ )
430
+ else:
431
+ return (
432
+ self.softmax_phi_k_v[layer_idx],
433
+ self.log_sum_phi_k[layer_idx],
434
+ self.rf_k_bar[layer_idx]
435
+ )
436
+
437
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
438
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
439
+ if len(self.w_k) <= layer_idx:
440
+ return 0
441
+ return self._seen_tokens
442
+
443
+ def get_max_length(self) -> Optional[int]:
444
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
445
+ return None
446
+
447
+ def update(
448
+ self,
449
+ layer_idx: int,
450
+ cache_kwargs: Optional[Dict[str, Any]] = None,
451
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ raise NotImplementedError("`update` is not used in Eva Cache.")
453
+
454
+ class EvaStaticCacheForTriton(Cache):
455
+ """
456
+ A variant of EvaCache for eva's triton kernels
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ batch_size,
462
+ num_key_value_heads,
463
+ window_size,
464
+ head_dim,
465
+ num_layers,
466
+ dtype,
467
+ device
468
+ ) -> None:
469
+ self.past_window_k: List[torch.Tensor] = []
470
+ self.past_window_v: List[torch.Tensor] = []
471
+
472
+ cache_shape = (batch_size, num_key_value_heads, window_size, head_dim)
473
+ for idx in range(num_layers):
474
+ new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device)
475
+ new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device)
476
+ self.past_window_k.append(new_window_k)
477
+ self.past_window_v.append(new_window_v)
478
+
479
+ self.past_window_pos: List[int] = []
480
+
481
+ self.rfa_k: List[torch.Tensor] = []
482
+ self.rfa_v: List[torch.Tensor] = []
483
+ # self.rfa_mask: List[torch.Tensor] = []
484
+
485
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
486
+
487
+ # attention masks temporary buffer
488
+ self.rf_mask: List[Optional[torch.Tensor]] = []
489
+ self.s_mask: List[torch.Tensor] = []
490
+
491
+ def __len__(self):
492
+ """
493
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
494
+ to the number of layers in the model.
495
+ """
496
+ return len(self.past_window_pos)
497
+
498
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
499
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
500
+ # Cache without size limit -> all cache is usable
501
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
502
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
503
+ max_length = self.get_max_length()
504
+ previous_seq_length = self.get_seq_length(layer_idx)
505
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
506
+ return max_length - new_seq_length
507
+ return previous_seq_length
508
+
509
+ def reorder_cache(self, beam_idx: torch.LongTensor):
510
+ """Reorders the cache for beam search, given the selected beam indices."""
511
+ for layer_idx in range(len(self.past_window_k)):
512
+ device = self.past_window_k[layer_idx].device
513
+ self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device))
514
+
515
+ device = self.past_window_v[layer_idx].device
516
+ self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device))
517
+
518
+ device = self.rfa_k[layer_idx].device
519
+ self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device))
520
+
521
+ device = self.rfa_v[layer_idx].device
522
+ self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device))
523
+
524
+ # device = self.rfa_mask[layer_idx].device
525
+ # self.rfa_mask[layer_idx] = self.rfa_mask[layer_idx].index_select(0, beam_idx.to(device))
526
+
527
+ device = self.rf_mask[layer_idx].device
528
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
529
+
530
+ device = self.s_mask[layer_idx].device
531
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
532
+
533
+ @property
534
+ def seen_tokens(self):
535
+ if hasattr(self, "_seen_tokens"):
536
+ return self._seen_tokens
537
+ else:
538
+ return None
539
+
540
+ def update_past_len(
541
+ self,
542
+ cur_q_len: int,
543
+ layer_idx: int
544
+ ):
545
+ # Update the number of seen tokens
546
+ if layer_idx == 0:
547
+ self._seen_tokens += cur_q_len
548
+ return self._seen_tokens
549
+
550
+ def update_mask(
551
+ self,
552
+ s_mask,
553
+ rf_mask,
554
+ layer_idx,
555
+ window_size,
556
+ ):
557
+ ############################################
558
+ # compute masks for singletons
559
+ ############################################
560
+ if len(self.s_mask) <= layer_idx:
561
+ # prefill stage
562
+ # q is of shape [b, h, n, d]
563
+ # s_v = # [b, h, 1, j, d]
564
+ # store the past window-wise key-value pairs
565
+ if s_mask is None:
566
+ cur_s_mask = None
567
+ else:
568
+ q_len = s_mask.shape[-2]
569
+ # s_mask is of shape [b, h, n, w]
570
+ # let r = q_len % window_size
571
+ # if r == 0, the mask to be appended is of shape [b, h, 1, w]
572
+ # otherwise, r < w, the mask to be appended is of shape [b, h, 1, r]
573
+ remainder_tokens = q_len % window_size
574
+ if remainder_tokens == 0:
575
+ cur_s_mask = None
576
+ else:
577
+ cur_s_mask = s_mask[..., -1:, :remainder_tokens]
578
+ self.s_mask.append(cur_s_mask)
579
+ # we use the passed s_mask for subsequent computations
580
+ dump_s_mask = s_mask
581
+ else:
582
+ # decoding stage
583
+ past_s_mask = self.s_mask[layer_idx]
584
+ if past_s_mask is None:
585
+ assert s_mask is None
586
+ cur_s_mask = None
587
+ else:
588
+ assert s_mask is not None
589
+ cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1)
590
+
591
+ dump_s_mask = cur_s_mask
592
+ if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size:
593
+ cur_s_mask = None
594
+ # store the past window-wise key-value pairs
595
+ self.s_mask[layer_idx] = cur_s_mask
596
+
597
+ ############################################
598
+ # compute masks for intra-chunks
599
+ ############################################
600
+ dump_rf_mask = None
601
+ if len(self.rf_mask) <= layer_idx:
602
+ # initialize chunk stats
603
+ # prefill stage
604
+ if rf_mask is None:
605
+ cur_rf_mask = None
606
+ else:
607
+ q_len = rf_mask.shape[-2]
608
+ if q_len < window_size:
609
+ dump_rf_mask = None
610
+ cur_rf_mask = rf_mask
611
+ else:
612
+ if q_len % window_size == 0:
613
+ dump_rf_mask = rf_mask
614
+ cur_rf_mask = None
615
+ else:
616
+ remainder_tokens = q_len % window_size
617
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
618
+ self.rf_mask.append(cur_rf_mask)
619
+ else:
620
+ past_rf_mask = self.rf_mask[layer_idx]
621
+ if past_rf_mask is not None:
622
+ # when decoding tokens, we always assume the
623
+ # incoming token mask is 0 (not masked)
624
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
625
+ else:
626
+ cur_rf_mask = None
627
+
628
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size:
629
+ dump_rf_mask = cur_rf_mask
630
+ cur_rf_mask = None
631
+
632
+ self.rf_mask[layer_idx] = cur_rf_mask
633
+
634
+ return dump_s_mask, dump_rf_mask
635
+
636
+ def update_singletons_and_chunks(
637
+ self,
638
+ k,
639
+ v,
640
+ layer_idx,
641
+ window_size,
642
+ ):
643
+ if len(self.past_window_pos) <= layer_idx:
644
+ # prefill stage
645
+ s_k = k
646
+ s_v = v
647
+ input_len = k.shape[-2]
648
+ window_pos = 0
649
+ if input_len <= window_size:
650
+ new_window_pos = window_pos + input_len
651
+
652
+ cached_window_k = k
653
+ cached_window_v = v
654
+ dump_k = None
655
+ dump_v = None
656
+ else:
657
+ remainder_tokens = input_len % window_size
658
+ if remainder_tokens == 0:
659
+ remainder_tokens = window_size
660
+ new_window_pos = window_pos + remainder_tokens
661
+
662
+ # [b, h, n-r, d] [b, h, r, d]
663
+ cached_window_k = k[..., -remainder_tokens:, :]
664
+ cached_window_v = v[..., -remainder_tokens:, :]
665
+ dump_k = k[..., :-remainder_tokens, :]
666
+ dump_v = v[..., :-remainder_tokens, :]
667
+ # store the past window-wise key-value pairs
668
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k
669
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v
670
+ self.past_window_pos.append(new_window_pos)
671
+ else:
672
+ # decoding stage
673
+ # if the previous cache has full tokens,
674
+ # roll back to the first elements
675
+ if self.past_window_pos[layer_idx] == window_size:
676
+ self.past_window_pos[layer_idx] = 0
677
+ dump_k = self.past_window_k[layer_idx].clone()
678
+ dump_v = self.past_window_v[layer_idx].clone()
679
+ else:
680
+ dump_k = None
681
+ dump_v = None
682
+
683
+ input_len = k.shape[-2]
684
+ window_pos = self.past_window_pos[layer_idx]
685
+ new_window_pos = window_pos + input_len
686
+
687
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k
688
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v
689
+
690
+ s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :]
691
+ s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :]
692
+
693
+ self.past_window_pos[layer_idx] = new_window_pos
694
+
695
+ return s_k, s_v, dump_k, dump_v
696
+
697
+ def update_chunk_rfas(
698
+ self,
699
+ rfa_k,
700
+ rfa_v,
701
+ layer_idx,
702
+ ):
703
+ if len(self.rfa_k) <= layer_idx:
704
+ # prefill stage
705
+ self.rfa_k.append(rfa_k)
706
+ self.rfa_v.append(rfa_v)
707
+ else:
708
+ # token decoding
709
+ past_rfa_k = self.rfa_k[layer_idx]
710
+ past_rfa_v = self.rfa_v[layer_idx]
711
+
712
+ if past_rfa_k is not None:
713
+ rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2)
714
+
715
+ if past_rfa_v is not None:
716
+ rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2)
717
+
718
+ self.rfa_k[layer_idx] = rfa_k
719
+ self.rfa_v[layer_idx] = rfa_v
720
+
721
+ return rfa_k, rfa_v
722
+
723
+ def get_past_window_pos(self, layer_idx):
724
+ if len(self.past_window_pos) <= layer_idx:
725
+ return None
726
+ else:
727
+ return self.past_window_pos[layer_idx]
728
+
729
+ def get_past_window_kv(self, layer_idx):
730
+ if len(self.past_window_pos) <= layer_idx:
731
+ return None, None
732
+ else:
733
+ return (
734
+ self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :],
735
+ self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :]
736
+ )
737
+
738
+ def get_chunk_rfas(self, layer_idx):
739
+ if len(self.rfa_k) <= layer_idx:
740
+ return None, None
741
+ else:
742
+ return self.rfa_k[layer_idx], self.rfa_v[layer_idx]
743
+
744
+ def get_seq_length(self, layer_idx = 0) -> int:
745
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
746
+ # layer_idx must be provided since otherwise
747
+ # any layer > 0 can only get the updated _seen_tokens
748
+ if len(self.past_window_pos) <= layer_idx:
749
+ return 0
750
+ return self._seen_tokens
751
+
752
+ def get_max_length(self) -> Optional[int]:
753
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
754
+ return None
755
+
756
+ def update(
757
+ self,
758
+ layer_idx: int,
759
+ cache_kwargs: Optional[Dict[str, Any]] = None,
760
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
761
+ raise NotImplementedError("`update` is not used in Eva Cache.")
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_prep_kv_kernel.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
10
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
11
+ }
12
+ )
13
+ @triton.jit
14
+ def _fwd_eva_prep_kv_kernel(
15
+ K, # [b, h, n, d]
16
+ V, # [b, h, n, d]
17
+ PARAM_MU, # [1, h, 1, 1, d]
18
+ PARAM_PHI, # [1, h, 1, 1, d]
19
+ Mask, # [b, h, n, 1]
20
+ Out_RFA_K, # [b, h, c, d]
21
+ Out_RFA_V, # [b, h, c, d]
22
+ softmax_scale,
23
+ stride_kb, stride_kh, stride_kn,
24
+ stride_vb, stride_vh, stride_vn,
25
+ stride_mu_h,
26
+ stride_phi_h,
27
+ stride_mb, stride_mn,
28
+ stride_ok_b, stride_ok_h, stride_ok_c,
29
+ stride_ov_b, stride_ov_h, stride_ov_c,
30
+ nheads,
31
+ seqlen,
32
+ nchunks,
33
+ headdim,
34
+ CHUNKS_PER_BLOCK: tl.constexpr,
35
+ CHUNK_SIZE: tl.constexpr,
36
+ MASK_TYPE: tl.constexpr,
37
+ BLOCK_HEADDIM: tl.constexpr,
38
+ EVEN_N: tl.constexpr,
39
+ EVEN_HEADDIM: tl.constexpr,
40
+ BLOCK_N: tl.constexpr,
41
+ ):
42
+ start_n = tl.program_id(0)
43
+ offs_bh = tl.program_id(1)
44
+ offs_h = offs_bh % nheads
45
+ offs_b = offs_bh // nheads
46
+ # initialize offsets
47
+ # we load BLOCK_N keys and values each time, and
48
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
49
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
50
+ offs_m = tl.arange(0, CHUNK_SIZE)
51
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
52
+
53
+ k_ptrs = (
54
+ K +
55
+ offs_b * stride_kb +
56
+ offs_h * stride_kh +
57
+ (
58
+ (
59
+ start_n * BLOCK_N +
60
+ offs_c[:, None, None] * CHUNK_SIZE +
61
+ offs_m[None, :, None]
62
+ ) * stride_kn +
63
+ offs_d[None, None, :]
64
+ )
65
+ )
66
+ v_ptrs = (
67
+ V +
68
+ offs_b * stride_vb +
69
+ offs_h * stride_vh +
70
+ (
71
+ (
72
+ start_n * BLOCK_N +
73
+ offs_c[:, None, None] * CHUNK_SIZE +
74
+ offs_m[None, :, None]
75
+ ) * stride_vn +
76
+ offs_d[None, None, :]
77
+ )
78
+ )
79
+ param_mu_ptrs = (
80
+ PARAM_MU +
81
+ offs_h * stride_mu_h +
82
+ offs_d[None, None, :]
83
+ )
84
+ param_phi_ptrs = (
85
+ PARAM_PHI +
86
+ offs_h * stride_phi_h +
87
+ offs_d[None, None, :]
88
+ )
89
+ log2e = 1.4426950408889634
90
+ if MASK_TYPE == 1:
91
+ m_ptrs = (
92
+ Mask +
93
+ offs_b * stride_mb +
94
+ (
95
+ (
96
+ start_n * BLOCK_N +
97
+ offs_c[:, None] * CHUNK_SIZE +
98
+ offs_m[None, :]
99
+ ) * stride_mn
100
+ )
101
+ )
102
+ if EVEN_N:
103
+ if EVEN_HEADDIM:
104
+ k = tl.load(
105
+ k_ptrs
106
+ )
107
+ else:
108
+ k = tl.load(
109
+ k_ptrs,
110
+ mask=offs_d[None, None, :] < headdim,
111
+ other=0.0
112
+ )
113
+ else:
114
+ if EVEN_HEADDIM:
115
+ k = tl.load(
116
+ k_ptrs,
117
+ mask=(
118
+ start_n * BLOCK_N +
119
+ offs_c[:, None, None] * CHUNK_SIZE +
120
+ offs_m[None, :, None]
121
+ ) < seqlen,
122
+ other=0.0
123
+ )
124
+ else:
125
+ k = tl.load(
126
+ k_ptrs,
127
+ mask=(
128
+ (
129
+ start_n * BLOCK_N +
130
+ offs_c[:, None, None] * CHUNK_SIZE +
131
+ offs_m[None, :, None]
132
+ ) < seqlen
133
+ ) & (offs_d[None, None, :] < headdim),
134
+ other=0.0
135
+ )
136
+
137
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
138
+ rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
139
+ rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
140
+ rfa_k_c_w *= log2e
141
+ if MASK_TYPE == 1:
142
+ if EVEN_N:
143
+ mask = tl.load(
144
+ m_ptrs
145
+ )
146
+ else:
147
+ mask = tl.load(
148
+ m_ptrs,
149
+ mask=(
150
+ start_n * BLOCK_N +
151
+ offs_c[:, None] * CHUNK_SIZE +
152
+ offs_m[None, :]
153
+ ) < seqlen,
154
+ other=1,
155
+ )
156
+ rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
157
+
158
+ m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
159
+ masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
160
+ m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
161
+ rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
162
+ denom_k = tl.sum(rfa_k_c_w, axis=-1)
163
+ denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
164
+ rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
165
+ rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
166
+ # TODO: understand why rematerialize offsets to save registers?
167
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
168
+ out_rfa_k_ptrs = (
169
+ Out_RFA_K +
170
+ offs_b * stride_ok_b +
171
+ offs_h * stride_ok_h +
172
+ (offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
173
+ )
174
+
175
+ if EVEN_N:
176
+ if EVEN_HEADDIM:
177
+ tl.store(
178
+ out_rfa_k_ptrs, rfa_k_c
179
+ )
180
+ else:
181
+ tl.store(
182
+ out_rfa_k_ptrs, rfa_k_c,
183
+ mask=offs_d[None, :] < headdim
184
+ )
185
+ else:
186
+ if EVEN_HEADDIM:
187
+ tl.store(
188
+ out_rfa_k_ptrs, rfa_k_c,
189
+ mask=offs_out_c[:, None] < nchunks
190
+ )
191
+ else:
192
+ tl.store(
193
+ out_rfa_k_ptrs, rfa_k_c,
194
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
195
+ )
196
+
197
+
198
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
199
+ rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
200
+ rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
201
+ rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
202
+ rfa_v_c_w *= log2e * softmax_scale
203
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
204
+ rfa_v_c_w += tl.where(
205
+ (
206
+ start_n * BLOCK_N +
207
+ offs_c[:, None] * CHUNK_SIZE +
208
+ offs_m[None, :]
209
+ ) < seqlen,
210
+ 0,
211
+ float("-inf")
212
+ )
213
+
214
+ if MASK_TYPE == 1:
215
+ rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
216
+
217
+ if EVEN_N:
218
+ if EVEN_HEADDIM:
219
+ v = tl.load(
220
+ v_ptrs
221
+ )
222
+ else:
223
+ v = tl.load(
224
+ v_ptrs,
225
+ mask=offs_d[None, None, :] < headdim,
226
+ other=0.0
227
+ )
228
+ else:
229
+ if EVEN_HEADDIM:
230
+ v = tl.load(
231
+ v_ptrs,
232
+ mask=(
233
+ start_n * BLOCK_N +
234
+ offs_c[:, None, None] * CHUNK_SIZE +
235
+ offs_m[None, :, None]
236
+ ) < seqlen,
237
+ other=0.0
238
+ )
239
+ else:
240
+ v = tl.load(
241
+ v_ptrs,
242
+ mask=(
243
+ (
244
+ start_n * BLOCK_N +
245
+ offs_c[:, None, None] * CHUNK_SIZE +
246
+ offs_m[None, :, None]
247
+ ) < seqlen
248
+ ) & (offs_d[None, None, :] < headdim),
249
+ other=0.0
250
+ )
251
+
252
+
253
+ m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
254
+ masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
255
+ m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
256
+ rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
257
+ denom_v = tl.sum(rfa_v_c_w, axis=-1)
258
+ denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
259
+ rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
260
+ rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
261
+
262
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
263
+ out_rfa_v_ptrs = (
264
+ Out_RFA_V +
265
+ offs_b * stride_ov_b +
266
+ offs_h * stride_ov_h +
267
+ (offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
268
+ )
269
+ if EVEN_N:
270
+ if EVEN_HEADDIM:
271
+ tl.store(
272
+ out_rfa_v_ptrs, rfa_v_c
273
+ )
274
+ else:
275
+ tl.store(
276
+ out_rfa_v_ptrs, rfa_v_c,
277
+ mask=offs_d[None, :] < headdim
278
+ )
279
+ else:
280
+ if EVEN_HEADDIM:
281
+ tl.store(
282
+ out_rfa_v_ptrs, rfa_v_c,
283
+ mask=offs_out_c[:, None] < nchunks
284
+ )
285
+ else:
286
+ tl.store(
287
+ out_rfa_v_ptrs, rfa_v_c,
288
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
289
+ )
290
+
291
+
292
+
293
+ @triton.heuristics(
294
+ {
295
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
296
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
297
+ }
298
+ )
299
+ @triton.jit
300
+ def _bwd_eva_prep_kv_kernel(
301
+ RFA_K, # [b, h, c, d]
302
+ RFA_V, # [b, h, c, d]
303
+ K, # [b, h, n, d]
304
+ V, # [b, h, n, d]
305
+ PARAM_MU, # [1, h, 1, 1, d]
306
+ PARAM_PHI, # [1, h, 1, 1, d]
307
+ Mask, # [b, h, n, 1]
308
+ D_RFA_K, # [b, h, c, d]
309
+ D_RFA_V, # [b, h, c, d]
310
+ D_K, # [b, h, n, d]
311
+ D_V, # [b, h, n, d]
312
+ D_PARAM_MU_PARTIAL, # [b, h, g, d]
313
+ D_PARAM_PHI_PARTIAL, # [b, h, g, d]
314
+ softmax_scale,
315
+ stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
316
+ stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
317
+ stride_kb, stride_kh, stride_kn,
318
+ stride_vb, stride_vh, stride_vn,
319
+ stride_mu_h,
320
+ stride_phi_h,
321
+ stride_mb, stride_mn,
322
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
323
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
324
+ stride_d_k_b, stride_d_k_h, stride_d_k_n,
325
+ stride_d_v_b, stride_d_v_h, stride_d_v_n,
326
+ stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
327
+ stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
328
+ nheads,
329
+ seqlen,
330
+ nchunks,
331
+ headdim,
332
+ CHUNKS_PER_BLOCK: tl.constexpr,
333
+ CHUNK_SIZE: tl.constexpr,
334
+ MASK_TYPE: tl.constexpr,
335
+ BLOCK_HEADDIM: tl.constexpr,
336
+ EVEN_N: tl.constexpr,
337
+ EVEN_HEADDIM: tl.constexpr,
338
+ BLOCK_N: tl.constexpr,
339
+ ):
340
+ start_n = tl.program_id(0)
341
+ offs_bh = tl.program_id(1)
342
+ offs_h = offs_bh % nheads
343
+ offs_b = offs_bh // nheads
344
+ # initialize offsets
345
+ # we load BLOCK_N keys and values each time, and
346
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
347
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
348
+ offs_m = tl.arange(0, CHUNK_SIZE)
349
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
350
+
351
+ offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
352
+
353
+ k_ptrs = (
354
+ K +
355
+ offs_b * stride_kb +
356
+ offs_h * stride_kh +
357
+ (
358
+ (
359
+ start_n * BLOCK_N +
360
+ offs_c[:, None, None] * CHUNK_SIZE +
361
+ offs_m[None, :, None]
362
+ ) * stride_kn +
363
+ offs_d[None, None, :]
364
+ )
365
+ )
366
+ rfa_k_ptrs = (
367
+ RFA_K +
368
+ offs_b * stride_rfa_k_b +
369
+ offs_h * stride_rfa_k_h +
370
+ (offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
371
+ )
372
+ rfa_v_ptrs = (
373
+ RFA_V +
374
+ offs_b * stride_rfa_v_b +
375
+ offs_h * stride_rfa_v_h +
376
+ (offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
377
+ )
378
+
379
+ d_rfa_k_ptrs = (
380
+ D_RFA_K +
381
+ offs_b * stride_d_rfa_k_b +
382
+ offs_h * stride_d_rfa_k_h +
383
+ (offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
384
+ )
385
+ d_rfa_v_ptrs = (
386
+ D_RFA_V +
387
+ offs_b * stride_d_rfa_v_b +
388
+ offs_h * stride_d_rfa_v_h +
389
+ (offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
390
+ )
391
+
392
+ param_mu_ptrs = (
393
+ PARAM_MU +
394
+ offs_h * stride_mu_h +
395
+ offs_d[None, None, :]
396
+ )
397
+ param_phi_ptrs = (
398
+ PARAM_PHI +
399
+ offs_h * stride_phi_h +
400
+ offs_d[None, None, :]
401
+ )
402
+
403
+ log2e = 1.4426950408889634
404
+ if MASK_TYPE == 1:
405
+ m_ptrs = (
406
+ Mask +
407
+ offs_b * stride_mb +
408
+ (
409
+ (
410
+ start_n * BLOCK_N +
411
+ offs_c[:, None] * CHUNK_SIZE +
412
+ offs_m[None, :]
413
+ ) * stride_mn
414
+ )
415
+ )
416
+ if EVEN_N:
417
+ if EVEN_HEADDIM:
418
+ k = tl.load(
419
+ k_ptrs
420
+ )
421
+ else:
422
+ k = tl.load(
423
+ k_ptrs,
424
+ mask=offs_d[None, None, :] < headdim,
425
+ other=0.0
426
+ )
427
+ else:
428
+ if EVEN_HEADDIM:
429
+ k = tl.load(
430
+ k_ptrs,
431
+ mask=(
432
+ start_n * BLOCK_N +
433
+ offs_c[:, None, None] * CHUNK_SIZE +
434
+ offs_m[None, :, None]
435
+ ) < seqlen,
436
+ other=0.0
437
+ )
438
+ else:
439
+ k = tl.load(
440
+ k_ptrs,
441
+ mask=(
442
+ (
443
+ start_n * BLOCK_N +
444
+ offs_c[:, None, None] * CHUNK_SIZE +
445
+ offs_m[None, :, None]
446
+ ) < seqlen
447
+ ) & (offs_d[None, None, :] < headdim),
448
+ other=0.0
449
+ )
450
+
451
+ if EVEN_N:
452
+ if EVEN_HEADDIM:
453
+ rfa_k = tl.load(
454
+ rfa_k_ptrs
455
+ )
456
+ else:
457
+ rfa_k = tl.load(
458
+ rfa_k_ptrs,
459
+ mask=offs_d[None, :] < headdim,
460
+ other=0.0
461
+ )
462
+ else:
463
+ if EVEN_HEADDIM:
464
+ rfa_k = tl.load(
465
+ rfa_k_ptrs,
466
+ mask=offs_rfa_c[:, None] < nchunks,
467
+ other=0.0
468
+ )
469
+ else:
470
+ rfa_k = tl.load(
471
+ rfa_k_ptrs,
472
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
473
+ other=0.0
474
+ )
475
+
476
+ if EVEN_N:
477
+ if EVEN_HEADDIM:
478
+ d_rfa_k = tl.load(
479
+ d_rfa_k_ptrs
480
+ )
481
+ else:
482
+ d_rfa_k = tl.load(
483
+ d_rfa_k_ptrs,
484
+ mask=offs_d[None, :] < headdim,
485
+ other=0.0
486
+ )
487
+ else:
488
+ if EVEN_HEADDIM:
489
+ d_rfa_k = tl.load(
490
+ d_rfa_k_ptrs,
491
+ mask=offs_rfa_c[:, None] < nchunks,
492
+ other=0.0
493
+ )
494
+ else:
495
+ d_rfa_k = tl.load(
496
+ d_rfa_k_ptrs,
497
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
498
+ other=0.0
499
+ )
500
+
501
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
502
+ mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
503
+ mu_c_w += tl.sum(k * param_mu, axis=-1)
504
+ mu_c_w *= log2e
505
+
506
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
507
+ mu_c_w += tl.where(
508
+ (
509
+ start_n * BLOCK_N +
510
+ offs_c[:, None] * CHUNK_SIZE +
511
+ offs_m[None, :]
512
+ ) < seqlen,
513
+ 0,
514
+ float("-inf")
515
+ )
516
+
517
+ if MASK_TYPE == 1:
518
+ if EVEN_N:
519
+ mask = tl.load(
520
+ m_ptrs
521
+ )
522
+ else:
523
+ mask = tl.load(
524
+ m_ptrs,
525
+ mask=(
526
+ start_n * BLOCK_N +
527
+ offs_c[:, None] * CHUNK_SIZE +
528
+ offs_m[None, :]
529
+ ) < seqlen,
530
+ other=1,
531
+ )
532
+ mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
533
+
534
+ # [c, w]
535
+ m_mu_c_w = tl.max(mu_c_w, axis=-1)
536
+ masked_out_rows_mu = (m_mu_c_w == float("-inf"))
537
+ m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
538
+ mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
539
+ denom_mu = tl.sum(mu_c_w, axis=-1)
540
+ denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
541
+ mu_tilde_c_w = mu_c_w / denom_mu[:, None]
542
+ mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
543
+ # [c, d] [c, w, d] -> [c, w]
544
+ d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
545
+ # [c, d] [c, d] -> [c]
546
+ d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
547
+ d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
548
+
549
+ # [c, w] [c, w, d] -> [d]
550
+ d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
551
+ # [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
552
+ d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
553
+
554
+ d_param_mu_partial_ptrs = (
555
+ D_PARAM_MU_PARTIAL +
556
+ offs_b * stride_d_mu_b +
557
+ offs_h * stride_d_mu_h +
558
+ start_n * stride_d_mu_g +
559
+ offs_d
560
+ )
561
+ if EVEN_HEADDIM:
562
+ tl.store(
563
+ d_param_mu_partial_ptrs, d_param_mu
564
+ )
565
+ else:
566
+ tl.store(
567
+ d_param_mu_partial_ptrs, d_param_mu,
568
+ mask=offs_d < headdim
569
+ )
570
+
571
+
572
+ v_ptrs = (
573
+ V +
574
+ offs_b * stride_vb +
575
+ offs_h * stride_vh +
576
+ (
577
+ (
578
+ start_n * BLOCK_N +
579
+ offs_c[:, None, None] * CHUNK_SIZE +
580
+ offs_m[None, :, None]
581
+ ) * stride_vn +
582
+ offs_d[None, None, :]
583
+ )
584
+ )
585
+ if EVEN_N:
586
+ if EVEN_HEADDIM:
587
+ v = tl.load(
588
+ v_ptrs
589
+ )
590
+ else:
591
+ v = tl.load(
592
+ v_ptrs,
593
+ mask=offs_d[None, None, :] < headdim,
594
+ other=0.0
595
+ )
596
+ else:
597
+ if EVEN_HEADDIM:
598
+ v = tl.load(
599
+ v_ptrs,
600
+ mask=(
601
+ start_n * BLOCK_N +
602
+ offs_c[:, None, None] * CHUNK_SIZE +
603
+ offs_m[None, :, None]
604
+ ) < seqlen,
605
+ other=0.0
606
+ )
607
+ else:
608
+ v = tl.load(
609
+ v_ptrs,
610
+ mask=(
611
+ (
612
+ start_n * BLOCK_N +
613
+ offs_c[:, None, None] * CHUNK_SIZE +
614
+ offs_m[None, :, None]
615
+ ) < seqlen
616
+ ) & (offs_d[None, None, :] < headdim),
617
+ other=0.0
618
+ )
619
+
620
+
621
+ if EVEN_N:
622
+ if EVEN_HEADDIM:
623
+ rfa_v = tl.load(
624
+ rfa_v_ptrs
625
+ )
626
+ else:
627
+ rfa_v = tl.load(
628
+ rfa_v_ptrs,
629
+ mask=offs_d[None, :] < headdim,
630
+ other=0.0
631
+ )
632
+ else:
633
+ if EVEN_HEADDIM:
634
+ rfa_v = tl.load(
635
+ rfa_v_ptrs,
636
+ mask=offs_rfa_c[:, None] < nchunks,
637
+ other=0.0
638
+ )
639
+ else:
640
+ rfa_v = tl.load(
641
+ rfa_v_ptrs,
642
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
643
+ other=0.0
644
+ )
645
+
646
+ if EVEN_N:
647
+ if EVEN_HEADDIM:
648
+ d_rfa_v = tl.load(
649
+ d_rfa_v_ptrs
650
+ )
651
+ else:
652
+ d_rfa_v = tl.load(
653
+ d_rfa_v_ptrs,
654
+ mask=offs_d[None, :] < headdim,
655
+ other=0.0
656
+ )
657
+ else:
658
+ if EVEN_HEADDIM:
659
+ d_rfa_v = tl.load(
660
+ d_rfa_v_ptrs,
661
+ mask=offs_rfa_c[:, None] < nchunks,
662
+ other=0.0
663
+ )
664
+ else:
665
+ d_rfa_v = tl.load(
666
+ d_rfa_v_ptrs,
667
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
668
+ other=0.0
669
+ )
670
+
671
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
672
+ phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
673
+ phi_c_w += tl.sum(k * param_phi, axis=-1)
674
+ phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
675
+ phi_c_w *= log2e * softmax_scale
676
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
677
+ phi_c_w += tl.where(
678
+ (
679
+ start_n * BLOCK_N +
680
+ offs_c[:, None] * CHUNK_SIZE +
681
+ offs_m[None, :]
682
+ ) < seqlen,
683
+ 0,
684
+ float("-inf")
685
+ )
686
+
687
+ if MASK_TYPE == 1:
688
+ phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
689
+
690
+
691
+ m_phi_c_w = tl.max(phi_c_w, axis=-1)
692
+ masked_out_rows_phi = (m_phi_c_w == float("-inf"))
693
+ m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
694
+ phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
695
+ denom_phi = tl.sum(phi_c_w, axis=-1)
696
+ denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
697
+ phi_tilde_c_w = phi_c_w / denom_phi[:, None]
698
+ # phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
699
+ # phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
700
+ phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
701
+ d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
702
+ d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
703
+ d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
704
+
705
+ d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
706
+ d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
707
+ # [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
708
+ d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
709
+
710
+ d_k_ptrs = (
711
+ D_K +
712
+ offs_b * stride_d_k_b +
713
+ offs_h * stride_d_k_h +
714
+ (
715
+ (
716
+ start_n * BLOCK_N +
717
+ offs_c[:, None, None] * CHUNK_SIZE +
718
+ offs_m[None, :, None]
719
+ ) * stride_d_k_n +
720
+ offs_d[None, None, :]
721
+ )
722
+ )
723
+ d_v_ptrs = (
724
+ D_V +
725
+ offs_b * stride_d_v_b +
726
+ offs_h * stride_d_v_h +
727
+ (
728
+ (
729
+ start_n * BLOCK_N +
730
+ offs_c[:, None, None] * CHUNK_SIZE +
731
+ offs_m[None, :, None]
732
+ ) * stride_d_v_n +
733
+ offs_d[None, None, :]
734
+ )
735
+ )
736
+ if EVEN_N:
737
+ if EVEN_HEADDIM:
738
+ tl.store(
739
+ d_k_ptrs, d_k
740
+ )
741
+ tl.store(
742
+ d_v_ptrs, d_v
743
+ )
744
+ else:
745
+ tl.store(
746
+ d_k_ptrs, d_k,
747
+ mask=offs_d[None, None, :] < headdim
748
+ )
749
+ tl.store(
750
+ d_v_ptrs, d_v,
751
+ mask=offs_d[None, None, :] < headdim
752
+ )
753
+ else:
754
+ if EVEN_HEADDIM:
755
+ tl.store(
756
+ d_k_ptrs, d_k,
757
+ mask=(
758
+ (
759
+ start_n * BLOCK_N +
760
+ offs_c[:, None, None] * CHUNK_SIZE +
761
+ offs_m[None, :, None]
762
+ ) < seqlen
763
+ ),
764
+ )
765
+ tl.store(
766
+ d_v_ptrs, d_v,
767
+ mask=(
768
+ (
769
+ start_n * BLOCK_N +
770
+ offs_c[:, None, None] * CHUNK_SIZE +
771
+ offs_m[None, :, None]
772
+ ) < seqlen
773
+ ),
774
+ )
775
+ else:
776
+ tl.store(
777
+ d_k_ptrs, d_k,
778
+ mask=(
779
+ (
780
+ start_n * BLOCK_N +
781
+ offs_c[:, None, None] * CHUNK_SIZE +
782
+ offs_m[None, :, None]
783
+ ) < seqlen
784
+ ) & (offs_d[None, None, :] < headdim),
785
+ )
786
+ tl.store(
787
+ d_v_ptrs, d_v,
788
+ mask=(
789
+ (
790
+ start_n * BLOCK_N +
791
+ offs_c[:, None, None] * CHUNK_SIZE +
792
+ offs_m[None, :, None]
793
+ ) < seqlen
794
+ ) & (offs_d[None, None, :] < headdim),
795
+ )
796
+ d_param_phi_partial_ptrs = (
797
+ D_PARAM_PHI_PARTIAL +
798
+ offs_b * stride_d_phi_b +
799
+ offs_h * stride_d_phi_h +
800
+ start_n * stride_d_phi_g +
801
+ offs_d
802
+ )
803
+ if EVEN_HEADDIM:
804
+ tl.store(
805
+ d_param_phi_partial_ptrs, d_param_phi
806
+ )
807
+ else:
808
+ tl.store(
809
+ d_param_phi_partial_ptrs, d_param_phi,
810
+ mask=offs_d < headdim
811
+ )
812
+
813
+ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
814
+ k, v, param_mu, param_phi = [
815
+ x if x.stride(-1) == 1 else x.contiguous()
816
+ for x in [k, v, param_mu, param_phi]
817
+ ]
818
+
819
+ # shape constraints
820
+ batch, nheads, seqlen, head_dim = k.shape
821
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
822
+ nchunks = seqlen // chunksize
823
+ assert k.shape == (batch, nheads, seqlen, head_dim)
824
+ assert v.shape == (batch, nheads, seqlen, head_dim)
825
+ assert param_mu.shape == (1, nheads, 1, 1, head_dim)
826
+ assert param_phi.shape == (1, nheads, 1, 1, head_dim)
827
+ assert head_dim <= 128, "We only test head dimensions up to 128"
828
+ assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
829
+ assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
830
+ assert k.is_cuda and v.is_cuda
831
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
832
+
833
+ mask_type = 0
834
+ if mask is not None:
835
+ mask_type = 1
836
+ assert mask.dtype == torch.bool
837
+ assert mask.is_cuda
838
+ assert mask.dim() == 4
839
+ assert mask.shape == (batch, 1, seqlen, 1)
840
+ if mask.stride(-1) != 1:
841
+ mask = mask.contiguous()
842
+ mask_strides = (
843
+ (mask.stride(0), mask.stride(2))
844
+ if mask_type == 1 else
845
+ (0, 0)
846
+ )
847
+ out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
848
+ out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
849
+
850
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
851
+ BLOCK = 128
852
+ num_warps = 4 if head_dim <= 64 else 8
853
+
854
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
855
+ chunks_per_block = BLOCK // chunksize
856
+
857
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
858
+ _fwd_eva_prep_kv_kernel[grid](
859
+ k,
860
+ v,
861
+ param_mu,
862
+ param_phi,
863
+ mask,
864
+ out_rfa_k,
865
+ out_rfa_v,
866
+ softmax_scale,
867
+ k.stride(0), k.stride(1), k.stride(2),
868
+ v.stride(0), v.stride(1), v.stride(2),
869
+ param_mu.stride(1),
870
+ param_phi.stride(1),
871
+ mask_strides[0], mask_strides[1],
872
+ out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
873
+ out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
874
+ nheads,
875
+ seqlen,
876
+ nchunks,
877
+ head_dim,
878
+ chunks_per_block,
879
+ chunksize,
880
+ mask_type,
881
+ BLOCK_HEADDIM,
882
+ BLOCK_N=BLOCK,
883
+ num_warps=num_warps,
884
+ num_stages=1,
885
+ )
886
+ return out_rfa_k, out_rfa_v
887
+
888
+ def triton_eva_prep_kv_bwd(
889
+ d_rfa_k, d_rfa_v,
890
+ k, v, param_mu, param_phi,
891
+ mask,
892
+ rfa_k, rfa_v,
893
+ d_k, d_v, d_param_mu, d_param_phi,
894
+ softmax_scale,
895
+ mask_type,
896
+ chunksize
897
+ ):
898
+ d_rfa_k, d_rfa_v = [
899
+ x if x.stride(-1) == 1 else x.contiguous()
900
+ for x in [d_rfa_k, d_rfa_v]
901
+ ]
902
+
903
+ # shape constraints
904
+ batch, nheads, seqlen, head_dim = k.shape
905
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
906
+ nchunks = seqlen // chunksize
907
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
908
+
909
+ mask_strides = (
910
+ (mask.stride(0), mask.stride(2))
911
+ if mask_type == 1 else
912
+ (0, 0)
913
+ )
914
+
915
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
916
+ BLOCK = 128
917
+ num_warps = 4 if head_dim <= 64 else 8
918
+
919
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
920
+ chunks_per_block = BLOCK // chunksize
921
+
922
+ partial_groups = triton.cdiv(seqlen, BLOCK)
923
+ d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
924
+ d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
925
+ grid = lambda META: (partial_groups, batch * nheads)
926
+ _bwd_eva_prep_kv_kernel[grid](
927
+ rfa_k, # [b, h, c, d]
928
+ rfa_v, # [b, h, c, d]
929
+ k, # [b, h, n, d]
930
+ v, # [b, h, n, d]
931
+ param_mu, # [1, h, 1, 1, d]
932
+ param_phi, # [1, h, 1, 1, d]
933
+ mask, # [b, h, n, 1]
934
+ d_rfa_k, # [b, h, c, d]
935
+ d_rfa_v, # [b, h, c, d]
936
+ d_k, # [b, h, n, d]
937
+ d_v, # [b, h, n, d]
938
+ d_param_mu_partial, # [b, h, g, d]
939
+ d_param_phi_partial, # [b, h, g, d]
940
+ softmax_scale,
941
+ rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
942
+ rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
943
+ k.stride(0), k.stride(1), k.stride(2),
944
+ v.stride(0), v.stride(1), v.stride(2),
945
+ param_mu.stride(1),
946
+ param_phi.stride(1),
947
+ mask_strides[0], mask_strides[1],
948
+ d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
949
+ d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
950
+ d_k.stride(0), d_k.stride(1), d_k.stride(2),
951
+ d_v.stride(0), d_v.stride(1), d_v.stride(2),
952
+ d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
953
+ d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
954
+ nheads,
955
+ seqlen,
956
+ nchunks,
957
+ head_dim,
958
+ chunks_per_block,
959
+ chunksize,
960
+ mask_type,
961
+ BLOCK_HEADDIM,
962
+ BLOCK_N=BLOCK,
963
+ num_warps=num_warps,
964
+ num_stages=1,
965
+ )
966
+ d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
967
+ d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
968
+
969
+
970
+
971
+ class EvaPrepKVFunc(torch.autograd.Function):
972
+ @staticmethod
973
+ def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
974
+ if mask is not None:
975
+ mask_type = 1
976
+ else:
977
+ mask_type = 0
978
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
979
+ k, v, param_mu, param_phi, mask, softmax_scale, chunksize
980
+ )
981
+ ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
982
+ ctx.softmax_scale = softmax_scale
983
+ ctx.chunksize = chunksize
984
+ ctx.mask_type = mask_type
985
+ return rfa_k, rfa_v
986
+
987
+ @staticmethod
988
+ def backward(ctx, d_rfa_k, d_rfa_v):
989
+ k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
990
+ d_k = torch.empty_like(k)
991
+ d_v = torch.empty_like(v)
992
+ d_param_mu = torch.empty_like(param_mu)
993
+ d_param_phi = torch.empty_like(param_phi)
994
+ triton_eva_prep_kv_bwd(
995
+ d_rfa_k, d_rfa_v,
996
+ k, v, param_mu, param_phi,
997
+ mask,
998
+ rfa_k, rfa_v,
999
+ d_k, d_v, d_param_mu, d_param_phi,
1000
+ ctx.softmax_scale,
1001
+ ctx.mask_type,
1002
+ ctx.chunksize
1003
+ )
1004
+ return d_k, d_v, d_param_mu, d_param_phi, None, None, None
1005
+
1006
+ def eva_prep_kv_func_triton(
1007
+ k, v,
1008
+ param_mu, param_phi,
1009
+ mask,
1010
+ softmax_scale=None, chunksize=None
1011
+ ):
1012
+ return EvaPrepKVFunc.apply(
1013
+ k, v,
1014
+ param_mu, param_phi,
1015
+ mask,
1016
+ softmax_scale, chunksize
1017
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/eva_pt_ref.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+
5
+ MASK_MIN_VALUE = -10e10
6
+
7
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
8
+ """
9
+ Rotates half the hidden dims (last dim) of the input.
10
+ Args:
11
+ x: Rotary embedded tensor
12
+ Return:
13
+ Tensor with half of last dim negated and rotated to the front.
14
+ """
15
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
16
+ return torch.cat((-x2, x1), dim=-1)
17
+
18
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
19
+ position_ids: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
22
+
23
+ The legends for dimensions are defined as:
24
+ num_heads: number of attention heads
25
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
26
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
27
+ maximum lenghth, e.g. the length of static sequence length of KV cache
28
+
29
+
30
+ Args:
31
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
32
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
33
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
34
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
35
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
36
+ (batch_size, current_seq_len).
37
+
38
+ Returns:
39
+ Embedded query and key tensor of same size as input.
40
+
41
+ """
42
+ bs, nheads, cur_seq_len, head_dim = q.shape
43
+ assert len(
44
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
45
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
46
+ assert list(k.shape[2:]) == [cur_seq_len,
47
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
48
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
49
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
50
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
51
+
52
+ q_embed = (q * cos) + (rotate_half(q) * sin)
53
+ k_embed = (k * cos) + (rotate_half(k) * sin)
54
+ return q_embed, k_embed
55
+
56
+ def attention_op(
57
+ q,
58
+ k,
59
+ v,
60
+ attn_mask,
61
+ mixedp_attn,
62
+ head_dim_scaling
63
+ ):
64
+ attn = torch.matmul(q, k.transpose(-2, -1))
65
+ if mixedp_attn:
66
+ attn = attn.to(torch.float)
67
+ attn = attn * head_dim_scaling
68
+ if attn_mask is not None:
69
+ attn = attn.masked_fill(attn_mask, MASK_MIN_VALUE)
70
+
71
+ attn_weights = torch.softmax(attn, dim=-1).to(q.dtype)
72
+ attn_output = torch.matmul(attn_weights, v)
73
+ return attn_output
74
+
75
+ def prm_projection(
76
+ x: torch.Tensor,
77
+ projection_matrix: torch.Tensor,
78
+ mixedp_attn: bool = False
79
+ ):
80
+ """
81
+ Constructs nonnegative kernel features for fast softmax attention.
82
+ Args:
83
+ x: input for which features are computed
84
+ projection_matrix: random matrix used to compute features
85
+ Returns:
86
+ Random features for fast attention.
87
+ """
88
+ # x : [..., m, d]
89
+ # proj : [..., r, d]
90
+ scaling_factor = (x.shape[-1] ** -0.5)
91
+ proj_x = torch.matmul(projection_matrix, x.transpose(-1, -2)) # [..., r, m]
92
+ norm = torch.sum(x ** 2, dim=-1).unsqueeze(-2) * 0.5 # [..., 1]
93
+ if mixedp_attn:
94
+ proj_x = proj_x.to(torch.float)
95
+ norm = norm.to(torch.float)
96
+ phi_x = scaling_factor * (proj_x - norm)
97
+ return phi_x
98
+
99
+ class EvaAttention(nn.Module):
100
+ def __init__(self, config, layer_idx: Optional[int] = None):
101
+ super().__init__()
102
+ self.config = config
103
+ self.layer_idx = layer_idx
104
+ self.hidden_size = config.hidden_size
105
+ self.num_heads = config.num_attention_heads
106
+ self.head_dim = self.hidden_size // self.num_heads
107
+ self.head_dim_scaling = self.head_dim ** -0.5
108
+
109
+ self.max_position_embeddings = config.max_position_embeddings
110
+
111
+ if (self.head_dim * self.num_heads) != self.hidden_size:
112
+ raise ValueError(
113
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
114
+ f" and `num_heads`: {self.num_heads})."
115
+ )
116
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
117
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
118
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
119
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
120
+
121
+ self.window_size = config.window_size
122
+
123
+ self.num_chunks = config.num_chunks
124
+ self.chunk_size = config.chunk_size
125
+ if self.chunk_size is not None:
126
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
127
+ # chunk_size overrides the number of landmarks
128
+ self.num_chunks = None
129
+
130
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
131
+ self.random_feature_dim = 1
132
+ self.adaptive_phi = nn.Parameter(
133
+ torch.randn(
134
+ 1,
135
+ self.num_heads,
136
+ 1,
137
+ 1,
138
+ self.head_dim
139
+ ).clamp(-1., 1.) * self.head_dim_scaling
140
+ )
141
+ self.adaptive_mu_k = nn.Parameter(
142
+ torch.randn(
143
+ 1,
144
+ self.num_heads,
145
+ 1,
146
+ 1,
147
+ self.head_dim
148
+ ).clamp(-1., 1.) * self.head_dim_scaling
149
+ )
150
+
151
+ def _generate_feature_map(self, rf_q, rf_k, rf_v):
152
+ rf_k_logits = torch.sum(self.adaptive_mu_k.to(rf_k.dtype) * rf_k, dim=-1, keepdim=True) # b h c m 1
153
+ if self.config.mixedp_attn:
154
+ rf_k_logits = rf_k_logits.to(torch.float)
155
+ rf_k_weights = torch.softmax(rf_k_logits, dim=-2).to(rf_k.dtype)
156
+ rf_k_bar = torch.sum(rf_k_weights * rf_k, dim=-2)
157
+ weights = self.adaptive_phi.to(rf_k.dtype)
158
+ return weights, rf_k_bar
159
+
160
+ def _calculate_chunk_rfa_cache(self, rf_q, rf_k, rf_v, weights, rf_mask=None):
161
+ proj_x = torch.sum(weights * rf_k, dim=-1, keepdim=True)
162
+ norm = torch.sum(rf_k ** 2, dim=-1, keepdim=True) * 0.5 # [..., 1]
163
+ if self.config.mixedp_attn:
164
+ proj_x = proj_x.to(torch.float)
165
+ norm = norm.to(torch.float)
166
+ log_phi_k = self.head_dim_scaling * (proj_x - norm)
167
+
168
+ if rf_mask is not None:
169
+ log_phi_k = log_phi_k.masked_fill(rf_mask, MASK_MIN_VALUE)
170
+
171
+ # [b, h, c, m, r]
172
+ softmax_phi_k = torch.softmax(log_phi_k, dim=-2).to(rf_k.dtype)
173
+ softmax_phi_k_v = torch.sum(softmax_phi_k * rf_v, dim=-2)
174
+ # [b, h, c, r, m] [b, h, c, m, d] -> [b, h, c, r, d]
175
+ # softmax_phi_k_v = torch.matmul(softmax_phi_k.transpose(-1, -2), rf_v).squeeze(-2)
176
+ log_sum_phi_k = None
177
+ return softmax_phi_k_v, log_sum_phi_k
178
+
179
+ def _calculate_chunk_rfa(self, q, softmax_phi_k_v, log_sum_phi_k, weights):
180
+ if self.random_feature_dim == 1:
181
+ # when r = 1, the snis weights becomes 1, so this takes no effect
182
+ # [b, h, c, r, d] -> [b, h, c, d]
183
+ return softmax_phi_k_v
184
+ else:
185
+ # [b, h, c, r, d] [b, h, 1, s, d] -> [b, h, c, r, s]
186
+ log_phi_q = prm_projection(q.unsqueeze(-3), weights, self.config.mixedp_attn)
187
+ # [b, h, c, r, s] [b, h, c, r, 1] -> [b, h, c, r, s]
188
+ sniw = torch.softmax(log_phi_q + log_sum_phi_k, dim=-1).to(q.dtype)
189
+ # [b, h, c, r, s] [b, h, c, r, d] -> [b, h, c, s, d] -> [b, h, s, c, d]
190
+ rfa_per_chunk = torch.matmul(sniw.transpose(-1, -2), softmax_phi_k_v).transpose(-3, -2)
191
+ return rfa_per_chunk
192
+
193
+ def window_partition(self, x, window_size=None):
194
+ window_size = window_size if window_size is not None else self.window_size
195
+
196
+ gw, d = x.shape[-2:]
197
+ leading_dims = x.shape[:-2]
198
+ n_groups = gw // window_size
199
+ return x.reshape(*leading_dims, n_groups, window_size, d)
200
+
201
+ def window_merge(self, x, window_size=None):
202
+ g, w, d = x.shape[-3:]
203
+ leading_dims = x.shape[:-3]
204
+ return x.reshape(*leading_dims, g * w, d)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
210
+ position_ids: Optional[torch.LongTensor] = None,
211
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
212
+ output_attentions: bool = False,
213
+ use_cache: bool = False,
214
+ cos: Optional[torch.Tensor] = None,
215
+ sin: Optional[torch.Tensor] = None,
216
+ multibyte_decoding: Optional[bool] = False,
217
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
218
+ assert not output_attentions
219
+ bsz, q_len, _ = hidden_states.size()
220
+
221
+ ############################################
222
+ # initialize past states if not provided
223
+ ############################################
224
+ if use_cache and past_key_value is None:
225
+ raise ValueError
226
+ if use_cache and multibyte_decoding:
227
+ raise NotImplementedError("Multibyte decoding is not supported for PyTorch native implementation")
228
+ # assert isinstance(attention_mask, tuple)
229
+ if len(attention_mask) == 4:
230
+ assert use_cache
231
+ prev_causal_mask, cur_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
232
+ elif len(attention_mask) == 3:
233
+ assert not use_cache
234
+ window_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
235
+ else:
236
+ raise NotImplementedError("Only attention-mask tuple with length 2 or 3 is supported")
237
+
238
+ ############################################
239
+ # compute q, k, v from hidden states
240
+ ############################################
241
+ # [b, h, q_len, d]
242
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
243
+ # [b, h, kv_len, d]
244
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
245
+ # [b, h, kv_len, d]
246
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
247
+
248
+ if use_cache:
249
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
250
+
251
+ ############################################
252
+ # apply rotary positional embeddings to q, k
253
+ ############################################
254
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
255
+
256
+ ############################################
257
+ # compute q, k, v stats for the local window
258
+ ############################################
259
+ if use_cache:
260
+ (prev_w_q, prev_w_k, prev_w_v), (cur_w_q, cur_w_k, cur_w_v) = past_key_value.update_singletons(
261
+ q,
262
+ k,
263
+ v,
264
+ self.layer_idx,
265
+ self.window_size,
266
+ )
267
+ else:
268
+ prev_w_q = self.window_partition(q) # [b, h, w, i, d]
269
+ prev_w_k = self.window_partition(k) # [b, h, w, j, d]
270
+ prev_w_v = self.window_partition(v) # [b, h, w, j, d]
271
+ # during training, we assume window_size divides seq_len so no remainders
272
+ cur_w_q = cur_w_k = cur_w_v = None
273
+
274
+ ############################################
275
+ # compute q, k, v stats for chunk-level RFAs
276
+ ############################################
277
+ if use_cache:
278
+ dump_q, dump_k, dump_v = past_key_value.update_chunks(q, k, v, self.layer_idx, self.chunk_size)
279
+ else:
280
+ dump_q, dump_k, dump_v = q, k, v
281
+
282
+ if use_cache:
283
+ prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask = past_key_value.update_mask(
284
+ prev_s_mask=prev_causal_mask,
285
+ cur_s_mask=cur_causal_mask,
286
+ chunk_mask=chunk_causal_mask,
287
+ rf_mask=intra_chunk_mask,
288
+ layer_idx=self.layer_idx,
289
+ window_size=self.window_size,
290
+ chunk_size=self.chunk_size,
291
+ )
292
+ else:
293
+ prev_s_mask = self.window_partition(prev_causal_mask) # [1, 1, w, i, j]
294
+ cur_s_mask = None
295
+ prev_chunk_mask = self.window_partition(chunk_causal_mask)
296
+ cur_chunk_mask = None
297
+ dump_rf_mask = intra_chunk_mask
298
+ if prev_s_mask.shape[-3] == 1:
299
+ # need to expand
300
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
301
+
302
+ if (
303
+ dump_q is not None and
304
+ dump_k is not None and
305
+ dump_v is not None
306
+ ):
307
+ # [b, h, c, j, d]
308
+ rf_q = self.window_partition(dump_q, window_size=self.chunk_size)
309
+ # [b, h, c, j, d]
310
+ rf_k = self.window_partition(dump_k, window_size=self.chunk_size)
311
+ # [b, h, c, j, d]
312
+ rf_v = self.window_partition(dump_v, window_size=self.chunk_size)
313
+
314
+ if dump_rf_mask is not None:
315
+ rf_mask = self.window_partition(dump_rf_mask, window_size=self.chunk_size)
316
+ rf_q = rf_q.masked_fill(rf_mask, 0.)
317
+ rf_k = rf_k.masked_fill(rf_mask, 0.)
318
+ rf_v = rf_v.masked_fill(rf_mask, 0.)
319
+ else:
320
+ rf_mask = None
321
+ else:
322
+ rf_q = None
323
+ rf_k = None
324
+ rf_v = None
325
+ rf_mask = None
326
+
327
+
328
+ if rf_q is not None:
329
+ # import pdb; pdb.set_trace()
330
+ weights, rf_k_bar = self._generate_feature_map(rf_q, rf_k, rf_v)
331
+ softmax_phi_k_v, log_sum_phi_k = self._calculate_chunk_rfa_cache(rf_q, rf_k, rf_v, weights, rf_mask=rf_mask)
332
+ if use_cache:
333
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.update_chunk_rfas(
334
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar, self.layer_idx, 1
335
+ )
336
+ elif use_cache:
337
+ weights = None
338
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.get_chunk_rfas(self.layer_idx)
339
+ else:
340
+ weights = None
341
+ softmax_phi_k_v = None
342
+ log_sum_phi_k = None
343
+ rf_k_bar = None
344
+
345
+ if rf_k_bar is not None:
346
+ rfa_per_chunk = self._calculate_chunk_rfa(q, softmax_phi_k_v, log_sum_phi_k, weights)
347
+ ############################################
348
+ # compute meta-attention weights for
349
+ # - group-wise RFAs and
350
+ # - singletons (equivalent to exact local attention)
351
+ ############################################
352
+ if prev_w_k is not None:
353
+ if rf_k_bar is not None:
354
+ num_windows = prev_w_k.shape[-3]
355
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
356
+ # -> [b, h, 1, c, d] -> [b, h, w, c, d]
357
+ prev_rf_k_bar = rf_k_bar.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
358
+ prev_rfa_per_chunk = rfa_per_chunk.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
359
+ prev_agg_k = torch.cat([prev_w_k, prev_rf_k_bar], dim=-2)
360
+ prev_agg_v = torch.cat([prev_w_v, prev_rfa_per_chunk], dim=-2)
361
+
362
+ prev_attn_mask = torch.cat([prev_s_mask, prev_chunk_mask], dim=-1)
363
+ else:
364
+ prev_agg_k = prev_w_k
365
+ prev_agg_v = prev_w_v
366
+ prev_attn_mask = prev_s_mask
367
+
368
+ prev_attn_output = attention_op(
369
+ q=prev_w_q,
370
+ k=prev_agg_k,
371
+ v=prev_agg_v,
372
+ attn_mask=prev_attn_mask,
373
+ mixedp_attn=self.config.mixedp_attn,
374
+ head_dim_scaling=self.head_dim_scaling
375
+ )
376
+ prev_attn_output = self.window_merge(prev_attn_output)
377
+
378
+ if cur_w_k is not None:
379
+ if rf_k_bar is not None:
380
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
381
+ # cur_w_k and cur_w_v also has shape [b, h, m, d]
382
+ cur_agg_k = torch.cat([cur_w_k, rf_k_bar], dim=-2)
383
+ cur_agg_v = torch.cat([cur_w_v, rfa_per_chunk], dim=-2)
384
+
385
+ cur_attn_mask = torch.cat([cur_s_mask, cur_chunk_mask], dim=-1)
386
+ else:
387
+ cur_agg_k = cur_w_k
388
+ cur_agg_v = cur_w_v
389
+ cur_attn_mask = cur_s_mask
390
+
391
+ cur_attn_output = attention_op(
392
+ q=cur_w_q,
393
+ k=cur_agg_k,
394
+ v=cur_agg_v,
395
+ attn_mask=cur_attn_mask,
396
+ mixedp_attn=self.config.mixedp_attn,
397
+ head_dim_scaling=self.head_dim_scaling
398
+ )
399
+
400
+ if prev_w_k is not None and cur_w_k is not None:
401
+ attn_output = torch.cat([prev_attn_output, cur_attn_output], dim=-2)
402
+ elif prev_w_k is not None:
403
+ attn_output = prev_attn_output
404
+ elif cur_w_k is not None:
405
+ attn_output = cur_attn_output
406
+ else:
407
+ raise ValueError("There must be some bug")
408
+
409
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
410
+ raise ValueError(
411
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
412
+ f" {attn_output.size()}"
413
+ )
414
+
415
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
416
+ attn_output = self.o_proj(attn_output)
417
+
418
+ attn_weights = None
419
+
420
+ return attn_output, attn_weights, past_key_value
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.47.1"
7
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/image_processing_evabyte.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Image processor class for EvaByte."""
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ import io
7
+ from transformers.image_processing_utils import BaseImageProcessor
8
+ from transformers.image_utils import (
9
+ ImageInput,
10
+ PILImageResampling,
11
+ valid_images,
12
+ validate_preprocess_arguments,
13
+ )
14
+ from PIL import Image
15
+
16
+ def _get_qtable_bytes():
17
+ return {
18
+ 5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
19
+ 10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
20
+ 15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
21
+ 20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
22
+ 25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
23
+ 30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
24
+ 50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
25
+ 75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
26
+ 95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
27
+ 100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
28
+ }
29
+
30
+
31
+ def _resize_if_exceeding_max_len(
32
+ width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
33
+ ) -> Tuple[int, int]:
34
+ """
35
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
36
+
37
+ Args:
38
+ height (`int`):
39
+ Height of the input image.
40
+ width (`int`):
41
+ Width of the input image.
42
+ max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
43
+ Defines the maximum dimensions of the image.
44
+
45
+ Returns:
46
+ The output size of the image after resizing.
47
+ """
48
+ max_len = max(height, width) if max_len is None else max_len
49
+ aspect_ratio = width / height
50
+
51
+ if width >= height and width > max_len:
52
+ width = max_len
53
+ height = int(width / aspect_ratio)
54
+ if height % 2 != 0:
55
+ height += 1
56
+ elif height > width and height > max_len:
57
+ height = max_len
58
+ width = int(height * aspect_ratio)
59
+ if width % 2 != 0:
60
+ width += 1
61
+
62
+ # Avoid resizing to a size smaller than 1
63
+ height = max(height, min_len)
64
+ width = max(width, min_len)
65
+ return width, height
66
+
67
+ class EvaByteImageProcessor(BaseImageProcessor):
68
+
69
+ model_input_names = []
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
75
+ size: Dict[str, int] = None,
76
+ do_convert_rgb: bool = True,
77
+ jpeg_quality: int = 25,
78
+ jpeg_subsampling: str = "4:2:0",
79
+ jpeg_streamtype: str = 2,
80
+ jpeg_restart_marker_blocks: int = 1,
81
+ **kwargs,
82
+ ) -> None:
83
+ super().__init__(**kwargs)
84
+ self.do_resize = do_resize
85
+ self.resample = resample
86
+ self.size = size if size is not None else {"longest_edge": 384}
87
+ self.do_convert_rgb = do_convert_rgb
88
+ self.jpeg_quality = jpeg_quality
89
+ self.jpeg_subsampling = jpeg_subsampling
90
+ self.jpeg_streamtype = jpeg_streamtype
91
+ self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
92
+
93
+ def jpeg_encode(
94
+ self,
95
+ image,
96
+ jpeg_quality,
97
+ jpeg_subsampling,
98
+ jpeg_streamtype,
99
+ jpeg_restart_marker_blocks,
100
+ ):
101
+ with io.BytesIO() as output:
102
+ image.save(
103
+ output,
104
+ format="JPEG",
105
+ quality=jpeg_quality,
106
+ subsampling=jpeg_subsampling,
107
+ streamtype=jpeg_streamtype,
108
+ restart_marker_blocks=jpeg_restart_marker_blocks
109
+ )
110
+ jpeg_bytes = output.getvalue()
111
+ return jpeg_bytes
112
+
113
+ def jpeg_merge_qtables(
114
+ self,
115
+ image_bytes,
116
+ jpeg_quality=None,
117
+ ):
118
+ if jpeg_quality is None:
119
+ jpeg_quality = self.jpeg_quality
120
+ qtable_bytes = _get_qtable_bytes()[jpeg_quality]
121
+ return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
122
+
123
+ def resize(
124
+ self,
125
+ image: Image,
126
+ size: Dict[str, int],
127
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
128
+ ) -> Image:
129
+ if "longest_edge" in size:
130
+ width, height = image.size
131
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
132
+ width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
133
+ size = (width, height)
134
+ elif "width" in size and "height" in size:
135
+ size = (size["width"], size["height"])
136
+ else:
137
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
138
+ resized_image = image.resize(size, resample=resample)
139
+ return resized_image
140
+
141
+ def preprocess(
142
+ self,
143
+ images: ImageInput,
144
+ do_resize: bool = None,
145
+ resample = None,
146
+ size: Dict[str, int] = None,
147
+ do_convert_rgb: bool = None,
148
+ jpeg_quality: int = None,
149
+ jpeg_subsampling: str = None,
150
+ jpeg_streamtype: str = None,
151
+ jpeg_restart_marker_blocks: int = None,
152
+ ):
153
+ do_resize = do_resize if do_resize is not None else self.do_resize
154
+ size = size if size is not None else self.size
155
+ resample = resample if resample is not None else self.resample
156
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
157
+
158
+ jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
159
+ jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
160
+ jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
161
+ jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
162
+
163
+ if images is not None and not valid_images(images):
164
+ raise ValueError(
165
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
166
+ "torch.Tensor, tf.Tensor or jax.ndarray."
167
+ )
168
+
169
+ validate_preprocess_arguments(
170
+ do_resize=do_resize,
171
+ size=size,
172
+ resample=resample,
173
+ )
174
+ images_list = images
175
+ if do_convert_rgb:
176
+ images_list = [
177
+ [
178
+ image.convert("RGB") for image in images
179
+ ]
180
+ for images in images_list
181
+ ]
182
+
183
+ if do_resize:
184
+ images_list = [
185
+ [
186
+ self.resize(image=image, size=size, resample=resample)
187
+ for image in images
188
+ ]
189
+ for images in images_list
190
+ ]
191
+
192
+ jpeg_bytes = [
193
+ [
194
+ self.jpeg_encode(
195
+ image,
196
+ jpeg_quality,
197
+ jpeg_subsampling,
198
+ jpeg_streamtype,
199
+ jpeg_restart_marker_blocks
200
+ ) for image in images
201
+ ]
202
+ for images in images_list
203
+ ]
204
+ return jpeg_bytes
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/model.safetensors.index.json ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 57058938880
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
10
+ "model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
11
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
24
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
25
+ "model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
26
+ "model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
27
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
28
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
29
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
30
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
31
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
32
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
33
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
34
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
35
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
36
+ "model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
37
+ "model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
38
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
39
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
40
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
41
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
42
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
43
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
44
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
45
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
46
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
47
+ "model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
48
+ "model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
49
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
50
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
51
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
52
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
53
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
54
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
55
+ "model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
56
+ "model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
57
+ "model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
58
+ "model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
59
+ "model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
60
+ "model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
61
+ "model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
62
+ "model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
63
+ "model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
64
+ "model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
65
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
66
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
67
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
68
+ "model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
69
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
70
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
71
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
72
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
73
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
74
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
75
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
76
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
77
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
90
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
91
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
92
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
93
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
94
+ "model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
95
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
96
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
97
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
98
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
99
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
100
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
101
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
102
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
103
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
104
+ "model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
105
+ "model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
106
+ "model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
107
+ "model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
108
+ "model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
109
+ "model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
110
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
111
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
112
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
113
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
114
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
115
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
116
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
119
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
126
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
127
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
143
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
144
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
145
+ "model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
146
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
147
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
148
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
149
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
150
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
151
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
152
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
153
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
154
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
155
+ "model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
156
+ "model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
157
+ "model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
158
+ "model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
159
+ "model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
160
+ "model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
161
+ "model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
162
+ "model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
163
+ "model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
164
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
165
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
166
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
167
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
168
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
169
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
170
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
171
+ "model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
172
+ "model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
173
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
174
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
175
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
176
+ "model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
177
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
178
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
184
+ "model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
185
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
190
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
192
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
194
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
195
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
196
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
197
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
198
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
201
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
202
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
203
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
204
+ "model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
205
+ "model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
206
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
213
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
214
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
215
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
221
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
225
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
226
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
227
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
228
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
229
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
230
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
231
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
232
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
233
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
238
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
239
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
240
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
241
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
242
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
243
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
244
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
245
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
246
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
247
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
248
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
249
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
250
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
251
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
256
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
257
+ "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
258
+ "model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
259
+ "model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
260
+ "model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
261
+ "model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
262
+ "model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
263
+ "model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
264
+ "model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
265
+ "model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
266
+ "model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
267
+ "model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
268
+ "model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
269
+ "model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
270
+ "model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
271
+ "model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
272
+ "model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
273
+ "model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
274
+ "model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
275
+ "model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
276
+ "model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
277
+ "model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
278
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
279
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
280
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
281
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
282
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
285
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
286
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
287
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
288
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
289
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
290
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
291
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
293
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
294
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
297
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
298
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
299
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
300
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
301
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
302
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
303
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
304
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
305
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
306
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
307
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
308
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
309
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
310
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
311
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
312
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
313
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
314
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
315
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
316
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
317
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
318
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
319
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
320
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
321
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
322
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
323
+ "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
324
+ "model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
325
+ "model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
326
+ "model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
327
+ "model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
328
+ "model.norm.weight": "model-00003-of-00003.safetensors",
329
+ "lm_head.weight": "model-00003-of-00003.safetensors",
330
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
331
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
332
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
333
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
334
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
335
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
336
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
337
+ "model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
338
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
339
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
340
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
341
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
342
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
343
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
344
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
345
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
346
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
347
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
348
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
349
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
350
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
351
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
352
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
353
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
354
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
355
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
356
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
357
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
358
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
359
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
360
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
361
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
362
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
363
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
364
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
365
+ "model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
366
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
367
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
368
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
369
+ "model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
370
+ "model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
371
+ "model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
372
+ "model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
373
+ "model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
374
+ "model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
375
+ "model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
376
+ "model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
377
+ "model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
378
+ "model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
379
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
380
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
381
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
382
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
383
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
384
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
385
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
386
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
387
+ "model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
388
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
389
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
390
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
391
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
392
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
393
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
394
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
395
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
396
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
398
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
399
+ "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
400
+ "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
401
+ "model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
402
+ "model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
403
+ "model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
404
+ "model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
405
+ "model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
406
+ "model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
407
+ "model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
408
+ "model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
409
+ "model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
410
+ "model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
411
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
412
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
413
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
414
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
415
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
416
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
417
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
418
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
419
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
420
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
421
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
422
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
423
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
424
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
425
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
426
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
427
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
428
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
429
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
430
+ "model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
431
+ "model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
432
+ "model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
433
+ "model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
434
+ "model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
435
+ "model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
436
+ "model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
437
+ "model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
438
+ "model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
439
+ "model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
440
+ "model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
441
+ "model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
442
+ "model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
443
+ "model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
444
+ "model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
445
+ "model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
446
+ "model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
447
+ "model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
448
+ "model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
449
+ }
450
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/modeling_evabyte.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from .configuration_evabyte import EvaByteConfig
17
+ from .multibyte_decoding_evabyte import MultiByteDecodingMixin
18
+ try:
19
+ import triton
20
+ USE_TRITON_IMPL = True
21
+ from .eva import EvaAttention
22
+ from .eva_agg_kernel import triton_eva_agg_fwd
23
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
24
+ except ImportError:
25
+ USE_TRITON_IMPL = False
26
+ print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
27
+ from .eva_pt_ref import EvaAttention
28
+ from .eva_cache import EvaCache, EvaStaticCacheForTriton
29
+
30
+ MASK_MIN_VALUE = -10e10
31
+
32
+ def prepare_eva_attention_mask(
33
+ seq_len,
34
+ device,
35
+ chunk_size,
36
+ window_size,
37
+ use_cache=False,
38
+ cache=None
39
+ ):
40
+ """
41
+ Prepare attention masks for EVA.
42
+
43
+ """
44
+ chunk_causal_mask = None
45
+ window_causal_mask = None
46
+ if use_cache:
47
+ cached_seq_len = cache.get_seq_length()
48
+ total_seq_len = seq_len + cached_seq_len
49
+ # cached_seq_len will be 0 during prefilling
50
+ # padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
51
+ padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
52
+ num_chunks = padded_seq_len // chunk_size
53
+ else:
54
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
55
+ assert seq_len % chunk_size == 0
56
+ num_chunks = seq_len // chunk_size
57
+
58
+ assert seq_len % window_size == 0
59
+
60
+ # create causal mask
61
+ ################################
62
+ # generate chunked causal masks
63
+ ################################
64
+ # [b, h, j, c, c]
65
+ chunks_per_window = window_size // chunk_size
66
+ if num_chunks >= chunks_per_window:
67
+ chunk_causal_mask = torch.ones(
68
+ (chunk_size, num_chunks, num_chunks),
69
+ device=device,
70
+ dtype=torch.bool
71
+ ).triu(0)
72
+
73
+ num_blocks = num_chunks // chunks_per_window
74
+ chunk_causal_mask = chunk_causal_mask.reshape(
75
+ chunk_size,
76
+ num_blocks,
77
+ chunks_per_window,
78
+ num_blocks,
79
+ chunks_per_window
80
+ ).transpose(-2, -3)
81
+
82
+ block_diag_zero = (
83
+ torch.eye(num_blocks, device=device, dtype=torch.bool)
84
+ .unsqueeze(-1)
85
+ .unsqueeze(-1)
86
+ .unsqueeze(0)
87
+ )
88
+
89
+ # Set diagonal blocks to zero
90
+ chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
91
+
92
+ # Reshape back to original size
93
+ chunk_causal_mask = (
94
+ chunk_causal_mask
95
+ .transpose(-2, -3)
96
+ .reshape(chunk_size, num_chunks, num_chunks)
97
+ .transpose(-2, -3)
98
+ .reshape(chunk_size * num_chunks, num_chunks)
99
+ .unsqueeze(0)
100
+ .unsqueeze(0)
101
+ )
102
+ else:
103
+ chunk_causal_mask = torch.ones(
104
+ (1, 1, chunk_size, num_chunks, num_chunks),
105
+ device=device,
106
+ dtype=torch.bool,
107
+ ).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
108
+ chunk_causal_mask = chunk_causal_mask.reshape(
109
+ 1, 1, chunk_size * num_chunks, num_chunks
110
+ ) # [1, 1, n, c]
111
+
112
+ if use_cache:
113
+ chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
114
+
115
+ window_causal_mask = torch.ones(
116
+ (1, 1, 1, window_size, window_size),
117
+ device=device
118
+ ).triu(1).to(torch.bool)
119
+ return (chunk_causal_mask, window_causal_mask)
120
+
121
+ def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
122
+ assert dim < 0 # only accept ``dim'' index in a reverse manner
123
+ seqlen = int(tensor.shape[dim])
124
+ m = seqlen / multiple
125
+ if m.is_integer():
126
+ if create_mask:
127
+ return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
128
+ else:
129
+ return tensor
130
+ remainder = math.ceil(m) * multiple - seqlen
131
+ pad_offset = (0,) * (-1 - dim) * 2
132
+ if left_padding:
133
+ padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
134
+ else:
135
+ padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
136
+ if create_mask:
137
+ # assume dim 0 is the batch size
138
+ padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
139
+ if left_padding:
140
+ padding_mask[:, :remainder] = False
141
+ else:
142
+ padding_mask[:, -remainder:] = False
143
+ return padded_res, padding_mask
144
+ else:
145
+ return padded_res
146
+
147
+ class EvaByteRMSNorm(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.fp32_ln = True
152
+ self.variance_epsilon = config.rms_norm_eps
153
+ self.add_unit_offset = config.norm_add_unit_offset
154
+ if self.add_unit_offset:
155
+ self.weight = nn.Parameter(torch.zeros(config.hidden_size))
156
+ else:
157
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
+
159
+ def forward(self, hidden_states):
160
+ _hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
161
+
162
+ variance = _hidden_states.pow(2).mean(-1, keepdim=True)
163
+ _hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
+ if self.add_unit_offset:
165
+ return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
166
+ else:
167
+ return (self.weight * _hidden_states).type_as(hidden_states)
168
+
169
+ class EvaByteRotaryEmbedding(torch.nn.Module):
170
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
171
+ super().__init__()
172
+
173
+ self.dim = dim
174
+ self.max_position_embeddings = max_position_embeddings
175
+ self.base = base
176
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
+
179
+ self._set_cos_sin_cache(seq_len=max_position_embeddings,
180
+ device=self.inv_freq.device,
181
+ dtype=torch.get_default_dtype())
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
186
+
187
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
188
+ emb = torch.cat((freqs, freqs), dim=-1)
189
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
190
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
191
+
192
+
193
+ def forward(self, x, seq_len=None):
194
+ # x: [bs, num_attention_heads, seq_len, head_size]
195
+ if seq_len > self.max_seq_len_cached:
196
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
197
+
198
+ # return (
199
+ # self.cos_cached[:seq_len].to(dtype=x.dtype),
200
+ # self.sin_cached[:seq_len].to(dtype=x.dtype),
201
+ # )
202
+ if seq_len < self.max_seq_len_cached:
203
+ cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
204
+ sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
205
+ else:
206
+ cos_slice = self.cos_cached
207
+ sin_slice = self.sin_cached
208
+
209
+ return (
210
+ cos_slice.to(dtype=x.dtype),
211
+ sin_slice.to(dtype=x.dtype),
212
+ )
213
+
214
+
215
+
216
+ class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
217
+ """EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
218
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
219
+ self.scaling_factor = scaling_factor
220
+ super().__init__(dim, max_position_embeddings, base, device)
221
+
222
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
223
+ self.max_seq_len_cached = seq_len
224
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
225
+ t = t / self.scaling_factor
226
+
227
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
228
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
229
+ emb = torch.cat((freqs, freqs), dim=-1)
230
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
231
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
232
+
233
+
234
+ class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
235
+ """EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
236
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
237
+ self.scaling_factor = scaling_factor
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+
243
+ if seq_len > self.max_position_embeddings:
244
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
245
+ (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
246
+ inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+
249
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
250
+
251
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
252
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
253
+ emb = torch.cat((freqs, freqs), dim=-1)
254
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
255
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
256
+
257
+
258
+ class EvaByteMLP(nn.Module):
259
+ def __init__(self, config, layer_idx: int = None):
260
+ super().__init__()
261
+ self.hidden_size = config.hidden_size
262
+ self.intermediate_size = config.intermediate_size
263
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
264
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
265
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
266
+ self.act_fn = ACT2FN[config.hidden_act]
267
+ self.layer_idx = layer_idx
268
+ self.config = config
269
+
270
+ def forward(self, x):
271
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
272
+ return down_proj
273
+
274
+ class EvaByteDecoderLayer(nn.Module):
275
+ def __init__(self, config: EvaByteConfig, layer_idx: int = None):
276
+ super().__init__()
277
+ self.config = config
278
+ self.hidden_size = config.hidden_size
279
+ self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
280
+ self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
281
+ self.input_layernorm = EvaByteRMSNorm(config)
282
+ self.post_attention_layernorm = EvaByteRMSNorm(config)
283
+
284
+ def forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ attention_mask: Optional[torch.Tensor] = None,
288
+ position_ids: Optional[torch.LongTensor] = None,
289
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
290
+ output_attentions: Optional[bool] = False,
291
+ use_cache: Optional[bool] = False,
292
+ cos: Optional[torch.Tensor] = None,
293
+ sin: Optional[torch.Tensor] = None,
294
+ multibyte_decoding: Optional[bool] = False,
295
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
296
+ residual = hidden_states
297
+ if self.config.fp32_skip_add:
298
+ residual = residual.float()
299
+
300
+ hidden_states = self.input_layernorm(hidden_states)
301
+
302
+ # Self Attention
303
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
304
+ attention_mask=attention_mask,
305
+ position_ids=position_ids,
306
+ past_key_value=past_key_value,
307
+ output_attentions=output_attentions,
308
+ use_cache=use_cache,
309
+ cos=cos,
310
+ sin=sin,
311
+ multibyte_decoding=multibyte_decoding)
312
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
313
+
314
+ # Fully Connected
315
+ residual = hidden_states
316
+ if self.config.fp32_skip_add:
317
+ residual = residual.float()
318
+ hidden_states = self.post_attention_layernorm(hidden_states)
319
+ hidden_states = self.mlp(hidden_states)
320
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
321
+
322
+ outputs = (hidden_states, )
323
+
324
+ if output_attentions:
325
+ outputs += (self_attn_weights, )
326
+
327
+ if use_cache:
328
+ outputs += (present_key_value, )
329
+ return outputs
330
+
331
+ class EvaBytePreTrainedModel(PreTrainedModel):
332
+ config_class = EvaByteConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["EvaByteDecoderLayer"]
336
+ _skip_keys_device_placement = "past_key_values"
337
+
338
+ def _init_weights(self, module):
339
+ std = getattr(self.config, "initializer_range", 0.02)
340
+ if isinstance(module, nn.Linear):
341
+ module.weight.data.normal_(mean=0.0, std=std)
342
+ if module.bias is not None:
343
+ module.bias.data.zero_()
344
+ elif isinstance(module, nn.Embedding):
345
+ module.weight.data.normal_(mean=0.0, std=std)
346
+ if module.padding_idx is not None:
347
+ module.weight.data[module.padding_idx].zero_()
348
+
349
+ def _set_gradient_checkpointing(self, module, value=False):
350
+ if isinstance(module, EvaByteModel):
351
+ module.gradient_checkpointing = value
352
+
353
+ class EvaByteModel(EvaBytePreTrainedModel):
354
+ """
355
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
356
+
357
+ Args:
358
+ config: EvaByteConfig
359
+ """
360
+ def __init__(self, config: EvaByteConfig):
361
+ super().__init__(config)
362
+ self.padding_idx = config.pad_token_id
363
+ self.vocab_size = config.vocab_size
364
+ self.hidden_size = config.hidden_size
365
+ self.num_heads = config.num_attention_heads
366
+ self.head_dim = self.hidden_size // self.num_heads
367
+ self.max_position_embeddings = self.config.max_position_embeddings
368
+
369
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
370
+ self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
371
+ self.norm = EvaByteRMSNorm(config)
372
+
373
+ self.gradient_checkpointing = False
374
+ self.rope = config.rope_theta
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+ self._init_rope()
378
+
379
+ def _init_rope(self):
380
+ if self.config.rope_scaling is None:
381
+ self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
382
+ max_position_embeddings=self.max_position_embeddings,
383
+ base=self.rope)
384
+ else:
385
+ scaling_type = self.config.rope_scaling["type"]
386
+ scaling_factor = self.config.rope_scaling["factor"]
387
+ if scaling_type == "linear":
388
+ self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
389
+ self.head_dim,
390
+ max_position_embeddings=self.max_position_embeddings,
391
+ scaling_factor=scaling_factor,
392
+ base=self.rope)
393
+ elif scaling_type == "dynamic":
394
+ self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
395
+ self.head_dim,
396
+ max_position_embeddings=self.max_position_embeddings,
397
+ scaling_factor=scaling_factor,
398
+ base=self.rope)
399
+ else:
400
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
401
+
402
+ def get_input_embeddings(self):
403
+ return self.embed_tokens
404
+
405
+ def set_input_embeddings(self, value):
406
+ self.embed_tokens = value
407
+
408
+ def _helper_padding_mask(
409
+ self,
410
+ padding_mask,
411
+ causal_mask
412
+ ):
413
+ padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
414
+ return torch.logical_or(padding_mask, causal_mask)
415
+
416
+ def _prepare_eva_generation_attn_mask_triton(
417
+ self,
418
+ attention_mask,
419
+ input_ids,
420
+ use_cache,
421
+ past_key_values
422
+ ):
423
+ batch_size, seq_len = input_ids.shape
424
+ if use_cache and past_key_values.get_seq_length() > 0:
425
+ # decoding phase
426
+ if past_key_values.rf_mask[0] is not None:
427
+ cur_rf_mask = torch.zeros(
428
+ (batch_size, 1, seq_len, 1),
429
+ dtype=past_key_values.rf_mask[0].dtype,
430
+ device=past_key_values.rf_mask[0].device
431
+ )
432
+ else:
433
+ cur_rf_mask = None
434
+
435
+ if past_key_values.s_mask[0] is not None:
436
+ cur_s_mask = torch.zeros(
437
+ (batch_size, 1, seq_len, 1),
438
+ dtype=past_key_values.s_mask[0].dtype,
439
+ device=past_key_values.s_mask[0].device
440
+ )
441
+ else:
442
+ cur_s_mask = None
443
+
444
+ seen_tokens = past_key_values.get_seq_length()
445
+ if seen_tokens <= self.config.window_size:
446
+ rfa_chunks_dummy_mask = None
447
+ else:
448
+ if cur_s_mask is not None:
449
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
450
+ # the ongoing decoding step would be (seen_seq_len + 1)-th token
451
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
452
+ rfa_chunks_dummy_mask = torch.zeros(
453
+ (batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
454
+ dtype=past_key_values.s_mask[0].dtype,
455
+ device=past_key_values.s_mask[0].device
456
+ )
457
+ else:
458
+ rfa_chunks_dummy_mask = None
459
+ # rf_mask and cur_mask are 0s because we do not want to mask them
460
+ return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
461
+
462
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
463
+ # convert 0 -> padding to 1 -> padding
464
+ padded_attention_mask = pad_to_multiple(
465
+ attention_mask,
466
+ self.config.window_size,
467
+ dim=-1,
468
+ value=0,
469
+ create_mask=False,
470
+ left_padding=False
471
+ )
472
+ # convert 0 -> padding to 1 -> padding
473
+ padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
474
+ # [b, 1, w, j, 1]
475
+ padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
476
+ # [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
477
+ w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
478
+ w_causal_mask = torch.ones(
479
+ (1, 1, 1, self.config.window_size, self.config.window_size),
480
+ device=input_ids.device
481
+ ).triu(1).to(torch.bool)
482
+ s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
483
+ s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
484
+ s_mask = s_mask[..., :seq_len, :]
485
+ # negate the attention mask to get the padding mask
486
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
487
+ return (s_mask, rf_mask)
488
+ else:
489
+ return (None, None)
490
+
491
+ def _prepare_eva_generation_attn_mask(
492
+ self,
493
+ attention_mask,
494
+ input_ids,
495
+ use_cache,
496
+ past_key_values
497
+ ):
498
+ batch_size, seq_len = input_ids.shape
499
+ if use_cache and past_key_values.get_seq_length() > 0:
500
+ # decoding phase
501
+ if past_key_values.rf_mask[0] is not None:
502
+ rf_mask = torch.zeros(
503
+ (batch_size, 1, seq_len, 1),
504
+ dtype=past_key_values.rf_mask[0].dtype,
505
+ device=past_key_values.rf_mask[0].device
506
+ )
507
+ else:
508
+ rf_mask = None
509
+
510
+ cur_causal_mask = torch.zeros(
511
+ (batch_size, 1, seq_len, 1),
512
+ dtype=torch.bool,
513
+ device=input_ids.device
514
+ )
515
+
516
+ chunk_causal_mask = torch.ones(
517
+ (batch_size, 1, seq_len, 1),
518
+ dtype=torch.bool,
519
+ device=input_ids.device
520
+ )
521
+ # chunk_causal_mask are 1s because we will mask them by default and
522
+ # will be unmasked when the current singleton attention is processed over
523
+ return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
524
+
525
+ true_num_chunks = seq_len // self.config.chunk_size
526
+ chunk_causal_mask, _ = prepare_eva_attention_mask(
527
+ seq_len,
528
+ input_ids.device,
529
+ self.config.chunk_size,
530
+ self.config.window_size,
531
+ use_cache=use_cache,
532
+ cache=past_key_values
533
+ )
534
+ chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
535
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
536
+ # convert 0 -> padding to 1 -> padding
537
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
538
+ else:
539
+ rf_mask = None
540
+
541
+ if seq_len < self.config.window_size:
542
+ cur_window_mask = torch.ones(
543
+ (1, 1, seq_len, seq_len),
544
+ device=input_ids.device
545
+ ).triu(1).to(torch.bool)
546
+ if rf_mask is not None:
547
+ cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
548
+ prev_window_mask = None
549
+ else:
550
+ if seq_len % self.config.window_size == 0:
551
+ num_windows = seq_len // self.config.window_size
552
+ cur_window_mask = None
553
+ prev_window_mask = torch.ones(
554
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
555
+ device=input_ids.device
556
+ ).triu(1).to(torch.bool)
557
+ if rf_mask is not None:
558
+ prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
559
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
560
+ else:
561
+ num_windows = seq_len // self.config.window_size
562
+ remainder_tokens = seq_len % self.config.window_size
563
+ cur_window_mask = torch.ones(
564
+ (1, 1, remainder_tokens, remainder_tokens),
565
+ device=input_ids.device
566
+ ).triu(1).to(torch.bool)
567
+ prev_window_mask = torch.ones(
568
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
569
+ device=input_ids.device
570
+ ).triu(1).to(torch.bool)
571
+ if rf_mask is not None:
572
+ prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
573
+ cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
574
+ prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
575
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
576
+
577
+ return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
578
+
579
+ def forward(
580
+ self,
581
+ input_ids: torch.LongTensor = None,
582
+ attention_mask: Optional[torch.Tensor] = None,
583
+ position_ids: Optional[torch.LongTensor] = None,
584
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
586
+ use_cache: Optional[bool] = None,
587
+ output_attentions: Optional[bool] = None,
588
+ output_hidden_states: Optional[bool] = None,
589
+ return_dict: Optional[bool] = None,
590
+ multibyte_decoding: Optional[bool] = None,
591
+ ) -> Tuple:
592
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
593
+ output_hidden_states = (output_hidden_states
594
+ if output_hidden_states is not None else self.config.output_hidden_states)
595
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
596
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
597
+
598
+ if (input_ids is None) ^ (inputs_embeds is not None):
599
+ raise ValueError(
600
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
601
+ )
602
+
603
+ if self.gradient_checkpointing and self.training and use_cache:
604
+ raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
605
+
606
+ batch_size, seq_len = input_ids.shape
607
+ #### Step 0. Hack
608
+ if (not self.training) and (not use_cache) and (not multibyte_decoding):
609
+ # forward-only inference mode.
610
+ # We tweak use_cache to be True to reuse code for generation
611
+ use_cache = True
612
+ device = input_ids.device if input_ids is not None else None
613
+ if position_ids is None:
614
+ position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
615
+
616
+ #### Step 1. Prepare caches if in inference mode
617
+ if use_cache:
618
+ if past_key_values is not None:
619
+ assert isinstance(past_key_values, Cache)
620
+ else:
621
+ if not USE_TRITON_IMPL:
622
+ past_key_values = EvaCache()
623
+ else:
624
+ past_key_values = EvaStaticCacheForTriton(
625
+ input_ids.shape[0],
626
+ self.config.num_attention_heads,
627
+ self.config.window_size,
628
+ self.config.hidden_size // self.config.num_attention_heads,
629
+ self.config.num_hidden_layers,
630
+ self.embed_tokens.weight.dtype,
631
+ self.embed_tokens.weight.device,
632
+ )
633
+
634
+ if not multibyte_decoding:
635
+ if use_cache:
636
+ if USE_TRITON_IMPL:
637
+ causal_mask = self._prepare_eva_generation_attn_mask_triton(
638
+ attention_mask,
639
+ input_ids,
640
+ use_cache,
641
+ past_key_values
642
+ )
643
+ else:
644
+ causal_mask = self._prepare_eva_generation_attn_mask(
645
+ attention_mask,
646
+ input_ids,
647
+ use_cache,
648
+ past_key_values
649
+ )
650
+ else:
651
+ assert self.training
652
+ assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
653
+ # for training, we need to pass in the attention mask
654
+ # usually calculated by _prepare_training_attn_mask()
655
+ causal_mask = attention_mask
656
+ else:
657
+ assert use_cache
658
+ causal_mask = attention_mask
659
+
660
+ if inputs_embeds is None:
661
+ inputs_embeds = self.embed_tokens(input_ids)
662
+
663
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
664
+ max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
665
+
666
+ hidden_states = inputs_embeds
667
+
668
+ if position_ids is None:
669
+ assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
670
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
671
+ position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
672
+
673
+ cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
674
+ assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
675
+ assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
676
+ assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
677
+ cos = cos[position_ids, :]
678
+ sin = sin[position_ids, :]
679
+ cos = cos.unsqueeze(1)
680
+ sin = sin.unsqueeze(1)
681
+
682
+ # decoder layers
683
+ all_hidden_states = () if output_hidden_states else None
684
+ all_self_attns = () if output_attentions else None
685
+ next_decoder_cache = None
686
+
687
+ for decoder_layer in self.layers:
688
+ if output_hidden_states:
689
+ all_hidden_states += (hidden_states, )
690
+
691
+ if self.gradient_checkpointing and self.training:
692
+ layer_outputs = torch.utils.checkpoint.checkpoint(
693
+ decoder_layer.__call__,
694
+ hidden_states,
695
+ causal_mask,
696
+ position_ids,
697
+ past_key_values,
698
+ output_attentions,
699
+ use_cache,
700
+ cos,
701
+ sin,
702
+ multibyte_decoding,
703
+ )
704
+ else:
705
+ layer_outputs = decoder_layer(
706
+ hidden_states,
707
+ attention_mask=causal_mask,
708
+ position_ids=position_ids,
709
+ past_key_value=past_key_values,
710
+ output_attentions=output_attentions,
711
+ use_cache=use_cache,
712
+ cos=cos,
713
+ sin=sin,
714
+ multibyte_decoding=multibyte_decoding,
715
+ )
716
+
717
+ hidden_states = layer_outputs[0]
718
+
719
+ if use_cache:
720
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
721
+
722
+ if output_attentions:
723
+ all_self_attns += (layer_outputs[1], )
724
+
725
+ hidden_states = self.norm(hidden_states)
726
+
727
+ # add hidden states from the last decoder layer
728
+ if output_hidden_states:
729
+ all_hidden_states += (hidden_states, )
730
+
731
+ next_cache = next_decoder_cache if use_cache else None
732
+ if not return_dict:
733
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
734
+
735
+ return BaseModelOutputWithPast(
736
+ last_hidden_state=hidden_states,
737
+ past_key_values=next_cache,
738
+ hidden_states=all_hidden_states,
739
+ attentions=all_self_attns,
740
+ )
741
+
742
+
743
+ class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
744
+ _tied_weights_keys = ["lm_head.weight"]
745
+
746
+ def __init__(self, config):
747
+ EvaBytePreTrainedModel.__init__(self, config)
748
+
749
+ self.model = EvaByteModel(config)
750
+ self.vocab_size = config.vocab_size
751
+ # define multibyte prediction heads
752
+ if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
753
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
754
+ else:
755
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
756
+
757
+ self.post_init()
758
+
759
+ def get_input_embeddings(self):
760
+ return self.model.embed_tokens
761
+
762
+ def set_input_embeddings(self, value):
763
+ self.model.embed_tokens = value
764
+
765
+ def get_output_embeddings(self):
766
+ return self.lm_head
767
+
768
+ def set_output_embeddings(self, new_embeddings):
769
+ self.lm_head = new_embeddings
770
+
771
+ def set_decoder(self, decoder):
772
+ self.model = decoder
773
+
774
+ def get_decoder(self):
775
+ return self.model
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: torch.LongTensor = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ position_ids: Optional[torch.LongTensor] = None,
782
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
783
+ inputs_embeds: Optional[torch.FloatTensor] = None,
784
+ labels: Optional[torch.LongTensor] = None,
785
+ use_cache: Optional[bool] = None,
786
+ output_attentions: Optional[bool] = None,
787
+ output_hidden_states: Optional[bool] = None,
788
+ return_dict: Optional[bool] = None,
789
+ return_all_pred_logits: Optional[bool] = None,
790
+ multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
791
+
792
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
793
+ output_hidden_states = (output_hidden_states
794
+ if output_hidden_states is not None else self.config.output_hidden_states)
795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796
+
797
+ if input_ids is None:
798
+ assert past_key_values is None
799
+
800
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
801
+ outputs = self.model(
802
+ input_ids=input_ids,
803
+ attention_mask=attention_mask,
804
+ position_ids=position_ids,
805
+ past_key_values=past_key_values,
806
+ inputs_embeds=inputs_embeds,
807
+ use_cache=use_cache,
808
+ output_attentions=output_attentions,
809
+ output_hidden_states=output_hidden_states,
810
+ return_dict=return_dict,
811
+ multibyte_decoding=multibyte_decoding,
812
+ )
813
+
814
+ hidden_states = outputs[0]
815
+
816
+ logits = self.lm_head(hidden_states)
817
+ if self.config.fp32_logits:
818
+ logits = logits.float()
819
+
820
+ loss = None
821
+ if labels is not None:
822
+ loss_fct = CrossEntropyLoss(reduction="none")
823
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
824
+ shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
825
+ # shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
826
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
827
+ else:
828
+ shift_logits = logits.view(-1, self.config.vocab_size)
829
+ shift_labels = labels.view(-1)
830
+ # Enable model parallelism
831
+ shift_labels = shift_labels.to(shift_logits.device)
832
+ loss = loss_fct(shift_logits, shift_labels)
833
+
834
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
835
+ all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
836
+
837
+ if return_all_pred_logits:
838
+ logits = all_pred_logits
839
+ else:
840
+ logits = all_pred_logits[..., 0, :]
841
+
842
+ if not return_dict:
843
+ output = (logits, ) + outputs[1:]
844
+ return (loss, ) + output if loss is not None else output
845
+
846
+ return CausalLMOutputWithPast(
847
+ loss=loss,
848
+ logits=logits,
849
+ past_key_values=outputs.past_key_values,
850
+ hidden_states=outputs.hidden_states,
851
+ attentions=outputs.attentions,
852
+ )
853
+
854
+
855
+ def prepare_inputs_for_generation(self,
856
+ input_ids,
857
+ past_key_values=None,
858
+ attention_mask=None,
859
+ inputs_embeds=None,
860
+ use_cache=True,
861
+ **kwargs):
862
+ # prefill phase:
863
+ # input_ids: b x s
864
+ # attention_mask: None if no padding or b x s
865
+ # position_ids : b x s
866
+
867
+ # token gen phase:
868
+ # input_ids : b x 1
869
+ # attention_mask: b x 1 x s
870
+ # position_ids: b x 1
871
+ past_length = 0
872
+ if past_key_values is not None:
873
+ assert isinstance(past_key_values, Cache)
874
+ past_length = past_key_values.get_seq_length()
875
+
876
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
877
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
878
+ elif past_length < input_ids.shape[1]:
879
+ input_ids = input_ids[:, past_length:]
880
+
881
+ position_ids = kwargs.get("position_ids", None)
882
+ if attention_mask is not None and position_ids is None:
883
+ position_ids = attention_mask.long().cumsum(-1) - 1
884
+ position_ids.masked_fill_(attention_mask == 0, 1)
885
+ if past_key_values:
886
+ position_ids = position_ids[:, -input_ids.shape[1]:]
887
+
888
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
889
+ if inputs_embeds is not None and past_key_values is None:
890
+ model_inputs = {"inputs_embeds": inputs_embeds}
891
+ else:
892
+ model_inputs = {"input_ids": input_ids}
893
+
894
+ # must initialize position_ids at each step during GPU inference
895
+ assert position_ids is not None
896
+ model_inputs.update(
897
+ {
898
+ "position_ids": position_ids,
899
+ "past_key_values": past_key_values,
900
+ "use_cache": use_cache,
901
+ "attention_mask": attention_mask,
902
+ }
903
+ )
904
+ return model_inputs
905
+
906
+ @staticmethod
907
+ def _reorder_cache(past_key_values, beam_idx):
908
+ reordered_past = ()
909
+ for layer_past in past_key_values:
910
+ reordered_past += (tuple(
911
+ past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
912
+ return reordered_past
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/multibyte_decoding_evabyte.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # The implementation of multibyte deocidng is largely adapted from
3
+ # Medusa decoding: https://github.com/FasterDecoding/Medusa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers.generation.stopping_criteria import (
7
+ MaxLengthCriteria,
8
+ StoppingCriteriaList,
9
+ )
10
+ from typing import Union, List
11
+ from .eva_cache import EvaStaticCacheForTriton
12
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
13
+
14
+ class MultibyteEosTokenCriteria:
15
+ """
16
+ This class implements a simple stopping criteria to stop generation whenever
17
+ the "end-of-sequence" token is generated in the last `new_tokens` tokens.
18
+
19
+ Adapted from
20
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
21
+ By default, it uses the `model.generation_config.eos_token_id`.
22
+
23
+ Args:
24
+ eos_token_id (`Union[int, List[int]]`):
25
+ The id(s) of the *end-of-sequence* token.
26
+ """
27
+
28
+ def __init__(self, eos_token_ids: Union[int, List[int]]):
29
+ if isinstance(eos_token_ids, int):
30
+ eos_token_ids = [eos_token_ids]
31
+ self.eos_token_ids = eos_token_ids
32
+
33
+ def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
34
+ current_input_len = input_ids.shape[-1]
35
+ new_token_ids = input_ids[:, current_input_len - new_tokens:]
36
+ for eos_token_id in self.eos_token_ids:
37
+ if torch.any(new_token_ids == eos_token_id):
38
+ return True
39
+ return False
40
+
41
+ def build_tree(spec):
42
+ nodes_at_depth = []
43
+ nodes_at_depth.append([()]) # Root at depth 1
44
+
45
+ for d in range(1, len(spec) + 1):
46
+ prev_nodes = nodes_at_depth[d - 1]
47
+ spec_list = spec[d - 1]
48
+ current_nodes = []
49
+ for node_idx, node in enumerate(prev_nodes):
50
+ if node_idx < len(spec_list):
51
+ num_children = spec_list[node_idx]
52
+ else:
53
+ num_children = 0
54
+ for child_idx in range(num_children):
55
+ new_node = node + (child_idx,)
56
+ current_nodes.append(new_node)
57
+ nodes_at_depth.append(current_nodes)
58
+
59
+ # Flatten the list of nodes, excluding the root node if desired
60
+ all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
61
+ return all_nodes
62
+
63
+ evabyte_7b_95 = build_tree(
64
+ [
65
+ [10],
66
+ [10, 8, 2, 2, 1, 1],
67
+ [10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
68
+ [8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
69
+ [6, 2, 1, 1],
70
+ [4, 2, 1, 1],
71
+ [4, 2, 1],
72
+ ]
73
+ )
74
+ evabyte_7b_31 = build_tree(
75
+ [
76
+ [4],
77
+ [3, 2, 1, 1],
78
+ [3, 2, 1, 1],
79
+ [2, 1, 1],
80
+ [2, 1],
81
+ [2, 1],
82
+ [2, 1],
83
+ ]
84
+ )
85
+ TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
86
+
87
+ def pad_path(path, length, pad_value=-2):
88
+ """
89
+ Pad the given path list with a specific value up to a specified length.
90
+
91
+ Parameters:
92
+ - path (list): The original list that needs padding.
93
+ - length (int): The desired length of the padded list.
94
+ - pad_value (optional, default=-2): The value to use for padding.
95
+
96
+ Returns:
97
+ - list: A new list based on the original path but padded to the desired length.
98
+
99
+ Example:
100
+ >>> pad_path([1,2,3], 5)
101
+ [1, 2, 3, -2, -2]
102
+
103
+ Note:
104
+ If the given path is already longer than the specified length,
105
+ then no padding occurs, and the original path is returned.
106
+ """
107
+ return path + [pad_value] * (length - len(path))
108
+
109
+ def reset_past_key_values(passed_key_values):
110
+ """
111
+ Resets the current lengths in the passed key-values to zero.
112
+
113
+ This function is designed to be used during the evaluation of a baseline model.
114
+ It iterates through each layer's key-values and sets their current lengths to zero,
115
+ effectively resetting their state.
116
+
117
+ Args:
118
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
119
+
120
+ Returns:
121
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
122
+ """
123
+ for i in range(len(passed_key_values)):
124
+ for j in range(2):
125
+ passed_key_values[i][j].current_length.fill_(0)
126
+ return passed_key_values
127
+
128
+ def get_nucleus_one_token(logit, temperature, top_p):
129
+ """
130
+ Performs token sampling based on the nucleus (top-p) sampling method.
131
+
132
+ This function selects a token from a given logit distribution using the nucleus sampling strategy.
133
+ It allows for more controlled and diverse generation compared to traditional top-k sampling.
134
+
135
+ Args:
136
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
137
+ temperature (float): A temperature parameter to control the randomness in sampling.
138
+ Higher values increase diversity, lower values make selections more deterministic.
139
+ top_p (float): The cumulative probability threshold for nucleus sampling.
140
+ It controls the size of the set of high-probability tokens to consider for sampling.
141
+
142
+ Returns:
143
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
144
+ """
145
+ if top_p >= 1:
146
+ return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
147
+ logit = logit / temperature
148
+ probs = torch.softmax(logit, dim=-1)
149
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
150
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
151
+ sorted_indices_to_remove = cum_probs > top_p
152
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
153
+ sorted_indices_to_remove[..., 0] = 0
154
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
155
+ logit[indices_to_remove] = float('-inf')
156
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
157
+ return sampled_tokens
158
+
159
+ def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
160
+ """
161
+ Implements token sampling based on the typical sampling method.
162
+
163
+ This function selects a token from a given logit distribution using the typical sampling strategy,
164
+ aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
165
+
166
+ Args:
167
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
168
+ temperature (float): A parameter to control the randomness in sampling.
169
+ Higher values increase diversity, lower values make selections more deterministic.
170
+ posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
171
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
172
+
173
+ Returns:
174
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
175
+ """
176
+ logit = logit / temperature
177
+ probs = torch.softmax(logit, dim=-1)
178
+ entropy = -torch.sum(
179
+ probs * torch.log(probs + 1e-5), dim=-1
180
+ )
181
+ threshold = torch.minimum(
182
+ torch.ones_like(entropy) * posterior_threshold,
183
+ torch.exp(-entropy) * posterior_alpha,
184
+ )
185
+ indices_to_remove = probs < threshold.unsqueeze(-1)
186
+ logit[indices_to_remove] = float('-inf')
187
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
188
+ return sampled_tokens
189
+
190
+
191
+
192
+ def generate_medusa_buffers(medusa_choices, device="cuda"):
193
+ """
194
+ Generate buffers for the Medusa structure based on the provided choices.
195
+
196
+ Parameters:
197
+ - medusa_choices (list): A nested list representing tree in the Medusa structure.
198
+ - device (str): Device to which the tensors should be moved. Default is "cuda".
199
+
200
+ Returns:
201
+ - dict: A dictionary containing buffers related to the Medusa structure.
202
+ """
203
+
204
+ # Sort the medusa_choices based on their lengths and then their values
205
+ sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
206
+ medusa_len = len(sorted_medusa_choices) + 1
207
+
208
+ # Initialize depth_counts to keep track of how many choices have a particular depth
209
+ depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
210
+ for path in sorted_medusa_choices:
211
+ depth_counts[len(path) - 1] += 1
212
+
213
+ # Create the attention mask for Medusa
214
+ medusa_attn_mask = torch.eye(medusa_len, medusa_len)
215
+ medusa_attn_mask[:, 0] = 1
216
+ start = 0
217
+ for i in range(len(depth_counts)):
218
+ for j in range(depth_counts[i]):
219
+ cur_medusa_choice = sorted_medusa_choices[start + j]
220
+ # retrieve ancestor position
221
+ if len(cur_medusa_choice) == 1:
222
+ continue
223
+ ancestor_idx = []
224
+ for c in range(len(cur_medusa_choice) - 1):
225
+ ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
226
+ medusa_attn_mask[j + start + 1, ancestor_idx] = 1
227
+ start += depth_counts[i]
228
+
229
+ # Generate tree indices for the Medusa structure
230
+ medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
231
+ medusa_tree_indices[0] = 0
232
+ start = 0
233
+ for i in range(len(depth_counts)):
234
+ for j in range(depth_counts[i]):
235
+ cur_medusa_choice = sorted_medusa_choices[start + j]
236
+ medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
237
+ start += depth_counts[i]
238
+
239
+ # Generate position IDs for the Medusa structure
240
+ medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
241
+ start = 0
242
+ for i in range(len(depth_counts)):
243
+ medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
244
+ start += depth_counts[i]
245
+
246
+ # Generate retrieval indices for Medusa structure verification
247
+ retrieve_indices_nest = []
248
+ retrieve_paths = []
249
+ for i in range(len(sorted_medusa_choices)):
250
+ cur_medusa_choice = sorted_medusa_choices[-i-1]
251
+ retrieve_indice = []
252
+ if cur_medusa_choice in retrieve_paths:
253
+ continue
254
+ else:
255
+ for c in range(len(cur_medusa_choice)):
256
+ retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
257
+ retrieve_paths.append(cur_medusa_choice[:c+1])
258
+ retrieve_indices_nest.append(retrieve_indice)
259
+ max_length = max([len(x) for x in retrieve_indices_nest])
260
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
261
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
262
+ retrieve_indices = retrieve_indices + 1
263
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
264
+
265
+ # Aggregate the generated buffers into a dictionary
266
+ medusa_buffers = {
267
+ "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
268
+ "tree_indices": medusa_tree_indices,
269
+ "medusa_position_ids": medusa_position_ids.unsqueeze(0),
270
+ "retrieve_indices": retrieve_indices,
271
+ }
272
+
273
+ # Move the tensors in the dictionary to the specified device
274
+ medusa_buffers = {
275
+ k: v.clone().to(device)
276
+ if isinstance(v, torch.Tensor)
277
+ else torch.tensor(v, device=device)
278
+ for k, v in medusa_buffers.items()
279
+ }
280
+ return medusa_buffers
281
+
282
+ def generate_candidates(
283
+ medusa_logits,
284
+ logits,
285
+ tree_indices,
286
+ retrieve_indices,
287
+ temperature = 0,
288
+ posterior_threshold=0.3,
289
+ posterior_alpha = 0.09,
290
+ top_p=0.8,
291
+ sampling = 'typical',
292
+ fast = False
293
+ ):
294
+ # Say we have 3 heads, and the top-4 for each head are:
295
+ # [10, 3, 8, 4]
296
+ # [9, 5, 1, 6]
297
+ # [7, 16, 3, 2]
298
+
299
+ # candidates_id = 10
300
+ if temperature == 0 or fast:
301
+ candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
302
+ else:
303
+ if sampling == 'typical':
304
+ candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
305
+ elif sampling == 'nucleus':
306
+ candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
307
+ else:
308
+ raise NotImplementedError
309
+
310
+ # this calculates the top-k medusa logits
311
+ # candidates_medusa_id = [
312
+ # [9, 5, 1, 6]
313
+ # [7, 16, 3, 2]
314
+ # ]
315
+ candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
316
+
317
+ # [10, 9, 5, 1, 6, 7, 16, 3, 2]
318
+ candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
319
+
320
+ # based on the pre-defined tree_indices, select the corresponding candidates
321
+ # if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
322
+ # tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
323
+ tree_candidate_ids = candidate_ids[tree_indices]
324
+
325
+ # tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
326
+ # Sometimes the tree_indices are padded, so we append a zero here
327
+ # so that all padded indices select the appended zero.
328
+ tree_candidate_ids_ext = torch.cat(
329
+ [
330
+ tree_candidate_ids,
331
+ torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
332
+ ],
333
+ dim=0
334
+ )
335
+ # [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
336
+ unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
337
+
338
+ tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
339
+
340
+ return tree_candidate_ids, unflattened_candidate_ids
341
+
342
+ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
343
+ """
344
+ Generates a posterior mask for token candidates using nucleus (top-p) sampling.
345
+
346
+ This function applies nucleus sampling to a set of logits, and then generates a mask indicating
347
+ which candidate tokens are selected. It adapts the sampling strategy to accommodate for
348
+ temperature scaling and cumulative probability thresholding.
349
+
350
+ Args:
351
+ logits (torch.Tensor): A tensor of logits from a language model output.
352
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
353
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
354
+ top_p (float): The cumulative probability threshold for nucleus sampling.
355
+
356
+ Returns:
357
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
358
+ """
359
+ # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
360
+
361
+ # Apply temperature
362
+ logits = logits[:, :-1] / temperature
363
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
364
+ logits = logits.view(n_samples*n_tokens, -1)
365
+ if top_p >= 1:
366
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
367
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
368
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
369
+ return posterior_mask
370
+ # Convert to probabilities (softmax)
371
+ probs = F.softmax(logits, dim=-1)
372
+ # Sort the probabilities
373
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
374
+
375
+ # Compute cumulative probabilities
376
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
377
+
378
+ # Create mask for the top-p nucleus
379
+ sorted_indices_to_remove = cum_probs > top_p
380
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
381
+ sorted_indices_to_remove[..., 0] = 0
382
+
383
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
384
+
385
+
386
+ # Remove low-probability tokens
387
+ logits[indices_to_remove] = float('-inf')
388
+ # Sample from the remaining tokens
389
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
390
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
391
+ # Create a mask for selected tokens
392
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
393
+
394
+ return posterior_mask
395
+
396
+ def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
397
+ """
398
+ Args:
399
+ logits (torch.Tensor): A tensor of logits from a language model output.
400
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
401
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
402
+ posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
403
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
404
+
405
+ Returns:
406
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
407
+ """
408
+ logits = logits[:, :-1] / temperature
409
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
410
+ logits = logits.view(n_samples*n_tokens, -1)
411
+ probs = F.softmax(logits, dim=-1)
412
+ entropy = -torch.sum(
413
+ probs * torch.log(probs + 1e-5), dim=-1
414
+ )
415
+ threshold = torch.minimum(
416
+ torch.ones_like(entropy) * posterior_threshold,
417
+ torch.exp(-entropy) * posterior_alpha,
418
+ )
419
+ indices_to_remove = probs < threshold.unsqueeze(-1)
420
+ logits[indices_to_remove] = float('-inf')
421
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
422
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
423
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
424
+ return posterior_mask
425
+
426
+
427
+
428
+ def evaluate_posterior(
429
+ logits,
430
+ candidates,
431
+ temperature,
432
+ posterior_threshold=0.3,
433
+ posterior_alpha = 0.09,
434
+ top_p=0.8,
435
+ sampling = 'typical',
436
+ fast = True
437
+ ):
438
+ if logits.shape[1] <= 1:
439
+ return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
440
+ # Greedy decoding based on temperature value
441
+ if temperature == 0:
442
+ # Find the tokens that match the maximum logits for each position in the sequence
443
+ posterior_mask = (
444
+ candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
445
+ ).int()
446
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
447
+ accept_length = candidates_accept_length.max().item()
448
+ # Choose the best candidate
449
+ if accept_length == 0:
450
+ # Default to the first candidate if none are accepted
451
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
452
+ else:
453
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
454
+ return best_candidate, accept_length
455
+ elif sampling == 'typical':
456
+ if fast:
457
+ posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
458
+ candidates_prob = torch.gather(
459
+ posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
460
+ ).squeeze(-1)
461
+ posterior_entropy = -torch.sum(
462
+ posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
463
+ ) # torch.sum(torch.log(*)) is faster than torch.prod
464
+ threshold = torch.minimum(
465
+ torch.ones_like(posterior_entropy) * posterior_threshold,
466
+ torch.exp(-posterior_entropy) * posterior_alpha,
467
+ )
468
+ posterior_mask = candidates_prob > threshold
469
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
470
+
471
+ # Choose the best candidate based on the evaluated posterior probabilities
472
+ accept_length = candidates_accept_length.max().item()
473
+ if accept_length == 0:
474
+ # If no candidates are accepted, just choose the first one
475
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
476
+ else:
477
+ best_candidates = torch.where(candidates_accept_length == accept_length)[0]
478
+ # Accept the best one according to likelihood
479
+ likelihood = torch.sum(
480
+ torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
481
+ )
482
+ best_candidate = best_candidates[torch.argmax(likelihood)]
483
+ return best_candidate, accept_length
484
+ # Calculate posterior probabilities and thresholds for candidate selection
485
+ posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
486
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
487
+ # Choose the best candidate based on the evaluated posterior probabilities
488
+ accept_length = candidates_accept_length.max().item()
489
+
490
+ if accept_length == 0:
491
+ # If no candidates are accepted, just choose the first one
492
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
493
+ else:
494
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
495
+ # Accept the best one according to likelihood
496
+ return best_candidate, accept_length
497
+ elif sampling == 'nucleus':
498
+ assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
499
+ posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
500
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
501
+ accept_length = candidates_accept_length.max().item()
502
+ # Choose the best candidate
503
+ if accept_length == 0:
504
+ # Default to the first candidate if none are accepted
505
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
506
+ else:
507
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
508
+ return best_candidate, accept_length
509
+ else:
510
+ raise NotImplementedError
511
+
512
+ def update_inference_inputs(
513
+ input_ids,
514
+ medusa_logits,
515
+ logits,
516
+ candidate_ids,
517
+ best_candidate,
518
+ accept_length,
519
+ ):
520
+ input_ids = torch.cat(
521
+ [
522
+ input_ids,
523
+ candidate_ids[None, best_candidate, : accept_length + 1]
524
+ ],
525
+ dim=-1
526
+ )
527
+ logits = logits[
528
+ None, best_candidate, accept_length : accept_length + 1
529
+ ]
530
+ medusa_logits = medusa_logits[
531
+ :, None, best_candidate, accept_length : accept_length + 1
532
+ ]
533
+ # Update the new token counter
534
+ new_token = accept_length + 1
535
+ return input_ids, medusa_logits, logits, new_token
536
+
537
+ def split_logits(full_logits):
538
+ # logits has shape [b, n, heads, vocab_size]
539
+ logits = full_logits[..., 0, :]
540
+ medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
541
+ return medusa_logits, logits
542
+
543
+ class MultiByteDecodingMixin:
544
+ def multi_byte_pred_update_cache(
545
+ self,
546
+ past_key_values,
547
+ retrieve_indices,
548
+ best_candidate,
549
+ new_tokens,
550
+ ):
551
+ prev_window_len = past_key_values.get_past_window_pos(0)
552
+ select_indices = (
553
+ retrieve_indices[best_candidate, : new_tokens] + prev_window_len
554
+ )
555
+ for layer_idx in range(self.config.num_hidden_layers):
556
+
557
+ past_key_values.update_past_len(new_tokens, layer_idx)
558
+
559
+ past_window_k = past_key_values.past_window_k[layer_idx]
560
+ past_window_v = past_key_values.past_window_v[layer_idx]
561
+
562
+ tgt_window_k = past_window_k[..., select_indices, :]
563
+ tgt_window_v = past_window_v[..., select_indices, :]
564
+
565
+ dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
566
+ dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
567
+
568
+ dst_window_k.copy_(tgt_window_k, non_blocking=True)
569
+ dst_window_v.copy_(tgt_window_v, non_blocking=True)
570
+
571
+ new_window_len = prev_window_len + new_tokens
572
+ if new_window_len >= self.config.window_size:
573
+ assert new_window_len < 2 * self.config.window_size
574
+
575
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
576
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
577
+
578
+ _window_len = new_window_len - self.config.window_size
579
+
580
+ if _window_len > 0:
581
+ new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
582
+ new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
583
+
584
+ _dst_window_k = past_window_k[..., : _window_len, :]
585
+ _dst_window_v = past_window_v[..., : _window_len, :]
586
+
587
+ _dst_window_k.copy_(new_window_k, non_blocking=True)
588
+ _dst_window_v.copy_(new_window_v, non_blocking=True)
589
+
590
+ past_key_values.past_window_pos[layer_idx] = _window_len
591
+ else:
592
+ dump_k = None
593
+ dump_v = None
594
+ past_key_values.past_window_pos[layer_idx] = new_window_len
595
+
596
+ if dump_k is not None and dump_v is not None:
597
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
598
+ dump_k, dump_v,
599
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
600
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
601
+ None,
602
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
603
+ self.model.layers[layer_idx].self_attn.chunk_size
604
+ )
605
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
606
+ rfa_k, rfa_v, layer_idx
607
+ )
608
+ return past_key_values
609
+
610
+ def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
611
+ self,
612
+ past_key_values,
613
+ ):
614
+ prev_window_len = past_key_values.get_past_window_pos(0)
615
+ for layer_idx in range(self.config.num_hidden_layers):
616
+
617
+ past_window_k = past_key_values.past_window_k[layer_idx]
618
+ past_window_v = past_key_values.past_window_v[layer_idx]
619
+
620
+ new_window_len = prev_window_len
621
+ if new_window_len == self.config.window_size:
622
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
623
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
624
+ past_key_values.past_window_pos[layer_idx] = 0
625
+
626
+ if dump_k is not None and dump_v is not None:
627
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
628
+ dump_k, dump_v,
629
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
630
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
631
+ None,
632
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
633
+ self.model.layers[layer_idx].self_attn.chunk_size
634
+ )
635
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
636
+ rfa_k, rfa_v, layer_idx
637
+ )
638
+ return past_key_values
639
+
640
+ def multi_byte_pred_update_attn_mask(
641
+ self,
642
+ last_iter_new_tokens,
643
+ tree_candidate_ids,
644
+ past_attn_mask,
645
+ medusa_attn_mask,
646
+ past_key_values,
647
+ ):
648
+ batch_size, tree_candidate_len = tree_candidate_ids.shape
649
+ seen_tokens = past_key_values.get_seq_length()
650
+ # NOTE: past_key_values has been updated so now
651
+ # seen_tokens incldues new tokens from the last tree iteration
652
+ assert seen_tokens > 0
653
+ # so one iteration would not cross two windows
654
+ assert last_iter_new_tokens < self.config.window_size
655
+
656
+ if past_attn_mask is not None and seen_tokens < self.config.window_size:
657
+ past_attn_mask = torch.cat(
658
+ [
659
+ past_attn_mask,
660
+ torch.ones(
661
+ [batch_size, 1, tree_candidate_len, last_iter_new_tokens],
662
+ dtype=torch.bool,
663
+ device=self.device
664
+ )
665
+ ],
666
+ dim=-1
667
+ )
668
+ else:
669
+ # we initialize attn mask each time when
670
+ # 1. the model crosses the window bounary, or
671
+ # 2. after prefilling
672
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
673
+
674
+ window_tokens = seen_tokens % self.config.window_size
675
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
676
+ attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
677
+ past_attn_mask = torch.ones(
678
+ (batch_size, 1, tree_candidate_len, attn_mask_len),
679
+ dtype=torch.bool,
680
+ device=self.device
681
+ )
682
+
683
+ # note that 1 indicates the position is not masked
684
+ tree_attn_mask = torch.cat(
685
+ [
686
+ past_attn_mask,
687
+ medusa_attn_mask.to(torch.bool)
688
+ ],
689
+ dim=-1
690
+ )
691
+ return tree_attn_mask, past_attn_mask
692
+
693
+ @torch.no_grad()
694
+ def multi_byte_generate(
695
+ self,
696
+ input_ids,
697
+ attention_mask=None,
698
+ temperature=0.0,
699
+ max_length=None,
700
+ max_new_tokens=None,
701
+ stopping_criteria=None,
702
+ posterior_threshold=0.09,
703
+ posterior_alpha=0.3,
704
+ top_p=0.8,
705
+ sampling='typical',
706
+ fast=True,
707
+ do_sample=False,
708
+ medusa_choices=None,
709
+ return_acc_lengths=False
710
+ ):
711
+ if do_sample or temperature > 0.0:
712
+ fast = False
713
+
714
+ ### Prepare `max_length` depending on other stopping criteria.
715
+ if max_new_tokens is not None:
716
+ max_length = max_new_tokens + input_ids.shape[-1]
717
+ elif max_new_tokens is None and max_length is None:
718
+ max_length = getattr(self.config, "max_position_embeddings", 32768)
719
+
720
+ ### Set up stopping criteria
721
+ eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
722
+ stop_criteria = StoppingCriteriaList()
723
+ if max_length is not None:
724
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
725
+ stop_criteria.append(
726
+ MaxLengthCriteria(
727
+ max_length=max_length,
728
+ max_position_embeddings=max_position_embeddings,
729
+ )
730
+ )
731
+ if stopping_criteria is not None and len(stopping_criteria) > 0:
732
+ stop_criteria.extend(stopping_criteria)
733
+
734
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
735
+ assert attention_mask is None, "Only support attention mask None for now"
736
+ # Avoid modifying the input_ids in-place
737
+ input_ids = input_ids.clone()
738
+ position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
739
+
740
+ ####################################################
741
+ # 0. initialize the medusa buffers
742
+ ####################################################
743
+ if medusa_choices is None:
744
+ medusa_choices = evabyte_7b_95
745
+ medusa_buffers = generate_medusa_buffers(
746
+ medusa_choices, device=self.device
747
+ )
748
+
749
+ past_key_values = EvaStaticCacheForTriton(
750
+ input_ids.shape[0],
751
+ self.config.num_attention_heads,
752
+ # we add 256 to allow tree ids
753
+ self.config.window_size + 256,
754
+ self.config.hidden_size // self.config.num_attention_heads,
755
+ self.config.num_hidden_layers,
756
+ self.lm_head.weight.dtype,
757
+ self.lm_head.weight.device,
758
+ )
759
+ # prefill to get medusa logits and logits
760
+ full_logits, past_key_values = self.forward(
761
+ input_ids,
762
+ attention_mask=attention_mask,
763
+ position_ids=position_ids,
764
+ use_cache=True,
765
+ past_key_values=past_key_values,
766
+ return_all_pred_logits=True,
767
+ multibyte_decoding=False,
768
+ )
769
+ # handles an edge case where the prefill length == window_size
770
+ # we force the previous window to be dumped into RFA chunks
771
+ past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
772
+ past_key_values
773
+ )
774
+ medusa_logits, logits = split_logits(full_logits)
775
+
776
+ past_attn_mask = None
777
+ last_iter_new_tokens = 0
778
+ max_iters = 32768
779
+ if return_acc_lengths:
780
+ acc_lengths = []
781
+ for _ in range(max_iters):
782
+ ####################################################
783
+ # 1. generate candidate_ids with topk predictions from Medusa heads
784
+ ####################################################
785
+ tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
786
+ medusa_logits,
787
+ logits,
788
+ medusa_buffers["tree_indices"],
789
+ medusa_buffers["retrieve_indices"],
790
+ temperature=temperature,
791
+ posterior_alpha=posterior_alpha,
792
+ posterior_threshold=posterior_threshold,
793
+ top_p=top_p,
794
+ sampling=sampling,
795
+ fast=fast,
796
+ )
797
+
798
+ ####################################################
799
+ # 2. Build the medusa attention mask and position ids
800
+ ####################################################
801
+ # NOTE: 1 indicates the position is not masked
802
+ medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
803
+ last_iter_new_tokens,
804
+ tree_candidate_ids,
805
+ past_attn_mask,
806
+ medusa_buffers["medusa_attn_mask"],
807
+ past_key_values,
808
+ )
809
+ medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
810
+
811
+ ####################################################
812
+ # 3. tree decoding
813
+ ####################################################
814
+ tree_full_logits, past_key_values = self.forward(
815
+ tree_candidate_ids,
816
+ past_key_values=past_key_values,
817
+ attention_mask=medusa_attn_mask,
818
+ position_ids=medusa_position_ids,
819
+ return_all_pred_logits=True,
820
+ multibyte_decoding=True,
821
+ )
822
+ _medusa_logits, _logits = split_logits(tree_full_logits)
823
+ medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
824
+ logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
825
+
826
+ ####################################################
827
+ # 4. candidate selection
828
+ ####################################################
829
+ # if the current iteration, with tree tokens, crosses window
830
+ # boundaries, trim the condidate_ids to be within the window
831
+ # so that those exceeded tokens (which will be inaccurate)
832
+ # will not be considered
833
+ tree_depth = unflattened_candidate_ids.shape[-1]
834
+ if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
835
+ max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
836
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
837
+ _trimmed_logits = logits[:, :max_acc_len]
838
+ else:
839
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids
840
+ _trimmed_logits = logits
841
+ best_candidate, accept_length = evaluate_posterior(
842
+ _trimmed_logits,
843
+ _trimmed_unflattened_candidate_ids,
844
+ temperature,
845
+ posterior_threshold,
846
+ posterior_alpha,
847
+ top_p=top_p,
848
+ sampling=sampling,
849
+ fast=fast
850
+ )
851
+
852
+ ####################################################
853
+ # 5. update model inputs and caches
854
+ ####################################################
855
+ input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
856
+ input_ids,
857
+ medusa_logits,
858
+ logits,
859
+ unflattened_candidate_ids,
860
+ best_candidate,
861
+ accept_length,
862
+ )
863
+
864
+ past_key_values = self.multi_byte_pred_update_cache(
865
+ past_key_values,
866
+ medusa_buffers["retrieve_indices"],
867
+ best_candidate,
868
+ last_iter_new_tokens,
869
+ )
870
+
871
+ if return_acc_lengths:
872
+ acc_lengths.append(last_iter_new_tokens)
873
+ if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
874
+ if return_acc_lengths:
875
+ return input_ids, acc_lengths
876
+ else:
877
+ return input_ids
878
+ if return_acc_lengths:
879
+ return input_ids, acc_lengths
880
+ else:
881
+ return input_ids
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
4
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
5
+ },
6
+ "do_convert_rgb": true,
7
+ "do_resize": true,
8
+ "image_processor_type": "EvaByteImageProcessor",
9
+ "jpeg_quality": 25,
10
+ "jpeg_restart_marker_blocks": 1,
11
+ "jpeg_streamtype": 2,
12
+ "jpeg_subsampling": "4:2:0",
13
+ "processor_class": "EvaByteProcessor",
14
+ "resample": 1,
15
+ "size": {
16
+ "longest_edge": 384
17
+ }
18
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processing_evabyte.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Processor class for EvaByte.
4
+ """
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ import requests
9
+ import os
10
+ import PIL
11
+ from PIL import Image
12
+
13
+ from typing import List, Optional, Union
14
+
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.image_utils import ImageInput, is_valid_image
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
19
+ from transformers.utils import TensorType, to_py_obj
20
+
21
+ def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
22
+ image_obj = None
23
+ if isinstance(image, Image.Image):
24
+ image_obj = image
25
+ elif image.startswith("http://") or image.startswith("https://"):
26
+ image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
27
+ elif os.path.isfile(image):
28
+ image_obj = Image.open(image)
29
+ elif image.startswith("data:image/"):
30
+ image = image.split(",")[1]
31
+ # Try to load as base64
32
+ try:
33
+ b64 = base64.decodebytes(image.encode())
34
+ image = PIL.Image.open(BytesIO(b64))
35
+ except Exception as e:
36
+ raise ValueError(
37
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
38
+ )
39
+ else:
40
+ image_obj = Image.open(image)
41
+ if image_obj is None:
42
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
43
+
44
+ return image_obj
45
+
46
+ def is_url(val) -> bool:
47
+ return isinstance(val, str) and val.startswith("http")
48
+
49
+ def is_file(val) -> bool:
50
+ return isinstance(val, str) and os.path.isfile(val)
51
+
52
+ def is_image_or_image_url(elem):
53
+ return is_url(elem) or is_valid_image(elem) or is_file(elem)
54
+
55
+ vl_chat_template = """
56
+ {{- bos_token }}
57
+ {%- if messages[0]['role'] == 'system' %}
58
+ {%- set system_message = messages[0]['content'] %}
59
+ {%- set messages = messages[1:] %}
60
+ {%- else %}
61
+ {%- set system_message = "" %}
62
+ {%- endif %}
63
+
64
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
65
+
66
+ {%- for message in messages %}
67
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
68
+ {{- raise_exception('Conversation roles must be user or assistant') }}
69
+ {%- endif %}
70
+
71
+ {%- if message['content'] is string %}
72
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
73
+ {%- else %}
74
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
75
+ {%- for content in message['content'] %}
76
+ {%- if content['type'] == 'image' %}
77
+ {{- '<image_placeholder>\n' }}
78
+ {%- elif content['type'] == 'text' %}
79
+ {{- content['text'] }}
80
+ {%- endif %}
81
+ {%- endfor %}
82
+ {{- '<|eot_id|>' }}
83
+ {%- endif %}
84
+ {%- endfor %}
85
+
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
88
+ {%- endif %}
89
+ """
90
+
91
+ class EvaByteProcessor(ProcessorMixin):
92
+ r"""
93
+ Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
94
+
95
+ [`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
96
+ [`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
97
+
98
+ Args:
99
+ image_processor ([`EvaByteImageProcessor`], *optional*):
100
+ The image processor is a required input.
101
+ tokenizer ([`EvaByteTokenizer`], *optional*):
102
+ The tokenizer is a required input.
103
+ """
104
+
105
+ attributes = ["image_processor", "tokenizer"]
106
+ image_processor_class = "AutoImageProcessor"
107
+ tokenizer_class = "AutoTokenizer"
108
+
109
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
110
+ if image_processor is None:
111
+ raise ValueError("You need to specify an `image_processor`.")
112
+ if tokenizer is None:
113
+ raise ValueError("You need to specify a `tokenizer`.")
114
+
115
+ super().__init__(image_processor, tokenizer)
116
+ self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
117
+ self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
118
+ self.image_placeholder = "<image_placeholder>"
119
+ self.vl_chat_template = vl_chat_template
120
+
121
+ def __call__(
122
+ self,
123
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
124
+ images: ImageInput = None,
125
+ return_tensors: Optional[Union[str, TensorType]] = None,
126
+ strip_ending_sentinel: bool = False,
127
+ encode_only: bool = False,
128
+ **kwargs
129
+ ) -> Union[BatchFeature, List[List[int]]]:
130
+ # processing pipeline:
131
+ # 1. read images or videos from paths
132
+ # 2. use image_processor to convert images / videos to byte streams
133
+ if images is not None:
134
+ if isinstance(images, bytes):
135
+ image_bytes_list = [[images]]
136
+ elif isinstance(images, list) and isinstance(images[0], bytes):
137
+ image_bytes_list = [images]
138
+ elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
139
+ image_bytes_list = images
140
+ else:
141
+ if is_image_or_image_url(images):
142
+ images = [[images]]
143
+ elif isinstance(images, list) and is_image_or_image_url(images[0]):
144
+ images = [images]
145
+ elif (
146
+ not isinstance(images, list)
147
+ and not isinstance(images[0], list)
148
+ and not is_image_or_image_url(images[0][0])
149
+ ):
150
+ raise ValueError(
151
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
152
+ )
153
+ # Load images if they are URLs
154
+ images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
155
+ image_bytes_list = self.image_processor(images=images, **kwargs)
156
+
157
+ if not isinstance(text, list):
158
+ text = [text]
159
+ assert len(text) == 1, "Only support batch size 1 for now"
160
+ assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
161
+ # TODO: invoke SequenceFeatureExtractor to get batched inputs
162
+
163
+ # 3. tokenize the text and put images / videos byte streams into the placeholders
164
+ # surrounded by special tokens like "<image>" and "</image>"
165
+ batch_input_ids = []
166
+ if not encode_only:
167
+ batch_attention_mask = []
168
+ else:
169
+ batch_attention_mask = None
170
+
171
+ for t, image_bytes in zip(text, image_bytes_list):
172
+ text_splits = t.split(self.image_placeholder)
173
+ if len(text_splits) != len(image_bytes) + 1:
174
+ raise ValueError(
175
+ f"The number of image tokens should be equal to the number of images, "
176
+ f"but got {len(text_splits)} and {len(image_bytes) + 1}"
177
+ )
178
+
179
+ input_ids = [self.tokenizer.bos_token_id]
180
+ for i, text_part in enumerate(text_splits):
181
+ # each text part must be non-empty because we added markers around placeholders
182
+ split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
183
+ input_ids.extend(split_tokens)
184
+ # Add image bytes after each text part except the last one
185
+ if i < len(image_bytes):
186
+ input_ids.append(self.t2v_token_id)
187
+ input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
188
+ input_ids.append(self.v2t_token_id)
189
+
190
+ if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
191
+ input_ids = input_ids[:-1]
192
+
193
+ batch_input_ids.append(input_ids)
194
+ if not encode_only:
195
+ batch_attention_mask.append([1] * len(input_ids))
196
+
197
+ if not encode_only:
198
+ # 4. return batch of features
199
+ inputs = BatchFeature({
200
+ "input_ids": batch_input_ids,
201
+ "attention_mask": batch_attention_mask
202
+ }, tensor_type=return_tensors)
203
+ return inputs
204
+ # # Pad sequences
205
+ # padded_inputs = self.tokenizer.pad(
206
+ # {"input_ids": batch_input_ids},
207
+ # padding=True,
208
+ # return_attention_mask=True,
209
+ # return_tensors=return_tensors,
210
+ # )
211
+ # return BatchFeature(data=padded_inputs)
212
+ else:
213
+ return batch_input_ids
214
+
215
+ def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
216
+ image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
217
+ image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
218
+ return image_bytes
219
+
220
+ def batch_decode(self, sequences, **kwargs):
221
+ """
222
+ This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
223
+ refer to the docstring of this method for more information.
224
+ """
225
+ rets = [self.decode(seq, **kwargs) for seq in sequences]
226
+ return tuple(map(list, zip(*rets)))
227
+
228
+ def decode(self, token_ids, **kwargs):
229
+ """
230
+ Decodes a sequence of input_ids, handling image tokens separately.
231
+ Returns a tuple of (decoded_text, images), where images is a list of bytes.
232
+ """
233
+ if kwargs and "jpeg_quality" in kwargs:
234
+ kwargs = kwargs.copy()
235
+ jpeg_quality = kwargs.pop("jpeg_quality")
236
+ else:
237
+ jpeg_quality = None
238
+
239
+ token_ids = to_py_obj(token_ids)
240
+ # Find indices of t2v_token_id and v2t_token_id
241
+ t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
242
+ v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
243
+
244
+ # Check for correct pairing of t2v and v2t tokens
245
+ if len(t2v_indices) != len(v2t_indices):
246
+ raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
247
+
248
+ # Ensure t2v and v2t tokens are in the correct order
249
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
250
+ if t2v_idx >= v2t_idx:
251
+ raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
252
+
253
+ # Initialize the start index
254
+ images = []
255
+ decoded_text = ""
256
+
257
+ start = 0
258
+ # Iterate over pairs of t2v and v2t indices
259
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
260
+ # Decode text tokens before the image
261
+ text_token_ids = token_ids[start:t2v_idx]
262
+ if len(text_token_ids) > 0:
263
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
264
+
265
+ # Insert image placeholder
266
+ decoded_text += self.image_placeholder
267
+
268
+ # Extract image tokens and convert them to bytes
269
+ image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
270
+ image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
271
+ images.append(image_bytes)
272
+
273
+ # Update the start index to the token after v2t_token_id
274
+ start = v2t_idx + 1
275
+
276
+ # Decode any remaining text tokens after the last image
277
+ if start < len(token_ids):
278
+ text_token_ids = token_ids[start:]
279
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
280
+
281
+ return decoded_text, images
282
+
283
+ @property
284
+ def model_input_names(self):
285
+ tokenizer_input_names = self.tokenizer.model_input_names
286
+ image_processor_input_names = self.image_processor.model_input_names
287
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
4
+ },
5
+ "processor_class": "EvaByteProcessor"
6
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/special_tokens_map.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<repo_name>",
4
+ "<file_sep>",
5
+ "<t2v_token>",
6
+ "<v2t_token>",
7
+ "<|start_header_id|>",
8
+ "<|end_header_id|>",
9
+ "<|eot_id|>",
10
+ "<extra_id_12>",
11
+ "<extra_id_13>",
12
+ "<extra_id_14>",
13
+ "<extra_id_15>",
14
+ "<extra_id_16>",
15
+ "<extra_id_17>",
16
+ "<extra_id_18>",
17
+ "<extra_id_19>",
18
+ "<extra_id_20>",
19
+ "<extra_id_21>",
20
+ "<extra_id_22>",
21
+ "<extra_id_23>",
22
+ "<extra_id_24>",
23
+ "<extra_id_25>",
24
+ "<extra_id_26>",
25
+ "<extra_id_27>",
26
+ "<extra_id_28>",
27
+ "<extra_id_29>",
28
+ "<extra_id_30>",
29
+ "<extra_id_31>",
30
+ "<extra_id_32>",
31
+ "<extra_id_33>",
32
+ "<extra_id_34>",
33
+ "<extra_id_35>",
34
+ "<extra_id_36>",
35
+ "<extra_id_37>",
36
+ "<extra_id_38>",
37
+ "<extra_id_39>",
38
+ "<extra_id_40>",
39
+ "<extra_id_41>",
40
+ "<extra_id_42>",
41
+ "<extra_id_43>",
42
+ "<extra_id_44>",
43
+ "<extra_id_45>",
44
+ "<extra_id_46>",
45
+ "<extra_id_47>",
46
+ "<extra_id_48>",
47
+ "<extra_id_49>",
48
+ "<extra_id_50>",
49
+ "<extra_id_51>",
50
+ "<extra_id_52>",
51
+ "<extra_id_53>",
52
+ "<extra_id_54>",
53
+ "<extra_id_55>",
54
+ "<extra_id_56>",
55
+ "<extra_id_57>",
56
+ "<extra_id_58>",
57
+ "<extra_id_59>",
58
+ "<extra_id_60>",
59
+ "<extra_id_61>",
60
+ "<extra_id_62>",
61
+ "<extra_id_63>"
62
+ ],
63
+ "bos_token": {
64
+ "content": "<bos>",
65
+ "lstrip": false,
66
+ "normalized": true,
67
+ "rstrip": false,
68
+ "single_word": false
69
+ },
70
+ "eos_token": {
71
+ "content": "<eos>",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false
76
+ },
77
+ "pad_token": {
78
+ "content": "<pad>",
79
+ "lstrip": false,
80
+ "normalized": true,
81
+ "rstrip": false,
82
+ "single_word": false
83
+ },
84
+ "sep_token": {
85
+ "content": "<sep>",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false
90
+ },
91
+ "unk_token": {
92
+ "content": "<unk>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false
97
+ }
98
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenization_evabyte.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """ Tokenization class for model EvaByte."""
4
+
5
+
6
+ from typing import List, Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ chat_template = """
16
+ {{- bos_token }}
17
+ {%- if messages[0]['role'] == 'system' %}
18
+ {%- set system_message = messages[0]['content'] %}
19
+ {%- set messages = messages[1:] %}
20
+ {%- else %}
21
+ {%- set system_message = "" %}
22
+ {%- endif %}
23
+
24
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
25
+
26
+ {%- for message in messages %}
27
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
28
+ {{- raise_exception('Conversation roles must be user or assistant') }}
29
+ {%- endif %}
30
+
31
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
32
+ {%- endfor %}
33
+
34
+ {%- if add_generation_prompt %}
35
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
36
+ {%- endif %}
37
+ """
38
+
39
+ class EvaByteTokenizer(PreTrainedTokenizer):
40
+ def __init__(
41
+ self,
42
+ bos_token="<bos>",
43
+ eos_token="<eos>",
44
+ unk_token="<unk>",
45
+ sep_token="<sep>",
46
+ pad_token="<pad>",
47
+ extra_ids=59,
48
+ additional_special_tokens=None,
49
+ clean_up_tokenization_spaces=False,
50
+ **kwargs,
51
+ ) -> None:
52
+ num_base_special_tokens = 5
53
+ # Add extra_ids to the special token list
54
+ if extra_ids > 0 and additional_special_tokens is None:
55
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
56
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
57
+ # Check that we have the right number of extra_id special tokens
58
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
59
+ if extra_tokens != extra_ids:
60
+ raise ValueError(
61
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
62
+ " provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
63
+ " extra_ids tokens"
64
+ )
65
+
66
+ #### override some reserved tokens to support chat template
67
+ for i, token in enumerate(additional_special_tokens):
68
+ if token == "<extra_id_5>":
69
+ token = "<repo_name>"
70
+ elif token == "<extra_id_6>":
71
+ token = "<file_sep>"
72
+ elif token == "<extra_id_7>":
73
+ token = "<t2v_token>"
74
+ elif token == "<extra_id_8>":
75
+ token = "<v2t_token>"
76
+ elif token == "<extra_id_9>":
77
+ token = "<|start_header_id|>"
78
+ elif token == "<extra_id_10>":
79
+ token = "<|end_header_id|>"
80
+ elif token == "<extra_id_11>":
81
+ token = "<|eot_id|>"
82
+ additional_special_tokens[i] = token
83
+
84
+ # lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
85
+ # this would be important for the byte tokenizer
86
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
87
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
88
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
89
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
90
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
91
+
92
+ self._added_tokens_decoder = {
93
+ 0: pad_token,
94
+ 1: bos_token,
95
+ 2: eos_token,
96
+ 3: unk_token, # unk_token is a placeholder
97
+ 4: sep_token,
98
+ **{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
99
+ }
100
+ self.offset = len(self._added_tokens_decoder)
101
+ self._utf_vocab_size = 2**8 # utf is 8 bits
102
+ self.add_bos_token = True
103
+ self.add_eos_token = False
104
+ super().__init__(
105
+ pad_token=pad_token,
106
+ bos_token=bos_token,
107
+ eos_token=eos_token,
108
+ unk_token=unk_token,
109
+ sep_token=sep_token,
110
+ extra_ids=0,
111
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
112
+ additional_special_tokens=additional_special_tokens,
113
+ **kwargs,
114
+ )
115
+ self.chat_template = chat_template
116
+
117
+
118
+ @property
119
+ def vocab_size(self):
120
+ return self._utf_vocab_size
121
+
122
+ def get_vocab(self):
123
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
124
+ vocab.update(self.added_tokens_encoder)
125
+ return vocab
126
+
127
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
128
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
129
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
130
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
131
+
132
+ output = bos_token_id + token_ids_0 + eos_token_id
133
+
134
+ if token_ids_1 is not None:
135
+ output = output + bos_token_id + token_ids_1 + eos_token_id
136
+
137
+ return output
138
+
139
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
140
+ def get_special_tokens_mask(
141
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
142
+ ) -> List[int]:
143
+ """
144
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
145
+ special tokens using the tokenizer `prepare_for_model` method.
146
+
147
+ Args:
148
+ token_ids_0 (`List[int]`):
149
+ List of IDs.
150
+ token_ids_1 (`List[int]`, *optional*):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
153
+ Whether or not the token list is already formatted with special tokens for the model.
154
+
155
+ Returns:
156
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
157
+ """
158
+ if already_has_special_tokens:
159
+ return super().get_special_tokens_mask(
160
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
161
+ )
162
+
163
+ bos_token_id = [1] if self.add_bos_token else []
164
+ eos_token_id = [1] if self.add_eos_token else []
165
+
166
+ if token_ids_1 is None:
167
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
168
+ return (
169
+ bos_token_id
170
+ + ([0] * len(token_ids_0))
171
+ + eos_token_id
172
+ + bos_token_id
173
+ + ([0] * len(token_ids_1))
174
+ + eos_token_id
175
+ )
176
+
177
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
178
+ def create_token_type_ids_from_sequences(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
183
+ sequence pair mask has the following format:
184
+
185
+ ```
186
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
187
+ | first sequence | second sequence |
188
+ ```
189
+
190
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
191
+
192
+ Args:
193
+ token_ids_0 (`List[int]`):
194
+ List of ids.
195
+ token_ids_1 (`List[int]`, *optional*):
196
+ Optional second list of IDs for sequence pairs.
197
+
198
+ Returns:
199
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
200
+ """
201
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
202
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
203
+
204
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
205
+
206
+ if token_ids_1 is not None:
207
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
208
+
209
+ return output
210
+
211
+ def _tokenize(self, text: str) -> List[str]:
212
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
213
+ tokens = [chr(i) for i in text.encode("utf-8")]
214
+ return tokens
215
+
216
+ def _convert_token_to_id(self, token):
217
+ """Converts a token (str) in an id using the vocab."""
218
+
219
+ if len(token) != 1:
220
+ token_id = None
221
+ else:
222
+ token_id = ord(token) + self.offset
223
+
224
+ return token_id
225
+
226
+ def _convert_id_to_token(self, index):
227
+ """Converts an index (integer) to a byte (str) using the vocab."""
228
+ token = chr(index - self.offset)
229
+ return token
230
+
231
+ def convert_tokens_to_string(self, tokens):
232
+ """Converts a sequence of bytes (string) to a single string."""
233
+ bstring = b""
234
+ for token in tokens:
235
+ if token in self.added_tokens_decoder:
236
+ tok_string = self.added_tokens_decoder[token].encode("utf-8")
237
+ elif token in self.added_tokens_encoder:
238
+ tok_string = token.encode("utf-8")
239
+ else:
240
+ tok_string = bytes([ord(token)])
241
+ bstring += tok_string
242
+ string = bstring.decode("utf-8", errors="ignore")
243
+ return string
244
+
245
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
246
+ return ()
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-40000/tokenizer_config.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<sep>",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<repo_name>",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": false
50
+ },
51
+ "6": {
52
+ "content": "<file_sep>",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ },
59
+ "7": {
60
+ "content": "<t2v_token>",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": false
66
+ },
67
+ "8": {
68
+ "content": "<v2t_token>",
69
+ "lstrip": false,
70
+ "normalized": true,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": false
74
+ },
75
+ "9": {
76
+ "content": "<|start_header_id|>",
77
+ "lstrip": false,
78
+ "normalized": true,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "10": {
84
+ "content": "<|end_header_id|>",
85
+ "lstrip": false,
86
+ "normalized": true,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "11": {
92
+ "content": "<|eot_id|>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "12": {
100
+ "content": "<extra_id_12>",
101
+ "lstrip": false,
102
+ "normalized": true,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "13": {
108
+ "content": "<extra_id_13>",
109
+ "lstrip": false,
110
+ "normalized": true,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "14": {
116
+ "content": "<extra_id_14>",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": false
122
+ },
123
+ "15": {
124
+ "content": "<extra_id_15>",
125
+ "lstrip": false,
126
+ "normalized": true,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": false
130
+ },
131
+ "16": {
132
+ "content": "<extra_id_16>",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": false
138
+ },
139
+ "17": {
140
+ "content": "<extra_id_17>",
141
+ "lstrip": false,
142
+ "normalized": true,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": false
146
+ },
147
+ "18": {
148
+ "content": "<extra_id_18>",
149
+ "lstrip": false,
150
+ "normalized": true,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": false
154
+ },
155
+ "19": {
156
+ "content": "<extra_id_19>",
157
+ "lstrip": false,
158
+ "normalized": true,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "20": {
164
+ "content": "<extra_id_20>",
165
+ "lstrip": false,
166
+ "normalized": true,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "21": {
172
+ "content": "<extra_id_21>",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": false
178
+ },
179
+ "22": {
180
+ "content": "<extra_id_22>",
181
+ "lstrip": false,
182
+ "normalized": true,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": false
186
+ },
187
+ "23": {
188
+ "content": "<extra_id_23>",
189
+ "lstrip": false,
190
+ "normalized": true,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": false
194
+ },
195
+ "24": {
196
+ "content": "<extra_id_24>",
197
+ "lstrip": false,
198
+ "normalized": true,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": false
202
+ },
203
+ "25": {
204
+ "content": "<extra_id_25>",
205
+ "lstrip": false,
206
+ "normalized": true,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": false
210
+ },
211
+ "26": {
212
+ "content": "<extra_id_26>",
213
+ "lstrip": false,
214
+ "normalized": true,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": false
218
+ },
219
+ "27": {
220
+ "content": "<extra_id_27>",
221
+ "lstrip": false,
222
+ "normalized": true,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": false
226
+ },
227
+ "28": {
228
+ "content": "<extra_id_28>",
229
+ "lstrip": false,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": false
234
+ },
235
+ "29": {
236
+ "content": "<extra_id_29>",
237
+ "lstrip": false,
238
+ "normalized": true,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": false
242
+ },
243
+ "30": {
244
+ "content": "<extra_id_30>",
245
+ "lstrip": false,
246
+ "normalized": true,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": false
250
+ },
251
+ "31": {
252
+ "content": "<extra_id_31>",
253
+ "lstrip": false,
254
+ "normalized": true,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": false
258
+ },
259
+ "32": {
260
+ "content": "<extra_id_32>",
261
+ "lstrip": false,
262
+ "normalized": true,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": false
266
+ },
267
+ "33": {
268
+ "content": "<extra_id_33>",
269
+ "lstrip": false,
270
+ "normalized": true,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": false
274
+ },
275
+ "34": {
276
+ "content": "<extra_id_34>",
277
+ "lstrip": false,
278
+ "normalized": true,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": false
282
+ },
283
+ "35": {
284
+ "content": "<extra_id_35>",
285
+ "lstrip": false,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": false
290
+ },
291
+ "36": {
292
+ "content": "<extra_id_36>",
293
+ "lstrip": false,
294
+ "normalized": true,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": false
298
+ },
299
+ "37": {
300
+ "content": "<extra_id_37>",
301
+ "lstrip": false,
302
+ "normalized": true,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": false
306
+ },
307
+ "38": {
308
+ "content": "<extra_id_38>",
309
+ "lstrip": false,
310
+ "normalized": true,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": false
314
+ },
315
+ "39": {
316
+ "content": "<extra_id_39>",
317
+ "lstrip": false,
318
+ "normalized": true,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": false
322
+ },
323
+ "40": {
324
+ "content": "<extra_id_40>",
325
+ "lstrip": false,
326
+ "normalized": true,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": false
330
+ },
331
+ "41": {
332
+ "content": "<extra_id_41>",
333
+ "lstrip": false,
334
+ "normalized": true,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": false
338
+ },
339
+ "42": {
340
+ "content": "<extra_id_42>",
341
+ "lstrip": false,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": false
346
+ },
347
+ "43": {
348
+ "content": "<extra_id_43>",
349
+ "lstrip": false,
350
+ "normalized": true,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": false
354
+ },
355
+ "44": {
356
+ "content": "<extra_id_44>",
357
+ "lstrip": false,
358
+ "normalized": true,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": false
362
+ },
363
+ "45": {
364
+ "content": "<extra_id_45>",
365
+ "lstrip": false,
366
+ "normalized": true,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": false
370
+ },
371
+ "46": {
372
+ "content": "<extra_id_46>",
373
+ "lstrip": false,
374
+ "normalized": true,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": false
378
+ },
379
+ "47": {
380
+ "content": "<extra_id_47>",
381
+ "lstrip": false,
382
+ "normalized": true,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": false
386
+ },
387
+ "48": {
388
+ "content": "<extra_id_48>",
389
+ "lstrip": false,
390
+ "normalized": true,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": false
394
+ },
395
+ "49": {
396
+ "content": "<extra_id_49>",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": false
402
+ },
403
+ "50": {
404
+ "content": "<extra_id_50>",
405
+ "lstrip": false,
406
+ "normalized": true,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": false
410
+ },
411
+ "51": {
412
+ "content": "<extra_id_51>",
413
+ "lstrip": false,
414
+ "normalized": true,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": false
418
+ },
419
+ "52": {
420
+ "content": "<extra_id_52>",
421
+ "lstrip": false,
422
+ "normalized": true,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": false
426
+ },
427
+ "53": {
428
+ "content": "<extra_id_53>",
429
+ "lstrip": false,
430
+ "normalized": true,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": false
434
+ },
435
+ "54": {
436
+ "content": "<extra_id_54>",
437
+ "lstrip": false,
438
+ "normalized": true,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": false
442
+ },
443
+ "55": {
444
+ "content": "<extra_id_55>",
445
+ "lstrip": false,
446
+ "normalized": true,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": false
450
+ },
451
+ "56": {
452
+ "content": "<extra_id_56>",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": false
458
+ },
459
+ "57": {
460
+ "content": "<extra_id_57>",
461
+ "lstrip": false,
462
+ "normalized": true,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": false
466
+ },
467
+ "58": {
468
+ "content": "<extra_id_58>",
469
+ "lstrip": false,
470
+ "normalized": true,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": false
474
+ },
475
+ "59": {
476
+ "content": "<extra_id_59>",
477
+ "lstrip": false,
478
+ "normalized": true,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": false
482
+ },
483
+ "60": {
484
+ "content": "<extra_id_60>",
485
+ "lstrip": false,
486
+ "normalized": true,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": false
490
+ },
491
+ "61": {
492
+ "content": "<extra_id_61>",
493
+ "lstrip": false,
494
+ "normalized": true,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": false
498
+ },
499
+ "62": {
500
+ "content": "<extra_id_62>",
501
+ "lstrip": false,
502
+ "normalized": true,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": false
506
+ },
507
+ "63": {
508
+ "content": "<extra_id_63>",
509
+ "lstrip": false,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": false
514
+ }
515
+ },
516
+ "additional_special_tokens": [
517
+ "<repo_name>",
518
+ "<file_sep>",
519
+ "<t2v_token>",
520
+ "<v2t_token>",
521
+ "<|start_header_id|>",
522
+ "<|end_header_id|>",
523
+ "<|eot_id|>",
524
+ "<extra_id_12>",
525
+ "<extra_id_13>",
526
+ "<extra_id_14>",
527
+ "<extra_id_15>",
528
+ "<extra_id_16>",
529
+ "<extra_id_17>",
530
+ "<extra_id_18>",
531
+ "<extra_id_19>",
532
+ "<extra_id_20>",
533
+ "<extra_id_21>",
534
+ "<extra_id_22>",
535
+ "<extra_id_23>",
536
+ "<extra_id_24>",
537
+ "<extra_id_25>",
538
+ "<extra_id_26>",
539
+ "<extra_id_27>",
540
+ "<extra_id_28>",
541
+ "<extra_id_29>",
542
+ "<extra_id_30>",
543
+ "<extra_id_31>",
544
+ "<extra_id_32>",
545
+ "<extra_id_33>",
546
+ "<extra_id_34>",
547
+ "<extra_id_35>",
548
+ "<extra_id_36>",
549
+ "<extra_id_37>",
550
+ "<extra_id_38>",
551
+ "<extra_id_39>",
552
+ "<extra_id_40>",
553
+ "<extra_id_41>",
554
+ "<extra_id_42>",
555
+ "<extra_id_43>",
556
+ "<extra_id_44>",
557
+ "<extra_id_45>",
558
+ "<extra_id_46>",
559
+ "<extra_id_47>",
560
+ "<extra_id_48>",
561
+ "<extra_id_49>",
562
+ "<extra_id_50>",
563
+ "<extra_id_51>",
564
+ "<extra_id_52>",
565
+ "<extra_id_53>",
566
+ "<extra_id_54>",
567
+ "<extra_id_55>",
568
+ "<extra_id_56>",
569
+ "<extra_id_57>",
570
+ "<extra_id_58>",
571
+ "<extra_id_59>",
572
+ "<extra_id_60>",
573
+ "<extra_id_61>",
574
+ "<extra_id_62>",
575
+ "<extra_id_63>"
576
+ ],
577
+ "auto_map": {
578
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor",
579
+ "AutoTokenizer": [
580
+ "tokenization_evabyte.EvaByteTokenizer",
581
+ null
582
+ ]
583
+ },
584
+ "bos_token": "<bos>",
585
+ "chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
586
+ "clean_up_tokenization_spaces": false,
587
+ "eos_token": "<eos>",
588
+ "extra_ids": 0,
589
+ "extra_special_tokens": {},
590
+ "model_max_length": 1000000000000000019884624838656,
591
+ "pad_token": "<pad>",
592
+ "processor_class": "EvaByteProcessor",
593
+ "sep_token": "<sep>",
594
+ "tokenizer_class": "EvaByteTokenizer",
595
+ "unk_token": "<unk>"
596
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # EvaByte Model Card
5
+
6
+ **EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
7
+
8
+ ## Model Resources
9
+
10
+ - **Repository:** https://github.com/openevabyte/evabyte
11
+ - **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
12
+ - **Paper:** Coming soon
13
+
14
+ ## Model Details
15
+
16
+ EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
17
+
18
+ | Stage | Model |
19
+ |:----- |:-----|
20
+ | Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
21
+ | Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
22
+ | SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
23
+
24
+
25
+ ## Usage
26
+
27
+ **Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
28
+
29
+ The code snippet below demonstrates EvaByte-6.5B for completion:
30
+
31
+ ```python
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ import torch
34
+
35
+ # Load model and tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
38
+
39
+ prompt = "The quick brown fox jumps "
40
+
41
+ # Tokenize input
42
+ # Option 1: standard HF tokenizer interface
43
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
44
+
45
+ # Option 2: Direct UTF-8 byte encoding with offset
46
+ # Note: Each byte is offset by 64 with <bos> prepended.
47
+ input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
48
+
49
+ # byte-by-byte generation (default)
50
+ generation_output = model.generate(
51
+ input_ids=input_ids,
52
+ max_new_tokens=32
53
+ )
54
+ # alternatively, use faster multibyte generation
55
+ generation_output = model.multi_byte_generate(
56
+ input_ids=input_ids,
57
+ max_new_tokens=32
58
+ )
59
+
60
+ # Decode and print the output
61
+ response = tokenizer.decode(
62
+ generation_output[0][input_ids.shape[1]:],
63
+ skip_special_tokens=False,
64
+ clean_up_tokenization_spaces=False
65
+ )
66
+ print(response)
67
+ # Sample output:
68
+ # over the lazy dog.\n\nThe quick
69
+ ```
70
+
71
+ ### ⚙️ Generation Modes
72
+
73
+ EvaByte supports two generation interfaces:
74
+ - `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
75
+ - `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
76
+ - `input_ids`: the input byte ids.
77
+ - `temperature`: the temperature for sampling.
78
+ - `max_length`: the maximum length of the generated sequence.
79
+ - `max_new_tokens`: the maximum number of new bytes to generate.
80
+ - `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
81
+ - `top_p`: the top-p parameter for sampling.
82
+ - `do_sample`: greedy decoding or sampling.
83
+
84
+ **Notes and Limitations:**
85
+ - `device_map="auto"` is not supported for >2 GPUs.
86
+ - Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
87
+ - `torch_dtype=torch.bfloat16` is required.
88
+ - The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
89
+
90
+ ## Bias, Risks, and Limitations
91
+ As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
92
+
93
+ ## Evaluation
94
+
95
+ For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
96
+
97
+ ## Citation
98
+ ```bibtex
99
+ @misc{evabyte,
100
+ title = {EvaByte: Efficient Byte-level Language Models at Scale},
101
+ url = {https://hkunlp.github.io/blog/2025/evabyte},
102
+ author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
103
+ year = {2025}
104
+ }
105
+ ```
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": null,
3
+ "architectures": [
4
+ "EvaByteForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_class": "eva",
8
+ "attention_dropout": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_evabyte.EvaByteConfig",
11
+ "AutoModelForCausalLM": "modeling_evabyte.EvaByteForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "chunk_size": 16,
15
+ "eos_token_id": 2,
16
+ "fp32_ln": true,
17
+ "fp32_logits": true,
18
+ "fp32_skip_add": false,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 5120,
21
+ "init_cutoff_factor": null,
22
+ "init_fn": "v2",
23
+ "init_std": 0.01275,
24
+ "initializer_range": 0.01275,
25
+ "intermediate_size": 16384,
26
+ "lazy_init": true,
27
+ "max_position_embeddings": 16384,
28
+ "max_seq_length": 16384,
29
+ "mixedp_attn": true,
30
+ "model_type": "evabyte",
31
+ "norm_add_unit_offset": true,
32
+ "num_attention_heads": 40,
33
+ "num_chunks": null,
34
+ "num_hidden_layers": 40,
35
+ "num_key_value_heads": 40,
36
+ "num_pred_heads": 1,
37
+ "pad_token_id": 0,
38
+ "return_dict": false,
39
+ "rms_norm_eps": 1e-06,
40
+ "rope_scaling": null,
41
+ "rope_theta": 100000.0,
42
+ "tie_word_embeddings": false,
43
+ "torch_dtype": "bfloat16",
44
+ "transformers_version": "4.47.1",
45
+ "use_cache": true,
46
+ "vocab_size": 320,
47
+ "window_size": 2048
48
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/configuration_evabyte.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EvaByte configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class EvaByteConfig(PretrainedConfig):
6
+ model_type = "evabyte"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=320,
12
+ hidden_size=4096,
13
+ intermediate_size=11008,
14
+ num_hidden_layers=32,
15
+ num_attention_heads=32,
16
+ num_key_value_heads=None,
17
+ hidden_act="silu",
18
+ max_position_embeddings=2048,
19
+ initializer_range=0.02,
20
+ rms_norm_eps=1e-6,
21
+ use_cache=True,
22
+ pad_token_id=None,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ norm_add_unit_offset=False,
31
+ init_fn="mitchell",
32
+ init_std=0.006,
33
+ init_cutoff_factor=None,
34
+ attention_class="mha",
35
+ window_size=512,
36
+ num_chunks=None,
37
+ chunk_size=256,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+
47
+ # for backward compatibility
48
+ if num_key_value_heads is None:
49
+ num_key_value_heads = num_attention_heads
50
+
51
+ self.num_key_value_heads = num_key_value_heads
52
+ self.hidden_act = hidden_act
53
+ self.initializer_range = initializer_range
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.use_cache = use_cache
56
+ self.rope_theta = rope_theta
57
+ self.rope_scaling = rope_scaling
58
+ self._rope_scaling_validation()
59
+ self.attention_bias = attention_bias
60
+ self.attention_dropout = attention_dropout
61
+
62
+ self.norm_add_unit_offset = norm_add_unit_offset
63
+ self.init_fn = init_fn
64
+ self.init_std = init_std
65
+ self.init_cutoff_factor = init_cutoff_factor
66
+
67
+ # Attention-specific paramters
68
+ self.attention_class = attention_class
69
+ self.window_size = window_size
70
+ self.num_chunks = num_chunks
71
+ self.chunk_size = chunk_size
72
+
73
+ super().__init__(
74
+ pad_token_id=pad_token_id,
75
+ bos_token_id=bos_token_id,
76
+ eos_token_id=eos_token_id,
77
+ tie_word_embeddings=tie_word_embeddings,
78
+ **kwargs,
79
+ )
80
+
81
+ def _rope_scaling_validation(self):
82
+ """
83
+ Validate the `rope_scaling` configuration.
84
+ """
85
+ if self.rope_scaling is None:
86
+ return
87
+
88
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
89
+ raise ValueError(
90
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
91
+ )
92
+ rope_scaling_type = self.rope_scaling.get("type", None)
93
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
94
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
95
+ raise ValueError(
96
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
97
+ )
98
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
99
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from .eva_agg_kernel import eva_agg_func_triton
6
+ from .eva_prep_kv_kernel import eva_prep_kv_func_triton
7
+ try:
8
+ import triton
9
+ USE_TRITON_IMPL = True
10
+ except ImportError:
11
+ USE_TRITON_IMPL = False
12
+ raise ImportError("Triton is not installed. Please install it by running `pip install triton`.")
13
+
14
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Rotates half the hidden dims (last dim) of the input.
17
+ Args:
18
+ x: Rotary embedded tensor
19
+ Return:
20
+ Tensor with half of last dim negated and rotated to the front.
21
+ """
22
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
23
+ return torch.cat((-x2, x1), dim=-1)
24
+
25
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
26
+ position_ids: torch.Tensor) -> torch.Tensor:
27
+ """
28
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
29
+
30
+ The legends for dimensions are defined as:
31
+ num_heads: number of attention heads
32
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
33
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
34
+ maximum lenghth, e.g. the length of static sequence length of KV cache
35
+
36
+
37
+ Args:
38
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
39
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
40
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
41
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
42
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
43
+ (batch_size, current_seq_len).
44
+
45
+ Returns:
46
+ Embedded query and key tensor of same size as input.
47
+
48
+ """
49
+ bs, nheads, cur_seq_len, head_dim = q.shape
50
+ assert len(
51
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
52
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
53
+ assert list(k.shape[2:]) == [cur_seq_len,
54
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
55
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
56
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
57
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
58
+
59
+ q_embed = (q * cos) + (rotate_half(q) * sin)
60
+ k_embed = (k * cos) + (rotate_half(k) * sin)
61
+ return q_embed, k_embed
62
+
63
+ class EvaAttention(nn.Module):
64
+ """
65
+ Causal EVA for language modeling.
66
+ """
67
+
68
+ def __init__(self, config, layer_idx: Optional[int] = None):
69
+ super().__init__()
70
+ self.config = config
71
+ self.layer_idx = layer_idx
72
+ self.hidden_size = config.hidden_size
73
+ self.num_heads = config.num_attention_heads
74
+ self.head_dim = self.hidden_size // self.num_heads
75
+ self.head_dim_scaling = self.head_dim ** -0.5
76
+
77
+ self.max_position_embeddings = config.max_position_embeddings
78
+
79
+ if (self.head_dim * self.num_heads) != self.hidden_size:
80
+ raise ValueError(
81
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
82
+ f" and `num_heads`: {self.num_heads})."
83
+ )
84
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
85
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
86
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
87
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
88
+
89
+ self.window_size = config.window_size
90
+
91
+ self.num_chunks = config.num_chunks
92
+ self.chunk_size = config.chunk_size
93
+ if self.chunk_size is not None:
94
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
95
+ # chunk_size overrides the number of landmarks
96
+ self.num_chunks = None
97
+
98
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
99
+ self.adaptive_phi = nn.Parameter(
100
+ torch.randn(
101
+ 1,
102
+ self.num_heads,
103
+ 1,
104
+ 1,
105
+ self.head_dim
106
+ ).clamp(-1., 1.) * self.head_dim_scaling
107
+ )
108
+ self.adaptive_mu_k = nn.Parameter(
109
+ torch.randn(
110
+ 1,
111
+ self.num_heads,
112
+ 1,
113
+ 1,
114
+ self.head_dim
115
+ ).clamp(-1., 1.) * self.head_dim_scaling
116
+ )
117
+
118
+ def _triton_forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
124
+ output_attentions: bool = False,
125
+ use_cache: bool = False,
126
+ cos: Optional[torch.Tensor] = None,
127
+ sin: Optional[torch.Tensor] = None,
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
129
+ assert not output_attentions
130
+ bsz, q_len, _ = hidden_states.size()
131
+
132
+ if use_cache:
133
+ if past_key_value is None:
134
+ raise ValueError
135
+ assert isinstance(attention_mask, tuple)
136
+
137
+ # infer the model's running mode
138
+ is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
139
+ is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
140
+
141
+ if is_prefilling:
142
+ assert len(attention_mask) == 2
143
+ window_mask, intra_chunk_mask = attention_mask
144
+ chunk_mask = None
145
+ elif is_decoding:
146
+ assert len(attention_mask) == 3
147
+ window_mask, intra_chunk_mask, chunk_mask = attention_mask
148
+ else:
149
+ if attention_mask is not None:
150
+ assert isinstance(attention_mask, tuple) and len(attention_mask) == 3
151
+ window_mask, chunk_mask, intra_chunk_mask = attention_mask
152
+ else:
153
+ window_mask, chunk_mask, intra_chunk_mask = None, None, None
154
+
155
+ ############################################
156
+ # compute q, k, v from hidden states
157
+ ############################################
158
+ # [b, h, q_len, d]
159
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
160
+ # [b, h, kv_len, d]
161
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
162
+ # [b, h, kv_len, d]
163
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
164
+
165
+ if use_cache:
166
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
167
+
168
+ ############################################
169
+ # apply rotary positional embeddings to q, k
170
+ ############################################
171
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
172
+
173
+ ############################################
174
+ # update and get cached singleton tokens
175
+ # update and cache k and v for calculating chunk-level RFAs
176
+ ############################################
177
+ if use_cache:
178
+ s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks(
179
+ k,
180
+ v,
181
+ self.layer_idx,
182
+ self.window_size,
183
+ )
184
+ else:
185
+ s_k, s_v = k, v
186
+ dump_k, dump_v = k, v
187
+
188
+ if use_cache:
189
+ singleton_mask, dump_rf_mask = past_key_value.update_mask(
190
+ s_mask=window_mask,
191
+ rf_mask=intra_chunk_mask,
192
+ layer_idx=self.layer_idx,
193
+ window_size=self.window_size,
194
+ )
195
+ else:
196
+ singleton_mask = window_mask
197
+ dump_rf_mask = intra_chunk_mask
198
+
199
+ if dump_k is not None and dump_v is not None:
200
+ # 1. in prefilling, the input shape is
201
+ # dump_k/dump_v: [b, h, n, d]
202
+ # rfa_k/rfa_v: [b, h, n // c, d]
203
+ # 2. in decoding, the input shape is
204
+ # k/v: [b, h, w, d]
205
+ # rfa_k/rfa_v: [b, h, w//c, d]
206
+ # 3. in forward inference; the seq_len is already divisible
207
+ rfa_k, rfa_v = eva_prep_kv_func_triton(
208
+ dump_k, dump_v,
209
+ self.adaptive_mu_k, self.adaptive_phi,
210
+ dump_rf_mask, self.head_dim_scaling, self.chunk_size
211
+ )
212
+ # rfa_mask = get_rfa_chunk_mask(dump_rf_mask)
213
+ if use_cache:
214
+ rfa_k, rfa_v = past_key_value.update_chunk_rfas(
215
+ rfa_k, rfa_v, self.layer_idx
216
+ )
217
+ elif use_cache:
218
+ # if there are not enough elements within the last chunk,
219
+ # we will only use the cached chunk-level RFAs
220
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
221
+ else:
222
+ rfa_k, rfa_v = None, None
223
+
224
+ ############################################
225
+ # compute the full attention output
226
+ ############################################
227
+ if is_prefilling:
228
+ # prefilling
229
+ # 1. in prefilling, the input shape is
230
+ # q: [b, h, n, d]
231
+ # k/v: [b, h, n, d]
232
+ # rfa_k/rfa_v: [b, h, n // c, d]
233
+ attn_output = eva_agg_func_triton(
234
+ q, s_k, s_v,
235
+ rfa_k, rfa_v,
236
+ singleton_mask, chunk_mask,
237
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
238
+ )
239
+ elif is_decoding:
240
+ # 2. in decoding, the input shape is
241
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
242
+ # k/v: [b, h, 1 + s, d]
243
+ # rfa_k/rfa_v: [b, h, n // c, d]
244
+ if rfa_k is not None and rfa_v is not None:
245
+ # we only take the chunk-level RFAs not in the current window
246
+ seen_seq_len = past_key_value.get_seq_length(self.layer_idx)
247
+ if seen_seq_len <= self.window_size:
248
+ agg_k = s_k
249
+ agg_v = s_v
250
+ attn_mask = singleton_mask
251
+ else:
252
+ # NOTE: we already updated the cache so the length now
253
+ # includes the current token
254
+ # we subtract 1 from seen_seq_len because we want
255
+ # if seen_seq_len = 2048 -> num_windows_seen_so_far = 0
256
+ # if seen_seq_len = 4096 -> num_windows_seen_so_far = 1
257
+ # if seen_seq_len = 4097 -> num_windows_seen_so_far = 2
258
+ # NOTE the cat order should be taken care of;
259
+ # should align with the order based on which
260
+ # the attention mask is constructed
261
+ num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size
262
+ agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
263
+ agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
264
+ if singleton_mask is not None:
265
+ assert chunk_mask is not None
266
+ attn_mask = torch.cat([singleton_mask, chunk_mask], dim=-1)
267
+ else:
268
+ attn_mask = singleton_mask
269
+ else:
270
+ agg_k = s_k
271
+ agg_v = s_v
272
+ attn_mask = singleton_mask
273
+ attn_output = F.scaled_dot_product_attention(
274
+ q, agg_k, agg_v,
275
+ attn_mask=attn_mask,
276
+ is_causal=False,
277
+ dropout_p=0.0,
278
+ scale=self.head_dim_scaling
279
+ )
280
+ else:
281
+ # 3. in single-forward inference
282
+ attn_output = eva_agg_func_triton(
283
+ q, s_k, s_v,
284
+ rfa_k, rfa_v,
285
+ singleton_mask, chunk_mask,
286
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
287
+ )
288
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
289
+ raise ValueError(
290
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
291
+ f" {attn_output.size()}"
292
+ )
293
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
294
+ attn_output = self.o_proj(attn_output)
295
+ attn_weights = None
296
+ return attn_output, attn_weights, past_key_value
297
+
298
+ def _multibyte_decoding_forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
304
+ output_attentions: bool = False,
305
+ use_cache: bool = False,
306
+ cos: Optional[torch.Tensor] = None,
307
+ sin: Optional[torch.Tensor] = None,
308
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
309
+ # during multi-byte forwarding, we only read caches and do not update them
310
+ assert not output_attentions
311
+ bsz, q_len, _ = hidden_states.size()
312
+
313
+ if use_cache and past_key_value is None:
314
+ raise ValueError
315
+
316
+ assert USE_TRITON_IMPL
317
+ assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool
318
+
319
+ assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0
320
+
321
+ ############################################
322
+ # compute q, k, v from hidden states
323
+ ############################################
324
+ # [b, h, q_len, d]
325
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
326
+ # [b, h, kv_len, d]
327
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
328
+ # [b, h, kv_len, d]
329
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
330
+
331
+ ############################################
332
+ # apply rotary positional embeddings to q, k
333
+ ############################################
334
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
335
+
336
+ ############################################
337
+ # update and get cached singleton tokens
338
+ ############################################
339
+ input_len = k.shape[-2]
340
+ window_pos = past_key_value.past_window_pos[self.layer_idx]
341
+ new_window_pos = window_pos + input_len
342
+
343
+ past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k
344
+ past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v
345
+ s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :]
346
+ s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :]
347
+
348
+ rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx)
349
+
350
+ ############################################
351
+ # compute the full attention output
352
+ ############################################
353
+ # 2. in decoding, the input shape is
354
+ # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction)
355
+ # k/v: [b, h, 1 + s, d]
356
+ # rfa_k/rfa_v: [b, h, n // c, d]
357
+ if rfa_k is not None and rfa_v is not None:
358
+ # NOTE the cat order should be taken care of;
359
+ # should align with the order based on which
360
+ # the attention mask is constructed
361
+ # agg_k = torch.cat([s_k, rfa_k], dim=-2)
362
+ # agg_v = torch.cat([s_v, rfa_v], dim=-2)
363
+ agg_k = torch.cat([rfa_k, s_k], dim=-2)
364
+ agg_v = torch.cat([rfa_v, s_v], dim=-2)
365
+ else:
366
+ agg_k = s_k
367
+ agg_v = s_v
368
+ attn_output = F.scaled_dot_product_attention(
369
+ q, agg_k, agg_v,
370
+ attn_mask=attention_mask,
371
+ is_causal=False,
372
+ dropout_p=0.0,
373
+ scale=self.head_dim_scaling
374
+ )
375
+
376
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
377
+ raise ValueError(
378
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
379
+ f" {attn_output.size()}"
380
+ )
381
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
382
+ attn_output = self.o_proj(attn_output)
383
+ attn_weights = None
384
+ return attn_output, attn_weights, past_key_value
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
392
+ output_attentions: bool = False,
393
+ use_cache: bool = False,
394
+ cos: Optional[torch.Tensor] = None,
395
+ sin: Optional[torch.Tensor] = None,
396
+ multibyte_decoding: Optional[bool] = False,
397
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
398
+ assert not output_attentions
399
+ if use_cache and past_key_value is None:
400
+ raise ValueError
401
+
402
+ assert USE_TRITON_IMPL
403
+ if use_cache and multibyte_decoding:
404
+ return self._multibyte_decoding_forward(
405
+ hidden_states,
406
+ attention_mask=attention_mask,
407
+ position_ids=position_ids,
408
+ past_key_value=past_key_value,
409
+ output_attentions=output_attentions,
410
+ use_cache=use_cache,
411
+ cos=cos,
412
+ sin=sin,
413
+ )
414
+ else:
415
+ return self._triton_forward(
416
+ hidden_states,
417
+ attention_mask=attention_mask,
418
+ position_ids=position_ids,
419
+ past_key_value=past_key_value,
420
+ output_attentions=output_attentions,
421
+ use_cache=use_cache,
422
+ cos=cos,
423
+ sin=sin,
424
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_agg_kernel.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
10
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
11
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
12
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
13
+ }
14
+ )
15
+ @triton.jit
16
+ def _bwd_eva_agg_kernel_dkdv(
17
+ Q,
18
+ K,
19
+ V,
20
+ WindowMask,
21
+ DO,
22
+ LSE,
23
+ DO_T_O,
24
+ DK,
25
+ DV,
26
+ softmax_scale,
27
+ stride_qb, stride_qh, stride_qm,
28
+ stride_kb, stride_kh, stride_kn,
29
+ stride_vb, stride_vh, stride_vn,
30
+ stride_window_mask_b, stride_window_mask_m,
31
+ stride_do_b, stride_do_h, stride_do_m,
32
+ stride_lse_b, stride_lse_h,
33
+ stride_do_t_o_b, stride_do_t_o_h,
34
+ stride_dk_b, stride_dk_h, stride_dk_n,
35
+ stride_dv_b, stride_dv_h, stride_dv_n,
36
+ nheads,
37
+ seqlen_q,
38
+ seqlen_k,
39
+ headdim,
40
+ WINDOW_SIZE: tl.constexpr,
41
+ MASK_TYPE: tl.constexpr,
42
+ BLOCK_HEADDIM: tl.constexpr,
43
+ EVEN_M: tl.constexpr,
44
+ EVEN_N: tl.constexpr,
45
+ EVEN_W: tl.constexpr,
46
+ EVEN_HEADDIM: tl.constexpr,
47
+ BLOCK_M: tl.constexpr,
48
+ BLOCK_N: tl.constexpr,
49
+ ):
50
+ off_bh = tl.program_id(1)
51
+ off_h = off_bh % nheads
52
+ off_b = off_bh // nheads
53
+
54
+ start_n = tl.program_id(0)
55
+ # determine which window the current KV block belongs to
56
+ offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
57
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
58
+ offs_m = tl.arange(0, BLOCK_M)
59
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
60
+
61
+ # initialize pointers
62
+ q_ptrs = (
63
+ Q +
64
+ off_b * stride_qb +
65
+ off_h * stride_qh +
66
+ offs_m[:, None] * stride_qm + offs_d[None, :]
67
+ )
68
+ k_ptrs = (
69
+ K +
70
+ off_b * stride_kb +
71
+ off_h * stride_kh +
72
+ offs_n[:, None] * stride_kn + offs_d[None, :]
73
+ )
74
+ v_ptrs = (
75
+ V +
76
+ off_b * stride_vb +
77
+ off_h * stride_vh +
78
+ offs_n[:, None] * stride_vn + offs_d[None, :]
79
+ )
80
+ do_ptrs = (
81
+ DO +
82
+ off_b * stride_do_b +
83
+ off_h * stride_do_h +
84
+ offs_m[:, None] * stride_do_m + offs_d[None, :]
85
+ )
86
+ do_t_o_ptrs = (
87
+ DO_T_O +
88
+ off_b * stride_do_t_o_b +
89
+ off_h * stride_do_t_o_h +
90
+ offs_m[:, None]
91
+ )
92
+ lse_ptrs = (
93
+ LSE +
94
+ off_b * stride_lse_b +
95
+ off_h * stride_lse_h +
96
+ offs_m[:, None]
97
+ )
98
+ if MASK_TYPE == 1:
99
+ m_ptrs = (
100
+ WindowMask +
101
+ off_b * stride_window_mask_b +
102
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
103
+ )
104
+ dk_ptrs = (
105
+ DK +
106
+ off_b * stride_dk_b +
107
+ off_h * stride_dk_h +
108
+ offs_n[:, None] * stride_dk_n + offs_d[None, :]
109
+ )
110
+ dv_ptrs = (
111
+ DV +
112
+ off_b * stride_dv_b +
113
+ off_h * stride_dv_h +
114
+ offs_n[:, None] * stride_dv_n + offs_d[None, :]
115
+ )
116
+
117
+ # 1. for singletons
118
+ # determine start and end of query block
119
+ begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
120
+ end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
121
+
122
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
123
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
124
+ if EVEN_N & EVEN_M:
125
+ if EVEN_HEADDIM:
126
+ k = tl.load(k_ptrs)
127
+ v = tl.load(v_ptrs)
128
+ else:
129
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
130
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
131
+ else:
132
+ if EVEN_HEADDIM:
133
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
134
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
135
+ else:
136
+ k = tl.load(
137
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
138
+ )
139
+ v = tl.load(
140
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
141
+ )
142
+ for start_m in range(begin_m, end_m, BLOCK_M):
143
+ start_m = tl.multiple_of(start_m, BLOCK_M)
144
+ # load q, do, and lse
145
+ if EVEN_M & EVEN_N:
146
+ if EVEN_HEADDIM:
147
+ q = tl.load(
148
+ q_ptrs + start_m * stride_qm
149
+ )
150
+ do = tl.load(
151
+ do_ptrs + start_m * stride_do_m
152
+ )
153
+ else:
154
+ q = tl.load(
155
+ q_ptrs + start_m * stride_qm,
156
+ mask=offs_d[None, :] < headdim,
157
+ other=0.0
158
+ )
159
+ do = tl.load(
160
+ do_ptrs + start_m * stride_do_m,
161
+ mask=offs_d[None, :] < headdim,
162
+ other=0.0
163
+ )
164
+ do_t_o = tl.load(
165
+ do_t_o_ptrs + start_m
166
+ )
167
+ lse = tl.load(
168
+ lse_ptrs + start_m
169
+ )
170
+ else:
171
+ if EVEN_HEADDIM:
172
+ q = tl.load(
173
+ q_ptrs + start_m * stride_qm,
174
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
175
+ other=0.0
176
+ )
177
+ do = tl.load(
178
+ do_ptrs + start_m * stride_do_m,
179
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
180
+ other=0.0
181
+ )
182
+ else:
183
+ q = tl.load(
184
+ q_ptrs + start_m * stride_qm,
185
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
186
+ other=0.0
187
+ )
188
+ do = tl.load(
189
+ do_ptrs + start_m * stride_do_m,
190
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
191
+ other=0.0
192
+ )
193
+ do_t_o = tl.load(
194
+ do_t_o_ptrs + start_m,
195
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
196
+ other=0.0
197
+ )
198
+ lse = tl.load(
199
+ lse_ptrs + start_m,
200
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
201
+ other=0.0
202
+ )
203
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
204
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
205
+ qk += tl.dot(q, tl.trans(k))
206
+ if not EVEN_M:
207
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
208
+
209
+ if MASK_TYPE == 1:
210
+ if EVEN_M & EVEN_W:
211
+ mask = tl.load(
212
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
213
+ )
214
+ else:
215
+ mask = tl.load(
216
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
217
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
218
+ & (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
219
+ other=1,
220
+ )
221
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
222
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
223
+ # to multiply with softmax_scale here.
224
+ # we assume mask already implies the causal masking
225
+ qk = qk * softmax_scale
226
+ qk = tl.where(mask, float("-inf"), qk)
227
+ p = tl.exp(qk - lse)
228
+ else:
229
+ qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
230
+ p = tl.exp(qk * softmax_scale - lse)
231
+
232
+ # dp [M, N]
233
+ dp = tl.dot(do, tl.trans(v))
234
+ # p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
235
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
236
+ # p is fp32 and [M, N], convert to q.dtype
237
+ # do [M, D] -> dv [N, D]
238
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do)
239
+ # dk [N, D]
240
+ dk += tl.dot(tl.trans(ds), q)
241
+ if EVEN_N & EVEN_M:
242
+ if EVEN_HEADDIM:
243
+ tl.store(dv_ptrs, dv)
244
+ tl.store(dk_ptrs, dk)
245
+ else:
246
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
247
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
248
+ else:
249
+ if EVEN_HEADDIM:
250
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
251
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
252
+ else:
253
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
254
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
255
+
256
+ @triton.heuristics(
257
+ {
258
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
259
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
260
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
261
+ }
262
+ )
263
+ @triton.jit
264
+ def _bwd_eva_agg_kernel_drfa_kv(
265
+ Q,
266
+ RFA_K,
267
+ RFA_V,
268
+ ChunkMask,
269
+ DO,
270
+ LSE,
271
+ DO_T_O,
272
+ D_RFA_K,
273
+ D_RFA_V,
274
+ softmax_scale,
275
+ stride_qb, stride_qh, stride_qm,
276
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
277
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
278
+ stride_chunk_mask_b, stride_chunk_mask_m,
279
+ stride_do_b, stride_do_h, stride_do_m,
280
+ stride_lse_b, stride_lse_h,
281
+ stride_do_t_o_b, stride_do_t_o_h,
282
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
283
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
284
+ nheads,
285
+ seqlen_q,
286
+ nchunks,
287
+ headdim,
288
+ CHUNKS_PER_WINDOW: tl.constexpr,
289
+ WINDOW_SIZE: tl.constexpr,
290
+ MASK_TYPE: tl.constexpr,
291
+ BLOCK_HEADDIM: tl.constexpr,
292
+ EVEN_M: tl.constexpr,
293
+ EVEN_C: tl.constexpr,
294
+ EVEN_HEADDIM: tl.constexpr,
295
+ BLOCK_M: tl.constexpr,
296
+ BLOCK_N: tl.constexpr,
297
+ ):
298
+ off_bh = tl.program_id(1)
299
+ off_h = off_bh % nheads
300
+ off_b = off_bh // nheads
301
+ start_c = tl.program_id(0)
302
+ # there are 128 chunks per window
303
+ offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
304
+ # determine which window the current KV block belongs to
305
+ offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
306
+ offs_m = tl.arange(0, BLOCK_M)
307
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
308
+
309
+ # initialize pointers
310
+ q_ptrs = (
311
+ Q +
312
+ off_b * stride_qb +
313
+ off_h * stride_qh +
314
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
315
+ )
316
+ do_ptrs = (
317
+ DO +
318
+ off_b * stride_do_b +
319
+ off_h * stride_do_h +
320
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
321
+ )
322
+ do_t_o_ptrs = (
323
+ DO_T_O +
324
+ off_b * stride_do_t_o_b +
325
+ off_h * stride_do_t_o_h +
326
+ (offs_m[:, None])
327
+ )
328
+ lse_ptrs = (
329
+ LSE +
330
+ off_b * stride_lse_b +
331
+ off_h * stride_lse_h +
332
+ (offs_m[:, None])
333
+ )
334
+ rfa_k_ptrs = (
335
+ RFA_K +
336
+ off_b * stride_rfa_kb +
337
+ off_h * stride_rfa_kh +
338
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
339
+ )
340
+ rfa_v_ptrs = (
341
+ RFA_V +
342
+ off_b * stride_rfa_vb +
343
+ off_h * stride_rfa_vh +
344
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
345
+ )
346
+ if MASK_TYPE == 1:
347
+ rfa_m_ptrs = (
348
+ ChunkMask +
349
+ off_b * stride_chunk_mask_b +
350
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
351
+ )
352
+ d_rfa_k_ptrs = (
353
+ D_RFA_K +
354
+ off_b * stride_d_rfa_k_b +
355
+ off_h * stride_d_rfa_k_h +
356
+ (offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
357
+ )
358
+ d_rfa_v_ptrs = (
359
+ D_RFA_V +
360
+ off_b * stride_d_rfa_v_b +
361
+ off_h * stride_d_rfa_v_h +
362
+ (offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
363
+ )
364
+
365
+ d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
366
+ d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
367
+ if EVEN_C & EVEN_M:
368
+ if EVEN_HEADDIM:
369
+ rfa_k = tl.load(rfa_k_ptrs)
370
+ rfa_v = tl.load(rfa_v_ptrs)
371
+ else:
372
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
373
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
374
+ else:
375
+ if EVEN_HEADDIM:
376
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
377
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
378
+ else:
379
+ rfa_k = tl.load(
380
+ rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
381
+ )
382
+ rfa_v = tl.load(
383
+ rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
384
+ )
385
+ begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
386
+ end_m = seqlen_q
387
+ for start_m in range(begin_m, end_m, BLOCK_M):
388
+ start_m = tl.multiple_of(start_m, BLOCK_M)
389
+ # load q, do, and lse
390
+ if EVEN_M:
391
+ if EVEN_HEADDIM:
392
+ q = tl.load(
393
+ q_ptrs + start_m * stride_qm
394
+ )
395
+ do = tl.load(
396
+ do_ptrs + start_m * stride_do_m
397
+ )
398
+ else:
399
+ q = tl.load(
400
+ q_ptrs + start_m * stride_qm,
401
+ mask=offs_d[None, :] < headdim,
402
+ other=0.0
403
+ )
404
+ do = tl.load(
405
+ do_ptrs + start_m * stride_do_m,
406
+ mask=offs_d[None, :] < headdim,
407
+ other=0.0
408
+ )
409
+ do_t_o = tl.load(
410
+ do_t_o_ptrs + start_m
411
+ )
412
+ lse = tl.load(
413
+ lse_ptrs + start_m
414
+ )
415
+ else:
416
+ if EVEN_HEADDIM:
417
+ q = tl.load(
418
+ q_ptrs + start_m * stride_qm,
419
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
420
+ other=0.0
421
+ )
422
+ do = tl.load(
423
+ do_ptrs + start_m * stride_do_m,
424
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
425
+ other=0.0
426
+ )
427
+ else:
428
+ q = tl.load(
429
+ q_ptrs + start_m * stride_qm,
430
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
431
+ other=0.0
432
+ )
433
+ do = tl.load(
434
+ do_ptrs + start_m * stride_do_m,
435
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
436
+ other=0.0
437
+ )
438
+ do_t_o = tl.load(
439
+ do_t_o_ptrs + start_m,
440
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
441
+ other=0.0
442
+ )
443
+ lse = tl.load(
444
+ lse_ptrs + start_m,
445
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
446
+ other=0.0
447
+ )
448
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
449
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
450
+ qk += tl.dot(q, tl.trans(rfa_k))
451
+ if not EVEN_M:
452
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
453
+
454
+ if MASK_TYPE == 1:
455
+ if EVEN_M & EVEN_C:
456
+ mask = tl.load(
457
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m)
458
+ )
459
+ else:
460
+ mask = tl.load(
461
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m),
462
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
463
+ & (offs_c[None, :] < nchunks),
464
+ other=1,
465
+ )
466
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
467
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
468
+ # to multiply with softmax_scale here.
469
+ # we assume mask already implies the causal masking
470
+ qk = qk * softmax_scale
471
+ qk = tl.where(mask, float("-inf"), qk)
472
+ p = tl.exp(qk - lse)
473
+ else:
474
+ p = tl.exp(qk * softmax_scale - lse)
475
+
476
+ dp = tl.dot(do, tl.trans(rfa_v))
477
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
478
+ # p is fp32, convert to q.dtype
479
+ d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
480
+ # move softmax_scale to ds to save computation
481
+ d_rfa_k += tl.dot(tl.trans(ds), q)
482
+ if EVEN_C & EVEN_M:
483
+ if EVEN_HEADDIM:
484
+ tl.store(d_rfa_v_ptrs, d_rfa_v)
485
+ tl.store(d_rfa_k_ptrs, d_rfa_k)
486
+ else:
487
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
488
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
489
+ else:
490
+ if EVEN_HEADDIM:
491
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
492
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
493
+ else:
494
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
495
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
496
+
497
+ @triton.heuristics(
498
+ {
499
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
500
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
501
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
502
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
503
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
504
+ }
505
+ )
506
+ @triton.jit
507
+ def _bwd_eva_agg_kernel_dq(
508
+ Q,
509
+ K,
510
+ V,
511
+ RFA_K,
512
+ RFA_V,
513
+ WindowMask,
514
+ ChunkMask,
515
+ DO,
516
+ LSE,
517
+ DO_T_O,
518
+ DQ,
519
+ softmax_scale,
520
+ stride_qb, stride_qh, stride_qm,
521
+ stride_kb, stride_kh, stride_kn,
522
+ stride_vb, stride_vh, stride_vn,
523
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
524
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
525
+ stride_window_mask_b, stride_window_mask_m,
526
+ stride_chunk_mask_b, stride_chunk_mask_m,
527
+ stride_do_b, stride_do_h, stride_do_m,
528
+ stride_lse_b, stride_lse_h,
529
+ stride_do_t_o_b, stride_do_t_o_h,
530
+ stride_dq_b, stride_dq_h, stride_dq_m,
531
+ nheads,
532
+ seqlen_q,
533
+ seqlen_k,
534
+ nchunks,
535
+ headdim,
536
+ CHUNKS_PER_WINDOW: tl.constexpr,
537
+ WINDOW_SIZE: tl.constexpr,
538
+ MASK_TYPE: tl.constexpr,
539
+ EMPTY_RFA_KV: tl.constexpr,
540
+ BLOCK_HEADDIM: tl.constexpr,
541
+ EVEN_M: tl.constexpr,
542
+ EVEN_N: tl.constexpr,
543
+ EVEN_W: tl.constexpr,
544
+ EVEN_C: tl.constexpr,
545
+ EVEN_HEADDIM: tl.constexpr,
546
+ BLOCK_M: tl.constexpr,
547
+ BLOCK_N: tl.constexpr,
548
+ ):
549
+ start_m = tl.program_id(0)
550
+ off_bh = tl.program_id(1)
551
+ off_h = off_bh % nheads
552
+ off_b = off_bh // nheads
553
+ # initialize offsets
554
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
555
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
556
+ offs_n = tl.arange(0, BLOCK_N)
557
+ offs_c = tl.arange(0, BLOCK_N)
558
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
559
+ # TODO: add paratheses or not
560
+ q_ptrs = (
561
+ Q +
562
+ off_b * stride_qb +
563
+ off_h * stride_qh +
564
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
565
+ )
566
+ k_ptrs = (
567
+ K +
568
+ off_b * stride_kb +
569
+ off_h * stride_kh +
570
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
571
+ )
572
+ v_ptrs = (
573
+ V +
574
+ off_b * stride_vb +
575
+ off_h * stride_vh +
576
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
577
+ )
578
+ if EMPTY_RFA_KV == 0:
579
+ rfa_k_ptrs = (
580
+ RFA_K +
581
+ off_b * stride_rfa_kb +
582
+ off_h * stride_rfa_kh +
583
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
584
+ )
585
+ rfa_v_ptrs = (
586
+ RFA_V +
587
+ off_b * stride_rfa_vb +
588
+ off_h * stride_rfa_vh +
589
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
590
+ )
591
+ dq_ptrs = (
592
+ DQ +
593
+ off_b * stride_dq_b +
594
+ off_h * stride_dq_h +
595
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
596
+ )
597
+ do_ptrs = (
598
+ DO +
599
+ off_b * stride_do_b +
600
+ off_h * stride_do_h +
601
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
602
+ )
603
+ do_t_o_ptrs = (
604
+ DO_T_O +
605
+ off_b * stride_do_t_o_b +
606
+ off_h * stride_do_t_o_h +
607
+ offs_m[:, None]
608
+ )
609
+ lse_ptrs = (
610
+ LSE +
611
+ off_b * stride_lse_b +
612
+ off_h * stride_lse_h +
613
+ offs_m[:, None]
614
+ )
615
+ ### load q, do, do_t_o, lse ####
616
+ if EVEN_M:
617
+ if EVEN_HEADDIM:
618
+ q = tl.load(
619
+ q_ptrs
620
+ )
621
+ do = tl.load(
622
+ do_ptrs
623
+ )
624
+ else:
625
+ q = tl.load(
626
+ q_ptrs,
627
+ mask=offs_d[None, :] < headdim,
628
+ other=0.0
629
+ )
630
+ do = tl.load(
631
+ do_ptrs,
632
+ mask=offs_d[None, :] < headdim,
633
+ other=0.0
634
+ )
635
+ do_t_o = tl.load(
636
+ do_t_o_ptrs
637
+ )
638
+ lse = tl.load(
639
+ lse_ptrs
640
+ )
641
+ else:
642
+ if EVEN_HEADDIM:
643
+ q = tl.load(
644
+ q_ptrs,
645
+ mask=offs_m[:, None] < seqlen_q,
646
+ other=0.0
647
+ )
648
+ do = tl.load(
649
+ do_ptrs,
650
+ mask=offs_m[:, None] < seqlen_q,
651
+ other=0.0
652
+ )
653
+ else:
654
+ q = tl.load(
655
+ q_ptrs,
656
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
657
+ other=0.0
658
+ )
659
+ do = tl.load(
660
+ do_ptrs,
661
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
662
+ other=0.0
663
+ )
664
+ do_t_o = tl.load(
665
+ do_t_o_ptrs,
666
+ mask=offs_m[:, None] < seqlen_q,
667
+ other=0.0
668
+ )
669
+ lse = tl.load(
670
+ lse_ptrs,
671
+ mask=offs_m[:, None] < seqlen_q,
672
+ other=0.0
673
+ )
674
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
675
+ lse *= 1.4426950408889634 # log2(e)
676
+ qk_scale = softmax_scale
677
+ qk_scale *= 1.4426950408889634 # log2(e)
678
+ if MASK_TYPE == 1:
679
+ window_mask_ptrs = (
680
+ WindowMask +
681
+ off_b * stride_window_mask_b +
682
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
683
+ )
684
+ if EMPTY_RFA_KV == 0:
685
+ chunk_mask_ptrs = (
686
+ ChunkMask +
687
+ off_b * stride_chunk_mask_b +
688
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
689
+ )
690
+
691
+ dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
692
+ # loop over k, v and update accumulator
693
+ # Iterate over local singletons;
694
+ # so we only iterate over blocks within the current window
695
+ start_idx_n = offs_w * WINDOW_SIZE
696
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
697
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
698
+ start_n = tl.multiple_of(start_n, BLOCK_N)
699
+ if EVEN_N & EVEN_M:
700
+ if EVEN_HEADDIM:
701
+ k = tl.load(
702
+ k_ptrs + start_n * stride_kn
703
+ )
704
+ else:
705
+ k = tl.load(
706
+ k_ptrs + start_n * stride_kn,
707
+ mask=offs_d[None, :] < headdim,
708
+ other=0.0
709
+ )
710
+ else:
711
+ if EVEN_HEADDIM:
712
+ k = tl.load(
713
+ k_ptrs + start_n * stride_kn,
714
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
715
+ other=0.0,
716
+ )
717
+ else:
718
+ k = tl.load(
719
+ k_ptrs + start_n * stride_kn,
720
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
721
+ other=0.0,
722
+ )
723
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
724
+ qk += tl.dot(q, tl.trans(k))
725
+ # Trying to combine the two masks seem to make the result wrong
726
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
727
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
728
+
729
+ if MASK_TYPE == 1:
730
+ if EVEN_M & EVEN_W:
731
+ window_mask = tl.load(
732
+ window_mask_ptrs + start_n - start_idx_n
733
+ )
734
+ else:
735
+ window_mask = tl.load(
736
+ window_mask_ptrs + start_n - start_idx_n,
737
+ mask=(offs_m[:, None] < seqlen_q)
738
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
739
+ other=1,
740
+ )
741
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
742
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
743
+ # to multiply with softmax_scale here.
744
+ # we assume mask already implies the causal masking
745
+ qk = qk * qk_scale
746
+ qk = tl.where(window_mask, float("-inf"), qk)
747
+ p = tl.exp2(qk - lse)
748
+ else:
749
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
750
+ p = tl.exp2(qk * qk_scale - lse)
751
+
752
+ if EVEN_N & EVEN_M:
753
+ if EVEN_HEADDIM:
754
+ v = tl.load(
755
+ v_ptrs + start_n * stride_vn
756
+ )
757
+ else:
758
+ v = tl.load(
759
+ v_ptrs + start_n * stride_vn,
760
+ mask=offs_d[None, :] < headdim,
761
+ other=0.0
762
+ )
763
+ else:
764
+ if EVEN_HEADDIM:
765
+ v = tl.load(
766
+ v_ptrs + start_n * stride_vn,
767
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
768
+ other=0.0,
769
+ )
770
+ else:
771
+ v = tl.load(
772
+ v_ptrs + start_n * stride_vn,
773
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
774
+ other=0.0,
775
+ )
776
+ dp = tl.dot(do, tl.trans(v))
777
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
778
+ dq += tl.dot(ds, k)
779
+
780
+ if EMPTY_RFA_KV == 0:
781
+ # Iterate over RFA chunks
782
+ # we only iterate over chunks before the current local singleton window
783
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
784
+ for start_c in range(0, end_idx_c, BLOCK_N):
785
+ start_c = tl.multiple_of(start_c, BLOCK_N)
786
+ # -- compute qk ----
787
+ if EVEN_C & EVEN_M:
788
+ if EVEN_HEADDIM:
789
+ rfa_k = tl.load(
790
+ rfa_k_ptrs + start_c * stride_rfa_kc
791
+ )
792
+ else:
793
+ rfa_k = tl.load(
794
+ rfa_k_ptrs + start_c * stride_rfa_kc,
795
+ mask=offs_d[None, :] < headdim,
796
+ other=0.0
797
+ )
798
+ else:
799
+ if EVEN_HEADDIM:
800
+ rfa_k = tl.load(
801
+ rfa_k_ptrs + start_c * stride_rfa_kc,
802
+ mask=(start_c + offs_c)[:, None] < nchunks,
803
+ other=0.0,
804
+ )
805
+ else:
806
+ rfa_k = tl.load(
807
+ rfa_k_ptrs + start_c * stride_rfa_kc,
808
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
809
+ other=0.0,
810
+ )
811
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
812
+ qk += tl.dot(q, tl.trans(rfa_k))
813
+ # Trying to combine the two masks seem to make the result wrong
814
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
815
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
816
+
817
+ if MASK_TYPE == 1:
818
+ if EVEN_C & EVEN_M:
819
+ chunk_mask = tl.load(
820
+ chunk_mask_ptrs + start_c
821
+ )
822
+ else:
823
+ chunk_mask = tl.load(
824
+ chunk_mask_ptrs + start_c,
825
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
826
+ other=1,
827
+ )
828
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
829
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
830
+ # to multiply with softmax_scale here.
831
+ # we assume mask already implies the causal masking
832
+ qk = qk * qk_scale
833
+ qk = tl.where(chunk_mask, float("-inf"), qk)
834
+ p = tl.exp2(qk - lse)
835
+ else:
836
+ p = tl.exp2(qk * qk_scale - lse)
837
+
838
+ if EVEN_C & EVEN_M:
839
+ if EVEN_HEADDIM:
840
+ rfa_v = tl.load(
841
+ rfa_v_ptrs + start_c * stride_rfa_vc
842
+ )
843
+ else:
844
+ rfa_v = tl.load(
845
+ rfa_v_ptrs + start_c * stride_rfa_vc,
846
+ mask=offs_d[None, :] < headdim,
847
+ other=0.0
848
+ )
849
+ else:
850
+ if EVEN_HEADDIM:
851
+ rfa_v = tl.load(
852
+ rfa_v_ptrs + start_c * stride_rfa_vc,
853
+ mask=(start_c + offs_n)[:, None] < nchunks,
854
+ other=0.0,
855
+ )
856
+ else:
857
+ rfa_v = tl.load(
858
+ rfa_v_ptrs + start_c * stride_rfa_vc,
859
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
860
+ other=0.0,
861
+ )
862
+ dp = tl.dot(do, tl.trans(rfa_v))
863
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
864
+ dq += tl.dot(ds, rfa_k)
865
+
866
+ start_m = tl.program_id(0)
867
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
868
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
869
+ dq_ptrs = (
870
+ DQ +
871
+ off_b * stride_dq_b +
872
+ off_h * stride_dq_h +
873
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
874
+ )
875
+ if EVEN_M:
876
+ if EVEN_HEADDIM:
877
+ tl.store(
878
+ dq_ptrs, dq
879
+ )
880
+ else:
881
+ tl.store(
882
+ dq_ptrs, dq,
883
+ mask=offs_d[None, :] < headdim
884
+ )
885
+ else:
886
+ if EVEN_HEADDIM:
887
+ tl.store(
888
+ dq_ptrs, dq,
889
+ mask=offs_m[:, None] < seqlen_q
890
+ )
891
+ else:
892
+ tl.store(
893
+ dq_ptrs, dq,
894
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
895
+ )
896
+
897
+ _capability_90_config = {
898
+ "fwd": {
899
+ (torch.bfloat16, 64): (128, 128, 4, 3),
900
+ (torch.bfloat16, 128): (128, 128, 8, 3),
901
+ (torch.float32, 64): (128, 64, 8, 3),
902
+ (torch.float32, 128): (64, 32, 4, 3),
903
+ },
904
+ "bwd_dq": {
905
+ (torch.bfloat16, 64): (128, 64, 4, 3),
906
+ (torch.bfloat16, 128): (128, 64, 8, 3),
907
+ (torch.float32, 64): (128, 64, 8, 2),
908
+ (torch.float32, 128): (32, 32, 4, 2),
909
+ },
910
+ "bwd_dkdv": {
911
+ (torch.bfloat16, 64): (128, 64, 4, 2),
912
+ (torch.bfloat16, 128): (128, 64, 8, 2),
913
+ (torch.float32, 64): (128, 64, 8, 2),
914
+ (torch.float32, 128): (32, 32, 4, 1),
915
+ },
916
+ "bwd_drfa_kv": {
917
+ (torch.bfloat16, 64): (128, 64, 4, 2),
918
+ (torch.bfloat16, 128): (128, 64, 8, 2),
919
+ (torch.float32, 64): (128, 64, 8, 2),
920
+ (torch.float32, 128): (32, 32, 4, 1),
921
+ }
922
+ }
923
+
924
+ _capability_80_config = {
925
+ "fwd": {
926
+ (torch.bfloat16, 64): (64, 64, 4, 3),
927
+ (torch.bfloat16, 128): (64, 64, 8, 3),
928
+ (torch.float32, 64): (64, 32, 4, 2),
929
+ (torch.float32, 128): (64, 32, 8, 1),
930
+ },
931
+ "bwd_dq": {
932
+ (torch.bfloat16, 64): (64, 64, 4, 3),
933
+ (torch.bfloat16, 128): (64, 32, 4, 2),
934
+ (torch.float32, 64): (32, 32, 4, 2),
935
+ (torch.float32, 128): (32, 32, 4, 2),
936
+ },
937
+ "bwd_dkdv": {
938
+ (torch.bfloat16, 64): (64, 64, 4, 3),
939
+ (torch.bfloat16, 128): (32, 32, 4, 2),
940
+ (torch.float32, 64): (32, 32, 4, 1),
941
+ (torch.float32, 128): (16, 64, 8, 1),
942
+ },
943
+ "bwd_drfa_kv": {
944
+ (torch.bfloat16, 64): (64, 64, 4, 3),
945
+ (torch.bfloat16, 128): (64, 32, 4, 3),
946
+ (torch.float32, 64): (32, 32, 4, 1),
947
+ (torch.float32, 128): (32, 32, 4, 1),
948
+ }
949
+ }
950
+
951
+ def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
952
+ capability = torch.cuda.get_device_capability()
953
+ if capability >= (9, 0):
954
+ kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
955
+ elif capability >= (8, 0):
956
+ kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
957
+ else:
958
+ if mode == "fwd":
959
+ if dtype == torch.float32:
960
+ kernel_config = (32, 16, 4, 2)
961
+ else:
962
+ kernel_config = (64, 32, 4, 2)
963
+ else:
964
+ if dtype == torch.float32:
965
+ kernel_config = (16, 16, 4, 1)
966
+ else:
967
+ kernel_config = (32, 32, 4, 1)
968
+ return kernel_config
969
+
970
+ @triton.heuristics(
971
+ {
972
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
973
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
974
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
975
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
976
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
977
+ }
978
+ )
979
+ @triton.jit
980
+ def _fwd_eva_agg_kernel(
981
+ Q,
982
+ K,
983
+ V,
984
+ RFA_K,
985
+ RFA_V,
986
+ WindowMask,
987
+ ChunkMask,
988
+ Out,
989
+ LSE,
990
+ softmax_scale,
991
+ stride_qb, stride_qh, stride_qm,
992
+ stride_kb, stride_kh, stride_kn,
993
+ stride_vb, stride_vh, stride_vn,
994
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
995
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
996
+ stride_window_mask_b, stride_window_mask_m,
997
+ stride_chunk_mask_b, stride_chunk_mask_m,
998
+ stride_ob, stride_oh, stride_om,
999
+ stride_lse_b, stride_lse_h,
1000
+ nheads,
1001
+ seqlen_q,
1002
+ seqlen_k,
1003
+ nchunks,
1004
+ headdim,
1005
+ CHUNKS_PER_WINDOW: tl.constexpr,
1006
+ WINDOW_SIZE: tl.constexpr,
1007
+ MASK_TYPE: tl.constexpr,
1008
+ EMPTY_RFA_KV: tl.constexpr,
1009
+ BLOCK_HEADDIM: tl.constexpr,
1010
+ EVEN_M: tl.constexpr,
1011
+ EVEN_N: tl.constexpr,
1012
+ EVEN_W: tl.constexpr,
1013
+ EVEN_C: tl.constexpr,
1014
+ EVEN_HEADDIM: tl.constexpr,
1015
+ BLOCK_M: tl.constexpr,
1016
+ BLOCK_N: tl.constexpr,
1017
+ ):
1018
+ start_m = tl.program_id(0)
1019
+ off_bh = tl.program_id(1)
1020
+ off_h = off_bh % nheads
1021
+ off_b = off_bh // nheads
1022
+ # initialize offsets
1023
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
1024
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
1025
+ offs_n = tl.arange(0, BLOCK_N)
1026
+ offs_c = tl.arange(0, BLOCK_N)
1027
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
1028
+ # TODO: add paratheses or not
1029
+ q_ptrs = (
1030
+ Q +
1031
+ off_b * stride_qb +
1032
+ off_h * stride_qh +
1033
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
1034
+ )
1035
+ k_ptrs = (
1036
+ K +
1037
+ off_b * stride_kb +
1038
+ off_h * stride_kh +
1039
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
1040
+ )
1041
+ v_ptrs = (
1042
+ V +
1043
+ off_b * stride_vb +
1044
+ off_h * stride_vh +
1045
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
1046
+ )
1047
+ if EMPTY_RFA_KV == 0:
1048
+ rfa_k_ptrs = (
1049
+ RFA_K +
1050
+ off_b * stride_rfa_kb +
1051
+ off_h * stride_rfa_kh +
1052
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
1053
+ )
1054
+ rfa_v_ptrs = (
1055
+ RFA_V +
1056
+ off_b * stride_rfa_vb +
1057
+ off_h * stride_rfa_vh +
1058
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
1059
+ )
1060
+
1061
+ qk_scale = softmax_scale
1062
+ qk_scale *= 1.4426950408889634 # log2(e)
1063
+ if MASK_TYPE == 1:
1064
+ window_mask_ptrs = (
1065
+ WindowMask +
1066
+ off_b * stride_window_mask_b +
1067
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
1068
+ )
1069
+ if EMPTY_RFA_KV == 0:
1070
+ chunk_mask_ptrs = (
1071
+ ChunkMask +
1072
+ off_b * stride_chunk_mask_b +
1073
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
1074
+ )
1075
+
1076
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
1077
+ d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
1078
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
1079
+ # load q: it will stay in SRAM throughout
1080
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
1081
+ # tl.load(q_ptrs), we get the wrong output!
1082
+ if EVEN_M & EVEN_N:
1083
+ if EVEN_HEADDIM:
1084
+ q = tl.load(
1085
+ q_ptrs
1086
+ )
1087
+ else:
1088
+ q = tl.load(
1089
+ q_ptrs,
1090
+ mask=offs_d[None, :] < headdim,
1091
+ other=0.0
1092
+ )
1093
+ else:
1094
+ if EVEN_HEADDIM:
1095
+ q = tl.load(
1096
+ q_ptrs,
1097
+ mask=offs_m[:, None] < seqlen_q,
1098
+ other=0.0
1099
+ )
1100
+ else:
1101
+ q = tl.load(
1102
+ q_ptrs,
1103
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
1104
+ other=0.0
1105
+ )
1106
+ # loop over k, v and update accumulator
1107
+ # Iterate over local singletons;
1108
+ # so we only iterate over blocks within the current window
1109
+ start_idx_n = offs_w * WINDOW_SIZE
1110
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
1111
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
1112
+ start_n = tl.multiple_of(start_n, BLOCK_N)
1113
+ # -- compute qk ----
1114
+ if EVEN_N & EVEN_M:
1115
+ if EVEN_HEADDIM:
1116
+ k = tl.load(
1117
+ k_ptrs + start_n * stride_kn
1118
+ )
1119
+ else:
1120
+ k = tl.load(
1121
+ k_ptrs + start_n * stride_kn,
1122
+ mask=offs_d[None, :] < headdim,
1123
+ other=0.0
1124
+ )
1125
+ else:
1126
+ if EVEN_HEADDIM:
1127
+ k = tl.load(
1128
+ k_ptrs + start_n * stride_kn,
1129
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
1130
+ other=0.0,
1131
+ )
1132
+ else:
1133
+ k = tl.load(
1134
+ k_ptrs + start_n * stride_kn,
1135
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
1136
+ other=0.0,
1137
+ )
1138
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
1139
+ qk += tl.dot(q, tl.trans(k))
1140
+ # Trying to combine the two masks seem to make the result wrong
1141
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
1142
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
1143
+
1144
+ if MASK_TYPE == 1:
1145
+ if EVEN_M & EVEN_W:
1146
+ window_mask = tl.load(
1147
+ window_mask_ptrs + start_n - start_idx_n
1148
+ )
1149
+ else:
1150
+ window_mask = tl.load(
1151
+ window_mask_ptrs + start_n - start_idx_n,
1152
+ mask=(offs_m[:, None] < seqlen_q)
1153
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
1154
+ other=1,
1155
+ )
1156
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1157
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1158
+ # to multiply with softmax_scale here.
1159
+ # we assume mask already implies the causal masking
1160
+ qk = qk * qk_scale
1161
+ qk = tl.where(window_mask, float("-inf"), qk)
1162
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
1163
+ masked_out_rows = (m_ij == float("-inf"))
1164
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1165
+ p = tl.exp2(qk - m_ij_masked[:, None])
1166
+ else:
1167
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
1168
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1169
+ masked_out_rows = (m_ij == float("-inf"))
1170
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1171
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1172
+
1173
+ d_ij = tl.sum(p, 1)
1174
+
1175
+ # scale acc_o
1176
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1177
+ # # -- update output accumulator --
1178
+ acc_o = acc_o * prev_scale[:, None]
1179
+ # update acc_o
1180
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
1181
+ if EVEN_HEADDIM:
1182
+ v = tl.load(
1183
+ v_ptrs + start_n * stride_vn
1184
+ )
1185
+ else:
1186
+ v = tl.load(
1187
+ v_ptrs + start_n * stride_vn,
1188
+ mask=offs_d[None, :] < headdim,
1189
+ other=0.0
1190
+ )
1191
+ else:
1192
+ if EVEN_HEADDIM:
1193
+ v = tl.load(
1194
+ v_ptrs + start_n * stride_vn,
1195
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
1196
+ other=0.0,
1197
+ )
1198
+ else:
1199
+ v = tl.load(
1200
+ v_ptrs + start_n * stride_vn,
1201
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
1202
+ other=0.0,
1203
+ )
1204
+ p = p.to(v.dtype)
1205
+ acc_o = tl.dot(p, v, acc_o)
1206
+
1207
+ # -- update statistics
1208
+ d_i = d_i * prev_scale + d_ij
1209
+ m_i = m_ij
1210
+
1211
+ if EMPTY_RFA_KV == 0:
1212
+ # Iterate over RFA chunks
1213
+ # we only iterate over chunks before the current local singleton window
1214
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
1215
+ for start_c in range(0, end_idx_c, BLOCK_N):
1216
+ start_c = tl.multiple_of(start_c, BLOCK_N)
1217
+ # -- compute qk ----
1218
+ if EVEN_C & EVEN_M:
1219
+ if EVEN_HEADDIM:
1220
+ rfa_k = tl.load(
1221
+ rfa_k_ptrs + start_c * stride_rfa_kc
1222
+ )
1223
+ else:
1224
+ rfa_k = tl.load(
1225
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1226
+ mask=offs_d[None, :] < headdim,
1227
+ other=0.0
1228
+ )
1229
+ else:
1230
+ if EVEN_HEADDIM:
1231
+ rfa_k = tl.load(
1232
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1233
+ mask=(start_c + offs_c)[:, None] < nchunks,
1234
+ other=0.0,
1235
+ )
1236
+ else:
1237
+ rfa_k = tl.load(
1238
+ rfa_k_ptrs + start_c * stride_rfa_kc,
1239
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
1240
+ other=0.0,
1241
+ )
1242
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
1243
+ qk += tl.dot(q, tl.trans(rfa_k))
1244
+ # Trying to combine the two masks seem to make the result wrong
1245
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
1246
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
1247
+
1248
+ if MASK_TYPE == 1:
1249
+ if EVEN_C & EVEN_M:
1250
+ chunk_mask = tl.load(
1251
+ chunk_mask_ptrs + start_c
1252
+ )
1253
+ else:
1254
+ chunk_mask = tl.load(
1255
+ chunk_mask_ptrs + start_c,
1256
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
1257
+ other=1,
1258
+ )
1259
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1260
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1261
+ # to multiply with softmax_scale here.
1262
+ # we assume mask already implies the causal masking
1263
+ qk = qk * qk_scale
1264
+ qk = tl.where(chunk_mask, float("-inf"), qk)
1265
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
1266
+ masked_out_rows = (m_ij == float("-inf"))
1267
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1268
+ p = tl.exp2(qk - m_ij_masked[:, None])
1269
+ else:
1270
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1271
+ masked_out_rows = (m_ij == float("-inf"))
1272
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1273
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1274
+
1275
+ d_ij = tl.sum(p, 1)
1276
+
1277
+ # scale acc_o
1278
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1279
+ # # -- update output accumulator --
1280
+ acc_o = acc_o * prev_scale[:, None]
1281
+ # update acc_o
1282
+ # TODO: If we just do "if EVEN_N", there seems to be some race condition ?
1283
+ if EVEN_C & EVEN_M:
1284
+ if EVEN_HEADDIM:
1285
+ rfa_v = tl.load(
1286
+ rfa_v_ptrs + start_c * stride_rfa_vc
1287
+ )
1288
+ else:
1289
+ rfa_v = tl.load(
1290
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1291
+ mask=offs_d[None, :] < headdim,
1292
+ other=0.0
1293
+ )
1294
+ else:
1295
+ if EVEN_HEADDIM:
1296
+ rfa_v = tl.load(
1297
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1298
+ mask=(start_c + offs_n)[:, None] < nchunks,
1299
+ other=0.0,
1300
+ )
1301
+ else:
1302
+ rfa_v = tl.load(
1303
+ rfa_v_ptrs + start_c * stride_rfa_vc,
1304
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
1305
+ other=0.0,
1306
+ )
1307
+ p = p.to(rfa_v.dtype)
1308
+ acc_o = tl.dot(p, rfa_v, acc_o)
1309
+
1310
+ # -- update statistics
1311
+ d_i = d_i * prev_scale + d_ij
1312
+ m_i = m_ij
1313
+
1314
+ # for rows that are all -inf, set d_i to 1.0
1315
+ d_i = tl.where(d_i == 0.0, 1.0, d_i)
1316
+ # multiply by log(2)
1317
+ lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
1318
+ acc_o = acc_o / d_i[:, None]
1319
+ # TODO: understand why rematerialize offsets to save registers?
1320
+ start_m = tl.program_id(0)
1321
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
1322
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
1323
+ out_ptrs = (
1324
+ Out +
1325
+ off_b * stride_ob +
1326
+ off_h * stride_oh +
1327
+ (offs_m[:, None] * stride_om + offs_d[None, :])
1328
+ )
1329
+ if EVEN_M:
1330
+ if EVEN_HEADDIM:
1331
+ tl.store(
1332
+ out_ptrs, acc_o
1333
+ )
1334
+ else:
1335
+ tl.store(
1336
+ out_ptrs, acc_o,
1337
+ mask=offs_d[None, :] < headdim
1338
+ )
1339
+ else:
1340
+ if EVEN_HEADDIM:
1341
+ tl.store(
1342
+ out_ptrs, acc_o,
1343
+ mask=offs_m[:, None] < seqlen_q
1344
+ )
1345
+ else:
1346
+ tl.store(
1347
+ out_ptrs, acc_o,
1348
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
1349
+ )
1350
+ lse_ptrs = (
1351
+ LSE +
1352
+ off_b * stride_lse_b +
1353
+ off_h * stride_lse_h +
1354
+ offs_m
1355
+ )
1356
+ if EVEN_M:
1357
+ tl.store(
1358
+ lse_ptrs, lse_m,
1359
+ )
1360
+ else:
1361
+ tl.store(
1362
+ lse_ptrs, lse_m,
1363
+ mask=offs_m < seqlen_q
1364
+ )
1365
+
1366
+ def triton_eva_agg_fwd(
1367
+ q, k, v, rfa_k, rfa_v,
1368
+ window_mask,
1369
+ chunk_mask,
1370
+ softmax_scale,
1371
+ window_size,
1372
+ chunks_per_window
1373
+ ):
1374
+ if rfa_k is None and rfa_v is None:
1375
+ empty_rfa_kv = 1
1376
+
1377
+ q, k, v = [
1378
+ x if x.stride(-1) == 1 else x.contiguous()
1379
+ for x in [q, k, v]
1380
+ ]
1381
+ else:
1382
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
1383
+ empty_rfa_kv = 0
1384
+
1385
+ q, k, v, rfa_k, rfa_v = [
1386
+ x if x.stride(-1) == 1 else x.contiguous()
1387
+ for x in [q, k, v, rfa_k, rfa_v]
1388
+ ]
1389
+
1390
+ # shape constraints
1391
+ batch, nheads, seqlen_q, head_dim = q.shape
1392
+ _, _, seqlen_k, _ = k.shape
1393
+ if empty_rfa_kv == 0:
1394
+ nchunks = rfa_k.shape[-2]
1395
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
1396
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
1397
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
1398
+ else:
1399
+ nchunks = 0
1400
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
1401
+ assert k.shape == (batch, nheads, seqlen_k, head_dim)
1402
+ assert v.shape == (batch, nheads, seqlen_k, head_dim)
1403
+
1404
+ assert head_dim <= 128, "We only test head dimensions up to 128"
1405
+ # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
1406
+ assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
1407
+ assert q.is_cuda and k.is_cuda and v.is_cuda
1408
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
1409
+
1410
+ mask_type = 0
1411
+ if window_mask is not None:
1412
+ mask_type = 1
1413
+ assert window_mask.dtype == torch.bool
1414
+ assert window_mask.is_cuda
1415
+ assert window_mask.dim() == 4
1416
+ assert window_mask.shape == (batch, 1, seqlen_q, window_size)
1417
+ if window_mask.stride(-1) != 1:
1418
+ window_mask = window_mask.contiguous()
1419
+
1420
+ assert chunk_mask is not None
1421
+ assert chunk_mask.dtype == torch.bool
1422
+ assert chunk_mask.is_cuda
1423
+ assert chunk_mask.dim() == 4
1424
+ assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
1425
+ if chunk_mask.stride(-1) != 1:
1426
+ chunk_mask = chunk_mask.contiguous()
1427
+
1428
+ chunk_mask_strides = (
1429
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1430
+ if mask_type == 1 else
1431
+ (0, 0)
1432
+ )
1433
+ window_mask_strides = (
1434
+ (window_mask.stride(0), window_mask.stride(2))
1435
+ if mask_type == 1 else
1436
+ (0, 0)
1437
+ )
1438
+
1439
+ rfa_k_strides = (
1440
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
1441
+ if empty_rfa_kv == 0 else
1442
+ (0, 0, 0)
1443
+ )
1444
+ rfa_v_strides = (
1445
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
1446
+ if empty_rfa_kv == 0 else
1447
+ (0, 0, 0)
1448
+ )
1449
+
1450
+ o = torch.empty_like(q)
1451
+ lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
1452
+
1453
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1454
+
1455
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
1456
+
1457
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1458
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
1459
+
1460
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
1461
+ _fwd_eva_agg_kernel[grid](
1462
+ q,
1463
+ k,
1464
+ v,
1465
+ rfa_k,
1466
+ rfa_v,
1467
+ window_mask,
1468
+ chunk_mask,
1469
+ o,
1470
+ lse,
1471
+ softmax_scale,
1472
+ q.stride(0), q.stride(1), q.stride(2),
1473
+ k.stride(0), k.stride(1), k.stride(2),
1474
+ v.stride(0), v.stride(1), v.stride(2),
1475
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1476
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1477
+ window_mask_strides[0], window_mask_strides[1],
1478
+ chunk_mask_strides[0], chunk_mask_strides[1],
1479
+ o.stride(0), o.stride(1), o.stride(2),
1480
+ lse.stride(0), lse.stride(1),
1481
+ nheads,
1482
+ seqlen_q,
1483
+ seqlen_k,
1484
+ nchunks,
1485
+ head_dim,
1486
+ chunks_per_window,
1487
+ window_size,
1488
+ mask_type,
1489
+ empty_rfa_kv,
1490
+ BLOCK_HEADDIM,
1491
+ BLOCK_M=BLOCK_M,
1492
+ BLOCK_N=BLOCK_N,
1493
+ num_warps=num_warps,
1494
+ num_stages=num_stages,
1495
+ )
1496
+ return o, lse
1497
+
1498
+ def triton_eva_agg_bwd(
1499
+ do,
1500
+ q, k, v, rfa_k, rfa_v,
1501
+ window_mask, chunk_mask,
1502
+ o, lse,
1503
+ dq, dk, dv, d_rfa_k, d_rfa_v,
1504
+ softmax_scale,
1505
+ window_size,
1506
+ chunks_per_window,
1507
+ empty_rfa_kv,
1508
+ mask_type,
1509
+ ):
1510
+ if do.stride(-1) != 1:
1511
+ do = do.contiguous()
1512
+
1513
+ # shape constraints
1514
+ batch, nheads, seqlen_q, head_dim = q.shape
1515
+ _, _, seqlen_k, _ = k.shape
1516
+ if empty_rfa_kv == 0:
1517
+ nchunks = rfa_k.shape[-2]
1518
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
1519
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
1520
+ assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
1521
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
1522
+ else:
1523
+ nchunks = 0
1524
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
1525
+
1526
+ assert lse.shape == (batch, nheads, seqlen_q)
1527
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
1528
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
1529
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
1530
+
1531
+ assert head_dim <= 128, "We only test head dimensions up to 128"
1532
+
1533
+ window_mask_strides = (
1534
+ (window_mask.stride(0), window_mask.stride(2))
1535
+ if mask_type == 1 else
1536
+ (0, 0)
1537
+ )
1538
+ chunk_mask_strides = (
1539
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1540
+ if mask_type == 1 else
1541
+ (0, 0)
1542
+ )
1543
+
1544
+ rfa_k_strides = (
1545
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
1546
+ if empty_rfa_kv == 0 else
1547
+ (0, 0, 0)
1548
+ )
1549
+ rfa_v_strides = (
1550
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
1551
+ if empty_rfa_kv == 0 else
1552
+ (0, 0, 0)
1553
+ )
1554
+
1555
+ d_rfa_k_strides = (
1556
+ (d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
1557
+ if empty_rfa_kv == 0 else
1558
+ (0, 0, 0)
1559
+ )
1560
+ d_rfa_v_strides = (
1561
+ (d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
1562
+ if empty_rfa_kv == 0 else
1563
+ (0, 0, 0)
1564
+ )
1565
+
1566
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1567
+
1568
+ do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
1569
+
1570
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
1571
+
1572
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1573
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
1574
+ grid = lambda META: (
1575
+ triton.cdiv(seqlen_q, META["BLOCK_M"]),
1576
+ batch * nheads,
1577
+ )
1578
+ _bwd_eva_agg_kernel_dq[grid](
1579
+ q,
1580
+ k,
1581
+ v,
1582
+ rfa_k,
1583
+ rfa_v,
1584
+ window_mask,
1585
+ chunk_mask,
1586
+ do,
1587
+ lse,
1588
+ do_t_o,
1589
+ dq,
1590
+ softmax_scale,
1591
+ q.stride(0), q.stride(1), q.stride(2),
1592
+ k.stride(0), k.stride(1), k.stride(2),
1593
+ v.stride(0), v.stride(1), v.stride(2),
1594
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1595
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1596
+ window_mask_strides[0], window_mask_strides[1],
1597
+ chunk_mask_strides[0], chunk_mask_strides[1],
1598
+ do.stride(0), do.stride(1), do.stride(2),
1599
+ lse.stride(0), lse.stride(1),
1600
+ do_t_o.stride(0), do_t_o.stride(1),
1601
+ dq.stride(0), dq.stride(1), dq.stride(2),
1602
+ nheads,
1603
+ seqlen_q,
1604
+ seqlen_k,
1605
+ nchunks,
1606
+ head_dim,
1607
+ chunks_per_window,
1608
+ window_size,
1609
+ mask_type,
1610
+ empty_rfa_kv,
1611
+ BLOCK_HEADDIM,
1612
+ BLOCK_M=BLOCK_M,
1613
+ BLOCK_N=BLOCK_N,
1614
+ num_warps=num_warps,
1615
+ num_stages=num_stages,
1616
+ )
1617
+
1618
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
1619
+ grid = lambda META: (
1620
+ triton.cdiv(seqlen_k, META["BLOCK_N"]),
1621
+ batch * nheads,
1622
+ )
1623
+ _bwd_eva_agg_kernel_dkdv[grid](
1624
+ q,
1625
+ k,
1626
+ v,
1627
+ window_mask,
1628
+ do,
1629
+ lse,
1630
+ do_t_o,
1631
+ dk,
1632
+ dv,
1633
+ softmax_scale,
1634
+ q.stride(0), q.stride(1), q.stride(2),
1635
+ k.stride(0), k.stride(1), k.stride(2),
1636
+ v.stride(0), v.stride(1), v.stride(2),
1637
+ window_mask_strides[0], window_mask_strides[1],
1638
+ do.stride(0), do.stride(1), do.stride(2),
1639
+ lse.stride(0), lse.stride(1),
1640
+ do_t_o.stride(0), do_t_o.stride(1),
1641
+ dk.stride(0), dk.stride(1), dk.stride(2),
1642
+ dv.stride(0), dv.stride(1), dv.stride(2),
1643
+ nheads,
1644
+ seqlen_q,
1645
+ seqlen_k,
1646
+ head_dim,
1647
+ window_size,
1648
+ mask_type,
1649
+ BLOCK_HEADDIM,
1650
+ BLOCK_M=BLOCK_M,
1651
+ BLOCK_N=BLOCK_N,
1652
+ num_warps=num_warps,
1653
+ num_stages=num_stages,
1654
+ )
1655
+ if empty_rfa_kv == 0:
1656
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
1657
+ grid = lambda META: (
1658
+ triton.cdiv(nchunks, META["BLOCK_N"]),
1659
+ batch * nheads,
1660
+ )
1661
+ _bwd_eva_agg_kernel_drfa_kv[grid](
1662
+ q,
1663
+ rfa_k,
1664
+ rfa_v,
1665
+ chunk_mask,
1666
+ do,
1667
+ lse,
1668
+ do_t_o,
1669
+ d_rfa_k,
1670
+ d_rfa_v,
1671
+ softmax_scale,
1672
+ q.stride(0), q.stride(1), q.stride(2),
1673
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1674
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1675
+ chunk_mask_strides[0], chunk_mask_strides[1],
1676
+ do.stride(0), do.stride(1), do.stride(2),
1677
+ lse.stride(0), lse.stride(1),
1678
+ do_t_o.stride(0), do_t_o.stride(1),
1679
+ d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
1680
+ d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
1681
+ nheads,
1682
+ seqlen_q,
1683
+ nchunks,
1684
+ head_dim,
1685
+ chunks_per_window,
1686
+ window_size,
1687
+ mask_type,
1688
+ BLOCK_HEADDIM,
1689
+ BLOCK_M=BLOCK_M,
1690
+ BLOCK_N=BLOCK_N,
1691
+ num_warps=num_warps,
1692
+ num_stages=num_stages,
1693
+ )
1694
+
1695
+
1696
+ class EvaAggFunc(torch.autograd.Function):
1697
+ @staticmethod
1698
+ def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
1699
+ if rfa_k is None and rfa_v is None:
1700
+ empty_rfa_kv = 1
1701
+ else:
1702
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
1703
+ empty_rfa_kv = 0
1704
+
1705
+ if window_mask is not None:
1706
+ mask_type = 1
1707
+ else:
1708
+ mask_type = 0
1709
+ o, lse = triton_eva_agg_fwd(
1710
+ q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
1711
+ )
1712
+ ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
1713
+ ctx.softmax_scale = softmax_scale
1714
+ ctx.window_size = window_size
1715
+ ctx.chunks_per_window = chunks_per_window
1716
+ ctx.empty_rfa_kv = empty_rfa_kv
1717
+ ctx.mask_type = mask_type
1718
+ return o
1719
+
1720
+ @staticmethod
1721
+ def backward(ctx, do):
1722
+ q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
1723
+ dq = torch.empty_like(q)
1724
+ dk = torch.empty_like(k)
1725
+ dv = torch.empty_like(v)
1726
+ if ctx.empty_rfa_kv == 0:
1727
+ d_rfa_k = torch.empty_like(rfa_k)
1728
+ d_rfa_v = torch.empty_like(rfa_v)
1729
+ else:
1730
+ d_rfa_k = None
1731
+ d_rfa_v = None
1732
+ triton_eva_agg_bwd(
1733
+ do,
1734
+ q,
1735
+ k,
1736
+ v,
1737
+ rfa_k,
1738
+ rfa_v,
1739
+ window_mask,
1740
+ chunk_mask,
1741
+ o,
1742
+ lse,
1743
+ dq,
1744
+ dk,
1745
+ dv,
1746
+ d_rfa_k,
1747
+ d_rfa_v,
1748
+ softmax_scale=ctx.softmax_scale,
1749
+ window_size=ctx.window_size,
1750
+ chunks_per_window=ctx.chunks_per_window,
1751
+ empty_rfa_kv=ctx.empty_rfa_kv,
1752
+ mask_type=ctx.mask_type,
1753
+ )
1754
+ return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
1755
+
1756
+
1757
+ def eva_agg_func_triton(
1758
+ q, k, v, rfa_k, rfa_v,
1759
+ window_mask, chunk_mask,
1760
+ softmax_scale=None, window_size=None, chunks_per_window=None,
1761
+ ):
1762
+ return EvaAggFunc.apply(
1763
+ q, k, v, rfa_k, rfa_v,
1764
+ window_mask, chunk_mask,
1765
+ softmax_scale, window_size, chunks_per_window,
1766
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_cache.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List, Any, Union
2
+ import torch
3
+ from transformers.cache_utils import Cache
4
+
5
+ class EvaCache(Cache):
6
+ """
7
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
8
+
9
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
10
+ `[batch_size, num_heads, seq_len, head_dim]`.
11
+ """
12
+
13
+ def __init__(self) -> None:
14
+ self.w_k: List[torch.Tensor] = []
15
+ self.w_v: List[torch.Tensor] = []
16
+
17
+ self.rf_q: List[torch.Tensor] = []
18
+ self.rf_k: List[torch.Tensor] = []
19
+ self.rf_v: List[torch.Tensor] = []
20
+
21
+ self.softmax_phi_k_v: List[torch.Tensor] = []
22
+ self.log_sum_phi_k: List[torch.Tensor] = []
23
+ self.rf_k_bar: List[torch.Tensor] = []
24
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
25
+
26
+ # attention masks temporary buffer
27
+ self.rf_mask: List[Optional[torch.Tensor]] = []
28
+ self.s_mask: List[torch.Tensor] = []
29
+ self.chunk_mask: List[torch.Tensor] = []
30
+
31
+ def __len__(self):
32
+ """
33
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
34
+ to the number of layers in the model.
35
+ """
36
+ return len(self.w_k)
37
+
38
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
39
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
40
+ # Cache without size limit -> all cache is usable
41
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
42
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
43
+ max_length = self.get_max_length()
44
+ previous_seq_length = self.get_seq_length(layer_idx)
45
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
46
+ return max_length - new_seq_length
47
+ return previous_seq_length
48
+
49
+ def reorder_cache(self, beam_idx: torch.LongTensor):
50
+ """Reorders the cache for beam search, given the selected beam indices."""
51
+ for layer_idx in range(len(self.w_k)):
52
+ device = self.w_k[layer_idx].device
53
+ self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device))
54
+
55
+ device = self.w_v[layer_idx].device
56
+ self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device))
57
+
58
+ device = self.rf_q[layer_idx].device
59
+ self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device))
60
+
61
+ device = self.rf_k[layer_idx].device
62
+ self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device))
63
+
64
+ device = self.rf_v[layer_idx].device
65
+ self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device))
66
+
67
+ device = self.softmax_phi_k_v[layer_idx].device
68
+ self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device))
69
+
70
+ device = self.log_sum_phi_k[layer_idx].device
71
+ self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device))
72
+
73
+ device = self.rf_k_bar[layer_idx].device
74
+ self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device))
75
+
76
+ device = self.rf_mask[layer_idx].device
77
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
78
+
79
+ device = self.s_mask[layer_idx].device
80
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
81
+
82
+ device = self.chunk_mask[layer_idx].device
83
+ self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device))
84
+ @property
85
+ def seen_tokens(self):
86
+ if hasattr(self, "_seen_tokens"):
87
+ return self._seen_tokens
88
+ else:
89
+ return None
90
+
91
+ def update_past_len(
92
+ self,
93
+ cur_q_len: int,
94
+ layer_idx: int
95
+ ):
96
+ # Update the number of seen tokens
97
+ if layer_idx == 0:
98
+ self._seen_tokens += cur_q_len
99
+ return self._seen_tokens
100
+
101
+ def update_mask(
102
+ self,
103
+ prev_s_mask,
104
+ cur_s_mask,
105
+ chunk_mask,
106
+ rf_mask,
107
+ layer_idx,
108
+ window_size,
109
+ chunk_size,
110
+ ):
111
+ ############################################
112
+ # compute masks for singletons
113
+ ############################################
114
+ q_len = None
115
+ if len(self.s_mask) <= layer_idx:
116
+ q_len = chunk_mask.shape[-2]
117
+ # prefill stage
118
+ # q is of shape [b, h, n, d]
119
+ if q_len < window_size:
120
+ assert prev_s_mask is None
121
+
122
+ # w_v = # [b, h, 1, j, d]
123
+ # store the past window-wise key-value pairs
124
+ self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :])
125
+ else:
126
+ # decoding stage
127
+ prev_s_mask = None
128
+
129
+ cached_s_mask = self.s_mask[layer_idx]
130
+ assert cached_s_mask is not None
131
+ if cached_s_mask.shape[-1] == window_size:
132
+ cur_s_mask = cur_s_mask
133
+ else:
134
+ cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1)
135
+
136
+ # store the past window-wise key-value pairs
137
+ self.s_mask[layer_idx] = cur_s_mask
138
+
139
+ ############################################
140
+ # compute masks for intra-chunks
141
+ ############################################
142
+ dump_rf_mask = None
143
+ if len(self.rf_mask) <= layer_idx:
144
+ # initialize chunk stats
145
+ # prefill stage
146
+ if q_len < chunk_size:
147
+ cur_rf_mask = rf_mask
148
+ else:
149
+ if q_len % chunk_size == 0:
150
+ dump_rf_mask = rf_mask
151
+ cur_rf_mask = None
152
+ else:
153
+ remainder_tokens = q_len % chunk_size
154
+ if rf_mask is not None:
155
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
156
+ else:
157
+ dump_rf_mask = None
158
+ cur_rf_mask = None
159
+ self.rf_mask.append(cur_rf_mask)
160
+ else:
161
+ past_rf_mask = self.rf_mask[layer_idx]
162
+ if past_rf_mask is not None:
163
+ # when decoding tokens, we always assume the
164
+ # incoming token mask is 0 (not masked)
165
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
166
+ else:
167
+ # we do not need to use rf_mask anymore after we receive generated tokens
168
+ cur_rf_mask = None
169
+ # We need to store rf_k_bar and RFA-results that
170
+ # compute the per-chunk RFA.
171
+
172
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
173
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size:
174
+ dump_rf_mask = cur_rf_mask
175
+ cur_rf_mask = None
176
+
177
+ self.rf_mask[layer_idx] = cur_rf_mask
178
+
179
+ ############################################
180
+ # compute masks for inter chunks
181
+ ############################################
182
+ if len(self.chunk_mask) <= layer_idx:
183
+ # prefill stage
184
+ # q is of shape [b, h, n, d]
185
+ if q_len < window_size:
186
+ cur_chunk_mask = chunk_mask
187
+ prev_chunk_mask = None
188
+ else:
189
+ if q_len % window_size == 0:
190
+ cur_chunk_mask = None
191
+ prev_chunk_mask = chunk_mask
192
+ else:
193
+ remainder_tokens = q_len % window_size
194
+ # [b, h, n-r, d] [b, h, r, d]
195
+ prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
196
+ bsz, num_heads, _, head_dim = prev_chunk_mask.shape
197
+ prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim)
198
+
199
+ assert prev_s_mask is not None
200
+ if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1:
201
+ # need to expand
202
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
203
+ # w_v = # [b, h, 1, j, d]
204
+ # store the past window-wise key-value pairs
205
+ self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :])
206
+ else:
207
+ # decoding stage
208
+ prev_chunk_mask = None
209
+ cur_chunk_mask = self.chunk_mask[layer_idx]
210
+
211
+ # if the current sequence length reaches <chunk_size>,
212
+ # we append a new 1 to the end of chunk_mask
213
+ seen_seq_len = self.get_seq_length(layer_idx)
214
+ if seen_seq_len > 0 and seen_seq_len % chunk_size == 0:
215
+ past_chunk_mask = self.chunk_mask[layer_idx]
216
+ if past_chunk_mask is not None:
217
+ # when decoding tokens, we always assume the
218
+ # incoming token mask is 0 (not masked)
219
+ cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1)
220
+ else:
221
+ cur_chunk_mask = chunk_mask
222
+ self.chunk_mask[layer_idx] = cur_chunk_mask
223
+
224
+ # if the len of current sequence reaches <window_size> + 1,
225
+ # we turn on the mask for most recent chunks
226
+ if seen_seq_len > 0 and seen_seq_len % window_size == 1:
227
+ cur_chunk_mask = self.chunk_mask[layer_idx]
228
+ # we do not need to use rf_mask anymore after we receive generated tokens
229
+ num_chunks_per_window = window_size // chunk_size
230
+ cur_chunk_mask[..., -num_chunks_per_window:] = False
231
+ self.chunk_mask[layer_idx] = cur_chunk_mask
232
+
233
+ return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask)
234
+
235
+ def update_singletons(
236
+ self,
237
+ q,
238
+ k,
239
+ v,
240
+ layer_idx,
241
+ window_size,
242
+ ):
243
+ if len(self.w_k) <= layer_idx:
244
+ # prefill stage
245
+ # q is of shape [b, h, n, d]
246
+ q_len = q.shape[-2]
247
+ if q_len < window_size:
248
+ w_q = q
249
+ w_k = k
250
+ w_v = v
251
+ past_w_q = past_w_k = past_w_v = None
252
+ else:
253
+ if q_len % window_size == 0:
254
+ w_q = None
255
+ w_k = None
256
+ w_v = None
257
+ past_w_q = q
258
+ past_w_k = k
259
+ past_w_v = v
260
+ else:
261
+ remainder_tokens = q_len % window_size
262
+ # [b, h, n-r, d] [b, h, r, d]
263
+ past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
264
+ past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
265
+ past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
266
+ bsz, num_heads, _, head_dim = past_w_q.shape
267
+ past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim)
268
+ past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim)
269
+ past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim)
270
+ # w_q = q[..., None, -window_size:, :] # [b, h, 1, j, d]
271
+ # w_k = # [b, h, 1, j, d]
272
+ # w_v = # [b, h, 1, j, d]
273
+ # store the past window-wise key-value pairs
274
+ # if w_k is None, it means we happen to pass in a sqeuence that is divisible by window_size
275
+ # we leave the cache with window_size-sized kv pairs to be cleared next iteration
276
+ self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :])
277
+ self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :])
278
+ else:
279
+ # decoding stage
280
+ past_w_q = past_w_k = past_w_v = None
281
+ # this is implemented as either a sliding window or fixed window
282
+ w_q = q # [b, h, 1, d]
283
+ w_k = k # [b, h, 1, d]
284
+ w_v = v # [b, h, 1, d]
285
+
286
+ cached_w_k = self.w_k[layer_idx]
287
+ assert cached_w_k is not None # [b, h, j, d]
288
+ if cached_w_k.shape[-2] == window_size:
289
+ w_k = w_k
290
+ else:
291
+ w_k = torch.cat([cached_w_k, w_k], dim=-2)
292
+
293
+ cached_w_v = self.w_v[layer_idx]
294
+ assert cached_w_v is not None
295
+ if cached_w_v.shape[-2] == window_size:
296
+ w_v = w_v
297
+ else:
298
+ w_v = torch.cat([cached_w_v, w_v], dim=-2)
299
+
300
+ # store the past window-wise key-value pairs
301
+ self.w_k[layer_idx] = w_k
302
+ self.w_v[layer_idx] = w_v
303
+ return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v)
304
+
305
+ def update_chunks(
306
+ self,
307
+ q,
308
+ k,
309
+ v,
310
+ layer_idx,
311
+ chunk_size
312
+ ):
313
+ q_len = q.shape[-2]
314
+ dump_q = None
315
+ dump_k = None
316
+ dump_v = None
317
+ if len(self.rf_q) <= layer_idx:
318
+ # initialize chunk stats
319
+ # prefill stage
320
+ if q_len < chunk_size:
321
+ rf_q = q
322
+ rf_k = k
323
+ rf_v = v
324
+ else:
325
+ if q_len % chunk_size == 0:
326
+ rf_q = None
327
+ rf_k = None
328
+ rf_v = None
329
+ dump_q = q
330
+ dump_k = k
331
+ dump_v = v
332
+ else:
333
+ remainder_tokens = q_len % chunk_size
334
+ # [b, h, n-r, d] [b, h, r, d]
335
+ dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2)
336
+ dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2)
337
+ dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2)
338
+ self.rf_q.append(rf_q)
339
+ self.rf_k.append(rf_k)
340
+ self.rf_v.append(rf_v)
341
+ else:
342
+ # decode tokens
343
+ # add query, key & value to the current chunk.
344
+ past_rf_q = self.rf_q[layer_idx]
345
+ if past_rf_q is not None:
346
+ rf_q = torch.cat([past_rf_q, q], dim=-2)
347
+ else:
348
+ rf_q = q
349
+
350
+ past_rf_k = self.rf_k[layer_idx]
351
+ if past_rf_k is not None:
352
+ rf_k = torch.cat([past_rf_k, k], dim=-2)
353
+ else:
354
+ rf_k = k
355
+
356
+ past_rf_v = self.rf_v[layer_idx]
357
+ if past_rf_v is not None:
358
+ rf_v = torch.cat([past_rf_v, v], dim=-2)
359
+ else:
360
+ rf_v = v
361
+
362
+ # We need to store rf_k_bar and RFA-results that
363
+ # compute the per-chunk RFA.
364
+
365
+ # Dump the chunk if the len of current chunk reaches <chunk_size>.
366
+ if rf_q.shape[-2] == chunk_size:
367
+ dump_q = rf_q
368
+ dump_k = rf_k
369
+ dump_v = rf_v
370
+ # clear the chunk
371
+ rf_q = None
372
+ rf_k = None
373
+ rf_v = None
374
+
375
+ self.rf_q[layer_idx] = rf_q
376
+ self.rf_k[layer_idx] = rf_k
377
+ self.rf_v[layer_idx] = rf_v
378
+
379
+ return dump_q, dump_k, dump_v
380
+
381
+ def update_chunk_rfas(
382
+ self,
383
+ softmax_phi_k_v,
384
+ log_sum_phi_k,
385
+ rf_k_bar,
386
+ layer_idx,
387
+ random_feature_dim
388
+ ):
389
+ if len(self.softmax_phi_k_v) <= layer_idx:
390
+ # prefill stage
391
+ self.softmax_phi_k_v.append(softmax_phi_k_v)
392
+ self.log_sum_phi_k.append(log_sum_phi_k)
393
+ self.rf_k_bar.append(rf_k_bar)
394
+ else:
395
+ # token decoding
396
+ past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx]
397
+ past_log_sum_phi_k = self.log_sum_phi_k[layer_idx]
398
+ past_rf_k_bar = self.rf_k_bar[layer_idx]
399
+
400
+ if past_softmax_phi_k_v is not None:
401
+ if random_feature_dim == 1:
402
+ dim = -2
403
+ else:
404
+ dim = -3
405
+ softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim)
406
+
407
+ if past_log_sum_phi_k is not None:
408
+ if random_feature_dim == 1:
409
+ dim = -2
410
+ else:
411
+ dim = -3
412
+ log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim)
413
+
414
+ if past_rf_k_bar is not None:
415
+ rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2)
416
+
417
+ self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v
418
+ self.log_sum_phi_k[layer_idx] = log_sum_phi_k
419
+ self.rf_k_bar[layer_idx] = rf_k_bar
420
+
421
+ return softmax_phi_k_v, log_sum_phi_k, rf_k_bar
422
+
423
+ def get_chunk_rfas(self, layer_idx):
424
+ if len(self.softmax_phi_k_v) <= layer_idx:
425
+ return (
426
+ None,
427
+ None,
428
+ None
429
+ )
430
+ else:
431
+ return (
432
+ self.softmax_phi_k_v[layer_idx],
433
+ self.log_sum_phi_k[layer_idx],
434
+ self.rf_k_bar[layer_idx]
435
+ )
436
+
437
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
438
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
439
+ if len(self.w_k) <= layer_idx:
440
+ return 0
441
+ return self._seen_tokens
442
+
443
+ def get_max_length(self) -> Optional[int]:
444
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
445
+ return None
446
+
447
+ def update(
448
+ self,
449
+ layer_idx: int,
450
+ cache_kwargs: Optional[Dict[str, Any]] = None,
451
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ raise NotImplementedError("`update` is not used in Eva Cache.")
453
+
454
+ class EvaStaticCacheForTriton(Cache):
455
+ """
456
+ A variant of EvaCache for eva's triton kernels
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ batch_size,
462
+ num_key_value_heads,
463
+ window_size,
464
+ head_dim,
465
+ num_layers,
466
+ dtype,
467
+ device
468
+ ) -> None:
469
+ self.past_window_k: List[torch.Tensor] = []
470
+ self.past_window_v: List[torch.Tensor] = []
471
+
472
+ cache_shape = (batch_size, num_key_value_heads, window_size, head_dim)
473
+ for idx in range(num_layers):
474
+ new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device)
475
+ new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device)
476
+ self.past_window_k.append(new_window_k)
477
+ self.past_window_v.append(new_window_v)
478
+
479
+ self.past_window_pos: List[int] = []
480
+
481
+ self.rfa_k: List[torch.Tensor] = []
482
+ self.rfa_v: List[torch.Tensor] = []
483
+ # self.rfa_mask: List[torch.Tensor] = []
484
+
485
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
486
+
487
+ # attention masks temporary buffer
488
+ self.rf_mask: List[Optional[torch.Tensor]] = []
489
+ self.s_mask: List[torch.Tensor] = []
490
+
491
+ def __len__(self):
492
+ """
493
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
494
+ to the number of layers in the model.
495
+ """
496
+ return len(self.past_window_pos)
497
+
498
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
499
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
500
+ # Cache without size limit -> all cache is usable
501
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
502
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
503
+ max_length = self.get_max_length()
504
+ previous_seq_length = self.get_seq_length(layer_idx)
505
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
506
+ return max_length - new_seq_length
507
+ return previous_seq_length
508
+
509
+ def reorder_cache(self, beam_idx: torch.LongTensor):
510
+ """Reorders the cache for beam search, given the selected beam indices."""
511
+ for layer_idx in range(len(self.past_window_k)):
512
+ device = self.past_window_k[layer_idx].device
513
+ self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device))
514
+
515
+ device = self.past_window_v[layer_idx].device
516
+ self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device))
517
+
518
+ device = self.rfa_k[layer_idx].device
519
+ self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device))
520
+
521
+ device = self.rfa_v[layer_idx].device
522
+ self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device))
523
+
524
+ # device = self.rfa_mask[layer_idx].device
525
+ # self.rfa_mask[layer_idx] = self.rfa_mask[layer_idx].index_select(0, beam_idx.to(device))
526
+
527
+ device = self.rf_mask[layer_idx].device
528
+ self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device))
529
+
530
+ device = self.s_mask[layer_idx].device
531
+ self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device))
532
+
533
+ @property
534
+ def seen_tokens(self):
535
+ if hasattr(self, "_seen_tokens"):
536
+ return self._seen_tokens
537
+ else:
538
+ return None
539
+
540
+ def update_past_len(
541
+ self,
542
+ cur_q_len: int,
543
+ layer_idx: int
544
+ ):
545
+ # Update the number of seen tokens
546
+ if layer_idx == 0:
547
+ self._seen_tokens += cur_q_len
548
+ return self._seen_tokens
549
+
550
+ def update_mask(
551
+ self,
552
+ s_mask,
553
+ rf_mask,
554
+ layer_idx,
555
+ window_size,
556
+ ):
557
+ ############################################
558
+ # compute masks for singletons
559
+ ############################################
560
+ if len(self.s_mask) <= layer_idx:
561
+ # prefill stage
562
+ # q is of shape [b, h, n, d]
563
+ # s_v = # [b, h, 1, j, d]
564
+ # store the past window-wise key-value pairs
565
+ if s_mask is None:
566
+ cur_s_mask = None
567
+ else:
568
+ q_len = s_mask.shape[-2]
569
+ # s_mask is of shape [b, h, n, w]
570
+ # let r = q_len % window_size
571
+ # if r == 0, the mask to be appended is of shape [b, h, 1, w]
572
+ # otherwise, r < w, the mask to be appended is of shape [b, h, 1, r]
573
+ remainder_tokens = q_len % window_size
574
+ if remainder_tokens == 0:
575
+ cur_s_mask = None
576
+ else:
577
+ cur_s_mask = s_mask[..., -1:, :remainder_tokens]
578
+ self.s_mask.append(cur_s_mask)
579
+ # we use the passed s_mask for subsequent computations
580
+ dump_s_mask = s_mask
581
+ else:
582
+ # decoding stage
583
+ past_s_mask = self.s_mask[layer_idx]
584
+ if past_s_mask is None:
585
+ assert s_mask is None
586
+ cur_s_mask = None
587
+ else:
588
+ assert s_mask is not None
589
+ cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1)
590
+
591
+ dump_s_mask = cur_s_mask
592
+ if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size:
593
+ cur_s_mask = None
594
+ # store the past window-wise key-value pairs
595
+ self.s_mask[layer_idx] = cur_s_mask
596
+
597
+ ############################################
598
+ # compute masks for intra-chunks
599
+ ############################################
600
+ dump_rf_mask = None
601
+ if len(self.rf_mask) <= layer_idx:
602
+ # initialize chunk stats
603
+ # prefill stage
604
+ if rf_mask is None:
605
+ cur_rf_mask = None
606
+ else:
607
+ q_len = rf_mask.shape[-2]
608
+ if q_len < window_size:
609
+ dump_rf_mask = None
610
+ cur_rf_mask = rf_mask
611
+ else:
612
+ if q_len % window_size == 0:
613
+ dump_rf_mask = rf_mask
614
+ cur_rf_mask = None
615
+ else:
616
+ remainder_tokens = q_len % window_size
617
+ dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2)
618
+ self.rf_mask.append(cur_rf_mask)
619
+ else:
620
+ past_rf_mask = self.rf_mask[layer_idx]
621
+ if past_rf_mask is not None:
622
+ # when decoding tokens, we always assume the
623
+ # incoming token mask is 0 (not masked)
624
+ cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2)
625
+ else:
626
+ cur_rf_mask = None
627
+
628
+ if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size:
629
+ dump_rf_mask = cur_rf_mask
630
+ cur_rf_mask = None
631
+
632
+ self.rf_mask[layer_idx] = cur_rf_mask
633
+
634
+ return dump_s_mask, dump_rf_mask
635
+
636
+ def update_singletons_and_chunks(
637
+ self,
638
+ k,
639
+ v,
640
+ layer_idx,
641
+ window_size,
642
+ ):
643
+ if len(self.past_window_pos) <= layer_idx:
644
+ # prefill stage
645
+ s_k = k
646
+ s_v = v
647
+ input_len = k.shape[-2]
648
+ window_pos = 0
649
+ if input_len <= window_size:
650
+ new_window_pos = window_pos + input_len
651
+
652
+ cached_window_k = k
653
+ cached_window_v = v
654
+ dump_k = None
655
+ dump_v = None
656
+ else:
657
+ remainder_tokens = input_len % window_size
658
+ if remainder_tokens == 0:
659
+ remainder_tokens = window_size
660
+ new_window_pos = window_pos + remainder_tokens
661
+
662
+ # [b, h, n-r, d] [b, h, r, d]
663
+ cached_window_k = k[..., -remainder_tokens:, :]
664
+ cached_window_v = v[..., -remainder_tokens:, :]
665
+ dump_k = k[..., :-remainder_tokens, :]
666
+ dump_v = v[..., :-remainder_tokens, :]
667
+ # store the past window-wise key-value pairs
668
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k
669
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v
670
+ self.past_window_pos.append(new_window_pos)
671
+ else:
672
+ # decoding stage
673
+ # if the previous cache has full tokens,
674
+ # roll back to the first elements
675
+ if self.past_window_pos[layer_idx] == window_size:
676
+ self.past_window_pos[layer_idx] = 0
677
+ dump_k = self.past_window_k[layer_idx].clone()
678
+ dump_v = self.past_window_v[layer_idx].clone()
679
+ else:
680
+ dump_k = None
681
+ dump_v = None
682
+
683
+ input_len = k.shape[-2]
684
+ window_pos = self.past_window_pos[layer_idx]
685
+ new_window_pos = window_pos + input_len
686
+
687
+ self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k
688
+ self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v
689
+
690
+ s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :]
691
+ s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :]
692
+
693
+ self.past_window_pos[layer_idx] = new_window_pos
694
+
695
+ return s_k, s_v, dump_k, dump_v
696
+
697
+ def update_chunk_rfas(
698
+ self,
699
+ rfa_k,
700
+ rfa_v,
701
+ layer_idx,
702
+ ):
703
+ if len(self.rfa_k) <= layer_idx:
704
+ # prefill stage
705
+ self.rfa_k.append(rfa_k)
706
+ self.rfa_v.append(rfa_v)
707
+ else:
708
+ # token decoding
709
+ past_rfa_k = self.rfa_k[layer_idx]
710
+ past_rfa_v = self.rfa_v[layer_idx]
711
+
712
+ if past_rfa_k is not None:
713
+ rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2)
714
+
715
+ if past_rfa_v is not None:
716
+ rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2)
717
+
718
+ self.rfa_k[layer_idx] = rfa_k
719
+ self.rfa_v[layer_idx] = rfa_v
720
+
721
+ return rfa_k, rfa_v
722
+
723
+ def get_past_window_pos(self, layer_idx):
724
+ if len(self.past_window_pos) <= layer_idx:
725
+ return None
726
+ else:
727
+ return self.past_window_pos[layer_idx]
728
+
729
+ def get_past_window_kv(self, layer_idx):
730
+ if len(self.past_window_pos) <= layer_idx:
731
+ return None, None
732
+ else:
733
+ return (
734
+ self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :],
735
+ self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :]
736
+ )
737
+
738
+ def get_chunk_rfas(self, layer_idx):
739
+ if len(self.rfa_k) <= layer_idx:
740
+ return None, None
741
+ else:
742
+ return self.rfa_k[layer_idx], self.rfa_v[layer_idx]
743
+
744
+ def get_seq_length(self, layer_idx = 0) -> int:
745
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
746
+ # layer_idx must be provided since otherwise
747
+ # any layer > 0 can only get the updated _seen_tokens
748
+ if len(self.past_window_pos) <= layer_idx:
749
+ return 0
750
+ return self._seen_tokens
751
+
752
+ def get_max_length(self) -> Optional[int]:
753
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
754
+ return None
755
+
756
+ def update(
757
+ self,
758
+ layer_idx: int,
759
+ cache_kwargs: Optional[Dict[str, Any]] = None,
760
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
761
+ raise NotImplementedError("`update` is not used in Eva Cache.")
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_prep_kv_kernel.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
10
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
11
+ }
12
+ )
13
+ @triton.jit
14
+ def _fwd_eva_prep_kv_kernel(
15
+ K, # [b, h, n, d]
16
+ V, # [b, h, n, d]
17
+ PARAM_MU, # [1, h, 1, 1, d]
18
+ PARAM_PHI, # [1, h, 1, 1, d]
19
+ Mask, # [b, h, n, 1]
20
+ Out_RFA_K, # [b, h, c, d]
21
+ Out_RFA_V, # [b, h, c, d]
22
+ softmax_scale,
23
+ stride_kb, stride_kh, stride_kn,
24
+ stride_vb, stride_vh, stride_vn,
25
+ stride_mu_h,
26
+ stride_phi_h,
27
+ stride_mb, stride_mn,
28
+ stride_ok_b, stride_ok_h, stride_ok_c,
29
+ stride_ov_b, stride_ov_h, stride_ov_c,
30
+ nheads,
31
+ seqlen,
32
+ nchunks,
33
+ headdim,
34
+ CHUNKS_PER_BLOCK: tl.constexpr,
35
+ CHUNK_SIZE: tl.constexpr,
36
+ MASK_TYPE: tl.constexpr,
37
+ BLOCK_HEADDIM: tl.constexpr,
38
+ EVEN_N: tl.constexpr,
39
+ EVEN_HEADDIM: tl.constexpr,
40
+ BLOCK_N: tl.constexpr,
41
+ ):
42
+ start_n = tl.program_id(0)
43
+ offs_bh = tl.program_id(1)
44
+ offs_h = offs_bh % nheads
45
+ offs_b = offs_bh // nheads
46
+ # initialize offsets
47
+ # we load BLOCK_N keys and values each time, and
48
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
49
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
50
+ offs_m = tl.arange(0, CHUNK_SIZE)
51
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
52
+
53
+ k_ptrs = (
54
+ K +
55
+ offs_b * stride_kb +
56
+ offs_h * stride_kh +
57
+ (
58
+ (
59
+ start_n * BLOCK_N +
60
+ offs_c[:, None, None] * CHUNK_SIZE +
61
+ offs_m[None, :, None]
62
+ ) * stride_kn +
63
+ offs_d[None, None, :]
64
+ )
65
+ )
66
+ v_ptrs = (
67
+ V +
68
+ offs_b * stride_vb +
69
+ offs_h * stride_vh +
70
+ (
71
+ (
72
+ start_n * BLOCK_N +
73
+ offs_c[:, None, None] * CHUNK_SIZE +
74
+ offs_m[None, :, None]
75
+ ) * stride_vn +
76
+ offs_d[None, None, :]
77
+ )
78
+ )
79
+ param_mu_ptrs = (
80
+ PARAM_MU +
81
+ offs_h * stride_mu_h +
82
+ offs_d[None, None, :]
83
+ )
84
+ param_phi_ptrs = (
85
+ PARAM_PHI +
86
+ offs_h * stride_phi_h +
87
+ offs_d[None, None, :]
88
+ )
89
+ log2e = 1.4426950408889634
90
+ if MASK_TYPE == 1:
91
+ m_ptrs = (
92
+ Mask +
93
+ offs_b * stride_mb +
94
+ (
95
+ (
96
+ start_n * BLOCK_N +
97
+ offs_c[:, None] * CHUNK_SIZE +
98
+ offs_m[None, :]
99
+ ) * stride_mn
100
+ )
101
+ )
102
+ if EVEN_N:
103
+ if EVEN_HEADDIM:
104
+ k = tl.load(
105
+ k_ptrs
106
+ )
107
+ else:
108
+ k = tl.load(
109
+ k_ptrs,
110
+ mask=offs_d[None, None, :] < headdim,
111
+ other=0.0
112
+ )
113
+ else:
114
+ if EVEN_HEADDIM:
115
+ k = tl.load(
116
+ k_ptrs,
117
+ mask=(
118
+ start_n * BLOCK_N +
119
+ offs_c[:, None, None] * CHUNK_SIZE +
120
+ offs_m[None, :, None]
121
+ ) < seqlen,
122
+ other=0.0
123
+ )
124
+ else:
125
+ k = tl.load(
126
+ k_ptrs,
127
+ mask=(
128
+ (
129
+ start_n * BLOCK_N +
130
+ offs_c[:, None, None] * CHUNK_SIZE +
131
+ offs_m[None, :, None]
132
+ ) < seqlen
133
+ ) & (offs_d[None, None, :] < headdim),
134
+ other=0.0
135
+ )
136
+
137
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
138
+ rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
139
+ rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
140
+ rfa_k_c_w *= log2e
141
+ if MASK_TYPE == 1:
142
+ if EVEN_N:
143
+ mask = tl.load(
144
+ m_ptrs
145
+ )
146
+ else:
147
+ mask = tl.load(
148
+ m_ptrs,
149
+ mask=(
150
+ start_n * BLOCK_N +
151
+ offs_c[:, None] * CHUNK_SIZE +
152
+ offs_m[None, :]
153
+ ) < seqlen,
154
+ other=1,
155
+ )
156
+ rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
157
+
158
+ m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
159
+ masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
160
+ m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
161
+ rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
162
+ denom_k = tl.sum(rfa_k_c_w, axis=-1)
163
+ denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
164
+ rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
165
+ rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
166
+ # TODO: understand why rematerialize offsets to save registers?
167
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
168
+ out_rfa_k_ptrs = (
169
+ Out_RFA_K +
170
+ offs_b * stride_ok_b +
171
+ offs_h * stride_ok_h +
172
+ (offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
173
+ )
174
+
175
+ if EVEN_N:
176
+ if EVEN_HEADDIM:
177
+ tl.store(
178
+ out_rfa_k_ptrs, rfa_k_c
179
+ )
180
+ else:
181
+ tl.store(
182
+ out_rfa_k_ptrs, rfa_k_c,
183
+ mask=offs_d[None, :] < headdim
184
+ )
185
+ else:
186
+ if EVEN_HEADDIM:
187
+ tl.store(
188
+ out_rfa_k_ptrs, rfa_k_c,
189
+ mask=offs_out_c[:, None] < nchunks
190
+ )
191
+ else:
192
+ tl.store(
193
+ out_rfa_k_ptrs, rfa_k_c,
194
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
195
+ )
196
+
197
+
198
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
199
+ rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
200
+ rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
201
+ rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
202
+ rfa_v_c_w *= log2e * softmax_scale
203
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
204
+ rfa_v_c_w += tl.where(
205
+ (
206
+ start_n * BLOCK_N +
207
+ offs_c[:, None] * CHUNK_SIZE +
208
+ offs_m[None, :]
209
+ ) < seqlen,
210
+ 0,
211
+ float("-inf")
212
+ )
213
+
214
+ if MASK_TYPE == 1:
215
+ rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
216
+
217
+ if EVEN_N:
218
+ if EVEN_HEADDIM:
219
+ v = tl.load(
220
+ v_ptrs
221
+ )
222
+ else:
223
+ v = tl.load(
224
+ v_ptrs,
225
+ mask=offs_d[None, None, :] < headdim,
226
+ other=0.0
227
+ )
228
+ else:
229
+ if EVEN_HEADDIM:
230
+ v = tl.load(
231
+ v_ptrs,
232
+ mask=(
233
+ start_n * BLOCK_N +
234
+ offs_c[:, None, None] * CHUNK_SIZE +
235
+ offs_m[None, :, None]
236
+ ) < seqlen,
237
+ other=0.0
238
+ )
239
+ else:
240
+ v = tl.load(
241
+ v_ptrs,
242
+ mask=(
243
+ (
244
+ start_n * BLOCK_N +
245
+ offs_c[:, None, None] * CHUNK_SIZE +
246
+ offs_m[None, :, None]
247
+ ) < seqlen
248
+ ) & (offs_d[None, None, :] < headdim),
249
+ other=0.0
250
+ )
251
+
252
+
253
+ m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
254
+ masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
255
+ m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
256
+ rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
257
+ denom_v = tl.sum(rfa_v_c_w, axis=-1)
258
+ denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
259
+ rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
260
+ rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
261
+
262
+ offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
263
+ out_rfa_v_ptrs = (
264
+ Out_RFA_V +
265
+ offs_b * stride_ov_b +
266
+ offs_h * stride_ov_h +
267
+ (offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
268
+ )
269
+ if EVEN_N:
270
+ if EVEN_HEADDIM:
271
+ tl.store(
272
+ out_rfa_v_ptrs, rfa_v_c
273
+ )
274
+ else:
275
+ tl.store(
276
+ out_rfa_v_ptrs, rfa_v_c,
277
+ mask=offs_d[None, :] < headdim
278
+ )
279
+ else:
280
+ if EVEN_HEADDIM:
281
+ tl.store(
282
+ out_rfa_v_ptrs, rfa_v_c,
283
+ mask=offs_out_c[:, None] < nchunks
284
+ )
285
+ else:
286
+ tl.store(
287
+ out_rfa_v_ptrs, rfa_v_c,
288
+ mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
289
+ )
290
+
291
+
292
+
293
+ @triton.heuristics(
294
+ {
295
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
296
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
297
+ }
298
+ )
299
+ @triton.jit
300
+ def _bwd_eva_prep_kv_kernel(
301
+ RFA_K, # [b, h, c, d]
302
+ RFA_V, # [b, h, c, d]
303
+ K, # [b, h, n, d]
304
+ V, # [b, h, n, d]
305
+ PARAM_MU, # [1, h, 1, 1, d]
306
+ PARAM_PHI, # [1, h, 1, 1, d]
307
+ Mask, # [b, h, n, 1]
308
+ D_RFA_K, # [b, h, c, d]
309
+ D_RFA_V, # [b, h, c, d]
310
+ D_K, # [b, h, n, d]
311
+ D_V, # [b, h, n, d]
312
+ D_PARAM_MU_PARTIAL, # [b, h, g, d]
313
+ D_PARAM_PHI_PARTIAL, # [b, h, g, d]
314
+ softmax_scale,
315
+ stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
316
+ stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
317
+ stride_kb, stride_kh, stride_kn,
318
+ stride_vb, stride_vh, stride_vn,
319
+ stride_mu_h,
320
+ stride_phi_h,
321
+ stride_mb, stride_mn,
322
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
323
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
324
+ stride_d_k_b, stride_d_k_h, stride_d_k_n,
325
+ stride_d_v_b, stride_d_v_h, stride_d_v_n,
326
+ stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
327
+ stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
328
+ nheads,
329
+ seqlen,
330
+ nchunks,
331
+ headdim,
332
+ CHUNKS_PER_BLOCK: tl.constexpr,
333
+ CHUNK_SIZE: tl.constexpr,
334
+ MASK_TYPE: tl.constexpr,
335
+ BLOCK_HEADDIM: tl.constexpr,
336
+ EVEN_N: tl.constexpr,
337
+ EVEN_HEADDIM: tl.constexpr,
338
+ BLOCK_N: tl.constexpr,
339
+ ):
340
+ start_n = tl.program_id(0)
341
+ offs_bh = tl.program_id(1)
342
+ offs_h = offs_bh % nheads
343
+ offs_b = offs_bh // nheads
344
+ # initialize offsets
345
+ # we load BLOCK_N keys and values each time, and
346
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
347
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
348
+ offs_m = tl.arange(0, CHUNK_SIZE)
349
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
350
+
351
+ offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
352
+
353
+ k_ptrs = (
354
+ K +
355
+ offs_b * stride_kb +
356
+ offs_h * stride_kh +
357
+ (
358
+ (
359
+ start_n * BLOCK_N +
360
+ offs_c[:, None, None] * CHUNK_SIZE +
361
+ offs_m[None, :, None]
362
+ ) * stride_kn +
363
+ offs_d[None, None, :]
364
+ )
365
+ )
366
+ rfa_k_ptrs = (
367
+ RFA_K +
368
+ offs_b * stride_rfa_k_b +
369
+ offs_h * stride_rfa_k_h +
370
+ (offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
371
+ )
372
+ rfa_v_ptrs = (
373
+ RFA_V +
374
+ offs_b * stride_rfa_v_b +
375
+ offs_h * stride_rfa_v_h +
376
+ (offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
377
+ )
378
+
379
+ d_rfa_k_ptrs = (
380
+ D_RFA_K +
381
+ offs_b * stride_d_rfa_k_b +
382
+ offs_h * stride_d_rfa_k_h +
383
+ (offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
384
+ )
385
+ d_rfa_v_ptrs = (
386
+ D_RFA_V +
387
+ offs_b * stride_d_rfa_v_b +
388
+ offs_h * stride_d_rfa_v_h +
389
+ (offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
390
+ )
391
+
392
+ param_mu_ptrs = (
393
+ PARAM_MU +
394
+ offs_h * stride_mu_h +
395
+ offs_d[None, None, :]
396
+ )
397
+ param_phi_ptrs = (
398
+ PARAM_PHI +
399
+ offs_h * stride_phi_h +
400
+ offs_d[None, None, :]
401
+ )
402
+
403
+ log2e = 1.4426950408889634
404
+ if MASK_TYPE == 1:
405
+ m_ptrs = (
406
+ Mask +
407
+ offs_b * stride_mb +
408
+ (
409
+ (
410
+ start_n * BLOCK_N +
411
+ offs_c[:, None] * CHUNK_SIZE +
412
+ offs_m[None, :]
413
+ ) * stride_mn
414
+ )
415
+ )
416
+ if EVEN_N:
417
+ if EVEN_HEADDIM:
418
+ k = tl.load(
419
+ k_ptrs
420
+ )
421
+ else:
422
+ k = tl.load(
423
+ k_ptrs,
424
+ mask=offs_d[None, None, :] < headdim,
425
+ other=0.0
426
+ )
427
+ else:
428
+ if EVEN_HEADDIM:
429
+ k = tl.load(
430
+ k_ptrs,
431
+ mask=(
432
+ start_n * BLOCK_N +
433
+ offs_c[:, None, None] * CHUNK_SIZE +
434
+ offs_m[None, :, None]
435
+ ) < seqlen,
436
+ other=0.0
437
+ )
438
+ else:
439
+ k = tl.load(
440
+ k_ptrs,
441
+ mask=(
442
+ (
443
+ start_n * BLOCK_N +
444
+ offs_c[:, None, None] * CHUNK_SIZE +
445
+ offs_m[None, :, None]
446
+ ) < seqlen
447
+ ) & (offs_d[None, None, :] < headdim),
448
+ other=0.0
449
+ )
450
+
451
+ if EVEN_N:
452
+ if EVEN_HEADDIM:
453
+ rfa_k = tl.load(
454
+ rfa_k_ptrs
455
+ )
456
+ else:
457
+ rfa_k = tl.load(
458
+ rfa_k_ptrs,
459
+ mask=offs_d[None, :] < headdim,
460
+ other=0.0
461
+ )
462
+ else:
463
+ if EVEN_HEADDIM:
464
+ rfa_k = tl.load(
465
+ rfa_k_ptrs,
466
+ mask=offs_rfa_c[:, None] < nchunks,
467
+ other=0.0
468
+ )
469
+ else:
470
+ rfa_k = tl.load(
471
+ rfa_k_ptrs,
472
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
473
+ other=0.0
474
+ )
475
+
476
+ if EVEN_N:
477
+ if EVEN_HEADDIM:
478
+ d_rfa_k = tl.load(
479
+ d_rfa_k_ptrs
480
+ )
481
+ else:
482
+ d_rfa_k = tl.load(
483
+ d_rfa_k_ptrs,
484
+ mask=offs_d[None, :] < headdim,
485
+ other=0.0
486
+ )
487
+ else:
488
+ if EVEN_HEADDIM:
489
+ d_rfa_k = tl.load(
490
+ d_rfa_k_ptrs,
491
+ mask=offs_rfa_c[:, None] < nchunks,
492
+ other=0.0
493
+ )
494
+ else:
495
+ d_rfa_k = tl.load(
496
+ d_rfa_k_ptrs,
497
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
498
+ other=0.0
499
+ )
500
+
501
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
502
+ mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
503
+ mu_c_w += tl.sum(k * param_mu, axis=-1)
504
+ mu_c_w *= log2e
505
+
506
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
507
+ mu_c_w += tl.where(
508
+ (
509
+ start_n * BLOCK_N +
510
+ offs_c[:, None] * CHUNK_SIZE +
511
+ offs_m[None, :]
512
+ ) < seqlen,
513
+ 0,
514
+ float("-inf")
515
+ )
516
+
517
+ if MASK_TYPE == 1:
518
+ if EVEN_N:
519
+ mask = tl.load(
520
+ m_ptrs
521
+ )
522
+ else:
523
+ mask = tl.load(
524
+ m_ptrs,
525
+ mask=(
526
+ start_n * BLOCK_N +
527
+ offs_c[:, None] * CHUNK_SIZE +
528
+ offs_m[None, :]
529
+ ) < seqlen,
530
+ other=1,
531
+ )
532
+ mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
533
+
534
+ # [c, w]
535
+ m_mu_c_w = tl.max(mu_c_w, axis=-1)
536
+ masked_out_rows_mu = (m_mu_c_w == float("-inf"))
537
+ m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
538
+ mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
539
+ denom_mu = tl.sum(mu_c_w, axis=-1)
540
+ denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
541
+ mu_tilde_c_w = mu_c_w / denom_mu[:, None]
542
+ mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
543
+ # [c, d] [c, w, d] -> [c, w]
544
+ d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
545
+ # [c, d] [c, d] -> [c]
546
+ d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
547
+ d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
548
+
549
+ # [c, w] [c, w, d] -> [d]
550
+ d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
551
+ # [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
552
+ d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
553
+
554
+ d_param_mu_partial_ptrs = (
555
+ D_PARAM_MU_PARTIAL +
556
+ offs_b * stride_d_mu_b +
557
+ offs_h * stride_d_mu_h +
558
+ start_n * stride_d_mu_g +
559
+ offs_d
560
+ )
561
+ if EVEN_HEADDIM:
562
+ tl.store(
563
+ d_param_mu_partial_ptrs, d_param_mu
564
+ )
565
+ else:
566
+ tl.store(
567
+ d_param_mu_partial_ptrs, d_param_mu,
568
+ mask=offs_d < headdim
569
+ )
570
+
571
+
572
+ v_ptrs = (
573
+ V +
574
+ offs_b * stride_vb +
575
+ offs_h * stride_vh +
576
+ (
577
+ (
578
+ start_n * BLOCK_N +
579
+ offs_c[:, None, None] * CHUNK_SIZE +
580
+ offs_m[None, :, None]
581
+ ) * stride_vn +
582
+ offs_d[None, None, :]
583
+ )
584
+ )
585
+ if EVEN_N:
586
+ if EVEN_HEADDIM:
587
+ v = tl.load(
588
+ v_ptrs
589
+ )
590
+ else:
591
+ v = tl.load(
592
+ v_ptrs,
593
+ mask=offs_d[None, None, :] < headdim,
594
+ other=0.0
595
+ )
596
+ else:
597
+ if EVEN_HEADDIM:
598
+ v = tl.load(
599
+ v_ptrs,
600
+ mask=(
601
+ start_n * BLOCK_N +
602
+ offs_c[:, None, None] * CHUNK_SIZE +
603
+ offs_m[None, :, None]
604
+ ) < seqlen,
605
+ other=0.0
606
+ )
607
+ else:
608
+ v = tl.load(
609
+ v_ptrs,
610
+ mask=(
611
+ (
612
+ start_n * BLOCK_N +
613
+ offs_c[:, None, None] * CHUNK_SIZE +
614
+ offs_m[None, :, None]
615
+ ) < seqlen
616
+ ) & (offs_d[None, None, :] < headdim),
617
+ other=0.0
618
+ )
619
+
620
+
621
+ if EVEN_N:
622
+ if EVEN_HEADDIM:
623
+ rfa_v = tl.load(
624
+ rfa_v_ptrs
625
+ )
626
+ else:
627
+ rfa_v = tl.load(
628
+ rfa_v_ptrs,
629
+ mask=offs_d[None, :] < headdim,
630
+ other=0.0
631
+ )
632
+ else:
633
+ if EVEN_HEADDIM:
634
+ rfa_v = tl.load(
635
+ rfa_v_ptrs,
636
+ mask=offs_rfa_c[:, None] < nchunks,
637
+ other=0.0
638
+ )
639
+ else:
640
+ rfa_v = tl.load(
641
+ rfa_v_ptrs,
642
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
643
+ other=0.0
644
+ )
645
+
646
+ if EVEN_N:
647
+ if EVEN_HEADDIM:
648
+ d_rfa_v = tl.load(
649
+ d_rfa_v_ptrs
650
+ )
651
+ else:
652
+ d_rfa_v = tl.load(
653
+ d_rfa_v_ptrs,
654
+ mask=offs_d[None, :] < headdim,
655
+ other=0.0
656
+ )
657
+ else:
658
+ if EVEN_HEADDIM:
659
+ d_rfa_v = tl.load(
660
+ d_rfa_v_ptrs,
661
+ mask=offs_rfa_c[:, None] < nchunks,
662
+ other=0.0
663
+ )
664
+ else:
665
+ d_rfa_v = tl.load(
666
+ d_rfa_v_ptrs,
667
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
668
+ other=0.0
669
+ )
670
+
671
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
672
+ phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
673
+ phi_c_w += tl.sum(k * param_phi, axis=-1)
674
+ phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
675
+ phi_c_w *= log2e * softmax_scale
676
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
677
+ phi_c_w += tl.where(
678
+ (
679
+ start_n * BLOCK_N +
680
+ offs_c[:, None] * CHUNK_SIZE +
681
+ offs_m[None, :]
682
+ ) < seqlen,
683
+ 0,
684
+ float("-inf")
685
+ )
686
+
687
+ if MASK_TYPE == 1:
688
+ phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
689
+
690
+
691
+ m_phi_c_w = tl.max(phi_c_w, axis=-1)
692
+ masked_out_rows_phi = (m_phi_c_w == float("-inf"))
693
+ m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
694
+ phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
695
+ denom_phi = tl.sum(phi_c_w, axis=-1)
696
+ denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
697
+ phi_tilde_c_w = phi_c_w / denom_phi[:, None]
698
+ # phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
699
+ # phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
700
+ phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
701
+ d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
702
+ d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
703
+ d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
704
+
705
+ d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
706
+ d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
707
+ # [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
708
+ d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
709
+
710
+ d_k_ptrs = (
711
+ D_K +
712
+ offs_b * stride_d_k_b +
713
+ offs_h * stride_d_k_h +
714
+ (
715
+ (
716
+ start_n * BLOCK_N +
717
+ offs_c[:, None, None] * CHUNK_SIZE +
718
+ offs_m[None, :, None]
719
+ ) * stride_d_k_n +
720
+ offs_d[None, None, :]
721
+ )
722
+ )
723
+ d_v_ptrs = (
724
+ D_V +
725
+ offs_b * stride_d_v_b +
726
+ offs_h * stride_d_v_h +
727
+ (
728
+ (
729
+ start_n * BLOCK_N +
730
+ offs_c[:, None, None] * CHUNK_SIZE +
731
+ offs_m[None, :, None]
732
+ ) * stride_d_v_n +
733
+ offs_d[None, None, :]
734
+ )
735
+ )
736
+ if EVEN_N:
737
+ if EVEN_HEADDIM:
738
+ tl.store(
739
+ d_k_ptrs, d_k
740
+ )
741
+ tl.store(
742
+ d_v_ptrs, d_v
743
+ )
744
+ else:
745
+ tl.store(
746
+ d_k_ptrs, d_k,
747
+ mask=offs_d[None, None, :] < headdim
748
+ )
749
+ tl.store(
750
+ d_v_ptrs, d_v,
751
+ mask=offs_d[None, None, :] < headdim
752
+ )
753
+ else:
754
+ if EVEN_HEADDIM:
755
+ tl.store(
756
+ d_k_ptrs, d_k,
757
+ mask=(
758
+ (
759
+ start_n * BLOCK_N +
760
+ offs_c[:, None, None] * CHUNK_SIZE +
761
+ offs_m[None, :, None]
762
+ ) < seqlen
763
+ ),
764
+ )
765
+ tl.store(
766
+ d_v_ptrs, d_v,
767
+ mask=(
768
+ (
769
+ start_n * BLOCK_N +
770
+ offs_c[:, None, None] * CHUNK_SIZE +
771
+ offs_m[None, :, None]
772
+ ) < seqlen
773
+ ),
774
+ )
775
+ else:
776
+ tl.store(
777
+ d_k_ptrs, d_k,
778
+ mask=(
779
+ (
780
+ start_n * BLOCK_N +
781
+ offs_c[:, None, None] * CHUNK_SIZE +
782
+ offs_m[None, :, None]
783
+ ) < seqlen
784
+ ) & (offs_d[None, None, :] < headdim),
785
+ )
786
+ tl.store(
787
+ d_v_ptrs, d_v,
788
+ mask=(
789
+ (
790
+ start_n * BLOCK_N +
791
+ offs_c[:, None, None] * CHUNK_SIZE +
792
+ offs_m[None, :, None]
793
+ ) < seqlen
794
+ ) & (offs_d[None, None, :] < headdim),
795
+ )
796
+ d_param_phi_partial_ptrs = (
797
+ D_PARAM_PHI_PARTIAL +
798
+ offs_b * stride_d_phi_b +
799
+ offs_h * stride_d_phi_h +
800
+ start_n * stride_d_phi_g +
801
+ offs_d
802
+ )
803
+ if EVEN_HEADDIM:
804
+ tl.store(
805
+ d_param_phi_partial_ptrs, d_param_phi
806
+ )
807
+ else:
808
+ tl.store(
809
+ d_param_phi_partial_ptrs, d_param_phi,
810
+ mask=offs_d < headdim
811
+ )
812
+
813
+ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
814
+ k, v, param_mu, param_phi = [
815
+ x if x.stride(-1) == 1 else x.contiguous()
816
+ for x in [k, v, param_mu, param_phi]
817
+ ]
818
+
819
+ # shape constraints
820
+ batch, nheads, seqlen, head_dim = k.shape
821
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
822
+ nchunks = seqlen // chunksize
823
+ assert k.shape == (batch, nheads, seqlen, head_dim)
824
+ assert v.shape == (batch, nheads, seqlen, head_dim)
825
+ assert param_mu.shape == (1, nheads, 1, 1, head_dim)
826
+ assert param_phi.shape == (1, nheads, 1, 1, head_dim)
827
+ assert head_dim <= 128, "We only test head dimensions up to 128"
828
+ assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
829
+ assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
830
+ assert k.is_cuda and v.is_cuda
831
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
832
+
833
+ mask_type = 0
834
+ if mask is not None:
835
+ mask_type = 1
836
+ assert mask.dtype == torch.bool
837
+ assert mask.is_cuda
838
+ assert mask.dim() == 4
839
+ assert mask.shape == (batch, 1, seqlen, 1)
840
+ if mask.stride(-1) != 1:
841
+ mask = mask.contiguous()
842
+ mask_strides = (
843
+ (mask.stride(0), mask.stride(2))
844
+ if mask_type == 1 else
845
+ (0, 0)
846
+ )
847
+ out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
848
+ out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
849
+
850
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
851
+ BLOCK = 128
852
+ num_warps = 4 if head_dim <= 64 else 8
853
+
854
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
855
+ chunks_per_block = BLOCK // chunksize
856
+
857
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
858
+ _fwd_eva_prep_kv_kernel[grid](
859
+ k,
860
+ v,
861
+ param_mu,
862
+ param_phi,
863
+ mask,
864
+ out_rfa_k,
865
+ out_rfa_v,
866
+ softmax_scale,
867
+ k.stride(0), k.stride(1), k.stride(2),
868
+ v.stride(0), v.stride(1), v.stride(2),
869
+ param_mu.stride(1),
870
+ param_phi.stride(1),
871
+ mask_strides[0], mask_strides[1],
872
+ out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
873
+ out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
874
+ nheads,
875
+ seqlen,
876
+ nchunks,
877
+ head_dim,
878
+ chunks_per_block,
879
+ chunksize,
880
+ mask_type,
881
+ BLOCK_HEADDIM,
882
+ BLOCK_N=BLOCK,
883
+ num_warps=num_warps,
884
+ num_stages=1,
885
+ )
886
+ return out_rfa_k, out_rfa_v
887
+
888
+ def triton_eva_prep_kv_bwd(
889
+ d_rfa_k, d_rfa_v,
890
+ k, v, param_mu, param_phi,
891
+ mask,
892
+ rfa_k, rfa_v,
893
+ d_k, d_v, d_param_mu, d_param_phi,
894
+ softmax_scale,
895
+ mask_type,
896
+ chunksize
897
+ ):
898
+ d_rfa_k, d_rfa_v = [
899
+ x if x.stride(-1) == 1 else x.contiguous()
900
+ for x in [d_rfa_k, d_rfa_v]
901
+ ]
902
+
903
+ # shape constraints
904
+ batch, nheads, seqlen, head_dim = k.shape
905
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
906
+ nchunks = seqlen // chunksize
907
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
908
+
909
+ mask_strides = (
910
+ (mask.stride(0), mask.stride(2))
911
+ if mask_type == 1 else
912
+ (0, 0)
913
+ )
914
+
915
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
916
+ BLOCK = 128
917
+ num_warps = 4 if head_dim <= 64 else 8
918
+
919
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
920
+ chunks_per_block = BLOCK // chunksize
921
+
922
+ partial_groups = triton.cdiv(seqlen, BLOCK)
923
+ d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
924
+ d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
925
+ grid = lambda META: (partial_groups, batch * nheads)
926
+ _bwd_eva_prep_kv_kernel[grid](
927
+ rfa_k, # [b, h, c, d]
928
+ rfa_v, # [b, h, c, d]
929
+ k, # [b, h, n, d]
930
+ v, # [b, h, n, d]
931
+ param_mu, # [1, h, 1, 1, d]
932
+ param_phi, # [1, h, 1, 1, d]
933
+ mask, # [b, h, n, 1]
934
+ d_rfa_k, # [b, h, c, d]
935
+ d_rfa_v, # [b, h, c, d]
936
+ d_k, # [b, h, n, d]
937
+ d_v, # [b, h, n, d]
938
+ d_param_mu_partial, # [b, h, g, d]
939
+ d_param_phi_partial, # [b, h, g, d]
940
+ softmax_scale,
941
+ rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
942
+ rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
943
+ k.stride(0), k.stride(1), k.stride(2),
944
+ v.stride(0), v.stride(1), v.stride(2),
945
+ param_mu.stride(1),
946
+ param_phi.stride(1),
947
+ mask_strides[0], mask_strides[1],
948
+ d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
949
+ d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
950
+ d_k.stride(0), d_k.stride(1), d_k.stride(2),
951
+ d_v.stride(0), d_v.stride(1), d_v.stride(2),
952
+ d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
953
+ d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
954
+ nheads,
955
+ seqlen,
956
+ nchunks,
957
+ head_dim,
958
+ chunks_per_block,
959
+ chunksize,
960
+ mask_type,
961
+ BLOCK_HEADDIM,
962
+ BLOCK_N=BLOCK,
963
+ num_warps=num_warps,
964
+ num_stages=1,
965
+ )
966
+ d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
967
+ d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
968
+
969
+
970
+
971
+ class EvaPrepKVFunc(torch.autograd.Function):
972
+ @staticmethod
973
+ def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
974
+ if mask is not None:
975
+ mask_type = 1
976
+ else:
977
+ mask_type = 0
978
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
979
+ k, v, param_mu, param_phi, mask, softmax_scale, chunksize
980
+ )
981
+ ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
982
+ ctx.softmax_scale = softmax_scale
983
+ ctx.chunksize = chunksize
984
+ ctx.mask_type = mask_type
985
+ return rfa_k, rfa_v
986
+
987
+ @staticmethod
988
+ def backward(ctx, d_rfa_k, d_rfa_v):
989
+ k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
990
+ d_k = torch.empty_like(k)
991
+ d_v = torch.empty_like(v)
992
+ d_param_mu = torch.empty_like(param_mu)
993
+ d_param_phi = torch.empty_like(param_phi)
994
+ triton_eva_prep_kv_bwd(
995
+ d_rfa_k, d_rfa_v,
996
+ k, v, param_mu, param_phi,
997
+ mask,
998
+ rfa_k, rfa_v,
999
+ d_k, d_v, d_param_mu, d_param_phi,
1000
+ ctx.softmax_scale,
1001
+ ctx.mask_type,
1002
+ ctx.chunksize
1003
+ )
1004
+ return d_k, d_v, d_param_mu, d_param_phi, None, None, None
1005
+
1006
+ def eva_prep_kv_func_triton(
1007
+ k, v,
1008
+ param_mu, param_phi,
1009
+ mask,
1010
+ softmax_scale=None, chunksize=None
1011
+ ):
1012
+ return EvaPrepKVFunc.apply(
1013
+ k, v,
1014
+ param_mu, param_phi,
1015
+ mask,
1016
+ softmax_scale, chunksize
1017
+ )
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/eva_pt_ref.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+
5
+ MASK_MIN_VALUE = -10e10
6
+
7
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
8
+ """
9
+ Rotates half the hidden dims (last dim) of the input.
10
+ Args:
11
+ x: Rotary embedded tensor
12
+ Return:
13
+ Tensor with half of last dim negated and rotated to the front.
14
+ """
15
+ x1, x2 = x.split(x.shape[-1] // 2, dim=-1)
16
+ return torch.cat((-x2, x1), dim=-1)
17
+
18
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
19
+ position_ids: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension.
22
+
23
+ The legends for dimensions are defined as:
24
+ num_heads: number of attention heads
25
+ current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len
26
+ max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always
27
+ maximum lenghth, e.g. the length of static sequence length of KV cache
28
+
29
+
30
+ Args:
31
+ q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim)
32
+ k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim)
33
+ cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim)
34
+ sin: Sine base of rotary embedding, of size (max_seq_len, head_dim)
35
+ position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of
36
+ (batch_size, current_seq_len).
37
+
38
+ Returns:
39
+ Embedded query and key tensor of same size as input.
40
+
41
+ """
42
+ bs, nheads, cur_seq_len, head_dim = q.shape
43
+ assert len(
44
+ k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead"
45
+ assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}"
46
+ assert list(k.shape[2:]) == [cur_seq_len,
47
+ head_dim], f"k has different current_seq_len and/or head_dim compared to q"
48
+ assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead"
49
+ assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\
50
+ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead"
51
+
52
+ q_embed = (q * cos) + (rotate_half(q) * sin)
53
+ k_embed = (k * cos) + (rotate_half(k) * sin)
54
+ return q_embed, k_embed
55
+
56
+ def attention_op(
57
+ q,
58
+ k,
59
+ v,
60
+ attn_mask,
61
+ mixedp_attn,
62
+ head_dim_scaling
63
+ ):
64
+ attn = torch.matmul(q, k.transpose(-2, -1))
65
+ if mixedp_attn:
66
+ attn = attn.to(torch.float)
67
+ attn = attn * head_dim_scaling
68
+ if attn_mask is not None:
69
+ attn = attn.masked_fill(attn_mask, MASK_MIN_VALUE)
70
+
71
+ attn_weights = torch.softmax(attn, dim=-1).to(q.dtype)
72
+ attn_output = torch.matmul(attn_weights, v)
73
+ return attn_output
74
+
75
+ def prm_projection(
76
+ x: torch.Tensor,
77
+ projection_matrix: torch.Tensor,
78
+ mixedp_attn: bool = False
79
+ ):
80
+ """
81
+ Constructs nonnegative kernel features for fast softmax attention.
82
+ Args:
83
+ x: input for which features are computed
84
+ projection_matrix: random matrix used to compute features
85
+ Returns:
86
+ Random features for fast attention.
87
+ """
88
+ # x : [..., m, d]
89
+ # proj : [..., r, d]
90
+ scaling_factor = (x.shape[-1] ** -0.5)
91
+ proj_x = torch.matmul(projection_matrix, x.transpose(-1, -2)) # [..., r, m]
92
+ norm = torch.sum(x ** 2, dim=-1).unsqueeze(-2) * 0.5 # [..., 1]
93
+ if mixedp_attn:
94
+ proj_x = proj_x.to(torch.float)
95
+ norm = norm.to(torch.float)
96
+ phi_x = scaling_factor * (proj_x - norm)
97
+ return phi_x
98
+
99
+ class EvaAttention(nn.Module):
100
+ def __init__(self, config, layer_idx: Optional[int] = None):
101
+ super().__init__()
102
+ self.config = config
103
+ self.layer_idx = layer_idx
104
+ self.hidden_size = config.hidden_size
105
+ self.num_heads = config.num_attention_heads
106
+ self.head_dim = self.hidden_size // self.num_heads
107
+ self.head_dim_scaling = self.head_dim ** -0.5
108
+
109
+ self.max_position_embeddings = config.max_position_embeddings
110
+
111
+ if (self.head_dim * self.num_heads) != self.hidden_size:
112
+ raise ValueError(
113
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
114
+ f" and `num_heads`: {self.num_heads})."
115
+ )
116
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
117
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
118
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
119
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
120
+
121
+ self.window_size = config.window_size
122
+
123
+ self.num_chunks = config.num_chunks
124
+ self.chunk_size = config.chunk_size
125
+ if self.chunk_size is not None:
126
+ assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0
127
+ # chunk_size overrides the number of landmarks
128
+ self.num_chunks = None
129
+
130
+ self.chunks_per_window = int(self.window_size // self.chunk_size)
131
+ self.random_feature_dim = 1
132
+ self.adaptive_phi = nn.Parameter(
133
+ torch.randn(
134
+ 1,
135
+ self.num_heads,
136
+ 1,
137
+ 1,
138
+ self.head_dim
139
+ ).clamp(-1., 1.) * self.head_dim_scaling
140
+ )
141
+ self.adaptive_mu_k = nn.Parameter(
142
+ torch.randn(
143
+ 1,
144
+ self.num_heads,
145
+ 1,
146
+ 1,
147
+ self.head_dim
148
+ ).clamp(-1., 1.) * self.head_dim_scaling
149
+ )
150
+
151
+ def _generate_feature_map(self, rf_q, rf_k, rf_v):
152
+ rf_k_logits = torch.sum(self.adaptive_mu_k.to(rf_k.dtype) * rf_k, dim=-1, keepdim=True) # b h c m 1
153
+ if self.config.mixedp_attn:
154
+ rf_k_logits = rf_k_logits.to(torch.float)
155
+ rf_k_weights = torch.softmax(rf_k_logits, dim=-2).to(rf_k.dtype)
156
+ rf_k_bar = torch.sum(rf_k_weights * rf_k, dim=-2)
157
+ weights = self.adaptive_phi.to(rf_k.dtype)
158
+ return weights, rf_k_bar
159
+
160
+ def _calculate_chunk_rfa_cache(self, rf_q, rf_k, rf_v, weights, rf_mask=None):
161
+ proj_x = torch.sum(weights * rf_k, dim=-1, keepdim=True)
162
+ norm = torch.sum(rf_k ** 2, dim=-1, keepdim=True) * 0.5 # [..., 1]
163
+ if self.config.mixedp_attn:
164
+ proj_x = proj_x.to(torch.float)
165
+ norm = norm.to(torch.float)
166
+ log_phi_k = self.head_dim_scaling * (proj_x - norm)
167
+
168
+ if rf_mask is not None:
169
+ log_phi_k = log_phi_k.masked_fill(rf_mask, MASK_MIN_VALUE)
170
+
171
+ # [b, h, c, m, r]
172
+ softmax_phi_k = torch.softmax(log_phi_k, dim=-2).to(rf_k.dtype)
173
+ softmax_phi_k_v = torch.sum(softmax_phi_k * rf_v, dim=-2)
174
+ # [b, h, c, r, m] [b, h, c, m, d] -> [b, h, c, r, d]
175
+ # softmax_phi_k_v = torch.matmul(softmax_phi_k.transpose(-1, -2), rf_v).squeeze(-2)
176
+ log_sum_phi_k = None
177
+ return softmax_phi_k_v, log_sum_phi_k
178
+
179
+ def _calculate_chunk_rfa(self, q, softmax_phi_k_v, log_sum_phi_k, weights):
180
+ if self.random_feature_dim == 1:
181
+ # when r = 1, the snis weights becomes 1, so this takes no effect
182
+ # [b, h, c, r, d] -> [b, h, c, d]
183
+ return softmax_phi_k_v
184
+ else:
185
+ # [b, h, c, r, d] [b, h, 1, s, d] -> [b, h, c, r, s]
186
+ log_phi_q = prm_projection(q.unsqueeze(-3), weights, self.config.mixedp_attn)
187
+ # [b, h, c, r, s] [b, h, c, r, 1] -> [b, h, c, r, s]
188
+ sniw = torch.softmax(log_phi_q + log_sum_phi_k, dim=-1).to(q.dtype)
189
+ # [b, h, c, r, s] [b, h, c, r, d] -> [b, h, c, s, d] -> [b, h, s, c, d]
190
+ rfa_per_chunk = torch.matmul(sniw.transpose(-1, -2), softmax_phi_k_v).transpose(-3, -2)
191
+ return rfa_per_chunk
192
+
193
+ def window_partition(self, x, window_size=None):
194
+ window_size = window_size if window_size is not None else self.window_size
195
+
196
+ gw, d = x.shape[-2:]
197
+ leading_dims = x.shape[:-2]
198
+ n_groups = gw // window_size
199
+ return x.reshape(*leading_dims, n_groups, window_size, d)
200
+
201
+ def window_merge(self, x, window_size=None):
202
+ g, w, d = x.shape[-3:]
203
+ leading_dims = x.shape[:-3]
204
+ return x.reshape(*leading_dims, g * w, d)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
210
+ position_ids: Optional[torch.LongTensor] = None,
211
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
212
+ output_attentions: bool = False,
213
+ use_cache: bool = False,
214
+ cos: Optional[torch.Tensor] = None,
215
+ sin: Optional[torch.Tensor] = None,
216
+ multibyte_decoding: Optional[bool] = False,
217
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
218
+ assert not output_attentions
219
+ bsz, q_len, _ = hidden_states.size()
220
+
221
+ ############################################
222
+ # initialize past states if not provided
223
+ ############################################
224
+ if use_cache and past_key_value is None:
225
+ raise ValueError
226
+ if use_cache and multibyte_decoding:
227
+ raise NotImplementedError("Multibyte decoding is not supported for PyTorch native implementation")
228
+ # assert isinstance(attention_mask, tuple)
229
+ if len(attention_mask) == 4:
230
+ assert use_cache
231
+ prev_causal_mask, cur_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
232
+ elif len(attention_mask) == 3:
233
+ assert not use_cache
234
+ window_causal_mask, chunk_causal_mask, intra_chunk_mask = attention_mask
235
+ else:
236
+ raise NotImplementedError("Only attention-mask tuple with length 2 or 3 is supported")
237
+
238
+ ############################################
239
+ # compute q, k, v from hidden states
240
+ ############################################
241
+ # [b, h, q_len, d]
242
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
243
+ # [b, h, kv_len, d]
244
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
245
+ # [b, h, kv_len, d]
246
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
247
+
248
+ if use_cache:
249
+ past_key_value.update_past_len(q.shape[-2], self.layer_idx)
250
+
251
+ ############################################
252
+ # apply rotary positional embeddings to q, k
253
+ ############################################
254
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
255
+
256
+ ############################################
257
+ # compute q, k, v stats for the local window
258
+ ############################################
259
+ if use_cache:
260
+ (prev_w_q, prev_w_k, prev_w_v), (cur_w_q, cur_w_k, cur_w_v) = past_key_value.update_singletons(
261
+ q,
262
+ k,
263
+ v,
264
+ self.layer_idx,
265
+ self.window_size,
266
+ )
267
+ else:
268
+ prev_w_q = self.window_partition(q) # [b, h, w, i, d]
269
+ prev_w_k = self.window_partition(k) # [b, h, w, j, d]
270
+ prev_w_v = self.window_partition(v) # [b, h, w, j, d]
271
+ # during training, we assume window_size divides seq_len so no remainders
272
+ cur_w_q = cur_w_k = cur_w_v = None
273
+
274
+ ############################################
275
+ # compute q, k, v stats for chunk-level RFAs
276
+ ############################################
277
+ if use_cache:
278
+ dump_q, dump_k, dump_v = past_key_value.update_chunks(q, k, v, self.layer_idx, self.chunk_size)
279
+ else:
280
+ dump_q, dump_k, dump_v = q, k, v
281
+
282
+ if use_cache:
283
+ prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask = past_key_value.update_mask(
284
+ prev_s_mask=prev_causal_mask,
285
+ cur_s_mask=cur_causal_mask,
286
+ chunk_mask=chunk_causal_mask,
287
+ rf_mask=intra_chunk_mask,
288
+ layer_idx=self.layer_idx,
289
+ window_size=self.window_size,
290
+ chunk_size=self.chunk_size,
291
+ )
292
+ else:
293
+ prev_s_mask = self.window_partition(prev_causal_mask) # [1, 1, w, i, j]
294
+ cur_s_mask = None
295
+ prev_chunk_mask = self.window_partition(chunk_causal_mask)
296
+ cur_chunk_mask = None
297
+ dump_rf_mask = intra_chunk_mask
298
+ if prev_s_mask.shape[-3] == 1:
299
+ # need to expand
300
+ prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1)
301
+
302
+ if (
303
+ dump_q is not None and
304
+ dump_k is not None and
305
+ dump_v is not None
306
+ ):
307
+ # [b, h, c, j, d]
308
+ rf_q = self.window_partition(dump_q, window_size=self.chunk_size)
309
+ # [b, h, c, j, d]
310
+ rf_k = self.window_partition(dump_k, window_size=self.chunk_size)
311
+ # [b, h, c, j, d]
312
+ rf_v = self.window_partition(dump_v, window_size=self.chunk_size)
313
+
314
+ if dump_rf_mask is not None:
315
+ rf_mask = self.window_partition(dump_rf_mask, window_size=self.chunk_size)
316
+ rf_q = rf_q.masked_fill(rf_mask, 0.)
317
+ rf_k = rf_k.masked_fill(rf_mask, 0.)
318
+ rf_v = rf_v.masked_fill(rf_mask, 0.)
319
+ else:
320
+ rf_mask = None
321
+ else:
322
+ rf_q = None
323
+ rf_k = None
324
+ rf_v = None
325
+ rf_mask = None
326
+
327
+
328
+ if rf_q is not None:
329
+ # import pdb; pdb.set_trace()
330
+ weights, rf_k_bar = self._generate_feature_map(rf_q, rf_k, rf_v)
331
+ softmax_phi_k_v, log_sum_phi_k = self._calculate_chunk_rfa_cache(rf_q, rf_k, rf_v, weights, rf_mask=rf_mask)
332
+ if use_cache:
333
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.update_chunk_rfas(
334
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar, self.layer_idx, 1
335
+ )
336
+ elif use_cache:
337
+ weights = None
338
+ softmax_phi_k_v, log_sum_phi_k, rf_k_bar = past_key_value.get_chunk_rfas(self.layer_idx)
339
+ else:
340
+ weights = None
341
+ softmax_phi_k_v = None
342
+ log_sum_phi_k = None
343
+ rf_k_bar = None
344
+
345
+ if rf_k_bar is not None:
346
+ rfa_per_chunk = self._calculate_chunk_rfa(q, softmax_phi_k_v, log_sum_phi_k, weights)
347
+ ############################################
348
+ # compute meta-attention weights for
349
+ # - group-wise RFAs and
350
+ # - singletons (equivalent to exact local attention)
351
+ ############################################
352
+ if prev_w_k is not None:
353
+ if rf_k_bar is not None:
354
+ num_windows = prev_w_k.shape[-3]
355
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
356
+ # -> [b, h, 1, c, d] -> [b, h, w, c, d]
357
+ prev_rf_k_bar = rf_k_bar.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
358
+ prev_rfa_per_chunk = rfa_per_chunk.unsqueeze(-3).expand(-1, -1, num_windows, -1, -1)
359
+ prev_agg_k = torch.cat([prev_w_k, prev_rf_k_bar], dim=-2)
360
+ prev_agg_v = torch.cat([prev_w_v, prev_rfa_per_chunk], dim=-2)
361
+
362
+ prev_attn_mask = torch.cat([prev_s_mask, prev_chunk_mask], dim=-1)
363
+ else:
364
+ prev_agg_k = prev_w_k
365
+ prev_agg_v = prev_w_v
366
+ prev_attn_mask = prev_s_mask
367
+
368
+ prev_attn_output = attention_op(
369
+ q=prev_w_q,
370
+ k=prev_agg_k,
371
+ v=prev_agg_v,
372
+ attn_mask=prev_attn_mask,
373
+ mixedp_attn=self.config.mixedp_attn,
374
+ head_dim_scaling=self.head_dim_scaling
375
+ )
376
+ prev_attn_output = self.window_merge(prev_attn_output)
377
+
378
+ if cur_w_k is not None:
379
+ if rf_k_bar is not None:
380
+ # rf_k_bar and rfa_per_chunk take the shape [b, h, c, d]
381
+ # cur_w_k and cur_w_v also has shape [b, h, m, d]
382
+ cur_agg_k = torch.cat([cur_w_k, rf_k_bar], dim=-2)
383
+ cur_agg_v = torch.cat([cur_w_v, rfa_per_chunk], dim=-2)
384
+
385
+ cur_attn_mask = torch.cat([cur_s_mask, cur_chunk_mask], dim=-1)
386
+ else:
387
+ cur_agg_k = cur_w_k
388
+ cur_agg_v = cur_w_v
389
+ cur_attn_mask = cur_s_mask
390
+
391
+ cur_attn_output = attention_op(
392
+ q=cur_w_q,
393
+ k=cur_agg_k,
394
+ v=cur_agg_v,
395
+ attn_mask=cur_attn_mask,
396
+ mixedp_attn=self.config.mixedp_attn,
397
+ head_dim_scaling=self.head_dim_scaling
398
+ )
399
+
400
+ if prev_w_k is not None and cur_w_k is not None:
401
+ attn_output = torch.cat([prev_attn_output, cur_attn_output], dim=-2)
402
+ elif prev_w_k is not None:
403
+ attn_output = prev_attn_output
404
+ elif cur_w_k is not None:
405
+ attn_output = cur_attn_output
406
+ else:
407
+ raise ValueError("There must be some bug")
408
+
409
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
410
+ raise ValueError(
411
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
412
+ f" {attn_output.size()}"
413
+ )
414
+
415
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
416
+ attn_output = self.o_proj(attn_output)
417
+
418
+ attn_weights = None
419
+
420
+ return attn_output, attn_weights, past_key_value
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.47.1"
7
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/image_processing_evabyte.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Image processor class for EvaByte."""
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ import io
7
+ from transformers.image_processing_utils import BaseImageProcessor
8
+ from transformers.image_utils import (
9
+ ImageInput,
10
+ PILImageResampling,
11
+ valid_images,
12
+ validate_preprocess_arguments,
13
+ )
14
+ from PIL import Image
15
+
16
+ def _get_qtable_bytes():
17
+ return {
18
+ 5: b'\xff\xd8\xff\xdb\x00C\x00\xa0nx\x8cxd\xa0\x8c\x82\x8c\xb4\xaa\xa0\xbe\xf0\xff\xff\xf0\xdc\xdc\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01\xa0\xb4\xb4\xf0\xd2\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
19
+ 10: b'\xff\xd8\xff\xdb\x00C\x00P7<F<2PFAFZUP_x\xc8\x82xnnx\xf5\xaf\xb9\x91\xc8\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x01PZZxix\xeb\x82\x82\xeb\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
20
+ 15: b'\xff\xd8\xff\xdb\x00C\x005%(/(!5/+/<95?P\x85WPIIP\xa3u{a\x85\xc1\xaa\xcb\xc8\xbe\xaa\xba\xb7\xd5\xf0\xff\xff\xd5\xe2\xff\xe6\xb7\xba\xff\xff\xff\xff\xff\xff\xff\xff\xff\xce\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xdb\x00C\x015<<PFP\x9dWW\x9d\xff\xdc\xba\xdc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xd9',
21
+ 20: b'\xff\xd8\xff\xdb\x00C\x00(\x1c\x1e#\x1e\x19(#!#-+(0<dA<77<{X]Id\x91\x80\x99\x96\x8f\x80\x8c\x8a\xa0\xb4\xe6\xc3\xa0\xaa\xda\xad\x8a\x8c\xc8\xff\xcb\xda\xee\xf5\xff\xff\xff\x9b\xc1\xff\xff\xff\xfa\xff\xe6\xfd\xff\xf8\xff\xdb\x00C\x01(--<5<vAAv\xf8\xa5\x8c\xa5\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xf8\xff\xd9',
22
+ 25: b'\xff\xd8\xff\xdb\x00C\x00 \x16\x18\x1c\x18\x14 \x1c\x1a\x1c$" &0P40,,0bFJ:Ptfzxrfpn\x80\x90\xb8\x9c\x80\x88\xae\x8anp\xa0\xda\xa2\xae\xbe\xc4\xce\xd0\xce|\x9a\xe2\xf2\xe0\xc8\xf0\xb8\xca\xce\xc6\xff\xdb\x00C\x01 $$0*0^44^\xc6\x84p\x84\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xc6\xff\xd9',
23
+ 30: b'\xff\xd8\xff\xdb\x00C\x00\x1b\x12\x14\x17\x14\x11\x1b\x17\x16\x17\x1e\x1c\x1b (B+(%%(Q:=0B`Ued_U][jx\x99\x81jq\x90s[]\x85\xb5\x86\x90\x9e\xa3\xab\xad\xabg\x80\xbc\xc9\xba\xa6\xc7\x99\xa8\xab\xa4\xff\xdb\x00C\x01\x1b\x1e\x1e(#(N++N\xa4n]n\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xa4\xff\xd9',
24
+ 50: b'\xff\xd8\xff\xdb\x00C\x00\x10\x0b\x0c\x0e\x0c\n\x10\x0e\r\x0e\x12\x11\x10\x13\x18(\x1a\x18\x16\x16\x181#%\x1d(:3=<9387@H\\N@DWE78PmQW_bghg>Mqypdx\\egc\xff\xdb\x00C\x01\x10\x12\x12\x18\x15\x18/\x1a\x1a/cB8Bcccccccccccccccccccccccccccccccccccccccccccccccccc\xff\xd9',
25
+ 75: b'\xff\xd8\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\x08\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xd9',
26
+ 95: b'\xff\xd8\xff\xdb\x00C\x00\x02\x01\x01\x01\x01\x01\x02\x01\x01\x01\x02\x02\x02\x02\x02\x04\x03\x02\x02\x02\x02\x05\x04\x04\x03\x04\x06\x05\x06\x06\x06\x05\x06\x06\x06\x07\t\x08\x06\x07\t\x07\x06\x06\x08\x0b\x08\t\n\n\n\n\n\x06\x08\x0b\x0c\x0b\n\x0c\t\n\n\n\xff\xdb\x00C\x01\x02\x02\x02\x02\x02\x02\x05\x03\x03\x05\n\x07\x06\x07\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\xff\xd9',
27
+ 100: b'\xff\xd8\xff\xdb\x00C\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xdb\x00C\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\xff\xd9',
28
+ }
29
+
30
+
31
+ def _resize_if_exceeding_max_len(
32
+ width: int, height: int, min_len: Optional[int] = 16, max_len: Optional[int] = None
33
+ ) -> Tuple[int, int]:
34
+ """
35
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
36
+
37
+ Args:
38
+ height (`int`):
39
+ Height of the input image.
40
+ width (`int`):
41
+ Width of the input image.
42
+ max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
43
+ Defines the maximum dimensions of the image.
44
+
45
+ Returns:
46
+ The output size of the image after resizing.
47
+ """
48
+ max_len = max(height, width) if max_len is None else max_len
49
+ aspect_ratio = width / height
50
+
51
+ if width >= height and width > max_len:
52
+ width = max_len
53
+ height = int(width / aspect_ratio)
54
+ if height % 2 != 0:
55
+ height += 1
56
+ elif height > width and height > max_len:
57
+ height = max_len
58
+ width = int(height * aspect_ratio)
59
+ if width % 2 != 0:
60
+ width += 1
61
+
62
+ # Avoid resizing to a size smaller than 1
63
+ height = max(height, min_len)
64
+ width = max(width, min_len)
65
+ return width, height
66
+
67
+ class EvaByteImageProcessor(BaseImageProcessor):
68
+
69
+ model_input_names = []
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
75
+ size: Dict[str, int] = None,
76
+ do_convert_rgb: bool = True,
77
+ jpeg_quality: int = 25,
78
+ jpeg_subsampling: str = "4:2:0",
79
+ jpeg_streamtype: str = 2,
80
+ jpeg_restart_marker_blocks: int = 1,
81
+ **kwargs,
82
+ ) -> None:
83
+ super().__init__(**kwargs)
84
+ self.do_resize = do_resize
85
+ self.resample = resample
86
+ self.size = size if size is not None else {"longest_edge": 384}
87
+ self.do_convert_rgb = do_convert_rgb
88
+ self.jpeg_quality = jpeg_quality
89
+ self.jpeg_subsampling = jpeg_subsampling
90
+ self.jpeg_streamtype = jpeg_streamtype
91
+ self.jpeg_restart_marker_blocks = jpeg_restart_marker_blocks
92
+
93
+ def jpeg_encode(
94
+ self,
95
+ image,
96
+ jpeg_quality,
97
+ jpeg_subsampling,
98
+ jpeg_streamtype,
99
+ jpeg_restart_marker_blocks,
100
+ ):
101
+ with io.BytesIO() as output:
102
+ image.save(
103
+ output,
104
+ format="JPEG",
105
+ quality=jpeg_quality,
106
+ subsampling=jpeg_subsampling,
107
+ streamtype=jpeg_streamtype,
108
+ restart_marker_blocks=jpeg_restart_marker_blocks
109
+ )
110
+ jpeg_bytes = output.getvalue()
111
+ return jpeg_bytes
112
+
113
+ def jpeg_merge_qtables(
114
+ self,
115
+ image_bytes,
116
+ jpeg_quality=None,
117
+ ):
118
+ if jpeg_quality is None:
119
+ jpeg_quality = self.jpeg_quality
120
+ qtable_bytes = _get_qtable_bytes()[jpeg_quality]
121
+ return image_bytes[:2] + qtable_bytes[2:-2] + image_bytes[2:]
122
+
123
+ def resize(
124
+ self,
125
+ image: Image,
126
+ size: Dict[str, int],
127
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
128
+ ) -> Image:
129
+ if "longest_edge" in size:
130
+ width, height = image.size
131
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
132
+ width, height = _resize_if_exceeding_max_len(width, height, max_len=size["longest_edge"])
133
+ size = (width, height)
134
+ elif "width" in size and "height" in size:
135
+ size = (size["width"], size["height"])
136
+ else:
137
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
138
+ resized_image = image.resize(size, resample=resample)
139
+ return resized_image
140
+
141
+ def preprocess(
142
+ self,
143
+ images: ImageInput,
144
+ do_resize: bool = None,
145
+ resample = None,
146
+ size: Dict[str, int] = None,
147
+ do_convert_rgb: bool = None,
148
+ jpeg_quality: int = None,
149
+ jpeg_subsampling: str = None,
150
+ jpeg_streamtype: str = None,
151
+ jpeg_restart_marker_blocks: int = None,
152
+ ):
153
+ do_resize = do_resize if do_resize is not None else self.do_resize
154
+ size = size if size is not None else self.size
155
+ resample = resample if resample is not None else self.resample
156
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
157
+
158
+ jpeg_quality = jpeg_quality if jpeg_quality is not None else self.jpeg_quality
159
+ jpeg_subsampling = jpeg_subsampling if jpeg_subsampling is not None else self.jpeg_subsampling
160
+ jpeg_streamtype = jpeg_streamtype if jpeg_streamtype is not None else self.jpeg_streamtype
161
+ jpeg_restart_marker_blocks = jpeg_restart_marker_blocks if jpeg_restart_marker_blocks is not None else self.jpeg_restart_marker_blocks
162
+
163
+ if images is not None and not valid_images(images):
164
+ raise ValueError(
165
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
166
+ "torch.Tensor, tf.Tensor or jax.ndarray."
167
+ )
168
+
169
+ validate_preprocess_arguments(
170
+ do_resize=do_resize,
171
+ size=size,
172
+ resample=resample,
173
+ )
174
+ images_list = images
175
+ if do_convert_rgb:
176
+ images_list = [
177
+ [
178
+ image.convert("RGB") for image in images
179
+ ]
180
+ for images in images_list
181
+ ]
182
+
183
+ if do_resize:
184
+ images_list = [
185
+ [
186
+ self.resize(image=image, size=size, resample=resample)
187
+ for image in images
188
+ ]
189
+ for images in images_list
190
+ ]
191
+
192
+ jpeg_bytes = [
193
+ [
194
+ self.jpeg_encode(
195
+ image,
196
+ jpeg_quality,
197
+ jpeg_subsampling,
198
+ jpeg_streamtype,
199
+ jpeg_restart_marker_blocks
200
+ ) for image in images
201
+ ]
202
+ for images in images_list
203
+ ]
204
+ return jpeg_bytes
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/model.safetensors.index.json ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 57058938880
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.1.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
10
+ "model.layers.1.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
11
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
24
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
25
+ "model.layers.13.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
26
+ "model.layers.13.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
27
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
28
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
29
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
30
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
31
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
32
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
33
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
34
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
35
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
36
+ "model.layers.21.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
37
+ "model.layers.21.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
38
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
39
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
40
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
41
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
42
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
43
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
44
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
45
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
46
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
47
+ "model.layers.29.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
48
+ "model.layers.29.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
49
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
50
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
51
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
52
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
53
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
54
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
55
+ "model.layers.36.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
56
+ "model.layers.36.input_layernorm.weight": "model-00003-of-00003.safetensors",
57
+ "model.layers.36.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
58
+ "model.layers.37.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
59
+ "model.layers.37.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
60
+ "model.layers.37.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
61
+ "model.layers.37.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
62
+ "model.layers.39.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
63
+ "model.layers.2.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
64
+ "model.layers.26.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
65
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
66
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
67
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
68
+ "model.layers.3.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
69
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
70
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
71
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
72
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
73
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
74
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
75
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
76
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
77
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
90
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
91
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
92
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
93
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
94
+ "model.layers.27.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
95
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
96
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
97
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
98
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
99
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
100
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
101
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
102
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
103
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
104
+ "model.layers.36.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
105
+ "model.layers.37.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
106
+ "model.layers.37.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
107
+ "model.layers.39.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
108
+ "model.layers.3.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
109
+ "model.layers.27.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
110
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
111
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
112
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
113
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
114
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
115
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
116
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.4.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
119
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
126
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
127
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
143
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
144
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
145
+ "model.layers.28.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
146
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
147
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
148
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
149
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
150
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
151
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
152
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
153
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
154
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
155
+ "model.layers.37.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
156
+ "model.layers.37.input_layernorm.weight": "model-00003-of-00003.safetensors",
157
+ "model.layers.37.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
158
+ "model.layers.38.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
159
+ "model.layers.38.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
160
+ "model.layers.4.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
161
+ "model.layers.28.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
162
+ "model.layers.5.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
163
+ "model.layers.0.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
164
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
165
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
166
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
167
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
168
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
169
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
170
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
171
+ "model.layers.9.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
172
+ "model.layers.9.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
173
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
174
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
175
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
176
+ "model.layers.12.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
177
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
178
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.17.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
184
+ "model.layers.17.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
185
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
190
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.23.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
192
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
194
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
195
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
196
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
197
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
198
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
201
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
202
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
203
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
204
+ "model.layers.33.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
205
+ "model.layers.33.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
206
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.36.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.36.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.38.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
213
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
214
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
215
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.5.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
221
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
225
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
226
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
227
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
228
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
229
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
230
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
231
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
232
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
233
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
238
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
239
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
240
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
241
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
242
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
243
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
244
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
245
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
246
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
247
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
248
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
249
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
250
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
251
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
256
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
257
+ "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
258
+ "model.layers.37.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
259
+ "model.layers.38.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
260
+ "model.layers.38.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
261
+ "model.layers.6.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
262
+ "model.layers.30.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
263
+ "model.layers.6.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
264
+ "model.layers.30.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
265
+ "model.layers.7.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
266
+ "model.layers.31.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
267
+ "model.layers.7.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
268
+ "model.layers.31.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
269
+ "model.layers.8.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
270
+ "model.layers.32.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
271
+ "model.layers.2.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
272
+ "model.layers.14.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
273
+ "model.layers.14.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
274
+ "model.layers.22.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
275
+ "model.layers.22.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
276
+ "model.layers.38.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
277
+ "model.layers.38.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
278
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
279
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
280
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
281
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
282
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
285
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
286
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
287
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
288
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
289
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
290
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
291
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
293
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
294
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
297
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
298
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
299
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
300
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
301
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
302
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
303
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
304
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
305
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
306
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
307
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
308
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
309
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
310
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
311
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
312
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
313
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
314
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
315
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
316
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
317
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
318
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
319
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
320
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
321
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
322
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
323
+ "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
324
+ "model.layers.37.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
325
+ "model.layers.39.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
326
+ "model.layers.39.input_layernorm.weight": "model-00003-of-00003.safetensors",
327
+ "model.layers.39.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
328
+ "model.norm.weight": "model-00003-of-00003.safetensors",
329
+ "lm_head.weight": "model-00003-of-00003.safetensors",
330
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
331
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
332
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
333
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
334
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
335
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
336
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
337
+ "model.layers.8.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
338
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
339
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
340
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
341
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
342
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
343
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
344
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
345
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
346
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
347
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
348
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
349
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
350
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
351
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
352
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
353
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
354
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
355
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
356
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
357
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
358
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
359
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
360
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
361
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
362
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
363
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
364
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
365
+ "model.layers.32.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
366
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
367
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
368
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
369
+ "model.layers.36.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
370
+ "model.layers.38.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
371
+ "model.layers.38.input_layernorm.weight": "model-00003-of-00003.safetensors",
372
+ "model.layers.38.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
373
+ "model.layers.39.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
374
+ "model.layers.39.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
375
+ "model.layers.10.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
376
+ "model.layers.34.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
377
+ "model.layers.10.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
378
+ "model.layers.34.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
379
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
380
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
381
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
382
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
383
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
384
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
385
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
386
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
387
+ "model.layers.11.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
388
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
389
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
390
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
391
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
392
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
393
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
394
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
395
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
396
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.35.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
398
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
399
+ "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
400
+ "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
401
+ "model.layers.36.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
402
+ "model.layers.36.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
403
+ "model.layers.38.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
404
+ "model.layers.39.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
405
+ "model.layers.39.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
406
+ "model.layers.16.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
407
+ "model.layers.16.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
408
+ "model.layers.24.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
409
+ "model.layers.24.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
410
+ "model.layers.11.self_attn.adaptive_mu_k": "model-00001-of-00003.safetensors",
411
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
412
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
413
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
414
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
415
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
416
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
417
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
418
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
419
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
420
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
421
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
422
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
423
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
424
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
425
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
426
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
427
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
428
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
429
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
430
+ "model.layers.35.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
431
+ "model.layers.12.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
432
+ "model.layers.36.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
433
+ "model.layers.36.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
434
+ "model.layers.0.self_attn.adaptive_phi": "model-00001-of-00003.safetensors",
435
+ "model.layers.15.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
436
+ "model.layers.20.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
437
+ "model.layers.20.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
438
+ "model.layers.25.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
439
+ "model.layers.25.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
440
+ "model.layers.15.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
441
+ "model.layers.39.self_attn.adaptive_phi": "model-00003-of-00003.safetensors",
442
+ "model.layers.39.self_attn.adaptive_mu_k": "model-00003-of-00003.safetensors",
443
+ "model.layers.18.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
444
+ "model.layers.18.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
445
+ "model.layers.23.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
446
+ "model.layers.19.self_attn.adaptive_phi": "model-00002-of-00003.safetensors",
447
+ "model.layers.19.self_attn.adaptive_mu_k": "model-00002-of-00003.safetensors",
448
+ "model.layers.26.self_attn.adaptive_phi": "model-00003-of-00003.safetensors"
449
+ }
450
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/modeling_evabyte.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from .configuration_evabyte import EvaByteConfig
17
+ from .multibyte_decoding_evabyte import MultiByteDecodingMixin
18
+ try:
19
+ import triton
20
+ USE_TRITON_IMPL = True
21
+ from .eva import EvaAttention
22
+ from .eva_agg_kernel import triton_eva_agg_fwd
23
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
24
+ except ImportError:
25
+ USE_TRITON_IMPL = False
26
+ print("WARNING: triton is not installed, using fallback EVA which might be slow and throw errors")
27
+ from .eva_pt_ref import EvaAttention
28
+ from .eva_cache import EvaCache, EvaStaticCacheForTriton
29
+
30
+ MASK_MIN_VALUE = -10e10
31
+
32
+ def prepare_eva_attention_mask(
33
+ seq_len,
34
+ device,
35
+ chunk_size,
36
+ window_size,
37
+ use_cache=False,
38
+ cache=None
39
+ ):
40
+ """
41
+ Prepare attention masks for EVA.
42
+
43
+ """
44
+ chunk_causal_mask = None
45
+ window_causal_mask = None
46
+ if use_cache:
47
+ cached_seq_len = cache.get_seq_length()
48
+ total_seq_len = seq_len + cached_seq_len
49
+ # cached_seq_len will be 0 during prefilling
50
+ # padded_seq_len = chunk_size * math.ceil(total_seq_len / chunk_size)
51
+ padded_seq_len = window_size * math.ceil(total_seq_len / window_size)
52
+ num_chunks = padded_seq_len // chunk_size
53
+ else:
54
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
55
+ assert seq_len % chunk_size == 0
56
+ num_chunks = seq_len // chunk_size
57
+
58
+ assert seq_len % window_size == 0
59
+
60
+ # create causal mask
61
+ ################################
62
+ # generate chunked causal masks
63
+ ################################
64
+ # [b, h, j, c, c]
65
+ chunks_per_window = window_size // chunk_size
66
+ if num_chunks >= chunks_per_window:
67
+ chunk_causal_mask = torch.ones(
68
+ (chunk_size, num_chunks, num_chunks),
69
+ device=device,
70
+ dtype=torch.bool
71
+ ).triu(0)
72
+
73
+ num_blocks = num_chunks // chunks_per_window
74
+ chunk_causal_mask = chunk_causal_mask.reshape(
75
+ chunk_size,
76
+ num_blocks,
77
+ chunks_per_window,
78
+ num_blocks,
79
+ chunks_per_window
80
+ ).transpose(-2, -3)
81
+
82
+ block_diag_zero = (
83
+ torch.eye(num_blocks, device=device, dtype=torch.bool)
84
+ .unsqueeze(-1)
85
+ .unsqueeze(-1)
86
+ .unsqueeze(0)
87
+ )
88
+
89
+ # Set diagonal blocks to zero
90
+ chunk_causal_mask = chunk_causal_mask.masked_fill(block_diag_zero, True)
91
+
92
+ # Reshape back to original size
93
+ chunk_causal_mask = (
94
+ chunk_causal_mask
95
+ .transpose(-2, -3)
96
+ .reshape(chunk_size, num_chunks, num_chunks)
97
+ .transpose(-2, -3)
98
+ .reshape(chunk_size * num_chunks, num_chunks)
99
+ .unsqueeze(0)
100
+ .unsqueeze(0)
101
+ )
102
+ else:
103
+ chunk_causal_mask = torch.ones(
104
+ (1, 1, chunk_size, num_chunks, num_chunks),
105
+ device=device,
106
+ dtype=torch.bool,
107
+ ).triu(0).transpose(-2, -3) # [1, 1, c, j, c]
108
+ chunk_causal_mask = chunk_causal_mask.reshape(
109
+ 1, 1, chunk_size * num_chunks, num_chunks
110
+ ) # [1, 1, n, c]
111
+
112
+ if use_cache:
113
+ chunk_causal_mask = chunk_causal_mask[..., cached_seq_len : cached_seq_len + seq_len, :]
114
+
115
+ window_causal_mask = torch.ones(
116
+ (1, 1, 1, window_size, window_size),
117
+ device=device
118
+ ).triu(1).to(torch.bool)
119
+ return (chunk_causal_mask, window_causal_mask)
120
+
121
+ def pad_to_multiple(tensor, multiple, dim=-2, value=0, create_mask=False, left_padding=False):
122
+ assert dim < 0 # only accept ``dim'' index in a reverse manner
123
+ seqlen = int(tensor.shape[dim])
124
+ m = seqlen / multiple
125
+ if m.is_integer():
126
+ if create_mask:
127
+ return tensor, torch.ones(size=(tensor.shape[0], tensor.shape[dim]), dtype=torch.bool, device=tensor.device)
128
+ else:
129
+ return tensor
130
+ remainder = math.ceil(m) * multiple - seqlen
131
+ pad_offset = (0,) * (-1 - dim) * 2
132
+ if left_padding:
133
+ padded_res = F.pad(tensor, (*pad_offset, remainder, 0), value=value)
134
+ else:
135
+ padded_res = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
136
+ if create_mask:
137
+ # assume dim 0 is the batch size
138
+ padding_mask = torch.ones(size=(padded_res.shape[0], padded_res.shape[dim]), dtype=torch.bool, device=padded_res.device)
139
+ if left_padding:
140
+ padding_mask[:, :remainder] = False
141
+ else:
142
+ padding_mask[:, -remainder:] = False
143
+ return padded_res, padding_mask
144
+ else:
145
+ return padded_res
146
+
147
+ class EvaByteRMSNorm(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.fp32_ln = True
152
+ self.variance_epsilon = config.rms_norm_eps
153
+ self.add_unit_offset = config.norm_add_unit_offset
154
+ if self.add_unit_offset:
155
+ self.weight = nn.Parameter(torch.zeros(config.hidden_size))
156
+ else:
157
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
+
159
+ def forward(self, hidden_states):
160
+ _hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
161
+
162
+ variance = _hidden_states.pow(2).mean(-1, keepdim=True)
163
+ _hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
+ if self.add_unit_offset:
165
+ return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
166
+ else:
167
+ return (self.weight * _hidden_states).type_as(hidden_states)
168
+
169
+ class EvaByteRotaryEmbedding(torch.nn.Module):
170
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
171
+ super().__init__()
172
+
173
+ self.dim = dim
174
+ self.max_position_embeddings = max_position_embeddings
175
+ self.base = base
176
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
+
179
+ self._set_cos_sin_cache(seq_len=max_position_embeddings,
180
+ device=self.inv_freq.device,
181
+ dtype=torch.get_default_dtype())
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
186
+
187
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
188
+ emb = torch.cat((freqs, freqs), dim=-1)
189
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
190
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
191
+
192
+
193
+ def forward(self, x, seq_len=None):
194
+ # x: [bs, num_attention_heads, seq_len, head_size]
195
+ if seq_len > self.max_seq_len_cached:
196
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
197
+
198
+ # return (
199
+ # self.cos_cached[:seq_len].to(dtype=x.dtype),
200
+ # self.sin_cached[:seq_len].to(dtype=x.dtype),
201
+ # )
202
+ if seq_len < self.max_seq_len_cached:
203
+ cos_slice = self.cos_cached.split(seq_len, dim=0)[0]
204
+ sin_slice = self.sin_cached.split(seq_len, dim=0)[0]
205
+ else:
206
+ cos_slice = self.cos_cached
207
+ sin_slice = self.sin_cached
208
+
209
+ return (
210
+ cos_slice.to(dtype=x.dtype),
211
+ sin_slice.to(dtype=x.dtype),
212
+ )
213
+
214
+
215
+
216
+ class EvaByteLinearScalingRotaryEmbedding(EvaByteRotaryEmbedding):
217
+ """EvaByteRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
218
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
219
+ self.scaling_factor = scaling_factor
220
+ super().__init__(dim, max_position_embeddings, base, device)
221
+
222
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
223
+ self.max_seq_len_cached = seq_len
224
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
225
+ t = t / self.scaling_factor
226
+
227
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
228
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
229
+ emb = torch.cat((freqs, freqs), dim=-1)
230
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
231
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
232
+
233
+
234
+ class EvaByteDynamicNTKScalingRotaryEmbedding(EvaByteRotaryEmbedding):
235
+ """EvaByteRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
236
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
237
+ self.scaling_factor = scaling_factor
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+
243
+ if seq_len > self.max_position_embeddings:
244
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
245
+ (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
246
+ inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+
249
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
250
+
251
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
252
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
253
+ emb = torch.cat((freqs, freqs), dim=-1)
254
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
255
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
256
+
257
+
258
+ class EvaByteMLP(nn.Module):
259
+ def __init__(self, config, layer_idx: int = None):
260
+ super().__init__()
261
+ self.hidden_size = config.hidden_size
262
+ self.intermediate_size = config.intermediate_size
263
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
264
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
265
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
266
+ self.act_fn = ACT2FN[config.hidden_act]
267
+ self.layer_idx = layer_idx
268
+ self.config = config
269
+
270
+ def forward(self, x):
271
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
272
+ return down_proj
273
+
274
+ class EvaByteDecoderLayer(nn.Module):
275
+ def __init__(self, config: EvaByteConfig, layer_idx: int = None):
276
+ super().__init__()
277
+ self.config = config
278
+ self.hidden_size = config.hidden_size
279
+ self.self_attn = EvaAttention(config=config, layer_idx=layer_idx)
280
+ self.mlp = EvaByteMLP(config, layer_idx=layer_idx)
281
+ self.input_layernorm = EvaByteRMSNorm(config)
282
+ self.post_attention_layernorm = EvaByteRMSNorm(config)
283
+
284
+ def forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ attention_mask: Optional[torch.Tensor] = None,
288
+ position_ids: Optional[torch.LongTensor] = None,
289
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
290
+ output_attentions: Optional[bool] = False,
291
+ use_cache: Optional[bool] = False,
292
+ cos: Optional[torch.Tensor] = None,
293
+ sin: Optional[torch.Tensor] = None,
294
+ multibyte_decoding: Optional[bool] = False,
295
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
296
+ residual = hidden_states
297
+ if self.config.fp32_skip_add:
298
+ residual = residual.float()
299
+
300
+ hidden_states = self.input_layernorm(hidden_states)
301
+
302
+ # Self Attention
303
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,
304
+ attention_mask=attention_mask,
305
+ position_ids=position_ids,
306
+ past_key_value=past_key_value,
307
+ output_attentions=output_attentions,
308
+ use_cache=use_cache,
309
+ cos=cos,
310
+ sin=sin,
311
+ multibyte_decoding=multibyte_decoding)
312
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
313
+
314
+ # Fully Connected
315
+ residual = hidden_states
316
+ if self.config.fp32_skip_add:
317
+ residual = residual.float()
318
+ hidden_states = self.post_attention_layernorm(hidden_states)
319
+ hidden_states = self.mlp(hidden_states)
320
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
321
+
322
+ outputs = (hidden_states, )
323
+
324
+ if output_attentions:
325
+ outputs += (self_attn_weights, )
326
+
327
+ if use_cache:
328
+ outputs += (present_key_value, )
329
+ return outputs
330
+
331
+ class EvaBytePreTrainedModel(PreTrainedModel):
332
+ config_class = EvaByteConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["EvaByteDecoderLayer"]
336
+ _skip_keys_device_placement = "past_key_values"
337
+
338
+ def _init_weights(self, module):
339
+ std = getattr(self.config, "initializer_range", 0.02)
340
+ if isinstance(module, nn.Linear):
341
+ module.weight.data.normal_(mean=0.0, std=std)
342
+ if module.bias is not None:
343
+ module.bias.data.zero_()
344
+ elif isinstance(module, nn.Embedding):
345
+ module.weight.data.normal_(mean=0.0, std=std)
346
+ if module.padding_idx is not None:
347
+ module.weight.data[module.padding_idx].zero_()
348
+
349
+ def _set_gradient_checkpointing(self, module, value=False):
350
+ if isinstance(module, EvaByteModel):
351
+ module.gradient_checkpointing = value
352
+
353
+ class EvaByteModel(EvaBytePreTrainedModel):
354
+ """
355
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EvaByteDecoderLayer`]
356
+
357
+ Args:
358
+ config: EvaByteConfig
359
+ """
360
+ def __init__(self, config: EvaByteConfig):
361
+ super().__init__(config)
362
+ self.padding_idx = config.pad_token_id
363
+ self.vocab_size = config.vocab_size
364
+ self.hidden_size = config.hidden_size
365
+ self.num_heads = config.num_attention_heads
366
+ self.head_dim = self.hidden_size // self.num_heads
367
+ self.max_position_embeddings = self.config.max_position_embeddings
368
+
369
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
370
+ self.layers = nn.ModuleList([EvaByteDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
371
+ self.norm = EvaByteRMSNorm(config)
372
+
373
+ self.gradient_checkpointing = False
374
+ self.rope = config.rope_theta
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+ self._init_rope()
378
+
379
+ def _init_rope(self):
380
+ if self.config.rope_scaling is None:
381
+ self.rotary_emb = EvaByteRotaryEmbedding(self.head_dim,
382
+ max_position_embeddings=self.max_position_embeddings,
383
+ base=self.rope)
384
+ else:
385
+ scaling_type = self.config.rope_scaling["type"]
386
+ scaling_factor = self.config.rope_scaling["factor"]
387
+ if scaling_type == "linear":
388
+ self.rotary_emb = EvaByteLinearScalingRotaryEmbedding(
389
+ self.head_dim,
390
+ max_position_embeddings=self.max_position_embeddings,
391
+ scaling_factor=scaling_factor,
392
+ base=self.rope)
393
+ elif scaling_type == "dynamic":
394
+ self.rotary_emb = EvaByteDynamicNTKScalingRotaryEmbedding(
395
+ self.head_dim,
396
+ max_position_embeddings=self.max_position_embeddings,
397
+ scaling_factor=scaling_factor,
398
+ base=self.rope)
399
+ else:
400
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
401
+
402
+ def get_input_embeddings(self):
403
+ return self.embed_tokens
404
+
405
+ def set_input_embeddings(self, value):
406
+ self.embed_tokens = value
407
+
408
+ def _helper_padding_mask(
409
+ self,
410
+ padding_mask,
411
+ causal_mask
412
+ ):
413
+ padding_mask = torch.logical_or(padding_mask, padding_mask.transpose(-1, -2))
414
+ return torch.logical_or(padding_mask, causal_mask)
415
+
416
+ def _prepare_eva_generation_attn_mask_triton(
417
+ self,
418
+ attention_mask,
419
+ input_ids,
420
+ use_cache,
421
+ past_key_values
422
+ ):
423
+ batch_size, seq_len = input_ids.shape
424
+ if use_cache and past_key_values.get_seq_length() > 0:
425
+ # decoding phase
426
+ if past_key_values.rf_mask[0] is not None:
427
+ cur_rf_mask = torch.zeros(
428
+ (batch_size, 1, seq_len, 1),
429
+ dtype=past_key_values.rf_mask[0].dtype,
430
+ device=past_key_values.rf_mask[0].device
431
+ )
432
+ else:
433
+ cur_rf_mask = None
434
+
435
+ if past_key_values.s_mask[0] is not None:
436
+ cur_s_mask = torch.zeros(
437
+ (batch_size, 1, seq_len, 1),
438
+ dtype=past_key_values.s_mask[0].dtype,
439
+ device=past_key_values.s_mask[0].device
440
+ )
441
+ else:
442
+ cur_s_mask = None
443
+
444
+ seen_tokens = past_key_values.get_seq_length()
445
+ if seen_tokens <= self.config.window_size:
446
+ rfa_chunks_dummy_mask = None
447
+ else:
448
+ if cur_s_mask is not None:
449
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
450
+ # the ongoing decoding step would be (seen_seq_len + 1)-th token
451
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
452
+ rfa_chunks_dummy_mask = torch.zeros(
453
+ (batch_size, 1, seq_len, num_windows_seen_so_far * chunks_per_window),
454
+ dtype=past_key_values.s_mask[0].dtype,
455
+ device=past_key_values.s_mask[0].device
456
+ )
457
+ else:
458
+ rfa_chunks_dummy_mask = None
459
+ # rf_mask and cur_mask are 0s because we do not want to mask them
460
+ return (cur_s_mask, cur_rf_mask, rfa_chunks_dummy_mask)
461
+
462
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
463
+ # convert 0 -> padding to 1 -> padding
464
+ padded_attention_mask = pad_to_multiple(
465
+ attention_mask,
466
+ self.config.window_size,
467
+ dim=-1,
468
+ value=0,
469
+ create_mask=False,
470
+ left_padding=False
471
+ )
472
+ # convert 0 -> padding to 1 -> padding
473
+ padded_rf_mask = ~padded_attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
474
+ # [b, 1, w, j, 1]
475
+ padded_w_attn_mask = padded_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1).to(torch.bool)
476
+ # [b, 1, w, j, 1] [b, 1, w, 1, j] -> [b, 1, w, j, j]
477
+ w_padding_mask = torch.logical_or(padded_w_attn_mask, padded_w_attn_mask.transpose(-1, -2))
478
+ w_causal_mask = torch.ones(
479
+ (1, 1, 1, self.config.window_size, self.config.window_size),
480
+ device=input_ids.device
481
+ ).triu(1).to(torch.bool)
482
+ s_mask = torch.logical_or(w_padding_mask, w_causal_mask)
483
+ s_mask = s_mask.reshape(batch_size, 1, -1, self.config.window_size)
484
+ s_mask = s_mask[..., :seq_len, :]
485
+ # negate the attention mask to get the padding mask
486
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
487
+ return (s_mask, rf_mask)
488
+ else:
489
+ return (None, None)
490
+
491
+ def _prepare_eva_generation_attn_mask(
492
+ self,
493
+ attention_mask,
494
+ input_ids,
495
+ use_cache,
496
+ past_key_values
497
+ ):
498
+ batch_size, seq_len = input_ids.shape
499
+ if use_cache and past_key_values.get_seq_length() > 0:
500
+ # decoding phase
501
+ if past_key_values.rf_mask[0] is not None:
502
+ rf_mask = torch.zeros(
503
+ (batch_size, 1, seq_len, 1),
504
+ dtype=past_key_values.rf_mask[0].dtype,
505
+ device=past_key_values.rf_mask[0].device
506
+ )
507
+ else:
508
+ rf_mask = None
509
+
510
+ cur_causal_mask = torch.zeros(
511
+ (batch_size, 1, seq_len, 1),
512
+ dtype=torch.bool,
513
+ device=input_ids.device
514
+ )
515
+
516
+ chunk_causal_mask = torch.ones(
517
+ (batch_size, 1, seq_len, 1),
518
+ dtype=torch.bool,
519
+ device=input_ids.device
520
+ )
521
+ # chunk_causal_mask are 1s because we will mask them by default and
522
+ # will be unmasked when the current singleton attention is processed over
523
+ return (None, cur_causal_mask, chunk_causal_mask, rf_mask)
524
+
525
+ true_num_chunks = seq_len // self.config.chunk_size
526
+ chunk_causal_mask, _ = prepare_eva_attention_mask(
527
+ seq_len,
528
+ input_ids.device,
529
+ self.config.chunk_size,
530
+ self.config.window_size,
531
+ use_cache=use_cache,
532
+ cache=past_key_values
533
+ )
534
+ chunk_causal_mask = chunk_causal_mask[..., :seq_len, :true_num_chunks]
535
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
536
+ # convert 0 -> padding to 1 -> padding
537
+ rf_mask = ~attention_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool) # [b, 1, n, 1]
538
+ else:
539
+ rf_mask = None
540
+
541
+ if seq_len < self.config.window_size:
542
+ cur_window_mask = torch.ones(
543
+ (1, 1, seq_len, seq_len),
544
+ device=input_ids.device
545
+ ).triu(1).to(torch.bool)
546
+ if rf_mask is not None:
547
+ cur_window_mask = self._helper_padding_mask(rf_mask, cur_window_mask)
548
+ prev_window_mask = None
549
+ else:
550
+ if seq_len % self.config.window_size == 0:
551
+ num_windows = seq_len // self.config.window_size
552
+ cur_window_mask = None
553
+ prev_window_mask = torch.ones(
554
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
555
+ device=input_ids.device
556
+ ).triu(1).to(torch.bool)
557
+ if rf_mask is not None:
558
+ prev_rf_mask = rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
559
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
560
+ else:
561
+ num_windows = seq_len // self.config.window_size
562
+ remainder_tokens = seq_len % self.config.window_size
563
+ cur_window_mask = torch.ones(
564
+ (1, 1, remainder_tokens, remainder_tokens),
565
+ device=input_ids.device
566
+ ).triu(1).to(torch.bool)
567
+ prev_window_mask = torch.ones(
568
+ (1, 1, num_windows, self.config.window_size, self.config.window_size),
569
+ device=input_ids.device
570
+ ).triu(1).to(torch.bool)
571
+ if rf_mask is not None:
572
+ prev_rf_mask, cur_rf_mask = torch.split(rf_mask, [seq_len - remainder_tokens, remainder_tokens], dim=-2)
573
+ cur_window_mask = self._helper_padding_mask(cur_rf_mask, cur_window_mask)
574
+ prev_rf_mask = prev_rf_mask.reshape(batch_size, 1, -1, self.config.window_size, 1)
575
+ prev_window_mask = self._helper_padding_mask(prev_rf_mask, prev_window_mask)
576
+
577
+ return (prev_window_mask, cur_window_mask, chunk_causal_mask, rf_mask)
578
+
579
+ def forward(
580
+ self,
581
+ input_ids: torch.LongTensor = None,
582
+ attention_mask: Optional[torch.Tensor] = None,
583
+ position_ids: Optional[torch.LongTensor] = None,
584
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
586
+ use_cache: Optional[bool] = None,
587
+ output_attentions: Optional[bool] = None,
588
+ output_hidden_states: Optional[bool] = None,
589
+ return_dict: Optional[bool] = None,
590
+ multibyte_decoding: Optional[bool] = None,
591
+ ) -> Tuple:
592
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
593
+ output_hidden_states = (output_hidden_states
594
+ if output_hidden_states is not None else self.config.output_hidden_states)
595
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
596
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
597
+
598
+ if (input_ids is None) ^ (inputs_embeds is not None):
599
+ raise ValueError(
600
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
601
+ )
602
+
603
+ if self.gradient_checkpointing and self.training and use_cache:
604
+ raise ValueError("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
605
+
606
+ batch_size, seq_len = input_ids.shape
607
+ #### Step 0. Hack
608
+ if (not self.training) and (not use_cache) and (not multibyte_decoding):
609
+ # forward-only inference mode.
610
+ # We tweak use_cache to be True to reuse code for generation
611
+ use_cache = True
612
+ device = input_ids.device if input_ids is not None else None
613
+ if position_ids is None:
614
+ position_ids = torch.arange(0, seq_len, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
615
+
616
+ #### Step 1. Prepare caches if in inference mode
617
+ if use_cache:
618
+ if past_key_values is not None:
619
+ assert isinstance(past_key_values, Cache)
620
+ else:
621
+ if not USE_TRITON_IMPL:
622
+ past_key_values = EvaCache()
623
+ else:
624
+ past_key_values = EvaStaticCacheForTriton(
625
+ input_ids.shape[0],
626
+ self.config.num_attention_heads,
627
+ self.config.window_size,
628
+ self.config.hidden_size // self.config.num_attention_heads,
629
+ self.config.num_hidden_layers,
630
+ self.embed_tokens.weight.dtype,
631
+ self.embed_tokens.weight.device,
632
+ )
633
+
634
+ if not multibyte_decoding:
635
+ if use_cache:
636
+ if USE_TRITON_IMPL:
637
+ causal_mask = self._prepare_eva_generation_attn_mask_triton(
638
+ attention_mask,
639
+ input_ids,
640
+ use_cache,
641
+ past_key_values
642
+ )
643
+ else:
644
+ causal_mask = self._prepare_eva_generation_attn_mask(
645
+ attention_mask,
646
+ input_ids,
647
+ use_cache,
648
+ past_key_values
649
+ )
650
+ else:
651
+ assert self.training
652
+ assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
653
+ # for training, we need to pass in the attention mask
654
+ # usually calculated by _prepare_training_attn_mask()
655
+ causal_mask = attention_mask
656
+ else:
657
+ assert use_cache
658
+ causal_mask = attention_mask
659
+
660
+ if inputs_embeds is None:
661
+ inputs_embeds = self.embed_tokens(input_ids)
662
+
663
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
664
+ max_seq_length = past_seen_tokens + inputs_embeds.shape[1]
665
+
666
+ hidden_states = inputs_embeds
667
+
668
+ if position_ids is None:
669
+ assert not use_cache, "during decoding we must explicitly pass position_ids to the model call"
670
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
671
+ position_ids = torch.arange(past_seen_tokens, max_seq_length, device=device, dtype=int).reshape(1, -1).expand(batch_size, -1)
672
+
673
+ cos, sin = self.rotary_emb(hidden_states, seq_len=max_seq_length)
674
+ assert len(cos.shape) == 2, f"cos should be of shape (max_seq_len, head_dim), got {cos.shape} instead"
675
+ assert sin.shape == cos.shape, f"sin should be of shape (max_seq_len, head_dim), got {sin.shape} instead"
676
+ assert len(position_ids.shape) == 2, f"position_ids should be of 2D, got {position_ids.shape} instead"
677
+ cos = cos[position_ids, :]
678
+ sin = sin[position_ids, :]
679
+ cos = cos.unsqueeze(1)
680
+ sin = sin.unsqueeze(1)
681
+
682
+ # decoder layers
683
+ all_hidden_states = () if output_hidden_states else None
684
+ all_self_attns = () if output_attentions else None
685
+ next_decoder_cache = None
686
+
687
+ for decoder_layer in self.layers:
688
+ if output_hidden_states:
689
+ all_hidden_states += (hidden_states, )
690
+
691
+ if self.gradient_checkpointing and self.training:
692
+ layer_outputs = torch.utils.checkpoint.checkpoint(
693
+ decoder_layer.__call__,
694
+ hidden_states,
695
+ causal_mask,
696
+ position_ids,
697
+ past_key_values,
698
+ output_attentions,
699
+ use_cache,
700
+ cos,
701
+ sin,
702
+ multibyte_decoding,
703
+ )
704
+ else:
705
+ layer_outputs = decoder_layer(
706
+ hidden_states,
707
+ attention_mask=causal_mask,
708
+ position_ids=position_ids,
709
+ past_key_value=past_key_values,
710
+ output_attentions=output_attentions,
711
+ use_cache=use_cache,
712
+ cos=cos,
713
+ sin=sin,
714
+ multibyte_decoding=multibyte_decoding,
715
+ )
716
+
717
+ hidden_states = layer_outputs[0]
718
+
719
+ if use_cache:
720
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
721
+
722
+ if output_attentions:
723
+ all_self_attns += (layer_outputs[1], )
724
+
725
+ hidden_states = self.norm(hidden_states)
726
+
727
+ # add hidden states from the last decoder layer
728
+ if output_hidden_states:
729
+ all_hidden_states += (hidden_states, )
730
+
731
+ next_cache = next_decoder_cache if use_cache else None
732
+ if not return_dict:
733
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
734
+
735
+ return BaseModelOutputWithPast(
736
+ last_hidden_state=hidden_states,
737
+ past_key_values=next_cache,
738
+ hidden_states=all_hidden_states,
739
+ attentions=all_self_attns,
740
+ )
741
+
742
+
743
+ class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
744
+ _tied_weights_keys = ["lm_head.weight"]
745
+
746
+ def __init__(self, config):
747
+ EvaBytePreTrainedModel.__init__(self, config)
748
+
749
+ self.model = EvaByteModel(config)
750
+ self.vocab_size = config.vocab_size
751
+ # define multibyte prediction heads
752
+ if hasattr(config, "num_pred_heads") and config.num_pred_heads > 1:
753
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_pred_heads, bias=False)
754
+ else:
755
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
756
+
757
+ self.post_init()
758
+
759
+ def get_input_embeddings(self):
760
+ return self.model.embed_tokens
761
+
762
+ def set_input_embeddings(self, value):
763
+ self.model.embed_tokens = value
764
+
765
+ def get_output_embeddings(self):
766
+ return self.lm_head
767
+
768
+ def set_output_embeddings(self, new_embeddings):
769
+ self.lm_head = new_embeddings
770
+
771
+ def set_decoder(self, decoder):
772
+ self.model = decoder
773
+
774
+ def get_decoder(self):
775
+ return self.model
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: torch.LongTensor = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ position_ids: Optional[torch.LongTensor] = None,
782
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
783
+ inputs_embeds: Optional[torch.FloatTensor] = None,
784
+ labels: Optional[torch.LongTensor] = None,
785
+ use_cache: Optional[bool] = None,
786
+ output_attentions: Optional[bool] = None,
787
+ output_hidden_states: Optional[bool] = None,
788
+ return_dict: Optional[bool] = None,
789
+ return_all_pred_logits: Optional[bool] = None,
790
+ multibyte_decoding: Optional[bool] = None) -> Union[Tuple, CausalLMOutputWithPast]:
791
+
792
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
793
+ output_hidden_states = (output_hidden_states
794
+ if output_hidden_states is not None else self.config.output_hidden_states)
795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796
+
797
+ if input_ids is None:
798
+ assert past_key_values is None
799
+
800
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
801
+ outputs = self.model(
802
+ input_ids=input_ids,
803
+ attention_mask=attention_mask,
804
+ position_ids=position_ids,
805
+ past_key_values=past_key_values,
806
+ inputs_embeds=inputs_embeds,
807
+ use_cache=use_cache,
808
+ output_attentions=output_attentions,
809
+ output_hidden_states=output_hidden_states,
810
+ return_dict=return_dict,
811
+ multibyte_decoding=multibyte_decoding,
812
+ )
813
+
814
+ hidden_states = outputs[0]
815
+
816
+ logits = self.lm_head(hidden_states)
817
+ if self.config.fp32_logits:
818
+ logits = logits.float()
819
+
820
+ loss = None
821
+ if labels is not None:
822
+ loss_fct = CrossEntropyLoss(reduction="none")
823
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
824
+ shift_logits = logits.view(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
825
+ # shift_logits = shift_logits.view(-1, logits.shape[1] * self.config.num_pred_heads, self.config.vocab_size)
826
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
827
+ else:
828
+ shift_logits = logits.view(-1, self.config.vocab_size)
829
+ shift_labels = labels.view(-1)
830
+ # Enable model parallelism
831
+ shift_labels = shift_labels.to(shift_logits.device)
832
+ loss = loss_fct(shift_logits, shift_labels)
833
+
834
+ if hasattr(self.config, "num_pred_heads") and self.config.num_pred_heads > 1:
835
+ all_pred_logits = logits.reshape(logits.shape[0], logits.shape[1], self.config.num_pred_heads, self.config.vocab_size)
836
+
837
+ if return_all_pred_logits:
838
+ logits = all_pred_logits
839
+ else:
840
+ logits = all_pred_logits[..., 0, :]
841
+
842
+ if not return_dict:
843
+ output = (logits, ) + outputs[1:]
844
+ return (loss, ) + output if loss is not None else output
845
+
846
+ return CausalLMOutputWithPast(
847
+ loss=loss,
848
+ logits=logits,
849
+ past_key_values=outputs.past_key_values,
850
+ hidden_states=outputs.hidden_states,
851
+ attentions=outputs.attentions,
852
+ )
853
+
854
+
855
+ def prepare_inputs_for_generation(self,
856
+ input_ids,
857
+ past_key_values=None,
858
+ attention_mask=None,
859
+ inputs_embeds=None,
860
+ use_cache=True,
861
+ **kwargs):
862
+ # prefill phase:
863
+ # input_ids: b x s
864
+ # attention_mask: None if no padding or b x s
865
+ # position_ids : b x s
866
+
867
+ # token gen phase:
868
+ # input_ids : b x 1
869
+ # attention_mask: b x 1 x s
870
+ # position_ids: b x 1
871
+ past_length = 0
872
+ if past_key_values is not None:
873
+ assert isinstance(past_key_values, Cache)
874
+ past_length = past_key_values.get_seq_length()
875
+
876
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
877
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
878
+ elif past_length < input_ids.shape[1]:
879
+ input_ids = input_ids[:, past_length:]
880
+
881
+ position_ids = kwargs.get("position_ids", None)
882
+ if attention_mask is not None and position_ids is None:
883
+ position_ids = attention_mask.long().cumsum(-1) - 1
884
+ position_ids.masked_fill_(attention_mask == 0, 1)
885
+ if past_key_values:
886
+ position_ids = position_ids[:, -input_ids.shape[1]:]
887
+
888
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
889
+ if inputs_embeds is not None and past_key_values is None:
890
+ model_inputs = {"inputs_embeds": inputs_embeds}
891
+ else:
892
+ model_inputs = {"input_ids": input_ids}
893
+
894
+ # must initialize position_ids at each step during GPU inference
895
+ assert position_ids is not None
896
+ model_inputs.update(
897
+ {
898
+ "position_ids": position_ids,
899
+ "past_key_values": past_key_values,
900
+ "use_cache": use_cache,
901
+ "attention_mask": attention_mask,
902
+ }
903
+ )
904
+ return model_inputs
905
+
906
+ @staticmethod
907
+ def _reorder_cache(past_key_values, beam_idx):
908
+ reordered_past = ()
909
+ for layer_past in past_key_values:
910
+ reordered_past += (tuple(
911
+ past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), )
912
+ return reordered_past
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/multibyte_decoding_evabyte.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # The implementation of multibyte deocidng is largely adapted from
3
+ # Medusa decoding: https://github.com/FasterDecoding/Medusa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers.generation.stopping_criteria import (
7
+ MaxLengthCriteria,
8
+ StoppingCriteriaList,
9
+ )
10
+ from typing import Union, List
11
+ from .eva_cache import EvaStaticCacheForTriton
12
+ from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
13
+
14
+ class MultibyteEosTokenCriteria:
15
+ """
16
+ This class implements a simple stopping criteria to stop generation whenever
17
+ the "end-of-sequence" token is generated in the last `new_tokens` tokens.
18
+
19
+ Adapted from
20
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
21
+ By default, it uses the `model.generation_config.eos_token_id`.
22
+
23
+ Args:
24
+ eos_token_id (`Union[int, List[int]]`):
25
+ The id(s) of the *end-of-sequence* token.
26
+ """
27
+
28
+ def __init__(self, eos_token_ids: Union[int, List[int]]):
29
+ if isinstance(eos_token_ids, int):
30
+ eos_token_ids = [eos_token_ids]
31
+ self.eos_token_ids = eos_token_ids
32
+
33
+ def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
34
+ current_input_len = input_ids.shape[-1]
35
+ new_token_ids = input_ids[:, current_input_len - new_tokens:]
36
+ for eos_token_id in self.eos_token_ids:
37
+ if torch.any(new_token_ids == eos_token_id):
38
+ return True
39
+ return False
40
+
41
+ def build_tree(spec):
42
+ nodes_at_depth = []
43
+ nodes_at_depth.append([()]) # Root at depth 1
44
+
45
+ for d in range(1, len(spec) + 1):
46
+ prev_nodes = nodes_at_depth[d - 1]
47
+ spec_list = spec[d - 1]
48
+ current_nodes = []
49
+ for node_idx, node in enumerate(prev_nodes):
50
+ if node_idx < len(spec_list):
51
+ num_children = spec_list[node_idx]
52
+ else:
53
+ num_children = 0
54
+ for child_idx in range(num_children):
55
+ new_node = node + (child_idx,)
56
+ current_nodes.append(new_node)
57
+ nodes_at_depth.append(current_nodes)
58
+
59
+ # Flatten the list of nodes, excluding the root node if desired
60
+ all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
61
+ return all_nodes
62
+
63
+ evabyte_7b_95 = build_tree(
64
+ [
65
+ [10],
66
+ [10, 8, 2, 2, 1, 1],
67
+ [10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
68
+ [8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
69
+ [6, 2, 1, 1],
70
+ [4, 2, 1, 1],
71
+ [4, 2, 1],
72
+ ]
73
+ )
74
+ evabyte_7b_31 = build_tree(
75
+ [
76
+ [4],
77
+ [3, 2, 1, 1],
78
+ [3, 2, 1, 1],
79
+ [2, 1, 1],
80
+ [2, 1],
81
+ [2, 1],
82
+ [2, 1],
83
+ ]
84
+ )
85
+ TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
86
+
87
+ def pad_path(path, length, pad_value=-2):
88
+ """
89
+ Pad the given path list with a specific value up to a specified length.
90
+
91
+ Parameters:
92
+ - path (list): The original list that needs padding.
93
+ - length (int): The desired length of the padded list.
94
+ - pad_value (optional, default=-2): The value to use for padding.
95
+
96
+ Returns:
97
+ - list: A new list based on the original path but padded to the desired length.
98
+
99
+ Example:
100
+ >>> pad_path([1,2,3], 5)
101
+ [1, 2, 3, -2, -2]
102
+
103
+ Note:
104
+ If the given path is already longer than the specified length,
105
+ then no padding occurs, and the original path is returned.
106
+ """
107
+ return path + [pad_value] * (length - len(path))
108
+
109
+ def reset_past_key_values(passed_key_values):
110
+ """
111
+ Resets the current lengths in the passed key-values to zero.
112
+
113
+ This function is designed to be used during the evaluation of a baseline model.
114
+ It iterates through each layer's key-values and sets their current lengths to zero,
115
+ effectively resetting their state.
116
+
117
+ Args:
118
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
119
+
120
+ Returns:
121
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
122
+ """
123
+ for i in range(len(passed_key_values)):
124
+ for j in range(2):
125
+ passed_key_values[i][j].current_length.fill_(0)
126
+ return passed_key_values
127
+
128
+ def get_nucleus_one_token(logit, temperature, top_p):
129
+ """
130
+ Performs token sampling based on the nucleus (top-p) sampling method.
131
+
132
+ This function selects a token from a given logit distribution using the nucleus sampling strategy.
133
+ It allows for more controlled and diverse generation compared to traditional top-k sampling.
134
+
135
+ Args:
136
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
137
+ temperature (float): A temperature parameter to control the randomness in sampling.
138
+ Higher values increase diversity, lower values make selections more deterministic.
139
+ top_p (float): The cumulative probability threshold for nucleus sampling.
140
+ It controls the size of the set of high-probability tokens to consider for sampling.
141
+
142
+ Returns:
143
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
144
+ """
145
+ if top_p >= 1:
146
+ return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
147
+ logit = logit / temperature
148
+ probs = torch.softmax(logit, dim=-1)
149
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
150
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
151
+ sorted_indices_to_remove = cum_probs > top_p
152
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
153
+ sorted_indices_to_remove[..., 0] = 0
154
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
155
+ logit[indices_to_remove] = float('-inf')
156
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
157
+ return sampled_tokens
158
+
159
+ def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
160
+ """
161
+ Implements token sampling based on the typical sampling method.
162
+
163
+ This function selects a token from a given logit distribution using the typical sampling strategy,
164
+ aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
165
+
166
+ Args:
167
+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
168
+ temperature (float): A parameter to control the randomness in sampling.
169
+ Higher values increase diversity, lower values make selections more deterministic.
170
+ posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
171
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
172
+
173
+ Returns:
174
+ torch.Tensor: A tensor containing the indices of the sampled tokens.
175
+ """
176
+ logit = logit / temperature
177
+ probs = torch.softmax(logit, dim=-1)
178
+ entropy = -torch.sum(
179
+ probs * torch.log(probs + 1e-5), dim=-1
180
+ )
181
+ threshold = torch.minimum(
182
+ torch.ones_like(entropy) * posterior_threshold,
183
+ torch.exp(-entropy) * posterior_alpha,
184
+ )
185
+ indices_to_remove = probs < threshold.unsqueeze(-1)
186
+ logit[indices_to_remove] = float('-inf')
187
+ sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
188
+ return sampled_tokens
189
+
190
+
191
+
192
+ def generate_medusa_buffers(medusa_choices, device="cuda"):
193
+ """
194
+ Generate buffers for the Medusa structure based on the provided choices.
195
+
196
+ Parameters:
197
+ - medusa_choices (list): A nested list representing tree in the Medusa structure.
198
+ - device (str): Device to which the tensors should be moved. Default is "cuda".
199
+
200
+ Returns:
201
+ - dict: A dictionary containing buffers related to the Medusa structure.
202
+ """
203
+
204
+ # Sort the medusa_choices based on their lengths and then their values
205
+ sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
206
+ medusa_len = len(sorted_medusa_choices) + 1
207
+
208
+ # Initialize depth_counts to keep track of how many choices have a particular depth
209
+ depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
210
+ for path in sorted_medusa_choices:
211
+ depth_counts[len(path) - 1] += 1
212
+
213
+ # Create the attention mask for Medusa
214
+ medusa_attn_mask = torch.eye(medusa_len, medusa_len)
215
+ medusa_attn_mask[:, 0] = 1
216
+ start = 0
217
+ for i in range(len(depth_counts)):
218
+ for j in range(depth_counts[i]):
219
+ cur_medusa_choice = sorted_medusa_choices[start + j]
220
+ # retrieve ancestor position
221
+ if len(cur_medusa_choice) == 1:
222
+ continue
223
+ ancestor_idx = []
224
+ for c in range(len(cur_medusa_choice) - 1):
225
+ ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
226
+ medusa_attn_mask[j + start + 1, ancestor_idx] = 1
227
+ start += depth_counts[i]
228
+
229
+ # Generate tree indices for the Medusa structure
230
+ medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
231
+ medusa_tree_indices[0] = 0
232
+ start = 0
233
+ for i in range(len(depth_counts)):
234
+ for j in range(depth_counts[i]):
235
+ cur_medusa_choice = sorted_medusa_choices[start + j]
236
+ medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
237
+ start += depth_counts[i]
238
+
239
+ # Generate position IDs for the Medusa structure
240
+ medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
241
+ start = 0
242
+ for i in range(len(depth_counts)):
243
+ medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
244
+ start += depth_counts[i]
245
+
246
+ # Generate retrieval indices for Medusa structure verification
247
+ retrieve_indices_nest = []
248
+ retrieve_paths = []
249
+ for i in range(len(sorted_medusa_choices)):
250
+ cur_medusa_choice = sorted_medusa_choices[-i-1]
251
+ retrieve_indice = []
252
+ if cur_medusa_choice in retrieve_paths:
253
+ continue
254
+ else:
255
+ for c in range(len(cur_medusa_choice)):
256
+ retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
257
+ retrieve_paths.append(cur_medusa_choice[:c+1])
258
+ retrieve_indices_nest.append(retrieve_indice)
259
+ max_length = max([len(x) for x in retrieve_indices_nest])
260
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
261
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
262
+ retrieve_indices = retrieve_indices + 1
263
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
264
+
265
+ # Aggregate the generated buffers into a dictionary
266
+ medusa_buffers = {
267
+ "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
268
+ "tree_indices": medusa_tree_indices,
269
+ "medusa_position_ids": medusa_position_ids.unsqueeze(0),
270
+ "retrieve_indices": retrieve_indices,
271
+ }
272
+
273
+ # Move the tensors in the dictionary to the specified device
274
+ medusa_buffers = {
275
+ k: v.clone().to(device)
276
+ if isinstance(v, torch.Tensor)
277
+ else torch.tensor(v, device=device)
278
+ for k, v in medusa_buffers.items()
279
+ }
280
+ return medusa_buffers
281
+
282
+ def generate_candidates(
283
+ medusa_logits,
284
+ logits,
285
+ tree_indices,
286
+ retrieve_indices,
287
+ temperature = 0,
288
+ posterior_threshold=0.3,
289
+ posterior_alpha = 0.09,
290
+ top_p=0.8,
291
+ sampling = 'typical',
292
+ fast = False
293
+ ):
294
+ # Say we have 3 heads, and the top-4 for each head are:
295
+ # [10, 3, 8, 4]
296
+ # [9, 5, 1, 6]
297
+ # [7, 16, 3, 2]
298
+
299
+ # candidates_id = 10
300
+ if temperature == 0 or fast:
301
+ candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
302
+ else:
303
+ if sampling == 'typical':
304
+ candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
305
+ elif sampling == 'nucleus':
306
+ candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
307
+ else:
308
+ raise NotImplementedError
309
+
310
+ # this calculates the top-k medusa logits
311
+ # candidates_medusa_id = [
312
+ # [9, 5, 1, 6]
313
+ # [7, 16, 3, 2]
314
+ # ]
315
+ candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
316
+
317
+ # [10, 9, 5, 1, 6, 7, 16, 3, 2]
318
+ candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
319
+
320
+ # based on the pre-defined tree_indices, select the corresponding candidates
321
+ # if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
322
+ # tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
323
+ tree_candidate_ids = candidate_ids[tree_indices]
324
+
325
+ # tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
326
+ # Sometimes the tree_indices are padded, so we append a zero here
327
+ # so that all padded indices select the appended zero.
328
+ tree_candidate_ids_ext = torch.cat(
329
+ [
330
+ tree_candidate_ids,
331
+ torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
332
+ ],
333
+ dim=0
334
+ )
335
+ # [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
336
+ unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
337
+
338
+ tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
339
+
340
+ return tree_candidate_ids, unflattened_candidate_ids
341
+
342
+ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
343
+ """
344
+ Generates a posterior mask for token candidates using nucleus (top-p) sampling.
345
+
346
+ This function applies nucleus sampling to a set of logits, and then generates a mask indicating
347
+ which candidate tokens are selected. It adapts the sampling strategy to accommodate for
348
+ temperature scaling and cumulative probability thresholding.
349
+
350
+ Args:
351
+ logits (torch.Tensor): A tensor of logits from a language model output.
352
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
353
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
354
+ top_p (float): The cumulative probability threshold for nucleus sampling.
355
+
356
+ Returns:
357
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
358
+ """
359
+ # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
360
+
361
+ # Apply temperature
362
+ logits = logits[:, :-1] / temperature
363
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
364
+ logits = logits.view(n_samples*n_tokens, -1)
365
+ if top_p >= 1:
366
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
367
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
368
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
369
+ return posterior_mask
370
+ # Convert to probabilities (softmax)
371
+ probs = F.softmax(logits, dim=-1)
372
+ # Sort the probabilities
373
+ sorted_logits, sorted_indices = torch.sort(probs, descending=True)
374
+
375
+ # Compute cumulative probabilities
376
+ cum_probs = torch.cumsum(sorted_logits, dim=-1)
377
+
378
+ # Create mask for the top-p nucleus
379
+ sorted_indices_to_remove = cum_probs > top_p
380
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
381
+ sorted_indices_to_remove[..., 0] = 0
382
+
383
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
384
+
385
+
386
+ # Remove low-probability tokens
387
+ logits[indices_to_remove] = float('-inf')
388
+ # Sample from the remaining tokens
389
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
390
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
391
+ # Create a mask for selected tokens
392
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
393
+
394
+ return posterior_mask
395
+
396
+ def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
397
+ """
398
+ Args:
399
+ logits (torch.Tensor): A tensor of logits from a language model output.
400
+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
401
+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
402
+ posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
403
+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
404
+
405
+ Returns:
406
+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
407
+ """
408
+ logits = logits[:, :-1] / temperature
409
+ n_samples, n_tokens = logits.shape[0], logits.shape[1]
410
+ logits = logits.view(n_samples*n_tokens, -1)
411
+ probs = F.softmax(logits, dim=-1)
412
+ entropy = -torch.sum(
413
+ probs * torch.log(probs + 1e-5), dim=-1
414
+ )
415
+ threshold = torch.minimum(
416
+ torch.ones_like(entropy) * posterior_threshold,
417
+ torch.exp(-entropy) * posterior_alpha,
418
+ )
419
+ indices_to_remove = probs < threshold.unsqueeze(-1)
420
+ logits[indices_to_remove] = float('-inf')
421
+ sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
422
+ sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
423
+ posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
424
+ return posterior_mask
425
+
426
+
427
+
428
+ def evaluate_posterior(
429
+ logits,
430
+ candidates,
431
+ temperature,
432
+ posterior_threshold=0.3,
433
+ posterior_alpha = 0.09,
434
+ top_p=0.8,
435
+ sampling = 'typical',
436
+ fast = True
437
+ ):
438
+ if logits.shape[1] <= 1:
439
+ return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
440
+ # Greedy decoding based on temperature value
441
+ if temperature == 0:
442
+ # Find the tokens that match the maximum logits for each position in the sequence
443
+ posterior_mask = (
444
+ candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
445
+ ).int()
446
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
447
+ accept_length = candidates_accept_length.max().item()
448
+ # Choose the best candidate
449
+ if accept_length == 0:
450
+ # Default to the first candidate if none are accepted
451
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
452
+ else:
453
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
454
+ return best_candidate, accept_length
455
+ elif sampling == 'typical':
456
+ if fast:
457
+ posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
458
+ candidates_prob = torch.gather(
459
+ posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
460
+ ).squeeze(-1)
461
+ posterior_entropy = -torch.sum(
462
+ posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
463
+ ) # torch.sum(torch.log(*)) is faster than torch.prod
464
+ threshold = torch.minimum(
465
+ torch.ones_like(posterior_entropy) * posterior_threshold,
466
+ torch.exp(-posterior_entropy) * posterior_alpha,
467
+ )
468
+ posterior_mask = candidates_prob > threshold
469
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
470
+
471
+ # Choose the best candidate based on the evaluated posterior probabilities
472
+ accept_length = candidates_accept_length.max().item()
473
+ if accept_length == 0:
474
+ # If no candidates are accepted, just choose the first one
475
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
476
+ else:
477
+ best_candidates = torch.where(candidates_accept_length == accept_length)[0]
478
+ # Accept the best one according to likelihood
479
+ likelihood = torch.sum(
480
+ torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
481
+ )
482
+ best_candidate = best_candidates[torch.argmax(likelihood)]
483
+ return best_candidate, accept_length
484
+ # Calculate posterior probabilities and thresholds for candidate selection
485
+ posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
486
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
487
+ # Choose the best candidate based on the evaluated posterior probabilities
488
+ accept_length = candidates_accept_length.max().item()
489
+
490
+ if accept_length == 0:
491
+ # If no candidates are accepted, just choose the first one
492
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
493
+ else:
494
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
495
+ # Accept the best one according to likelihood
496
+ return best_candidate, accept_length
497
+ elif sampling == 'nucleus':
498
+ assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
499
+ posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
500
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
501
+ accept_length = candidates_accept_length.max().item()
502
+ # Choose the best candidate
503
+ if accept_length == 0:
504
+ # Default to the first candidate if none are accepted
505
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
506
+ else:
507
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
508
+ return best_candidate, accept_length
509
+ else:
510
+ raise NotImplementedError
511
+
512
+ def update_inference_inputs(
513
+ input_ids,
514
+ medusa_logits,
515
+ logits,
516
+ candidate_ids,
517
+ best_candidate,
518
+ accept_length,
519
+ ):
520
+ input_ids = torch.cat(
521
+ [
522
+ input_ids,
523
+ candidate_ids[None, best_candidate, : accept_length + 1]
524
+ ],
525
+ dim=-1
526
+ )
527
+ logits = logits[
528
+ None, best_candidate, accept_length : accept_length + 1
529
+ ]
530
+ medusa_logits = medusa_logits[
531
+ :, None, best_candidate, accept_length : accept_length + 1
532
+ ]
533
+ # Update the new token counter
534
+ new_token = accept_length + 1
535
+ return input_ids, medusa_logits, logits, new_token
536
+
537
+ def split_logits(full_logits):
538
+ # logits has shape [b, n, heads, vocab_size]
539
+ logits = full_logits[..., 0, :]
540
+ medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
541
+ return medusa_logits, logits
542
+
543
+ class MultiByteDecodingMixin:
544
+ def multi_byte_pred_update_cache(
545
+ self,
546
+ past_key_values,
547
+ retrieve_indices,
548
+ best_candidate,
549
+ new_tokens,
550
+ ):
551
+ prev_window_len = past_key_values.get_past_window_pos(0)
552
+ select_indices = (
553
+ retrieve_indices[best_candidate, : new_tokens] + prev_window_len
554
+ )
555
+ for layer_idx in range(self.config.num_hidden_layers):
556
+
557
+ past_key_values.update_past_len(new_tokens, layer_idx)
558
+
559
+ past_window_k = past_key_values.past_window_k[layer_idx]
560
+ past_window_v = past_key_values.past_window_v[layer_idx]
561
+
562
+ tgt_window_k = past_window_k[..., select_indices, :]
563
+ tgt_window_v = past_window_v[..., select_indices, :]
564
+
565
+ dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
566
+ dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
567
+
568
+ dst_window_k.copy_(tgt_window_k, non_blocking=True)
569
+ dst_window_v.copy_(tgt_window_v, non_blocking=True)
570
+
571
+ new_window_len = prev_window_len + new_tokens
572
+ if new_window_len >= self.config.window_size:
573
+ assert new_window_len < 2 * self.config.window_size
574
+
575
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
576
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
577
+
578
+ _window_len = new_window_len - self.config.window_size
579
+
580
+ if _window_len > 0:
581
+ new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
582
+ new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
583
+
584
+ _dst_window_k = past_window_k[..., : _window_len, :]
585
+ _dst_window_v = past_window_v[..., : _window_len, :]
586
+
587
+ _dst_window_k.copy_(new_window_k, non_blocking=True)
588
+ _dst_window_v.copy_(new_window_v, non_blocking=True)
589
+
590
+ past_key_values.past_window_pos[layer_idx] = _window_len
591
+ else:
592
+ dump_k = None
593
+ dump_v = None
594
+ past_key_values.past_window_pos[layer_idx] = new_window_len
595
+
596
+ if dump_k is not None and dump_v is not None:
597
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
598
+ dump_k, dump_v,
599
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
600
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
601
+ None,
602
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
603
+ self.model.layers[layer_idx].self_attn.chunk_size
604
+ )
605
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
606
+ rfa_k, rfa_v, layer_idx
607
+ )
608
+ return past_key_values
609
+
610
+ def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
611
+ self,
612
+ past_key_values,
613
+ ):
614
+ prev_window_len = past_key_values.get_past_window_pos(0)
615
+ for layer_idx in range(self.config.num_hidden_layers):
616
+
617
+ past_window_k = past_key_values.past_window_k[layer_idx]
618
+ past_window_v = past_key_values.past_window_v[layer_idx]
619
+
620
+ new_window_len = prev_window_len
621
+ if new_window_len == self.config.window_size:
622
+ dump_k = past_window_k[..., :self.config.window_size, :].clone()
623
+ dump_v = past_window_v[..., :self.config.window_size, :].clone()
624
+ past_key_values.past_window_pos[layer_idx] = 0
625
+
626
+ if dump_k is not None and dump_v is not None:
627
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
628
+ dump_k, dump_v,
629
+ self.model.layers[layer_idx].self_attn.adaptive_mu_k,
630
+ self.model.layers[layer_idx].self_attn.adaptive_phi,
631
+ None,
632
+ self.model.layers[layer_idx].self_attn.head_dim_scaling,
633
+ self.model.layers[layer_idx].self_attn.chunk_size
634
+ )
635
+ rfa_k, rfa_v = past_key_values.update_chunk_rfas(
636
+ rfa_k, rfa_v, layer_idx
637
+ )
638
+ return past_key_values
639
+
640
+ def multi_byte_pred_update_attn_mask(
641
+ self,
642
+ last_iter_new_tokens,
643
+ tree_candidate_ids,
644
+ past_attn_mask,
645
+ medusa_attn_mask,
646
+ past_key_values,
647
+ ):
648
+ batch_size, tree_candidate_len = tree_candidate_ids.shape
649
+ seen_tokens = past_key_values.get_seq_length()
650
+ # NOTE: past_key_values has been updated so now
651
+ # seen_tokens incldues new tokens from the last tree iteration
652
+ assert seen_tokens > 0
653
+ # so one iteration would not cross two windows
654
+ assert last_iter_new_tokens < self.config.window_size
655
+
656
+ if past_attn_mask is not None and seen_tokens < self.config.window_size:
657
+ past_attn_mask = torch.cat(
658
+ [
659
+ past_attn_mask,
660
+ torch.ones(
661
+ [batch_size, 1, tree_candidate_len, last_iter_new_tokens],
662
+ dtype=torch.bool,
663
+ device=self.device
664
+ )
665
+ ],
666
+ dim=-1
667
+ )
668
+ else:
669
+ # we initialize attn mask each time when
670
+ # 1. the model crosses the window bounary, or
671
+ # 2. after prefilling
672
+ chunks_per_window = int(self.config.window_size // self.config.chunk_size)
673
+
674
+ window_tokens = seen_tokens % self.config.window_size
675
+ num_windows_seen_so_far = seen_tokens // self.config.window_size
676
+ attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
677
+ past_attn_mask = torch.ones(
678
+ (batch_size, 1, tree_candidate_len, attn_mask_len),
679
+ dtype=torch.bool,
680
+ device=self.device
681
+ )
682
+
683
+ # note that 1 indicates the position is not masked
684
+ tree_attn_mask = torch.cat(
685
+ [
686
+ past_attn_mask,
687
+ medusa_attn_mask.to(torch.bool)
688
+ ],
689
+ dim=-1
690
+ )
691
+ return tree_attn_mask, past_attn_mask
692
+
693
+ @torch.no_grad()
694
+ def multi_byte_generate(
695
+ self,
696
+ input_ids,
697
+ attention_mask=None,
698
+ temperature=0.0,
699
+ max_length=None,
700
+ max_new_tokens=None,
701
+ stopping_criteria=None,
702
+ posterior_threshold=0.09,
703
+ posterior_alpha=0.3,
704
+ top_p=0.8,
705
+ sampling='typical',
706
+ fast=True,
707
+ do_sample=False,
708
+ medusa_choices=None,
709
+ return_acc_lengths=False
710
+ ):
711
+ if do_sample or temperature > 0.0:
712
+ fast = False
713
+
714
+ ### Prepare `max_length` depending on other stopping criteria.
715
+ if max_new_tokens is not None:
716
+ max_length = max_new_tokens + input_ids.shape[-1]
717
+ elif max_new_tokens is None and max_length is None:
718
+ max_length = getattr(self.config, "max_position_embeddings", 32768)
719
+
720
+ ### Set up stopping criteria
721
+ eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
722
+ stop_criteria = StoppingCriteriaList()
723
+ if max_length is not None:
724
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
725
+ stop_criteria.append(
726
+ MaxLengthCriteria(
727
+ max_length=max_length,
728
+ max_position_embeddings=max_position_embeddings,
729
+ )
730
+ )
731
+ if stopping_criteria is not None and len(stopping_criteria) > 0:
732
+ stop_criteria.extend(stopping_criteria)
733
+
734
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
735
+ assert attention_mask is None, "Only support attention mask None for now"
736
+ # Avoid modifying the input_ids in-place
737
+ input_ids = input_ids.clone()
738
+ position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
739
+
740
+ ####################################################
741
+ # 0. initialize the medusa buffers
742
+ ####################################################
743
+ if medusa_choices is None:
744
+ medusa_choices = evabyte_7b_95
745
+ medusa_buffers = generate_medusa_buffers(
746
+ medusa_choices, device=self.device
747
+ )
748
+
749
+ past_key_values = EvaStaticCacheForTriton(
750
+ input_ids.shape[0],
751
+ self.config.num_attention_heads,
752
+ # we add 256 to allow tree ids
753
+ self.config.window_size + 256,
754
+ self.config.hidden_size // self.config.num_attention_heads,
755
+ self.config.num_hidden_layers,
756
+ self.lm_head.weight.dtype,
757
+ self.lm_head.weight.device,
758
+ )
759
+ # prefill to get medusa logits and logits
760
+ full_logits, past_key_values = self.forward(
761
+ input_ids,
762
+ attention_mask=attention_mask,
763
+ position_ids=position_ids,
764
+ use_cache=True,
765
+ past_key_values=past_key_values,
766
+ return_all_pred_logits=True,
767
+ multibyte_decoding=False,
768
+ )
769
+ # handles an edge case where the prefill length == window_size
770
+ # we force the previous window to be dumped into RFA chunks
771
+ past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
772
+ past_key_values
773
+ )
774
+ medusa_logits, logits = split_logits(full_logits)
775
+
776
+ past_attn_mask = None
777
+ last_iter_new_tokens = 0
778
+ max_iters = 32768
779
+ if return_acc_lengths:
780
+ acc_lengths = []
781
+ for _ in range(max_iters):
782
+ ####################################################
783
+ # 1. generate candidate_ids with topk predictions from Medusa heads
784
+ ####################################################
785
+ tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
786
+ medusa_logits,
787
+ logits,
788
+ medusa_buffers["tree_indices"],
789
+ medusa_buffers["retrieve_indices"],
790
+ temperature=temperature,
791
+ posterior_alpha=posterior_alpha,
792
+ posterior_threshold=posterior_threshold,
793
+ top_p=top_p,
794
+ sampling=sampling,
795
+ fast=fast,
796
+ )
797
+
798
+ ####################################################
799
+ # 2. Build the medusa attention mask and position ids
800
+ ####################################################
801
+ # NOTE: 1 indicates the position is not masked
802
+ medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
803
+ last_iter_new_tokens,
804
+ tree_candidate_ids,
805
+ past_attn_mask,
806
+ medusa_buffers["medusa_attn_mask"],
807
+ past_key_values,
808
+ )
809
+ medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
810
+
811
+ ####################################################
812
+ # 3. tree decoding
813
+ ####################################################
814
+ tree_full_logits, past_key_values = self.forward(
815
+ tree_candidate_ids,
816
+ past_key_values=past_key_values,
817
+ attention_mask=medusa_attn_mask,
818
+ position_ids=medusa_position_ids,
819
+ return_all_pred_logits=True,
820
+ multibyte_decoding=True,
821
+ )
822
+ _medusa_logits, _logits = split_logits(tree_full_logits)
823
+ medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
824
+ logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
825
+
826
+ ####################################################
827
+ # 4. candidate selection
828
+ ####################################################
829
+ # if the current iteration, with tree tokens, crosses window
830
+ # boundaries, trim the condidate_ids to be within the window
831
+ # so that those exceeded tokens (which will be inaccurate)
832
+ # will not be considered
833
+ tree_depth = unflattened_candidate_ids.shape[-1]
834
+ if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
835
+ max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
836
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
837
+ _trimmed_logits = logits[:, :max_acc_len]
838
+ else:
839
+ _trimmed_unflattened_candidate_ids = unflattened_candidate_ids
840
+ _trimmed_logits = logits
841
+ best_candidate, accept_length = evaluate_posterior(
842
+ _trimmed_logits,
843
+ _trimmed_unflattened_candidate_ids,
844
+ temperature,
845
+ posterior_threshold,
846
+ posterior_alpha,
847
+ top_p=top_p,
848
+ sampling=sampling,
849
+ fast=fast
850
+ )
851
+
852
+ ####################################################
853
+ # 5. update model inputs and caches
854
+ ####################################################
855
+ input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
856
+ input_ids,
857
+ medusa_logits,
858
+ logits,
859
+ unflattened_candidate_ids,
860
+ best_candidate,
861
+ accept_length,
862
+ )
863
+
864
+ past_key_values = self.multi_byte_pred_update_cache(
865
+ past_key_values,
866
+ medusa_buffers["retrieve_indices"],
867
+ best_candidate,
868
+ last_iter_new_tokens,
869
+ )
870
+
871
+ if return_acc_lengths:
872
+ acc_lengths.append(last_iter_new_tokens)
873
+ if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
874
+ if return_acc_lengths:
875
+ return input_ids, acc_lengths
876
+ else:
877
+ return input_ids
878
+ if return_acc_lengths:
879
+ return input_ids, acc_lengths
880
+ else:
881
+ return input_ids
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_evabyte.EvaByteImageProcessor",
4
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
5
+ },
6
+ "do_convert_rgb": true,
7
+ "do_resize": true,
8
+ "image_processor_type": "EvaByteImageProcessor",
9
+ "jpeg_quality": 25,
10
+ "jpeg_restart_marker_blocks": 1,
11
+ "jpeg_streamtype": 2,
12
+ "jpeg_subsampling": "4:2:0",
13
+ "processor_class": "EvaByteProcessor",
14
+ "resample": 1,
15
+ "size": {
16
+ "longest_edge": 384
17
+ }
18
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processing_evabyte.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Processor class for EvaByte.
4
+ """
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ import requests
9
+ import os
10
+ import PIL
11
+ from PIL import Image
12
+
13
+ from typing import List, Optional, Union
14
+
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.image_utils import ImageInput, is_valid_image
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
19
+ from transformers.utils import TensorType, to_py_obj
20
+
21
+ def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
22
+ image_obj = None
23
+ if isinstance(image, Image.Image):
24
+ image_obj = image
25
+ elif image.startswith("http://") or image.startswith("https://"):
26
+ image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
27
+ elif os.path.isfile(image):
28
+ image_obj = Image.open(image)
29
+ elif image.startswith("data:image/"):
30
+ image = image.split(",")[1]
31
+ # Try to load as base64
32
+ try:
33
+ b64 = base64.decodebytes(image.encode())
34
+ image = PIL.Image.open(BytesIO(b64))
35
+ except Exception as e:
36
+ raise ValueError(
37
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
38
+ )
39
+ else:
40
+ image_obj = Image.open(image)
41
+ if image_obj is None:
42
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
43
+
44
+ return image_obj
45
+
46
+ def is_url(val) -> bool:
47
+ return isinstance(val, str) and val.startswith("http")
48
+
49
+ def is_file(val) -> bool:
50
+ return isinstance(val, str) and os.path.isfile(val)
51
+
52
+ def is_image_or_image_url(elem):
53
+ return is_url(elem) or is_valid_image(elem) or is_file(elem)
54
+
55
+ vl_chat_template = """
56
+ {{- bos_token }}
57
+ {%- if messages[0]['role'] == 'system' %}
58
+ {%- set system_message = messages[0]['content'] %}
59
+ {%- set messages = messages[1:] %}
60
+ {%- else %}
61
+ {%- set system_message = "" %}
62
+ {%- endif %}
63
+
64
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
65
+
66
+ {%- for message in messages %}
67
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
68
+ {{- raise_exception('Conversation roles must be user or assistant') }}
69
+ {%- endif %}
70
+
71
+ {%- if message['content'] is string %}
72
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
73
+ {%- else %}
74
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
75
+ {%- for content in message['content'] %}
76
+ {%- if content['type'] == 'image' %}
77
+ {{- '<image_placeholder>\n' }}
78
+ {%- elif content['type'] == 'text' %}
79
+ {{- content['text'] }}
80
+ {%- endif %}
81
+ {%- endfor %}
82
+ {{- '<|eot_id|>' }}
83
+ {%- endif %}
84
+ {%- endfor %}
85
+
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
88
+ {%- endif %}
89
+ """
90
+
91
+ class EvaByteProcessor(ProcessorMixin):
92
+ r"""
93
+ Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.
94
+
95
+ [`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
96
+ [`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.
97
+
98
+ Args:
99
+ image_processor ([`EvaByteImageProcessor`], *optional*):
100
+ The image processor is a required input.
101
+ tokenizer ([`EvaByteTokenizer`], *optional*):
102
+ The tokenizer is a required input.
103
+ """
104
+
105
+ attributes = ["image_processor", "tokenizer"]
106
+ image_processor_class = "AutoImageProcessor"
107
+ tokenizer_class = "AutoTokenizer"
108
+
109
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
110
+ if image_processor is None:
111
+ raise ValueError("You need to specify an `image_processor`.")
112
+ if tokenizer is None:
113
+ raise ValueError("You need to specify a `tokenizer`.")
114
+
115
+ super().__init__(image_processor, tokenizer)
116
+ self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
117
+ self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
118
+ self.image_placeholder = "<image_placeholder>"
119
+ self.vl_chat_template = vl_chat_template
120
+
121
+ def __call__(
122
+ self,
123
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
124
+ images: ImageInput = None,
125
+ return_tensors: Optional[Union[str, TensorType]] = None,
126
+ strip_ending_sentinel: bool = False,
127
+ encode_only: bool = False,
128
+ **kwargs
129
+ ) -> Union[BatchFeature, List[List[int]]]:
130
+ # processing pipeline:
131
+ # 1. read images or videos from paths
132
+ # 2. use image_processor to convert images / videos to byte streams
133
+ if images is not None:
134
+ if isinstance(images, bytes):
135
+ image_bytes_list = [[images]]
136
+ elif isinstance(images, list) and isinstance(images[0], bytes):
137
+ image_bytes_list = [images]
138
+ elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
139
+ image_bytes_list = images
140
+ else:
141
+ if is_image_or_image_url(images):
142
+ images = [[images]]
143
+ elif isinstance(images, list) and is_image_or_image_url(images[0]):
144
+ images = [images]
145
+ elif (
146
+ not isinstance(images, list)
147
+ and not isinstance(images[0], list)
148
+ and not is_image_or_image_url(images[0][0])
149
+ ):
150
+ raise ValueError(
151
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
152
+ )
153
+ # Load images if they are URLs
154
+ images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
155
+ image_bytes_list = self.image_processor(images=images, **kwargs)
156
+
157
+ if not isinstance(text, list):
158
+ text = [text]
159
+ assert len(text) == 1, "Only support batch size 1 for now"
160
+ assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
161
+ # TODO: invoke SequenceFeatureExtractor to get batched inputs
162
+
163
+ # 3. tokenize the text and put images / videos byte streams into the placeholders
164
+ # surrounded by special tokens like "<image>" and "</image>"
165
+ batch_input_ids = []
166
+ if not encode_only:
167
+ batch_attention_mask = []
168
+ else:
169
+ batch_attention_mask = None
170
+
171
+ for t, image_bytes in zip(text, image_bytes_list):
172
+ text_splits = t.split(self.image_placeholder)
173
+ if len(text_splits) != len(image_bytes) + 1:
174
+ raise ValueError(
175
+ f"The number of image tokens should be equal to the number of images, "
176
+ f"but got {len(text_splits)} and {len(image_bytes) + 1}"
177
+ )
178
+
179
+ input_ids = [self.tokenizer.bos_token_id]
180
+ for i, text_part in enumerate(text_splits):
181
+ # each text part must be non-empty because we added markers around placeholders
182
+ split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
183
+ input_ids.extend(split_tokens)
184
+ # Add image bytes after each text part except the last one
185
+ if i < len(image_bytes):
186
+ input_ids.append(self.t2v_token_id)
187
+ input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
188
+ input_ids.append(self.v2t_token_id)
189
+
190
+ if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
191
+ input_ids = input_ids[:-1]
192
+
193
+ batch_input_ids.append(input_ids)
194
+ if not encode_only:
195
+ batch_attention_mask.append([1] * len(input_ids))
196
+
197
+ if not encode_only:
198
+ # 4. return batch of features
199
+ inputs = BatchFeature({
200
+ "input_ids": batch_input_ids,
201
+ "attention_mask": batch_attention_mask
202
+ }, tensor_type=return_tensors)
203
+ return inputs
204
+ # # Pad sequences
205
+ # padded_inputs = self.tokenizer.pad(
206
+ # {"input_ids": batch_input_ids},
207
+ # padding=True,
208
+ # return_attention_mask=True,
209
+ # return_tensors=return_tensors,
210
+ # )
211
+ # return BatchFeature(data=padded_inputs)
212
+ else:
213
+ return batch_input_ids
214
+
215
+ def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
216
+ image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
217
+ image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
218
+ return image_bytes
219
+
220
+ def batch_decode(self, sequences, **kwargs):
221
+ """
222
+ This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
223
+ refer to the docstring of this method for more information.
224
+ """
225
+ rets = [self.decode(seq, **kwargs) for seq in sequences]
226
+ return tuple(map(list, zip(*rets)))
227
+
228
+ def decode(self, token_ids, **kwargs):
229
+ """
230
+ Decodes a sequence of input_ids, handling image tokens separately.
231
+ Returns a tuple of (decoded_text, images), where images is a list of bytes.
232
+ """
233
+ if kwargs and "jpeg_quality" in kwargs:
234
+ kwargs = kwargs.copy()
235
+ jpeg_quality = kwargs.pop("jpeg_quality")
236
+ else:
237
+ jpeg_quality = None
238
+
239
+ token_ids = to_py_obj(token_ids)
240
+ # Find indices of t2v_token_id and v2t_token_id
241
+ t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
242
+ v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
243
+
244
+ # Check for correct pairing of t2v and v2t tokens
245
+ if len(t2v_indices) != len(v2t_indices):
246
+ raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))
247
+
248
+ # Ensure t2v and v2t tokens are in the correct order
249
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
250
+ if t2v_idx >= v2t_idx:
251
+ raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")
252
+
253
+ # Initialize the start index
254
+ images = []
255
+ decoded_text = ""
256
+
257
+ start = 0
258
+ # Iterate over pairs of t2v and v2t indices
259
+ for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
260
+ # Decode text tokens before the image
261
+ text_token_ids = token_ids[start:t2v_idx]
262
+ if len(text_token_ids) > 0:
263
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
264
+
265
+ # Insert image placeholder
266
+ decoded_text += self.image_placeholder
267
+
268
+ # Extract image tokens and convert them to bytes
269
+ image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
270
+ image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
271
+ images.append(image_bytes)
272
+
273
+ # Update the start index to the token after v2t_token_id
274
+ start = v2t_idx + 1
275
+
276
+ # Decode any remaining text tokens after the last image
277
+ if start < len(token_ids):
278
+ text_token_ids = token_ids[start:]
279
+ decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)
280
+
281
+ return decoded_text, images
282
+
283
+ @property
284
+ def model_input_names(self):
285
+ tokenizer_input_names = self.tokenizer.model_input_names
286
+ image_processor_input_names = self.image_processor.model_input_names
287
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor"
4
+ },
5
+ "processor_class": "EvaByteProcessor"
6
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/special_tokens_map.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<repo_name>",
4
+ "<file_sep>",
5
+ "<t2v_token>",
6
+ "<v2t_token>",
7
+ "<|start_header_id|>",
8
+ "<|end_header_id|>",
9
+ "<|eot_id|>",
10
+ "<extra_id_12>",
11
+ "<extra_id_13>",
12
+ "<extra_id_14>",
13
+ "<extra_id_15>",
14
+ "<extra_id_16>",
15
+ "<extra_id_17>",
16
+ "<extra_id_18>",
17
+ "<extra_id_19>",
18
+ "<extra_id_20>",
19
+ "<extra_id_21>",
20
+ "<extra_id_22>",
21
+ "<extra_id_23>",
22
+ "<extra_id_24>",
23
+ "<extra_id_25>",
24
+ "<extra_id_26>",
25
+ "<extra_id_27>",
26
+ "<extra_id_28>",
27
+ "<extra_id_29>",
28
+ "<extra_id_30>",
29
+ "<extra_id_31>",
30
+ "<extra_id_32>",
31
+ "<extra_id_33>",
32
+ "<extra_id_34>",
33
+ "<extra_id_35>",
34
+ "<extra_id_36>",
35
+ "<extra_id_37>",
36
+ "<extra_id_38>",
37
+ "<extra_id_39>",
38
+ "<extra_id_40>",
39
+ "<extra_id_41>",
40
+ "<extra_id_42>",
41
+ "<extra_id_43>",
42
+ "<extra_id_44>",
43
+ "<extra_id_45>",
44
+ "<extra_id_46>",
45
+ "<extra_id_47>",
46
+ "<extra_id_48>",
47
+ "<extra_id_49>",
48
+ "<extra_id_50>",
49
+ "<extra_id_51>",
50
+ "<extra_id_52>",
51
+ "<extra_id_53>",
52
+ "<extra_id_54>",
53
+ "<extra_id_55>",
54
+ "<extra_id_56>",
55
+ "<extra_id_57>",
56
+ "<extra_id_58>",
57
+ "<extra_id_59>",
58
+ "<extra_id_60>",
59
+ "<extra_id_61>",
60
+ "<extra_id_62>",
61
+ "<extra_id_63>"
62
+ ],
63
+ "bos_token": {
64
+ "content": "<bos>",
65
+ "lstrip": false,
66
+ "normalized": true,
67
+ "rstrip": false,
68
+ "single_word": false
69
+ },
70
+ "eos_token": {
71
+ "content": "<eos>",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false
76
+ },
77
+ "pad_token": {
78
+ "content": "<pad>",
79
+ "lstrip": false,
80
+ "normalized": true,
81
+ "rstrip": false,
82
+ "single_word": false
83
+ },
84
+ "sep_token": {
85
+ "content": "<sep>",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false
90
+ },
91
+ "unk_token": {
92
+ "content": "<unk>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false
97
+ }
98
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenization_evabyte.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """ Tokenization class for model EvaByte."""
4
+
5
+
6
+ from typing import List, Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ chat_template = """
16
+ {{- bos_token }}
17
+ {%- if messages[0]['role'] == 'system' %}
18
+ {%- set system_message = messages[0]['content'] %}
19
+ {%- set messages = messages[1:] %}
20
+ {%- else %}
21
+ {%- set system_message = "" %}
22
+ {%- endif %}
23
+
24
+ {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}
25
+
26
+ {%- for message in messages %}
27
+ {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
28
+ {{- raise_exception('Conversation roles must be user or assistant') }}
29
+ {%- endif %}
30
+
31
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
32
+ {%- endfor %}
33
+
34
+ {%- if add_generation_prompt %}
35
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
36
+ {%- endif %}
37
+ """
38
+
39
+ class EvaByteTokenizer(PreTrainedTokenizer):
40
+ def __init__(
41
+ self,
42
+ bos_token="<bos>",
43
+ eos_token="<eos>",
44
+ unk_token="<unk>",
45
+ sep_token="<sep>",
46
+ pad_token="<pad>",
47
+ extra_ids=59,
48
+ additional_special_tokens=None,
49
+ clean_up_tokenization_spaces=False,
50
+ **kwargs,
51
+ ) -> None:
52
+ num_base_special_tokens = 5
53
+ # Add extra_ids to the special token list
54
+ if extra_ids > 0 and additional_special_tokens is None:
55
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(num_base_special_tokens, extra_ids + num_base_special_tokens)]
56
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
57
+ # Check that we have the right number of extra_id special tokens
58
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
59
+ if extra_tokens != extra_ids:
60
+ raise ValueError(
61
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
62
+ " provided to EvaByteTokenizer. In this case the additional_special_tokens must include the"
63
+ " extra_ids tokens"
64
+ )
65
+
66
+ #### override some reserved tokens to support chat template
67
+ for i, token in enumerate(additional_special_tokens):
68
+ if token == "<extra_id_5>":
69
+ token = "<repo_name>"
70
+ elif token == "<extra_id_6>":
71
+ token = "<file_sep>"
72
+ elif token == "<extra_id_7>":
73
+ token = "<t2v_token>"
74
+ elif token == "<extra_id_8>":
75
+ token = "<v2t_token>"
76
+ elif token == "<extra_id_9>":
77
+ token = "<|start_header_id|>"
78
+ elif token == "<extra_id_10>":
79
+ token = "<|end_header_id|>"
80
+ elif token == "<extra_id_11>":
81
+ token = "<|eot_id|>"
82
+ additional_special_tokens[i] = token
83
+
84
+ # lstrip and rstrip are set to False because we don't want to strip the whitespace from the special tokens
85
+ # this would be important for the byte tokenizer
86
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
87
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
88
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
89
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
90
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
91
+
92
+ self._added_tokens_decoder = {
93
+ 0: pad_token,
94
+ 1: bos_token,
95
+ 2: eos_token,
96
+ 3: unk_token, # unk_token is a placeholder
97
+ 4: sep_token,
98
+ **{i: AddedToken(t, lstrip=False, rstrip=False) for i, t in enumerate(additional_special_tokens, start=num_base_special_tokens)},
99
+ }
100
+ self.offset = len(self._added_tokens_decoder)
101
+ self._utf_vocab_size = 2**8 # utf is 8 bits
102
+ self.add_bos_token = True
103
+ self.add_eos_token = False
104
+ super().__init__(
105
+ pad_token=pad_token,
106
+ bos_token=bos_token,
107
+ eos_token=eos_token,
108
+ unk_token=unk_token,
109
+ sep_token=sep_token,
110
+ extra_ids=0,
111
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
112
+ additional_special_tokens=additional_special_tokens,
113
+ **kwargs,
114
+ )
115
+ self.chat_template = chat_template
116
+
117
+
118
+ @property
119
+ def vocab_size(self):
120
+ return self._utf_vocab_size
121
+
122
+ def get_vocab(self):
123
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
124
+ vocab.update(self.added_tokens_encoder)
125
+ return vocab
126
+
127
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
128
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
129
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
130
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
131
+
132
+ output = bos_token_id + token_ids_0 + eos_token_id
133
+
134
+ if token_ids_1 is not None:
135
+ output = output + bos_token_id + token_ids_1 + eos_token_id
136
+
137
+ return output
138
+
139
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
140
+ def get_special_tokens_mask(
141
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
142
+ ) -> List[int]:
143
+ """
144
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
145
+ special tokens using the tokenizer `prepare_for_model` method.
146
+
147
+ Args:
148
+ token_ids_0 (`List[int]`):
149
+ List of IDs.
150
+ token_ids_1 (`List[int]`, *optional*):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
153
+ Whether or not the token list is already formatted with special tokens for the model.
154
+
155
+ Returns:
156
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
157
+ """
158
+ if already_has_special_tokens:
159
+ return super().get_special_tokens_mask(
160
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
161
+ )
162
+
163
+ bos_token_id = [1] if self.add_bos_token else []
164
+ eos_token_id = [1] if self.add_eos_token else []
165
+
166
+ if token_ids_1 is None:
167
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
168
+ return (
169
+ bos_token_id
170
+ + ([0] * len(token_ids_0))
171
+ + eos_token_id
172
+ + bos_token_id
173
+ + ([0] * len(token_ids_1))
174
+ + eos_token_id
175
+ )
176
+
177
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
178
+ def create_token_type_ids_from_sequences(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
183
+ sequence pair mask has the following format:
184
+
185
+ ```
186
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
187
+ | first sequence | second sequence |
188
+ ```
189
+
190
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
191
+
192
+ Args:
193
+ token_ids_0 (`List[int]`):
194
+ List of ids.
195
+ token_ids_1 (`List[int]`, *optional*):
196
+ Optional second list of IDs for sequence pairs.
197
+
198
+ Returns:
199
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
200
+ """
201
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
202
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
203
+
204
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
205
+
206
+ if token_ids_1 is not None:
207
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
208
+
209
+ return output
210
+
211
+ def _tokenize(self, text: str) -> List[str]:
212
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
213
+ tokens = [chr(i) for i in text.encode("utf-8")]
214
+ return tokens
215
+
216
+ def _convert_token_to_id(self, token):
217
+ """Converts a token (str) in an id using the vocab."""
218
+
219
+ if len(token) != 1:
220
+ token_id = None
221
+ else:
222
+ token_id = ord(token) + self.offset
223
+
224
+ return token_id
225
+
226
+ def _convert_id_to_token(self, index):
227
+ """Converts an index (integer) to a byte (str) using the vocab."""
228
+ token = chr(index - self.offset)
229
+ return token
230
+
231
+ def convert_tokens_to_string(self, tokens):
232
+ """Converts a sequence of bytes (string) to a single string."""
233
+ bstring = b""
234
+ for token in tokens:
235
+ if token in self.added_tokens_decoder:
236
+ tok_string = self.added_tokens_decoder[token].encode("utf-8")
237
+ elif token in self.added_tokens_encoder:
238
+ tok_string = token.encode("utf-8")
239
+ else:
240
+ tok_string = bytes([ord(token)])
241
+ bstring += tok_string
242
+ string = bstring.decode("utf-8", errors="ignore")
243
+ return string
244
+
245
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
246
+ return ()
ckpts/ocpython_14b_bsz-2m_seq16k_100raw_docmask_100B_2m_step-50000/tokenizer_config.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<sep>",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<repo_name>",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": false
50
+ },
51
+ "6": {
52
+ "content": "<file_sep>",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ },
59
+ "7": {
60
+ "content": "<t2v_token>",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": false
66
+ },
67
+ "8": {
68
+ "content": "<v2t_token>",
69
+ "lstrip": false,
70
+ "normalized": true,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": false
74
+ },
75
+ "9": {
76
+ "content": "<|start_header_id|>",
77
+ "lstrip": false,
78
+ "normalized": true,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "10": {
84
+ "content": "<|end_header_id|>",
85
+ "lstrip": false,
86
+ "normalized": true,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "11": {
92
+ "content": "<|eot_id|>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "12": {
100
+ "content": "<extra_id_12>",
101
+ "lstrip": false,
102
+ "normalized": true,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "13": {
108
+ "content": "<extra_id_13>",
109
+ "lstrip": false,
110
+ "normalized": true,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "14": {
116
+ "content": "<extra_id_14>",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": false
122
+ },
123
+ "15": {
124
+ "content": "<extra_id_15>",
125
+ "lstrip": false,
126
+ "normalized": true,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": false
130
+ },
131
+ "16": {
132
+ "content": "<extra_id_16>",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": false
138
+ },
139
+ "17": {
140
+ "content": "<extra_id_17>",
141
+ "lstrip": false,
142
+ "normalized": true,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": false
146
+ },
147
+ "18": {
148
+ "content": "<extra_id_18>",
149
+ "lstrip": false,
150
+ "normalized": true,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": false
154
+ },
155
+ "19": {
156
+ "content": "<extra_id_19>",
157
+ "lstrip": false,
158
+ "normalized": true,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "20": {
164
+ "content": "<extra_id_20>",
165
+ "lstrip": false,
166
+ "normalized": true,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "21": {
172
+ "content": "<extra_id_21>",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": false
178
+ },
179
+ "22": {
180
+ "content": "<extra_id_22>",
181
+ "lstrip": false,
182
+ "normalized": true,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": false
186
+ },
187
+ "23": {
188
+ "content": "<extra_id_23>",
189
+ "lstrip": false,
190
+ "normalized": true,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": false
194
+ },
195
+ "24": {
196
+ "content": "<extra_id_24>",
197
+ "lstrip": false,
198
+ "normalized": true,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": false
202
+ },
203
+ "25": {
204
+ "content": "<extra_id_25>",
205
+ "lstrip": false,
206
+ "normalized": true,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": false
210
+ },
211
+ "26": {
212
+ "content": "<extra_id_26>",
213
+ "lstrip": false,
214
+ "normalized": true,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": false
218
+ },
219
+ "27": {
220
+ "content": "<extra_id_27>",
221
+ "lstrip": false,
222
+ "normalized": true,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": false
226
+ },
227
+ "28": {
228
+ "content": "<extra_id_28>",
229
+ "lstrip": false,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": false
234
+ },
235
+ "29": {
236
+ "content": "<extra_id_29>",
237
+ "lstrip": false,
238
+ "normalized": true,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": false
242
+ },
243
+ "30": {
244
+ "content": "<extra_id_30>",
245
+ "lstrip": false,
246
+ "normalized": true,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": false
250
+ },
251
+ "31": {
252
+ "content": "<extra_id_31>",
253
+ "lstrip": false,
254
+ "normalized": true,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": false
258
+ },
259
+ "32": {
260
+ "content": "<extra_id_32>",
261
+ "lstrip": false,
262
+ "normalized": true,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": false
266
+ },
267
+ "33": {
268
+ "content": "<extra_id_33>",
269
+ "lstrip": false,
270
+ "normalized": true,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": false
274
+ },
275
+ "34": {
276
+ "content": "<extra_id_34>",
277
+ "lstrip": false,
278
+ "normalized": true,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": false
282
+ },
283
+ "35": {
284
+ "content": "<extra_id_35>",
285
+ "lstrip": false,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": false
290
+ },
291
+ "36": {
292
+ "content": "<extra_id_36>",
293
+ "lstrip": false,
294
+ "normalized": true,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": false
298
+ },
299
+ "37": {
300
+ "content": "<extra_id_37>",
301
+ "lstrip": false,
302
+ "normalized": true,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": false
306
+ },
307
+ "38": {
308
+ "content": "<extra_id_38>",
309
+ "lstrip": false,
310
+ "normalized": true,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": false
314
+ },
315
+ "39": {
316
+ "content": "<extra_id_39>",
317
+ "lstrip": false,
318
+ "normalized": true,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": false
322
+ },
323
+ "40": {
324
+ "content": "<extra_id_40>",
325
+ "lstrip": false,
326
+ "normalized": true,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": false
330
+ },
331
+ "41": {
332
+ "content": "<extra_id_41>",
333
+ "lstrip": false,
334
+ "normalized": true,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": false
338
+ },
339
+ "42": {
340
+ "content": "<extra_id_42>",
341
+ "lstrip": false,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": false
346
+ },
347
+ "43": {
348
+ "content": "<extra_id_43>",
349
+ "lstrip": false,
350
+ "normalized": true,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": false
354
+ },
355
+ "44": {
356
+ "content": "<extra_id_44>",
357
+ "lstrip": false,
358
+ "normalized": true,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": false
362
+ },
363
+ "45": {
364
+ "content": "<extra_id_45>",
365
+ "lstrip": false,
366
+ "normalized": true,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": false
370
+ },
371
+ "46": {
372
+ "content": "<extra_id_46>",
373
+ "lstrip": false,
374
+ "normalized": true,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": false
378
+ },
379
+ "47": {
380
+ "content": "<extra_id_47>",
381
+ "lstrip": false,
382
+ "normalized": true,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": false
386
+ },
387
+ "48": {
388
+ "content": "<extra_id_48>",
389
+ "lstrip": false,
390
+ "normalized": true,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": false
394
+ },
395
+ "49": {
396
+ "content": "<extra_id_49>",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": false
402
+ },
403
+ "50": {
404
+ "content": "<extra_id_50>",
405
+ "lstrip": false,
406
+ "normalized": true,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": false
410
+ },
411
+ "51": {
412
+ "content": "<extra_id_51>",
413
+ "lstrip": false,
414
+ "normalized": true,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": false
418
+ },
419
+ "52": {
420
+ "content": "<extra_id_52>",
421
+ "lstrip": false,
422
+ "normalized": true,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": false
426
+ },
427
+ "53": {
428
+ "content": "<extra_id_53>",
429
+ "lstrip": false,
430
+ "normalized": true,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": false
434
+ },
435
+ "54": {
436
+ "content": "<extra_id_54>",
437
+ "lstrip": false,
438
+ "normalized": true,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": false
442
+ },
443
+ "55": {
444
+ "content": "<extra_id_55>",
445
+ "lstrip": false,
446
+ "normalized": true,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": false
450
+ },
451
+ "56": {
452
+ "content": "<extra_id_56>",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": false
458
+ },
459
+ "57": {
460
+ "content": "<extra_id_57>",
461
+ "lstrip": false,
462
+ "normalized": true,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": false
466
+ },
467
+ "58": {
468
+ "content": "<extra_id_58>",
469
+ "lstrip": false,
470
+ "normalized": true,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": false
474
+ },
475
+ "59": {
476
+ "content": "<extra_id_59>",
477
+ "lstrip": false,
478
+ "normalized": true,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": false
482
+ },
483
+ "60": {
484
+ "content": "<extra_id_60>",
485
+ "lstrip": false,
486
+ "normalized": true,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": false
490
+ },
491
+ "61": {
492
+ "content": "<extra_id_61>",
493
+ "lstrip": false,
494
+ "normalized": true,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": false
498
+ },
499
+ "62": {
500
+ "content": "<extra_id_62>",
501
+ "lstrip": false,
502
+ "normalized": true,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": false
506
+ },
507
+ "63": {
508
+ "content": "<extra_id_63>",
509
+ "lstrip": false,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": false
514
+ }
515
+ },
516
+ "additional_special_tokens": [
517
+ "<repo_name>",
518
+ "<file_sep>",
519
+ "<t2v_token>",
520
+ "<v2t_token>",
521
+ "<|start_header_id|>",
522
+ "<|end_header_id|>",
523
+ "<|eot_id|>",
524
+ "<extra_id_12>",
525
+ "<extra_id_13>",
526
+ "<extra_id_14>",
527
+ "<extra_id_15>",
528
+ "<extra_id_16>",
529
+ "<extra_id_17>",
530
+ "<extra_id_18>",
531
+ "<extra_id_19>",
532
+ "<extra_id_20>",
533
+ "<extra_id_21>",
534
+ "<extra_id_22>",
535
+ "<extra_id_23>",
536
+ "<extra_id_24>",
537
+ "<extra_id_25>",
538
+ "<extra_id_26>",
539
+ "<extra_id_27>",
540
+ "<extra_id_28>",
541
+ "<extra_id_29>",
542
+ "<extra_id_30>",
543
+ "<extra_id_31>",
544
+ "<extra_id_32>",
545
+ "<extra_id_33>",
546
+ "<extra_id_34>",
547
+ "<extra_id_35>",
548
+ "<extra_id_36>",
549
+ "<extra_id_37>",
550
+ "<extra_id_38>",
551
+ "<extra_id_39>",
552
+ "<extra_id_40>",
553
+ "<extra_id_41>",
554
+ "<extra_id_42>",
555
+ "<extra_id_43>",
556
+ "<extra_id_44>",
557
+ "<extra_id_45>",
558
+ "<extra_id_46>",
559
+ "<extra_id_47>",
560
+ "<extra_id_48>",
561
+ "<extra_id_49>",
562
+ "<extra_id_50>",
563
+ "<extra_id_51>",
564
+ "<extra_id_52>",
565
+ "<extra_id_53>",
566
+ "<extra_id_54>",
567
+ "<extra_id_55>",
568
+ "<extra_id_56>",
569
+ "<extra_id_57>",
570
+ "<extra_id_58>",
571
+ "<extra_id_59>",
572
+ "<extra_id_60>",
573
+ "<extra_id_61>",
574
+ "<extra_id_62>",
575
+ "<extra_id_63>"
576
+ ],
577
+ "auto_map": {
578
+ "AutoProcessor": "processing_evabyte.EvaByteProcessor",
579
+ "AutoTokenizer": [
580
+ "tokenization_evabyte.EvaByteTokenizer",
581
+ null
582
+ ]
583
+ },
584
+ "bos_token": "<bos>",
585
+ "chat_template": "\n{{- bos_token }}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}\n\n{%- for message in messages %}\n {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}\n {{- raise_exception('Conversation roles must be user or assistant') }}\n {%- endif %}\n\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}\n{%- endif %}\n",
586
+ "clean_up_tokenization_spaces": false,
587
+ "eos_token": "<eos>",
588
+ "extra_ids": 0,
589
+ "extra_special_tokens": {},
590
+ "model_max_length": 1000000000000000019884624838656,
591
+ "pad_token": "<pad>",
592
+ "processor_class": "EvaByteProcessor",
593
+ "sep_token": "<sep>",
594
+ "tokenizer_class": "EvaByteTokenizer",
595
+ "unk_token": "<unk>"
596
+ }
ckpts/ocpython_14b_bsz-2m_seq16k_docmask_multipredc2r8_90dynamic-10raw_transsentinel_minsize0ent98line16ow16pack_100B_2m_new_2_step-10000/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # EvaByte Model Card
5
+
6
+ **EvaByte** is a 6.5B **byte-level language model** built upon an improved architecture with multibyte prediction and EVA -- an efficient attention mechanism designed for scalability and performance. Trained on 1.5T bytes spanning natural language text, math, and code, EvaByte demonstrates the viability of efficient byte-level processing at scale -- rivaling top open-source tokenizer-based LMs using 5x less training data, excelling in coding tasks, and decoding up to 2x faster.
7
+
8
+ ## Model Resources
9
+
10
+ - **Repository:** https://github.com/openevabyte/evabyte
11
+ - **Blog:** https://hkunlp.github.io/blog/2025/evabyte and https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale
12
+ - **Paper:** Coming soon
13
+
14
+ ## Model Details
15
+
16
+ EvaByte is trained using the performant SambaNova SN30 RDU system with a batch size of 8M bytes and 32K context length. The training process consists of 3 phases: after pre-training on 1.2T bytes (yielding **EvaByte-Phase1**), two independent annealing runs (100B and 200B bytes respectively) are conducted with learning rate linearly decayed from 1e-4 to 0. The resulting checkpoints are merged via model soup (**EvaByte**), which then undergoes supervised fine-tuning (**EvaByte-SFT**).
17
+
18
+ | Stage | Model |
19
+ |:----- |:-----|
20
+ | Base (before annealing) | [EvaByte-Phase1](https://huggingface.co/evabyte/EvaByte-Phase1) |
21
+ | Base | [EvaByte](https://huggingface.co/evabyte/EvaByte) <-- you are here |
22
+ | SFT | [EvaByte-SFT](https://huggingface.co/evabyte/EvaByte-SFT) |
23
+
24
+
25
+ ## Usage
26
+
27
+ **Note:** Make sure to set `trust_remote_code=True` when loading the model (or tokenizer), as our implementation includes custom code.
28
+
29
+ The code snippet below demonstrates EvaByte-6.5B for completion:
30
+
31
+ ```python
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ import torch
34
+
35
+ # Load model and tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte", trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
38
+
39
+ prompt = "The quick brown fox jumps "
40
+
41
+ # Tokenize input
42
+ # Option 1: standard HF tokenizer interface
43
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
44
+
45
+ # Option 2: Direct UTF-8 byte encoding with offset
46
+ # Note: Each byte is offset by 64 with <bos> prepended.
47
+ input_ids = torch.tensor([[1] + [b + 64 for b in prompt.encode("utf-8")]]).to("cuda")
48
+
49
+ # byte-by-byte generation (default)
50
+ generation_output = model.generate(
51
+ input_ids=input_ids,
52
+ max_new_tokens=32
53
+ )
54
+ # alternatively, use faster multibyte generation
55
+ generation_output = model.multi_byte_generate(
56
+ input_ids=input_ids,
57
+ max_new_tokens=32
58
+ )
59
+
60
+ # Decode and print the output
61
+ response = tokenizer.decode(
62
+ generation_output[0][input_ids.shape[1]:],
63
+ skip_special_tokens=False,
64
+ clean_up_tokenization_spaces=False
65
+ )
66
+ print(response)
67
+ # Sample output:
68
+ # over the lazy dog.\n\nThe quick
69
+ ```
70
+
71
+ ### ⚙️ Generation Modes
72
+
73
+ EvaByte supports two generation interfaces:
74
+ - `model.generate()`: The default generation method compatible with Huggingface `transformers` library. This approach generates one byte at a time and might be slow.
75
+ - `model.multi_byte_generate()`: A faster alternative that generates multiple bytes per step and usually yields the same result as `model.generate()` under greedy decoding, with the implementation adapted from [Medusa](https://github.com/FasterDecoding/Medusa). `model.multi_byte_generate()` supports a subset of arguments in `model.generate()`:
76
+ - `input_ids`: the input byte ids.
77
+ - `temperature`: the temperature for sampling.
78
+ - `max_length`: the maximum length of the generated sequence.
79
+ - `max_new_tokens`: the maximum number of new bytes to generate.
80
+ - `stopping_criteria`: the [stopping criteria](https://huggingface.co/docs/transformers/v4.47.1/en/internal/generation_utils#transformers.StoppingCriteria) for generation.
81
+ - `top_p`: the top-p parameter for sampling.
82
+ - `do_sample`: greedy decoding or sampling.
83
+
84
+ **Notes and Limitations:**
85
+ - `device_map="auto"` is not supported for >2 GPUs.
86
+ - Only batch size of 1 (with `attention_mask=None`) is supported for decoding.
87
+ - `torch_dtype=torch.bfloat16` is required.
88
+ - The multibyte generation `model.multi_byte_generate()` might return extra bytes after the end-of-sequence sentinel, due to the nature of the multibyte decoding. Manual truncation or cleaning may be needed.
89
+
90
+ ## Bias, Risks, and Limitations
91
+ As a pretrained base model, **EvaByte** has not been fine-tuned for chat or instruction following, so users should not expect reliable performance in conversational or instruction-based tasks. Like other base models, it does not incorporate any moderation mechanisms, making it possible to generate potentially harmful or inappropriate content.
92
+
93
+ ## Evaluation
94
+
95
+ For detailed evaluation results, check out our blog post at [SambaNova](https://sambanova.ai/blog/evabyte-efficient-byte-level-language-models-at-scale) or [HKUNLP](https://hkunlp.github.io/blog/2025/evabyte).
96
+
97
+ ## Citation
98
+ ```bibtex
99
+ @misc{evabyte,
100
+ title = {EvaByte: Efficient Byte-level Language Models at Scale},
101
+ url = {https://hkunlp.github.io/blog/2025/evabyte},
102
+ author = {Lin Zheng and Xueliang Zhao and Guangtao Wang and Chen Wu and David Dong and Angela Wang and Mingran Wang and Yun Du and Haige Bo and Amol Sharma and Bo Li and Kejie Zhang and Changran Hu and Urmish Thakker and Lingpeng Kong},
103
+ year = {2025}
104
+ }
105
+ ```