OpenRAG128 commited on
Commit
e4af6d2
·
verified ·
1 Parent(s): 2eb484e

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +62 -3
  2. config.json +14 -0
  3. model.py +220 -0
  4. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,62 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ language: en
4
+ license: apache-2.0
5
+ tags:
6
+ - efficient-llm
7
+ - quantization
8
+ - ternary
9
+ - bitnet
10
+ - pytorch
11
+ - tinystories
12
+ datasets:
13
+ - roneneldan/TinyStories
14
+ arxiv: 2602.07374
15
+ ---
16
+
17
+ # TernaryLM-132M
18
+
19
+ TernaryLM-132M is a 132M parameter Transformer trained natively using ternary weights {-1, 0, +1}.
20
+
21
+ Unlike post-training quantization methods, this model learns quantized representations during training.
22
+
23
+ ## Architecture
24
+
25
+ - Parameters: 132M
26
+ - Layers: 12
27
+ - Hidden Size: 768
28
+ - Attention Heads: 12
29
+ - Context Length: 512
30
+ - Quantization: Native Ternary Training
31
+
32
+ ## Training
33
+
34
+ - Dataset: TinyStories (~60k stories)
35
+ - Optimizer: AdamW (betas=(0.9, 0.98))
36
+ - LR: 3e-4
37
+ - Scheduler: OneCycleLR
38
+ - Epochs: 15
39
+ - Hardware: Multi-GPU T4 setup (Kaggle)
40
+
41
+ ## Intended Use
42
+
43
+ Research on:
44
+ - Efficient Transformers
45
+ - Quantization-aware training
46
+ - Edge deployment
47
+
48
+ ## Limitations
49
+
50
+ - Not instruction-tuned
51
+ - Limited dataset scale
52
+ - Research prototype
53
+
54
+ ## Citation
55
+
56
+ @article{ternarylm2026,
57
+ title={TernaryLM: Native 1-Bit Transformer Training},
58
+ author={Your Name},
59
+ year={2026},
60
+ eprint={2602.07374},
61
+ archivePrefix={arXiv}
62
+ }
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "ternarylm",
3
+ "vocab_size": 30522,
4
+ "hidden_size": 768,
5
+ "num_hidden_layers": 12,
6
+ "num_attention_heads": 12,
7
+ "max_position_embeddings": 512,
8
+ "quantization": "native ternary {-1,0,+1}",
9
+ "training_dataset": "roneneldan/TinyStories",
10
+ "epochs": 15,
11
+ "optimizer": "AdamW",
12
+ "learning_rate": 0.0003,
13
+ "scheduler": "OneCycleLR"
14
+ }
model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class RoPEPositionalEncoding(nn.Module):
8
+ def __init__(self, dim, max_len=2048):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
13
+ self.register_buffer("inv_freq", inv_freq)
14
+
15
+ self._cached_cos = None
16
+ self._cached_sin = None
17
+ self._cached_len = 0
18
+
19
+ def _compute_cache(self, seq_len, device):
20
+ if seq_len > self._cached_len or (
21
+ self._cached_cos is not None and self._cached_cos.device != device
22
+ ):
23
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
24
+ inv_freq = self.inv_freq.to(device)
25
+ freqs = torch.outer(t, inv_freq)
26
+ emb = torch.cat((freqs, freqs), dim=-1)
27
+
28
+ self._cached_cos = emb.cos()
29
+ self._cached_sin = emb.sin()
30
+ self._cached_len = seq_len
31
+
32
+ return (
33
+ self._cached_cos[:seq_len].to(device),
34
+ self._cached_sin[:seq_len].to(device),
35
+ )
36
+
37
+ def rotate_half(self, x):
38
+ x1 = x[..., : x.shape[-1] // 2]
39
+ x2 = x[..., x.shape[-1] // 2 :]
40
+ return torch.cat((-x2, x1), dim=-1)
41
+
42
+ def apply_rope(self, q, k, seq_len):
43
+ cos, sin = self._compute_cache(seq_len, q.device)
44
+ cos = cos.unsqueeze(0).unsqueeze(0)
45
+ sin = sin.unsqueeze(0).unsqueeze(0)
46
+
47
+ q = (q * cos) + (self.rotate_half(q) * sin)
48
+ k = (k * cos) + (self.rotate_half(k) * sin)
49
+
50
+ return q, k
51
+
52
+
53
+ class BitLinearFunction(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, input, weight, bias=None):
56
+ scale = 127.0 / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
57
+ x_quant = (input * scale).round().clamp(-128, 127) / scale
58
+
59
+ w_scale = weight.abs().mean().clamp(min=1e-5)
60
+ w_quant = (weight / w_scale).round().clamp(-1, 1) * w_scale
61
+
62
+ ctx.save_for_backward(input, weight)
63
+ ctx.w_quant = w_quant
64
+
65
+ return F.linear(x_quant, w_quant, bias)
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad_output):
69
+ input, weight = ctx.saved_tensors
70
+ w_quant = ctx.w_quant
71
+
72
+ grad_input = grad_output.matmul(w_quant)
73
+
74
+ grad_output_flat = grad_output.view(-1, grad_output.shape[-1])
75
+ input_flat = input.view(-1, input.shape[-1])
76
+ grad_weight = grad_output_flat.t().mm(input_flat)
77
+
78
+ grad_bias = None
79
+ if ctx.needs_input_grad[2]:
80
+ grad_bias = grad_output_flat.sum(0)
81
+
82
+ return grad_input, grad_weight, grad_bias
83
+
84
+
85
+ class RigorousBitLinear(nn.Module):
86
+ def __init__(self, in_features, out_features, bias=False):
87
+ super().__init__()
88
+ self.weight = nn.Parameter(torch.randn(out_features, in_features))
89
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
90
+
91
+ def forward(self, x):
92
+ return BitLinearFunction.apply(x, self.weight, self.bias)
93
+
94
+
95
+ class RMSNorm(nn.Module):
96
+ def __init__(self, dim, eps=1e-6):
97
+ super().__init__()
98
+ self.eps = eps
99
+ self.weight = nn.Parameter(torch.ones(dim))
100
+
101
+ def forward(self, x):
102
+ normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
103
+ return normed * self.weight
104
+
105
+
106
+ class ImprovedBitAttention(nn.Module):
107
+ def __init__(self, dim, heads=8, dropout=0.1, max_len=2048):
108
+ super().__init__()
109
+ self.heads = heads
110
+ self.head_dim = dim // heads
111
+ self.scale = self.head_dim ** -0.5
112
+
113
+ self.q_proj = RigorousBitLinear(dim, dim)
114
+ self.k_proj = RigorousBitLinear(dim, dim)
115
+ self.v_proj = RigorousBitLinear(dim, dim)
116
+ self.out_proj = RigorousBitLinear(dim, dim)
117
+
118
+ self.rope = RoPEPositionalEncoding(self.head_dim, max_len)
119
+ self.dropout = nn.Dropout(dropout)
120
+
121
+ def forward(self, x):
122
+ B, L, D = x.shape
123
+
124
+ q = self.q_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
125
+ k = self.k_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
126
+ v = self.v_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
127
+
128
+ q, k = self.rope.apply_rope(q, k, L)
129
+
130
+ attn = (q @ k.transpose(-2, -1)) * self.scale
131
+
132
+ mask = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool))
133
+ attn = attn.masked_fill(~mask, float("-inf"))
134
+
135
+ attn = F.softmax(attn, dim=-1)
136
+ attn = self.dropout(attn)
137
+
138
+ out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D)
139
+ return self.out_proj(out)
140
+
141
+
142
+
143
+ class SwiGLUMLP(nn.Module):
144
+ def __init__(self, dim, expansion=2.67, dropout=0.1):
145
+ super().__init__()
146
+ hidden = int(dim * expansion)
147
+
148
+ # IMPORTANT: keep original names
149
+ self.gate_proj = RigorousBitLinear(dim, hidden)
150
+ self.up_proj = RigorousBitLinear(dim, hidden)
151
+ self.down_proj = RigorousBitLinear(hidden, dim)
152
+
153
+ self.dropout = nn.Dropout(dropout)
154
+
155
+ def forward(self, x):
156
+ gate = F.silu(self.gate_proj(x))
157
+ up = self.up_proj(x)
158
+ return self.down_proj(self.dropout(gate * up))
159
+
160
+
161
+
162
+ class ImprovedBitBlock(nn.Module):
163
+ def __init__(self, dim, heads=8, dropout=0.1, max_len=2048):
164
+ super().__init__()
165
+ self.norm1 = RMSNorm(dim)
166
+ self.attn = ImprovedBitAttention(dim, heads, dropout, max_len)
167
+ self.norm2 = RMSNorm(dim)
168
+ self.mlp = SwiGLUMLP(dim, dropout=dropout)
169
+
170
+ def forward(self, x):
171
+ x = x + self.attn(self.norm1(x))
172
+ x = x + self.mlp(self.norm2(x))
173
+ return x
174
+
175
+
176
+ class ImprovedBitNet(nn.Module):
177
+ def __init__(
178
+ self,
179
+ vocab_size: int = 30522,
180
+ dim: int = 768,
181
+ depth: int = 12,
182
+ heads: int = 12,
183
+ max_len: int = 512,
184
+ dropout: float = 0.05,
185
+ ):
186
+ super().__init__()
187
+
188
+ self.vocab_size = vocab_size
189
+ self.dim = dim
190
+ self.depth = depth
191
+
192
+ # Token embedding
193
+ self.token_emb = nn.Embedding(vocab_size, dim)
194
+
195
+ # Transformer blocks
196
+ self.blocks = nn.ModuleList(
197
+ [
198
+ ImprovedBitBlock(
199
+ dim=dim,
200
+ heads=heads,
201
+ dropout=dropout,
202
+ max_len=max_len,
203
+ )
204
+ for _ in range(depth)
205
+ ]
206
+ )
207
+
208
+ # Final normalization + LM head
209
+ self.norm = RMSNorm(dim)
210
+ self.head = nn.Linear(dim, vocab_size)
211
+
212
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
213
+ x = self.token_emb(x)
214
+
215
+ for block in self.blocks:
216
+ x = block(x)
217
+
218
+ x = self.norm(x)
219
+ logits = self.head(x)
220
+ return logits
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d501c2a4a2a373bd46722fca989887b6ffa88a387a6dec6d7f325b7fdfde12b
3
+ size 527699616