kgrabko commited on
Commit
36ea9e0
·
verified ·
1 Parent(s): f466bb8

Create JiRackTernaryPyTorch_1b_HF_SPEC.py

Browse files
Files changed (1) hide show
  1. JiRackTernaryPyTorch_1b_HF_SPEC.py +196 -0
JiRackTernaryPyTorch_1b_HF_SPEC.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ # ==============================================================================
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import PreTrainedModel, PretrainedConfig # added for HF
10
+
11
+ # Contact for JiRack Signature Layer
12
+ VOCAB_SIZE = 128256
13
+ HIDDEN_SIZE = 2048
14
+ NUM_LAYERS = 16
15
+ NUM_HEADS = 32
16
+ NUM_KV_HEADS = 8
17
+ INTERMEDIATE_SIZE = 8192
18
+ MAX_SEQ_LEN = 4096
19
+ RMS_EPS = 1e-6
20
+
21
+ TERNARY_MIN = -1
22
+ TERNARY_MAX = 1
23
+ INT8_MIN = -128
24
+ INT8_MAX = 127
25
+ INT8_SCALE_TARGET = 127.0
26
+ STABILITY_EPS = 1e-9
27
+
28
+ class JiRackTernaryConfig(PretrainedConfig):
29
+ """Configuration to registration Hugging Face ecosystem """
30
+ model_type = "jirack_ternary"
31
+ def __init__(
32
+ self,
33
+ vocab_size=VOCAB_SIZE,
34
+ hidden_size=HIDDEN_SIZE,
35
+ num_hidden_layers=NUM_LAYERS,
36
+ num_attention_heads=NUM_HEADS,
37
+ num_key_value_heads=NUM_KV_HEADS,
38
+ intermediate_size=INTERMEDIATE_SIZE,
39
+ max_position_embeddings=MAX_SEQ_LEN,
40
+ rms_norm_eps=RMS_EPS,
41
+ **kwargs
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.vocab_size = vocab_size
45
+ self.hidden_size = hidden_size
46
+ self.num_hidden_layers = num_hidden_layers
47
+ self.num_attention_heads = num_attention_heads
48
+ self.num_key_value_heads = num_key_value_heads
49
+ self.intermediate_size = intermediate_size
50
+ self.max_position_embeddings = max_position_embeddings
51
+ self.rms_norm_eps = rms_norm_eps
52
+
53
+ # Old class exist TernaryConfig to be compateble old ML scripts
54
+ class TernaryConfig:
55
+ def __init__(self):
56
+ self.vocab_size = VOCAB_SIZE
57
+ self.hidden_size = HIDDEN_SIZE
58
+ self.num_hidden_layers = NUM_LAYERS
59
+ self.num_attention_heads = NUM_HEADS
60
+ self.num_key_value_heads = NUM_KV_HEADS
61
+ self.intermediate_size = INTERMEDIATE_SIZE
62
+ self.max_position_embeddings = MAX_SEQ_LEN
63
+ self.rms_norm_eps = RMS_EPS
64
+
65
+ class BitLinear(nn.Linear):
66
+ def __init__(self, in_features, out_features, bias=False):
67
+ super().__init__(in_features, out_features, bias)
68
+
69
+ def forward(self, x):
70
+ w = self.weight
71
+ gamma = w.abs().mean().clamp(min=STABILITY_EPS)
72
+ w_quant = torch.clamp(torch.round(w / gamma), TERNARY_MIN, TERNARY_MAX)
73
+ w_final = w + (w_quant * gamma - w).detach()
74
+
75
+ x_norm = x - x.mean(dim=-1, keepdim=True)
76
+ x_max = x_norm.abs().max(dim=-1, keepdim=True).values.clamp(min=STABILITY_EPS)
77
+ scale = INT8_SCALE_TARGET / x_max
78
+ x_quant = (x_norm * scale).round().clamp(INT8_MIN, INT8_MAX) / scale
79
+ x_final = x + (x_quant - x).detach()
80
+
81
+ return F.linear(x_final, w_final, self.bias)
82
+
83
+ class RMSNorm(nn.Module):
84
+ def __init__(self, dim, eps=RMS_EPS):
85
+ super().__init__()
86
+ self.eps = eps
87
+ self.weight = nn.Parameter(torch.ones(dim))
88
+
89
+ def forward(self, x):
90
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
91
+
92
+ def precompute_freqs_cis(dim, seq_len):
93
+ base = 10000.0
94
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
95
+ t = torch.arange(seq_len).float()
96
+ freqs = torch.outer(t, freqs)
97
+ return torch.polar(torch.ones_like(freqs), freqs)
98
+
99
+ def apply_rotary_emb(xq, xk, freqs_cis):
100
+ xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
101
+ xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
102
+ freqs_cis = freqs_cis.to(xq_f.device)[None, None, :xq_f.shape[2], :]
103
+ xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
104
+ xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
105
+ return xq_out.type_as(xq), xk_out.type_as(xk)
106
+
107
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
108
+ if n_rep == 1: return x
109
+ bs, n_kv_heads, seqlen, head_dim = x.shape
110
+ return (
111
+ x[:, :, None, :, :]
112
+ .expand(bs, n_kv_heads, n_rep, seqlen, head_dim)
113
+ .reshape(bs, n_kv_heads * n_rep, seqlen, head_dim)
114
+ )
115
+
116
+ class TransformerBlock(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.n_heads = config.num_attention_heads
120
+ self.n_kv_heads = config.num_key_value_heads
121
+ self.n_rep = self.n_heads // self.n_kv_heads
122
+ self.head_dim = config.hidden_size // self.n_heads
123
+
124
+ self.q_proj = BitLinear(config.hidden_size, config.hidden_size)
125
+ self.k_proj = BitLinear(config.hidden_size, self.n_kv_heads * self.head_dim)
126
+ self.v_proj = BitLinear(config.hidden_size, self.n_kv_heads * self.head_dim)
127
+ self.out_proj = BitLinear(config.hidden_size, config.hidden_size)
128
+
129
+ self.ffn_w1 = BitLinear(config.hidden_size, config.intermediate_size)
130
+ self.ffn_w3 = BitLinear(config.hidden_size, config.intermediate_size)
131
+ self.ffn_w2 = BitLinear(config.intermediate_size, config.hidden_size)
132
+
133
+ self.norm1 = RMSNorm(config.hidden_size)
134
+ self.norm2 = RMSNorm(config.hidden_size)
135
+
136
+ def forward(self, x, freqs_cis):
137
+ B, T, D = x.shape
138
+ h = self.norm1(x)
139
+
140
+ q = self.q_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
141
+ k = self.k_proj(h).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
142
+ v = self.v_proj(h).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
143
+
144
+ q, k = apply_rotary_emb(q, k, freqs_cis)
145
+
146
+ k = repeat_kv(k, self.n_rep)
147
+ v = repeat_kv(v, self.n_rep)
148
+
149
+ attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
150
+
151
+ attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
152
+ x = x + self.out_proj(attn_out)
153
+
154
+ m = self.norm2(x)
155
+ x = x + self.ffn_w2(F.silu(self.ffn_w1(m)) * self.ffn_w3(m))
156
+ return x
157
+
158
+ class TernaryTransformer1B(PreTrainedModel):
159
+ config_class = JiRackTernaryConfig
160
+
161
+ def __init__(self, config):
162
+ # if old ojbect came TernaryConfig, then convert it to JiRackTernaryConfig
163
+ if not isinstance(config, PretrainedConfig):
164
+ config = JiRackTernaryConfig()
165
+
166
+ super().__init__(config)
167
+ self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
168
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
169
+ self.ln_f = RMSNorm(config.hidden_size)
170
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
171
+
172
+ # Reg buffer for RoPE
173
+ self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, MAX_SEQ_LEN))
174
+
175
+ # init for HF
176
+ self.post_init()
177
+
178
+ def forward(self, input_ids, labels=None, **kwargs):
179
+ x = self.token_emb(input_ids)
180
+ for block in self.blocks:
181
+ x = block(x, self.freqs_cis)
182
+
183
+ logits = self.lm_head(self.ln_f(x))
184
+
185
+ loss = None
186
+ if labels is not None:
187
+ # Shift so that tokens < n predict n
188
+ shift_logits = logits[..., :-1, :].contiguous()
189
+ shift_labels = labels[..., 1:].contiguous()
190
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
191
+
192
+ return (loss, logits) if loss is not None else (logits, None)
193
+
194
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
195
+ # minimal model.generate()
196
+ return {"input_ids": input_ids}