Initial upload of MOSAIC FoundationBert model, v1.0. Final successful local test.
Browse files- __pycache__/foundation_bert.cpython-312.pyc +0 -0
- config.json +29 -0
- environment.yml +267 -0
- example.ipynb +41 -0
- foundation_bert.py +300 -0
- model.safetensors +3 -0
- requirements.txt +211 -0
- setup.py +14 -0
- train_config.yaml +213 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-312.pyc +0 -0
- utils/__pycache__/masked_data_modeling_loss.cpython-312.pyc +0 -0
- utils/__pycache__/yaml_util.cpython-312.pyc +0 -0
- utils/masked_data_modeling_loss.py +24 -0
- utils/yaml_util.py +24 -0
__pycache__/foundation_bert.cpython-312.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_auto_class": "FoundationBert",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoModel": "foundation_bert.py.FoundationBert"
|
| 5 |
+
},
|
| 6 |
+
|
| 7 |
+
"architectures": [
|
| 8 |
+
"FoundationBert"
|
| 9 |
+
],
|
| 10 |
+
"attention_probs_dropout_prob": 0.1,
|
| 11 |
+
"classifier_dropout": null,
|
| 12 |
+
"hidden_act": "gelu",
|
| 13 |
+
"hidden_dropout_prob": 0.1,
|
| 14 |
+
"hidden_size": 768,
|
| 15 |
+
"initializer_range": 0.02,
|
| 16 |
+
"intermediate_size": 3072,
|
| 17 |
+
"layer_norm_eps": 1e-12,
|
| 18 |
+
"max_position_embeddings": 1149,
|
| 19 |
+
"model_type": "bert",
|
| 20 |
+
"num_attention_heads": 12,
|
| 21 |
+
"num_hidden_layers": 18,
|
| 22 |
+
"pad_token_id": -1,
|
| 23 |
+
"position_embedding_type": "absolute",
|
| 24 |
+
"torch_dtype": "float32",
|
| 25 |
+
"transformers_version": "4.46.3",
|
| 26 |
+
"type_vocab_size": 2,
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"vocab_size": 2048
|
| 29 |
+
}
|
environment.yml
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: fsdp
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 6 |
+
- _openmp_mutex=4.5=2_gnu
|
| 7 |
+
- anyio=4.9.0=pyh29332c3_0
|
| 8 |
+
- archspec=0.2.3=pyhd8ed1ab_0
|
| 9 |
+
- argon2-cffi=23.1.0=pyhd8ed1ab_1
|
| 10 |
+
- argon2-cffi-bindings=21.2.0=py312h66e93f0_5
|
| 11 |
+
- arrow=1.3.0=pyhd8ed1ab_1
|
| 12 |
+
- asttokens=3.0.0=pyhd8ed1ab_1
|
| 13 |
+
- attrs=25.3.0=pyh71513ae_0
|
| 14 |
+
- beautifulsoup4=4.13.3=pyha770c72_0
|
| 15 |
+
- bleach=6.2.0=pyh29332c3_4
|
| 16 |
+
- bleach-with-css=6.2.0=h82add2a_4
|
| 17 |
+
- boltons=24.0.0=pyhd8ed1ab_0
|
| 18 |
+
- brotli-python=1.1.0=py312h2ec8cdc_2
|
| 19 |
+
- bzip2=1.0.8=h4bc722e_7
|
| 20 |
+
- c-ares=1.32.3=h4bc722e_0
|
| 21 |
+
- ca-certificates=2025.4.26=hbd8a1cb_0
|
| 22 |
+
- cached-property=1.5.2=hd8ed1ab_1
|
| 23 |
+
- cached_property=1.5.2=pyha770c72_1
|
| 24 |
+
- certifi=2025.4.26=pyhd8ed1ab_0
|
| 25 |
+
- cffi=1.17.0=py312h06ac9bb_1
|
| 26 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
| 27 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
| 28 |
+
- comm=0.2.2=pyhd8ed1ab_1
|
| 29 |
+
- conda-package-handling=2.3.0=pyh7900ff3_0
|
| 30 |
+
- conda-package-streaming=0.10.0=pyhd8ed1ab_0
|
| 31 |
+
- debugpy=1.8.13=py312h2ec8cdc_0
|
| 32 |
+
- decorator=5.2.1=pyhd8ed1ab_0
|
| 33 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
| 34 |
+
- distro=1.9.0=pyhd8ed1ab_0
|
| 35 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
| 36 |
+
- fmt=11.0.2=h434a139_0
|
| 37 |
+
- fqdn=1.5.1=pyhd8ed1ab_1
|
| 38 |
+
- frozendict=2.4.4=py312h9a8786e_0
|
| 39 |
+
- gitdb=4.0.12=pyhd8ed1ab_0
|
| 40 |
+
- gitpython=3.1.44=pyhff2d567_0
|
| 41 |
+
- h2=4.1.0=pyhd8ed1ab_0
|
| 42 |
+
- hpack=4.0.0=pyh9f0ad1d_0
|
| 43 |
+
- hyperframe=6.0.1=pyhd8ed1ab_0
|
| 44 |
+
- icu=75.1=he02047a_0
|
| 45 |
+
- idna=3.8=pyhd8ed1ab_0
|
| 46 |
+
- importlib-metadata=8.6.1=pyha770c72_0
|
| 47 |
+
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
| 48 |
+
- ipykernel=6.29.5=pyh3099207_0
|
| 49 |
+
- ipython_pygments_lexers=1.1.1=pyhd8ed1ab_0
|
| 50 |
+
- isoduration=20.11.0=pyhd8ed1ab_1
|
| 51 |
+
- jedi=0.19.2=pyhd8ed1ab_1
|
| 52 |
+
- jq=1.7.1=hd590300_0
|
| 53 |
+
- jsonpatch=1.33=pyhd8ed1ab_0
|
| 54 |
+
- jsonpointer=3.0.0=py312h7900ff3_1
|
| 55 |
+
- jsonschema=4.23.0=pyhd8ed1ab_1
|
| 56 |
+
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
|
| 57 |
+
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
|
| 58 |
+
- jupyter-server-mathjax=0.2.6=pyhbbac1ac_2
|
| 59 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_1
|
| 60 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
| 61 |
+
- jupyter_events=0.12.0=pyh29332c3_0
|
| 62 |
+
- jupyter_server=2.15.0=pyhd8ed1ab_0
|
| 63 |
+
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
|
| 64 |
+
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
|
| 65 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 66 |
+
- krb5=1.21.3=h659f571_0
|
| 67 |
+
- ld_impl_linux-64=2.40=hf3520f5_7
|
| 68 |
+
- libarchive=3.7.4=hfca40fe_0
|
| 69 |
+
- libcurl=8.10.1=hbbe4b11_0
|
| 70 |
+
- libedit=3.1.20191231=he28a2e2_2
|
| 71 |
+
- libev=4.33=hd590300_2
|
| 72 |
+
- libexpat=2.6.2=h59595ed_0
|
| 73 |
+
- libffi=3.4.2=h7f98852_5
|
| 74 |
+
- libgcc=14.1.0=h77fa898_1
|
| 75 |
+
- libgcc-ng=14.1.0=h69a702a_1
|
| 76 |
+
- libgomp=14.1.0=h77fa898_1
|
| 77 |
+
- libiconv=1.17=hd590300_2
|
| 78 |
+
- libmamba=1.5.10=hf72d635_1
|
| 79 |
+
- libmambapy=1.5.10=py312hf3f0a4e_1
|
| 80 |
+
- libnghttp2=1.58.0=h47da74e_1
|
| 81 |
+
- libnsl=2.0.1=hd590300_0
|
| 82 |
+
- libsodium=1.0.20=h4ab18f5_0
|
| 83 |
+
- libsolv=0.7.30=h3509ff9_0
|
| 84 |
+
- libsqlite=3.46.1=hadc24fc_0
|
| 85 |
+
- libssh2=1.11.0=h0841786_0
|
| 86 |
+
- libstdcxx=14.1.0=hc0a3c3a_1
|
| 87 |
+
- libstdcxx-ng=14.1.0=h4852527_1
|
| 88 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 89 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 90 |
+
- libxml2=2.12.7=he7c6b58_4
|
| 91 |
+
- libzlib=1.3.1=h4ab18f5_1
|
| 92 |
+
- lz4-c=1.9.4=hcb278e6_0
|
| 93 |
+
- lzo=2.10=hd590300_1001
|
| 94 |
+
- markupsafe=3.0.2=py312h178313f_1
|
| 95 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
| 96 |
+
- menuinst=2.1.2=py312h7900ff3_1
|
| 97 |
+
- mistune=3.1.3=pyh29332c3_0
|
| 98 |
+
- nbclient=0.10.2=pyhd8ed1ab_0
|
| 99 |
+
- nbconvert-core=7.16.6=pyh29332c3_0
|
| 100 |
+
- nbdime=4.0.2=pyhd8ed1ab_1
|
| 101 |
+
- nbformat=5.10.4=pyhd8ed1ab_1
|
| 102 |
+
- ncurses=6.5=he02047a_1
|
| 103 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
| 104 |
+
- oniguruma=6.9.10=hb9d3cd8_0
|
| 105 |
+
- openssl=3.5.0=h7b32b05_1
|
| 106 |
+
- overrides=7.7.0=pyhd8ed1ab_1
|
| 107 |
+
- packaging=24.1=pyhd8ed1ab_0
|
| 108 |
+
- parso=0.8.4=pyhd8ed1ab_1
|
| 109 |
+
- pexpect=4.9.0=pyhd8ed1ab_1
|
| 110 |
+
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
| 111 |
+
- pip=24.2=pyh8b19718_1
|
| 112 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
|
| 113 |
+
- platformdirs=4.2.2=pyhd8ed1ab_0
|
| 114 |
+
- pluggy=1.5.0=pyhd8ed1ab_0
|
| 115 |
+
- prometheus_client=0.21.1=pyhd8ed1ab_0
|
| 116 |
+
- prompt-toolkit=3.0.50=pyha770c72_0
|
| 117 |
+
- psutil=7.0.0=py312h66e93f0_0
|
| 118 |
+
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
| 119 |
+
- pure_eval=0.2.3=pyhd8ed1ab_1
|
| 120 |
+
- pybind11-abi=4=hd8ed1ab_3
|
| 121 |
+
- pycosat=0.6.6=py312h98912ed_0
|
| 122 |
+
- pycparser=2.22=pyhd8ed1ab_0
|
| 123 |
+
- pygments=2.19.1=pyhd8ed1ab_0
|
| 124 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
| 125 |
+
- python=3.12.5=h2ad013b_0_cpython
|
| 126 |
+
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
| 127 |
+
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
|
| 128 |
+
- python_abi=3.12=5_cp312
|
| 129 |
+
- pyyaml=6.0.2=py312h178313f_2
|
| 130 |
+
- pyzmq=26.2.1=py312hbf22597_0
|
| 131 |
+
- readline=8.2=h8228510_1
|
| 132 |
+
- referencing=0.36.2=pyh29332c3_0
|
| 133 |
+
- reproc=14.2.4.post0=hd590300_1
|
| 134 |
+
- reproc-cpp=14.2.4.post0=h59595ed_1
|
| 135 |
+
- requests=2.32.3=pyhd8ed1ab_0
|
| 136 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
|
| 137 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
| 138 |
+
- ruamel.yaml=0.18.6=py312h98912ed_0
|
| 139 |
+
- ruamel.yaml.clib=0.2.8=py312h98912ed_0
|
| 140 |
+
- send2trash=1.8.3=pyh0d859eb_1
|
| 141 |
+
- six=1.17.0=pyhd8ed1ab_0
|
| 142 |
+
- smmap=5.0.2=pyhd8ed1ab_0
|
| 143 |
+
- sniffio=1.3.1=pyhd8ed1ab_1
|
| 144 |
+
- stack_data=0.6.3=pyhd8ed1ab_1
|
| 145 |
+
- terminado=0.18.1=pyh0d859eb_0
|
| 146 |
+
- tinycss2=1.4.0=pyhd8ed1ab_0
|
| 147 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 148 |
+
- tornado=6.4.2=py312h66e93f0_0
|
| 149 |
+
- tqdm=4.66.5=pyhd8ed1ab_0
|
| 150 |
+
- traitlets=5.14.3=pyhd8ed1ab_1
|
| 151 |
+
- truststore=0.9.2=pyhd8ed1ab_0
|
| 152 |
+
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
|
| 153 |
+
- typing-extensions=4.12.2=hd8ed1ab_1
|
| 154 |
+
- typing_extensions=4.12.2=pyha770c72_1
|
| 155 |
+
- typing_utils=0.1.0=pyhd8ed1ab_1
|
| 156 |
+
- uri-template=1.3.0=pyhd8ed1ab_1
|
| 157 |
+
- urllib3=2.2.2=pyhd8ed1ab_1
|
| 158 |
+
- wcwidth=0.2.13=pyhd8ed1ab_1
|
| 159 |
+
- webcolors=24.11.1=pyhd8ed1ab_0
|
| 160 |
+
- webencodings=0.5.1=pyhd8ed1ab_3
|
| 161 |
+
- websocket-client=1.8.0=pyhd8ed1ab_1
|
| 162 |
+
- wheel=0.44.0=pyhd8ed1ab_0
|
| 163 |
+
- xz=5.2.6=h166bdaf_0
|
| 164 |
+
- yaml=0.2.5=h7f98852_2
|
| 165 |
+
- yaml-cpp=0.8.0=h59595ed_0
|
| 166 |
+
- zeromq=4.3.5=h3b0a872_7
|
| 167 |
+
- zipp=3.21.0=pyhd8ed1ab_1
|
| 168 |
+
- zstandard=0.23.0=py312hef9b889_1
|
| 169 |
+
- zstd=1.5.6=ha6fb4c9_0
|
| 170 |
+
- pip:
|
| 171 |
+
- accelerate==1.1.1
|
| 172 |
+
- alphashape==1.3.1
|
| 173 |
+
- annotated-types==0.7.0
|
| 174 |
+
- async-lru==2.0.5
|
| 175 |
+
- babel==2.17.0
|
| 176 |
+
- brokenaxes==0.6.2
|
| 177 |
+
- click==8.1.8
|
| 178 |
+
- click-log==0.4.0
|
| 179 |
+
- contourpy==1.3.1
|
| 180 |
+
- cycler==0.12.1
|
| 181 |
+
- descartes==1.1.0
|
| 182 |
+
- docker-pycreds==0.4.0
|
| 183 |
+
- einops==0.8.0
|
| 184 |
+
- executing==2.2.0
|
| 185 |
+
- filelock==3.17.0
|
| 186 |
+
- fonttools==4.56.0
|
| 187 |
+
- fsspec==2025.2.0
|
| 188 |
+
- h11==0.14.0
|
| 189 |
+
- hjson==3.1.0
|
| 190 |
+
- httpcore==1.0.7
|
| 191 |
+
- httpx==0.28.1
|
| 192 |
+
- huggingface-hub==0.29.1
|
| 193 |
+
- ijson==3.4.0
|
| 194 |
+
- ipympl==0.9.7
|
| 195 |
+
- ipython==9.0.2
|
| 196 |
+
- ipywidgets==8.1.5
|
| 197 |
+
- jinja2==3.1.5
|
| 198 |
+
- joblib==1.4.2
|
| 199 |
+
- json5==0.10.0
|
| 200 |
+
- jupyter-lsp==2.2.5
|
| 201 |
+
- jupyterlab==4.3.6
|
| 202 |
+
- jupyterlab-server==2.27.3
|
| 203 |
+
- jupyterlab-widgets==3.0.13
|
| 204 |
+
- kiwisolver==1.4.8
|
| 205 |
+
- llvmlite==0.44.0
|
| 206 |
+
- matplotlib==3.10.1
|
| 207 |
+
- mpi4py==4.0.3
|
| 208 |
+
- mpmath==1.3.0
|
| 209 |
+
- msgpack==1.1.0
|
| 210 |
+
- narwhals==1.42.1
|
| 211 |
+
- networkx==3.4.2
|
| 212 |
+
- ninja==1.11.1.3
|
| 213 |
+
- notebook==7.3.3
|
| 214 |
+
- notebook-shim==0.2.4
|
| 215 |
+
- numba==0.61.2
|
| 216 |
+
- numpy==2.2.3
|
| 217 |
+
- nvidia-cublas-cu12==12.4.5.8
|
| 218 |
+
- nvidia-cuda-cupti-cu12==12.4.127
|
| 219 |
+
- nvidia-cuda-nvrtc-cu12==12.4.127
|
| 220 |
+
- nvidia-cuda-runtime-cu12==12.4.127
|
| 221 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 222 |
+
- nvidia-cufft-cu12==11.2.1.3
|
| 223 |
+
- nvidia-curand-cu12==10.3.5.147
|
| 224 |
+
- nvidia-cusolver-cu12==11.6.1.9
|
| 225 |
+
- nvidia-cusparse-cu12==12.3.1.170
|
| 226 |
+
- nvidia-cusparselt-cu12==0.6.2
|
| 227 |
+
- nvidia-ml-py==12.570.86
|
| 228 |
+
- nvidia-nccl-cu12==2.21.5
|
| 229 |
+
- nvidia-nvjitlink-cu12==12.4.127
|
| 230 |
+
- nvidia-nvtx-cu12==12.4.127
|
| 231 |
+
- pandas==2.3.0
|
| 232 |
+
- pandocfilters==1.5.1
|
| 233 |
+
- pillow==11.1.0
|
| 234 |
+
- plotly==6.1.2
|
| 235 |
+
- protobuf==5.29.3
|
| 236 |
+
- py-cpuinfo==9.0.0
|
| 237 |
+
- pydantic==2.10.6
|
| 238 |
+
- pydantic-core==2.27.2
|
| 239 |
+
- pynndescent==0.5.13
|
| 240 |
+
- pyparsing==3.2.1
|
| 241 |
+
- python-json-logger==3.3.0
|
| 242 |
+
- pytz==2025.2
|
| 243 |
+
- regex==2024.11.6
|
| 244 |
+
- rpds-py==0.23.1
|
| 245 |
+
- rtree==1.4.1
|
| 246 |
+
- safetensors==0.5.3
|
| 247 |
+
- scikit-learn==1.6.1
|
| 248 |
+
- scipy==1.15.2
|
| 249 |
+
- sentry-sdk==2.22.0
|
| 250 |
+
- setproctitle==1.3.5
|
| 251 |
+
- setuptools==75.8.2
|
| 252 |
+
- shapely==2.1.2
|
| 253 |
+
- soupsieve==2.6
|
| 254 |
+
- sympy==1.13.1
|
| 255 |
+
- threadpoolctl==3.6.0
|
| 256 |
+
- tokenizers==0.20.3
|
| 257 |
+
- torch==2.6.0
|
| 258 |
+
- torchaudio==2.6.0
|
| 259 |
+
- torchvision==0.21.0
|
| 260 |
+
- transformers==4.46.3
|
| 261 |
+
- trimesh==4.8.3
|
| 262 |
+
- triton==3.2.0
|
| 263 |
+
- tzdata==2025.2
|
| 264 |
+
- umap-learn==0.5.7
|
| 265 |
+
- wandb==0.18.3
|
| 266 |
+
- widgetsnbextension==4.0.13
|
| 267 |
+
prefix: /global/homes/b/binxia/.conda/envs/fsdp
|
example.ipynb
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "07604227",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": []
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "9882fd75",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": []
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"metadata": {
|
| 21 |
+
"kernelspec": {
|
| 22 |
+
"display_name": "fsdp",
|
| 23 |
+
"language": "python",
|
| 24 |
+
"name": "python3"
|
| 25 |
+
},
|
| 26 |
+
"language_info": {
|
| 27 |
+
"codemirror_mode": {
|
| 28 |
+
"name": "ipython",
|
| 29 |
+
"version": 3
|
| 30 |
+
},
|
| 31 |
+
"file_extension": ".py",
|
| 32 |
+
"mimetype": "text/x-python",
|
| 33 |
+
"name": "python",
|
| 34 |
+
"nbconvert_exporter": "python",
|
| 35 |
+
"pygments_lexer": "ipython3",
|
| 36 |
+
"version": "3.12.5"
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"nbformat": 4,
|
| 40 |
+
"nbformat_minor": 5
|
| 41 |
+
}
|
foundation_bert.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax
|
| 6 |
+
# from ..utils.contrastive_loss import ContrastiveLoss
|
| 7 |
+
from utils.yaml_util import MyLoader
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from transformers import BertModel, BertConfig, PretrainedConfig
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class FoundationOutput:
|
| 15 |
+
loss: torch.Tensor = None
|
| 16 |
+
logits: torch.Tensor = None
|
| 17 |
+
num_output: torch.Tensor = None
|
| 18 |
+
est_err_output: torch.Tensor = None
|
| 19 |
+
hidden_states: torch.Tensor = None
|
| 20 |
+
masked_loss: torch.Tensor = None
|
| 21 |
+
num_loss: torch.Tensor = None
|
| 22 |
+
est_err_loss: torch.Tensor = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class FoundationBertConfig:
|
| 27 |
+
vocab_size: int
|
| 28 |
+
hidden_size: int
|
| 29 |
+
num_hidden_layers: int
|
| 30 |
+
num_attention_heads: int
|
| 31 |
+
intermediate_size: int
|
| 32 |
+
hidden_dropout_prob: float
|
| 33 |
+
attention_probs_dropout_prob: float
|
| 34 |
+
pad_token_id: int
|
| 35 |
+
classifier_dropout: float
|
| 36 |
+
max_position_embeddings: int
|
| 37 |
+
contrastive_temperature: float
|
| 38 |
+
loss_weights: dict
|
| 39 |
+
use_xval_loss: bool = True
|
| 40 |
+
use_mlm_loss: bool = True
|
| 41 |
+
use_regression_loss: bool = False
|
| 42 |
+
use_contrastive_loss: bool = False
|
| 43 |
+
transform_numeric: bool = False
|
| 44 |
+
|
| 45 |
+
def to_dict(self):
|
| 46 |
+
return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()}
|
| 47 |
+
|
| 48 |
+
class FoundationBert(BertModel):
|
| 49 |
+
def __init__(self,
|
| 50 |
+
config: FoundationBertConfig = None,
|
| 51 |
+
use_mlm_loss: bool = False,
|
| 52 |
+
use_regression_loss: bool = True,
|
| 53 |
+
use_contrastive_loss: bool = False,
|
| 54 |
+
use_xval_loss: bool = False,
|
| 55 |
+
transform_numeric: bool = False,
|
| 56 |
+
*args,
|
| 57 |
+
**kwargs):
|
| 58 |
+
self.gconfig = config
|
| 59 |
+
# print(f"⚠️ FoundationBert.__init__: {self.gconfig=}")
|
| 60 |
+
bert_conf = BertConfig(
|
| 61 |
+
vocab_size=config.vocab_size,
|
| 62 |
+
hidden_size=config.hidden_size,
|
| 63 |
+
num_hidden_layers=config.num_hidden_layers,
|
| 64 |
+
num_attention_heads=config.num_attention_heads,
|
| 65 |
+
intermediate_size=config.intermediate_size,
|
| 66 |
+
hidden_dropout_prob=config.hidden_dropout_prob,
|
| 67 |
+
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
| 68 |
+
pad_token_id=config.pad_token_id,
|
| 69 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 70 |
+
_attn_implementation='sdpa'
|
| 71 |
+
)
|
| 72 |
+
self.gconfig.transform_numeric = transform_numeric
|
| 73 |
+
super().__init__(bert_conf,)
|
| 74 |
+
try:
|
| 75 |
+
if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss:
|
| 76 |
+
raise ValueError("At least one loss must be enabled")
|
| 77 |
+
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
|
| 78 |
+
except:
|
| 79 |
+
self.gconfig.use_mlm_loss = use_mlm_loss
|
| 80 |
+
self.gconfig.use_regression_loss = use_regression_loss
|
| 81 |
+
self.gconfig.use_contrastive_loss = use_contrastive_loss
|
| 82 |
+
self.gconfig.use_xval_loss = use_xval_loss
|
| 83 |
+
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
|
| 84 |
+
|
| 85 |
+
self.dataset_path = kwargs.get('dataset_path', None)
|
| 86 |
+
|
| 87 |
+
self.modalities = kwargs['modalities']
|
| 88 |
+
self.mask_token = kwargs['mask_token']
|
| 89 |
+
|
| 90 |
+
self.scalar_keys = [
|
| 91 |
+
'redshift',
|
| 92 |
+
'halo_mass',
|
| 93 |
+
'stellar_mass',
|
| 94 |
+
]
|
| 95 |
+
self.vector_keys = [
|
| 96 |
+
'SED',
|
| 97 |
+
'SFH',
|
| 98 |
+
'mag_{band}_spherex',
|
| 99 |
+
'mag_{band}_lsst',
|
| 100 |
+
]
|
| 101 |
+
self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities]
|
| 102 |
+
self.modalscalars = list(dict.fromkeys(self.modalscalars))
|
| 103 |
+
|
| 104 |
+
# print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅")
|
| 105 |
+
|
| 106 |
+
self.embedding = torch.nn.ModuleDict() # modality specific embedding layers
|
| 107 |
+
self.num_head = torch.nn.ModuleDict() # modality specific regression heads
|
| 108 |
+
# create modality specific layers
|
| 109 |
+
for modality in self.modalscalars:
|
| 110 |
+
self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H)
|
| 111 |
+
self.num_head[modality] = torch.nn.Sequential(
|
| 112 |
+
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
| 113 |
+
torch.nn.LayerNorm(config.hidden_size),
|
| 114 |
+
torch.nn.GELU(),
|
| 115 |
+
torch.nn.Linear(config.hidden_size, config.hidden_size // 2),
|
| 116 |
+
torch.nn.GELU(),
|
| 117 |
+
torch.nn.Linear(config.hidden_size // 2, 1)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 121 |
+
self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
| 122 |
+
|
| 123 |
+
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
|
| 124 |
+
self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
|
| 125 |
+
self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
|
| 126 |
+
self.distributed_loss = False
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def from_pretrained(self,
|
| 130 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 131 |
+
*model_args,
|
| 132 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 133 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 134 |
+
ignore_mismatched_sizes: bool = False,
|
| 135 |
+
force_download: bool = False,
|
| 136 |
+
local_files_only: bool = False,
|
| 137 |
+
token: Optional[Union[str, bool]] = None,
|
| 138 |
+
revision: str = "main",
|
| 139 |
+
use_safetensors: bool = None,
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Modification to correctly handle loading extraneous parameters for GBert
|
| 144 |
+
"""
|
| 145 |
+
model_config = Path(pretrained_model_name_or_path) / 'train_config.yaml'
|
| 146 |
+
with open(model_config, 'r') as f:
|
| 147 |
+
config = yaml.load(f, Loader=MyLoader)
|
| 148 |
+
kwargs['modalities'] = config['modalities']
|
| 149 |
+
kwargs['dataset_path'] = config['dataset_path']
|
| 150 |
+
kwargs['mask_token'] = config['mask_token']
|
| 151 |
+
# print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅")
|
| 152 |
+
return super().from_pretrained(
|
| 153 |
+
pretrained_model_name_or_path,
|
| 154 |
+
**config['model_config'],
|
| 155 |
+
**kwargs
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def pool_output(self,
|
| 159 |
+
embeddings: torch.Tensor,
|
| 160 |
+
attention_mask: torch.Tensor,
|
| 161 |
+
use_last: bool = False
|
| 162 |
+
) -> torch.Tensor:
|
| 163 |
+
"""Average pool the hidden states using the attention mask.
|
| 164 |
+
|
| 165 |
+
Parameters
|
| 166 |
+
----------
|
| 167 |
+
embeddings : torch.Tensor
|
| 168 |
+
The hidden states to pool (B, SeqLen, HiddenDim).
|
| 169 |
+
attention_mask : torch.Tensor
|
| 170 |
+
The attention mask for the hidden states (B, SeqLen).
|
| 171 |
+
|
| 172 |
+
Returns
|
| 173 |
+
-------
|
| 174 |
+
torch.Tensor
|
| 175 |
+
The pooled embeddings (B, HiddenDim).
|
| 176 |
+
"""
|
| 177 |
+
# Get the sequence lengths
|
| 178 |
+
sl_mod = 1 if use_last else 2
|
| 179 |
+
seq_lengths = attention_mask.sum(axis=1)
|
| 180 |
+
# Set the attention mask to 0 for start and end tokens
|
| 181 |
+
new_attention = attention_mask.clone()
|
| 182 |
+
new_attention[:, 0] = attention_mask[:,0] * 0
|
| 183 |
+
new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod]
|
| 184 |
+
|
| 185 |
+
# Create a mask for the pooling operation (B, SeqLen, HiddenDim)
|
| 186 |
+
pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device)
|
| 187 |
+
# Sum the embeddings over the sequence length (use the mask to avoid
|
| 188 |
+
# pad, start, and stop tokens)
|
| 189 |
+
sum_embeds = torch.sum(embeddings * pool_mask, 1)
|
| 190 |
+
# Avoid division by zero for zero length sequences by clamping
|
| 191 |
+
# sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9)
|
| 192 |
+
seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast
|
| 193 |
+
# Compute mean pooled embeddings for each sequence
|
| 194 |
+
return sum_embeds / seq_lengths
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def last_token_pool(
|
| 198 |
+
self,
|
| 199 |
+
embeddings: torch.Tensor,
|
| 200 |
+
attention_mask: torch.Tensor,
|
| 201 |
+
) -> torch.Tensor:
|
| 202 |
+
"""Pool the last hidden states using the attention mask.
|
| 203 |
+
|
| 204 |
+
Parameters
|
| 205 |
+
----------
|
| 206 |
+
embeddings : torch.Tensor
|
| 207 |
+
The last hidden states to pool (B, SeqLen, HiddenDim).
|
| 208 |
+
attention_mask : torch.Tensor
|
| 209 |
+
The attention mask for the hidden states (B, SeqLen).
|
| 210 |
+
|
| 211 |
+
Returns
|
| 212 |
+
-------
|
| 213 |
+
torch.Tensor
|
| 214 |
+
The pooled embeddings (B, HiddenDim).
|
| 215 |
+
"""
|
| 216 |
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
| 217 |
+
if left_padding:
|
| 218 |
+
return embeddings[:, -1]
|
| 219 |
+
else:
|
| 220 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 221 |
+
batch_size = embeddings.shape[0]
|
| 222 |
+
return embeddings[
|
| 223 |
+
torch.arange(batch_size, device=embeddings.device),
|
| 224 |
+
sequence_lengths,
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
def forward(self, inputs, return_input_label_mapping=False):
|
| 228 |
+
"""
|
| 229 |
+
Forward pass that computes predictions for each modality.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
input_label_mapping (dict): A dictionary containing inputs and labels for different modalities.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
outputs (dict): A dictionary containing the logits and error logits for each modality.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
# Initialize the dictionary for the dynamic input-label mapping
|
| 239 |
+
input_label_mapping = {}
|
| 240 |
+
combined = []
|
| 241 |
+
for src_modality in self.modalscalars:
|
| 242 |
+
# Add the modality's input and label data to the input_label_mapping
|
| 243 |
+
input_label_mapping[src_modality] = {
|
| 244 |
+
'input': inputs[f"input_{src_modality}"], # Input data
|
| 245 |
+
'labels': inputs[f"labels_{src_modality}"] # Corresponding labels
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
input_data = input_label_mapping[src_modality]['input'] # get input data
|
| 249 |
+
label = input_label_mapping[src_modality]['labels'] # get label data (for masking)
|
| 250 |
+
input_data = torch.where(label, self.mask_token, input_data) # apply masking
|
| 251 |
+
|
| 252 |
+
x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H)
|
| 253 |
+
x = torch.nn.functional.silu(x)
|
| 254 |
+
combined.append(x) # combine all modalities
|
| 255 |
+
|
| 256 |
+
combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension
|
| 257 |
+
|
| 258 |
+
self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L)
|
| 259 |
+
combined += self.position_embeddings(self.position_ids) # add position embedding
|
| 260 |
+
combined = self.embed_dropout(combined)
|
| 261 |
+
|
| 262 |
+
x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input
|
| 263 |
+
|
| 264 |
+
start = 0
|
| 265 |
+
outputs = {}
|
| 266 |
+
# Iterate over each target modality to compute logits
|
| 267 |
+
for tgt_modality in self.modalscalars:
|
| 268 |
+
length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality
|
| 269 |
+
x_t = x[:, start:start+length, :] # slice the encoded output for each modality
|
| 270 |
+
outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head
|
| 271 |
+
|
| 272 |
+
start += length # update start index for next modality
|
| 273 |
+
|
| 274 |
+
if getattr(self, 'save_umap_for', None):
|
| 275 |
+
pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension
|
| 276 |
+
self.save_pooled_embedding(pooled) # saved for UMAP visualization
|
| 277 |
+
|
| 278 |
+
return (outputs, input_label_mapping) if return_input_label_mapping else outputs
|
| 279 |
+
|
| 280 |
+
def save_pooled_embedding(self, features):
|
| 281 |
+
"""
|
| 282 |
+
Save the last hidden state to a file.
|
| 283 |
+
"""
|
| 284 |
+
import h5py
|
| 285 |
+
fname = Path(self.save_umap_for)
|
| 286 |
+
fname.parent.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
features = features.detach().cpu().numpy()
|
| 289 |
+
|
| 290 |
+
if fname.exists():
|
| 291 |
+
with h5py.File(fname, 'r+') as f:
|
| 292 |
+
old_size = f['features'].shape[0] # get current size
|
| 293 |
+
new_size = old_size + features.shape[0] # calculate new size
|
| 294 |
+
|
| 295 |
+
f['features'].resize((new_size, features.shape[-1])) # resize dataset
|
| 296 |
+
f['features'][old_size:] = features # append new features
|
| 297 |
+
|
| 298 |
+
else:
|
| 299 |
+
with h5py.File(fname, 'w') as f:
|
| 300 |
+
f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2acc57f8c67e0f2b358632241243752031c23a7ed7030ba95c33b7f81e06c62
|
| 3 |
+
size 550172096
|
requirements.txt
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.1.1
|
| 2 |
+
alphashape==1.3.1
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_anyio_1742243108/work
|
| 5 |
+
appdirs==1.4.4
|
| 6 |
+
archspec @ file:///home/conda/feedstock_root/build_artifacts/archspec_1708969572489/work
|
| 7 |
+
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1733311059102/work
|
| 8 |
+
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1725356585055/work
|
| 9 |
+
arrow @ file:///home/conda/feedstock_root/build_artifacts/arrow_1733584251875/work
|
| 10 |
+
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
| 11 |
+
async-lru==2.0.5
|
| 12 |
+
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1741918516150/work
|
| 13 |
+
babel==2.17.0
|
| 14 |
+
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1738740337718/work
|
| 15 |
+
bleach @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_bleach_1737382993/work
|
| 16 |
+
boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1711936407380/work
|
| 17 |
+
brokenaxes==0.6.2
|
| 18 |
+
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1725267488082/work
|
| 19 |
+
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
|
| 20 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
|
| 21 |
+
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1724956320552/work
|
| 22 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
|
| 23 |
+
# Editable install with no version control (chatarena==0.1.8)
|
| 24 |
+
-e /pscratch/sd/b/binxia/Werewolf
|
| 25 |
+
click==8.1.8
|
| 26 |
+
click-log==0.4.0
|
| 27 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
|
| 28 |
+
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
| 29 |
+
conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1717678605937/work
|
| 30 |
+
conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1717678526951/work
|
| 31 |
+
configparser==7.2.0
|
| 32 |
+
contourpy==1.3.1
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1741148399929/work
|
| 35 |
+
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
| 36 |
+
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
|
| 37 |
+
descartes==1.1.0
|
| 38 |
+
distro @ file:///home/conda/feedstock_root/build_artifacts/distro_1704321475663/work
|
| 39 |
+
docker-pycreds==0.4.0
|
| 40 |
+
einops==0.8.0
|
| 41 |
+
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work
|
| 42 |
+
executing==2.2.0
|
| 43 |
+
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1733235979760/work/dist
|
| 44 |
+
filelock==3.17.0
|
| 45 |
+
fonttools==4.56.0
|
| 46 |
+
fqdn @ file:///home/conda/feedstock_root/build_artifacts/fqdn_1733327382592/work/dist
|
| 47 |
+
frozendict @ file:///home/conda/feedstock_root/build_artifacts/frozendict_1715092752354/work
|
| 48 |
+
fsspec==2025.2.0
|
| 49 |
+
gitdb @ file:///home/conda/feedstock_root/build_artifacts/gitdb_1735887193964/work
|
| 50 |
+
GitPython @ file:///home/conda/feedstock_root/build_artifacts/gitpython_1735929639977/work
|
| 51 |
+
h11==0.14.0
|
| 52 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work
|
| 53 |
+
hjson==3.1.0
|
| 54 |
+
hpack==4.0.0
|
| 55 |
+
httpcore==1.0.7
|
| 56 |
+
httpx==0.28.1
|
| 57 |
+
huggingface-hub==0.29.1
|
| 58 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work
|
| 59 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1724450538981/work
|
| 60 |
+
ijson==3.4.0
|
| 61 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
|
| 62 |
+
importlib_resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1736252299705/work
|
| 63 |
+
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
| 64 |
+
ipympl==0.9.7
|
| 65 |
+
ipython==9.0.2
|
| 66 |
+
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
|
| 67 |
+
ipywidgets==8.1.5
|
| 68 |
+
isoduration @ file:///home/conda/feedstock_root/build_artifacts/isoduration_1733493628631/work/dist
|
| 69 |
+
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
| 70 |
+
Jinja2==3.1.5
|
| 71 |
+
joblib==1.4.2
|
| 72 |
+
json5==0.10.0
|
| 73 |
+
jsonpatch @ file:///home/conda/feedstock_root/build_artifacts/jsonpatch_1695536281965/work
|
| 74 |
+
jsonpointer @ file:///home/conda/feedstock_root/build_artifacts/jsonpointer_1725302935093/work
|
| 75 |
+
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema_1733472696581/work
|
| 76 |
+
jsonschema-specifications @ file:///tmp/tmpk0f344m9/src
|
| 77 |
+
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_jupyter_events_1738765986/work
|
| 78 |
+
jupyter-lsp==2.2.5
|
| 79 |
+
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
| 80 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
|
| 81 |
+
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1734702637701/work
|
| 82 |
+
jupyter_server_mathjax @ file:///home/conda/feedstock_root/build_artifacts/jupyter-server-mathjax_1734509714511/work
|
| 83 |
+
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1733427956852/work
|
| 84 |
+
jupyterlab==4.3.6
|
| 85 |
+
jupyterlab_pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1733328101776/work
|
| 86 |
+
jupyterlab_server==2.27.3
|
| 87 |
+
jupyterlab_widgets==3.0.13
|
| 88 |
+
kiwisolver==1.4.8
|
| 89 |
+
kymatio==0.3.0
|
| 90 |
+
libmambapy @ file:///home/conda/feedstock_root/build_artifacts/mamba-split_1727883551957/work/libmambapy
|
| 91 |
+
llvmlite==0.44.0
|
| 92 |
+
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
|
| 93 |
+
matplotlib==3.10.1
|
| 94 |
+
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
| 95 |
+
menuinst @ file:///home/conda/feedstock_root/build_artifacts/menuinst_1725359038078/work
|
| 96 |
+
mistune @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_mistune_1742402716/work
|
| 97 |
+
mpi4py==4.0.3
|
| 98 |
+
mpmath==1.3.0
|
| 99 |
+
msgpack==1.1.0
|
| 100 |
+
narwhals==1.42.1
|
| 101 |
+
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1734628800805/work
|
| 102 |
+
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_nbconvert-core_1738067871/work
|
| 103 |
+
nbdime @ file:///home/conda/feedstock_root/build_artifacts/nbdime_1734533951497/work
|
| 104 |
+
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1733402752141/work
|
| 105 |
+
nersc-pymon==0.3.0
|
| 106 |
+
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
| 107 |
+
networkx==3.4.2
|
| 108 |
+
ninja==1.11.1.3
|
| 109 |
+
notebook==7.3.3
|
| 110 |
+
notebook_shim==0.2.4
|
| 111 |
+
numba==0.61.2
|
| 112 |
+
numpy==2.2.3
|
| 113 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 114 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 115 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 116 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 117 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 118 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 119 |
+
nvidia-curand-cu12==10.3.5.147
|
| 120 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 121 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 122 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 123 |
+
nvidia-ml-py==12.570.86
|
| 124 |
+
nvidia-nccl-cu12==2.21.5
|
| 125 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 126 |
+
nvidia-nvtx-cu12==12.4.127
|
| 127 |
+
overrides @ file:///home/conda/feedstock_root/build_artifacts/overrides_1734587627321/work
|
| 128 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
|
| 129 |
+
pandas==2.3.0
|
| 130 |
+
pandocfilters==1.5.1
|
| 131 |
+
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
| 132 |
+
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
| 133 |
+
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
| 134 |
+
pillow==11.1.0
|
| 135 |
+
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1733344503739/work
|
| 136 |
+
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
|
| 137 |
+
plotly==6.1.2
|
| 138 |
+
pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1713667077545/work
|
| 139 |
+
prometheus_client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1733327310477/work
|
| 140 |
+
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work
|
| 141 |
+
protobuf==5.29.3
|
| 142 |
+
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663123172/work
|
| 143 |
+
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
| 144 |
+
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
| 145 |
+
py-cpuinfo==9.0.0
|
| 146 |
+
pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1696355774225/work
|
| 147 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work
|
| 148 |
+
pydantic==2.10.6
|
| 149 |
+
pydantic_core==2.27.2
|
| 150 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
|
| 151 |
+
pynndescent==0.5.13
|
| 152 |
+
pyparsing==3.2.1
|
| 153 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
|
| 154 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
|
| 155 |
+
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
|
| 156 |
+
pytz==2025.2
|
| 157 |
+
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1737454647378/work
|
| 158 |
+
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1738270962252/work
|
| 159 |
+
referencing @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_referencing_1737836872/work
|
| 160 |
+
regex==2024.11.6
|
| 161 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1717057054362/work
|
| 162 |
+
rfc3339_validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1733599910982/work
|
| 163 |
+
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
|
| 164 |
+
rpds-py==0.23.1
|
| 165 |
+
rtree==1.4.1
|
| 166 |
+
ruamel.yaml @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml_1707298132558/work
|
| 167 |
+
ruamel.yaml.clib @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml.clib_1707314473810/work
|
| 168 |
+
safetensors==0.5.3
|
| 169 |
+
scikit-learn==1.6.1
|
| 170 |
+
scipy==1.15.2
|
| 171 |
+
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1733322040660/work
|
| 172 |
+
sentry-sdk==2.22.0
|
| 173 |
+
setproctitle==1.3.5
|
| 174 |
+
setuptools==73.0.1
|
| 175 |
+
shapely==2.1.2
|
| 176 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
| 177 |
+
smmap @ file:///home/conda/feedstock_root/build_artifacts/smmap_1739781697784/work
|
| 178 |
+
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1733244044561/work
|
| 179 |
+
soupsieve==2.6
|
| 180 |
+
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
| 181 |
+
sympy==1.13.1
|
| 182 |
+
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1710262609923/work
|
| 183 |
+
threadpoolctl==3.6.0
|
| 184 |
+
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1729802851396/work
|
| 185 |
+
tokenizers==0.20.3
|
| 186 |
+
torch==2.6.0
|
| 187 |
+
torchaudio==2.6.0
|
| 188 |
+
torchvision==0.21.0
|
| 189 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1732615905931/work
|
| 190 |
+
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1722737464726/work
|
| 191 |
+
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
| 192 |
+
transformers==4.46.3
|
| 193 |
+
trimesh==4.8.3
|
| 194 |
+
triton==3.2.0
|
| 195 |
+
truststore @ file:///home/conda/feedstock_root/build_artifacts/truststore_1724770958874/work
|
| 196 |
+
types-python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/types-python-dateutil_1733612335562/work
|
| 197 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
|
| 198 |
+
typing_utils @ file:///home/conda/feedstock_root/build_artifacts/typing_utils_1733331286120/work
|
| 199 |
+
tzdata==2025.2
|
| 200 |
+
umap-learn==0.5.7
|
| 201 |
+
uri-template @ file:///home/conda/feedstock_root/build_artifacts/uri-template_1733323593477/work/dist
|
| 202 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1719391292974/work
|
| 203 |
+
wandb==0.18.3
|
| 204 |
+
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
| 205 |
+
webcolors @ file:///home/conda/feedstock_root/build_artifacts/webcolors_1733359735138/work
|
| 206 |
+
webencodings @ file:///home/conda/feedstock_root/build_artifacts/webencodings_1733236011802/work
|
| 207 |
+
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1733157342724/work
|
| 208 |
+
wheel==0.44.0
|
| 209 |
+
widgetsnbextension==4.0.13
|
| 210 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
|
| 211 |
+
zstandard==0.23.0
|
setup.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from setuptools import find_packages, setup, Command
|
| 3 |
+
|
| 4 |
+
with open('requirements.txt', 'r') as f:
|
| 5 |
+
requires = f.read().splitlines()
|
| 6 |
+
|
| 7 |
+
setup(
|
| 8 |
+
name='object_foundations',
|
| 9 |
+
version = 0.0,
|
| 10 |
+
install_requirements=requires,
|
| 11 |
+
packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
)
|
train_config.yaml
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: /global/cfs/cdirs/m4717/azton/galaxy-foundations/object_foundation/utils/supermock_dataset_11.2-14.json
|
| 2 |
+
mask_token: 0
|
| 3 |
+
masked_generation: false
|
| 4 |
+
masking_prob:
|
| 5 |
+
- 0.2
|
| 6 |
+
- 0.2
|
| 7 |
+
- 0.2
|
| 8 |
+
- 0.2
|
| 9 |
+
- 0.5
|
| 10 |
+
- 0.5
|
| 11 |
+
- 0.5
|
| 12 |
+
modalities:
|
| 13 |
+
- SFH
|
| 14 |
+
- SED
|
| 15 |
+
- mag_{band}_lsst
|
| 16 |
+
- mag_{band}_spherex
|
| 17 |
+
- redshift
|
| 18 |
+
- halo_mass
|
| 19 |
+
- stellar_mass
|
| 20 |
+
model_config:
|
| 21 |
+
attention_probs_dropout_prob: 0.1
|
| 22 |
+
classifier_dropout: 0.0
|
| 23 |
+
contrastive_temperature: 0.05
|
| 24 |
+
hidden_dropout_prob: 0.1
|
| 25 |
+
hidden_size: 768
|
| 26 |
+
intermediate_size: 3072
|
| 27 |
+
loss_weights:
|
| 28 |
+
contrastive:
|
| 29 |
+
rounds: 0
|
| 30 |
+
w0T:
|
| 31 |
+
- 0
|
| 32 |
+
- 0
|
| 33 |
+
masked:
|
| 34 |
+
rounds: 0
|
| 35 |
+
w0T:
|
| 36 |
+
- 0.8
|
| 37 |
+
- 3
|
| 38 |
+
smooth:
|
| 39 |
+
rounds: 0
|
| 40 |
+
w0T:
|
| 41 |
+
- 0
|
| 42 |
+
- 0.3
|
| 43 |
+
unmasked:
|
| 44 |
+
rounds: 0
|
| 45 |
+
w0T:
|
| 46 |
+
- 0.2
|
| 47 |
+
- 0.3
|
| 48 |
+
max_position_embeddings: 1149
|
| 49 |
+
num_attention_heads: 12
|
| 50 |
+
num_hidden_layers: 18
|
| 51 |
+
pad_token_id: -1
|
| 52 |
+
transform_numeric: false
|
| 53 |
+
use_contrastive_loss: false
|
| 54 |
+
use_mlm_loss: true
|
| 55 |
+
use_regression_loss: false
|
| 56 |
+
use_xval_loss: false
|
| 57 |
+
vocab_size: 2048
|
| 58 |
+
model_name_or_path: galaxybert
|
| 59 |
+
tokenizer_name_or_path: Salesforce/SFR-Embedding-Mistral
|
| 60 |
+
training_args:
|
| 61 |
+
_n_gpu: 1
|
| 62 |
+
accelerator_config:
|
| 63 |
+
dispatch_batches: null
|
| 64 |
+
even_batches: true
|
| 65 |
+
gradient_accumulation_kwargs: null
|
| 66 |
+
non_blocking: false
|
| 67 |
+
split_batches: false
|
| 68 |
+
use_configured_state: false
|
| 69 |
+
use_seedable_sampler: true
|
| 70 |
+
adafactor: false
|
| 71 |
+
adam_beta1: 0.9
|
| 72 |
+
adam_beta2: 0.999
|
| 73 |
+
adam_epsilon: 1.0e-08
|
| 74 |
+
auto_find_batch_size: false
|
| 75 |
+
average_tokens_across_devices: false
|
| 76 |
+
batch_eval_metrics: false
|
| 77 |
+
bf16: true
|
| 78 |
+
bf16_full_eval: false
|
| 79 |
+
data_seed: null
|
| 80 |
+
dataloader_drop_last: false
|
| 81 |
+
dataloader_num_workers: 16
|
| 82 |
+
dataloader_persistent_workers: false
|
| 83 |
+
dataloader_pin_memory: true
|
| 84 |
+
dataloader_prefetch_factor: 8
|
| 85 |
+
ddp_backend: null
|
| 86 |
+
ddp_broadcast_buffers: null
|
| 87 |
+
ddp_bucket_cap_mb: null
|
| 88 |
+
ddp_find_unused_parameters: null
|
| 89 |
+
ddp_timeout: 1800
|
| 90 |
+
debug: []
|
| 91 |
+
deepspeed: null
|
| 92 |
+
disable_tqdm: false
|
| 93 |
+
dispatch_batches: null
|
| 94 |
+
do_eval: true
|
| 95 |
+
do_predict: false
|
| 96 |
+
do_train: false
|
| 97 |
+
eval_accumulation_steps: 5
|
| 98 |
+
eval_delay: 0
|
| 99 |
+
eval_do_concat_batches: true
|
| 100 |
+
eval_on_start: false
|
| 101 |
+
eval_steps: 20
|
| 102 |
+
eval_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
|
| 103 |
+
- steps
|
| 104 |
+
eval_use_gather_object: false
|
| 105 |
+
evaluation_strategy: null
|
| 106 |
+
fp16: false
|
| 107 |
+
fp16_backend: auto
|
| 108 |
+
fp16_full_eval: false
|
| 109 |
+
fp16_opt_level: O1
|
| 110 |
+
fsdp: []
|
| 111 |
+
fsdp_config:
|
| 112 |
+
min_num_params: 0
|
| 113 |
+
xla: false
|
| 114 |
+
xla_fsdp_grad_ckpt: false
|
| 115 |
+
xla_fsdp_v2: false
|
| 116 |
+
fsdp_min_num_params: 0
|
| 117 |
+
fsdp_transformer_layer_cls_to_wrap: null
|
| 118 |
+
full_determinism: false
|
| 119 |
+
gradient_accumulation_steps: 5
|
| 120 |
+
gradient_checkpointing: false
|
| 121 |
+
gradient_checkpointing_kwargs: null
|
| 122 |
+
greater_is_better: null
|
| 123 |
+
group_by_length: false
|
| 124 |
+
half_precision_backend: auto
|
| 125 |
+
hub_always_push: false
|
| 126 |
+
hub_model_id: null
|
| 127 |
+
hub_private_repo: false
|
| 128 |
+
hub_strategy: !!python/object/apply:transformers.trainer_utils.HubStrategy
|
| 129 |
+
- every_save
|
| 130 |
+
hub_token: null
|
| 131 |
+
ignore_data_skip: false
|
| 132 |
+
include_for_metrics: []
|
| 133 |
+
include_inputs_for_metrics: false
|
| 134 |
+
include_num_input_tokens_seen: false
|
| 135 |
+
include_tokens_per_second: false
|
| 136 |
+
jit_mode_eval: false
|
| 137 |
+
label_names: null
|
| 138 |
+
label_smoothing_factor: 0.0
|
| 139 |
+
learning_rate: 0.0001
|
| 140 |
+
length_column_name: length
|
| 141 |
+
load_best_model_at_end: false
|
| 142 |
+
local_rank: 0
|
| 143 |
+
log_level: passive
|
| 144 |
+
log_level_replica: warning
|
| 145 |
+
log_on_each_node: true
|
| 146 |
+
logging_dir: sm_foundation_lg_gmm_nomasklab
|
| 147 |
+
logging_first_step: true
|
| 148 |
+
logging_nan_inf_filter: true
|
| 149 |
+
logging_steps: 1
|
| 150 |
+
logging_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
|
| 151 |
+
- steps
|
| 152 |
+
lr_scheduler_kwargs: {}
|
| 153 |
+
lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
|
| 154 |
+
- cosine
|
| 155 |
+
max_grad_norm: 1.0
|
| 156 |
+
max_steps: -1
|
| 157 |
+
metric_for_best_model: null
|
| 158 |
+
mp_parameters: ''
|
| 159 |
+
neftune_noise_alpha: null
|
| 160 |
+
no_cuda: false
|
| 161 |
+
num_train_epochs: 60
|
| 162 |
+
optim: !!python/object/apply:transformers.training_args.OptimizerNames
|
| 163 |
+
- adamw_torch
|
| 164 |
+
optim_args: null
|
| 165 |
+
optim_target_modules: null
|
| 166 |
+
output_dir: supermock_te60_
|
| 167 |
+
overwrite_output_dir: true
|
| 168 |
+
past_index: -1
|
| 169 |
+
per_device_eval_batch_size: 100
|
| 170 |
+
per_device_train_batch_size: 100
|
| 171 |
+
per_gpu_eval_batch_size: null
|
| 172 |
+
per_gpu_train_batch_size: null
|
| 173 |
+
prediction_loss_only: false
|
| 174 |
+
push_to_hub: false
|
| 175 |
+
push_to_hub_model_id: null
|
| 176 |
+
push_to_hub_organization: null
|
| 177 |
+
push_to_hub_token: null
|
| 178 |
+
ray_scope: last
|
| 179 |
+
remove_unused_columns: false
|
| 180 |
+
report_to:
|
| 181 |
+
- wandb
|
| 182 |
+
restore_callback_states_from_checkpoint: false
|
| 183 |
+
resume_from_checkpoint: null
|
| 184 |
+
run_name: NO_SHARD_b50
|
| 185 |
+
save_on_each_node: false
|
| 186 |
+
save_only_model: false
|
| 187 |
+
save_safetensors: true
|
| 188 |
+
save_steps: 30
|
| 189 |
+
save_strategy: !!python/object/apply:transformers.trainer_utils.IntervalStrategy
|
| 190 |
+
- steps
|
| 191 |
+
save_total_limit: 360
|
| 192 |
+
seed: 42
|
| 193 |
+
skip_memory_metrics: true
|
| 194 |
+
split_batches: null
|
| 195 |
+
tf32: null
|
| 196 |
+
torch_compile: false
|
| 197 |
+
torch_compile_backend: null
|
| 198 |
+
torch_compile_mode: null
|
| 199 |
+
torch_empty_cache_steps: null
|
| 200 |
+
torchdynamo: null
|
| 201 |
+
tpu_metrics_debug: false
|
| 202 |
+
tpu_num_cores: null
|
| 203 |
+
use_cpu: false
|
| 204 |
+
use_ipex: false
|
| 205 |
+
use_legacy_prediction_loop: false
|
| 206 |
+
use_liger_kernel: false
|
| 207 |
+
use_mps_device: false
|
| 208 |
+
warmup_ratio: 0.0
|
| 209 |
+
warmup_steps: 0
|
| 210 |
+
weight_decay: 0.1
|
| 211 |
+
transform_numeric: false
|
| 212 |
+
wandb_project: supermock-foundation-perl
|
| 213 |
+
wandb_run_name: ''
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
utils/__pycache__/masked_data_modeling_loss.cpython-312.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
utils/__pycache__/yaml_util.cpython-312.pyc
ADDED
|
Binary file (1.83 kB). View file
|
|
|
utils/masked_data_modeling_loss.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
'''
|
| 4 |
+
Simple class to do all MLM sort of loss operations in one place
|
| 5 |
+
'''
|
| 6 |
+
class MaskedDataLossWithSoftmax(torch.nn.Module):
|
| 7 |
+
def __init__(self, ignore: int=-100, reduction: str='mean', weight=None):
|
| 8 |
+
super(MaskedDataLossWithSoftmax, self).__init__()
|
| 9 |
+
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-100,
|
| 10 |
+
reduction=reduction,
|
| 11 |
+
weight=weight)
|
| 12 |
+
|
| 13 |
+
def __call__(self, logits: torch.Tensor,
|
| 14 |
+
labels: torch.Tensor
|
| 15 |
+
)-> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Logits: [batch_size, seq_len, vocab_size]; without softmax applied
|
| 18 |
+
Labels should have -100 for all indices that are not part of masked tokens
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
logits = rearrange(logits, 'b s v -> b v s')
|
| 22 |
+
loss = self.loss(logits, labels)
|
| 23 |
+
|
| 24 |
+
return loss
|
utils/yaml_util.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
class MyLoader(yaml.SafeLoader):
|
| 3 |
+
# returns
|
| 4 |
+
def construct_mapping(self, *args, **kwargs):
|
| 5 |
+
super().add_constructor(None, construct_undefined)
|
| 6 |
+
# when loading we want to skip keys that require construction,
|
| 7 |
+
mapping = super().construct_mapping(*args, **kwargs)
|
| 8 |
+
|
| 9 |
+
return mapping
|
| 10 |
+
import typing
|
| 11 |
+
class Tagged(typing.NamedTuple):
|
| 12 |
+
tag: str
|
| 13 |
+
value: object
|
| 14 |
+
|
| 15 |
+
def construct_undefined(self, node):
|
| 16 |
+
if isinstance(node, yaml.nodes.ScalarNode):
|
| 17 |
+
value = self.construct_scalar(node)
|
| 18 |
+
elif isinstance(node, yaml.nodes.SequenceNode):
|
| 19 |
+
value = self.construct_sequence(node)
|
| 20 |
+
elif isinstance(node, yaml.nodes.MappingNode):
|
| 21 |
+
value = self.construct_mapping(node)
|
| 22 |
+
else:
|
| 23 |
+
assert False, f"unexpected node: {node!r}"
|
| 24 |
+
return Tagged(node.tag, value)
|