YuanGao-YG commited on
Commit
0462de6
·
verified ·
1 Parent(s): af56035

Upload 107 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. data/.DS_Store +0 -0
  3. data/climate_mean_s_t_ssh.npy +3 -0
  4. data/land_mask.h5 +3 -0
  5. data/mean_s_t_ssh.npy +3 -0
  6. data/std_s_t_ssh.npy +3 -0
  7. data/test/.ipynb_checkpoints/test_data-checkpoint.txt +1 -0
  8. data/test/test_data.txt +1 -0
  9. data/train/.ipynb_checkpoints/train_data-checkpoint.txt +1 -0
  10. data/train/train_data.txt +1 -0
  11. data/valid/.ipynb_checkpoints/valid_data-checkpoint.txt +1 -0
  12. data/valid/valid_data.txt +1 -0
  13. environment.yml +248 -0
  14. exp/.DS_Store +0 -0
  15. exp/NeuralOM/.DS_Store +0 -0
  16. exp/NeuralOM/20250309-195251/.DS_Store +0 -0
  17. exp/NeuralOM/20250309-195251/.ipynb_checkpoints/config-checkpoint.yaml +70 -0
  18. exp/NeuralOM/20250309-195251/6_steps_finetune/.DS_Store +0 -0
  19. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/.DS_Store +0 -0
  20. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/.DS_Store +0 -0
  21. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
  22. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt +1 -0
  23. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
  24. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt +1 -0
  25. exp/NeuralOM/20250309-195251/config.yaml +70 -0
  26. exp/baselines/readme.txt +1 -0
  27. img/.DS_Store +0 -0
  28. img/fig_NeuralOM.jpg +3 -0
  29. img/fig_csi.jpg +3 -0
  30. img/fig_rmse_acc.jpg +3 -0
  31. img/fig_visual.jpg +3 -0
  32. img/tab_acc_rmse.jpg +3 -0
  33. inference.py +312 -0
  34. inference.sh +13 -0
  35. my_utils/.ipynb_checkpoints/YParams-checkpoint.py +55 -0
  36. my_utils/.ipynb_checkpoints/data_loader-checkpoint.py +205 -0
  37. my_utils/.ipynb_checkpoints/logging_utils-checkpoint.py +26 -0
  38. my_utils/.ipynb_checkpoints/norm-checkpoint.py +114 -0
  39. my_utils/YParams.py +55 -0
  40. my_utils/__pycache__/YParams.cpython-310.pyc +0 -0
  41. my_utils/__pycache__/YParams.cpython-37.pyc +0 -0
  42. my_utils/__pycache__/YParams.cpython-39.pyc +0 -0
  43. my_utils/__pycache__/bicubic.cpython-310.pyc +0 -0
  44. my_utils/__pycache__/bicubic.cpython-39.pyc +0 -0
  45. my_utils/__pycache__/darcy_loss.cpython-310.pyc +0 -0
  46. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 +0 -0
  47. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 +0 -0
  48. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 +0 -0
  49. my_utils/__pycache__/darcy_loss.cpython-37.pyc +0 -0
  50. my_utils/__pycache__/darcy_loss.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/fig_csi.jpg filter=lfs diff=lfs merge=lfs -text
37
+ img/fig_NeuralOM.jpg filter=lfs diff=lfs merge=lfs -text
38
+ img/fig_rmse_acc.jpg filter=lfs diff=lfs merge=lfs -text
39
+ img/fig_visual.jpg filter=lfs diff=lfs merge=lfs -text
40
+ img/tab_acc_rmse.jpg filter=lfs diff=lfs merge=lfs -text
41
+ networks/__pycache__/GraphCast.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/climate_mean_s_t_ssh.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8cdfd60af19771408e952a2214593593452e82ccb45d877ee1310bfb5f13d85
3
+ size 36809870528
data/land_mask.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b21a4f8d43afff2391d95f4b7afa3242da33b6680ce93607193a3a106c50927
3
+ size 96424448
data/mean_s_t_ssh.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ed37c0e9f106ad2d403af1d6ceff5d41630a90207764200b4d6dd7aee9c9beb
3
+ size 904
data/std_s_t_ssh.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fecbc3deb38ae6ed8f98ac46126959361c59ebfb33ea46e3de0774c97291325f
3
+ size 904
data/test/.ipynb_checkpoints/test_data-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the test data in this folder
data/test/test_data.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the test data in this folder, the intact project is available at the Hugging Face.
data/train/.ipynb_checkpoints/train_data-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the train data in this folder
data/train/train_data.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the train data in this folder, the intact project is available at the Hugging Face.
data/valid/.ipynb_checkpoints/valid_data-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the valid data in this folder
data/valid/valid_data.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Please put the valid data in this folder, the intact project is available at the Hugging Face.
environment.yml ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: neuralom
2
+ channels:
3
+ - pytorch
4
+ - dglteam/label/th24_cu118
5
+ - nvidia
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotli-python=1.0.9=py310h6a678d5_8
12
+ - bzip2=1.0.8=h5eee18b_6
13
+ - ca-certificates=2024.9.24=h06a4308_0
14
+ - certifi=2024.8.30=py310h06a4308_0
15
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
16
+ - cuda-cudart=11.8.89=0
17
+ - cuda-cupti=11.8.87=0
18
+ - cuda-libraries=11.8.0=0
19
+ - cuda-nvrtc=11.8.89=0
20
+ - cuda-nvtx=11.8.86=0
21
+ - cuda-runtime=11.8.0=0
22
+ - dgl=2.4.0.th24.cu118=py310_0
23
+ - ffmpeg=4.3=hf484d3e_0
24
+ - filelock=3.13.1=py310h06a4308_0
25
+ - freetype=2.12.1=h4a9f257_0
26
+ - gmp=6.2.1=h295c915_3
27
+ - gmpy2=2.1.2=py310heeb90bb_0
28
+ - gnutls=3.6.15=he1e5248_0
29
+ - idna=3.7=py310h06a4308_0
30
+ - intel-openmp=2023.1.0=hdb19cb5_46306
31
+ - jinja2=3.1.4=py310h06a4308_0
32
+ - jpeg=9e=h5eee18b_3
33
+ - lame=3.100=h7b6447c_0
34
+ - lcms2=2.12=h3be6417_0
35
+ - ld_impl_linux-64=2.38=h1181459_1
36
+ - lerc=3.0=h295c915_0
37
+ - libcublas=11.11.3.6=0
38
+ - libcufft=10.9.0.58=0
39
+ - libcufile=1.9.1.3=0
40
+ - libcurand=10.3.5.147=0
41
+ - libcusolver=11.4.1.48=0
42
+ - libcusparse=11.7.5.86=0
43
+ - libdeflate=1.17=h5eee18b_1
44
+ - libffi=3.4.4=h6a678d5_1
45
+ - libgcc-ng=11.2.0=h1234567_1
46
+ - libgfortran-ng=11.2.0=h00389a5_1
47
+ - libgfortran5=11.2.0=h1234567_1
48
+ - libgomp=11.2.0=h1234567_1
49
+ - libiconv=1.16=h5eee18b_3
50
+ - libidn2=2.3.4=h5eee18b_0
51
+ - libjpeg-turbo=2.0.0=h9bf148f_0
52
+ - libnpp=11.8.0.86=0
53
+ - libnvjpeg=11.9.0.86=0
54
+ - libpng=1.6.39=h5eee18b_0
55
+ - libstdcxx-ng=11.2.0=h1234567_1
56
+ - libtasn1=4.19.0=h5eee18b_0
57
+ - libtiff=4.5.1=h6a678d5_0
58
+ - libunistring=0.9.10=h27cfd23_0
59
+ - libuuid=1.41.5=h5eee18b_0
60
+ - libwebp-base=1.3.2=h5eee18b_0
61
+ - llvm-openmp=14.0.6=h9e868ea_0
62
+ - lz4-c=1.9.4=h6a678d5_1
63
+ - markupsafe=2.1.3=py310h5eee18b_0
64
+ - mkl=2023.1.0=h213fc3f_46344
65
+ - mkl-service=2.4.0=py310h5eee18b_1
66
+ - mkl_fft=1.3.10=py310h5eee18b_0
67
+ - mkl_random=1.2.7=py310h1128e8f_0
68
+ - mpc=1.1.0=h10f8cd9_1
69
+ - mpfr=4.0.2=hb69a4c5_1
70
+ - mpmath=1.3.0=py310h06a4308_0
71
+ - ncurses=6.4=h6a678d5_0
72
+ - nettle=3.7.3=hbbd107a_1
73
+ - networkx=3.3=py310h06a4308_0
74
+ - numpy=1.26.4=py310h5f9d8c6_0
75
+ - numpy-base=1.26.4=py310hb5e798b_0
76
+ - openh264=2.1.1=h4ff587b_0
77
+ - openjpeg=2.5.2=he7f1fd0_0
78
+ - openssl=3.0.15=h5eee18b_0
79
+ - pillow=10.4.0=py310h5eee18b_0
80
+ - pip=24.2=py310h06a4308_0
81
+ - psutil=5.9.0=py310h5eee18b_0
82
+ - pybind11-abi=4=hd3eb1b0_1
83
+ - pysocks=1.7.1=py310h06a4308_0
84
+ - python=3.10.15=he870216_1
85
+ - pytorch=2.4.0=py3.10_cuda11.8_cudnn9.1.0_0
86
+ - pytorch-cuda=11.8=h7e8668a_5
87
+ - pytorch-mutex=1.0=cuda
88
+ - pyyaml=6.0.1=py310h5eee18b_0
89
+ - readline=8.2=h5eee18b_0
90
+ - requests=2.32.3=py310h06a4308_0
91
+ - scipy=1.13.1=py310h5f9d8c6_0
92
+ - setuptools=75.1.0=py310h06a4308_0
93
+ - sqlite=3.45.3=h5eee18b_0
94
+ - sympy=1.13.2=py310h06a4308_0
95
+ - tbb=2021.8.0=hdb19cb5_0
96
+ - tk=8.6.14=h39e8969_0
97
+ - torchaudio=2.4.0=py310_cu118
98
+ - torchtriton=3.0.0=py310
99
+ - torchvision=0.19.0=py310_cu118
100
+ - tqdm=4.66.5=py310h2f386ee_0
101
+ - typing_extensions=4.11.0=py310h06a4308_0
102
+ - urllib3=2.2.3=py310h06a4308_0
103
+ - wheel=0.44.0=py310h06a4308_0
104
+ - xz=5.4.6=h5eee18b_1
105
+ - yaml=0.2.5=h7b6447c_0
106
+ - zlib=1.2.13=h5eee18b_1
107
+ - zstd=1.5.5=hc292b87_2
108
+ - pip:
109
+ - aiobotocore==2.15.1
110
+ - aiohappyeyeballs==2.4.3
111
+ - aiohttp==3.10.8
112
+ - aioitertools==0.12.0
113
+ - aiosignal==1.3.1
114
+ - anyio==4.6.0
115
+ - argon2-cffi==23.1.0
116
+ - argon2-cffi-bindings==21.2.0
117
+ - arrow==1.3.0
118
+ - asttokens==2.4.1
119
+ - async-lru==2.0.4
120
+ - async-timeout==4.0.3
121
+ - attrs==24.2.0
122
+ - babel==2.16.0
123
+ - beautifulsoup4==4.12.3
124
+ - bleach==6.1.0
125
+ - blessed==1.20.0
126
+ - botocore==1.35.23
127
+ - cartopy==0.24.1
128
+ - cffi==1.17.1
129
+ - cftime==1.6.4.post1
130
+ - cmocean==4.0.3
131
+ - colorama==0.4.6
132
+ - comm==0.2.2
133
+ - contourpy==1.3.0
134
+ - cycler==0.12.1
135
+ - debugpy==1.8.6
136
+ - decorator==5.1.1
137
+ - defusedxml==0.7.1
138
+ - einops==0.8.0
139
+ - exceptiongroup==1.2.2
140
+ - executing==2.1.0
141
+ - fastjsonschema==2.20.0
142
+ - fonttools==4.54.1
143
+ - fqdn==1.5.1
144
+ - frozenlist==1.4.1
145
+ - fsspec==2024.9.0
146
+ - gpustat==1.1.1
147
+ - h11==0.14.0
148
+ - h5netcdf==1.4.0
149
+ - h5py==3.12.1
150
+ - httpcore==1.0.6
151
+ - httpx==0.27.2
152
+ - huggingface-hub==0.25.1
153
+ - icecream==2.1.3
154
+ - importlib-metadata==8.5.0
155
+ - ipykernel==6.29.5
156
+ - ipython==8.28.0
157
+ - isoduration==20.11.0
158
+ - jedi==0.19.1
159
+ - jmespath==1.0.1
160
+ - joblib==1.4.2
161
+ - json5==0.9.25
162
+ - jsonpointer==3.0.0
163
+ - jsonschema==4.23.0
164
+ - jsonschema-specifications==2024.10.1
165
+ - jupyter-client==8.6.3
166
+ - jupyter-core==5.7.2
167
+ - jupyter-events==0.10.0
168
+ - jupyter-lsp==2.2.5
169
+ - jupyter-server==2.14.2
170
+ - jupyter-server-terminals==0.5.3
171
+ - jupyterlab==4.2.5
172
+ - jupyterlab-pygments==0.3.0
173
+ - jupyterlab-server==2.27.3
174
+ - kiwisolver==1.4.7
175
+ - matplotlib==3.9.2
176
+ - matplotlib-inline==0.1.7
177
+ - mistune==3.0.2
178
+ - multidict==6.1.0
179
+ - nbclient==0.10.0
180
+ - nbconvert==7.16.4
181
+ - nbformat==5.10.4
182
+ - nest-asyncio==1.6.0
183
+ - netcdf4==1.7.2
184
+ - notebook==7.2.2
185
+ - notebook-shim==0.2.4
186
+ - nvfuser-cu118-torch24==0.2.9.dev20240808
187
+ - nvidia-cuda-cupti-cu11==11.8.87
188
+ - nvidia-cuda-nvrtc-cu11==11.8.89
189
+ - nvidia-cuda-runtime-cu11==11.8.89
190
+ - nvidia-ml-py==12.560.30
191
+ - nvidia-nvtx-cu11==11.8.86
192
+ - overrides==7.7.0
193
+ - packaging==24.1
194
+ - pandas==2.2.3
195
+ - pandocfilters==1.5.1
196
+ - parso==0.8.4
197
+ - pexpect==4.9.0
198
+ - platformdirs==4.3.6
199
+ - prometheus-client==0.21.0
200
+ - prompt-toolkit==3.0.48
201
+ - ptyprocess==0.7.0
202
+ - pure-eval==0.2.3
203
+ - pycparser==2.22
204
+ - pygments==2.18.0
205
+ - pyparsing==3.2.0
206
+ - pyproj==3.7.0
207
+ - pyshp==2.3.1
208
+ - python-dateutil==2.9.0.post0
209
+ - python-json-logger==2.0.7
210
+ - pytz==2024.2
211
+ - pyzmq==26.2.0
212
+ - referencing==0.35.1
213
+ - rfc3339-validator==0.1.4
214
+ - rfc3986-validator==0.1.1
215
+ - rpds-py==0.20.0
216
+ - ruamel-yaml==0.18.6
217
+ - ruamel-yaml-clib==0.2.8
218
+ - s3fs==2024.9.0
219
+ - safetensors==0.4.5
220
+ - scikit-learn==1.5.2
221
+ - send2trash==1.8.3
222
+ - shapely==2.0.6
223
+ - six==1.16.0
224
+ - sniffio==1.3.1
225
+ - soupsieve==2.6
226
+ - stack-data==0.6.3
227
+ - terminado==0.18.1
228
+ - thop==0.1.1-2209072238
229
+ - threadpoolctl==3.5.0
230
+ - timm==1.0.9
231
+ - tinycss2==1.3.0
232
+ - tomli==2.0.2
233
+ - torchsummary==1.5.1
234
+ - tornado==6.4.1
235
+ - traitlets==5.14.3
236
+ - treelib==1.7.0
237
+ - types-python-dateutil==2.9.0.20241003
238
+ - tzdata==2024.2
239
+ - uri-template==1.3.0
240
+ - wcwidth==0.2.13
241
+ - webcolors==24.8.0
242
+ - webencodings==0.5.1
243
+ - websocket-client==1.8.0
244
+ - wrapt==1.16.0
245
+ - xarray==2024.9.0
246
+ - yarl==1.13.1
247
+ - zipp==3.20.2
248
+ prefix: /miniconda3/envs/neuralom
exp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/20250309-195251/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/20250309-195251/.ipynb_checkpoints/config-checkpoint.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### base config ###
2
+ # -*- coding: utf-8 -*-
3
+ full_field: &FULL_FIELD
4
+ num_data_workers: 4
5
+ dt: 1
6
+ n_history: 0
7
+ prediction_length: 41
8
+ ics_type: "default"
9
+
10
+ exp_dir: './exp'
11
+
12
+ # data
13
+ train_data_path: './data/train'
14
+ valid_data_path: './data/valid'
15
+ test_data_path: './data/test'
16
+
17
+ # land mask
18
+ land_mask: !!bool True
19
+ land_mask_path: './data/land_mask.h5'
20
+
21
+ # normalization
22
+ normalize: !!bool True
23
+ normalization: 'zscore'
24
+ global_means_path: './data/mean_s_t_ssh.npy'
25
+ global_stds_path: './data/std_s_t_ssh.npy'
26
+
27
+ # orography
28
+ orography: !!bool False
29
+
30
+ # noise
31
+ add_noise: !!bool False
32
+ noise_std: 0
33
+
34
+ # crop
35
+ crop_size_x: None
36
+ crop_size_y: None
37
+
38
+ log_to_screen: !!bool True
39
+ log_to_wandb: !!bool False
40
+ save_checkpoint: !!bool True
41
+ plot_animations: !!bool False
42
+
43
+
44
+ #############################################
45
+ NeuralOM: &NeuralOM
46
+ <<: *FULL_FIELD
47
+ nettype: 'NeuralOM'
48
+ log_to_wandb: !!bool False
49
+
50
+ # Train params
51
+ lr: 1e-3
52
+ batch_size: 32
53
+ scheduler: 'CosineAnnealingLR'
54
+
55
+ loss_channel_wise: True
56
+ loss_scale: False
57
+ use_loss_scaler_from_metnet3: True
58
+
59
+
60
+ atmos_channels: [93, 94, 95, 96]
61
+
62
+ ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
63
+
64
+ in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
65
+
66
+ out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
67
+
68
+
69
+ out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
70
+
exp/NeuralOM/20250309-195251/6_steps_finetune/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9fe78a12419997b9deaf0dd2ec912c1c936e96cc418bc507acdd2baecf908a2
3
+ size 661771939
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fb1c1827478608c96b490662f9b59351a52b1ee423883158c0a9e7c09b7d919
3
+ size 661813411
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/config.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### base config ###
2
+ # -*- coding: utf-8 -*-
3
+ full_field: &FULL_FIELD
4
+ num_data_workers: 4
5
+ dt: 1
6
+ n_history: 0
7
+ prediction_length: 41
8
+ ics_type: "default"
9
+
10
+ exp_dir: './exp'
11
+
12
+ # data
13
+ train_data_path: './data/train'
14
+ valid_data_path: './data/valid'
15
+ test_data_path: './data/test'
16
+
17
+ # land mask
18
+ land_mask: !!bool True
19
+ land_mask_path: './data/land_mask.h5'
20
+
21
+ # normalization
22
+ normalize: !!bool True
23
+ normalization: 'zscore'
24
+ global_means_path: './data/mean_s_t_ssh.npy'
25
+ global_stds_path: './data/std_s_t_ssh.npy'
26
+
27
+ # orography
28
+ orography: !!bool False
29
+
30
+ # noise
31
+ add_noise: !!bool False
32
+ noise_std: 0
33
+
34
+ # crop
35
+ crop_size_x: None
36
+ crop_size_y: None
37
+
38
+ log_to_screen: !!bool True
39
+ log_to_wandb: !!bool False
40
+ save_checkpoint: !!bool True
41
+ plot_animations: !!bool False
42
+
43
+
44
+ #############################################
45
+ NeuralOM: &NeuralOM
46
+ <<: *FULL_FIELD
47
+ nettype: 'NeuralOM'
48
+ log_to_wandb: !!bool False
49
+
50
+ # Train params
51
+ lr: 1e-3
52
+ batch_size: 32
53
+ scheduler: 'CosineAnnealingLR'
54
+
55
+ loss_channel_wise: True
56
+ loss_scale: False
57
+ use_loss_scaler_from_metnet3: True
58
+
59
+
60
+ atmos_channels: [93, 94, 95, 96]
61
+
62
+ ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
63
+
64
+ in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
65
+
66
+ out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
67
+
68
+
69
+ out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
70
+
exp/baselines/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
img/.DS_Store ADDED
Binary file (6.15 kB). View file
 
img/fig_NeuralOM.jpg ADDED

Git LFS Details

  • SHA256: 64ddc7efd09c19612aff6391a9755428c2648aafbaa3878af0b5689f1ca917ee
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
img/fig_csi.jpg ADDED

Git LFS Details

  • SHA256: 2578e3c09addf3356ecc7182d17ef9a439504bd6b919c18c2358266e45665110
  • Pointer size: 131 Bytes
  • Size of remote file: 644 kB
img/fig_rmse_acc.jpg ADDED

Git LFS Details

  • SHA256: 31273d21a43094e6502d0949f949781d515398714fa9ebdd4d1c8ef9940bac0c
  • Pointer size: 131 Bytes
  • Size of remote file: 914 kB
img/fig_visual.jpg ADDED

Git LFS Details

  • SHA256: 83becb78e355cc1fc06e76fea58df76f127af0f90133616667dbe90494336118
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
img/tab_acc_rmse.jpg ADDED

Git LFS Details

  • SHA256: de7afbd5f51ef96f4f27ac4fdb3412051ad52c93a97ea72d3504ced1b053b857
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
inference.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import glob
5
+ import h5py
6
+ import logging
7
+ import argparse
8
+ import numpy as np
9
+ from icecream import ic
10
+ from datetime import datetime
11
+ from collections import OrderedDict
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel
17
+
18
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
19
+ from my_utils.YParams import YParams
20
+ from my_utils.data_loader import get_data_loader
21
+ from my_utils import logging_utils
22
+ logging_utils.config_logger()
23
+
24
+
25
+ def load_model(model, params, checkpoint_file):
26
+ model.zero_grad()
27
+ checkpoint_fname = checkpoint_file
28
+ checkpoint = torch.load(checkpoint_fname)
29
+ try:
30
+ new_state_dict = OrderedDict()
31
+ for key, val in checkpoint['model_state'].items():
32
+ name = key[7:]
33
+ if name != 'ged':
34
+ new_state_dict[name] = val
35
+ model.load_state_dict(new_state_dict)
36
+ except:
37
+ model.load_state_dict(checkpoint['model_state'])
38
+ model.eval()
39
+ return model
40
+
41
+ def setup(params):
42
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
43
+
44
+ # get data loader
45
+ valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
46
+
47
+ img_shape_x = valid_dataset.img_shape_x
48
+ img_shape_y = valid_dataset.img_shape_y
49
+ params.img_shape_x = img_shape_x
50
+ params.img_shape_y = img_shape_y
51
+
52
+ in_channels = np.array(params.in_channels)
53
+ out_channels = np.array(params.out_channels)
54
+ n_in_channels = len(in_channels)
55
+ n_out_channels = len(out_channels)
56
+
57
+ params['N_in_channels'] = n_in_channels
58
+ params['N_out_channels'] = n_out_channels
59
+
60
+ if params.normalization == 'zscore':
61
+ params.means = np.load(params.global_means_path)
62
+ params.stds = np.load(params.global_stds_path)
63
+
64
+ if params.nettype == 'NeuralOM':
65
+ from networks.MIGNN1 import MIGraph as model
66
+ from networks.MIGNN2 import MIGraph_stage2 as model2
67
+ else:
68
+ raise Exception("not implemented")
69
+
70
+ checkpoint_file = params['best_checkpoint_path']
71
+ checkpoint_file2 = params['best_checkpoint_path2']
72
+ logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
73
+ logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
74
+
75
+ model = model(params).to(device)
76
+ model = load_model(model, params, checkpoint_file)
77
+ model = model.to(device)
78
+
79
+ print('model is ok')
80
+
81
+ model2 = model2(params).to(device)
82
+ model2 = load_model(model2, params, checkpoint_file2)
83
+ model2 = model2.to(device)
84
+
85
+ print('model2 is ok')
86
+
87
+ files_paths = glob.glob(params.test_data_path + "/*.h5")
88
+ files_paths.sort()
89
+
90
+ # which year
91
+ yr = 0
92
+ logging.info('Loading inference data')
93
+ logging.info('Inference data from {}'.format(files_paths[yr]))
94
+ climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
95
+ valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
96
+ valid_data_full = valid_data_full - climate_mean
97
+
98
+ return valid_data_full, model, model2
99
+
100
+
101
+ def autoregressive_inference(params, init_condition, valid_data_full, model, model2):
102
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
103
+
104
+ icd = int(init_condition)
105
+
106
+ exp_dir = params['experiment_dir']
107
+ dt = int(params.dt)
108
+ prediction_length = int(params.prediction_length/dt)
109
+ n_history = params.n_history
110
+ img_shape_x = params.img_shape_x
111
+ img_shape_y = params.img_shape_y
112
+ in_channels = np.array(params.in_channels)
113
+ out_channels = np.array(params.out_channels)
114
+ atmos_channels = np.array(params.atmos_channels)
115
+ n_in_channels = len(in_channels)
116
+ n_out_channels = len(out_channels)
117
+
118
+ seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
119
+ seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
120
+
121
+
122
+ valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
123
+ logging.info(f'valid_data_full: {valid_data_full.shape}')
124
+ logging.info(f'valid_data: {valid_data.shape}')
125
+
126
+ # normalize
127
+ if params.normalization == 'zscore':
128
+ valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
129
+ valid_data = np.nan_to_num(valid_data, nan=0)
130
+
131
+ valid_data = torch.as_tensor(valid_data)
132
+
133
+ # autoregressive inference
134
+ logging.info('Begin autoregressive inference')
135
+
136
+
137
+ with torch.no_grad():
138
+ for i in range(valid_data.shape[0]):
139
+ if i==0: # start of sequence, t0 --> t0'
140
+ first = valid_data[0:n_history+1]
141
+ ic(valid_data.shape, first.shape)
142
+ future = valid_data[n_history+1]
143
+ ic(future.shape)
144
+
145
+ for h in range(n_history+1):
146
+
147
+ seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
148
+
149
+ seq_pred[h] = seq_real[h]
150
+
151
+ first = first.to(device, dtype=torch.float)
152
+ first_ocean = first[:, params.ocean_channels, :, :]
153
+ ic(first_ocean.shape)
154
+ future_force0 = first[:, params.atmos_channels, :, :]
155
+
156
+ future_force = future[params.atmos_channels, :360, :720]
157
+ future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
158
+ model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
159
+ ic(model_input.shape)
160
+ model1_future_pred = model(model_input)
161
+ with h5py.File(params.land_mask_path, 'r') as _f:
162
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
163
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
164
+ future_pred = model2(model1_future_pred) + model1_future_pred
165
+
166
+ else:
167
+ if i < prediction_length-1:
168
+ future0 = valid_data[n_history+i]
169
+ future = valid_data[n_history+i+1]
170
+
171
+ inf_one_step_start = time.time()
172
+ future_force0 = future0[params.atmos_channels, :360, :720]
173
+ future_force = future[params.atmos_channels, :360, :720]
174
+ future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
175
+ future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
176
+ model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
177
+ with h5py.File(params.land_mask_path, 'r') as _f:
178
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
179
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
180
+ future_pred = model2(model1_future_pred) + model1_future_pred
181
+ inf_one_step_time = time.time() - inf_one_step_start
182
+
183
+ logging.info(f'inference one step time: {inf_one_step_time}')
184
+
185
+ if i < prediction_length - 1: # not on the last step
186
+ with h5py.File(params.land_mask_path, 'r') as _f:
187
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
188
+ seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
189
+ seq_real[n_history+i+1] = future[:93]
190
+ history_stack = seq_pred[i+1:i+2+n_history]
191
+
192
+ future_pred = history_stack
193
+
194
+ pred = torch.unsqueeze(seq_pred[i], 0)
195
+ tar = torch.unsqueeze(seq_real[i], 0)
196
+
197
+ with h5py.File(params.land_mask_path, 'r') as _f:
198
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
199
+ ic(mask_data.shape, pred.shape, tar.shape)
200
+ pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
201
+ tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
202
+
203
+ print(torch.mean((pred-tar)**2))
204
+
205
+
206
+ seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
207
+ seq_real = seq_real.numpy()
208
+ seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
209
+ seq_pred = seq_pred.numpy()
210
+
211
+
212
+ return (np.expand_dims(seq_real[n_history:], 0),
213
+ np.expand_dims(seq_pred[n_history:], 0),
214
+ )
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
220
+ parser.add_argument("--config", default='full_field', type=str)
221
+ parser.add_argument("--run_num", default='00', type=str)
222
+ parser.add_argument("--prediction_length", default=61, type=int)
223
+ parser.add_argument("--finetune_dir", default='', type=str)
224
+ parser.add_argument("--ics_type", default='default', type=str)
225
+ args = parser.parse_args()
226
+
227
+ config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
228
+ params = YParams(config_path, args.config)
229
+
230
+ params['resuming'] = False
231
+ params['interp'] = 0
232
+ params['world_size'] = 1
233
+ params['local_rank'] = 0
234
+ params['global_batch_size'] = params.batch_size
235
+ params['prediction_length'] = args.prediction_length
236
+ params['multi_steps_finetune'] = 1
237
+
238
+ torch.cuda.set_device(0)
239
+ torch.backends.cudnn.benchmark = True
240
+
241
+ # Set up directory
242
+ if args.finetune_dir == '':
243
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
244
+ else:
245
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
246
+ logging.info(f'expDir: {expDir}')
247
+ params['experiment_dir'] = expDir
248
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
249
+ params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
250
+
251
+ # set up logging
252
+ logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
253
+ logging_utils.log_versions()
254
+ params.log()
255
+
256
+ if params["ics_type"] == 'default':
257
+ ics = np.arange(0, 240, 1)
258
+ n_ics = len(ics)
259
+ print('init_condition:', ics)
260
+
261
+ logging.info("Inference for {} initial conditions".format(n_ics))
262
+
263
+ try:
264
+ autoregressive_inference_filetag = params["inference_file_tag"]
265
+ except:
266
+ autoregressive_inference_filetag = ""
267
+ if params.interp > 0:
268
+ autoregressive_inference_filetag = "_coarse"
269
+
270
+ valid_data_full, model, model2 = setup(params)
271
+
272
+
273
+ seq_pred = []
274
+ seq_real = []
275
+
276
+ # run autoregressive inference for multiple initial conditions
277
+ for i, ic_ in enumerate(ics):
278
+ logging.info("Initial condition {} of {}".format(i+1, n_ics))
279
+ seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, model, model2)
280
+
281
+ prediction_length = seq_real[0].shape[0]
282
+ n_out_channels = seq_real[0].shape[1]
283
+ img_shape_x = seq_real[0].shape[2]
284
+ img_shape_y = seq_real[0].shape[3]
285
+
286
+ # save predictions and loss
287
+ save_path = os.path.join(params['experiment_dir'], 'results.h5')
288
+ logging.info("Saving to {}".format(save_path))
289
+ print(f'saving to {save_path}')
290
+ if i==0:
291
+ f = h5py.File(save_path, 'w')
292
+ f.create_dataset(
293
+ "ground_truth",
294
+ data=seq_real,
295
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
296
+ dtype=np.float32)
297
+ f.create_dataset(
298
+ "predicted",
299
+ data=seq_pred,
300
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
301
+ dtype=np.float32)
302
+ f.close()
303
+ else:
304
+ f = h5py.File(save_path, 'a')
305
+
306
+ f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
307
+ f["ground_truth"][-1:] = seq_real
308
+
309
+ f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
310
+ f["predicted"][-1:] = seq_pred
311
+ f.close()
312
+
inference.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prediction_length=61 # 31
2
+
3
+ exp_dir='./exp'
4
+ config='NeuralOM'
5
+ run_num='20250309-195251'
6
+ finetune_dir='6_steps_finetune'
7
+
8
+ ics_type='default'
9
+
10
+ CUDA_VISIBLE_DEVICES=0 python inference.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
11
+
12
+
13
+
my_utils/.ipynb_checkpoints/YParams-checkpoint.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import importlib
4
+ import sys
5
+ import os
6
+ importlib.reload(sys)
7
+
8
+ from ruamel.yaml import YAML
9
+ import logging
10
+
11
+ class YParams():
12
+ """ Yaml file parser """
13
+ def __init__(self, yaml_filename, config_name, print_params=False):
14
+ self._yaml_filename = yaml_filename
15
+ self._config_name = config_name
16
+ self.params = {}
17
+
18
+ if print_params:
19
+ print(os.system('hostname'))
20
+ print("------------------ Configuration ------------------ ", yaml_filename)
21
+
22
+ with open(yaml_filename, 'rb') as _file:
23
+ yaml = YAML().load(_file)
24
+ for key, val in yaml[config_name].items():
25
+ if print_params: print(key, val)
26
+ if val =='None': val = None
27
+
28
+ self.params[key] = val
29
+ self.__setattr__(key, val)
30
+
31
+ if print_params:
32
+ print("---------------------------------------------------")
33
+
34
+ def __getitem__(self, key):
35
+ return self.params[key]
36
+
37
+ def __setitem__(self, key, val):
38
+ self.params[key] = val
39
+ self.__setattr__(key, val)
40
+
41
+ def __contains__(self, key):
42
+ return (key in self.params)
43
+
44
+ def update_params(self, config):
45
+ for key, val in config.items():
46
+ self.params[key] = val
47
+ self.__setattr__(key, val)
48
+
49
+ def log(self):
50
+ logging.info("------------------ Configuration ------------------")
51
+ logging.info("Configuration file: "+str(self._yaml_filename))
52
+ logging.info("Configuration name: "+str(self._config_name))
53
+ for key, val in self.params.items():
54
+ logging.info(str(key) + ' ' + str(val))
55
+ logging.info("---------------------------------------------------")
my_utils/.ipynb_checkpoints/data_loader-checkpoint.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch import Tensor
9
+ import h5py
10
+ import math
11
+ from my_utils.norm import reshape_fields
12
+ import os
13
+
14
+
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ parent_dir = os.path.dirname(current_dir)
17
+ climate_mean_path = os.path.join(parent_dir, 'data/climate_mean_s_t_ssh.npy')
18
+
19
+ def get_data_loader(params, files_pattern, distributed, train):
20
+ dataset = GetDataset(params, files_pattern, train)
21
+ sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
22
+
23
+
24
+ dataloader = DataLoader(dataset,
25
+ batch_size = int(params.batch_size),
26
+ num_workers = params.num_data_workers,
27
+ shuffle = False,
28
+ sampler = sampler if train else None,
29
+ drop_last = True,
30
+ pin_memory = True)
31
+
32
+ if train:
33
+ return dataloader, dataset, sampler
34
+ else:
35
+ return dataloader, dataset
36
+
37
+
38
+ class GetDataset(Dataset):
39
+ def __init__(self, params, location, train):
40
+ self.params = params
41
+ self.location = location
42
+ self.train = train
43
+ self.orography = params.orography
44
+ self.normalize = params.normalize
45
+ self.dt = params.dt
46
+ self.n_history = params.n_history
47
+ self.in_channels = np.array(params.in_channels)
48
+ self.out_channels = np.array(params.out_channels)
49
+ self.ocean_channels = np.array(params.ocean_channels)
50
+ self.atmos_channels = np.array(params.atmos_channels)
51
+ self.n_in_channels = len(self.in_channels)
52
+ self.n_out_channels = len(self.out_channels)
53
+
54
+ self._get_files_stats()
55
+ self.add_noise = params.add_noise if train else False
56
+ self.climate_mean = np.load(climate_mean_path, mmap_mode='r')
57
+
58
+
59
+ def _get_files_stats(self):
60
+ self.files_paths = glob.glob(self.location + "/*.h5")
61
+ self.files_paths.sort()
62
+ self.n_years = len(self.files_paths)
63
+
64
+ with h5py.File(self.files_paths[0], 'r') as _f:
65
+ logging.info("Getting file stats from {}".format(self.files_paths[0]))
66
+
67
+ self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
68
+
69
+ self.img_shape_x = _f['fields'].shape[2] - 1
70
+ self.img_shape_y = _f['fields'].shape[3]
71
+
72
+ self.n_samples_total = self.n_years * self.n_samples_per_year
73
+ self.files = [None for _ in range(self.n_years)]
74
+
75
+ logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
76
+ logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
77
+ self.n_samples_total,
78
+ self.img_shape_x,
79
+ self.img_shape_y,
80
+ self.n_in_channels))
81
+ logging.info("Delta t: {} days".format(1 * self.dt))
82
+ logging.info("Including {} days of past history in training at a frequency of {} days".format(
83
+ 1 * self.dt * self.n_history, 1 * self.dt))
84
+
85
+ def _open_file(self, year_idx):
86
+ _file = h5py.File(self.files_paths[year_idx], 'r')
87
+ self.files[year_idx] = _file['fields']
88
+
89
+ if self.orography and self.params.normalization == 'zscore':
90
+ _orog_file = h5py.File(self.params.orography_norm_zscore_path, 'r')
91
+ if self.orography and self.params.normalization == 'maxmin':
92
+ _orog_file = h5py.File(self.params.orography_norm_maxmin_path, 'r')
93
+
94
+ def __len__(self):
95
+ return self.n_samples_total
96
+
97
+ def __getitem__(self, global_idx):
98
+ year_idx = int(global_idx / self.n_samples_per_year) # which year
99
+ local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
100
+
101
+ if self.files[year_idx] is None:
102
+ self._open_file(year_idx)
103
+
104
+ if local_idx < self.dt * self.n_history:
105
+ local_idx += self.dt * self.n_history
106
+
107
+ step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
108
+
109
+ orog = None
110
+
111
+
112
+ if self.params.multi_steps_finetune == 1:
113
+ if local_idx == 365:
114
+ local_idx = 364
115
+
116
+ climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
117
+ ocean = reshape_fields(
118
+ self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
119
+ 'ocean',
120
+ self.params,
121
+ self.train,
122
+ self.normalize,
123
+ orog,
124
+ self.add_noise
125
+ )
126
+
127
+ force_future0 = reshape_fields(
128
+ self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
129
+ 'force',
130
+ self.params,
131
+ self.train,
132
+ self.normalize,
133
+ orog,
134
+ self.add_noise
135
+ )
136
+
137
+ force_future1 = reshape_fields(
138
+ self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
139
+ 'force',
140
+ self.params,
141
+ self.train,
142
+ self.normalize,
143
+ orog,
144
+ self.add_noise
145
+ )
146
+
147
+ climate_mean_tar = self.climate_mean[local_idx+step, self.out_channels, :360, :720]
148
+ tar = reshape_fields(
149
+ self.files[year_idx][local_idx+step, self.out_channels, :360, :720] - climate_mean_tar,
150
+ 'tar',
151
+ self.params,
152
+ self.train,
153
+ self.normalize,
154
+ orog
155
+ )
156
+ else:
157
+ climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
158
+ ocean = reshape_fields(
159
+ self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
160
+ 'ocean',
161
+ self.params,
162
+ self.train,
163
+ self.normalize,
164
+ orog,
165
+ self.add_noise
166
+ )
167
+
168
+ force_future0 = reshape_fields(
169
+ self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
170
+ 'force',
171
+ self.params,
172
+ self.train,
173
+ self.normalize,
174
+ orog,
175
+ self.add_noise
176
+ )
177
+
178
+ force_future1 = reshape_fields(
179
+ self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
180
+ 'force',
181
+ self.params,
182
+ self.train,
183
+ self.normalize,
184
+ orog,
185
+ self.add_noise
186
+ )
187
+
188
+ climate_mean_tar = self.climate_mean[local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
189
+ tar_data = self.files[year_idx][local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
190
+ tar = reshape_fields(
191
+ tar_data - climate_mean_tar,
192
+ 'inp',
193
+ self.params,
194
+ self.train,
195
+ self.normalize,
196
+ orog
197
+ )
198
+
199
+ ocean = np.nan_to_num(ocean, nan=0)
200
+ force_future0 = np.nan_to_num(force_future0, nan=0)
201
+ force_future1 = np.nan_to_num(force_future1, nan=0)
202
+ tar = np.nan_to_num(tar, nan=0)
203
+
204
+
205
+ return np.concatenate((ocean, force_future0, force_future1), axis=0), tar
my_utils/.ipynb_checkpoints/logging_utils-checkpoint.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
+
6
+ def config_logger(log_level=logging.INFO):
7
+ logging.basicConfig(format=_format, level=log_level)
8
+
9
+ def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
10
+
11
+ if not os.path.exists(os.path.dirname(log_filename)):
12
+ os.makedirs(os.path.dirname(log_filename))
13
+
14
+ if logger_name is not None:
15
+ log = logging.getLogger(logger_name)
16
+ else:
17
+ log = logging.getLogger()
18
+
19
+ fh = logging.FileHandler(log_filename)
20
+ fh.setLevel(log_level)
21
+ fh.setFormatter(logging.Formatter(_format))
22
+ log.addHandler(fh)
23
+
24
+ def log_versions():
25
+ import torch
26
+ import subprocess
my_utils/.ipynb_checkpoints/norm-checkpoint.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from types import new_class
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch import Tensor
13
+ import h5py
14
+ import math
15
+ import torchvision.transforms.functional as TF
16
+ # import matplotlib
17
+ # import matplotlib.pyplot as plt
18
+
19
+ class PeriodicPad2d(nn.Module):
20
+ """
21
+ pad longitudinal (left-right) circular
22
+ and pad latitude (top-bottom) with zeros
23
+ """
24
+ def __init__(self, pad_width):
25
+ super(PeriodicPad2d, self).__init__()
26
+ self.pad_width = pad_width
27
+
28
+ def forward(self, x):
29
+ # pad left and right circular
30
+ out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
31
+ # pad top and bottom zeros
32
+ out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
33
+ return out
34
+
35
+ def reshape_fields(img, inp_or_tar, params, train, normalize=True, orog=None, add_noise=False):
36
+ # Takes in np array of size (n_history+1, c, h, w)
37
+ # returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
38
+
39
+ if len(np.shape(img)) == 3:
40
+ img = np.expand_dims(img, 0)
41
+
42
+ if np.shape(img)[2] == 721:
43
+ img = img[:,:, 0:720, :] # remove last pixel
44
+
45
+ n_history = np.shape(img)[0] - 1
46
+ img_shape_x = np.shape(img)[-2]
47
+ img_shape_y = np.shape(img)[-1]
48
+ n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
49
+
50
+ if inp_or_tar == 'inp':
51
+ channels = params.in_channels
52
+ elif inp_or_tar == 'ocean':
53
+ channels = params.ocean_channels
54
+ elif inp_or_tar == 'force':
55
+ channels = params.atmos_channels
56
+ else:
57
+ channels = params.out_channels
58
+
59
+ if normalize and params.normalization == 'minmax':
60
+ maxs = np.load(params.global_maxs_path)[:, channels]
61
+ mins = np.load(params.global_mins_path)[:, channels]
62
+ img = (img - mins) / (maxs - mins)
63
+
64
+ if normalize and params.normalization == 'zscore':
65
+ means = np.load(params.global_means_path)[:, channels]
66
+ stds = np.load(params.global_stds_path)[:, channels]
67
+ img -=means
68
+ img /=stds
69
+
70
+ if normalize and params.normalization == 'zscore_lat':
71
+ means = np.load(params.global_lat_means_path)[:, channels,:720]
72
+ stds = np.load(params.global_lat_stds_path)[:, channels,:720]
73
+ img -=means
74
+ img /=stds
75
+
76
+ if params.orography and inp_or_tar == 'inp':
77
+ # print('img:', img.shape, 'orog:', orog.shape)
78
+ orog = np.expand_dims(orog, axis = (0,1))
79
+ orog = np.repeat(orog, repeats=img.shape[0], axis=0)
80
+ # print('img:', img.shape, 'orog:', orog.shape)
81
+ img = np.concatenate((img, orog), axis = 1)
82
+ n_channels += 1
83
+
84
+ img = np.squeeze(img)
85
+ # if inp_or_tar == 'inp':
86
+ # img = np.reshape(img, (n_channels*(n_history+1))) # ??
87
+ # elif inp_or_tar == 'tar':
88
+ # img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) #??
89
+
90
+ if add_noise:
91
+ img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)
92
+
93
+ return torch.as_tensor(img)
94
+
95
+ def vis_precip(fields):
96
+ pred, tar = fields
97
+ fig, ax = plt.subplots(1, 2, figsize=(24,12))
98
+ ax[0].imshow(pred, cmap="coolwarm")
99
+ ax[0].set_title("tp pred")
100
+ ax[1].imshow(tar, cmap="coolwarm")
101
+ ax[1].set_title("tp tar")
102
+ fig.tight_layout()
103
+ return fig
104
+
105
+ def read_max_min_value(min_max_val_file_path):
106
+ with h5py.File(min_max_val_file_path, 'r') as f:
107
+ max_values = f['max_values']
108
+ min_values = f['min_values']
109
+ return max_values, min_values
110
+
111
+
112
+
113
+
114
+
my_utils/YParams.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import importlib
4
+ import sys
5
+ import os
6
+ importlib.reload(sys)
7
+
8
+ from ruamel.yaml import YAML
9
+ import logging
10
+
11
+ class YParams():
12
+ """ Yaml file parser """
13
+ def __init__(self, yaml_filename, config_name, print_params=False):
14
+ self._yaml_filename = yaml_filename
15
+ self._config_name = config_name
16
+ self.params = {}
17
+
18
+ if print_params:
19
+ print(os.system('hostname'))
20
+ print("------------------ Configuration ------------------ ", yaml_filename)
21
+
22
+ with open(yaml_filename, 'rb') as _file:
23
+ yaml = YAML().load(_file)
24
+ for key, val in yaml[config_name].items():
25
+ if print_params: print(key, val)
26
+ if val =='None': val = None
27
+
28
+ self.params[key] = val
29
+ self.__setattr__(key, val)
30
+
31
+ if print_params:
32
+ print("---------------------------------------------------")
33
+
34
+ def __getitem__(self, key):
35
+ return self.params[key]
36
+
37
+ def __setitem__(self, key, val):
38
+ self.params[key] = val
39
+ self.__setattr__(key, val)
40
+
41
+ def __contains__(self, key):
42
+ return (key in self.params)
43
+
44
+ def update_params(self, config):
45
+ for key, val in config.items():
46
+ self.params[key] = val
47
+ self.__setattr__(key, val)
48
+
49
+ def log(self):
50
+ logging.info("------------------ Configuration ------------------")
51
+ logging.info("Configuration file: "+str(self._yaml_filename))
52
+ logging.info("Configuration name: "+str(self._config_name))
53
+ for key, val in self.params.items():
54
+ logging.info(str(key) + ' ' + str(val))
55
+ logging.info("---------------------------------------------------")
my_utils/__pycache__/YParams.cpython-310.pyc ADDED
Binary file (2.12 kB). View file
 
my_utils/__pycache__/YParams.cpython-37.pyc ADDED
Binary file (2.11 kB). View file
 
my_utils/__pycache__/YParams.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
my_utils/__pycache__/bicubic.cpython-310.pyc ADDED
Binary file (9.24 kB). View file
 
my_utils/__pycache__/bicubic.cpython-39.pyc ADDED
Binary file (9.2 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-37.pyc ADDED
Binary file (9.02 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-39.pyc ADDED
Binary file (14.2 kB). View file