JiRackTernary_7b / JiRackTernaryPyTorch_7b.py
kgrabko's picture
Upload 8 files
2b24c37 verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
import torch.utils.checkpoint
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class JiRack7BConfig(PretrainedConfig):
model_type = "jirack_ternary_7b_full"
def __init__(
self,
vocab_size=128256,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
intermediate_size=11008,
max_position_embeddings=8192, # Увеличено для RoPE Scaling
rope_theta=10000.0,
rope_scaling={"type": "dynamic", "factor": 2.0},
rms_norm_eps=1e-5,
dropout_rate=0.0,
window_size=512, # SWA Window
author="Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025",
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rms_norm_eps = rms_norm_eps
self.dropout_rate = dropout_rate
self.window_size = window_size
self.author = author
class BitLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False, num_layers=32):
super().__init__(in_features, out_features, bias)
std = 0.02 / math.sqrt(2 * num_layers)
nn.init.normal_(self.weight, mean=0.0, std=std)
def forward(self, x):
w = self.weight
gamma = w.abs().mean() + 1e-9
w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
w_final = w + (w_quant * gamma - w).detach()
x_norm = x - x.mean(dim=-1, keepdim=True)
x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
return F.linear(x_quant, w_final, self.bias)
class PhaserizationLayer(nn.Module):
def __init__(self, dim):
super().__init__()
self.phase_shift = nn.Parameter(torch.zeros(dim))
def forward(self, x):
magnitude = torch.norm(x, dim=-1, keepdim=True)
phase = torch.atan2(x, x.roll(1, -1) + 1e-6) + self.phase_shift
return magnitude * torch.cos(phase)
class SignatureLayer(nn.Module):
def __init__(self, dim, author_name):
super().__init__()
self.gate = nn.Parameter(torch.ones(dim))
seed = sum(ord(c) for c in author_name)
torch.manual_seed(seed)
self.signage = nn.Parameter(torch.randn(dim, dim) * 0.01)
def forward(self, x):
sig = torch.tanh(F.linear(x, self.signage))
return x * torch.sigmoid(self.gate) + sig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def apply_rope_scaling(xq, xk, freqs_cis):
xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class JiRackAttention7B(nn.Module):
def __init__(self, config: JiRack7BConfig):
super().__init__()
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = BitLinear(config.hidden_size, config.hidden_size)
self.k_proj = BitLinear(config.hidden_size, config.hidden_size)
self.v_proj = BitLinear(config.hidden_size, config.hidden_size)
self.out_proj = BitLinear(config.hidden_size, config.hidden_size)
self.phaser = PhaserizationLayer(config.hidden_size)
self.window_size = config.window_size # Sliding Window Feature
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
B, T, D = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope_scaling(q, k, freqs_cis[pos_offset : pos_offset + T])
if past_kv is not None:
pk, pv = past_kv
k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
new_kv = (k.detach(), v.detach())
# SWA Masking Logic
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** -0.5)
mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1)
# Apply Sliding Window constraint to mask
if T > self.window_size:
mask = mask.to(torch.float32)
for row in range(T):
mask[row, :max(0, k.size(2) - T + row - self.window_size)] = float('-inf')
attn_weights = F.softmax((attn_weights + mask.unsqueeze(0).unsqueeze(0)).float(), dim=-1).type_as(x)
out = torch.matmul(attn_weights, v).transpose(1, 2).reshape(B, T, D)
return self.phaser(self.out_proj(out)), new_kv
class SwiGLU7B(nn.Module):
def __init__(self, config: JiRack7BConfig):
super().__init__()
self.w1 = BitLinear(config.hidden_size, config.intermediate_size)
self.w3 = BitLinear(config.hidden_size, config.intermediate_size)
self.w2 = BitLinear(config.intermediate_size, config.hidden_size)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock7B(nn.Module):
def __init__(self, config: JiRack7BConfig):
super().__init__()
self.attn = JiRackAttention7B(config)
self.ffn = SwiGLU7B(config)
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.signature = SignatureLayer(config.hidden_size, author_name=config.author)
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
x = x + h
x = self.signature(x + self.ffn(self.norm2(x)))
return x, new_kv
class JiRackTernary7B(PreTrainedModel):
config_class = JiRack7BConfig
def __init__(self, config: JiRack7BConfig):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.blocks = nn.ModuleList([TransformerBlock7B(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# RoPE Precomputation with Scaling factor
self.register_buffer("freqs_cis", self._precompute_freqs(config), persistent=False)
self.register_buffer("proof_of_authorship", torch.tensor([ord(c) for c in config.author], dtype=torch.uint8))
self.post_init()
self.lm_head.weight = self.token_emb.weight
def _precompute_freqs(self, config):
dim = config.hidden_size // config.num_attention_heads
theta = config.rope_theta
if config.rope_scaling:
theta *= config.rope_scaling.get("factor", 1.0)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(config.max_position_embeddings).float()
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def get_author_info(self):
return "".join([chr(c) for c in self.proof_of_authorship.tolist()])
def forward(self, input_ids, labels=None, past_key_values=None, **kwargs):
x = self.token_emb(input_ids)
pos_offset = past_key_values[0][0].size(2) if past_key_values else 0
new_kvs = []
for i, block in enumerate(self.blocks):
x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
new_kvs.append(kv)
logits = self.lm_head(self.ln_f(x))
loss = None
if labels is not None:
loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs)