Upload folder using huggingface_hub
Browse files- __init__.py +3 -0
- config.json +9 -0
- config.py +39 -0
- modeling.py +381 -0
- psi.py +788 -0
__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from .psi import PSI
|
| 3 |
+
from .config import PSIConfig
|
config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "StanfordNeuroAILab/PSI",
|
| 3 |
+
"architectures": ["PSI"],
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "config.PSIConfig",
|
| 6 |
+
"AutoModel": "psi.PSI"
|
| 7 |
+
},
|
| 8 |
+
"model_type": "PSI"
|
| 9 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import Tuple, List, Optional
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
class PSIConfig(PretrainedConfig):
|
| 6 |
+
model_type: str = "PSI"
|
| 7 |
+
def __init__(self,
|
| 8 |
+
vocab_size: int = 96256,
|
| 9 |
+
channel_size: int = 12,
|
| 10 |
+
n_layer: int = 12,
|
| 11 |
+
n_head: int = 12,
|
| 12 |
+
n_embd: int = 768,
|
| 13 |
+
dropout: float = 0.0,
|
| 14 |
+
bias: bool = False,
|
| 15 |
+
attention_mask: str = "causal",
|
| 16 |
+
tie_weights: bool = False,
|
| 17 |
+
partition_embedding: bool = False,
|
| 18 |
+
n_lm_vocab: Optional[int] = None,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
self.vocab_size = vocab_size
|
| 22 |
+
self.channel_size = channel_size
|
| 23 |
+
self.n_layer = n_layer
|
| 24 |
+
self.n_head = n_head
|
| 25 |
+
self.n_embd = n_embd
|
| 26 |
+
self.dropout = dropout
|
| 27 |
+
self.bias = bias
|
| 28 |
+
self.attention_mask = attention_mask
|
| 29 |
+
self.tie_weights = tie_weights
|
| 30 |
+
self.partition_embedding = partition_embedding
|
| 31 |
+
self.n_lm_vocab = n_lm_vocab
|
| 32 |
+
|
| 33 |
+
# Aside from HuggingFace default config attributes,
|
| 34 |
+
# all extra kwargs are assigned using setattr. For HuggingFace attrs, see:
|
| 35 |
+
# https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/configuration_utils.py#L45
|
| 36 |
+
|
| 37 |
+
# Since token ranges are checkpoint-specific, we don't include them
|
| 38 |
+
# in this config and let them be assigned from kwargs.
|
| 39 |
+
super().__init__(**kwargs)
|
modeling.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import torch_xla.core.xla_model as xm
|
| 9 |
+
import torch_xla.distributed.spmd.xla_sharding as xs
|
| 10 |
+
except ImportError:
|
| 11 |
+
xm = None
|
| 12 |
+
xs = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Rotary3D(nn.Module):
|
| 16 |
+
def __init__(self, dim, base=100):
|
| 17 |
+
super().__init__()
|
| 18 |
+
assert dim % 16 == 0, "Embedding dim must be divisible by 16"
|
| 19 |
+
|
| 20 |
+
# Embedding dimensions must align precisely with dim // num_heads
|
| 21 |
+
self.x_dim = (6 * dim) // 16
|
| 22 |
+
self.y_dim = (6 * dim) // 16
|
| 23 |
+
self.t_dim = dim - self.x_dim - self.y_dim
|
| 24 |
+
|
| 25 |
+
# Precompute inverse frequencies
|
| 26 |
+
self.register_buffer('inv_freq_x', 1.0 / (base ** (torch.arange(0, self.x_dim, 2).float() / self.x_dim)))
|
| 27 |
+
self.register_buffer('inv_freq_y', 1.0 / (base ** (torch.arange(0, self.y_dim, 2).float() / self.y_dim)))
|
| 28 |
+
self.register_buffer('inv_freq_t', 1.0 / (base ** (torch.arange(0, self.t_dim, 2).float() / self.t_dim)))
|
| 29 |
+
|
| 30 |
+
def forward(self, x, pos):
|
| 31 |
+
"""
|
| 32 |
+
x: [batch, nh, seq_len, head_dim]
|
| 33 |
+
pos: [batch, seq_len, 3] integer positions along (x, y, t)
|
| 34 |
+
"""
|
| 35 |
+
B, nh, T, hs = x.shape
|
| 36 |
+
assert pos.shape[-1] == 3, "Position tensor must have shape [batch, seq_len, 3]"
|
| 37 |
+
|
| 38 |
+
# Compute embeddings directly to match `hs`
|
| 39 |
+
dim_total = hs
|
| 40 |
+
assert dim_total % 2 == 0, "head_dim (hs) must be divisible by 2 for rotary embedding."
|
| 41 |
+
|
| 42 |
+
# Positional dimensions expanded explicitly
|
| 43 |
+
dtype = self.inv_freq_x.dtype
|
| 44 |
+
pos_x = pos[..., 0].to(dtype) # [B, T]
|
| 45 |
+
pos_y = pos[..., 1].to(dtype) # [B, T]
|
| 46 |
+
pos_t = pos[..., 2].to(dtype) # [B, T]
|
| 47 |
+
|
| 48 |
+
# Generate embeddings for x, y, t and combine
|
| 49 |
+
freqs_x = torch.einsum('bt,f -> btf', pos_x, self.inv_freq_x)
|
| 50 |
+
freqs_y = torch.einsum('bt,f -> btf', pos_y, self.inv_freq_y)
|
| 51 |
+
freqs_t = torch.einsum('bt,f -> btf', pos_t, self.inv_freq_t)
|
| 52 |
+
|
| 53 |
+
# Concatenate embeddings and match dimensions exactly
|
| 54 |
+
freq_combined = torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
|
| 55 |
+
|
| 56 |
+
# Cos and Sin embedding, reshape to match x exactly
|
| 57 |
+
cos_emb = freq_combined.cos().unsqueeze(1) # [B, 1, T, hs/2]
|
| 58 |
+
sin_emb = freq_combined.sin().unsqueeze(1) # [B, 1, T, hs/2]
|
| 59 |
+
|
| 60 |
+
# Split embedding dimension for rotation
|
| 61 |
+
x1, x2 = x[..., :hs//2], x[..., hs//2:]
|
| 62 |
+
|
| 63 |
+
# Ensure exact dimensional matching
|
| 64 |
+
x_rotated = torch.cat([
|
| 65 |
+
x1 * cos_emb - x2 * sin_emb,
|
| 66 |
+
x1 * sin_emb + x2 * cos_emb
|
| 67 |
+
], dim=-1)
|
| 68 |
+
|
| 69 |
+
return x_rotated
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class PSIAttentionLayer(nn.Module):
|
| 73 |
+
|
| 74 |
+
def __init__(self, config):
|
| 75 |
+
|
| 76 |
+
super().__init__()
|
| 77 |
+
assert config.n_embd % config.n_head == 0
|
| 78 |
+
|
| 79 |
+
# key, query, value projections for all heads, but in a batch
|
| 80 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 81 |
+
# output projection
|
| 82 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 83 |
+
# regularization
|
| 84 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 85 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 86 |
+
self.n_head = config.n_head
|
| 87 |
+
self.n_embd = config.n_embd
|
| 88 |
+
self.dropout = config.dropout
|
| 89 |
+
# positional embedding
|
| 90 |
+
self.rope = Rotary3D(config.n_embd // config.n_head)
|
| 91 |
+
|
| 92 |
+
# check if we are using causal attention
|
| 93 |
+
if config.attention_mask == "causal":
|
| 94 |
+
self.is_causal = True
|
| 95 |
+
else:
|
| 96 |
+
self.is_causal = False
|
| 97 |
+
|
| 98 |
+
# check if GPU Flash Attention is available
|
| 99 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 100 |
+
|
| 101 |
+
# check if we are running on TPU
|
| 102 |
+
try:
|
| 103 |
+
# Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
|
| 104 |
+
import torch_xla.core.xla_model as xm_local
|
| 105 |
+
self.tpu = True
|
| 106 |
+
except ImportError:
|
| 107 |
+
self.tpu = False
|
| 108 |
+
|
| 109 |
+
# Apply XLA sharding for model parallelism
|
| 110 |
+
xla_device_available = False
|
| 111 |
+
if xm is not None:
|
| 112 |
+
try:
|
| 113 |
+
device_kind = xm.xla_device_kind()
|
| 114 |
+
if device_kind is not None:
|
| 115 |
+
xla_device_available = True
|
| 116 |
+
except RuntimeError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
@torch.compiler.disable
|
| 120 |
+
def emplace_kv(self, T, k_cache, v_cache, k, v):
|
| 121 |
+
# torch.compile doesn't play well with this op (5x slowdown)
|
| 122 |
+
# so we insert a graph break and copy eagerly
|
| 123 |
+
k_cache[:,:,-T:].copy_(k)
|
| 124 |
+
v_cache[:,:,-T:].copy_(v)
|
| 125 |
+
return k_cache, v_cache
|
| 126 |
+
|
| 127 |
+
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, mask=None):
|
| 128 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 129 |
+
|
| 130 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 131 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 132 |
+
|
| 133 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 134 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 135 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 136 |
+
|
| 137 |
+
# Apply rotary positional embedding
|
| 138 |
+
k = self.rope(k, pos)
|
| 139 |
+
q = self.rope(q, pos)
|
| 140 |
+
|
| 141 |
+
if inplace_kv and k_cache is not None and v_cache is not None:
|
| 142 |
+
# assign into kv cache in-place
|
| 143 |
+
k, v = self.emplace_kv(T, k_cache, v_cache, k, v)
|
| 144 |
+
else:
|
| 145 |
+
# append cached keys and values with new keys and values
|
| 146 |
+
if k_cache is not None:
|
| 147 |
+
k = torch.cat((k_cache, k), dim=2)
|
| 148 |
+
if v_cache is not None:
|
| 149 |
+
v = torch.cat((v_cache, v), dim=2)
|
| 150 |
+
|
| 151 |
+
# Apply attention
|
| 152 |
+
if self.tpu:
|
| 153 |
+
# (1)
|
| 154 |
+
from torch_xla.experimental.custom_kernel import flash_attention
|
| 155 |
+
q_norm = q / math.sqrt(k.size(-1))
|
| 156 |
+
y = flash_attention(
|
| 157 |
+
q_norm, k, v,
|
| 158 |
+
causal=True, partition_spec=('fsdp', None, None, None))
|
| 159 |
+
# (2)
|
| 160 |
+
# y = torch.nn.functional.scaled_dot_product_attention(
|
| 161 |
+
# q, k, v,
|
| 162 |
+
# # dropout_p=self.dropout if self.training else 0,
|
| 163 |
+
# # attn_mask=None if mask is None else mask.to(q.dtype),
|
| 164 |
+
# is_causal=True
|
| 165 |
+
# )
|
| 166 |
+
elif self.flash:
|
| 167 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 168 |
+
L, S = q.size(-2), k.size(-2)
|
| 169 |
+
is_causal = self.is_causal and mask is None
|
| 170 |
+
# is_causal doesn't work when not square, so replace with a manual mask if needed
|
| 171 |
+
if is_causal and L < S:
|
| 172 |
+
if L > 1: # if L=1, just use no mask
|
| 173 |
+
mask = torch.ones(L, S, dtype=q.dtype, device=q.device)
|
| 174 |
+
mask.masked_fill_(mask.to(torch.bool).triu(S-L+1), float('-inf'))
|
| 175 |
+
is_causal = False
|
| 176 |
+
|
| 177 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
| 178 |
+
q, k, v,
|
| 179 |
+
dropout_p=self.dropout if self.training else 0,
|
| 180 |
+
attn_mask=None if mask is None else mask.to(q.dtype),
|
| 181 |
+
is_causal=is_causal
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
# manual implementation of attention
|
| 185 |
+
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
|
| 186 |
+
# apply mask, or use causal if default
|
| 187 |
+
if mask is not None:
|
| 188 |
+
att = att + mask
|
| 189 |
+
elif self.is_causal:
|
| 190 |
+
L, S = q.size(-2), k.size(-2)
|
| 191 |
+
mask = torch.ones(1, 1, L, S).triu(S-L+1).to(dtype=torch.bool).to(x.device)
|
| 192 |
+
att.masked_fill_(mask, float('-inf'))
|
| 193 |
+
# upcast to float32 for numerical stability, as per llama implementation
|
| 194 |
+
att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 195 |
+
att = self.attn_dropout(att)
|
| 196 |
+
# multiply attention weights with values to get output
|
| 197 |
+
y = torch.einsum('bnsk,bnkh->bnsh', att, v)
|
| 198 |
+
|
| 199 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 200 |
+
# output projection
|
| 201 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 202 |
+
# return key and value caches if requested
|
| 203 |
+
if return_kv:
|
| 204 |
+
return y, k, v
|
| 205 |
+
|
| 206 |
+
return y
|
| 207 |
+
|
| 208 |
+
def kv_cache_forward(self, x, pos, k_cache=None, v_cache=None):
|
| 209 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 210 |
+
|
| 211 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 212 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 213 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 214 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 215 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 216 |
+
|
| 217 |
+
# Apply rotary positional embedding (before concat)
|
| 218 |
+
k = self.rope(k, pos)
|
| 219 |
+
q = self.rope(q, pos)
|
| 220 |
+
|
| 221 |
+
# append cached keys and values with new keys and values
|
| 222 |
+
if k_cache is not None:
|
| 223 |
+
k = torch.cat((k_cache, k), dim=2)
|
| 224 |
+
if v_cache is not None:
|
| 225 |
+
v = torch.cat((v_cache, v), dim=2)
|
| 226 |
+
|
| 227 |
+
# manual implementation of attention
|
| 228 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 229 |
+
att = F.softmax(att, dim=-1)
|
| 230 |
+
att = self.attn_dropout(att)
|
| 231 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 232 |
+
|
| 233 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 234 |
+
|
| 235 |
+
# output projection
|
| 236 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 237 |
+
|
| 238 |
+
return y, k, v
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class MLP(nn.Module):
|
| 242 |
+
|
| 243 |
+
def __init__(self, config):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 246 |
+
self.gelu = nn.GELU()
|
| 247 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 248 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 249 |
+
|
| 250 |
+
# Apply XLA sharding for model parallelism
|
| 251 |
+
xla_device_available = False
|
| 252 |
+
if xm is not None:
|
| 253 |
+
try:
|
| 254 |
+
device_kind = xm.xla_device_kind()
|
| 255 |
+
if device_kind is not None:
|
| 256 |
+
xla_device_available = True
|
| 257 |
+
except RuntimeError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
if xla_device_available and xs is not None and xs.global_mesh() is not None:
|
| 261 |
+
mesh = xs.global_mesh()
|
| 262 |
+
if mesh.mesh_shape[1] > 1: # If the 'model' axis has size > 1
|
| 263 |
+
xs.mark_sharding(self.c_fc.weight, mesh, (1, 0))
|
| 264 |
+
if self.c_fc.bias is not None:
|
| 265 |
+
xs.mark_sharding(self.c_fc.bias, mesh, (1,))
|
| 266 |
+
print(f"MLP: Applied MP sharding to c_fc {mesh.mesh_shape} spec weight(1,0), bias(1,)")
|
| 267 |
+
|
| 268 |
+
xs.mark_sharding(self.c_proj.weight, mesh, (0, 1))
|
| 269 |
+
if self.c_proj.bias is not None:
|
| 270 |
+
xs.mark_sharding(self.c_proj.bias, mesh, (0,))
|
| 271 |
+
print(f"MLP: Applied MP sharding to c_proj {mesh.mesh_shape} spec weight(0,1), bias(0,)")
|
| 272 |
+
|
| 273 |
+
def forward(self, x, spmd_mesh=None):
|
| 274 |
+
|
| 275 |
+
x = self.c_fc(x)
|
| 276 |
+
x = self.gelu(x)
|
| 277 |
+
|
| 278 |
+
if spmd_mesh is not None:
|
| 279 |
+
import torch_xla.distributed.spmd.xla_sharding as xs
|
| 280 |
+
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
|
| 281 |
+
|
| 282 |
+
x = self.c_proj(x)
|
| 283 |
+
x = self.dropout(x)
|
| 284 |
+
|
| 285 |
+
if spmd_mesh is not None:
|
| 286 |
+
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
|
| 287 |
+
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class RMSNorm(nn.Module):
|
| 292 |
+
""" Root Mean Square Normalization """
|
| 293 |
+
def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-5): # whl
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.eps = eps
|
| 296 |
+
self.weight = nn.Parameter(torch.ones(dim)) if weight else None
|
| 297 |
+
|
| 298 |
+
def _norm(self, x):
|
| 299 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 300 |
+
|
| 301 |
+
def forward(self, x):
|
| 302 |
+
output = self._norm(x.float()).type_as(x)
|
| 303 |
+
if self.weight is not None:
|
| 304 |
+
return output * self.weight
|
| 305 |
+
return output
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class PSIBlock(nn.Module):
|
| 309 |
+
def __init__(self, config):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.ln_1 = RMSNorm(config.n_embd, bias=config.bias)
|
| 312 |
+
self.attn = PSIAttentionLayer(config)
|
| 313 |
+
self.ln_2 = RMSNorm(config.n_embd, bias=config.bias)
|
| 314 |
+
self.mlp = MLP(config)
|
| 315 |
+
|
| 316 |
+
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, spmd_mesh=None, mask=None):
|
| 317 |
+
# If we are given a key and value cache, we will use the pre-computed values to minimize
|
| 318 |
+
# the computation cost
|
| 319 |
+
if return_kv:
|
| 320 |
+
# Pass the key and value cache to the attention layer, obtain new key and value caches
|
| 321 |
+
x_attn, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache,
|
| 322 |
+
return_kv=True, inplace_kv=inplace_kv, mask=mask)
|
| 323 |
+
x = x + x_attn
|
| 324 |
+
x = x + self.mlp(self.ln_2(x))
|
| 325 |
+
return x, k, v
|
| 326 |
+
# Else we proceed with the regular forward pass
|
| 327 |
+
x = x + self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, inplace_kv=inplace_kv, mask=mask)
|
| 328 |
+
x = x + self.mlp(self.ln_2(x))
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class PartitionedEmbedding(nn.Module):
|
| 333 |
+
def __init__(self, num_embeddings, embedding_dim, partition_size=65536):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.num_embeddings = num_embeddings
|
| 336 |
+
self.embedding_dim = embedding_dim
|
| 337 |
+
self.partition_size = partition_size
|
| 338 |
+
self.num_partitions = (num_embeddings + partition_size - 1) // partition_size
|
| 339 |
+
|
| 340 |
+
self.embedding_layers = nn.ModuleList()
|
| 341 |
+
for i in range(self.num_partitions):
|
| 342 |
+
start_idx = i * self.partition_size
|
| 343 |
+
end_idx = min(start_idx + self.partition_size, num_embeddings)
|
| 344 |
+
vocab_size = end_idx - start_idx
|
| 345 |
+
self.embedding_layers.append(nn.Embedding(vocab_size, embedding_dim))
|
| 346 |
+
|
| 347 |
+
def forward(self, input_ids):
|
| 348 |
+
partition_ids = input_ids // self.partition_size
|
| 349 |
+
relative_ids = input_ids % self.partition_size
|
| 350 |
+
|
| 351 |
+
output = torch.zeros(*input_ids.shape, self.embedding_dim, device=input_ids.device, dtype=self.embedding_layers[0].weight.dtype)
|
| 352 |
+
|
| 353 |
+
for i in range(self.num_partitions):
|
| 354 |
+
mask = (partition_ids == i)
|
| 355 |
+
if mask.any():
|
| 356 |
+
partition_input_ids = relative_ids[mask]
|
| 357 |
+
embedded = self.embedding_layers[i](partition_input_ids)
|
| 358 |
+
output[mask] = embedded
|
| 359 |
+
|
| 360 |
+
return output
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class PartitionedLinear(nn.Module):
|
| 364 |
+
def __init__(self, in_features, out_features, partition_size=65536, bias=False):
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.in_features = in_features
|
| 367 |
+
self.out_features = out_features
|
| 368 |
+
self.partition_size = partition_size
|
| 369 |
+
self.num_partitions = (out_features + partition_size - 1) // partition_size
|
| 370 |
+
|
| 371 |
+
self.linear_layers = nn.ModuleList()
|
| 372 |
+
for i in range(self.num_partitions):
|
| 373 |
+
start_idx = i * self.partition_size
|
| 374 |
+
end_idx = min(start_idx + self.partition_size, out_features)
|
| 375 |
+
output_partition_size = end_idx - start_idx
|
| 376 |
+
self.linear_layers.append(nn.Linear(in_features, output_partition_size, bias=bias))
|
| 377 |
+
|
| 378 |
+
def forward(self, input):
|
| 379 |
+
outputs = [layer(input) for layer in self.linear_layers]
|
| 380 |
+
return torch.cat(outputs, dim=-1)
|
| 381 |
+
|
psi.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PSI Model Definition
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Tuple, Union, List, Optional, Callable, Dict
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import numpy as np
|
| 13 |
+
import tqdm
|
| 14 |
+
|
| 15 |
+
from .config import PSIConfig
|
| 16 |
+
from .modeling import (
|
| 17 |
+
RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import torch_xla.core.xla_model as xm
|
| 22 |
+
import torch_xla.distributed.spmd.xla_sharding as xs
|
| 23 |
+
except ImportError:
|
| 24 |
+
xm = None
|
| 25 |
+
xs = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PSI(PreTrainedModel):
|
| 30 |
+
config_class = PSIConfig
|
| 31 |
+
|
| 32 |
+
### Initialization Functions ###
|
| 33 |
+
|
| 34 |
+
def __init__(self, config):
|
| 35 |
+
super().__init__(config)
|
| 36 |
+
self.config = config
|
| 37 |
+
|
| 38 |
+
if hasattr(config, "partition_embedding") and config.partition_embedding:
|
| 39 |
+
token_embedding = PartitionedEmbedding(config.vocab_size, config.n_embd)
|
| 40 |
+
lm_head = PartitionedLinear(config.n_embd, config.vocab_size, bias=False)
|
| 41 |
+
else:
|
| 42 |
+
token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
|
| 43 |
+
if hasattr(config, "n_lm_vocab") and config.n_lm_vocab is not None:
|
| 44 |
+
n_lm_vocab = config.n_lm_vocab
|
| 45 |
+
else:
|
| 46 |
+
n_lm_vocab = config.vocab_size
|
| 47 |
+
lm_head = nn.Linear(config.n_embd, n_lm_vocab, bias=False)
|
| 48 |
+
|
| 49 |
+
self.transformer = nn.ModuleDict(dict(
|
| 50 |
+
token_embedding = token_embedding,
|
| 51 |
+
channel_embedding = nn.Embedding(config.channel_size, config.n_embd),
|
| 52 |
+
drop = nn.Dropout(config.dropout),
|
| 53 |
+
h = nn.ModuleList([PSIBlock(config) for _ in range(config.n_layer)]),
|
| 54 |
+
ln_f = RMSNorm(config.n_embd, bias=config.bias),
|
| 55 |
+
))
|
| 56 |
+
self.lm_head = lm_head
|
| 57 |
+
|
| 58 |
+
# init all weights
|
| 59 |
+
self.apply(self._init_weights)
|
| 60 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
| 61 |
+
for pn, p in self.named_parameters():
|
| 62 |
+
if pn.endswith('c_proj.weight'):
|
| 63 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 64 |
+
|
| 65 |
+
if hasattr(config, "tie_weights") and config.tie_weights:
|
| 66 |
+
if hasattr(config, "partition_embedding") and config.partition_embedding:
|
| 67 |
+
for i in range(len(self.transformer.token_embedding.embedding_layers)):
|
| 68 |
+
self.lm_head.linear_layers[i].weight = self.transformer.token_embedding.embedding_layers[i].weight
|
| 69 |
+
else:
|
| 70 |
+
self.lm_head.weight = self.transformer.token_embedding.weight
|
| 71 |
+
|
| 72 |
+
# Apply XLA sharding for model parallelism if on XLA and model axis > 1
|
| 73 |
+
xla_device_available = False
|
| 74 |
+
if xm is not None:
|
| 75 |
+
try:
|
| 76 |
+
device_kind = xm.xla_device_kind()
|
| 77 |
+
if device_kind is not None:
|
| 78 |
+
xla_device_available = True
|
| 79 |
+
except RuntimeError:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
self.unsharded_param_count = self.get_num_params()
|
| 83 |
+
|
| 84 |
+
def _init_weights(self, module):
|
| 85 |
+
if isinstance(module, nn.Linear):
|
| 86 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 87 |
+
if module.bias is not None:
|
| 88 |
+
torch.nn.init.zeros_(module.bias)
|
| 89 |
+
elif isinstance(module, nn.Embedding):
|
| 90 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 91 |
+
|
| 92 |
+
def get_num_params(self):
|
| 93 |
+
"""Return the number of parameters in the model."""
|
| 94 |
+
return sum(p.numel() for p in self.parameters())
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
### Training Functions ###
|
| 98 |
+
|
| 99 |
+
def forward(
|
| 100 |
+
self,
|
| 101 |
+
seq: torch.Tensor,
|
| 102 |
+
pos: torch.Tensor,
|
| 103 |
+
tgt: torch.Tensor = None,
|
| 104 |
+
mask: torch.Tensor = None,
|
| 105 |
+
k_cache: torch.Tensor = None,
|
| 106 |
+
v_cache: torch.Tensor = None,
|
| 107 |
+
return_kv: bool = False,
|
| 108 |
+
inplace_kv: bool = False,
|
| 109 |
+
output_hidden_states: bool = False,
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
"""
|
| 112 |
+
Forward pass of the model
|
| 113 |
+
|
| 114 |
+
Parameters:
|
| 115 |
+
seq (torch.Tensor) of size b, t: The input sequence
|
| 116 |
+
pos (torch.Tensor) of size b, t, d: The positional indices of the sequence of shape (batch, tokens, dimensions)
|
| 117 |
+
They consist of x, y, t and c coordinates, where x, y are the spatial coordinates of the patch,
|
| 118 |
+
t is the time index and c is the channel index
|
| 119 |
+
tgt (torch.Tensor) of size b, t_tgt: The target sequence
|
| 120 |
+
mask (torch.Tensor) of size b, t, t: The mask of the sequence
|
| 121 |
+
k_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A k-cache to prepend
|
| 122 |
+
v_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A v-cache to prepend
|
| 123 |
+
return_kv (bool): If True, returns (logits, k, v). Ignored if tgt != None
|
| 124 |
+
inplace_kv (bool): If True, k_cache/v_cache are modified in-place. They must be sufficiently large to store
|
| 125 |
+
the new tokens, and the last N tokens will be overwritten. If False (default), the input kv will not be
|
| 126 |
+
modified, and a concat operation will be used instead. No effect if k_cache/v_cache are None.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor: The logits of the model. Size b, t if tgt is None, else b, t_tgt
|
| 130 |
+
if tgt != None:
|
| 131 |
+
torch.Tensor: The cross entropy loss between the logits and tgt
|
| 132 |
+
elif return_k:
|
| 133 |
+
torch.Tensor: the k-cache
|
| 134 |
+
torch.Tensor: the v-cache
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
st_pos = pos[:, :, :-1]
|
| 138 |
+
channel_pos = pos[:, :, -1]
|
| 139 |
+
|
| 140 |
+
# forward the GPT model itself
|
| 141 |
+
tok_emb = self.transformer.token_embedding(seq) # token embeddings of shape (b, t, n_embd)
|
| 142 |
+
channel_emb = self.transformer.channel_embedding(channel_pos) # position embeddings of shape (t, n_embd)
|
| 143 |
+
x = self.transformer.drop(tok_emb + channel_emb)
|
| 144 |
+
|
| 145 |
+
if output_hidden_states:
|
| 146 |
+
hidden_states = [x]
|
| 147 |
+
|
| 148 |
+
k_list, v_list = [], []
|
| 149 |
+
for i, block in enumerate(self.transformer.h):
|
| 150 |
+
x = block(x, pos=st_pos, mask=mask,
|
| 151 |
+
k_cache=None if k_cache is None else k_cache[i],
|
| 152 |
+
v_cache=None if v_cache is None else v_cache[i],
|
| 153 |
+
return_kv=return_kv, inplace_kv=inplace_kv)
|
| 154 |
+
if return_kv:
|
| 155 |
+
x, k, v = x
|
| 156 |
+
k_list.append(k)
|
| 157 |
+
v_list.append(v)
|
| 158 |
+
if output_hidden_states:
|
| 159 |
+
hidden_states.append(x)
|
| 160 |
+
|
| 161 |
+
x = self.transformer.ln_f(x)
|
| 162 |
+
if output_hidden_states:
|
| 163 |
+
hidden_states.append(x)
|
| 164 |
+
|
| 165 |
+
# if tgt is not none, compute the logits for the entire sequence
|
| 166 |
+
if tgt is None:
|
| 167 |
+
logits = self.lm_head(x)
|
| 168 |
+
if output_hidden_states:
|
| 169 |
+
logits = {"logits": logits, "hidden_states": hidden_states}
|
| 170 |
+
if return_kv:
|
| 171 |
+
if inplace_kv:
|
| 172 |
+
# We modified in-place; avoid allocating a new tensor with torch.stack
|
| 173 |
+
return logits, k_cache, v_cache
|
| 174 |
+
else:
|
| 175 |
+
return logits, torch.stack(k_list), torch.stack(v_list)
|
| 176 |
+
return logits, None
|
| 177 |
+
|
| 178 |
+
# if tgt is not none, compute the logits and the loss for the target sequence
|
| 179 |
+
logits = self.lm_head(x[:, -tgt.size(1):])
|
| 180 |
+
if output_hidden_states:
|
| 181 |
+
logits = {"logits": logits, "hidden_states": hidden_states}
|
| 182 |
+
|
| 183 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), ignore_index=-1)
|
| 184 |
+
return logits, loss
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
### Rollout Functions ###
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def sample_logits(self,
|
| 191 |
+
logits: torch.FloatTensor,
|
| 192 |
+
temp: Optional[float] = None,
|
| 193 |
+
post_temp: Optional[float] = None,
|
| 194 |
+
top_k: Optional[int] = None,
|
| 195 |
+
top_p: Optional[float] = None,
|
| 196 |
+
min_p: Optional[float] = None,
|
| 197 |
+
sample_range: Optional[Tuple[int,int]] = None,
|
| 198 |
+
blacklist: Optional[Union[List[int], torch.LongTensor]] = None
|
| 199 |
+
) -> torch.LongTensor:
|
| 200 |
+
"""
|
| 201 |
+
Samples an integer from the distribution of logits
|
| 202 |
+
|
| 203 |
+
Parameters:
|
| 204 |
+
logits (torch.FloatTensor): The logits of the distribution
|
| 205 |
+
temp (float): The temperature of the sampling, if 0.0, then argmax is used
|
| 206 |
+
top_k (int): The number of top k tokens to consider during sampling
|
| 207 |
+
top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
|
| 208 |
+
min_p (float): The minimum probability threshold factor for min-p sampling
|
| 209 |
+
blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling
|
| 210 |
+
Returns:
|
| 211 |
+
torch.LongTensor: The sampled integers
|
| 212 |
+
"""
|
| 213 |
+
if isinstance(temp, list):
|
| 214 |
+
temp = temp[0]
|
| 215 |
+
if isinstance(post_temp, list):
|
| 216 |
+
post_temp = post_temp[0]
|
| 217 |
+
if isinstance(top_k, list):
|
| 218 |
+
top_k = top_k[0]
|
| 219 |
+
if isinstance(top_p, list):
|
| 220 |
+
top_p = top_p[0]
|
| 221 |
+
assert temp is None or temp >= 0.0
|
| 222 |
+
assert post_temp is None or post_temp >= 0.0
|
| 223 |
+
assert top_k is None or top_k > 0
|
| 224 |
+
assert top_p is None or top_p >= 0.0
|
| 225 |
+
assert min_p is None or 0.0 <= min_p <= 1.0
|
| 226 |
+
assert sample_range is None or (
|
| 227 |
+
sample_range[0] < sample_range[1] and
|
| 228 |
+
sample_range[0] >= 0 and
|
| 229 |
+
sample_range[1] <= logits.shape[-1]
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Apply blacklist & sample range
|
| 233 |
+
if blacklist is not None:
|
| 234 |
+
logits[...,blacklist] = float('-inf')
|
| 235 |
+
if sample_range is not None:
|
| 236 |
+
logits = logits[...,sample_range[0]:sample_range[1]]
|
| 237 |
+
token_offset = sample_range[0]
|
| 238 |
+
else:
|
| 239 |
+
token_offset = 0
|
| 240 |
+
|
| 241 |
+
# Apply temperature, or use argmax if 0.0
|
| 242 |
+
if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0):
|
| 243 |
+
return token_offset + torch.argmax(logits, dim=-1)
|
| 244 |
+
if temp is not None and temp != 1.0:
|
| 245 |
+
logits.div_(temp)
|
| 246 |
+
|
| 247 |
+
# Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting).
|
| 248 |
+
# We sample in sorted order then re-order before returning.
|
| 249 |
+
if top_k is not None or top_p is not None:
|
| 250 |
+
logits, order = torch.sort(logits, dim=-1, descending=True)
|
| 251 |
+
else:
|
| 252 |
+
order = None # Don't sort
|
| 253 |
+
|
| 254 |
+
# Apply top-k filtering if specified
|
| 255 |
+
if top_k is not None:
|
| 256 |
+
logits = logits[...,:top_k]
|
| 257 |
+
|
| 258 |
+
# Apply top-p (nucleus) filtering if specified
|
| 259 |
+
if top_p is not None:
|
| 260 |
+
probs = F.softmax(logits, dim=-1) # Already sorted
|
| 261 |
+
cumulative_probs = probs.cumsum_(dim=-1)
|
| 262 |
+
idxs_to_remove = cumulative_probs > top_p
|
| 263 |
+
# Shift the mask right to keep at least one token
|
| 264 |
+
logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf')
|
| 265 |
+
del probs, cumulative_probs, idxs_to_remove
|
| 266 |
+
|
| 267 |
+
# Apply min-p filtering if specified
|
| 268 |
+
if min_p is not None:
|
| 269 |
+
probs = F.softmax(logits, dim=-1)
|
| 270 |
+
maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values
|
| 271 |
+
logits[probs < maxprob * min_p] = float('-inf')
|
| 272 |
+
del probs, maxprob
|
| 273 |
+
|
| 274 |
+
# Apply optional post-temperature
|
| 275 |
+
if post_temp is not None and post_temp != 1.0:
|
| 276 |
+
logits.div_(post_temp)
|
| 277 |
+
|
| 278 |
+
# Compute softmax probabilities
|
| 279 |
+
orig_shape = logits.shape
|
| 280 |
+
probs = torch.softmax(logits, dim=-1, out=logits)
|
| 281 |
+
# Flatten probabilities to (batch_size * sequence_length, vocab_size)
|
| 282 |
+
flat_probs = probs.view(-1, probs.size(-1))
|
| 283 |
+
# Sample from the distribution
|
| 284 |
+
sampled = torch.multinomial(flat_probs, num_samples=1)
|
| 285 |
+
# Reshape to original shape except for the last dimension
|
| 286 |
+
sampled = sampled.view(*orig_shape[:-1])
|
| 287 |
+
|
| 288 |
+
# If we sorted, unsort to collect the actual token values
|
| 289 |
+
if order is not None:
|
| 290 |
+
sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1)
|
| 291 |
+
return token_offset + sampled
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@torch.no_grad()
|
| 295 |
+
def rollout_patches(self,
|
| 296 |
+
seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]],
|
| 297 |
+
pos: Union[torch.LongTensor, List[torch.LongTensor]],
|
| 298 |
+
idx: torch.LongTensor,
|
| 299 |
+
n_tokens_per_patch: int = 5,
|
| 300 |
+
n_seq_patches: int = -1,
|
| 301 |
+
weights: Optional[Union[List[float], torch.Tensor]] = None,
|
| 302 |
+
k_cache: Optional[torch.Tensor] = None,
|
| 303 |
+
v_cache: Optional[torch.Tensor] = None,
|
| 304 |
+
cache_mask: Optional[torch.Tensor] = None,
|
| 305 |
+
policy: Callable[..., torch.LongTensor] = None,
|
| 306 |
+
*,
|
| 307 |
+
unmask_parallel: bool = False,
|
| 308 |
+
return_logits: bool = False,
|
| 309 |
+
return_idx_logits: bool = True,
|
| 310 |
+
return_kv: bool = False,
|
| 311 |
+
verbose: bool = True,
|
| 312 |
+
temp: Optional[Union[float, List[float]]] = None,
|
| 313 |
+
post_temp: Optional[Union[float, List[float]]] = None,
|
| 314 |
+
top_k: Optional[Union[int, List[int]]] = None,
|
| 315 |
+
top_p: Optional[Union[float, List[float]]] = None,
|
| 316 |
+
min_p: Optional[Union[float, List[float]]] = None,
|
| 317 |
+
sample_range: Optional[Tuple[int, int]] = None,
|
| 318 |
+
blacklist: Optional[Union[List[int], torch.LongTensor]] = None
|
| 319 |
+
) -> Union[
|
| 320 |
+
torch.LongTensor, # seq
|
| 321 |
+
Tuple[torch.LongTensor, torch.Tensor], # seq, logits
|
| 322 |
+
Tuple[torch.LongTensor, Dict[str, torch.Tensor]], # seq, kvcache
|
| 323 |
+
Tuple[torch.LongTensor, torch.Tensor, Dict[str, torch.Tensor]], # seq, logits, kvcache
|
| 324 |
+
]:
|
| 325 |
+
"""
|
| 326 |
+
K = number of given sequences (1 if seq is not a list)
|
| 327 |
+
T = length of a conditioning sequence (per sequence)
|
| 328 |
+
N = total length of conditioning + generated tokens (per sequence)
|
| 329 |
+
I = number of index tokens to roll out
|
| 330 |
+
max(T) = maximum of T across sequences
|
| 331 |
+
num_new_tokens = len(idx) * n_tokens_per_patch
|
| 332 |
+
|
| 333 |
+
***Tips for long rollouts***:
|
| 334 |
+
1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order)
|
| 335 |
+
2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache
|
| 336 |
+
contiguously in memory. If you create tensors between rollouts, consider moving them to
|
| 337 |
+
CPU or cloning them to defragment VRAM.
|
| 338 |
+
3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running
|
| 339 |
+
n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the
|
| 340 |
+
KV cache will duplicate memory. To avoid this, try moving the cache to CPU first:
|
| 341 |
+
```
|
| 342 |
+
seq1, kvcache = predictor.rollout_patches(..., return_kv=True)
|
| 343 |
+
kvcache = { k: v.cpu() for k, v in kvcache.items() }
|
| 344 |
+
gc.collect(); torch.cuda.empty_cache()
|
| 345 |
+
seq2 = predictor.rollout_patches(..., **kvcache)
|
| 346 |
+
```
|
| 347 |
+
With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it.
|
| 348 |
+
|
| 349 |
+
TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens
|
| 350 |
+
|
| 351 |
+
Parameters:
|
| 352 |
+
seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]):
|
| 353 |
+
[T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence
|
| 354 |
+
pos (Union[torch.LongTensor, List[torch.LongTensor]]):
|
| 355 |
+
[N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq)
|
| 356 |
+
idx (torch.LongTensor):
|
| 357 |
+
[I] / the patch indices to use in the rollout (same for all sequences)
|
| 358 |
+
n_tokens_per_patch (int):
|
| 359 |
+
number of tokens per patch, including patch index
|
| 360 |
+
n_seq_patches (int):
|
| 361 |
+
number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel
|
| 362 |
+
weights (Optional[Union[List[float], torch.Tensor]]):
|
| 363 |
+
float weights for the logits produced by each sequence. If None and multiple sequences are given, defaults to all ones.
|
| 364 |
+
If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens].
|
| 365 |
+
If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered.
|
| 366 |
+
When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency
|
| 367 |
+
k_cache (Optional[torch.Tensor]):
|
| 368 |
+
optional k_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
|
| 369 |
+
v_cache (Optional[torch.Tensor]):
|
| 370 |
+
optional v_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
|
| 371 |
+
cache_mask (Optional[torch.Tensor]):
|
| 372 |
+
optional mask to be applied to the provided KV cache with shape [K 1 1 n_tok], where n_tok matches k_cache/v_cache. Useful when a KV cache is supplied
|
| 373 |
+
for multiple conditioning sequences of different lengths, where the cache_mask indicates which elements of the cache should be attended to for each sequence.
|
| 374 |
+
If k_cache/v_cache are given and cache_mask is None, the cache will be fully unmasked. May be on a different device
|
| 375 |
+
policy (Callable[..., torch.LongTensor]):
|
| 376 |
+
optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next.
|
| 377 |
+
Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023],
|
| 378 |
+
return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given
|
| 379 |
+
- `idx` (torch.LongTensor of shape [I]) the candidate patch indices
|
| 380 |
+
- `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx`
|
| 381 |
+
- `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None
|
| 382 |
+
- `k_cache` (torch.Tensor)
|
| 383 |
+
- `v_cache` (torch.Tensor)
|
| 384 |
+
- `cache_mask` (torch.Tensor of shape [K 1 1 n_tok])
|
| 385 |
+
- `kvcache` (Dict[str, torch.Tensor]): the kvcache dict with keys 'k_cache', 'v_cache', and 'cache_mask'
|
| 386 |
+
- `all_k_cache` (torch.Tensor): the entire preallocated k-cache, including uninitialized tokens
|
| 387 |
+
- `all_v_cache` (torch.Tensor): the entire preallocated v-cache, including uninitialized tokens
|
| 388 |
+
- `n_tokens_per_patch` (int)
|
| 389 |
+
- `sample_range` (Optional[Tuple[int,int]])
|
| 390 |
+
- `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens
|
| 391 |
+
- `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens
|
| 392 |
+
- `device` (torch.device)\n
|
| 393 |
+
The callback must return the following
|
| 394 |
+
- (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself)
|
| 395 |
+
unmask_parallel (bool):
|
| 396 |
+
if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves
|
| 397 |
+
return_logits (bool):
|
| 398 |
+
return the logits of the sequence
|
| 399 |
+
return_idx_logits (bool):
|
| 400 |
+
if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns
|
| 401 |
+
logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits
|
| 402 |
+
for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False
|
| 403 |
+
return_kv (bool):
|
| 404 |
+
return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', and 'cache_mask', useful for downstream operations. If True and return_logits=False,
|
| 405 |
+
returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if
|
| 406 |
+
the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
torch.LongTensor:
|
| 410 |
+
[num_new_tokens] the generated tokens only (the input sequence is not prepended)
|
| 411 |
+
torch.Tensor:
|
| 412 |
+
(optional) [n_tokens vocab_size] the logits of the sequence, where n_tokens depends on return_idx_logits
|
| 413 |
+
Dict[str, torch.Tensor]:
|
| 414 |
+
(optional) the KV cache, with the following key/value pairs
|
| 415 |
+
- `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
|
| 416 |
+
- `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
|
| 417 |
+
- `cache_mask` (torch.Tensor) [K 1 1 n_tok]
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
#########################
|
| 421 |
+
# === Preprocessing === #
|
| 422 |
+
#########################
|
| 423 |
+
|
| 424 |
+
if not isinstance(seq, list):
|
| 425 |
+
seq = [seq] if seq is None or seq.ndim == 1 else list(seq)
|
| 426 |
+
if not isinstance(pos, list):
|
| 427 |
+
pos = [pos] if pos.ndim == 2 else list(pos)
|
| 428 |
+
|
| 429 |
+
nnt = idx.numel() * n_tokens_per_patch # num new tokens
|
| 430 |
+
device = pos[0].device
|
| 431 |
+
idtype = pos[0].dtype
|
| 432 |
+
dtype = self.lm_head.weight.dtype
|
| 433 |
+
|
| 434 |
+
if weights is not None:
|
| 435 |
+
if isinstance(weights, list):
|
| 436 |
+
weights = torch.tensor(weights, dtype=dtype, device=device)
|
| 437 |
+
if weights.ndim != 2:
|
| 438 |
+
weights = weights.unsqueeze(-1).expand(-1, nnt)
|
| 439 |
+
weights = weights.to(dtype).to(device)
|
| 440 |
+
elif len(seq) > 1:
|
| 441 |
+
weights = torch.ones(len(seq), nnt, dtype=dtype, device=device)
|
| 442 |
+
|
| 443 |
+
if n_seq_patches < 0:
|
| 444 |
+
n_seq_patches = idx.shape[0]
|
| 445 |
+
if temp is not None and not isinstance(temp, list):
|
| 446 |
+
temp = [temp] * nnt
|
| 447 |
+
if post_temp is not None and not isinstance(post_temp, list):
|
| 448 |
+
post_temp = [post_temp] * nnt
|
| 449 |
+
if top_k is not None and not isinstance(top_k, list):
|
| 450 |
+
top_k = [top_k] * nnt
|
| 451 |
+
if top_p is not None and not isinstance(top_p, list):
|
| 452 |
+
top_p = [top_p] * nnt
|
| 453 |
+
if min_p is not None and not isinstance(min_p, list):
|
| 454 |
+
min_p = [min_p] * nnt
|
| 455 |
+
|
| 456 |
+
K = len(seq)
|
| 457 |
+
I = idx.shape[0]
|
| 458 |
+
T = [0 if s is None else s.shape[0] for s in seq]
|
| 459 |
+
maxT = max(1, max(T))
|
| 460 |
+
|
| 461 |
+
tpp = n_tokens_per_patch
|
| 462 |
+
in_cache_size = 0 if k_cache is None else k_cache.shape[3]
|
| 463 |
+
n_rollout_tokens = tpp * n_seq_patches
|
| 464 |
+
n_par_patches = I - n_seq_patches
|
| 465 |
+
return_idx_logits = return_logits and return_idx_logits
|
| 466 |
+
run_last_parallel_tokens = return_idx_logits or return_kv
|
| 467 |
+
|
| 468 |
+
# Validate inputs as best we can
|
| 469 |
+
assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}'
|
| 470 |
+
assert idx.ndim == 1
|
| 471 |
+
if weights is not None:
|
| 472 |
+
assert weights.ndim == 2 and weights.shape == (K, nnt)
|
| 473 |
+
assert I * tpp == nnt, f'Requested {nnt} new tokens, but ({I} idx tokens) * ({tpp} tok per patch) = {I*tpp} != {nnt}'
|
| 474 |
+
assert 0 <= n_seq_patches <= I
|
| 475 |
+
assert k_cache is None or k_cache.ndim == 5
|
| 476 |
+
assert v_cache is None or v_cache.ndim == 5
|
| 477 |
+
assert (k_cache is None) == (v_cache is None)
|
| 478 |
+
assert k_cache is None or k_cache.shape[3] == v_cache.shape[3]
|
| 479 |
+
if cache_mask is not None:
|
| 480 |
+
assert cache_mask.ndim == 4 and cache_mask.shape[1] == 1 and cache_mask.shape[2] == 1
|
| 481 |
+
assert cache_mask.shape[-1] == in_cache_size, f'cache_mask ({cache_mask.shape[-1]} tokens) does not match the size of k_cache/v_cache ({in_cache_size} tokens)'
|
| 482 |
+
for i, (s, p) in enumerate(zip(seq, pos)):
|
| 483 |
+
if s is not None:
|
| 484 |
+
assert s.ndim == 1, f'Expected all sequence tensors to be 1D, but got seq[{i}].ndim={s.ndim}'
|
| 485 |
+
assert p.ndim == 2, f'Expected all position tensors to be 2D, but got pos[{i}].ndim={p.ndim}'
|
| 486 |
+
assert p.shape[1] == 4, f'Expected all position tensors have shape (*,4), but got pos[{i}].shape[1]={p.shape[1]}'
|
| 487 |
+
assert p.shape[0] == T[i] + nnt, f'Sequence {i}: With {T[i]} conditioning and {nnt} new tokens, expected pos[{i}].shape[0]={T[i]+nnt}, but got {p.shape[0]}'
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
#########################
|
| 491 |
+
# === Preallocation === #
|
| 492 |
+
#########################
|
| 493 |
+
|
| 494 |
+
# Preallocate the KV cache so we don't need to constantly resize it
|
| 495 |
+
# If we won't run the last parallel pass, we don't need to cache those toks
|
| 496 |
+
n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches)
|
| 497 |
+
# [n_layer K n_head n_tok n_embd//n_head]
|
| 498 |
+
kv_shape = (
|
| 499 |
+
self.config.n_layer, K, self.config.n_head,
|
| 500 |
+
n_kvcache, self.config.n_embd // self.config.n_head
|
| 501 |
+
)
|
| 502 |
+
all_v_cache = torch.empty(kv_shape, dtype=dtype, device=device)
|
| 503 |
+
all_k_cache = torch.empty(kv_shape, dtype=dtype, device=device)
|
| 504 |
+
if in_cache_size > 0:
|
| 505 |
+
all_k_cache[...,:in_cache_size,:].copy_(k_cache, non_blocking=True)
|
| 506 |
+
all_v_cache[...,:in_cache_size,:].copy_(v_cache, non_blocking=True)
|
| 507 |
+
|
| 508 |
+
# Also preallocate the output logits tensor, if requested
|
| 509 |
+
if return_logits:
|
| 510 |
+
n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1))
|
| 511 |
+
all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
################################
|
| 515 |
+
# === Initial Forward Pass === #
|
| 516 |
+
################################
|
| 517 |
+
|
| 518 |
+
# Stack seq/pos into a batch, left-padded
|
| 519 |
+
# [K maxT]
|
| 520 |
+
seq = torch.stack([(
|
| 521 |
+
torch.zeros(maxT, dtype=idtype, device=device) if s is None else
|
| 522 |
+
F.pad(s, (maxT - s.shape[0], 0))
|
| 523 |
+
) for s in seq])
|
| 524 |
+
# [K maxN 4]
|
| 525 |
+
pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)])
|
| 526 |
+
|
| 527 |
+
# Build attention mask for initial forward pass [K 1 maxT maxT]
|
| 528 |
+
# Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked
|
| 529 |
+
mask = torch.zeros(K, 1, maxT, maxT, device=device)
|
| 530 |
+
mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf'))
|
| 531 |
+
for i, t in enumerate(T):
|
| 532 |
+
mask[i, ..., :maxT-t] = float('-inf')
|
| 533 |
+
# Unmask the diagonal so pad tokens can self-attend
|
| 534 |
+
# This doesn't matter with torch sdpa, but prevents NaNs with manual attention
|
| 535 |
+
# NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token
|
| 536 |
+
# in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below
|
| 537 |
+
mask[i, 0].fill_diagonal_(0.0)
|
| 538 |
+
if k_cache is not None:
|
| 539 |
+
if cache_mask is not None:
|
| 540 |
+
mask = torch.cat([cache_mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1)
|
| 541 |
+
else:
|
| 542 |
+
mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0))
|
| 543 |
+
# The above mask[:,0,-1,:] might look something like this:
|
| 544 |
+
# kv cache | sequences
|
| 545 |
+
# T[0]==3 [ T T T T T T | F F F F T T T ]
|
| 546 |
+
# T[1]==6 [ T T T T T T | F T T T T T T ]
|
| 547 |
+
# T[2]==7 [ T T T T T T | T T T T T T T ]
|
| 548 |
+
# T[3]==4 [ T T T T T T | F F F T T T T ]
|
| 549 |
+
# For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like:
|
| 550 |
+
# kv cache | sequences
|
| 551 |
+
# [ T T T T T T | T F F F F F F ]
|
| 552 |
+
# [ T T T T T T | F T F F F F F ]
|
| 553 |
+
# [ T T T T T T | F F T F F F F ]
|
| 554 |
+
# [ T T T T T T | F F F T F F F ]
|
| 555 |
+
# [ T T T T T T | F F F F T F F ]
|
| 556 |
+
# [ T T T T T T | F F F F T T F ]
|
| 557 |
+
# [ T T T T T T | F F F F T T T ]
|
| 558 |
+
# If a custom cache_mask is given, the kv cache part above may be different
|
| 559 |
+
|
| 560 |
+
# Initial forward pass (conditioning sequences only)
|
| 561 |
+
k_cache = all_k_cache[...,:in_cache_size+maxT,:]
|
| 562 |
+
v_cache = all_v_cache[...,:in_cache_size+maxT,:]
|
| 563 |
+
self.forward(
|
| 564 |
+
seq=seq, pos=pos[:,:maxT], mask=mask,
|
| 565 |
+
k_cache=k_cache, v_cache=v_cache, inplace_kv=True
|
| 566 |
+
)
|
| 567 |
+
pos = pos[:,maxT:]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
##############################
|
| 571 |
+
# === Sequential Rollout === #
|
| 572 |
+
##############################
|
| 573 |
+
|
| 574 |
+
# Build attention mask for rollout [K 1 1 maxT+n_rollout_tokens], clone to free memory
|
| 575 |
+
mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0))
|
| 576 |
+
for i, t in enumerate(T):
|
| 577 |
+
# If t==0, the fill_diagonal call above unmasked the last pad token,
|
| 578 |
+
# so we need to re-mask it before we start rolling out
|
| 579 |
+
if t == 0:
|
| 580 |
+
mask[i, ..., in_cache_size+maxT-1] = float('-inf')
|
| 581 |
+
# The above mask[:,0,0,:] might look something like this:
|
| 582 |
+
# kv cache | sequences | rollout
|
| 583 |
+
# T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ]
|
| 584 |
+
# T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ]
|
| 585 |
+
# T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ]
|
| 586 |
+
# T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ]
|
| 587 |
+
# We construct this mask once, then slice off part of the right side at each rollout step
|
| 588 |
+
|
| 589 |
+
rollout_seq = []
|
| 590 |
+
|
| 591 |
+
# Rollout
|
| 592 |
+
for i in tqdm.tqdm(range(n_rollout_tokens), desc='Rollout', unit='tok', disable=(not verbose or n_rollout_tokens==0)):
|
| 593 |
+
if (i % tpp) == 0:
|
| 594 |
+
patch_number = i // tpp
|
| 595 |
+
if policy is None:
|
| 596 |
+
# Use provided order
|
| 597 |
+
next_token = idx[patch_number]
|
| 598 |
+
else:
|
| 599 |
+
# Use callback to select the next patch
|
| 600 |
+
policy_cache_mask = mask[..., :in_cache_size+maxT+i]
|
| 601 |
+
idx_of_next_idx = policy(
|
| 602 |
+
idx=idx[patch_number:], # [N]
|
| 603 |
+
pos=pos[:, i:], # [K N 4]
|
| 604 |
+
weights=None if weights is None else weights[:, i:], # [K N]
|
| 605 |
+
k_cache=k_cache,
|
| 606 |
+
v_cache=v_cache,
|
| 607 |
+
cache_mask=policy_cache_mask, # [K 1 1 n_tok]
|
| 608 |
+
kvcache=dict(k_cache=k_cache, v_cache=v_cache, cache_mask=policy_cache_mask),
|
| 609 |
+
all_k_cache=all_k_cache,
|
| 610 |
+
all_v_cache=all_v_cache,
|
| 611 |
+
n_tokens_per_patch=n_tokens_per_patch,
|
| 612 |
+
sample_range=sample_range,
|
| 613 |
+
idx_pos=pos[:, i::tpp], # [K I 4]
|
| 614 |
+
idx_weights=None if weights is None else weights[:, i::tpp], # [K I]
|
| 615 |
+
device=idx.device,
|
| 616 |
+
)
|
| 617 |
+
del policy_cache_mask
|
| 618 |
+
# Move the patch patch_number+idx_of_next_idx to the next position by swapping
|
| 619 |
+
if idx_of_next_idx != 0: # Don't bother if it's already next
|
| 620 |
+
i1, i2 = patch_number, patch_number + int(idx_of_next_idx)
|
| 621 |
+
idx[[i1,i2]] = idx[[i2,i1]]
|
| 622 |
+
# [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4]
|
| 623 |
+
pos = pos.reshape(K, I, tpp, 4)
|
| 624 |
+
pos[:,[i1,i2]] = pos[:,[i2,i1]]
|
| 625 |
+
pos = pos.reshape(K, -1, 4)
|
| 626 |
+
next_token = idx[patch_number]
|
| 627 |
+
rollout_seq.append(next_token)
|
| 628 |
+
|
| 629 |
+
# Forward call
|
| 630 |
+
k_cache = all_k_cache[...,:in_cache_size+maxT+i+1,:]
|
| 631 |
+
v_cache = all_v_cache[...,:in_cache_size+maxT+i+1,:]
|
| 632 |
+
logits, _ = self.forward(
|
| 633 |
+
seq=next_token.expand(K, 1), # [K 1]
|
| 634 |
+
pos=pos[:, [i]], # [K 1 4]
|
| 635 |
+
mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1]
|
| 636 |
+
k_cache=k_cache,
|
| 637 |
+
v_cache=v_cache,
|
| 638 |
+
inplace_kv=True
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Weighted sum of logits [K 1 V] -> [1 V]
|
| 642 |
+
if weights is not None:
|
| 643 |
+
w = weights[:,[i]] # [K nnt] -> [K 1]
|
| 644 |
+
logits = w.T @ logits.squeeze(1) # [1 V]
|
| 645 |
+
else:
|
| 646 |
+
logits = logits.squeeze(0) # [1 1 V] -> [1 V]
|
| 647 |
+
|
| 648 |
+
# If next is an index token, we just needed to cache the previous token
|
| 649 |
+
if ((i + 1) % tpp) == 0:
|
| 650 |
+
if return_idx_logits:
|
| 651 |
+
all_logits[i].copy_(logits.squeeze())
|
| 652 |
+
continue
|
| 653 |
+
if return_logits:
|
| 654 |
+
all_logits[i].copy_(logits.squeeze())
|
| 655 |
+
|
| 656 |
+
# Sample from logits
|
| 657 |
+
next_token = self.sample_logits(
|
| 658 |
+
logits,
|
| 659 |
+
temp=temp[i] if temp else None,
|
| 660 |
+
post_temp=post_temp[i] if post_temp else None,
|
| 661 |
+
top_k=top_k[i] if top_k else None,
|
| 662 |
+
top_p=top_p[i] if top_p else None,
|
| 663 |
+
min_p=min_p[i] if min_p else None,
|
| 664 |
+
sample_range=sample_range,
|
| 665 |
+
blacklist=blacklist
|
| 666 |
+
).squeeze(0)
|
| 667 |
+
rollout_seq.append(next_token)
|
| 668 |
+
|
| 669 |
+
if n_rollout_tokens > 0:
|
| 670 |
+
rollout_seq = torch.stack(rollout_seq)
|
| 671 |
+
if n_rollout_tokens == nnt:
|
| 672 |
+
ret = (rollout_seq,)
|
| 673 |
+
if return_logits:
|
| 674 |
+
ret = (*ret, all_logits)
|
| 675 |
+
if return_kv:
|
| 676 |
+
ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
|
| 677 |
+
return ret if len(ret) > 1 else ret[0]
|
| 678 |
+
|
| 679 |
+
############################
|
| 680 |
+
# === Parallel Rollout === #
|
| 681 |
+
############################
|
| 682 |
+
|
| 683 |
+
npp = n_par_patches
|
| 684 |
+
npt = npp * tpp # num parallel tokens
|
| 685 |
+
idx = idx[-npp:]
|
| 686 |
+
|
| 687 |
+
# Build attention mask for parallel part [K 1 npp n_past+npp]
|
| 688 |
+
if unmask_parallel:
|
| 689 |
+
par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp)
|
| 690 |
+
else:
|
| 691 |
+
par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device)
|
| 692 |
+
par_mask.fill_diagonal_(0.0)
|
| 693 |
+
par_mask = par_mask.expand(K, 1, npp, npp)
|
| 694 |
+
mask = mask.expand(K, 1, npp, -1)
|
| 695 |
+
# Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim
|
| 696 |
+
|
| 697 |
+
# Reshape positions for parallel passes
|
| 698 |
+
# [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4]
|
| 699 |
+
pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4)
|
| 700 |
+
|
| 701 |
+
# Reshape scheduled properties (transpose similarly to positions)
|
| 702 |
+
# [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt]
|
| 703 |
+
if temp is not None:
|
| 704 |
+
temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
|
| 705 |
+
if post_temp is not None:
|
| 706 |
+
post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
|
| 707 |
+
if top_k is not None:
|
| 708 |
+
top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
|
| 709 |
+
if top_p is not None:
|
| 710 |
+
top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
|
| 711 |
+
if min_p is not None:
|
| 712 |
+
min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
|
| 713 |
+
if weights is not None:
|
| 714 |
+
# [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt]
|
| 715 |
+
weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt)
|
| 716 |
+
|
| 717 |
+
next_tokens = idx
|
| 718 |
+
parallel_seq = [next_tokens]
|
| 719 |
+
|
| 720 |
+
# Run parallel passes
|
| 721 |
+
for i in range(tpp if run_last_parallel_tokens else (tpp - 1)):
|
| 722 |
+
mask = torch.cat([mask, par_mask], dim=-1)
|
| 723 |
+
parallel_slice = slice(i*npp, (i+1)*npp)
|
| 724 |
+
k_cache = all_k_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
|
| 725 |
+
v_cache = all_v_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
|
| 726 |
+
logits, _ = self.forward(
|
| 727 |
+
seq=next_tokens.expand(K, npp), # [K npp]
|
| 728 |
+
pos=pos[:,parallel_slice], # [K npp 4]
|
| 729 |
+
mask=mask, # [K 1 npp n_past+npp]
|
| 730 |
+
k_cache=k_cache,
|
| 731 |
+
v_cache=v_cache,
|
| 732 |
+
inplace_kv=True
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Weighted sum of logits [K npp V] -> [npp V]
|
| 736 |
+
if weights is not None:
|
| 737 |
+
w = weights[:,parallel_slice] # [K npt] -> [K npp]
|
| 738 |
+
logits = (logits * w.unsqueeze(-1)).sum(0) # [K npp V] -> [npp V]
|
| 739 |
+
else:
|
| 740 |
+
logits = logits.squeeze(0) # [1 npp V] -> [npp V]
|
| 741 |
+
|
| 742 |
+
if return_logits and (i < tpp - 1 or return_idx_logits):
|
| 743 |
+
# Store the logits with a stride so we don't need to transpose later
|
| 744 |
+
# NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp
|
| 745 |
+
stride = tpp if return_idx_logits else (tpp - 1)
|
| 746 |
+
all_logits[n_rollout_tokens+i::stride].copy_(logits)
|
| 747 |
+
if i == (tpp - 1):
|
| 748 |
+
# We just needed to compute logits and/or KV to return them; no need to predict
|
| 749 |
+
break
|
| 750 |
+
|
| 751 |
+
# Sample from logits
|
| 752 |
+
# TODO: Index using parallel_slice instead of -1 to support scheduled parameters
|
| 753 |
+
next_tokens = self.sample_logits(
|
| 754 |
+
logits,
|
| 755 |
+
temp=temp[-1] if temp else None,
|
| 756 |
+
post_temp=post_temp[-1] if post_temp else None,
|
| 757 |
+
top_k=top_k[-1] if top_k else None,
|
| 758 |
+
top_p=top_p[-1] if top_p else None,
|
| 759 |
+
min_p=min_p[-1] if min_p else None,
|
| 760 |
+
sample_range=sample_range,
|
| 761 |
+
blacklist=blacklist
|
| 762 |
+
)
|
| 763 |
+
parallel_seq.append(next_tokens)
|
| 764 |
+
|
| 765 |
+
# [tpp npp] -> [npp tpp] -> [npt]
|
| 766 |
+
parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten()
|
| 767 |
+
if return_kv:
|
| 768 |
+
# Transpose only the last npp (num parallel patches) patches
|
| 769 |
+
# [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E]
|
| 770 |
+
edims = all_k_cache.shape[:-2]
|
| 771 |
+
par_k = all_k_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
|
| 772 |
+
par_v = all_v_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
|
| 773 |
+
all_k_cache[...,-npt:,:].copy_(par_k.clone())
|
| 774 |
+
all_v_cache[...,-npt:,:].copy_(par_v.clone())
|
| 775 |
+
del par_k, par_v
|
| 776 |
+
# [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols)
|
| 777 |
+
# Unmask the last npt cols (if necessary) to make the parallel part fully unmasked
|
| 778 |
+
mask = mask[...,[-1],:].clone()
|
| 779 |
+
if not unmask_parallel:
|
| 780 |
+
mask[...,-npt:] = 0.0
|
| 781 |
+
|
| 782 |
+
ret = (torch.cat([rollout_seq, parallel_seq]),) if n_rollout_tokens > 0 else (parallel_seq,)
|
| 783 |
+
if return_logits:
|
| 784 |
+
ret = (*ret, all_logits)
|
| 785 |
+
if return_kv:
|
| 786 |
+
ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
|
| 787 |
+
return ret if len(ret) > 1 else ret[0]
|
| 788 |
+
|