shubham212 commited on
Commit
2db8e8e
·
verified ·
1 Parent(s): daeebfa

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # from .modeling_my_model import GPT, GPTConfig
2
+
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+
5
+ from .configuration_my_model import GPTConfig
6
+ from .modeling_my_model import GPT
7
+
8
+ AutoConfig.register("custom_gpt", GPTConfig)
9
+ AutoModelForCausalLM.register(GPTConfig, GPT)
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (253 Bytes). View file
 
__pycache__/modeling_my_model.cpython-312.pyc ADDED
Binary file (34.3 kB). View file
 
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_gpt",
3
+ "architectures": ["GPT"],
4
+
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_my_model.GPTConfig",
7
+ "AutoModelForCausalLM": "modeling_my_model.GPT"
8
+ },
9
+
10
+ "block_size": 1024,
11
+ "vocab_size": 50304,
12
+ "n_layer": 24,
13
+ "n_head": 16,
14
+ "n_embd": 1024,
15
+ "dropout": 0.0,
16
+ "bias": false,
17
+
18
+ "hc_num_streams": 1,
19
+ "hc_num_fracs": 1,
20
+ "hc_disable": true,
21
+ "mhc": false,
22
+ "sinkhorn_iters": 10,
23
+ "sinkhorn_tau": 0.05,
24
+ "mhc_h_res_proj": "sinkhorn",
25
+ "ns_steps": 5,
26
+ "ns_eps": 1e-7,
27
+ "ns_coeffs": [3.0, -3.2, 1.2]
28
+ }
configuration_my_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class GPTConfig(PretrainedConfig):
4
+ model_type = "custom_gpt"
5
+
6
+ def __init__(
7
+ self,
8
+ block_size=1024,
9
+ vocab_size=50304,
10
+ n_layer=12,
11
+ n_head=12,
12
+ n_embd=768,
13
+ dropout=0.0,
14
+ bias=True,
15
+ hc_num_streams=1,
16
+ hc_num_fracs=1,
17
+ hc_disable=False,
18
+ mhc=False,
19
+ sinkhorn_iters=10,
20
+ sinkhorn_tau=0.05,
21
+ mhc_h_res_proj="sinkhorn",
22
+ ns_steps=5,
23
+ ns_eps=1e-7,
24
+ ns_coeffs=(3.0, -3.2, 1.2),
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.block_size = block_size
30
+ self.vocab_size = vocab_size
31
+ self.n_layer = n_layer
32
+ self.n_head = n_head
33
+ self.n_embd = n_embd
34
+ self.dropout = dropout
35
+ self.bias = bias
36
+
37
+ self.hc_num_streams = hc_num_streams
38
+ self.hc_num_fracs = hc_num_fracs
39
+ self.hc_disable = hc_disable
40
+ self.mhc = mhc
41
+ self.sinkhorn_iters = sinkhorn_iters
42
+ self.sinkhorn_tau = sinkhorn_tau
43
+ self.mhc_h_res_proj = mhc_h_res_proj
44
+ self.ns_steps = ns_steps
45
+ self.ns_eps = ns_eps
46
+ self.ns_coeffs = ns_coeffs
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_my_model.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+ from typing import Callable
11
+ from transformers.generation.utils import GenerationMixin
12
+ from functools import partial
13
+ from random import randrange
14
+ import math
15
+
16
+ import torch
17
+ from torch import nn, cat
18
+ import torch.nn.functional as F
19
+ from torch.nn import Module, Sequential
20
+ from torch.utils._pytree import tree_flatten, tree_unflatten
21
+
22
+ from einops import rearrange, repeat, reduce, einsum
23
+ from einops.layers.torch import Rearrange, Reduce
24
+ from .configuration_my_model import GPTConfig
25
+
26
+ """
27
+ ein notation:
28
+ b - batch
29
+ d - feature dimension
30
+ s - residual streams
31
+ t - residual streams + num branch inputs
32
+ f - number of fractions (division of feature dimension space)
33
+ v - number of views for branch input
34
+ """
35
+
36
+ # helper functions
37
+
38
+
39
+ def exists(v):
40
+ return v is not None
41
+
42
+
43
+ def divisible_by(num, den):
44
+ return (num % den) == 0
45
+
46
+
47
+ def default(v, d):
48
+ return v if exists(v) else d
49
+
50
+
51
+ def identity(t):
52
+ return t
53
+
54
+
55
+ def add(x, y):
56
+ return x + y
57
+
58
+
59
+ def sinkhorn_log(logits, num_iters=10, tau=0.05):
60
+ n = logits.shape[-1]
61
+ Z = logits / tau
62
+ log_marginal = torch.full(
63
+ (n,), -math.log(n), device=logits.device, dtype=logits.dtype
64
+ )
65
+
66
+ u = torch.zeros(n, device=Z.device, dtype=Z.dtype)
67
+ v = torch.zeros(n, device=Z.device, dtype=Z.dtype)
68
+
69
+ for _ in range(num_iters):
70
+ u = log_marginal - torch.logsumexp(Z + v.unsqueeze(0), dim=1)
71
+ v = log_marginal - torch.logsumexp(Z + u.unsqueeze(1), dim=0)
72
+
73
+ return torch.exp(Z + u.unsqueeze(1) + v.unsqueeze(0)) * n
74
+
75
+
76
+ def zeropower_via_newtonschulz(X, steps=5, eps=1e-7, coeffs=(3.0, -3.2, 1.2)):
77
+ a, b, c = coeffs
78
+
79
+ X = X / (X.norm() + eps)
80
+
81
+ transpose = False
82
+ if X.shape[0] > X.shape[1]:
83
+ X = X.T
84
+ transpose = True
85
+
86
+ for _ in range(steps):
87
+ A = X @ X.T
88
+ B = b * A + c * A @ A
89
+ X = a * X + B @ X
90
+
91
+ if transpose:
92
+ X = X.T
93
+
94
+ return X
95
+
96
+
97
+ def orthostochastic_project(
98
+ logits, ns_steps=5, ns_eps=1e-7, ns_coeffs=(3.0, -3.2, 1.2)
99
+ ):
100
+ O = zeropower_via_newtonschulz(logits, steps=ns_steps, eps=ns_eps, coeffs=ns_coeffs)
101
+ return O.square()
102
+
103
+
104
+ # main functions
105
+
106
+
107
+ def get_expand_reduce_stream_functions(
108
+ num_streams, add_stream_embed=False, dim=None, disable=False
109
+ ):
110
+ if num_streams == 1 or disable:
111
+ return (nn.Identity(), nn.Identity())
112
+
113
+ if add_stream_embed:
114
+ assert exists(dim), (
115
+ "`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added"
116
+ )
117
+
118
+ expand_fn = StreamEmbed(num_streams, dim, expand_to_streams=True)
119
+ else:
120
+ expand_fn = Reduce(
121
+ pattern="b ... -> (b s) ...", reduction="repeat", s=num_streams
122
+ )
123
+
124
+ reduce_fn = Reduce(pattern="(b s) ... -> b ...", reduction="sum", s=num_streams)
125
+
126
+ return expand_fn, reduce_fn
127
+
128
+
129
+ def get_init_and_expand_reduce_stream_functions(
130
+ num_streams, num_fracs=1, dim=None, add_stream_embed=False, disable=None
131
+ ):
132
+ disable = default(disable, num_streams == 1 and num_fracs == 1)
133
+
134
+ hyper_conn_klass = HyperConnections if not disable else Residual
135
+
136
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs=num_fracs)
137
+ expand_reduce_fns = get_expand_reduce_stream_functions(
138
+ num_streams, add_stream_embed=add_stream_embed, dim=dim, disable=disable
139
+ )
140
+
141
+ if exists(dim):
142
+ init_hyper_conn_fn = partial(init_hyper_conn_fn, dim=dim)
143
+
144
+ return (init_hyper_conn_fn, *expand_reduce_fns)
145
+
146
+
147
+ # norms
148
+
149
+
150
+ class RMSNorm(Module):
151
+ def __init__(self, dim):
152
+ super().__init__()
153
+ self.scale = dim**0.5
154
+ self.gamma = nn.Parameter(torch.zeros(dim))
155
+
156
+ def forward(self, x):
157
+ return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
158
+
159
+
160
+ # main classes
161
+
162
+ # residual base class
163
+
164
+
165
+ class Residual(Module):
166
+ def __init__(
167
+ self,
168
+ *args,
169
+ branch: Module | None = None,
170
+ residual_transform: Module | None = None,
171
+ **kwargs,
172
+ ):
173
+ super().__init__()
174
+ self.branch = branch
175
+ self.residual_transform = default(residual_transform, nn.Identity())
176
+
177
+ def width_connection(self, residuals):
178
+ return residuals, residuals, dict()
179
+
180
+ def depth_connection(
181
+ self,
182
+ branch_output,
183
+ residuals,
184
+ ):
185
+ return branch_output + self.residual_transform(residuals)
186
+
187
+ def decorate_branch(self, branch: Callable):
188
+ assert not exists(self.branch), "branch was already wrapped on init"
189
+
190
+ def forward_and_add_residual(residual, *args, **kwargs):
191
+ branch_input, add_residual = self.forward(residual)
192
+
193
+ branch_output = branch(branch_input, *args, **kwargs)
194
+
195
+ residual = add_residual(branch_output)
196
+
197
+ return residual
198
+
199
+ return forward_and_add_residual
200
+
201
+ def forward(self, residuals, *branch_args, **branch_kwargs):
202
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
203
+
204
+ def add_residual_fn(branch_out):
205
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
206
+
207
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
208
+
209
+ return tree_unflatten((branch_out, *rest), tree_spec)
210
+
211
+ if not exists(self.branch):
212
+ return branch_input, add_residual_fn
213
+
214
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
215
+
216
+ return add_residual_fn(branch_output)
217
+
218
+
219
+ # hyper connection residual streams
220
+
221
+
222
+ class HyperConnections(Module):
223
+ def __init__(
224
+ self,
225
+ num_residual_streams,
226
+ *,
227
+ dim,
228
+ branch: Module | None = None,
229
+ layer_index=None,
230
+ tanh=True,
231
+ channel_first=False,
232
+ dropout=0.0,
233
+ residual_transform: Module
234
+ | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
235
+ add_branch_out_to_residual=True, # will disable depth connections (weighted residual sum with beta) if set False
236
+ num_input_views=1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
237
+ depth_residual_fn=add,
238
+ num_fracs=1, # https://arxiv.org/abs/2503.14125
239
+ mhc=False,
240
+ sinkhorn_iters=10,
241
+ sinkhorn_tau=0.05,
242
+ mhc_h_res_proj="sinkhorn",
243
+ ns_steps=5,
244
+ ns_eps=1e-7,
245
+ ns_coeffs=(3.0, -3.2, 1.2),
246
+ ):
247
+ """
248
+ Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
249
+ """
250
+ super().__init__()
251
+
252
+ self.branch = branch
253
+
254
+ self.act = nn.Tanh() if tanh else nn.Identity()
255
+
256
+ # frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
257
+
258
+ assert num_fracs >= 1
259
+
260
+ self.num_fracs = num_fracs
261
+ self.has_fracs = num_fracs > 1
262
+
263
+ self.split_fracs = Rearrange("b ... (f d) -> b ... f d", f=num_fracs)
264
+ self.merge_fracs = Rearrange("b ... f d -> b ... (f d)")
265
+
266
+ assert divisible_by(dim, num_fracs), (
267
+ f"feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})"
268
+ )
269
+
270
+ dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
271
+
272
+ # they used layernorm in paper, but rmsnorm is fine given what we know now
273
+
274
+ self.norm = RMSNorm(dim)
275
+
276
+ assert num_residual_streams > 0, "`num_residual_streams` must be greater than 0"
277
+
278
+ self.num_residual_streams = num_residual_streams
279
+ init_residual_index = (
280
+ default(layer_index, randrange(num_residual_streams)) % num_residual_streams
281
+ ) # just choose one random residual stream if layer index not given
282
+
283
+ # handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
284
+
285
+ num_residual_streams_fracs = num_residual_streams * num_fracs
286
+ num_input_views_fracs = num_input_views * num_fracs
287
+
288
+ # width num residual streams
289
+
290
+ assert num_input_views >= 1
291
+ self.num_input_views = num_input_views
292
+
293
+ # width connection
294
+
295
+ init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
296
+ init_alpha0[init_residual_index, :] = 1.0
297
+
298
+ self.static_alpha = nn.Parameter(
299
+ cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim=1)
300
+ )
301
+
302
+ self.dynamic_alpha_fn = nn.Parameter(
303
+ torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs)
304
+ )
305
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
306
+
307
+ # depth connection related (beta)
308
+
309
+ self.add_branch_out_to_residual = add_branch_out_to_residual
310
+
311
+ if add_branch_out_to_residual:
312
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
313
+
314
+ dynamic_beta_shape = (
315
+ (dim,) if num_fracs == 1 else (dim, num_fracs)
316
+ ) # preserve backwards compat
317
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
318
+
319
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
320
+
321
+ # dropouts
322
+
323
+ self.dropout = nn.Dropout(dropout)
324
+
325
+ # channel first option
326
+
327
+ self.channel_first = channel_first
328
+
329
+ # maybe residual transform
330
+
331
+ self.residual_transform = default(residual_transform, nn.Identity())
332
+
333
+ # maybe custom depth connection residual function
334
+ # this is to prepare for gating the addition of the branch outputs to the residual streams
335
+ # needed for memory lanes a la RMT / LMM
336
+
337
+ self.depth_residual_fn = depth_residual_fn
338
+
339
+ self.mhc = mhc
340
+ self.sinkhorn_iters = sinkhorn_iters
341
+ self.sinkhorn_tau = sinkhorn_tau
342
+ self.mhc_h_res_proj = mhc_h_res_proj
343
+ self.ns_steps = ns_steps
344
+ self.ns_eps = ns_eps
345
+ self.ns_coeffs = ns_coeffs
346
+
347
+ if mhc:
348
+ assert num_fracs == 1, "mhc currently requires num_fracs = 1"
349
+ assert num_input_views == 1, "mhc currently requires num_input_views = 1"
350
+ assert mhc_h_res_proj in (
351
+ "sinkhorn",
352
+ "orthostochastic",
353
+ ), "mhc_h_res_proj must be 'sinkhorn' or 'orthostochastic'"
354
+
355
+ H_res_init = torch.full((num_residual_streams, num_residual_streams), -8.0)
356
+ H_res_init.fill_diagonal_(0.0)
357
+ self.H_res_logits = nn.Parameter(H_res_init)
358
+
359
+ H_pre_init = torch.full((num_residual_streams,), -8.0)
360
+ H_pre_init[init_residual_index] = 0.0
361
+ self.H_pre_logits = nn.Parameter(H_pre_init)
362
+
363
+ if add_branch_out_to_residual:
364
+ self.H_post_logits = nn.Parameter(torch.zeros(num_residual_streams))
365
+
366
+ def width_connection(self, residuals):
367
+ streams = self.num_residual_streams
368
+
369
+ maybe_transformed_residuals = self.residual_transform(residuals)
370
+
371
+ # width connection
372
+
373
+ # handle channel first
374
+
375
+ if self.channel_first:
376
+ residuals = rearrange(residuals, "b d ... -> b ... d")
377
+
378
+ # split out fractions
379
+
380
+ residuals = self.split_fracs(residuals)
381
+
382
+ # split out streams
383
+
384
+ residuals = rearrange(residuals, "(b s) ... d -> b ... s d", s=streams)
385
+
386
+ if self.mhc:
387
+ residuals_mixed_source = maybe_transformed_residuals
388
+
389
+ if self.channel_first:
390
+ residuals_mixed_source = rearrange(
391
+ residuals_mixed_source, "b d ... -> b ... d"
392
+ )
393
+
394
+ residuals_mixed_source = self.split_fracs(residuals_mixed_source)
395
+ residuals_mixed_source = rearrange(
396
+ residuals_mixed_source, "(b s) ... d -> b ... s d", s=streams
397
+ )
398
+
399
+ if self.mhc_h_res_proj == "orthostochastic":
400
+ H_res = orthostochastic_project(
401
+ self.H_res_logits,
402
+ ns_steps=self.ns_steps,
403
+ ns_eps=self.ns_eps,
404
+ ns_coeffs=self.ns_coeffs,
405
+ )
406
+ else:
407
+ H_res = sinkhorn_log(
408
+ self.H_res_logits, self.sinkhorn_iters, self.sinkhorn_tau
409
+ )
410
+ H_pre = F.softmax(self.H_pre_logits, dim=-1)
411
+
412
+ H_post = None
413
+ if self.add_branch_out_to_residual:
414
+ H_post = F.softmax(self.H_post_logits, dim=-1)
415
+
416
+ residuals_mixed = einsum(
417
+ H_res, residuals_mixed_source, "s t, ... s d -> ... t d"
418
+ )
419
+ branch_input = einsum(H_pre, residuals, "s, ... s d -> ... d")
420
+
421
+ if getattr(self, "collect_stats", False):
422
+ with torch.no_grad():
423
+ stats = dict(
424
+ h_res_min=H_res.min(),
425
+ h_res_row_sum=H_res.sum(dim=-1).mean(),
426
+ h_res_col_sum=H_res.sum(dim=-2).mean(),
427
+ h_pre_min=H_pre.min(),
428
+ )
429
+ if H_post is not None:
430
+ stats["h_post_min"] = H_post.min()
431
+ self.last_stats = {k: v.detach() for k, v in stats.items()}
432
+
433
+ if self.channel_first:
434
+ branch_input = rearrange(branch_input, "b ... d -> b d ...")
435
+
436
+ branch_input = self.merge_fracs(branch_input)
437
+
438
+ return (
439
+ branch_input,
440
+ maybe_transformed_residuals,
441
+ dict(beta=H_post, residuals_mixed=residuals_mixed),
442
+ )
443
+
444
+ # norm
445
+
446
+ normed = self.norm(residuals)
447
+
448
+ # alpha for weighted sum of residuals going into branch
449
+
450
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
451
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
452
+
453
+ static_alpha = rearrange(self.static_alpha, "(f s) d -> f s d", s=streams)
454
+
455
+ alpha = dynamic_alpha + static_alpha
456
+
457
+ alpha = self.split_fracs(
458
+ alpha
459
+ ) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
460
+
461
+ # beta for weights from branch output back to residual streams
462
+
463
+ beta = None
464
+
465
+ if self.add_branch_out_to_residual:
466
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
467
+
468
+ if not self.has_fracs:
469
+ dc_weight = rearrange(dc_weight, "... -> ... 1")
470
+
471
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
472
+
473
+ static_beta = rearrange(self.static_beta, "... (s f) -> ... s f", s=streams)
474
+
475
+ beta = dynamic_beta + static_beta
476
+
477
+ if getattr(self, "collect_stats", False):
478
+ with torch.no_grad():
479
+ num_input_views_fracs = self.num_input_views * self.num_fracs
480
+ alpha_branch = alpha[..., :num_input_views_fracs]
481
+ alpha_residual = alpha[..., num_input_views_fracs:]
482
+ alpha_branch_abs_mean = alpha_branch.abs().mean()
483
+ alpha_residual_abs_mean = alpha_residual.abs().mean()
484
+ stats = dict(
485
+ alpha_branch_mean=alpha_branch.mean(),
486
+ alpha_branch_abs_mean=alpha_branch_abs_mean,
487
+ alpha_residual_mean=alpha_residual.mean(),
488
+ alpha_residual_abs_mean=alpha_residual_abs_mean,
489
+ alpha_branch_residual_ratio=alpha_branch_abs_mean
490
+ / (alpha_residual_abs_mean + 1e-8),
491
+ )
492
+ if beta is not None:
493
+ stats.update(
494
+ beta_mean=beta.mean(),
495
+ beta_abs_mean=beta.abs().mean(),
496
+ beta_min=beta.min(),
497
+ beta_max=beta.max(),
498
+ )
499
+ self.last_stats = {k: v.detach() for k, v in stats.items()}
500
+
501
+ mix_h = einsum(alpha, residuals, "... f1 s f2 t, ... f1 s d -> ... f2 t d")
502
+
503
+ if self.num_input_views == 1:
504
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
505
+ else:
506
+ branch_input, residuals = (
507
+ mix_h[..., : self.num_input_views, :],
508
+ mix_h[..., self.num_input_views :, :],
509
+ )
510
+ branch_input = rearrange(branch_input, "b ... v d -> v b ... d")
511
+
512
+ if self.channel_first:
513
+ branch_input = rearrange(branch_input, "b ... d -> b d ...")
514
+
515
+ # maybe merge fractions back
516
+
517
+ branch_input = self.merge_fracs(branch_input)
518
+
519
+ return branch_input, maybe_transformed_residuals, dict(beta=beta)
520
+
521
+ def depth_connection(self, branch_output, residuals, *, beta, residuals_mixed=None):
522
+ assert self.add_branch_out_to_residual
523
+
524
+ # maybe split fractions
525
+
526
+ branch_output = self.split_fracs(branch_output)
527
+
528
+ # 'depth' connection
529
+
530
+ if self.channel_first:
531
+ branch_output = rearrange(branch_output, "b d ... -> b ... d")
532
+
533
+ if self.mhc:
534
+ assert residuals_mixed is not None
535
+ assert beta is not None
536
+
537
+ branch_to_streams = einsum(branch_output, beta, "b ... d, s -> b ... s d")
538
+ output = residuals_mixed + branch_to_streams
539
+ output = rearrange(output, "b ... s d -> (b s) ... d")
540
+
541
+ output = self.merge_fracs(output)
542
+
543
+ if self.channel_first:
544
+ output = rearrange(output, "b ... d -> b d ...")
545
+
546
+ return self.dropout(output)
547
+
548
+ output = einsum(
549
+ branch_output, beta, "b ... f1 d, b ... f1 s f2 -> b ... f2 s d"
550
+ )
551
+
552
+ output = rearrange(output, "b ... s d -> (b s) ... d")
553
+
554
+ # merge merge back fractions
555
+
556
+ output = self.merge_fracs(output)
557
+
558
+ # channel first
559
+
560
+ if self.channel_first:
561
+ output = rearrange(output, "b ... d -> b d ...")
562
+
563
+ residuals = self.depth_residual_fn(output, residuals)
564
+
565
+ return self.dropout(residuals)
566
+
567
+ def decorate_branch(self, branch: Callable):
568
+ assert not exists(self.branch), "branch was already wrapped on init"
569
+
570
+ def forward_and_add_residual(residual, *args, **kwargs):
571
+ branch_input, add_residual = self.forward(residual)
572
+
573
+ branch_output = branch(branch_input, *args, **kwargs)
574
+
575
+ residual = add_residual(branch_output)
576
+
577
+ return residual
578
+
579
+ return forward_and_add_residual
580
+
581
+ def forward(self, residuals, *branch_args, **branch_kwargs):
582
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
583
+
584
+ def add_residual_fn(branch_out):
585
+ if not self.add_branch_out_to_residual:
586
+ return branch_out
587
+
588
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
589
+
590
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
591
+
592
+ return tree_unflatten((branch_out, *rest), tree_spec)
593
+
594
+ if not exists(self.branch):
595
+ return branch_input, add_residual_fn
596
+
597
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
598
+
599
+ return add_residual_fn(branch_output)
600
+
601
+
602
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(
603
+ get_expand_reduce_stream_functions
604
+ )
605
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(
606
+ get_init_and_expand_reduce_stream_functions
607
+ )
608
+
609
+ # stream embed
610
+
611
+
612
+ class StreamEmbed(Module):
613
+ def __init__(self, num_streams, dim, channel_first=False, expand_to_streams=False):
614
+ super().__init__()
615
+ self.channel_first = channel_first
616
+ self.num_streams = num_streams
617
+
618
+ self.expand_to_streams = expand_to_streams
619
+ self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
620
+
621
+ def forward(self, residuals):
622
+ if self.expand_to_streams:
623
+ residuals = repeat(residuals, "b ... -> (b s) ...", s=self.num_streams)
624
+
625
+ if self.channel_first:
626
+ residuals = rearrange(
627
+ residuals, "(b s) d ... -> b ... s d", s=self.num_streams
628
+ )
629
+ else:
630
+ residuals = rearrange(
631
+ residuals, "(b s) ... d -> b ... s d", s=self.num_streams
632
+ )
633
+
634
+ residuals = residuals + self.stream_embed
635
+
636
+ if self.channel_first:
637
+ residuals = rearrange(
638
+ residuals, "b ... s d -> (b s) d ...", s=self.num_streams
639
+ )
640
+ else:
641
+ residuals = rearrange(
642
+ residuals, "b ... s d -> (b s) ... d", s=self.num_streams
643
+ )
644
+
645
+ return residuals
646
+
647
+
648
+ # attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
649
+
650
+
651
+ class AttentionPoolReduceStream(Module):
652
+ def __init__(self, num_streams, dim, channel_first=False):
653
+ super().__init__()
654
+ self.num_streams = num_streams
655
+ self.channel_first = channel_first
656
+
657
+ self.to_attn_logits = nn.Linear(dim, dim, bias=False)
658
+ self.to_attn_logits.weight.data.copy_(torch.eye(dim))
659
+
660
+ def forward(self, residuals):
661
+ if self.channel_first:
662
+ residuals = rearrange(
663
+ residuals, "(b s) d ... -> b ... s d", s=self.num_streams
664
+ )
665
+ else:
666
+ residuals = rearrange(
667
+ residuals, "(b s) ... d -> b ... s d", s=self.num_streams
668
+ )
669
+
670
+ attn_logits = self.to_attn_logits(residuals)
671
+ attn = attn_logits.softmax(dim=-2)
672
+
673
+ residuals = reduce(residuals * attn, "b ... s d -> b ... d", "sum")
674
+
675
+ if self.channel_first:
676
+ residuals = rearrange(residuals, "b ... d -> b d ...")
677
+
678
+ return residuals
679
+
680
+
681
+ class CausalSelfAttention(nn.Module):
682
+ def __init__(self, config):
683
+ super().__init__()
684
+
685
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
686
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
687
+ self.c_proj.NANOGPT_SCALE_INIT = 1
688
+
689
+ self.n_head = config.n_head
690
+ self.n_embd = config.n_embd
691
+
692
+ self.register_buffer(
693
+ "bias",
694
+ torch.tril(torch.ones(config.block_size, config.block_size))
695
+ .view(1, 1, config.block_size, config.block_size),
696
+ persistent=False,
697
+ )
698
+
699
+ def forward(self, x):
700
+ B, T, C = x.size()
701
+
702
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
703
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
704
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
705
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
706
+
707
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
708
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
709
+ att = F.softmax(att, dim=-1)
710
+
711
+ y = att @ v
712
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
713
+ return self.c_proj(y)
714
+ class MLP(nn.Module):
715
+ def __init__(self, config):
716
+ super().__init__()
717
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
718
+ self.gelu = nn.GELU(approximate="tanh")
719
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
720
+ self.c_proj.NANOGPT_SCALE_INIT = 1
721
+
722
+ def forward(self, x):
723
+ return self.c_proj(self.gelu(self.c_fc(x)))
724
+ class AttnBranch(nn.Module):
725
+ def __init__(self, norm, attn):
726
+ super().__init__()
727
+ self.norm = norm
728
+ self.attn = attn
729
+
730
+ def forward(self, x):
731
+ return self.attn(self.norm(x))
732
+
733
+
734
+ class Block(nn.Module):
735
+ def __init__(self, config, layer_idx, init_hc):
736
+ super().__init__()
737
+
738
+ self.ln_1 = nn.LayerNorm(config.n_embd)
739
+ self.ln_2 = nn.LayerNorm(config.n_embd)
740
+
741
+ self.attn = CausalSelfAttention(config)
742
+ self.mlp = MLP(config)
743
+
744
+ self.attn_branch = AttnBranch(self.ln_1, self.attn)
745
+
746
+ hc_kwargs = dict(
747
+ mhc=config.mhc,
748
+ sinkhorn_iters=config.sinkhorn_iters,
749
+ sinkhorn_tau=config.sinkhorn_tau,
750
+ mhc_h_res_proj=config.mhc_h_res_proj,
751
+ ns_steps=config.ns_steps,
752
+ ns_eps=config.ns_eps,
753
+ ns_coeffs=config.ns_coeffs,
754
+ )
755
+
756
+ self.hc_attn = init_hc(
757
+ dim=config.n_embd,
758
+ branch=self.attn_branch,
759
+ layer_index=layer_idx * 2,
760
+ **hc_kwargs,
761
+ )
762
+
763
+ self.hc_mlp = init_hc(
764
+ dim=config.n_embd,
765
+ branch=nn.Sequential(self.ln_2, self.mlp),
766
+ layer_index=layer_idx * 2 + 1,
767
+ **hc_kwargs,
768
+ )
769
+
770
+ def forward(self, x):
771
+ x = self.hc_attn(x)
772
+ x = self.hc_mlp(x)
773
+ return x
774
+ class GPTConfig(PretrainedConfig):
775
+ model_type = "custom_gpt"
776
+
777
+ def __init__(
778
+ self,
779
+ block_size=1024,
780
+ vocab_size=50304,
781
+ n_layer=12,
782
+ n_head=12,
783
+ n_embd=768,
784
+ dropout=0.0,
785
+ bias=True,
786
+ hc_num_streams=1,
787
+ hc_num_fracs=1,
788
+ hc_disable=False,
789
+ mhc=False,
790
+ sinkhorn_iters=10,
791
+ sinkhorn_tau=0.05,
792
+ mhc_h_res_proj="sinkhorn",
793
+ ns_steps=5,
794
+ ns_eps=1e-7,
795
+ ns_coeffs=(3.0, -3.2, 1.2),
796
+ **kwargs,
797
+ ):
798
+ super().__init__(**kwargs)
799
+
800
+ self.block_size = block_size
801
+ self.vocab_size = vocab_size
802
+ self.n_layer = n_layer
803
+ self.n_head = n_head
804
+ self.n_embd = n_embd
805
+ self.dropout = dropout
806
+ self.bias = bias
807
+
808
+ self.hc_num_streams = hc_num_streams
809
+ self.hc_num_fracs = hc_num_fracs
810
+ self.hc_disable = hc_disable
811
+ self.mhc = mhc
812
+ self.sinkhorn_iters = sinkhorn_iters
813
+ self.sinkhorn_tau = sinkhorn_tau
814
+ self.mhc_h_res_proj = mhc_h_res_proj
815
+ self.ns_steps = ns_steps
816
+ self.ns_eps = ns_eps
817
+ self.ns_coeffs = ns_coeffs
818
+
819
+ # 🔑 HF compatibility aliases
820
+ self.num_hidden_layers = n_layer
821
+ self.num_attention_heads = n_head
822
+ self.hidden_size = n_embd
823
+ self.max_position_embeddings = block_size
824
+
825
+ class GPT(PreTrainedModel, GenerationMixin):
826
+ config_class = GPTConfig
827
+ # config_class = MyGPTConfig
828
+
829
+ def __init__(self, config):
830
+ super().__init__(config)
831
+
832
+ init_hc, expand_stream, reduce_stream = (
833
+ get_init_and_expand_reduce_stream_functions(
834
+ config.hc_num_streams,
835
+ num_fracs=config.hc_num_fracs,
836
+ disable=config.hc_disable,
837
+ )
838
+ )
839
+
840
+ self.expand_stream = expand_stream
841
+ self.reduce_stream = reduce_stream
842
+
843
+ self.transformer = nn.ModuleDict(
844
+ dict(
845
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
846
+ wpe=nn.Embedding(config.block_size, config.n_embd),
847
+ h=nn.ModuleList(
848
+ [Block(config, i, init_hc) for i in range(config.n_layer)]
849
+ ),
850
+ ln_f=nn.LayerNorm(config.n_embd),
851
+ )
852
+ )
853
+
854
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
855
+ self.transformer.wte.weight = self.lm_head.weight
856
+
857
+ self.post_init()
858
+
859
+ def prepare_inputs_for_generation(
860
+ self,
861
+ input_ids,
862
+ past_key_values=None,
863
+ **kwargs,
864
+ ):
865
+ # We do NOT use KV cache yet, so always feed full sequence
866
+ return {
867
+ "input_ids": input_ids,
868
+ "past_key_values": None,
869
+ }
870
+
871
+ def forward(
872
+ self,
873
+ input_ids=None,
874
+ attention_mask=None, # 👈 ADD THIS
875
+ labels=None,
876
+ past_key_values=None,
877
+ use_cache=None,
878
+ **kwargs,
879
+ ):
880
+
881
+
882
+ b, t = input_ids.size()
883
+ assert t <= self.config.block_size
884
+
885
+ pos = torch.arange(0, t, device=input_ids.device).unsqueeze(0)
886
+
887
+ x = self.transformer.wte(input_ids) + self.transformer.wpe(pos)
888
+ x = self.expand_stream(x)
889
+
890
+ for block in self.transformer.h:
891
+ x = block(x)
892
+
893
+ x = self.transformer.ln_f(x)
894
+ x = self.reduce_stream(x)
895
+
896
+ logits = self.lm_head(x)
897
+
898
+ loss = None
899
+ if labels is not None:
900
+ loss = F.cross_entropy(
901
+ logits.view(-1, logits.size(-1)),
902
+ labels.view(-1),
903
+ )
904
+
905
+ return CausalLMOutput(
906
+ loss=loss,
907
+ logits=logits,
908
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df198621f8d5ee6f68a316ddd5730b5816c022ad1ed6daaa309ec09bc0b79e7c
3
+ size 1520347603
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff