kgrabko commited on
Commit
ab8854e
·
verified ·
1 Parent(s): aa55f9d

Upload JiRackTernaryPyTorch_70b.py

Browse files
Files changed (1) hide show
  1. JiRackTernaryPyTorch_70b.py +193 -0
JiRackTernaryPyTorch_70b.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ #
5
+ # This software is licensed under the Commercial License Agreement V.1.2.
6
+ # Any use, modification, or distribution of this code requires compliance with
7
+ # the terms found in the LICENSE.md file in the root directory.
8
+ #
9
+ # NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
10
+ # based on the BRE or SWA architectures disclosed herein.
11
+ # Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
12
+ # ==============================================================================
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import Optional, List, Tuple, Union
19
+ import math
20
+ import torch.utils.checkpoint
21
+ from transformers import PreTrainedModel, PretrainedConfig
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+
24
+ class JiRackTernaryConfig(PretrainedConfig):
25
+ model_type = "jirack_ternary_transformer"
26
+ def __init__(
27
+ self,
28
+ vocab_size=128256,
29
+ hidden_size=8192,
30
+ num_hidden_layers=80,
31
+ num_attention_heads=64,
32
+ intermediate_size=28672,
33
+ max_position_embeddings=4096,
34
+ rms_norm_eps=1e-5,
35
+ dropout_rate=0.0,
36
+ window_size=2048,
37
+ author="Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025",
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.intermediate_size = intermediate_size
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.rms_norm_eps = rms_norm_eps
48
+ self.dropout_rate = dropout_rate
49
+ self.window_size = window_size
50
+ self.author = author
51
+
52
+ class SignatureLayer(nn.Module):
53
+ def __init__(self, dim, author_name):
54
+ super().__init__()
55
+ self.gate = nn.Parameter(torch.ones(dim))
56
+ seed = sum(ord(c) for c in author_name)
57
+ torch.manual_seed(seed)
58
+ self.signage_cms = nn.Parameter(torch.randn(dim, dim) * 0.005)
59
+ def forward(self, x):
60
+ sig = torch.tanh(F.linear(x, self.signage_cms))
61
+ return x * torch.sigmoid(self.gate) + sig
62
+
63
+ class PhaserizationLayer(nn.Module):
64
+ def __init__(self, dim):
65
+ super().__init__()
66
+ self.phase_shift = nn.Parameter(torch.zeros(dim))
67
+ def forward(self, x):
68
+ magnitude = torch.norm(x, dim=-1, keepdim=True)
69
+ phase = torch.atan2(x, x.roll(1, -1) + 1e-6) + self.phase_shift
70
+ return magnitude * torch.cos(phase)
71
+
72
+ class JiRackBitLinear(nn.Linear):
73
+ def __init__(self, in_features, out_features, bias=False, num_layers=80):
74
+ super().__init__(in_features, out_features, bias)
75
+ std = 0.02 / math.sqrt(2 * num_layers)
76
+ nn.init.normal_(self.weight, mean=0.0, std=std)
77
+ def forward(self, x):
78
+ w = self.weight
79
+ gamma = w.abs().mean() + 1e-9
80
+ w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
81
+ w_final = w + (w_quant * gamma - w).detach()
82
+ x_norm = x - x.mean(dim=-1, keepdim=True)
83
+ x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
84
+ return F.linear(x_quant, w_final, self.bias)
85
+
86
+ class RMSNorm(nn.Module):
87
+ def __init__(self, dim: int, eps: float = 1e-5):
88
+ super().__init__()
89
+ self.eps = eps
90
+ self.weight = nn.Parameter(torch.ones(dim))
91
+ def forward(self, x):
92
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) * self.weight
93
+
94
+ def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
95
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
96
+ t = torch.arange(seq_len).float()
97
+ freqs = torch.outer(t, freqs)
98
+ return torch.polar(torch.ones_like(freqs), freqs)
99
+
100
+ def apply_rotary_emb(xq, xk, freqs_cis):
101
+ xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
102
+ xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
103
+ freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
104
+ xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
105
+ xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
106
+ return xq_out.type_as(xq), xk_out.type_as(xk)
107
+
108
+ class JiRackAttention(nn.Module):
109
+ def __init__(self, config: JiRackTernaryConfig):
110
+ super().__init__()
111
+ self.n_heads = config.num_attention_heads
112
+ self.head_dim = config.hidden_size // config.num_attention_heads
113
+ self.q_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
114
+ self.k_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
115
+ self.v_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
116
+ self.out_proj = JiRackBitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
117
+ self.phaser = PhaserizationLayer(config.hidden_size)
118
+ self.scale = self.head_dim ** -0.5
119
+ self.window_size = config.window_size
120
+
121
+ def forward(self, x, freqs_cis, pos_offset, past_kv=None):
122
+ B, T, D = x.shape
123
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
124
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
125
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
126
+ q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
127
+ if past_kv is not None:
128
+ pk, pv = past_kv
129
+ k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
130
+ v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
131
+ new_kv = (k.detach(), v.detach())
132
+ attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
133
+ mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
134
+ attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
135
+ out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
136
+ return self.phaser(self.out_proj(out)), new_kv
137
+
138
+ class JiRackSwiGLU(nn.Module):
139
+ def __init__(self, config: JiRackTernaryConfig):
140
+ super().__init__()
141
+ self.w1 = JiRackBitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
142
+ self.w3 = JiRackBitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
143
+ self.w2 = JiRackBitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
144
+ def forward(self, x):
145
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
146
+
147
+ class JiRackBlock(nn.Module):
148
+ def __init__(self, config: JiRackTernaryConfig):
149
+ super().__init__()
150
+ self.attn = JiRackAttention(config)
151
+ self.ffn = JiRackSwiGLU(config)
152
+ self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
153
+ self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
154
+ self.signature = SignatureLayer(config.hidden_size, author_name=config.author)
155
+ self.dropout = nn.Dropout(config.dropout_rate)
156
+ def forward(self, x, freqs_cis, pos_offset, past_kv=None):
157
+ h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
158
+ x = x + self.dropout(h)
159
+ x = self.signature(x + self.dropout(self.ffn(self.norm2(x))))
160
+ return x, new_kv
161
+
162
+ class JiRackTernary70B(PreTrainedModel):
163
+ config_class = JiRackTernaryConfig
164
+ def __init__(self, config: JiRackTernaryConfig):
165
+ super().__init__(config)
166
+ self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
167
+ self.blocks = nn.ModuleList([JiRackBlock(config) for _ in range(config.num_hidden_layers)])
168
+ self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
169
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
170
+ self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
171
+ self.register_buffer("proof_of_authorship", torch.tensor([ord(c) for c in config.author], dtype=torch.uint8))
172
+ self.post_init()
173
+ self.lm_head.weight = self.token_emb.weight
174
+ self.gradient_checkpointing = False
175
+
176
+ def get_author_info(self):
177
+ return "".join([chr(c) for c in self.proof_of_authorship.tolist()])
178
+
179
+ def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
180
+ x = self.token_emb(input_ids)
181
+ pos_offset = past_key_values[0][0].size(2) if past_key_values else 0
182
+ new_kvs = []
183
+ for i, block in enumerate(self.blocks):
184
+ if self.gradient_checkpointing and self.training:
185
+ x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
186
+ else:
187
+ x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
188
+ if not self.training or past_key_values: new_kvs.append(kv)
189
+ logits = self.lm_head(self.ln_f(x))
190
+ loss = None
191
+ if labels is not None:
192
+ loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
193
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None)