Upload 107 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- data/.DS_Store +0 -0
- data/climate_mean_s_t_ssh.npy +3 -0
- data/land_mask.h5 +3 -0
- data/mean_s_t_ssh.npy +3 -0
- data/std_s_t_ssh.npy +3 -0
- data/test/.ipynb_checkpoints/test_data-checkpoint.txt +1 -0
- data/test/test_data.txt +1 -0
- data/train/.ipynb_checkpoints/train_data-checkpoint.txt +1 -0
- data/train/train_data.txt +1 -0
- data/valid/.ipynb_checkpoints/valid_data-checkpoint.txt +1 -0
- data/valid/valid_data.txt +1 -0
- environment.yml +248 -0
- exp/.DS_Store +0 -0
- exp/NeuralOM/.DS_Store +0 -0
- exp/NeuralOM/20250309-195251/.DS_Store +0 -0
- exp/NeuralOM/20250309-195251/.ipynb_checkpoints/config-checkpoint.yaml +70 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/.DS_Store +0 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/.DS_Store +0 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/.DS_Store +0 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt +1 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt +1 -0
- exp/NeuralOM/20250309-195251/config.yaml +70 -0
- exp/baselines/readme.txt +1 -0
- img/.DS_Store +0 -0
- img/fig_NeuralOM.jpg +3 -0
- img/fig_csi.jpg +3 -0
- img/fig_rmse_acc.jpg +3 -0
- img/fig_visual.jpg +3 -0
- img/tab_acc_rmse.jpg +3 -0
- inference.py +312 -0
- inference.sh +13 -0
- my_utils/.ipynb_checkpoints/YParams-checkpoint.py +55 -0
- my_utils/.ipynb_checkpoints/data_loader-checkpoint.py +205 -0
- my_utils/.ipynb_checkpoints/logging_utils-checkpoint.py +26 -0
- my_utils/.ipynb_checkpoints/norm-checkpoint.py +114 -0
- my_utils/YParams.py +55 -0
- my_utils/__pycache__/YParams.cpython-310.pyc +0 -0
- my_utils/__pycache__/YParams.cpython-37.pyc +0 -0
- my_utils/__pycache__/YParams.cpython-39.pyc +0 -0
- my_utils/__pycache__/bicubic.cpython-310.pyc +0 -0
- my_utils/__pycache__/bicubic.cpython-39.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-37.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-39.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
img/fig_csi.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
img/fig_NeuralOM.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
img/fig_rmse_acc.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
img/fig_visual.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
img/tab_acc_rmse.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
networks/__pycache__/GraphCast.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
data/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
data/climate_mean_s_t_ssh.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8cdfd60af19771408e952a2214593593452e82ccb45d877ee1310bfb5f13d85
|
| 3 |
+
size 36809870528
|
data/land_mask.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b21a4f8d43afff2391d95f4b7afa3242da33b6680ce93607193a3a106c50927
|
| 3 |
+
size 96424448
|
data/mean_s_t_ssh.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ed37c0e9f106ad2d403af1d6ceff5d41630a90207764200b4d6dd7aee9c9beb
|
| 3 |
+
size 904
|
data/std_s_t_ssh.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fecbc3deb38ae6ed8f98ac46126959361c59ebfb33ea46e3de0774c97291325f
|
| 3 |
+
size 904
|
data/test/.ipynb_checkpoints/test_data-checkpoint.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the test data in this folder
|
data/test/test_data.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the test data in this folder, the intact project is available at the Hugging Face.
|
data/train/.ipynb_checkpoints/train_data-checkpoint.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the train data in this folder
|
data/train/train_data.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the train data in this folder, the intact project is available at the Hugging Face.
|
data/valid/.ipynb_checkpoints/valid_data-checkpoint.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the valid data in this folder
|
data/valid/valid_data.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please put the valid data in this folder, the intact project is available at the Hugging Face.
|
environment.yml
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: neuralom
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- dglteam/label/th24_cu118
|
| 5 |
+
- nvidia
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- _libgcc_mutex=0.1=main
|
| 9 |
+
- _openmp_mutex=5.1=1_gnu
|
| 10 |
+
- blas=1.0=mkl
|
| 11 |
+
- brotli-python=1.0.9=py310h6a678d5_8
|
| 12 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 13 |
+
- ca-certificates=2024.9.24=h06a4308_0
|
| 14 |
+
- certifi=2024.8.30=py310h06a4308_0
|
| 15 |
+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
|
| 16 |
+
- cuda-cudart=11.8.89=0
|
| 17 |
+
- cuda-cupti=11.8.87=0
|
| 18 |
+
- cuda-libraries=11.8.0=0
|
| 19 |
+
- cuda-nvrtc=11.8.89=0
|
| 20 |
+
- cuda-nvtx=11.8.86=0
|
| 21 |
+
- cuda-runtime=11.8.0=0
|
| 22 |
+
- dgl=2.4.0.th24.cu118=py310_0
|
| 23 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 24 |
+
- filelock=3.13.1=py310h06a4308_0
|
| 25 |
+
- freetype=2.12.1=h4a9f257_0
|
| 26 |
+
- gmp=6.2.1=h295c915_3
|
| 27 |
+
- gmpy2=2.1.2=py310heeb90bb_0
|
| 28 |
+
- gnutls=3.6.15=he1e5248_0
|
| 29 |
+
- idna=3.7=py310h06a4308_0
|
| 30 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
| 31 |
+
- jinja2=3.1.4=py310h06a4308_0
|
| 32 |
+
- jpeg=9e=h5eee18b_3
|
| 33 |
+
- lame=3.100=h7b6447c_0
|
| 34 |
+
- lcms2=2.12=h3be6417_0
|
| 35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 36 |
+
- lerc=3.0=h295c915_0
|
| 37 |
+
- libcublas=11.11.3.6=0
|
| 38 |
+
- libcufft=10.9.0.58=0
|
| 39 |
+
- libcufile=1.9.1.3=0
|
| 40 |
+
- libcurand=10.3.5.147=0
|
| 41 |
+
- libcusolver=11.4.1.48=0
|
| 42 |
+
- libcusparse=11.7.5.86=0
|
| 43 |
+
- libdeflate=1.17=h5eee18b_1
|
| 44 |
+
- libffi=3.4.4=h6a678d5_1
|
| 45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 46 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
| 47 |
+
- libgfortran5=11.2.0=h1234567_1
|
| 48 |
+
- libgomp=11.2.0=h1234567_1
|
| 49 |
+
- libiconv=1.16=h5eee18b_3
|
| 50 |
+
- libidn2=2.3.4=h5eee18b_0
|
| 51 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
| 52 |
+
- libnpp=11.8.0.86=0
|
| 53 |
+
- libnvjpeg=11.9.0.86=0
|
| 54 |
+
- libpng=1.6.39=h5eee18b_0
|
| 55 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 56 |
+
- libtasn1=4.19.0=h5eee18b_0
|
| 57 |
+
- libtiff=4.5.1=h6a678d5_0
|
| 58 |
+
- libunistring=0.9.10=h27cfd23_0
|
| 59 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 60 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
| 61 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
| 62 |
+
- lz4-c=1.9.4=h6a678d5_1
|
| 63 |
+
- markupsafe=2.1.3=py310h5eee18b_0
|
| 64 |
+
- mkl=2023.1.0=h213fc3f_46344
|
| 65 |
+
- mkl-service=2.4.0=py310h5eee18b_1
|
| 66 |
+
- mkl_fft=1.3.10=py310h5eee18b_0
|
| 67 |
+
- mkl_random=1.2.7=py310h1128e8f_0
|
| 68 |
+
- mpc=1.1.0=h10f8cd9_1
|
| 69 |
+
- mpfr=4.0.2=hb69a4c5_1
|
| 70 |
+
- mpmath=1.3.0=py310h06a4308_0
|
| 71 |
+
- ncurses=6.4=h6a678d5_0
|
| 72 |
+
- nettle=3.7.3=hbbd107a_1
|
| 73 |
+
- networkx=3.3=py310h06a4308_0
|
| 74 |
+
- numpy=1.26.4=py310h5f9d8c6_0
|
| 75 |
+
- numpy-base=1.26.4=py310hb5e798b_0
|
| 76 |
+
- openh264=2.1.1=h4ff587b_0
|
| 77 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
| 78 |
+
- openssl=3.0.15=h5eee18b_0
|
| 79 |
+
- pillow=10.4.0=py310h5eee18b_0
|
| 80 |
+
- pip=24.2=py310h06a4308_0
|
| 81 |
+
- psutil=5.9.0=py310h5eee18b_0
|
| 82 |
+
- pybind11-abi=4=hd3eb1b0_1
|
| 83 |
+
- pysocks=1.7.1=py310h06a4308_0
|
| 84 |
+
- python=3.10.15=he870216_1
|
| 85 |
+
- pytorch=2.4.0=py3.10_cuda11.8_cudnn9.1.0_0
|
| 86 |
+
- pytorch-cuda=11.8=h7e8668a_5
|
| 87 |
+
- pytorch-mutex=1.0=cuda
|
| 88 |
+
- pyyaml=6.0.1=py310h5eee18b_0
|
| 89 |
+
- readline=8.2=h5eee18b_0
|
| 90 |
+
- requests=2.32.3=py310h06a4308_0
|
| 91 |
+
- scipy=1.13.1=py310h5f9d8c6_0
|
| 92 |
+
- setuptools=75.1.0=py310h06a4308_0
|
| 93 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 94 |
+
- sympy=1.13.2=py310h06a4308_0
|
| 95 |
+
- tbb=2021.8.0=hdb19cb5_0
|
| 96 |
+
- tk=8.6.14=h39e8969_0
|
| 97 |
+
- torchaudio=2.4.0=py310_cu118
|
| 98 |
+
- torchtriton=3.0.0=py310
|
| 99 |
+
- torchvision=0.19.0=py310_cu118
|
| 100 |
+
- tqdm=4.66.5=py310h2f386ee_0
|
| 101 |
+
- typing_extensions=4.11.0=py310h06a4308_0
|
| 102 |
+
- urllib3=2.2.3=py310h06a4308_0
|
| 103 |
+
- wheel=0.44.0=py310h06a4308_0
|
| 104 |
+
- xz=5.4.6=h5eee18b_1
|
| 105 |
+
- yaml=0.2.5=h7b6447c_0
|
| 106 |
+
- zlib=1.2.13=h5eee18b_1
|
| 107 |
+
- zstd=1.5.5=hc292b87_2
|
| 108 |
+
- pip:
|
| 109 |
+
- aiobotocore==2.15.1
|
| 110 |
+
- aiohappyeyeballs==2.4.3
|
| 111 |
+
- aiohttp==3.10.8
|
| 112 |
+
- aioitertools==0.12.0
|
| 113 |
+
- aiosignal==1.3.1
|
| 114 |
+
- anyio==4.6.0
|
| 115 |
+
- argon2-cffi==23.1.0
|
| 116 |
+
- argon2-cffi-bindings==21.2.0
|
| 117 |
+
- arrow==1.3.0
|
| 118 |
+
- asttokens==2.4.1
|
| 119 |
+
- async-lru==2.0.4
|
| 120 |
+
- async-timeout==4.0.3
|
| 121 |
+
- attrs==24.2.0
|
| 122 |
+
- babel==2.16.0
|
| 123 |
+
- beautifulsoup4==4.12.3
|
| 124 |
+
- bleach==6.1.0
|
| 125 |
+
- blessed==1.20.0
|
| 126 |
+
- botocore==1.35.23
|
| 127 |
+
- cartopy==0.24.1
|
| 128 |
+
- cffi==1.17.1
|
| 129 |
+
- cftime==1.6.4.post1
|
| 130 |
+
- cmocean==4.0.3
|
| 131 |
+
- colorama==0.4.6
|
| 132 |
+
- comm==0.2.2
|
| 133 |
+
- contourpy==1.3.0
|
| 134 |
+
- cycler==0.12.1
|
| 135 |
+
- debugpy==1.8.6
|
| 136 |
+
- decorator==5.1.1
|
| 137 |
+
- defusedxml==0.7.1
|
| 138 |
+
- einops==0.8.0
|
| 139 |
+
- exceptiongroup==1.2.2
|
| 140 |
+
- executing==2.1.0
|
| 141 |
+
- fastjsonschema==2.20.0
|
| 142 |
+
- fonttools==4.54.1
|
| 143 |
+
- fqdn==1.5.1
|
| 144 |
+
- frozenlist==1.4.1
|
| 145 |
+
- fsspec==2024.9.0
|
| 146 |
+
- gpustat==1.1.1
|
| 147 |
+
- h11==0.14.0
|
| 148 |
+
- h5netcdf==1.4.0
|
| 149 |
+
- h5py==3.12.1
|
| 150 |
+
- httpcore==1.0.6
|
| 151 |
+
- httpx==0.27.2
|
| 152 |
+
- huggingface-hub==0.25.1
|
| 153 |
+
- icecream==2.1.3
|
| 154 |
+
- importlib-metadata==8.5.0
|
| 155 |
+
- ipykernel==6.29.5
|
| 156 |
+
- ipython==8.28.0
|
| 157 |
+
- isoduration==20.11.0
|
| 158 |
+
- jedi==0.19.1
|
| 159 |
+
- jmespath==1.0.1
|
| 160 |
+
- joblib==1.4.2
|
| 161 |
+
- json5==0.9.25
|
| 162 |
+
- jsonpointer==3.0.0
|
| 163 |
+
- jsonschema==4.23.0
|
| 164 |
+
- jsonschema-specifications==2024.10.1
|
| 165 |
+
- jupyter-client==8.6.3
|
| 166 |
+
- jupyter-core==5.7.2
|
| 167 |
+
- jupyter-events==0.10.0
|
| 168 |
+
- jupyter-lsp==2.2.5
|
| 169 |
+
- jupyter-server==2.14.2
|
| 170 |
+
- jupyter-server-terminals==0.5.3
|
| 171 |
+
- jupyterlab==4.2.5
|
| 172 |
+
- jupyterlab-pygments==0.3.0
|
| 173 |
+
- jupyterlab-server==2.27.3
|
| 174 |
+
- kiwisolver==1.4.7
|
| 175 |
+
- matplotlib==3.9.2
|
| 176 |
+
- matplotlib-inline==0.1.7
|
| 177 |
+
- mistune==3.0.2
|
| 178 |
+
- multidict==6.1.0
|
| 179 |
+
- nbclient==0.10.0
|
| 180 |
+
- nbconvert==7.16.4
|
| 181 |
+
- nbformat==5.10.4
|
| 182 |
+
- nest-asyncio==1.6.0
|
| 183 |
+
- netcdf4==1.7.2
|
| 184 |
+
- notebook==7.2.2
|
| 185 |
+
- notebook-shim==0.2.4
|
| 186 |
+
- nvfuser-cu118-torch24==0.2.9.dev20240808
|
| 187 |
+
- nvidia-cuda-cupti-cu11==11.8.87
|
| 188 |
+
- nvidia-cuda-nvrtc-cu11==11.8.89
|
| 189 |
+
- nvidia-cuda-runtime-cu11==11.8.89
|
| 190 |
+
- nvidia-ml-py==12.560.30
|
| 191 |
+
- nvidia-nvtx-cu11==11.8.86
|
| 192 |
+
- overrides==7.7.0
|
| 193 |
+
- packaging==24.1
|
| 194 |
+
- pandas==2.2.3
|
| 195 |
+
- pandocfilters==1.5.1
|
| 196 |
+
- parso==0.8.4
|
| 197 |
+
- pexpect==4.9.0
|
| 198 |
+
- platformdirs==4.3.6
|
| 199 |
+
- prometheus-client==0.21.0
|
| 200 |
+
- prompt-toolkit==3.0.48
|
| 201 |
+
- ptyprocess==0.7.0
|
| 202 |
+
- pure-eval==0.2.3
|
| 203 |
+
- pycparser==2.22
|
| 204 |
+
- pygments==2.18.0
|
| 205 |
+
- pyparsing==3.2.0
|
| 206 |
+
- pyproj==3.7.0
|
| 207 |
+
- pyshp==2.3.1
|
| 208 |
+
- python-dateutil==2.9.0.post0
|
| 209 |
+
- python-json-logger==2.0.7
|
| 210 |
+
- pytz==2024.2
|
| 211 |
+
- pyzmq==26.2.0
|
| 212 |
+
- referencing==0.35.1
|
| 213 |
+
- rfc3339-validator==0.1.4
|
| 214 |
+
- rfc3986-validator==0.1.1
|
| 215 |
+
- rpds-py==0.20.0
|
| 216 |
+
- ruamel-yaml==0.18.6
|
| 217 |
+
- ruamel-yaml-clib==0.2.8
|
| 218 |
+
- s3fs==2024.9.0
|
| 219 |
+
- safetensors==0.4.5
|
| 220 |
+
- scikit-learn==1.5.2
|
| 221 |
+
- send2trash==1.8.3
|
| 222 |
+
- shapely==2.0.6
|
| 223 |
+
- six==1.16.0
|
| 224 |
+
- sniffio==1.3.1
|
| 225 |
+
- soupsieve==2.6
|
| 226 |
+
- stack-data==0.6.3
|
| 227 |
+
- terminado==0.18.1
|
| 228 |
+
- thop==0.1.1-2209072238
|
| 229 |
+
- threadpoolctl==3.5.0
|
| 230 |
+
- timm==1.0.9
|
| 231 |
+
- tinycss2==1.3.0
|
| 232 |
+
- tomli==2.0.2
|
| 233 |
+
- torchsummary==1.5.1
|
| 234 |
+
- tornado==6.4.1
|
| 235 |
+
- traitlets==5.14.3
|
| 236 |
+
- treelib==1.7.0
|
| 237 |
+
- types-python-dateutil==2.9.0.20241003
|
| 238 |
+
- tzdata==2024.2
|
| 239 |
+
- uri-template==1.3.0
|
| 240 |
+
- wcwidth==0.2.13
|
| 241 |
+
- webcolors==24.8.0
|
| 242 |
+
- webencodings==0.5.1
|
| 243 |
+
- websocket-client==1.8.0
|
| 244 |
+
- wrapt==1.16.0
|
| 245 |
+
- xarray==2024.9.0
|
| 246 |
+
- yarl==1.13.1
|
| 247 |
+
- zipp==3.20.2
|
| 248 |
+
prefix: /miniconda3/envs/neuralom
|
exp/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/20250309-195251/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/20250309-195251/.ipynb_checkpoints/config-checkpoint.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### base config ###
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
full_field: &FULL_FIELD
|
| 4 |
+
num_data_workers: 4
|
| 5 |
+
dt: 1
|
| 6 |
+
n_history: 0
|
| 7 |
+
prediction_length: 41
|
| 8 |
+
ics_type: "default"
|
| 9 |
+
|
| 10 |
+
exp_dir: './exp'
|
| 11 |
+
|
| 12 |
+
# data
|
| 13 |
+
train_data_path: './data/train'
|
| 14 |
+
valid_data_path: './data/valid'
|
| 15 |
+
test_data_path: './data/test'
|
| 16 |
+
|
| 17 |
+
# land mask
|
| 18 |
+
land_mask: !!bool True
|
| 19 |
+
land_mask_path: './data/land_mask.h5'
|
| 20 |
+
|
| 21 |
+
# normalization
|
| 22 |
+
normalize: !!bool True
|
| 23 |
+
normalization: 'zscore'
|
| 24 |
+
global_means_path: './data/mean_s_t_ssh.npy'
|
| 25 |
+
global_stds_path: './data/std_s_t_ssh.npy'
|
| 26 |
+
|
| 27 |
+
# orography
|
| 28 |
+
orography: !!bool False
|
| 29 |
+
|
| 30 |
+
# noise
|
| 31 |
+
add_noise: !!bool False
|
| 32 |
+
noise_std: 0
|
| 33 |
+
|
| 34 |
+
# crop
|
| 35 |
+
crop_size_x: None
|
| 36 |
+
crop_size_y: None
|
| 37 |
+
|
| 38 |
+
log_to_screen: !!bool True
|
| 39 |
+
log_to_wandb: !!bool False
|
| 40 |
+
save_checkpoint: !!bool True
|
| 41 |
+
plot_animations: !!bool False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
#############################################
|
| 45 |
+
NeuralOM: &NeuralOM
|
| 46 |
+
<<: *FULL_FIELD
|
| 47 |
+
nettype: 'NeuralOM'
|
| 48 |
+
log_to_wandb: !!bool False
|
| 49 |
+
|
| 50 |
+
# Train params
|
| 51 |
+
lr: 1e-3
|
| 52 |
+
batch_size: 32
|
| 53 |
+
scheduler: 'CosineAnnealingLR'
|
| 54 |
+
|
| 55 |
+
loss_channel_wise: True
|
| 56 |
+
loss_scale: False
|
| 57 |
+
use_loss_scaler_from_metnet3: True
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
atmos_channels: [93, 94, 95, 96]
|
| 61 |
+
|
| 62 |
+
ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
| 63 |
+
|
| 64 |
+
in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
|
| 65 |
+
|
| 66 |
+
out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
|
| 70 |
+
|
exp/NeuralOM/20250309-195251/6_steps_finetune/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9fe78a12419997b9deaf0dd2ec912c1c936e96cc418bc507acdd2baecf908a2
|
| 3 |
+
size 661771939
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8fb1c1827478608c96b490662f9b59351a52b1ee423883158c0a9e7c09b7d919
|
| 3 |
+
size 661813411
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/config.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### base config ###
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
full_field: &FULL_FIELD
|
| 4 |
+
num_data_workers: 4
|
| 5 |
+
dt: 1
|
| 6 |
+
n_history: 0
|
| 7 |
+
prediction_length: 41
|
| 8 |
+
ics_type: "default"
|
| 9 |
+
|
| 10 |
+
exp_dir: './exp'
|
| 11 |
+
|
| 12 |
+
# data
|
| 13 |
+
train_data_path: './data/train'
|
| 14 |
+
valid_data_path: './data/valid'
|
| 15 |
+
test_data_path: './data/test'
|
| 16 |
+
|
| 17 |
+
# land mask
|
| 18 |
+
land_mask: !!bool True
|
| 19 |
+
land_mask_path: './data/land_mask.h5'
|
| 20 |
+
|
| 21 |
+
# normalization
|
| 22 |
+
normalize: !!bool True
|
| 23 |
+
normalization: 'zscore'
|
| 24 |
+
global_means_path: './data/mean_s_t_ssh.npy'
|
| 25 |
+
global_stds_path: './data/std_s_t_ssh.npy'
|
| 26 |
+
|
| 27 |
+
# orography
|
| 28 |
+
orography: !!bool False
|
| 29 |
+
|
| 30 |
+
# noise
|
| 31 |
+
add_noise: !!bool False
|
| 32 |
+
noise_std: 0
|
| 33 |
+
|
| 34 |
+
# crop
|
| 35 |
+
crop_size_x: None
|
| 36 |
+
crop_size_y: None
|
| 37 |
+
|
| 38 |
+
log_to_screen: !!bool True
|
| 39 |
+
log_to_wandb: !!bool False
|
| 40 |
+
save_checkpoint: !!bool True
|
| 41 |
+
plot_animations: !!bool False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
#############################################
|
| 45 |
+
NeuralOM: &NeuralOM
|
| 46 |
+
<<: *FULL_FIELD
|
| 47 |
+
nettype: 'NeuralOM'
|
| 48 |
+
log_to_wandb: !!bool False
|
| 49 |
+
|
| 50 |
+
# Train params
|
| 51 |
+
lr: 1e-3
|
| 52 |
+
batch_size: 32
|
| 53 |
+
scheduler: 'CosineAnnealingLR'
|
| 54 |
+
|
| 55 |
+
loss_channel_wise: True
|
| 56 |
+
loss_scale: False
|
| 57 |
+
use_loss_scaler_from_metnet3: True
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
atmos_channels: [93, 94, 95, 96]
|
| 61 |
+
|
| 62 |
+
ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
| 63 |
+
|
| 64 |
+
in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
|
| 65 |
+
|
| 66 |
+
out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
|
| 70 |
+
|
exp/baselines/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The intact project is available at the Hugging Face.
|
img/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
img/fig_NeuralOM.jpg
ADDED
|
Git LFS Details
|
img/fig_csi.jpg
ADDED
|
Git LFS Details
|
img/fig_rmse_acc.jpg
ADDED
|
Git LFS Details
|
img/fig_visual.jpg
ADDED
|
Git LFS Details
|
img/tab_acc_rmse.jpg
ADDED
|
Git LFS Details
|
inference.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import glob
|
| 5 |
+
import h5py
|
| 6 |
+
import logging
|
| 7 |
+
import argparse
|
| 8 |
+
import numpy as np
|
| 9 |
+
from icecream import ic
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 17 |
+
|
| 18 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
|
| 19 |
+
from my_utils.YParams import YParams
|
| 20 |
+
from my_utils.data_loader import get_data_loader
|
| 21 |
+
from my_utils import logging_utils
|
| 22 |
+
logging_utils.config_logger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_model(model, params, checkpoint_file):
|
| 26 |
+
model.zero_grad()
|
| 27 |
+
checkpoint_fname = checkpoint_file
|
| 28 |
+
checkpoint = torch.load(checkpoint_fname)
|
| 29 |
+
try:
|
| 30 |
+
new_state_dict = OrderedDict()
|
| 31 |
+
for key, val in checkpoint['model_state'].items():
|
| 32 |
+
name = key[7:]
|
| 33 |
+
if name != 'ged':
|
| 34 |
+
new_state_dict[name] = val
|
| 35 |
+
model.load_state_dict(new_state_dict)
|
| 36 |
+
except:
|
| 37 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 38 |
+
model.eval()
|
| 39 |
+
return model
|
| 40 |
+
|
| 41 |
+
def setup(params):
|
| 42 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
|
| 44 |
+
# get data loader
|
| 45 |
+
valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
|
| 46 |
+
|
| 47 |
+
img_shape_x = valid_dataset.img_shape_x
|
| 48 |
+
img_shape_y = valid_dataset.img_shape_y
|
| 49 |
+
params.img_shape_x = img_shape_x
|
| 50 |
+
params.img_shape_y = img_shape_y
|
| 51 |
+
|
| 52 |
+
in_channels = np.array(params.in_channels)
|
| 53 |
+
out_channels = np.array(params.out_channels)
|
| 54 |
+
n_in_channels = len(in_channels)
|
| 55 |
+
n_out_channels = len(out_channels)
|
| 56 |
+
|
| 57 |
+
params['N_in_channels'] = n_in_channels
|
| 58 |
+
params['N_out_channels'] = n_out_channels
|
| 59 |
+
|
| 60 |
+
if params.normalization == 'zscore':
|
| 61 |
+
params.means = np.load(params.global_means_path)
|
| 62 |
+
params.stds = np.load(params.global_stds_path)
|
| 63 |
+
|
| 64 |
+
if params.nettype == 'NeuralOM':
|
| 65 |
+
from networks.MIGNN1 import MIGraph as model
|
| 66 |
+
from networks.MIGNN2 import MIGraph_stage2 as model2
|
| 67 |
+
else:
|
| 68 |
+
raise Exception("not implemented")
|
| 69 |
+
|
| 70 |
+
checkpoint_file = params['best_checkpoint_path']
|
| 71 |
+
checkpoint_file2 = params['best_checkpoint_path2']
|
| 72 |
+
logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
|
| 73 |
+
logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
|
| 74 |
+
|
| 75 |
+
model = model(params).to(device)
|
| 76 |
+
model = load_model(model, params, checkpoint_file)
|
| 77 |
+
model = model.to(device)
|
| 78 |
+
|
| 79 |
+
print('model is ok')
|
| 80 |
+
|
| 81 |
+
model2 = model2(params).to(device)
|
| 82 |
+
model2 = load_model(model2, params, checkpoint_file2)
|
| 83 |
+
model2 = model2.to(device)
|
| 84 |
+
|
| 85 |
+
print('model2 is ok')
|
| 86 |
+
|
| 87 |
+
files_paths = glob.glob(params.test_data_path + "/*.h5")
|
| 88 |
+
files_paths.sort()
|
| 89 |
+
|
| 90 |
+
# which year
|
| 91 |
+
yr = 0
|
| 92 |
+
logging.info('Loading inference data')
|
| 93 |
+
logging.info('Inference data from {}'.format(files_paths[yr]))
|
| 94 |
+
climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
|
| 95 |
+
valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
|
| 96 |
+
valid_data_full = valid_data_full - climate_mean
|
| 97 |
+
|
| 98 |
+
return valid_data_full, model, model2
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def autoregressive_inference(params, init_condition, valid_data_full, model, model2):
|
| 102 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
| 103 |
+
|
| 104 |
+
icd = int(init_condition)
|
| 105 |
+
|
| 106 |
+
exp_dir = params['experiment_dir']
|
| 107 |
+
dt = int(params.dt)
|
| 108 |
+
prediction_length = int(params.prediction_length/dt)
|
| 109 |
+
n_history = params.n_history
|
| 110 |
+
img_shape_x = params.img_shape_x
|
| 111 |
+
img_shape_y = params.img_shape_y
|
| 112 |
+
in_channels = np.array(params.in_channels)
|
| 113 |
+
out_channels = np.array(params.out_channels)
|
| 114 |
+
atmos_channels = np.array(params.atmos_channels)
|
| 115 |
+
n_in_channels = len(in_channels)
|
| 116 |
+
n_out_channels = len(out_channels)
|
| 117 |
+
|
| 118 |
+
seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
| 119 |
+
seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
|
| 123 |
+
logging.info(f'valid_data_full: {valid_data_full.shape}')
|
| 124 |
+
logging.info(f'valid_data: {valid_data.shape}')
|
| 125 |
+
|
| 126 |
+
# normalize
|
| 127 |
+
if params.normalization == 'zscore':
|
| 128 |
+
valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
|
| 129 |
+
valid_data = np.nan_to_num(valid_data, nan=0)
|
| 130 |
+
|
| 131 |
+
valid_data = torch.as_tensor(valid_data)
|
| 132 |
+
|
| 133 |
+
# autoregressive inference
|
| 134 |
+
logging.info('Begin autoregressive inference')
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
for i in range(valid_data.shape[0]):
|
| 139 |
+
if i==0: # start of sequence, t0 --> t0'
|
| 140 |
+
first = valid_data[0:n_history+1]
|
| 141 |
+
ic(valid_data.shape, first.shape)
|
| 142 |
+
future = valid_data[n_history+1]
|
| 143 |
+
ic(future.shape)
|
| 144 |
+
|
| 145 |
+
for h in range(n_history+1):
|
| 146 |
+
|
| 147 |
+
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
|
| 148 |
+
|
| 149 |
+
seq_pred[h] = seq_real[h]
|
| 150 |
+
|
| 151 |
+
first = first.to(device, dtype=torch.float)
|
| 152 |
+
first_ocean = first[:, params.ocean_channels, :, :]
|
| 153 |
+
ic(first_ocean.shape)
|
| 154 |
+
future_force0 = first[:, params.atmos_channels, :, :]
|
| 155 |
+
|
| 156 |
+
future_force = future[params.atmos_channels, :360, :720]
|
| 157 |
+
future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
| 158 |
+
model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
|
| 159 |
+
ic(model_input.shape)
|
| 160 |
+
model1_future_pred = model(model_input)
|
| 161 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
| 162 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
| 163 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
| 164 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
if i < prediction_length-1:
|
| 168 |
+
future0 = valid_data[n_history+i]
|
| 169 |
+
future = valid_data[n_history+i+1]
|
| 170 |
+
|
| 171 |
+
inf_one_step_start = time.time()
|
| 172 |
+
future_force0 = future0[params.atmos_channels, :360, :720]
|
| 173 |
+
future_force = future[params.atmos_channels, :360, :720]
|
| 174 |
+
future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
|
| 175 |
+
future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
| 176 |
+
model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
|
| 177 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
| 178 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
| 179 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
| 180 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
| 181 |
+
inf_one_step_time = time.time() - inf_one_step_start
|
| 182 |
+
|
| 183 |
+
logging.info(f'inference one step time: {inf_one_step_time}')
|
| 184 |
+
|
| 185 |
+
if i < prediction_length - 1: # not on the last step
|
| 186 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
| 187 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
| 188 |
+
seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
|
| 189 |
+
seq_real[n_history+i+1] = future[:93]
|
| 190 |
+
history_stack = seq_pred[i+1:i+2+n_history]
|
| 191 |
+
|
| 192 |
+
future_pred = history_stack
|
| 193 |
+
|
| 194 |
+
pred = torch.unsqueeze(seq_pred[i], 0)
|
| 195 |
+
tar = torch.unsqueeze(seq_real[i], 0)
|
| 196 |
+
|
| 197 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
| 198 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
| 199 |
+
ic(mask_data.shape, pred.shape, tar.shape)
|
| 200 |
+
pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
|
| 201 |
+
tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
|
| 202 |
+
|
| 203 |
+
print(torch.mean((pred-tar)**2))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
| 207 |
+
seq_real = seq_real.numpy()
|
| 208 |
+
seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
| 209 |
+
seq_pred = seq_pred.numpy()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
return (np.expand_dims(seq_real[n_history:], 0),
|
| 213 |
+
np.expand_dims(seq_pred[n_history:], 0),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == '__main__':
|
| 218 |
+
parser = argparse.ArgumentParser()
|
| 219 |
+
parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
|
| 220 |
+
parser.add_argument("--config", default='full_field', type=str)
|
| 221 |
+
parser.add_argument("--run_num", default='00', type=str)
|
| 222 |
+
parser.add_argument("--prediction_length", default=61, type=int)
|
| 223 |
+
parser.add_argument("--finetune_dir", default='', type=str)
|
| 224 |
+
parser.add_argument("--ics_type", default='default', type=str)
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
|
| 227 |
+
config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
|
| 228 |
+
params = YParams(config_path, args.config)
|
| 229 |
+
|
| 230 |
+
params['resuming'] = False
|
| 231 |
+
params['interp'] = 0
|
| 232 |
+
params['world_size'] = 1
|
| 233 |
+
params['local_rank'] = 0
|
| 234 |
+
params['global_batch_size'] = params.batch_size
|
| 235 |
+
params['prediction_length'] = args.prediction_length
|
| 236 |
+
params['multi_steps_finetune'] = 1
|
| 237 |
+
|
| 238 |
+
torch.cuda.set_device(0)
|
| 239 |
+
torch.backends.cudnn.benchmark = True
|
| 240 |
+
|
| 241 |
+
# Set up directory
|
| 242 |
+
if args.finetune_dir == '':
|
| 243 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
|
| 244 |
+
else:
|
| 245 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
|
| 246 |
+
logging.info(f'expDir: {expDir}')
|
| 247 |
+
params['experiment_dir'] = expDir
|
| 248 |
+
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
|
| 249 |
+
params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
|
| 250 |
+
|
| 251 |
+
# set up logging
|
| 252 |
+
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
|
| 253 |
+
logging_utils.log_versions()
|
| 254 |
+
params.log()
|
| 255 |
+
|
| 256 |
+
if params["ics_type"] == 'default':
|
| 257 |
+
ics = np.arange(0, 240, 1)
|
| 258 |
+
n_ics = len(ics)
|
| 259 |
+
print('init_condition:', ics)
|
| 260 |
+
|
| 261 |
+
logging.info("Inference for {} initial conditions".format(n_ics))
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
autoregressive_inference_filetag = params["inference_file_tag"]
|
| 265 |
+
except:
|
| 266 |
+
autoregressive_inference_filetag = ""
|
| 267 |
+
if params.interp > 0:
|
| 268 |
+
autoregressive_inference_filetag = "_coarse"
|
| 269 |
+
|
| 270 |
+
valid_data_full, model, model2 = setup(params)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
seq_pred = []
|
| 274 |
+
seq_real = []
|
| 275 |
+
|
| 276 |
+
# run autoregressive inference for multiple initial conditions
|
| 277 |
+
for i, ic_ in enumerate(ics):
|
| 278 |
+
logging.info("Initial condition {} of {}".format(i+1, n_ics))
|
| 279 |
+
seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, model, model2)
|
| 280 |
+
|
| 281 |
+
prediction_length = seq_real[0].shape[0]
|
| 282 |
+
n_out_channels = seq_real[0].shape[1]
|
| 283 |
+
img_shape_x = seq_real[0].shape[2]
|
| 284 |
+
img_shape_y = seq_real[0].shape[3]
|
| 285 |
+
|
| 286 |
+
# save predictions and loss
|
| 287 |
+
save_path = os.path.join(params['experiment_dir'], 'results.h5')
|
| 288 |
+
logging.info("Saving to {}".format(save_path))
|
| 289 |
+
print(f'saving to {save_path}')
|
| 290 |
+
if i==0:
|
| 291 |
+
f = h5py.File(save_path, 'w')
|
| 292 |
+
f.create_dataset(
|
| 293 |
+
"ground_truth",
|
| 294 |
+
data=seq_real,
|
| 295 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
| 296 |
+
dtype=np.float32)
|
| 297 |
+
f.create_dataset(
|
| 298 |
+
"predicted",
|
| 299 |
+
data=seq_pred,
|
| 300 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
| 301 |
+
dtype=np.float32)
|
| 302 |
+
f.close()
|
| 303 |
+
else:
|
| 304 |
+
f = h5py.File(save_path, 'a')
|
| 305 |
+
|
| 306 |
+
f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
|
| 307 |
+
f["ground_truth"][-1:] = seq_real
|
| 308 |
+
|
| 309 |
+
f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
|
| 310 |
+
f["predicted"][-1:] = seq_pred
|
| 311 |
+
f.close()
|
| 312 |
+
|
inference.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prediction_length=61 # 31
|
| 2 |
+
|
| 3 |
+
exp_dir='./exp'
|
| 4 |
+
config='NeuralOM'
|
| 5 |
+
run_num='20250309-195251'
|
| 6 |
+
finetune_dir='6_steps_finetune'
|
| 7 |
+
|
| 8 |
+
ics_type='default'
|
| 9 |
+
|
| 10 |
+
CUDA_VISIBLE_DEVICES=0 python inference.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
my_utils/.ipynb_checkpoints/YParams-checkpoint.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
importlib.reload(sys)
|
| 7 |
+
|
| 8 |
+
from ruamel.yaml import YAML
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
class YParams():
|
| 12 |
+
""" Yaml file parser """
|
| 13 |
+
def __init__(self, yaml_filename, config_name, print_params=False):
|
| 14 |
+
self._yaml_filename = yaml_filename
|
| 15 |
+
self._config_name = config_name
|
| 16 |
+
self.params = {}
|
| 17 |
+
|
| 18 |
+
if print_params:
|
| 19 |
+
print(os.system('hostname'))
|
| 20 |
+
print("------------------ Configuration ------------------ ", yaml_filename)
|
| 21 |
+
|
| 22 |
+
with open(yaml_filename, 'rb') as _file:
|
| 23 |
+
yaml = YAML().load(_file)
|
| 24 |
+
for key, val in yaml[config_name].items():
|
| 25 |
+
if print_params: print(key, val)
|
| 26 |
+
if val =='None': val = None
|
| 27 |
+
|
| 28 |
+
self.params[key] = val
|
| 29 |
+
self.__setattr__(key, val)
|
| 30 |
+
|
| 31 |
+
if print_params:
|
| 32 |
+
print("---------------------------------------------------")
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, key):
|
| 35 |
+
return self.params[key]
|
| 36 |
+
|
| 37 |
+
def __setitem__(self, key, val):
|
| 38 |
+
self.params[key] = val
|
| 39 |
+
self.__setattr__(key, val)
|
| 40 |
+
|
| 41 |
+
def __contains__(self, key):
|
| 42 |
+
return (key in self.params)
|
| 43 |
+
|
| 44 |
+
def update_params(self, config):
|
| 45 |
+
for key, val in config.items():
|
| 46 |
+
self.params[key] = val
|
| 47 |
+
self.__setattr__(key, val)
|
| 48 |
+
|
| 49 |
+
def log(self):
|
| 50 |
+
logging.info("------------------ Configuration ------------------")
|
| 51 |
+
logging.info("Configuration file: "+str(self._yaml_filename))
|
| 52 |
+
logging.info("Configuration name: "+str(self._config_name))
|
| 53 |
+
for key, val in self.params.items():
|
| 54 |
+
logging.info(str(key) + ' ' + str(val))
|
| 55 |
+
logging.info("---------------------------------------------------")
|
my_utils/.ipynb_checkpoints/data_loader-checkpoint.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import h5py
|
| 10 |
+
import math
|
| 11 |
+
from my_utils.norm import reshape_fields
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
parent_dir = os.path.dirname(current_dir)
|
| 17 |
+
climate_mean_path = os.path.join(parent_dir, 'data/climate_mean_s_t_ssh.npy')
|
| 18 |
+
|
| 19 |
+
def get_data_loader(params, files_pattern, distributed, train):
|
| 20 |
+
dataset = GetDataset(params, files_pattern, train)
|
| 21 |
+
sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
dataloader = DataLoader(dataset,
|
| 25 |
+
batch_size = int(params.batch_size),
|
| 26 |
+
num_workers = params.num_data_workers,
|
| 27 |
+
shuffle = False,
|
| 28 |
+
sampler = sampler if train else None,
|
| 29 |
+
drop_last = True,
|
| 30 |
+
pin_memory = True)
|
| 31 |
+
|
| 32 |
+
if train:
|
| 33 |
+
return dataloader, dataset, sampler
|
| 34 |
+
else:
|
| 35 |
+
return dataloader, dataset
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GetDataset(Dataset):
|
| 39 |
+
def __init__(self, params, location, train):
|
| 40 |
+
self.params = params
|
| 41 |
+
self.location = location
|
| 42 |
+
self.train = train
|
| 43 |
+
self.orography = params.orography
|
| 44 |
+
self.normalize = params.normalize
|
| 45 |
+
self.dt = params.dt
|
| 46 |
+
self.n_history = params.n_history
|
| 47 |
+
self.in_channels = np.array(params.in_channels)
|
| 48 |
+
self.out_channels = np.array(params.out_channels)
|
| 49 |
+
self.ocean_channels = np.array(params.ocean_channels)
|
| 50 |
+
self.atmos_channels = np.array(params.atmos_channels)
|
| 51 |
+
self.n_in_channels = len(self.in_channels)
|
| 52 |
+
self.n_out_channels = len(self.out_channels)
|
| 53 |
+
|
| 54 |
+
self._get_files_stats()
|
| 55 |
+
self.add_noise = params.add_noise if train else False
|
| 56 |
+
self.climate_mean = np.load(climate_mean_path, mmap_mode='r')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_files_stats(self):
|
| 60 |
+
self.files_paths = glob.glob(self.location + "/*.h5")
|
| 61 |
+
self.files_paths.sort()
|
| 62 |
+
self.n_years = len(self.files_paths)
|
| 63 |
+
|
| 64 |
+
with h5py.File(self.files_paths[0], 'r') as _f:
|
| 65 |
+
logging.info("Getting file stats from {}".format(self.files_paths[0]))
|
| 66 |
+
|
| 67 |
+
self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
|
| 68 |
+
|
| 69 |
+
self.img_shape_x = _f['fields'].shape[2] - 1
|
| 70 |
+
self.img_shape_y = _f['fields'].shape[3]
|
| 71 |
+
|
| 72 |
+
self.n_samples_total = self.n_years * self.n_samples_per_year
|
| 73 |
+
self.files = [None for _ in range(self.n_years)]
|
| 74 |
+
|
| 75 |
+
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
|
| 76 |
+
logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
|
| 77 |
+
self.n_samples_total,
|
| 78 |
+
self.img_shape_x,
|
| 79 |
+
self.img_shape_y,
|
| 80 |
+
self.n_in_channels))
|
| 81 |
+
logging.info("Delta t: {} days".format(1 * self.dt))
|
| 82 |
+
logging.info("Including {} days of past history in training at a frequency of {} days".format(
|
| 83 |
+
1 * self.dt * self.n_history, 1 * self.dt))
|
| 84 |
+
|
| 85 |
+
def _open_file(self, year_idx):
|
| 86 |
+
_file = h5py.File(self.files_paths[year_idx], 'r')
|
| 87 |
+
self.files[year_idx] = _file['fields']
|
| 88 |
+
|
| 89 |
+
if self.orography and self.params.normalization == 'zscore':
|
| 90 |
+
_orog_file = h5py.File(self.params.orography_norm_zscore_path, 'r')
|
| 91 |
+
if self.orography and self.params.normalization == 'maxmin':
|
| 92 |
+
_orog_file = h5py.File(self.params.orography_norm_maxmin_path, 'r')
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return self.n_samples_total
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, global_idx):
|
| 98 |
+
year_idx = int(global_idx / self.n_samples_per_year) # which year
|
| 99 |
+
local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
|
| 100 |
+
|
| 101 |
+
if self.files[year_idx] is None:
|
| 102 |
+
self._open_file(year_idx)
|
| 103 |
+
|
| 104 |
+
if local_idx < self.dt * self.n_history:
|
| 105 |
+
local_idx += self.dt * self.n_history
|
| 106 |
+
|
| 107 |
+
step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
|
| 108 |
+
|
| 109 |
+
orog = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if self.params.multi_steps_finetune == 1:
|
| 113 |
+
if local_idx == 365:
|
| 114 |
+
local_idx = 364
|
| 115 |
+
|
| 116 |
+
climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
|
| 117 |
+
ocean = reshape_fields(
|
| 118 |
+
self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
|
| 119 |
+
'ocean',
|
| 120 |
+
self.params,
|
| 121 |
+
self.train,
|
| 122 |
+
self.normalize,
|
| 123 |
+
orog,
|
| 124 |
+
self.add_noise
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
force_future0 = reshape_fields(
|
| 128 |
+
self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
|
| 129 |
+
'force',
|
| 130 |
+
self.params,
|
| 131 |
+
self.train,
|
| 132 |
+
self.normalize,
|
| 133 |
+
orog,
|
| 134 |
+
self.add_noise
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
force_future1 = reshape_fields(
|
| 138 |
+
self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
|
| 139 |
+
'force',
|
| 140 |
+
self.params,
|
| 141 |
+
self.train,
|
| 142 |
+
self.normalize,
|
| 143 |
+
orog,
|
| 144 |
+
self.add_noise
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
climate_mean_tar = self.climate_mean[local_idx+step, self.out_channels, :360, :720]
|
| 148 |
+
tar = reshape_fields(
|
| 149 |
+
self.files[year_idx][local_idx+step, self.out_channels, :360, :720] - climate_mean_tar,
|
| 150 |
+
'tar',
|
| 151 |
+
self.params,
|
| 152 |
+
self.train,
|
| 153 |
+
self.normalize,
|
| 154 |
+
orog
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
|
| 158 |
+
ocean = reshape_fields(
|
| 159 |
+
self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
|
| 160 |
+
'ocean',
|
| 161 |
+
self.params,
|
| 162 |
+
self.train,
|
| 163 |
+
self.normalize,
|
| 164 |
+
orog,
|
| 165 |
+
self.add_noise
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
force_future0 = reshape_fields(
|
| 169 |
+
self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
|
| 170 |
+
'force',
|
| 171 |
+
self.params,
|
| 172 |
+
self.train,
|
| 173 |
+
self.normalize,
|
| 174 |
+
orog,
|
| 175 |
+
self.add_noise
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
force_future1 = reshape_fields(
|
| 179 |
+
self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
|
| 180 |
+
'force',
|
| 181 |
+
self.params,
|
| 182 |
+
self.train,
|
| 183 |
+
self.normalize,
|
| 184 |
+
orog,
|
| 185 |
+
self.add_noise
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
climate_mean_tar = self.climate_mean[local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
|
| 189 |
+
tar_data = self.files[year_idx][local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
|
| 190 |
+
tar = reshape_fields(
|
| 191 |
+
tar_data - climate_mean_tar,
|
| 192 |
+
'inp',
|
| 193 |
+
self.params,
|
| 194 |
+
self.train,
|
| 195 |
+
self.normalize,
|
| 196 |
+
orog
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
ocean = np.nan_to_num(ocean, nan=0)
|
| 200 |
+
force_future0 = np.nan_to_num(force_future0, nan=0)
|
| 201 |
+
force_future1 = np.nan_to_num(force_future1, nan=0)
|
| 202 |
+
tar = np.nan_to_num(tar, nan=0)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
return np.concatenate((ocean, force_future0, force_future1), axis=0), tar
|
my_utils/.ipynb_checkpoints/logging_utils-checkpoint.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 5 |
+
|
| 6 |
+
def config_logger(log_level=logging.INFO):
|
| 7 |
+
logging.basicConfig(format=_format, level=log_level)
|
| 8 |
+
|
| 9 |
+
def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
|
| 10 |
+
|
| 11 |
+
if not os.path.exists(os.path.dirname(log_filename)):
|
| 12 |
+
os.makedirs(os.path.dirname(log_filename))
|
| 13 |
+
|
| 14 |
+
if logger_name is not None:
|
| 15 |
+
log = logging.getLogger(logger_name)
|
| 16 |
+
else:
|
| 17 |
+
log = logging.getLogger()
|
| 18 |
+
|
| 19 |
+
fh = logging.FileHandler(log_filename)
|
| 20 |
+
fh.setLevel(log_level)
|
| 21 |
+
fh.setFormatter(logging.Formatter(_format))
|
| 22 |
+
log.addHandler(fh)
|
| 23 |
+
|
| 24 |
+
def log_versions():
|
| 25 |
+
import torch
|
| 26 |
+
import subprocess
|
my_utils/.ipynb_checkpoints/norm-checkpoint.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
from types import new_class
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import random
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import h5py
|
| 14 |
+
import math
|
| 15 |
+
import torchvision.transforms.functional as TF
|
| 16 |
+
# import matplotlib
|
| 17 |
+
# import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
class PeriodicPad2d(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
pad longitudinal (left-right) circular
|
| 22 |
+
and pad latitude (top-bottom) with zeros
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, pad_width):
|
| 25 |
+
super(PeriodicPad2d, self).__init__()
|
| 26 |
+
self.pad_width = pad_width
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
# pad left and right circular
|
| 30 |
+
out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
|
| 31 |
+
# pad top and bottom zeros
|
| 32 |
+
out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
|
| 33 |
+
return out
|
| 34 |
+
|
| 35 |
+
def reshape_fields(img, inp_or_tar, params, train, normalize=True, orog=None, add_noise=False):
|
| 36 |
+
# Takes in np array of size (n_history+1, c, h, w)
|
| 37 |
+
# returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
|
| 38 |
+
|
| 39 |
+
if len(np.shape(img)) == 3:
|
| 40 |
+
img = np.expand_dims(img, 0)
|
| 41 |
+
|
| 42 |
+
if np.shape(img)[2] == 721:
|
| 43 |
+
img = img[:,:, 0:720, :] # remove last pixel
|
| 44 |
+
|
| 45 |
+
n_history = np.shape(img)[0] - 1
|
| 46 |
+
img_shape_x = np.shape(img)[-2]
|
| 47 |
+
img_shape_y = np.shape(img)[-1]
|
| 48 |
+
n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
|
| 49 |
+
|
| 50 |
+
if inp_or_tar == 'inp':
|
| 51 |
+
channels = params.in_channels
|
| 52 |
+
elif inp_or_tar == 'ocean':
|
| 53 |
+
channels = params.ocean_channels
|
| 54 |
+
elif inp_or_tar == 'force':
|
| 55 |
+
channels = params.atmos_channels
|
| 56 |
+
else:
|
| 57 |
+
channels = params.out_channels
|
| 58 |
+
|
| 59 |
+
if normalize and params.normalization == 'minmax':
|
| 60 |
+
maxs = np.load(params.global_maxs_path)[:, channels]
|
| 61 |
+
mins = np.load(params.global_mins_path)[:, channels]
|
| 62 |
+
img = (img - mins) / (maxs - mins)
|
| 63 |
+
|
| 64 |
+
if normalize and params.normalization == 'zscore':
|
| 65 |
+
means = np.load(params.global_means_path)[:, channels]
|
| 66 |
+
stds = np.load(params.global_stds_path)[:, channels]
|
| 67 |
+
img -=means
|
| 68 |
+
img /=stds
|
| 69 |
+
|
| 70 |
+
if normalize and params.normalization == 'zscore_lat':
|
| 71 |
+
means = np.load(params.global_lat_means_path)[:, channels,:720]
|
| 72 |
+
stds = np.load(params.global_lat_stds_path)[:, channels,:720]
|
| 73 |
+
img -=means
|
| 74 |
+
img /=stds
|
| 75 |
+
|
| 76 |
+
if params.orography and inp_or_tar == 'inp':
|
| 77 |
+
# print('img:', img.shape, 'orog:', orog.shape)
|
| 78 |
+
orog = np.expand_dims(orog, axis = (0,1))
|
| 79 |
+
orog = np.repeat(orog, repeats=img.shape[0], axis=0)
|
| 80 |
+
# print('img:', img.shape, 'orog:', orog.shape)
|
| 81 |
+
img = np.concatenate((img, orog), axis = 1)
|
| 82 |
+
n_channels += 1
|
| 83 |
+
|
| 84 |
+
img = np.squeeze(img)
|
| 85 |
+
# if inp_or_tar == 'inp':
|
| 86 |
+
# img = np.reshape(img, (n_channels*(n_history+1))) # ??
|
| 87 |
+
# elif inp_or_tar == 'tar':
|
| 88 |
+
# img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) #??
|
| 89 |
+
|
| 90 |
+
if add_noise:
|
| 91 |
+
img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)
|
| 92 |
+
|
| 93 |
+
return torch.as_tensor(img)
|
| 94 |
+
|
| 95 |
+
def vis_precip(fields):
|
| 96 |
+
pred, tar = fields
|
| 97 |
+
fig, ax = plt.subplots(1, 2, figsize=(24,12))
|
| 98 |
+
ax[0].imshow(pred, cmap="coolwarm")
|
| 99 |
+
ax[0].set_title("tp pred")
|
| 100 |
+
ax[1].imshow(tar, cmap="coolwarm")
|
| 101 |
+
ax[1].set_title("tp tar")
|
| 102 |
+
fig.tight_layout()
|
| 103 |
+
return fig
|
| 104 |
+
|
| 105 |
+
def read_max_min_value(min_max_val_file_path):
|
| 106 |
+
with h5py.File(min_max_val_file_path, 'r') as f:
|
| 107 |
+
max_values = f['max_values']
|
| 108 |
+
min_values = f['min_values']
|
| 109 |
+
return max_values, min_values
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
my_utils/YParams.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
importlib.reload(sys)
|
| 7 |
+
|
| 8 |
+
from ruamel.yaml import YAML
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
class YParams():
|
| 12 |
+
""" Yaml file parser """
|
| 13 |
+
def __init__(self, yaml_filename, config_name, print_params=False):
|
| 14 |
+
self._yaml_filename = yaml_filename
|
| 15 |
+
self._config_name = config_name
|
| 16 |
+
self.params = {}
|
| 17 |
+
|
| 18 |
+
if print_params:
|
| 19 |
+
print(os.system('hostname'))
|
| 20 |
+
print("------------------ Configuration ------------------ ", yaml_filename)
|
| 21 |
+
|
| 22 |
+
with open(yaml_filename, 'rb') as _file:
|
| 23 |
+
yaml = YAML().load(_file)
|
| 24 |
+
for key, val in yaml[config_name].items():
|
| 25 |
+
if print_params: print(key, val)
|
| 26 |
+
if val =='None': val = None
|
| 27 |
+
|
| 28 |
+
self.params[key] = val
|
| 29 |
+
self.__setattr__(key, val)
|
| 30 |
+
|
| 31 |
+
if print_params:
|
| 32 |
+
print("---------------------------------------------------")
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, key):
|
| 35 |
+
return self.params[key]
|
| 36 |
+
|
| 37 |
+
def __setitem__(self, key, val):
|
| 38 |
+
self.params[key] = val
|
| 39 |
+
self.__setattr__(key, val)
|
| 40 |
+
|
| 41 |
+
def __contains__(self, key):
|
| 42 |
+
return (key in self.params)
|
| 43 |
+
|
| 44 |
+
def update_params(self, config):
|
| 45 |
+
for key, val in config.items():
|
| 46 |
+
self.params[key] = val
|
| 47 |
+
self.__setattr__(key, val)
|
| 48 |
+
|
| 49 |
+
def log(self):
|
| 50 |
+
logging.info("------------------ Configuration ------------------")
|
| 51 |
+
logging.info("Configuration file: "+str(self._yaml_filename))
|
| 52 |
+
logging.info("Configuration name: "+str(self._config_name))
|
| 53 |
+
for key, val in self.params.items():
|
| 54 |
+
logging.info(str(key) + ' ' + str(val))
|
| 55 |
+
logging.info("---------------------------------------------------")
|
my_utils/__pycache__/YParams.cpython-310.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
my_utils/__pycache__/YParams.cpython-37.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
my_utils/__pycache__/YParams.cpython-39.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
my_utils/__pycache__/bicubic.cpython-310.pyc
ADDED
|
Binary file (9.24 kB). View file
|
|
|
my_utils/__pycache__/bicubic.cpython-39.pyc
ADDED
|
Binary file (9.2 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304
ADDED
|
Binary file (13.5 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584
ADDED
|
Binary file (13.5 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808
ADDED
|
Binary file (13.5 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-37.pyc
ADDED
|
Binary file (9.02 kB). View file
|
|
|
my_utils/__pycache__/darcy_loss.cpython-39.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|