ChrisMcCormick commited on
Commit
572eda0
Β·
verified Β·
1 Parent(s): 7bca257

Fixed architecture name

Browse files
Files changed (1) hide show
  1. models/shared_space_config.py +256 -256
models/shared_space_config.py CHANGED
@@ -1,256 +1,256 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from transformers.configuration_utils import PretrainedConfig
7
- from transformers.modeling_utils import PreTrainedModel
8
-
9
- class SharedSpaceDecoderConfig(PretrainedConfig):
10
- r"""
11
- Configuration class for SharedSpaceDecoderConfig.
12
-
13
- Extends the HuggingFace `PretrainedConfig` to support architectural
14
- variations including:
15
- - Multi-Head Latent Attention (MLA)
16
- - Decomposed MLPs (low-rank FFNs)
17
- - Flexible attention backends (eager, flash, sdpa)
18
- - Explicit shared subspaces for Q, K, V, and O projections
19
-
20
- This config does not infer any defaults based on `hidden_size`. All
21
- dimensions and ranks must be explicitly specified. If required values are
22
- missing, a `ValueError` is raised during initialization.
23
-
24
- ----------------------
25
- Core Model Parameters:
26
- ----------------------
27
- - vocab_size (`int`) β€” Vocabulary size.
28
- - hidden_size (`int`) β€” Model hidden dimension.
29
- - num_hidden_layers (`int`) β€” Number of transformer blocks.
30
- - intermediate_size (`int`) β€” Feed-forward hidden dimension.
31
- - hidden_act (`str`) β€” Activation function.
32
- - hidden_dropout_prob (`float`) β€” Dropout after projections and FFNs.
33
- - attention_dropout_prob (`float`) β€” Dropout applied to attention scores.
34
- - max_position_embeddings (`int`) β€” Max sequence length.
35
- - initializer_range (`float`) β€” Stddev of weight init.
36
-
37
- - layer_norm_eps (`float`) β€” Epsilon for LayerNorm.
38
- - rms_norm_ps (`float`) β€” Epsilon for RMSNorm
39
-
40
- - classifier_dropout (`float` or None) β€” Dropout for final classifier.
41
-
42
- - vocab_subspace
43
- - vocab_rank
44
-
45
- ----------------------------------
46
- Multi-Head Latent Attention (MLA):
47
- ----------------------------------
48
- - num_attention_heads (`int`) β€” Number of attention heads.
49
-
50
- - q_shared_dim (`int`) β€” Rank of the shared query subspace.
51
- - kv_shared_dim (`int`) β€” Rank of the shared key/value subspace.
52
-
53
- - output_subspace (`bool`) β€” Whether to use a shared latent subspace for output projections.
54
- - o_shared_dim (`int`) β€” Rank of the shared output subspace (required if `output_subspace=True`).
55
- - qk_private_dim (`int`) β€” Query/key private dimension per head.
56
- - vo_private_dim (`int`) β€” Value/output private dimension per head.
57
-
58
- - rope_dims (`int`) β€” Number of head dimensions carrying RoPE.
59
- - nope_dims (`int`) β€” Non-positional encoding dimensions.
60
- - rope_theta (`float`) β€” Base frequency used for RoPE.
61
- - rope_scaling (`dict` or None) β€” HF-style scaling dict for RoPE.
62
- - attention_bias (`bool`) β€” Whether to include bias terms in Q/K/V projections.
63
- - num_dense_layers (`int`) β€” Number of leading layers that do not use
64
- subspaces for attention or FFNs.
65
- - attention_backend (`str`) β€” Must be one of `"eager"`, `"flash_attention_2"`, or `"sdpa"`.
66
-
67
- ----------------------
68
- Decomposed MLP (Low-Rank FFN):
69
- ----------------------
70
- - ffn_decompose (`bool`) β€” Whether to enable low-rank FFNs.
71
- - ffn_rank (`int`) β€” Rank of the shared FFN latent space (required if `ffn_decompose=True`).
72
-
73
- ----------------------
74
- Validation Behavior:
75
- ----------------------
76
- Raises `ValueError` at init time if:
77
- - FFN decomposition is enabled without specifying `ffn_rank`.
78
- - An unknown `attention_backend` is provided.
79
- """
80
-
81
- model_type = "shared_subspace_decoder"
82
-
83
- def __init__(
84
- self,
85
-
86
- # === Core Model ===
87
- vocab_size: int = 30522,
88
- hidden_size: int = 512,
89
- num_hidden_layers: int = 12,
90
-
91
- intermediate_size: int = 3072,
92
-
93
- hidden_dropout_prob=0.1,
94
- attention_dropout_prob=0.1,
95
- max_position_embeddings: int = 2048,
96
- initializer_range=0.02,
97
- layer_norm_eps=1e-12,
98
- rms_norm_eps=1e-6, # Their default, but confirm in config.
99
- norm_type="layernorm", # Choice between "layernorm" and "rmsnorm"
100
- classifier_dropout=None,
101
-
102
- vocab_subspace=False,
103
- vocab_rank=None,
104
- tie_word_embeddings=True,
105
-
106
- # === Multi-Head Latent Attention ===
107
- num_attention_heads: int = 16,
108
- rope_dims: int = 16,
109
-
110
- q_shared_dim: int = None,
111
- kv_shared_dim: int = None,
112
-
113
- o_shared_dim=None, # If None, no output subspace is used
114
-
115
- # Private head dimensions
116
- qk_private_dim: int = None, # Query/key private dimension per head
117
- vo_private_dim: int = None, # Value/output private dimension per head
118
- nope_dims: int = None, # Non-positional encoding dimensions
119
-
120
- attention_backend="eager",
121
- rope_theta=10000.0,
122
- rope_scaling=None,
123
- attention_bias=False,
124
-
125
- # === MLA Composition ===
126
- num_dense_layers=12, # dense MHA layers before MLA starts
127
-
128
- # === Decomposed MLP ===
129
- ffn_decompose=False,
130
- ffn_rank=None,
131
- **kwargs
132
- ) -> None:
133
- super().__init__(**kwargs)
134
-
135
-
136
-
137
- # === Core Model ===
138
- self.vocab_size = vocab_size
139
- self.hidden_size = hidden_size
140
- self.num_hidden_layers = num_hidden_layers
141
- self.intermediate_size = intermediate_size
142
- self.hidden_dropout_prob = hidden_dropout_prob
143
- self.attention_dropout_prob = attention_dropout_prob
144
- self.max_position_embeddings = max_position_embeddings
145
- self.initializer_range = initializer_range
146
- self.layer_norm_eps = layer_norm_eps
147
- self.rms_norm_eps = rms_norm_eps
148
- self.norm_type = norm_type
149
- self.classifier_dropout = classifier_dropout
150
-
151
- self.vocab_subspace = vocab_subspace
152
- self.vocab_rank = vocab_rank
153
- self.tie_word_embeddings = tie_word_embeddings
154
-
155
- # === MLA ===
156
- self.num_attention_heads = num_attention_heads
157
- self.rope_dims = rope_dims
158
-
159
- self.q_shared_dim = q_shared_dim
160
- self.kv_shared_dim = kv_shared_dim
161
- self.o_shared_dim = o_shared_dim
162
-
163
- # Private head dimensions
164
- self.qk_private_dim = qk_private_dim
165
- self.vo_private_dim = vo_private_dim
166
- self.nope_dims = nope_dims
167
- self.rope_theta = rope_theta
168
- self.rope_scaling = rope_scaling
169
- self.attention_bias = attention_bias
170
- self.num_dense_layers = num_dense_layers
171
-
172
- # === Decomposed FFN ===
173
- self.ffn_decompose = ffn_decompose
174
- self.ffn_rank = ffn_rank
175
-
176
- # === Attention backend ===
177
- self.attention_backend = attention_backend
178
-
179
- # === Validation ===
180
- # TODO - Somewhere during training these get instantiated with bad
181
- # values...
182
- #self._validate()
183
-
184
- #print(f" > SubEnc *Config.init: {make_shorthand(self)}\n")
185
-
186
-
187
- def _validate(self):
188
- # === Model ===
189
- if self.num_dense_layers > self.num_hidden_layers:
190
- raise ValueError("`num_dense_layers` must be <= `num_hidden_layers`")
191
- if self.vocab_subspace and self.vocab_rank is None:
192
- raise ValueError("`vocab_rank` must be set when `vocab_subspace=True`")
193
-
194
- # === MLA Validation ===
195
- # At least one of q_shared_dim or kv_shared_dim must be set if we have subspace layers
196
- if self.num_dense_layers < self.num_hidden_layers and self.q_shared_dim is None and self.kv_shared_dim is None:
197
- raise ValueError("At least one of q_shared_dim or kv_shared_dim must be set when there are subspace layers")
198
-
199
- # Validate that private dimensions are set
200
- if self.qk_private_dim is None or self.vo_private_dim is None:
201
- raise ValueError("Must set qk_private_dim and vo_private_dim")
202
- if self.nope_dims is None:
203
- raise ValueError("Must set nope_dims")
204
-
205
- # === Decomposed FFN ===
206
- if self.ffn_decompose and self.ffn_rank is None:
207
- raise ValueError("`ffn_rank` must be set when `ffn_decompose=True`")
208
- if self.ffn_decompose and self.num_dense_layers >= self.num_hidden_layers:
209
- raise ValueError("`ffn_decompose` was set but `num_dense` is >= number of layers")
210
-
211
- # === Attention Backend ===
212
- valid_backends = ["eager", "flash_attention_2", "sdpa"]
213
- if self.attention_backend not in valid_backends:
214
- raise ValueError(f"Unknown attention backend: {self.attention_backend}, options are {valid_backends}")
215
-
216
- # === Norm Type ===
217
- valid_norm_types = ["layernorm", "rmsnorm"]
218
- if self.norm_type not in valid_norm_types:
219
- raise ValueError(f"Unknown norm type: {self.norm_type}, options are {valid_norm_types}")
220
-
221
-
222
- import json
223
-
224
- def get_config(filename):
225
-
226
- # Load the config file.
227
- with open(filename) as f:
228
- full_cfg = json.load(f)
229
-
230
- # Strict key check on the model configuration.
231
-
232
- # Get the list of keys allowed / required by `*Config`
233
- valid_keys = SharedSpaceDecoderConfig.__init__.__code__.co_varnames
234
- # Remove `self` and `kwargs`
235
- valid_keys = set(valid_keys) - {"self", "kwargs"}
236
-
237
- # Compare the set of keys in the json file vs `*Config`
238
- extra_keys = set(full_cfg["model"]) - valid_keys
239
- missing_keys = valid_keys - set(full_cfg["model"])
240
-
241
- # If there any in the `json` that aren't in `*Config`,
242
- if extra_keys:
243
- # List them for the user.
244
- raise ValueError(f"Unknown keys in config: {sorted(extra_keys)}")
245
-
246
- # If the json config is missing required keys,
247
- if missing_keys:
248
- # List them for the user.
249
- raise ValueError(f"config json is missing: {sorted(missing_keys)}")
250
-
251
- # Will raise TypeError, by design, if required args are missing
252
- # The asterisks unpack the dictionary into a list of keywords as though
253
- # all of the settings were writting out individually.
254
- model_cfg = SharedSpaceDecoderConfig(**full_cfg["model"])
255
-
256
- return full_cfg, model_cfg
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_utils import PreTrainedModel
8
+
9
+ class SharedSpaceDecoderConfig(PretrainedConfig):
10
+ r"""
11
+ Configuration class for SharedSpaceDecoderConfig.
12
+
13
+ Extends the HuggingFace `PretrainedConfig` to support architectural
14
+ variations including:
15
+ - Multi-Head Latent Attention (MLA)
16
+ - Decomposed MLPs (low-rank FFNs)
17
+ - Flexible attention backends (eager, flash, sdpa)
18
+ - Explicit shared subspaces for Q, K, V, and O projections
19
+
20
+ This config does not infer any defaults based on `hidden_size`. All
21
+ dimensions and ranks must be explicitly specified. If required values are
22
+ missing, a `ValueError` is raised during initialization.
23
+
24
+ ----------------------
25
+ Core Model Parameters:
26
+ ----------------------
27
+ - vocab_size (`int`) β€” Vocabulary size.
28
+ - hidden_size (`int`) β€” Model hidden dimension.
29
+ - num_hidden_layers (`int`) β€” Number of transformer blocks.
30
+ - intermediate_size (`int`) β€” Feed-forward hidden dimension.
31
+ - hidden_act (`str`) β€” Activation function.
32
+ - hidden_dropout_prob (`float`) β€” Dropout after projections and FFNs.
33
+ - attention_dropout_prob (`float`) β€” Dropout applied to attention scores.
34
+ - max_position_embeddings (`int`) β€” Max sequence length.
35
+ - initializer_range (`float`) β€” Stddev of weight init.
36
+
37
+ - layer_norm_eps (`float`) β€” Epsilon for LayerNorm.
38
+ - rms_norm_ps (`float`) β€” Epsilon for RMSNorm
39
+
40
+ - classifier_dropout (`float` or None) β€” Dropout for final classifier.
41
+
42
+ - vocab_subspace
43
+ - vocab_rank
44
+
45
+ ----------------------------------
46
+ Multi-Head Latent Attention (MLA):
47
+ ----------------------------------
48
+ - num_attention_heads (`int`) β€” Number of attention heads.
49
+
50
+ - q_shared_dim (`int`) β€” Rank of the shared query subspace.
51
+ - kv_shared_dim (`int`) β€” Rank of the shared key/value subspace.
52
+
53
+ - output_subspace (`bool`) β€” Whether to use a shared latent subspace for output projections.
54
+ - o_shared_dim (`int`) β€” Rank of the shared output subspace (required if `output_subspace=True`).
55
+ - qk_private_dim (`int`) β€” Query/key private dimension per head.
56
+ - vo_private_dim (`int`) β€” Value/output private dimension per head.
57
+
58
+ - rope_dims (`int`) β€” Number of head dimensions carrying RoPE.
59
+ - nope_dims (`int`) β€” Non-positional encoding dimensions.
60
+ - rope_theta (`float`) β€” Base frequency used for RoPE.
61
+ - rope_scaling (`dict` or None) β€” HF-style scaling dict for RoPE.
62
+ - attention_bias (`bool`) β€” Whether to include bias terms in Q/K/V projections.
63
+ - num_dense_layers (`int`) β€” Number of leading layers that do not use
64
+ subspaces for attention or FFNs.
65
+ - attention_backend (`str`) β€” Must be one of `"eager"`, `"flash_attention_2"`, or `"sdpa"`.
66
+
67
+ ----------------------
68
+ Decomposed MLP (Low-Rank FFN):
69
+ ----------------------
70
+ - ffn_decompose (`bool`) β€” Whether to enable low-rank FFNs.
71
+ - ffn_rank (`int`) β€” Rank of the shared FFN latent space (required if `ffn_decompose=True`).
72
+
73
+ ----------------------
74
+ Validation Behavior:
75
+ ----------------------
76
+ Raises `ValueError` at init time if:
77
+ - FFN decomposition is enabled without specifying `ffn_rank`.
78
+ - An unknown `attention_backend` is provided.
79
+ """
80
+
81
+ model_type = "shared_space_decoder"
82
+
83
+ def __init__(
84
+ self,
85
+
86
+ # === Core Model ===
87
+ vocab_size: int = 30522,
88
+ hidden_size: int = 512,
89
+ num_hidden_layers: int = 12,
90
+
91
+ intermediate_size: int = 3072,
92
+
93
+ hidden_dropout_prob=0.1,
94
+ attention_dropout_prob=0.1,
95
+ max_position_embeddings: int = 2048,
96
+ initializer_range=0.02,
97
+ layer_norm_eps=1e-12,
98
+ rms_norm_eps=1e-6, # Their default, but confirm in config.
99
+ norm_type="layernorm", # Choice between "layernorm" and "rmsnorm"
100
+ classifier_dropout=None,
101
+
102
+ vocab_subspace=False,
103
+ vocab_rank=None,
104
+ tie_word_embeddings=True,
105
+
106
+ # === Multi-Head Latent Attention ===
107
+ num_attention_heads: int = 16,
108
+ rope_dims: int = 16,
109
+
110
+ q_shared_dim: int = None,
111
+ kv_shared_dim: int = None,
112
+
113
+ o_shared_dim=None, # If None, no output subspace is used
114
+
115
+ # Private head dimensions
116
+ qk_private_dim: int = None, # Query/key private dimension per head
117
+ vo_private_dim: int = None, # Value/output private dimension per head
118
+ nope_dims: int = None, # Non-positional encoding dimensions
119
+
120
+ attention_backend="eager",
121
+ rope_theta=10000.0,
122
+ rope_scaling=None,
123
+ attention_bias=False,
124
+
125
+ # === MLA Composition ===
126
+ num_dense_layers=12, # dense MHA layers before MLA starts
127
+
128
+ # === Decomposed MLP ===
129
+ ffn_decompose=False,
130
+ ffn_rank=None,
131
+ **kwargs
132
+ ) -> None:
133
+ super().__init__(**kwargs)
134
+
135
+
136
+
137
+ # === Core Model ===
138
+ self.vocab_size = vocab_size
139
+ self.hidden_size = hidden_size
140
+ self.num_hidden_layers = num_hidden_layers
141
+ self.intermediate_size = intermediate_size
142
+ self.hidden_dropout_prob = hidden_dropout_prob
143
+ self.attention_dropout_prob = attention_dropout_prob
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.initializer_range = initializer_range
146
+ self.layer_norm_eps = layer_norm_eps
147
+ self.rms_norm_eps = rms_norm_eps
148
+ self.norm_type = norm_type
149
+ self.classifier_dropout = classifier_dropout
150
+
151
+ self.vocab_subspace = vocab_subspace
152
+ self.vocab_rank = vocab_rank
153
+ self.tie_word_embeddings = tie_word_embeddings
154
+
155
+ # === MLA ===
156
+ self.num_attention_heads = num_attention_heads
157
+ self.rope_dims = rope_dims
158
+
159
+ self.q_shared_dim = q_shared_dim
160
+ self.kv_shared_dim = kv_shared_dim
161
+ self.o_shared_dim = o_shared_dim
162
+
163
+ # Private head dimensions
164
+ self.qk_private_dim = qk_private_dim
165
+ self.vo_private_dim = vo_private_dim
166
+ self.nope_dims = nope_dims
167
+ self.rope_theta = rope_theta
168
+ self.rope_scaling = rope_scaling
169
+ self.attention_bias = attention_bias
170
+ self.num_dense_layers = num_dense_layers
171
+
172
+ # === Decomposed FFN ===
173
+ self.ffn_decompose = ffn_decompose
174
+ self.ffn_rank = ffn_rank
175
+
176
+ # === Attention backend ===
177
+ self.attention_backend = attention_backend
178
+
179
+ # === Validation ===
180
+ # TODO - Somewhere during training these get instantiated with bad
181
+ # values...
182
+ #self._validate()
183
+
184
+ #print(f" > SubEnc *Config.init: {make_shorthand(self)}\n")
185
+
186
+
187
+ def _validate(self):
188
+ # === Model ===
189
+ if self.num_dense_layers > self.num_hidden_layers:
190
+ raise ValueError("`num_dense_layers` must be <= `num_hidden_layers`")
191
+ if self.vocab_subspace and self.vocab_rank is None:
192
+ raise ValueError("`vocab_rank` must be set when `vocab_subspace=True`")
193
+
194
+ # === MLA Validation ===
195
+ # At least one of q_shared_dim or kv_shared_dim must be set if we have subspace layers
196
+ if self.num_dense_layers < self.num_hidden_layers and self.q_shared_dim is None and self.kv_shared_dim is None:
197
+ raise ValueError("At least one of q_shared_dim or kv_shared_dim must be set when there are subspace layers")
198
+
199
+ # Validate that private dimensions are set
200
+ if self.qk_private_dim is None or self.vo_private_dim is None:
201
+ raise ValueError("Must set qk_private_dim and vo_private_dim")
202
+ if self.nope_dims is None:
203
+ raise ValueError("Must set nope_dims")
204
+
205
+ # === Decomposed FFN ===
206
+ if self.ffn_decompose and self.ffn_rank is None:
207
+ raise ValueError("`ffn_rank` must be set when `ffn_decompose=True`")
208
+ if self.ffn_decompose and self.num_dense_layers >= self.num_hidden_layers:
209
+ raise ValueError("`ffn_decompose` was set but `num_dense` is >= number of layers")
210
+
211
+ # === Attention Backend ===
212
+ valid_backends = ["eager", "flash_attention_2", "sdpa"]
213
+ if self.attention_backend not in valid_backends:
214
+ raise ValueError(f"Unknown attention backend: {self.attention_backend}, options are {valid_backends}")
215
+
216
+ # === Norm Type ===
217
+ valid_norm_types = ["layernorm", "rmsnorm"]
218
+ if self.norm_type not in valid_norm_types:
219
+ raise ValueError(f"Unknown norm type: {self.norm_type}, options are {valid_norm_types}")
220
+
221
+
222
+ import json
223
+
224
+ def get_config(filename):
225
+
226
+ # Load the config file.
227
+ with open(filename) as f:
228
+ full_cfg = json.load(f)
229
+
230
+ # Strict key check on the model configuration.
231
+
232
+ # Get the list of keys allowed / required by `*Config`
233
+ valid_keys = SharedSpaceDecoderConfig.__init__.__code__.co_varnames
234
+ # Remove `self` and `kwargs`
235
+ valid_keys = set(valid_keys) - {"self", "kwargs"}
236
+
237
+ # Compare the set of keys in the json file vs `*Config`
238
+ extra_keys = set(full_cfg["model"]) - valid_keys
239
+ missing_keys = valid_keys - set(full_cfg["model"])
240
+
241
+ # If there any in the `json` that aren't in `*Config`,
242
+ if extra_keys:
243
+ # List them for the user.
244
+ raise ValueError(f"Unknown keys in config: {sorted(extra_keys)}")
245
+
246
+ # If the json config is missing required keys,
247
+ if missing_keys:
248
+ # List them for the user.
249
+ raise ValueError(f"config json is missing: {sorted(missing_keys)}")
250
+
251
+ # Will raise TypeError, by design, if required args are missing
252
+ # The asterisks unpack the dictionary into a list of keywords as though
253
+ # all of the settings were writting out individually.
254
+ model_cfg = SharedSpaceDecoderConfig(**full_cfg["model"])
255
+
256
+ return full_cfg, model_cfg