| import math |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutput |
| from typing import Callable |
| from transformers.generation.utils import GenerationMixin |
| from functools import partial |
| from random import randrange |
| import math |
| from fast_hadamard_transform import hadamard_transform |
| import torch |
| from torch import nn, cat |
| import torch.nn.functional as F |
| from torch.nn import Module, Sequential |
| from torch.utils._pytree import tree_flatten, tree_unflatten |
|
|
| from einops import rearrange, repeat, reduce, einsum |
| from einops.layers.torch import Rearrange, Reduce |
| from .configuration_my_model import GPTConfig |
|
|
| """ |
| ein notation: |
| b - batch |
| d - feature dimension |
| s - residual streams |
| t - residual streams + num branch inputs |
| f - number of fractions (division of feature dimension space) |
| v - number of views for branch input |
| """ |
|
|
| |
|
|
|
|
| def exists(v): |
| return v is not None |
|
|
|
|
| def divisible_by(num, den): |
| return (num % den) == 0 |
|
|
|
|
| def default(v, d): |
| return v if exists(v) else d |
|
|
|
|
| def identity(t): |
| return t |
|
|
|
|
| def add(x, y): |
| return x + y |
|
|
|
|
| def sinkhorn_log(logits, num_iters=10, tau=0.05): |
| n = logits.shape[-1] |
| Z = logits / tau |
| log_marginal = torch.full( |
| (n,), -math.log(n), device=logits.device, dtype=logits.dtype |
| ) |
|
|
| u = torch.zeros(n, device=Z.device, dtype=Z.dtype) |
| v = torch.zeros(n, device=Z.device, dtype=Z.dtype) |
|
|
| for _ in range(num_iters): |
| u = log_marginal - torch.logsumexp(Z + v.unsqueeze(0), dim=1) |
| v = log_marginal - torch.logsumexp(Z + u.unsqueeze(1), dim=0) |
|
|
| return torch.exp(Z + u.unsqueeze(1) + v.unsqueeze(0)) * n |
|
|
|
|
| def zeropower_via_newtonschulz(X, steps=5, eps=1e-7, coeffs=(3.0, -3.2, 1.2)): |
| a, b, c = coeffs |
|
|
| X = X / (X.norm() + eps) |
|
|
| transpose = False |
| if X.shape[0] > X.shape[1]: |
| X = X.T |
| transpose = True |
|
|
| for _ in range(steps): |
| A = X @ X.T |
| B = b * A + c * A @ A |
| X = a * X + B @ X |
|
|
| if transpose: |
| X = X.T |
|
|
| return X |
|
|
|
|
| def orthostochastic_project( |
| logits, ns_steps=5, ns_eps=1e-7, ns_coeffs=(3.0, -3.2, 1.2) |
| ): |
| O = zeropower_via_newtonschulz(logits, steps=ns_steps, eps=ns_eps, coeffs=ns_coeffs) |
| return O.square() |
|
|
|
|
| |
|
|
|
|
| def get_expand_reduce_stream_functions( |
| num_streams, add_stream_embed=False, dim=None, disable=False |
| ): |
| if num_streams == 1 or disable: |
| return (nn.Identity(), nn.Identity()) |
|
|
| if add_stream_embed: |
| assert exists(dim), ( |
| "`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added" |
| ) |
|
|
| expand_fn = StreamEmbed(num_streams, dim, expand_to_streams=True) |
| else: |
| expand_fn = Reduce( |
| pattern="b ... -> (b s) ...", reduction="repeat", s=num_streams |
| ) |
|
|
| reduce_fn = Reduce(pattern="(b s) ... -> b ...", reduction="sum", s=num_streams) |
|
|
| return expand_fn, reduce_fn |
|
|
|
|
| def get_init_and_expand_reduce_stream_functions( |
| num_streams, num_fracs=1, dim=None, add_stream_embed=False, disable=None |
| ): |
| disable = default(disable, num_streams == 1 and num_fracs == 1) |
|
|
| hyper_conn_klass = HyperConnections if not disable else Residual |
|
|
| init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs=num_fracs) |
| expand_reduce_fns = get_expand_reduce_stream_functions( |
| num_streams, add_stream_embed=add_stream_embed, dim=dim, disable=disable |
| ) |
|
|
| if exists(dim): |
| init_hyper_conn_fn = partial(init_hyper_conn_fn, dim=dim) |
|
|
| return (init_hyper_conn_fn, *expand_reduce_fns) |
|
|
|
|
| |
|
|
|
|
| class RMSNorm(Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.scale = dim**0.5 |
| self.gamma = nn.Parameter(torch.zeros(dim)) |
|
|
| def forward(self, x): |
| return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1) |
|
|
|
|
| |
|
|
| |
|
|
|
|
| class Residual(Module): |
| def __init__( |
| self, |
| *args, |
| branch: Module | None = None, |
| residual_transform: Module | None = None, |
| **kwargs, |
| ): |
| super().__init__() |
| self.branch = branch |
| self.residual_transform = default(residual_transform, nn.Identity()) |
|
|
| def width_connection(self, residuals): |
| return residuals, residuals, dict() |
|
|
| def depth_connection( |
| self, |
| branch_output, |
| residuals, |
| ): |
| return branch_output + self.residual_transform(residuals) |
|
|
| def decorate_branch(self, branch: Callable): |
| assert not exists(self.branch), "branch was already wrapped on init" |
|
|
| def forward_and_add_residual(residual, *args, **kwargs): |
| branch_input, add_residual = self.forward(residual) |
|
|
| branch_output = branch(branch_input, *args, **kwargs) |
|
|
| residual = add_residual(branch_output) |
|
|
| return residual |
|
|
| return forward_and_add_residual |
|
|
| def forward(self, residuals, *branch_args, **branch_kwargs): |
| branch_input, residuals, residual_kwargs = self.width_connection(residuals) |
|
|
| def add_residual_fn(branch_out): |
| (branch_out, *rest), tree_spec = tree_flatten(branch_out) |
|
|
| branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs) |
|
|
| return tree_unflatten((branch_out, *rest), tree_spec) |
|
|
| if not exists(self.branch): |
| return branch_input, add_residual_fn |
|
|
| branch_output = self.branch(branch_input, *branch_args, **branch_kwargs) |
|
|
| return add_residual_fn(branch_output) |
|
|
|
|
| |
|
|
|
|
| class HyperConnections(Module): |
| def __init__( |
| self, |
| num_residual_streams, |
| *, |
| dim, |
| branch: Module | None = None, |
| layer_index=None, |
| tanh=True, |
| channel_first=False, |
| dropout=0.0, |
| residual_transform: Module |
| | None = None, |
| add_branch_out_to_residual=True, |
| num_input_views=1, |
| depth_residual_fn=add, |
| num_fracs=1, |
| mhc=False, |
| sinkhorn_iters=10, |
| sinkhorn_tau=0.05, |
| mhc_h_res_proj="sinkhorn", |
| ns_steps=5, |
| ns_eps=1e-7, |
| ns_coeffs=(3.0, -3.2, 1.2), |
| ): |
| """ |
| Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606 |
| """ |
| super().__init__() |
|
|
| self.branch = branch |
|
|
| self.act = nn.Tanh() if tanh else nn.Identity() |
|
|
| |
|
|
| assert num_fracs >= 1 |
|
|
| self.num_fracs = num_fracs |
| self.has_fracs = num_fracs > 1 |
|
|
| self.split_fracs = Rearrange("b ... (f d) -> b ... f d", f=num_fracs) |
| self.merge_fracs = Rearrange("b ... f d -> b ... (f d)") |
|
|
| assert divisible_by(dim, num_fracs), ( |
| f"feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})" |
| ) |
|
|
| dim //= num_fracs |
|
|
| |
|
|
| self.norm = RMSNorm(dim) |
|
|
| assert num_residual_streams > 0, "`num_residual_streams` must be greater than 0" |
|
|
| self.num_residual_streams = num_residual_streams |
| init_residual_index = ( |
| default(layer_index, randrange(num_residual_streams)) % num_residual_streams |
| ) |
|
|
| |
|
|
| num_residual_streams_fracs = num_residual_streams * num_fracs |
| num_input_views_fracs = num_input_views * num_fracs |
|
|
| |
|
|
| assert num_input_views >= 1 |
| self.num_input_views = num_input_views |
|
|
| |
|
|
| init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs)) |
| init_alpha0[init_residual_index, :] = 1.0 |
|
|
| self.static_alpha = nn.Parameter( |
| cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim=1) |
| ) |
|
|
| self.dynamic_alpha_fn = nn.Parameter( |
| torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs) |
| ) |
| self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2) |
|
|
| |
|
|
| self.add_branch_out_to_residual = add_branch_out_to_residual |
|
|
| if add_branch_out_to_residual: |
| self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs)) |
|
|
| dynamic_beta_shape = ( |
| (dim,) if num_fracs == 1 else (dim, num_fracs) |
| ) |
| self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape)) |
|
|
| self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2) |
|
|
| |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| |
|
|
| self.channel_first = channel_first |
|
|
| |
|
|
| self.residual_transform = default(residual_transform, nn.Identity()) |
|
|
| |
| |
| |
|
|
| self.depth_residual_fn = depth_residual_fn |
|
|
| self.mhc = mhc |
| self.sinkhorn_iters = sinkhorn_iters |
| self.sinkhorn_tau = sinkhorn_tau |
| self.mhc_h_res_proj = mhc_h_res_proj |
| self.ns_steps = ns_steps |
| self.ns_eps = ns_eps |
| self.ns_coeffs = ns_coeffs |
|
|
| if mhc: |
| assert num_fracs == 1, "mhc currently requires num_fracs = 1" |
| assert num_input_views == 1, "mhc currently requires num_input_views = 1" |
| assert mhc_h_res_proj in ( |
| "sinkhorn", |
| "orthostochastic", |
| ), "mhc_h_res_proj must be 'sinkhorn' or 'orthostochastic'" |
|
|
| H_res_init = torch.full((num_residual_streams, num_residual_streams), -8.0) |
| H_res_init.fill_diagonal_(0.0) |
| self.H_res_logits = nn.Parameter(H_res_init) |
|
|
| H_pre_init = torch.full((num_residual_streams,), -8.0) |
| H_pre_init[init_residual_index] = 0.0 |
| self.H_pre_logits = nn.Parameter(H_pre_init) |
|
|
| if add_branch_out_to_residual: |
| self.H_post_logits = nn.Parameter(torch.zeros(num_residual_streams)) |
|
|
| def width_connection(self, residuals): |
| streams = self.num_residual_streams |
|
|
| maybe_transformed_residuals = self.residual_transform(residuals) |
|
|
| |
|
|
| |
|
|
| if self.channel_first: |
| residuals = rearrange(residuals, "b d ... -> b ... d") |
|
|
| |
|
|
| residuals = self.split_fracs(residuals) |
|
|
| |
|
|
| residuals = rearrange(residuals, "(b s) ... d -> b ... s d", s=streams) |
|
|
| if self.mhc: |
| residuals_mixed_source = maybe_transformed_residuals |
|
|
| if self.channel_first: |
| residuals_mixed_source = rearrange( |
| residuals_mixed_source, "b d ... -> b ... d" |
| ) |
|
|
| residuals_mixed_source = self.split_fracs(residuals_mixed_source) |
| residuals_mixed_source = rearrange( |
| residuals_mixed_source, "(b s) ... d -> b ... s d", s=streams |
| ) |
|
|
| if self.mhc_h_res_proj == "orthostochastic": |
| H_res = orthostochastic_project( |
| self.H_res_logits, |
| ns_steps=self.ns_steps, |
| ns_eps=self.ns_eps, |
| ns_coeffs=self.ns_coeffs, |
| ) |
| else: |
| H_res = sinkhorn_log( |
| self.H_res_logits, self.sinkhorn_iters, self.sinkhorn_tau |
| ) |
| H_pre = F.softmax(self.H_pre_logits, dim=-1) |
|
|
| H_post = None |
| if self.add_branch_out_to_residual: |
| H_post = F.softmax(self.H_post_logits, dim=-1) |
|
|
| residuals_mixed = einsum( |
| H_res, residuals_mixed_source, "s t, ... s d -> ... t d" |
| ) |
| branch_input = einsum(H_pre, residuals, "s, ... s d -> ... d") |
|
|
| if getattr(self, "collect_stats", False): |
| with torch.no_grad(): |
| stats = dict( |
| h_res_min=H_res.min(), |
| h_res_row_sum=H_res.sum(dim=-1).mean(), |
| h_res_col_sum=H_res.sum(dim=-2).mean(), |
| h_pre_min=H_pre.min(), |
| ) |
| if H_post is not None: |
| stats["h_post_min"] = H_post.min() |
| self.last_stats = {k: v.detach() for k, v in stats.items()} |
|
|
| if self.channel_first: |
| branch_input = rearrange(branch_input, "b ... d -> b d ...") |
|
|
| branch_input = self.merge_fracs(branch_input) |
|
|
| return ( |
| branch_input, |
| maybe_transformed_residuals, |
| dict(beta=H_post, residuals_mixed=residuals_mixed), |
| ) |
|
|
| |
|
|
| normed = self.norm(residuals) |
|
|
| |
|
|
| wc_weight = self.act(normed @ self.dynamic_alpha_fn) |
| dynamic_alpha = wc_weight * self.dynamic_alpha_scale |
|
|
| static_alpha = rearrange(self.static_alpha, "(f s) d -> f s d", s=streams) |
|
|
| alpha = dynamic_alpha + static_alpha |
|
|
| alpha = self.split_fracs( |
| alpha |
| ) |
|
|
| |
|
|
| beta = None |
|
|
| if self.add_branch_out_to_residual: |
| dc_weight = self.act(normed @ self.dynamic_beta_fn) |
|
|
| if not self.has_fracs: |
| dc_weight = rearrange(dc_weight, "... -> ... 1") |
|
|
| dynamic_beta = dc_weight * self.dynamic_beta_scale |
|
|
| static_beta = rearrange(self.static_beta, "... (s f) -> ... s f", s=streams) |
|
|
| beta = dynamic_beta + static_beta |
|
|
| if getattr(self, "collect_stats", False): |
| with torch.no_grad(): |
| num_input_views_fracs = self.num_input_views * self.num_fracs |
| alpha_branch = alpha[..., :num_input_views_fracs] |
| alpha_residual = alpha[..., num_input_views_fracs:] |
| alpha_branch_abs_mean = alpha_branch.abs().mean() |
| alpha_residual_abs_mean = alpha_residual.abs().mean() |
| stats = dict( |
| alpha_branch_mean=alpha_branch.mean(), |
| alpha_branch_abs_mean=alpha_branch_abs_mean, |
| alpha_residual_mean=alpha_residual.mean(), |
| alpha_residual_abs_mean=alpha_residual_abs_mean, |
| alpha_branch_residual_ratio=alpha_branch_abs_mean |
| / (alpha_residual_abs_mean + 1e-8), |
| ) |
| if beta is not None: |
| stats.update( |
| beta_mean=beta.mean(), |
| beta_abs_mean=beta.abs().mean(), |
| beta_min=beta.min(), |
| beta_max=beta.max(), |
| ) |
| self.last_stats = {k: v.detach() for k, v in stats.items()} |
|
|
| mix_h = einsum(alpha, residuals, "... f1 s f2 t, ... f1 s d -> ... f2 t d") |
|
|
| if self.num_input_views == 1: |
| branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :] |
| else: |
| branch_input, residuals = ( |
| mix_h[..., : self.num_input_views, :], |
| mix_h[..., self.num_input_views :, :], |
| ) |
| branch_input = rearrange(branch_input, "b ... v d -> v b ... d") |
|
|
| if self.channel_first: |
| branch_input = rearrange(branch_input, "b ... d -> b d ...") |
|
|
| |
|
|
| branch_input = self.merge_fracs(branch_input) |
|
|
| return branch_input, maybe_transformed_residuals, dict(beta=beta) |
|
|
| def depth_connection(self, branch_output, residuals, *, beta, residuals_mixed=None): |
| assert self.add_branch_out_to_residual |
|
|
| |
|
|
| branch_output = self.split_fracs(branch_output) |
|
|
| |
|
|
| if self.channel_first: |
| branch_output = rearrange(branch_output, "b d ... -> b ... d") |
|
|
| if self.mhc: |
| assert residuals_mixed is not None |
| assert beta is not None |
|
|
| branch_to_streams = einsum(branch_output, beta, "b ... d, s -> b ... s d") |
| output = residuals_mixed + branch_to_streams |
| output = rearrange(output, "b ... s d -> (b s) ... d") |
|
|
| output = self.merge_fracs(output) |
|
|
| if self.channel_first: |
| output = rearrange(output, "b ... d -> b d ...") |
|
|
| return self.dropout(output) |
|
|
| output = einsum( |
| branch_output, beta, "b ... f1 d, b ... f1 s f2 -> b ... f2 s d" |
| ) |
|
|
| output = rearrange(output, "b ... s d -> (b s) ... d") |
|
|
| |
|
|
| output = self.merge_fracs(output) |
|
|
| |
|
|
| if self.channel_first: |
| output = rearrange(output, "b ... d -> b d ...") |
|
|
| residuals = self.depth_residual_fn(output, residuals) |
|
|
| return self.dropout(residuals) |
|
|
| def decorate_branch(self, branch: Callable): |
| assert not exists(self.branch), "branch was already wrapped on init" |
|
|
| def forward_and_add_residual(residual, *args, **kwargs): |
| branch_input, add_residual = self.forward(residual) |
|
|
| branch_output = branch(branch_input, *args, **kwargs) |
|
|
| residual = add_residual(branch_output) |
|
|
| return residual |
|
|
| return forward_and_add_residual |
|
|
| def forward(self, residuals, *branch_args, **branch_kwargs): |
| branch_input, residuals, residual_kwargs = self.width_connection(residuals) |
|
|
| def add_residual_fn(branch_out): |
| if not self.add_branch_out_to_residual: |
| return branch_out |
|
|
| (branch_out, *rest), tree_spec = tree_flatten(branch_out) |
|
|
| branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs) |
|
|
| return tree_unflatten((branch_out, *rest), tree_spec) |
|
|
| if not exists(self.branch): |
| return branch_input, add_residual_fn |
|
|
| branch_output = self.branch(branch_input, *branch_args, **branch_kwargs) |
|
|
| return add_residual_fn(branch_output) |
|
|
|
|
| HyperConnections.get_expand_reduce_stream_functions = staticmethod( |
| get_expand_reduce_stream_functions |
| ) |
| HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod( |
| get_init_and_expand_reduce_stream_functions |
| ) |
|
|
| |
|
|
|
|
| class StreamEmbed(Module): |
| def __init__(self, num_streams, dim, channel_first=False, expand_to_streams=False): |
| super().__init__() |
| self.channel_first = channel_first |
| self.num_streams = num_streams |
|
|
| self.expand_to_streams = expand_to_streams |
| self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim)) |
|
|
| def forward(self, residuals): |
| if self.expand_to_streams: |
| residuals = repeat(residuals, "b ... -> (b s) ...", s=self.num_streams) |
|
|
| if self.channel_first: |
| residuals = rearrange( |
| residuals, "(b s) d ... -> b ... s d", s=self.num_streams |
| ) |
| else: |
| residuals = rearrange( |
| residuals, "(b s) ... d -> b ... s d", s=self.num_streams |
| ) |
|
|
| residuals = residuals + self.stream_embed |
|
|
| if self.channel_first: |
| residuals = rearrange( |
| residuals, "b ... s d -> (b s) d ...", s=self.num_streams |
| ) |
| else: |
| residuals = rearrange( |
| residuals, "b ... s d -> (b s) ... d", s=self.num_streams |
| ) |
|
|
| return residuals |
|
|
|
|
| |
|
|
|
|
| class AttentionPoolReduceStream(Module): |
| def __init__(self, num_streams, dim, channel_first=False): |
| super().__init__() |
| self.num_streams = num_streams |
| self.channel_first = channel_first |
|
|
| self.to_attn_logits = nn.Linear(dim, dim, bias=False) |
| self.to_attn_logits.weight.data.copy_(torch.eye(dim)) |
|
|
| def forward(self, residuals): |
| if self.channel_first: |
| residuals = rearrange( |
| residuals, "(b s) d ... -> b ... s d", s=self.num_streams |
| ) |
| else: |
| residuals = rearrange( |
| residuals, "(b s) ... d -> b ... s d", s=self.num_streams |
| ) |
|
|
| attn_logits = self.to_attn_logits(residuals) |
| attn = attn_logits.softmax(dim=-2) |
|
|
| residuals = reduce(residuals * attn, "b ... s d -> b ... d", "sum") |
|
|
| if self.channel_first: |
| residuals = rearrange(residuals, "b ... d -> b d ...") |
|
|
| return residuals |
|
|
|
|
| class Rotary(torch.nn.Module): |
|
|
| def __init__(self, dim, base=10000): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self.seq_len_cached = None |
| self.cos_cached = None |
| self.sin_cached = None |
|
|
| def forward(self, x): |
| seq_len = x.shape[1] |
| if seq_len != self.seq_len_cached: |
| self.seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| freqs = torch.outer(t, self.inv_freq).to(x.device) |
| self.cos_cached = freqs.cos() |
| self.sin_cached = freqs.sin() |
| return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] |
|
|
| def apply_rotary_emb(x, cos, sin): |
| assert x.ndim == 4 |
| d = x.shape[3]//2 |
| x1 = x[..., :d] |
| x2 = x[..., d:] |
| y1 = x1 * cos + x2 * sin |
| y2 = x1 * (-sin) + x2 * cos |
| return torch.cat([y1, y2], 3) |
|
|
|
|
| class CasualSelfAttention(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
|
|
| self.c_attn = nn.Linear(config.n_embd, 3* config.n_embd) |
| |
| |
|
|
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.hada_scale = 1.0/ math.sqrt(config.n_embd) |
| self.rotary = Rotary(config.n_embd//config.n_head) |
| self.out_scale = nn.Parameter( |
| torch.ones(config.n_embd) / math.sqrt(2 * config.n_layer) |
| ) |
|
|
| self.out_bias = nn.Parameter(torch.zeros(config.n_embd)) |
|
|
|
|
| |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size)) |
|
|
|
|
| |
| def fused_hadamard_output(self, y, B, T, C): |
| """Fuse reshape, hadamard, scaling operations""" |
| y = y.reshape(-1, C) |
| y = hadamard_transform(y, scale = self.hada_scale) |
| y = y * self.out_scale |
| y = y + self.out_bias |
| return y.view(B, T, C) |
|
|
|
|
|
|
| def forward(self, x): |
| B, T, C = x.size() |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| k = k.view(B, T, self.n_head, C // self.n_head) |
| q = q.view(B, T, self.n_head, C // self.n_head) |
| v = v.view(B, T, self.n_head, C // self.n_head) |
|
|
| cos, sin = self.rotary(q) |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
|
|
| y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| |
| |
|
|
| y = self.fused_hadamard_output(y,B,T,C) |
| |
| |
| |
|
|
| |
|
|
| return y |
|
|
|
|
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| |
| inner_dim = int(4 * config.n_embd * 2 / 3) |
| |
|
|
| |
| |
| self.inner_dim = inner_dim |
| |
| self.c_fc = nn.Linear(config.n_embd, 2 * inner_dim) |
| self.c_proj = nn.Linear(inner_dim, config.n_embd) |
| self.c_proj.NANOGPT_SCALE_INIT = 1 |
|
|
| def forward(self, x): |
| |
| x_in = self.c_fc(x) |
| x_gate, x_up = x_in.chunk(2, dim=-1) |
| x = F.silu(x_gate) * x_up |
| x = self.c_proj(x) |
|
|
| return x |
|
|
|
|
|
|
| |
| class AttnBranch(nn.Module): |
| def __init__(self, norm, attn): |
| super().__init__() |
| self.norm = norm |
| self.attn = attn |
|
|
| def forward(self, x): |
| return self.attn(self.norm(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config, layer_idx, init_hc): |
| super().__init__() |
| self.ln_1 = nn.RMSNorm(config.n_embd,eps = 1e-6) |
| self.attn = CasualSelfAttention(config) |
| self.ln_2 = nn.RMSNorm(config.n_embd,eps = 1e-6) |
| self.mlp = MLP(config) |
| self.attn_branch = AttnBranch(self.ln_1, self.attn) |
|
|
| hc_kwargs = dict( |
| mhc=config.mhc, |
| sinkhorn_iters=config.sinkhorn_iters, |
| sinkhorn_tau=config.sinkhorn_tau, |
| mhc_h_res_proj=config.mhc_h_res_proj, |
| ns_steps=config.ns_steps, |
| ns_eps=config.ns_eps, |
| ns_coeffs=config.ns_coeffs, |
| ) |
|
|
| self.hc_attn = init_hc( |
| dim=config.n_embd, |
| branch=self.attn_branch, |
| layer_index=layer_idx * 2, |
| **hc_kwargs, |
| ) |
|
|
| self.hc_mlp = init_hc( |
| dim=config.n_embd, |
| branch=nn.Sequential(self.ln_2, self.mlp), |
| layer_index=layer_idx * 2 + 1, |
| **hc_kwargs, |
| ) |
|
|
| def forward(self, x): |
| x = self.hc_attn(x) |
| x = self.hc_mlp(x) |
| return x |
| class GPTConfig(PretrainedConfig): |
| model_type = "custom_gpt" |
|
|
| def __init__( |
| self, |
| block_size=1024, |
| vocab_size=50304, |
| n_layer=24, |
| n_head=16, |
| n_embd=1024, |
| dropout=0.0, |
| bias=True, |
| hc_num_streams=1, |
| hc_num_fracs=1, |
| hc_disable=False, |
| mhc=False, |
| sinkhorn_iters=10, |
| sinkhorn_tau=0.05, |
| mhc_h_res_proj="sinkhorn", |
| ns_steps=5, |
| ns_eps=1e-7, |
| ns_coeffs=(3.0, -3.2, 1.2), |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.block_size = block_size |
| self.vocab_size = vocab_size |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_embd = n_embd |
| self.dropout = dropout |
| self.bias = bias |
|
|
| self.hc_num_streams = hc_num_streams |
| self.hc_num_fracs = hc_num_fracs |
| self.hc_disable = hc_disable |
| self.mhc = mhc |
| self.sinkhorn_iters = sinkhorn_iters |
| self.sinkhorn_tau = sinkhorn_tau |
| self.mhc_h_res_proj = mhc_h_res_proj |
| self.ns_steps = ns_steps |
| self.ns_eps = ns_eps |
| self.ns_coeffs = ns_coeffs |
|
|
| |
| self.num_hidden_layers = n_layer |
| self.num_attention_heads = n_head |
| self.hidden_size = n_embd |
| self.max_position_embeddings = block_size |
|
|
| class GPT(PreTrainedModel, GenerationMixin): |
| config_class = GPTConfig |
| |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| init_hc, expand_stream, reduce_stream = ( |
| get_init_and_expand_reduce_stream_functions( |
| config.hc_num_streams, |
| num_fracs=config.hc_num_fracs, |
| disable=config.hc_disable, |
| ) |
| ) |
|
|
| self.expand_stream = expand_stream |
| self.reduce_stream = reduce_stream |
|
|
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding(config.vocab_size, config.n_embd), |
| h=nn.ModuleList( |
| [Block(config, i, init_hc) for i in range(config.n_layer)] |
| ), |
| ln_f = nn.RMSNorm(config.n_embd,eps = 1e-6) |
| ) |
| ) |
|
|
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.transformer.wte.weight = self.lm_head.weight |
|
|
| self.post_init() |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| **kwargs, |
| ): |
| |
| return { |
| "input_ids": input_ids, |
| "past_key_values": None, |
| } |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| labels=None, |
| past_key_values=None, |
| use_cache=None, |
| **kwargs, |
| ): |
|
|
|
|
| b, t = input_ids.size() |
| assert t <= self.config.block_size |
|
|
| pos = torch.arange(0, t, device=input_ids.device).unsqueeze(0) |
|
|
| x = self.transformer.wte(input_ids) |
| x = self.expand_stream(x) |
|
|
| for block in self.transformer.h: |
| x = block(x) |
|
|
| x = self.transformer.ln_f(x) |
| x = self.reduce_stream(x) |
|
|
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ) |
|
|
| return CausalLMOutput( |
| loss=loss, |
| logits=logits, |
| ) |
|
|