ThomasTheMaker commited on
Commit
80a9a5e
·
verified ·
1 Parent(s): 0a38bb7

Delete pico-decoder-tiny

Browse files
Files changed (41) hide show
  1. pico-decoder-tiny/checkpoints/step_0/config.json +0 -22
  2. pico-decoder-tiny/checkpoints/step_0/fabric_state/checkpoint.pt +0 -3
  3. pico-decoder-tiny/checkpoints/step_0/generation_config.json +0 -4
  4. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_activations.pt +0 -3
  5. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  6. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json +0 -19
  7. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/state.json +0 -13
  8. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_gradients.pt +0 -3
  9. pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_weights.pt +0 -3
  10. pico-decoder-tiny/checkpoints/step_0/model.safetensors +0 -3
  11. pico-decoder-tiny/checkpoints/step_0/pico_decoder.py +0 -856
  12. pico-decoder-tiny/checkpoints/step_0/special_tokens_map.json +0 -16
  13. pico-decoder-tiny/checkpoints/step_0/tokenizer.json +0 -0
  14. pico-decoder-tiny/checkpoints/step_0/tokenizer_config.json +0 -239
  15. pico-decoder-tiny/checkpoints/step_1000/config.json +0 -22
  16. pico-decoder-tiny/checkpoints/step_1000/fabric_state/checkpoint.pt +0 -3
  17. pico-decoder-tiny/checkpoints/step_1000/generation_config.json +0 -4
  18. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_activations.pt +0 -3
  19. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  20. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json +0 -19
  21. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/state.json +0 -13
  22. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_gradients.pt +0 -3
  23. pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_weights.pt +0 -3
  24. pico-decoder-tiny/checkpoints/step_1000/model.safetensors +0 -3
  25. pico-decoder-tiny/checkpoints/step_1000/pico_decoder.py +0 -871
  26. pico-decoder-tiny/checkpoints/step_1000/special_tokens_map.json +0 -16
  27. pico-decoder-tiny/checkpoints/step_1000/tokenizer.json +0 -0
  28. pico-decoder-tiny/checkpoints/step_1000/tokenizer_config.json +0 -239
  29. pico-decoder-tiny/checkpoints/step_1755/config.json +0 -22
  30. pico-decoder-tiny/checkpoints/step_1755/fabric_state/checkpoint.pt +0 -3
  31. pico-decoder-tiny/checkpoints/step_1755/generation_config.json +0 -4
  32. pico-decoder-tiny/checkpoints/step_1755/model.safetensors +0 -3
  33. pico-decoder-tiny/checkpoints/step_1755/pico_decoder.py +0 -871
  34. pico-decoder-tiny/checkpoints/step_1755/special_tokens_map.json +0 -16
  35. pico-decoder-tiny/checkpoints/step_1755/tokenizer.json +0 -0
  36. pico-decoder-tiny/checkpoints/step_1755/tokenizer_config.json +0 -239
  37. pico-decoder-tiny/eval_results/step_0.json +0 -1
  38. pico-decoder-tiny/eval_results/step_1000.json +0 -1
  39. pico-decoder-tiny/eval_results/step_1755.json +0 -1
  40. pico-decoder-tiny/logs/log_20250828_220514.log +0 -185
  41. pico-decoder-tiny/training_config.yaml +0 -74
pico-decoder-tiny/checkpoints/step_0/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b50a50fd67e7a1dfa214a074549428c03047ccc26357734db80084015a538b90
3
- size 45187997
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33fda803f83cb9653b125b70cf8386e39812fa3e30e4746b52db22f5a248be93
3
- size 33819
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:85fcf259ee523f219f5133a952ded67c5a339f05dc40df33188f33a1838bb3e0
3
- size 65384
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "d0a54608fc979d10",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b48826ca39adc92f370b9c3aa0ed42dce5dbf1ffe4fcfe1c320df08c344016bb
3
- size 2371527
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/learning_dynamics/train_weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c029ef92a6494ae121c847e432e52e6a8ff3bf7d9fef3e61bef871c1e9a9aa02
3
- size 2371443
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1852515eb5c8556533445f22edf523884b9f8cc44812379a6a951668a4ffa3a3
3
- size 45143592
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/pico_decoder.py DELETED
@@ -1,856 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
35
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
36
- from transformers.generation import GenerationConfig
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, 'max_position_embeddings'):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, 'vocab_size'):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
677
- """Resize token embeddings."""
678
- old_embeddings = self.get_input_embeddings()
679
- if new_num_tokens is None:
680
- new_num_tokens = old_embeddings.num_embeddings
681
-
682
- new_embeddings = torch.nn.Embedding(new_num_tokens, old_embeddings.embedding_dim)
683
- new_embeddings.weight.data[:old_embeddings.num_embeddings] = old_embeddings.weight.data
684
-
685
- self.pico_decoder.embedding_proj = new_embeddings
686
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
687
- old_embeddings.embedding_dim, new_num_tokens, bias=False
688
- )
689
-
690
- return new_embeddings
691
-
692
-
693
- # Register for auto classes
694
- PicoDecoderHFConfig.register_for_auto_class()
695
- PicoDecoderHF.register_for_auto_class("AutoModel")
696
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
697
-
698
-
699
- ########################################################
700
- #
701
- # New PicoDecoderForCausalLM class for generation support
702
- #
703
- ########################################################
704
-
705
-
706
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
707
- """
708
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
709
-
710
- This class is designed to work with existing checkpoints and provides full generation support.
711
- It inherits from the right base classes that HuggingFace expects for text generation.
712
- """
713
-
714
- config_class = PicoDecoderHFConfig
715
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
716
- main_input_name = "input_ids"
717
-
718
- def __init__(self, config: PicoDecoderHFConfig):
719
- super().__init__(config)
720
- self.pico_decoder = PicoDecoder(config)
721
- # Initialize generation config with defaults
722
- self.generation_config = GenerationConfig()
723
- # Set some reasonable defaults for the model
724
- if hasattr(config, 'max_position_embeddings'):
725
- self.generation_config.max_length = config.max_position_embeddings
726
- if hasattr(config, 'vocab_size'):
727
- self.generation_config.vocab_size = config.vocab_size
728
-
729
- def forward(
730
- self,
731
- input_ids: torch.Tensor,
732
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
733
- use_cache: bool = False,
734
- **kwargs,
735
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
736
- """Forward pass for text generation."""
737
- logits, past_key_values = self.pico_decoder(
738
- input_ids, past_key_values, use_cache
739
- )
740
- if use_cache:
741
- return CausalLMOutputWithPast(
742
- logits=logits,
743
- past_key_values=past_key_values,
744
- )
745
- else:
746
- return CausalLMOutput(
747
- logits=logits,
748
- )
749
-
750
- def prepare_inputs_for_generation(
751
- self,
752
- input_ids: torch.LongTensor,
753
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
754
- attention_mask: Optional[torch.LongTensor] = None,
755
- **kwargs
756
- ) -> Dict[str, Any]:
757
- """Prepare inputs for generation."""
758
- # If we have past_key_values, we only need the last token
759
- if past_key_values is not None:
760
- input_ids = input_ids[:, -1:]
761
-
762
- return {
763
- "input_ids": input_ids,
764
- "past_key_values": past_key_values,
765
- "use_cache": True,
766
- }
767
-
768
- def get_input_embeddings(self):
769
- """Get the input embeddings layer."""
770
- return self.pico_decoder.embedding_proj
771
-
772
- def set_input_embeddings(self, value):
773
- """Set the input embeddings layer."""
774
- self.pico_decoder.embedding_proj = value
775
-
776
- def get_output_embeddings(self):
777
- """Get the output embeddings layer."""
778
- return self.pico_decoder.de_embedding_proj
779
-
780
- def set_output_embeddings(self, value):
781
- """Set the output embeddings layer."""
782
- self.pico_decoder.de_embedding_proj = value
783
-
784
- def get_lm_head(self):
785
- """Get the language model head."""
786
- return self.pico_decoder.de_embedding_proj
787
-
788
- def can_generate(self) -> bool:
789
- """Check if the model can generate text."""
790
- return True
791
-
792
- @property
793
- def is_encoder_decoder(self) -> bool:
794
- """Check if the model is an encoder-decoder model."""
795
- return False
796
-
797
- @property
798
- def can_use_cache(self) -> bool:
799
- """Check if the model can use KV cache."""
800
- return True
801
-
802
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
803
- """Resize token embeddings."""
804
- old_embeddings = self.get_input_embeddings()
805
- if new_num_tokens is None:
806
- new_num_tokens = old_embeddings.num_embeddings
807
-
808
- new_embeddings = torch.nn.Embedding(new_num_tokens, old_embeddings.embedding_dim)
809
- new_embeddings.weight.data[:old_embeddings.num_embeddings] = old_embeddings.weight.data
810
-
811
- self.pico_decoder.embedding_proj = new_embeddings
812
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
813
- old_embeddings.embedding_dim, new_num_tokens, bias=False
814
- )
815
-
816
- return new_embeddings
817
-
818
- @classmethod
819
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
820
- """
821
- Load a pretrained model from a checkpoint.
822
-
823
- This method handles loading from both the old PicoDecoderHF format and the new format.
824
- """
825
- # First try to load with the new class
826
- try:
827
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
828
- except Exception as e:
829
- print(f"Failed to load with new class: {e}")
830
- print("Attempting to load with legacy class and convert...")
831
-
832
- # Try to load with the old class and convert
833
- try:
834
- from transformers import AutoModel
835
- old_model = AutoModel.from_pretrained(
836
- pretrained_model_name_or_path,
837
- trust_remote_code=True,
838
- *model_args,
839
- **kwargs
840
- )
841
-
842
- # Create new model instance
843
- new_model = cls(old_model.config)
844
-
845
- # Copy state dict
846
- new_model.load_state_dict(old_model.state_dict(), strict=False)
847
-
848
- return new_model
849
-
850
- except Exception as e2:
851
- print(f"Failed to convert from legacy format: {e2}")
852
- raise e
853
-
854
-
855
- # Register the new class
856
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_0/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny/checkpoints/step_0/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3e80057e3aeee9020555bb47c5510dffa49aa7bb95aa28626c755fd1bcd84c6
3
- size 135543171
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7e89d5ed24dce96b1c7926d0525d09f6fc80cd7ce982fdf4cb66817dcfbaeba9
3
- size 33819
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b91f61ac8aae1d61544370a4c754bda11f9501a2a8b2bdae615ead87385d6d0
3
- size 64520
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "6c5c8added4701f3",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7cfeda57ef54a270df54117da31b7fa317f97f60870cd825f2f063392e85c1ad
3
- size 2371527
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/learning_dynamics/train_weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3b41e2ffb9726f463b88554fa3adf500de7cd5e7700cd3c15a0052d19e80ed3
3
- size 2371443
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1584cdf11592f978f9dd63c44fd15eec69dbc665b9b4c7a45d89a8f736931968
3
- size 45143592
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/pico_decoder.py DELETED
@@ -1,871 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
- from transformers.generation import GenerationConfig
36
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, "max_position_embeddings"):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, "vocab_size"):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs,
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(
677
- self, new_num_tokens: Optional[int] = None
678
- ) -> torch.nn.Embedding:
679
- """Resize token embeddings."""
680
- old_embeddings = self.get_input_embeddings()
681
- if new_num_tokens is None:
682
- new_num_tokens = old_embeddings.num_embeddings
683
-
684
- new_embeddings = torch.nn.Embedding(
685
- new_num_tokens, old_embeddings.embedding_dim
686
- )
687
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
- old_embeddings.weight.data
689
- )
690
-
691
- self.pico_decoder.embedding_proj = new_embeddings
692
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
- old_embeddings.embedding_dim, new_num_tokens, bias=False
694
- )
695
-
696
- return new_embeddings
697
-
698
-
699
- # Register for auto classes
700
- PicoDecoderHFConfig.register_for_auto_class()
701
- PicoDecoderHF.register_for_auto_class("AutoModel")
702
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
-
704
-
705
- ########################################################
706
- #
707
- # New PicoDecoderForCausalLM class for generation support
708
- #
709
- ########################################################
710
-
711
-
712
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
- """
714
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
-
716
- This class is designed to work with existing checkpoints and provides full generation support.
717
- It inherits from the right base classes that HuggingFace expects for text generation.
718
- """
719
-
720
- config_class = PicoDecoderHFConfig
721
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
- main_input_name = "input_ids"
723
-
724
- def __init__(self, config: PicoDecoderHFConfig):
725
- super().__init__(config)
726
- self.pico_decoder = PicoDecoder(config)
727
- # Initialize generation config with defaults
728
- self.generation_config = GenerationConfig()
729
- # Set some reasonable defaults for the model
730
- if hasattr(config, "max_position_embeddings"):
731
- self.generation_config.max_length = config.max_position_embeddings
732
- if hasattr(config, "vocab_size"):
733
- self.generation_config.vocab_size = config.vocab_size
734
-
735
- def forward(
736
- self,
737
- input_ids: torch.Tensor,
738
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
- use_cache: bool = False,
740
- **kwargs,
741
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
- """Forward pass for text generation."""
743
- logits, past_key_values = self.pico_decoder(
744
- input_ids, past_key_values, use_cache
745
- )
746
- if use_cache:
747
- return CausalLMOutputWithPast(
748
- logits=logits,
749
- past_key_values=past_key_values,
750
- )
751
- else:
752
- return CausalLMOutput(
753
- logits=logits,
754
- )
755
-
756
- def prepare_inputs_for_generation(
757
- self,
758
- input_ids: torch.LongTensor,
759
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
- attention_mask: Optional[torch.LongTensor] = None,
761
- **kwargs,
762
- ) -> Dict[str, Any]:
763
- """Prepare inputs for generation."""
764
- # If we have past_key_values, we only need the last token
765
- if past_key_values is not None:
766
- input_ids = input_ids[:, -1:]
767
-
768
- return {
769
- "input_ids": input_ids,
770
- "past_key_values": past_key_values,
771
- "use_cache": True,
772
- }
773
-
774
- def get_input_embeddings(self):
775
- """Get the input embeddings layer."""
776
- return self.pico_decoder.embedding_proj
777
-
778
- def set_input_embeddings(self, value):
779
- """Set the input embeddings layer."""
780
- self.pico_decoder.embedding_proj = value
781
-
782
- def get_output_embeddings(self):
783
- """Get the output embeddings layer."""
784
- return self.pico_decoder.de_embedding_proj
785
-
786
- def set_output_embeddings(self, value):
787
- """Set the output embeddings layer."""
788
- self.pico_decoder.de_embedding_proj = value
789
-
790
- def get_lm_head(self):
791
- """Get the language model head."""
792
- return self.pico_decoder.de_embedding_proj
793
-
794
- def can_generate(self) -> bool:
795
- """Check if the model can generate text."""
796
- return True
797
-
798
- @property
799
- def is_encoder_decoder(self) -> bool:
800
- """Check if the model is an encoder-decoder model."""
801
- return False
802
-
803
- @property
804
- def can_use_cache(self) -> bool:
805
- """Check if the model can use KV cache."""
806
- return True
807
-
808
- def resize_token_embeddings(
809
- self, new_num_tokens: Optional[int] = None
810
- ) -> torch.nn.Embedding:
811
- """Resize token embeddings."""
812
- old_embeddings = self.get_input_embeddings()
813
- if new_num_tokens is None:
814
- new_num_tokens = old_embeddings.num_embeddings
815
-
816
- new_embeddings = torch.nn.Embedding(
817
- new_num_tokens, old_embeddings.embedding_dim
818
- )
819
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
- old_embeddings.weight.data
821
- )
822
-
823
- self.pico_decoder.embedding_proj = new_embeddings
824
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
- old_embeddings.embedding_dim, new_num_tokens, bias=False
826
- )
827
-
828
- return new_embeddings
829
-
830
- @classmethod
831
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
- """
833
- Load a pretrained model from a checkpoint.
834
-
835
- This method handles loading from both the old PicoDecoderHF format and the new format.
836
- """
837
- # First try to load with the new class
838
- try:
839
- return super().from_pretrained(
840
- pretrained_model_name_or_path, *model_args, **kwargs
841
- )
842
- except Exception as e:
843
- print(f"Failed to load with new class: {e}")
844
- print("Attempting to load with legacy class and convert...")
845
-
846
- # Try to load with the old class and convert
847
- try:
848
- from transformers import AutoModel
849
-
850
- old_model = AutoModel.from_pretrained(
851
- pretrained_model_name_or_path,
852
- trust_remote_code=True,
853
- *model_args,
854
- **kwargs,
855
- )
856
-
857
- # Create new model instance
858
- new_model = cls(old_model.config)
859
-
860
- # Copy state dict
861
- new_model.load_state_dict(old_model.state_dict(), strict=False)
862
-
863
- return new_model
864
-
865
- except Exception as e2:
866
- print(f"Failed to convert from legacy format: {e2}")
867
- raise e
868
-
869
-
870
- # Register the new class
871
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1000/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny/checkpoints/step_1000/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f78990840c8b2a26e89eea5f5414a84a9e8a0c76b9637d3cac17ec22e5486678
3
- size 135543171
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3084d44929c019203e308a3f500b8792ca69ff273c69edb7eb6a433268e540f9
3
- size 45143592
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/pico_decoder.py DELETED
@@ -1,871 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
- from transformers.generation import GenerationConfig
36
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, "max_position_embeddings"):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, "vocab_size"):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs,
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(
677
- self, new_num_tokens: Optional[int] = None
678
- ) -> torch.nn.Embedding:
679
- """Resize token embeddings."""
680
- old_embeddings = self.get_input_embeddings()
681
- if new_num_tokens is None:
682
- new_num_tokens = old_embeddings.num_embeddings
683
-
684
- new_embeddings = torch.nn.Embedding(
685
- new_num_tokens, old_embeddings.embedding_dim
686
- )
687
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
- old_embeddings.weight.data
689
- )
690
-
691
- self.pico_decoder.embedding_proj = new_embeddings
692
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
- old_embeddings.embedding_dim, new_num_tokens, bias=False
694
- )
695
-
696
- return new_embeddings
697
-
698
-
699
- # Register for auto classes
700
- PicoDecoderHFConfig.register_for_auto_class()
701
- PicoDecoderHF.register_for_auto_class("AutoModel")
702
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
-
704
-
705
- ########################################################
706
- #
707
- # New PicoDecoderForCausalLM class for generation support
708
- #
709
- ########################################################
710
-
711
-
712
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
- """
714
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
-
716
- This class is designed to work with existing checkpoints and provides full generation support.
717
- It inherits from the right base classes that HuggingFace expects for text generation.
718
- """
719
-
720
- config_class = PicoDecoderHFConfig
721
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
- main_input_name = "input_ids"
723
-
724
- def __init__(self, config: PicoDecoderHFConfig):
725
- super().__init__(config)
726
- self.pico_decoder = PicoDecoder(config)
727
- # Initialize generation config with defaults
728
- self.generation_config = GenerationConfig()
729
- # Set some reasonable defaults for the model
730
- if hasattr(config, "max_position_embeddings"):
731
- self.generation_config.max_length = config.max_position_embeddings
732
- if hasattr(config, "vocab_size"):
733
- self.generation_config.vocab_size = config.vocab_size
734
-
735
- def forward(
736
- self,
737
- input_ids: torch.Tensor,
738
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
- use_cache: bool = False,
740
- **kwargs,
741
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
- """Forward pass for text generation."""
743
- logits, past_key_values = self.pico_decoder(
744
- input_ids, past_key_values, use_cache
745
- )
746
- if use_cache:
747
- return CausalLMOutputWithPast(
748
- logits=logits,
749
- past_key_values=past_key_values,
750
- )
751
- else:
752
- return CausalLMOutput(
753
- logits=logits,
754
- )
755
-
756
- def prepare_inputs_for_generation(
757
- self,
758
- input_ids: torch.LongTensor,
759
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
- attention_mask: Optional[torch.LongTensor] = None,
761
- **kwargs,
762
- ) -> Dict[str, Any]:
763
- """Prepare inputs for generation."""
764
- # If we have past_key_values, we only need the last token
765
- if past_key_values is not None:
766
- input_ids = input_ids[:, -1:]
767
-
768
- return {
769
- "input_ids": input_ids,
770
- "past_key_values": past_key_values,
771
- "use_cache": True,
772
- }
773
-
774
- def get_input_embeddings(self):
775
- """Get the input embeddings layer."""
776
- return self.pico_decoder.embedding_proj
777
-
778
- def set_input_embeddings(self, value):
779
- """Set the input embeddings layer."""
780
- self.pico_decoder.embedding_proj = value
781
-
782
- def get_output_embeddings(self):
783
- """Get the output embeddings layer."""
784
- return self.pico_decoder.de_embedding_proj
785
-
786
- def set_output_embeddings(self, value):
787
- """Set the output embeddings layer."""
788
- self.pico_decoder.de_embedding_proj = value
789
-
790
- def get_lm_head(self):
791
- """Get the language model head."""
792
- return self.pico_decoder.de_embedding_proj
793
-
794
- def can_generate(self) -> bool:
795
- """Check if the model can generate text."""
796
- return True
797
-
798
- @property
799
- def is_encoder_decoder(self) -> bool:
800
- """Check if the model is an encoder-decoder model."""
801
- return False
802
-
803
- @property
804
- def can_use_cache(self) -> bool:
805
- """Check if the model can use KV cache."""
806
- return True
807
-
808
- def resize_token_embeddings(
809
- self, new_num_tokens: Optional[int] = None
810
- ) -> torch.nn.Embedding:
811
- """Resize token embeddings."""
812
- old_embeddings = self.get_input_embeddings()
813
- if new_num_tokens is None:
814
- new_num_tokens = old_embeddings.num_embeddings
815
-
816
- new_embeddings = torch.nn.Embedding(
817
- new_num_tokens, old_embeddings.embedding_dim
818
- )
819
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
- old_embeddings.weight.data
821
- )
822
-
823
- self.pico_decoder.embedding_proj = new_embeddings
824
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
- old_embeddings.embedding_dim, new_num_tokens, bias=False
826
- )
827
-
828
- return new_embeddings
829
-
830
- @classmethod
831
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
- """
833
- Load a pretrained model from a checkpoint.
834
-
835
- This method handles loading from both the old PicoDecoderHF format and the new format.
836
- """
837
- # First try to load with the new class
838
- try:
839
- return super().from_pretrained(
840
- pretrained_model_name_or_path, *model_args, **kwargs
841
- )
842
- except Exception as e:
843
- print(f"Failed to load with new class: {e}")
844
- print("Attempting to load with legacy class and convert...")
845
-
846
- # Try to load with the old class and convert
847
- try:
848
- from transformers import AutoModel
849
-
850
- old_model = AutoModel.from_pretrained(
851
- pretrained_model_name_or_path,
852
- trust_remote_code=True,
853
- *model_args,
854
- **kwargs,
855
- )
856
-
857
- # Create new model instance
858
- new_model = cls(old_model.config)
859
-
860
- # Copy state dict
861
- new_model.load_state_dict(old_model.state_dict(), strict=False)
862
-
863
- return new_model
864
-
865
- except Exception as e2:
866
- print(f"Failed to convert from legacy format: {e2}")
867
- raise e
868
-
869
-
870
- # Register the new class
871
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/checkpoints/step_1755/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny/checkpoints/step_1755/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/eval_results/step_0.json DELETED
@@ -1 +0,0 @@
1
- {"paloma": Infinity}
 
 
pico-decoder-tiny/eval_results/step_1000.json DELETED
@@ -1 +0,0 @@
1
- {"paloma": 9.54583880403771e+19}
 
 
pico-decoder-tiny/eval_results/step_1755.json DELETED
@@ -1 +0,0 @@
1
- {"paloma": 2.945795672816324e+21}
 
 
pico-decoder-tiny/logs/log_20250828_220514.log DELETED
@@ -1,185 +0,0 @@
1
- 2025-08-28 22:07:06 - pico-train - INFO - Step 0 -- 📊 Evaluation Results
2
- 2025-08-28 22:07:06 - pico-train - INFO - └── paloma: inf
3
- 2025-08-28 22:07:06 - pico-train - INFO - ==================================================
4
- 2025-08-28 22:07:06 - pico-train - INFO - ✨ Training Configuration
5
- 2025-08-28 22:07:06 - pico-train - INFO - ==================================================
6
- 2025-08-28 22:07:06 - pico-train - INFO - ╭─────────────────────────────────────────────────────╮
7
- 2025-08-28 22:07:06 - pico-train - INFO - │ checkpointing: │
8
- 2025-08-28 22:07:06 - pico-train - INFO - │ checkpoints_dir: checkpoints │
9
- 2025-08-28 22:07:06 - pico-train - INFO - │ evaluation: │
10
- 2025-08-28 22:07:06 - pico-train - INFO - │ eval_results_dir: eval_results │
11
- 2025-08-28 22:07:06 - pico-train - INFO - │ fabric_checkpoint_dir: fabric_state │
12
- 2025-08-28 22:07:06 - pico-train - INFO - │ fabric_checkpoint_filename: checkpoint.pt │
13
- 2025-08-28 22:07:06 - pico-train - INFO - │ hf_checkpoint: │
14
- 2025-08-28 22:07:06 - pico-train - INFO - │ collection_slug: null │
15
- 2025-08-28 22:07:06 - pico-train - INFO - │ repo_id: ThomasTheMaker/pico-decoder-tiny │
16
- 2025-08-28 22:07:06 - pico-train - INFO - │ learning_dynamics: │
17
- 2025-08-28 22:07:06 - pico-train - INFO - │ batch_size: 4 │
18
- 2025-08-28 22:07:06 - pico-train - INFO - │ eval_data: null │
19
- 2025-08-28 22:07:06 - pico-train - INFO - │ layer_suffixes: │
20
- 2025-08-28 22:07:06 - pico-train - INFO - │ - attention.v_proj │
21
- 2025-08-28 22:07:06 - pico-train - INFO - │ - attention.o_proj │
22
- 2025-08-28 22:07:06 - pico-train - INFO - │ - swiglu.w_2 │
23
- 2025-08-28 22:07:06 - pico-train - INFO - │ sequence_idx: -1 │
24
- 2025-08-28 22:07:06 - pico-train - INFO - │ learning_dynamics_dir: learning_dynamics │
25
- 2025-08-28 22:07:06 - pico-train - INFO - │ logs_dir: logs │
26
- 2025-08-28 22:07:06 - pico-train - INFO - │ run_name: pico-decoder-tiny │
27
- 2025-08-28 22:07:06 - pico-train - INFO - │ runs_dir: runs │
28
- 2025-08-28 22:07:06 - pico-train - INFO - │ save_every_n_steps: 1000 │
29
- 2025-08-28 22:07:06 - pico-train - INFO - │ save_to_hf: true │
30
- 2025-08-28 22:07:06 - pico-train - INFO - │ training: │
31
- 2025-08-28 22:07:06 - pico-train - INFO - │ auto_resume: true │
32
- 2025-08-28 22:07:06 - pico-train - INFO - │ data: │
33
- 2025-08-28 22:07:06 - pico-train - INFO - │ dataloader: │
34
- 2025-08-28 22:07:06 - pico-train - INFO - │ batch_size: 256 │
35
- 2025-08-28 22:07:06 - pico-train - INFO - │ dataset: │
36
- 2025-08-28 22:07:06 - pico-train - INFO - │ name: pico-lm/pretokenized-dolma-tinsy │
37
- 2025-08-28 22:07:06 - pico-train - INFO - │ tokenizer: │
38
- 2025-08-28 22:07:06 - pico-train - INFO - │ name: allenai/OLMo-7B-0724-hf │
39
- 2025-08-28 22:07:06 - pico-train - INFO - │ vocab_size: 50304 │
40
- 2025-08-28 22:07:06 - pico-train - INFO - │ evaluation: │
41
- 2025-08-28 22:07:06 - pico-train - INFO - │ metrics: │
42
- 2025-08-28 22:07:06 - pico-train - INFO - │ - paloma │
43
- 2025-08-28 22:07:06 - pico-train - INFO - │ paloma: │
44
- 2025-08-28 22:07:06 - pico-train - INFO - │ batch_size: 1 │
45
- 2025-08-28 22:07:06 - pico-train - INFO - │ dataset_name: pico-lm/pretokenized-paloma-tinsy │
46
- 2025-08-28 22:07:06 - pico-train - INFO - │ dataset_split: val │
47
- 2025-08-28 22:07:06 - pico-train - INFO - │ max_length: 2048 │
48
- 2025-08-28 22:07:06 - pico-train - INFO - │ model: │
49
- 2025-08-28 22:07:06 - pico-train - INFO - │ activation_hidden_dim: 384 │
50
- 2025-08-28 22:07:06 - pico-train - INFO - │ attention_n_heads: 12 │
51
- 2025-08-28 22:07:06 - pico-train - INFO - │ attention_n_kv_heads: 4 │
52
- 2025-08-28 22:07:06 - pico-train - INFO - │ batch_size: 1024 │
53
- 2025-08-28 22:07:06 - pico-train - INFO - │ d_model: 96 │
54
- 2025-08-28 22:07:06 - pico-train - INFO - │ max_seq_len: 2048 │
55
- 2025-08-28 22:07:06 - pico-train - INFO - │ model_type: pico_decoder │
56
- 2025-08-28 22:07:06 - pico-train - INFO - │ n_layers: 12 │
57
- 2025-08-28 22:07:06 - pico-train - INFO - │ norm_eps: 1.0e-06 │
58
- 2025-08-28 22:07:06 - pico-train - INFO - │ position_emb_theta: 10000.0 │
59
- 2025-08-28 22:07:06 - pico-train - INFO - │ vocab_size: 50304 │
60
- 2025-08-28 22:07:06 - pico-train - INFO - │ monitoring: │
61
- 2025-08-28 22:07:06 - pico-train - INFO - │ logging: │
62
- 2025-08-28 22:07:06 - pico-train - INFO - │ log_every_n_steps: 100 │
63
- 2025-08-28 22:07:06 - pico-train - INFO - │ log_level: INFO │
64
- 2025-08-28 22:07:06 - pico-train - INFO - │ save_to_wandb: false │
65
- 2025-08-28 22:07:06 - pico-train - INFO - │ wandb: │
66
- 2025-08-28 22:07:06 - pico-train - INFO - │ entity: boymyc │
67
- 2025-08-28 22:07:06 - pico-train - INFO - │ project: pico-decoder-tiny │
68
- 2025-08-28 22:07:06 - pico-train - INFO - │ training: │
69
- 2025-08-28 22:07:06 - pico-train - INFO - │ fabric: │
70
- 2025-08-28 22:07:06 - pico-train - INFO - │ accelerator: cuda │
71
- 2025-08-28 22:07:06 - pico-train - INFO - │ num_devices: 1 │
72
- 2025-08-28 22:07:06 - pico-train - INFO - │ num_nodes: 1 │
73
- 2025-08-28 22:07:06 - pico-train - INFO - │ precision: bf16-mixed │
74
- 2025-08-28 22:07:06 - pico-train - INFO - │ max_steps: 200000 │
75
- 2025-08-28 22:07:06 - pico-train - INFO - │ optimization: │
76
- 2025-08-28 22:07:06 - pico-train - INFO - │ gradient_accumulation_steps: 4 │
77
- 2025-08-28 22:07:06 - pico-train - INFO - │ lr: 0.0003 │
78
- 2025-08-28 22:07:06 - pico-train - INFO - │ lr_scheduler: linear_with_warmup │
79
- 2025-08-28 22:07:06 - pico-train - INFO - │ lr_warmup_steps: 2500 │
80
- 2025-08-28 22:07:06 - pico-train - INFO - │ optimizer: adamw │
81
- 2025-08-28 22:07:06 - pico-train - INFO - │ │
82
- 2025-08-28 22:07:06 - pico-train - INFO - ╰─────────────────────────────────────────────────────╯
83
- 2025-08-28 22:07:06 - pico-train - INFO - ==================================================
84
- 2025-08-28 22:07:06 - pico-train - INFO - ⛭ Runtime Summary:
85
- 2025-08-28 22:07:06 - pico-train - INFO - ==================================================
86
- 2025-08-28 22:07:06 - pico-train - INFO - Starting from step: 0
87
- 2025-08-28 22:07:06 - pico-train - INFO - Model Setup:
88
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Total Parameters: 11,282,784
89
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Trainable Parameters: 11,282,784
90
- 2025-08-28 22:07:06 - pico-train - INFO - Distributed Setup:
91
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Number of Devices: 1
92
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Device Type: NVIDIA GeForce RTX 5090
93
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Available Memory: 33.68 GB
94
- 2025-08-28 22:07:06 - pico-train - INFO - Software Setup:
95
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Python Version: 3.10.12
96
- 2025-08-28 22:07:06 - pico-train - INFO - └─ PyTorch Version: 2.8.0+cu128
97
- 2025-08-28 22:07:06 - pico-train - INFO - └─ CUDA Version: 12.8
98
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Operating System: Linux 6.8.0-63-generic
99
- 2025-08-28 22:07:06 - pico-train - INFO - Batch Size Configuration:
100
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Global Batch Size: 4
101
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Per Device Batch Size: 1
102
- 2025-08-28 22:07:06 - pico-train - INFO - └─ Gradient Accumulation Steps: 4
103
- 2025-08-28 22:07:06 - pico-train - INFO - ==================================================
104
- 2025-08-28 22:07:07 - pico-train - INFO - Step 0 -- 🔄 Training Metrics
105
- 2025-08-28 22:07:07 - pico-train - INFO - ├── Loss: 10.9886
106
- 2025-08-28 22:07:07 - pico-train - INFO - ├── Learning Rate: 0.00e+00
107
- 2025-08-28 22:07:07 - pico-train - INFO - └── Inf/NaN count: 0
108
- 2025-08-28 22:07:07 - pico-train - INFO - Step 0 -- 📈 Saving Learning Dynamics
109
- 2025-08-28 22:08:00 - pico-train - INFO - Step 100 -- 🔄 Training Metrics
110
- 2025-08-28 22:08:00 - pico-train - INFO - ├── Loss: 10.9373
111
- 2025-08-28 22:08:00 - pico-train - INFO - ├── Learning Rate: 1.20e-05
112
- 2025-08-28 22:08:00 - pico-train - INFO - └── Inf/NaN count: 0
113
- 2025-08-28 22:08:51 - pico-train - INFO - Step 200 -- 🔄 Training Metrics
114
- 2025-08-28 22:08:51 - pico-train - INFO - ├── Loss: 10.5423
115
- 2025-08-28 22:08:51 - pico-train - INFO - ├── Learning Rate: 2.40e-05
116
- 2025-08-28 22:08:51 - pico-train - INFO - └── Inf/NaN count: 0
117
- 2025-08-28 22:09:43 - pico-train - INFO - Step 300 -- 🔄 Training Metrics
118
- 2025-08-28 22:09:43 - pico-train - INFO - ├── Loss: 9.9452
119
- 2025-08-28 22:09:43 - pico-train - INFO - ├── Learning Rate: 3.60e-05
120
- 2025-08-28 22:09:43 - pico-train - INFO - └── Inf/NaN count: 0
121
- 2025-08-28 22:10:34 - pico-train - INFO - Step 400 -- 🔄 Training Metrics
122
- 2025-08-28 22:10:34 - pico-train - INFO - ├── Loss: 9.4490
123
- 2025-08-28 22:10:34 - pico-train - INFO - ├── Learning Rate: 4.80e-05
124
- 2025-08-28 22:10:34 - pico-train - INFO - └── Inf/NaN count: 0
125
- 2025-08-28 22:11:25 - pico-train - INFO - Step 500 -- 🔄 Training Metrics
126
- 2025-08-28 22:11:25 - pico-train - INFO - ├── Loss: 8.8455
127
- 2025-08-28 22:11:25 - pico-train - INFO - ├── Learning Rate: 6.00e-05
128
- 2025-08-28 22:11:25 - pico-train - INFO - └── Inf/NaN count: 0
129
- 2025-08-28 22:12:16 - pico-train - INFO - Step 600 -- 🔄 Training Metrics
130
- 2025-08-28 22:12:16 - pico-train - INFO - ├── Loss: 8.1482
131
- 2025-08-28 22:12:16 - pico-train - INFO - ├── Learning Rate: 7.20e-05
132
- 2025-08-28 22:12:16 - pico-train - INFO - └── Inf/NaN count: 0
133
- 2025-08-28 22:13:08 - pico-train - INFO - Step 700 -- 🔄 Training Metrics
134
- 2025-08-28 22:13:08 - pico-train - INFO - ├── Loss: 7.4303
135
- 2025-08-28 22:13:08 - pico-train - INFO - ├── Learning Rate: 8.40e-05
136
- 2025-08-28 22:13:08 - pico-train - INFO - └── Inf/NaN count: 0
137
- 2025-08-28 22:13:59 - pico-train - INFO - Step 800 -- 🔄 Training Metrics
138
- 2025-08-28 22:13:59 - pico-train - INFO - ├── Loss: 7.0363
139
- 2025-08-28 22:13:59 - pico-train - INFO - ├── Learning Rate: 9.60e-05
140
- 2025-08-28 22:13:59 - pico-train - INFO - └── Inf/NaN count: 0
141
- 2025-08-28 22:14:50 - pico-train - INFO - Step 900 -- 🔄 Training Metrics
142
- 2025-08-28 22:14:50 - pico-train - INFO - ├── Loss: 6.9702
143
- 2025-08-28 22:14:50 - pico-train - INFO - ├── Learning Rate: 1.08e-04
144
- 2025-08-28 22:14:50 - pico-train - INFO - └── Inf/NaN count: 0
145
- 2025-08-28 22:15:40 - pico-train - INFO - Step 1000 -- 💾 Saving Checkpoint
146
- 2025-08-28 22:17:41 - pico-train - INFO - Step 1000 -- 📊 Evaluation Results
147
- 2025-08-28 22:17:41 - pico-train - INFO - └── paloma: 9.54583880403771e+19
148
- 2025-08-28 22:17:43 - pico-train - INFO - Step 1000 -- 🔄 Training Metrics
149
- 2025-08-28 22:17:43 - pico-train - INFO - ├── Loss: 6.8975
150
- 2025-08-28 22:17:43 - pico-train - INFO - ├── Learning Rate: 1.20e-04
151
- 2025-08-28 22:17:43 - pico-train - INFO - └── Inf/NaN count: 0
152
- 2025-08-28 22:17:43 - pico-train - INFO - Step 1000 -- 📈 Saving Learning Dynamics
153
- 2025-08-28 22:18:37 - pico-train - INFO - Step 1100 -- 🔄 Training Metrics
154
- 2025-08-28 22:18:37 - pico-train - INFO - ├── Loss: 6.8920
155
- 2025-08-28 22:18:37 - pico-train - INFO - ├── Learning Rate: 1.32e-04
156
- 2025-08-28 22:18:37 - pico-train - INFO - └── Inf/NaN count: 0
157
- 2025-08-28 22:19:28 - pico-train - INFO - Step 1200 -- 🔄 Training Metrics
158
- 2025-08-28 22:19:28 - pico-train - INFO - ├── Loss: 6.6684
159
- 2025-08-28 22:19:28 - pico-train - INFO - ├── Learning Rate: 1.44e-04
160
- 2025-08-28 22:19:28 - pico-train - INFO - └── Inf/NaN count: 0
161
- 2025-08-28 22:20:18 - pico-train - INFO - Step 1300 -- 🔄 Training Metrics
162
- 2025-08-28 22:20:18 - pico-train - INFO - ├── Loss: 6.4754
163
- 2025-08-28 22:20:18 - pico-train - INFO - ├── Learning Rate: 1.56e-04
164
- 2025-08-28 22:20:18 - pico-train - INFO - └── Inf/NaN count: 0
165
- 2025-08-28 22:21:09 - pico-train - INFO - Step 1400 -- 🔄 Training Metrics
166
- 2025-08-28 22:21:09 - pico-train - INFO - ├── Loss: 6.3649
167
- 2025-08-28 22:21:09 - pico-train - INFO - ├── Learning Rate: 1.68e-04
168
- 2025-08-28 22:21:09 - pico-train - INFO - └── Inf/NaN count: 0
169
- 2025-08-28 22:22:00 - pico-train - INFO - Step 1500 -- 🔄 Training Metrics
170
- 2025-08-28 22:22:00 - pico-train - INFO - ├── Loss: 6.2981
171
- 2025-08-28 22:22:00 - pico-train - INFO - ├── Learning Rate: 1.80e-04
172
- 2025-08-28 22:22:00 - pico-train - INFO - └── Inf/NaN count: 0
173
- 2025-08-28 22:22:51 - pico-train - INFO - Step 1600 -- 🔄 Training Metrics
174
- 2025-08-28 22:22:51 - pico-train - INFO - ├── Loss: 6.1551
175
- 2025-08-28 22:22:51 - pico-train - INFO - ├── Learning Rate: 1.92e-04
176
- 2025-08-28 22:22:51 - pico-train - INFO - └── Inf/NaN count: 0
177
- 2025-08-28 22:23:42 - pico-train - INFO - Step 1700 -- 🔄 Training Metrics
178
- 2025-08-28 22:23:42 - pico-train - INFO - ├── Loss: 5.9163
179
- 2025-08-28 22:23:42 - pico-train - INFO - ├── Learning Rate: 2.04e-04
180
- 2025-08-28 22:23:42 - pico-train - INFO - └── Inf/NaN count: 0
181
- 2025-08-28 22:24:09 - pico-train - INFO - Step 1755 -- 💾 Saving Final Checkpoint
182
- 2025-08-28 22:26:24 - pico-train - INFO - Step 1755 -- 📊 Evaluation Results
183
- 2025-08-28 22:26:24 - pico-train - INFO - └── paloma: 2.945795672816324e+21
184
- 2025-08-28 22:26:24 - pico-train - INFO - 🎉 Training complete! Final step: 1755
185
- 2025-08-28 22:26:24 - pico-train - WARNING - Note: Training stopped before max steps (200000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny/training_config.yaml DELETED
@@ -1,74 +0,0 @@
1
- checkpointing:
2
- checkpoints_dir: checkpoints
3
- evaluation:
4
- eval_results_dir: eval_results
5
- fabric_checkpoint_dir: fabric_state
6
- fabric_checkpoint_filename: checkpoint.pt
7
- hf_checkpoint:
8
- collection_slug: null
9
- repo_id: ThomasTheMaker/pico-decoder-tiny
10
- learning_dynamics:
11
- batch_size: 4
12
- eval_data: null
13
- layer_suffixes:
14
- - attention.v_proj
15
- - attention.o_proj
16
- - swiglu.w_2
17
- sequence_idx: -1
18
- learning_dynamics_dir: learning_dynamics
19
- logs_dir: logs
20
- run_name: pico-decoder-tiny
21
- runs_dir: runs
22
- save_every_n_steps: 1000
23
- save_to_hf: true
24
- training:
25
- auto_resume: true
26
- data:
27
- dataloader:
28
- batch_size: 256
29
- dataset:
30
- name: pico-lm/pretokenized-dolma-tinsy
31
- tokenizer:
32
- name: allenai/OLMo-7B-0724-hf
33
- vocab_size: 50304
34
- evaluation:
35
- metrics:
36
- - paloma
37
- paloma:
38
- batch_size: 1
39
- dataset_name: pico-lm/pretokenized-paloma-tinsy
40
- dataset_split: val
41
- max_length: 2048
42
- model:
43
- activation_hidden_dim: 384
44
- attention_n_heads: 12
45
- attention_n_kv_heads: 4
46
- batch_size: 1024
47
- d_model: 96
48
- max_seq_len: 2048
49
- model_type: pico_decoder
50
- n_layers: 12
51
- norm_eps: 1.0e-06
52
- position_emb_theta: 10000.0
53
- vocab_size: 50304
54
- monitoring:
55
- logging:
56
- log_every_n_steps: 100
57
- log_level: INFO
58
- save_to_wandb: false
59
- wandb:
60
- entity: boymyc
61
- project: pico-decoder-tiny
62
- training:
63
- fabric:
64
- accelerator: cuda
65
- num_devices: 1
66
- num_nodes: 1
67
- precision: bf16-mixed
68
- max_steps: 200000
69
- optimization:
70
- gradient_accumulation_steps: 4
71
- lr: 0.0003
72
- lr_scheduler: linear_with_warmup
73
- lr_warmup_steps: 2500
74
- optimizer: adamw