Upload 205 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- Exp3_Kuroshio_forecasting/.DS_Store +0 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_ConvLSTM_exp1_20250311_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250221_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250222_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250223_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250224_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250224_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250316_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp1_20250226_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp2_20250225_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp1_20250224_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp_128_20250324_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp1_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp2_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_128_20250322_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250221_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250222_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model_prediction.h5 +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_64_20250323_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp2_20241107_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241107_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241111_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp_20241107_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_multi_finetune_20250227_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250225_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250226_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp2_20250226_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Kuro_Unet_exp_128_20250324_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Triton_Gulf_uv_20250218_exp1_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/Triton_Kuroshio_uv_20250218_exp1_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/checkpoints/dit_kuro_256_20250227_best_model.pth +3 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif +3 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.ipynb +397 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.py +122 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_high_kuro-checkpoint.py +82 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.ipynb +209 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.py +69 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_G_uv-checkpoint.py +69 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_K_uv-checkpoint.py +69 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi-checkpoint.py +134 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_128-checkpoint.py +134 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_64-checkpoint.py +134 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_single-checkpoint.py +96 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_test-checkpoint.ipynb +6 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif +3 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuro_vis-checkpoint.ipynb +0 -0
- Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/ocean_currents_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/sample_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Exp3_Kuroshio_forecasting/plt_triton/nmi_vis.ipynb filter=lfs diff=lfs merge=lfs -text
|
Exp3_Kuroshio_forecasting/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_ConvLSTM_exp1_20250311_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68e12e494d088a481aa995c63c709dc208890abe9da52ac6f72742981a5658cc
|
| 3 |
+
size 11610
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250221_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63595e507a023ffaf26d6d1d7d3ec7b4f3dadbdf87e5c7881b8e1c9bc598ee83
|
| 3 |
+
size 75550190
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250222_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f39ba8c5caaef358aba27d54b8ef392a5a51d4836480753dba4d898565a13a94
|
| 3 |
+
size 75550056
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250223_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e00a1db14d32761d9f1ea1660dbd64c7c33d23f09358be805292b587c1c71dda
|
| 3 |
+
size 75550200
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250224_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e15e71e84d50c3bd848dd1bd1a888c7396deade2855f0f334fd7edb19a763b80
|
| 3 |
+
size 75550204
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250224_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6bef4b0385afabcbb6a0c677a568e378de0b76a52a06dfbafd071a5bae24591
|
| 3 |
+
size 75550709
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250316_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f708b01a88610e1517094eb1da50ea24d99d7e41a5cac7638aaecfc5fc0b9cee
|
| 3 |
+
size 63095505
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp1_20250226_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe9e87f65bd7b0e3ea3f25d826332065c787b3ef8c0479b18bf13701a6ede152
|
| 3 |
+
size 529476
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp2_20250225_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b0ea4a92f7be963bfb50bb4c6d8976fb98b3f6a2236c351ffcccbe03239909a
|
| 3 |
+
size 99835562
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp1_20250224_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:489e579df4993f8b7e24606dc7773f6aae4823cc5fbec0ad31d32be4b304ca5c
|
| 3 |
+
size 19040548
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp_128_20250324_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:613ba5274130915e187780bb4b29586a0dcb2f991a94e7696ec20468bb07f97d
|
| 3 |
+
size 19040464
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp1_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e3f0a2d53bb67432564819e7e08fa35c15f46a898e0d8056cfbc3fdf78c8703
|
| 3 |
+
size 378552683
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp2_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9502f5eab2ce1a8ce6cc8961e9fafb201e62f0eed638ff1f3c6894dc1103cfdf
|
| 3 |
+
size 378552813
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_128_20250322_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d00a458f45ea7f8253f71ec5fb54aee5b4db91bc3b206d7c9c1a6a9f6e61f884
|
| 3 |
+
size 378552684
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250221_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:671add7c61a1efa5544703e212dd6aa5845107977578ba0333d8772524daa301
|
| 3 |
+
size 397465196
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250222_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c31f88b68ded17d0abee7bb76f0dfbc73b94173839b866c81b592cbd407c208d
|
| 3 |
+
size 397465203
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:940b8236f041d3c847412a62cdac18e7f7cccc5f04059dc632d03167ca781760
|
| 3 |
+
size 378552801
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model_prediction.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37f30d197d020780343ab8c9054d2d5943d2560bd7b655bb22874250217d398f
|
| 3 |
+
size 68163584
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_64_20250323_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c29fbe97c5e335b97a2cba5c2d63f8e46b26c7ee8df8e95a001e9d6180496b8e
|
| 3 |
+
size 378552678
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp2_20241107_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7abe16e33f1e63941f77a02f136639d9a928d41c63f935feb40b161e7a468c6b
|
| 3 |
+
size 378552823
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241107_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d8e14d8f13ab5530b1129ef408b64260c99c4acd4974d80f6a24925fbdf9c16
|
| 3 |
+
size 378552675
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241111_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a7c63b24f362dd6b929f1fc003949f843a8229040c12f94f54c90b22d0c16fe
|
| 3 |
+
size 378556268
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp_20241107_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e54888a509a62f9805009527d153209b16c9861ea3226619523e92b3b879672
|
| 3 |
+
size 378552694
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_multi_finetune_20250227_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58c4db8010b0f292ac4bc93df237cf107f4f21643317a648c47c33e650b216c9
|
| 3 |
+
size 432755108
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250225_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d4a08539f0188ea40e21b4b2b189ef3c24dd44ab683bed220525e3c84681927
|
| 3 |
+
size 99835639
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250226_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:928f8e6b4d14093c2f5a007d3a2b7bcec4e34adbca102b646db884ad19e61e10
|
| 3 |
+
size 124189508
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp2_20250226_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64bbfd30c0a223c0a7bf002291b34bdf62eb7de532321e0a1305b420ebbff8e6
|
| 3 |
+
size 99901810
|
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Unet_exp_128_20250324_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6fbbf634b2e46b116f94759b6ee669d508f8c03b8caba95b8f4eb713291a0dc
|
| 3 |
+
size 30872161
|
Exp3_Kuroshio_forecasting/checkpoints/Triton_Gulf_uv_20250218_exp1_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f94637d268f4bcbb77bda05bd7cf32cce1390ecdf56bbfd75f9f8cc6a2202eee
|
| 3 |
+
size 378552693
|
Exp3_Kuroshio_forecasting/checkpoints/Triton_Kuroshio_uv_20250218_exp1_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d9b7d665a91e73d33ecd27bf5dcbfcb07c104cfd6eb33c442726c30b96bd2cae
|
| 3 |
+
size 378552695
|
Exp3_Kuroshio_forecasting/checkpoints/dit_kuro_256_20250227_best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08389ce9e69da332168b798f6790544b3e9ff6c1fa8432c320d83a4d973ae1f7
|
| 3 |
+
size 63092615
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif
ADDED
|
Git LFS Details
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "f3b16ba8-ad82-45c1-8119-b6c61e7311b8",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Input shape: torch.Size([32, 10, 2, 128, 128])\n",
|
| 14 |
+
"Target shape: torch.Size([32, 5, 2, 128, 128])\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"import h5py\n",
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"import torch\n",
|
| 22 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"class KuroshioDataset(Dataset):\n",
|
| 25 |
+
" def __init__(self, data, input_length, target_length, downsample_factor=1):\n",
|
| 26 |
+
" \"\"\"\n",
|
| 27 |
+
" Args:\n",
|
| 28 |
+
" data: Tensor of shape (num_samples, num_timesteps, C, H, W)\n",
|
| 29 |
+
" input_length: Number of input time steps (T_in)\n",
|
| 30 |
+
" target_length: Number of prediction time steps (T_out)\n",
|
| 31 |
+
" downsample_factor: Spatial downsampling factor\n",
|
| 32 |
+
" \"\"\"\n",
|
| 33 |
+
" super().__init__()\n",
|
| 34 |
+
" self.data = data\n",
|
| 35 |
+
" self.input_length = input_length\n",
|
| 36 |
+
" self.target_length = target_length\n",
|
| 37 |
+
" self.downsample_factor = downsample_factor\n",
|
| 38 |
+
"\n",
|
| 39 |
+
" # Validate time dimensions\n",
|
| 40 |
+
" self.num_samples, self.num_timesteps, self.C, self.H, self.W = data.shape\n",
|
| 41 |
+
" self.max_t_start = self.num_timesteps - self.input_length - self.target_length\n",
|
| 42 |
+
" assert self.max_t_start >= 0, \"Not enough timesteps for input and output\"\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" # Generate sample indices (sample_idx, t_start)\n",
|
| 45 |
+
" self.sample_indices = []\n",
|
| 46 |
+
" for s in range(self.num_samples):\n",
|
| 47 |
+
" for t_start in range(self.max_t_start + 1):\n",
|
| 48 |
+
" self.sample_indices.append((s, t_start))\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" def __len__(self):\n",
|
| 51 |
+
" return len(self.sample_indices)\n",
|
| 52 |
+
"\n",
|
| 53 |
+
" def __getitem__(self, idx):\n",
|
| 54 |
+
" s, t_start = self.sample_indices[idx]\n",
|
| 55 |
+
" \n",
|
| 56 |
+
" # Extract sequences\n",
|
| 57 |
+
" input_end = t_start + self.input_length\n",
|
| 58 |
+
" output_end = input_end + self.target_length\n",
|
| 59 |
+
" \n",
|
| 60 |
+
" input_seq = self.data[s, t_start:input_end] # (T_in, C, H, W)\n",
|
| 61 |
+
" target_seq = self.data[s, input_end:output_end] # (T_out, C, H, W)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
" # Apply downsampling\n",
|
| 64 |
+
" if self.downsample_factor > 1:\n",
|
| 65 |
+
" dsf = self.downsample_factor\n",
|
| 66 |
+
" input_seq = input_seq[..., ::dsf, ::dsf]\n",
|
| 67 |
+
" target_seq = target_seq[..., ::dsf, ::dsf]\n",
|
| 68 |
+
"\n",
|
| 69 |
+
" return input_seq.float(), target_seq.float()\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"def load_datasets(file_path, args):\n",
|
| 72 |
+
" # Load and preprocess data\n",
|
| 73 |
+
" with h5py.File(file_path, 'r') as f:\n",
|
| 74 |
+
" u_k = np.transpose(f['u_k'][:], (0, 3, 1, 2)) # (2046, 50, 128, 128)\n",
|
| 75 |
+
" v_k = np.transpose(f['v_k'][:], (0, 3, 1, 2))\n",
|
| 76 |
+
" \n",
|
| 77 |
+
" # Combine u and v channels\n",
|
| 78 |
+
" combined = np.stack([u_k, v_k], axis=2) # (2046, 50, 2, 128, 128)\n",
|
| 79 |
+
" data_tensor = torch.tensor(combined, dtype=torch.float32)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" # Split dataset\n",
|
| 82 |
+
" total_samples = 2046\n",
|
| 83 |
+
" train_size = int(0.8 * total_samples)\n",
|
| 84 |
+
" val_size = int(0.1 * total_samples)\n",
|
| 85 |
+
" \n",
|
| 86 |
+
" train_data = data_tensor[:train_size]\n",
|
| 87 |
+
" val_data = data_tensor[train_size:train_size+val_size]\n",
|
| 88 |
+
" test_data = data_tensor[train_size+val_size:]\n",
|
| 89 |
+
"\n",
|
| 90 |
+
" # Create datasets\n",
|
| 91 |
+
" train_dataset = KuroshioDataset(train_data, \n",
|
| 92 |
+
" args['input_length'],\n",
|
| 93 |
+
" args['target_length'],\n",
|
| 94 |
+
" args['downsample_factor'])\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" val_dataset = KuroshioDataset(val_data,\n",
|
| 97 |
+
" args['input_length'],\n",
|
| 98 |
+
" args['target_length'],\n",
|
| 99 |
+
" args['downsample_factor'])\n",
|
| 100 |
+
" \n",
|
| 101 |
+
" test_dataset = KuroshioDataset(test_data,\n",
|
| 102 |
+
" args['input_length'],\n",
|
| 103 |
+
" args['target_length'],\n",
|
| 104 |
+
" args['downsample_factor'])\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" return train_dataset, val_dataset, test_dataset\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"# Example usage\n",
|
| 109 |
+
"if __name__ == \"__main__\":\n",
|
| 110 |
+
" config = {\n",
|
| 111 |
+
" 'input_length': 10, # T_in: 输入时间步数\n",
|
| 112 |
+
" 'target_length': 5, # T_out: 预测时间步数\n",
|
| 113 |
+
" 'downsample_factor': 1 # 空间下采样因子\n",
|
| 114 |
+
" }\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" # 加载数据集\n",
|
| 117 |
+
" train_ds, val_ds, test_ds = load_datasets('./Kuroshio_window_data.h5', config)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
" # 创建DataLoader\n",
|
| 120 |
+
" batch_size = 32\n",
|
| 121 |
+
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
|
| 122 |
+
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
|
| 123 |
+
" test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
" # 验证数据形状\n",
|
| 126 |
+
" sample_input, sample_target = next(iter(train_loader))\n",
|
| 127 |
+
" print(f\"Input shape: {sample_input.shape}\") # 应为 (B, T_in, 2, H, W)\n",
|
| 128 |
+
" print(f\"Target shape: {sample_target.shape}\") # 应为 (B, T_out, 2, H, W)"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": 4,
|
| 134 |
+
"id": "9c6b1e5c-7874-49f2-9004-c17470a3ae85",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [
|
| 137 |
+
{
|
| 138 |
+
"name": "stdout",
|
| 139 |
+
"output_type": "stream",
|
| 140 |
+
"text": [
|
| 141 |
+
"可视化已保存为 kuroshio_animation.gif\n"
|
| 142 |
+
]
|
| 143 |
+
}
|
| 144 |
+
],
|
| 145 |
+
"source": [
|
| 146 |
+
"import h5py\n",
|
| 147 |
+
"import numpy as np\n",
|
| 148 |
+
"import torch\n",
|
| 149 |
+
"import matplotlib.pyplot as plt\n",
|
| 150 |
+
"import matplotlib.animation as animation\n",
|
| 151 |
+
"from matplotlib import gridspec\n",
|
| 152 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"# 修正后的可视化函数\n",
|
| 156 |
+
"def create_visualization(input_seq, target_seq, sample_idx=0, downsample=4, fps=10):\n",
|
| 157 |
+
" # 数据准备\n",
|
| 158 |
+
" input_np = input_seq[sample_idx].cpu().numpy()\n",
|
| 159 |
+
" target_np = target_seq[sample_idx].cpu().numpy()\n",
|
| 160 |
+
" full_seq = np.concatenate([input_np, target_np], axis=0)\n",
|
| 161 |
+
" full_seq = np.transpose(full_seq, (0, 2, 3, 1)) # [T, H, W, C]\n",
|
| 162 |
+
" \n",
|
| 163 |
+
" # 创建网格\n",
|
| 164 |
+
" H, W = full_seq.shape[1], full_seq.shape[2]\n",
|
| 165 |
+
" x = np.arange(W)\n",
|
| 166 |
+
" y = np.arange(H)\n",
|
| 167 |
+
" X, Y = np.meshgrid(x, y)\n",
|
| 168 |
+
" X_ds, Y_ds = X[::downsample, ::downsample], Y[::downsample, ::downsample]\n",
|
| 169 |
+
" \n",
|
| 170 |
+
" # 计算速度幅值\n",
|
| 171 |
+
" speed = np.sqrt(full_seq[...,0]**2 + full_seq[...,1]**2)\n",
|
| 172 |
+
" speed_min, speed_max = speed.min(), speed.max()\n",
|
| 173 |
+
" \n",
|
| 174 |
+
" # 创建画布\n",
|
| 175 |
+
" fig = plt.figure(figsize=(15, 5), facecolor='white')\n",
|
| 176 |
+
" gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1])\n",
|
| 177 |
+
" ax1 = plt.subplot(gs[0])\n",
|
| 178 |
+
" ax2 = plt.subplot(gs[1])\n",
|
| 179 |
+
" ax3 = plt.subplot(gs[2])\n",
|
| 180 |
+
" \n",
|
| 181 |
+
" # 初始化子图\n",
|
| 182 |
+
" im1 = ax1.imshow(full_seq[0,...,0], origin='lower', cmap='RdBu_r', vmax=1, vmin=-1)\n",
|
| 183 |
+
" ax1.set_title(\"U Component\")\n",
|
| 184 |
+
" plt.colorbar(im1, ax=ax1)\n",
|
| 185 |
+
" \n",
|
| 186 |
+
" im2 = ax2.imshow(full_seq[0,...,1], origin='lower', cmap='RdBu_r', vmax=1, vmin=-1)\n",
|
| 187 |
+
" ax2.set_title(\"V Component\")\n",
|
| 188 |
+
" plt.colorbar(im2, ax=ax2)\n",
|
| 189 |
+
" \n",
|
| 190 |
+
" # 初始化矢量场\n",
|
| 191 |
+
" U = full_seq[0,...,0][::downsample, ::downsample]\n",
|
| 192 |
+
" V = full_seq[0,...,1][::downsample, ::downsample]\n",
|
| 193 |
+
" speed_initial = np.sqrt(U**2 + V**2)\n",
|
| 194 |
+
" quiver = ax3.quiver(X_ds, Y_ds, U, V, speed_initial, \n",
|
| 195 |
+
" cmap='RdBu_r', \n",
|
| 196 |
+
" scale=50, \n",
|
| 197 |
+
" width=0.003,\n",
|
| 198 |
+
" clim=[speed_min, speed_max])\n",
|
| 199 |
+
" plt.colorbar(quiver, ax=ax3, label='Flow Speed')\n",
|
| 200 |
+
" ax3.set_title(\"Vector Field\")\n",
|
| 201 |
+
" \n",
|
| 202 |
+
" # 统一设置\n",
|
| 203 |
+
" for ax in [ax1, ax2, ax3]:\n",
|
| 204 |
+
" ax.set_xticks([])\n",
|
| 205 |
+
" ax.set_yticks([])\n",
|
| 206 |
+
" ax.set_xlabel(f\"Timestep: 0/{full_seq.shape[0]-1}\")\n",
|
| 207 |
+
" \n",
|
| 208 |
+
" # 动画更新函数\n",
|
| 209 |
+
" def update(frame):\n",
|
| 210 |
+
" # 更新分量图\n",
|
| 211 |
+
" im1.set_data(full_seq[frame,...,0])\n",
|
| 212 |
+
" im2.set_data(full_seq[frame,...,1])\n",
|
| 213 |
+
" \n",
|
| 214 |
+
" # 更新矢量场\n",
|
| 215 |
+
" U = full_seq[frame,...,0][::downsample, ::downsample]\n",
|
| 216 |
+
" V = full_seq[frame,...,1][::downsample, ::downsample]\n",
|
| 217 |
+
" speed = np.sqrt(U**2 + V**2)\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" quiver.set_UVC(U, V)\n",
|
| 220 |
+
" quiver.set_array(speed.flatten())\n",
|
| 221 |
+
" \n",
|
| 222 |
+
" # 更新时间标签\n",
|
| 223 |
+
" for ax in [ax1, ax2, ax3]:\n",
|
| 224 |
+
" ax.set_xlabel(f\"Timestep: {frame}/{full_seq.shape[0]-1}\")\n",
|
| 225 |
+
" \n",
|
| 226 |
+
" return [im1, im2, quiver]\n",
|
| 227 |
+
" \n",
|
| 228 |
+
" # 生成动画\n",
|
| 229 |
+
" ani = animation.FuncAnimation(fig, update, frames=full_seq.shape[0], \n",
|
| 230 |
+
" interval=1000//fps, blit=False)\n",
|
| 231 |
+
" ani.save('kuroshio_animation.gif', writer='pillow', fps=fps)\n",
|
| 232 |
+
" plt.close()\n",
|
| 233 |
+
" print(\"可视化已保存为 kuroshio_animation.gif\")\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"# 完整使用示例\n",
|
| 236 |
+
"if __name__ == \"__main__\":\n",
|
| 237 |
+
" # 配置参数\n",
|
| 238 |
+
" config = {\n",
|
| 239 |
+
" 'input_length': 25,\n",
|
| 240 |
+
" 'target_length': 25,\n",
|
| 241 |
+
" 'downsample_factor': 1\n",
|
| 242 |
+
" }\n",
|
| 243 |
+
" \n",
|
| 244 |
+
" # 加载数据\n",
|
| 245 |
+
" train_ds, val_ds, test_ds = load_datasets('./Kuroshio_window_data.h5', config)\n",
|
| 246 |
+
" train_loader = DataLoader(train_ds, batch_size=10, shuffle=True)\n",
|
| 247 |
+
" \n",
|
| 248 |
+
" # 获取样本数据\n",
|
| 249 |
+
" sample_input, sample_target = next(iter(train_loader))\n",
|
| 250 |
+
" \n",
|
| 251 |
+
" # 生成可视化(关键参数调整)\n",
|
| 252 |
+
" create_visualization(\n",
|
| 253 |
+
" sample_input, \n",
|
| 254 |
+
" sample_target,\n",
|
| 255 |
+
" sample_idx=2, # 选择样本索引\n",
|
| 256 |
+
" downsample=1, # 矢量场密度(值越小越密集)\n",
|
| 257 |
+
" fps=4 # 动画帧率\n",
|
| 258 |
+
" )"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": 2,
|
| 264 |
+
"id": "d0454a79-3e01-49dd-b4c5-0aca0bd76bcf",
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"outputs": [],
|
| 267 |
+
"source": [
|
| 268 |
+
"import os\n",
|
| 269 |
+
"import torch\n",
|
| 270 |
+
"import torch.distributed as dist\n",
|
| 271 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 272 |
+
"from torch.utils.data.distributed import DistributedSampler\n",
|
| 273 |
+
"import h5py\n",
|
| 274 |
+
"import numpy as np\n",
|
| 275 |
+
"from torch.utils.data import Dataset\n",
|
| 276 |
+
"from torch.utils.data import DataLoader\n",
|
| 277 |
+
"import torchvision.transforms as transforms\n",
|
| 278 |
+
"import torch.utils.data as data\n",
|
| 279 |
+
"import h5py\n",
|
| 280 |
+
"import torch\n",
|
| 281 |
+
"import numpy as np\n",
|
| 282 |
+
"import matplotlib.pyplot as plt\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"class WeatherDataset(Dataset):\n",
|
| 285 |
+
" def __init__(self, data_path, horizon, transform=None):\n",
|
| 286 |
+
" with h5py.File(data_path, 'r') as f:\n",
|
| 287 |
+
" self.data_uv_g = f['u_k'][:] \n",
|
| 288 |
+
" self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)\n",
|
| 289 |
+
" self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2) \n",
|
| 290 |
+
" \n",
|
| 291 |
+
" self.data_uv_k = f['v_k'][:] \n",
|
| 292 |
+
" self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)\n",
|
| 293 |
+
" self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2) \n",
|
| 294 |
+
" self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)\n",
|
| 295 |
+
" self.transform = transform\n",
|
| 296 |
+
" self.horizon = horizon\n",
|
| 297 |
+
" self.mean = 0\n",
|
| 298 |
+
" self.std = 1\n",
|
| 299 |
+
" \n",
|
| 300 |
+
" def __len__(self):\n",
|
| 301 |
+
" return len(self.data_uv_gk)\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" def __getitem__(self, idx):\n",
|
| 304 |
+
" input_frames = self.data_uv_gk[idx][:self.horizon]\n",
|
| 305 |
+
" output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]\n",
|
| 306 |
+
" input_frames = (input_frames - self.mean) / self.std\n",
|
| 307 |
+
" output_frames = (output_frames - self.mean) / self.std\n",
|
| 308 |
+
" return input_frames, output_frames\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):\n",
|
| 311 |
+
" dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)\n",
|
| 312 |
+
" \n",
|
| 313 |
+
" total_samples = len(dataset)\n",
|
| 314 |
+
" train_size = int(0.8 * total_samples)\n",
|
| 315 |
+
" val_size = int(0.1 * total_samples)\n",
|
| 316 |
+
" \n",
|
| 317 |
+
" train_dataset = dataset[:train_size]\n",
|
| 318 |
+
" val_dataset = dataset[train_size:train_size+val_size]\n",
|
| 319 |
+
" test_dataset = dataset[train_size+val_size:]\n",
|
| 320 |
+
" \n",
|
| 321 |
+
" train_sampler = DistributedSampler(train_dataset)\n",
|
| 322 |
+
" val_sampler = DistributedSampler(val_dataset)\n",
|
| 323 |
+
" test_sampler = DistributedSampler(test_dataset)\n",
|
| 324 |
+
"\n",
|
| 325 |
+
" dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,\n",
|
| 326 |
+
" num_workers=num_workers, drop_last=True)\n",
|
| 327 |
+
" dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,\n",
|
| 328 |
+
" num_workers=num_workers, drop_last=True)\n",
|
| 329 |
+
" dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,\n",
|
| 330 |
+
" num_workers=num_workers, drop_last=True)\n",
|
| 331 |
+
" mean, std = 0, 1\n",
|
| 332 |
+
"\n",
|
| 333 |
+
" return dataloader_train, dataloader_validation, dataloader_test, mean, std"
|
| 334 |
+
]
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"cell_type": "code",
|
| 338 |
+
"execution_count": 3,
|
| 339 |
+
"id": "6b051bb4-492a-4b4a-828a-6099ce9437b4",
|
| 340 |
+
"metadata": {},
|
| 341 |
+
"outputs": [
|
| 342 |
+
{
|
| 343 |
+
"ename": "NameError",
|
| 344 |
+
"evalue": "name 'data_tensor' is not defined",
|
| 345 |
+
"output_type": "error",
|
| 346 |
+
"traceback": [
|
| 347 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 348 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
| 349 |
+
"Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m train_loader, val_loader, test_loader, mean, std \u001b[38;5;241m=\u001b[39m \u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/jizhicfs/easyluwu/ocean_project/kuro/ft_local\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mval_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mhorizon\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m input_frames, output_frames \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28miter\u001b[39m(train_loader):\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(input_frames\u001b[38;5;241m.\u001b[39mshape, output_frames\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;66;03m# [B, T, C, H, W]\u001b[39;00m\n",
|
| 350 |
+
"Cell \u001b[0;32mIn[2], line 50\u001b[0m, in \u001b[0;36mload_data\u001b[0;34m(data_path, batch_size, val_batch_size, horizon, num_workers)\u001b[0m\n\u001b[1;32m 47\u001b[0m train_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;241m0.8\u001b[39m \u001b[38;5;241m*\u001b[39m total_samples)\n\u001b[1;32m 48\u001b[0m val_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;241m0.1\u001b[39m \u001b[38;5;241m*\u001b[39m total_samples)\n\u001b[0;32m---> 50\u001b[0m train_dataset \u001b[38;5;241m=\u001b[39m \u001b[43mdata_tensor\u001b[49m[:train_size]\n\u001b[1;32m 51\u001b[0m val_dataset \u001b[38;5;241m=\u001b[39m data_tensor[train_size:train_size\u001b[38;5;241m+\u001b[39mval_size]\n\u001b[1;32m 52\u001b[0m test_dataset \u001b[38;5;241m=\u001b[39m data_tensor[train_size\u001b[38;5;241m+\u001b[39mval_size:]\n",
|
| 351 |
+
"\u001b[0;31mNameError\u001b[0m: name 'data_tensor' is not defined"
|
| 352 |
+
]
|
| 353 |
+
}
|
| 354 |
+
],
|
| 355 |
+
"source": [
|
| 356 |
+
"if __name__ == '__main__':\n",
|
| 357 |
+
" train_loader, val_loader, test_loader, mean, std = load_data(data_path='/jizhicfs/easyluwu/ocean_project/kuro/ft_local',\n",
|
| 358 |
+
" batch_size=8, \n",
|
| 359 |
+
" val_batch_size=8, \n",
|
| 360 |
+
" horizon=10,\n",
|
| 361 |
+
" num_workers=8)\n",
|
| 362 |
+
" for input_frames, output_frames in iter(train_loader):\n",
|
| 363 |
+
" print(input_frames.shape, output_frames.shape) # [B, T, C, H, W]\n",
|
| 364 |
+
" break"
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"cell_type": "code",
|
| 369 |
+
"execution_count": null,
|
| 370 |
+
"id": "95e177b0-9b93-42b6-b809-350fadc23a9b",
|
| 371 |
+
"metadata": {},
|
| 372 |
+
"outputs": [],
|
| 373 |
+
"source": []
|
| 374 |
+
}
|
| 375 |
+
],
|
| 376 |
+
"metadata": {
|
| 377 |
+
"kernelspec": {
|
| 378 |
+
"display_name": "Python 3 (ipykernel)",
|
| 379 |
+
"language": "python",
|
| 380 |
+
"name": "python3"
|
| 381 |
+
},
|
| 382 |
+
"language_info": {
|
| 383 |
+
"codemirror_mode": {
|
| 384 |
+
"name": "ipython",
|
| 385 |
+
"version": 3
|
| 386 |
+
},
|
| 387 |
+
"file_extension": ".py",
|
| 388 |
+
"mimetype": "text/x-python",
|
| 389 |
+
"name": "python",
|
| 390 |
+
"nbconvert_exporter": "python",
|
| 391 |
+
"pygments_lexer": "ipython3",
|
| 392 |
+
"version": "3.8.20"
|
| 393 |
+
}
|
| 394 |
+
},
|
| 395 |
+
"nbformat": 4,
|
| 396 |
+
"nbformat_minor": 5
|
| 397 |
+
}
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import netCDF4 as nc
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
|
| 6 |
+
args = {
|
| 7 |
+
'data_path': '/data/workspace/yancheng/MM/OriSTP/dataset/05res',
|
| 8 |
+
'ocean_lead_time': 10,
|
| 9 |
+
'atmosphere_lead_time': 10,
|
| 10 |
+
'shuffle': True,
|
| 11 |
+
'variables_input': [0, 2, 3, 4],
|
| 12 |
+
'variables_future': [2, 3, 4],
|
| 13 |
+
'variables_output': [0],
|
| 14 |
+
'lon_start': 0,
|
| 15 |
+
'lat_start': 0,
|
| 16 |
+
'lon_end': 720,
|
| 17 |
+
'lat_end': 360,
|
| 18 |
+
'ds_factor': 1,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
class train_Dataset(data.Dataset):
|
| 22 |
+
def __init__(self, args):
|
| 23 |
+
super(train_Dataset, self).__init__()
|
| 24 |
+
self.args = args
|
| 25 |
+
self.years = range(1993, 2018)
|
| 26 |
+
self.dates = range(12, 357, 3)
|
| 27 |
+
self.indices = [(m, n) for m in self.years for n in self.dates]
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, index):
|
| 30 |
+
years, dates = self.indices[index]
|
| 31 |
+
train_data = nc.Dataset(f'{self.args["data_path"]}/025res_{years}.nc')
|
| 32 |
+
input_now = train_data.variables['mhws_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
|
| 33 |
+
self.args['variables_input'],
|
| 34 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 35 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 36 |
+
|
| 37 |
+
input_future = train_data.variables['mhws_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
|
| 38 |
+
self.args['variables_future'],
|
| 39 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 40 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 41 |
+
|
| 42 |
+
input = np.concatenate([input_now, input_future], 1)
|
| 43 |
+
|
| 44 |
+
target = train_data.variables['mhws_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
|
| 45 |
+
self.args['variables_output'],
|
| 46 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 47 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 48 |
+
|
| 49 |
+
input = torch.tensor(input, dtype=torch.float32)
|
| 50 |
+
target = torch.tensor(target, dtype=torch.float32)
|
| 51 |
+
input = torch.nan_to_num(input, nan=0.0)
|
| 52 |
+
target = torch.nan_to_num(target, nan=0.0)
|
| 53 |
+
|
| 54 |
+
return input, target
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.indices)
|
| 58 |
+
|
| 59 |
+
class test_Dataset(data.Dataset):
|
| 60 |
+
def __init__(self, args):
|
| 61 |
+
super(test_Dataset, self).__init__()
|
| 62 |
+
self.args = args
|
| 63 |
+
self.years = range(2018, 2022)
|
| 64 |
+
self.dates = range(12, 357, 3)
|
| 65 |
+
self.indices = [(m, n) for m in self.years for n in self.dates]
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
years, dates = self.indices[index]
|
| 69 |
+
train_data = nc.Dataset(f'{self.args["data_path"]}/025res_{years}.nc')
|
| 70 |
+
input_now = train_data.variables['mhws_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
|
| 71 |
+
self.args['variables_input'],
|
| 72 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 73 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 74 |
+
|
| 75 |
+
input_future = train_data.variables['mhws_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
|
| 76 |
+
self.args['variables_future'],
|
| 77 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 78 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 79 |
+
|
| 80 |
+
input = np.concatenate([input_now, input_future], 1)
|
| 81 |
+
|
| 82 |
+
target = train_data.variables['mhws_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
|
| 83 |
+
self.args['variables_output'],
|
| 84 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 85 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 86 |
+
|
| 87 |
+
input = torch.tensor(input, dtype=torch.float32)
|
| 88 |
+
target = torch.tensor(target, dtype=torch.float32)
|
| 89 |
+
input = torch.nan_to_num(input, nan=0.0)
|
| 90 |
+
target = torch.nan_to_num(target, nan=0.0)
|
| 91 |
+
|
| 92 |
+
return input, target
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return len(self.indices)
|
| 96 |
+
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
args = {
|
| 99 |
+
'data_path': '/jizhicfs/easyluwu/dataset/ft_local',
|
| 100 |
+
'ocean_lead_time': 10,
|
| 101 |
+
'atmosphere_lead_time': 10,
|
| 102 |
+
'shuffle': True,
|
| 103 |
+
'variables_input': [1, 2, 3, 4],
|
| 104 |
+
'variables_future': [2, 3, 4],
|
| 105 |
+
'variables_output': [1],
|
| 106 |
+
'lon_start': 0,
|
| 107 |
+
'lat_start': 0,
|
| 108 |
+
'lon_end': 1440,
|
| 109 |
+
'lat_end': 720,
|
| 110 |
+
'ds_factor': 1,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
train_dataset = train_Dataset(args)
|
| 115 |
+
test_dataset = test_Dataset(args)
|
| 116 |
+
|
| 117 |
+
train_loader = data.DataLoader(train_dataset, batch_size=2)
|
| 118 |
+
test_loader = data.DataLoader(test_dataset, batch_size=2)
|
| 119 |
+
|
| 120 |
+
for inputs, targets in iter(train_loader):
|
| 121 |
+
print(inputs.shape, targets.shape)
|
| 122 |
+
break
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_high_kuro-checkpoint.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import netCDF4 as nc
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class train_Dataset(data.Dataset):
|
| 9 |
+
def __init__(self, args):
|
| 10 |
+
super(train_Dataset, self).__init__()
|
| 11 |
+
self.args = args
|
| 12 |
+
self.years = range(1993, 2018)
|
| 13 |
+
self.dates = range(12, 357, 3)
|
| 14 |
+
self.indices = [(m, n) for m in self.years for n in self.dates]
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
years, dates = self.indices[index]
|
| 18 |
+
train_data = nc.Dataset(f'{self.args["data_path"]}/KURO_{years}_norm.nc')
|
| 19 |
+
input_now = train_data.variables['mhw_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
|
| 20 |
+
self.args['variables_input'],
|
| 21 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 22 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 23 |
+
|
| 24 |
+
input_future = train_data.variables['mhw_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
|
| 25 |
+
self.args['variables_future'],
|
| 26 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 27 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 28 |
+
|
| 29 |
+
input = np.concatenate([input_now, input_future], 1)
|
| 30 |
+
|
| 31 |
+
target = train_data.variables['mhw_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
|
| 32 |
+
self.args['variables_output'],
|
| 33 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 34 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 35 |
+
|
| 36 |
+
input = torch.tensor(input, dtype=torch.float32)
|
| 37 |
+
target = torch.tensor(target, dtype=torch.float32)
|
| 38 |
+
input = torch.nan_to_num(input, nan=0.0)
|
| 39 |
+
target = torch.nan_to_num(target, nan=0.0)
|
| 40 |
+
|
| 41 |
+
return input, target
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.indices)
|
| 45 |
+
|
| 46 |
+
class test_Dataset(data.Dataset):
|
| 47 |
+
def __init__(self, args):
|
| 48 |
+
super(test_Dataset, self).__init__()
|
| 49 |
+
self.args = args
|
| 50 |
+
self.years = range(2018, 2021)
|
| 51 |
+
self.dates = range(12, 357, 3)
|
| 52 |
+
self.indices = [(m, n) for m in self.years for n in self.dates]
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, index):
|
| 55 |
+
years, dates = self.indices[index]
|
| 56 |
+
train_data = nc.Dataset(f'{self.args["data_path"]}/KURO_{years}_norm.nc')
|
| 57 |
+
input_now = train_data.variables['mhw_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
|
| 58 |
+
self.args['variables_input'],
|
| 59 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 60 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 61 |
+
|
| 62 |
+
input_future = train_data.variables['mhw_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
|
| 63 |
+
self.args['variables_future'],
|
| 64 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 65 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 66 |
+
|
| 67 |
+
input = np.concatenate([input_now, input_future], 1)
|
| 68 |
+
|
| 69 |
+
target = train_data.variables['mhw_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
|
| 70 |
+
self.args['variables_output'],
|
| 71 |
+
self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
|
| 72 |
+
self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
|
| 73 |
+
|
| 74 |
+
input = torch.tensor(input, dtype=torch.float32)
|
| 75 |
+
target = torch.tensor(target, dtype=torch.float32)
|
| 76 |
+
input = torch.nan_to_num(input, nan=0.0)
|
| 77 |
+
target = torch.nan_to_num(target, nan=0.0)
|
| 78 |
+
|
| 79 |
+
return input, target
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.indices)
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
+
"id": "f7a16b9b-07cb-46af-b891-d225ca8a8b2c",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/miniconda3/envs/haowu/lib/python3.10/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11000). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)\n",
|
| 14 |
+
" return torch._C._cuda_getDeviceCount() > 0\n"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"name": "stdout",
|
| 19 |
+
"output_type": "stream",
|
| 20 |
+
"text": [
|
| 21 |
+
"torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
|
| 22 |
+
"\n",
|
| 23 |
+
" torch.Size([10, 2, 256, 256]) \n",
|
| 24 |
+
"torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"\n",
|
| 28 |
+
" torch.Size([10, 2, 256, 256])\n",
|
| 29 |
+
"torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) \n",
|
| 32 |
+
"torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"\n",
|
| 35 |
+
" torch.Size([10, 2, 256, 256])\n",
|
| 36 |
+
"torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])\n",
|
| 37 |
+
"torch.Size([1, 10, 2, 256, 256]) torch.Size([1, 10, 2, 256, 256])\n",
|
| 38 |
+
"输入数据范围: [-1.54, 1.66]\n",
|
| 39 |
+
"NaN值存在性: False\n",
|
| 40 |
+
"Inf值存在性: False\n"
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
],
|
| 44 |
+
"source": [
|
| 45 |
+
"import torch\n",
|
| 46 |
+
"import torch.distributed as dist\n",
|
| 47 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 48 |
+
"from torch.utils.data.distributed import DistributedSampler\n",
|
| 49 |
+
"import netCDF4 as nc\n",
|
| 50 |
+
"import numpy as np\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"class OceanCurrentDataset(Dataset):\n",
|
| 53 |
+
" def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):\n",
|
| 54 |
+
" \"\"\"\n",
|
| 55 |
+
" 海洋流数据集类\n",
|
| 56 |
+
" :param data_path: NetCDF文件路径\n",
|
| 57 |
+
" :param input_steps: 输入时间步数\n",
|
| 58 |
+
" :param output_steps: 预测时间步数\n",
|
| 59 |
+
" :param transform: 数据增强变换\n",
|
| 60 |
+
" \"\"\"\n",
|
| 61 |
+
" self.data_path = data_path\n",
|
| 62 |
+
" self.input_steps = input_steps\n",
|
| 63 |
+
" self.output_steps = output_steps\n",
|
| 64 |
+
" self.transform = transform\n",
|
| 65 |
+
" self.total_steps = input_steps + output_steps\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" # 加载并预处理数据\n",
|
| 68 |
+
" self.data = self._load_and_process_data()\n",
|
| 69 |
+
" self.mean, self.std = 0, 1\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" def _load_and_process_data(self):\n",
|
| 72 |
+
" \"\"\"加载和处理NetCDF数据\"\"\"\n",
|
| 73 |
+
" with nc.Dataset(self.data_path, 'r') as ds:\n",
|
| 74 |
+
" # 处理缺失值\n",
|
| 75 |
+
" def process_var(var):\n",
|
| 76 |
+
" arr = var[:]\n",
|
| 77 |
+
" if '_FillValue' in var.ncattrs():\n",
|
| 78 |
+
" fill_value = var._FillValue\n",
|
| 79 |
+
" arr = np.ma.masked_values(arr, fill_value).filled(np.nan)\n",
|
| 80 |
+
" return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)\n",
|
| 81 |
+
"\n",
|
| 82 |
+
" # 加载并合并UV分量\n",
|
| 83 |
+
" ugos = process_var(ds['ugos']) # (time, lat, lon)\n",
|
| 84 |
+
" vgos = process_var(ds['vgos'])\n",
|
| 85 |
+
" \n",
|
| 86 |
+
" # 调整维度顺序 [time, channels, lat, lon]\n",
|
| 87 |
+
" return torch.stack([ugos, vgos], dim=1) \n",
|
| 88 |
+
"\n",
|
| 89 |
+
" def _compute_stats(self):\n",
|
| 90 |
+
" \"\"\"计算训练集的统计量\"\"\"\n",
|
| 91 |
+
" return torch.mean(self.data[:10000]), torch.std(self.data[:10000])\n",
|
| 92 |
+
"\n",
|
| 93 |
+
" def __len__(self):\n",
|
| 94 |
+
" return len(self.data) - self.total_steps + 1\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" def __getitem__(self, idx):\n",
|
| 97 |
+
" window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]\n",
|
| 98 |
+
" \n",
|
| 99 |
+
" window = (window - self.mean) / self.std\n",
|
| 100 |
+
" \n",
|
| 101 |
+
" # 分割输入输出\n",
|
| 102 |
+
" input_seq = window[:self.input_steps]\n",
|
| 103 |
+
" target_seq = window[self.input_steps:]\n",
|
| 104 |
+
" print(input_seq.shape, target_seq.shape)\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" if self.transform:\n",
|
| 107 |
+
" input_seq = self.transform(input_seq)\n",
|
| 108 |
+
" target_seq = self.transform(target_seq)\n",
|
| 109 |
+
" \n",
|
| 110 |
+
" return input_seq[:,:,::2,::2], target_seq[:,:,::2,::2]\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"def create_dataloaders(config):\n",
|
| 113 |
+
" full_dataset = OceanCurrentDataset(\n",
|
| 114 |
+
" data_path=config['data_path'],\n",
|
| 115 |
+
" input_steps=config['input_steps'],\n",
|
| 116 |
+
" output_steps=config['output_steps']\n",
|
| 117 |
+
" )\n",
|
| 118 |
+
" \n",
|
| 119 |
+
" train_size = 10000 - config['input_steps'] - config['output_steps'] + 1\n",
|
| 120 |
+
" val_size = 500\n",
|
| 121 |
+
" test_size = len(full_dataset) - train_size - val_size\n",
|
| 122 |
+
" \n",
|
| 123 |
+
" train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
|
| 124 |
+
" full_dataset, [train_size, val_size, test_size],\n",
|
| 125 |
+
" generator=torch.Generator().manual_seed(config['seed'])\n",
|
| 126 |
+
" )\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" # train_sampler = DistributedSampler(train_dataset, shuffle=True)\n",
|
| 129 |
+
" # val_sampler = DistributedSampler(val_dataset, shuffle=False)\n",
|
| 130 |
+
" # test_sampler = DistributedSampler(test_dataset, shuffle=False)\n",
|
| 131 |
+
" \n",
|
| 132 |
+
" dataloader_train = DataLoader(\n",
|
| 133 |
+
" train_dataset,\n",
|
| 134 |
+
" batch_size=config['batch_size'],\n",
|
| 135 |
+
" num_workers=config['num_workers'],\n",
|
| 136 |
+
" pin_memory=True,\n",
|
| 137 |
+
" drop_last=True\n",
|
| 138 |
+
" )\n",
|
| 139 |
+
" \n",
|
| 140 |
+
" dataloader_val = DataLoader(\n",
|
| 141 |
+
" val_dataset,\n",
|
| 142 |
+
" batch_size=config['val_batch_size'],\n",
|
| 143 |
+
" num_workers=config['num_workers'],\n",
|
| 144 |
+
" pin_memory=True,\n",
|
| 145 |
+
" drop_last=True\n",
|
| 146 |
+
" )\n",
|
| 147 |
+
" \n",
|
| 148 |
+
" dataloader_test = DataLoader(\n",
|
| 149 |
+
" test_dataset,\n",
|
| 150 |
+
" batch_size=config['val_batch_size'],\n",
|
| 151 |
+
" num_workers=config['num_workers'],\n",
|
| 152 |
+
" pin_memory=True,\n",
|
| 153 |
+
" drop_last=True\n",
|
| 154 |
+
" )\n",
|
| 155 |
+
" \n",
|
| 156 |
+
" return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"config = {\n",
|
| 159 |
+
" 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',\n",
|
| 160 |
+
" 'input_steps': 10,\n",
|
| 161 |
+
" 'output_steps': 10,\n",
|
| 162 |
+
" 'batch_size': 1,\n",
|
| 163 |
+
" 'val_batch_size': 1,\n",
|
| 164 |
+
" 'num_workers': 8,\n",
|
| 165 |
+
" 'seed': 42\n",
|
| 166 |
+
"}\n",
|
| 167 |
+
"# dist.init_process_group(backend='nccl')\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"for sample_input, sample_target in train_loader:\n",
|
| 172 |
+
" print(sample_input.shape, sample_target.shape)\n",
|
| 173 |
+
" print(f\"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]\")\n",
|
| 174 |
+
" print(f\"NaN值存在性: {torch.isnan(sample_input).any().item()}\")\n",
|
| 175 |
+
" print(f\"Inf值存在性: {torch.isinf(sample_input).any().item()}\")\n",
|
| 176 |
+
" break"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": null,
|
| 182 |
+
"id": "ad0379fc-13ba-48b3-84ad-5356f0e03968",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": []
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
"metadata": {
|
| 189 |
+
"kernelspec": {
|
| 190 |
+
"display_name": "Python 3 (ipykernel)",
|
| 191 |
+
"language": "python",
|
| 192 |
+
"name": "python3"
|
| 193 |
+
},
|
| 194 |
+
"language_info": {
|
| 195 |
+
"codemirror_mode": {
|
| 196 |
+
"name": "ipython",
|
| 197 |
+
"version": 3
|
| 198 |
+
},
|
| 199 |
+
"file_extension": ".py",
|
| 200 |
+
"mimetype": "text/x-python",
|
| 201 |
+
"name": "python",
|
| 202 |
+
"nbconvert_exporter": "python",
|
| 203 |
+
"pygments_lexer": "ipython3",
|
| 204 |
+
"version": "3.10.16"
|
| 205 |
+
}
|
| 206 |
+
},
|
| 207 |
+
"nbformat": 4,
|
| 208 |
+
"nbformat_minor": 5
|
| 209 |
+
}
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 6 |
+
import h5py
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torch.utils.data as data
|
| 12 |
+
import h5py
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
class WeatherDataset(Dataset):
|
| 18 |
+
def __init__(self, data_path, horizon, transform=None):
|
| 19 |
+
with h5py.File(data_path, 'r') as f:
|
| 20 |
+
self.data_uv_g = f['uv_g'][:]
|
| 21 |
+
self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
|
| 22 |
+
self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 23 |
+
|
| 24 |
+
self.data_uv_k = f['uv_k'][:]
|
| 25 |
+
self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
|
| 26 |
+
self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 27 |
+
self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
|
| 28 |
+
self.transform = transform
|
| 29 |
+
self.horizon = horizon
|
| 30 |
+
self.mean = 0
|
| 31 |
+
self.std = 1
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.data_uv_gk)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
input_frames = self.data_uv_gk[idx][:self.horizon]
|
| 38 |
+
output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
|
| 39 |
+
input_frames = (input_frames - self.mean) / self.std
|
| 40 |
+
output_frames = (output_frames - self.mean) / self.std
|
| 41 |
+
return input_frames, output_frames
|
| 42 |
+
|
| 43 |
+
def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
|
| 44 |
+
dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
|
| 45 |
+
dataset_size = len(dataset)
|
| 46 |
+
train_size = int(dataset_size * 0.8)
|
| 47 |
+
val_size = int(dataset_size * 0.1)
|
| 48 |
+
test_size = dataset_size - train_size - val_size
|
| 49 |
+
|
| 50 |
+
train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
|
| 51 |
+
|
| 52 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 53 |
+
val_sampler = DistributedSampler(val_dataset)
|
| 54 |
+
test_sampler = DistributedSampler(test_dataset)
|
| 55 |
+
|
| 56 |
+
dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True,
|
| 57 |
+
num_workers=num_workers, drop_last=True)
|
| 58 |
+
dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=True,
|
| 59 |
+
num_workers=num_workers, drop_last=True)
|
| 60 |
+
dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=True,
|
| 61 |
+
num_workers=num_workers, drop_last=True)
|
| 62 |
+
mean, std = 0, 1
|
| 63 |
+
|
| 64 |
+
return dataloader_train, dataloader_validation, dataloader_test, mean, std
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_G_uv-checkpoint.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 6 |
+
import h5py
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torch.utils.data as data
|
| 12 |
+
import h5py
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
class WeatherDataset(Dataset):
|
| 18 |
+
def __init__(self, data_path, horizon, transform=None):
|
| 19 |
+
with h5py.File(data_path, 'r') as f:
|
| 20 |
+
self.data_uv_g = f['u_g'][:]
|
| 21 |
+
self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
|
| 22 |
+
self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 23 |
+
|
| 24 |
+
self.data_uv_k = f['v_g'][:]
|
| 25 |
+
self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
|
| 26 |
+
self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 27 |
+
self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
|
| 28 |
+
self.transform = transform
|
| 29 |
+
self.horizon = horizon
|
| 30 |
+
self.mean = 0
|
| 31 |
+
self.std = 1
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.data_uv_gk)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
input_frames = self.data_uv_gk[idx][:self.horizon]
|
| 38 |
+
output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
|
| 39 |
+
input_frames = (input_frames - self.mean) / self.std
|
| 40 |
+
output_frames = (output_frames - self.mean) / self.std
|
| 41 |
+
return input_frames, output_frames
|
| 42 |
+
|
| 43 |
+
def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
|
| 44 |
+
dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
|
| 45 |
+
dataset_size = len(dataset)
|
| 46 |
+
train_size = int(dataset_size * 0.8)
|
| 47 |
+
val_size = int(dataset_size * 0.1)
|
| 48 |
+
test_size = dataset_size - train_size - val_size
|
| 49 |
+
|
| 50 |
+
train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
|
| 51 |
+
|
| 52 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 53 |
+
val_sampler = DistributedSampler(val_dataset)
|
| 54 |
+
test_sampler = DistributedSampler(test_dataset)
|
| 55 |
+
|
| 56 |
+
dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,
|
| 57 |
+
num_workers=num_workers, drop_last=True)
|
| 58 |
+
dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,
|
| 59 |
+
num_workers=num_workers, drop_last=True)
|
| 60 |
+
dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,
|
| 61 |
+
num_workers=num_workers, drop_last=True)
|
| 62 |
+
mean, std = 0, 1
|
| 63 |
+
|
| 64 |
+
return dataloader_train, dataloader_validation, dataloader_test, mean, std
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_K_uv-checkpoint.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 6 |
+
import h5py
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torch.utils.data as data
|
| 12 |
+
import h5py
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
class WeatherDataset(Dataset):
|
| 18 |
+
def __init__(self, data_path, horizon, transform=None):
|
| 19 |
+
with h5py.File(data_path, 'r') as f:
|
| 20 |
+
self.data_uv_g = f['u_k'][:]
|
| 21 |
+
self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
|
| 22 |
+
self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 23 |
+
|
| 24 |
+
self.data_uv_k = f['v_k'][:]
|
| 25 |
+
self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
|
| 26 |
+
self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
|
| 27 |
+
self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
|
| 28 |
+
self.transform = transform
|
| 29 |
+
self.horizon = horizon
|
| 30 |
+
self.mean = 0
|
| 31 |
+
self.std = 1
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.data_uv_gk)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
input_frames = self.data_uv_gk[idx][:self.horizon]
|
| 38 |
+
output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
|
| 39 |
+
input_frames = (input_frames - self.mean) / self.std
|
| 40 |
+
output_frames = (output_frames - self.mean) / self.std
|
| 41 |
+
return input_frames, output_frames
|
| 42 |
+
|
| 43 |
+
def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
|
| 44 |
+
dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
|
| 45 |
+
dataset_size = len(dataset)
|
| 46 |
+
train_size = int(dataset_size * 0.8)
|
| 47 |
+
val_size = int(dataset_size * 0.1)
|
| 48 |
+
test_size = dataset_size - train_size - val_size
|
| 49 |
+
|
| 50 |
+
train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
|
| 51 |
+
|
| 52 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 53 |
+
val_sampler = DistributedSampler(val_dataset)
|
| 54 |
+
test_sampler = DistributedSampler(test_dataset)
|
| 55 |
+
|
| 56 |
+
dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,
|
| 57 |
+
num_workers=num_workers, drop_last=True)
|
| 58 |
+
dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,
|
| 59 |
+
num_workers=num_workers, drop_last=True)
|
| 60 |
+
dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,
|
| 61 |
+
num_workers=num_workers, drop_last=True)
|
| 62 |
+
mean, std = 0, 1
|
| 63 |
+
|
| 64 |
+
return dataloader_train, dataloader_validation, dataloader_test, mean, std
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi-checkpoint.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 5 |
+
import netCDF4 as nc
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class OceanCurrentDataset(Dataset):
|
| 9 |
+
def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
|
| 10 |
+
"""
|
| 11 |
+
海洋流数据集类
|
| 12 |
+
:param data_path: NetCDF文件路径
|
| 13 |
+
:param input_steps: 输入时间步数
|
| 14 |
+
:param output_steps: 预测时间步数
|
| 15 |
+
:param transform: 数据增强变换
|
| 16 |
+
"""
|
| 17 |
+
self.data_path = data_path
|
| 18 |
+
self.input_steps = input_steps
|
| 19 |
+
self.output_steps = output_steps
|
| 20 |
+
self.transform = transform
|
| 21 |
+
self.total_steps = input_steps + output_steps
|
| 22 |
+
|
| 23 |
+
# 加载并预处理数据
|
| 24 |
+
self.data = self._load_and_process_data()
|
| 25 |
+
self.mean, self.std = 0, 1
|
| 26 |
+
|
| 27 |
+
def _load_and_process_data(self):
|
| 28 |
+
"""加载和处理NetCDF数据"""
|
| 29 |
+
with nc.Dataset(self.data_path, 'r') as ds:
|
| 30 |
+
# 处理缺失值
|
| 31 |
+
def process_var(var):
|
| 32 |
+
arr = var[:]
|
| 33 |
+
if '_FillValue' in var.ncattrs():
|
| 34 |
+
fill_value = var._FillValue
|
| 35 |
+
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 36 |
+
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 37 |
+
|
| 38 |
+
# 加载并合并UV分量
|
| 39 |
+
ugos = process_var(ds['ugos']) # (time, lat, lon)
|
| 40 |
+
vgos = process_var(ds['vgos'])
|
| 41 |
+
|
| 42 |
+
# 调整维度顺序 [time, channels, lat, lon]
|
| 43 |
+
return torch.stack([ugos, vgos], dim=1)
|
| 44 |
+
|
| 45 |
+
def _compute_stats(self):
|
| 46 |
+
"""计算训练集的统计量"""
|
| 47 |
+
return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data) - self.total_steps + 1
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 54 |
+
|
| 55 |
+
window = (window - self.mean) / self.std
|
| 56 |
+
|
| 57 |
+
# 分割输入输出
|
| 58 |
+
input_seq = window[:self.input_steps]
|
| 59 |
+
target_seq = window[self.input_steps:]
|
| 60 |
+
|
| 61 |
+
if self.transform:
|
| 62 |
+
input_seq = self.transform(input_seq)
|
| 63 |
+
target_seq = self.transform(target_seq)
|
| 64 |
+
|
| 65 |
+
return input_seq, target_seq
|
| 66 |
+
|
| 67 |
+
def create_dataloaders(config):
|
| 68 |
+
full_dataset = OceanCurrentDataset(
|
| 69 |
+
data_path=config['data_path'],
|
| 70 |
+
input_steps=config['input_steps'],
|
| 71 |
+
output_steps=config['output_steps']
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
|
| 75 |
+
val_size = 500
|
| 76 |
+
test_size = len(full_dataset) - train_size - val_size
|
| 77 |
+
|
| 78 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 79 |
+
full_dataset, [train_size, val_size, test_size],
|
| 80 |
+
generator=torch.Generator().manual_seed(config['seed'])
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 84 |
+
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 85 |
+
test_sampler = DistributedSampler(test_dataset, shuffle=False)
|
| 86 |
+
|
| 87 |
+
dataloader_train = DataLoader(
|
| 88 |
+
train_dataset,
|
| 89 |
+
batch_size=config['batch_size'],
|
| 90 |
+
sampler=train_sampler,
|
| 91 |
+
num_workers=config['num_workers'],
|
| 92 |
+
pin_memory=True,
|
| 93 |
+
drop_last=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
dataloader_val = DataLoader(
|
| 97 |
+
val_dataset,
|
| 98 |
+
batch_size=config['val_batch_size'],
|
| 99 |
+
sampler=val_sampler,
|
| 100 |
+
num_workers=config['num_workers'],
|
| 101 |
+
pin_memory=True,
|
| 102 |
+
drop_last=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
dataloader_test = DataLoader(
|
| 106 |
+
test_dataset,
|
| 107 |
+
batch_size=config['val_batch_size'],
|
| 108 |
+
sampler=test_sampler,
|
| 109 |
+
num_workers=config['num_workers'],
|
| 110 |
+
pin_memory=True,
|
| 111 |
+
drop_last=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
|
| 115 |
+
|
| 116 |
+
# config = {
|
| 117 |
+
# 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
|
| 118 |
+
# 'input_steps': 10,
|
| 119 |
+
# 'output_steps': 10,
|
| 120 |
+
# 'batch_size': 1,
|
| 121 |
+
# 'val_batch_size': 1,
|
| 122 |
+
# 'num_workers': 8,
|
| 123 |
+
# 'seed': 42
|
| 124 |
+
# }
|
| 125 |
+
# dist.init_process_group(backend='nccl')
|
| 126 |
+
|
| 127 |
+
# train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
|
| 128 |
+
|
| 129 |
+
# for sample_input, sample_target in train_loader:
|
| 130 |
+
# print(sample_input.shape, sample_target.shape)
|
| 131 |
+
# print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
|
| 132 |
+
# print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
|
| 133 |
+
# print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
|
| 134 |
+
# break
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_128-checkpoint.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 5 |
+
import netCDF4 as nc
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class OceanCurrentDataset(Dataset):
|
| 9 |
+
def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
|
| 10 |
+
"""
|
| 11 |
+
海洋流数据集类
|
| 12 |
+
:param data_path: NetCDF文件路径
|
| 13 |
+
:param input_steps: 输入时间步数
|
| 14 |
+
:param output_steps: 预测时间步数
|
| 15 |
+
:param transform: 数据增强变换
|
| 16 |
+
"""
|
| 17 |
+
self.data_path = data_path
|
| 18 |
+
self.input_steps = input_steps
|
| 19 |
+
self.output_steps = output_steps
|
| 20 |
+
self.transform = transform
|
| 21 |
+
self.total_steps = input_steps + output_steps
|
| 22 |
+
|
| 23 |
+
# 加载并预处理数据
|
| 24 |
+
self.data = self._load_and_process_data()
|
| 25 |
+
self.mean, self.std = 0, 1
|
| 26 |
+
|
| 27 |
+
def _load_and_process_data(self):
|
| 28 |
+
"""加载和处理NetCDF数据"""
|
| 29 |
+
with nc.Dataset(self.data_path, 'r') as ds:
|
| 30 |
+
# 处理缺失值
|
| 31 |
+
def process_var(var):
|
| 32 |
+
arr = var[:]
|
| 33 |
+
if '_FillValue' in var.ncattrs():
|
| 34 |
+
fill_value = var._FillValue
|
| 35 |
+
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 36 |
+
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 37 |
+
|
| 38 |
+
# 加载并合并UV分量
|
| 39 |
+
ugos = process_var(ds['ugos']) # (time, lat, lon)
|
| 40 |
+
vgos = process_var(ds['vgos'])
|
| 41 |
+
|
| 42 |
+
# 调整维度顺序 [time, channels, lat, lon]
|
| 43 |
+
return torch.stack([ugos, vgos], dim=1)
|
| 44 |
+
|
| 45 |
+
def _compute_stats(self):
|
| 46 |
+
"""计算训练集的统计量"""
|
| 47 |
+
return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data) - self.total_steps + 1
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 54 |
+
|
| 55 |
+
window = (window - self.mean) / self.std
|
| 56 |
+
|
| 57 |
+
# 分割输入输出
|
| 58 |
+
input_seq = window[:self.input_steps]
|
| 59 |
+
target_seq = window[self.input_steps:]
|
| 60 |
+
|
| 61 |
+
if self.transform:
|
| 62 |
+
input_seq = self.transform(input_seq)
|
| 63 |
+
target_seq = self.transform(target_seq)
|
| 64 |
+
|
| 65 |
+
return input_seq[:,:,::2,::2], target_seq[:,:,::2,::2]
|
| 66 |
+
|
| 67 |
+
def create_dataloaders(config):
|
| 68 |
+
full_dataset = OceanCurrentDataset(
|
| 69 |
+
data_path=config['data_path'],
|
| 70 |
+
input_steps=config['input_steps'],
|
| 71 |
+
output_steps=config['output_steps']
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
|
| 75 |
+
val_size = 500
|
| 76 |
+
test_size = len(full_dataset) - train_size - val_size
|
| 77 |
+
|
| 78 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 79 |
+
full_dataset, [train_size, val_size, test_size],
|
| 80 |
+
generator=torch.Generator().manual_seed(config['seed'])
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 84 |
+
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 85 |
+
test_sampler = DistributedSampler(test_dataset, shuffle=False)
|
| 86 |
+
|
| 87 |
+
dataloader_train = DataLoader(
|
| 88 |
+
train_dataset,
|
| 89 |
+
batch_size=config['batch_size'],
|
| 90 |
+
sampler=train_sampler,
|
| 91 |
+
num_workers=config['num_workers'],
|
| 92 |
+
pin_memory=True,
|
| 93 |
+
drop_last=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
dataloader_val = DataLoader(
|
| 97 |
+
val_dataset,
|
| 98 |
+
batch_size=config['val_batch_size'],
|
| 99 |
+
sampler=val_sampler,
|
| 100 |
+
num_workers=config['num_workers'],
|
| 101 |
+
pin_memory=True,
|
| 102 |
+
drop_last=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
dataloader_test = DataLoader(
|
| 106 |
+
test_dataset,
|
| 107 |
+
batch_size=config['val_batch_size'],
|
| 108 |
+
sampler=test_sampler,
|
| 109 |
+
num_workers=config['num_workers'],
|
| 110 |
+
pin_memory=True,
|
| 111 |
+
drop_last=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
|
| 115 |
+
|
| 116 |
+
# config = {
|
| 117 |
+
# 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
|
| 118 |
+
# 'input_steps': 10,
|
| 119 |
+
# 'output_steps': 10,
|
| 120 |
+
# 'batch_size': 1,
|
| 121 |
+
# 'val_batch_size': 1,
|
| 122 |
+
# 'num_workers': 8,
|
| 123 |
+
# 'seed': 42
|
| 124 |
+
# }
|
| 125 |
+
# dist.init_process_group(backend='nccl')
|
| 126 |
+
|
| 127 |
+
# train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
|
| 128 |
+
|
| 129 |
+
# for sample_input, sample_target in train_loader:
|
| 130 |
+
# print(sample_input.shape, sample_target.shape)
|
| 131 |
+
# print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
|
| 132 |
+
# print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
|
| 133 |
+
# print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
|
| 134 |
+
# break
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_64-checkpoint.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 5 |
+
import netCDF4 as nc
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class OceanCurrentDataset(Dataset):
|
| 9 |
+
def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
|
| 10 |
+
"""
|
| 11 |
+
海洋流数据集类
|
| 12 |
+
:param data_path: NetCDF文件路径
|
| 13 |
+
:param input_steps: 输入时间步数
|
| 14 |
+
:param output_steps: 预测时间步数
|
| 15 |
+
:param transform: 数据增强变换
|
| 16 |
+
"""
|
| 17 |
+
self.data_path = data_path
|
| 18 |
+
self.input_steps = input_steps
|
| 19 |
+
self.output_steps = output_steps
|
| 20 |
+
self.transform = transform
|
| 21 |
+
self.total_steps = input_steps + output_steps
|
| 22 |
+
|
| 23 |
+
# 加载并预处理数据
|
| 24 |
+
self.data = self._load_and_process_data()
|
| 25 |
+
self.mean, self.std = 0, 1
|
| 26 |
+
|
| 27 |
+
def _load_and_process_data(self):
|
| 28 |
+
"""加载和处理NetCDF数据"""
|
| 29 |
+
with nc.Dataset(self.data_path, 'r') as ds:
|
| 30 |
+
# 处理缺失值
|
| 31 |
+
def process_var(var):
|
| 32 |
+
arr = var[:]
|
| 33 |
+
if '_FillValue' in var.ncattrs():
|
| 34 |
+
fill_value = var._FillValue
|
| 35 |
+
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 36 |
+
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 37 |
+
|
| 38 |
+
# 加载并合并UV分量
|
| 39 |
+
ugos = process_var(ds['ugos']) # (time, lat, lon)
|
| 40 |
+
vgos = process_var(ds['vgos'])
|
| 41 |
+
|
| 42 |
+
# 调整维度顺序 [time, channels, lat, lon]
|
| 43 |
+
return torch.stack([ugos, vgos], dim=1)
|
| 44 |
+
|
| 45 |
+
def _compute_stats(self):
|
| 46 |
+
"""计算训练集的统计量"""
|
| 47 |
+
return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data) - self.total_steps + 1
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 54 |
+
|
| 55 |
+
window = (window - self.mean) / self.std
|
| 56 |
+
|
| 57 |
+
# 分割输入输出
|
| 58 |
+
input_seq = window[:self.input_steps]
|
| 59 |
+
target_seq = window[self.input_steps:]
|
| 60 |
+
|
| 61 |
+
if self.transform:
|
| 62 |
+
input_seq = self.transform(input_seq)
|
| 63 |
+
target_seq = self.transform(target_seq)
|
| 64 |
+
|
| 65 |
+
return input_seq[:,:,::4,::4], target_seq[:,:,::4,::4]
|
| 66 |
+
|
| 67 |
+
def create_dataloaders(config):
|
| 68 |
+
full_dataset = OceanCurrentDataset(
|
| 69 |
+
data_path=config['data_path'],
|
| 70 |
+
input_steps=config['input_steps'],
|
| 71 |
+
output_steps=config['output_steps']
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
|
| 75 |
+
val_size = 500
|
| 76 |
+
test_size = len(full_dataset) - train_size - val_size
|
| 77 |
+
|
| 78 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 79 |
+
full_dataset, [train_size, val_size, test_size],
|
| 80 |
+
generator=torch.Generator().manual_seed(config['seed'])
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 84 |
+
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 85 |
+
test_sampler = DistributedSampler(test_dataset, shuffle=False)
|
| 86 |
+
|
| 87 |
+
dataloader_train = DataLoader(
|
| 88 |
+
train_dataset,
|
| 89 |
+
batch_size=config['batch_size'],
|
| 90 |
+
sampler=train_sampler,
|
| 91 |
+
num_workers=config['num_workers'],
|
| 92 |
+
pin_memory=True,
|
| 93 |
+
drop_last=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
dataloader_val = DataLoader(
|
| 97 |
+
val_dataset,
|
| 98 |
+
batch_size=config['val_batch_size'],
|
| 99 |
+
sampler=val_sampler,
|
| 100 |
+
num_workers=config['num_workers'],
|
| 101 |
+
pin_memory=True,
|
| 102 |
+
drop_last=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
dataloader_test = DataLoader(
|
| 106 |
+
test_dataset,
|
| 107 |
+
batch_size=config['val_batch_size'],
|
| 108 |
+
sampler=test_sampler,
|
| 109 |
+
num_workers=config['num_workers'],
|
| 110 |
+
pin_memory=True,
|
| 111 |
+
drop_last=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
|
| 115 |
+
|
| 116 |
+
# config = {
|
| 117 |
+
# 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
|
| 118 |
+
# 'input_steps': 10,
|
| 119 |
+
# 'output_steps': 10,
|
| 120 |
+
# 'batch_size': 1,
|
| 121 |
+
# 'val_batch_size': 1,
|
| 122 |
+
# 'num_workers': 8,
|
| 123 |
+
# 'seed': 42
|
| 124 |
+
# }
|
| 125 |
+
# dist.init_process_group(backend='nccl')
|
| 126 |
+
|
| 127 |
+
# train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
|
| 128 |
+
|
| 129 |
+
# for sample_input, sample_target in train_loader:
|
| 130 |
+
# print(sample_input.shape, sample_target.shape)
|
| 131 |
+
# print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
|
| 132 |
+
# print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
|
| 133 |
+
# print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
|
| 134 |
+
# break
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_single-checkpoint.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import netCDF4 as nc
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class OceanCurrentDataset(Dataset):
|
| 7 |
+
def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
|
| 8 |
+
self.data_path = data_path
|
| 9 |
+
self.input_steps = input_steps
|
| 10 |
+
self.output_steps = output_steps
|
| 11 |
+
self.transform = transform
|
| 12 |
+
self.total_steps = input_steps + output_steps
|
| 13 |
+
|
| 14 |
+
# Load and process data
|
| 15 |
+
self.data = self._load_and_process_data()
|
| 16 |
+
self.mean, self.std = self._compute_stats()
|
| 17 |
+
|
| 18 |
+
def _load_and_process_data(self):
|
| 19 |
+
"""Load and process NetCDF data"""
|
| 20 |
+
with nc.Dataset(self.data_path, 'r') as ds:
|
| 21 |
+
def process_var(var):
|
| 22 |
+
arr = var[:]
|
| 23 |
+
if '_FillValue' in var.ncattrs():
|
| 24 |
+
fill_value = var._FillValue
|
| 25 |
+
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 26 |
+
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 27 |
+
|
| 28 |
+
ugos = process_var(ds['ugos']) # (time, lat, lon)
|
| 29 |
+
vgos = process_var(ds['vgos'])
|
| 30 |
+
|
| 31 |
+
return torch.stack([ugos, vgos], dim=1) # [time, channels, lat, lon]
|
| 32 |
+
|
| 33 |
+
def _compute_stats(self):
|
| 34 |
+
"""Compute dataset statistics"""
|
| 35 |
+
return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.data) - self.total_steps + 1
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 42 |
+
window = (window - 0) / 1
|
| 43 |
+
|
| 44 |
+
input_seq = window[:self.input_steps]
|
| 45 |
+
target_seq = window[self.input_steps:]
|
| 46 |
+
|
| 47 |
+
if self.transform:
|
| 48 |
+
input_seq = self.transform(input_seq)
|
| 49 |
+
target_seq = self.transform(target_seq)
|
| 50 |
+
|
| 51 |
+
return input_seq, target_seq
|
| 52 |
+
|
| 53 |
+
def create_dataloaders(config):
|
| 54 |
+
full_dataset = OceanCurrentDataset(
|
| 55 |
+
data_path=config['data_path'],
|
| 56 |
+
input_steps=config['input_steps'],
|
| 57 |
+
output_steps=config['output_steps']
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
|
| 61 |
+
val_size = 500
|
| 62 |
+
test_size = len(full_dataset) - train_size - val_size
|
| 63 |
+
|
| 64 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 65 |
+
full_dataset, [train_size, val_size, test_size],
|
| 66 |
+
generator=torch.Generator().manual_seed(config['seed'])
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
dataloader_train = DataLoader(
|
| 70 |
+
train_dataset,
|
| 71 |
+
batch_size=config['batch_size'],
|
| 72 |
+
shuffle=True, # Changed from DistributedSampler
|
| 73 |
+
num_workers=config['num_workers'],
|
| 74 |
+
pin_memory=True,
|
| 75 |
+
drop_last=True
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
dataloader_val = DataLoader(
|
| 79 |
+
val_dataset,
|
| 80 |
+
batch_size=config['val_batch_size'],
|
| 81 |
+
shuffle=False,
|
| 82 |
+
num_workers=config['num_workers'],
|
| 83 |
+
pin_memory=True,
|
| 84 |
+
drop_last=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
dataloader_test = DataLoader(
|
| 88 |
+
test_dataset,
|
| 89 |
+
batch_size=config['val_batch_size'],
|
| 90 |
+
shuffle=False,
|
| 91 |
+
num_workers=config['num_workers'],
|
| 92 |
+
pin_memory=True,
|
| 93 |
+
drop_last=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return dataloader_train, dataloader_val, dataloader_test,0, 1
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_test-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif
ADDED
|
Git LFS Details
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuro_vis-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif
ADDED
|
Git LFS Details
|