JiRack_GPT5_3b / JiRackPyTorch_GPT5_class_3b.py
kgrabko's picture
Upload 2 files
0f01cb6 verified
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
"""
JiRackPyTorch 3B Model Definition - FINAL AUTHORIZED VERSION
Complete with SWA, RoPE Scaling, SwiGLU, and Authorship Verification.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
import math
# ==========================================
# Конфигурация JiRack 3B
# ==========================================
VOCAB_SIZE = 50304
MODEL_DIM = 3072
NUM_HEADS = 24
NUM_KV_HEADS = 8
NUM_LAYERS = 32
MAX_SEQ_LEN = 2048
FFN_HIDDEN_DIM = 8192
HEAD_DIM = MODEL_DIM // NUM_HEADS
WINDOW_SIZE = 512
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) * self.weight
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, xq_.size(1), 1, xq_.size(3))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class JiRackAttention(nn.Module):
def __init__(self):
super().__init__()
self.wq = nn.Linear(MODEL_DIM, NUM_HEADS * HEAD_DIM, bias=False)
self.wk = nn.Linear(MODEL_DIM, NUM_KV_HEADS * HEAD_DIM, bias=False)
self.wv = nn.Linear(MODEL_DIM, NUM_KV_HEADS * HEAD_DIM, bias=False)
self.wo = nn.Linear(NUM_HEADS * HEAD_DIM, MODEL_DIM, bias=False)
def forward(self, x, freqs_cis, past_kv=None):
b, l, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(b, l, NUM_HEADS, HEAD_DIM)
xk = xk.view(b, l, NUM_KV_HEADS, HEAD_DIM)
xv = xv.view(b, l, NUM_KV_HEADS, HEAD_DIM)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis[:l])
if past_kv is not None:
pk, pv = past_kv
xk = torch.cat([pk, xk], dim=1)
xv = torch.cat([pv, xv], dim=1)
if xk.size(1) > WINDOW_SIZE:
xk, xv = xk[:, -WINDOW_SIZE:], xv[:, -WINDOW_SIZE:]
full_kv = (xk, xv)
xq = xq.transpose(1, 2)
xk = xk.repeat_interleave(NUM_HEADS // NUM_KV_HEADS, dim=2).transpose(1, 2)
xv = xv.repeat_interleave(NUM_HEADS // NUM_KV_HEADS, dim=2).transpose(1, 2)
out = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
out = out.transpose(1, 2).contiguous().view(b, l, MODEL_DIM)
return self.wo(out), full_kv
class SwiGLU(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
self.w3 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
self.w2 = nn.Linear(FFN_HIDDEN_DIM, MODEL_DIM, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class JiRackPyTorch(nn.Module):
def __init__(self):
super().__init__()
self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
self.blocks = nn.ModuleList([nn.ModuleDict({
'norm1': RMSNorm(MODEL_DIM),
'attn': JiRackAttention(),
'norm2': RMSNorm(MODEL_DIM),
'ffn': SwiGLU()
}) for _ in range(NUM_LAYERS)])
self.norm_f = RMSNorm(MODEL_DIM)
self.head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
self.head.weight = self.token_emb.weight
self.register_buffer("freqs_cis", precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN * 2))
# Цифровая подпись автора
signature = "Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025"
self.register_buffer("proof_of_authorship", torch.tensor([ord(c) for c in signature], dtype=torch.uint8))
def get_author_info(self):
return "".join([chr(c) for c in self.proof_of_authorship.tolist()])
def forward(self, idx, targets=None, past_kv=None):
x = self.token_emb(idx)
new_kvs = []
for i, block in enumerate(self.blocks):
h, kv = block['attn'](block['norm1'](x), self.freqs_cis, past_kv[i] if past_kv else None)
x = x + h
x = x + block['ffn'](block['norm2'](x))
new_kvs.append(kv)
logits = self.head(self.norm_f(x))
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1)) if targets is not None else None
return logits, loss, new_kvs