Xsmos commited on
Commit
bd91d07
·
verified ·
1 Parent(s): e1c7ad1

Upload updated mosaic-light model

Browse files
Files changed (4) hide show
  1. config.json +48 -0
  2. foundation_bert.py +396 -0
  3. model.safetensors +3 -0
  4. train_config.yaml +247 -0
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_auto_class": "FoundationBert",
3
+ "auto_map": {
4
+ "AutoModel": "foundation_bert.FoundationBert"
5
+ },
6
+ "architectures": [
7
+ "FoundationBert"
8
+ ],
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "attention_probs_dropout_prob": 0.1,
12
+ "bos_token_id": 50281,
13
+ "classifier_activation": "gelu",
14
+ "classifier_bias": false,
15
+ "classifier_dropout": 0.0,
16
+ "classifier_pooling": "cls",
17
+ "cls_token_id": 50281,
18
+ "decoder_bias": true,
19
+ "deterministic_flash_attn": false,
20
+ "dtype": "float32",
21
+ "embedding_dropout": 0.0,
22
+ "eos_token_id": 50282,
23
+ "global_attn_every_n_layers": 3,
24
+ "global_rope_theta": 160000.0,
25
+ "hidden_activation": "gelu",
26
+ "hidden_dropout_prob": 0.1,
27
+ "hidden_size": 384,
28
+ "initializer_cutoff_factor": 2.0,
29
+ "initializer_range": 0.02,
30
+ "intermediate_size": 3072,
31
+ "local_attention": 128,
32
+ "local_rope_theta": 10000.0,
33
+ "max_position_embeddings": 1149,
34
+ "mlp_bias": false,
35
+ "mlp_dropout": 0.0,
36
+ "model_type": "modernbert",
37
+ "norm_bias": false,
38
+ "norm_eps": 1e-05,
39
+ "num_attention_heads": 12,
40
+ "num_hidden_layers": 8,
41
+ "pad_token_id": -1,
42
+ "repad_logits_with_grad": false,
43
+ "sep_token_id": 50282,
44
+ "sparse_pred_ignore_index": -100,
45
+ "sparse_prediction": false,
46
+ "transformers_version": "4.57.1",
47
+ "vocab_size": 2048
48
+ }
foundation_bert.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import yaml
4
+ from pathlib import Path
5
+ # from ..utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax
6
+ # from ..utils.contrastive_loss import ContrastiveLoss
7
+ # from ..utils.yaml_util import MyLoader
8
+ from dataclasses import dataclass
9
+ from transformers import ModernBertModel, ModernBertConfig, PretrainedConfig
10
+ from typing import Optional, Union
11
+
12
+ # import yaml
13
+ class MyLoader(yaml.SafeLoader):
14
+ # returns
15
+ def construct_mapping(self, *args, **kwargs):
16
+ super().add_constructor(None, construct_undefined)
17
+ # when loading we want to skip keys that require construction,
18
+ mapping = super().construct_mapping(*args, **kwargs)
19
+
20
+ return mapping
21
+
22
+ import typing
23
+ class Tagged(typing.NamedTuple):
24
+ tag: str
25
+ value: object
26
+
27
+ def construct_undefined(self, node):
28
+ if isinstance(node, yaml.nodes.ScalarNode):
29
+ value = self.construct_scalar(node)
30
+ elif isinstance(node, yaml.nodes.SequenceNode):
31
+ value = self.construct_sequence(node)
32
+ elif isinstance(node, yaml.nodes.MappingNode):
33
+ value = self.construct_mapping(node)
34
+ else:
35
+ assert False, f"unexpected node: {node!r}"
36
+ return Tagged(node.tag, value)
37
+
38
+ @dataclass
39
+ class FoundationOutput:
40
+ loss: torch.Tensor = None
41
+ logits: torch.Tensor = None
42
+ num_output: torch.Tensor = None
43
+ est_err_output: torch.Tensor = None
44
+ hidden_states: torch.Tensor = None
45
+ masked_loss: torch.Tensor = None
46
+ num_loss: torch.Tensor = None
47
+ est_err_loss: torch.Tensor = None
48
+
49
+
50
+ @dataclass
51
+ class FoundationBertConfig:
52
+ vocab_size: int
53
+ hidden_size: int
54
+ num_hidden_layers: int
55
+ num_attention_heads: int
56
+ intermediate_size: int
57
+ hidden_dropout_prob: float
58
+ attention_probs_dropout_prob: float
59
+ pad_token_id: int
60
+ classifier_dropout: float
61
+ max_position_embeddings: int
62
+ contrastive_temperature: float
63
+ loss_weights: dict
64
+ use_xval_loss: bool = True
65
+ use_mlm_loss: bool = True
66
+ use_regression_loss: bool = False
67
+ use_contrastive_loss: bool = False
68
+ transform_numeric: bool = False
69
+ use_sdpa_attention: bool = True
70
+
71
+ def to_dict(self):
72
+ return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()}
73
+
74
+ class FoundationBert(ModernBertModel):
75
+ def __init__(self,
76
+ config: FoundationBertConfig = None,
77
+ use_mlm_loss: bool = False,
78
+ use_regression_loss: bool = True,
79
+ use_contrastive_loss: bool = False,
80
+ use_xval_loss: bool = False,
81
+ transform_numeric: bool = False,
82
+ *args,
83
+ **kwargs):
84
+ self.gconfig = config
85
+ # print(f"⚠️ FoundationBert.__init__: {self.gconfig=}")
86
+ bert_conf = ModernBertConfig(
87
+ vocab_size=config.vocab_size,
88
+ hidden_size=config.hidden_size,
89
+ num_hidden_layers=config.num_hidden_layers,
90
+ num_attention_heads=config.num_attention_heads,
91
+ intermediate_size=config.intermediate_size,
92
+ hidden_dropout_prob=config.hidden_dropout_prob,
93
+ attention_probs_dropout_prob=config.attention_probs_dropout_prob,
94
+ pad_token_id=config.pad_token_id,
95
+ max_position_embeddings=config.max_position_embeddings,
96
+ _attn_implementation='sdpa'
97
+ )
98
+ self.gconfig.transform_numeric = transform_numeric
99
+ super().__init__(bert_conf,)
100
+ try:
101
+ if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss:
102
+ raise ValueError("At least one loss must be enabled")
103
+ self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
104
+ except:
105
+ self.gconfig.use_mlm_loss = use_mlm_loss
106
+ self.gconfig.use_regression_loss = use_regression_loss
107
+ self.gconfig.use_contrastive_loss = use_contrastive_loss
108
+ self.gconfig.use_xval_loss = use_xval_loss
109
+ self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
110
+
111
+ self.dataset_path = kwargs.get('dataset_path', None)
112
+
113
+ self.vector_shape = kwargs['vector_shape']
114
+ self.scalar_shape = kwargs['scalar_shape']
115
+ self.mask_token = kwargs['mask_token']
116
+
117
+ # self.scalar_keys = [
118
+ # 'redshift',
119
+ # 'halo_mass',
120
+ # 'stellar_mass',
121
+ # ]
122
+ # self.vector_keys = [
123
+ # 'SED',
124
+ # 'SFH',
125
+ # 'mag_{band}_spherex',
126
+ # 'mag_{band}_lsst',
127
+ # ]
128
+
129
+ # convert modality names to 'scalars' or keep as is if in vector shape
130
+ self.modalscalars = [m if m in self.vector_shape else 'scalars' for m in self.modalities]
131
+ # remove duplicates while preserving order
132
+ self.modalscalars = list(dict.fromkeys(self.modalscalars))
133
+
134
+ print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅")
135
+
136
+ self.embedding = torch.nn.ModuleDict() # modality specific embedding layers
137
+ self.num_head = torch.nn.ModuleDict() # modality specific regression heads
138
+ # create modality specific layers
139
+ for modality in self.modalscalars:
140
+ self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H)
141
+ self.num_head[modality] = torch.nn.Sequential(
142
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
143
+ torch.nn.LayerNorm(config.hidden_size),
144
+ torch.nn.GELU(),
145
+ torch.nn.Linear(config.hidden_size, config.hidden_size // 2),
146
+ torch.nn.GELU(),
147
+ torch.nn.Linear(config.hidden_size // 2, 1)
148
+ )
149
+
150
+ # self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
151
+ self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
152
+
153
+ # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
154
+ # self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
155
+ # self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
156
+ self.distributed_loss = False
157
+
158
+ @property
159
+ def modalities(self):
160
+ return self.vector_shape | self.scalar_shape
161
+
162
+ @classmethod
163
+ def from_pretrained(self,
164
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
165
+ *model_args,
166
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
167
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
168
+ ignore_mismatched_sizes: bool = False,
169
+ force_download: bool = False,
170
+ local_files_only: bool = False,
171
+ token: Optional[Union[str, bool]] = None,
172
+ revision: str = "main",
173
+ use_safetensors: bool = None,
174
+ **kwargs,
175
+ ):
176
+ """
177
+ Modification to correctly handle loading extraneous parameters for GBert
178
+ """
179
+ if 'checkpoint' in pretrained_model_name_or_path:
180
+ model_config = Path(pretrained_model_name_or_path).parent / 'train_config.yaml'
181
+ elif 'train_config.yaml' in os.listdir(pretrained_model_name_or_path):
182
+ model_config = Path(pretrained_model_name_or_path) / 'train_config.yaml'
183
+ else:
184
+ raise ValueError(f"Could not find train_config.yaml in {pretrained_model_name_or_path}")
185
+
186
+ with open(model_config, 'r') as f:
187
+ config = yaml.load(f, Loader=MyLoader)
188
+ kwargs['modalities'] = config['modalities']
189
+ kwargs['dataset_path'] = config['dataset_path']
190
+ kwargs['mask_token'] = config['mask_token']
191
+
192
+ if 'vector_shape' not in kwargs and 'vector_shape' in config:
193
+ kwargs['vector_shape'] = config['vector_shape']
194
+ if 'scalar_shape' not in kwargs and 'scalar_shape' in config:
195
+ kwargs['scalar_shape'] = config['scalar_shape']
196
+
197
+ print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅")
198
+ return super().from_pretrained(
199
+ pretrained_model_name_or_path,
200
+ **config['model_config'],
201
+ **kwargs
202
+ )
203
+
204
+ def pool_output(self,
205
+ embeddings: torch.Tensor,
206
+ attention_mask: torch.Tensor,
207
+ use_last: bool = False
208
+ ) -> torch.Tensor:
209
+ """Average pool the hidden states using the attention mask.
210
+
211
+ Parameters
212
+ ----------
213
+ embeddings : torch.Tensor
214
+ The hidden states to pool (B, SeqLen, HiddenDim).
215
+ attention_mask : torch.Tensor
216
+ The attention mask for the hidden states (B, SeqLen).
217
+
218
+ Returns
219
+ -------
220
+ torch.Tensor
221
+ The pooled embeddings (B, HiddenDim).
222
+ """
223
+ # Get the sequence lengths
224
+ sl_mod = 1 if use_last else 2
225
+ seq_lengths = attention_mask.sum(axis=1)
226
+ # Set the attention mask to 0 for start and end tokens
227
+ new_attention = attention_mask.clone()
228
+ new_attention[:, 0] = attention_mask[:,0] * 0
229
+ new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod]
230
+
231
+ # Create a mask for the pooling operation (B, SeqLen, HiddenDim)
232
+ pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device)
233
+ # Sum the embeddings over the sequence length (use the mask to avoid
234
+ # pad, start, and stop tokens)
235
+ sum_embeds = torch.sum(embeddings * pool_mask, 1)
236
+ # Avoid division by zero for zero length sequences by clamping
237
+ # sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9)
238
+ seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast
239
+ # Compute mean pooled embeddings for each sequence
240
+ return sum_embeds / seq_lengths
241
+
242
+
243
+ def last_token_pool(
244
+ self,
245
+ embeddings: torch.Tensor,
246
+ attention_mask: torch.Tensor,
247
+ ) -> torch.Tensor:
248
+ """Pool the last hidden states using the attention mask.
249
+
250
+ Parameters
251
+ ----------
252
+ embeddings : torch.Tensor
253
+ The last hidden states to pool (B, SeqLen, HiddenDim).
254
+ attention_mask : torch.Tensor
255
+ The attention mask for the hidden states (B, SeqLen).
256
+
257
+ Returns
258
+ -------
259
+ torch.Tensor
260
+ The pooled embeddings (B, HiddenDim).
261
+ """
262
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
263
+ if left_padding:
264
+ return embeddings[:, -1]
265
+ else:
266
+ sequence_lengths = attention_mask.sum(dim=1) - 1
267
+ batch_size = embeddings.shape[0]
268
+ return embeddings[
269
+ torch.arange(batch_size, device=embeddings.device),
270
+ sequence_lengths,
271
+ ]
272
+
273
+ def forward(self, inputs, return_input_label_mapping=False):
274
+ """
275
+ Forward pass that computes predictions for each modality.
276
+
277
+ Args:
278
+ input_label_mapping (dict): A dictionary containing inputs and labels for different modalities.
279
+
280
+ Returns:
281
+ outputs (dict): A dictionary containing the logits and error logits for each modality.
282
+ """
283
+
284
+ # Initialize the dictionary for the dynamic input-label mapping
285
+ input_label_mapping = {}
286
+ combined = []
287
+ for src_modality in self.modalscalars:
288
+ # Add the modality's input and label data to the input_label_mapping
289
+ input_label_mapping[src_modality] = {
290
+ 'input': inputs[f"input_{src_modality}"], # Input data
291
+ 'labels': inputs[f"labels_{src_modality}"] # Corresponding labels
292
+ }
293
+
294
+ input_data = input_label_mapping[src_modality]['input'] # get input data
295
+ label = input_label_mapping[src_modality]['labels'] # get label data (for masking)
296
+ input_data = torch.where(label, self.mask_token, input_data) # apply masking
297
+
298
+ x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H)
299
+ x = torch.nn.functional.silu(x)
300
+ combined.append(x) # combine all modalities
301
+
302
+ combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension
303
+
304
+ position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L)
305
+ # combined += self.position_embeddings(position_ids) # add position embedding
306
+ combined = self.embed_dropout(combined)
307
+
308
+ # x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input
309
+ hidden_states = combined
310
+ for encoder_layer in self.layers:
311
+ hidden_states = encoder_layer(hidden_states, position_ids = position_ids)[0]
312
+ x = self.final_norm(hidden_states)
313
+
314
+ start = 0
315
+ outputs = {}
316
+ # Iterate over each target modality to compute logits
317
+ for tgt_modality in self.modalscalars:
318
+ length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality
319
+ x_t = x[:, start:start+length, :] # slice the encoded output for each modality
320
+ outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head
321
+
322
+ start += length # update start index for next modality
323
+
324
+ if getattr(self, 'save_umap_for', None):
325
+ pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension
326
+ self.save_pooled_embedding(pooled) # saved for UMAP visualization
327
+
328
+ return (outputs, input_label_mapping) if return_input_label_mapping else outputs
329
+
330
+ def save_pooled_embedding(self, features):
331
+ """
332
+ Save the last hidden state to a file.
333
+ """
334
+ import h5py
335
+ fname = Path(self.save_umap_for)
336
+ fname.parent.mkdir(parents=True, exist_ok=True)
337
+
338
+ features = features.detach().cpu().numpy()
339
+
340
+ if fname.exists():
341
+ with h5py.File(fname, 'r+') as f:
342
+ old_size = f['features'].shape[0] # get current size
343
+ new_size = old_size + features.shape[0] # calculate new size
344
+
345
+ f['features'].resize((new_size, features.shape[-1])) # resize dataset
346
+ f['features'][old_size:] = features # append new features
347
+
348
+ else:
349
+ with h5py.File(fname, 'w') as f:
350
+ f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)
351
+
352
+ def get_retrieval_embedding(
353
+ self,
354
+ inputs,
355
+ pooling: str = "mean",
356
+ normalize: bool = True,
357
+ ) -> torch.Tensor:
358
+ """
359
+ Build a single embedding per sample for kNN-style retrieval.
360
+
361
+ Parameters
362
+ ----------
363
+ inputs : dict
364
+ Batch dict with `input_<modality>` and `labels_<modality>` entries.
365
+ pooling : str
366
+ `mean` (default) or `last`.
367
+ normalize : bool
368
+ L2-normalize output embeddings for cosine/inner-product search.
369
+ """
370
+ combined = []
371
+ for src_modality in self.modalscalars:
372
+ input_data = inputs[f"input_{src_modality}"]
373
+ label = inputs[f"labels_{src_modality}"]
374
+ input_data = torch.where(label, self.mask_token, input_data)
375
+ x = self.embedding[src_modality](input_data.unsqueeze(-1))
376
+ x = torch.nn.functional.silu(x)
377
+ combined.append(x)
378
+
379
+ combined = torch.cat(combined, dim=1)
380
+ position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device)
381
+ combined = self.embed_dropout(combined)
382
+
383
+ hidden_states = combined
384
+ for encoder_layer in self.layers:
385
+ hidden_states = encoder_layer(hidden_states, position_ids=position_ids)[0]
386
+ hidden_states = self.final_norm(hidden_states)
387
+
388
+ if pooling == "last":
389
+ embedding = hidden_states[:, -1, :]
390
+ else:
391
+ embedding = hidden_states.mean(dim=1)
392
+
393
+ if normalize:
394
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
395
+
396
+ return embedding
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80913251a26a74a8704f02417ae8abf0f2691fc94c4ef3ebf3bc62be8154659a
3
+ size 139771732
train_config.yaml ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: /pscratch/sd/b/binxia/supermock_dataset_11.2-14.json
2
+ input_errors:
3
+ - 0
4
+ - 0
5
+ - 0
6
+ - 0
7
+ - 0
8
+ - 0
9
+ - 0
10
+ mask_token: 0
11
+ masked_generation: false
12
+ masking_prob:
13
+ - 0.2
14
+ - 0.2
15
+ - 0.2
16
+ - 0.2
17
+ - 0.5
18
+ - 0.5
19
+ - 0.5
20
+ modalities:
21
+ - SFH
22
+ - SED
23
+ - mag_{band}_spherex
24
+ - mag_{band}_lsst
25
+ - redshift
26
+ - halo_mass
27
+ - stellar_mass
28
+ scalar_shape:
29
+ redshift:
30
+ - 20000
31
+ - 1
32
+ halo_mass:
33
+ - 20000
34
+ - 1
35
+ stellar_mass:
36
+ - 20000
37
+ - 1
38
+ vector_shape:
39
+ SFH:
40
+ - 20000
41
+ - 117
42
+ SED:
43
+ - 20000
44
+ - 921
45
+ mag_{band}_spherex:
46
+ - 20000
47
+ - 102
48
+ mag_{band}_lsst:
49
+ - 20000
50
+ - 6
51
+ model_config:
52
+ attention_probs_dropout_prob: 0.1
53
+ classifier_dropout: 0.0
54
+ contrastive_temperature: 0.05
55
+ hidden_dropout_prob: 0.1
56
+ hidden_size: 384
57
+ intermediate_size: 3072
58
+ loss_weights:
59
+ contrastive:
60
+ rounds: 0
61
+ w0T:
62
+ - 0
63
+ - 0
64
+ masked:
65
+ rounds: 0
66
+ w0T:
67
+ - 0.8
68
+ - 3
69
+ smooth:
70
+ rounds: 0
71
+ w0T:
72
+ - 0
73
+ - 0.3
74
+ unmasked:
75
+ rounds: 0
76
+ w0T:
77
+ - 0.2
78
+ - 0.3
79
+ max_position_embeddings: 1149
80
+ num_attention_heads: 12
81
+ num_hidden_layers: 8
82
+ pad_token_id: -1
83
+ transform_numeric: false
84
+ use_contrastive_loss: false
85
+ use_mlm_loss: true
86
+ use_regression_loss: false
87
+ use_sdpa_attention: true
88
+ use_xval_loss: false
89
+ vocab_size: 2048
90
+ model_name_or_path: galaxybert
91
+ num_total_samples: -1
92
+ tokenizer_name_or_path: Salesforce/SFR-Embedding-Mistral
93
+ training_args:
94
+ _n_gpu: 1
95
+ accelerator_config:
96
+ dispatch_batches: null
97
+ even_batches: true
98
+ gradient_accumulation_kwargs: null
99
+ non_blocking: false
100
+ split_batches: false
101
+ use_configured_state: false
102
+ use_seedable_sampler: true
103
+ adafactor: false
104
+ adam_beta1: 0.9
105
+ adam_beta2: 0.999
106
+ adam_epsilon: 1.0e-08
107
+ auto_find_batch_size: false
108
+ average_tokens_across_devices: true
109
+ batch_eval_metrics: false
110
+ bf16: true
111
+ bf16_full_eval: false
112
+ data_seed: null
113
+ dataloader_drop_last: false
114
+ dataloader_num_workers: 16
115
+ dataloader_persistent_workers: false
116
+ dataloader_pin_memory: true
117
+ dataloader_prefetch_factor: 8
118
+ ddp_backend: null
119
+ ddp_broadcast_buffers: null
120
+ ddp_bucket_cap_mb: null
121
+ ddp_find_unused_parameters: null
122
+ ddp_timeout: 1800
123
+ debug: []
124
+ deepspeed: null
125
+ disable_tqdm: false
126
+ do_eval: true
127
+ do_predict: false
128
+ do_train: false
129
+ eval_accumulation_steps: 5
130
+ eval_delay: 0
131
+ eval_do_concat_batches: true
132
+ eval_on_start: false
133
+ eval_steps: 20
134
+ eval_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
135
+ - steps
136
+ eval_use_gather_object: false
137
+ fp16: false
138
+ fp16_backend: auto
139
+ fp16_full_eval: false
140
+ fp16_opt_level: O1
141
+ fsdp: []
142
+ fsdp_config:
143
+ min_num_params: 0
144
+ xla: false
145
+ xla_fsdp_grad_ckpt: false
146
+ xla_fsdp_v2: false
147
+ fsdp_min_num_params: 0
148
+ fsdp_transformer_layer_cls_to_wrap: null
149
+ full_determinism: false
150
+ gradient_accumulation_steps: 5
151
+ gradient_checkpointing: false
152
+ gradient_checkpointing_kwargs: null
153
+ greater_is_better: null
154
+ group_by_length: false
155
+ half_precision_backend: auto
156
+ hub_always_push: false
157
+ hub_model_id: null
158
+ hub_private_repo: null
159
+ hub_revision: null
160
+ hub_strategy: !!python/object/apply:transformers.trainer_utils.HubStrategy
161
+ - every_save
162
+ hub_token: null
163
+ ignore_data_skip: false
164
+ include_for_metrics: []
165
+ include_inputs_for_metrics: false
166
+ include_num_input_tokens_seen: 'no'
167
+ include_tokens_per_second: false
168
+ jit_mode_eval: false
169
+ label_names: null
170
+ label_smoothing_factor: 0.0
171
+ learning_rate: 0.0001
172
+ length_column_name: length
173
+ liger_kernel_config: null
174
+ load_best_model_at_end: false
175
+ local_rank: 3
176
+ log_level: passive
177
+ log_level_replica: warning
178
+ log_on_each_node: true
179
+ logging_dir: sm_foundation_lg_gmm_nomasklab
180
+ logging_first_step: true
181
+ logging_nan_inf_filter: true
182
+ logging_steps: 1
183
+ logging_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
184
+ - steps
185
+ lr_scheduler_kwargs: {}
186
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
187
+ - cosine
188
+ max_grad_norm: 1.0
189
+ max_steps: -1
190
+ metric_for_best_model: null
191
+ mp_parameters: ''
192
+ neftune_noise_alpha: null
193
+ no_cuda: false
194
+ num_train_epochs: 120
195
+ optim: !!python/object/apply:transformers.training_args.OptimizerNames
196
+ - adamw_torch
197
+ optim_args: null
198
+ optim_target_modules: null
199
+ output_dir: supermock_light_nte120_nts-1
200
+ overwrite_output_dir: true
201
+ parallelism_config: null
202
+ past_index: -1
203
+ per_device_eval_batch_size: 40
204
+ per_device_train_batch_size: 40
205
+ per_gpu_eval_batch_size: null
206
+ per_gpu_train_batch_size: null
207
+ prediction_loss_only: false
208
+ project: huggingface
209
+ push_to_hub: false
210
+ push_to_hub_model_id: null
211
+ push_to_hub_organization: null
212
+ push_to_hub_token: null
213
+ ray_scope: last
214
+ remove_unused_columns: false
215
+ report_to:
216
+ - wandb
217
+ restore_callback_states_from_checkpoint: false
218
+ resume_from_checkpoint: null
219
+ run_name: NO_SHARD_b50
220
+ save_on_each_node: false
221
+ save_only_model: false
222
+ save_safetensors: true
223
+ save_steps: 30
224
+ save_strategy: !!python/object/apply:transformers.trainer_utils.SaveStrategy
225
+ - steps
226
+ save_total_limit: 360
227
+ seed: 42
228
+ skip_memory_metrics: true
229
+ tf32: null
230
+ torch_compile: false
231
+ torch_compile_backend: null
232
+ torch_compile_mode: null
233
+ torch_empty_cache_steps: null
234
+ torchdynamo: null
235
+ tpu_metrics_debug: false
236
+ tpu_num_cores: null
237
+ trackio_space_id: trackio
238
+ use_cpu: false
239
+ use_legacy_prediction_loop: false
240
+ use_liger_kernel: false
241
+ use_mps_device: false
242
+ warmup_ratio: 0.0
243
+ warmup_steps: 0
244
+ weight_decay: 0.1
245
+ transform_numeric: false
246
+ wandb_project: supermock-foundation-perl
247
+ wandb_run_name: ''