suchirsalhan commited on
Commit
c3ad4c9
Β·
verified Β·
1 Parent(s): 3850015

Fix pico_decoder.py: __init__ defaults, ZeroDivisionError, all_tied_weights_keys

Browse files
Files changed (1) hide show
  1. pico_decoder.py +66 -105
pico_decoder.py CHANGED
@@ -1,14 +1,11 @@
1
-
2
  from dataclasses import asdict
3
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
4
-
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch.nn.attention import SDPBackend, sdpa_kernel
9
  from transformers import PretrainedConfig, PreTrainedModel
10
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
11
-
12
  try:
13
  if TYPE_CHECKING:
14
  from src.config import ModelConfig
@@ -21,49 +18,36 @@ class RMSNorm(torch.nn.Module):
21
  super().__init__()
22
  self.eps = config.norm_eps
23
  self.weight = nn.Parameter(torch.ones(config.d_model))
24
-
25
  def _norm(self, x):
26
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
27
-
28
  def forward(self, x):
29
  return self._norm(x.float()).type_as(x) * self.weight
30
 
31
 
32
  class RoPE(nn.Module):
33
  _freqs_cis_tensor = None
34
-
35
  def __init__(self, config):
36
  super().__init__()
37
  self.theta = config.position_emb_theta
38
- self.dim = config.d_model // config.attention_n_heads
39
  if RoPE._freqs_cis_tensor is None:
40
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
41
- config.max_seq_len, self.theta, self.dim
42
- )
43
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
44
-
45
  @classmethod
46
  def _setup_freqs_cis(cls, seq_len, theta, dim):
47
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
48
- freqs = torch.outer(torch.arange(seq_len), _freqs)
49
- return torch.polar(torch.ones_like(freqs), freqs)
50
-
51
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
52
  _f = self._freqs_cis[start_pos:end_pos]
53
  ndim = len(input_shape)
54
- assert 0 <= 1 < ndim
55
- assert _f.shape == (input_shape[1], input_shape[-1])
56
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
57
- return _f.view(*shape)
58
-
59
  def forward(self, queries, keys, start_pos=0):
60
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
61
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
62
  fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
63
- return (
64
- torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
65
- torch.view_as_real(k_ * fc).flatten(3).type_as(keys),
66
- )
67
 
68
 
69
  class Attention(nn.Module):
@@ -81,35 +65,30 @@ class Attention(nn.Module):
81
  self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
82
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
83
  self.rope = RoPE(config)
84
-
85
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
86
  bsz, seq_len, _ = input.shape
87
- queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
88
- keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
89
- values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
90
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
91
- queries, keys = self.rope(queries, keys, start_pos)
92
  if past_key_values is not None:
93
- keys = torch.cat([past_key_values[0], keys], dim=1)
94
- values = torch.cat([past_key_values[1], values], dim=1)
95
- cached_keys = keys if use_cache else None
96
- cached_values = values if use_cache else None
97
- queries = queries.transpose(1, 2)
98
- keys = keys.transpose(1, 2)
99
- values = values.transpose(1, 2)
100
- apply_gqa = self.n_rep > 1
101
- if apply_gqa and queries.device.type == "mps":
102
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
103
- values = values.repeat_interleave(self.n_rep, dim=-3)
104
- apply_gqa = False
105
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
106
  out = F.scaled_dot_product_attention(
107
- queries.contiguous(), keys.contiguous(), values.contiguous(),
108
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
109
- enable_gqa=apply_gqa,
110
  )
111
- out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
112
- return self.o_proj(out), (cached_keys, cached_values)
113
 
114
 
115
  class SwiGLU(nn.Module):
@@ -118,7 +97,6 @@ class SwiGLU(nn.Module):
118
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
119
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
120
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
121
-
122
  def forward(self, x):
123
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
124
 
@@ -130,14 +108,11 @@ class PicoDecoderBlock(nn.Module):
130
  self.swiglu = SwiGLU(config)
131
  self.attention_norm = RMSNorm(config)
132
  self.swiglu_norm = RMSNorm(config)
133
-
134
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
135
- attn_out, cached = self.attention(
136
- self.attention_norm(input), mask=mask,
137
- past_key_values=past_key_values, use_cache=use_cache,
138
- )
139
- h = input + attn_out
140
- return h + self.swiglu(self.swiglu_norm(h)), cached
141
 
142
 
143
  class PicoDecoder(nn.Module):
@@ -145,60 +120,45 @@ class PicoDecoder(nn.Module):
145
  super().__init__()
146
  self.config = model_config
147
  self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
148
- self.layers = nn.ModuleList(
149
- [PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)]
150
- )
151
  self.output_norm = RMSNorm(model_config)
152
  self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
153
-
154
  def convert_to_hf_model(self):
155
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
156
- hf_model = PicoDecoderHF(hf_config)
157
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
158
- return hf_model
159
-
160
  def forward(self, input_ids, past_key_values=None, use_cache=False):
161
- seq_len = input_ids.shape[-1]
162
- h = self.embedding_proj(input_ids)
163
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
164
  mask = None
165
- if seq_len > 1:
166
- mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1)
167
  if past_key_values is not None:
168
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
169
  mask = mask.to(h.device)
170
- cached_key_values = () if use_cache else None
171
- for idx, layer in enumerate(self.layers):
172
- layer_past = past_key_values[idx] if past_key_values is not None else None
173
- h, layer_cached = layer(h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
174
  if use_cache:
175
- cached_key_values += (layer_cached,)
176
- return self.de_embedding_proj(self.output_norm(h)).float(), cached_key_values
177
 
178
 
 
 
179
  class PicoDecoderHFConfig(PretrainedConfig):
180
- """HuggingFace config for BeetleLM PicoDecoder.
181
- Defaults match generate_configs.py MODEL_BASE exactly.
182
- """
183
  model_type = "pico_decoder"
184
 
185
- def __init__(
186
- self,
187
- n_layers: int = 14,
188
- d_model: int = 768,
189
- vocab_size: int = 32768,
190
- attention_n_heads: int = 12,
191
- attention_n_kv_heads: int = 1,
192
- max_seq_len: int = 512,
193
- batch_size: int = 64,
194
- position_emb_theta: float = 10000.0,
195
- activation_hidden_dim: int = 3072,
196
- norm_eps: float = 1e-5,
197
- dropout: float = 0.1,
198
- **kwargs,
199
- ):
200
- # Fix: guard against None/0/missing attention_n_kv_heads in old config.json
201
- if not attention_n_kv_heads:
202
  attention_n_kv_heads = attention_n_heads
203
  super().__init__(**kwargs)
204
  self.n_layers = n_layers
@@ -228,12 +188,15 @@ class PicoDecoderHFConfig(PretrainedConfig):
228
 
229
 
230
  class PicoDecoderHF(PreTrainedModel):
231
- """HuggingFace wrapper for BeetleLM PicoDecoder.
232
- Load with: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
233
- """
234
  config_class = PicoDecoderHFConfig
235
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
236
- _tied_weights_keys = [] # Fix: required by transformers >= 4.38
 
 
 
 
 
237
 
238
  def __init__(self, config: PicoDecoderHFConfig):
239
  super().__init__(config)
@@ -245,16 +208,14 @@ class PicoDecoderHF(PreTrainedModel):
245
  def set_input_embeddings(self, value):
246
  self.pico_decoder.embedding_proj = value
247
 
248
- def forward(self, input_ids=None, past_key_values=None, use_cache=False,
249
- labels=None, **kwargs):
250
  logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
251
  loss = None
252
  if labels is not None:
253
- shift_logits = logits[:, :-1].contiguous()
254
- shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
255
  loss = F.cross_entropy(
256
- shift_logits.view(-1, self.config.vocab_size),
257
- shift_labels.view(-1),
258
  )
259
  if use_cache:
260
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
 
 
1
  from dataclasses import asdict
2
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torch.nn.attention import SDPBackend, sdpa_kernel
7
  from transformers import PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
 
9
  try:
10
  if TYPE_CHECKING:
11
  from src.config import ModelConfig
 
18
  super().__init__()
19
  self.eps = config.norm_eps
20
  self.weight = nn.Parameter(torch.ones(config.d_model))
 
21
  def _norm(self, x):
22
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
23
  def forward(self, x):
24
  return self._norm(x.float()).type_as(x) * self.weight
25
 
26
 
27
  class RoPE(nn.Module):
28
  _freqs_cis_tensor = None
 
29
  def __init__(self, config):
30
  super().__init__()
31
  self.theta = config.position_emb_theta
32
+ self.dim = config.d_model // config.attention_n_heads
33
  if RoPE._freqs_cis_tensor is None:
34
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(config.max_seq_len, self.theta, self.dim)
 
 
35
  self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
 
36
  @classmethod
37
  def _setup_freqs_cis(cls, seq_len, theta, dim):
38
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
39
+ return torch.polar(torch.ones_like(f := torch.outer(torch.arange(seq_len), _freqs)), f)
 
 
40
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
41
  _f = self._freqs_cis[start_pos:end_pos]
42
  ndim = len(input_shape)
43
+ assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
44
+ return _f.view(*[d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)])
 
 
 
45
  def forward(self, queries, keys, start_pos=0):
46
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
47
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
48
  fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
49
+ return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
50
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
 
 
51
 
52
 
53
  class Attention(nn.Module):
 
65
  self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
66
  self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
67
  self.rope = RoPE(config)
 
68
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
69
  bsz, seq_len, _ = input.shape
70
+ q = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
71
+ k = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
72
+ v = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
73
+ sp = past_key_values[0].shape[1] if past_key_values is not None else 0
74
+ q, k = self.rope(q, k, sp)
75
  if past_key_values is not None:
76
+ k = torch.cat([past_key_values[0], k], dim=1)
77
+ v = torch.cat([past_key_values[1], v], dim=1)
78
+ ck, cv = (k, v) if use_cache else (None, None)
79
+ q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
80
+ gqa = self.n_rep > 1
81
+ if gqa and q.device.type == "mps":
82
+ k = k.repeat_interleave(self.n_rep, dim=-3)
83
+ v = v.repeat_interleave(self.n_rep, dim=-3)
84
+ gqa = False
 
 
 
85
  with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
86
  out = F.scaled_dot_product_attention(
87
+ q.contiguous(), k.contiguous(), v.contiguous(),
88
+ attn_mask=mask.to(q.dtype) if mask is not None else None,
89
+ enable_gqa=gqa,
90
  )
91
+ return self.o_proj(out.transpose(1,2).contiguous().view(bsz, seq_len, -1)), (ck, cv)
 
92
 
93
 
94
  class SwiGLU(nn.Module):
 
97
  self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
98
  self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
99
  self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
 
100
  def forward(self, x):
101
  return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
102
 
 
108
  self.swiglu = SwiGLU(config)
109
  self.attention_norm = RMSNorm(config)
110
  self.swiglu_norm = RMSNorm(config)
 
111
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
112
+ a, c = self.attention(self.attention_norm(input), mask=mask,
113
+ past_key_values=past_key_values, use_cache=use_cache)
114
+ h = input + a
115
+ return h + self.swiglu(self.swiglu_norm(h)), c
 
 
116
 
117
 
118
  class PicoDecoder(nn.Module):
 
120
  super().__init__()
121
  self.config = model_config
122
  self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
123
+ self.layers = nn.ModuleList([PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
 
 
124
  self.output_norm = RMSNorm(model_config)
125
  self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
 
126
  def convert_to_hf_model(self):
127
+ hf = PicoDecoderHF(PicoDecoderHFConfig.from_dataclass(self.config))
128
+ hf.load_state_dict(self.state_dict(prefix="pico_decoder."))
129
+ return hf
 
 
130
  def forward(self, input_ids, past_key_values=None, use_cache=False):
131
+ sl = input_ids.shape[-1]
132
+ h = self.embedding_proj(input_ids)
133
+ sp = 0 if past_key_values is None else past_key_values[0][0].shape[1]
134
  mask = None
135
+ if sl > 1:
136
+ mask = torch.triu(torch.full((sl, sl), float("-inf")), diagonal=1)
137
  if past_key_values is not None:
138
+ mask = torch.hstack([torch.zeros((sl, sp)), mask])
139
  mask = mask.to(h.device)
140
+ ckv = () if use_cache else None
141
+ for i, layer in enumerate(self.layers):
142
+ lp = past_key_values[i] if past_key_values is not None else None
143
+ h, lc = layer(h, mask=mask, past_key_values=lp, use_cache=use_cache)
144
  if use_cache:
145
+ ckv += (lc,)
146
+ return self.de_embedding_proj(self.output_norm(h)).float(), ckv
147
 
148
 
149
+ # ── HuggingFace wrappers ──────────────────────────────────────────────────────
150
+
151
  class PicoDecoderHFConfig(PretrainedConfig):
 
 
 
152
  model_type = "pico_decoder"
153
 
154
+ # FIX 1 + 2: explicit __init__ with MODEL_BASE defaults; guards None/0 kv_heads
155
+ def __init__(self,
156
+ n_layers=14, d_model=768, vocab_size=32768,
157
+ attention_n_heads=12, attention_n_kv_heads=1,
158
+ max_seq_len=512, batch_size=64, position_emb_theta=10000.0,
159
+ activation_hidden_dim=3072, norm_eps=1e-5, dropout=0.1,
160
+ **kwargs):
161
+ if not attention_n_kv_heads: # catches None, 0, missing
 
 
 
 
 
 
 
 
 
162
  attention_n_kv_heads = attention_n_heads
163
  super().__init__(**kwargs)
164
  self.n_layers = n_layers
 
188
 
189
 
190
  class PicoDecoderHF(PreTrainedModel):
191
+ """Load with: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)"""
 
 
192
  config_class = PicoDecoderHFConfig
193
  _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
194
+ _tied_weights_keys = [] # FIX 3
195
+
196
+ # FIX 4: explicit property β€” transformers >= 4.38 calls this directly
197
+ @property
198
+ def all_tied_weights_keys(self):
199
+ return self._tied_weights_keys
200
 
201
  def __init__(self, config: PicoDecoderHFConfig):
202
  super().__init__(config)
 
208
  def set_input_embeddings(self, value):
209
  self.pico_decoder.embedding_proj = value
210
 
211
+ def forward(self, input_ids=None, past_key_values=None,
212
+ use_cache=False, labels=None, **kwargs):
213
  logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
214
  loss = None
215
  if labels is not None:
 
 
216
  loss = F.cross_entropy(
217
+ logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
218
+ labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1).view(-1),
219
  )
220
  if use_cache:
221
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)