File size: 8,691 Bytes
bc0c3f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Union
import math
import torch.utils.checkpoint
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class TernaryConfig(PretrainedConfig):
model_type = "ternary_transformer"
def __init__(
self,
vocab_size=50257,
hidden_size=3072,
num_hidden_layers=24,
num_attention_heads=32,
intermediate_size=12288,
max_position_embeddings=2048,
rms_norm_eps=1e-6,
dropout_rate=0.1,
window_size=512,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.dropout_rate = dropout_rate
self.window_size = window_size
class BitLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False, num_layers=24):
super().__init__(in_features, out_features, bias)
std = 0.02 / math.sqrt(2 * num_layers)
nn.init.normal_(self.weight, mean=0.0, std=std)
def forward(self, x):
w = self.weight
gamma = w.abs().mean() + 1e-9
w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
w_final = w + (w_quant * gamma - w).detach()
x_norm = x - x.mean(dim=-1, keepdim=True)
x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
return F.linear(x_quant, w_final, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return norm * self.weight
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(seq_len).float()
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(xq, xk, freqs_cis):
xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiHeadAttention(nn.Module):
def __init__(self, config: TernaryConfig):
super().__init__()
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
self.k_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
self.v_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
self.out_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
self.scale = self.head_dim ** -0.5
self.window_size = config.window_size
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
B, T, D = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
if past_kv is not None:
pk, pv = past_kv
k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
new_kv = (k.detach(), v.detach())
attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
return self.out_proj(out), new_kv
class SwiGLUFeedForward(nn.Module):
def __init__(self, config: TernaryConfig):
super().__init__()
self.w1 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
self.w3 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
self.w2 = BitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, config: TernaryConfig):
super().__init__()
self.attn = MultiHeadAttention(config)
self.ffn = SwiGLUFeedForward(config)
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
x = x + self.dropout(h)
x = x + self.dropout(self.ffn(self.norm2(x)))
return x, new_kv
class TernaryTransformer(PreTrainedModel):
config_class = TernaryConfig
supports_gradient_checkpointing = True
def __init__(self, config: TernaryConfig):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
self.post_init()
self.lm_head.weight = self.token_emb.weight
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (TernaryTransformer, TransformerBlock)):
self.gradient_checkpointing = value
def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
x = self.token_emb(input_ids)
pos_offset = past_key_values[0][0].size(2) if past_key_values and past_key_values[0] is not None else 0
new_kvs = []
for i, block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
else:
x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
if not self.training or past_key_values: new_kvs.append(kv)
logits = self.lm_head(self.ln_f(x))
loss = None
if labels is not None:
loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None) |