PSI / modeling.py
TheTrueJard's picture
Upload folder using huggingface_hub
e3d287d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import importlib
try:
xm = importlib.import_module('torch_xla.core.xla_model')
xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
except ImportError:
xm = None
xs = None
class Rotary3D(nn.Module):
def __init__(self, dim, base=100):
super().__init__()
assert dim % 16 == 0, "Embedding dim must be divisible by 16"
# Embedding dimensions must align precisely with dim // num_heads
self.x_dim = (6 * dim) // 16
self.y_dim = (6 * dim) // 16
self.t_dim = dim - self.x_dim - self.y_dim
# Precompute inverse frequencies
self.register_buffer('inv_freq_x', 1.0 / (base ** (torch.arange(0, self.x_dim, 2).float() / self.x_dim)))
self.register_buffer('inv_freq_y', 1.0 / (base ** (torch.arange(0, self.y_dim, 2).float() / self.y_dim)))
self.register_buffer('inv_freq_t', 1.0 / (base ** (torch.arange(0, self.t_dim, 2).float() / self.t_dim)))
def forward(self, x, pos):
"""
x: [batch, nh, seq_len, head_dim]
pos: [batch, seq_len, 3] integer positions along (x, y, t)
"""
B, nh, T, hs = x.shape
assert pos.shape[-1] == 3, "Position tensor must have shape [batch, seq_len, 3]"
# Compute embeddings directly to match `hs`
dim_total = hs
assert dim_total % 2 == 0, "head_dim (hs) must be divisible by 2 for rotary embedding."
# Positional dimensions expanded explicitly
dtype = self.inv_freq_x.dtype
pos_x = pos[..., 0].to(dtype) # [B, T]
pos_y = pos[..., 1].to(dtype) # [B, T]
pos_t = pos[..., 2].to(dtype) # [B, T]
# Generate embeddings for x, y, t and combine
freqs_x = torch.einsum('bt,f -> btf', pos_x, self.inv_freq_x)
freqs_y = torch.einsum('bt,f -> btf', pos_y, self.inv_freq_y)
freqs_t = torch.einsum('bt,f -> btf', pos_t, self.inv_freq_t)
# Concatenate embeddings and match dimensions exactly
freq_combined = torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
# Cos and Sin embedding, reshape to match x exactly
cos_emb = freq_combined.cos().unsqueeze(1) # [B, 1, T, hs/2]
sin_emb = freq_combined.sin().unsqueeze(1) # [B, 1, T, hs/2]
# Split embedding dimension for rotation
x1, x2 = x[..., :hs//2], x[..., hs//2:]
# Ensure exact dimensional matching
x_rotated = torch.cat([
x1 * cos_emb - x2 * sin_emb,
x1 * sin_emb + x2 * cos_emb
], dim=-1)
return x_rotated
class PSIAttentionLayer(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# positional embedding
self.rope = Rotary3D(config.n_embd // config.n_head)
# check if we are using causal attention
if config.attention_mask == "causal":
self.is_causal = True
else:
self.is_causal = False
# check if GPU Flash Attention is available
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
# check if we are running on TPU
try:
# Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
xm_local = importlib.import_module('torch_xla.core.xla_model')
self.tpu = True
except ImportError:
self.tpu = False
# Apply XLA sharding for model parallelism
xla_device_available = False
if xm is not None:
try:
device_kind = xm.xla_device_kind()
if device_kind is not None:
xla_device_available = True
except RuntimeError:
pass
@torch.compiler.disable
def emplace_kv(self, T, k_cache, v_cache, k, v):
# torch.compile doesn't play well with this op (5x slowdown)
# so we insert a graph break and copy eagerly
k_cache[:,:,-T:].copy_(k)
v_cache[:,:,-T:].copy_(v)
return k_cache, v_cache
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Apply rotary positional embedding
k = self.rope(k, pos)
q = self.rope(q, pos)
if inplace_kv and k_cache is not None and v_cache is not None:
# assign into kv cache in-place
k, v = self.emplace_kv(T, k_cache, v_cache, k, v)
else:
# append cached keys and values with new keys and values
if k_cache is not None:
k = torch.cat((k_cache, k), dim=2)
if v_cache is not None:
v = torch.cat((v_cache, v), dim=2)
# Apply attention
if self.tpu:
# (1)
flash_attention = importlib.import_module('torch_xla.experimental.custom_kernel.flash_attention')
q_norm = q / math.sqrt(k.size(-1))
y = flash_attention(
q_norm, k, v,
causal=True, partition_spec=('fsdp', None, None, None))
# (2)
# y = torch.nn.functional.scaled_dot_product_attention(
# q, k, v,
# # dropout_p=self.dropout if self.training else 0,
# # attn_mask=None if mask is None else mask.to(q.dtype),
# is_causal=True
# )
elif self.flash:
# efficient attention using Flash Attention CUDA kernels
L, S = q.size(-2), k.size(-2)
is_causal = self.is_causal and mask is None
# is_causal doesn't work when not square, so replace with a manual mask if needed
if is_causal and L < S:
if L > 1: # if L=1, just use no mask
mask = torch.ones(L, S, dtype=q.dtype, device=q.device)
mask.masked_fill_(mask.to(torch.bool).triu(S-L+1), float('-inf'))
is_causal = False
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0,
attn_mask=None if mask is None else mask.to(q.dtype),
is_causal=is_causal
)
else:
# manual implementation of attention
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
# apply mask, or use causal if default
if mask is not None:
att = att + mask
elif self.is_causal:
L, S = q.size(-2), k.size(-2)
mask = torch.ones(1, 1, L, S).triu(S-L+1).to(dtype=torch.bool).to(x.device)
att.masked_fill_(mask, float('-inf'))
# upcast to float32 for numerical stability, as per llama implementation
att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
att = self.attn_dropout(att)
# multiply attention weights with values to get output
y = torch.einsum('bnsk,bnkh->bnsh', att, v)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
# return key and value caches if requested
if return_kv:
return y, k, v
return y
def kv_cache_forward(self, x, pos, k_cache=None, v_cache=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Apply rotary positional embedding (before concat)
k = self.rope(k, pos)
q = self.rope(q, pos)
# append cached keys and values with new keys and values
if k_cache is not None:
k = torch.cat((k_cache, k), dim=2)
if v_cache is not None:
v = torch.cat((v_cache, v), dim=2)
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y, k, v
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
# Apply XLA sharding for model parallelism
xla_device_available = False
if xm is not None:
try:
device_kind = xm.xla_device_kind()
if device_kind is not None:
xla_device_available = True
except RuntimeError:
pass
if xla_device_available and xs is not None and xs.global_mesh() is not None:
mesh = xs.global_mesh()
if mesh.mesh_shape[1] > 1: # If the 'model' axis has size > 1
xs.mark_sharding(self.c_fc.weight, mesh, (1, 0))
if self.c_fc.bias is not None:
xs.mark_sharding(self.c_fc.bias, mesh, (1,))
print(f"MLP: Applied MP sharding to c_fc {mesh.mesh_shape} spec weight(1,0), bias(1,)")
xs.mark_sharding(self.c_proj.weight, mesh, (0, 1))
if self.c_proj.bias is not None:
xs.mark_sharding(self.c_proj.bias, mesh, (0,))
print(f"MLP: Applied MP sharding to c_proj {mesh.mesh_shape} spec weight(0,1), bias(0,)")
def forward(self, x, spmd_mesh=None):
x = self.c_fc(x)
x = self.gelu(x)
if spmd_mesh is not None:
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
x = self.c_proj(x)
x = self.dropout(x)
if spmd_mesh is not None:
xs.mark_sharding(x, spmd_mesh, (('dcn', 'data'), None, 'model'))
return x
class RMSNorm(nn.Module):
""" Root Mean Square Normalization """
def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-5): # whl
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) if weight else None
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
return output * self.weight
return output
class PSIBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, bias=config.bias)
self.attn = PSIAttentionLayer(config)
self.ln_2 = RMSNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, spmd_mesh=None, mask=None):
# If we are given a key and value cache, we will use the pre-computed values to minimize
# the computation cost
if return_kv:
# Pass the key and value cache to the attention layer, obtain new key and value caches
x_attn, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache,
return_kv=True, inplace_kv=inplace_kv, mask=mask)
x = x + x_attn
x = x + self.mlp(self.ln_2(x))
return x, k, v
# Else we proceed with the regular forward pass
x = x + self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, inplace_kv=inplace_kv, mask=mask)
x = x + self.mlp(self.ln_2(x))
return x
class PartitionedEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, partition_size=65536):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.partition_size = partition_size
self.num_partitions = (num_embeddings + partition_size - 1) // partition_size
self.embedding_layers = nn.ModuleList()
for i in range(self.num_partitions):
start_idx = i * self.partition_size
end_idx = min(start_idx + self.partition_size, num_embeddings)
vocab_size = end_idx - start_idx
self.embedding_layers.append(nn.Embedding(vocab_size, embedding_dim))
def forward(self, input_ids):
partition_ids = input_ids // self.partition_size
relative_ids = input_ids % self.partition_size
output = torch.zeros(*input_ids.shape, self.embedding_dim, device=input_ids.device, dtype=self.embedding_layers[0].weight.dtype)
for i in range(self.num_partitions):
mask = (partition_ids == i)
if mask.any():
partition_input_ids = relative_ids[mask]
embedded = self.embedding_layers[i](partition_input_ids)
output[mask] = embedded
return output
class PartitionedLinear(nn.Module):
def __init__(self, in_features, out_features, partition_size=65536, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.partition_size = partition_size
self.num_partitions = (out_features + partition_size - 1) // partition_size
self.linear_layers = nn.ModuleList()
for i in range(self.num_partitions):
start_idx = i * self.partition_size
end_idx = min(start_idx + self.partition_size, out_features)
output_partition_size = end_idx - start_idx
self.linear_layers.append(nn.Linear(in_features, output_partition_size, bias=bias))
def forward(self, input):
outputs = [layer(input) for layer in self.linear_layers]
return torch.cat(outputs, dim=-1)