suchirsalhan commited on
Commit
9b894b8
Β·
verified Β·
1 Parent(s): f0f046e

Fix pico_decoder.py: __init__ defaults, ZeroDivisionError, all_tied_weights_keys

Browse files
Files changed (2) hide show
  1. config.json +11 -19
  2. pico_decoder.py +99 -119
config.json CHANGED
@@ -3,29 +3,21 @@
3
  "PicoDecoderHF"
4
  ],
5
  "model_type": "pico_decoder",
6
- "vocab_size": 32000,
7
- "hidden_size": 768,
8
- "num_hidden_layers": 14,
9
- "num_attention_heads": 4,
10
- "intermediate_size": 3072,
11
- "max_position_embeddings": 2048,
12
- "hidden_act": "silu",
13
- "initializer_range": 0.02,
14
- "rms_norm_eps": 1e-05,
15
- "tie_word_embeddings": false,
16
- "torch_dtype": "float32",
17
- "transformers_version": "4.48.3",
18
  "auto_map": {
19
  "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
20
  "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
21
  },
22
- "d_model": 768,
23
  "n_layers": 14,
24
- "attention_n_heads": 6,
25
- "attention_n_kv_heads": 0,
26
- "activation_hidden_dim": 3072,
27
- "max_seq_len": 2048,
28
- "norm_eps": 1e-06,
 
29
  "position_emb_theta": 10000.0,
30
- "batch_size": 1
 
 
 
 
31
  }
 
3
  "PicoDecoderHF"
4
  ],
5
  "model_type": "pico_decoder",
 
 
 
 
 
 
 
 
 
 
 
 
6
  "auto_map": {
7
  "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
8
  "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
9
  },
 
10
  "n_layers": 14,
11
+ "d_model": 768,
12
+ "vocab_size": 32768,
13
+ "attention_n_heads": 12,
14
+ "attention_n_kv_heads": 1,
15
+ "max_seq_len": 512,
16
+ "batch_size": 64,
17
  "position_emb_theta": 10000.0,
18
+ "activation_hidden_dim": 3072,
19
+ "norm_eps": 1e-05,
20
+ "dropout": 0.1,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.48.3"
23
  }
pico_decoder.py CHANGED
@@ -1,117 +1,95 @@
1
-
2
- """
3
- Pico Decoder β€” BeetleLM
4
- Adapted from pico-lm/pico-decoder-tiny (Apache 2.0).
5
- Load with trust_remote_code=True.
6
- """
7
-
8
- from typing import Optional, Tuple, Union
9
-
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch.nn.attention import SDPBackend, sdpa_kernel
14
  from transformers import PretrainedConfig, PreTrainedModel
15
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
 
 
 
 
 
16
 
17
 
18
- # ── RMSNorm ───────────────────────────────────────────────────────────────────
19
-
20
  class RMSNorm(torch.nn.Module):
21
  def __init__(self, config):
22
  super().__init__()
23
  self.eps = config.norm_eps
24
  self.weight = nn.Parameter(torch.ones(config.d_model))
25
-
26
  def _norm(self, x):
27
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
-
29
  def forward(self, x):
30
  return self._norm(x.float()).type_as(x) * self.weight
31
 
32
 
33
- # ── RoPE ──────────────────────────────────────────────────────────────────────
34
-
35
  class RoPE(nn.Module):
36
  _freqs_cis_tensor = None
37
-
38
  def __init__(self, config):
39
  super().__init__()
40
  self.theta = config.position_emb_theta
41
- self.dim = config.d_model // config.attention_n_heads
42
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
43
- config.max_seq_len, self.theta, self.dim
44
- )
45
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
46
-
47
  @classmethod
48
  def _setup_freqs_cis(cls, seq_len, theta, dim):
49
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
50
- freqs = torch.outer(torch.arange(seq_len), _freqs)
51
- return torch.polar(torch.ones_like(freqs), freqs)
52
-
53
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
54
  _f = self._freqs_cis[start_pos:end_pos]
55
  ndim = len(input_shape)
56
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
57
- return _f.view(*shape)
58
-
59
  def forward(self, queries, keys, start_pos=0):
60
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
61
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
62
  fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
63
- return (
64
- torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
65
- torch.view_as_real(k_ * fc).flatten(3).type_as(keys),
66
- )
67
-
68
 
69
- # ── Attention ─────────────────────────────────────────────────────────────────
70
 
71
  class Attention(nn.Module):
72
  def __init__(self, config):
73
  super().__init__()
74
- self.n_heads = config.attention_n_heads
75
- self.n_kv_heads = config.attention_n_kv_heads
76
- self.n_rep = self.n_heads // self.n_kv_heads
77
- self.max_seq_len = config.max_seq_len
78
  self.batch_size = config.batch_size
 
79
  d = config.d_model
80
  self.head_dim = d // self.n_heads
 
81
  self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
82
  self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
83
  self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
84
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
85
  self.rope = RoPE(config)
86
-
87
- def forward(self, x, mask=None, past_key_values=None, use_cache=False):
88
- bsz, seq_len, _ = x.shape
89
- q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim)
90
- k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
91
- v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
92
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
93
- q, k = self.rope(q, k, start_pos)
94
  if past_key_values is not None:
95
  k = torch.cat([past_key_values[0], k], dim=1)
96
  v = torch.cat([past_key_values[1], v], dim=1)
97
  ck, cv = (k, v) if use_cache else (None, None)
98
  q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
99
- apply_gqa = self.n_rep > 1
100
- if apply_gqa and q.device.type == "mps":
101
  k = k.repeat_interleave(self.n_rep, dim=-3)
102
  v = v.repeat_interleave(self.n_rep, dim=-3)
103
- apply_gqa = False
104
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
105
  out = F.scaled_dot_product_attention(
106
  q.contiguous(), k.contiguous(), v.contiguous(),
107
  attn_mask=mask.to(q.dtype) if mask is not None else None,
108
- enable_gqa=apply_gqa,
109
  )
110
- out = out.transpose(1,2).contiguous().view(bsz, seq_len, -1)
111
- return self.o_proj(out), (ck, cv)
112
-
113
 
114
- # ── SwiGLU ────────────────────────────────────────────────────────────────────
115
 
116
  class SwiGLU(nn.Module):
117
  def __init__(self, config):
@@ -119,13 +97,10 @@ class SwiGLU(nn.Module):
119
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
120
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
121
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
122
-
123
  def forward(self, x):
124
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
125
 
126
 
127
- # ── PicoDecoderBlock ──────────────────────────────────────────────────────────
128
-
129
  class PicoDecoderBlock(nn.Module):
130
  def __init__(self, config):
131
  super().__init__()
@@ -133,18 +108,13 @@ class PicoDecoderBlock(nn.Module):
133
  self.swiglu = SwiGLU(config)
134
  self.attention_norm = RMSNorm(config)
135
  self.swiglu_norm = RMSNorm(config)
136
-
137
- def forward(self, x, mask=None, past_key_values=None, use_cache=False):
138
- attn_out, cached = self.attention(
139
- self.attention_norm(x), mask=mask,
140
- past_key_values=past_key_values, use_cache=use_cache,
141
- )
142
- h = x + attn_out
143
- return h + self.swiglu(self.swiglu_norm(h)), cached
144
 
145
 
146
- # ── PicoDecoder ───────────────────────────────────────────────────────────────
147
-
148
  class PicoDecoder(nn.Module):
149
  def __init__(self, model_config):
150
  super().__init__()
@@ -153,67 +123,80 @@ class PicoDecoder(nn.Module):
153
  self.layers = nn.ModuleList([PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
154
  self.output_norm = RMSNorm(model_config)
155
  self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
156
-
 
 
 
157
  def forward(self, input_ids, past_key_values=None, use_cache=False):
158
- seq_len = input_ids.shape[-1]
159
- h = self.embedding_proj(input_ids)
160
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
161
  mask = None
162
- if seq_len > 1:
163
- mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1)
164
  if past_key_values is not None:
165
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
166
  mask = mask.to(h.device)
167
- cached_kvs = () if use_cache else None
168
- for idx, layer in enumerate(self.layers):
169
- layer_past = past_key_values[idx] if past_key_values is not None else None
170
- h, layer_cached = layer(h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
171
  if use_cache:
172
- cached_kvs += (layer_cached,)
173
- return self.de_embedding_proj(self.output_norm(h)).float(), cached_kvs
174
 
175
 
176
- # ── HuggingFace Config ────────────────────────────────────────────────────────
177
 
178
  class PicoDecoderHFConfig(PretrainedConfig):
179
  model_type = "pico_decoder"
180
 
181
- def __init__(
182
- self,
183
- vocab_size=32000,
184
- d_model=256,
185
- n_layers=6,
186
- attention_n_heads=8,
187
- attention_n_kv_heads=4,
188
- activation_hidden_dim=1024,
189
- max_seq_len=2048,
190
- norm_eps=1e-6,
191
- position_emb_theta=10000.0,
192
- batch_size=1,
193
- **kwargs,
194
- ):
195
  super().__init__(**kwargs)
196
- self.vocab_size = vocab_size
197
- self.d_model = d_model
198
- self.n_layers = n_layers
199
- self.attention_n_heads = attention_n_heads
200
- self.attention_n_kv_heads = attention_n_kv_heads
201
- self.activation_hidden_dim = activation_hidden_dim
202
- self.max_seq_len = max_seq_len
203
- self.norm_eps = norm_eps
204
- self.position_emb_theta = position_emb_theta
205
- self.batch_size = batch_size
 
 
 
 
 
 
 
 
 
 
206
 
 
 
 
207
 
208
- # ── HuggingFace Model ─────────────────────────────────────────────────────────
209
 
210
  class PicoDecoderHF(PreTrainedModel):
211
- """
212
- HuggingFace wrapper for BeetleLM PicoDecoder.
213
- Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
214
- """
215
- config_class = PicoDecoderHFConfig
216
- _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
 
 
 
217
 
218
  def __init__(self, config: PicoDecoderHFConfig):
219
  super().__init__(config)
@@ -225,16 +208,14 @@ class PicoDecoderHF(PreTrainedModel):
225
  def set_input_embeddings(self, value):
226
  self.pico_decoder.embedding_proj = value
227
 
228
- def forward(self, input_ids=None, past_key_values=None, use_cache=False, labels=None, **kwargs):
229
- input_ids = input_ids.clamp(0, self.config.vocab_size - 1)
230
  logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
231
  loss = None
232
  if labels is not None:
233
- shift_logits = logits[:, :-1].contiguous()
234
- shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
235
  loss = F.cross_entropy(
236
- shift_logits.view(-1, self.config.vocab_size),
237
- shift_labels.view(-1),
238
  )
239
  if use_cache:
240
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
@@ -244,7 +225,6 @@ class PicoDecoderHF(PreTrainedModel):
244
  return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
245
 
246
 
247
- # Auto-class registration (runs on trust_remote_code import)
248
  PicoDecoderHFConfig.register_for_auto_class()
249
  PicoDecoderHF.register_for_auto_class("AutoModel")
250
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
 
1
+ from dataclasses import asdict
2
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
 
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torch.nn.attention import SDPBackend, sdpa_kernel
7
  from transformers import PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
9
+ try:
10
+ if TYPE_CHECKING:
11
+ from src.config import ModelConfig
12
+ except ImportError:
13
+ pass
14
 
15
 
 
 
16
  class RMSNorm(torch.nn.Module):
17
  def __init__(self, config):
18
  super().__init__()
19
  self.eps = config.norm_eps
20
  self.weight = nn.Parameter(torch.ones(config.d_model))
 
21
  def _norm(self, x):
22
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
23
  def forward(self, x):
24
  return self._norm(x.float()).type_as(x) * self.weight
25
 
26
 
 
 
27
  class RoPE(nn.Module):
28
  _freqs_cis_tensor = None
 
29
  def __init__(self, config):
30
  super().__init__()
31
  self.theta = config.position_emb_theta
32
+ self.dim = config.d_model // config.attention_n_heads
33
+ if RoPE._freqs_cis_tensor is None:
34
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(config.max_seq_len, self.theta, self.dim)
 
35
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
 
36
  @classmethod
37
  def _setup_freqs_cis(cls, seq_len, theta, dim):
38
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
39
+ return torch.polar(torch.ones_like(f := torch.outer(torch.arange(seq_len), _freqs)), f)
 
 
40
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
41
  _f = self._freqs_cis[start_pos:end_pos]
42
  ndim = len(input_shape)
43
+ assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
44
+ return _f.view(*[d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)])
 
45
  def forward(self, queries, keys, start_pos=0):
46
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
47
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
48
  fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
49
+ return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
50
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
 
 
 
51
 
 
52
 
53
  class Attention(nn.Module):
54
  def __init__(self, config):
55
  super().__init__()
56
+ self.n_heads = config.attention_n_heads
57
+ self.n_kv_heads = config.attention_n_kv_heads
 
 
58
  self.batch_size = config.batch_size
59
+ self.max_seq_len = config.max_seq_len
60
  d = config.d_model
61
  self.head_dim = d // self.n_heads
62
+ self.n_rep = self.n_heads // self.n_kv_heads
63
  self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
64
  self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
65
  self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
66
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
67
  self.rope = RoPE(config)
68
+ def forward(self, input, mask=None, past_key_values=None, use_cache=False):
69
+ bsz, seq_len, _ = input.shape
70
+ q = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
71
+ k = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
72
+ v = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
73
+ sp = past_key_values[0].shape[1] if past_key_values is not None else 0
74
+ q, k = self.rope(q, k, sp)
 
75
  if past_key_values is not None:
76
  k = torch.cat([past_key_values[0], k], dim=1)
77
  v = torch.cat([past_key_values[1], v], dim=1)
78
  ck, cv = (k, v) if use_cache else (None, None)
79
  q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
80
+ gqa = self.n_rep > 1
81
+ if gqa and q.device.type == "mps":
82
  k = k.repeat_interleave(self.n_rep, dim=-3)
83
  v = v.repeat_interleave(self.n_rep, dim=-3)
84
+ gqa = False
85
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
86
  out = F.scaled_dot_product_attention(
87
  q.contiguous(), k.contiguous(), v.contiguous(),
88
  attn_mask=mask.to(q.dtype) if mask is not None else None,
89
+ enable_gqa=gqa,
90
  )
91
+ return self.o_proj(out.transpose(1,2).contiguous().view(bsz, seq_len, -1)), (ck, cv)
 
 
92
 
 
93
 
94
  class SwiGLU(nn.Module):
95
  def __init__(self, config):
 
97
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
98
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
99
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
 
100
  def forward(self, x):
101
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
102
 
103
 
 
 
104
  class PicoDecoderBlock(nn.Module):
105
  def __init__(self, config):
106
  super().__init__()
 
108
  self.swiglu = SwiGLU(config)
109
  self.attention_norm = RMSNorm(config)
110
  self.swiglu_norm = RMSNorm(config)
111
+ def forward(self, input, mask=None, past_key_values=None, use_cache=False):
112
+ a, c = self.attention(self.attention_norm(input), mask=mask,
113
+ past_key_values=past_key_values, use_cache=use_cache)
114
+ h = input + a
115
+ return h + self.swiglu(self.swiglu_norm(h)), c
 
 
 
116
 
117
 
 
 
118
  class PicoDecoder(nn.Module):
119
  def __init__(self, model_config):
120
  super().__init__()
 
123
  self.layers = nn.ModuleList([PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
124
  self.output_norm = RMSNorm(model_config)
125
  self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
126
+ def convert_to_hf_model(self):
127
+ hf = PicoDecoderHF(PicoDecoderHFConfig.from_dataclass(self.config))
128
+ hf.load_state_dict(self.state_dict(prefix="pico_decoder."))
129
+ return hf
130
  def forward(self, input_ids, past_key_values=None, use_cache=False):
131
+ sl = input_ids.shape[-1]
132
+ h = self.embedding_proj(input_ids)
133
+ sp = 0 if past_key_values is None else past_key_values[0][0].shape[1]
134
  mask = None
135
+ if sl > 1:
136
+ mask = torch.triu(torch.full((sl, sl), float("-inf")), diagonal=1)
137
  if past_key_values is not None:
138
+ mask = torch.hstack([torch.zeros((sl, sp)), mask])
139
  mask = mask.to(h.device)
140
+ ckv = () if use_cache else None
141
+ for i, layer in enumerate(self.layers):
142
+ lp = past_key_values[i] if past_key_values is not None else None
143
+ h, lc = layer(h, mask=mask, past_key_values=lp, use_cache=use_cache)
144
  if use_cache:
145
+ ckv += (lc,)
146
+ return self.de_embedding_proj(self.output_norm(h)).float(), ckv
147
 
148
 
149
+ # ── HuggingFace wrappers ──────────────────────────────────────────────────────
150
 
151
  class PicoDecoderHFConfig(PretrainedConfig):
152
  model_type = "pico_decoder"
153
 
154
+ # FIX 1 + 2: explicit __init__ with MODEL_BASE defaults; guards None/0 kv_heads
155
+ def __init__(self,
156
+ n_layers=14, d_model=768, vocab_size=32768,
157
+ attention_n_heads=12, attention_n_kv_heads=1,
158
+ max_seq_len=512, batch_size=64, position_emb_theta=10000.0,
159
+ activation_hidden_dim=3072, norm_eps=1e-5, dropout=0.1,
160
+ **kwargs):
161
+ if not attention_n_kv_heads: # catches None, 0, missing
162
+ attention_n_kv_heads = attention_n_heads
 
 
 
 
 
163
  super().__init__(**kwargs)
164
+ self.n_layers = n_layers
165
+ self.d_model = d_model
166
+ self.vocab_size = vocab_size
167
+ self.attention_n_heads = attention_n_heads
168
+ self.attention_n_kv_heads = attention_n_kv_heads
169
+ self.max_seq_len = max_seq_len
170
+ self.batch_size = batch_size
171
+ self.position_emb_theta = position_emb_theta
172
+ self.activation_hidden_dim = activation_hidden_dim
173
+ self.norm_eps = norm_eps
174
+ self.dropout = dropout
175
+
176
+ @classmethod
177
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
178
+ pico_config = cls(**config_dict)
179
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
180
+ unused_kwargs = {k: v for k, v in kwargs.items() if not hasattr(pico_config, k)}
181
+ if return_unused_kwargs:
182
+ return pico_config, unused_kwargs
183
+ return pico_config
184
 
185
+ @classmethod
186
+ def from_dataclass(cls, model_config):
187
+ return cls.from_dict(asdict(model_config))
188
 
 
189
 
190
  class PicoDecoderHF(PreTrainedModel):
191
+ """Load with: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)"""
192
+ config_class = PicoDecoderHFConfig
193
+ _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
194
+ _tied_weights_keys = [] # FIX 3
195
+
196
+ # FIX 4: explicit property β€” transformers >= 4.38 calls this directly
197
+ @property
198
+ def all_tied_weights_keys(self):
199
+ return self._tied_weights_keys
200
 
201
  def __init__(self, config: PicoDecoderHFConfig):
202
  super().__init__(config)
 
208
  def set_input_embeddings(self, value):
209
  self.pico_decoder.embedding_proj = value
210
 
211
+ def forward(self, input_ids=None, past_key_values=None,
212
+ use_cache=False, labels=None, **kwargs):
213
  logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
214
  loss = None
215
  if labels is not None:
 
 
216
  loss = F.cross_entropy(
217
+ logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
218
+ labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1).view(-1),
219
  )
220
  if use_cache:
221
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
 
225
  return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
226
 
227
 
 
228
  PicoDecoderHFConfig.register_for_auto_class()
229
  PicoDecoderHF.register_for_auto_class("AutoModel")
230
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")