yagizdevre commited on
Commit
6ff2080
·
1 Parent(s): 9064a3d

added configs

Browse files
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_minitransformer import MiniTransformerConfig
2
+ from .modeling_minitransformer import MiniTransformer
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|endofprompt|>": 200018
3
+ }
attn.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .rotary_emb import apply_rotary_emb
7
+ from .utils import nearest_power_of_two
8
+
9
+ try:
10
+ from flash_attn import flash_attn_func as fa2
11
+ except ImportError as e:
12
+ print(
13
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
14
+ )
15
+ # TODO: Add FlexAttention + local attention mask when it's in stable release
16
+
17
+ class Attention(nn.Module):
18
+ def __init__(self, config):
19
+ super(Attention, self).__init__()
20
+ if isinstance(config.torch_dtype, str):
21
+ torch_dtype = getattr(torch, config.torch_dtype)
22
+ else:
23
+ torch_dtype = config.torch_dtype
24
+ assert torch.cuda.is_available(), "CUDA is required."
25
+ assert config.n_embd % config.n_heads == 0
26
+ self.n_heads = config.n_heads
27
+
28
+ self.device = torch.device("cuda")
29
+ self.bsz = config.bsz
30
+ self.attn = nn.Linear(
31
+ config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
32
+ )
33
+ self.o_proj = nn.Linear(
34
+ config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
35
+ )
36
+ self.o_proj.SCALE_INIT = 1
37
+ self.dropout = config.dropout
38
+ self.resid_dropout = nn.Dropout(self.dropout)
39
+ self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
40
+ self.window_size = config.window_size
41
+ self.softcap = config.softcap
42
+
43
+ def _generate_slopes(self, n: int):
44
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
45
+ return [start * (start**i) for i in range(n)]
46
+
47
+ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
48
+ # If n_heads is a power of 2, generate slopes directly
49
+ if math.log2(n_heads).is_integer():
50
+ slopes = self._generate_slopes(n_heads)
51
+ else:
52
+ # Get slopes for the nearest power of two
53
+ n = nearest_power_of_two(n_heads, round_up=False)
54
+ slopes_power_of_two = self._generate_slopes(n)
55
+
56
+ # Generate extra slopes
57
+ extra_slopes = self._generate_slopes(2 * n)
58
+ extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
59
+ slopes = slopes_power_of_two + extra_slopes_trunc
60
+ slopes = torch.tensor(slopes, device=self.device)
61
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
62
+ return slopes.to(torch.float32) # Ensure slopes are in float32
63
+
64
+
65
+ def forward(self, x):
66
+ bsz, seq_len, d_in = x.size()
67
+
68
+ qkv = self.attn(x)
69
+ q, k, v = torch.chunk(qkv, 3, dim=2)
70
+
71
+ q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
72
+ k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
73
+ v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
74
+ y = fa2( # https://arxiv.org/pdf/2307.08691
75
+ q,
76
+ k,
77
+ v,
78
+ dropout_p=self.dropout if self.training else 0.0,
79
+ causal=True,
80
+ window_size=(self.window_size, 0),
81
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
82
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
83
+ )
84
+ y = y.contiguous().view(bsz, seq_len, d_in)
85
+ y = self.resid_dropout(self.o_proj(y))
86
+ return y
87
+
88
+ class AttentionSDPA(nn.Module):
89
+ def __init__(self, config):
90
+ super(Attention, self).__init__()
91
+ if isinstance(config.torch_dtype, str):
92
+ torch_dtype = getattr(torch, config.torch_dtype)
93
+ else:
94
+ torch_dtype = config.torch_dtype
95
+ assert torch.cuda.is_available(), "CUDA is required."
96
+ assert config.n_embd % config.n_heads == 0
97
+ self.n_heads = config.n_heads
98
+
99
+ self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
100
+ self.bsz = config.bsz
101
+ self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
102
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
103
+ self.dropout = config.dropout
104
+ self.resid_dropout = nn.Dropout(self.dropout)
105
+
106
+ def forward(self, x):
107
+ bsz, seq_len, d_in = x.size()
108
+
109
+ qkv = self.attn(x)
110
+ q, k, v = torch.chunk(qkv, 3, dim=2)
111
+
112
+ q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
113
+ k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
114
+ v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
115
+
116
+ y = F.scaled_dot_product_attention(
117
+ q, k, v,
118
+ is_causal=True,
119
+ dropout_p=self.dropout if self.training else 0.0
120
+ )
121
+
122
+ y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
123
+
124
+ y = self.resid_dropout(self.o_proj(y))
125
+ return y
126
+
127
+
128
+ class FlexAttention(nn.Module):
129
+ """
130
+ Generalized Multihead Attention and supports various attention masks.
131
+ Supports Rotary Positional Embeddings.
132
+ """
133
+ def __init__(self, config, mask_mod, score_mod=None):
134
+ """
135
+ Initializes the Attention class.
136
+
137
+ Args:
138
+ dim (int): Embedding size.
139
+ num_heads (int): Number of heads.
140
+ mask_mod (Callable): Mask to modify attention scores, e.g. causal.
141
+ """
142
+ super().__init__()
143
+ self.dim, self.num_heads = config.dim, config.num_heads
144
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
145
+ self.head_dim = config.dim // config.num_heads
146
+
147
+ self.wq = nn.Linear(config.dim, config.dim)
148
+ self.wk = nn.Linear(config.dim, config.dim)
149
+ self.wv = nn.Linear(config.dim, config.dim)
150
+
151
+ self.mask_mod = mask_mod
152
+ self.score_mod = score_mod
153
+ self.block_mask = create_block_mask(
154
+ mask_mod=self.mask_mod,
155
+ B=None, # Broadcast
156
+ H=None, # Broadcast
157
+ Q_LEN=config.seq_len,
158
+ KV_LEN=config.seq_len,
159
+ device=config.device,
160
+ )
161
+
162
+ self.o_proj = nn.Linear(config.dim, config.dim)
163
+ self.o_proj.SCALE_INIT = 1
164
+
165
+ def forward(
166
+ self,
167
+ x: torch.Tensor = None,
168
+ q: torch.Tensor = None,
169
+ k: torch.Tensor = None,
170
+ v: torch.Tensor = None,
171
+ freqs_cis: torch.Tensor = None,
172
+ ) -> torch.Tensor:
173
+ if x is not None:
174
+ q = k = v = x
175
+ if any(t is None for t in [q, k, v]):
176
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
177
+
178
+ bsz, q_len, _ = q.shape
179
+ _, k_len, _ = k.shape
180
+ _, v_len, _ = v.shape
181
+
182
+ Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim)
183
+ K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim)
184
+ V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim)
185
+
186
+ Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)
187
+
188
+ output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod)
189
+ output = output.reshape(bsz, q_len, self.dim)
190
+ output = self.o_proj(output)
191
+ return output
attn_masks.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.attention.flex_attention import _mask_mod_signature
3
+
4
+ def causal_mask(
5
+ batch_size: int,
6
+ num_heads: int,
7
+ q_idx: torch.Tensor,
8
+ kv_idx: torch.Tensor
9
+ ) -> torch.Tensor:
10
+ """
11
+ Returns a boolean tensor indicating which positions in the attention matrix
12
+ are valid for causal (autoregressive) attention. By default, it's True for
13
+ positions (i, j) where i >= j.
14
+
15
+ Args:
16
+ batch_size (int): Batch size (unused here).
17
+ num_heads (int): Number of heads (unused here).
18
+ q_idx (torch.Tensor): Tensor indexing the query positions.
19
+ kv_idx (torch.Tensor): Tensor indexing the key/value positions.
20
+
21
+ Returns:
22
+ torch.Tensor: A boolean tensor where True indicates that the query at
23
+ position i can attend to the key at position j, respecting i >= j.
24
+ """
25
+ return q_idx >= kv_idx
26
+
27
+
28
+ def generate_sliding_window_mask(window_size: int, causal: bool = True) -> _mask_mod_signature:
29
+ """
30
+ Creates a sliding window mask function.
31
+
32
+ If `causal=True`, each query token at position i can attend only to tokens j
33
+ in [i - window_size, i].
34
+ If `causal=False`, each query token i can attend to any token j in
35
+ [i - window_size, i + window_size], i.e. a symmetric window of size `window_size`.
36
+
37
+ Args:
38
+ window_size (int): The maximum distance from i that i can attend to.
39
+ causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
40
+
41
+ Returns:
42
+ _mask_mod_signature: A callable mask function that takes
43
+ (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
44
+ indicating allowed attention connections.
45
+ """
46
+ def sliding_window_mask(
47
+ batch_size: int,
48
+ num_heads: int,
49
+ q_idx: torch.Tensor,
50
+ kv_idx: torch.Tensor
51
+ ) -> torch.Tensor:
52
+ """
53
+ If causal is True:
54
+ within_window = (q_idx - kv_idx) <= window_size, and q_idx >= kv_idx.
55
+ If causal is False:
56
+ within_window = abs(q_idx - kv_idx) <= window_size.
57
+ """
58
+ if causal:
59
+ # standard "look back" window
60
+ distance = q_idx - kv_idx
61
+ within_window = (distance >= 0) & (distance <= window_size)
62
+ else:
63
+ # symmetrical window around i
64
+ distance = (q_idx - kv_idx).abs()
65
+ within_window = distance <= window_size
66
+
67
+ return within_window
68
+
69
+ name_ext = "causal" if causal else "noncausal"
70
+ sliding_window_mask.__name__ = f"sliding_window_{window_size}_{name_ext}"
71
+ return sliding_window_mask
72
+
73
+
74
+ def generate_dilated_sliding_window_mask(
75
+ window_size: int,
76
+ dilation: int = 2,
77
+ causal: bool = True
78
+ ) -> _mask_mod_signature:
79
+ """
80
+ Creates a dilated sliding window mask function.
81
+
82
+ If `causal=True`, each query token i can attend tokens j in [i - window_size, i]
83
+ such that (i - j) % dilation == 0.
84
+ If `causal=False`, each query token i can attend tokens j in [i - window_size,
85
+ i + window_size] for which |i - j| % dilation == 0.
86
+
87
+ Args:
88
+ window_size (int): The maximum distance from i to j (backwards if causal=True,
89
+ otherwise symmetric around i).
90
+ dilation (int): The stride for skipping positions.
91
+ causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True.
92
+
93
+ Returns:
94
+ _mask_mod_signature: A callable mask function that takes
95
+ (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor
96
+ indicating allowed attention connections.
97
+ """
98
+ def dilated_sliding_window_mask(
99
+ batch_size: int,
100
+ num_heads: int,
101
+ q_idx: torch.Tensor,
102
+ kv_idx: torch.Tensor
103
+ ) -> torch.Tensor:
104
+ """
105
+ If causal is True:
106
+ distance = q_idx - kv_idx
107
+ 0 <= distance <= window_size and distance % dilation == 0.
108
+ If causal is False:
109
+ distance = (q_idx - kv_idx).abs()
110
+ distance <= window_size and distance % dilation == 0.
111
+ """
112
+ if causal:
113
+ distance = q_idx - kv_idx
114
+ within_window = (distance >= 0) & (distance <= window_size)
115
+ else:
116
+ distance = (q_idx - kv_idx).abs()
117
+ within_window = distance <= window_size
118
+
119
+ meets_dilation = (distance % dilation) == 0
120
+ return within_window & meets_dilation
121
+
122
+ mode_str = "causal" if causal else "noncausal"
123
+ dilated_sliding_window_mask.__name__ = (
124
+ f"dilated_sliding_window_{window_size}_dilation_{dilation}_{mode_str}"
125
+ )
126
+ return dilated_sliding_window_mask
127
+
128
+
129
+ def main():
130
+ """
131
+ Demonstrates usage of each mask by printing attention grids. We include a few
132
+ basic checks to ensure the masks behave as expected. We show both the causal
133
+ and non-causal versions for the sliding window and dilated masks.
134
+ """
135
+ B, H = 1, 1
136
+ Q_LEN, KV_LEN = 8, 8
137
+
138
+ # coordinate grids
139
+ q_idx = torch.arange(Q_LEN).unsqueeze(-1).expand(Q_LEN, KV_LEN)
140
+ kv_idx = torch.arange(KV_LEN).unsqueeze(0).expand(Q_LEN, KV_LEN)
141
+
142
+ print("= Causal Mask =")
143
+ c_mask = causal_mask(B, H, q_idx, kv_idx)
144
+ print(c_mask.int(), "\n")
145
+
146
+ print("= Sliding Window (window_size=2, causal=True) =")
147
+ sw_causal_fn = generate_sliding_window_mask(window_size=2, causal=True)
148
+ sw_causal = sw_causal_fn(B, H, q_idx, kv_idx)
149
+ print(sw_causal.int(), "\n")
150
+
151
+ print("= Sliding Window (window_size=2, causal=False) =")
152
+ sw_noncausal_fn = generate_sliding_window_mask(window_size=2, causal=False)
153
+ sw_noncausal = sw_noncausal_fn(B, H, q_idx, kv_idx)
154
+ print(sw_noncausal.int(), "\n")
155
+
156
+ print("= Dilated Sliding Window (window_size=4, dilation=2, causal=True) =")
157
+ ds_causal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=True)
158
+ ds_causal = ds_causal_fn(B, H, q_idx, kv_idx)
159
+ print(ds_causal.int(), "\n")
160
+
161
+ print("= Dilated Sliding Window (window_size=4, dilation=2, causal=False) =")
162
+ ds_noncausal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=False)
163
+ ds_noncausal = ds_noncausal_fn(B, H, q_idx, kv_idx)
164
+ print(ds_noncausal.int(), "\n")
165
+
166
+ # Quick checks:
167
+ # (1) Causal means no i < j
168
+ assert torch.all(c_mask == (q_idx >= kv_idx)), "Causal mask mismatch!"
169
+ # (2) For windowed masks with causal=True, check a random row
170
+ i = 5
171
+ row_sw = sw_causal[i]
172
+ allowed_js = torch.where(row_sw)[0]
173
+ if len(allowed_js) > 0:
174
+ # difference i-j <= 2
175
+ assert (i - allowed_js.min()) <= 2, "Window mismatch for sliding_window_mask(causal=True)."
176
+
177
+ # (3) Dilated mask with causal=True should skip every other position if dilation=2
178
+ i = 6
179
+ row_ds = ds_causal[i]
180
+ allowed_js = torch.where(row_ds)[0]
181
+ for j in allowed_js:
182
+ diff = i - j
183
+ assert diff % 2 == 0, f"Dilation mismatch: got diff={diff}."
184
+
185
+ print("All checks passed.")
186
+
187
+ if __name__ == "__main__":
188
+ main()
attn_mods.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.attention.flex_attention import _score_mod_signature
4
+ from torch._inductor.lowering import make_pointwise, register_lowering
5
+
6
+ # Some internal torch.compile details
7
+ from torch._inductor.virtualized import ops
8
+ from functools import partial
9
+
10
+
11
+ @torch.library.custom_op("approx::tanh", mutates_args=())
12
+ def _tanh_approx(inp: Tensor) -> Tensor:
13
+ return torch.tanh(inp)
14
+
15
+
16
+ @_tanh_approx.register_fake
17
+ def _(inp: torch.Tensor) -> torch.Tensor:
18
+ return torch.tanh(inp)
19
+
20
+
21
+ def _tanh_approx_lowering(inp):
22
+ fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
23
+ return make_pointwise(fn)(inp)
24
+
25
+
26
+ register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
27
+
28
+
29
+ class _TanhApprox(torch.autograd.Function):
30
+ @staticmethod
31
+ def forward(x):
32
+ return torch.ops.approx.tanh(x)
33
+
34
+ @staticmethod
35
+ def setup_context(ctx, inputs, output):
36
+ (x,) = inputs
37
+ result = output
38
+ ctx.save_for_backward(result)
39
+
40
+ @staticmethod
41
+ def backward(ctx, grad_output):
42
+ (result,) = ctx.saved_tensors
43
+ return grad_output * (1 - result * result)
44
+
45
+ @staticmethod
46
+ def vmap(info, in_dims, x):
47
+ return torch.tanh(x), 0
48
+
49
+
50
+ _tanh_approx = _TanhApprox.apply
51
+
52
+
53
+ def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
54
+ """Returns an tanh bias score_mod given the number of heads H
55
+
56
+ Args:
57
+ soft_cap: The soft cap value to use for normalizing logits
58
+ approx: Whether to use the `tanh.approx.` ptx instruction
59
+
60
+ Returns:
61
+ tanh_softcap: score_mod
62
+ """
63
+ tanh = _tanh_approx if approx else torch.tanh
64
+
65
+ def tanh_softcap(score, b, h, q_idx, kv_idx):
66
+ return soft_cap * tanh(score / soft_cap)
67
+
68
+ prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
69
+ tanh_softcap.__name__ = f"{prefix}_{soft_cap}"
70
+
71
+ return tanh_softcap
72
+
73
+ def generate_alibi_bias(H: int) -> _score_mod_signature:
74
+ """Returns an alibi bias score_mod given the number of heads H
75
+
76
+ Args:
77
+ H: number of heads
78
+
79
+ Returns:
80
+ alibi_bias: alibi bias score_mod
81
+ """
82
+
83
+ def alibi_mod(score, b, h, q_idx, kv_idx):
84
+ scale = torch.exp2(-((h + 1) * 8.0 / H))
85
+ bias = (kv_idx - q_idx) * scale
86
+ return score + bias
87
+
88
+ return alibi_mod
89
+
90
+
91
+ def generate_tanh_softcap_alibi(H: int, soft_cap: float, approx: bool = False) -> _score_mod_signature:
92
+ """Returns a combined ALiBi and tanh softcapping score_mod.
93
+
94
+ Args:
95
+ H (int): number of heads for ALiBi scaling
96
+ soft_cap (float): the soft cap value for normalizing/logit clipping
97
+ approx (bool): Whether to use the 'tanh.approx' PTX-based approximation
98
+
99
+ Returns:
100
+ A combined score_mod function that first applies ALiBi,
101
+ then performs softcap + tanh (optionally approximate).
102
+ """
103
+ tanh_func = _tanh_approx if approx else torch.tanh
104
+
105
+ def alibi_tanh_softcap(score, b, h, q_idx, kv_idx):
106
+ # Compute ALiBi bias
107
+ scale = torch.exp2(-((h + 1) * 8.0 / H))
108
+ bias = (kv_idx - q_idx) * scale
109
+ score = score + bias
110
+
111
+ # Apply softcap
112
+ score = score / soft_cap
113
+
114
+ # Apply tanh
115
+ score = tanh_func(score)
116
+
117
+ # Rescale by soft_cap
118
+ score = score * soft_cap
119
+ return score
120
+
121
+ # Give the score_mod a unique name:
122
+ if approx:
123
+ alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_approx_{soft_cap}"
124
+ else:
125
+ alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_{soft_cap}"
126
+
127
+ return alibi_tanh_softcap
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "minitransformer",
3
+ "_name_or_path": "Transformer_500M",
4
+ "architectures": ["MiniTransformer"],
5
+ "dim": 896,
6
+ "num_heads": 8,
7
+ "num_layers": 12,
8
+ "seq_len": 8192,
9
+ "window_size": 8192,
10
+ "vocab_size": 200064,
11
+ "mlp_scale": 12,
12
+ "bias": false,
13
+ "dropout": 0.0,
14
+ "weight_tying": true,
15
+ "num_epochs": 1,
16
+ "global_bsz": 524288,
17
+ "bsz": 2,
18
+ "warmup_steps": 1907,
19
+ "eval_peruse_alibiiod": 50,
20
+ "save_period": 500,
21
+ "max_lr": 3.0e-4,
22
+ "min_lr": 3.0e-5,
23
+ "max_norm": 1.0,
24
+ "dilation": 1,
25
+ "fsdp": true,
26
+ "ddp": false,
27
+ "mixed_precision": true,
28
+ "torch_dtype": "bfloat16",
29
+ "cpu_offload": false,
30
+ "sharding_strategy": "full_shard",
31
+ "state_dict_type": "full",
32
+ "auto_wrap_policy": "partial",
33
+ "backward_prefetch": "backward_pre",
34
+ "forward_prefetch": false,
35
+ "sync_module_states": true,
36
+ "use_orig_params": true,
37
+ "device_id": null,
38
+ "precision": {
39
+ "param": "bfloat16",
40
+ "reduce": "bfloat16",
41
+ "buffer": "bfloat16"
42
+ },
43
+ "fsdp_modules": [
44
+ "AttentionLayer"
45
+ ],
46
+ "use_activation_checkpointing": true,
47
+ "softcap": 50.0,
48
+ "theta": 10000.0,
49
+ "use_alibi": false,
50
+ "torch_compile": false
51
+ }
configuration_minitransformer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig, AutoConfig
3
+
4
+ class MiniTransformerConfig(PretrainedConfig):
5
+ model_type = "minitransformer"
6
+
7
+ def __init__(
8
+ self,
9
+ bsz: int = 1,
10
+ dim: int = 896,
11
+ num_heads: int = 8,
12
+ num_layers: int = 12,
13
+ seq_len: int = 8192,
14
+ window_size: int = 8192,
15
+ vocab_size: int = 200064,
16
+ mlp_scale: int = 12,
17
+ bias: bool = False,
18
+ dropout: float = 0.0,
19
+ softcap: float = 50.0,
20
+ theta: float = 10_000.0,
21
+ use_alibi: bool = False,
22
+ torch_dtype: torch.dtype = torch.bfloat16,
23
+ device: torch.device = None,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.bsz = bsz
28
+ self.dim = dim
29
+ self.num_heads = num_heads
30
+ self.num_layers = num_layers
31
+ self.seq_len = seq_len
32
+ self.window_size = window_size
33
+ self.vocab_size = vocab_size
34
+ self.hidden_size = dim
35
+ self.mlp_scale = mlp_scale
36
+ self.intermediate_size = self.dim * self.mlp_scale
37
+ self.bias = bias
38
+ self.dropout = dropout
39
+ self.softcap = softcap
40
+ self.theta = theta
41
+ self.use_alibi = use_alibi
42
+ self.torch_dtype = torch_dtype
43
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
44
+
convolve.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .utils import nearest_power_of_two
5
+ from flashfftconv import FlashFFTConv
6
+
7
+
8
+ def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
9
+ bsz, seq_len, d_in = u.shape
10
+
11
+ sgn = torch.full((1, seq_len, 1), 1, device=u.device, dtype=torch.float32)
12
+ sgn[:, 1::2] *= -1
13
+
14
+ # Cast u and v to float32 for FFT
15
+ u = u.to(torch.float32)
16
+ v = v.to(torch.float32)
17
+
18
+ if use_approx:
19
+ _, d_out = v.shape
20
+ v = v.view(1, -1, d_out, 1)
21
+ else:
22
+ _, K = v.shape
23
+ sgn = sgn.unsqueeze(-1)
24
+ v = v.view(1, -1, K, 1, 1)
25
+ u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
26
+
27
+ v = torch.fft.rfft(v, n=n, dim=1)
28
+ U = torch.stack([u, u * sgn], dim=-1)
29
+ U = torch.fft.rfft(U, n=n, dim=1)
30
+ U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
31
+ U_plus, U_minus = torch.unbind(U_conv, dim=-1)
32
+ U_minus = U_minus * sgn
33
+
34
+ # Convert back to original dtype
35
+ U_plus = U_plus.to(u.dtype)
36
+ U_minus = U_minus.to(u.dtype)
37
+
38
+ return U_plus, U_minus
39
+
40
+ def flash_convolve(
41
+ u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
42
+ ) -> tuple[torch.Tensor, torch.Tensor]:
43
+ dtype = u.dtype # Store the original dtype
44
+ u = u.to(torch.float32)
45
+ v = v.to(torch.float32)
46
+
47
+ bsz, seq_len, d_in = u.shape
48
+ _, K = v.shape
49
+
50
+ padded_len = nearest_power_of_two(seq_len, round_up=True)
51
+ pad_len = padded_len - seq_len
52
+
53
+ sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=torch.float32)
54
+ sgn[:, :, 1::2] = -1
55
+
56
+ if use_approx:
57
+ u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).contiguous()
58
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).contiguous()
59
+ u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
60
+ else:
61
+ u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1).contiguous()
62
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).repeat(d_in, 1).contiguous()
63
+ u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
64
+
65
+ U_conv = flash_fft(u_conv, v_padded)
66
+
67
+ # Trim the output back to the original sequence length
68
+ U_conv = U_conv[..., :seq_len]
69
+
70
+ u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
71
+
72
+ if use_approx:
73
+ u_minus = u_minus * sgn[:, :, :seq_len]
74
+ U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
75
+ else:
76
+ sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
77
+ U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
78
+ U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
79
+
80
+ # Convert back to original dtype
81
+ U_plus = U_plus.to(dtype)
82
+ U_minus = U_minus.to(dtype)
83
+
84
+ return U_plus, U_minus
filters.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .utils import logger
7
+ from .utils import get_hankel
8
+
9
+ def get_spectral_filters(
10
+ seq_len: int,
11
+ K: int,
12
+ use_hankel_L: bool = False,
13
+ device: torch.device = None,
14
+ dtype: torch.dtype = torch.bfloat16,
15
+ ) -> torch.Tensor:
16
+ # Generate the Hankel matrix using PyTorch
17
+ Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype)
18
+
19
+ # Cast Z to torch.float32 for the eigenvalue decomposition
20
+ Z_float32 = Z.to(torch.float32)
21
+
22
+ # Perform eigen decomposition using torch.float32
23
+ sigma, phi = torch.linalg.eigh(Z_float32)
24
+
25
+ # Cast the results back to the original dtype (torch.bfloat16)
26
+ sigma = sigma.to(dtype=dtype)
27
+ phi = phi.to(dtype=dtype)
28
+
29
+ # Select the top K eigenvalues and eigenvectors
30
+ sigma_k, phi_k = sigma[-K:], phi[:, -K:]
31
+
32
+ # Compute the spectral filters
33
+ phi_k = phi_k * sigma_k ** 0.25
34
+
35
+ # Ensure the filters are in the correct dtype and device
36
+ filters = phi_k.to(device=device, dtype=dtype)
37
+
38
+ return filters
39
+
40
+
41
+ def compute_dimensions(n: int) -> tuple[int, int, int]:
42
+ if n <= 2:
43
+ raise ValueError("n must be greater than 2")
44
+
45
+ T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2
46
+ sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2))
47
+ k_max = sqrt_T_prime
48
+ return T_prime, sqrt_T_prime, k_max
49
+
50
+ def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor:
51
+ T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
52
+ k = min(k, k_max)
53
+
54
+ Z = get_hankel(sqrt_T_prime).to(device)
55
+ sigma, phi = torch.linalg.eigh(Z)
56
+ sigma_k = sigma[-k:]
57
+ phi_k = phi[:, -k:]
58
+
59
+ result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device)
60
+
61
+ for i in range(k):
62
+ for j in range(k):
63
+ phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25)
64
+ phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25)
65
+ kron = torch.kron(phi_i, phi_j)
66
+ result += kron
67
+
68
+ return result
69
+
70
+
71
+ def get_tensorized_spectral_filters(
72
+ n: int = 8192,
73
+ k: int = 24,
74
+ use_hankel_L: bool = False,
75
+ device: torch.device = None,
76
+ dtype: torch.dtype = torch.bfloat16,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Compute tensorized spectral filters for given sequence length and filter count.
80
+
81
+ Args:
82
+ n: Sequence length
83
+ k: Number of filters
84
+ use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main.
85
+ device: Computation device
86
+ dtype: Computation dtype
87
+ """
88
+ assert torch.cuda.is_available(), "CUDA is required."
89
+
90
+ T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
91
+ k = min(k, k_max)
92
+
93
+ Z = get_hankel(sqrt_T_prime)
94
+ sigma, phi = torch.linalg.eigh(Z)
95
+ phi_i = phi[:, -k:] * sigma[-k:] ** 0.25
96
+
97
+ if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L)
98
+ logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.")
99
+ Z_L = get_hankel(sqrt_T_prime, True)
100
+ sigma_L, phi_L = torch.linalg.eigh(Z_L)
101
+ phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25
102
+ else:
103
+ phi_j = phi_i
104
+
105
+ filters = torch.kron(phi_i, phi_j)
106
+ return filters.to(device=device, dtype=dtype)
layers.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .attn import FlexAttention
5
+ from .modules import MLP
6
+ from .modules import Attention
7
+ try:
8
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
9
+ triton_mlp = True
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
13
+ )
14
+ triton_mlp = False
15
+
16
+ try:
17
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
18
+ triton_norm = True
19
+ except ImportError as e:
20
+ print(
21
+ f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
22
+ )
23
+ from torch.nn import RMSNorm
24
+ triton_norm = False
25
+
26
+ class AttentionLayer(nn.Module):
27
+ def __init__(self, config, mask_mod, score_mod=None) -> None:
28
+ super(AttentionLayer, self).__init__()
29
+ self.attn_norm = nn.RMSNorm(config.dim)
30
+ self.attn = FlexAttention(
31
+ config=config,
32
+ mask_mod=mask_mod,
33
+ score_mod=score_mod,
34
+ )
35
+ self.mlp_norm = nn.RMSNorm(config.dim)
36
+ self.mlp = MLP(config)
37
+
38
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
39
+ x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis)
40
+ x = x + self.mlp(self.mlp_norm(x))
41
+ return x
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mlp.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+ import torch
4
+ class MLP(nn.Module):
5
+ def __init__(self, config):
6
+ # https://arxiv.org/pdf/2002.05202
7
+ super().__init__()
8
+ self.hidden_size = config.dim
9
+ self.intermediate_size = config.dim * config.mlp_scale
10
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
11
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
12
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
13
+ self.dropout = nn.Dropout(config.dropout)
14
+
15
+ def forward(self, x):
16
+ gate = self.gate_proj(x)
17
+ gate = F.gelu(gate, approximate="tanh")
18
+ up = self.up_proj(x)
19
+ fuse = gate * up
20
+ outputs = self.down_proj(fuse)
21
+ outputs = self.dropout(outputs)
22
+ return outputs
modeling_minitransformer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutput
7
+
8
+ from .modules import Attention
9
+ from .utils import nearest_power_of_two
10
+ from .layers import AttentionLayer
11
+ from .configuration_minitransformer import MiniTransformerConfig
12
+
13
+ from .attn_masks import causal_mask
14
+ from .attn_mods import generate_tanh_softcap
15
+ from .rotary_emb import precompute_freqs_cis
16
+
17
+ try:
18
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
19
+ triton_norm = True
20
+ except ImportError as e:
21
+ print(
22
+ f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
23
+ )
24
+ from torch.nn import RMSNorm
25
+ triton_norm = False
26
+ # Load the tokenizer
27
+
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ model_name = "Hazan-Lab/Transformer_500M"
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ model_name,
32
+ trust_remote_code=True
33
+ )
34
+
35
+ class MiniTransformer(PreTrainedModel):
36
+ config_class = MiniTransformerConfig
37
+
38
+ def __init__(self, config) -> None:
39
+ super(MiniTransformer, self).__init__(config)
40
+ self.num_layers = config.num_layers
41
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
42
+ self.head_dim = config.dim // config.num_heads
43
+ logit_softcap = generate_tanh_softcap(soft_cap=config.softcap)
44
+
45
+ # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
46
+ self.register_buffer("freqs_cis", precompute_freqs_cis(
47
+ head_dim=self.head_dim,
48
+ max_seq_len=config.seq_len,
49
+ theta=config.theta,
50
+ ), persistent=True)
51
+
52
+ self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
53
+ self.dropout = nn.Dropout(config.dropout)
54
+
55
+ self.layers = nn.ModuleList()
56
+ for _ in range(self.num_layers):
57
+ layer = AttentionLayer(config, mask_mod=causal_mask, score_mod=logit_softcap)
58
+ self.layers.append(layer)
59
+
60
+ self.norm = nn.RMSNorm(config.dim)
61
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
62
+ # self.tok_emb.weight = self.lm_head.weight
63
+
64
+ self.std = (config.dim) ** -0.5
65
+ self.apply(self._init_weights)
66
+ print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: torch.Tensor,
71
+ labels: torch.Tensor = None,
72
+ **kwargs
73
+ ) -> CausalLMOutput:
74
+ # Compute embeddings
75
+ tok_emb = self.tok_emb(input_ids)
76
+
77
+ for layer in self.layers:
78
+ tok_emb = layer(tok_emb, self.freqs_cis)
79
+
80
+ # Normalize and project to vocabulary
81
+ tok_emb = self.norm(tok_emb)
82
+ logits = self.lm_head(tok_emb)
83
+
84
+ loss = None
85
+ if labels is not None:
86
+ # Shift so that tokens predict the next token
87
+ shift_logits = logits[..., :-1, :].contiguous()
88
+ shift_labels = labels[..., 1:].contiguous()
89
+ loss_fct = nn.CrossEntropyLoss()
90
+ loss = loss_fct(
91
+ shift_logits.view(-1, shift_logits.size(-1)),
92
+ shift_labels.view(-1)
93
+ )
94
+
95
+ return CausalLMOutput(
96
+ loss=loss,
97
+ logits=logits,
98
+ )
99
+
100
+ def _get_num_params(self):
101
+ n_params = sum(p.numel() for p in self.parameters())
102
+ if hasattr(self, "pos_emb") and self.pos_emb is not None:
103
+ n_params -= self.pos_emb.weight.numel()
104
+ if self.tok_emb.weight is self.lm_head.weight:
105
+ n_params -= self.tok_emb.weight.numel()
106
+ return n_params
107
+
108
+ def _init_weights(self, module):
109
+ if isinstance(module, nn.Linear):
110
+ if hasattr(module, "SCALE_INIT"):
111
+ self.std *= (2 * self.num_layers) ** -0.5
112
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
113
+ if module.bias is not None:
114
+ torch.nn.init.zeros_(module.bias)
115
+ elif isinstance(module, nn.Embedding):
116
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
117
+
118
+ @staticmethod
119
+ def top_k_top_p_filtering(
120
+ logits: torch.Tensor,
121
+ top_k: int = 50,
122
+ top_p: float = 0.95,
123
+ filter_value: float = float("-inf"),
124
+ ):
125
+ """
126
+ Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
127
+ """
128
+ # top_k
129
+ if top_k > 0:
130
+ top_k = min(top_k, logits.size(-1))
131
+ # Remove all logits that are not in the top k
132
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
133
+ logits[indices_to_remove] = filter_value
134
+
135
+ # top_p (nucleus)
136
+ if 0 < top_p < 1.0:
137
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
138
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
139
+
140
+ # Remove tokens with cumulative probability above the threshold
141
+ sorted_indices_to_remove = cumulative_probs > top_p
142
+ # Shift the indices to the right to keep also the first token above the threshold
143
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
144
+ sorted_indices_to_remove[:, 0] = False
145
+
146
+ indices_to_remove = sorted_indices_to_remove.scatter(
147
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
148
+ )
149
+ logits[indices_to_remove] = filter_value
150
+
151
+ return logits
152
+
153
+ def generate(
154
+ self,
155
+ input_ids: torch.LongTensor,
156
+ max_new_tokens: int = 50,
157
+ temperature: float = 0.5,
158
+ top_k: int = 50,
159
+ top_p: float = 0.95,
160
+ eos_token_id: int = None,
161
+ pad_token_id: int = 0,
162
+ **kwargs
163
+ ):
164
+ """
165
+ Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
166
+
167
+ Args:
168
+ input_ids (torch.LongTensor): shape (batch_size, sequence_length).
169
+ max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
170
+ temperature (float): sampling temperature (>=0).
171
+ top_k (int): Top-K sampling cutoff.
172
+ top_p (float): Nucleus sampling cutoff.
173
+ eos_token_id (int): If set, stop generation when this token is produced.
174
+ pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
175
+ kwargs: Unused arguments (like num_beams) for compatibility.
176
+
177
+ Returns:
178
+ torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
179
+ """
180
+ device = input_ids.device
181
+ print("1=====================")
182
+ print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
183
+ print("1=====================")
184
+
185
+ # We'll accumulate new tokens into generated_ids
186
+ generated_ids = input_ids.clone()
187
+
188
+ for _ in range(max_new_tokens):
189
+ # Forward pass to get logits for the last token
190
+ outputs = self.forward(generated_ids)
191
+ logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
192
+
193
+ # Scale logits by temperature
194
+ if temperature != 1.0:
195
+ logits = logits / temperature
196
+
197
+ # Filter logits using top-k and/or top-p
198
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
199
+
200
+ # Convert to probabilities
201
+ probabilities = F.softmax(logits, dim=-1)
202
+
203
+ # Sample from the distribution
204
+ next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
205
+
206
+ # Append next token
207
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
208
+
209
+ # If eos_token_id is set and any sample produced it, we optionally could break early
210
+ if eos_token_id is not None:
211
+ # Check if all sequences in the batch ended
212
+ # or if you want to do a more fine-grained approach
213
+ if (next_token == eos_token_id).all():
214
+ break
215
+ print("2=====================")
216
+ print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
217
+ print("2=====================")
218
+ return generated_ids
modules.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .attn import Attention
2
+ from .attn import AttentionSDPA
3
+ from .mlp import MLP
4
+ from .stu import STU
5
+
6
+
rotary_emb.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
4
+ # For half the dimensions, build the scale factor:
5
+ freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
6
+ freqs = 1.0 / (theta ** freq_seq)
7
+
8
+ # Outer product with positions
9
+ t = torch.arange(max_seq_len, dtype=torch.float32)
10
+ angles = torch.outer(t, freqs)
11
+
12
+ # Build a complex exponential e^{i * theta}
13
+ freqs_cis = torch.polar(
14
+ torch.ones_like(angles),
15
+ angles
16
+ )
17
+ return freqs_cis
18
+
19
+
20
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
21
+ """
22
+ x is [B, n_heads, seq_len, head_dim_as_complex],
23
+ so we want to broadcast freqs_cis from [max_seq_len, half_dim]
24
+ to [1, 1, seq_len, half_dim].
25
+ """
26
+ seq_len = x.shape[2]
27
+ freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
28
+ return freqs_cis.view(1, 1, seq_len, -1)
29
+
30
+
31
+ def apply_rotary_emb(
32
+ xq: torch.Tensor,
33
+ xk: torch.Tensor,
34
+ freqs_cis: torch.Tensor,
35
+ ) -> tuple[torch.Tensor, torch.Tensor]:
36
+ # Convert real -> complex by grouping last dim in pairs
37
+ # shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
38
+ xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
39
+ xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
40
+
41
+ # Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
42
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
43
+
44
+ # Multiply => apply rotation
45
+ xq_complex = xq_complex * freqs_cis
46
+ xk_complex = xk_complex * freqs_cis
47
+
48
+ # Convert back to real => shape [B, n_heads, seq_len, head_dim]
49
+ xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
50
+ xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
51
+ return xq_out.type_as(xq), xk_out.type_as(xk)
52
+
53
+
54
+ def main():
55
+ import math
56
+ from torch.testing import assert_close
57
+
58
+ # Test 1: No rotation at position 0
59
+ dim = 2
60
+ freqs_cis = precompute_freqs_cis(dim=dim, max_seq_len=1, theta=1.0)
61
+ xq = torch.tensor([[[[1.0, 0.0]]]])
62
+ xq_out, _ = apply_rotary_emb(xq, xq.clone(), freqs_cis)
63
+ assert_close(xq_out, xq, msg="Test 1 failed")
64
+ print("Test 1 passed.")
65
+
66
+ # Test 2: Verify rotation at positions [0..4] in 2D
67
+ L = 5
68
+ freqs_cis = precompute_freqs_cis(dim=dim, max_seq_len=L, theta=1.0)
69
+ xq = torch.tensor([[[[1.0, 0.0] for _ in range(L)]]])
70
+ xq_out, _ = apply_rotary_emb(xq, xq.clone(), freqs_cis)
71
+ expected = torch.tensor([[[[math.cos(p), math.sin(p)] for p in range(L)]]])
72
+ assert_close(xq_out, expected, rtol=1e-6, atol=1e-6, msg="Test 2 failed")
73
+ print("Test 2 passed.")
74
+
75
+ # Test 3: Higher dimension at position 0
76
+ xq = torch.tensor([[[[1.0, 0.0, 1.0, 0.0]]]])
77
+ freqs_cis = precompute_freqs_cis(dim=4, max_seq_len=1, theta=1.0)
78
+ xq_out, _ = apply_rotary_emb(xq, xq.clone(), freqs_cis)
79
+ assert_close(xq_out, xq, msg="Test 3 failed")
80
+ print("Test 3 passed.")
81
+
82
+ # Test 4: Random shape & norm checks
83
+ torch.manual_seed(1337)
84
+ B, H, L, D = 2, 3, 5, 8
85
+ xq = torch.randn(B, H, L, D)
86
+ xk = torch.randn(B, H, L, D)
87
+ freqs_cis = precompute_freqs_cis(dim=D, max_seq_len=L, theta=1.0)
88
+ xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cis)
89
+ assert xq_out.shape == (B, H, L, D), "Test 4 Q shape failed"
90
+ assert xk_out.shape == (B, H, L, D), "Test 4 K shape failed"
91
+ for b in range(B):
92
+ for h in range(H):
93
+ for l in range(L):
94
+ assert torch.allclose(xq[b,h,l].norm(), xq_out[b,h,l].norm(), atol=1e-5), "Test 4 Q norm failed"
95
+ assert torch.allclose(xk[b,h,l].norm(), xk_out[b,h,l].norm(), atol=1e-5), "Test 4 K norm failed"
96
+ print("Test 4 passed.\nAll tests passed successfully!")
97
+
98
+ if __name__ == "__main__":
99
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
stu.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .convolve import convolve, flash_convolve
5
+
6
+ try:
7
+ from flashfftconv import FlashFFTConv
8
+
9
+ flash_fft_available = True
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
13
+ )
14
+ flash_fft_available = False
15
+
16
+
17
+ class STU(nn.Module):
18
+ def __init__(self, config, phi, n) -> None:
19
+ super(STU, self).__init__()
20
+ self.config = config
21
+ if isinstance(config.torch_dtype, str):
22
+ torch_dtype = getattr(torch, config.torch_dtype)
23
+ else:
24
+ torch_dtype = config.torch_dtype
25
+ self.phi = phi.to(device=config.device, dtype=torch_dtype)
26
+ self.n = n
27
+ self.K = config.num_eigh
28
+ self.d_in = config.n_embd
29
+ self.d_out = config.n_embd
30
+ self.use_hankel_L = config.use_hankel_L
31
+ self.use_approx = config.use_approx
32
+ self.flash_fft = (
33
+ FlashFFTConv(self.n, dtype=torch.bfloat16)
34
+ if config.use_flash_fft and flash_fft_available
35
+ else None
36
+ )
37
+ if self.use_approx:
38
+ self.M_inputs = nn.Parameter(
39
+ torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
40
+ )
41
+ self.M_filters = nn.Parameter(
42
+ torch.empty(self.K, self.d_in, dtype=torch_dtype)
43
+ )
44
+ else:
45
+ self.M_phi_plus = nn.Parameter(
46
+ torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
47
+ )
48
+ if not self.use_hankel_L:
49
+ self.M_phi_minus = nn.Parameter(
50
+ torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ dtype = self.M_inputs.dtype
55
+ x = x.to(dtype=dtype)
56
+ if self.use_approx:
57
+ # Contract inputs and filters over the K and d_in dimensions, then convolve
58
+ x_proj = x @ self.M_inputs
59
+ phi_proj = self.phi @ self.M_filters
60
+ x_proj = x_proj.to(dtype=dtype)
61
+ phi_proj = phi_proj.to(dtype=dtype)
62
+ if self.flash_fft:
63
+ spectral_plus, spectral_minus = flash_convolve(
64
+ x_proj, phi_proj, self.flash_fft, self.use_approx
65
+ )
66
+ else:
67
+ spectral_plus, spectral_minus = convolve(
68
+ x_proj, phi_proj, self.n, self.use_approx
69
+ )
70
+ else:
71
+ # Convolve inputs and filters,
72
+ if self.flash_fft:
73
+ U_plus, U_minus = flash_convolve(
74
+ x, self.phi, self.flash_fft, self.use_approx
75
+ )
76
+ else:
77
+ U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
78
+ # Then, contract over the K and d_in dimensions
79
+ spectral_plus = torch.tensordot(
80
+ U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
81
+ )
82
+ if not self.use_hankel_L:
83
+ spectral_minus = torch.tensordot(
84
+ U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
85
+ )
86
+
87
+ return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "199999": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200018": {
13
+ "content": "<|endofprompt|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|endoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "eos_token": "<|endoftext|>",
24
+ "model_max_length": 128000,
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": "<|endoftext|>"
27
+ }
utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ import logging
6
+ import os
7
+ import sys
8
+ from colorama import Fore, Style, init
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+ init(autoreset=True)
13
+
14
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
15
+ return (
16
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
17
+ )
18
+
19
+ def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
20
+ entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device)
21
+ i_plus_j = entries[:, None] + entries[None, :]
22
+
23
+ if use_hankel_L:
24
+ sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
25
+ denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
26
+ Z = sgn * (8.0 / denom)
27
+ elif not use_hankel_L:
28
+ Z = 2.0 / (i_plus_j**3 - i_plus_j)
29
+ else:
30
+ raise ValueError("use_hankel_L must be a boolean")
31
+
32
+ return Z
33
+
34
+
35
+ class ColorFormatter(logging.Formatter):
36
+ """
37
+ A custom log formatter that applies color based on the log level using the Colorama library.
38
+
39
+ Attributes:
40
+ LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes.
41
+ """
42
+
43
+ # Colors for each log level
44
+ LOG_COLORS = {
45
+ logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT,
46
+ logging.INFO: Fore.CYAN,
47
+ logging.WARNING: Fore.YELLOW + Style.BRIGHT,
48
+ logging.ERROR: Fore.RED + Style.BRIGHT,
49
+ logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL,
50
+ }
51
+
52
+ # Colors for other parts of the log message
53
+ TIME_COLOR = Fore.GREEN
54
+ FILE_COLOR = Fore.BLUE
55
+ LEVEL_COLOR = Style.BRIGHT
56
+
57
+ def __init__(self, fmt=None):
58
+ super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S")
59
+
60
+ def format(self, record):
61
+ """
62
+ Formats a log record with the appropriate color based on the log level.
63
+
64
+ Args:
65
+ record (logging.LogRecord): The log record to format.
66
+
67
+ Returns:
68
+ str: The formatted log message with colors applied.
69
+ """
70
+ # Apply color based on the log level
71
+ level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE)
72
+ time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}"
73
+ levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}"
74
+ file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}"
75
+
76
+ # Format the log message with color
77
+ log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}"
78
+ return log_msg
79
+
80
+ def setup_logger():
81
+ """
82
+ Sets up a logger with a custom color formatter that logs to standard output (stdout).
83
+
84
+ The logger is configured with the ColorFormatter to format log messages with color based on the log level.
85
+ The log level is set to INFO by default, but this can be changed to show more or less detailed messages.
86
+
87
+ Returns:
88
+ logging.Logger: A logger instance that logs formatted messages to stdout.
89
+ """
90
+ handler = logging.StreamHandler(sys.stdout)
91
+
92
+ # Set custom formatter
93
+ formatter = ColorFormatter()
94
+ handler.setFormatter(formatter)
95
+ logger = logging.getLogger(__name__)
96
+
97
+ # Set to DEBUG to capture all logging levels
98
+ DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t")
99
+ logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO)
100
+ logger.addHandler(handler)
101
+ logger.propagate = False # Prevents multiple logging if re-initialized
102
+
103
+ return logger
104
+
105
+ logger = setup_logger() # Initialize once to prevent multiple loggers
vocab.json ADDED
The diff for this file is too large to render. See raw diff