Xsmos commited on
Commit
719e5d5
·
verified ·
1 Parent(s): b281600

Initial upload of MOSAIC FoundationBert model, v1.0. Final successful local test.

Browse files
__pycache__/foundation_bert.cpython-312.pyc ADDED
Binary file (15.7 kB). View file
 
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_auto_class": "FoundationBert",
3
+ "auto_map": {
4
+ "AutoModel": "foundation_bert.py.FoundationBert"
5
+ },
6
+
7
+ "architectures": [
8
+ "FoundationBert"
9
+ ],
10
+ "attention_probs_dropout_prob": 0.1,
11
+ "classifier_dropout": null,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 1149,
19
+ "model_type": "bert",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 18,
22
+ "pad_token_id": -1,
23
+ "position_embedding_type": "absolute",
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.46.3",
26
+ "type_vocab_size": 2,
27
+ "use_cache": true,
28
+ "vocab_size": 2048
29
+ }
environment.yml ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: fsdp
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=conda_forge
6
+ - _openmp_mutex=4.5=2_gnu
7
+ - anyio=4.9.0=pyh29332c3_0
8
+ - archspec=0.2.3=pyhd8ed1ab_0
9
+ - argon2-cffi=23.1.0=pyhd8ed1ab_1
10
+ - argon2-cffi-bindings=21.2.0=py312h66e93f0_5
11
+ - arrow=1.3.0=pyhd8ed1ab_1
12
+ - asttokens=3.0.0=pyhd8ed1ab_1
13
+ - attrs=25.3.0=pyh71513ae_0
14
+ - beautifulsoup4=4.13.3=pyha770c72_0
15
+ - bleach=6.2.0=pyh29332c3_4
16
+ - bleach-with-css=6.2.0=h82add2a_4
17
+ - boltons=24.0.0=pyhd8ed1ab_0
18
+ - brotli-python=1.1.0=py312h2ec8cdc_2
19
+ - bzip2=1.0.8=h4bc722e_7
20
+ - c-ares=1.32.3=h4bc722e_0
21
+ - ca-certificates=2025.4.26=hbd8a1cb_0
22
+ - cached-property=1.5.2=hd8ed1ab_1
23
+ - cached_property=1.5.2=pyha770c72_1
24
+ - certifi=2025.4.26=pyhd8ed1ab_0
25
+ - cffi=1.17.0=py312h06ac9bb_1
26
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
27
+ - colorama=0.4.6=pyhd8ed1ab_0
28
+ - comm=0.2.2=pyhd8ed1ab_1
29
+ - conda-package-handling=2.3.0=pyh7900ff3_0
30
+ - conda-package-streaming=0.10.0=pyhd8ed1ab_0
31
+ - debugpy=1.8.13=py312h2ec8cdc_0
32
+ - decorator=5.2.1=pyhd8ed1ab_0
33
+ - defusedxml=0.7.1=pyhd8ed1ab_0
34
+ - distro=1.9.0=pyhd8ed1ab_0
35
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
36
+ - fmt=11.0.2=h434a139_0
37
+ - fqdn=1.5.1=pyhd8ed1ab_1
38
+ - frozendict=2.4.4=py312h9a8786e_0
39
+ - gitdb=4.0.12=pyhd8ed1ab_0
40
+ - gitpython=3.1.44=pyhff2d567_0
41
+ - h2=4.1.0=pyhd8ed1ab_0
42
+ - hpack=4.0.0=pyh9f0ad1d_0
43
+ - hyperframe=6.0.1=pyhd8ed1ab_0
44
+ - icu=75.1=he02047a_0
45
+ - idna=3.8=pyhd8ed1ab_0
46
+ - importlib-metadata=8.6.1=pyha770c72_0
47
+ - importlib_resources=6.5.2=pyhd8ed1ab_0
48
+ - ipykernel=6.29.5=pyh3099207_0
49
+ - ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0
50
+ - isoduration=20.11.0=pyhd8ed1ab_1
51
+ - jedi=0.19.2=pyhd8ed1ab_1
52
+ - jq=1.7.1=hd590300_0
53
+ - jsonpatch=1.33=pyhd8ed1ab_0
54
+ - jsonpointer=3.0.0=py312h7900ff3_1
55
+ - jsonschema=4.23.0=pyhd8ed1ab_1
56
+ - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
57
+ - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
58
+ - jupyter-server-mathjax=0.2.6=pyhbbac1ac_2
59
+ - jupyter_client=8.6.3=pyhd8ed1ab_1
60
+ - jupyter_core=5.7.2=pyh31011fe_1
61
+ - jupyter_events=0.12.0=pyh29332c3_0
62
+ - jupyter_server=2.15.0=pyhd8ed1ab_0
63
+ - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
64
+ - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
65
+ - keyutils=1.6.1=h166bdaf_0
66
+ - krb5=1.21.3=h659f571_0
67
+ - ld_impl_linux-64=2.40=hf3520f5_7
68
+ - libarchive=3.7.4=hfca40fe_0
69
+ - libcurl=8.10.1=hbbe4b11_0
70
+ - libedit=3.1.20191231=he28a2e2_2
71
+ - libev=4.33=hd590300_2
72
+ - libexpat=2.6.2=h59595ed_0
73
+ - libffi=3.4.2=h7f98852_5
74
+ - libgcc=14.1.0=h77fa898_1
75
+ - libgcc-ng=14.1.0=h69a702a_1
76
+ - libgomp=14.1.0=h77fa898_1
77
+ - libiconv=1.17=hd590300_2
78
+ - libmamba=1.5.10=hf72d635_1
79
+ - libmambapy=1.5.10=py312hf3f0a4e_1
80
+ - libnghttp2=1.58.0=h47da74e_1
81
+ - libnsl=2.0.1=hd590300_0
82
+ - libsodium=1.0.20=h4ab18f5_0
83
+ - libsolv=0.7.30=h3509ff9_0
84
+ - libsqlite=3.46.1=hadc24fc_0
85
+ - libssh2=1.11.0=h0841786_0
86
+ - libstdcxx=14.1.0=hc0a3c3a_1
87
+ - libstdcxx-ng=14.1.0=h4852527_1
88
+ - libuuid=2.38.1=h0b41bf4_0
89
+ - libxcrypt=4.4.36=hd590300_1
90
+ - libxml2=2.12.7=he7c6b58_4
91
+ - libzlib=1.3.1=h4ab18f5_1
92
+ - lz4-c=1.9.4=hcb278e6_0
93
+ - lzo=2.10=hd590300_1001
94
+ - markupsafe=3.0.2=py312h178313f_1
95
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
96
+ - menuinst=2.1.2=py312h7900ff3_1
97
+ - mistune=3.1.3=pyh29332c3_0
98
+ - nbclient=0.10.2=pyhd8ed1ab_0
99
+ - nbconvert-core=7.16.6=pyh29332c3_0
100
+ - nbdime=4.0.2=pyhd8ed1ab_1
101
+ - nbformat=5.10.4=pyhd8ed1ab_1
102
+ - ncurses=6.5=he02047a_1
103
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
104
+ - oniguruma=6.9.10=hb9d3cd8_0
105
+ - openssl=3.5.0=h7b32b05_1
106
+ - overrides=7.7.0=pyhd8ed1ab_1
107
+ - packaging=24.1=pyhd8ed1ab_0
108
+ - parso=0.8.4=pyhd8ed1ab_1
109
+ - pexpect=4.9.0=pyhd8ed1ab_1
110
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
111
+ - pip=24.2=pyh8b19718_1
112
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
113
+ - platformdirs=4.2.2=pyhd8ed1ab_0
114
+ - pluggy=1.5.0=pyhd8ed1ab_0
115
+ - prometheus_client=0.21.1=pyhd8ed1ab_0
116
+ - prompt-toolkit=3.0.50=pyha770c72_0
117
+ - psutil=7.0.0=py312h66e93f0_0
118
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
119
+ - pure_eval=0.2.3=pyhd8ed1ab_1
120
+ - pybind11-abi=4=hd8ed1ab_3
121
+ - pycosat=0.6.6=py312h98912ed_0
122
+ - pycparser=2.22=pyhd8ed1ab_0
123
+ - pygments=2.19.1=pyhd8ed1ab_0
124
+ - pysocks=1.7.1=pyha2e5f31_6
125
+ - python=3.12.5=h2ad013b_0_cpython
126
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
127
+ - python-fastjsonschema=2.21.1=pyhd8ed1ab_0
128
+ - python_abi=3.12=5_cp312
129
+ - pyyaml=6.0.2=py312h178313f_2
130
+ - pyzmq=26.2.1=py312hbf22597_0
131
+ - readline=8.2=h8228510_1
132
+ - referencing=0.36.2=pyh29332c3_0
133
+ - reproc=14.2.4.post0=hd590300_1
134
+ - reproc-cpp=14.2.4.post0=h59595ed_1
135
+ - requests=2.32.3=pyhd8ed1ab_0
136
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_1
137
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
138
+ - ruamel.yaml=0.18.6=py312h98912ed_0
139
+ - ruamel.yaml.clib=0.2.8=py312h98912ed_0
140
+ - send2trash=1.8.3=pyh0d859eb_1
141
+ - six=1.17.0=pyhd8ed1ab_0
142
+ - smmap=5.0.2=pyhd8ed1ab_0
143
+ - sniffio=1.3.1=pyhd8ed1ab_1
144
+ - stack_data=0.6.3=pyhd8ed1ab_1
145
+ - terminado=0.18.1=pyh0d859eb_0
146
+ - tinycss2=1.4.0=pyhd8ed1ab_0
147
+ - tk=8.6.13=noxft_h4845f30_101
148
+ - tornado=6.4.2=py312h66e93f0_0
149
+ - tqdm=4.66.5=pyhd8ed1ab_0
150
+ - traitlets=5.14.3=pyhd8ed1ab_1
151
+ - truststore=0.9.2=pyhd8ed1ab_0
152
+ - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
153
+ - typing-extensions=4.12.2=hd8ed1ab_1
154
+ - typing_extensions=4.12.2=pyha770c72_1
155
+ - typing_utils=0.1.0=pyhd8ed1ab_1
156
+ - uri-template=1.3.0=pyhd8ed1ab_1
157
+ - urllib3=2.2.2=pyhd8ed1ab_1
158
+ - wcwidth=0.2.13=pyhd8ed1ab_1
159
+ - webcolors=24.11.1=pyhd8ed1ab_0
160
+ - webencodings=0.5.1=pyhd8ed1ab_3
161
+ - websocket-client=1.8.0=pyhd8ed1ab_1
162
+ - wheel=0.44.0=pyhd8ed1ab_0
163
+ - xz=5.2.6=h166bdaf_0
164
+ - yaml=0.2.5=h7f98852_2
165
+ - yaml-cpp=0.8.0=h59595ed_0
166
+ - zeromq=4.3.5=h3b0a872_7
167
+ - zipp=3.21.0=pyhd8ed1ab_1
168
+ - zstandard=0.23.0=py312hef9b889_1
169
+ - zstd=1.5.6=ha6fb4c9_0
170
+ - pip:
171
+ - accelerate==1.1.1
172
+ - alphashape==1.3.1
173
+ - annotated-types==0.7.0
174
+ - async-lru==2.0.5
175
+ - babel==2.17.0
176
+ - brokenaxes==0.6.2
177
+ - click==8.1.8
178
+ - click-log==0.4.0
179
+ - contourpy==1.3.1
180
+ - cycler==0.12.1
181
+ - descartes==1.1.0
182
+ - docker-pycreds==0.4.0
183
+ - einops==0.8.0
184
+ - executing==2.2.0
185
+ - filelock==3.17.0
186
+ - fonttools==4.56.0
187
+ - fsspec==2025.2.0
188
+ - h11==0.14.0
189
+ - hjson==3.1.0
190
+ - httpcore==1.0.7
191
+ - httpx==0.28.1
192
+ - huggingface-hub==0.29.1
193
+ - ijson==3.4.0
194
+ - ipympl==0.9.7
195
+ - ipython==9.0.2
196
+ - ipywidgets==8.1.5
197
+ - jinja2==3.1.5
198
+ - joblib==1.4.2
199
+ - json5==0.10.0
200
+ - jupyter-lsp==2.2.5
201
+ - jupyterlab==4.3.6
202
+ - jupyterlab-server==2.27.3
203
+ - jupyterlab-widgets==3.0.13
204
+ - kiwisolver==1.4.8
205
+ - llvmlite==0.44.0
206
+ - matplotlib==3.10.1
207
+ - mpi4py==4.0.3
208
+ - mpmath==1.3.0
209
+ - msgpack==1.1.0
210
+ - narwhals==1.42.1
211
+ - networkx==3.4.2
212
+ - ninja==1.11.1.3
213
+ - notebook==7.3.3
214
+ - notebook-shim==0.2.4
215
+ - numba==0.61.2
216
+ - numpy==2.2.3
217
+ - nvidia-cublas-cu12==12.4.5.8
218
+ - nvidia-cuda-cupti-cu12==12.4.127
219
+ - nvidia-cuda-nvrtc-cu12==12.4.127
220
+ - nvidia-cuda-runtime-cu12==12.4.127
221
+ - nvidia-cudnn-cu12==9.1.0.70
222
+ - nvidia-cufft-cu12==11.2.1.3
223
+ - nvidia-curand-cu12==10.3.5.147
224
+ - nvidia-cusolver-cu12==11.6.1.9
225
+ - nvidia-cusparse-cu12==12.3.1.170
226
+ - nvidia-cusparselt-cu12==0.6.2
227
+ - nvidia-ml-py==12.570.86
228
+ - nvidia-nccl-cu12==2.21.5
229
+ - nvidia-nvjitlink-cu12==12.4.127
230
+ - nvidia-nvtx-cu12==12.4.127
231
+ - pandas==2.3.0
232
+ - pandocfilters==1.5.1
233
+ - pillow==11.1.0
234
+ - plotly==6.1.2
235
+ - protobuf==5.29.3
236
+ - py-cpuinfo==9.0.0
237
+ - pydantic==2.10.6
238
+ - pydantic-core==2.27.2
239
+ - pynndescent==0.5.13
240
+ - pyparsing==3.2.1
241
+ - python-json-logger==3.3.0
242
+ - pytz==2025.2
243
+ - regex==2024.11.6
244
+ - rpds-py==0.23.1
245
+ - rtree==1.4.1
246
+ - safetensors==0.5.3
247
+ - scikit-learn==1.6.1
248
+ - scipy==1.15.2
249
+ - sentry-sdk==2.22.0
250
+ - setproctitle==1.3.5
251
+ - setuptools==75.8.2
252
+ - shapely==2.1.2
253
+ - soupsieve==2.6
254
+ - sympy==1.13.1
255
+ - threadpoolctl==3.6.0
256
+ - tokenizers==0.20.3
257
+ - torch==2.6.0
258
+ - torchaudio==2.6.0
259
+ - torchvision==0.21.0
260
+ - transformers==4.46.3
261
+ - trimesh==4.8.3
262
+ - triton==3.2.0
263
+ - tzdata==2025.2
264
+ - umap-learn==0.5.7
265
+ - wandb==0.18.3
266
+ - widgetsnbextension==4.0.13
267
+ prefix: /global/homes/b/binxia/.conda/envs/fsdp
example.ipynb ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "07604227",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "9882fd75",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": []
18
+ }
19
+ ],
20
+ "metadata": {
21
+ "kernelspec": {
22
+ "display_name": "fsdp",
23
+ "language": "python",
24
+ "name": "python3"
25
+ },
26
+ "language_info": {
27
+ "codemirror_mode": {
28
+ "name": "ipython",
29
+ "version": 3
30
+ },
31
+ "file_extension": ".py",
32
+ "mimetype": "text/x-python",
33
+ "name": "python",
34
+ "nbconvert_exporter": "python",
35
+ "pygments_lexer": "ipython3",
36
+ "version": "3.12.5"
37
+ }
38
+ },
39
+ "nbformat": 4,
40
+ "nbformat_minor": 5
41
+ }
foundation_bert.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import yaml
4
+ from pathlib import Path
5
+ from utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax
6
+ # from ..utils.contrastive_loss import ContrastiveLoss
7
+ from utils.yaml_util import MyLoader
8
+ from dataclasses import dataclass
9
+ from transformers import BertModel, BertConfig, PretrainedConfig
10
+ from typing import Optional, Union
11
+
12
+
13
+ @dataclass
14
+ class FoundationOutput:
15
+ loss: torch.Tensor = None
16
+ logits: torch.Tensor = None
17
+ num_output: torch.Tensor = None
18
+ est_err_output: torch.Tensor = None
19
+ hidden_states: torch.Tensor = None
20
+ masked_loss: torch.Tensor = None
21
+ num_loss: torch.Tensor = None
22
+ est_err_loss: torch.Tensor = None
23
+
24
+
25
+ @dataclass
26
+ class FoundationBertConfig:
27
+ vocab_size: int
28
+ hidden_size: int
29
+ num_hidden_layers: int
30
+ num_attention_heads: int
31
+ intermediate_size: int
32
+ hidden_dropout_prob: float
33
+ attention_probs_dropout_prob: float
34
+ pad_token_id: int
35
+ classifier_dropout: float
36
+ max_position_embeddings: int
37
+ contrastive_temperature: float
38
+ loss_weights: dict
39
+ use_xval_loss: bool = True
40
+ use_mlm_loss: bool = True
41
+ use_regression_loss: bool = False
42
+ use_contrastive_loss: bool = False
43
+ transform_numeric: bool = False
44
+
45
+ def to_dict(self):
46
+ return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()}
47
+
48
+ class FoundationBert(BertModel):
49
+ def __init__(self,
50
+ config: FoundationBertConfig = None,
51
+ use_mlm_loss: bool = False,
52
+ use_regression_loss: bool = True,
53
+ use_contrastive_loss: bool = False,
54
+ use_xval_loss: bool = False,
55
+ transform_numeric: bool = False,
56
+ *args,
57
+ **kwargs):
58
+ self.gconfig = config
59
+ # print(f"⚠️ FoundationBert.__init__: {self.gconfig=}")
60
+ bert_conf = BertConfig(
61
+ vocab_size=config.vocab_size,
62
+ hidden_size=config.hidden_size,
63
+ num_hidden_layers=config.num_hidden_layers,
64
+ num_attention_heads=config.num_attention_heads,
65
+ intermediate_size=config.intermediate_size,
66
+ hidden_dropout_prob=config.hidden_dropout_prob,
67
+ attention_probs_dropout_prob=config.attention_probs_dropout_prob,
68
+ pad_token_id=config.pad_token_id,
69
+ max_position_embeddings=config.max_position_embeddings,
70
+ _attn_implementation='sdpa'
71
+ )
72
+ self.gconfig.transform_numeric = transform_numeric
73
+ super().__init__(bert_conf,)
74
+ try:
75
+ if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss:
76
+ raise ValueError("At least one loss must be enabled")
77
+ self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
78
+ except:
79
+ self.gconfig.use_mlm_loss = use_mlm_loss
80
+ self.gconfig.use_regression_loss = use_regression_loss
81
+ self.gconfig.use_contrastive_loss = use_contrastive_loss
82
+ self.gconfig.use_xval_loss = use_xval_loss
83
+ self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
84
+
85
+ self.dataset_path = kwargs.get('dataset_path', None)
86
+
87
+ self.modalities = kwargs['modalities']
88
+ self.mask_token = kwargs['mask_token']
89
+
90
+ self.scalar_keys = [
91
+ 'redshift',
92
+ 'halo_mass',
93
+ 'stellar_mass',
94
+ ]
95
+ self.vector_keys = [
96
+ 'SED',
97
+ 'SFH',
98
+ 'mag_{band}_spherex',
99
+ 'mag_{band}_lsst',
100
+ ]
101
+ self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities]
102
+ self.modalscalars = list(dict.fromkeys(self.modalscalars))
103
+
104
+ # print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅")
105
+
106
+ self.embedding = torch.nn.ModuleDict() # modality specific embedding layers
107
+ self.num_head = torch.nn.ModuleDict() # modality specific regression heads
108
+ # create modality specific layers
109
+ for modality in self.modalscalars:
110
+ self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H)
111
+ self.num_head[modality] = torch.nn.Sequential(
112
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
113
+ torch.nn.LayerNorm(config.hidden_size),
114
+ torch.nn.GELU(),
115
+ torch.nn.Linear(config.hidden_size, config.hidden_size // 2),
116
+ torch.nn.GELU(),
117
+ torch.nn.Linear(config.hidden_size // 2, 1)
118
+ )
119
+
120
+ self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
121
+ self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
122
+
123
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
124
+ self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
125
+ self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
126
+ self.distributed_loss = False
127
+
128
+ @classmethod
129
+ def from_pretrained(self,
130
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
131
+ *model_args,
132
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
133
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
134
+ ignore_mismatched_sizes: bool = False,
135
+ force_download: bool = False,
136
+ local_files_only: bool = False,
137
+ token: Optional[Union[str, bool]] = None,
138
+ revision: str = "main",
139
+ use_safetensors: bool = None,
140
+ **kwargs,
141
+ ):
142
+ """
143
+ Modification to correctly handle loading extraneous parameters for GBert
144
+ """
145
+ model_config = Path(pretrained_model_name_or_path) / 'train_config.yaml'
146
+ with open(model_config, 'r') as f:
147
+ config = yaml.load(f, Loader=MyLoader)
148
+ kwargs['modalities'] = config['modalities']
149
+ kwargs['dataset_path'] = config['dataset_path']
150
+ kwargs['mask_token'] = config['mask_token']
151
+ # print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅")
152
+ return super().from_pretrained(
153
+ pretrained_model_name_or_path,
154
+ **config['model_config'],
155
+ **kwargs
156
+ )
157
+
158
+ def pool_output(self,
159
+ embeddings: torch.Tensor,
160
+ attention_mask: torch.Tensor,
161
+ use_last: bool = False
162
+ ) -> torch.Tensor:
163
+ """Average pool the hidden states using the attention mask.
164
+
165
+ Parameters
166
+ ----------
167
+ embeddings : torch.Tensor
168
+ The hidden states to pool (B, SeqLen, HiddenDim).
169
+ attention_mask : torch.Tensor
170
+ The attention mask for the hidden states (B, SeqLen).
171
+
172
+ Returns
173
+ -------
174
+ torch.Tensor
175
+ The pooled embeddings (B, HiddenDim).
176
+ """
177
+ # Get the sequence lengths
178
+ sl_mod = 1 if use_last else 2
179
+ seq_lengths = attention_mask.sum(axis=1)
180
+ # Set the attention mask to 0 for start and end tokens
181
+ new_attention = attention_mask.clone()
182
+ new_attention[:, 0] = attention_mask[:,0] * 0
183
+ new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod]
184
+
185
+ # Create a mask for the pooling operation (B, SeqLen, HiddenDim)
186
+ pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device)
187
+ # Sum the embeddings over the sequence length (use the mask to avoid
188
+ # pad, start, and stop tokens)
189
+ sum_embeds = torch.sum(embeddings * pool_mask, 1)
190
+ # Avoid division by zero for zero length sequences by clamping
191
+ # sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9)
192
+ seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast
193
+ # Compute mean pooled embeddings for each sequence
194
+ return sum_embeds / seq_lengths
195
+
196
+
197
+ def last_token_pool(
198
+ self,
199
+ embeddings: torch.Tensor,
200
+ attention_mask: torch.Tensor,
201
+ ) -> torch.Tensor:
202
+ """Pool the last hidden states using the attention mask.
203
+
204
+ Parameters
205
+ ----------
206
+ embeddings : torch.Tensor
207
+ The last hidden states to pool (B, SeqLen, HiddenDim).
208
+ attention_mask : torch.Tensor
209
+ The attention mask for the hidden states (B, SeqLen).
210
+
211
+ Returns
212
+ -------
213
+ torch.Tensor
214
+ The pooled embeddings (B, HiddenDim).
215
+ """
216
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
217
+ if left_padding:
218
+ return embeddings[:, -1]
219
+ else:
220
+ sequence_lengths = attention_mask.sum(dim=1) - 1
221
+ batch_size = embeddings.shape[0]
222
+ return embeddings[
223
+ torch.arange(batch_size, device=embeddings.device),
224
+ sequence_lengths,
225
+ ]
226
+
227
+ def forward(self, inputs, return_input_label_mapping=False):
228
+ """
229
+ Forward pass that computes predictions for each modality.
230
+
231
+ Args:
232
+ input_label_mapping (dict): A dictionary containing inputs and labels for different modalities.
233
+
234
+ Returns:
235
+ outputs (dict): A dictionary containing the logits and error logits for each modality.
236
+ """
237
+
238
+ # Initialize the dictionary for the dynamic input-label mapping
239
+ input_label_mapping = {}
240
+ combined = []
241
+ for src_modality in self.modalscalars:
242
+ # Add the modality's input and label data to the input_label_mapping
243
+ input_label_mapping[src_modality] = {
244
+ 'input': inputs[f"input_{src_modality}"], # Input data
245
+ 'labels': inputs[f"labels_{src_modality}"] # Corresponding labels
246
+ }
247
+
248
+ input_data = input_label_mapping[src_modality]['input'] # get input data
249
+ label = input_label_mapping[src_modality]['labels'] # get label data (for masking)
250
+ input_data = torch.where(label, self.mask_token, input_data) # apply masking
251
+
252
+ x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H)
253
+ x = torch.nn.functional.silu(x)
254
+ combined.append(x) # combine all modalities
255
+
256
+ combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension
257
+
258
+ self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L)
259
+ combined += self.position_embeddings(self.position_ids) # add position embedding
260
+ combined = self.embed_dropout(combined)
261
+
262
+ x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input
263
+
264
+ start = 0
265
+ outputs = {}
266
+ # Iterate over each target modality to compute logits
267
+ for tgt_modality in self.modalscalars:
268
+ length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality
269
+ x_t = x[:, start:start+length, :] # slice the encoded output for each modality
270
+ outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head
271
+
272
+ start += length # update start index for next modality
273
+
274
+ if getattr(self, 'save_umap_for', None):
275
+ pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension
276
+ self.save_pooled_embedding(pooled) # saved for UMAP visualization
277
+
278
+ return (outputs, input_label_mapping) if return_input_label_mapping else outputs
279
+
280
+ def save_pooled_embedding(self, features):
281
+ """
282
+ Save the last hidden state to a file.
283
+ """
284
+ import h5py
285
+ fname = Path(self.save_umap_for)
286
+ fname.parent.mkdir(parents=True, exist_ok=True)
287
+
288
+ features = features.detach().cpu().numpy()
289
+
290
+ if fname.exists():
291
+ with h5py.File(fname, 'r+') as f:
292
+ old_size = f['features'].shape[0] # get current size
293
+ new_size = old_size + features.shape[0] # calculate new size
294
+
295
+ f['features'].resize((new_size, features.shape[-1])) # resize dataset
296
+ f['features'][old_size:] = features # append new features
297
+
298
+ else:
299
+ with h5py.File(fname, 'w') as f:
300
+ f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2acc57f8c67e0f2b358632241243752031c23a7ed7030ba95c33b7f81e06c62
3
+ size 550172096
requirements.txt ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.1.1
2
+ alphashape==1.3.1
3
+ annotated-types==0.7.0
4
+ anyio @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_anyio_1742243108/work
5
+ appdirs==1.4.4
6
+ archspec @ file:///home/conda/feedstock_root/build_artifacts/archspec_1708969572489/work
7
+ argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1733311059102/work
8
+ argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1725356585055/work
9
+ arrow @ file:///home/conda/feedstock_root/build_artifacts/arrow_1733584251875/work
10
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
11
+ async-lru==2.0.5
12
+ attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1741918516150/work
13
+ babel==2.17.0
14
+ beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1738740337718/work
15
+ bleach @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_bleach_1737382993/work
16
+ boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1711936407380/work
17
+ brokenaxes==0.6.2
18
+ Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1725267488082/work
19
+ cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
20
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
21
+ cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1724956320552/work
22
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
23
+ # Editable install with no version control (chatarena==0.1.8)
24
+ -e /pscratch/sd/b/binxia/Werewolf
25
+ click==8.1.8
26
+ click-log==0.4.0
27
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
28
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
29
+ conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1717678605937/work
30
+ conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1717678526951/work
31
+ configparser==7.2.0
32
+ contourpy==1.3.1
33
+ cycler==0.12.1
34
+ debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1741148399929/work
35
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
36
+ defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
37
+ descartes==1.1.0
38
+ distro @ file:///home/conda/feedstock_root/build_artifacts/distro_1704321475663/work
39
+ docker-pycreds==0.4.0
40
+ einops==0.8.0
41
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work
42
+ executing==2.2.0
43
+ fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1733235979760/work/dist
44
+ filelock==3.17.0
45
+ fonttools==4.56.0
46
+ fqdn @ file:///home/conda/feedstock_root/build_artifacts/fqdn_1733327382592/work/dist
47
+ frozendict @ file:///home/conda/feedstock_root/build_artifacts/frozendict_1715092752354/work
48
+ fsspec==2025.2.0
49
+ gitdb @ file:///home/conda/feedstock_root/build_artifacts/gitdb_1735887193964/work
50
+ GitPython @ file:///home/conda/feedstock_root/build_artifacts/gitpython_1735929639977/work
51
+ h11==0.14.0
52
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work
53
+ hjson==3.1.0
54
+ hpack==4.0.0
55
+ httpcore==1.0.7
56
+ httpx==0.28.1
57
+ huggingface-hub==0.29.1
58
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work
59
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1724450538981/work
60
+ ijson==3.4.0
61
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
62
+ importlib_resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1736252299705/work
63
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
64
+ ipympl==0.9.7
65
+ ipython==9.0.2
66
+ ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
67
+ ipywidgets==8.1.5
68
+ isoduration @ file:///home/conda/feedstock_root/build_artifacts/isoduration_1733493628631/work/dist
69
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
70
+ Jinja2==3.1.5
71
+ joblib==1.4.2
72
+ json5==0.10.0
73
+ jsonpatch @ file:///home/conda/feedstock_root/build_artifacts/jsonpatch_1695536281965/work
74
+ jsonpointer @ file:///home/conda/feedstock_root/build_artifacts/jsonpointer_1725302935093/work
75
+ jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema_1733472696581/work
76
+ jsonschema-specifications @ file:///tmp/tmpk0f344m9/src
77
+ jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_jupyter_events_1738765986/work
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
80
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
81
+ jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1734702637701/work
82
+ jupyter_server_mathjax @ file:///home/conda/feedstock_root/build_artifacts/jupyter-server-mathjax_1734509714511/work
83
+ jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1733427956852/work
84
+ jupyterlab==4.3.6
85
+ jupyterlab_pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1733328101776/work
86
+ jupyterlab_server==2.27.3
87
+ jupyterlab_widgets==3.0.13
88
+ kiwisolver==1.4.8
89
+ kymatio==0.3.0
90
+ libmambapy @ file:///home/conda/feedstock_root/build_artifacts/mamba-split_1727883551957/work/libmambapy
91
+ llvmlite==0.44.0
92
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
93
+ matplotlib==3.10.1
94
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
95
+ menuinst @ file:///home/conda/feedstock_root/build_artifacts/menuinst_1725359038078/work
96
+ mistune @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_mistune_1742402716/work
97
+ mpi4py==4.0.3
98
+ mpmath==1.3.0
99
+ msgpack==1.1.0
100
+ narwhals==1.42.1
101
+ nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1734628800805/work
102
+ nbconvert @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_nbconvert-core_1738067871/work
103
+ nbdime @ file:///home/conda/feedstock_root/build_artifacts/nbdime_1734533951497/work
104
+ nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1733402752141/work
105
+ nersc-pymon==0.3.0
106
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
107
+ networkx==3.4.2
108
+ ninja==1.11.1.3
109
+ notebook==7.3.3
110
+ notebook_shim==0.2.4
111
+ numba==0.61.2
112
+ numpy==2.2.3
113
+ nvidia-cublas-cu12==12.4.5.8
114
+ nvidia-cuda-cupti-cu12==12.4.127
115
+ nvidia-cuda-nvrtc-cu12==12.4.127
116
+ nvidia-cuda-runtime-cu12==12.4.127
117
+ nvidia-cudnn-cu12==9.1.0.70
118
+ nvidia-cufft-cu12==11.2.1.3
119
+ nvidia-curand-cu12==10.3.5.147
120
+ nvidia-cusolver-cu12==11.6.1.9
121
+ nvidia-cusparse-cu12==12.3.1.170
122
+ nvidia-cusparselt-cu12==0.6.2
123
+ nvidia-ml-py==12.570.86
124
+ nvidia-nccl-cu12==2.21.5
125
+ nvidia-nvjitlink-cu12==12.4.127
126
+ nvidia-nvtx-cu12==12.4.127
127
+ overrides @ file:///home/conda/feedstock_root/build_artifacts/overrides_1734587627321/work
128
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
129
+ pandas==2.3.0
130
+ pandocfilters==1.5.1
131
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
132
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
133
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
134
+ pillow==11.1.0
135
+ pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1733344503739/work
136
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
137
+ plotly==6.1.2
138
+ pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1713667077545/work
139
+ prometheus_client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1733327310477/work
140
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work
141
+ protobuf==5.29.3
142
+ psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663123172/work
143
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
144
+ pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
145
+ py-cpuinfo==9.0.0
146
+ pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1696355774225/work
147
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work
148
+ pydantic==2.10.6
149
+ pydantic_core==2.27.2
150
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
151
+ pynndescent==0.5.13
152
+ pyparsing==3.2.1
153
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
154
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
155
+ python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
156
+ pytz==2025.2
157
+ PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1737454647378/work
158
+ pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1738270962252/work
159
+ referencing @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_referencing_1737836872/work
160
+ regex==2024.11.6
161
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1717057054362/work
162
+ rfc3339_validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1733599910982/work
163
+ rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
164
+ rpds-py==0.23.1
165
+ rtree==1.4.1
166
+ ruamel.yaml @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml_1707298132558/work
167
+ ruamel.yaml.clib @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml.clib_1707314473810/work
168
+ safetensors==0.5.3
169
+ scikit-learn==1.6.1
170
+ scipy==1.15.2
171
+ Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1733322040660/work
172
+ sentry-sdk==2.22.0
173
+ setproctitle==1.3.5
174
+ setuptools==73.0.1
175
+ shapely==2.1.2
176
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
177
+ smmap @ file:///home/conda/feedstock_root/build_artifacts/smmap_1739781697784/work
178
+ sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1733244044561/work
179
+ soupsieve==2.6
180
+ stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
181
+ sympy==1.13.1
182
+ terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1710262609923/work
183
+ threadpoolctl==3.6.0
184
+ tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1729802851396/work
185
+ tokenizers==0.20.3
186
+ torch==2.6.0
187
+ torchaudio==2.6.0
188
+ torchvision==0.21.0
189
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1732615905931/work
190
+ tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1722737464726/work
191
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
192
+ transformers==4.46.3
193
+ trimesh==4.8.3
194
+ triton==3.2.0
195
+ truststore @ file:///home/conda/feedstock_root/build_artifacts/truststore_1724770958874/work
196
+ types-python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/types-python-dateutil_1733612335562/work
197
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
198
+ typing_utils @ file:///home/conda/feedstock_root/build_artifacts/typing_utils_1733331286120/work
199
+ tzdata==2025.2
200
+ umap-learn==0.5.7
201
+ uri-template @ file:///home/conda/feedstock_root/build_artifacts/uri-template_1733323593477/work/dist
202
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1719391292974/work
203
+ wandb==0.18.3
204
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
205
+ webcolors @ file:///home/conda/feedstock_root/build_artifacts/webcolors_1733359735138/work
206
+ webencodings @ file:///home/conda/feedstock_root/build_artifacts/webencodings_1733236011802/work
207
+ websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1733157342724/work
208
+ wheel==0.44.0
209
+ widgetsnbextension==4.0.13
210
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
211
+ zstandard==0.23.0
setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from setuptools import find_packages, setup, Command
3
+
4
+ with open('requirements.txt', 'r') as f:
5
+ requires = f.read().splitlines()
6
+
7
+ setup(
8
+ name='object_foundations',
9
+ version = 0.0,
10
+ install_requirements=requires,
11
+ packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
12
+
13
+
14
+ )
train_config.yaml ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: /global/cfs/cdirs/m4717/azton/galaxy-foundations/object_foundation/utils/supermock_dataset_11.2-14.json
2
+ mask_token: 0
3
+ masked_generation: false
4
+ masking_prob:
5
+ - 0.2
6
+ - 0.2
7
+ - 0.2
8
+ - 0.2
9
+ - 0.5
10
+ - 0.5
11
+ - 0.5
12
+ modalities:
13
+ - SFH
14
+ - SED
15
+ - mag_{band}_lsst
16
+ - mag_{band}_spherex
17
+ - redshift
18
+ - halo_mass
19
+ - stellar_mass
20
+ model_config:
21
+ attention_probs_dropout_prob: 0.1
22
+ classifier_dropout: 0.0
23
+ contrastive_temperature: 0.05
24
+ hidden_dropout_prob: 0.1
25
+ hidden_size: 768
26
+ intermediate_size: 3072
27
+ loss_weights:
28
+ contrastive:
29
+ rounds: 0
30
+ w0T:
31
+ - 0
32
+ - 0
33
+ masked:
34
+ rounds: 0
35
+ w0T:
36
+ - 0.8
37
+ - 3
38
+ smooth:
39
+ rounds: 0
40
+ w0T:
41
+ - 0
42
+ - 0.3
43
+ unmasked:
44
+ rounds: 0
45
+ w0T:
46
+ - 0.2
47
+ - 0.3
48
+ max_position_embeddings: 1149
49
+ num_attention_heads: 12
50
+ num_hidden_layers: 18
51
+ pad_token_id: -1
52
+ transform_numeric: false
53
+ use_contrastive_loss: false
54
+ use_mlm_loss: true
55
+ use_regression_loss: false
56
+ use_xval_loss: false
57
+ vocab_size: 2048
58
+ model_name_or_path: galaxybert
59
+ tokenizer_name_or_path: Salesforce/SFR-Embedding-Mistral
60
+ training_args:
61
+ _n_gpu: 1
62
+ accelerator_config:
63
+ dispatch_batches: null
64
+ even_batches: true
65
+ gradient_accumulation_kwargs: null
66
+ non_blocking: false
67
+ split_batches: false
68
+ use_configured_state: false
69
+ use_seedable_sampler: true
70
+ adafactor: false
71
+ adam_beta1: 0.9
72
+ adam_beta2: 0.999
73
+ adam_epsilon: 1.0e-08
74
+ auto_find_batch_size: false
75
+ average_tokens_across_devices: false
76
+ batch_eval_metrics: false
77
+ bf16: true
78
+ bf16_full_eval: false
79
+ data_seed: null
80
+ dataloader_drop_last: false
81
+ dataloader_num_workers: 16
82
+ dataloader_persistent_workers: false
83
+ dataloader_pin_memory: true
84
+ dataloader_prefetch_factor: 8
85
+ ddp_backend: null
86
+ ddp_broadcast_buffers: null
87
+ ddp_bucket_cap_mb: null
88
+ ddp_find_unused_parameters: null
89
+ ddp_timeout: 1800
90
+ debug: []
91
+ deepspeed: null
92
+ disable_tqdm: false
93
+ dispatch_batches: null
94
+ do_eval: true
95
+ do_predict: false
96
+ do_train: false
97
+ eval_accumulation_steps: 5
98
+ eval_delay: 0
99
+ eval_do_concat_batches: true
100
+ eval_on_start: false
101
+ eval_steps: 20
102
+ eval_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
103
+ - steps
104
+ eval_use_gather_object: false
105
+ evaluation_strategy: null
106
+ fp16: false
107
+ fp16_backend: auto
108
+ fp16_full_eval: false
109
+ fp16_opt_level: O1
110
+ fsdp: []
111
+ fsdp_config:
112
+ min_num_params: 0
113
+ xla: false
114
+ xla_fsdp_grad_ckpt: false
115
+ xla_fsdp_v2: false
116
+ fsdp_min_num_params: 0
117
+ fsdp_transformer_layer_cls_to_wrap: null
118
+ full_determinism: false
119
+ gradient_accumulation_steps: 5
120
+ gradient_checkpointing: false
121
+ gradient_checkpointing_kwargs: null
122
+ greater_is_better: null
123
+ group_by_length: false
124
+ half_precision_backend: auto
125
+ hub_always_push: false
126
+ hub_model_id: null
127
+ hub_private_repo: false
128
+ hub_strategy: !!python/object/apply:transformers.trainer_utils.HubStrategy
129
+ - every_save
130
+ hub_token: null
131
+ ignore_data_skip: false
132
+ include_for_metrics: []
133
+ include_inputs_for_metrics: false
134
+ include_num_input_tokens_seen: false
135
+ include_tokens_per_second: false
136
+ jit_mode_eval: false
137
+ label_names: null
138
+ label_smoothing_factor: 0.0
139
+ learning_rate: 0.0001
140
+ length_column_name: length
141
+ load_best_model_at_end: false
142
+ local_rank: 0
143
+ log_level: passive
144
+ log_level_replica: warning
145
+ log_on_each_node: true
146
+ logging_dir: sm_foundation_lg_gmm_nomasklab
147
+ logging_first_step: true
148
+ logging_nan_inf_filter: true
149
+ logging_steps: 1
150
+ logging_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
151
+ - steps
152
+ lr_scheduler_kwargs: {}
153
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
154
+ - cosine
155
+ max_grad_norm: 1.0
156
+ max_steps: -1
157
+ metric_for_best_model: null
158
+ mp_parameters: ''
159
+ neftune_noise_alpha: null
160
+ no_cuda: false
161
+ num_train_epochs: 60
162
+ optim: !!python/object/apply:transformers.training_args.OptimizerNames
163
+ - adamw_torch
164
+ optim_args: null
165
+ optim_target_modules: null
166
+ output_dir: supermock_te60_
167
+ overwrite_output_dir: true
168
+ past_index: -1
169
+ per_device_eval_batch_size: 100
170
+ per_device_train_batch_size: 100
171
+ per_gpu_eval_batch_size: null
172
+ per_gpu_train_batch_size: null
173
+ prediction_loss_only: false
174
+ push_to_hub: false
175
+ push_to_hub_model_id: null
176
+ push_to_hub_organization: null
177
+ push_to_hub_token: null
178
+ ray_scope: last
179
+ remove_unused_columns: false
180
+ report_to:
181
+ - wandb
182
+ restore_callback_states_from_checkpoint: false
183
+ resume_from_checkpoint: null
184
+ run_name: NO_SHARD_b50
185
+ save_on_each_node: false
186
+ save_only_model: false
187
+ save_safetensors: true
188
+ save_steps: 30
189
+ save_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
190
+ - steps
191
+ save_total_limit: 360
192
+ seed: 42
193
+ skip_memory_metrics: true
194
+ split_batches: null
195
+ tf32: null
196
+ torch_compile: false
197
+ torch_compile_backend: null
198
+ torch_compile_mode: null
199
+ torch_empty_cache_steps: null
200
+ torchdynamo: null
201
+ tpu_metrics_debug: false
202
+ tpu_num_cores: null
203
+ use_cpu: false
204
+ use_ipex: false
205
+ use_legacy_prediction_loop: false
206
+ use_liger_kernel: false
207
+ use_mps_device: false
208
+ warmup_ratio: 0.0
209
+ warmup_steps: 0
210
+ weight_decay: 0.1
211
+ transform_numeric: false
212
+ wandb_project: supermock-foundation-perl
213
+ wandb_run_name: ''
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (175 Bytes). View file
 
utils/__pycache__/masked_data_modeling_loss.cpython-312.pyc ADDED
Binary file (1.62 kB). View file
 
utils/__pycache__/yaml_util.cpython-312.pyc ADDED
Binary file (1.83 kB). View file
 
utils/masked_data_modeling_loss.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ '''
4
+ Simple class to do all MLM sort of loss operations in one place
5
+ '''
6
+ class MaskedDataLossWithSoftmax(torch.nn.Module):
7
+ def __init__(self, ignore: int=-100, reduction: str='mean', weight=None):
8
+ super(MaskedDataLossWithSoftmax, self).__init__()
9
+ self.loss = torch.nn.CrossEntropyLoss(ignore_index=-100,
10
+ reduction=reduction,
11
+ weight=weight)
12
+
13
+ def __call__(self, logits: torch.Tensor,
14
+ labels: torch.Tensor
15
+ )-> torch.Tensor:
16
+ """
17
+ Logits: [batch_size, seq_len, vocab_size]; without softmax applied
18
+ Labels should have -100 for all indices that are not part of masked tokens
19
+
20
+ """
21
+ logits = rearrange(logits, 'b s v -> b v s')
22
+ loss = self.loss(logits, labels)
23
+
24
+ return loss
utils/yaml_util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ class MyLoader(yaml.SafeLoader):
3
+ # returns
4
+ def construct_mapping(self, *args, **kwargs):
5
+ super().add_constructor(None, construct_undefined)
6
+ # when loading we want to skip keys that require construction,
7
+ mapping = super().construct_mapping(*args, **kwargs)
8
+
9
+ return mapping
10
+ import typing
11
+ class Tagged(typing.NamedTuple):
12
+ tag: str
13
+ value: object
14
+
15
+ def construct_undefined(self, node):
16
+ if isinstance(node, yaml.nodes.ScalarNode):
17
+ value = self.construct_scalar(node)
18
+ elif isinstance(node, yaml.nodes.SequenceNode):
19
+ value = self.construct_sequence(node)
20
+ elif isinstance(node, yaml.nodes.MappingNode):
21
+ value = self.construct_mapping(node)
22
+ else:
23
+ assert False, f"unexpected node: {node!r}"
24
+ return Tagged(node.tag, value)