sochasticbackup commited on
Commit
ea3734f
·
1 Parent(s): 2997d61

second init with torch

Browse files
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,216 +1,46 @@
1
- # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
- *.py[codz]
4
  *$py.class
5
-
6
- # C extensions
7
  *.so
8
-
9
- # Distribution / packaging
10
  .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
  *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py.cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- # Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- # poetry.lock
109
- # poetry.toml
110
-
111
- # pdm
112
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
- # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
- # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
- # pdm.lock
116
- # pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # pixi
121
- # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
- # pixi.lock
123
- # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
- # in the .venv directory. It is recommended not to include this directory in version control.
125
- .pixi
126
-
127
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
- __pypackages__/
129
-
130
- # Celery stuff
131
- celerybeat-schedule
132
- celerybeat.pid
133
-
134
- # Redis
135
- *.rdb
136
- *.aof
137
- *.pid
138
-
139
- # RabbitMQ
140
- mnesia/
141
- rabbitmq/
142
- rabbitmq-data/
143
-
144
- # ActiveMQ
145
- activemq-data/
146
-
147
- # SageMath parsed files
148
- *.sage.py
149
 
150
- # Environments
151
- .env
152
- .envrc
153
- .venv
154
- env/
155
  venv/
 
156
  ENV/
157
- env.bak/
158
- venv.bak/
159
-
160
- # Spyder project settings
161
- .spyderproject
162
- .spyproject
163
-
164
- # Rope project settings
165
- .ropeproject
166
-
167
- # mkdocs documentation
168
- /site
169
-
170
- # mypy
171
- .mypy_cache/
172
- .dmypy.json
173
- dmypy.json
174
 
175
- # Pyre type checker
176
- .pyre/
 
 
 
 
177
 
178
- # pytype static type analyzer
179
- .pytype/
 
180
 
181
- # Cython debug symbols
182
- cython_debug/
 
183
 
184
- # PyCharm
185
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187
- # and can be added to the global gitignore or merged into this file. For a more nuclear
188
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189
- # .idea/
190
 
191
- # Abstra
192
- # Abstra is an AI-powered process automation framework.
193
- # Ignore directories containing user credentials, local state, and settings.
194
- # Learn more at https://abstra.io/docs
195
- .abstra/
196
-
197
- # Visual Studio Code
198
- # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
- # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
- # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
- # you could uncomment the following to ignore the entire vscode folder
202
- # .vscode/
203
-
204
- # Ruff stuff:
205
- .ruff_cache/
206
-
207
- # PyPI configuration file
208
- .pypirc
209
-
210
- # Marimo
211
- marimo/_static/
212
- marimo/_lsp/
213
- __marimo__/
214
 
215
- # Streamlit
216
- .streamlit/secrets.toml
 
1
+ # Python
2
  __pycache__/
3
+ *.py[cod]
4
  *$py.class
 
 
5
  *.so
 
 
6
  .Python
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  *.egg-info/
8
+ dist/
9
+ build/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Virtual environments
 
 
 
 
12
  venv/
13
+ env/
14
  ENV/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+ *~
22
 
23
+ # OS
24
+ .DS_Store
25
+ Thumbs.db
26
 
27
+ # Gradio
28
+ gradio_cached_examples/
29
+ flagged/
30
 
31
+ # Model cache (these will download automatically on HF Spaces)
32
+ *.bin
33
+ *.safetensors
34
+ models/
35
+ checkpoints/
 
36
 
37
+ # Logs
38
+ *.log
39
+ DEPLOY.md
40
+ DEPLOYMENT_READY.md
41
+ README_TASK.md
42
+ SETUP_NOTES.md
43
+ verify_deployment.py
44
+ verify_deployment.py
45
+ install_stripedhyena.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- python_version: 3.11
12
  ---
13
 
14
  Check configuration
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ python_version: 3.10
12
  ---
13
 
14
  Check configuration
evo/configs/evo-1-8k-base_inference.yml CHANGED
@@ -2,8 +2,8 @@ vocab_size: 512
2
  hidden_size: 4096
3
  num_filters: 4096
4
  max_sequence_len: 8192
5
- attn_layer_idxs: [8, 16, 24]
6
- hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31]
7
  num_layers: 32
8
  short_filter_length: 3
9
  num_attention_heads: 32
 
2
  hidden_size: 4096
3
  num_filters: 4096
4
  max_sequence_len: 8192
5
+ attn_layer_idxs: []
6
+ hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
7
  num_layers: 32
8
  short_filter_length: 3
9
  num_attention_heads: 32
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- gradio==4.44.0
2
  torch==2.1.0
3
  numpy==1.24.3
4
  transformers==4.36.0
5
  einops==0.7.0
6
  pyyaml==6.0.1
7
- git+https://github.com/togethercomputer/stripedhyena.git
 
 
 
1
  torch==2.1.0
2
  numpy==1.24.3
3
  transformers==4.36.0
4
  einops==0.7.0
5
  pyyaml==6.0.1
6
+ tokenizers>=0.15.0
7
+ gradio==4.44.0
stripedhyena/__init__.py ADDED
File without changes
stripedhyena/cache.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ from torch import Tensor
9
+
10
+
11
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
12
+ @dataclass
13
+ class InferenceParams:
14
+ """Inference parameters that are passed to the main model in order
15
+ to efficienly calculate and store the context during inference."""
16
+
17
+ max_seqlen: int
18
+ max_batch_size: int
19
+ seqlen_offset: int = 0
20
+ batch_size_offset: int = 0
21
+ key_value_memory_dict: dict = field(default_factory=dict)
22
+ lengths_per_sample: Optional[Tensor] = None
23
+
24
+ def reset(self, max_seqlen, max_batch_size):
25
+ self.max_seqlen = max_seqlen
26
+ self.max_batch_size = max_batch_size
27
+ self.seqlen_offset = 0
28
+ if self.lengths_per_sample is not None:
29
+ self.lengths_per_sample.zero_()
30
+
31
+
32
+ @dataclass
33
+ class RecurrentInferenceParams:
34
+ """Inference parameters passed to blocks with recurrent mode."""
35
+
36
+ fir_filter_length: int = 3
37
+ state_dim: int = 16
38
+ # seqlen_offset not used
39
+ seqlen_offset: int = 0
40
+ fir_state_dict: dict = field(default_factory=dict)
41
+ state_dict: dict = field(default_factory=dict)
42
+
43
+ def reset(self):
44
+ self.fir_filter_length = 3
45
+ self.state_dim = 16
46
+ self.seqlen_offset = 0
stripedhyena/engine.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+ import gc
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ try:
11
+ import conv1d_cpp
12
+ except:
13
+ pass
14
+ from stripedhyena.utils import column_split
15
+
16
+ IIR_PREFILL_MODES = [
17
+ "recurrence",
18
+ "modal-fft",
19
+ "hybrid-modal-recurrence",
20
+ "modal-scan",
21
+ "canonical-fft",
22
+ "iir-fir-caching",
23
+ ]
24
+
25
+
26
+ def canonicalize_modal_system(poles, residues):
27
+ """Canonicalize a modal system.
28
+
29
+ Args:
30
+ poles (Tensor): The poles of the system.
31
+ residues (Tensor): The residues of the system.
32
+
33
+ Returns:
34
+ Tuple[Tensor, Tensor]: The canonicalized poles and residues.
35
+ """
36
+ raise NotImplementedError
37
+
38
+
39
+ def list_tensors(idx):
40
+ for obj in gc.get_objects():
41
+ try:
42
+ if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
43
+ # dump to log
44
+ print(type(obj), obj.size())
45
+ el = obj[0]
46
+ with open(f"tensors_{idx}.txt", "a") as f:
47
+ f.write(f"{type(obj)} {obj.size()} {el}\n")
48
+ except Exception as e:
49
+ pass
50
+
51
+
52
+ class HyenaInferenceEngine:
53
+ def __init__(
54
+ self,
55
+ fir_fn=None,
56
+ iir_prefill_style="modal-fft",
57
+ layer_idx=None,
58
+ ) -> None:
59
+ self.fir_fn = fir_fn
60
+ assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
61
+ self.iir_prefill_style = iir_prefill_style
62
+ self.layer_idx = layer_idx
63
+ self.low_mem_mode = False
64
+
65
+ def parallel_fir(
66
+ self,
67
+ fir_fn,
68
+ u,
69
+ weight,
70
+ bias,
71
+ L,
72
+ fir_length=3,
73
+ inference_params=None,
74
+ prefill_mode=None,
75
+ padding_mask=None,
76
+ ):
77
+ """Compute the output state of the long convolutional filter."""
78
+ # prepare input layout, dimensions and dispatch to fir kernel
79
+ if fir_fn != torch.nn.functional.conv1d:
80
+ z_pre = fir_fn(u)[:, :L] # B, L, D
81
+ z_pre = z_pre.permute(0, 2, 1)
82
+ else:
83
+ u = u.permute(0, 2, 1) # B, D, L
84
+ z_pre = fir_fn(
85
+ u,
86
+ weight,
87
+ bias=None, # don't pass it here, add manually instead! source of small error
88
+ stride=1,
89
+ padding=fir_length - 1,
90
+ groups=u.shape[1],
91
+ )[..., :L]
92
+
93
+ # add manually instead! source of small error
94
+ z_pre = z_pre + bias[None, :, None]
95
+
96
+ # handle padding post fir, the only place with biases
97
+ if type(padding_mask) == torch.Tensor:
98
+ z_pre = z_pre * padding_mask[:, None]
99
+
100
+ if inference_params is not None:
101
+ # handle seqlen last and dim last cases for `u`
102
+ if fir_fn != torch.nn.functional.conv1d:
103
+ fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
104
+ else:
105
+ fir_state = u[..., -fir_length + 1 :]
106
+ else:
107
+ fir_state = None
108
+
109
+ return z_pre, fir_state
110
+
111
+ def parallel_iir(
112
+ self,
113
+ z_pre,
114
+ h,
115
+ D,
116
+ L,
117
+ poles,
118
+ residues,
119
+ t,
120
+ dims,
121
+ layer_idx,
122
+ inference_params=None,
123
+ prefill_style="fft",
124
+ fftconv_fn=None,
125
+ padding_mask=None,
126
+ use_flashfft=False,
127
+ column_split_hyena=False,
128
+ long_fir_threshold=None,
129
+ ):
130
+ """Compute the output state of the short convolutional filter."""
131
+ fft_size = 2 * L
132
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
133
+ # Compatibility with training infra that column splits the projections
134
+ if column_split_hyena:
135
+ z = z_pre.reshape(
136
+ z_pre.shape[0],
137
+ num_attention_heads,
138
+ 3 * hidden_size_per_attention_head,
139
+ z_pre.shape[2],
140
+ )
141
+ x2, x1, v = (
142
+ z[:, :, :hidden_size_per_attention_head],
143
+ z[
144
+ :,
145
+ :,
146
+ hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
147
+ ],
148
+ z[:, :, 2 * hidden_size_per_attention_head :],
149
+ )
150
+ x2, x1, v = (
151
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
152
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
153
+ v.reshape(v.shape[0], -1, v.shape[-1]),
154
+ )
155
+ else:
156
+ x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
157
+
158
+ x1v = x1 * v
159
+
160
+ if inference_params is not None and prefill_style == "recurrence":
161
+ y = self.prefill_via_direct_recurrence(
162
+ inference_params=inference_params,
163
+ x1v=x1v,
164
+ L=L,
165
+ poles=poles,
166
+ residues=residues,
167
+ )
168
+
169
+ else:
170
+ if use_flashfft and (L % 2) == 0: # only works with even L
171
+ y = fftconv_fn(
172
+ x1v.to(dtype=torch.bfloat16).contiguous(),
173
+ h.to(dtype=torch.float32),
174
+ )
175
+ X_s = None
176
+
177
+ elif long_fir_threshold is None:
178
+ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
179
+ X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
180
+ X = X_s[..., : H.shape[-1]]
181
+ if len(z_pre.shape) > 3:
182
+ H = H.unsqueeze(1)
183
+ y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
184
+
185
+ else:
186
+ assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
187
+ h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
188
+ h = h[..., :long_fir_threshold]
189
+ y = F.conv1d(
190
+ x1v,
191
+ h.to(dtype=x1v.dtype),
192
+ stride=1,
193
+ groups=x1v.shape[1],
194
+ padding=h.shape[-1] - 1,
195
+ )[..., :L]
196
+
197
+ y = y.to(dtype=x1v.dtype)
198
+ y = (y + x1v * D.unsqueeze(-1)) * x2
199
+
200
+ if inference_params is not None:
201
+ if prefill_style == "fft":
202
+ self.prefill_via_modal_fft(
203
+ inference_params=inference_params,
204
+ x1v=x1v,
205
+ X_s=X_s,
206
+ L=L,
207
+ t=t,
208
+ poles=poles,
209
+ dims=dims,
210
+ layer_idx=layer_idx,
211
+ use_flashfft=use_flashfft,
212
+ fftconv_fn=fftconv_fn,
213
+ )
214
+
215
+ elif prefill_style == "recurrence":
216
+ # recurrent prefill is done before
217
+ pass
218
+ else:
219
+ raise NotImplementedError
220
+ if self.low_mem_mode:
221
+ # TODO: smarter gc
222
+ del z_pre, x2, x1, v, x1v, h, poles, residues
223
+ torch.cuda.empty_cache()
224
+
225
+ return y.permute(0, 2, 1)
226
+
227
+ def step_fir(self, u, fir_state, weight, bias=None):
228
+ """Step the FIR filter.
229
+
230
+ Note:
231
+ `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
232
+ We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
233
+ """
234
+ h0, h = weight[..., 0, -1], weight[..., 0, :-1]
235
+ h0, h = h0[None], h[None]
236
+ y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
237
+
238
+ # update
239
+ fir_state = torch.roll(fir_state, -1, dims=2)
240
+ fir_state[..., -1] = u
241
+ return y, fir_state
242
+
243
+ def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
244
+ x1v = x1 * v
245
+
246
+ residues, poles = (
247
+ torch.view_as_complex(residues.to(torch.float32)),
248
+ torch.view_as_complex(poles.to(torch.float32)),
249
+ )
250
+ # squeeze the dummy seqlen dimension
251
+ # D, state_dim, 1 -> 1, D, state_dim
252
+ residues, poles = residues[..., 0][None], poles[..., 0][None]
253
+ iir_state = poles * iir_state + x1v[..., None]
254
+
255
+ res_state = torch.sum(residues * iir_state, dim=-1).real
256
+
257
+ if iir_groups > 1:
258
+ raise NotImplementedError
259
+ y = x2 * (res_state + D * x1v)
260
+
261
+ return y, iir_state
262
+
263
+ def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
264
+ """Turns the IIR filter into a FIR and uses a cache for decoding."""
265
+ raise NotImplementedError(":)")
266
+
267
+ def prefill_via_direct_recurrence(
268
+ self, inference_params, x1v, L, residues, poles, *args, **kwargs
269
+ ) -> torch.Tensor:
270
+ """
271
+ Compute the IIR state via explicit SSM recurrence (modal form)
272
+
273
+ This is the most memory efficient prefilling method for Hyena filters.
274
+
275
+ Note:
276
+ dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
277
+ """
278
+ state_dim = poles.shape[1]
279
+ x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
280
+ x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
281
+ x1v_[..., 1] = 0
282
+
283
+ state = 0 * x1v_[:, :, 0]
284
+ output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
285
+
286
+ # suppress dummy seqlen dimension
287
+ poles = poles[:, :, 0][None]
288
+ residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
289
+
290
+ # state: b, d, sdim, reim
291
+ # poles: 1, d, sdim, reim
292
+ # x1v_: b, d, l, sdim, reim
293
+ for i in range(L):
294
+ state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
295
+ state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
296
+ output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
297
+
298
+ inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32))
299
+
300
+ return output
301
+
302
+ def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
303
+ """
304
+ Compute the IIR state via hybrid recurrence-convolution over blocks
305
+ """
306
+ raise NotImplementedError(":)")
307
+
308
+ def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
309
+ raise NotImplementedError
310
+
311
+ def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
312
+ """
313
+ Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
314
+
315
+ This is the most memory efficient "parallelized" prefilling method for Hyena.
316
+
317
+ From: https://arxiv.org/abs/2310.18780
318
+ """
319
+ raise NotImplementedError(":)")
320
+
321
+ def prefill_via_modal_fft(
322
+ self,
323
+ inference_params,
324
+ x1v,
325
+ L,
326
+ poles,
327
+ t,
328
+ dims,
329
+ layer_idx,
330
+ X_s=None,
331
+ use_flashfft=False,
332
+ fftconv_fn=None,
333
+ state_dtype=torch.complex64,
334
+ *args,
335
+ **kwargs,
336
+ ):
337
+ """
338
+ Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
339
+ """
340
+ # When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
341
+ # we split the filter into poles and residues and reuse FFT computation on the input.
342
+ # This optimization is currently not supported when using flashfftconv.
343
+ hidden_size, _, _, state_size, hyena_filter_groups = dims
344
+
345
+ if use_flashfft:
346
+ # using real states
347
+ poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
348
+
349
+ state_s = poles**t
350
+ if hyena_filter_groups > 1:
351
+ raise NotImplementedError
352
+
353
+ x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
354
+ x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
355
+ state_s = state_s[None]
356
+
357
+ state = fftconv_fn(
358
+ x1v.contiguous(),
359
+ state_s.to(dtype=torch.float32),
360
+ )
361
+ state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
362
+ state = torch.view_as_complex(state.contiguous().to(dtype=torch.float32))
363
+ inference_params.state_dict[self.layer_idx] = state
364
+ else:
365
+ assert X_s is not None
366
+ bs = x1v.shape[0]
367
+ fft_size = 2 * L
368
+ poles = torch.view_as_complex(poles.to(torch.float32))
369
+ state_s = poles**t
370
+ state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
371
+ if hyena_filter_groups > 1:
372
+ state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
373
+ state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
374
+ inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
375
+
376
+ def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
377
+ """
378
+ Compute the IIR state given an input `u` and log_poles of the modal system.
379
+ """
380
+ bs = u.shape[0]
381
+ fft_size = 2 * L
382
+ U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
383
+ fft_size = 2 * L
384
+ x = (log_poles * t).exp()
385
+ # [batch, hidden_size, state_dim, 2 * seqlen]
386
+ X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
387
+ state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
388
+ return state
stripedhyena/generation.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ # Barebones generation class for standalone inference.
6
+
7
+ import torch
8
+
9
+ from stripedhyena.sample import sample
10
+ from stripedhyena.tokenizer import CharLevelTokenizer
11
+ from stripedhyena.utils import print_rank_0
12
+
13
+
14
+ class Generator:
15
+ def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
16
+ self.model = model
17
+ self.tokenizer = tokenizer
18
+ self.top_k = top_k
19
+ self.top_p = top_p
20
+ self.temperature = temperature
21
+ self.untils = ["\n\n"]
22
+
23
+ def generate(
24
+ self,
25
+ device,
26
+ input_string=None,
27
+ input_ids=None,
28
+ num_tokens=32,
29
+ cached_generation=False,
30
+ print_generation=True,
31
+ verbose=False,
32
+ skip_special_tokens=False,
33
+ stop_at_eos=True,
34
+ max_seqlen=None,
35
+ ):
36
+ if isinstance(self.tokenizer.eos, int):
37
+ eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
38
+ else:
39
+ # is a tensor
40
+ eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
41
+
42
+ if input_ids is None:
43
+ input = self.tokenizer.tokenize(input_string)
44
+ if isinstance(input, list):
45
+ input = torch.LongTensor(input).unsqueeze(0).to(device)
46
+ # is a tensor
47
+ else:
48
+ input = input.unsqueeze(0).to(device)
49
+
50
+ else:
51
+ input = input_ids
52
+ x = input
53
+
54
+ if max_seqlen is not None:
55
+ x = x[:, -max_seqlen:]
56
+
57
+ prompt_len = x.shape[-1]
58
+
59
+ num_tokens = int(num_tokens)
60
+ tot_length = prompt_len + num_tokens
61
+ batch_size = x.shape[0]
62
+
63
+ generation = torch.empty(
64
+ x.shape[0],
65
+ num_tokens,
66
+ dtype=torch.long,
67
+ device=x.device,
68
+ )
69
+
70
+ scores = torch.empty(
71
+ x.shape[0],
72
+ num_tokens,
73
+ self.tokenizer.vocab_size,
74
+ dtype=torch.float,
75
+ device=x.device,
76
+ )
77
+
78
+ if cached_generation:
79
+ inference_params_dict_out = self.model.initialize_inference_params()
80
+ inference_params_dict_out["mha"].max_batch_size = batch_size
81
+ inference_params_dict_out["hyena"].max_batch_size = batch_size
82
+ else:
83
+ inference_params_dict_out = None
84
+
85
+ if verbose:
86
+ mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
87
+ print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
88
+ print_rank_0("Starting generation...")
89
+ if input_string is not None:
90
+ print_rank_0("Prompt: " + input_string)
91
+ else:
92
+ print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")
93
+
94
+ for i in range(int(num_tokens)):
95
+ post_prefill = cached_generation and i > 0
96
+ # prefill then process only the last token
97
+ if post_prefill:
98
+ x = x[:, -1:]
99
+ seqlen_offset = inference_params_dict_out["mha"].seqlen_offset
100
+
101
+ if seqlen_offset == 0:
102
+ seqlen_offset = input.shape[-1]
103
+ inference_params_dict_out["hyena"].seqlen_offset = seqlen_offset
104
+ inference_params_dict_out["mha"].seqlen_offset = seqlen_offset
105
+ else:
106
+ inference_params_dict_out["mha"].seqlen_offset += 1
107
+ inference_params_dict_out["hyena"].seqlen_offset += 1
108
+
109
+ # do forward pass with no gradient
110
+ with torch.no_grad():
111
+ logits, inference_params_dict_out = self.model(
112
+ x,
113
+ inference_params_dict=inference_params_dict_out,
114
+ )
115
+
116
+ last_logits = logits[:, -1]
117
+
118
+ new_idx = sample(
119
+ last_logits,
120
+ top_k=self.top_k,
121
+ top_p=self.top_p,
122
+ temperature=self.temperature,
123
+ )
124
+
125
+ if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
126
+ print_rank_0("Stopping generation at EOS")
127
+
128
+ if print_generation and verbose and batch_size == 1:
129
+ print_rank_0(
130
+ f"{self.tokenizer.detokenize([new_idx.item()])}",
131
+ end=" ",
132
+ )
133
+
134
+ scores[:, i] = last_logits
135
+ generation[:, i] = new_idx
136
+
137
+ if post_prefill:
138
+ x = new_idx[:, None]
139
+ else:
140
+ x = torch.cat([x, new_idx[:, None]], dim=-1)
141
+
142
+ if verbose:
143
+ kwargs = {}
144
+ if not isinstance(self.tokenizer, CharLevelTokenizer):
145
+ kwargs["skip_special_tokens"] = skip_special_tokens
146
+ y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)
147
+
148
+ for until in self.untils:
149
+ if until in y:
150
+ y = y.split(until)[0]
151
+ break
152
+
153
+ print_rank_0(f"\nInput: {input_string}, Output: {y}")
154
+
155
+ mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
156
+ print_rank_0(f"Memory after generation: {mem_end} GB")
157
+
158
+ return generation[:, : i + 1], scores[:, : i + 1]
stripedhyena/layers.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from stripedhyena.utils import grab_first_if_tuple
12
+
13
+
14
+ class RMSNorm(torch.nn.Module):
15
+ def __init__(self, config):
16
+ super(RMSNorm, self).__init__()
17
+ self.eps, self.hidden_size = config.eps, config.hidden_size
18
+ self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
19
+ self.register_parameter("scale", self.scale)
20
+ self.scale = self.scale.to(config.params_dtype)
21
+ self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
22
+
23
+ if self.use_flash_rmsnorm:
24
+ from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
25
+
26
+ self.rmsnorm_func = rmsnorm_func
27
+
28
+ def forward(self, x):
29
+ if self.use_flash_rmsnorm:
30
+ return self.rmsnorm_func(x, self.scale, self.eps)
31
+ else:
32
+ y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
33
+ return self.scale * y
34
+
35
+
36
+ class ParallelGatedMLP(nn.Module):
37
+ def __init__(
38
+ self,
39
+ config,
40
+ ):
41
+ super().__init__()
42
+
43
+ multiple_of = config.get("inner_size_multiple_of", 64)
44
+ self.act_type = config.get("mlp_activation", "silu")
45
+ if self.act_type == "gelu":
46
+ self.act = F.gelu
47
+ elif self.act_type == "silu":
48
+ self.act = F.silu
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ self.multiple_of = multiple_of * config.model_parallel_size
53
+
54
+ inner_size = int(2 * config.hidden_size * 4 / 3)
55
+ inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
56
+ if config.get("inner_mlp_size", None) is not None:
57
+ inner_size = config.inner_mlp_size
58
+
59
+ self.l1 = nn.Linear(
60
+ in_features=config.hidden_size,
61
+ out_features=inner_size,
62
+ bias=False,
63
+ )
64
+ self.l2 = nn.Linear(
65
+ in_features=config.hidden_size,
66
+ out_features=inner_size,
67
+ bias=False,
68
+ )
69
+ self.l3 = nn.Linear(
70
+ in_features=inner_size,
71
+ out_features=config.hidden_size,
72
+ bias=False,
73
+ )
74
+
75
+ def forward(self, z):
76
+ z1, z2 = self.l1(z), self.l2(z)
77
+ z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
78
+ y = self.l3(self.act(z1) * z2)
79
+ return grab_first_if_tuple(y)
80
+
81
+
82
+ class Embedding(nn.Module):
83
+ _train_dtype = "bf16"
84
+
85
+ def __init__(self, config):
86
+ super().__init__()
87
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
88
+
89
+ def embed(self, input_ids, position_ids=None, tokentype_ids=None):
90
+ embeddings = self.word_embeddings(input_ids)
91
+ return embeddings
92
+
93
+ def unembed(self, u):
94
+ weight = self.word_embeddings.weight
95
+ return torch.matmul(u, weight)
96
+
97
+
98
+ class VocabParallelEmbedding(nn.Embedding):
99
+ "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
100
+
101
+ def __init__(self, config):
102
+ vocab_size, process_group, padding_idx = (
103
+ config.vocab_size,
104
+ config.get("process_group", None),
105
+ config.get("padding_idx", None),
106
+ )
107
+ self.process_group = process_group
108
+ if process_group is not None:
109
+ world_size = torch.distributed.get_world_size(process_group)
110
+ if vocab_size % world_size != 0:
111
+ raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})")
112
+ if world_size > 1 and padding_idx is not None:
113
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
114
+ else:
115
+ world_size = 1
116
+ super().__init__(
117
+ vocab_size // world_size,
118
+ embedding_dim=config.hidden_size,
119
+ padding_idx=padding_idx,
120
+ )
121
+
122
+ def embed(self, input: Tensor) -> Tensor:
123
+ if self.process_group is None:
124
+ return self.forward(input)
125
+ else:
126
+ rank = torch.distributed.get_rank(self.process_group)
127
+ vocab_size = self.num_embeddings
128
+ vocab_start_index, vocab_end_index = (
129
+ rank * vocab_size,
130
+ (rank + 1) * vocab_size,
131
+ )
132
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
133
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
134
+ input = input - vocab_start_index
135
+ input[input_ids_mask] = 0
136
+ embeddings = self.forward(input)
137
+ embeddings[input_ids_mask] = 0.0
138
+ # Reduce to the global process group
139
+ torch.distributed.all_reduce(embeddings, group=self.process_group)
140
+ return embeddings
141
+
142
+ def unembed(self, u: Tensor) -> Tensor:
143
+ if self.process_group is None:
144
+ return u @ self.weight.T
145
+ else:
146
+ raise NotImplementedError
stripedhyena/model.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+ # Note: MP and PP utilities are removed for ease of use and editing.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from stripedhyena.cache import InferenceParams, RecurrentInferenceParams
11
+ from stripedhyena.engine import HyenaInferenceEngine
12
+ from stripedhyena.layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
13
+ from stripedhyena.utils import column_split, print_rank_0
14
+
15
+ try:
16
+ from flash_attn.modules.mha import MHA
17
+ except ImportError:
18
+ "flash_attn not installed"
19
+
20
+ try:
21
+ from stripedhyena.positional_embeddings import swap_mha_rope
22
+ except ImportError:
23
+ "could not import swap_mha_rope from src.positional_embeddings"
24
+
25
+
26
+ class AttentionBlock(nn.Module):
27
+ def __init__(self, config, layer_idx) -> None:
28
+ super().__init__()
29
+ self.config = config
30
+ self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
31
+ self.layer_idx = layer_idx
32
+ self.proj_groups = config.get("proj_groups", 1)
33
+ dtype = config.get("attn_block_dtype", torch.bfloat16)
34
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
35
+ self.num_attention_heads = config.num_attention_heads
36
+ self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
37
+
38
+ self.counter = 0
39
+ self.inner_mha_cls = MHA(
40
+ embed_dim=config.hidden_size,
41
+ num_heads=config.num_attention_heads,
42
+ num_heads_kv=config.num_attention_heads // self.proj_groups,
43
+ rotary_emb_dim=config.hidden_size // config.num_attention_heads,
44
+ qkv_proj_bias=config.get("qkv_proj_bias", True),
45
+ rotary_emb_base=config.get("rotary_emb_base", 10000),
46
+ causal=True,
47
+ layer_idx=layer_idx,
48
+ out_proj_bias=config.get("mha_out_proj_bias", True),
49
+ use_flash_attn=self.config.use_flash_attn,
50
+ ).to(dtype=dtype)
51
+
52
+ # check if using interpolated rotary pos emb from config, and swap the rope emb
53
+ if config.get("use_interpolated_rotary_pos_emb", False):
54
+ swap_mha_rope(
55
+ mha=self.inner_mha_cls,
56
+ kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)},
57
+ )
58
+
59
+ if self.config.get("smeared_gqa", False):
60
+ self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
61
+ self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
62
+
63
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
64
+
65
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
66
+ if (
67
+ type(padding_mask) == torch.Tensor
68
+ ): # workaround for masking bug in FA. This works because Wqkv does not have bias
69
+ # and attention scores will be also automatically zeroed.
70
+ u = u * padding_mask[..., None]
71
+
72
+ u = (
73
+ self.inner_mha_cls(
74
+ self.pre_norm(u),
75
+ inference_params=inference_params,
76
+ )
77
+ + u
78
+ )
79
+ if type(padding_mask) == torch.Tensor: # guard against bias
80
+ u = u * padding_mask[..., None]
81
+ u = self.mlp(self.post_norm(u)) + u
82
+ return u, None
83
+
84
+
85
+ class ParallelHyenaFilter(nn.Module):
86
+ def __init__(self, config, layer_idx) -> None:
87
+ super().__init__()
88
+ self.config = config
89
+ self.layer_idx = layer_idx
90
+ self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
91
+
92
+ self.use_flashfft = config.get("use_flashfft", False)
93
+ self.state_size = config.state_size
94
+ self.hidden_size = config.hidden_size
95
+ self.num_filters = config.num_filters
96
+ self.inference_mode = config.get("inference_mode", True)
97
+ self.counter = 0
98
+ self.column_split_hyena = config.get("column_split_hyena", True)
99
+
100
+ assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
101
+
102
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
103
+
104
+ # attention heads are not used except to split post short_filter
105
+ # projections in the same way as the checkpoint
106
+ self.num_attention_heads = config.num_attention_heads
107
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
108
+
109
+ # after preprocessing here we can save the new checkpoint
110
+ self.short_filter_length = config.short_filter_length
111
+ self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
112
+ self.short_filter_bias = (
113
+ nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
114
+ )
115
+
116
+ self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
117
+ self.use_flash_depthwise = config.get("use_flash_depthwise", False)
118
+ self.data_dtype = None
119
+
120
+ if self.use_flash_depthwise:
121
+ try:
122
+ from flashfftconv import FlashDepthwiseConv1d
123
+
124
+ self.fir_fn = FlashDepthwiseConv1d(
125
+ channels=3 * self.hidden_size,
126
+ kernel_size=self.short_filter_length,
127
+ padding=self.short_filter_length - 1,
128
+ weights=self.short_filter_weight,
129
+ bias=self.short_filter_bias,
130
+ device=None,
131
+ dtype=self.config.get("depthwise_dtype", torch.bfloat16),
132
+ )
133
+ except ImportError:
134
+ "flashfftconv not installed"
135
+ else:
136
+ self.fir_fn = F.conv1d
137
+
138
+ self.fftconv_fn = None
139
+ self.long_fir_threshold = config.get("long_fir_threshold", None)
140
+ if self.long_fir_threshold is not None:
141
+ assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
142
+
143
+ self.num_systems = self.hidden_size // self.hyena_filter_groups
144
+
145
+ poles = torch.randn(self.num_systems, self.state_size, 1, 2)
146
+
147
+ # TODO: bring over init from internals
148
+ poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
149
+ poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
150
+
151
+ self.poles = nn.Parameter(poles)
152
+
153
+ self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
154
+ self.h = None
155
+
156
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
157
+ if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
158
+ return self.sequential_forward(u, inference_params)
159
+
160
+ else:
161
+ return self.parallel_forward(u, inference_params, padding_mask)
162
+
163
+ def parallel_forward(self, u, inference_params=None, padding_mask=None):
164
+ L = u.shape[1]
165
+ z_pre, fir_state = self.engine.parallel_fir(
166
+ self.fir_fn,
167
+ u,
168
+ self.short_filter_weight,
169
+ self.short_filter_bias,
170
+ L,
171
+ fir_length=self.short_filter_length,
172
+ inference_params=inference_params,
173
+ padding_mask=padding_mask,
174
+ )
175
+ if inference_params:
176
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
177
+
178
+ if self.h is None:
179
+ h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
180
+ else:
181
+ h = self.h
182
+ filter_dtype = self.h.dtype
183
+
184
+ if self.hyena_filter_groups > 1:
185
+ h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
186
+
187
+ # if inference_params is not None, we plan to perform generation:
188
+ # prefilling is handled by the engine.
189
+ dims = (
190
+ self.hidden_size,
191
+ self.num_attention_heads,
192
+ self.hidden_size_per_attention_head,
193
+ self.state_size,
194
+ self.hyena_filter_groups,
195
+ )
196
+ y = self.engine.parallel_iir(
197
+ z_pre,
198
+ h,
199
+ self.D,
200
+ L,
201
+ t=self.t,
202
+ poles=self.poles,
203
+ residues=self.residues,
204
+ dims=dims,
205
+ inference_params=inference_params,
206
+ layer_idx=self.layer_idx,
207
+ prefill_style=self.config.get("prefill_style", "fft"),
208
+ use_flashfft=self.use_flashfft,
209
+ fftconv_fn=self.fftconv_fn,
210
+ column_split_hyena=self.column_split_hyena,
211
+ long_fir_threshold=self.long_fir_threshold,
212
+ padding_mask=padding_mask,
213
+ )
214
+
215
+ return y, inference_params
216
+
217
+ def sequential_forward(self, u, inference_params):
218
+ if self.data_dtype is None:
219
+ self.data_dtype = u.dtype
220
+ if len(u.shape) > 2:
221
+ u = u[:, -1]
222
+
223
+ fir_state, iir_state = (
224
+ inference_params.fir_state_dict[self.layer_idx],
225
+ inference_params.state_dict[self.layer_idx],
226
+ )
227
+
228
+ z_pre, fir_state = self.engine.step_fir(
229
+ u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
230
+ )
231
+ x2, x1, v = (
232
+ column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
233
+ if self.column_split_hyena
234
+ else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
235
+ )
236
+
237
+ y, iir_state = self.engine.step_iir(
238
+ x2,
239
+ x1,
240
+ v,
241
+ self.D,
242
+ self.residues,
243
+ self.poles,
244
+ iir_state,
245
+ iir_groups=self.hyena_filter_groups,
246
+ )
247
+
248
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
249
+ inference_params.state_dict[self.layer_idx] = iir_state
250
+ y = y.to(dtype=self.data_dtype)
251
+ return y[:, None], inference_params
252
+
253
+ def update_time(self, L, device):
254
+ """
255
+ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
256
+ If L is greater than the length of the previous batch, then the time vector is
257
+ reinitialized. Otherwise, the time vector is truncated from cache.
258
+ """
259
+ if not hasattr(self, "t"):
260
+ self.t = torch.arange(L, device=device)[None, None]
261
+ elif self.t.shape[-1] < L:
262
+ self.t = torch.arange(L, device=device)[None, None]
263
+ else:
264
+ self.t = self.t[..., :L]
265
+
266
+ def compute_filter(self, L, device):
267
+ self.update_time(L, device)
268
+ filter_dtype = torch.float32
269
+ residues, log_poles = (
270
+ torch.view_as_complex(self.residues.to(filter_dtype)),
271
+ torch.view_as_complex(self.poles.to(filter_dtype)).log(),
272
+ )
273
+ h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
274
+ return h, filter_dtype, log_poles, residues
275
+
276
+
277
+ class ParallelGatedConvBlock(nn.Module):
278
+ def __init__(self, config, layer_idx) -> None:
279
+ super().__init__()
280
+ self.config = config
281
+ self.layer_idx = layer_idx
282
+ self.low_mem_mode = config.get("low_mem_mode", False)
283
+ dtype = config.get("hyena_block_dtype", torch.float32)
284
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
285
+ self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype)
286
+ self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
287
+ self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
288
+ self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
289
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
290
+
291
+ self.proj_norm_fn = self.proj_norm
292
+ self.res_mlp_norm_fn = self.res_mlp_norm
293
+
294
+ if self.config.get("compile", False):
295
+ self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
296
+ self.res_mlp_norm_fn = torch.compile(
297
+ self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
298
+ )
299
+
300
+ def proj_norm(self, x):
301
+ return self.projections(self.pre_norm(x))
302
+
303
+ def res_mlp_norm(self, x):
304
+ return self.mlp(self.post_norm(x)) + x
305
+
306
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
307
+ z = self.proj_norm_fn(u)
308
+
309
+ if type(padding_mask) == torch.Tensor: # guard against bias
310
+ z = z * padding_mask[..., None]
311
+
312
+ z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
313
+
314
+ z_in = self.out_filter_dense(z) + u
315
+
316
+ if type(padding_mask) == torch.Tensor: # guard against bias
317
+ z_in = z_in * padding_mask[..., None]
318
+
319
+ y = self.res_mlp_norm_fn(z_in)
320
+
321
+ return y, inference_params
322
+
323
+
324
+ def get_block(config, layer_idx, flash_fft=None):
325
+ if layer_idx in config.attn_layer_idxs:
326
+ return AttentionBlock(config, layer_idx)
327
+ elif layer_idx in config.hyena_layer_idxs:
328
+ block = ParallelGatedConvBlock(config, layer_idx)
329
+ if config.get("use_flashfft", "False"):
330
+ block.filter.fftconv_fn = flash_fft
331
+ return block
332
+ else:
333
+ raise NotImplementedError
334
+
335
+
336
+ class StripedHyena(nn.Module):
337
+ def __init__(self, config):
338
+ super().__init__()
339
+ self.config = config
340
+ self.embedding_layer = VocabParallelEmbedding(config)
341
+ self.norm = RMSNorm(config) if config.get("final_norm", True) else None
342
+ self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
343
+
344
+ if config.get("use_flashfft", "True"):
345
+ try:
346
+ from flashfftconv import FlashFFTConv
347
+
348
+ self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16)
349
+ except ImportError:
350
+ "flashfftconv not installed"
351
+ else:
352
+ self.flash_fft = None
353
+
354
+ self.blocks = nn.ModuleList(
355
+ get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
356
+ )
357
+
358
+ def forward(self, x, inference_params_dict=None, padding_mask=None):
359
+ L = x.shape[1]
360
+ x = self.embedding_layer.embed(x)
361
+ if inference_params_dict is not None:
362
+ x, inference_params_dict_out = self.stateful_forward(
363
+ x,
364
+ inference_params_dict=inference_params_dict,
365
+ )
366
+ else:
367
+ x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
368
+
369
+ x = self.norm(x)
370
+ x = self.unembed.unembed(x)
371
+ return x, inference_params_dict_out
372
+
373
+ def stateful_forward(self, x, inference_params_dict=None):
374
+ for block_idx, block in enumerate(self.blocks):
375
+ block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
376
+ inference_params = inference_params_dict[block_name]
377
+ x, _ = block(x, inference_params=inference_params)
378
+
379
+ return x, inference_params_dict
380
+
381
+ def stateless_forward(self, x, padding_mask=None):
382
+ if type(padding_mask) == torch.Tensor:
383
+ x = x * padding_mask[..., None]
384
+
385
+ for _, block in enumerate(self.blocks):
386
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
387
+ return x, None
388
+
389
+ def initialize_inference_params(self):
390
+ inference_params_dict = {
391
+ "mha": InferenceParams(
392
+ max_seqlen=self.config.get("max_seqlen", 8192),
393
+ max_batch_size=self.config.get("max_batch_size", 1),
394
+ seqlen_offset=0,
395
+ ),
396
+ "hyena": RecurrentInferenceParams(
397
+ fir_filter_length=self.config.short_filter_length,
398
+ state_dim=self.config.state_size,
399
+ seqlen_offset=0,
400
+ ),
401
+ }
402
+ return inference_params_dict
403
+
404
+ def precompute_filters(self, L, device):
405
+ for block_idx, block in enumerate(self.blocks):
406
+ if type(block) == ParallelGatedConvBlock:
407
+ if type(block.filter) == ParallelHyenaFilter:
408
+ L = block.filter.long_fir_threshold or L
409
+ print_rank_0(f"Precomputing filters, L={L}...")
410
+
411
+ filter_dtype = torch.float16 if L >= 2048 else torch.float32
412
+
413
+ block.filter._set_time(L, device)
414
+ residues, poles = (
415
+ torch.view_as_complex(block.filter.residues.to(torch.float16)),
416
+ torch.view_as_complex(block.filter.poles.to(torch.float16)),
417
+ )
418
+
419
+ block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
420
+ block.filter.h = block.filter.h.to(dtype=filter_dtype)
421
+
422
+ def load_poles_residues(self, path):
423
+ "Load different poles and residues for each layer."
424
+ for block_idx, block in enumerate(self.blocks):
425
+ if type(block) == ParallelGatedConvBlock:
426
+ if type(block.filter) == ParallelHyenaFilter:
427
+ print(f"Loading poles and residues for block {block_idx}")
428
+ poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
429
+ poles = torch.view_as_real(poles)
430
+ residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
431
+ residues = torch.view_as_real(residues)
432
+ poles = poles.permute(1, 0, 2).unsqueeze(-2)
433
+ residues = residues.permute(1, 0, 2).unsqueeze(-2)
434
+
435
+ block.filter.poles = nn.Parameter(poles)
436
+ block.filter.residues = nn.Parameter(residues)
437
+
438
+ def to_bfloat16_except_poles_residues(self):
439
+ """Convert all parameters to bfloat16 except for the poles and residues.
440
+
441
+ Particularly important for longer prompts.
442
+ """
443
+ for k, p in self.named_parameters():
444
+ if "poles" not in k and "residues" not in k:
445
+ p.data = p.data.to(torch.bfloat16)
stripedhyena/positional_embeddings.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Armin Thomas, Jan 2023. Modified by Eric Nguyen.
3
+
4
+ Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn.
5
+
6
+ """
7
+
8
+ import copy
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from flash_attn.layers.rotary import RotaryEmbedding
13
+ from flash_attn.modules.mha import MHA
14
+
15
+
16
+ # simple wrapper for flash-attn RoPE with linear scaling:
17
+ class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ scaling_factor: float = 1.0,
22
+ base=10000.0,
23
+ interleaved=False,
24
+ scale_base=None,
25
+ pos_idx_in_fp32=True,
26
+ device=None,
27
+ ):
28
+ super().__init__(
29
+ dim=dim,
30
+ base=base,
31
+ interleaved=interleaved,
32
+ scale_base=scale_base,
33
+ pos_idx_in_fp32=pos_idx_in_fp32,
34
+ device=device,
35
+ )
36
+ self._linear_scaling_factor = scaling_factor
37
+
38
+ # adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
39
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
40
+ # Reset the tables if the sequence length has changed,
41
+ # if we're on a new device (possibly due to tracing for instance),
42
+ # or if we're switching from inference mode to training
43
+ if (
44
+ seqlen > self._seq_len_cached
45
+ or self._cos_cached is None
46
+ or self._cos_cached.device != device
47
+ or self._cos_cached.dtype != dtype
48
+ or (self.training and self._cos_cached.is_inference())
49
+ ):
50
+ self._seq_len_cached = seqlen
51
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
52
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
53
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
54
+ if self.pos_idx_in_fp32:
55
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
56
+ # linear scaling:
57
+ t = t / self._linear_scaling_factor
58
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
59
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
60
+ # cos & sin output to change significantly.
61
+ # We want to recompute self.inv_freq if it was not loaded in fp32
62
+ if self.inv_freq.dtype != torch.float32:
63
+ inv_freq = self._compute_inv_freq(device=device)
64
+ else:
65
+ inv_freq = self.inv_freq
66
+ else:
67
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
68
+ # linear scaling:
69
+ t = t / self._linear_scaling_factor
70
+ inv_freq = self.inv_freq
71
+ # Don't do einsum, it converts fp32 to fp16 under AMP
72
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
73
+ freqs = torch.outer(t, inv_freq)
74
+ if self.scale is None:
75
+ self._cos_cached = torch.cos(freqs).to(dtype)
76
+ self._sin_cached = torch.sin(freqs).to(dtype)
77
+ else:
78
+ power = (
79
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
80
+ ) / self.scale_base
81
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
82
+ # We want the multiplication by scale to happen in fp32
83
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
84
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
85
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
86
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
87
+
88
+
89
+ # swap out RoPE of existing mha:
90
+ def swap_mha_rope(mha, new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding, kwargs_new_rope: dict = None):
91
+ # determine mha dtype and device:
92
+ dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
93
+ device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
94
+ # determine RoPE settings:
95
+ kwargs_old_rope = dict(
96
+ dim=mha.rotary_emb.dim,
97
+ base=mha.rotary_emb.base,
98
+ interleaved=mha.rotary_emb.interleaved,
99
+ scale_base=mha.rotary_emb.scale_base,
100
+ pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32,
101
+ device=mha.rotary_emb.inv_freq.device,
102
+ )
103
+ # delete old RoPE:
104
+ del mha.rotary_emb
105
+ # create new RoPE:
106
+ kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0}
107
+ scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype)
108
+ # attach new RoPE to mha:
109
+ mha.rotary_emb = scaled_rope
110
+ # make new sure RoPE is correctly registered:
111
+ assert isinstance(mha.rotary_emb, new_rope)
112
+ return mha
stripedhyena/sample.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
5
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
6
+ def modify_logits_for_top_k_filtering(logits, top_k):
7
+ """Set the logits for none top-k values to -inf. Done in-place."""
8
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
9
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
10
+
11
+
12
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
13
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
14
+ def modify_logits_for_top_p_filtering(logits, top_p):
15
+ """Set the logits for none top-p values to -inf. Done in-place."""
16
+ if top_p <= 0.0 or top_p >= 1.0:
17
+ return
18
+
19
+ # First sort and calculate cumulative sum of probabilities.
20
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
21
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
22
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
23
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
24
+ # scatter sorted tensors to original indexing
25
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
26
+ logits.masked_fill_(indices_to_remove, float("-inf"))
27
+
28
+
29
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
30
+ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
31
+ """Sample from top-k logits.
32
+ Arguments:
33
+ logits: Tensor of shape (batch_size, vocab_size)
34
+ """
35
+ logits = torch.nan_to_num(logits)
36
+ logits = torch.where(logits == float("-inf"), 0, logits)
37
+ logits = torch.where(logits == float("inf"), 0, logits)
38
+
39
+ if top_k == 1: # Short-circuit for greedy decoding
40
+ return logits.argmax(dim=-1)
41
+ else:
42
+ if top_p > 0.0:
43
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
44
+ if top_k > 0:
45
+ top_k = min(top_k, logits.size(-1)) # Safety check
46
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
47
+ if temperature != 1.0:
48
+ logits_top /= temperature
49
+ modify_logits_for_top_p_filtering(logits_top, top_p)
50
+
51
+ return indices[
52
+ torch.arange(indices.shape[0], device=indices.device),
53
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
54
+ ]
55
+ else:
56
+ # Clone so that when we modify for top_p we don't change the original logits
57
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
58
+ modify_logits_for_top_p_filtering(logits_top, top_p)
59
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
stripedhyena/tokenizer.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
2
+ import json
3
+ import pathlib
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import tqdm
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ class HFAutoTokenizer:
14
+ def __init__(self, vocab_file):
15
+ self.tokenizer = Tokenizer.from_file(vocab_file)
16
+ self.eos = "</s>"
17
+ self.bos = "<s>"
18
+ self.eos_id = self.tokenize(self.eos)
19
+ self.bos_id = self.tokenize(self.bos)
20
+ self.vsize = 32000
21
+
22
+ def encode_to_list(self, text):
23
+ return self.tokenizer.encode(text, add_special_tokens=False)
24
+
25
+ def tokenize_file(self, input_file, output_file, verbose=False):
26
+ if verbose:
27
+ print(f"Tokenizing file: {input_file}")
28
+
29
+ if pathlib.Path(output_file).exists():
30
+ print(f"Output file {output_file} already exists, skipping")
31
+ return
32
+ with open(input_file, "r") as fin, open(output_file, "w") as fout:
33
+ for line in tqdm.tqdm(fin):
34
+ if verbose:
35
+ print(f"Tokenizing line: {line[-200:]}")
36
+ data = json.loads(line.strip())
37
+ if "text" not in data.keys():
38
+ break
39
+ tokenized_data = self.tokenize(data["text"])
40
+ fout.write(json.dumps({"tokens": tokenized_data}) + "\n")
41
+
42
+ def tokenize(self, text: str, *args, **kwargs):
43
+ ids = self.tokenizer.encode(text)
44
+ if type(ids) == list:
45
+ return torch.tensor(ids)
46
+ else:
47
+ return torch.tensor(ids.ids)
48
+
49
+ def tokenize_batch(self, text_batch):
50
+ return self.tokenizer.encode_batch(text_batch)
51
+
52
+ def detokenize(self, token_ids, skip_special_tokens=False):
53
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
54
+
55
+ def detokenize_batch(self, token_ids_batch, skip_special_tokens=False):
56
+ out = []
57
+ for token_ids in token_ids_batch:
58
+ out.append(
59
+ self.detokenize(
60
+ [t.item() for t in token_ids],
61
+ skip_special_tokens=skip_special_tokens,
62
+ )
63
+ )
64
+ return out
65
+
66
+ @property
67
+ def eod(self):
68
+ return self.eod_id
69
+
70
+ @property
71
+ def vocab_size(self):
72
+ return 32000
73
+
74
+
75
+ class AbstractTokenizer(ABC):
76
+ """Abstract class for tokenizer."""
77
+
78
+ def __init__(self, name):
79
+ self.name = name
80
+ super().__init__()
81
+
82
+ @property
83
+ @abstractmethod
84
+ def vocab_size(self):
85
+ pass
86
+
87
+ @property
88
+ @abstractmethod
89
+ def vocab(self):
90
+ """Dictionary from vocab text token to id token."""
91
+ pass
92
+
93
+ @property
94
+ @abstractmethod
95
+ def inv_vocab(self):
96
+ """Dictionary from vocab id token to text token."""
97
+ pass
98
+
99
+ @abstractmethod
100
+ def tokenize(self, text):
101
+ pass
102
+
103
+ def detokenize(self, token_ids):
104
+ raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))
105
+
106
+ @property
107
+ def cls(self):
108
+ raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))
109
+
110
+ @property
111
+ def sep(self):
112
+ raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))
113
+
114
+ @property
115
+ def pad(self):
116
+ raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))
117
+
118
+ @property
119
+ def eod(self):
120
+ raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))
121
+
122
+ @property
123
+ def mask(self):
124
+ raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))
125
+
126
+
127
+ class CharLevelTokenizer(AbstractTokenizer):
128
+ """Character Level Tokenizer"""
129
+
130
+ def __init__(self, vocab_size):
131
+ name = "CharLevelTokenizer"
132
+ super().__init__(name)
133
+ self._vocab_size = vocab_size
134
+ self.eod_id = 0
135
+ self.eos_id = 0
136
+ self.pad_id = 1
137
+
138
+ def clamp(self, n):
139
+ return max(32, min(n, self.vocab_size))
140
+
141
+ @property
142
+ def vocab_size(self):
143
+ return self._vocab_size
144
+
145
+ @property
146
+ def vocab(self):
147
+ raise NotImplementedError
148
+
149
+ @property
150
+ def inv_vocab(self):
151
+ raise NotImplementedError
152
+
153
+ def decode_token(self, token: int):
154
+ return str(chr(self.clamp(token)))
155
+
156
+ def tokenize(self, text: str):
157
+ return list(np.frombuffer(text.encode(), dtype=np.uint8))
158
+
159
+ def tokenize_batch(self, text_batch: Union[List[str], str]):
160
+ if isinstance(text_batch, list):
161
+ return [self.tokenize(s) for s in text_batch]
162
+ else:
163
+ return self.tokenize(text_batch)
164
+
165
+ def detokenize(self, token_ids):
166
+ return "".join(list(map(self.decode_token, token_ids)))
167
+
168
+ def detokenize_batch(self, token_ids: Union[List[str], str]):
169
+ if isinstance(token_ids, list):
170
+ return [self.detokenize(s) for s in token_ids]
171
+ # elif if tensor, convert to list first
172
+ elif isinstance(token_ids, torch.Tensor):
173
+ return [self.detokenize(s) for s in token_ids.tolist()]
174
+ else:
175
+ return self.detokenize(token_ids)
176
+
177
+ @property
178
+ def eod(self):
179
+ return self.eod_id
180
+
181
+ # duplicate to suppose both names, eos and eod
182
+ @property
183
+ def eos(self):
184
+ return self.eod_id
stripedhyena/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def grab_first_if_tuple(x):
5
+ if x.__class__.__name__ == "tuple":
6
+ return x[0]
7
+ else:
8
+ return x
9
+
10
+
11
+ def column_split(x, num_heads, head_size):
12
+ """Split a tensor with `num_heads` alongside the head dimension, instead of
13
+ across heads. Fixed to three projections
14
+ """
15
+
16
+ x_reshaped = x.reshape(
17
+ x.shape[0],
18
+ num_heads,
19
+ 3 * head_size,
20
+ )
21
+
22
+ x2, x1, v = (
23
+ x_reshaped[:, :, :head_size],
24
+ x_reshaped[
25
+ :,
26
+ :,
27
+ head_size : 2 * head_size,
28
+ ],
29
+ x_reshaped[:, :, 2 * head_size :],
30
+ )
31
+ x2, x1, v = (
32
+ x2.reshape(x2.shape[0], -1),
33
+ x1.reshape(x1.shape[0], -1),
34
+ v.reshape(v.shape[0], -1),
35
+ )
36
+ return x2, x1, v
37
+
38
+
39
+ def get_init_from_string(init_str):
40
+ if type(init_str) == str:
41
+ if init_str == "torch.nn.init.zeros_":
42
+ return torch.nn.init.zeros_
43
+ elif init_str == "torch.nn.init.xavier_uniform_":
44
+ return torch.nn.init.xavier_uniform_
45
+ elif init_str == "torch.nn.init.xavier_normal_":
46
+ return torch.nn.init.xavier_normal_
47
+ else:
48
+ raise ValueError(f"Unrecognized init {init_str}")
49
+
50
+
51
+ def print_rank_0(message, debug=False, end="\n"):
52
+ """Print from rank 0 only."""
53
+ if torch.distributed.is_initialized():
54
+ if torch.distributed.get_rank() == 0:
55
+ print(message, flush=True, end=end)
56
+ else:
57
+ print(message, flush=True, end=end)
58
+
59
+
60
+ class dotdict(dict):
61
+ """dot.notation access to dictionary attributes"""
62
+
63
+ __getattr__ = dict.get
64
+ __setattr__ = dict.__setitem__
65
+ __delattr__ = dict.__delitem__
66
+
67
+
68
+ def ensure_divisibility(numerator, denominator):
69
+ """Ensure that numerator is divisible by the denominator."""
70
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
71
+
72
+
73
+ def divide(numerator, denominator):
74
+ """Ensure that numerator is divisible by the denominator and return
75
+ the division value."""
76
+ ensure_divisibility(numerator, denominator)
77
+ return numerator // denominator
78
+
79
+
80
+ class VocabUtility:
81
+ """Split the vocabulary into `world_size` chunks amd return the
82
+ first and last index of the vocabulary belonging to the `rank`
83
+ partition: Note that indices in [first, last]"""
84
+
85
+ @staticmethod
86
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
87
+ index_f = rank * per_partition_vocab_size
88
+ index_l = index_f + per_partition_vocab_size
89
+ return index_f, index_l
90
+
91
+ @staticmethod
92
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
93
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
94
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)