kgrabko commited on
Commit
32254d6
·
verified ·
1 Parent(s): e291dcf

Create JiRackTernaryPyTorch_1b.py

Browse files
prepared_sft_data/JiRackTernaryPyTorch_1b.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ # ==============================================================================
5
+ #
6
+ # fixed RoPe
7
+ #
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ # --- JIRACK 1B ARCHITECTURE CONSTANTS ---
13
+ VOCAB_SIZE = 128256
14
+ HIDDEN_SIZE = 2048
15
+ NUM_LAYERS = 16
16
+ NUM_HEADS = 32
17
+ NUM_KV_HEADS = 8
18
+ INTERMEDIATE_SIZE = 8192
19
+ MAX_SEQ_LEN = 4096
20
+ RMS_EPS = 1e-6
21
+
22
+ # --- QUANTIZATION PARAMETERS ---
23
+ STABILITY_EPS = 1e-9
24
+ INT8_SCALE_TARGET = 127.0
25
+
26
+ class TernaryConfig:
27
+ def __init__(self):
28
+ self.vocab_size = VOCAB_SIZE
29
+ self.hidden_size = HIDDEN_SIZE
30
+ self.num_hidden_layers = NUM_LAYERS
31
+ self.num_attention_heads = NUM_HEADS
32
+ self.num_key_value_heads = NUM_KV_HEADS
33
+ self.intermediate_size = INTERMEDIATE_SIZE
34
+ self.max_position_embeddings = MAX_SEQ_LEN
35
+ self.rms_norm_eps = RMS_EPS
36
+
37
+ class BitLinear(nn.Linear):
38
+ def __init__(self, in_features, out_features, bias=False):
39
+ super().__init__(in_features, out_features, bias)
40
+
41
+ def forward(self, x):
42
+ # Weight Quantization
43
+ w = self.weight
44
+ gamma = w.abs().mean().clamp(min=STABILITY_EPS)
45
+ w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
46
+ w_final = w + (w_quant * gamma - w).detach()
47
+
48
+ # Activation Quantization (Absmax)
49
+ x_norm = x - x.mean(dim=-1, keepdim=True)
50
+ x_max = x_norm.abs().max(dim=-1, keepdim=True).values.clamp(min=STABILITY_EPS)
51
+ scale = INT8_SCALE_TARGET / x_max
52
+ x_quant = (x_norm * scale).round().clamp(-128, 127) / scale
53
+ x_final = x + (x_quant - x).detach()
54
+
55
+ return F.linear(x_final, w_final, self.bias)
56
+
57
+ class RMSNorm(nn.Module):
58
+ def __init__(self, dim, eps=RMS_EPS):
59
+ super().__init__()
60
+ self.eps = eps
61
+ self.weight = nn.Parameter(torch.ones(dim))
62
+ def forward(self, x):
63
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
64
+
65
+ # --- ROPE WITHOUT COMPLEX NUMBERS ---
66
+ def precompute_freqs_cis(dim, seq_len, theta=500000.0):
67
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
68
+ t = torch.arange(seq_len).float()
69
+ freqs = torch.outer(t, freqs)
70
+ return torch.cos(freqs), torch.sin(freqs)
71
+
72
+ def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
73
+ def rotate_half(x):
74
+ # Split 64 into two 32s
75
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
76
+ return torch.cat((-x2, x1), dim=-1)
77
+
78
+ T = xq.shape[2]
79
+ # FIX: Repeat frequencies (32 -> 64) to match head_dim
80
+ f_cos = freqs_cos[:T].to(xq.device).view(1, 1, T, -1).repeat(1, 1, 1, 2)
81
+ f_sin = freqs_sin[:T].to(xq.device).view(1, 1, T, -1).repeat(1, 1, 1, 2)
82
+
83
+ xq_out = (xq * f_cos) + (rotate_half(xq) * f_sin)
84
+ xk_out = (xk * f_cos) + (rotate_half(xk) * f_sin)
85
+ return xq_out, xk_out
86
+
87
+ def repeat_kv(x, n_rep):
88
+ if n_rep == 1: return x
89
+ bs, n_kv_heads, seqlen, head_dim = x.shape
90
+ return x[:, :, None, :, :].expand(bs, n_kv_heads, n_rep, seqlen, head_dim).reshape(bs, n_kv_heads * n_rep, seqlen, head_dim)
91
+
92
+ class TransformerBlock(nn.Module):
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.n_heads = config.num_attention_heads
96
+ self.n_kv_heads = config.num_key_value_heads
97
+ self.n_rep = self.n_heads // self.n_kv_heads
98
+ self.head_dim = config.hidden_size // self.n_heads
99
+ self.q_proj = BitLinear(config.hidden_size, config.hidden_size)
100
+ self.k_proj = BitLinear(config.hidden_size, self.n_kv_heads * self.head_dim)
101
+ self.v_proj = BitLinear(config.hidden_size, self.n_kv_heads * self.head_dim)
102
+ self.out_proj = BitLinear(config.hidden_size, config.hidden_size)
103
+ self.ffn_w1 = BitLinear(config.hidden_size, config.intermediate_size)
104
+ self.ffn_w3 = BitLinear(config.hidden_size, config.intermediate_size)
105
+ self.ffn_w2 = BitLinear(config.intermediate_size, config.hidden_size)
106
+ self.norm1 = RMSNorm(config.hidden_size)
107
+ self.norm2 = RMSNorm(config.hidden_size)
108
+
109
+ def forward(self, x, freqs_cos, freqs_sin):
110
+ h = self.norm1(x)
111
+ B, T, D = x.shape
112
+ q = self.q_proj(h).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
113
+ k = self.k_proj(h).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
114
+ v = self.v_proj(h).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
115
+
116
+ q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
117
+
118
+ k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
119
+ attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
120
+
121
+ x = x + self.out_proj(attn_out.transpose(1, 2).reshape(B, T, D))
122
+ m = self.norm2(x)
123
+ x = x + self.ffn_w2(F.silu(self.ffn_w1(m)) * self.ffn_w3(m))
124
+ return x
125
+
126
+ class TernaryTransformer1B(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
130
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
131
+ self.ln_f = RMSNorm(config.hidden_size)
132
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
133
+
134
+ # RoPE frequencies (64 head_dim -> 32 pairs)
135
+ cos, sin = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, MAX_SEQ_LEN)
136
+ self.register_buffer("freqs_cos", cos)
137
+ self.register_buffer("freqs_sin", sin)
138
+
139
+ def forward(self, input_ids):
140
+ x = self.token_emb(input_ids)
141
+ for block in self.blocks:
142
+ x = block(x, self.freqs_cos, self.freqs_sin)
143
+ return self.lm_head(self.ln_f(x)), None