smithblack-0 commited on
Commit
e6fbdc8
·
verified ·
1 Parent(s): 0b88b09

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +64 -64
  2. config.json +22 -22
  3. huggingface.py +256 -256
  4. tokenizer_config.json +12 -12
README.md CHANGED
@@ -1,64 +1,64 @@
1
- ---
2
- language:
3
- - en
4
- license: mit
5
- library_name: transformers
6
- pipeline_tag: text-generation
7
- tags:
8
- - pytorch
9
- - research
10
- - llama
11
- ---
12
-
13
- # advanced-transformers-lib -- Llama 3 Baseline
14
-
15
- A Llama 3-style decoder-only transformer architecture for research. No pretrained
16
- weights -- pull the architecture from the Hub and instantiate a freshly initialised
17
- model from config. Override any parameter at instantiation time.
18
-
19
- > **Important:** `trust_remote_code=True` is required. It downloads the architecture
20
- > source files from the Hub and imports them into your Python process. Review the
21
- > source at [smithblack-0/llama3_baseline](https://huggingface.co/smithblack-0/llama3_baseline) before use.
22
-
23
- ## Usage
24
-
25
- ```python
26
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
27
-
28
- # Pull architecture config -- override any parameter at instantiation time
29
- config = AutoConfig.from_pretrained(
30
- "smithblack-0/llama3_baseline",
31
- trust_remote_code=True,
32
- num_hidden_layers=16, # example override
33
- )
34
-
35
- # Instantiate with fresh random weights -- no checkpoint required
36
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
37
-
38
- # Load tokenizer
39
- tokenizer = AutoTokenizer.from_pretrained("smithblack-0/llama3_baseline")
40
-
41
- # Save and reload after training
42
- model.save_pretrained("./checkpoint")
43
- model = AutoModelForCausalLM.from_pretrained("./checkpoint", trust_remote_code=True)
44
- ```
45
-
46
- ## Default Configuration
47
-
48
- | Parameter | Default |
49
- |-----------|---------|
50
- | `vocab_size` | 50277 |
51
- | `hidden_size` | 768 |
52
- | `intermediate_size` | 1568 |
53
- | `num_hidden_layers` | 24 |
54
- | `num_attention_heads` | 16 |
55
- | `num_key_value_heads` | 4 |
56
- | `head_dim` | 48 |
57
- | `max_position_embeddings` | 8192 |
58
- | `rope_theta` | 500000.0 |
59
-
60
- ## License
61
-
62
- MIT. Clean-room synthesis: the human author has not read the Llama source code.
63
- Architectural decisions derive from the published paper. Tokenizer is GPT-NeoX
64
- (`EleutherAI/gpt-neox-20b`, Apache 2.0).
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - pytorch
9
+ - research
10
+ - llama
11
+ ---
12
+
13
+ # advanced-transformers-lib -- Llama 3 Baseline
14
+
15
+ A Llama 3-style decoder-only transformer architecture for research. No pretrained
16
+ weights -- pull the architecture from the Hub and instantiate a freshly initialised
17
+ model from config. Override any parameter at instantiation time.
18
+
19
+ > **Important:** `trust_remote_code=True` is required. It downloads the architecture
20
+ > source files from the Hub and imports them into your Python process. Review the
21
+ > source at [smithblack-0/llama3_baseline](https://huggingface.co/smithblack-0/llama3_baseline) before use.
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
27
+
28
+ # Pull architecture config -- override any parameter at instantiation time
29
+ config = AutoConfig.from_pretrained(
30
+ "smithblack-0/llama3_baseline",
31
+ trust_remote_code=True,
32
+ num_hidden_layers=16, # example override
33
+ )
34
+
35
+ # Instantiate with fresh random weights -- no checkpoint required
36
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
37
+
38
+ # Load tokenizer
39
+ tokenizer = AutoTokenizer.from_pretrained("smithblack-0/llama3_baseline")
40
+
41
+ # Save and reload after training
42
+ model.save_pretrained("./checkpoint")
43
+ model = AutoModelForCausalLM.from_pretrained("./checkpoint", trust_remote_code=True)
44
+ ```
45
+
46
+ ## Default Configuration
47
+
48
+ | Parameter | Default |
49
+ |-----------|---------|
50
+ | `vocab_size` | 50277 |
51
+ | `hidden_size` | 768 |
52
+ | `intermediate_size` | 1568 |
53
+ | `num_hidden_layers` | 24 |
54
+ | `num_attention_heads` | 16 |
55
+ | `num_key_value_heads` | 4 |
56
+ | `head_dim` | 48 |
57
+ | `max_position_embeddings` | 8192 |
58
+ | `rope_theta` | 500000.0 |
59
+
60
+ ## License
61
+
62
+ MIT. Clean-room synthesis: the human author has not read the Llama source code.
63
+ Architectural decisions derive from the published paper. Tokenizer is GPT-NeoX
64
+ (`EleutherAI/gpt-neox-20b`, Apache 2.0).
config.json CHANGED
@@ -1,22 +1,22 @@
1
- {
2
- "attention_dropout": 0.0,
3
- "auto_map": {
4
- "AutoConfig": "configuration.Llama3Config",
5
- "AutoModelForCausalLM": "huggingface.Llama3ForCausalLM"
6
- },
7
- "head_dim": 48,
8
- "hidden_size": 768,
9
- "intermediate_size": 1568,
10
- "max_position_embeddings": 8192,
11
- "model_type": "llama3_baseline",
12
- "num_attention_heads": 16,
13
- "num_hidden_layers": 24,
14
- "num_key_value_heads": 4,
15
- "rms_norm_eps": 1e-05,
16
- "rope_parameters": null,
17
- "rope_theta": 500000.0,
18
- "tie_word_embeddings": false,
19
- "transformers_version": "5.3.0",
20
- "use_cache": true,
21
- "vocab_size": 50277
22
- }
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "auto_map": {
4
+ "AutoConfig": "configuration.Llama3Config",
5
+ "AutoModelForCausalLM": "huggingface.Llama3ForCausalLM"
6
+ },
7
+ "head_dim": 48,
8
+ "hidden_size": 768,
9
+ "intermediate_size": 1568,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "llama3_baseline",
12
+ "num_attention_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "num_key_value_heads": 4,
15
+ "rms_norm_eps": 1e-05,
16
+ "rope_parameters": null,
17
+ "rope_theta": 500000.0,
18
+ "tie_word_embeddings": false,
19
+ "transformers_version": "5.3.0",
20
+ "use_cache": true,
21
+ "vocab_size": 50277
22
+ }
huggingface.py CHANGED
@@ -1,256 +1,256 @@
1
- """HuggingFace wrapper for the Llama 3 baseline.
2
-
3
- Llama3ForCausalLM wraps Llama3Model with everything a researcher needs to
4
- train, evaluate, and generate from it through the HuggingFace ecosystem:
5
- token embedding, vocabulary projection, next-token loss, weight tying, and
6
- the full AutoClass and GenerationMixin contracts.
7
-
8
- The token embedding lives here, not on the backbone. Llama3Model is a pure
9
- transformer stack that accepts pre-embedded hidden states — it has no knowledge
10
- of tokens or vocabulary. This is the correct HF convention: the backbone is
11
- modality-agnostic; the token interface belongs on the task wrapper.
12
-
13
- The LM head projects the backbone's (batch, seq, hidden_size) output to
14
- (batch, seq, vocab_size) logits. When labels are provided, cross-entropy loss
15
- is computed with a one-position shift: token i predicts token i+1. The shift
16
- is applied here rather than expected from the caller — a causal LM always
17
- trains this way and there is no use case for an unshifted loss.
18
-
19
- Weight tying: when config.tie_word_embeddings is True, lm_head.weight is
20
- directly assigned to embed_tokens.weight after post_init(). Both matrices are
21
- shape (vocab_size, hidden_size) — same shape, no transpose needed.
22
-
23
- KV caching uses HuggingFace's Cache protocol. GenerationMixin creates and
24
- manages the DynamicCache for generate() calls, passing it as past_key_values
25
- on every forward call. The backbone updates the cache in place and returns the
26
- same object. _reorder_cache delegates to DynamicCache.reorder_cache() for beam
27
- search, keeping all beam-reordering logic inside the cache implementation.
28
-
29
- Returns a CausalLMOutputWithPast. ModelOutput subclasses support both attribute
30
- access (output.logits) and dict-style access (output["logits"]), satisfying
31
- GenerationMixin's attribute access requirements while keeping existing code unchanged.
32
- """
33
-
34
- import torch
35
- import torch.nn as nn
36
- from transformers import PreTrainedModel, GenerationMixin
37
- from transformers.cache_utils import Cache, DynamicCache
38
- from transformers.modeling_outputs import CausalLMOutputWithPast
39
-
40
- from .configuration import Llama3Config
41
- from .model import Llama3Model
42
-
43
-
44
- class Llama3ForCausalLM(PreTrainedModel, GenerationMixin):
45
- """Llama 3 causal language model: token embedding, backbone, LM head, HF contract.
46
-
47
- Owns the token embedding and LM head. Delegates all transformer computation
48
- to Llama3Model. Adds loss computation for training, weight tying between the
49
- LM head and the input embedding, and the full HuggingFace AutoClass and
50
- GenerationMixin contracts.
51
-
52
- Args:
53
- config: Model configuration. Must be a ``Llama3Config`` instance.
54
- """
55
-
56
- config_class = Llama3Config
57
- base_model_prefix = "model"
58
- _no_split_modules = ["DecoderLayer"]
59
- supports_gradient_checkpointing = True
60
-
61
- def __init__(self, config: Llama3Config) -> None:
62
- super().__init__(config)
63
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
64
- self.model = Llama3Model(config)
65
-
66
- # No bias — consistent with all other projections in this architecture.
67
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
68
- self.post_init()
69
-
70
- # Direct weight tying: both matrices are (vocab_size, hidden_size) — same shape,
71
- # no transpose. Explicit here for visibility; post_init() → tie_weights() also
72
- # performs this via get_input/output_embeddings(), but that is less readable.
73
- if config.tie_word_embeddings:
74
- self.lm_head.weight = self.embed_tokens.weight
75
-
76
- def _init_weights(self, module: nn.Module) -> None:
77
- # Suppress HF's default reinitialisation pass. HF's _init_weights overwrites
78
- # all Linear and Embedding weights with normal(0, 0.02) after construction,
79
- # silently replacing PyTorch's own defaults (kaiming_uniform_ for Linear,
80
- # normal(0,1) for Embedding). PyTorch's reset_parameters() already ran at
81
- # construction time and those initialisations should stand.
82
- pass
83
-
84
- def get_input_embeddings(self) -> nn.Embedding:
85
- """Return the token embedding matrix. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
86
- return self.embed_tokens
87
-
88
- def set_input_embeddings(self, value: nn.Embedding) -> None:
89
- """Replace the token embedding matrix. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
90
- self.embed_tokens = value
91
-
92
- def get_output_embeddings(self) -> nn.Linear:
93
- """Return the LM head. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
94
- return self.lm_head
95
-
96
- def set_output_embeddings(self, value: nn.Linear) -> None:
97
- """Replace the LM head. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
98
- self.lm_head = value
99
-
100
- def _reorder_cache(
101
- self, past_key_values: Cache, beam_idx: torch.Tensor
102
- ) -> Cache:
103
- """Reorder the KV cache to match beam reordering during beam search.
104
-
105
- GenerationMixin calls this after pruning and reordering beams at each
106
- step. beam_idx[i] is the old batch position whose cache should move to
107
- position i. DynamicCache.reorder_cache() handles the index-select on
108
- every stored tensor's batch dimension, keeping the cache consistent with
109
- the reordered beam hypotheses.
110
-
111
- Args:
112
- past_key_values: The active Cache object.
113
- beam_idx: 1-D tensor of shape (batch * num_beams,) mapping new batch
114
- positions to old ones.
115
-
116
- Returns:
117
- The same Cache object, reordered in place.
118
- """
119
- past_key_values.reorder_cache(beam_idx)
120
- return past_key_values
121
-
122
- def forward(
123
- self,
124
- input_ids: torch.Tensor,
125
- position_ids: torch.Tensor | None = None,
126
- past_key_values: Cache | None = None,
127
- use_cache: bool | None = None,
128
- output_hidden_states: bool | None = None,
129
- labels: torch.Tensor | None = None,
130
- cache_position: torch.Tensor | None = None,
131
- **kwargs,
132
- ) -> CausalLMOutputWithPast:
133
- """Run the causal language model.
134
-
135
- Args:
136
- input_ids: Token indices of shape (batch, seq_len).
137
- position_ids: Absolute positions of shape (batch, seq_len). Passed
138
- through to the backbone. When use_cache=True and this is None,
139
- derived from cache_position.
140
- past_key_values: A HuggingFace Cache object from a prior step, or
141
- None. When use_cache=True and this is None, a fresh DynamicCache
142
- is created here before calling the backbone.
143
- use_cache: Whether to accumulate and return a KV cache. When True
144
- and no cache is provided, a DynamicCache is created. When False,
145
- None is passed to the backbone regardless of what was provided.
146
- Defaults to config.use_cache when None.
147
- output_hidden_states: Whether to return per-layer hidden states.
148
- Passed through to the backbone.
149
- labels: Target token indices of shape (batch, seq_len) for computing
150
- next-token prediction loss. The loss is computed over positions
151
- 1..seq_len predicting from positions 0..seq_len-1 — the shift
152
- is applied internally. Positions with label value -100 are
153
- ignored by cross-entropy, following the HuggingFace convention
154
- for padding and masked positions.
155
- cache_position: 1-D integer tensor of shape (seq_len,) giving the
156
- absolute position of each input token in the full sequence.
157
- Provided by GenerationMixin during generate(). When use_cache=True
158
- and this is None, it is derived from the current cache length.
159
- **kwargs: Additional keyword arguments passed by GenerationMixin
160
- (e.g. return_dict). Accepted and ignored for forward compatibility.
161
- We always return CausalLMOutputWithPast regardless of return_dict.
162
-
163
- Returns:
164
- CausalLMOutputWithPast with fields:
165
- - ``logits``: vocabulary scores of shape (batch, seq_len, vocab_size).
166
- Always present.
167
- - ``loss``: scalar cross-entropy loss, or None if labels not provided.
168
- - ``past_key_values``: the updated Cache object, or None.
169
- - ``hidden_states``: per-layer hidden states, or None.
170
- """
171
- if kwargs.get("attention_mask") is not None:
172
- raise ValueError(
173
- "attention_mask is not supported. This model does not support padding masks. "
174
- "For training on variable-length sequences, use right-padding with -100 labels."
175
- )
176
-
177
- # Resolve both flags against config defaults. Config sets the default;
178
- # per-call arguments override it. Both fields in Llama3Config remain live.
179
- use_cache = use_cache if use_cache is not None else self.config.use_cache
180
- output_hidden_states = (
181
- output_hidden_states
182
- if output_hidden_states is not None
183
- else self.config.output_hidden_states
184
- )
185
-
186
- # Cache lifecycle is owned here — the backbone only receives a cache or None
187
- # and never decides whether to create one.
188
- if use_cache:
189
- if past_key_values is None:
190
- past_key_values = DynamicCache()
191
- else:
192
- past_key_values = None
193
-
194
- inputs_embeds = self.embed_tokens(input_ids)
195
- batch, seq_len, _ = inputs_embeds.shape
196
-
197
- # For training (use_cache=False), positions are always 0..seq_len-1.
198
- # This is not inference from state — it is a trivial fact about a
199
- # non-cached forward pass. The backbone requires explicit position_ids.
200
- if not use_cache and position_ids is None:
201
- position_ids = torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0).expand(batch, -1)
202
-
203
- causal_mask = None
204
- if use_cache:
205
- # cache_position is GenerationMixin's responsibility. If it is absent,
206
- # positions are unknown and any mask or RoPE encoding we produce would be
207
- # silently wrong — potentially corrupting a checkpoint. Crash immediately.
208
- if cache_position is None:
209
- raise ValueError(
210
- "cache_position must be provided when use_cache=True. "
211
- "GenerationMixin supplies this automatically during generate(). "
212
- "If calling forward() directly with use_cache=True, pass cache_position explicitly."
213
- )
214
-
215
- # Derive position_ids for RoPE from cache_position when not provided.
216
- # This is a valid computation: cache_position is the authoritative source
217
- # of absolute sequence positions, and position_ids is its batch-expanded form.
218
- if position_ids is None:
219
- position_ids = cache_position.unsqueeze(0).expand(batch, -1)
220
-
221
- # Build the causal attention mask. For each query at absolute position p,
222
- # it may attend to all keys at positions 0..p. k_len is the full sequence
223
- # length after this step: one past the last query position.
224
- k_len = int(cache_position[-1].item()) + 1
225
- k_positions = torch.arange(k_len, device=inputs_embeds.device)
226
- # mask[q, k] = True when key position k is within the causal horizon of query q.
227
- # Shape: (1, 1, seq_len, k_len) — broadcast over batch and head dimensions.
228
- causal_mask = (k_positions[None, :] <= cache_position[:, None]).unsqueeze(0).unsqueeze(0)
229
-
230
- backbone_out = self.model(
231
- inputs_embeds,
232
- position_ids=position_ids,
233
- past_key_values=past_key_values,
234
- output_hidden_states=output_hidden_states,
235
- causal_mask=causal_mask,
236
- )
237
-
238
- logits = self.lm_head(backbone_out["last_hidden_state"])
239
-
240
- loss = None
241
- if labels is not None:
242
- # Shift so that each position predicts the next token. The final
243
- # logit has no target; the first label has no corresponding input.
244
- shift_logits = logits[:, :-1, :].contiguous()
245
- shift_labels = labels[:, 1:].contiguous()
246
- loss = nn.functional.cross_entropy(
247
- shift_logits.view(-1, self.config.vocab_size),
248
- shift_labels.view(-1),
249
- )
250
-
251
- return CausalLMOutputWithPast(
252
- logits=logits,
253
- loss=loss,
254
- past_key_values=backbone_out["past_key_values"],
255
- hidden_states=backbone_out["hidden_states"],
256
- )
 
1
+ """HuggingFace wrapper for the Llama 3 baseline.
2
+
3
+ Llama3ForCausalLM wraps Llama3Model with everything a researcher needs to
4
+ train, evaluate, and generate from it through the HuggingFace ecosystem:
5
+ token embedding, vocabulary projection, next-token loss, weight tying, and
6
+ the full AutoClass and GenerationMixin contracts.
7
+
8
+ The token embedding lives here, not on the backbone. Llama3Model is a pure
9
+ transformer stack that accepts pre-embedded hidden states — it has no knowledge
10
+ of tokens or vocabulary. This is the correct HF convention: the backbone is
11
+ modality-agnostic; the token interface belongs on the task wrapper.
12
+
13
+ The LM head projects the backbone's (batch, seq, hidden_size) output to
14
+ (batch, seq, vocab_size) logits. When labels are provided, cross-entropy loss
15
+ is computed with a one-position shift: token i predicts token i+1. The shift
16
+ is applied here rather than expected from the caller — a causal LM always
17
+ trains this way and there is no use case for an unshifted loss.
18
+
19
+ Weight tying: when config.tie_word_embeddings is True, lm_head.weight is
20
+ directly assigned to embed_tokens.weight after post_init(). Both matrices are
21
+ shape (vocab_size, hidden_size) — same shape, no transpose needed.
22
+
23
+ KV caching uses HuggingFace's Cache protocol. GenerationMixin creates and
24
+ manages the DynamicCache for generate() calls, passing it as past_key_values
25
+ on every forward call. The backbone updates the cache in place and returns the
26
+ same object. _reorder_cache delegates to DynamicCache.reorder_cache() for beam
27
+ search, keeping all beam-reordering logic inside the cache implementation.
28
+
29
+ Returns a CausalLMOutputWithPast. ModelOutput subclasses support both attribute
30
+ access (output.logits) and dict-style access (output["logits"]), satisfying
31
+ GenerationMixin's attribute access requirements while keeping existing code unchanged.
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ from transformers import PreTrainedModel, GenerationMixin
37
+ from transformers.cache_utils import Cache, DynamicCache
38
+ from transformers.modeling_outputs import CausalLMOutputWithPast
39
+
40
+ from .configuration import Llama3Config
41
+ from .model import Llama3Model
42
+
43
+
44
+ class Llama3ForCausalLM(PreTrainedModel, GenerationMixin):
45
+ """Llama 3 causal language model: token embedding, backbone, LM head, HF contract.
46
+
47
+ Owns the token embedding and LM head. Delegates all transformer computation
48
+ to Llama3Model. Adds loss computation for training, weight tying between the
49
+ LM head and the input embedding, and the full HuggingFace AutoClass and
50
+ GenerationMixin contracts.
51
+
52
+ Args:
53
+ config: Model configuration. Must be a ``Llama3Config`` instance.
54
+ """
55
+
56
+ config_class = Llama3Config
57
+ base_model_prefix = "model"
58
+ _no_split_modules = ["DecoderLayer"]
59
+ supports_gradient_checkpointing = True
60
+
61
+ def __init__(self, config: Llama3Config) -> None:
62
+ super().__init__(config)
63
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
64
+ self.model = Llama3Model(config)
65
+
66
+ # No bias — consistent with all other projections in this architecture.
67
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
68
+ self.post_init()
69
+
70
+ # Direct weight tying: both matrices are (vocab_size, hidden_size) — same shape,
71
+ # no transpose. Explicit here for visibility; post_init() → tie_weights() also
72
+ # performs this via get_input/output_embeddings(), but that is less readable.
73
+ if config.tie_word_embeddings:
74
+ self.lm_head.weight = self.embed_tokens.weight
75
+
76
+ def _init_weights(self, module: nn.Module) -> None:
77
+ # Suppress HF's default reinitialisation pass. HF's _init_weights overwrites
78
+ # all Linear and Embedding weights with normal(0, 0.02) after construction,
79
+ # silently replacing PyTorch's own defaults (kaiming_uniform_ for Linear,
80
+ # normal(0,1) for Embedding). PyTorch's reset_parameters() already ran at
81
+ # construction time and those initialisations should stand.
82
+ pass
83
+
84
+ def get_input_embeddings(self) -> nn.Embedding:
85
+ """Return the token embedding matrix. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
86
+ return self.embed_tokens
87
+
88
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
89
+ """Replace the token embedding matrix. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
90
+ self.embed_tokens = value
91
+
92
+ def get_output_embeddings(self) -> nn.Linear:
93
+ """Return the LM head. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
94
+ return self.lm_head
95
+
96
+ def set_output_embeddings(self, value: nn.Linear) -> None:
97
+ """Replace the LM head. Required by PreTrainedModel for weight tying and resize_token_embeddings."""
98
+ self.lm_head = value
99
+
100
+ def _reorder_cache(
101
+ self, past_key_values: Cache, beam_idx: torch.Tensor
102
+ ) -> Cache:
103
+ """Reorder the KV cache to match beam reordering during beam search.
104
+
105
+ GenerationMixin calls this after pruning and reordering beams at each
106
+ step. beam_idx[i] is the old batch position whose cache should move to
107
+ position i. DynamicCache.reorder_cache() handles the index-select on
108
+ every stored tensor's batch dimension, keeping the cache consistent with
109
+ the reordered beam hypotheses.
110
+
111
+ Args:
112
+ past_key_values: The active Cache object.
113
+ beam_idx: 1-D tensor of shape (batch * num_beams,) mapping new batch
114
+ positions to old ones.
115
+
116
+ Returns:
117
+ The same Cache object, reordered in place.
118
+ """
119
+ past_key_values.reorder_cache(beam_idx)
120
+ return past_key_values
121
+
122
+ def forward(
123
+ self,
124
+ input_ids: torch.Tensor,
125
+ position_ids: torch.Tensor | None = None,
126
+ past_key_values: Cache | None = None,
127
+ use_cache: bool | None = None,
128
+ output_hidden_states: bool | None = None,
129
+ labels: torch.Tensor | None = None,
130
+ cache_position: torch.Tensor | None = None,
131
+ **kwargs,
132
+ ) -> CausalLMOutputWithPast:
133
+ """Run the causal language model.
134
+
135
+ Args:
136
+ input_ids: Token indices of shape (batch, seq_len).
137
+ position_ids: Absolute positions of shape (batch, seq_len). Passed
138
+ through to the backbone. When use_cache=True and this is None,
139
+ derived from cache_position.
140
+ past_key_values: A HuggingFace Cache object from a prior step, or
141
+ None. When use_cache=True and this is None, a fresh DynamicCache
142
+ is created here before calling the backbone.
143
+ use_cache: Whether to accumulate and return a KV cache. When True
144
+ and no cache is provided, a DynamicCache is created. When False,
145
+ None is passed to the backbone regardless of what was provided.
146
+ Defaults to config.use_cache when None.
147
+ output_hidden_states: Whether to return per-layer hidden states.
148
+ Passed through to the backbone.
149
+ labels: Target token indices of shape (batch, seq_len) for computing
150
+ next-token prediction loss. The loss is computed over positions
151
+ 1..seq_len predicting from positions 0..seq_len-1 — the shift
152
+ is applied internally. Positions with label value -100 are
153
+ ignored by cross-entropy, following the HuggingFace convention
154
+ for padding and masked positions.
155
+ cache_position: 1-D integer tensor of shape (seq_len,) giving the
156
+ absolute position of each input token in the full sequence.
157
+ Provided by GenerationMixin during generate(). When use_cache=True
158
+ and this is None, it is derived from the current cache length.
159
+ **kwargs: Additional keyword arguments passed by GenerationMixin
160
+ (e.g. return_dict). Accepted and ignored for forward compatibility.
161
+ We always return CausalLMOutputWithPast regardless of return_dict.
162
+
163
+ Returns:
164
+ CausalLMOutputWithPast with fields:
165
+ - ``logits``: vocabulary scores of shape (batch, seq_len, vocab_size).
166
+ Always present.
167
+ - ``loss``: scalar cross-entropy loss, or None if labels not provided.
168
+ - ``past_key_values``: the updated Cache object, or None.
169
+ - ``hidden_states``: per-layer hidden states, or None.
170
+ """
171
+ if kwargs.get("attention_mask") is not None:
172
+ raise ValueError(
173
+ "attention_mask is not supported. This model does not support padding masks. "
174
+ "For training on variable-length sequences, use right-padding with -100 labels."
175
+ )
176
+
177
+ # Resolve both flags against config defaults. Config sets the default;
178
+ # per-call arguments override it. Both fields in Llama3Config remain live.
179
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
180
+ output_hidden_states = (
181
+ output_hidden_states
182
+ if output_hidden_states is not None
183
+ else self.config.output_hidden_states
184
+ )
185
+
186
+ # Cache lifecycle is owned here — the backbone only receives a cache or None
187
+ # and never decides whether to create one.
188
+ if use_cache:
189
+ if past_key_values is None:
190
+ past_key_values = DynamicCache()
191
+ else:
192
+ past_key_values = None
193
+
194
+ inputs_embeds = self.embed_tokens(input_ids)
195
+ batch, seq_len, _ = inputs_embeds.shape
196
+
197
+ # For training (use_cache=False), positions are always 0..seq_len-1.
198
+ # This is not inference from state — it is a trivial fact about a
199
+ # non-cached forward pass. The backbone requires explicit position_ids.
200
+ if not use_cache and position_ids is None:
201
+ position_ids = torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0).expand(batch, -1)
202
+
203
+ causal_mask = None
204
+ if use_cache:
205
+ # cache_position is GenerationMixin's responsibility. If it is absent,
206
+ # positions are unknown and any mask or RoPE encoding we produce would be
207
+ # silently wrong — potentially corrupting a checkpoint. Crash immediately.
208
+ if cache_position is None:
209
+ raise ValueError(
210
+ "cache_position must be provided when use_cache=True. "
211
+ "GenerationMixin supplies this automatically during generate(). "
212
+ "If calling forward() directly with use_cache=True, pass cache_position explicitly."
213
+ )
214
+
215
+ # Derive position_ids for RoPE from cache_position when not provided.
216
+ # This is a valid computation: cache_position is the authoritative source
217
+ # of absolute sequence positions, and position_ids is its batch-expanded form.
218
+ if position_ids is None:
219
+ position_ids = cache_position.unsqueeze(0).expand(batch, -1)
220
+
221
+ # Build the causal attention mask. For each query at absolute position p,
222
+ # it may attend to all keys at positions 0..p. k_len is the full sequence
223
+ # length after this step: one past the last query position.
224
+ k_len = int(cache_position[-1].item()) + 1
225
+ k_positions = torch.arange(k_len, device=inputs_embeds.device)
226
+ # mask[q, k] = True when key position k is within the causal horizon of query q.
227
+ # Shape: (1, 1, seq_len, k_len) — broadcast over batch and head dimensions.
228
+ causal_mask = (k_positions[None, :] <= cache_position[:, None]).unsqueeze(0).unsqueeze(0)
229
+
230
+ backbone_out = self.model(
231
+ inputs_embeds,
232
+ position_ids=position_ids,
233
+ past_key_values=past_key_values,
234
+ output_hidden_states=output_hidden_states,
235
+ causal_mask=causal_mask,
236
+ )
237
+
238
+ logits = self.lm_head(backbone_out["last_hidden_state"])
239
+
240
+ loss = None
241
+ if labels is not None:
242
+ # Shift so that each position predicts the next token. The final
243
+ # logit has no target; the first label has no corresponding input.
244
+ shift_logits = logits[:, :-1, :].contiguous()
245
+ shift_labels = labels[:, 1:].contiguous()
246
+ loss = nn.functional.cross_entropy(
247
+ shift_logits.view(-1, self.config.vocab_size),
248
+ shift_labels.view(-1),
249
+ )
250
+
251
+ return CausalLMOutputWithPast(
252
+ logits=logits,
253
+ loss=loss,
254
+ past_key_values=backbone_out["past_key_values"],
255
+ hidden_states=backbone_out["hidden_states"],
256
+ )
tokenizer_config.json CHANGED
@@ -1,13 +1,13 @@
1
- {
2
- "add_prefix_space": false,
3
- "backend": "tokenizers",
4
- "bos_token": "<|endoftext|>",
5
- "eos_token": "<|endoftext|>",
6
- "errors": "replace",
7
- "is_local": false,
8
- "model_max_length": 1000000000000000019884624838656,
9
- "pad_token": "<|padding|>",
10
- "tokenizer_class": "GPTNeoXTokenizerFast",
11
- "trim_offsets": true,
12
- "unk_token": "<|endoftext|>"
13
  }
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<|padding|>",
10
+ "tokenizer_class": "GPTNeoXTokenizerFast",
11
+ "trim_offsets": true,
12
+ "unk_token": "<|endoftext|>"
13
  }