suchirsalhan commited on
Commit
ab43fee
·
verified ·
1 Parent(s): 3816bbf

Fix: vocab_size=32000 (BPE base from model.vocab); top-level weights; all compat fixes

Browse files
Files changed (1) hide show
  1. pico_decoder.py +47 -20
pico_decoder.py CHANGED
@@ -25,29 +25,43 @@ class RMSNorm(torch.nn.Module):
25
 
26
 
27
  class RoPE(nn.Module):
28
- _freqs_cis_tensor = None
 
 
 
 
29
  def __init__(self, config):
30
  super().__init__()
31
- self.theta = config.position_emb_theta
32
- self.dim = config.d_model // config.attention_n_heads
33
- if RoPE._freqs_cis_tensor is None:
34
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
35
- config.max_seq_len, self.theta, self.dim)
36
- self.register_buffer("_freqs_cis", RoPE._freqs_cis_tensor, persistent=False)
37
- @classmethod
38
- def _setup_freqs_cis(cls, seq_len, theta, dim):
39
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
40
- freqs = torch.outer(torch.arange(seq_len), _freqs)
41
- return torch.polar(torch.ones_like(freqs), freqs)
42
- def get_freqs_cis(self, input_shape, start_pos, end_pos):
43
- _f = self._freqs_cis[start_pos:end_pos]
 
 
 
 
 
 
 
44
  ndim = len(input_shape)
45
  assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
46
- return _f.view(*[d if i==1 or i==ndim-1 else 1 for i,d in enumerate(input_shape)])
 
 
47
  def forward(self, queries, keys, start_pos=0):
 
48
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
49
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
50
- fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1])
51
  return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
52
  torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
53
 
@@ -196,9 +210,10 @@ class PicoDecoderHF(PreTrainedModel):
196
  """
197
  HuggingFace wrapper for BeetleLM PicoDecoder.
198
  Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
 
199
  """
200
  config_class = PicoDecoderHFConfig
201
- _no_split_modules = ["PicoDecoderBlock", "Attention", "SwiGLU", "RMSNorm"]
202
  _tied_weights_keys = []
203
 
204
  @property
@@ -212,6 +227,19 @@ class PicoDecoderHF(PreTrainedModel):
212
  [PicoDecoderBlock(config) for _ in range(config.n_layers)])
213
  self.output_norm = RMSNorm(config)
214
  self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def get_input_embeddings(self): return self.embedding_proj
217
  def set_input_embeddings(self, value): self.embedding_proj = value
@@ -223,11 +251,10 @@ class PicoDecoderHF(PreTrainedModel):
223
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
224
  mask = None
225
  if seq_len > 1:
226
- mask = torch.full((seq_len, seq_len), float("-inf"))
227
  mask = torch.triu(mask, diagonal=1)
228
  if past_key_values is not None:
229
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
230
- mask = mask.to(h.device)
231
  cached_key_values = () if use_cache else None
232
  for idx, layer in enumerate(self.layers):
233
  layer_past = past_key_values[idx] if past_key_values is not None else None
 
25
 
26
 
27
  class RoPE(nn.Module):
28
+ """
29
+ Rotary Position Embedding.
30
+ freqs_cis is computed lazily on first use and cached per-device,
31
+ avoiding meta-tensor issues when HF loads with low_cpu_mem_usage=True.
32
+ """
33
  def __init__(self, config):
34
  super().__init__()
35
+ self.theta = config.position_emb_theta
36
+ self.dim = config.d_model // config.attention_n_heads
37
+ self.max_seq = config.max_seq_len
38
+ # NOT a buffer — plain dict so it never touches the meta device
39
+ self._cache: Dict[torch.device, torch.Tensor] = {}
40
+
41
+ def _get_freqs_cis(self, device: torch.device) -> torch.Tensor:
42
+ if device not in self._cache:
43
+ freqs = 1.0 / (
44
+ self.theta ** (
45
+ torch.arange(0, self.dim, 2, device=device).float() / self.dim
46
+ )
47
+ )
48
+ t = torch.arange(self.max_seq, device=device)
49
+ freqs = torch.outer(t, freqs)
50
+ self._cache[device] = torch.polar(torch.ones_like(freqs), freqs)
51
+ return self._cache[device]
52
+
53
+ def get_freqs_cis(self, input_shape, start_pos, end_pos, device):
54
+ _f = self._get_freqs_cis(device)[start_pos:end_pos]
55
  ndim = len(input_shape)
56
  assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
57
+ return _f.view(*[d if i == 1 or i == ndim - 1 else 1
58
+ for i, d in enumerate(input_shape)])
59
+
60
  def forward(self, queries, keys, start_pos=0):
61
+ device = queries.device
62
  q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
63
  k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
64
+ fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1], device)
65
  return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
66
  torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
67
 
 
210
  """
211
  HuggingFace wrapper for BeetleLM PicoDecoder.
212
  Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
213
+ Works with CPU, CUDA (A100, etc.), and MPS out of the box.
214
  """
215
  config_class = PicoDecoderHFConfig
216
+ _no_split_modules = ["PicoDecoderBlock"]
217
  _tied_weights_keys = []
218
 
219
  @property
 
227
  [PicoDecoderBlock(config) for _ in range(config.n_layers)])
228
  self.output_norm = RMSNorm(config)
229
  self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
230
+ # Required: lets HF finalize weight init and meta-device materialization
231
+ self.post_init()
232
+
233
+ # Required for low_cpu_mem_usage / Accelerate device-dispatch to work
234
+ def _init_weights(self, module):
235
+ if isinstance(module, nn.Linear):
236
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
237
+ if module.bias is not None:
238
+ nn.init.zeros_(module.bias)
239
+ elif isinstance(module, nn.Embedding):
240
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
241
+ elif isinstance(module, RMSNorm):
242
+ nn.init.ones_(module.weight)
243
 
244
  def get_input_embeddings(self): return self.embedding_proj
245
  def set_input_embeddings(self, value): self.embedding_proj = value
 
251
  start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
252
  mask = None
253
  if seq_len > 1:
254
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=h.device)
255
  mask = torch.triu(mask, diagonal=1)
256
  if past_key_values is not None:
257
+ mask = torch.hstack([torch.zeros((seq_len, start_pos), device=h.device), mask])
 
258
  cached_key_values = () if use_cache else None
259
  for idx, layer in enumerate(self.layers):
260
  layer_past = past_key_values[idx] if past_key_values is not None else None