| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| |
| VOCAB_SIZE = 128256 |
| MODEL_DIM = 12288 |
| NUM_HEADS = 96 |
| NUM_KV_HEADS = 8 |
| NUM_LAYERS = 160 |
| MAX_SEQ_LEN = 2048 |
| FFN_HIDDEN_DIM = 32768 |
| HEAD_DIM = MODEL_DIM // NUM_HEADS |
| EPSILON = 1e-5 |
|
|
| class JiRackTernaryLinear(nn.Module): |
| """ |
| CLAIM 1: Ternary-Quantized Optimization. |
| Implementation of weights restricted to {-1, 0, +1} with learnable Gamma scaling. |
| """ |
| def __init__(self, in_features, out_features, bias=False): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = nn.Parameter(torch.randn(out_features, in_features)) |
| self.gamma = nn.Parameter(torch.ones(1)) |
|
|
| def forward(self, x): |
| |
| w_centered = self.weight - self.weight.mean() |
| |
| |
| |
| w_quant = torch.sign(w_centered) |
| w_ternary = (w_quant - self.weight).detach() + self.weight |
| |
| |
| return F.linear(x, w_ternary) * self.gamma |
|
|
| class RMSNorm(nn.Module): |
| """Stable normalization for ultra-deep networks (100+ layers)""" |
| def __init__(self, dim, eps=EPSILON): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) * self.weight |
|
|
| def precompute_freqs_cis(dim, seq_len, theta=500000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(seq_len) |
| freqs = torch.outer(t, freqs).float() |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
| def apply_rotary_emb(xq, xk, freqs_cis): |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| freqs_cis = freqs_cis.view(1, xq_.size(1), 1, xq_.size(3)) |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
| class SWA_Fusion_Block(nn.Module): |
| """ |
| CLAIM 3: SwiGLU-Attention (SWA) Fusion. |
| Unified compute block to optimize HBM throughput and reduce thermal throttling. |
| """ |
| def __init__(self): |
| super().__init__() |
| self.n_rep = NUM_HEADS // NUM_KV_HEADS |
| |
| |
| self.wq = JiRackTernaryLinear(MODEL_DIM, NUM_HEADS * HEAD_DIM) |
| self.wk = JiRackTernaryLinear(MODEL_DIM, NUM_KV_HEADS * HEAD_DIM) |
| self.wv = JiRackTernaryLinear(MODEL_DIM, NUM_KV_HEADS * HEAD_DIM) |
| self.wo = JiRackTernaryLinear(NUM_HEADS * HEAD_DIM, MODEL_DIM) |
| |
| |
| self.w1 = JiRackTernaryLinear(MODEL_DIM, FFN_HIDDEN_DIM) |
| self.w2 = JiRackTernaryLinear(FFN_HIDDEN_DIM, MODEL_DIM) |
| self.w3 = JiRackTernaryLinear(MODEL_DIM, FFN_HIDDEN_DIM) |
|
|
| def forward(self, x, freqs_cis): |
| b, t, _ = x.shape |
| |
| |
| q, k, v = self.wq(x), self.wk(x), self.wv(x) |
| q, k = apply_rotary_emb(q.view(b, t, NUM_HEADS, HEAD_DIM), |
| k.view(b, t, NUM_KV_HEADS, HEAD_DIM), |
| freqs_cis[:t]) |
| |
| |
| k = k[:, :, :, None, :].expand(b, t, NUM_KV_HEADS, self.n_rep, HEAD_DIM).reshape(b, t, NUM_HEADS, HEAD_DIM) |
| v = v[:, :, :, None, :].expand(b, t, NUM_KV_HEADS, self.n_rep, HEAD_DIM).reshape(b, t, NUM_HEADS, HEAD_DIM) |
| |
| attn_out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) |
| attn_out = self.wo(attn_out.transpose(1, 2).contiguous().view(b, t, MODEL_DIM)) |
| |
| |
| ffn_out = self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| |
| return attn_out + ffn_out |
|
|
| class JiRackTernary236B(nn.Module): |
| """ |
| Main Engine: JiRack 236B (Ternary Extreme Edition) |
| Inventor/Architect: Konstantin Vladimirovich Grabko |
| """ |
| def __init__(self, config=None): |
| super().__init__() |
| |
| self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM) |
| |
| self.layers = nn.ModuleList([ |
| nn.ModuleDict({ |
| 'norm1': RMSNorm(MODEL_DIM), |
| 'swa': SWA_Fusion_Block(), |
| 'norm2': RMSNorm(MODEL_DIM) |
| }) for _ in range(NUM_LAYERS) |
| ]) |
| |
| self.norm_f = RMSNorm(MODEL_DIM) |
| self.head = JiRackTernaryLinear(MODEL_DIM, VOCAB_SIZE) |
| |
| self.register_buffer("freqs_cis", precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN)) |
| |
| |
| signature = "AUTHOR: KONSTANTIN VLADIMIROVICH GRABKO | CMS MANHATTAN 2025" |
| self.register_buffer("proof", torch.tensor([ord(c) for c in signature], dtype=torch.uint8)) |
|
|
| def forward(self, idx, targets=None): |
| |
| x = self.token_emb(idx) |
| |
| for layer in self.layers: |
| |
| x = x + layer['swa'](layer['norm1'](x), self.freqs_cis) |
| |
| x = self.norm_f(x) |
| logits = self.head(x) |
| |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1)) |
| return type('Outputs', (object,), {'logits': logits, 'loss': loss}) |
| return logits |
|
|
| def get_author_info(self): |
| """Extracts the proof of authorship signature from model buffers.""" |
| return "".join([chr(c) for c in self.proof.tolist()]) |
|
|
| class JiRackTernaryConfig: |
| def __init__(self, num_hidden_layers=NUM_LAYERS): |
| self.num_hidden_layers = num_hidden_layers |