|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import importlib |
|
|
|
|
|
try: |
|
|
xm = importlib.import_module('torch_xla.core.xla_model') |
|
|
xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding') |
|
|
except ImportError: |
|
|
xm = None |
|
|
xs = None |
|
|
|
|
|
|
|
|
class Rotary3D(nn.Module): |
|
|
def __init__(self, dim, base=100): |
|
|
super().__init__() |
|
|
assert dim % 16 == 0, "Embedding dim must be divisible by 16" |
|
|
|
|
|
|
|
|
self.x_dim = (6 * dim) // 16 |
|
|
self.y_dim = (6 * dim) // 16 |
|
|
self.t_dim = dim - self.x_dim - self.y_dim |
|
|
|
|
|
|
|
|
self.register_buffer('inv_freq_x', 1.0 / (base ** (torch.arange(0, self.x_dim, 2).float() / self.x_dim))) |
|
|
self.register_buffer('inv_freq_y', 1.0 / (base ** (torch.arange(0, self.y_dim, 2).float() / self.y_dim))) |
|
|
self.register_buffer('inv_freq_t', 1.0 / (base ** (torch.arange(0, self.t_dim, 2).float() / self.t_dim))) |
|
|
|
|
|
def forward(self, x, pos): |
|
|
""" |
|
|
x: [batch, nh, seq_len, head_dim] |
|
|
pos: [batch, seq_len, 3] integer positions along (x, y, t) |
|
|
""" |
|
|
B, nh, T, hs = x.shape |
|
|
assert pos.shape[-1] == 3, "Position tensor must have shape [batch, seq_len, 3]" |
|
|
|
|
|
|
|
|
dim_total = hs |
|
|
assert dim_total % 2 == 0, "head_dim (hs) must be divisible by 2 for rotary embedding." |
|
|
|
|
|
|
|
|
dtype = self.inv_freq_x.dtype |
|
|
pos_x = pos[..., 0].to(dtype) |
|
|
pos_y = pos[..., 1].to(dtype) |
|
|
pos_t = pos[..., 2].to(dtype) |
|
|
|
|
|
|
|
|
freqs_x = torch.einsum('bt,f -> btf', pos_x, self.inv_freq_x) |
|
|
freqs_y = torch.einsum('bt,f -> btf', pos_y, self.inv_freq_y) |
|
|
freqs_t = torch.einsum('bt,f -> btf', pos_t, self.inv_freq_t) |
|
|
|
|
|
|
|
|
freq_combined = torch.cat([freqs_x, freqs_y, freqs_t], dim=-1) |
|
|
|
|
|
|
|
|
cos_emb = freq_combined.cos().unsqueeze(1) |
|
|
sin_emb = freq_combined.sin().unsqueeze(1) |
|
|
|
|
|
|
|
|
x1, x2 = x[..., :hs//2], x[..., hs//2:] |
|
|
|
|
|
|
|
|
x_rotated = torch.cat([ |
|
|
x1 * cos_emb - x2 * sin_emb, |
|
|
x1 * sin_emb + x2 * cos_emb |
|
|
], dim=-1) |
|
|
|
|
|
return x_rotated |
|
|
|
|
|
|
|
|
class PSIAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
|
|
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
self.n_head = config.n_head |
|
|
self.n_embd = config.n_embd |
|
|
self.dropout = config.dropout |
|
|
|
|
|
self.rope = Rotary3D(config.n_embd // config.n_head) |
|
|
|
|
|
|
|
|
if config.attention_mask == "causal": |
|
|
self.is_causal = True |
|
|
else: |
|
|
self.is_causal = False |
|
|
|
|
|
|
|
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
xm_local = importlib.import_module('torch_xla.core.xla_model') |
|
|
self.tpu = True |
|
|
except ImportError: |
|
|
self.tpu = False |
|
|
|
|
|
|
|
|
xla_device_available = False |
|
|
if xm is not None: |
|
|
try: |
|
|
device_kind = xm.xla_device_kind() |
|
|
if device_kind is not None: |
|
|
xla_device_available = True |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
@torch.compiler.disable |
|
|
def emplace_kv(self, T, k_cache, v_cache, k, v): |
|
|
|
|
|
|
|
|
k_cache[:,:,-T:].copy_(k) |
|
|
v_cache[:,:,-T:].copy_(v) |
|
|
return k_cache, v_cache |
|
|
|
|
|
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, mask=None): |
|
|
B, T, C = x.size() |
|
|
|
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
|
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
|
|
|
|
|
|
k = self.rope(k, pos) |
|
|
q = self.rope(q, pos) |
|
|
|
|
|
if inplace_kv and k_cache is not None and v_cache is not None: |
|
|
|
|
|
k, v = self.emplace_kv(T, k_cache, v_cache, k, v) |
|
|
else: |
|
|
|
|
|
if k_cache is not None: |
|
|
k = torch.cat((k_cache, k), dim=2) |
|
|
if v_cache is not None: |
|
|
v = torch.cat((v_cache, v), dim=2) |
|
|
|
|
|
|
|
|
if self.tpu: |
|
|
|
|
|
flash_attention = importlib.import_module('torch_xla.experimental.custom_kernel.flash_attention') |
|
|
q_norm = q / math.sqrt(k.size(-1)) |
|
|
y = flash_attention( |
|
|
q_norm, k, v, |
|
|
causal=True, partition_spec=('fsdp', None, None, None)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif self.flash: |
|
|
|
|
|
L, S = q.size(-2), k.size(-2) |
|
|
is_causal = self.is_causal and mask is None |
|
|
|
|
|
if is_causal and L < S: |
|
|
if L > 1: |
|
|
mask = torch.ones(L, S, dtype=q.dtype, device=q.device) |
|
|
mask.masked_fill_(mask.to(torch.bool).triu(S-L+1), float('-inf')) |
|
|
is_causal = False |
|
|
|
|
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
dropout_p=self.dropout if self.training else 0, |
|
|
attn_mask=None if mask is None else mask.to(q.dtype), |
|
|
is_causal=is_causal |
|
|
) |
|
|
else: |
|
|
|
|
|
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1))) |
|
|
|
|
|
if mask is not None: |
|
|
att = att + mask |
|
|
elif self.is_causal: |
|
|
L, S = q.size(-2), k.size(-2) |
|
|
mask = torch.ones(1, 1, L, S).triu(S-L+1).to(dtype=torch.bool).to(x.device) |
|
|
att.masked_fill_(mask, float('-inf')) |
|
|
|
|
|
att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype) |
|
|
att = self.attn_dropout(att) |
|
|
|
|
|
y = torch.einsum('bnsk,bnkh->bnsh', att, v) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
|
|
|
|
if return_kv: |
|
|
return y, k, v |
|
|
|
|
|
return y |
|
|
|
|
|
def kv_cache_forward(self, x, pos, k_cache=None, v_cache=None): |
|
|
B, T, C = x.size() |
|
|
|
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
|
|
|
|
|
|
k = self.rope(k, pos) |
|
|
q = self.rope(q, pos) |
|
|
|
|
|
|
|
|
if k_cache is not None: |
|
|
k = torch.cat((k_cache, k), dim=2) |
|
|
if v_cache is not None: |
|
|
v = torch.cat((v_cache, v), dim=2) |
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_dropout(att) |
|
|
y = att @ v |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
|
|
|
|
return y, k, v |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
|
|
self.gelu = nn.GELU() |
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
|
|
|
xla_device_available = False |
|
|
if xm is not None: |
|
|
try: |
|
|
device_kind = xm.xla_device_kind() |
|
|
if device_kind is not None: |
|
|
xla_device_available = True |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
if xla_device_available and xs is not None and xs.global_mesh() is not None: |
|
|
mesh = xs.global_mesh() |
|
|
if mesh.mesh_shape[1] > 1: |
|
|
xs.mark_sharding(self.c_fc.weight, mesh, (1, 0)) |
|
|
if self.c_fc.bias is not None: |
|
|
xs.mark_sharding(self.c_fc.bias, mesh, (1,)) |
|
|
print(f"MLP: Applied MP sharding to c_fc {mesh.mesh_shape} spec weight(1,0), bias(1,)") |
|
|
|
|
|
xs.mark_sharding(self.c_proj.weight, mesh, (0, 1)) |
|
|
if self.c_proj.bias is not None: |
|
|
xs.mark_sharding(self.c_proj.bias, mesh, (0,)) |
|
|
print(f"MLP: Applied MP sharding to c_proj {mesh.mesh_shape} spec weight(0,1), bias(0,)") |
|
|
|
|
|
def forward(self, x, spmd_mesh=None): |
|
|
|
|
|
x = self.c_fc(x) |
|
|
x = self.gelu(x) |
|
|
|
|
|
if spmd_mesh is not None: |
|
|
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model')) |
|
|
|
|
|
x = self.c_proj(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
if spmd_mesh is not None: |
|
|
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model')) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" Root Mean Square Normalization """ |
|
|
def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) if weight else None |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()).type_as(x) |
|
|
if self.weight is not None: |
|
|
return output * self.weight |
|
|
return output |
|
|
|
|
|
|
|
|
class PSIBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.ln_1 = RMSNorm(config.n_embd, bias=config.bias) |
|
|
self.attn = PSIAttentionLayer(config) |
|
|
self.ln_2 = RMSNorm(config.n_embd, bias=config.bias) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, spmd_mesh=None, mask=None): |
|
|
|
|
|
|
|
|
if return_kv: |
|
|
|
|
|
x_attn, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, |
|
|
return_kv=True, inplace_kv=inplace_kv, mask=mask) |
|
|
x = x + x_attn |
|
|
x = x + self.mlp(self.ln_2(x)) |
|
|
return x, k, v |
|
|
|
|
|
x = x + self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, inplace_kv=inplace_kv, mask=mask) |
|
|
x = x + self.mlp(self.ln_2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class PartitionedEmbedding(nn.Module): |
|
|
def __init__(self, num_embeddings, embedding_dim, partition_size=65536): |
|
|
super().__init__() |
|
|
self.num_embeddings = num_embeddings |
|
|
self.embedding_dim = embedding_dim |
|
|
self.partition_size = partition_size |
|
|
self.num_partitions = (num_embeddings + partition_size - 1) // partition_size |
|
|
|
|
|
self.embedding_layers = nn.ModuleList() |
|
|
for i in range(self.num_partitions): |
|
|
start_idx = i * self.partition_size |
|
|
end_idx = min(start_idx + self.partition_size, num_embeddings) |
|
|
vocab_size = end_idx - start_idx |
|
|
self.embedding_layers.append(nn.Embedding(vocab_size, embedding_dim)) |
|
|
|
|
|
def forward(self, input_ids): |
|
|
partition_ids = input_ids // self.partition_size |
|
|
relative_ids = input_ids % self.partition_size |
|
|
|
|
|
output = torch.zeros(*input_ids.shape, self.embedding_dim, device=input_ids.device, dtype=self.embedding_layers[0].weight.dtype) |
|
|
|
|
|
for i in range(self.num_partitions): |
|
|
mask = (partition_ids == i) |
|
|
if mask.any(): |
|
|
partition_input_ids = relative_ids[mask] |
|
|
embedded = self.embedding_layers[i](partition_input_ids) |
|
|
output[mask] = embedded |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class PartitionedLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features, partition_size=65536, bias=False): |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.partition_size = partition_size |
|
|
self.num_partitions = (out_features + partition_size - 1) // partition_size |
|
|
|
|
|
self.linear_layers = nn.ModuleList() |
|
|
for i in range(self.num_partitions): |
|
|
start_idx = i * self.partition_size |
|
|
end_idx = min(start_idx + self.partition_size, out_features) |
|
|
output_partition_size = end_idx - start_idx |
|
|
self.linear_layers.append(nn.Linear(in_features, output_partition_size, bias=bias)) |
|
|
|
|
|
def forward(self, input): |
|
|
outputs = [layer(input) for layer in self.linear_layers] |
|
|
return torch.cat(outputs, dim=-1) |
|
|
|
|
|
|