suchirsalhan commited on
Commit
3850015
·
verified ·
1 Parent(s): 56a67a7

Fix pico_decoder.py: __init__ defaults, ZeroDivisionError, _tied_weights_keys

Browse files
Files changed (2) hide show
  1. config.json +8 -8
  2. pico_decoder.py +88 -106
config.json CHANGED
@@ -1,4 +1,12 @@
1
  {
 
 
 
 
 
 
 
 
2
  "n_layers": 14,
3
  "d_model": 768,
4
  "vocab_size": 32768,
@@ -10,14 +18,6 @@
10
  "activation_hidden_dim": 3072,
11
  "norm_eps": 1e-05,
12
  "dropout": 0.1,
13
- "architectures": [
14
- "PicoDecoderHF"
15
- ],
16
- "model_type": "pico_decoder",
17
- "auto_map": {
18
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
19
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
20
- },
21
  "torch_dtype": "float32",
22
  "transformers_version": "4.48.3"
23
  }
 
1
  {
2
+ "architectures": [
3
+ "PicoDecoderHF"
4
+ ],
5
+ "model_type": "pico_decoder",
6
+ "auto_map": {
7
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
8
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
9
+ },
10
  "n_layers": 14,
11
  "d_model": 768,
12
  "vocab_size": 32768,
 
18
  "activation_hidden_dim": 3072,
19
  "norm_eps": 1e-05,
20
  "dropout": 0.1,
 
 
 
 
 
 
 
 
21
  "torch_dtype": "float32",
22
  "transformers_version": "4.48.3"
23
  }
pico_decoder.py CHANGED
@@ -1,27 +1,3 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
- Implementation from https://github.com/pico-lm/pico-train/blob/main/src/model/pico_decoder.py
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
- - KV-cache for faster autoregressive generation
15
-
16
- References:
17
- - RoPE: https://arxiv.org/abs/2104.09864
18
- - SwiGLU: https://arxiv.org/abs/2002.05202
19
- - LLAMA: https://arxiv.org/abs/2302.13971
20
-
21
- Adapted from:
22
- - OLMO: https://github.com/allenai/OLMo
23
- - LLAMA: https://github.com/meta/llama
24
- """
25
 
26
  from dataclasses import asdict
27
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
@@ -50,8 +26,7 @@ class RMSNorm(torch.nn.Module):
50
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
51
 
52
  def forward(self, x):
53
- output = self._norm(x.float()).type_as(x)
54
- return output * self.weight
55
 
56
 
57
  class RoPE(nn.Module):
@@ -61,69 +36,64 @@ class RoPE(nn.Module):
61
  super().__init__()
62
  self.theta = config.position_emb_theta
63
  self.dim = config.d_model // config.attention_n_heads
64
- max_seq_len = config.max_seq_len
65
  if RoPE._freqs_cis_tensor is None:
66
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(max_seq_len, self.theta, self.dim)
67
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
 
 
68
 
69
  @classmethod
70
  def _setup_freqs_cis(cls, seq_len, theta, dim):
71
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
72
- positions = torch.arange(seq_len)
73
- freqs = torch.outer(positions, _freqs)
74
  return torch.polar(torch.ones_like(freqs), freqs)
75
 
76
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
77
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
78
  ndim = len(input_shape)
79
  assert 0 <= 1 < ndim
80
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
81
  shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
82
- return _freqs_cis.view(*shape)
83
 
84
  def forward(self, queries, keys, start_pos=0):
85
- queries_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
86
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
87
- input_shape = queries_.shape
88
- freqs_start_pos = start_pos
89
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
90
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
91
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
92
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
93
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
94
 
95
 
96
  class Attention(nn.Module):
97
  def __init__(self, config):
98
  super().__init__()
99
- self.n_heads = config.attention_n_heads
100
- self.n_kv_heads = config.attention_n_kv_heads
101
  self.batch_size = config.batch_size
102
  self.max_seq_len = config.max_seq_len
103
- d_model = config.d_model
104
- self.head_dim = d_model // self.n_heads
105
- self.n_rep = self.n_heads // self.n_kv_heads
106
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
107
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
108
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
109
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
110
- self.rope = RoPE(config)
111
 
112
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
113
  bsz, seq_len, _ = input.shape
114
- _queries, _keys, _values = self.q_proj(input), self.k_proj(input), self.v_proj(input)
115
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
116
- keys = _keys.view( bsz, seq_len, self.n_kv_heads, self.head_dim)
117
- values = _values.view( bsz, seq_len, self.n_kv_heads, self.head_dim)
118
  start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
119
  queries, keys = self.rope(queries, keys, start_pos)
120
  if past_key_values is not None:
121
  keys = torch.cat([past_key_values[0], keys], dim=1)
122
  values = torch.cat([past_key_values[1], values], dim=1)
123
- if use_cache:
124
- cached_keys, cached_values = keys, values
125
- else:
126
- cached_keys = cached_values = None
127
  queries = queries.transpose(1, 2)
128
  keys = keys.transpose(1, 2)
129
  values = values.transpose(1, 2)
@@ -132,15 +102,14 @@ class Attention(nn.Module):
132
  keys = keys.repeat_interleave(self.n_rep, dim=-3)
133
  values = values.repeat_interleave(self.n_rep, dim=-3)
134
  apply_gqa = False
135
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
136
- with sdpa_kernel(backends=backends):
137
- attn_output = F.scaled_dot_product_attention(
138
  queries.contiguous(), keys.contiguous(), values.contiguous(),
139
- attn_mask=mask.to(queries.dtype),
140
  enable_gqa=apply_gqa,
141
  )
142
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
143
- return self.o_proj(attn_output), (cached_keys, cached_values)
144
 
145
 
146
  class SwiGLU(nn.Module):
@@ -163,25 +132,24 @@ class PicoDecoderBlock(nn.Module):
163
  self.swiglu_norm = RMSNorm(config)
164
 
165
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
166
- attention_output, cached_key_values = self.attention(
167
  self.attention_norm(input), mask=mask,
168
  past_key_values=past_key_values, use_cache=use_cache,
169
  )
170
- h = input + attention_output
171
- out = h + self.swiglu(self.swiglu_norm(h))
172
- return out, cached_key_values
173
 
174
 
175
  class PicoDecoder(nn.Module):
176
  def __init__(self, model_config):
177
  super().__init__()
178
  self.config = model_config
179
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
180
  self.layers = nn.ModuleList(
181
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
182
  )
183
- self.output_norm = RMSNorm(self.config)
184
- self.de_embedding_proj = nn.Linear(self.config.d_model, self.config.vocab_size, bias=False)
185
 
186
  def convert_to_hf_model(self):
187
  hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
@@ -195,8 +163,7 @@ class PicoDecoder(nn.Module):
195
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
196
  mask = None
197
  if seq_len > 1:
198
- mask = torch.full((seq_len, seq_len), float("-inf"))
199
- mask = torch.triu(mask, diagonal=1)
200
  if past_key_values is not None:
201
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
202
  mask = mask.to(h.device)
@@ -206,19 +173,15 @@ class PicoDecoder(nn.Module):
206
  h, layer_cached = layer(h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
207
  if use_cache:
208
  cached_key_values += (layer_cached,)
209
- h = self.output_norm(h)
210
- logits = self.de_embedding_proj(h).float()
211
- return logits, cached_key_values
212
 
213
 
214
  class PicoDecoderHFConfig(PretrainedConfig):
215
- """Config class for the Pico Decoder HuggingFace wrapper."""
216
-
 
217
  model_type = "pico_decoder"
218
 
219
- # Defaults match generate_configs.py MODEL_BASE exactly.
220
- # The guard on attention_n_kv_heads fixes ZeroDivisionError when the field
221
- # is missing or null in config.json from older checkpoints.
222
  def __init__(
223
  self,
224
  n_layers: int = 14,
@@ -234,20 +197,21 @@ class PicoDecoderHFConfig(PretrainedConfig):
234
  dropout: float = 0.1,
235
  **kwargs,
236
  ):
 
237
  if not attention_n_kv_heads:
238
  attention_n_kv_heads = attention_n_heads
239
  super().__init__(**kwargs)
240
- self.n_layers = n_layers
241
- self.d_model = d_model
242
- self.vocab_size = vocab_size
243
- self.attention_n_heads = attention_n_heads
244
- self.attention_n_kv_heads = attention_n_kv_heads
245
- self.max_seq_len = max_seq_len
246
- self.batch_size = batch_size
247
- self.position_emb_theta = position_emb_theta
248
- self.activation_hidden_dim = activation_hidden_dim
249
- self.norm_eps = norm_eps
250
- self.dropout = dropout
251
 
252
  @classmethod
253
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
@@ -264,22 +228,40 @@ class PicoDecoderHFConfig(PretrainedConfig):
264
 
265
 
266
  class PicoDecoderHF(PreTrainedModel):
 
 
267
  """
268
- HuggingFace wrapper for BeetleLM PicoDecoder.
269
- Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
270
- """
271
- config_class = PicoDecoderHFConfig
272
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
273
- _tied_weights_keys = []
274
  def __init__(self, config: PicoDecoderHFConfig):
275
  super().__init__(config)
276
  self.pico_decoder = PicoDecoder(config)
277
 
278
- def forward(self, input_ids, past_key_values=None, use_cache=False, **kwargs):
279
- logits, past_key_values = self.pico_decoder(input_ids, past_key_values, use_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  if use_cache:
281
- return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
282
- return CausalLMOutput(logits=logits)
 
 
 
283
 
284
 
285
  PicoDecoderHFConfig.register_for_auto_class()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  from dataclasses import asdict
3
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
26
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
27
 
28
  def forward(self, x):
29
+ return self._norm(x.float()).type_as(x) * self.weight
 
30
 
31
 
32
  class RoPE(nn.Module):
 
36
  super().__init__()
37
  self.theta = config.position_emb_theta
38
  self.dim = config.d_model // config.attention_n_heads
 
39
  if RoPE._freqs_cis_tensor is None:
40
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
41
+ config.max_seq_len, self.theta, self.dim
42
+ )
43
+ self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
44
 
45
  @classmethod
46
  def _setup_freqs_cis(cls, seq_len, theta, dim):
47
  _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
48
+ freqs = torch.outer(torch.arange(seq_len), _freqs)
 
49
  return torch.polar(torch.ones_like(freqs), freqs)
50
 
51
  def get_freqs_cis(self, input_shape, start_pos, end_pos):
52
+ _f = self._freqs_cis[start_pos:end_pos]
53
  ndim = len(input_shape)
54
  assert 0 <= 1 < ndim
55
+ assert _f.shape == (input_shape[1], input_shape[-1])
56
  shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
57
+ return _f.view(*shape)
58
 
59
  def forward(self, queries, keys, start_pos=0):
60
+ q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
61
+ k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
62
+ fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
63
+ return (
64
+ torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
65
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys),
66
+ )
 
 
67
 
68
 
69
  class Attention(nn.Module):
70
  def __init__(self, config):
71
  super().__init__()
72
+ self.n_heads = config.attention_n_heads
73
+ self.n_kv_heads = config.attention_n_kv_heads
74
  self.batch_size = config.batch_size
75
  self.max_seq_len = config.max_seq_len
76
+ d = config.d_model
77
+ self.head_dim = d // self.n_heads
78
+ self.n_rep = self.n_heads // self.n_kv_heads
79
+ self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
80
+ self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
81
+ self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
82
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
83
+ self.rope = RoPE(config)
84
 
85
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
86
  bsz, seq_len, _ = input.shape
87
+ queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
88
+ keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
89
+ values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
 
90
  start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
91
  queries, keys = self.rope(queries, keys, start_pos)
92
  if past_key_values is not None:
93
  keys = torch.cat([past_key_values[0], keys], dim=1)
94
  values = torch.cat([past_key_values[1], values], dim=1)
95
+ cached_keys = keys if use_cache else None
96
+ cached_values = values if use_cache else None
 
 
97
  queries = queries.transpose(1, 2)
98
  keys = keys.transpose(1, 2)
99
  values = values.transpose(1, 2)
 
102
  keys = keys.repeat_interleave(self.n_rep, dim=-3)
103
  values = values.repeat_interleave(self.n_rep, dim=-3)
104
  apply_gqa = False
105
+ with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
106
+ out = F.scaled_dot_product_attention(
 
107
  queries.contiguous(), keys.contiguous(), values.contiguous(),
108
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
109
  enable_gqa=apply_gqa,
110
  )
111
+ out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
112
+ return self.o_proj(out), (cached_keys, cached_values)
113
 
114
 
115
  class SwiGLU(nn.Module):
 
132
  self.swiglu_norm = RMSNorm(config)
133
 
134
  def forward(self, input, mask=None, past_key_values=None, use_cache=False):
135
+ attn_out, cached = self.attention(
136
  self.attention_norm(input), mask=mask,
137
  past_key_values=past_key_values, use_cache=use_cache,
138
  )
139
+ h = input + attn_out
140
+ return h + self.swiglu(self.swiglu_norm(h)), cached
 
141
 
142
 
143
  class PicoDecoder(nn.Module):
144
  def __init__(self, model_config):
145
  super().__init__()
146
  self.config = model_config
147
+ self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
148
  self.layers = nn.ModuleList(
149
+ [PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)]
150
  )
151
+ self.output_norm = RMSNorm(model_config)
152
+ self.de_embedding_proj = nn.Linear(model_config.d_model, model_config.vocab_size, bias=False)
153
 
154
  def convert_to_hf_model(self):
155
  hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
 
163
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
164
  mask = None
165
  if seq_len > 1:
166
+ mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1)
 
167
  if past_key_values is not None:
168
  mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
169
  mask = mask.to(h.device)
 
173
  h, layer_cached = layer(h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
174
  if use_cache:
175
  cached_key_values += (layer_cached,)
176
+ return self.de_embedding_proj(self.output_norm(h)).float(), cached_key_values
 
 
177
 
178
 
179
  class PicoDecoderHFConfig(PretrainedConfig):
180
+ """HuggingFace config for BeetleLM PicoDecoder.
181
+ Defaults match generate_configs.py MODEL_BASE exactly.
182
+ """
183
  model_type = "pico_decoder"
184
 
 
 
 
185
  def __init__(
186
  self,
187
  n_layers: int = 14,
 
197
  dropout: float = 0.1,
198
  **kwargs,
199
  ):
200
+ # Fix: guard against None/0/missing attention_n_kv_heads in old config.json
201
  if not attention_n_kv_heads:
202
  attention_n_kv_heads = attention_n_heads
203
  super().__init__(**kwargs)
204
+ self.n_layers = n_layers
205
+ self.d_model = d_model
206
+ self.vocab_size = vocab_size
207
+ self.attention_n_heads = attention_n_heads
208
+ self.attention_n_kv_heads = attention_n_kv_heads
209
+ self.max_seq_len = max_seq_len
210
+ self.batch_size = batch_size
211
+ self.position_emb_theta = position_emb_theta
212
+ self.activation_hidden_dim = activation_hidden_dim
213
+ self.norm_eps = norm_eps
214
+ self.dropout = dropout
215
 
216
  @classmethod
217
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
 
228
 
229
 
230
  class PicoDecoderHF(PreTrainedModel):
231
+ """HuggingFace wrapper for BeetleLM PicoDecoder.
232
+ Load with: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
233
  """
234
+ config_class = PicoDecoderHFConfig
235
+ _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
236
+ _tied_weights_keys = [] # Fix: required by transformers >= 4.38
237
+
 
 
238
  def __init__(self, config: PicoDecoderHFConfig):
239
  super().__init__(config)
240
  self.pico_decoder = PicoDecoder(config)
241
 
242
+ def get_input_embeddings(self):
243
+ return self.pico_decoder.embedding_proj
244
+
245
+ def set_input_embeddings(self, value):
246
+ self.pico_decoder.embedding_proj = value
247
+
248
+ def forward(self, input_ids=None, past_key_values=None, use_cache=False,
249
+ labels=None, **kwargs):
250
+ logits, new_past = self.pico_decoder(input_ids, past_key_values, use_cache)
251
+ loss = None
252
+ if labels is not None:
253
+ shift_logits = logits[:, :-1].contiguous()
254
+ shift_labels = labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1)
255
+ loss = F.cross_entropy(
256
+ shift_logits.view(-1, self.config.vocab_size),
257
+ shift_labels.view(-1),
258
+ )
259
  if use_cache:
260
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_past)
261
+ return CausalLMOutput(loss=loss, logits=logits)
262
+
263
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
264
+ return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
265
 
266
 
267
  PicoDecoderHFConfig.register_for_auto_class()