|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple
|
| import math
|
| import torch.utils.checkpoint
|
| from transformers import PreTrainedModel, PretrainedConfig
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
| class JiRack7BConfig(PretrainedConfig):
|
| model_type = "jirack_ternary_7b_full"
|
| def __init__(
|
| self,
|
| vocab_size=128256,
|
| hidden_size=4096,
|
| num_hidden_layers=32,
|
| num_attention_heads=32,
|
| intermediate_size=11008,
|
| max_position_embeddings=8192,
|
| rope_theta=10000.0,
|
| rope_scaling={"type": "dynamic", "factor": 2.0},
|
| rms_norm_eps=1e-5,
|
| dropout_rate=0.0,
|
| window_size=512,
|
| author="Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025",
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
| self.vocab_size = vocab_size
|
| self.hidden_size = hidden_size
|
| self.num_hidden_layers = num_hidden_layers
|
| self.num_attention_heads = num_attention_heads
|
| self.intermediate_size = intermediate_size
|
| self.max_position_embeddings = max_position_embeddings
|
| self.rope_theta = rope_theta
|
| self.rope_scaling = rope_scaling
|
| self.rms_norm_eps = rms_norm_eps
|
| self.dropout_rate = dropout_rate
|
| self.window_size = window_size
|
| self.author = author
|
|
|
| class BitLinear(nn.Linear):
|
| def __init__(self, in_features, out_features, bias=False, num_layers=32):
|
| super().__init__(in_features, out_features, bias)
|
| std = 0.02 / math.sqrt(2 * num_layers)
|
| nn.init.normal_(self.weight, mean=0.0, std=std)
|
|
|
| def forward(self, x):
|
| w = self.weight
|
| gamma = w.abs().mean() + 1e-9
|
| w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
|
| w_final = w + (w_quant * gamma - w).detach()
|
| x_norm = x - x.mean(dim=-1, keepdim=True)
|
| x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
|
| return F.linear(x_quant, w_final, self.bias)
|
|
|
| class PhaserizationLayer(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.phase_shift = nn.Parameter(torch.zeros(dim))
|
| def forward(self, x):
|
| magnitude = torch.norm(x, dim=-1, keepdim=True)
|
| phase = torch.atan2(x, x.roll(1, -1) + 1e-6) + self.phase_shift
|
| return magnitude * torch.cos(phase)
|
|
|
| class SignatureLayer(nn.Module):
|
| def __init__(self, dim, author_name):
|
| super().__init__()
|
| self.gate = nn.Parameter(torch.ones(dim))
|
| seed = sum(ord(c) for c in author_name)
|
| torch.manual_seed(seed)
|
| self.signage = nn.Parameter(torch.randn(dim, dim) * 0.01)
|
| def forward(self, x):
|
| sig = torch.tanh(F.linear(x, self.signage))
|
| return x * torch.sigmoid(self.gate) + sig
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, dim: int, eps: float = 1e-5):
|
| super().__init__()
|
| self.eps = eps
|
| self.weight = nn.Parameter(torch.ones(dim))
|
| def forward(self, x):
|
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
|
|
| def apply_rope_scaling(xq, xk, freqs_cis):
|
| xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
|
| xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
|
| xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
|
| return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
| class JiRackAttention7B(nn.Module):
|
| def __init__(self, config: JiRack7BConfig):
|
| super().__init__()
|
| self.n_heads = config.num_attention_heads
|
| self.head_dim = config.hidden_size // config.num_attention_heads
|
| self.q_proj = BitLinear(config.hidden_size, config.hidden_size)
|
| self.k_proj = BitLinear(config.hidden_size, config.hidden_size)
|
| self.v_proj = BitLinear(config.hidden_size, config.hidden_size)
|
| self.out_proj = BitLinear(config.hidden_size, config.hidden_size)
|
| self.phaser = PhaserizationLayer(config.hidden_size)
|
| self.window_size = config.window_size
|
|
|
| def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| B, T, D = x.shape
|
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
| q, k = apply_rope_scaling(q, k, freqs_cis[pos_offset : pos_offset + T])
|
|
|
| if past_kv is not None:
|
| pk, pv = past_kv
|
| k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
|
| v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
|
|
|
| new_kv = (k.detach(), v.detach())
|
|
|
|
|
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** -0.5)
|
| mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1)
|
|
|
| if T > self.window_size:
|
| mask = mask.to(torch.float32)
|
| for row in range(T):
|
| mask[row, :max(0, k.size(2) - T + row - self.window_size)] = float('-inf')
|
|
|
| attn_weights = F.softmax((attn_weights + mask.unsqueeze(0).unsqueeze(0)).float(), dim=-1).type_as(x)
|
| out = torch.matmul(attn_weights, v).transpose(1, 2).reshape(B, T, D)
|
| return self.phaser(self.out_proj(out)), new_kv
|
|
|
| class SwiGLU7B(nn.Module):
|
| def __init__(self, config: JiRack7BConfig):
|
| super().__init__()
|
| self.w1 = BitLinear(config.hidden_size, config.intermediate_size)
|
| self.w3 = BitLinear(config.hidden_size, config.intermediate_size)
|
| self.w2 = BitLinear(config.intermediate_size, config.hidden_size)
|
| def forward(self, x):
|
| return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
| class TransformerBlock7B(nn.Module):
|
| def __init__(self, config: JiRack7BConfig):
|
| super().__init__()
|
| self.attn = JiRackAttention7B(config)
|
| self.ffn = SwiGLU7B(config)
|
| self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| self.signature = SignatureLayer(config.hidden_size, author_name=config.author)
|
|
|
| def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
| h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
|
| x = x + h
|
| x = self.signature(x + self.ffn(self.norm2(x)))
|
| return x, new_kv
|
|
|
| class JiRackTernary7B(PreTrainedModel):
|
| config_class = JiRack7BConfig
|
| def __init__(self, config: JiRack7BConfig):
|
| super().__init__(config)
|
| self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
|
| self.blocks = nn.ModuleList([TransformerBlock7B(config) for _ in range(config.num_hidden_layers)])
|
| self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| self.register_buffer("freqs_cis", self._precompute_freqs(config), persistent=False)
|
| self.register_buffer("proof_of_authorship", torch.tensor([ord(c) for c in config.author], dtype=torch.uint8))
|
|
|
| self.post_init()
|
| self.lm_head.weight = self.token_emb.weight
|
|
|
| def _precompute_freqs(self, config):
|
| dim = config.hidden_size // config.num_attention_heads
|
| theta = config.rope_theta
|
| if config.rope_scaling:
|
| theta *= config.rope_scaling.get("factor", 1.0)
|
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| t = torch.arange(config.max_position_embeddings).float()
|
| freqs = torch.outer(t, freqs)
|
| return torch.polar(torch.ones_like(freqs), freqs)
|
|
|
| def get_author_info(self):
|
| return "".join([chr(c) for c in self.proof_of_authorship.tolist()])
|
|
|
| def forward(self, input_ids, labels=None, past_key_values=None, **kwargs):
|
| x = self.token_emb(input_ids)
|
| pos_offset = past_key_values[0][0].size(2) if past_key_values else 0
|
| new_kvs = []
|
| for i, block in enumerate(self.blocks):
|
| x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
|
| new_kvs.append(kv)
|
|
|
| logits = self.lm_head(self.ln_f(x))
|
| loss = None
|
| if labels is not None:
|
| loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
|
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs) |