File size: 8,691 Bytes
bc0c3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with 
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Union
import math
import torch.utils.checkpoint
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

class TernaryConfig(PretrainedConfig):
    model_type = "ternary_transformer"
    def __init__(

        self,

        vocab_size=50257,

        hidden_size=3072,

        num_hidden_layers=24,

        num_attention_heads=32,

        intermediate_size=12288,

        max_position_embeddings=2048,

        rms_norm_eps=1e-6,

        dropout_rate=0.1,

        window_size=512,

        **kwargs

    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.dropout_rate = dropout_rate
        self.window_size = window_size

class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False, num_layers=24):
        super().__init__(in_features, out_features, bias)
        std = 0.02 / math.sqrt(2 * num_layers)
        nn.init.normal_(self.weight, mean=0.0, std=std)

    def forward(self, x):
        w = self.weight
        gamma = w.abs().mean() + 1e-9
        w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
        w_final = w + (w_quant * gamma - w).detach()
        x_norm = x - x.mean(dim=-1, keepdim=True)
        x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
        return F.linear(x_quant, w_final, self.bias)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return norm * self.weight

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(seq_len).float()
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
    xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class MultiHeadAttention(nn.Module):
    def __init__(self, config: TernaryConfig):
        super().__init__()
        self.n_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.q_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
        self.k_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
        self.v_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
        self.out_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
        self.scale = self.head_dim ** -0.5
        self.window_size = config.window_size

    def forward(self, x, freqs_cis, pos_offset, past_kv=None):
        B, T, D = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
        if past_kv is not None:
            pk, pv = past_kv
            k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
            v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
        new_kv = (k.detach(), v.detach())
        attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
        mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
        attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
        out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
        return self.out_proj(out), new_kv

class SwiGLUFeedForward(nn.Module):
    def __init__(self, config: TernaryConfig):
        super().__init__()
        self.w1 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
        self.w3 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
        self.w2 = BitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, config: TernaryConfig):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ffn = SwiGLUFeedForward(config)
        self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.dropout = nn.Dropout(config.dropout_rate)
    def forward(self, x, freqs_cis, pos_offset, past_kv=None):
        h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
        x = x + self.dropout(h)
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x, new_kv

class TernaryTransformer(PreTrainedModel):
    config_class = TernaryConfig
    supports_gradient_checkpointing = True
    def __init__(self, config: TernaryConfig):
        super().__init__(config)
        self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
        self.post_init()
        self.lm_head.weight = self.token_emb.weight
        self.gradient_checkpointing = False

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (TernaryTransformer, TransformerBlock)):
            self.gradient_checkpointing = value

    def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
        x = self.token_emb(input_ids)
        pos_offset = past_key_values[0][0].size(2) if past_key_values and past_key_values[0] is not None else 0
        new_kvs = []
        for i, block in enumerate(self.blocks):
            if self.gradient_checkpointing and self.training:
                x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
            else:
                x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
            if not self.training or past_key_values: new_kvs.append(kv)
        logits = self.lm_head(self.ln_f(x))
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None)