TheTrueJard commited on
Commit
6f09125
·
verified ·
1 Parent(s): 0642acb

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. __init__.py +3 -0
  2. config.json +9 -0
  3. config.py +39 -0
  4. modeling.py +381 -0
  5. psi.py +788 -0
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .psi import PSI
3
+ from .config import PSIConfig
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "StanfordNeuroAILab/PSI",
3
+ "architectures": ["PSI"],
4
+ "auto_map": {
5
+ "AutoConfig": "config.PSIConfig",
6
+ "AutoModel": "psi.PSI"
7
+ },
8
+ "model_type": "PSI"
9
+ }
config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Optional
3
+ from transformers import PretrainedConfig
4
+
5
+ class PSIConfig(PretrainedConfig):
6
+ model_type: str = "PSI"
7
+ def __init__(self,
8
+ vocab_size: int = 96256,
9
+ channel_size: int = 12,
10
+ n_layer: int = 12,
11
+ n_head: int = 12,
12
+ n_embd: int = 768,
13
+ dropout: float = 0.0,
14
+ bias: bool = False,
15
+ attention_mask: str = "causal",
16
+ tie_weights: bool = False,
17
+ partition_embedding: bool = False,
18
+ n_lm_vocab: Optional[int] = None,
19
+ **kwargs
20
+ ):
21
+ self.vocab_size = vocab_size
22
+ self.channel_size = channel_size
23
+ self.n_layer = n_layer
24
+ self.n_head = n_head
25
+ self.n_embd = n_embd
26
+ self.dropout = dropout
27
+ self.bias = bias
28
+ self.attention_mask = attention_mask
29
+ self.tie_weights = tie_weights
30
+ self.partition_embedding = partition_embedding
31
+ self.n_lm_vocab = n_lm_vocab
32
+
33
+ # Aside from HuggingFace default config attributes,
34
+ # all extra kwargs are assigned using setattr. For HuggingFace attrs, see:
35
+ # https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/configuration_utils.py#L45
36
+
37
+ # Since token ranges are checkpoint-specific, we don't include them
38
+ # in this config and let them be assigned from kwargs.
39
+ super().__init__(**kwargs)
modeling.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ try:
8
+ import torch_xla.core.xla_model as xm
9
+ import torch_xla.distributed.spmd.xla_sharding as xs
10
+ except ImportError:
11
+ xm = None
12
+ xs = None
13
+
14
+
15
+ class Rotary3D(nn.Module):
16
+ def __init__(self, dim, base=100):
17
+ super().__init__()
18
+ assert dim % 16 == 0, "Embedding dim must be divisible by 16"
19
+
20
+ # Embedding dimensions must align precisely with dim // num_heads
21
+ self.x_dim = (6 * dim) // 16
22
+ self.y_dim = (6 * dim) // 16
23
+ self.t_dim = dim - self.x_dim - self.y_dim
24
+
25
+ # Precompute inverse frequencies
26
+ self.register_buffer('inv_freq_x', 1.0 / (base ** (torch.arange(0, self.x_dim, 2).float() / self.x_dim)))
27
+ self.register_buffer('inv_freq_y', 1.0 / (base ** (torch.arange(0, self.y_dim, 2).float() / self.y_dim)))
28
+ self.register_buffer('inv_freq_t', 1.0 / (base ** (torch.arange(0, self.t_dim, 2).float() / self.t_dim)))
29
+
30
+ def forward(self, x, pos):
31
+ """
32
+ x: [batch, nh, seq_len, head_dim]
33
+ pos: [batch, seq_len, 3] integer positions along (x, y, t)
34
+ """
35
+ B, nh, T, hs = x.shape
36
+ assert pos.shape[-1] == 3, "Position tensor must have shape [batch, seq_len, 3]"
37
+
38
+ # Compute embeddings directly to match `hs`
39
+ dim_total = hs
40
+ assert dim_total % 2 == 0, "head_dim (hs) must be divisible by 2 for rotary embedding."
41
+
42
+ # Positional dimensions expanded explicitly
43
+ dtype = self.inv_freq_x.dtype
44
+ pos_x = pos[..., 0].to(dtype) # [B, T]
45
+ pos_y = pos[..., 1].to(dtype) # [B, T]
46
+ pos_t = pos[..., 2].to(dtype) # [B, T]
47
+
48
+ # Generate embeddings for x, y, t and combine
49
+ freqs_x = torch.einsum('bt,f -> btf', pos_x, self.inv_freq_x)
50
+ freqs_y = torch.einsum('bt,f -> btf', pos_y, self.inv_freq_y)
51
+ freqs_t = torch.einsum('bt,f -> btf', pos_t, self.inv_freq_t)
52
+
53
+ # Concatenate embeddings and match dimensions exactly
54
+ freq_combined = torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
55
+
56
+ # Cos and Sin embedding, reshape to match x exactly
57
+ cos_emb = freq_combined.cos().unsqueeze(1) # [B, 1, T, hs/2]
58
+ sin_emb = freq_combined.sin().unsqueeze(1) # [B, 1, T, hs/2]
59
+
60
+ # Split embedding dimension for rotation
61
+ x1, x2 = x[..., :hs//2], x[..., hs//2:]
62
+
63
+ # Ensure exact dimensional matching
64
+ x_rotated = torch.cat([
65
+ x1 * cos_emb - x2 * sin_emb,
66
+ x1 * sin_emb + x2 * cos_emb
67
+ ], dim=-1)
68
+
69
+ return x_rotated
70
+
71
+
72
+ class PSIAttentionLayer(nn.Module):
73
+
74
+ def __init__(self, config):
75
+
76
+ super().__init__()
77
+ assert config.n_embd % config.n_head == 0
78
+
79
+ # key, query, value projections for all heads, but in a batch
80
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
81
+ # output projection
82
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
83
+ # regularization
84
+ self.attn_dropout = nn.Dropout(config.dropout)
85
+ self.resid_dropout = nn.Dropout(config.dropout)
86
+ self.n_head = config.n_head
87
+ self.n_embd = config.n_embd
88
+ self.dropout = config.dropout
89
+ # positional embedding
90
+ self.rope = Rotary3D(config.n_embd // config.n_head)
91
+
92
+ # check if we are using causal attention
93
+ if config.attention_mask == "causal":
94
+ self.is_causal = True
95
+ else:
96
+ self.is_causal = False
97
+
98
+ # check if GPU Flash Attention is available
99
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
100
+
101
+ # check if we are running on TPU
102
+ try:
103
+ # Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
104
+ import torch_xla.core.xla_model as xm_local
105
+ self.tpu = True
106
+ except ImportError:
107
+ self.tpu = False
108
+
109
+ # Apply XLA sharding for model parallelism
110
+ xla_device_available = False
111
+ if xm is not None:
112
+ try:
113
+ device_kind = xm.xla_device_kind()
114
+ if device_kind is not None:
115
+ xla_device_available = True
116
+ except RuntimeError:
117
+ pass
118
+
119
+ @torch.compiler.disable
120
+ def emplace_kv(self, T, k_cache, v_cache, k, v):
121
+ # torch.compile doesn't play well with this op (5x slowdown)
122
+ # so we insert a graph break and copy eagerly
123
+ k_cache[:,:,-T:].copy_(k)
124
+ v_cache[:,:,-T:].copy_(v)
125
+ return k_cache, v_cache
126
+
127
+ def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, mask=None):
128
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
129
+
130
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
131
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
132
+
133
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
134
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
135
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
136
+
137
+ # Apply rotary positional embedding
138
+ k = self.rope(k, pos)
139
+ q = self.rope(q, pos)
140
+
141
+ if inplace_kv and k_cache is not None and v_cache is not None:
142
+ # assign into kv cache in-place
143
+ k, v = self.emplace_kv(T, k_cache, v_cache, k, v)
144
+ else:
145
+ # append cached keys and values with new keys and values
146
+ if k_cache is not None:
147
+ k = torch.cat((k_cache, k), dim=2)
148
+ if v_cache is not None:
149
+ v = torch.cat((v_cache, v), dim=2)
150
+
151
+ # Apply attention
152
+ if self.tpu:
153
+ # (1)
154
+ from torch_xla.experimental.custom_kernel import flash_attention
155
+ q_norm = q / math.sqrt(k.size(-1))
156
+ y = flash_attention(
157
+ q_norm, k, v,
158
+ causal=True, partition_spec=('fsdp', None, None, None))
159
+ # (2)
160
+ # y = torch.nn.functional.scaled_dot_product_attention(
161
+ # q, k, v,
162
+ # # dropout_p=self.dropout if self.training else 0,
163
+ # # attn_mask=None if mask is None else mask.to(q.dtype),
164
+ # is_causal=True
165
+ # )
166
+ elif self.flash:
167
+ # efficient attention using Flash Attention CUDA kernels
168
+ L, S = q.size(-2), k.size(-2)
169
+ is_causal = self.is_causal and mask is None
170
+ # is_causal doesn't work when not square, so replace with a manual mask if needed
171
+ if is_causal and L < S:
172
+ if L > 1: # if L=1, just use no mask
173
+ mask = torch.ones(L, S, dtype=q.dtype, device=q.device)
174
+ mask.masked_fill_(mask.to(torch.bool).triu(S-L+1), float('-inf'))
175
+ is_causal = False
176
+
177
+ y = torch.nn.functional.scaled_dot_product_attention(
178
+ q, k, v,
179
+ dropout_p=self.dropout if self.training else 0,
180
+ attn_mask=None if mask is None else mask.to(q.dtype),
181
+ is_causal=is_causal
182
+ )
183
+ else:
184
+ # manual implementation of attention
185
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
186
+ # apply mask, or use causal if default
187
+ if mask is not None:
188
+ att = att + mask
189
+ elif self.is_causal:
190
+ L, S = q.size(-2), k.size(-2)
191
+ mask = torch.ones(1, 1, L, S).triu(S-L+1).to(dtype=torch.bool).to(x.device)
192
+ att.masked_fill_(mask, float('-inf'))
193
+ # upcast to float32 for numerical stability, as per llama implementation
194
+ att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
195
+ att = self.attn_dropout(att)
196
+ # multiply attention weights with values to get output
197
+ y = torch.einsum('bnsk,bnkh->bnsh', att, v)
198
+
199
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
200
+ # output projection
201
+ y = self.resid_dropout(self.c_proj(y))
202
+ # return key and value caches if requested
203
+ if return_kv:
204
+ return y, k, v
205
+
206
+ return y
207
+
208
+ def kv_cache_forward(self, x, pos, k_cache=None, v_cache=None):
209
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
210
+
211
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
212
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
213
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
214
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
215
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
216
+
217
+ # Apply rotary positional embedding (before concat)
218
+ k = self.rope(k, pos)
219
+ q = self.rope(q, pos)
220
+
221
+ # append cached keys and values with new keys and values
222
+ if k_cache is not None:
223
+ k = torch.cat((k_cache, k), dim=2)
224
+ if v_cache is not None:
225
+ v = torch.cat((v_cache, v), dim=2)
226
+
227
+ # manual implementation of attention
228
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
229
+ att = F.softmax(att, dim=-1)
230
+ att = self.attn_dropout(att)
231
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
232
+
233
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
234
+
235
+ # output projection
236
+ y = self.resid_dropout(self.c_proj(y))
237
+
238
+ return y, k, v
239
+
240
+
241
+ class MLP(nn.Module):
242
+
243
+ def __init__(self, config):
244
+ super().__init__()
245
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
246
+ self.gelu = nn.GELU()
247
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
248
+ self.dropout = nn.Dropout(config.dropout)
249
+
250
+ # Apply XLA sharding for model parallelism
251
+ xla_device_available = False
252
+ if xm is not None:
253
+ try:
254
+ device_kind = xm.xla_device_kind()
255
+ if device_kind is not None:
256
+ xla_device_available = True
257
+ except RuntimeError:
258
+ pass
259
+
260
+ if xla_device_available and xs is not None and xs.global_mesh() is not None:
261
+ mesh = xs.global_mesh()
262
+ if mesh.mesh_shape[1] > 1: # If the 'model' axis has size > 1
263
+ xs.mark_sharding(self.c_fc.weight, mesh, (1, 0))
264
+ if self.c_fc.bias is not None:
265
+ xs.mark_sharding(self.c_fc.bias, mesh, (1,))
266
+ print(f"MLP: Applied MP sharding to c_fc {mesh.mesh_shape} spec weight(1,0), bias(1,)")
267
+
268
+ xs.mark_sharding(self.c_proj.weight, mesh, (0, 1))
269
+ if self.c_proj.bias is not None:
270
+ xs.mark_sharding(self.c_proj.bias, mesh, (0,))
271
+ print(f"MLP: Applied MP sharding to c_proj {mesh.mesh_shape} spec weight(0,1), bias(0,)")
272
+
273
+ def forward(self, x, spmd_mesh=None):
274
+
275
+ x = self.c_fc(x)
276
+ x = self.gelu(x)
277
+
278
+ if spmd_mesh is not None:
279
+ import torch_xla.distributed.spmd.xla_sharding as xs
280
+ xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
281
+
282
+ x = self.c_proj(x)
283
+ x = self.dropout(x)
284
+
285
+ if spmd_mesh is not None:
286
+ xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
287
+
288
+ return x
289
+
290
+
291
+ class RMSNorm(nn.Module):
292
+ """ Root Mean Square Normalization """
293
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-5): # whl
294
+ super().__init__()
295
+ self.eps = eps
296
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
297
+
298
+ def _norm(self, x):
299
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
300
+
301
+ def forward(self, x):
302
+ output = self._norm(x.float()).type_as(x)
303
+ if self.weight is not None:
304
+ return output * self.weight
305
+ return output
306
+
307
+
308
+ class PSIBlock(nn.Module):
309
+ def __init__(self, config):
310
+ super().__init__()
311
+ self.ln_1 = RMSNorm(config.n_embd, bias=config.bias)
312
+ self.attn = PSIAttentionLayer(config)
313
+ self.ln_2 = RMSNorm(config.n_embd, bias=config.bias)
314
+ self.mlp = MLP(config)
315
+
316
+ def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, spmd_mesh=None, mask=None):
317
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
318
+ # the computation cost
319
+ if return_kv:
320
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
321
+ x_attn, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache,
322
+ return_kv=True, inplace_kv=inplace_kv, mask=mask)
323
+ x = x + x_attn
324
+ x = x + self.mlp(self.ln_2(x))
325
+ return x, k, v
326
+ # Else we proceed with the regular forward pass
327
+ x = x + self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, inplace_kv=inplace_kv, mask=mask)
328
+ x = x + self.mlp(self.ln_2(x))
329
+ return x
330
+
331
+
332
+ class PartitionedEmbedding(nn.Module):
333
+ def __init__(self, num_embeddings, embedding_dim, partition_size=65536):
334
+ super().__init__()
335
+ self.num_embeddings = num_embeddings
336
+ self.embedding_dim = embedding_dim
337
+ self.partition_size = partition_size
338
+ self.num_partitions = (num_embeddings + partition_size - 1) // partition_size
339
+
340
+ self.embedding_layers = nn.ModuleList()
341
+ for i in range(self.num_partitions):
342
+ start_idx = i * self.partition_size
343
+ end_idx = min(start_idx + self.partition_size, num_embeddings)
344
+ vocab_size = end_idx - start_idx
345
+ self.embedding_layers.append(nn.Embedding(vocab_size, embedding_dim))
346
+
347
+ def forward(self, input_ids):
348
+ partition_ids = input_ids // self.partition_size
349
+ relative_ids = input_ids % self.partition_size
350
+
351
+ output = torch.zeros(*input_ids.shape, self.embedding_dim, device=input_ids.device, dtype=self.embedding_layers[0].weight.dtype)
352
+
353
+ for i in range(self.num_partitions):
354
+ mask = (partition_ids == i)
355
+ if mask.any():
356
+ partition_input_ids = relative_ids[mask]
357
+ embedded = self.embedding_layers[i](partition_input_ids)
358
+ output[mask] = embedded
359
+
360
+ return output
361
+
362
+
363
+ class PartitionedLinear(nn.Module):
364
+ def __init__(self, in_features, out_features, partition_size=65536, bias=False):
365
+ super().__init__()
366
+ self.in_features = in_features
367
+ self.out_features = out_features
368
+ self.partition_size = partition_size
369
+ self.num_partitions = (out_features + partition_size - 1) // partition_size
370
+
371
+ self.linear_layers = nn.ModuleList()
372
+ for i in range(self.num_partitions):
373
+ start_idx = i * self.partition_size
374
+ end_idx = min(start_idx + self.partition_size, out_features)
375
+ output_partition_size = end_idx - start_idx
376
+ self.linear_layers.append(nn.Linear(in_features, output_partition_size, bias=bias))
377
+
378
+ def forward(self, input):
379
+ outputs = [layer(input) for layer in self.linear_layers]
380
+ return torch.cat(outputs, dim=-1)
381
+
psi.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PSI Model Definition
3
+ """
4
+
5
+
6
+ import math
7
+ from typing import Tuple, Union, List, Optional, Callable, Dict
8
+ from transformers import PreTrainedModel
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import tqdm
14
+
15
+ from .config import PSIConfig
16
+ from .modeling import (
17
+ RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
18
+ )
19
+
20
+ try:
21
+ import torch_xla.core.xla_model as xm
22
+ import torch_xla.distributed.spmd.xla_sharding as xs
23
+ except ImportError:
24
+ xm = None
25
+ xs = None
26
+
27
+
28
+
29
+ class PSI(PreTrainedModel):
30
+ config_class = PSIConfig
31
+
32
+ ### Initialization Functions ###
33
+
34
+ def __init__(self, config):
35
+ super().__init__(config)
36
+ self.config = config
37
+
38
+ if hasattr(config, "partition_embedding") and config.partition_embedding:
39
+ token_embedding = PartitionedEmbedding(config.vocab_size, config.n_embd)
40
+ lm_head = PartitionedLinear(config.n_embd, config.vocab_size, bias=False)
41
+ else:
42
+ token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
43
+ if hasattr(config, "n_lm_vocab") and config.n_lm_vocab is not None:
44
+ n_lm_vocab = config.n_lm_vocab
45
+ else:
46
+ n_lm_vocab = config.vocab_size
47
+ lm_head = nn.Linear(config.n_embd, n_lm_vocab, bias=False)
48
+
49
+ self.transformer = nn.ModuleDict(dict(
50
+ token_embedding = token_embedding,
51
+ channel_embedding = nn.Embedding(config.channel_size, config.n_embd),
52
+ drop = nn.Dropout(config.dropout),
53
+ h = nn.ModuleList([PSIBlock(config) for _ in range(config.n_layer)]),
54
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
55
+ ))
56
+ self.lm_head = lm_head
57
+
58
+ # init all weights
59
+ self.apply(self._init_weights)
60
+ # apply special scaled init to the residual projections, per GPT-2 paper
61
+ for pn, p in self.named_parameters():
62
+ if pn.endswith('c_proj.weight'):
63
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
64
+
65
+ if hasattr(config, "tie_weights") and config.tie_weights:
66
+ if hasattr(config, "partition_embedding") and config.partition_embedding:
67
+ for i in range(len(self.transformer.token_embedding.embedding_layers)):
68
+ self.lm_head.linear_layers[i].weight = self.transformer.token_embedding.embedding_layers[i].weight
69
+ else:
70
+ self.lm_head.weight = self.transformer.token_embedding.weight
71
+
72
+ # Apply XLA sharding for model parallelism if on XLA and model axis > 1
73
+ xla_device_available = False
74
+ if xm is not None:
75
+ try:
76
+ device_kind = xm.xla_device_kind()
77
+ if device_kind is not None:
78
+ xla_device_available = True
79
+ except RuntimeError:
80
+ pass
81
+
82
+ self.unsharded_param_count = self.get_num_params()
83
+
84
+ def _init_weights(self, module):
85
+ if isinstance(module, nn.Linear):
86
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
87
+ if module.bias is not None:
88
+ torch.nn.init.zeros_(module.bias)
89
+ elif isinstance(module, nn.Embedding):
90
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
91
+
92
+ def get_num_params(self):
93
+ """Return the number of parameters in the model."""
94
+ return sum(p.numel() for p in self.parameters())
95
+
96
+
97
+ ### Training Functions ###
98
+
99
+ def forward(
100
+ self,
101
+ seq: torch.Tensor,
102
+ pos: torch.Tensor,
103
+ tgt: torch.Tensor = None,
104
+ mask: torch.Tensor = None,
105
+ k_cache: torch.Tensor = None,
106
+ v_cache: torch.Tensor = None,
107
+ return_kv: bool = False,
108
+ inplace_kv: bool = False,
109
+ output_hidden_states: bool = False,
110
+ ) -> torch.Tensor:
111
+ """
112
+ Forward pass of the model
113
+
114
+ Parameters:
115
+ seq (torch.Tensor) of size b, t: The input sequence
116
+ pos (torch.Tensor) of size b, t, d: The positional indices of the sequence of shape (batch, tokens, dimensions)
117
+ They consist of x, y, t and c coordinates, where x, y are the spatial coordinates of the patch,
118
+ t is the time index and c is the channel index
119
+ tgt (torch.Tensor) of size b, t_tgt: The target sequence
120
+ mask (torch.Tensor) of size b, t, t: The mask of the sequence
121
+ k_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A k-cache to prepend
122
+ v_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A v-cache to prepend
123
+ return_kv (bool): If True, returns (logits, k, v). Ignored if tgt != None
124
+ inplace_kv (bool): If True, k_cache/v_cache are modified in-place. They must be sufficiently large to store
125
+ the new tokens, and the last N tokens will be overwritten. If False (default), the input kv will not be
126
+ modified, and a concat operation will be used instead. No effect if k_cache/v_cache are None.
127
+
128
+ Returns:
129
+ torch.Tensor: The logits of the model. Size b, t if tgt is None, else b, t_tgt
130
+ if tgt != None:
131
+ torch.Tensor: The cross entropy loss between the logits and tgt
132
+ elif return_k:
133
+ torch.Tensor: the k-cache
134
+ torch.Tensor: the v-cache
135
+ """
136
+
137
+ st_pos = pos[:, :, :-1]
138
+ channel_pos = pos[:, :, -1]
139
+
140
+ # forward the GPT model itself
141
+ tok_emb = self.transformer.token_embedding(seq) # token embeddings of shape (b, t, n_embd)
142
+ channel_emb = self.transformer.channel_embedding(channel_pos) # position embeddings of shape (t, n_embd)
143
+ x = self.transformer.drop(tok_emb + channel_emb)
144
+
145
+ if output_hidden_states:
146
+ hidden_states = [x]
147
+
148
+ k_list, v_list = [], []
149
+ for i, block in enumerate(self.transformer.h):
150
+ x = block(x, pos=st_pos, mask=mask,
151
+ k_cache=None if k_cache is None else k_cache[i],
152
+ v_cache=None if v_cache is None else v_cache[i],
153
+ return_kv=return_kv, inplace_kv=inplace_kv)
154
+ if return_kv:
155
+ x, k, v = x
156
+ k_list.append(k)
157
+ v_list.append(v)
158
+ if output_hidden_states:
159
+ hidden_states.append(x)
160
+
161
+ x = self.transformer.ln_f(x)
162
+ if output_hidden_states:
163
+ hidden_states.append(x)
164
+
165
+ # if tgt is not none, compute the logits for the entire sequence
166
+ if tgt is None:
167
+ logits = self.lm_head(x)
168
+ if output_hidden_states:
169
+ logits = {"logits": logits, "hidden_states": hidden_states}
170
+ if return_kv:
171
+ if inplace_kv:
172
+ # We modified in-place; avoid allocating a new tensor with torch.stack
173
+ return logits, k_cache, v_cache
174
+ else:
175
+ return logits, torch.stack(k_list), torch.stack(v_list)
176
+ return logits, None
177
+
178
+ # if tgt is not none, compute the logits and the loss for the target sequence
179
+ logits = self.lm_head(x[:, -tgt.size(1):])
180
+ if output_hidden_states:
181
+ logits = {"logits": logits, "hidden_states": hidden_states}
182
+
183
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), ignore_index=-1)
184
+ return logits, loss
185
+
186
+
187
+ ### Rollout Functions ###
188
+
189
+ @torch.no_grad()
190
+ def sample_logits(self,
191
+ logits: torch.FloatTensor,
192
+ temp: Optional[float] = None,
193
+ post_temp: Optional[float] = None,
194
+ top_k: Optional[int] = None,
195
+ top_p: Optional[float] = None,
196
+ min_p: Optional[float] = None,
197
+ sample_range: Optional[Tuple[int,int]] = None,
198
+ blacklist: Optional[Union[List[int], torch.LongTensor]] = None
199
+ ) -> torch.LongTensor:
200
+ """
201
+ Samples an integer from the distribution of logits
202
+
203
+ Parameters:
204
+ logits (torch.FloatTensor): The logits of the distribution
205
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
206
+ top_k (int): The number of top k tokens to consider during sampling
207
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
208
+ min_p (float): The minimum probability threshold factor for min-p sampling
209
+ blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling
210
+ Returns:
211
+ torch.LongTensor: The sampled integers
212
+ """
213
+ if isinstance(temp, list):
214
+ temp = temp[0]
215
+ if isinstance(post_temp, list):
216
+ post_temp = post_temp[0]
217
+ if isinstance(top_k, list):
218
+ top_k = top_k[0]
219
+ if isinstance(top_p, list):
220
+ top_p = top_p[0]
221
+ assert temp is None or temp >= 0.0
222
+ assert post_temp is None or post_temp >= 0.0
223
+ assert top_k is None or top_k > 0
224
+ assert top_p is None or top_p >= 0.0
225
+ assert min_p is None or 0.0 <= min_p <= 1.0
226
+ assert sample_range is None or (
227
+ sample_range[0] < sample_range[1] and
228
+ sample_range[0] >= 0 and
229
+ sample_range[1] <= logits.shape[-1]
230
+ )
231
+
232
+ # Apply blacklist & sample range
233
+ if blacklist is not None:
234
+ logits[...,blacklist] = float('-inf')
235
+ if sample_range is not None:
236
+ logits = logits[...,sample_range[0]:sample_range[1]]
237
+ token_offset = sample_range[0]
238
+ else:
239
+ token_offset = 0
240
+
241
+ # Apply temperature, or use argmax if 0.0
242
+ if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0):
243
+ return token_offset + torch.argmax(logits, dim=-1)
244
+ if temp is not None and temp != 1.0:
245
+ logits.div_(temp)
246
+
247
+ # Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting).
248
+ # We sample in sorted order then re-order before returning.
249
+ if top_k is not None or top_p is not None:
250
+ logits, order = torch.sort(logits, dim=-1, descending=True)
251
+ else:
252
+ order = None # Don't sort
253
+
254
+ # Apply top-k filtering if specified
255
+ if top_k is not None:
256
+ logits = logits[...,:top_k]
257
+
258
+ # Apply top-p (nucleus) filtering if specified
259
+ if top_p is not None:
260
+ probs = F.softmax(logits, dim=-1) # Already sorted
261
+ cumulative_probs = probs.cumsum_(dim=-1)
262
+ idxs_to_remove = cumulative_probs > top_p
263
+ # Shift the mask right to keep at least one token
264
+ logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf')
265
+ del probs, cumulative_probs, idxs_to_remove
266
+
267
+ # Apply min-p filtering if specified
268
+ if min_p is not None:
269
+ probs = F.softmax(logits, dim=-1)
270
+ maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values
271
+ logits[probs < maxprob * min_p] = float('-inf')
272
+ del probs, maxprob
273
+
274
+ # Apply optional post-temperature
275
+ if post_temp is not None and post_temp != 1.0:
276
+ logits.div_(post_temp)
277
+
278
+ # Compute softmax probabilities
279
+ orig_shape = logits.shape
280
+ probs = torch.softmax(logits, dim=-1, out=logits)
281
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
282
+ flat_probs = probs.view(-1, probs.size(-1))
283
+ # Sample from the distribution
284
+ sampled = torch.multinomial(flat_probs, num_samples=1)
285
+ # Reshape to original shape except for the last dimension
286
+ sampled = sampled.view(*orig_shape[:-1])
287
+
288
+ # If we sorted, unsort to collect the actual token values
289
+ if order is not None:
290
+ sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1)
291
+ return token_offset + sampled
292
+
293
+
294
+ @torch.no_grad()
295
+ def rollout_patches(self,
296
+ seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]],
297
+ pos: Union[torch.LongTensor, List[torch.LongTensor]],
298
+ idx: torch.LongTensor,
299
+ n_tokens_per_patch: int = 5,
300
+ n_seq_patches: int = -1,
301
+ weights: Optional[Union[List[float], torch.Tensor]] = None,
302
+ k_cache: Optional[torch.Tensor] = None,
303
+ v_cache: Optional[torch.Tensor] = None,
304
+ cache_mask: Optional[torch.Tensor] = None,
305
+ policy: Callable[..., torch.LongTensor] = None,
306
+ *,
307
+ unmask_parallel: bool = False,
308
+ return_logits: bool = False,
309
+ return_idx_logits: bool = True,
310
+ return_kv: bool = False,
311
+ verbose: bool = True,
312
+ temp: Optional[Union[float, List[float]]] = None,
313
+ post_temp: Optional[Union[float, List[float]]] = None,
314
+ top_k: Optional[Union[int, List[int]]] = None,
315
+ top_p: Optional[Union[float, List[float]]] = None,
316
+ min_p: Optional[Union[float, List[float]]] = None,
317
+ sample_range: Optional[Tuple[int, int]] = None,
318
+ blacklist: Optional[Union[List[int], torch.LongTensor]] = None
319
+ ) -> Union[
320
+ torch.LongTensor, # seq
321
+ Tuple[torch.LongTensor, torch.Tensor], # seq, logits
322
+ Tuple[torch.LongTensor, Dict[str, torch.Tensor]], # seq, kvcache
323
+ Tuple[torch.LongTensor, torch.Tensor, Dict[str, torch.Tensor]], # seq, logits, kvcache
324
+ ]:
325
+ """
326
+ K = number of given sequences (1 if seq is not a list)
327
+ T = length of a conditioning sequence (per sequence)
328
+ N = total length of conditioning + generated tokens (per sequence)
329
+ I = number of index tokens to roll out
330
+ max(T) = maximum of T across sequences
331
+ num_new_tokens = len(idx) * n_tokens_per_patch
332
+
333
+ ***Tips for long rollouts***:
334
+ 1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order)
335
+ 2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache
336
+ contiguously in memory. If you create tensors between rollouts, consider moving them to
337
+ CPU or cloning them to defragment VRAM.
338
+ 3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running
339
+ n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the
340
+ KV cache will duplicate memory. To avoid this, try moving the cache to CPU first:
341
+ ```
342
+ seq1, kvcache = predictor.rollout_patches(..., return_kv=True)
343
+ kvcache = { k: v.cpu() for k, v in kvcache.items() }
344
+ gc.collect(); torch.cuda.empty_cache()
345
+ seq2 = predictor.rollout_patches(..., **kvcache)
346
+ ```
347
+ With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it.
348
+
349
+ TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens
350
+
351
+ Parameters:
352
+ seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]):
353
+ [T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence
354
+ pos (Union[torch.LongTensor, List[torch.LongTensor]]):
355
+ [N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq)
356
+ idx (torch.LongTensor):
357
+ [I] / the patch indices to use in the rollout (same for all sequences)
358
+ n_tokens_per_patch (int):
359
+ number of tokens per patch, including patch index
360
+ n_seq_patches (int):
361
+ number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel
362
+ weights (Optional[Union[List[float], torch.Tensor]]):
363
+ float weights for the logits produced by each sequence. If None and multiple sequences are given, defaults to all ones.
364
+ If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens].
365
+ If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered.
366
+ When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency
367
+ k_cache (Optional[torch.Tensor]):
368
+ optional k_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
369
+ v_cache (Optional[torch.Tensor]):
370
+ optional v_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
371
+ cache_mask (Optional[torch.Tensor]):
372
+ optional mask to be applied to the provided KV cache with shape [K 1 1 n_tok], where n_tok matches k_cache/v_cache. Useful when a KV cache is supplied
373
+ for multiple conditioning sequences of different lengths, where the cache_mask indicates which elements of the cache should be attended to for each sequence.
374
+ If k_cache/v_cache are given and cache_mask is None, the cache will be fully unmasked. May be on a different device
375
+ policy (Callable[..., torch.LongTensor]):
376
+ optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next.
377
+ Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023],
378
+ return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given
379
+ - `idx` (torch.LongTensor of shape [I]) the candidate patch indices
380
+ - `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx`
381
+ - `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None
382
+ - `k_cache` (torch.Tensor)
383
+ - `v_cache` (torch.Tensor)
384
+ - `cache_mask` (torch.Tensor of shape [K 1 1 n_tok])
385
+ - `kvcache` (Dict[str, torch.Tensor]): the kvcache dict with keys 'k_cache', 'v_cache', and 'cache_mask'
386
+ - `all_k_cache` (torch.Tensor): the entire preallocated k-cache, including uninitialized tokens
387
+ - `all_v_cache` (torch.Tensor): the entire preallocated v-cache, including uninitialized tokens
388
+ - `n_tokens_per_patch` (int)
389
+ - `sample_range` (Optional[Tuple[int,int]])
390
+ - `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens
391
+ - `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens
392
+ - `device` (torch.device)\n
393
+ The callback must return the following
394
+ - (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself)
395
+ unmask_parallel (bool):
396
+ if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves
397
+ return_logits (bool):
398
+ return the logits of the sequence
399
+ return_idx_logits (bool):
400
+ if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns
401
+ logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits
402
+ for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False
403
+ return_kv (bool):
404
+ return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', and 'cache_mask', useful for downstream operations. If True and return_logits=False,
405
+ returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if
406
+ the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally
407
+
408
+ Returns:
409
+ torch.LongTensor:
410
+ [num_new_tokens] the generated tokens only (the input sequence is not prepended)
411
+ torch.Tensor:
412
+ (optional) [n_tokens vocab_size] the logits of the sequence, where n_tokens depends on return_idx_logits
413
+ Dict[str, torch.Tensor]:
414
+ (optional) the KV cache, with the following key/value pairs
415
+ - `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
416
+ - `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
417
+ - `cache_mask` (torch.Tensor) [K 1 1 n_tok]
418
+ """
419
+
420
+ #########################
421
+ # === Preprocessing === #
422
+ #########################
423
+
424
+ if not isinstance(seq, list):
425
+ seq = [seq] if seq is None or seq.ndim == 1 else list(seq)
426
+ if not isinstance(pos, list):
427
+ pos = [pos] if pos.ndim == 2 else list(pos)
428
+
429
+ nnt = idx.numel() * n_tokens_per_patch # num new tokens
430
+ device = pos[0].device
431
+ idtype = pos[0].dtype
432
+ dtype = self.lm_head.weight.dtype
433
+
434
+ if weights is not None:
435
+ if isinstance(weights, list):
436
+ weights = torch.tensor(weights, dtype=dtype, device=device)
437
+ if weights.ndim != 2:
438
+ weights = weights.unsqueeze(-1).expand(-1, nnt)
439
+ weights = weights.to(dtype).to(device)
440
+ elif len(seq) > 1:
441
+ weights = torch.ones(len(seq), nnt, dtype=dtype, device=device)
442
+
443
+ if n_seq_patches < 0:
444
+ n_seq_patches = idx.shape[0]
445
+ if temp is not None and not isinstance(temp, list):
446
+ temp = [temp] * nnt
447
+ if post_temp is not None and not isinstance(post_temp, list):
448
+ post_temp = [post_temp] * nnt
449
+ if top_k is not None and not isinstance(top_k, list):
450
+ top_k = [top_k] * nnt
451
+ if top_p is not None and not isinstance(top_p, list):
452
+ top_p = [top_p] * nnt
453
+ if min_p is not None and not isinstance(min_p, list):
454
+ min_p = [min_p] * nnt
455
+
456
+ K = len(seq)
457
+ I = idx.shape[0]
458
+ T = [0 if s is None else s.shape[0] for s in seq]
459
+ maxT = max(1, max(T))
460
+
461
+ tpp = n_tokens_per_patch
462
+ in_cache_size = 0 if k_cache is None else k_cache.shape[3]
463
+ n_rollout_tokens = tpp * n_seq_patches
464
+ n_par_patches = I - n_seq_patches
465
+ return_idx_logits = return_logits and return_idx_logits
466
+ run_last_parallel_tokens = return_idx_logits or return_kv
467
+
468
+ # Validate inputs as best we can
469
+ assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}'
470
+ assert idx.ndim == 1
471
+ if weights is not None:
472
+ assert weights.ndim == 2 and weights.shape == (K, nnt)
473
+ assert I * tpp == nnt, f'Requested {nnt} new tokens, but ({I} idx tokens) * ({tpp} tok per patch) = {I*tpp} != {nnt}'
474
+ assert 0 <= n_seq_patches <= I
475
+ assert k_cache is None or k_cache.ndim == 5
476
+ assert v_cache is None or v_cache.ndim == 5
477
+ assert (k_cache is None) == (v_cache is None)
478
+ assert k_cache is None or k_cache.shape[3] == v_cache.shape[3]
479
+ if cache_mask is not None:
480
+ assert cache_mask.ndim == 4 and cache_mask.shape[1] == 1 and cache_mask.shape[2] == 1
481
+ assert cache_mask.shape[-1] == in_cache_size, f'cache_mask ({cache_mask.shape[-1]} tokens) does not match the size of k_cache/v_cache ({in_cache_size} tokens)'
482
+ for i, (s, p) in enumerate(zip(seq, pos)):
483
+ if s is not None:
484
+ assert s.ndim == 1, f'Expected all sequence tensors to be 1D, but got seq[{i}].ndim={s.ndim}'
485
+ assert p.ndim == 2, f'Expected all position tensors to be 2D, but got pos[{i}].ndim={p.ndim}'
486
+ assert p.shape[1] == 4, f'Expected all position tensors have shape (*,4), but got pos[{i}].shape[1]={p.shape[1]}'
487
+ assert p.shape[0] == T[i] + nnt, f'Sequence {i}: With {T[i]} conditioning and {nnt} new tokens, expected pos[{i}].shape[0]={T[i]+nnt}, but got {p.shape[0]}'
488
+
489
+
490
+ #########################
491
+ # === Preallocation === #
492
+ #########################
493
+
494
+ # Preallocate the KV cache so we don't need to constantly resize it
495
+ # If we won't run the last parallel pass, we don't need to cache those toks
496
+ n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches)
497
+ # [n_layer K n_head n_tok n_embd//n_head]
498
+ kv_shape = (
499
+ self.config.n_layer, K, self.config.n_head,
500
+ n_kvcache, self.config.n_embd // self.config.n_head
501
+ )
502
+ all_v_cache = torch.empty(kv_shape, dtype=dtype, device=device)
503
+ all_k_cache = torch.empty(kv_shape, dtype=dtype, device=device)
504
+ if in_cache_size > 0:
505
+ all_k_cache[...,:in_cache_size,:].copy_(k_cache, non_blocking=True)
506
+ all_v_cache[...,:in_cache_size,:].copy_(v_cache, non_blocking=True)
507
+
508
+ # Also preallocate the output logits tensor, if requested
509
+ if return_logits:
510
+ n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1))
511
+ all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device)
512
+
513
+
514
+ ################################
515
+ # === Initial Forward Pass === #
516
+ ################################
517
+
518
+ # Stack seq/pos into a batch, left-padded
519
+ # [K maxT]
520
+ seq = torch.stack([(
521
+ torch.zeros(maxT, dtype=idtype, device=device) if s is None else
522
+ F.pad(s, (maxT - s.shape[0], 0))
523
+ ) for s in seq])
524
+ # [K maxN 4]
525
+ pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)])
526
+
527
+ # Build attention mask for initial forward pass [K 1 maxT maxT]
528
+ # Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked
529
+ mask = torch.zeros(K, 1, maxT, maxT, device=device)
530
+ mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf'))
531
+ for i, t in enumerate(T):
532
+ mask[i, ..., :maxT-t] = float('-inf')
533
+ # Unmask the diagonal so pad tokens can self-attend
534
+ # This doesn't matter with torch sdpa, but prevents NaNs with manual attention
535
+ # NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token
536
+ # in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below
537
+ mask[i, 0].fill_diagonal_(0.0)
538
+ if k_cache is not None:
539
+ if cache_mask is not None:
540
+ mask = torch.cat([cache_mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1)
541
+ else:
542
+ mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0))
543
+ # The above mask[:,0,-1,:] might look something like this:
544
+ # kv cache | sequences
545
+ # T[0]==3 [ T T T T T T | F F F F T T T ]
546
+ # T[1]==6 [ T T T T T T | F T T T T T T ]
547
+ # T[2]==7 [ T T T T T T | T T T T T T T ]
548
+ # T[3]==4 [ T T T T T T | F F F T T T T ]
549
+ # For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like:
550
+ # kv cache | sequences
551
+ # [ T T T T T T | T F F F F F F ]
552
+ # [ T T T T T T | F T F F F F F ]
553
+ # [ T T T T T T | F F T F F F F ]
554
+ # [ T T T T T T | F F F T F F F ]
555
+ # [ T T T T T T | F F F F T F F ]
556
+ # [ T T T T T T | F F F F T T F ]
557
+ # [ T T T T T T | F F F F T T T ]
558
+ # If a custom cache_mask is given, the kv cache part above may be different
559
+
560
+ # Initial forward pass (conditioning sequences only)
561
+ k_cache = all_k_cache[...,:in_cache_size+maxT,:]
562
+ v_cache = all_v_cache[...,:in_cache_size+maxT,:]
563
+ self.forward(
564
+ seq=seq, pos=pos[:,:maxT], mask=mask,
565
+ k_cache=k_cache, v_cache=v_cache, inplace_kv=True
566
+ )
567
+ pos = pos[:,maxT:]
568
+
569
+
570
+ ##############################
571
+ # === Sequential Rollout === #
572
+ ##############################
573
+
574
+ # Build attention mask for rollout [K 1 1 maxT+n_rollout_tokens], clone to free memory
575
+ mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0))
576
+ for i, t in enumerate(T):
577
+ # If t==0, the fill_diagonal call above unmasked the last pad token,
578
+ # so we need to re-mask it before we start rolling out
579
+ if t == 0:
580
+ mask[i, ..., in_cache_size+maxT-1] = float('-inf')
581
+ # The above mask[:,0,0,:] might look something like this:
582
+ # kv cache | sequences | rollout
583
+ # T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ]
584
+ # T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ]
585
+ # T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ]
586
+ # T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ]
587
+ # We construct this mask once, then slice off part of the right side at each rollout step
588
+
589
+ rollout_seq = []
590
+
591
+ # Rollout
592
+ for i in tqdm.tqdm(range(n_rollout_tokens), desc='Rollout', unit='tok', disable=(not verbose or n_rollout_tokens==0)):
593
+ if (i % tpp) == 0:
594
+ patch_number = i // tpp
595
+ if policy is None:
596
+ # Use provided order
597
+ next_token = idx[patch_number]
598
+ else:
599
+ # Use callback to select the next patch
600
+ policy_cache_mask = mask[..., :in_cache_size+maxT+i]
601
+ idx_of_next_idx = policy(
602
+ idx=idx[patch_number:], # [N]
603
+ pos=pos[:, i:], # [K N 4]
604
+ weights=None if weights is None else weights[:, i:], # [K N]
605
+ k_cache=k_cache,
606
+ v_cache=v_cache,
607
+ cache_mask=policy_cache_mask, # [K 1 1 n_tok]
608
+ kvcache=dict(k_cache=k_cache, v_cache=v_cache, cache_mask=policy_cache_mask),
609
+ all_k_cache=all_k_cache,
610
+ all_v_cache=all_v_cache,
611
+ n_tokens_per_patch=n_tokens_per_patch,
612
+ sample_range=sample_range,
613
+ idx_pos=pos[:, i::tpp], # [K I 4]
614
+ idx_weights=None if weights is None else weights[:, i::tpp], # [K I]
615
+ device=idx.device,
616
+ )
617
+ del policy_cache_mask
618
+ # Move the patch patch_number+idx_of_next_idx to the next position by swapping
619
+ if idx_of_next_idx != 0: # Don't bother if it's already next
620
+ i1, i2 = patch_number, patch_number + int(idx_of_next_idx)
621
+ idx[[i1,i2]] = idx[[i2,i1]]
622
+ # [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4]
623
+ pos = pos.reshape(K, I, tpp, 4)
624
+ pos[:,[i1,i2]] = pos[:,[i2,i1]]
625
+ pos = pos.reshape(K, -1, 4)
626
+ next_token = idx[patch_number]
627
+ rollout_seq.append(next_token)
628
+
629
+ # Forward call
630
+ k_cache = all_k_cache[...,:in_cache_size+maxT+i+1,:]
631
+ v_cache = all_v_cache[...,:in_cache_size+maxT+i+1,:]
632
+ logits, _ = self.forward(
633
+ seq=next_token.expand(K, 1), # [K 1]
634
+ pos=pos[:, [i]], # [K 1 4]
635
+ mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1]
636
+ k_cache=k_cache,
637
+ v_cache=v_cache,
638
+ inplace_kv=True
639
+ )
640
+
641
+ # Weighted sum of logits [K 1 V] -> [1 V]
642
+ if weights is not None:
643
+ w = weights[:,[i]] # [K nnt] -> [K 1]
644
+ logits = w.T @ logits.squeeze(1) # [1 V]
645
+ else:
646
+ logits = logits.squeeze(0) # [1 1 V] -> [1 V]
647
+
648
+ # If next is an index token, we just needed to cache the previous token
649
+ if ((i + 1) % tpp) == 0:
650
+ if return_idx_logits:
651
+ all_logits[i].copy_(logits.squeeze())
652
+ continue
653
+ if return_logits:
654
+ all_logits[i].copy_(logits.squeeze())
655
+
656
+ # Sample from logits
657
+ next_token = self.sample_logits(
658
+ logits,
659
+ temp=temp[i] if temp else None,
660
+ post_temp=post_temp[i] if post_temp else None,
661
+ top_k=top_k[i] if top_k else None,
662
+ top_p=top_p[i] if top_p else None,
663
+ min_p=min_p[i] if min_p else None,
664
+ sample_range=sample_range,
665
+ blacklist=blacklist
666
+ ).squeeze(0)
667
+ rollout_seq.append(next_token)
668
+
669
+ if n_rollout_tokens > 0:
670
+ rollout_seq = torch.stack(rollout_seq)
671
+ if n_rollout_tokens == nnt:
672
+ ret = (rollout_seq,)
673
+ if return_logits:
674
+ ret = (*ret, all_logits)
675
+ if return_kv:
676
+ ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
677
+ return ret if len(ret) > 1 else ret[0]
678
+
679
+ ############################
680
+ # === Parallel Rollout === #
681
+ ############################
682
+
683
+ npp = n_par_patches
684
+ npt = npp * tpp # num parallel tokens
685
+ idx = idx[-npp:]
686
+
687
+ # Build attention mask for parallel part [K 1 npp n_past+npp]
688
+ if unmask_parallel:
689
+ par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp)
690
+ else:
691
+ par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device)
692
+ par_mask.fill_diagonal_(0.0)
693
+ par_mask = par_mask.expand(K, 1, npp, npp)
694
+ mask = mask.expand(K, 1, npp, -1)
695
+ # Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim
696
+
697
+ # Reshape positions for parallel passes
698
+ # [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4]
699
+ pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4)
700
+
701
+ # Reshape scheduled properties (transpose similarly to positions)
702
+ # [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt]
703
+ if temp is not None:
704
+ temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
705
+ if post_temp is not None:
706
+ post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
707
+ if top_k is not None:
708
+ top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
709
+ if top_p is not None:
710
+ top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
711
+ if min_p is not None:
712
+ min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
713
+ if weights is not None:
714
+ # [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt]
715
+ weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt)
716
+
717
+ next_tokens = idx
718
+ parallel_seq = [next_tokens]
719
+
720
+ # Run parallel passes
721
+ for i in range(tpp if run_last_parallel_tokens else (tpp - 1)):
722
+ mask = torch.cat([mask, par_mask], dim=-1)
723
+ parallel_slice = slice(i*npp, (i+1)*npp)
724
+ k_cache = all_k_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
725
+ v_cache = all_v_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
726
+ logits, _ = self.forward(
727
+ seq=next_tokens.expand(K, npp), # [K npp]
728
+ pos=pos[:,parallel_slice], # [K npp 4]
729
+ mask=mask, # [K 1 npp n_past+npp]
730
+ k_cache=k_cache,
731
+ v_cache=v_cache,
732
+ inplace_kv=True
733
+ )
734
+
735
+ # Weighted sum of logits [K npp V] -> [npp V]
736
+ if weights is not None:
737
+ w = weights[:,parallel_slice] # [K npt] -> [K npp]
738
+ logits = (logits * w.unsqueeze(-1)).sum(0) # [K npp V] -> [npp V]
739
+ else:
740
+ logits = logits.squeeze(0) # [1 npp V] -> [npp V]
741
+
742
+ if return_logits and (i < tpp - 1 or return_idx_logits):
743
+ # Store the logits with a stride so we don't need to transpose later
744
+ # NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp
745
+ stride = tpp if return_idx_logits else (tpp - 1)
746
+ all_logits[n_rollout_tokens+i::stride].copy_(logits)
747
+ if i == (tpp - 1):
748
+ # We just needed to compute logits and/or KV to return them; no need to predict
749
+ break
750
+
751
+ # Sample from logits
752
+ # TODO: Index using parallel_slice instead of -1 to support scheduled parameters
753
+ next_tokens = self.sample_logits(
754
+ logits,
755
+ temp=temp[-1] if temp else None,
756
+ post_temp=post_temp[-1] if post_temp else None,
757
+ top_k=top_k[-1] if top_k else None,
758
+ top_p=top_p[-1] if top_p else None,
759
+ min_p=min_p[-1] if min_p else None,
760
+ sample_range=sample_range,
761
+ blacklist=blacklist
762
+ )
763
+ parallel_seq.append(next_tokens)
764
+
765
+ # [tpp npp] -> [npp tpp] -> [npt]
766
+ parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten()
767
+ if return_kv:
768
+ # Transpose only the last npp (num parallel patches) patches
769
+ # [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E]
770
+ edims = all_k_cache.shape[:-2]
771
+ par_k = all_k_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
772
+ par_v = all_v_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
773
+ all_k_cache[...,-npt:,:].copy_(par_k.clone())
774
+ all_v_cache[...,-npt:,:].copy_(par_v.clone())
775
+ del par_k, par_v
776
+ # [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols)
777
+ # Unmask the last npt cols (if necessary) to make the parallel part fully unmasked
778
+ mask = mask[...,[-1],:].clone()
779
+ if not unmask_parallel:
780
+ mask[...,-npt:] = 0.0
781
+
782
+ ret = (torch.cat([rollout_seq, parallel_seq]),) if n_rollout_tokens > 0 else (parallel_seq,)
783
+ if return_logits:
784
+ ret = (*ret, all_logits)
785
+ if return_kv:
786
+ ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
787
+ return ret if len(ret) > 1 else ret[0]
788
+