Upload JiRackTernaryPyTorch_70b.py
Browse files- JiRackTernaryPyTorch_70b.py +193 -0
JiRackTernaryPyTorch_70b.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==============================================================================
|
| 2 |
+
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
|
| 3 |
+
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
|
| 4 |
+
#
|
| 5 |
+
# This software is licensed under the Commercial License Agreement V.1.2.
|
| 6 |
+
# Any use, modification, or distribution of this code requires compliance with
|
| 7 |
+
# the terms found in the LICENSE.md file in the root directory.
|
| 8 |
+
#
|
| 9 |
+
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
|
| 10 |
+
# based on the BRE or SWA architectures disclosed herein.
|
| 11 |
+
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from typing import Optional, List, Tuple, Union
|
| 19 |
+
import math
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 22 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 23 |
+
|
| 24 |
+
class JiRackTernaryConfig(PretrainedConfig):
|
| 25 |
+
model_type = "jirack_ternary_transformer"
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
vocab_size=128256,
|
| 29 |
+
hidden_size=8192,
|
| 30 |
+
num_hidden_layers=80,
|
| 31 |
+
num_attention_heads=64,
|
| 32 |
+
intermediate_size=28672,
|
| 33 |
+
max_position_embeddings=4096,
|
| 34 |
+
rms_norm_eps=1e-5,
|
| 35 |
+
dropout_rate=0.0,
|
| 36 |
+
window_size=2048,
|
| 37 |
+
author="Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025",
|
| 38 |
+
**kwargs
|
| 39 |
+
):
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
self.vocab_size = vocab_size
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.num_hidden_layers = num_hidden_layers
|
| 44 |
+
self.num_attention_heads = num_attention_heads
|
| 45 |
+
self.intermediate_size = intermediate_size
|
| 46 |
+
self.max_position_embeddings = max_position_embeddings
|
| 47 |
+
self.rms_norm_eps = rms_norm_eps
|
| 48 |
+
self.dropout_rate = dropout_rate
|
| 49 |
+
self.window_size = window_size
|
| 50 |
+
self.author = author
|
| 51 |
+
|
| 52 |
+
class SignatureLayer(nn.Module):
|
| 53 |
+
def __init__(self, dim, author_name):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.gate = nn.Parameter(torch.ones(dim))
|
| 56 |
+
seed = sum(ord(c) for c in author_name)
|
| 57 |
+
torch.manual_seed(seed)
|
| 58 |
+
self.signage_cms = nn.Parameter(torch.randn(dim, dim) * 0.005)
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
sig = torch.tanh(F.linear(x, self.signage_cms))
|
| 61 |
+
return x * torch.sigmoid(self.gate) + sig
|
| 62 |
+
|
| 63 |
+
class PhaserizationLayer(nn.Module):
|
| 64 |
+
def __init__(self, dim):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.phase_shift = nn.Parameter(torch.zeros(dim))
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
magnitude = torch.norm(x, dim=-1, keepdim=True)
|
| 69 |
+
phase = torch.atan2(x, x.roll(1, -1) + 1e-6) + self.phase_shift
|
| 70 |
+
return magnitude * torch.cos(phase)
|
| 71 |
+
|
| 72 |
+
class JiRackBitLinear(nn.Linear):
|
| 73 |
+
def __init__(self, in_features, out_features, bias=False, num_layers=80):
|
| 74 |
+
super().__init__(in_features, out_features, bias)
|
| 75 |
+
std = 0.02 / math.sqrt(2 * num_layers)
|
| 76 |
+
nn.init.normal_(self.weight, mean=0.0, std=std)
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
w = self.weight
|
| 79 |
+
gamma = w.abs().mean() + 1e-9
|
| 80 |
+
w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
|
| 81 |
+
w_final = w + (w_quant * gamma - w).detach()
|
| 82 |
+
x_norm = x - x.mean(dim=-1, keepdim=True)
|
| 83 |
+
x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
|
| 84 |
+
return F.linear(x_quant, w_final, self.bias)
|
| 85 |
+
|
| 86 |
+
class RMSNorm(nn.Module):
|
| 87 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.eps = eps
|
| 90 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) * self.weight
|
| 93 |
+
|
| 94 |
+
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
|
| 95 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 96 |
+
t = torch.arange(seq_len).float()
|
| 97 |
+
freqs = torch.outer(t, freqs)
|
| 98 |
+
return torch.polar(torch.ones_like(freqs), freqs)
|
| 99 |
+
|
| 100 |
+
def apply_rotary_emb(xq, xk, freqs_cis):
|
| 101 |
+
xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 102 |
+
xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 103 |
+
freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
|
| 104 |
+
xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
|
| 105 |
+
xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
|
| 106 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 107 |
+
|
| 108 |
+
class JiRackAttention(nn.Module):
|
| 109 |
+
def __init__(self, config: JiRackTernaryConfig):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.n_heads = config.num_attention_heads
|
| 112 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 113 |
+
self.q_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| 114 |
+
self.k_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| 115 |
+
self.v_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| 116 |
+
self.out_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| 117 |
+
self.phaser = PhaserizationLayer(config.hidden_size)
|
| 118 |
+
self.scale = self.head_dim ** -0.5
|
| 119 |
+
self.window_size = config.window_size
|
| 120 |
+
|
| 121 |
+
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| 122 |
+
B, T, D = x.shape
|
| 123 |
+
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 124 |
+
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 125 |
+
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 126 |
+
q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
|
| 127 |
+
if past_kv is not None:
|
| 128 |
+
pk, pv = past_kv
|
| 129 |
+
k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
|
| 130 |
+
v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
|
| 131 |
+
new_kv = (k.detach(), v.detach())
|
| 132 |
+
attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
|
| 133 |
+
mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
|
| 134 |
+
attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
|
| 135 |
+
out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
|
| 136 |
+
return self.phaser(self.out_proj(out)), new_kv
|
| 137 |
+
|
| 138 |
+
class JiRackSwiGLU(nn.Module):
|
| 139 |
+
def __init__(self, config: JiRackTernaryConfig):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.w1 = JiRackBitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
| 142 |
+
self.w3 = JiRackBitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
| 143 |
+
self.w2 = JiRackBitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 146 |
+
|
| 147 |
+
class JiRackBlock(nn.Module):
|
| 148 |
+
def __init__(self, config: JiRackTernaryConfig):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.attn = JiRackAttention(config)
|
| 151 |
+
self.ffn = JiRackSwiGLU(config)
|
| 152 |
+
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 153 |
+
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 154 |
+
self.signature = SignatureLayer(config.hidden_size, author_name=config.author)
|
| 155 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 156 |
+
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| 157 |
+
h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
|
| 158 |
+
x = x + self.dropout(h)
|
| 159 |
+
x = self.signature(x + self.dropout(self.ffn(self.norm2(x))))
|
| 160 |
+
return x, new_kv
|
| 161 |
+
|
| 162 |
+
class JiRackTernary70B(PreTrainedModel):
|
| 163 |
+
config_class = JiRackTernaryConfig
|
| 164 |
+
def __init__(self, config: JiRackTernaryConfig):
|
| 165 |
+
super().__init__(config)
|
| 166 |
+
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 167 |
+
self.blocks = nn.ModuleList([JiRackBlock(config) for _ in range(config.num_hidden_layers)])
|
| 168 |
+
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 169 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 170 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
|
| 171 |
+
self.register_buffer("proof_of_authorship", torch.tensor([ord(c) for c in config.author], dtype=torch.uint8))
|
| 172 |
+
self.post_init()
|
| 173 |
+
self.lm_head.weight = self.token_emb.weight
|
| 174 |
+
self.gradient_checkpointing = False
|
| 175 |
+
|
| 176 |
+
def get_author_info(self):
|
| 177 |
+
return "".join([chr(c) for c in self.proof_of_authorship.tolist()])
|
| 178 |
+
|
| 179 |
+
def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
|
| 180 |
+
x = self.token_emb(input_ids)
|
| 181 |
+
pos_offset = past_key_values[0][0].size(2) if past_key_values else 0
|
| 182 |
+
new_kvs = []
|
| 183 |
+
for i, block in enumerate(self.blocks):
|
| 184 |
+
if self.gradient_checkpointing and self.training:
|
| 185 |
+
x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
|
| 186 |
+
else:
|
| 187 |
+
x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
|
| 188 |
+
if not self.training or past_key_values: new_kvs.append(kv)
|
| 189 |
+
logits = self.lm_head(self.ln_f(x))
|
| 190 |
+
loss = None
|
| 191 |
+
if labels is not None:
|
| 192 |
+
loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
|
| 193 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None)
|