ChrisMcCormick commited on
Commit
bfabd1e
·
verified ·
1 Parent(s): 8841574

Up one level

Browse files
models/modeling_shared_subspace_decoder.py → modeling_shared_subspace_decoder.py RENAMED
@@ -1,390 +1,390 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- """
4
- modeling_shared_subspace_decoder.py
5
-
6
- SharedSpaceDecoder model implementation for HuggingFace Transformers.
7
- """
8
-
9
- from typing import Optional
10
-
11
- import torch
12
- from torch import nn
13
-
14
- from transformers.configuration_utils import PretrainedConfig
15
- from transformers.modeling_utils import PreTrainedModel
16
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
17
-
18
- from ..layers.mla import MultiheadLatentAttention, RotaryEmbedding
19
- from ..layers.feedforward import SubspaceFeedForward
20
- from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig
21
-
22
- """`RMSNorm`
23
-
24
- From:
25
- https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
26
-
27
- TODO - May not need?
28
- """
29
-
30
- class DeepseekV3RMSNorm(nn.Module):
31
- def __init__(self, hidden_size, eps=1e-6):
32
- """
33
- DeepseekV3RMSNorm is equivalent to T5LayerNorm
34
- """
35
- super().__init__()
36
- self.weight = nn.Parameter(torch.ones(hidden_size))
37
- self.variance_epsilon = eps
38
-
39
- def forward(self, hidden_states):
40
- input_dtype = hidden_states.dtype
41
- hidden_states = hidden_states.to(torch.float32)
42
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
- return self.weight * hidden_states.to(input_dtype)
45
-
46
- def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
47
- """
48
- Create a normalization layer based on the config norm_type.
49
-
50
- Args:
51
- hidden_size: The dimension to normalize over
52
- config: Configuration containing norm_type and epsilon values
53
-
54
- Returns:
55
- Either a LayerNorm or RMSNorm layer
56
- """
57
- if config.norm_type == "layernorm":
58
- return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
59
- elif config.norm_type == "rmsnorm":
60
- return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
61
- else:
62
- # This should be caught by config validation, but being defensive
63
- raise ValueError(f"Unknown norm_type: {config.norm_type}")
64
-
65
- """#### *PreTrainedModel"""
66
-
67
- class SharedSpaceDecoderPreTrainedModel(PreTrainedModel):
68
- """
69
- The **PreTrainedModel object:
70
- - Is instantiated when TODO
71
- - Initializes:
72
- - TODO
73
- - Provides access to TODO
74
- - Executes TODO
75
- """
76
-
77
- config_class = SharedSpaceDecoderConfig
78
- base_model_prefix = "model"
79
-
80
- def _init_weights(self, module: nn.Module) -> None:
81
- """Weight initialization hook used by :class:`PreTrainedModel`.
82
-
83
- ``PreTrainedModel.post_init`` will recursively apply this function to
84
- every submodule right after construction. HuggingFace models override
85
- it so that creating a model from scratch yields the same initialization
86
- as ``from_pretrained`` when no checkpoint is supplied.
87
-
88
- This decoder-specific initialization strategy includes:
89
- - Proper handling of configurable normalization layers (LayerNorm or RMSNorm)
90
- - Special initialization for language modeling heads
91
- - Considerations for causal attention and autoregressive modeling
92
- - Support for both dense and decomposed vocabulary embeddings
93
- """
94
-
95
- if isinstance(module, nn.Linear):
96
- # Standard linear layer initialization
97
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
98
- if module.bias is not None:
99
- module.bias.data.zero_()
100
-
101
- elif isinstance(module, nn.Embedding):
102
- # Initialize embeddings with normal distribution
103
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
104
- if module.padding_idx is not None:
105
- module.weight.data[module.padding_idx].zero_()
106
-
107
- elif isinstance(module, DeepseekV3RMSNorm):
108
- # RMSNorm initialization: weight to 1.0, no bias term
109
- module.weight.data.fill_(1.0)
110
-
111
- elif isinstance(module, nn.LayerNorm):
112
- # LayerNorm initialization: bias to 0, weight to 1.0
113
- module.bias.data.zero_()
114
- module.weight.data.fill_(1.0)
115
-
116
- """# ▂▂▂▂▂▂▂▂▂▂▂▂
117
-
118
- # Classes
119
- """
120
-
121
- """#### `*Layer`"""
122
-
123
- class SharedSpaceDecoderLayer(nn.Module):
124
- """
125
- The **Layer object:
126
- - Is instantiated by :class:`SharedSpaceDecoderModel` for each
127
- Transformer block in the decoder.
128
- - Initializes:
129
- - ``self_attn`` – multi-head latent attention implementing either
130
- dense or latent projections depending on the configuration.
131
- - ``ffn`` – a :class:`SubspaceFeedForward` block.
132
- - RMSNorm layers for pre-attention and pre-FFN normalization.
133
- - Provides access to the attention and feed-forward submodules via the
134
- attributes ``self_attn`` and ``ffn``.
135
- - Executes a single decoder block in :meth:`forward`.
136
- """
137
-
138
- def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None:
139
-
140
- super().__init__()
141
-
142
- # Norm applied prior to attention.
143
- self.attn_input_norm = create_norm_layer(config.hidden_size, config)
144
-
145
- # Attention block
146
- self.self_attn = MultiheadLatentAttention(config, layer_idx)
147
-
148
- # Norm applied prior to FFN
149
- self.ffn_input_norm = create_norm_layer(config.hidden_size, config)
150
-
151
- # Feed-forward network used after attention
152
- self.ffn = SubspaceFeedForward(config, layer_idx)
153
-
154
- def forward(
155
- self,
156
- hidden_states: torch.Tensor,
157
- position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE embeddings
158
- attention_mask: Optional[torch.Tensor],
159
- ) -> torch.Tensor:
160
-
161
- # ========================
162
- # Self Attention
163
- # ========================
164
- residual_strm = hidden_states
165
-
166
- # Normalize the hidden states to create the input to attention.
167
- attn_input = self.attn_input_norm(hidden_states)
168
-
169
- # Evaluate
170
- attn_output = self.self_attn(
171
- attn_input,
172
- position_embeddings,
173
- attention_mask,
174
- )
175
-
176
- # Add the attention output (the residual) back to the non-normalized
177
- # hidden_states.
178
- hidden_states = residual_strm + attn_output
179
-
180
- # ===========================
181
- # Feed-Forward Network
182
- # ===========================
183
- residual_strm = hidden_states
184
-
185
- # Normalize the updated hidden states prior to the FFN
186
- ffn_input = self.ffn_input_norm(hidden_states)
187
-
188
- # Evaluate
189
- ffn_output = self.ffn(ffn_input)
190
-
191
- # Add the output the un-normalized hidden states.
192
- hidden_states = residual_strm + ffn_output
193
-
194
- return hidden_states
195
-
196
- """#### *Model"""
197
-
198
- class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel):
199
- """
200
- The **Model object:
201
- - Initializes:
202
- - The vocabulary embeddings (and optional decomposition)
203
- - Position embeddings (calculated in RotaryEmbedding)
204
- - All of the **Layer objects.
205
- - Provides interface to vocab embeddings.
206
- - Executes the whole decoder model in `forward` with causal attention.
207
-
208
- This is the base decoder without the language modeling head.
209
- Use SubspaceDecoderForCausalLM for language modeling tasks.
210
- """
211
-
212
- def __init__(self, config: SharedSpaceDecoderConfig) -> None:
213
- super().__init__(config)
214
-
215
- # ============================
216
- # Vocabulary Embeddings
217
- # ============================
218
- # Decomposing the vocabulary (if enabled) defines a shared projection
219
- # which constrains the model to store semantic information (and
220
- # whatever other static token knowledge) into a limited set of
221
- # feature directions.
222
-
223
- # If we're decomposing the token embeddings,
224
- # TODO - Rename to vocab_subspace.
225
- if config.vocab_subspace:
226
-
227
- # Create the embedding table. Vocabulary embeddings are learned
228
- # in a lower dimensional latent space.
229
- self.vocab_embed = nn.Embedding(
230
- config.vocab_size, # Number of tokens
231
- config.vocab_rank # Subspace dimension
232
- )
233
-
234
- # Create a
235
- # Selected token latents will be projected up to model size.
236
- # vocab_proj has shape [vocab_rank x model_size]
237
- self.vocab_proj = nn.Linear(
238
- config.vocab_rank, # Size of latents
239
- config.hidden_size, # Model size
240
- bias=False
241
- )
242
-
243
- # Otherwise, for a dense vocabulary,
244
- else:
245
- # Create the dense embedding table in model space.
246
- self.vocab_embed = nn.Embedding(
247
- config.vocab_size, # Number of tokens
248
- config.hidden_size # Model size
249
- )
250
-
251
- self.vocab_proj = None
252
-
253
- # =====================
254
- # RoPE Embeddings
255
- # =====================
256
-
257
- # Pre-computes the table of RoPE embeddings, leaving them in
258
- # GPU memory.
259
- self.rope = RotaryEmbedding(config)
260
-
261
- # ===================
262
- # Create Layers
263
- # ===================
264
-
265
- layers = []
266
-
267
- # For each layer,
268
- for i in range(config.num_hidden_layers):
269
- # Create a **Layer, providing the config and indicating its number.
270
- layers.append(
271
- SharedSpaceDecoderLayer(
272
- config,
273
- layer_idx = i
274
- )
275
- )
276
-
277
- # Wrap in torch ModuleList
278
- self.layers = nn.ModuleList(layers)
279
-
280
- # Whatever huggingface does behind the scenes...
281
- self.post_init()
282
-
283
- # Agents: Do not define boilerplate helpers, e.g., get/set_input_embeddings
284
-
285
-
286
- def embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
287
- """
288
- Return token embeddings for input ids.
289
- This will perform the up projection to model space if the vocabulary is
290
- decomposed.
291
-
292
- input_ids have shape [batch_size, seq_len]
293
- """
294
-
295
- # If the vocabulary is decomposed,
296
- if self.vocab_proj is not None:
297
-
298
- # Retrieve the latents
299
- # input_ids: [batch_size, seq_len]
300
- # x: [batch_size, seq_len, latent_dim]
301
- x = self.vocab_embed(input_ids)
302
-
303
- # Project the latents back to model space and return.
304
- return(self.vocab_proj(x))
305
-
306
- # If the vocabulary is dense,
307
- else:
308
- # Just return the embeddings.
309
- return self.vocab_embed(input_ids)
310
-
311
- def forward(
312
- self,
313
- input_ids: torch.LongTensor,
314
- attention_mask: Optional[torch.Tensor] = None,
315
- **kwargs,
316
- ) -> torch.Tensor:
317
- """
318
- Run the full decoder stack with causal attention.
319
-
320
- Inputs:
321
- input_ids [batch_size, seq_len]
322
- attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding
323
-
324
- Returns:
325
- Final decoder layer output [batch_size, seq_len, model_size]
326
- """
327
-
328
- # Retrieve the token embeddings for this sequence.
329
- # These are model_size, regardless of whether the vocab is decompd.
330
- hidden_states = self.embed(input_ids)
331
-
332
- # Retrieve the rotary position embeddings for all of the positions in
333
- # our current input sequence.
334
-
335
- seq_len = hidden_states.size(1)
336
-
337
- # Retrieves just the ones necessary for the sequence length of the
338
- # input. These are vectors, two per token. Their length is the
339
- # number of head dimensions we're applying RoPE to.
340
- # Input
341
- # cos: [max_seq_len, rope_dims]
342
- # sin: [max_seq_len, rope_dims]
343
- # Outputs:
344
- # R_cos [seq_len, rope_dims]
345
- # R_sin [seq_len, rope_dims]
346
- R_cos = self.rope.cos[:seq_len]
347
- R_sin = self.rope.sin[:seq_len]
348
-
349
-
350
- # ===============================
351
- # Attention Mask Conversion
352
- # ===============================
353
-
354
- """
355
- use_sdpa_attention_masks = (
356
- self.attn_implementation == "sdpa"
357
- and self.position_embedding_type == "absolute"
358
- and head_mask is None
359
- and not output_attentions
360
- )
361
- """
362
-
363
- # Expand the attention mask
364
- #if use_sdpa_attention_masks and attention_mask.dim() == 2:
365
- if True:
366
- # Expand the attention mask for SDPA.
367
- # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
368
- extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
369
- attention_mask,
370
- hidden_states.dtype,
371
- tgt_len = seq_len
372
- )
373
- attention_mask = extended_attention_mask
374
-
375
-
376
- # Run the model!
377
-
378
- # For each decoder layer,
379
- for layer_i, layer in enumerate(self.layers):
380
-
381
- # Evaluate the layer
382
- hidden_states = layer(
383
- hidden_states, # Token embeddings
384
- (R_cos, R_sin), # Rope embeddings, passed as a tuple.
385
- attention_mask, # Attn mask
386
- )
387
-
388
- # Return the final output of the decoder stack.
389
- return hidden_states
390
-
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ modeling_shared_subspace_decoder.py
5
+
6
+ SharedSpaceDecoder model implementation for HuggingFace Transformers.
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
17
+
18
+ from layers.mla import MultiheadLatentAttention, RotaryEmbedding
19
+ from layers.feedforward import SubspaceFeedForward
20
+ from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig
21
+
22
+ """`RMSNorm`
23
+
24
+ From:
25
+ https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
26
+
27
+ TODO - May not need?
28
+ """
29
+
30
+ class DeepseekV3RMSNorm(nn.Module):
31
+ def __init__(self, hidden_size, eps=1e-6):
32
+ """
33
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
34
+ """
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
47
+ """
48
+ Create a normalization layer based on the config norm_type.
49
+
50
+ Args:
51
+ hidden_size: The dimension to normalize over
52
+ config: Configuration containing norm_type and epsilon values
53
+
54
+ Returns:
55
+ Either a LayerNorm or RMSNorm layer
56
+ """
57
+ if config.norm_type == "layernorm":
58
+ return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
59
+ elif config.norm_type == "rmsnorm":
60
+ return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
61
+ else:
62
+ # This should be caught by config validation, but being defensive
63
+ raise ValueError(f"Unknown norm_type: {config.norm_type}")
64
+
65
+ """#### *PreTrainedModel"""
66
+
67
+ class SharedSpaceDecoderPreTrainedModel(PreTrainedModel):
68
+ """
69
+ The **PreTrainedModel object:
70
+ - Is instantiated when TODO
71
+ - Initializes:
72
+ - TODO
73
+ - Provides access to TODO
74
+ - Executes TODO
75
+ """
76
+
77
+ config_class = SharedSpaceDecoderConfig
78
+ base_model_prefix = "model"
79
+
80
+ def _init_weights(self, module: nn.Module) -> None:
81
+ """Weight initialization hook used by :class:`PreTrainedModel`.
82
+
83
+ ``PreTrainedModel.post_init`` will recursively apply this function to
84
+ every submodule right after construction. HuggingFace models override
85
+ it so that creating a model from scratch yields the same initialization
86
+ as ``from_pretrained`` when no checkpoint is supplied.
87
+
88
+ This decoder-specific initialization strategy includes:
89
+ - Proper handling of configurable normalization layers (LayerNorm or RMSNorm)
90
+ - Special initialization for language modeling heads
91
+ - Considerations for causal attention and autoregressive modeling
92
+ - Support for both dense and decomposed vocabulary embeddings
93
+ """
94
+
95
+ if isinstance(module, nn.Linear):
96
+ # Standard linear layer initialization
97
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
98
+ if module.bias is not None:
99
+ module.bias.data.zero_()
100
+
101
+ elif isinstance(module, nn.Embedding):
102
+ # Initialize embeddings with normal distribution
103
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
104
+ if module.padding_idx is not None:
105
+ module.weight.data[module.padding_idx].zero_()
106
+
107
+ elif isinstance(module, DeepseekV3RMSNorm):
108
+ # RMSNorm initialization: weight to 1.0, no bias term
109
+ module.weight.data.fill_(1.0)
110
+
111
+ elif isinstance(module, nn.LayerNorm):
112
+ # LayerNorm initialization: bias to 0, weight to 1.0
113
+ module.bias.data.zero_()
114
+ module.weight.data.fill_(1.0)
115
+
116
+ """# ▂▂▂▂▂▂▂▂▂▂▂▂
117
+
118
+ # Classes
119
+ """
120
+
121
+ """#### `*Layer`"""
122
+
123
+ class SharedSpaceDecoderLayer(nn.Module):
124
+ """
125
+ The **Layer object:
126
+ - Is instantiated by :class:`SharedSpaceDecoderModel` for each
127
+ Transformer block in the decoder.
128
+ - Initializes:
129
+ - ``self_attn`` – multi-head latent attention implementing either
130
+ dense or latent projections depending on the configuration.
131
+ - ``ffn`` – a :class:`SubspaceFeedForward` block.
132
+ - RMSNorm layers for pre-attention and pre-FFN normalization.
133
+ - Provides access to the attention and feed-forward submodules via the
134
+ attributes ``self_attn`` and ``ffn``.
135
+ - Executes a single decoder block in :meth:`forward`.
136
+ """
137
+
138
+ def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None:
139
+
140
+ super().__init__()
141
+
142
+ # Norm applied prior to attention.
143
+ self.attn_input_norm = create_norm_layer(config.hidden_size, config)
144
+
145
+ # Attention block
146
+ self.self_attn = MultiheadLatentAttention(config, layer_idx)
147
+
148
+ # Norm applied prior to FFN
149
+ self.ffn_input_norm = create_norm_layer(config.hidden_size, config)
150
+
151
+ # Feed-forward network used after attention
152
+ self.ffn = SubspaceFeedForward(config, layer_idx)
153
+
154
+ def forward(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE embeddings
158
+ attention_mask: Optional[torch.Tensor],
159
+ ) -> torch.Tensor:
160
+
161
+ # ========================
162
+ # Self Attention
163
+ # ========================
164
+ residual_strm = hidden_states
165
+
166
+ # Normalize the hidden states to create the input to attention.
167
+ attn_input = self.attn_input_norm(hidden_states)
168
+
169
+ # Evaluate
170
+ attn_output = self.self_attn(
171
+ attn_input,
172
+ position_embeddings,
173
+ attention_mask,
174
+ )
175
+
176
+ # Add the attention output (the residual) back to the non-normalized
177
+ # hidden_states.
178
+ hidden_states = residual_strm + attn_output
179
+
180
+ # ===========================
181
+ # Feed-Forward Network
182
+ # ===========================
183
+ residual_strm = hidden_states
184
+
185
+ # Normalize the updated hidden states prior to the FFN
186
+ ffn_input = self.ffn_input_norm(hidden_states)
187
+
188
+ # Evaluate
189
+ ffn_output = self.ffn(ffn_input)
190
+
191
+ # Add the output the un-normalized hidden states.
192
+ hidden_states = residual_strm + ffn_output
193
+
194
+ return hidden_states
195
+
196
+ """#### *Model"""
197
+
198
+ class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel):
199
+ """
200
+ The **Model object:
201
+ - Initializes:
202
+ - The vocabulary embeddings (and optional decomposition)
203
+ - Position embeddings (calculated in RotaryEmbedding)
204
+ - All of the **Layer objects.
205
+ - Provides interface to vocab embeddings.
206
+ - Executes the whole decoder model in `forward` with causal attention.
207
+
208
+ This is the base decoder without the language modeling head.
209
+ Use SubspaceDecoderForCausalLM for language modeling tasks.
210
+ """
211
+
212
+ def __init__(self, config: SharedSpaceDecoderConfig) -> None:
213
+ super().__init__(config)
214
+
215
+ # ============================
216
+ # Vocabulary Embeddings
217
+ # ============================
218
+ # Decomposing the vocabulary (if enabled) defines a shared projection
219
+ # which constrains the model to store semantic information (and
220
+ # whatever other static token knowledge) into a limited set of
221
+ # feature directions.
222
+
223
+ # If we're decomposing the token embeddings,
224
+ # TODO - Rename to vocab_subspace.
225
+ if config.vocab_subspace:
226
+
227
+ # Create the embedding table. Vocabulary embeddings are learned
228
+ # in a lower dimensional latent space.
229
+ self.vocab_embed = nn.Embedding(
230
+ config.vocab_size, # Number of tokens
231
+ config.vocab_rank # Subspace dimension
232
+ )
233
+
234
+ # Create a
235
+ # Selected token latents will be projected up to model size.
236
+ # vocab_proj has shape [vocab_rank x model_size]
237
+ self.vocab_proj = nn.Linear(
238
+ config.vocab_rank, # Size of latents
239
+ config.hidden_size, # Model size
240
+ bias=False
241
+ )
242
+
243
+ # Otherwise, for a dense vocabulary,
244
+ else:
245
+ # Create the dense embedding table in model space.
246
+ self.vocab_embed = nn.Embedding(
247
+ config.vocab_size, # Number of tokens
248
+ config.hidden_size # Model size
249
+ )
250
+
251
+ self.vocab_proj = None
252
+
253
+ # =====================
254
+ # RoPE Embeddings
255
+ # =====================
256
+
257
+ # Pre-computes the table of RoPE embeddings, leaving them in
258
+ # GPU memory.
259
+ self.rope = RotaryEmbedding(config)
260
+
261
+ # ===================
262
+ # Create Layers
263
+ # ===================
264
+
265
+ layers = []
266
+
267
+ # For each layer,
268
+ for i in range(config.num_hidden_layers):
269
+ # Create a **Layer, providing the config and indicating its number.
270
+ layers.append(
271
+ SharedSpaceDecoderLayer(
272
+ config,
273
+ layer_idx = i
274
+ )
275
+ )
276
+
277
+ # Wrap in torch ModuleList
278
+ self.layers = nn.ModuleList(layers)
279
+
280
+ # Whatever huggingface does behind the scenes...
281
+ self.post_init()
282
+
283
+ # Agents: Do not define boilerplate helpers, e.g., get/set_input_embeddings
284
+
285
+
286
+ def embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
287
+ """
288
+ Return token embeddings for input ids.
289
+ This will perform the up projection to model space if the vocabulary is
290
+ decomposed.
291
+
292
+ input_ids have shape [batch_size, seq_len]
293
+ """
294
+
295
+ # If the vocabulary is decomposed,
296
+ if self.vocab_proj is not None:
297
+
298
+ # Retrieve the latents
299
+ # input_ids: [batch_size, seq_len]
300
+ # x: [batch_size, seq_len, latent_dim]
301
+ x = self.vocab_embed(input_ids)
302
+
303
+ # Project the latents back to model space and return.
304
+ return(self.vocab_proj(x))
305
+
306
+ # If the vocabulary is dense,
307
+ else:
308
+ # Just return the embeddings.
309
+ return self.vocab_embed(input_ids)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.LongTensor,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ **kwargs,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Run the full decoder stack with causal attention.
319
+
320
+ Inputs:
321
+ input_ids [batch_size, seq_len]
322
+ attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding
323
+
324
+ Returns:
325
+ Final decoder layer output [batch_size, seq_len, model_size]
326
+ """
327
+
328
+ # Retrieve the token embeddings for this sequence.
329
+ # These are model_size, regardless of whether the vocab is decompd.
330
+ hidden_states = self.embed(input_ids)
331
+
332
+ # Retrieve the rotary position embeddings for all of the positions in
333
+ # our current input sequence.
334
+
335
+ seq_len = hidden_states.size(1)
336
+
337
+ # Retrieves just the ones necessary for the sequence length of the
338
+ # input. These are vectors, two per token. Their length is the
339
+ # number of head dimensions we're applying RoPE to.
340
+ # Input
341
+ # cos: [max_seq_len, rope_dims]
342
+ # sin: [max_seq_len, rope_dims]
343
+ # Outputs:
344
+ # R_cos [seq_len, rope_dims]
345
+ # R_sin [seq_len, rope_dims]
346
+ R_cos = self.rope.cos[:seq_len]
347
+ R_sin = self.rope.sin[:seq_len]
348
+
349
+
350
+ # ===============================
351
+ # Attention Mask Conversion
352
+ # ===============================
353
+
354
+ """
355
+ use_sdpa_attention_masks = (
356
+ self.attn_implementation == "sdpa"
357
+ and self.position_embedding_type == "absolute"
358
+ and head_mask is None
359
+ and not output_attentions
360
+ )
361
+ """
362
+
363
+ # Expand the attention mask
364
+ #if use_sdpa_attention_masks and attention_mask.dim() == 2:
365
+ if True:
366
+ # Expand the attention mask for SDPA.
367
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
368
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
369
+ attention_mask,
370
+ hidden_states.dtype,
371
+ tgt_len = seq_len
372
+ )
373
+ attention_mask = extended_attention_mask
374
+
375
+
376
+ # Run the model!
377
+
378
+ # For each decoder layer,
379
+ for layer_i, layer in enumerate(self.layers):
380
+
381
+ # Evaluate the layer
382
+ hidden_states = layer(
383
+ hidden_states, # Token embeddings
384
+ (R_cos, R_sin), # Rope embeddings, passed as a tuple.
385
+ attention_mask, # Attn mask
386
+ )
387
+
388
+ # Return the final output of the decoder stack.
389
+ return hidden_states
390
+