suchirsalhan commited on
Commit
d141bde
·
verified ·
1 Parent(s): 3bef99a

Fix: vocab_size=32000 (BPE base from model.vocab); top-level weights; all compat fixes

Browse files
Files changed (2) hide show
  1. config.json +23 -0
  2. pico_decoder.py +281 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PicoDecoderHF"
4
+ ],
5
+ "model_type": "pico_decoder",
6
+ "auto_map": {
7
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
8
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
9
+ },
10
+ "n_layers": 14,
11
+ "d_model": 768,
12
+ "attention_n_heads": 12,
13
+ "attention_n_kv_heads": 1,
14
+ "max_seq_len": 512,
15
+ "batch_size": 64,
16
+ "position_emb_theta": 10000.0,
17
+ "activation_hidden_dim": 3072,
18
+ "norm_eps": 1e-05,
19
+ "dropout": 0.1,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.48.3",
22
+ "vocab_size": 32000
23
+ }
pico_decoder.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict
2
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.attention import SDPBackend, sdpa_kernel
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
9
+ try:
10
+ if TYPE_CHECKING:
11
+ from src.config import ModelConfig
12
+ except ImportError:
13
+ pass
14
+
15
+
16
+ class RMSNorm(torch.nn.Module):
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ self.eps = config.norm_eps
20
+ self.weight = nn.Parameter(torch.ones(config.d_model))
21
+ def _norm(self, x):
22
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
23
+ def forward(self, x):
24
+ return self._norm(x.float()).type_as(x) * self.weight
25
+
26
+
27
+ class RoPE(nn.Module):
28
+ """
29
+ Rotary Position Embedding.
30
+ freqs_cis is computed lazily on first use and cached per-device,
31
+ avoiding meta-tensor issues when HF loads with low_cpu_mem_usage=True.
32
+ """
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.theta = config.position_emb_theta
36
+ self.dim = config.d_model // config.attention_n_heads
37
+ self.max_seq = config.max_seq_len
38
+ # NOT a buffer — plain dict so it never touches the meta device
39
+ self._cache: Dict[torch.device, torch.Tensor] = {}
40
+
41
+ def _get_freqs_cis(self, device: torch.device) -> torch.Tensor:
42
+ if device not in self._cache:
43
+ freqs = 1.0 / (
44
+ self.theta ** (
45
+ torch.arange(0, self.dim, 2, device=device).float() / self.dim
46
+ )
47
+ )
48
+ t = torch.arange(self.max_seq, device=device)
49
+ freqs = torch.outer(t, freqs)
50
+ self._cache[device] = torch.polar(torch.ones_like(freqs), freqs)
51
+ return self._cache[device]
52
+
53
+ def get_freqs_cis(self, input_shape, start_pos, end_pos, device):
54
+ _f = self._get_freqs_cis(device)[start_pos:end_pos]
55
+ ndim = len(input_shape)
56
+ assert 0 <= 1 < ndim and _f.shape == (input_shape[1], input_shape[-1])
57
+ return _f.view(*[d if i == 1 or i == ndim - 1 else 1
58
+ for i, d in enumerate(input_shape)])
59
+
60
+ def forward(self, queries, keys, start_pos=0):
61
+ device = queries.device
62
+ q_ = torch.view_as_complex(queries.float().reshape(*queries.shape[:-1], -1, 2))
63
+ k_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
64
+ fc = self.get_freqs_cis(q_.shape, start_pos, start_pos + q_.shape[1], device)
65
+ return (torch.view_as_real(q_ * fc).flatten(3).type_as(queries),
66
+ torch.view_as_real(k_ * fc).flatten(3).type_as(keys))
67
+
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.n_heads = config.attention_n_heads
73
+ self.n_kv_heads = config.attention_n_kv_heads
74
+ self.batch_size = config.batch_size
75
+ self.max_seq_len = config.max_seq_len
76
+ d = config.d_model
77
+ self.head_dim = d // self.n_heads
78
+ self.n_rep = self.n_heads // self.n_kv_heads
79
+ self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
80
+ self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
81
+ self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
82
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d, bias=False)
83
+ self.rope = RoPE(config)
84
+ def forward(self, input, mask=None, past_key_values=None, use_cache=False):
85
+ bsz, seq_len, _ = input.shape
86
+ queries = self.q_proj(input).view(bsz, seq_len, self.n_heads, self.head_dim)
87
+ keys = self.k_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
88
+ values = self.v_proj(input).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
89
+ start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
90
+ queries, keys = self.rope(queries, keys, start_pos)
91
+ if past_key_values is not None:
92
+ keys = torch.cat([past_key_values[0], keys], dim=1)
93
+ values = torch.cat([past_key_values[1], values], dim=1)
94
+ cached_keys = keys if use_cache else None
95
+ cached_values = values if use_cache else None
96
+ queries = queries.transpose(1, 2)
97
+ keys = keys.transpose(1, 2)
98
+ values = values.transpose(1, 2)
99
+ apply_gqa = self.n_rep > 1
100
+ if apply_gqa and queries.device.type == "mps":
101
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
102
+ values = values.repeat_interleave(self.n_rep, dim=-3)
103
+ apply_gqa = False
104
+ attn_mask = mask.to(queries.dtype) if mask is not None else None
105
+ with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
106
+ attn_output = F.scaled_dot_product_attention(
107
+ queries.contiguous(), keys.contiguous(), values.contiguous(),
108
+ attn_mask=attn_mask, enable_gqa=apply_gqa,
109
+ )
110
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
111
+ return self.o_proj(attn_output), (cached_keys, cached_values)
112
+
113
+
114
+ class SwiGLU(nn.Module):
115
+ def __init__(self, config):
116
+ super().__init__()
117
+ self.w_0 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
118
+ self.w_1 = nn.Linear(config.d_model, config.activation_hidden_dim, bias=False)
119
+ self.w_2 = nn.Linear(config.activation_hidden_dim, config.d_model, bias=False)
120
+ def forward(self, x):
121
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
122
+
123
+
124
+ class PicoDecoderBlock(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ self.attention = Attention(config)
128
+ self.swiglu = SwiGLU(config)
129
+ self.attention_norm = RMSNorm(config)
130
+ self.swiglu_norm = RMSNorm(config)
131
+ def forward(self, input, mask=None, past_key_values=None, use_cache=False):
132
+ attention_output, cached_key_values = self.attention(
133
+ self.attention_norm(input), mask=mask,
134
+ past_key_values=past_key_values, use_cache=use_cache)
135
+ h = input + attention_output
136
+ return h + self.swiglu(self.swiglu_norm(h)), cached_key_values
137
+
138
+
139
+ class PicoDecoder(nn.Module):
140
+ def __init__(self, model_config):
141
+ super().__init__()
142
+ self.config = model_config
143
+ self.embedding_proj = nn.Embedding(model_config.vocab_size, model_config.d_model)
144
+ self.layers = nn.ModuleList(
145
+ [PicoDecoderBlock(model_config) for _ in range(model_config.n_layers)])
146
+ self.output_norm = RMSNorm(model_config)
147
+ self.de_embedding_proj = nn.Linear(
148
+ model_config.d_model, model_config.vocab_size, bias=False)
149
+ def convert_to_hf_model(self):
150
+ hf = PicoDecoderHF(PicoDecoderHFConfig.from_dataclass(self.config))
151
+ hf.load_state_dict(self.state_dict())
152
+ return hf
153
+ def forward(self, input_ids, past_key_values=None, use_cache=False):
154
+ seq_len = input_ids.shape[-1]
155
+ h = self.embedding_proj(input_ids)
156
+ start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
157
+ mask = None
158
+ if seq_len > 1:
159
+ mask = torch.full((seq_len, seq_len), float("-inf"))
160
+ mask = torch.triu(mask, diagonal=1)
161
+ if past_key_values is not None:
162
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
163
+ mask = mask.to(h.device)
164
+ cached_key_values = () if use_cache else None
165
+ for idx, layer in enumerate(self.layers):
166
+ layer_past = past_key_values[idx] if past_key_values is not None else None
167
+ h, layer_cached = layer(
168
+ h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
169
+ if use_cache:
170
+ cached_key_values += (layer_cached,)
171
+ return self.de_embedding_proj(self.output_norm(h)).float(), cached_key_values
172
+
173
+
174
+ class PicoDecoderHFConfig(PretrainedConfig):
175
+ model_type = "pico_decoder"
176
+ def __init__(self,
177
+ n_layers=14, d_model=768, vocab_size=32768,
178
+ attention_n_heads=12, attention_n_kv_heads=1,
179
+ max_seq_len=512, batch_size=64, position_emb_theta=10000.0,
180
+ activation_hidden_dim=3072, norm_eps=1e-5, dropout=0.1,
181
+ **kwargs):
182
+ if not attention_n_kv_heads:
183
+ attention_n_kv_heads = attention_n_heads
184
+ super().__init__(**kwargs)
185
+ self.n_layers = n_layers
186
+ self.d_model = d_model
187
+ self.vocab_size = vocab_size
188
+ self.attention_n_heads = attention_n_heads
189
+ self.attention_n_kv_heads = attention_n_kv_heads
190
+ self.max_seq_len = max_seq_len
191
+ self.batch_size = batch_size
192
+ self.position_emb_theta = position_emb_theta
193
+ self.activation_hidden_dim = activation_hidden_dim
194
+ self.norm_eps = norm_eps
195
+ self.dropout = dropout
196
+ @classmethod
197
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
198
+ pico_config = cls(**config_dict)
199
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
200
+ unused_kwargs = {k: v for k, v in kwargs.items() if not hasattr(pico_config, k)}
201
+ if return_unused_kwargs:
202
+ return pico_config, unused_kwargs
203
+ return pico_config
204
+ @classmethod
205
+ def from_dataclass(cls, model_config):
206
+ return cls.from_dict(asdict(model_config))
207
+
208
+
209
+ class PicoDecoderHF(PreTrainedModel):
210
+ """
211
+ HuggingFace wrapper for BeetleLM PicoDecoder.
212
+ Usage: AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
213
+ Works with CPU, CUDA (A100, etc.), and MPS out of the box.
214
+ """
215
+ config_class = PicoDecoderHFConfig
216
+ _no_split_modules = ["PicoDecoderBlock"]
217
+ _tied_weights_keys = []
218
+
219
+ def __init__(self, config: PicoDecoderHFConfig):
220
+ super().__init__(config)
221
+ self.embedding_proj = nn.Embedding(config.vocab_size, config.d_model)
222
+ self.layers = nn.ModuleList(
223
+ [PicoDecoderBlock(config) for _ in range(config.n_layers)])
224
+ self.output_norm = RMSNorm(config)
225
+ self.de_embedding_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
226
+ # Required: lets HF finalize weight init and meta-device materialization
227
+ self.post_init()
228
+
229
+ # Required for low_cpu_mem_usage / Accelerate device-dispatch to work
230
+ def _init_weights(self, module):
231
+ if isinstance(module, nn.Linear):
232
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
233
+ if module.bias is not None:
234
+ nn.init.zeros_(module.bias)
235
+ elif isinstance(module, nn.Embedding):
236
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
237
+ elif isinstance(module, RMSNorm):
238
+ nn.init.ones_(module.weight)
239
+
240
+ def get_input_embeddings(self): return self.embedding_proj
241
+ def set_input_embeddings(self, value): self.embedding_proj = value
242
+
243
+ def forward(self, input_ids=None, past_key_values=None,
244
+ use_cache=False, labels=None, **kwargs):
245
+ seq_len = input_ids.shape[-1]
246
+ h = self.embedding_proj(input_ids)
247
+ start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
248
+ mask = None
249
+ if seq_len > 1:
250
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=h.device)
251
+ mask = torch.triu(mask, diagonal=1)
252
+ if past_key_values is not None:
253
+ mask = torch.hstack([torch.zeros((seq_len, start_pos), device=h.device), mask])
254
+ cached_key_values = () if use_cache else None
255
+ for idx, layer in enumerate(self.layers):
256
+ layer_past = past_key_values[idx] if past_key_values is not None else None
257
+ h, layer_cached = layer(
258
+ h, mask=mask, past_key_values=layer_past, use_cache=use_cache)
259
+ if use_cache:
260
+ cached_key_values += (layer_cached,)
261
+ logits = self.de_embedding_proj(self.output_norm(h)).float()
262
+ loss = None
263
+ if labels is not None:
264
+ loss = F.cross_entropy(
265
+ logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
266
+ labels[:, 1:].contiguous().clamp(0, self.config.vocab_size - 1).view(-1),
267
+ )
268
+ if use_cache:
269
+ return CausalLMOutputWithPast(
270
+ loss=loss, logits=logits, past_key_values=cached_key_values)
271
+ return CausalLMOutput(loss=loss, logits=logits)
272
+
273
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
274
+ return {"input_ids": input_ids,
275
+ "past_key_values": past_key_values,
276
+ "use_cache": True}
277
+
278
+
279
+ PicoDecoderHFConfig.register_for_auto_class()
280
+ PicoDecoderHF.register_for_auto_class("AutoModel")
281
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")