# ============================================================================== # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED. # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY # # This software is licensed under the Commercial License Agreement V.1.2. # Any use, modification, or distribution of this code requires compliance with # the terms found in the LICENSE.md file in the root directory. # # NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims # based on the BRE or SWA architectures disclosed herein. # Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945 # ============================================================================== import os import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Tuple, Union import math import torch.utils.checkpoint from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast class TernaryConfig(PretrainedConfig): model_type = "ternary_transformer" def __init__( self, vocab_size=50257, hidden_size=3072, num_hidden_layers=24, num_attention_heads=32, intermediate_size=12288, max_position_embeddings=2048, rms_norm_eps=1e-6, dropout_rate=0.1, window_size=512, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.dropout_rate = dropout_rate self.window_size = window_size class BitLinear(nn.Linear): def __init__(self, in_features, out_features, bias=False, num_layers=24): super().__init__(in_features, out_features, bias) std = 0.02 / math.sqrt(2 * num_layers) nn.init.normal_(self.weight, mean=0.0, std=std) def forward(self, x): w = self.weight gamma = w.abs().mean() + 1e-9 w_quant = torch.clamp(torch.round(w / gamma), -1, 1) w_final = w + (w_quant * gamma - w).detach() x_norm = x - x.mean(dim=-1, keepdim=True) x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach() return F.linear(x_quant, w_final, self.bias) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return norm * self.weight def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(seq_len).float() freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) def apply_rotary_emb(xq, xk, freqs_cis): xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :] xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class MultiHeadAttention(nn.Module): def __init__(self, config: TernaryConfig): super().__init__() self.n_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.q_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers) self.k_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers) self.v_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers) self.out_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers) self.scale = self.head_dim ** -0.5 self.window_size = config.window_size def forward(self, x, freqs_cis, pos_offset, past_kv=None): B, T, D = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T]) if past_kv is not None: pk, pv = past_kv k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:] v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:] new_kv = (k.detach(), v.detach()) attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale) mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0) attn = F.softmax((attn + mask).float(), dim=-1).type_as(x) out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D) return self.out_proj(out), new_kv class SwiGLUFeedForward(nn.Module): def __init__(self, config: TernaryConfig): super().__init__() self.w1 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers) self.w3 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers) self.w2 = BitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, config: TernaryConfig): super().__init__() self.attn = MultiHeadAttention(config) self.ffn = SwiGLUFeedForward(config) self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, x, freqs_cis, pos_offset, past_kv=None): h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv) x = x + self.dropout(h) x = x + self.dropout(self.ffn(self.norm2(x))) return x, new_kv class TernaryTransformer(PreTrainedModel): config_class = TernaryConfig supports_gradient_checkpointing = True def __init__(self, config: TernaryConfig): super().__init__(config) self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False) self.post_init() self.lm_head.weight = self.token_emb.weight self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (TernaryTransformer, TransformerBlock)): self.gradient_checkpointing = value def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs): x = self.token_emb(input_ids) pos_offset = past_key_values[0][0].size(2) if past_key_values and past_key_values[0] is not None else 0 new_kvs = [] for i, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False) else: x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None) if not self.training or past_key_values: new_kvs.append(kv) logits = self.lm_head(self.ln_f(x)) loss = None if labels is not None: loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None)