suchirsalhan commited on
Commit
29ecc5e
Β·
verified Β·
1 Parent(s): 7f2c8c5

Add pico_decoder.py + auto_map config (main)

Browse files
Files changed (2) hide show
  1. config.json +15 -2
  2. pico_decoder.py +250 -0
config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "architectures": [
3
- "PicoDecoderModel"
4
  ],
5
  "model_type": "pico_decoder",
6
  "vocab_size": 32000,
@@ -14,5 +14,18 @@
14
  "rms_norm_eps": 1e-05,
15
  "tie_word_embeddings": false,
16
  "torch_dtype": "float32",
17
- "transformers_version": "4.40.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  }
 
1
  {
2
  "architectures": [
3
+ "PicoDecoderHF"
4
  ],
5
  "model_type": "pico_decoder",
6
  "vocab_size": 32000,
 
14
  "rms_norm_eps": 1e-05,
15
  "tie_word_embeddings": false,
16
  "torch_dtype": "float32",
17
+ "transformers_version": "4.48.3",
18
+ "auto_map": {
19
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
20
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
21
+ },
22
+ "d_model": 768,
23
+ "n_layers": 14,
24
+ "attention_n_heads": 6,
25
+ "attention_n_kv_heads": 0,
26
+ "activation_hidden_dim": 3072,
27
+ "max_seq_len": 2048,
28
+ "norm_eps": 1e-06,
29
+ "position_emb_theta": 10000.0,
30
+ "batch_size": 1
31
  }
pico_decoder.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Pico Decoder β€” BeetleLM
4
+ Adapted from pico-lm/pico-decoder-tiny (Apache 2.0).
5
+ Load with trust_remote_code=True.
6
+ """
7
+
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.attention import SDPBackend, sdpa_kernel
14
+ from transformers import PretrainedConfig, PreTrainedModel
15
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
16
+
17
+
18
+ # ── RMSNorm ───────────────────────────────────────────────────────────────────
19
+
20
+ class RMSNorm(torch.nn.Module):
21
+ def __init__(self, config):
22
+ super().__init__()
23
+ self.eps = config.norm_eps
24
+ self.weight = nn.Parameter(torch.ones(config.d_model))
25
+
26
+ def _norm(self, x):
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+ def forward(self, x):
30
+ return self._norm(x.float()).type_as(x) * self.weight
31
+
32
+
33
+ # ── RoPE ──────────────────────────────────────────────────────────────────────
34
+
35
+ class RoPE(nn.Module):
36
+ _freqs_cis_tensor = None
37
+
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.theta = config.position_emb_theta
41
+ self.dim = config.d_model // config.attention_n_heads
42
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
43
+ config.max_seq_len, self.theta, self.dim
44
+ )
45
+ self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
46
+
47
+ @classmethod
48
+ def _setup_freqs_cis(cls, seq_len, theta, dim):
49
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
50
+ freqs = torch.outer(torch.arange(seq_len), _freqs)
51
+ return torch.polar(torch.ones_like(freqs), freqs)
52
+
53
+ def get_freqs_cis(self, input_shape, start_pos, end_pos):
54
+ _f = self._freqs_cis[start_pos:end_pos]
55
+ ndim = len(input_shape)
56
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
57
+ return _f.view(*shape)
58
+
59
+ def forward(self, queries, keys, start_pos=0):
60
+ q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
61
+ k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
62
+ fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
63
+ return (
64
+ torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
65
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys),
66
+ )
67
+
68
+
69
+ # ── Attention ─────────────────────────────────────────────────────────────────
70
+
71
+ class Attention(nn.Module):
72
+ def __init__(self, config):
73
+ super().__init__()
74
+ self.n_heads = config.attention_n_heads
75
+ self.n_kv_heads = config.attention_n_kv_heads
76
+ self.n_rep = self.n_heads // self.n_kv_heads
77
+ self.max_seq_len = config.max_seq_len
78
+ self.batch_size = config.batch_size
79
+ d = config.d_model
80
+ self.head_dim = d // self.n_heads
81
+ self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
82
+ self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
83
+ self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
84
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
85
+ self.rope = RoPE(config)
86
+
87
+ def forward(self, x, mask=None, past_key_values=None, use_cache=False):
88
+ bsz, seq_len, _ = x.shape
89
+ q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim)
90
+ k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
91
+ v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
92
+ start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
93
+ q, k = self.rope(q, k, start_pos)
94
+ if past_key_values is not None:
95
+ k = torch.cat([past_key_values[0], k], dim=1)
96
+ v = torch.cat([past_key_values[1], v], dim=1)
97
+ ck, cv = (k, v) if use_cache else (None, None)
98
+ q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
99
+ apply_gqa = self.n_rep > 1
100
+ if apply_gqa and q.device.type == "mps":
101
+ k = k.repeat_interleave(self.n_rep, dim=-3)
102
+ v = v.repeat_interleave(self.n_rep, dim=-3)
103
+ apply_gqa = False
104
+ with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
105
+ out = F.scaled_dot_product_attention(
106
+ q.contiguous(), k.contiguous(), v.contiguous(),
107
+ attn_mask=mask.to(q.dtype) if mask is not None else None,
108
+ enable_gqa=apply_gqa,
109
+ )
110
+ out = out.transpose(1,2).contiguous().view(bsz, seq_len, -1)
111
+ return self.o_proj(out), (ck, cv)
112
+
113
+
114
+ # ── SwiGLU ────────────────────────────────────────────────────────────────────
115
+
116
+ class SwiGLU(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
120
+ self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
121
+ self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
122
+
123
+ def forward(self, x):
124
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
125
+
126
+
127
+ # ── PicoDecoderBlock ──────────────────────────────────────────────────────────
128
+
129
+ class PicoDecoderBlock(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.attention = Attention(config)
133
+ self.swiglu = SwiGLU(config)
134
+ self.attention_norm = RMSNorm(config)
135
+ self.swiglu_norm = RMSNorm(config)
136
+
137
+ def forward(self, x, mask=None, past_key_values=None, use_cache=False):
138
+ attn_out, cached = self.attention(
139
+ self.attention_norm(x), mask=mask,
140
+ past_key_values=past_key_values, use_cache=use_cache,
141
+ )
142
+ h = x + attn_out
143
+ return h + self.swiglu(self.swiglu_norm(h)), cached
144
+
145
+
146
+ # ── PicoDecoder ───────────────────────────────────────────────────────────────
147
+
148
+ class PicoDecoder(nn.Module):
149
+ def __init__(self, model_config):
150
+ super().__init__()
151
+ self.config = model_config
152
+ self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
153
+ self.layers = nn.ModuleList([PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
154
+ self.output_norm = RMSNorm(model_config)
155
+ self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
156
+
157
+ def forward(self, input_ids, past_key_values=None, use_cache=False):
158
+ seq_len = input_ids.shape[-1]
159
+ h = self.embedding_proj(input_ids)
160
+ start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
161
+ mask = None
162
+ if seq_len > 1:
163
+ mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1)
164
+ if past_key_values is not None:
165
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
166
+ mask = mask.to(h.device)
167
+ cached_kvs = () if use_cache else None
168
+ for idx, layer in enumerate(self.layers):
169
+ layer_past = past_key_values[idx] if past_key_values is not None else None
170
+ h, layer_cached = layer(h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
171
+ if use_cache:
172
+ cached_kvs += (layer_cached,)
173
+ return self.de_embedding_proj(self.output_norm(h)).float(), cached_kvs
174
+
175
+
176
+ # ── HuggingFace Config ────────────────────────────────────────────────────────
177
+
178
+ class PicoDecoderHFConfig(PretrainedConfig):
179
+ model_type = "pico_decoder"
180
+
181
+ def __init__(
182
+ self,
183
+ vocab_size=32000,
184
+ d_model=256,
185
+ n_layers=6,
186
+ attention_n_heads=8,
187
+ attention_n_kv_heads=4,
188
+ activation_hidden_dim=1024,
189
+ max_seq_len=2048,
190
+ norm_eps=1e-6,
191
+ position_emb_theta=10000.0,
192
+ batch_size=1,
193
+ **kwargs,
194
+ ):
195
+ super().__init__(**kwargs)
196
+ self.vocab_size = vocab_size
197
+ self.d_model = d_model
198
+ self.n_layers = n_layers
199
+ self.attention_n_heads = attention_n_heads
200
+ self.attention_n_kv_heads = attention_n_kv_heads
201
+ self.activation_hidden_dim = activation_hidden_dim
202
+ self.max_seq_len = max_seq_len
203
+ self.norm_eps = norm_eps
204
+ self.position_emb_theta = position_emb_theta
205
+ self.batch_size = batch_size
206
+
207
+
208
+ # ── HuggingFace Model ─────────────────────────────────────────────────────────
209
+
210
+ class PicoDecoderHF(PreTrainedModel):
211
+ """
212
+ HuggingFace wrapper for BeetleLM PicoDecoder.
213
+ Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
214
+ """
215
+ config_class = PicoDecoderHFConfig
216
+ _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
217
+
218
+ def __init__(self, config: PicoDecoderHFConfig):
219
+ super().__init__(config)
220
+ self.pico_decoder = PicoDecoder(config)
221
+
222
+ def get_input_embeddings(self):
223
+ return self.pico_decoder.embedding_proj
224
+
225
+ def set_input_embeddings(self, value):
226
+ self.pico_decoder.embedding_proj = value
227
+
228
+ def forward(self, input_ids=None, past_key_values=None, use_cache=False, labels=None, **kwargs):
229
+ input_ids = input_ids.clamp(0, self.config.vocab_size - 1)
230
+ logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
231
+ loss = None
232
+ if labels is not None:
233
+ shift_logits = logits[:, :-1].contiguous()
234
+ shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
235
+ loss = F.cross_entropy(
236
+ shift_logits.view(-1, self.config.vocab_size),
237
+ shift_labels.view(-1),
238
+ )
239
+ if use_cache:
240
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
241
+ return CausalLMOutput(loss=loss, logits=logits)
242
+
243
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
244
+ return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
245
+
246
+
247
+ # Auto-class registration (runs on trust_remote_code import)
248
+ PicoDecoderHFConfig.register_for_auto_class()
249
+ PicoDecoderHF.register_for_auto_class("AutoModel")
250
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")