Fix: vocab_size=32000 (BPE base from model.vocab); top-level weights; all compat fixes
Browse files- pico_decoder.py +47 -20
pico_decoder.py
CHANGED
|
@@ -25,29 +25,43 @@ class RMSNorm(torch.nn.Module):
|
|
| 25 |
|
| 26 |
|
| 27 |
class RoPE(nn.Module):
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def __init__(self, config):
|
| 30 |
super().__init__()
|
| 31 |
-
self.theta
|
| 32 |
-
self.dim
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
ndim = len(input_shape)
|
| 45 |
assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
|
| 46 |
-
return _f.view(*[d if i==1 or i==ndim-1 else 1
|
|
|
|
|
|
|
| 47 |
def forward(self, queries, keys, start_pos=0):
|
|
|
|
| 48 |
q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
|
| 49 |
k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
|
| 50 |
-
fc
|
| 51 |
return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
|
| 52 |
torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
|
| 53 |
|
|
@@ -196,9 +210,10 @@ class PicoDecoderHF(PreTrainedModel):
|
|
| 196 |
"""
|
| 197 |
HuggingFace wrapper for BeetleLM PicoDecoder.
|
| 198 |
Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
|
|
|
|
| 199 |
"""
|
| 200 |
config_class = PicoDecoderHFConfig
|
| 201 |
-
_no_split_modules = ["PicoDecoderBlock"
|
| 202 |
_tied_weights_keys = []
|
| 203 |
|
| 204 |
@property
|
|
@@ -212,6 +227,19 @@ class PicoDecoderHF(PreTrainedModel):
|
|
| 212 |
[PicoDecoderBlock(config) for _ in range(config.n_layers)])
|
| 213 |
self.output_norm = RMSNorm(config)
|
| 214 |
self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
def get_input_embeddings(self): return self.embedding_proj
|
| 217 |
def set_input_embeddings(self, value): self.embedding_proj = value
|
|
@@ -223,11 +251,10 @@ class PicoDecoderHF(PreTrainedModel):
|
|
| 223 |
start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
|
| 224 |
mask = None
|
| 225 |
if seq_len > 1:
|
| 226 |
-
mask = torch.full((seq_len, seq_len), float("-inf"))
|
| 227 |
mask = torch.triu(mask, diagonal=1)
|
| 228 |
if past_key_values is not None:
|
| 229 |
-
mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
|
| 230 |
-
mask = mask.to(h.device)
|
| 231 |
cached_key_values = () if use_cache else None
|
| 232 |
for idx, layer in enumerate(self.layers):
|
| 233 |
layer_past = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class RoPE(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Rotary Position Embedding.
|
| 30 |
+
freqs_cis is computed lazily on first use and cached per-device,
|
| 31 |
+
avoiding meta-tensor issues when HF loads with low_cpu_mem_usage=True.
|
| 32 |
+
"""
|
| 33 |
def __init__(self, config):
|
| 34 |
super().__init__()
|
| 35 |
+
self.theta = config.position_emb_theta
|
| 36 |
+
self.dim = config.d_model // config.attention_n_heads
|
| 37 |
+
self.max_seq = config.max_seq_len
|
| 38 |
+
# NOT a buffer — plain dict so it never touches the meta device
|
| 39 |
+
self._cache: Dict[torch.device, torch.Tensor] = {}
|
| 40 |
+
|
| 41 |
+
def _get_freqs_cis(self, device: torch.device) -> torch.Tensor:
|
| 42 |
+
if device not in self._cache:
|
| 43 |
+
freqs = 1.0 / (
|
| 44 |
+
self.theta ** (
|
| 45 |
+
torch.arange(0, self.dim, 2, device=device).float() / self.dim
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
t = torch.arange(self.max_seq, device=device)
|
| 49 |
+
freqs = torch.outer(t, freqs)
|
| 50 |
+
self._cache[device] = torch.polar(torch.ones_like(freqs), freqs)
|
| 51 |
+
return self._cache[device]
|
| 52 |
+
|
| 53 |
+
def get_freqs_cis(self, input_shape, start_pos, end_pos, device):
|
| 54 |
+
_f = self._get_freqs_cis(device)[start_pos:end_pos]
|
| 55 |
ndim = len(input_shape)
|
| 56 |
assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
|
| 57 |
+
return _f.view(*[d if i == 1 or i == ndim - 1 else 1
|
| 58 |
+
for i, d in enumerate(input_shape)])
|
| 59 |
+
|
| 60 |
def forward(self, queries, keys, start_pos=0):
|
| 61 |
+
device = queries.device
|
| 62 |
q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
|
| 63 |
k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
|
| 64 |
+
fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1], device)
|
| 65 |
return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
|
| 66 |
torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
|
| 67 |
|
|
|
|
| 210 |
"""
|
| 211 |
HuggingFace wrapper for BeetleLM PicoDecoder.
|
| 212 |
Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
|
| 213 |
+
Works with CPU, CUDA (A100, etc.), and MPS out of the box.
|
| 214 |
"""
|
| 215 |
config_class = PicoDecoderHFConfig
|
| 216 |
+
_no_split_modules = ["PicoDecoderBlock"]
|
| 217 |
_tied_weights_keys = []
|
| 218 |
|
| 219 |
@property
|
|
|
|
| 227 |
[PicoDecoderBlock(config) for _ in range(config.n_layers)])
|
| 228 |
self.output_norm = RMSNorm(config)
|
| 229 |
self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 230 |
+
# Required: lets HF finalize weight init and meta-device materialization
|
| 231 |
+
self.post_init()
|
| 232 |
+
|
| 233 |
+
# Required for low_cpu_mem_usage / Accelerate device-dispatch to work
|
| 234 |
+
def _init_weights(self, module):
|
| 235 |
+
if isinstance(module, nn.Linear):
|
| 236 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 237 |
+
if module.bias is not None:
|
| 238 |
+
nn.init.zeros_(module.bias)
|
| 239 |
+
elif isinstance(module, nn.Embedding):
|
| 240 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 241 |
+
elif isinstance(module, RMSNorm):
|
| 242 |
+
nn.init.ones_(module.weight)
|
| 243 |
|
| 244 |
def get_input_embeddings(self): return self.embedding_proj
|
| 245 |
def set_input_embeddings(self, value): self.embedding_proj = value
|
|
|
|
| 251 |
start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
|
| 252 |
mask = None
|
| 253 |
if seq_len > 1:
|
| 254 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), device=h.device)
|
| 255 |
mask = torch.triu(mask, diagonal=1)
|
| 256 |
if past_key_values is not None:
|
| 257 |
+
mask = torch.hstack([torch.zeros((seq_len, start_pos), device=h.device), mask])
|
|
|
|
| 258 |
cached_key_values = () if use_cache else None
|
| 259 |
for idx, layer in enumerate(self.layers):
|
| 260 |
layer_past = past_key_values[idx] if past_key_values is not None else None
|