GPT2-hadamard-medium / modeling_my_model.py
shubham212's picture
Upload folder using huggingface_hub
93ba6f2 verified
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
from typing import Callable
from transformers.generation.utils import GenerationMixin
from functools import partial
from random import randrange
import math
from fast_hadamard_transform import hadamard_transform
import torch
from torch import nn, cat
import torch.nn.functional as F
from torch.nn import Module, Sequential
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat, reduce, einsum
from einops.layers.torch import Rearrange, Reduce
from .configuration_my_model import GPTConfig
"""
ein notation:
b - batch
d - feature dimension
s - residual streams
t - residual streams + num branch inputs
f - number of fractions (division of feature dimension space)
v - number of views for branch input
"""
# helper functions
def exists(v):
return v is not None
def divisible_by(num, den):
return (num % den) == 0
def default(v, d):
return v if exists(v) else d
def identity(t):
return t
def add(x, y):
return x + y
def sinkhorn_log(logits, num_iters=10, tau=0.05):
n = logits.shape[-1]
Z = logits / tau
log_marginal = torch.full(
(n,), -math.log(n), device=logits.device, dtype=logits.dtype
)
u = torch.zeros(n, device=Z.device, dtype=Z.dtype)
v = torch.zeros(n, device=Z.device, dtype=Z.dtype)
for _ in range(num_iters):
u = log_marginal - torch.logsumexp(Z + v.unsqueeze(0), dim=1)
v = log_marginal - torch.logsumexp(Z + u.unsqueeze(1), dim=0)
return torch.exp(Z + u.unsqueeze(1) + v.unsqueeze(0)) * n
def zeropower_via_newtonschulz(X, steps=5, eps=1e-7, coeffs=(3.0, -3.2, 1.2)):
a, b, c = coeffs
X = X / (X.norm() + eps)
transpose = False
if X.shape[0] > X.shape[1]:
X = X.T
transpose = True
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if transpose:
X = X.T
return X
def orthostochastic_project(
logits, ns_steps=5, ns_eps=1e-7, ns_coeffs=(3.0, -3.2, 1.2)
):
O = zeropower_via_newtonschulz(logits, steps=ns_steps, eps=ns_eps, coeffs=ns_coeffs)
return O.square()
# main functions
def get_expand_reduce_stream_functions(
num_streams, add_stream_embed=False, dim=None, disable=False
):
if num_streams == 1 or disable:
return (nn.Identity(), nn.Identity())
if add_stream_embed:
assert exists(dim), (
"`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added"
)
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams=True)
else:
expand_fn = Reduce(
pattern="b ... -> (b s) ...", reduction="repeat", s=num_streams
)
reduce_fn = Reduce(pattern="(b s) ... -> b ...", reduction="sum", s=num_streams)
return expand_fn, reduce_fn
def get_init_and_expand_reduce_stream_functions(
num_streams, num_fracs=1, dim=None, add_stream_embed=False, disable=None
):
disable = default(disable, num_streams == 1 and num_fracs == 1)
hyper_conn_klass = HyperConnections if not disable else Residual
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs=num_fracs)
expand_reduce_fns = get_expand_reduce_stream_functions(
num_streams, add_stream_embed=add_stream_embed, dim=dim, disable=disable
)
if exists(dim):
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim=dim)
return (init_hyper_conn_fn, *expand_reduce_fns)
# norms
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.zeros(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
# main classes
# residual base class
class Residual(Module):
def __init__(
self,
*args,
branch: Module | None = None,
residual_transform: Module | None = None,
**kwargs,
):
super().__init__()
self.branch = branch
self.residual_transform = default(residual_transform, nn.Identity())
def width_connection(self, residuals):
return residuals, residuals, dict()
def depth_connection(
self,
branch_output,
residuals,
):
return branch_output + self.residual_transform(residuals)
def decorate_branch(self, branch: Callable):
assert not exists(self.branch), "branch was already wrapped on init"
def forward_and_add_residual(residual, *args, **kwargs):
branch_input, add_residual = self.forward(residual)
branch_output = branch(branch_input, *args, **kwargs)
residual = add_residual(branch_output)
return residual
return forward_and_add_residual
def forward(self, residuals, *branch_args, **branch_kwargs):
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
def add_residual_fn(branch_out):
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
return tree_unflatten((branch_out, *rest), tree_spec)
if not exists(self.branch):
return branch_input, add_residual_fn
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
return add_residual_fn(branch_output)
# hyper connection residual streams
class HyperConnections(Module):
def __init__(
self,
num_residual_streams,
*,
dim,
branch: Module | None = None,
layer_index=None,
tanh=True,
channel_first=False,
dropout=0.0,
residual_transform: Module
| None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
add_branch_out_to_residual=True, # will disable depth connections (weighted residual sum with beta) if set False
num_input_views=1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
depth_residual_fn=add,
num_fracs=1, # https://arxiv.org/abs/2503.14125
mhc=False,
sinkhorn_iters=10,
sinkhorn_tau=0.05,
mhc_h_res_proj="sinkhorn",
ns_steps=5,
ns_eps=1e-7,
ns_coeffs=(3.0, -3.2, 1.2),
):
"""
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
"""
super().__init__()
self.branch = branch
self.act = nn.Tanh() if tanh else nn.Identity()
# frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
assert num_fracs >= 1
self.num_fracs = num_fracs
self.has_fracs = num_fracs > 1
self.split_fracs = Rearrange("b ... (f d) -> b ... f d", f=num_fracs)
self.merge_fracs = Rearrange("b ... f d -> b ... (f d)")
assert divisible_by(dim, num_fracs), (
f"feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})"
)
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
# they used layernorm in paper, but rmsnorm is fine given what we know now
self.norm = RMSNorm(dim)
assert num_residual_streams > 0, "`num_residual_streams` must be greater than 0"
self.num_residual_streams = num_residual_streams
init_residual_index = (
default(layer_index, randrange(num_residual_streams)) % num_residual_streams
) # just choose one random residual stream if layer index not given
# handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
num_residual_streams_fracs = num_residual_streams * num_fracs
num_input_views_fracs = num_input_views * num_fracs
# width num residual streams
assert num_input_views >= 1
self.num_input_views = num_input_views
# width connection
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
init_alpha0[init_residual_index, :] = 1.0
self.static_alpha = nn.Parameter(
cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim=1)
)
self.dynamic_alpha_fn = nn.Parameter(
torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs)
)
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
# depth connection related (beta)
self.add_branch_out_to_residual = add_branch_out_to_residual
if add_branch_out_to_residual:
self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
dynamic_beta_shape = (
(dim,) if num_fracs == 1 else (dim, num_fracs)
) # preserve backwards compat
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
# dropouts
self.dropout = nn.Dropout(dropout)
# channel first option
self.channel_first = channel_first
# maybe residual transform
self.residual_transform = default(residual_transform, nn.Identity())
# maybe custom depth connection residual function
# this is to prepare for gating the addition of the branch outputs to the residual streams
# needed for memory lanes a la RMT / LMM
self.depth_residual_fn = depth_residual_fn
self.mhc = mhc
self.sinkhorn_iters = sinkhorn_iters
self.sinkhorn_tau = sinkhorn_tau
self.mhc_h_res_proj = mhc_h_res_proj
self.ns_steps = ns_steps
self.ns_eps = ns_eps
self.ns_coeffs = ns_coeffs
if mhc:
assert num_fracs == 1, "mhc currently requires num_fracs = 1"
assert num_input_views == 1, "mhc currently requires num_input_views = 1"
assert mhc_h_res_proj in (
"sinkhorn",
"orthostochastic",
), "mhc_h_res_proj must be 'sinkhorn' or 'orthostochastic'"
H_res_init = torch.full((num_residual_streams, num_residual_streams), -8.0)
H_res_init.fill_diagonal_(0.0)
self.H_res_logits = nn.Parameter(H_res_init)
H_pre_init = torch.full((num_residual_streams,), -8.0)
H_pre_init[init_residual_index] = 0.0
self.H_pre_logits = nn.Parameter(H_pre_init)
if add_branch_out_to_residual:
self.H_post_logits = nn.Parameter(torch.zeros(num_residual_streams))
def width_connection(self, residuals):
streams = self.num_residual_streams
maybe_transformed_residuals = self.residual_transform(residuals)
# width connection
# handle channel first
if self.channel_first:
residuals = rearrange(residuals, "b d ... -> b ... d")
# split out fractions
residuals = self.split_fracs(residuals)
# split out streams
residuals = rearrange(residuals, "(b s) ... d -> b ... s d", s=streams)
if self.mhc:
residuals_mixed_source = maybe_transformed_residuals
if self.channel_first:
residuals_mixed_source = rearrange(
residuals_mixed_source, "b d ... -> b ... d"
)
residuals_mixed_source = self.split_fracs(residuals_mixed_source)
residuals_mixed_source = rearrange(
residuals_mixed_source, "(b s) ... d -> b ... s d", s=streams
)
if self.mhc_h_res_proj == "orthostochastic":
H_res = orthostochastic_project(
self.H_res_logits,
ns_steps=self.ns_steps,
ns_eps=self.ns_eps,
ns_coeffs=self.ns_coeffs,
)
else:
H_res = sinkhorn_log(
self.H_res_logits, self.sinkhorn_iters, self.sinkhorn_tau
)
H_pre = F.softmax(self.H_pre_logits, dim=-1)
H_post = None
if self.add_branch_out_to_residual:
H_post = F.softmax(self.H_post_logits, dim=-1)
residuals_mixed = einsum(
H_res, residuals_mixed_source, "s t, ... s d -> ... t d"
)
branch_input = einsum(H_pre, residuals, "s, ... s d -> ... d")
if getattr(self, "collect_stats", False):
with torch.no_grad():
stats = dict(
h_res_min=H_res.min(),
h_res_row_sum=H_res.sum(dim=-1).mean(),
h_res_col_sum=H_res.sum(dim=-2).mean(),
h_pre_min=H_pre.min(),
)
if H_post is not None:
stats["h_post_min"] = H_post.min()
self.last_stats = {k: v.detach() for k, v in stats.items()}
if self.channel_first:
branch_input = rearrange(branch_input, "b ... d -> b d ...")
branch_input = self.merge_fracs(branch_input)
return (
branch_input,
maybe_transformed_residuals,
dict(beta=H_post, residuals_mixed=residuals_mixed),
)
# norm
normed = self.norm(residuals)
# alpha for weighted sum of residuals going into branch
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
static_alpha = rearrange(self.static_alpha, "(f s) d -> f s d", s=streams)
alpha = dynamic_alpha + static_alpha
alpha = self.split_fracs(
alpha
) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
# beta for weights from branch output back to residual streams
beta = None
if self.add_branch_out_to_residual:
dc_weight = self.act(normed @ self.dynamic_beta_fn)
if not self.has_fracs:
dc_weight = rearrange(dc_weight, "... -> ... 1")
dynamic_beta = dc_weight * self.dynamic_beta_scale
static_beta = rearrange(self.static_beta, "... (s f) -> ... s f", s=streams)
beta = dynamic_beta + static_beta
if getattr(self, "collect_stats", False):
with torch.no_grad():
num_input_views_fracs = self.num_input_views * self.num_fracs
alpha_branch = alpha[..., :num_input_views_fracs]
alpha_residual = alpha[..., num_input_views_fracs:]
alpha_branch_abs_mean = alpha_branch.abs().mean()
alpha_residual_abs_mean = alpha_residual.abs().mean()
stats = dict(
alpha_branch_mean=alpha_branch.mean(),
alpha_branch_abs_mean=alpha_branch_abs_mean,
alpha_residual_mean=alpha_residual.mean(),
alpha_residual_abs_mean=alpha_residual_abs_mean,
alpha_branch_residual_ratio=alpha_branch_abs_mean
/ (alpha_residual_abs_mean + 1e-8),
)
if beta is not None:
stats.update(
beta_mean=beta.mean(),
beta_abs_mean=beta.abs().mean(),
beta_min=beta.min(),
beta_max=beta.max(),
)
self.last_stats = {k: v.detach() for k, v in stats.items()}
mix_h = einsum(alpha, residuals, "... f1 s f2 t, ... f1 s d -> ... f2 t d")
if self.num_input_views == 1:
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
else:
branch_input, residuals = (
mix_h[..., : self.num_input_views, :],
mix_h[..., self.num_input_views :, :],
)
branch_input = rearrange(branch_input, "b ... v d -> v b ... d")
if self.channel_first:
branch_input = rearrange(branch_input, "b ... d -> b d ...")
# maybe merge fractions back
branch_input = self.merge_fracs(branch_input)
return branch_input, maybe_transformed_residuals, dict(beta=beta)
def depth_connection(self, branch_output, residuals, *, beta, residuals_mixed=None):
assert self.add_branch_out_to_residual
# maybe split fractions
branch_output = self.split_fracs(branch_output)
# 'depth' connection
if self.channel_first:
branch_output = rearrange(branch_output, "b d ... -> b ... d")
if self.mhc:
assert residuals_mixed is not None
assert beta is not None
branch_to_streams = einsum(branch_output, beta, "b ... d, s -> b ... s d")
output = residuals_mixed + branch_to_streams
output = rearrange(output, "b ... s d -> (b s) ... d")
output = self.merge_fracs(output)
if self.channel_first:
output = rearrange(output, "b ... d -> b d ...")
return self.dropout(output)
output = einsum(
branch_output, beta, "b ... f1 d, b ... f1 s f2 -> b ... f2 s d"
)
output = rearrange(output, "b ... s d -> (b s) ... d")
# merge merge back fractions
output = self.merge_fracs(output)
# channel first
if self.channel_first:
output = rearrange(output, "b ... d -> b d ...")
residuals = self.depth_residual_fn(output, residuals)
return self.dropout(residuals)
def decorate_branch(self, branch: Callable):
assert not exists(self.branch), "branch was already wrapped on init"
def forward_and_add_residual(residual, *args, **kwargs):
branch_input, add_residual = self.forward(residual)
branch_output = branch(branch_input, *args, **kwargs)
residual = add_residual(branch_output)
return residual
return forward_and_add_residual
def forward(self, residuals, *branch_args, **branch_kwargs):
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
def add_residual_fn(branch_out):
if not self.add_branch_out_to_residual:
return branch_out
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
return tree_unflatten((branch_out, *rest), tree_spec)
if not exists(self.branch):
return branch_input, add_residual_fn
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
return add_residual_fn(branch_output)
HyperConnections.get_expand_reduce_stream_functions = staticmethod(
get_expand_reduce_stream_functions
)
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(
get_init_and_expand_reduce_stream_functions
)
# stream embed
class StreamEmbed(Module):
def __init__(self, num_streams, dim, channel_first=False, expand_to_streams=False):
super().__init__()
self.channel_first = channel_first
self.num_streams = num_streams
self.expand_to_streams = expand_to_streams
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
def forward(self, residuals):
if self.expand_to_streams:
residuals = repeat(residuals, "b ... -> (b s) ...", s=self.num_streams)
if self.channel_first:
residuals = rearrange(
residuals, "(b s) d ... -> b ... s d", s=self.num_streams
)
else:
residuals = rearrange(
residuals, "(b s) ... d -> b ... s d", s=self.num_streams
)
residuals = residuals + self.stream_embed
if self.channel_first:
residuals = rearrange(
residuals, "b ... s d -> (b s) d ...", s=self.num_streams
)
else:
residuals = rearrange(
residuals, "b ... s d -> (b s) ... d", s=self.num_streams
)
return residuals
# attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
class AttentionPoolReduceStream(Module):
def __init__(self, num_streams, dim, channel_first=False):
super().__init__()
self.num_streams = num_streams
self.channel_first = channel_first
self.to_attn_logits = nn.Linear(dim, dim, bias=False)
self.to_attn_logits.weight.data.copy_(torch.eye(dim))
def forward(self, residuals):
if self.channel_first:
residuals = rearrange(
residuals, "(b s) d ... -> b ... s d", s=self.num_streams
)
else:
residuals = rearrange(
residuals, "(b s) ... d -> b ... s d", s=self.num_streams
)
attn_logits = self.to_attn_logits(residuals)
attn = attn_logits.softmax(dim=-2)
residuals = reduce(residuals * attn, "b ... s d -> b ... d", "sum")
if self.channel_first:
residuals = rearrange(residuals, "b ... d -> b d ...")
return residuals
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
self.cos_cached = freqs.cos()
self.sin_cached = freqs.sin()
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3]//2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
class CasualSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.c_attn = nn.Linear(config.n_embd, 3* config.n_embd)
# self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# self.c_proj.NANOGPT_SCALE_INIT = 1 # Attaching a flag/attribute for initialization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.hada_scale = 1.0/ math.sqrt(config.n_embd)
self.rotary = Rotary(config.n_embd//config.n_head)
self.out_scale = nn.Parameter(
torch.ones(config.n_embd) / math.sqrt(2 * config.n_layer)
)
self.out_bias = nn.Parameter(torch.zeros(config.n_embd))
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
# @torch.compile
def fused_hadamard_output(self, y, B, T, C):
"""Fuse reshape, hadamard, scaling operations"""
y = y.reshape(-1, C)
y = hadamard_transform(y, scale = self.hada_scale) # Batched transform on last dim
y = y * self.out_scale # Vector scaling
y = y + self.out_bias # <- add this
return y.view(B, T, C)
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
cos, sin = self.rotary(q)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
# y: (B, T, C)
y = self.fused_hadamard_output(y,B,T,C)
# y2d = y.view(-1, C) # (B * T, C)
# y2d = hadamard_transform(y2d) # apply per-vector
# y = y2d.view(B, T, C) # restore original shape
# y = y * self.out_scale
return y
# class CasualSelfAttention(nn.Module):
# def __init__(self, config):
# super().__init__()
# self.c_attn = nn.Linear(config.n_embd, 3* config.n_embd)
# self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# self.c_proj.NANOGPT_SCALE_INIT = 1 # Attaching a flag/attribute for initialization
# self.n_head = config.n_head
# self.n_embd = config.n_embd
# self.rotary = Rotary(config.n_embd//config.n_head)
# # causal mask to ensure that attention is only applied to the left in the input sequence
# self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
# .view(1, 1, config.block_size, config.block_size))
# def forward(self, x):
# B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# # calculate query, key, values for all heads in batch and move head forward to be the batch dim
# q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# k = k.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
# q = q.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
# v = v.view(B, T, self.n_head, C // self.n_head) # (B, nh, T, hs)
# cos, sin = self.rotary(q)
# q = apply_rotary_emb(q, cos, sin)
# k = apply_rotary_emb(k, cos, sin)
# y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
# y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# # output projection
# y = self.c_proj(y)
# return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
# Set inner dim ~ 8/3 * d so that parameter count matches 4d-GELU
inner_dim = int(4 * config.n_embd * 2 / 3)
# inner_dim = 2048
# Round up to multiple of 256 for efficiency
# inner_dim = ((inner_dim + 255) // 256) * 256
self.inner_dim = inner_dim
# value and gate
self.c_fc = nn.Linear(config.n_embd, 2 * inner_dim)
self.c_proj = nn.Linear(inner_dim, config.n_embd)
self.c_proj.NANOGPT_SCALE_INIT = 1
def forward(self, x):
x_in = self.c_fc(x)
x_gate, x_up = x_in.chunk(2, dim=-1)
x = F.silu(x_gate) * x_up
x = self.c_proj(x)
return x
class AttnBranch(nn.Module):
def __init__(self, norm, attn):
super().__init__()
self.norm = norm
self.attn = attn
def forward(self, x):
return self.attn(self.norm(x))
class Block(nn.Module):
def __init__(self, config, layer_idx, init_hc):
super().__init__()
self.ln_1 = nn.RMSNorm(config.n_embd,eps = 1e-6)
self.attn = CasualSelfAttention(config)
self.ln_2 = nn.RMSNorm(config.n_embd,eps = 1e-6)
self.mlp = MLP(config)
self.attn_branch = AttnBranch(self.ln_1, self.attn)
hc_kwargs = dict(
mhc=config.mhc,
sinkhorn_iters=config.sinkhorn_iters,
sinkhorn_tau=config.sinkhorn_tau,
mhc_h_res_proj=config.mhc_h_res_proj,
ns_steps=config.ns_steps,
ns_eps=config.ns_eps,
ns_coeffs=config.ns_coeffs,
)
self.hc_attn = init_hc(
dim=config.n_embd,
branch=self.attn_branch,
layer_index=layer_idx * 2,
**hc_kwargs,
)
self.hc_mlp = init_hc(
dim=config.n_embd,
branch=nn.Sequential(self.ln_2, self.mlp),
layer_index=layer_idx * 2 + 1,
**hc_kwargs,
)
def forward(self, x):
x = self.hc_attn(x)
x = self.hc_mlp(x)
return x
class GPTConfig(PretrainedConfig):
model_type = "custom_gpt"
def __init__(
self,
block_size=1024,
vocab_size=50304,
n_layer=24,
n_head=16,
n_embd=1024,
dropout=0.0,
bias=True,
hc_num_streams=1,
hc_num_fracs=1,
hc_disable=False,
mhc=False,
sinkhorn_iters=10,
sinkhorn_tau=0.05,
mhc_h_res_proj="sinkhorn",
ns_steps=5,
ns_eps=1e-7,
ns_coeffs=(3.0, -3.2, 1.2),
**kwargs,
):
super().__init__(**kwargs)
self.block_size = block_size
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.bias = bias
self.hc_num_streams = hc_num_streams
self.hc_num_fracs = hc_num_fracs
self.hc_disable = hc_disable
self.mhc = mhc
self.sinkhorn_iters = sinkhorn_iters
self.sinkhorn_tau = sinkhorn_tau
self.mhc_h_res_proj = mhc_h_res_proj
self.ns_steps = ns_steps
self.ns_eps = ns_eps
self.ns_coeffs = ns_coeffs
# 🔑 HF compatibility aliases
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.hidden_size = n_embd
self.max_position_embeddings = block_size
class GPT(PreTrainedModel, GenerationMixin):
config_class = GPTConfig
# config_class = MyGPTConfig
def __init__(self, config):
super().__init__(config)
init_hc, expand_stream, reduce_stream = (
get_init_and_expand_reduce_stream_functions(
config.hc_num_streams,
num_fracs=config.hc_num_fracs,
disable=config.hc_disable,
)
)
self.expand_stream = expand_stream
self.reduce_stream = reduce_stream
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList(
[Block(config, i, init_hc) for i in range(config.n_layer)]
),
ln_f = nn.RMSNorm(config.n_embd,eps = 1e-6)
)
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.post_init()
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
**kwargs,
):
# We do NOT use KV cache yet, so always feed full sequence
return {
"input_ids": input_ids,
"past_key_values": None,
}
def forward(
self,
input_ids=None,
attention_mask=None, # 👈 ADD THIS
labels=None,
past_key_values=None,
use_cache=None,
**kwargs,
):
b, t = input_ids.size()
assert t <= self.config.block_size
pos = torch.arange(0, t, device=input_ids.device).unsqueeze(0)
x = self.transformer.wte(input_ids)
x = self.expand_stream(x)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
x = self.reduce_stream(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
)
return CausalLMOutput(
loss=loss,
logits=logits,
)