suchirsalhan commited on
Commit
878a2a7
Β·
verified Β·
1 Parent(s): dc52d13

Fix: vocab_size=32000 (BPE base from model.vocab); top-level weights; all compat fixes

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. pico_decoder.py +41 -152
config.json CHANGED
@@ -19,5 +19,5 @@
19
  "dropout": 0.1,
20
  "torch_dtype": "float32",
21
  "transformers_version": "4.48.3",
22
- "vocab_size": 32768
23
  }
 
19
  "dropout": 0.1,
20
  "torch_dtype": "float32",
21
  "transformers_version": "4.48.3",
22
+ "vocab_size": 32000
23
  }
pico_decoder.py CHANGED
@@ -1,17 +1,11 @@
1
- # pico_decoder.py β€” BeetleLM HuggingFace wrapper
2
- # Source: pico-lm/pico-train (Apache 2.0)
3
- # Load: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
4
-
5
  from dataclasses import asdict
6
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
7
-
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from torch.nn.attention import SDPBackend, sdpa_kernel
12
  from transformers import PretrainedConfig, PreTrainedModel
13
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
14
-
15
  try:
16
  if TYPE_CHECKING:
17
  from src.config import ModelConfig
@@ -19,62 +13,44 @@ except ImportError:
19
  pass
20
 
21
 
22
- # ── RMSNorm ──────────────────────────────────────────────────────────────────
23
-
24
  class RMSNorm(torch.nn.Module):
25
  def __init__(self, config):
26
  super().__init__()
27
  self.eps = config.norm_eps
28
  self.weight = nn.Parameter(torch.ones(config.d_model))
29
-
30
  def _norm(self, x):
31
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
32
-
33
  def forward(self, x):
34
  return self._norm(x.float()).type_as(x) * self.weight
35
 
36
 
37
- # ── RoPE ─────────────────────────────────────────────────────────────────────
38
-
39
  class RoPE(nn.Module):
40
  _freqs_cis_tensor = None
41
-
42
  def __init__(self, config):
43
  super().__init__()
44
  self.theta = config.position_emb_theta
45
  self.dim = config.d_model // config.attention_n_heads
46
  if RoPE._freqs_cis_tensor is None:
47
  RoPE._freqs_cis_tensor = self._setup_freqs_cis(
48
- config.max_seq_len, self.theta, self.dim
49
- )
50
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
51
-
52
  @classmethod
53
  def _setup_freqs_cis(cls, seq_len, theta, dim):
54
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
55
  freqs = torch.outer(torch.arange(seq_len), _freqs)
56
  return torch.polar(torch.ones_like(freqs), freqs)
57
-
58
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
59
- _f = self._freqs_cis[start_pos:end_pos]
60
  ndim = len(input_shape)
61
- assert 0 <= 1 < ndim
62
- assert _f.shape == (input_shape[1], input_shape[-1])
63
- shape = [d if i == 1 or i == ndim - 1 else 1
64
- for i, d in enumerate(input_shape)]
65
- return _f.view(*shape)
66
-
67
  def forward(self, queries, keys, start_pos=0):
68
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
69
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
70
- input_shape = q_.shape
71
- fc = self.get_freqs_cis(input_shape, start_pos, start_pos + q_.shape[1])
72
- q_rot = torch.view_as_real(q_ * fc).flatten(3)
73
- k_rot = torch.view_as_real(k_ * fc).flatten(3)
74
- return q_rot.type_as(queries), k_rot.type_as(keys)
75
-
76
 
77
- # ── Attention ────────────────────────────────────────────────────────────────
78
 
79
  class Attention(nn.Module):
80
  def __init__(self, config):
@@ -83,72 +59,54 @@ class Attention(nn.Module):
83
  self.n_kv_heads = config.attention_n_kv_heads
84
  self.batch_size = config.batch_size
85
  self.max_seq_len = config.max_seq_len
86
- d = config.d_model
87
  self.head_dim = d // self.n_heads
88
  self.n_rep = self.n_heads // self.n_kv_heads
89
- self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
90
- self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
91
- self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
92
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
93
- self.rope = RoPE(config)
94
-
95
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
96
  bsz, seq_len, _ = input.shape
97
  queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
98
  keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
99
  values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
100
-
101
  start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
102
  queries, keys = self.rope(queries, keys, start_pos)
103
-
104
  if past_key_values is not None:
105
  keys = torch.cat([past_key_values[0], keys], dim=1)
106
  values = torch.cat([past_key_values[1], values], dim=1)
107
-
108
  cached_keys = keys if use_cache else None
109
  cached_values = values if use_cache else None
110
-
111
  queries = queries.transpose(1, 2)
112
  keys = keys.transpose(1, 2)
113
  values = values.transpose(1, 2)
114
-
115
  apply_gqa = self.n_rep > 1
116
  if apply_gqa and queries.device.type == "mps":
117
  keys = keys.repeat_interleave(self.n_rep, dim=-3)
118
  values = values.repeat_interleave(self.n_rep, dim=-3)
119
  apply_gqa = False
120
-
121
- # FIX: guard mask against None (happens during generation when seq_len==1)
122
  attn_mask = mask.to(queries.dtype) if mask is not None else None
123
-
124
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
125
  attn_output = F.scaled_dot_product_attention(
126
- queries.contiguous(),
127
- keys.contiguous(),
128
- values.contiguous(),
129
- attn_mask=attn_mask,
130
- enable_gqa=apply_gqa,
131
  )
132
-
133
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
134
  return self.o_proj(attn_output), (cached_keys, cached_values)
135
 
136
 
137
- # ── SwiGLU ───────────────────────────────────────────────────────────────────
138
-
139
  class SwiGLU(nn.Module):
140
  def __init__(self, config):
141
  super().__init__()
142
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
143
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
144
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
145
-
146
  def forward(self, x):
147
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
148
 
149
 
150
- # ── PicoDecoderBlock ─────────────────────────────────────────────────────────
151
-
152
  class PicoDecoderBlock(nn.Module):
153
  def __init__(self, config):
154
  super().__init__()
@@ -156,20 +114,13 @@ class PicoDecoderBlock(nn.Module):
156
  self.swiglu = SwiGLU(config)
157
  self.attention_norm = RMSNorm(config)
158
  self.swiglu_norm = RMSNorm(config)
159
-
160
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
161
  attention_output, cached_key_values = self.attention(
162
- self.attention_norm(input),
163
- mask=mask,
164
- past_key_values=past_key_values,
165
- use_cache=use_cache,
166
- )
167
- h = input + attention_output
168
- out = h + self.swiglu(self.swiglu_norm(h))
169
- return out, cached_key_values
170
-
171
 
172
- # ── PicoDecoder (standalone, used during training) ───────────────────────────
173
 
174
  class PicoDecoder(nn.Module):
175
  def __init__(self, model_config):
@@ -181,18 +132,14 @@ class PicoDecoder(nn.Module):
181
  self.output_norm = RMSNorm(model_config)
182
  self.de_embedding_proj = nn.Linear(
183
  model_config.d_model, model_config.vocab_size, bias=False)
184
-
185
  def convert_to_hf_model(self):
186
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
187
- hf_model = PicoDecoderHF(hf_config)
188
- hf_model.load_state_dict(self.state_dict())
189
- return hf_model
190
-
191
  def forward(self, input_ids, past_key_values=None, use_cache=False):
192
  seq_len = input_ids.shape[-1]
193
  h = self.embedding_proj(input_ids)
194
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
195
-
196
  mask = None
197
  if seq_len > 1:
198
  mask = torch.full((seq_len, seq_len), float("-inf"))
@@ -200,7 +147,6 @@ class PicoDecoder(nn.Module):
200
  if past_key_values is not None:
201
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
202
  mask = mask.to(h.device)
203
-
204
  cached_key_values = () if use_cache else None
205
  for idx, layer in enumerate(self.layers):
206
  layer_past = past_key_values[idx] if past_key_values is not None else None
@@ -208,43 +154,19 @@ class PicoDecoder(nn.Module):
208
  h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
209
  if use_cache:
210
  cached_key_values += (layer_cached,)
 
211
 
212
- h = self.output_norm(h)
213
- logits = self.de_embedding_proj(h).float()
214
- return logits, cached_key_values
215
-
216
-
217
- # ── PicoDecoderHFConfig ───────────────────────────────────────────────────────
218
 
219
  class PicoDecoderHFConfig(PretrainedConfig):
220
- """
221
- HuggingFace config for BeetleLM PicoDecoder.
222
-
223
- Defaults match generate_configs.py MODEL_BASE. vocab_size is overridden
224
- per-repo in config.json because the trainer sets it from the tokenizer.
225
- """
226
  model_type = "pico_decoder"
227
-
228
- def __init__(
229
- self,
230
- # Architecture β€” defaults from generate_configs.py MODEL_BASE
231
- n_layers = 14,
232
- d_model = 768,
233
- vocab_size = 32768, # overridden per-repo in config.json
234
- attention_n_heads = 12,
235
- attention_n_kv_heads = 1, # MQA
236
- max_seq_len = 512,
237
- batch_size = 64,
238
- position_emb_theta = 10000.0,
239
- activation_hidden_dim = 3072,
240
- norm_eps = 1e-5,
241
- dropout = 0.1,
242
- **kwargs,
243
- ):
244
- # FIX: guard against None/0/missing attention_n_kv_heads in old config.json
245
  if not attention_n_kv_heads:
246
  attention_n_kv_heads = attention_n_heads
247
-
248
  super().__init__(**kwargs)
249
  self.n_layers = n_layers
250
  self.d_model = d_model
@@ -257,72 +179,48 @@ class PicoDecoderHFConfig(PretrainedConfig):
257
  self.activation_hidden_dim = activation_hidden_dim
258
  self.norm_eps = norm_eps
259
  self.dropout = dropout
260
-
261
  @classmethod
262
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
263
  pico_config = cls(**config_dict)
264
  return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
265
- unused_kwargs = {
266
- k: v for k, v in kwargs.items() if not hasattr(pico_config, k)
267
- }
268
  if return_unused_kwargs:
269
  return pico_config, unused_kwargs
270
  return pico_config
271
-
272
  @classmethod
273
  def from_dataclass(cls, model_config):
274
  return cls.from_dict(asdict(model_config))
275
 
276
 
277
- # ── PicoDecoderHF ─────────────────────────────────────────────────────────────
278
-
279
  class PicoDecoderHF(PreTrainedModel):
280
  """
281
  HuggingFace wrapper for BeetleLM PicoDecoder.
282
-
283
- IMPORTANT β€” weights live at the TOP LEVEL (not under self.pico_decoder)
284
- because the trainer saves raw PicoDecoder state dicts:
285
- checkpoint keys: embedding_proj.weight, layers.0.attention.q_proj.weight ...
286
- A self.pico_decoder wrapper would expect pico_decoder.embedding_proj.weight
287
- which does not exist in any saved checkpoint.
288
  """
289
  config_class = PicoDecoderHFConfig
290
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
291
  _tied_weights_keys = []
292
 
293
- # FIX: transformers >= 4.38 calls .keys() on this β€” must return a dict
294
  @property
295
  def all_tied_weights_keys(self):
296
  return {}
297
 
298
  def __init__(self, config: PicoDecoderHFConfig):
299
  super().__init__(config)
300
- # FIX: top-level storage β€” matches raw PicoDecoder checkpoint keys
301
  self.embedding_proj = nn.Embedding(config.vocab_size, config.d_model)
302
  self.layers = nn.ModuleList(
303
  [PicoDecoderBlock(config) for _ in range(config.n_layers)])
304
  self.output_norm = RMSNorm(config)
305
- self.de_embedding_proj = nn.Linear(
306
- config.d_model, config.vocab_size, bias=False)
307
 
308
- def get_input_embeddings(self):
309
- return self.embedding_proj
310
 
311
- def set_input_embeddings(self, value):
312
- self.embedding_proj = value
313
-
314
- def forward(
315
- self,
316
- input_ids = None,
317
- past_key_values = None,
318
- use_cache = False,
319
- labels = None,
320
- **kwargs,
321
- ):
322
  seq_len = input_ids.shape[-1]
323
  h = self.embedding_proj(input_ids)
324
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
325
-
326
  mask = None
327
  if seq_len > 1:
328
  mask = torch.full((seq_len, seq_len), float("-inf"))
@@ -330,7 +228,6 @@ class PicoDecoderHF(PreTrainedModel):
330
  if past_key_values is not None:
331
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
332
  mask = mask.to(h.device)
333
-
334
  cached_key_values = () if use_cache else None
335
  for idx, layer in enumerate(self.layers):
336
  layer_past = past_key_values[idx] if past_key_values is not None else None
@@ -338,32 +235,24 @@ class PicoDecoderHF(PreTrainedModel):
338
  h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
339
  if use_cache:
340
  cached_key_values += (layer_cached,)
341
-
342
  logits = self.de_embedding_proj(self.output_norm(h)).float()
343
-
344
  loss = None
345
  if labels is not None:
346
- shift_logits = logits[:, :-1].contiguous()
347
- shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
348
  loss = F.cross_entropy(
349
- shift_logits.view(-1, self.config.vocab_size),
350
- shift_labels.view(-1),
351
  )
352
-
353
  if use_cache:
354
  return CausalLMOutputWithPast(
355
  loss=loss, logits=logits, past_key_values=cached_key_values)
356
  return CausalLMOutput(loss=loss, logits=logits)
357
 
358
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
359
- return {
360
- "input_ids": input_ids,
361
- "past_key_values": past_key_values,
362
- "use_cache": True,
363
- }
364
 
365
 
366
- # Auto-class registration
367
  PicoDecoderHFConfig.register_for_auto_class()
368
  PicoDecoderHF.register_for_auto_class("AutoModel")
369
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
1
  from dataclasses import asdict
2
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torch.nn.attention import SDPBackend, sdpa_kernel
7
  from transformers import PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
 
9
  try:
10
  if TYPE_CHECKING:
11
  from src.config import ModelConfig
 
13
  pass
14
 
15
 
 
 
16
  class RMSNorm(torch.nn.Module):
17
  def __init__(self, config):
18
  super().__init__()
19
  self.eps = config.norm_eps
20
  self.weight = nn.Parameter(torch.ones(config.d_model))
 
21
  def _norm(self, x):
22
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
23
  def forward(self, x):
24
  return self._norm(x.float()).type_as(x) * self.weight
25
 
26
 
 
 
27
  class RoPE(nn.Module):
28
  _freqs_cis_tensor = None
 
29
  def __init__(self, config):
30
  super().__init__()
31
  self.theta = config.position_emb_theta
32
  self.dim = config.d_model // config.attention_n_heads
33
  if RoPE._freqs_cis_tensor is None:
34
  RoPE._freqs_cis_tensor = self._setup_freqs_cis(
35
+ config.max_seq_len, self.theta, self.dim)
 
36
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
 
37
  @classmethod
38
  def _setup_freqs_cis(cls, seq_len, theta, dim):
39
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
40
  freqs = torch.outer(torch.arange(seq_len), _freqs)
41
  return torch.polar(torch.ones_like(freqs), freqs)
 
42
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
43
+ _f = self._freqs_cis[start_pos:end_pos]
44
  ndim = len(input_shape)
45
+ assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
46
+ return _f.view(*[d if i==1 or i==ndim-1 else 1 for i,d in enumerate(input_shape)])
 
 
 
 
47
  def forward(self, queries, keys, start_pos=0):
48
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
49
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
50
+ fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
51
+ return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
52
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
 
 
 
53
 
 
54
 
55
  class Attention(nn.Module):
56
  def __init__(self, config):
 
59
  self.n_kv_heads = config.attention_n_kv_heads
60
  self.batch_size = config.batch_size
61
  self.max_seq_len = config.max_seq_len
62
+ d = config.d_model
63
  self.head_dim = d // self.n_heads
64
  self.n_rep = self.n_heads // self.n_kv_heads
65
+ self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
66
+ self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
67
+ self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
68
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
69
+ self.rope = RoPE(config)
 
70
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
71
  bsz, seq_len, _ = input.shape
72
  queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
73
  keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
74
  values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
 
75
  start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
76
  queries, keys = self.rope(queries, keys, start_pos)
 
77
  if past_key_values is not None:
78
  keys = torch.cat([past_key_values[0], keys], dim=1)
79
  values = torch.cat([past_key_values[1], values], dim=1)
 
80
  cached_keys = keys if use_cache else None
81
  cached_values = values if use_cache else None
 
82
  queries = queries.transpose(1, 2)
83
  keys = keys.transpose(1, 2)
84
  values = values.transpose(1, 2)
 
85
  apply_gqa = self.n_rep > 1
86
  if apply_gqa and queries.device.type == "mps":
87
  keys = keys.repeat_interleave(self.n_rep, dim=-3)
88
  values = values.repeat_interleave(self.n_rep, dim=-3)
89
  apply_gqa = False
 
 
90
  attn_mask = mask.to(queries.dtype) if mask is not None else None
 
91
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
92
  attn_output = F.scaled_dot_product_attention(
93
+ queries.contiguous(), keys.contiguous(), values.contiguous(),
94
+ attn_mask=attn_mask, enable_gqa=apply_gqa,
 
 
 
95
  )
 
96
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
97
  return self.o_proj(attn_output), (cached_keys, cached_values)
98
 
99
 
 
 
100
  class SwiGLU(nn.Module):
101
  def __init__(self, config):
102
  super().__init__()
103
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
104
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
105
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
 
106
  def forward(self, x):
107
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
108
 
109
 
 
 
110
  class PicoDecoderBlock(nn.Module):
111
  def __init__(self, config):
112
  super().__init__()
 
114
  self.swiglu = SwiGLU(config)
115
  self.attention_norm = RMSNorm(config)
116
  self.swiglu_norm = RMSNorm(config)
 
117
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
118
  attention_output, cached_key_values = self.attention(
119
+ self.attention_norm(input), mask=mask,
120
+ past_key_values=past_key_values, use_cache=use_cache)
121
+ h = input + attention_output
122
+ return h + self.swiglu(self.swiglu_norm(h)), cached_key_values
 
 
 
 
 
123
 
 
124
 
125
  class PicoDecoder(nn.Module):
126
  def __init__(self, model_config):
 
132
  self.output_norm = RMSNorm(model_config)
133
  self.de_embedding_proj = nn.Linear(
134
  model_config.d_model, model_config.vocab_size, bias=False)
 
135
  def convert_to_hf_model(self):
136
+ hf = PicoDecoderHF(PicoDecoderHFConfig.from_dataclass(self.config))
137
+ hf.load_state_dict(self.state_dict())
138
+ return hf
 
 
139
  def forward(self, input_ids, past_key_values=None, use_cache=False):
140
  seq_len = input_ids.shape[-1]
141
  h = self.embedding_proj(input_ids)
142
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
 
143
  mask = None
144
  if seq_len > 1:
145
  mask = torch.full((seq_len, seq_len), float("-inf"))
 
147
  if past_key_values is not None:
148
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
149
  mask = mask.to(h.device)
 
150
  cached_key_values = () if use_cache else None
151
  for idx, layer in enumerate(self.layers):
152
  layer_past = past_key_values[idx] if past_key_values is not None else None
 
154
  h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
155
  if use_cache:
156
  cached_key_values += (layer_cached,)
157
+ return self.de_embedding_proj(self.output_norm(h)).float(), cached_key_values
158
 
 
 
 
 
 
 
159
 
160
  class PicoDecoderHFConfig(PretrainedConfig):
 
 
 
 
 
 
161
  model_type = "pico_decoder"
162
+ def __init__(self,
163
+ n_layers=14, d_model=768, vocab_size=32768,
164
+ attention_n_heads=12, attention_n_kv_heads=1,
165
+ max_seq_len=512, batch_size=64, position_emb_theta=10000.0,
166
+ activation_hidden_dim=3072, norm_eps=1e-5, dropout=0.1,
167
+ **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
168
  if not attention_n_kv_heads:
169
  attention_n_kv_heads = attention_n_heads
 
170
  super().__init__(**kwargs)
171
  self.n_layers = n_layers
172
  self.d_model = d_model
 
179
  self.activation_hidden_dim = activation_hidden_dim
180
  self.norm_eps = norm_eps
181
  self.dropout = dropout
 
182
  @classmethod
183
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
184
  pico_config = cls(**config_dict)
185
  return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
186
+ unused_kwargs = {k: v for k, v in kwargs.items() if not hasattr(pico_config, k)}
 
 
187
  if return_unused_kwargs:
188
  return pico_config, unused_kwargs
189
  return pico_config
 
190
  @classmethod
191
  def from_dataclass(cls, model_config):
192
  return cls.from_dict(asdict(model_config))
193
 
194
 
 
 
195
  class PicoDecoderHF(PreTrainedModel):
196
  """
197
  HuggingFace wrapper for BeetleLM PicoDecoder.
198
+ Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
 
 
 
 
 
199
  """
200
  config_class = PicoDecoderHFConfig
201
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
202
  _tied_weights_keys = []
203
 
 
204
  @property
205
  def all_tied_weights_keys(self):
206
  return {}
207
 
208
  def __init__(self, config: PicoDecoderHFConfig):
209
  super().__init__(config)
 
210
  self.embedding_proj = nn.Embedding(config.vocab_size, config.d_model)
211
  self.layers = nn.ModuleList(
212
  [PicoDecoderBlock(config) for _ in range(config.n_layers)])
213
  self.output_norm = RMSNorm(config)
214
+ self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
215
 
216
+ def get_input_embeddings(self): return self.embedding_proj
217
+ def set_input_embeddings(self, value): self.embedding_proj = value
218
 
219
+ def forward(self, input_ids=None, past_key_values=None,
220
+ use_cache=False, labels=None, **kwargs):
 
 
 
 
 
 
 
 
 
221
  seq_len = input_ids.shape[-1]
222
  h = self.embedding_proj(input_ids)
223
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
 
224
  mask = None
225
  if seq_len > 1:
226
  mask = torch.full((seq_len, seq_len), float("-inf"))
 
228
  if past_key_values is not None:
229
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
230
  mask = mask.to(h.device)
 
231
  cached_key_values = () if use_cache else None
232
  for idx, layer in enumerate(self.layers):
233
  layer_past = past_key_values[idx] if past_key_values is not None else None
 
235
  h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
236
  if use_cache:
237
  cached_key_values += (layer_cached,)
 
238
  logits = self.de_embedding_proj(self.output_norm(h)).float()
 
239
  loss = None
240
  if labels is not None:
 
 
241
  loss = F.cross_entropy(
242
+ logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
243
+ labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1).view(-1),
244
  )
 
245
  if use_cache:
246
  return CausalLMOutputWithPast(
247
  loss=loss, logits=logits, past_key_values=cached_key_values)
248
  return CausalLMOutput(loss=loss, logits=logits)
249
 
250
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
251
+ return {"input_ids": input_ids,
252
+ "past_key_values": past_key_values,
253
+ "use_cache": True}
 
 
254
 
255
 
 
256
  PicoDecoderHFConfig.register_for_auto_class()
257
  PicoDecoderHF.register_for_auto_class("AutoModel")
258
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")