suchirsalhan commited on
Commit
cc7e2c0
Β·
verified Β·
1 Parent(s): 97ebc3c

Fix: vocab_size=32768 from tokenizer; top-level weights; ZeroDivisionError guard; all_tied_weights_keys->dict

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. pico_decoder.py +220 -81
config.json CHANGED
@@ -9,7 +9,6 @@
9
  },
10
  "n_layers": 14,
11
  "d_model": 768,
12
- "vocab_size": 32768,
13
  "attention_n_heads": 12,
14
  "attention_n_kv_heads": 1,
15
  "max_seq_len": 512,
@@ -19,5 +18,6 @@
19
  "norm_eps": 1e-05,
20
  "dropout": 0.1,
21
  "torch_dtype": "float32",
22
- "transformers_version": "4.48.3"
 
23
  }
 
9
  },
10
  "n_layers": 14,
11
  "d_model": 768,
 
12
  "attention_n_heads": 12,
13
  "attention_n_kv_heads": 1,
14
  "max_seq_len": 512,
 
18
  "norm_eps": 1e-05,
19
  "dropout": 0.1,
20
  "torch_dtype": "float32",
21
+ "transformers_version": "4.48.3",
22
+ "vocab_size": 32768
23
  }
pico_decoder.py CHANGED
@@ -1,11 +1,17 @@
 
 
 
 
1
  from dataclasses import asdict
2
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torch.nn.attention import SDPBackend, sdpa_kernel
7
  from transformers import PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
 
9
  try:
10
  if TYPE_CHECKING:
11
  from src.config import ModelConfig
@@ -13,43 +19,63 @@ except ImportError:
13
  pass
14
 
15
 
 
 
16
  class RMSNorm(torch.nn.Module):
17
  def __init__(self, config):
18
  super().__init__()
19
- self.eps = config.norm_eps
20
  self.weight = nn.Parameter(torch.ones(config.d_model))
 
21
  def _norm(self, x):
22
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
23
  def forward(self, x):
24
  return self._norm(x.float()).type_as(x) * self.weight
25
 
26
 
 
 
27
  class RoPE(nn.Module):
28
  _freqs_cis_tensor = None
 
29
  def __init__(self, config):
30
  super().__init__()
31
  self.theta = config.position_emb_theta
32
  self.dim = config.d_model // config.attention_n_heads
33
  if RoPE._freqs_cis_tensor is None:
34
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(config.max_seq_len, self.theta, self.dim)
 
 
35
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
 
36
  @classmethod
37
  def _setup_freqs_cis(cls, seq_len, theta, dim):
38
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
39
- return torch.polar(torch.ones_like(f := torch.outer(torch.arange(seq_len), _freqs)), f)
 
 
40
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
41
- _f = self._freqs_cis[start_pos:end_pos]
42
  ndim = len(input_shape)
43
- assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
44
- return _f.view(*[d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)])
 
 
 
 
45
  def forward(self, queries, keys, start_pos=0):
46
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
47
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
48
- fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
49
- return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
50
- torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
 
 
51
 
52
 
 
 
53
  class Attention(nn.Module):
54
  def __init__(self, config):
55
  super().__init__()
@@ -57,39 +83,58 @@ class Attention(nn.Module):
57
  self.n_kv_heads = config.attention_n_kv_heads
58
  self.batch_size = config.batch_size
59
  self.max_seq_len = config.max_seq_len
60
- d = config.d_model
61
  self.head_dim = d // self.n_heads
62
  self.n_rep = self.n_heads // self.n_kv_heads
63
- self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
64
- self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
65
- self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
66
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
67
- self.rope = RoPE(config)
 
68
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
69
  bsz, seq_len, _ = input.shape
70
- q = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
71
- k = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
72
- v = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
73
- sp = past_key_values[0].shape[1] if past_key_values is not None else 0
74
- q, k = self.rope(q, k, sp)
 
 
75
  if past_key_values is not None:
76
- k = torch.cat([past_key_values[0], k], dim=1)
77
- v = torch.cat([past_key_values[1], v], dim=1)
78
- ck, cv = (k, v) if use_cache else (None, None)
79
- q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
80
- gqa = self.n_rep > 1
81
- if gqa and q.device.type == "mps":
82
- k = k.repeat_interleave(self.n_rep, dim=-3)
83
- v = v.repeat_interleave(self.n_rep, dim=-3)
84
- gqa = False
 
 
 
 
 
 
 
 
 
 
85
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
86
- out = F.scaled_dot_product_attention(
87
- q.contiguous(), k.contiguous(), v.contiguous(),
88
- attn_mask=mask.to(q.dtype) if mask is not None else None,
89
- enable_gqa=gqa,
 
 
90
  )
91
- return self.o_proj(out.transpose(1,2).contiguous().view(bsz, seq_len, -1)), (ck, cv)
92
 
 
 
 
 
 
93
 
94
  class SwiGLU(nn.Module):
95
  def __init__(self, config):
@@ -97,10 +142,13 @@ class SwiGLU(nn.Module):
97
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
98
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
99
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
 
100
  def forward(self, x):
101
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
102
 
103
 
 
 
104
  class PicoDecoderBlock(nn.Module):
105
  def __init__(self, config):
106
  super().__init__()
@@ -108,58 +156,95 @@ class PicoDecoderBlock(nn.Module):
108
  self.swiglu = SwiGLU(config)
109
  self.attention_norm = RMSNorm(config)
110
  self.swiglu_norm = RMSNorm(config)
 
111
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
112
- a, c = self.attention(self.attention_norm(input), mask=mask,
113
- past_key_values=past_key_values, use_cache=use_cache)
114
- h = input + a
115
- return h + self.swiglu(self.swiglu_norm(h)), c
 
 
 
 
 
116
 
117
 
 
 
118
  class PicoDecoder(nn.Module):
119
  def __init__(self, model_config):
120
  super().__init__()
121
- self.config = model_config
122
  self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
123
- self.layers = nn.ModuleList([PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
 
124
  self.output_norm = RMSNorm(model_config)
125
- self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
 
 
126
  def convert_to_hf_model(self):
127
- hf = PicoDecoderHF(PicoDecoderHFConfig.from_dataclass(self.config))
128
- hf.load_state_dict(self.state_dict(prefix="pico_decoder."))
129
- return hf
 
 
130
  def forward(self, input_ids, past_key_values=None, use_cache=False):
131
- sl = input_ids.shape[-1]
132
- h = self.embedding_proj(input_ids)
133
- sp = 0 if past_key_values is None else past_key_values[0][0].shape[1]
 
134
  mask = None
135
- if sl > 1:
136
- mask = torch.triu(torch.full((sl, sl), float("-inf")), diagonal=1)
 
137
  if past_key_values is not None:
138
- mask = torch.hstack([torch.zeros((sl, sp)), mask])
139
  mask = mask.to(h.device)
140
- ckv = () if use_cache else None
141
- for i, layer in enumerate(self.layers):
142
- lp = past_key_values[i] if past_key_values is not None else None
143
- h, lc = layer(h, mask=mask, past_key_values=lp, use_cache=use_cache)
 
 
144
  if use_cache:
145
- ckv += (lc,)
146
- return self.de_embedding_proj(self.output_norm(h)).float(), ckv
 
 
 
147
 
148
 
149
- # ── HuggingFace wrappers ──────────────────────────────────────────────────────
150
 
151
  class PicoDecoderHFConfig(PretrainedConfig):
 
 
 
 
 
 
152
  model_type = "pico_decoder"
153
 
154
- # FIX 1 + 2: explicit __init__ with MODEL_BASE defaults; guards None/0 kv_heads
155
- def __init__(self,
156
- n_layers=14, d_model=768, vocab_size=32768,
157
- attention_n_heads=12, attention_n_kv_heads=1,
158
- max_seq_len=512, batch_size=64, position_emb_theta=10000.0,
159
- activation_hidden_dim=3072, norm_eps=1e-5, dropout=0.1,
160
- **kwargs):
161
- if not attention_n_kv_heads: # catches None, 0, missing
 
 
 
 
 
 
 
 
 
 
162
  attention_n_kv_heads = attention_n_heads
 
163
  super().__init__(**kwargs)
164
  self.n_layers = n_layers
165
  self.d_model = d_model
@@ -177,7 +262,9 @@ class PicoDecoderHFConfig(PretrainedConfig):
177
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
178
  pico_config = cls(**config_dict)
179
  return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
180
- unused_kwargs = {k: v for k, v in kwargs.items() if not hasattr(pico_config, k)}
 
 
181
  if return_unused_kwargs:
182
  return pico_config, unused_kwargs
183
  return pico_config
@@ -187,44 +274,96 @@ class PicoDecoderHFConfig(PretrainedConfig):
187
  return cls.from_dict(asdict(model_config))
188
 
189
 
 
 
190
  class PicoDecoderHF(PreTrainedModel):
191
- """Load with: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)"""
 
 
 
 
 
 
 
 
192
  config_class = PicoDecoderHFConfig
193
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
194
- _tied_weights_keys = [] # FIX 3
195
 
196
- # FIX 4: explicit property β€” transformers >= 4.38 calls this directly
197
  @property
198
  def all_tied_weights_keys(self):
199
- return self._tied_weights_keys
200
 
201
  def __init__(self, config: PicoDecoderHFConfig):
202
  super().__init__(config)
203
- self.pico_decoder = PicoDecoder(config)
 
 
 
 
 
 
204
 
205
  def get_input_embeddings(self):
206
- return self.pico_decoder.embedding_proj
207
 
208
  def set_input_embeddings(self, value):
209
- self.pico_decoder.embedding_proj = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- def forward(self, input_ids=None, past_key_values=None,
212
- use_cache=False, labels=None, **kwargs):
213
- logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
214
  loss = None
215
  if labels is not None:
 
 
216
  loss = F.cross_entropy(
217
- logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
218
- labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1).view(-1),
219
  )
 
220
  if use_cache:
221
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
 
222
  return CausalLMOutput(loss=loss, logits=logits)
223
 
224
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
225
- return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
 
 
 
 
226
 
227
 
 
228
  PicoDecoderHFConfig.register_for_auto_class()
229
  PicoDecoderHF.register_for_auto_class("AutoModel")
230
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
 
1
+ # pico_decoder.py β€” BeetleLM HuggingFace wrapper
2
+ # Source: pico-lm/pico-train (Apache 2.0)
3
+ # Load: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
4
+
5
  from dataclasses import asdict
6
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
7
+
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from torch.nn.attention import SDPBackend, sdpa_kernel
12
  from transformers import PretrainedConfig, PreTrainedModel
13
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
14
+
15
  try:
16
  if TYPE_CHECKING:
17
  from src.config import ModelConfig
 
19
  pass
20
 
21
 
22
+ # ── RMSNorm ──────────────────────────────────────────────────────────────────
23
+
24
  class RMSNorm(torch.nn.Module):
25
  def __init__(self, config):
26
  super().__init__()
27
+ self.eps = config.norm_eps
28
  self.weight = nn.Parameter(torch.ones(config.d_model))
29
+
30
  def _norm(self, x):
31
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
32
+
33
  def forward(self, x):
34
  return self._norm(x.float()).type_as(x) * self.weight
35
 
36
 
37
+ # ── RoPE ─────────────────────────────────────────────────────────────────────
38
+
39
  class RoPE(nn.Module):
40
  _freqs_cis_tensor = None
41
+
42
  def __init__(self, config):
43
  super().__init__()
44
  self.theta = config.position_emb_theta
45
  self.dim = config.d_model // config.attention_n_heads
46
  if RoPE._freqs_cis_tensor is None:
47
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
48
+ config.max_seq_len, self.theta, self.dim
49
+ )
50
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
51
+
52
  @classmethod
53
  def _setup_freqs_cis(cls, seq_len, theta, dim):
54
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
55
+ freqs = torch.outer(torch.arange(seq_len), _freqs)
56
+ return torch.polar(torch.ones_like(freqs), freqs)
57
+
58
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
59
+ _f = self._freqs_cis[start_pos:end_pos]
60
  ndim = len(input_shape)
61
+ assert 0 <= 1 < ndim
62
+ assert _f.shape == (input_shape[1], input_shape[-1])
63
+ shape = [d if i == 1 or i == ndim - 1 else 1
64
+ for i, d in enumerate(input_shape)]
65
+ return _f.view(*shape)
66
+
67
  def forward(self, queries, keys, start_pos=0):
68
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
69
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
70
+ input_shape = q_.shape
71
+ fc = self.get_freqs_cis(input_shape, start_pos, start_pos + q_.shape[1])
72
+ q_rot = torch.view_as_real(q_ * fc).flatten(3)
73
+ k_rot = torch.view_as_real(k_ * fc).flatten(3)
74
+ return q_rot.type_as(queries), k_rot.type_as(keys)
75
 
76
 
77
+ # ── Attention ────────────────────────────────────────────────────────────────
78
+
79
  class Attention(nn.Module):
80
  def __init__(self, config):
81
  super().__init__()
 
83
  self.n_kv_heads = config.attention_n_kv_heads
84
  self.batch_size = config.batch_size
85
  self.max_seq_len = config.max_seq_len
86
+ d = config.d_model
87
  self.head_dim = d // self.n_heads
88
  self.n_rep = self.n_heads // self.n_kv_heads
89
+ self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
90
+ self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
91
+ self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
92
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
93
+ self.rope = RoPE(config)
94
+
95
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
96
  bsz, seq_len, _ = input.shape
97
+ queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
98
+ keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
99
+ values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
100
+
101
+ start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
102
+ queries, keys = self.rope(queries, keys, start_pos)
103
+
104
  if past_key_values is not None:
105
+ keys = torch.cat([past_key_values[0], keys], dim=1)
106
+ values = torch.cat([past_key_values[1], values], dim=1)
107
+
108
+ cached_keys = keys if use_cache else None
109
+ cached_values = values if use_cache else None
110
+
111
+ queries = queries.transpose(1, 2)
112
+ keys = keys.transpose(1, 2)
113
+ values = values.transpose(1, 2)
114
+
115
+ apply_gqa = self.n_rep > 1
116
+ if apply_gqa and queries.device.type == "mps":
117
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
118
+ values = values.repeat_interleave(self.n_rep, dim=-3)
119
+ apply_gqa = False
120
+
121
+ # FIX: guard mask against None (happens during generation when seq_len==1)
122
+ attn_mask = mask.to(queries.dtype) if mask is not None else None
123
+
124
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
125
+ attn_output = F.scaled_dot_product_attention(
126
+ queries.contiguous(),
127
+ keys.contiguous(),
128
+ values.contiguous(),
129
+ attn_mask=attn_mask,
130
+ enable_gqa=apply_gqa,
131
  )
 
132
 
133
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
134
+ return self.o_proj(attn_output), (cached_keys, cached_values)
135
+
136
+
137
+ # ── SwiGLU ───────────────────────────────────────────────────────────────────
138
 
139
  class SwiGLU(nn.Module):
140
  def __init__(self, config):
 
142
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
143
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
144
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
145
+
146
  def forward(self, x):
147
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
148
 
149
 
150
+ # ── PicoDecoderBlock ─────────────────────────────────────────────────────────
151
+
152
  class PicoDecoderBlock(nn.Module):
153
  def __init__(self, config):
154
  super().__init__()
 
156
  self.swiglu = SwiGLU(config)
157
  self.attention_norm = RMSNorm(config)
158
  self.swiglu_norm = RMSNorm(config)
159
+
160
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
161
+ attention_output, cached_key_values = self.attention(
162
+ self.attention_norm(input),
163
+ mask=mask,
164
+ past_key_values=past_key_values,
165
+ use_cache=use_cache,
166
+ )
167
+ h = input + attention_output
168
+ out = h + self.swiglu(self.swiglu_norm(h))
169
+ return out, cached_key_values
170
 
171
 
172
+ # ── PicoDecoder (standalone, used during training) ───────────────────────────
173
+
174
  class PicoDecoder(nn.Module):
175
  def __init__(self, model_config):
176
  super().__init__()
177
+ self.config = model_config
178
  self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
179
+ self.layers = nn.ModuleList(
180
+ [PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
181
  self.output_norm = RMSNorm(model_config)
182
+ self.de_embedding_proj = nn.Linear(
183
+ model_config.d_model, model_config.vocab_size, bias=False)
184
+
185
  def convert_to_hf_model(self):
186
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
187
+ hf_model = PicoDecoderHF(hf_config)
188
+ hf_model.load_state_dict(self.state_dict())
189
+ return hf_model
190
+
191
  def forward(self, input_ids, past_key_values=None, use_cache=False):
192
+ seq_len = input_ids.shape[-1]
193
+ h = self.embedding_proj(input_ids)
194
+ start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
195
+
196
  mask = None
197
+ if seq_len > 1:
198
+ mask = torch.full((seq_len, seq_len), float("-inf"))
199
+ mask = torch.triu(mask, diagonal=1)
200
  if past_key_values is not None:
201
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
202
  mask = mask.to(h.device)
203
+
204
+ cached_key_values = () if use_cache else None
205
+ for idx, layer in enumerate(self.layers):
206
+ layer_past = past_key_values[idx] if past_key_values is not None else None
207
+ h, layer_cached = layer(
208
+ h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
209
  if use_cache:
210
+ cached_key_values += (layer_cached,)
211
+
212
+ h = self.output_norm(h)
213
+ logits = self.de_embedding_proj(h).float()
214
+ return logits, cached_key_values
215
 
216
 
217
+ # ── PicoDecoderHFConfig ───────────────────────────────────────────────────────
218
 
219
  class PicoDecoderHFConfig(PretrainedConfig):
220
+ """
221
+ HuggingFace config for BeetleLM PicoDecoder.
222
+
223
+ Defaults match generate_configs.py MODEL_BASE. vocab_size is overridden
224
+ per-repo in config.json because the trainer sets it from the tokenizer.
225
+ """
226
  model_type = "pico_decoder"
227
 
228
+ def __init__(
229
+ self,
230
+ # Architecture β€” defaults from generate_configs.py MODEL_BASE
231
+ n_layers = 14,
232
+ d_model = 768,
233
+ vocab_size = 32768, # overridden per-repo in config.json
234
+ attention_n_heads = 12,
235
+ attention_n_kv_heads = 1, # MQA
236
+ max_seq_len = 512,
237
+ batch_size = 64,
238
+ position_emb_theta = 10000.0,
239
+ activation_hidden_dim = 3072,
240
+ norm_eps = 1e-5,
241
+ dropout = 0.1,
242
+ **kwargs,
243
+ ):
244
+ # FIX: guard against None/0/missing attention_n_kv_heads in old config.json
245
+ if not attention_n_kv_heads:
246
  attention_n_kv_heads = attention_n_heads
247
+
248
  super().__init__(**kwargs)
249
  self.n_layers = n_layers
250
  self.d_model = d_model
 
262
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
263
  pico_config = cls(**config_dict)
264
  return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
265
+ unused_kwargs = {
266
+ k: v for k, v in kwargs.items() if not hasattr(pico_config, k)
267
+ }
268
  if return_unused_kwargs:
269
  return pico_config, unused_kwargs
270
  return pico_config
 
274
  return cls.from_dict(asdict(model_config))
275
 
276
 
277
+ # ── PicoDecoderHF ─────────────────────────────────────────────────────────────
278
+
279
  class PicoDecoderHF(PreTrainedModel):
280
+ """
281
+ HuggingFace wrapper for BeetleLM PicoDecoder.
282
+
283
+ IMPORTANT β€” weights live at the TOP LEVEL (not under self.pico_decoder)
284
+ because the trainer saves raw PicoDecoder state dicts:
285
+ checkpoint keys: embedding_proj.weight, layers.0.attention.q_proj.weight ...
286
+ A self.pico_decoder wrapper would expect pico_decoder.embedding_proj.weight
287
+ which does not exist in any saved checkpoint.
288
+ """
289
  config_class = PicoDecoderHFConfig
290
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
291
+ _tied_weights_keys = []
292
 
293
+ # FIX: transformers >= 4.38 calls .keys() on this β€” must return a dict
294
  @property
295
  def all_tied_weights_keys(self):
296
+ return {}
297
 
298
  def __init__(self, config: PicoDecoderHFConfig):
299
  super().__init__(config)
300
+ # FIX: top-level storage β€” matches raw PicoDecoder checkpoint keys
301
+ self.embedding_proj = nn.Embedding(config.vocab_size, config.d_model)
302
+ self.layers = nn.ModuleList(
303
+ [PicoDecoderBlock(config) for _ in range(config.n_layers)])
304
+ self.output_norm = RMSNorm(config)
305
+ self.de_embedding_proj = nn.Linear(
306
+ config.d_model, config.vocab_size, bias=False)
307
 
308
  def get_input_embeddings(self):
309
+ return self.embedding_proj
310
 
311
  def set_input_embeddings(self, value):
312
+ self.embedding_proj = value
313
+
314
+ def forward(
315
+ self,
316
+ input_ids = None,
317
+ past_key_values = None,
318
+ use_cache = False,
319
+ labels = None,
320
+ **kwargs,
321
+ ):
322
+ seq_len = input_ids.shape[-1]
323
+ h = self.embedding_proj(input_ids)
324
+ start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
325
+
326
+ mask = None
327
+ if seq_len > 1:
328
+ mask = torch.full((seq_len, seq_len), float("-inf"))
329
+ mask = torch.triu(mask, diagonal=1)
330
+ if past_key_values is not None:
331
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
332
+ mask = mask.to(h.device)
333
+
334
+ cached_key_values = () if use_cache else None
335
+ for idx, layer in enumerate(self.layers):
336
+ layer_past = past_key_values[idx] if past_key_values is not None else None
337
+ h, layer_cached = layer(
338
+ h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
339
+ if use_cache:
340
+ cached_key_values += (layer_cached,)
341
+
342
+ logits = self.de_embedding_proj(self.output_norm(h)).float()
343
 
 
 
 
344
  loss = None
345
  if labels is not None:
346
+ shift_logits = logits[:, :-1].contiguous()
347
+ shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
348
  loss = F.cross_entropy(
349
+ shift_logits.view(-1, self.config.vocab_size),
350
+ shift_labels.view(-1),
351
  )
352
+
353
  if use_cache:
354
+ return CausalLMOutputWithPast(
355
+ loss=loss, logits=logits, past_key_values=cached_key_values)
356
  return CausalLMOutput(loss=loss, logits=logits)
357
 
358
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
359
+ return {
360
+ "input_ids": input_ids,
361
+ "past_key_values": past_key_values,
362
+ "use_cache": True,
363
+ }
364
 
365
 
366
+ # Auto-class registration
367
  PicoDecoderHFConfig.register_for_auto_class()
368
  PicoDecoderHF.register_for_auto_class("AutoModel")
369
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")