easylearning commited on
Commit
fa26901
·
verified ·
1 Parent(s): 1ee2b6d

Upload 205 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. Exp3_Kuroshio_forecasting/.DS_Store +0 -0
  3. Exp3_Kuroshio_forecasting/checkpoints/Kuro_ConvLSTM_exp1_20250311_best_model.pth +3 -0
  4. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250221_best_model.pth +3 -0
  5. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250222_best_model.pth +3 -0
  6. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250223_best_model.pth +3 -0
  7. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250224_best_model.pth +3 -0
  8. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250224_best_model.pth +3 -0
  9. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250316_best_model.pth +3 -0
  10. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp1_20250226_best_model.pth +3 -0
  11. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp2_20250225_best_model.pth +3 -0
  12. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp1_20250224_best_model.pth +3 -0
  13. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp_128_20250324_best_model.pth +3 -0
  14. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp1_best_model.pth +3 -0
  15. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp2_best_model.pth +3 -0
  16. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_128_20250322_best_model.pth +3 -0
  17. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250221_best_model.pth +3 -0
  18. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250222_best_model.pth +3 -0
  19. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model.pth +3 -0
  20. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model_prediction.h5 +3 -0
  21. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_64_20250323_best_model.pth +3 -0
  22. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp2_20241107_best_model.pth +3 -0
  23. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241107_best_model.pth +3 -0
  24. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241111_best_model.pth +3 -0
  25. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp_20241107_best_model.pth +3 -0
  26. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_multi_finetune_20250227_best_model.pth +3 -0
  27. Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250225_best_model.pth +3 -0
  28. Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250226_best_model.pth +3 -0
  29. Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp2_20250226_best_model.pth +3 -0
  30. Exp3_Kuroshio_forecasting/checkpoints/Kuro_Unet_exp_128_20250324_best_model.pth +3 -0
  31. Exp3_Kuroshio_forecasting/checkpoints/Triton_Gulf_uv_20250218_exp1_best_model.pth +3 -0
  32. Exp3_Kuroshio_forecasting/checkpoints/Triton_Kuroshio_uv_20250218_exp1_best_model.pth +3 -0
  33. Exp3_Kuroshio_forecasting/checkpoints/dit_kuro_256_20250227_best_model.pth +3 -0
  34. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  35. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif +3 -0
  36. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.ipynb +397 -0
  37. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.py +122 -0
  38. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_high_kuro-checkpoint.py +82 -0
  39. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.ipynb +209 -0
  40. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.py +69 -0
  41. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_G_uv-checkpoint.py +69 -0
  42. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_K_uv-checkpoint.py +69 -0
  43. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi-checkpoint.py +134 -0
  44. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_128-checkpoint.py +134 -0
  45. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_64-checkpoint.py +134 -0
  46. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_single-checkpoint.py +96 -0
  47. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_test-checkpoint.ipynb +6 -0
  48. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif +3 -0
  49. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuro_vis-checkpoint.ipynb +0 -0
  50. Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif +3 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
37
+ Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
38
+ Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
39
+ Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/ocean_currents_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
40
+ Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/sample_animation-checkpoint.gif filter=lfs diff=lfs merge=lfs -text
41
+ Exp3_Kuroshio_forecasting/plt_triton/nmi_vis.ipynb filter=lfs diff=lfs merge=lfs -text
Exp3_Kuroshio_forecasting/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Exp3_Kuroshio_forecasting/checkpoints/Kuro_ConvLSTM_exp1_20250311_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e12e494d088a481aa995c63c709dc208890abe9da52ac6f72742981a5658cc
3
+ size 11610
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250221_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63595e507a023ffaf26d6d1d7d3ec7b4f3dadbdf87e5c7881b8e1c9bc598ee83
3
+ size 75550190
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250222_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f39ba8c5caaef358aba27d54b8ef392a5a51d4836480753dba4d898565a13a94
3
+ size 75550056
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250223_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e00a1db14d32761d9f1ea1660dbd64c7c33d23f09358be805292b587c1c71dda
3
+ size 75550200
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp1_20250224_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e15e71e84d50c3bd848dd1bd1a888c7396deade2855f0f334fd7edb19a763b80
3
+ size 75550204
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250224_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6bef4b0385afabcbb6a0c677a568e378de0b76a52a06dfbafd071a5bae24591
3
+ size 75550709
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Dit_exp2_20250316_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f708b01a88610e1517094eb1da50ea24d99d7e41a5cac7638aaecfc5fc0b9cee
3
+ size 63095505
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp1_20250226_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe9e87f65bd7b0e3ea3f25d826332065c787b3ef8c0479b18bf13701a6ede152
3
+ size 529476
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Kno_exp2_20250225_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b0ea4a92f7be963bfb50bb4c6d8976fb98b3f6a2236c351ffcccbe03239909a
3
+ size 99835562
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp1_20250224_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:489e579df4993f8b7e24606dc7773f6aae4823cc5fbec0ad31d32be4b304ca5c
3
+ size 19040548
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Simvp_exp_128_20250324_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:613ba5274130915e187780bb4b29586a0dcb2f991a94e7696ec20468bb07f97d
3
+ size 19040464
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp1_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e3f0a2d53bb67432564819e7e08fa35c15f46a898e0d8056cfbc3fdf78c8703
3
+ size 378552683
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_K_uv_20250218_exp2_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9502f5eab2ce1a8ce6cc8961e9fafb201e62f0eed638ff1f3c6894dc1103cfdf
3
+ size 378552813
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_128_20250322_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d00a458f45ea7f8253f71ec5fb54aee5b4db91bc3b206d7c9c1a6a9f6e61f884
3
+ size 378552684
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250221_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:671add7c61a1efa5544703e212dd6aa5845107977578ba0333d8772524daa301
3
+ size 397465196
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250222_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c31f88b68ded17d0abee7bb76f0dfbc73b94173839b866c81b592cbd407c208d
3
+ size 397465203
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940b8236f041d3c847412a62cdac18e7f7cccc5f04059dc632d03167ca781760
3
+ size 378552801
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_20250224_best_model_prediction.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37f30d197d020780343ab8c9054d2d5943d2560bd7b655bb22874250217d398f
3
+ size 68163584
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp1_64_20250323_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c29fbe97c5e335b97a2cba5c2d63f8e46b26c7ee8df8e95a001e9d6180496b8e
3
+ size 378552678
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp2_20241107_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7abe16e33f1e63941f77a02f136639d9a928d41c63f935feb40b161e7a468c6b
3
+ size 378552823
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241107_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d8e14d8f13ab5530b1129ef408b64260c99c4acd4974d80f6a24925fbdf9c16
3
+ size 378552675
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp3_20241111_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a7c63b24f362dd6b929f1fc003949f843a8229040c12f94f54c90b22d0c16fe
3
+ size 378556268
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_exp_20241107_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e54888a509a62f9805009527d153209b16c9861ea3226619523e92b3b879672
3
+ size 378552694
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Triton_multi_finetune_20250227_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58c4db8010b0f292ac4bc93df237cf107f4f21643317a648c47c33e650b216c9
3
+ size 432755108
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250225_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d4a08539f0188ea40e21b4b2b189ef3c24dd44ab683bed220525e3c84681927
3
+ size 99835639
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp1_20250226_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:928f8e6b4d14093c2f5a007d3a2b7bcec4e34adbca102b646db884ad19e61e10
3
+ size 124189508
Exp3_Kuroshio_forecasting/checkpoints/Kuro_U_net_exp2_20250226_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64bbfd30c0a223c0a7bf002291b34bdf62eb7de532321e0a1305b420ebbff8e6
3
+ size 99901810
Exp3_Kuroshio_forecasting/checkpoints/Kuro_Unet_exp_128_20250324_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6fbbf634b2e46b116f94759b6ee669d508f8c03b8caba95b8f4eb713291a0dc
3
+ size 30872161
Exp3_Kuroshio_forecasting/checkpoints/Triton_Gulf_uv_20250218_exp1_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f94637d268f4bcbb77bda05bd7cf32cce1390ecdf56bbfd75f9f8cc6a2202eee
3
+ size 378552693
Exp3_Kuroshio_forecasting/checkpoints/Triton_Kuroshio_uv_20250218_exp1_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9b7d665a91e73d33ecd27bf5dcbfcb07c104cfd6eb33c442726c30b96bd2cae
3
+ size 378552695
Exp3_Kuroshio_forecasting/checkpoints/dit_kuro_256_20250227_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08389ce9e69da332168b798f6790544b3e9ff6c1fa8432c320d83a4d973ae1f7
3
+ size 63092615
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/current_animation-checkpoint.gif ADDED

Git LFS Details

  • SHA256: b041bd12464292ae0bcd2a8e5126f023a73c5e950d8a1a184ebcd1bd465dc152
  • Pointer size: 132 Bytes
  • Size of remote file: 3.09 MB
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.ipynb ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "f3b16ba8-ad82-45c1-8119-b6c61e7311b8",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Input shape: torch.Size([32, 10, 2, 128, 128])\n",
14
+ "Target shape: torch.Size([32, 5, 2, 128, 128])\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import h5py\n",
20
+ "import numpy as np\n",
21
+ "import torch\n",
22
+ "from torch.utils.data import Dataset, DataLoader\n",
23
+ "\n",
24
+ "class KuroshioDataset(Dataset):\n",
25
+ " def __init__(self, data, input_length, target_length, downsample_factor=1):\n",
26
+ " \"\"\"\n",
27
+ " Args:\n",
28
+ " data: Tensor of shape (num_samples, num_timesteps, C, H, W)\n",
29
+ " input_length: Number of input time steps (T_in)\n",
30
+ " target_length: Number of prediction time steps (T_out)\n",
31
+ " downsample_factor: Spatial downsampling factor\n",
32
+ " \"\"\"\n",
33
+ " super().__init__()\n",
34
+ " self.data = data\n",
35
+ " self.input_length = input_length\n",
36
+ " self.target_length = target_length\n",
37
+ " self.downsample_factor = downsample_factor\n",
38
+ "\n",
39
+ " # Validate time dimensions\n",
40
+ " self.num_samples, self.num_timesteps, self.C, self.H, self.W = data.shape\n",
41
+ " self.max_t_start = self.num_timesteps - self.input_length - self.target_length\n",
42
+ " assert self.max_t_start >= 0, \"Not enough timesteps for input and output\"\n",
43
+ "\n",
44
+ " # Generate sample indices (sample_idx, t_start)\n",
45
+ " self.sample_indices = []\n",
46
+ " for s in range(self.num_samples):\n",
47
+ " for t_start in range(self.max_t_start + 1):\n",
48
+ " self.sample_indices.append((s, t_start))\n",
49
+ "\n",
50
+ " def __len__(self):\n",
51
+ " return len(self.sample_indices)\n",
52
+ "\n",
53
+ " def __getitem__(self, idx):\n",
54
+ " s, t_start = self.sample_indices[idx]\n",
55
+ " \n",
56
+ " # Extract sequences\n",
57
+ " input_end = t_start + self.input_length\n",
58
+ " output_end = input_end + self.target_length\n",
59
+ " \n",
60
+ " input_seq = self.data[s, t_start:input_end] # (T_in, C, H, W)\n",
61
+ " target_seq = self.data[s, input_end:output_end] # (T_out, C, H, W)\n",
62
+ "\n",
63
+ " # Apply downsampling\n",
64
+ " if self.downsample_factor > 1:\n",
65
+ " dsf = self.downsample_factor\n",
66
+ " input_seq = input_seq[..., ::dsf, ::dsf]\n",
67
+ " target_seq = target_seq[..., ::dsf, ::dsf]\n",
68
+ "\n",
69
+ " return input_seq.float(), target_seq.float()\n",
70
+ "\n",
71
+ "def load_datasets(file_path, args):\n",
72
+ " # Load and preprocess data\n",
73
+ " with h5py.File(file_path, 'r') as f:\n",
74
+ " u_k = np.transpose(f['u_k'][:], (0, 3, 1, 2)) # (2046, 50, 128, 128)\n",
75
+ " v_k = np.transpose(f['v_k'][:], (0, 3, 1, 2))\n",
76
+ " \n",
77
+ " # Combine u and v channels\n",
78
+ " combined = np.stack([u_k, v_k], axis=2) # (2046, 50, 2, 128, 128)\n",
79
+ " data_tensor = torch.tensor(combined, dtype=torch.float32)\n",
80
+ "\n",
81
+ " # Split dataset\n",
82
+ " total_samples = 2046\n",
83
+ " train_size = int(0.8 * total_samples)\n",
84
+ " val_size = int(0.1 * total_samples)\n",
85
+ " \n",
86
+ " train_data = data_tensor[:train_size]\n",
87
+ " val_data = data_tensor[train_size:train_size+val_size]\n",
88
+ " test_data = data_tensor[train_size+val_size:]\n",
89
+ "\n",
90
+ " # Create datasets\n",
91
+ " train_dataset = KuroshioDataset(train_data, \n",
92
+ " args['input_length'],\n",
93
+ " args['target_length'],\n",
94
+ " args['downsample_factor'])\n",
95
+ " \n",
96
+ " val_dataset = KuroshioDataset(val_data,\n",
97
+ " args['input_length'],\n",
98
+ " args['target_length'],\n",
99
+ " args['downsample_factor'])\n",
100
+ " \n",
101
+ " test_dataset = KuroshioDataset(test_data,\n",
102
+ " args['input_length'],\n",
103
+ " args['target_length'],\n",
104
+ " args['downsample_factor'])\n",
105
+ "\n",
106
+ " return train_dataset, val_dataset, test_dataset\n",
107
+ "\n",
108
+ "# Example usage\n",
109
+ "if __name__ == \"__main__\":\n",
110
+ " config = {\n",
111
+ " 'input_length': 10, # T_in: 输入时间步数\n",
112
+ " 'target_length': 5, # T_out: 预测时间步数\n",
113
+ " 'downsample_factor': 1 # 空间下采样因子\n",
114
+ " }\n",
115
+ "\n",
116
+ " # 加载数据集\n",
117
+ " train_ds, val_ds, test_ds = load_datasets('./Kuroshio_window_data.h5', config)\n",
118
+ "\n",
119
+ " # 创建DataLoader\n",
120
+ " batch_size = 32\n",
121
+ " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
122
+ " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
123
+ " test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)\n",
124
+ "\n",
125
+ " # 验证数据形状\n",
126
+ " sample_input, sample_target = next(iter(train_loader))\n",
127
+ " print(f\"Input shape: {sample_input.shape}\") # 应为 (B, T_in, 2, H, W)\n",
128
+ " print(f\"Target shape: {sample_target.shape}\") # 应为 (B, T_out, 2, H, W)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 4,
134
+ "id": "9c6b1e5c-7874-49f2-9004-c17470a3ae85",
135
+ "metadata": {},
136
+ "outputs": [
137
+ {
138
+ "name": "stdout",
139
+ "output_type": "stream",
140
+ "text": [
141
+ "可视化已保存为 kuroshio_animation.gif\n"
142
+ ]
143
+ }
144
+ ],
145
+ "source": [
146
+ "import h5py\n",
147
+ "import numpy as np\n",
148
+ "import torch\n",
149
+ "import matplotlib.pyplot as plt\n",
150
+ "import matplotlib.animation as animation\n",
151
+ "from matplotlib import gridspec\n",
152
+ "from torch.utils.data import Dataset, DataLoader\n",
153
+ "\n",
154
+ "\n",
155
+ "# 修正后的可视化函数\n",
156
+ "def create_visualization(input_seq, target_seq, sample_idx=0, downsample=4, fps=10):\n",
157
+ " # 数据准备\n",
158
+ " input_np = input_seq[sample_idx].cpu().numpy()\n",
159
+ " target_np = target_seq[sample_idx].cpu().numpy()\n",
160
+ " full_seq = np.concatenate([input_np, target_np], axis=0)\n",
161
+ " full_seq = np.transpose(full_seq, (0, 2, 3, 1)) # [T, H, W, C]\n",
162
+ " \n",
163
+ " # 创建网格\n",
164
+ " H, W = full_seq.shape[1], full_seq.shape[2]\n",
165
+ " x = np.arange(W)\n",
166
+ " y = np.arange(H)\n",
167
+ " X, Y = np.meshgrid(x, y)\n",
168
+ " X_ds, Y_ds = X[::downsample, ::downsample], Y[::downsample, ::downsample]\n",
169
+ " \n",
170
+ " # 计算速度幅值\n",
171
+ " speed = np.sqrt(full_seq[...,0]**2 + full_seq[...,1]**2)\n",
172
+ " speed_min, speed_max = speed.min(), speed.max()\n",
173
+ " \n",
174
+ " # 创建画布\n",
175
+ " fig = plt.figure(figsize=(15, 5), facecolor='white')\n",
176
+ " gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1])\n",
177
+ " ax1 = plt.subplot(gs[0])\n",
178
+ " ax2 = plt.subplot(gs[1])\n",
179
+ " ax3 = plt.subplot(gs[2])\n",
180
+ " \n",
181
+ " # 初始化子图\n",
182
+ " im1 = ax1.imshow(full_seq[0,...,0], origin='lower', cmap='RdBu_r', vmax=1, vmin=-1)\n",
183
+ " ax1.set_title(\"U Component\")\n",
184
+ " plt.colorbar(im1, ax=ax1)\n",
185
+ " \n",
186
+ " im2 = ax2.imshow(full_seq[0,...,1], origin='lower', cmap='RdBu_r', vmax=1, vmin=-1)\n",
187
+ " ax2.set_title(\"V Component\")\n",
188
+ " plt.colorbar(im2, ax=ax2)\n",
189
+ " \n",
190
+ " # 初始化矢量场\n",
191
+ " U = full_seq[0,...,0][::downsample, ::downsample]\n",
192
+ " V = full_seq[0,...,1][::downsample, ::downsample]\n",
193
+ " speed_initial = np.sqrt(U**2 + V**2)\n",
194
+ " quiver = ax3.quiver(X_ds, Y_ds, U, V, speed_initial, \n",
195
+ " cmap='RdBu_r', \n",
196
+ " scale=50, \n",
197
+ " width=0.003,\n",
198
+ " clim=[speed_min, speed_max])\n",
199
+ " plt.colorbar(quiver, ax=ax3, label='Flow Speed')\n",
200
+ " ax3.set_title(\"Vector Field\")\n",
201
+ " \n",
202
+ " # 统一设置\n",
203
+ " for ax in [ax1, ax2, ax3]:\n",
204
+ " ax.set_xticks([])\n",
205
+ " ax.set_yticks([])\n",
206
+ " ax.set_xlabel(f\"Timestep: 0/{full_seq.shape[0]-1}\")\n",
207
+ " \n",
208
+ " # 动画更新函数\n",
209
+ " def update(frame):\n",
210
+ " # 更新分量图\n",
211
+ " im1.set_data(full_seq[frame,...,0])\n",
212
+ " im2.set_data(full_seq[frame,...,1])\n",
213
+ " \n",
214
+ " # 更新矢量场\n",
215
+ " U = full_seq[frame,...,0][::downsample, ::downsample]\n",
216
+ " V = full_seq[frame,...,1][::downsample, ::downsample]\n",
217
+ " speed = np.sqrt(U**2 + V**2)\n",
218
+ " \n",
219
+ " quiver.set_UVC(U, V)\n",
220
+ " quiver.set_array(speed.flatten())\n",
221
+ " \n",
222
+ " # 更新时间标签\n",
223
+ " for ax in [ax1, ax2, ax3]:\n",
224
+ " ax.set_xlabel(f\"Timestep: {frame}/{full_seq.shape[0]-1}\")\n",
225
+ " \n",
226
+ " return [im1, im2, quiver]\n",
227
+ " \n",
228
+ " # 生成动画\n",
229
+ " ani = animation.FuncAnimation(fig, update, frames=full_seq.shape[0], \n",
230
+ " interval=1000//fps, blit=False)\n",
231
+ " ani.save('kuroshio_animation.gif', writer='pillow', fps=fps)\n",
232
+ " plt.close()\n",
233
+ " print(\"可视化已保存为 kuroshio_animation.gif\")\n",
234
+ "\n",
235
+ "# 完整使用示例\n",
236
+ "if __name__ == \"__main__\":\n",
237
+ " # 配置参数\n",
238
+ " config = {\n",
239
+ " 'input_length': 25,\n",
240
+ " 'target_length': 25,\n",
241
+ " 'downsample_factor': 1\n",
242
+ " }\n",
243
+ " \n",
244
+ " # 加载数据\n",
245
+ " train_ds, val_ds, test_ds = load_datasets('./Kuroshio_window_data.h5', config)\n",
246
+ " train_loader = DataLoader(train_ds, batch_size=10, shuffle=True)\n",
247
+ " \n",
248
+ " # 获取样本数据\n",
249
+ " sample_input, sample_target = next(iter(train_loader))\n",
250
+ " \n",
251
+ " # 生成可视化(关键参数调整)\n",
252
+ " create_visualization(\n",
253
+ " sample_input, \n",
254
+ " sample_target,\n",
255
+ " sample_idx=2, # 选择样本索引\n",
256
+ " downsample=1, # 矢量场密度(值越小越密集)\n",
257
+ " fps=4 # 动画帧率\n",
258
+ " )"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 2,
264
+ "id": "d0454a79-3e01-49dd-b4c5-0aca0bd76bcf",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "import os\n",
269
+ "import torch\n",
270
+ "import torch.distributed as dist\n",
271
+ "from torch.utils.data import Dataset, DataLoader\n",
272
+ "from torch.utils.data.distributed import DistributedSampler\n",
273
+ "import h5py\n",
274
+ "import numpy as np\n",
275
+ "from torch.utils.data import Dataset\n",
276
+ "from torch.utils.data import DataLoader\n",
277
+ "import torchvision.transforms as transforms\n",
278
+ "import torch.utils.data as data\n",
279
+ "import h5py\n",
280
+ "import torch\n",
281
+ "import numpy as np\n",
282
+ "import matplotlib.pyplot as plt\n",
283
+ "\n",
284
+ "class WeatherDataset(Dataset):\n",
285
+ " def __init__(self, data_path, horizon, transform=None):\n",
286
+ " with h5py.File(data_path, 'r') as f:\n",
287
+ " self.data_uv_g = f['u_k'][:] \n",
288
+ " self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)\n",
289
+ " self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2) \n",
290
+ " \n",
291
+ " self.data_uv_k = f['v_k'][:] \n",
292
+ " self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)\n",
293
+ " self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2) \n",
294
+ " self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)\n",
295
+ " self.transform = transform\n",
296
+ " self.horizon = horizon\n",
297
+ " self.mean = 0\n",
298
+ " self.std = 1\n",
299
+ " \n",
300
+ " def __len__(self):\n",
301
+ " return len(self.data_uv_gk)\n",
302
+ "\n",
303
+ " def __getitem__(self, idx):\n",
304
+ " input_frames = self.data_uv_gk[idx][:self.horizon]\n",
305
+ " output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]\n",
306
+ " input_frames = (input_frames - self.mean) / self.std\n",
307
+ " output_frames = (output_frames - self.mean) / self.std\n",
308
+ " return input_frames, output_frames\n",
309
+ "\n",
310
+ "def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):\n",
311
+ " dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)\n",
312
+ " \n",
313
+ " total_samples = len(dataset)\n",
314
+ " train_size = int(0.8 * total_samples)\n",
315
+ " val_size = int(0.1 * total_samples)\n",
316
+ " \n",
317
+ " train_dataset = dataset[:train_size]\n",
318
+ " val_dataset = dataset[train_size:train_size+val_size]\n",
319
+ " test_dataset = dataset[train_size+val_size:]\n",
320
+ " \n",
321
+ " train_sampler = DistributedSampler(train_dataset)\n",
322
+ " val_sampler = DistributedSampler(val_dataset)\n",
323
+ " test_sampler = DistributedSampler(test_dataset)\n",
324
+ "\n",
325
+ " dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,\n",
326
+ " num_workers=num_workers, drop_last=True)\n",
327
+ " dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,\n",
328
+ " num_workers=num_workers, drop_last=True)\n",
329
+ " dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,\n",
330
+ " num_workers=num_workers, drop_last=True)\n",
331
+ " mean, std = 0, 1\n",
332
+ "\n",
333
+ " return dataloader_train, dataloader_validation, dataloader_test, mean, std"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": 3,
339
+ "id": "6b051bb4-492a-4b4a-828a-6099ce9437b4",
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "ename": "NameError",
344
+ "evalue": "name 'data_tensor' is not defined",
345
+ "output_type": "error",
346
+ "traceback": [
347
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
348
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
349
+ "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m train_loader, val_loader, test_loader, mean, std \u001b[38;5;241m=\u001b[39m \u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/jizhicfs/easyluwu/ocean_project/kuro/ft_local\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mval_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mhorizon\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m input_frames, output_frames \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28miter\u001b[39m(train_loader):\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(input_frames\u001b[38;5;241m.\u001b[39mshape, output_frames\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;66;03m# [B, T, C, H, W]\u001b[39;00m\n",
350
+ "Cell \u001b[0;32mIn[2], line 50\u001b[0m, in \u001b[0;36mload_data\u001b[0;34m(data_path, batch_size, val_batch_size, horizon, num_workers)\u001b[0m\n\u001b[1;32m 47\u001b[0m train_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;241m0.8\u001b[39m \u001b[38;5;241m*\u001b[39m total_samples)\n\u001b[1;32m 48\u001b[0m val_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;241m0.1\u001b[39m \u001b[38;5;241m*\u001b[39m total_samples)\n\u001b[0;32m---> 50\u001b[0m train_dataset \u001b[38;5;241m=\u001b[39m \u001b[43mdata_tensor\u001b[49m[:train_size]\n\u001b[1;32m 51\u001b[0m val_dataset \u001b[38;5;241m=\u001b[39m data_tensor[train_size:train_size\u001b[38;5;241m+\u001b[39mval_size]\n\u001b[1;32m 52\u001b[0m test_dataset \u001b[38;5;241m=\u001b[39m data_tensor[train_size\u001b[38;5;241m+\u001b[39mval_size:]\n",
351
+ "\u001b[0;31mNameError\u001b[0m: name 'data_tensor' is not defined"
352
+ ]
353
+ }
354
+ ],
355
+ "source": [
356
+ "if __name__ == '__main__':\n",
357
+ " train_loader, val_loader, test_loader, mean, std = load_data(data_path='/jizhicfs/easyluwu/ocean_project/kuro/ft_local',\n",
358
+ " batch_size=8, \n",
359
+ " val_batch_size=8, \n",
360
+ " horizon=10,\n",
361
+ " num_workers=8)\n",
362
+ " for input_frames, output_frames in iter(train_loader):\n",
363
+ " print(input_frames.shape, output_frames.shape) # [B, T, C, H, W]\n",
364
+ " break"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "id": "95e177b0-9b93-42b6-b809-350fadc23a9b",
371
+ "metadata": {},
372
+ "outputs": [],
373
+ "source": []
374
+ }
375
+ ],
376
+ "metadata": {
377
+ "kernelspec": {
378
+ "display_name": "Python 3 (ipykernel)",
379
+ "language": "python",
380
+ "name": "python3"
381
+ },
382
+ "language_info": {
383
+ "codemirror_mode": {
384
+ "name": "ipython",
385
+ "version": 3
386
+ },
387
+ "file_extension": ".py",
388
+ "mimetype": "text/x-python",
389
+ "name": "python",
390
+ "nbconvert_exporter": "python",
391
+ "pygments_lexer": "ipython3",
392
+ "version": "3.8.20"
393
+ }
394
+ },
395
+ "nbformat": 4,
396
+ "nbformat_minor": 5
397
+ }
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader-checkpoint.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import netCDF4 as nc
3
+ import torch
4
+ import torch.utils.data as data
5
+
6
+ args = {
7
+ 'data_path': '/data/workspace/yancheng/MM/OriSTP/dataset/05res',
8
+ 'ocean_lead_time': 10,
9
+ 'atmosphere_lead_time': 10,
10
+ 'shuffle': True,
11
+ 'variables_input': [0, 2, 3, 4],
12
+ 'variables_future': [2, 3, 4],
13
+ 'variables_output': [0],
14
+ 'lon_start': 0,
15
+ 'lat_start': 0,
16
+ 'lon_end': 720,
17
+ 'lat_end': 360,
18
+ 'ds_factor': 1,
19
+ }
20
+
21
+ class train_Dataset(data.Dataset):
22
+ def __init__(self, args):
23
+ super(train_Dataset, self).__init__()
24
+ self.args = args
25
+ self.years = range(1993, 2018)
26
+ self.dates = range(12, 357, 3)
27
+ self.indices = [(m, n) for m in self.years for n in self.dates]
28
+
29
+ def __getitem__(self, index):
30
+ years, dates = self.indices[index]
31
+ train_data = nc.Dataset(f'{self.args["data_path"]}/025res_{years}.nc')
32
+ input_now = train_data.variables['mhws_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
33
+ self.args['variables_input'],
34
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
35
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
36
+
37
+ input_future = train_data.variables['mhws_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
38
+ self.args['variables_future'],
39
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
40
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
41
+
42
+ input = np.concatenate([input_now, input_future], 1)
43
+
44
+ target = train_data.variables['mhws_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
45
+ self.args['variables_output'],
46
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
47
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
48
+
49
+ input = torch.tensor(input, dtype=torch.float32)
50
+ target = torch.tensor(target, dtype=torch.float32)
51
+ input = torch.nan_to_num(input, nan=0.0)
52
+ target = torch.nan_to_num(target, nan=0.0)
53
+
54
+ return input, target
55
+
56
+ def __len__(self):
57
+ return len(self.indices)
58
+
59
+ class test_Dataset(data.Dataset):
60
+ def __init__(self, args):
61
+ super(test_Dataset, self).__init__()
62
+ self.args = args
63
+ self.years = range(2018, 2022)
64
+ self.dates = range(12, 357, 3)
65
+ self.indices = [(m, n) for m in self.years for n in self.dates]
66
+
67
+ def __getitem__(self, index):
68
+ years, dates = self.indices[index]
69
+ train_data = nc.Dataset(f'{self.args["data_path"]}/025res_{years}.nc')
70
+ input_now = train_data.variables['mhws_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
71
+ self.args['variables_input'],
72
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
73
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
74
+
75
+ input_future = train_data.variables['mhws_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
76
+ self.args['variables_future'],
77
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
78
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
79
+
80
+ input = np.concatenate([input_now, input_future], 1)
81
+
82
+ target = train_data.variables['mhws_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
83
+ self.args['variables_output'],
84
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
85
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
86
+
87
+ input = torch.tensor(input, dtype=torch.float32)
88
+ target = torch.tensor(target, dtype=torch.float32)
89
+ input = torch.nan_to_num(input, nan=0.0)
90
+ target = torch.nan_to_num(target, nan=0.0)
91
+
92
+ return input, target
93
+
94
+ def __len__(self):
95
+ return len(self.indices)
96
+
97
+ if __name__ == '__main__':
98
+ args = {
99
+ 'data_path': '/jizhicfs/easyluwu/dataset/ft_local',
100
+ 'ocean_lead_time': 10,
101
+ 'atmosphere_lead_time': 10,
102
+ 'shuffle': True,
103
+ 'variables_input': [1, 2, 3, 4],
104
+ 'variables_future': [2, 3, 4],
105
+ 'variables_output': [1],
106
+ 'lon_start': 0,
107
+ 'lat_start': 0,
108
+ 'lon_end': 1440,
109
+ 'lat_end': 720,
110
+ 'ds_factor': 1,
111
+ }
112
+
113
+
114
+ train_dataset = train_Dataset(args)
115
+ test_dataset = test_Dataset(args)
116
+
117
+ train_loader = data.DataLoader(train_dataset, batch_size=2)
118
+ test_loader = data.DataLoader(test_dataset, batch_size=2)
119
+
120
+ for inputs, targets in iter(train_loader):
121
+ print(inputs.shape, targets.shape)
122
+ break
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_high_kuro-checkpoint.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import netCDF4 as nc
3
+ import torch
4
+ import torch.utils.data as data
5
+
6
+
7
+
8
+ class train_Dataset(data.Dataset):
9
+ def __init__(self, args):
10
+ super(train_Dataset, self).__init__()
11
+ self.args = args
12
+ self.years = range(1993, 2018)
13
+ self.dates = range(12, 357, 3)
14
+ self.indices = [(m, n) for m in self.years for n in self.dates]
15
+
16
+ def __getitem__(self, index):
17
+ years, dates = self.indices[index]
18
+ train_data = nc.Dataset(f'{self.args["data_path"]}/KURO_{years}_norm.nc')
19
+ input_now = train_data.variables['mhw_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
20
+ self.args['variables_input'],
21
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
22
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
23
+
24
+ input_future = train_data.variables['mhw_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
25
+ self.args['variables_future'],
26
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
27
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
28
+
29
+ input = np.concatenate([input_now, input_future], 1)
30
+
31
+ target = train_data.variables['mhw_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
32
+ self.args['variables_output'],
33
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
34
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
35
+
36
+ input = torch.tensor(input, dtype=torch.float32)
37
+ target = torch.tensor(target, dtype=torch.float32)
38
+ input = torch.nan_to_num(input, nan=0.0)
39
+ target = torch.nan_to_num(target, nan=0.0)
40
+
41
+ return input, target
42
+
43
+ def __len__(self):
44
+ return len(self.indices)
45
+
46
+ class test_Dataset(data.Dataset):
47
+ def __init__(self, args):
48
+ super(test_Dataset, self).__init__()
49
+ self.args = args
50
+ self.years = range(2018, 2021)
51
+ self.dates = range(12, 357, 3)
52
+ self.indices = [(m, n) for m in self.years for n in self.dates]
53
+
54
+ def __getitem__(self, index):
55
+ years, dates = self.indices[index]
56
+ train_data = nc.Dataset(f'{self.args["data_path"]}/KURO_{years}_norm.nc')
57
+ input_now = train_data.variables['mhw_variables'][dates-self.args['atmosphere_lead_time']+1:dates+1,
58
+ self.args['variables_input'],
59
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
60
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
61
+
62
+ input_future = train_data.variables['mhw_variables'][dates+1:dates+self.args['atmosphere_lead_time']+1,
63
+ self.args['variables_future'],
64
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
65
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
66
+
67
+ input = np.concatenate([input_now, input_future], 1)
68
+
69
+ target = train_data.variables['mhw_variables'][dates+1:dates+self.args['ocean_lead_time']+1,
70
+ self.args['variables_output'],
71
+ self.args['lat_start']:self.args['lat_end']:self.args['ds_factor'],
72
+ self.args['lon_start']:self.args['lon_end']:self.args['ds_factor']]
73
+
74
+ input = torch.tensor(input, dtype=torch.float32)
75
+ target = torch.tensor(target, dtype=torch.float32)
76
+ input = torch.nan_to_num(input, nan=0.0)
77
+ target = torch.nan_to_num(target, nan=0.0)
78
+
79
+ return input, target
80
+
81
+ def __len__(self):
82
+ return len(self.indices)
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.ipynb ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "f7a16b9b-07cb-46af-b891-d225ca8a8b2c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/miniconda3/envs/haowu/lib/python3.10/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11000). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)\n",
14
+ " return torch._C._cuda_getDeviceCount() > 0\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
22
+ "\n",
23
+ " torch.Size([10, 2, 256, 256]) \n",
24
+ "torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
25
+ "\n",
26
+ "\n",
27
+ "\n",
28
+ " torch.Size([10, 2, 256, 256])\n",
29
+ "torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
30
+ "\n",
31
+ "torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256]) \n",
32
+ "torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])torch.Size([10, 2, 256, 256])\n",
33
+ "\n",
34
+ "\n",
35
+ " torch.Size([10, 2, 256, 256])\n",
36
+ "torch.Size([10, 2, 256, 256]) torch.Size([10, 2, 256, 256])\n",
37
+ "torch.Size([1, 10, 2, 256, 256]) torch.Size([1, 10, 2, 256, 256])\n",
38
+ "输入数据范围: [-1.54, 1.66]\n",
39
+ "NaN值存在性: False\n",
40
+ "Inf值存在性: False\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "import torch\n",
46
+ "import torch.distributed as dist\n",
47
+ "from torch.utils.data import Dataset, DataLoader\n",
48
+ "from torch.utils.data.distributed import DistributedSampler\n",
49
+ "import netCDF4 as nc\n",
50
+ "import numpy as np\n",
51
+ "\n",
52
+ "class OceanCurrentDataset(Dataset):\n",
53
+ " def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):\n",
54
+ " \"\"\"\n",
55
+ " 海洋流数据集类\n",
56
+ " :param data_path: NetCDF文件路径\n",
57
+ " :param input_steps: 输入时间步数\n",
58
+ " :param output_steps: 预测时间步数\n",
59
+ " :param transform: 数据增强变换\n",
60
+ " \"\"\"\n",
61
+ " self.data_path = data_path\n",
62
+ " self.input_steps = input_steps\n",
63
+ " self.output_steps = output_steps\n",
64
+ " self.transform = transform\n",
65
+ " self.total_steps = input_steps + output_steps\n",
66
+ " \n",
67
+ " # 加载并预处理数据\n",
68
+ " self.data = self._load_and_process_data()\n",
69
+ " self.mean, self.std = 0, 1\n",
70
+ "\n",
71
+ " def _load_and_process_data(self):\n",
72
+ " \"\"\"加载和处理NetCDF数据\"\"\"\n",
73
+ " with nc.Dataset(self.data_path, 'r') as ds:\n",
74
+ " # 处理缺失值\n",
75
+ " def process_var(var):\n",
76
+ " arr = var[:]\n",
77
+ " if '_FillValue' in var.ncattrs():\n",
78
+ " fill_value = var._FillValue\n",
79
+ " arr = np.ma.masked_values(arr, fill_value).filled(np.nan)\n",
80
+ " return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)\n",
81
+ "\n",
82
+ " # 加载并合并UV分量\n",
83
+ " ugos = process_var(ds['ugos']) # (time, lat, lon)\n",
84
+ " vgos = process_var(ds['vgos'])\n",
85
+ " \n",
86
+ " # 调整维度顺序 [time, channels, lat, lon]\n",
87
+ " return torch.stack([ugos, vgos], dim=1) \n",
88
+ "\n",
89
+ " def _compute_stats(self):\n",
90
+ " \"\"\"计算训练集的统计量\"\"\"\n",
91
+ " return torch.mean(self.data[:10000]), torch.std(self.data[:10000])\n",
92
+ "\n",
93
+ " def __len__(self):\n",
94
+ " return len(self.data) - self.total_steps + 1\n",
95
+ "\n",
96
+ " def __getitem__(self, idx):\n",
97
+ " window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]\n",
98
+ " \n",
99
+ " window = (window - self.mean) / self.std\n",
100
+ " \n",
101
+ " # 分割输入输出\n",
102
+ " input_seq = window[:self.input_steps]\n",
103
+ " target_seq = window[self.input_steps:]\n",
104
+ " print(input_seq.shape, target_seq.shape)\n",
105
+ " \n",
106
+ " if self.transform:\n",
107
+ " input_seq = self.transform(input_seq)\n",
108
+ " target_seq = self.transform(target_seq)\n",
109
+ " \n",
110
+ " return input_seq[:,:,::2,::2], target_seq[:,:,::2,::2]\n",
111
+ "\n",
112
+ "def create_dataloaders(config):\n",
113
+ " full_dataset = OceanCurrentDataset(\n",
114
+ " data_path=config['data_path'],\n",
115
+ " input_steps=config['input_steps'],\n",
116
+ " output_steps=config['output_steps']\n",
117
+ " )\n",
118
+ " \n",
119
+ " train_size = 10000 - config['input_steps'] - config['output_steps'] + 1\n",
120
+ " val_size = 500\n",
121
+ " test_size = len(full_dataset) - train_size - val_size\n",
122
+ " \n",
123
+ " train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
124
+ " full_dataset, [train_size, val_size, test_size],\n",
125
+ " generator=torch.Generator().manual_seed(config['seed'])\n",
126
+ " )\n",
127
+ " \n",
128
+ " # train_sampler = DistributedSampler(train_dataset, shuffle=True)\n",
129
+ " # val_sampler = DistributedSampler(val_dataset, shuffle=False)\n",
130
+ " # test_sampler = DistributedSampler(test_dataset, shuffle=False)\n",
131
+ " \n",
132
+ " dataloader_train = DataLoader(\n",
133
+ " train_dataset,\n",
134
+ " batch_size=config['batch_size'],\n",
135
+ " num_workers=config['num_workers'],\n",
136
+ " pin_memory=True,\n",
137
+ " drop_last=True\n",
138
+ " )\n",
139
+ " \n",
140
+ " dataloader_val = DataLoader(\n",
141
+ " val_dataset,\n",
142
+ " batch_size=config['val_batch_size'],\n",
143
+ " num_workers=config['num_workers'],\n",
144
+ " pin_memory=True,\n",
145
+ " drop_last=True\n",
146
+ " )\n",
147
+ " \n",
148
+ " dataloader_test = DataLoader(\n",
149
+ " test_dataset,\n",
150
+ " batch_size=config['val_batch_size'],\n",
151
+ " num_workers=config['num_workers'],\n",
152
+ " pin_memory=True,\n",
153
+ " drop_last=True\n",
154
+ " )\n",
155
+ " \n",
156
+ " return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std\n",
157
+ "\n",
158
+ "config = {\n",
159
+ " 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',\n",
160
+ " 'input_steps': 10,\n",
161
+ " 'output_steps': 10,\n",
162
+ " 'batch_size': 1,\n",
163
+ " 'val_batch_size': 1,\n",
164
+ " 'num_workers': 8,\n",
165
+ " 'seed': 42\n",
166
+ "}\n",
167
+ "# dist.init_process_group(backend='nccl')\n",
168
+ "\n",
169
+ "train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)\n",
170
+ "\n",
171
+ "for sample_input, sample_target in train_loader:\n",
172
+ " print(sample_input.shape, sample_target.shape)\n",
173
+ " print(f\"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]\")\n",
174
+ " print(f\"NaN值存在性: {torch.isnan(sample_input).any().item()}\")\n",
175
+ " print(f\"Inf值存在性: {torch.isinf(sample_input).any().item()}\")\n",
176
+ " break"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "ad0379fc-13ba-48b3-84ad-5356f0e03968",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": []
186
+ }
187
+ ],
188
+ "metadata": {
189
+ "kernelspec": {
190
+ "display_name": "Python 3 (ipykernel)",
191
+ "language": "python",
192
+ "name": "python3"
193
+ },
194
+ "language_info": {
195
+ "codemirror_mode": {
196
+ "name": "ipython",
197
+ "version": 3
198
+ },
199
+ "file_extension": ".py",
200
+ "mimetype": "text/x-python",
201
+ "name": "python",
202
+ "nbconvert_exporter": "python",
203
+ "pygments_lexer": "ipython3",
204
+ "version": "3.10.16"
205
+ }
206
+ },
207
+ "nbformat": 4,
208
+ "nbformat_minor": 5
209
+ }
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio-checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ import h5py
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from torch.utils.data import DataLoader
10
+ import torchvision.transforms as transforms
11
+ import torch.utils.data as data
12
+ import h5py
13
+ import torch
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+ class WeatherDataset(Dataset):
18
+ def __init__(self, data_path, horizon, transform=None):
19
+ with h5py.File(data_path, 'r') as f:
20
+ self.data_uv_g = f['uv_g'][:]
21
+ self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
22
+ self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
23
+
24
+ self.data_uv_k = f['uv_k'][:]
25
+ self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
26
+ self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
27
+ self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
28
+ self.transform = transform
29
+ self.horizon = horizon
30
+ self.mean = 0
31
+ self.std = 1
32
+
33
+ def __len__(self):
34
+ return len(self.data_uv_gk)
35
+
36
+ def __getitem__(self, idx):
37
+ input_frames = self.data_uv_gk[idx][:self.horizon]
38
+ output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
39
+ input_frames = (input_frames - self.mean) / self.std
40
+ output_frames = (output_frames - self.mean) / self.std
41
+ return input_frames, output_frames
42
+
43
+ def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
44
+ dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
45
+ dataset_size = len(dataset)
46
+ train_size = int(dataset_size * 0.8)
47
+ val_size = int(dataset_size * 0.1)
48
+ test_size = dataset_size - train_size - val_size
49
+
50
+ train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
51
+
52
+ train_sampler = DistributedSampler(train_dataset)
53
+ val_sampler = DistributedSampler(val_dataset)
54
+ test_sampler = DistributedSampler(test_dataset)
55
+
56
+ dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True,
57
+ num_workers=num_workers, drop_last=True)
58
+ dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=True,
59
+ num_workers=num_workers, drop_last=True)
60
+ dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=True,
61
+ num_workers=num_workers, drop_last=True)
62
+ mean, std = 0, 1
63
+
64
+ return dataloader_train, dataloader_validation, dataloader_test, mean, std
65
+
66
+
67
+
68
+
69
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_G_uv-checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ import h5py
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from torch.utils.data import DataLoader
10
+ import torchvision.transforms as transforms
11
+ import torch.utils.data as data
12
+ import h5py
13
+ import torch
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+ class WeatherDataset(Dataset):
18
+ def __init__(self, data_path, horizon, transform=None):
19
+ with h5py.File(data_path, 'r') as f:
20
+ self.data_uv_g = f['u_g'][:]
21
+ self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
22
+ self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
23
+
24
+ self.data_uv_k = f['v_g'][:]
25
+ self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
26
+ self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
27
+ self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
28
+ self.transform = transform
29
+ self.horizon = horizon
30
+ self.mean = 0
31
+ self.std = 1
32
+
33
+ def __len__(self):
34
+ return len(self.data_uv_gk)
35
+
36
+ def __getitem__(self, idx):
37
+ input_frames = self.data_uv_gk[idx][:self.horizon]
38
+ output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
39
+ input_frames = (input_frames - self.mean) / self.std
40
+ output_frames = (output_frames - self.mean) / self.std
41
+ return input_frames, output_frames
42
+
43
+ def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
44
+ dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
45
+ dataset_size = len(dataset)
46
+ train_size = int(dataset_size * 0.8)
47
+ val_size = int(dataset_size * 0.1)
48
+ test_size = dataset_size - train_size - val_size
49
+
50
+ train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
51
+
52
+ train_sampler = DistributedSampler(train_dataset)
53
+ val_sampler = DistributedSampler(val_dataset)
54
+ test_sampler = DistributedSampler(test_dataset)
55
+
56
+ dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,
57
+ num_workers=num_workers, drop_last=True)
58
+ dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,
59
+ num_workers=num_workers, drop_last=True)
60
+ dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,
61
+ num_workers=num_workers, drop_last=True)
62
+ mean, std = 0, 1
63
+
64
+ return dataloader_train, dataloader_validation, dataloader_test, mean, std
65
+
66
+
67
+
68
+
69
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_K_uv-checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ import h5py
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from torch.utils.data import DataLoader
10
+ import torchvision.transforms as transforms
11
+ import torch.utils.data as data
12
+ import h5py
13
+ import torch
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+ class WeatherDataset(Dataset):
18
+ def __init__(self, data_path, horizon, transform=None):
19
+ with h5py.File(data_path, 'r') as f:
20
+ self.data_uv_g = f['u_k'][:]
21
+ self.data_uv_g = torch.from_numpy(self.data_uv_g).to(torch.float32)
22
+ self.data_uv_g = self.data_uv_g.permute(0, 3, 1, 2).unsqueeze_(2)
23
+
24
+ self.data_uv_k = f['v_k'][:]
25
+ self.data_uv_k = torch.from_numpy(self.data_uv_k).to(torch.float32)
26
+ self.data_uv_k = self.data_uv_k.permute(0, 3, 1, 2).unsqueeze_(2)
27
+ self.data_uv_gk = torch.cat([self.data_uv_g, self.data_uv_k], dim=2)
28
+ self.transform = transform
29
+ self.horizon = horizon
30
+ self.mean = 0
31
+ self.std = 1
32
+
33
+ def __len__(self):
34
+ return len(self.data_uv_gk)
35
+
36
+ def __getitem__(self, idx):
37
+ input_frames = self.data_uv_gk[idx][:self.horizon]
38
+ output_frames = self.data_uv_gk[idx][self.horizon:2*self.horizon]
39
+ input_frames = (input_frames - self.mean) / self.std
40
+ output_frames = (output_frames - self.mean) / self.std
41
+ return input_frames, output_frames
42
+
43
+ def load_data(data_path, batch_size, val_batch_size, horizon, num_workers):
44
+ dataset = WeatherDataset(data_path=data_path+'/kg_all_20_mask_latmean.h5', horizon=horizon, transform=None)
45
+ dataset_size = len(dataset)
46
+ train_size = int(dataset_size * 0.8)
47
+ val_size = int(dataset_size * 0.1)
48
+ test_size = dataset_size - train_size - val_size
49
+
50
+ train_dataset, val_dataset, test_dataset = data.random_split(dataset, [train_size, val_size, test_size])
51
+
52
+ train_sampler = DistributedSampler(train_dataset)
53
+ val_sampler = DistributedSampler(val_dataset)
54
+ test_sampler = DistributedSampler(test_dataset)
55
+
56
+ dataloader_train = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=False,
57
+ num_workers=num_workers, drop_last=True)
58
+ dataloader_validation = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=False,
59
+ num_workers=num_workers, drop_last=True)
60
+ dataloader_test = DataLoader(test_dataset, batch_size=val_batch_size, sampler=test_sampler, pin_memory=False,
61
+ num_workers=num_workers, drop_last=True)
62
+ mean, std = 0, 1
63
+
64
+ return dataloader_train, dataloader_validation, dataloader_test, mean, std
65
+
66
+
67
+
68
+
69
+
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi-checkpoint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch.utils.data.distributed import DistributedSampler
5
+ import netCDF4 as nc
6
+ import numpy as np
7
+
8
+ class OceanCurrentDataset(Dataset):
9
+ def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
10
+ """
11
+ 海洋流数据集类
12
+ :param data_path: NetCDF文件路径
13
+ :param input_steps: 输入时间步数
14
+ :param output_steps: 预测时间步数
15
+ :param transform: 数据增强变换
16
+ """
17
+ self.data_path = data_path
18
+ self.input_steps = input_steps
19
+ self.output_steps = output_steps
20
+ self.transform = transform
21
+ self.total_steps = input_steps + output_steps
22
+
23
+ # 加载并预处理数据
24
+ self.data = self._load_and_process_data()
25
+ self.mean, self.std = 0, 1
26
+
27
+ def _load_and_process_data(self):
28
+ """加载和处理NetCDF数据"""
29
+ with nc.Dataset(self.data_path, 'r') as ds:
30
+ # 处理缺失值
31
+ def process_var(var):
32
+ arr = var[:]
33
+ if '_FillValue' in var.ncattrs():
34
+ fill_value = var._FillValue
35
+ arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
36
+ return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
37
+
38
+ # 加载并合并UV分量
39
+ ugos = process_var(ds['ugos']) # (time, lat, lon)
40
+ vgos = process_var(ds['vgos'])
41
+
42
+ # 调整维度顺序 [time, channels, lat, lon]
43
+ return torch.stack([ugos, vgos], dim=1)
44
+
45
+ def _compute_stats(self):
46
+ """计算训练集的统计量"""
47
+ return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
48
+
49
+ def __len__(self):
50
+ return len(self.data) - self.total_steps + 1
51
+
52
+ def __getitem__(self, idx):
53
+ window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
54
+
55
+ window = (window - self.mean) / self.std
56
+
57
+ # 分割输入输出
58
+ input_seq = window[:self.input_steps]
59
+ target_seq = window[self.input_steps:]
60
+
61
+ if self.transform:
62
+ input_seq = self.transform(input_seq)
63
+ target_seq = self.transform(target_seq)
64
+
65
+ return input_seq, target_seq
66
+
67
+ def create_dataloaders(config):
68
+ full_dataset = OceanCurrentDataset(
69
+ data_path=config['data_path'],
70
+ input_steps=config['input_steps'],
71
+ output_steps=config['output_steps']
72
+ )
73
+
74
+ train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
75
+ val_size = 500
76
+ test_size = len(full_dataset) - train_size - val_size
77
+
78
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
79
+ full_dataset, [train_size, val_size, test_size],
80
+ generator=torch.Generator().manual_seed(config['seed'])
81
+ )
82
+
83
+ train_sampler = DistributedSampler(train_dataset, shuffle=True)
84
+ val_sampler = DistributedSampler(val_dataset, shuffle=False)
85
+ test_sampler = DistributedSampler(test_dataset, shuffle=False)
86
+
87
+ dataloader_train = DataLoader(
88
+ train_dataset,
89
+ batch_size=config['batch_size'],
90
+ sampler=train_sampler,
91
+ num_workers=config['num_workers'],
92
+ pin_memory=True,
93
+ drop_last=True
94
+ )
95
+
96
+ dataloader_val = DataLoader(
97
+ val_dataset,
98
+ batch_size=config['val_batch_size'],
99
+ sampler=val_sampler,
100
+ num_workers=config['num_workers'],
101
+ pin_memory=True,
102
+ drop_last=True
103
+ )
104
+
105
+ dataloader_test = DataLoader(
106
+ test_dataset,
107
+ batch_size=config['val_batch_size'],
108
+ sampler=test_sampler,
109
+ num_workers=config['num_workers'],
110
+ pin_memory=True,
111
+ drop_last=True
112
+ )
113
+
114
+ return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
115
+
116
+ # config = {
117
+ # 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
118
+ # 'input_steps': 10,
119
+ # 'output_steps': 10,
120
+ # 'batch_size': 1,
121
+ # 'val_batch_size': 1,
122
+ # 'num_workers': 8,
123
+ # 'seed': 42
124
+ # }
125
+ # dist.init_process_group(backend='nccl')
126
+
127
+ # train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
128
+
129
+ # for sample_input, sample_target in train_loader:
130
+ # print(sample_input.shape, sample_target.shape)
131
+ # print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
132
+ # print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
133
+ # print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
134
+ # break
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_128-checkpoint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch.utils.data.distributed import DistributedSampler
5
+ import netCDF4 as nc
6
+ import numpy as np
7
+
8
+ class OceanCurrentDataset(Dataset):
9
+ def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
10
+ """
11
+ 海洋流数据集类
12
+ :param data_path: NetCDF文件路径
13
+ :param input_steps: 输入时间步数
14
+ :param output_steps: 预测时间步数
15
+ :param transform: 数据增强变换
16
+ """
17
+ self.data_path = data_path
18
+ self.input_steps = input_steps
19
+ self.output_steps = output_steps
20
+ self.transform = transform
21
+ self.total_steps = input_steps + output_steps
22
+
23
+ # 加载并预处理数据
24
+ self.data = self._load_and_process_data()
25
+ self.mean, self.std = 0, 1
26
+
27
+ def _load_and_process_data(self):
28
+ """加载和处理NetCDF数据"""
29
+ with nc.Dataset(self.data_path, 'r') as ds:
30
+ # 处理缺失值
31
+ def process_var(var):
32
+ arr = var[:]
33
+ if '_FillValue' in var.ncattrs():
34
+ fill_value = var._FillValue
35
+ arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
36
+ return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
37
+
38
+ # 加载并合并UV分量
39
+ ugos = process_var(ds['ugos']) # (time, lat, lon)
40
+ vgos = process_var(ds['vgos'])
41
+
42
+ # 调整维度顺序 [time, channels, lat, lon]
43
+ return torch.stack([ugos, vgos], dim=1)
44
+
45
+ def _compute_stats(self):
46
+ """计算训练集的统计量"""
47
+ return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
48
+
49
+ def __len__(self):
50
+ return len(self.data) - self.total_steps + 1
51
+
52
+ def __getitem__(self, idx):
53
+ window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
54
+
55
+ window = (window - self.mean) / self.std
56
+
57
+ # 分割输入输出
58
+ input_seq = window[:self.input_steps]
59
+ target_seq = window[self.input_steps:]
60
+
61
+ if self.transform:
62
+ input_seq = self.transform(input_seq)
63
+ target_seq = self.transform(target_seq)
64
+
65
+ return input_seq[:,:,::2,::2], target_seq[:,:,::2,::2]
66
+
67
+ def create_dataloaders(config):
68
+ full_dataset = OceanCurrentDataset(
69
+ data_path=config['data_path'],
70
+ input_steps=config['input_steps'],
71
+ output_steps=config['output_steps']
72
+ )
73
+
74
+ train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
75
+ val_size = 500
76
+ test_size = len(full_dataset) - train_size - val_size
77
+
78
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
79
+ full_dataset, [train_size, val_size, test_size],
80
+ generator=torch.Generator().manual_seed(config['seed'])
81
+ )
82
+
83
+ train_sampler = DistributedSampler(train_dataset, shuffle=True)
84
+ val_sampler = DistributedSampler(val_dataset, shuffle=False)
85
+ test_sampler = DistributedSampler(test_dataset, shuffle=False)
86
+
87
+ dataloader_train = DataLoader(
88
+ train_dataset,
89
+ batch_size=config['batch_size'],
90
+ sampler=train_sampler,
91
+ num_workers=config['num_workers'],
92
+ pin_memory=True,
93
+ drop_last=True
94
+ )
95
+
96
+ dataloader_val = DataLoader(
97
+ val_dataset,
98
+ batch_size=config['val_batch_size'],
99
+ sampler=val_sampler,
100
+ num_workers=config['num_workers'],
101
+ pin_memory=True,
102
+ drop_last=True
103
+ )
104
+
105
+ dataloader_test = DataLoader(
106
+ test_dataset,
107
+ batch_size=config['val_batch_size'],
108
+ sampler=test_sampler,
109
+ num_workers=config['num_workers'],
110
+ pin_memory=True,
111
+ drop_last=True
112
+ )
113
+
114
+ return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
115
+
116
+ # config = {
117
+ # 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
118
+ # 'input_steps': 10,
119
+ # 'output_steps': 10,
120
+ # 'batch_size': 1,
121
+ # 'val_batch_size': 1,
122
+ # 'num_workers': 8,
123
+ # 'seed': 42
124
+ # }
125
+ # dist.init_process_group(backend='nccl')
126
+
127
+ # train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
128
+
129
+ # for sample_input, sample_target in train_loader:
130
+ # print(sample_input.shape, sample_target.shape)
131
+ # print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
132
+ # print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
133
+ # print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
134
+ # break
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_64-checkpoint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch.utils.data.distributed import DistributedSampler
5
+ import netCDF4 as nc
6
+ import numpy as np
7
+
8
+ class OceanCurrentDataset(Dataset):
9
+ def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
10
+ """
11
+ 海洋流数据集类
12
+ :param data_path: NetCDF文件路径
13
+ :param input_steps: 输入时间步数
14
+ :param output_steps: 预测时间步数
15
+ :param transform: 数据增强变换
16
+ """
17
+ self.data_path = data_path
18
+ self.input_steps = input_steps
19
+ self.output_steps = output_steps
20
+ self.transform = transform
21
+ self.total_steps = input_steps + output_steps
22
+
23
+ # 加载并预处理数据
24
+ self.data = self._load_and_process_data()
25
+ self.mean, self.std = 0, 1
26
+
27
+ def _load_and_process_data(self):
28
+ """加载和处理NetCDF数据"""
29
+ with nc.Dataset(self.data_path, 'r') as ds:
30
+ # 处理缺失值
31
+ def process_var(var):
32
+ arr = var[:]
33
+ if '_FillValue' in var.ncattrs():
34
+ fill_value = var._FillValue
35
+ arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
36
+ return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
37
+
38
+ # 加载并合并UV分量
39
+ ugos = process_var(ds['ugos']) # (time, lat, lon)
40
+ vgos = process_var(ds['vgos'])
41
+
42
+ # 调整维度顺序 [time, channels, lat, lon]
43
+ return torch.stack([ugos, vgos], dim=1)
44
+
45
+ def _compute_stats(self):
46
+ """计算训练集的统计量"""
47
+ return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
48
+
49
+ def __len__(self):
50
+ return len(self.data) - self.total_steps + 1
51
+
52
+ def __getitem__(self, idx):
53
+ window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
54
+
55
+ window = (window - self.mean) / self.std
56
+
57
+ # 分割输入输出
58
+ input_seq = window[:self.input_steps]
59
+ target_seq = window[self.input_steps:]
60
+
61
+ if self.transform:
62
+ input_seq = self.transform(input_seq)
63
+ target_seq = self.transform(target_seq)
64
+
65
+ return input_seq[:,:,::4,::4], target_seq[:,:,::4,::4]
66
+
67
+ def create_dataloaders(config):
68
+ full_dataset = OceanCurrentDataset(
69
+ data_path=config['data_path'],
70
+ input_steps=config['input_steps'],
71
+ output_steps=config['output_steps']
72
+ )
73
+
74
+ train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
75
+ val_size = 500
76
+ test_size = len(full_dataset) - train_size - val_size
77
+
78
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
79
+ full_dataset, [train_size, val_size, test_size],
80
+ generator=torch.Generator().manual_seed(config['seed'])
81
+ )
82
+
83
+ train_sampler = DistributedSampler(train_dataset, shuffle=True)
84
+ val_sampler = DistributedSampler(val_dataset, shuffle=False)
85
+ test_sampler = DistributedSampler(test_dataset, shuffle=False)
86
+
87
+ dataloader_train = DataLoader(
88
+ train_dataset,
89
+ batch_size=config['batch_size'],
90
+ sampler=train_sampler,
91
+ num_workers=config['num_workers'],
92
+ pin_memory=True,
93
+ drop_last=True
94
+ )
95
+
96
+ dataloader_val = DataLoader(
97
+ val_dataset,
98
+ batch_size=config['val_batch_size'],
99
+ sampler=val_sampler,
100
+ num_workers=config['num_workers'],
101
+ pin_memory=True,
102
+ drop_last=True
103
+ )
104
+
105
+ dataloader_test = DataLoader(
106
+ test_dataset,
107
+ batch_size=config['val_batch_size'],
108
+ sampler=test_sampler,
109
+ num_workers=config['num_workers'],
110
+ pin_memory=True,
111
+ drop_last=True
112
+ )
113
+
114
+ return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
115
+
116
+ # config = {
117
+ # 'data_path': '/jizhicfs/easyluwu/ocean_project/kuro/KURO.nc',
118
+ # 'input_steps': 10,
119
+ # 'output_steps': 10,
120
+ # 'batch_size': 1,
121
+ # 'val_batch_size': 1,
122
+ # 'num_workers': 8,
123
+ # 'seed': 42
124
+ # }
125
+ # dist.init_process_group(backend='nccl')
126
+
127
+ # train_loader, val_loader, test_loader, data_mean, data_std = create_dataloaders(config)
128
+
129
+ # for sample_input, sample_target in train_loader:
130
+ # print(sample_input.shape, sample_target.shape)
131
+ # print(f"输入数据范围: [{sample_input.min():.2f}, {sample_input.max():.2f}]")
132
+ # print(f"NaN值存在性: {torch.isnan(sample_input).any().item()}")
133
+ # print(f"Inf值存在性: {torch.isinf(sample_input).any().item()}")
134
+ # break
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_kuroshio_ruiqi_single-checkpoint.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import netCDF4 as nc
4
+ import numpy as np
5
+
6
+ class OceanCurrentDataset(Dataset):
7
+ def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
8
+ self.data_path = data_path
9
+ self.input_steps = input_steps
10
+ self.output_steps = output_steps
11
+ self.transform = transform
12
+ self.total_steps = input_steps + output_steps
13
+
14
+ # Load and process data
15
+ self.data = self._load_and_process_data()
16
+ self.mean, self.std = self._compute_stats()
17
+
18
+ def _load_and_process_data(self):
19
+ """Load and process NetCDF data"""
20
+ with nc.Dataset(self.data_path, 'r') as ds:
21
+ def process_var(var):
22
+ arr = var[:]
23
+ if '_FillValue' in var.ncattrs():
24
+ fill_value = var._FillValue
25
+ arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
26
+ return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
27
+
28
+ ugos = process_var(ds['ugos']) # (time, lat, lon)
29
+ vgos = process_var(ds['vgos'])
30
+
31
+ return torch.stack([ugos, vgos], dim=1) # [time, channels, lat, lon]
32
+
33
+ def _compute_stats(self):
34
+ """Compute dataset statistics"""
35
+ return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
36
+
37
+ def __len__(self):
38
+ return len(self.data) - self.total_steps + 1
39
+
40
+ def __getitem__(self, idx):
41
+ window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
42
+ window = (window - 0) / 1
43
+
44
+ input_seq = window[:self.input_steps]
45
+ target_seq = window[self.input_steps:]
46
+
47
+ if self.transform:
48
+ input_seq = self.transform(input_seq)
49
+ target_seq = self.transform(target_seq)
50
+
51
+ return input_seq, target_seq
52
+
53
+ def create_dataloaders(config):
54
+ full_dataset = OceanCurrentDataset(
55
+ data_path=config['data_path'],
56
+ input_steps=config['input_steps'],
57
+ output_steps=config['output_steps']
58
+ )
59
+
60
+ train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
61
+ val_size = 500
62
+ test_size = len(full_dataset) - train_size - val_size
63
+
64
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
65
+ full_dataset, [train_size, val_size, test_size],
66
+ generator=torch.Generator().manual_seed(config['seed'])
67
+ )
68
+
69
+ dataloader_train = DataLoader(
70
+ train_dataset,
71
+ batch_size=config['batch_size'],
72
+ shuffle=True, # Changed from DistributedSampler
73
+ num_workers=config['num_workers'],
74
+ pin_memory=True,
75
+ drop_last=True
76
+ )
77
+
78
+ dataloader_val = DataLoader(
79
+ val_dataset,
80
+ batch_size=config['val_batch_size'],
81
+ shuffle=False,
82
+ num_workers=config['num_workers'],
83
+ pin_memory=True,
84
+ drop_last=True
85
+ )
86
+
87
+ dataloader_test = DataLoader(
88
+ test_dataset,
89
+ batch_size=config['val_batch_size'],
90
+ shuffle=False,
91
+ num_workers=config['num_workers'],
92
+ pin_memory=True,
93
+ drop_last=True
94
+ )
95
+
96
+ return dataloader_train, dataloader_val, dataloader_test,0, 1
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/dataloader_test-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/input_output_animation-checkpoint.gif ADDED

Git LFS Details

  • SHA256: ad3a31851d266a90b2badc436562f2151145fd07c6b24c04c47bf8276af26bbf
  • Pointer size: 131 Bytes
  • Size of remote file: 171 kB
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuro_vis-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Exp3_Kuroshio_forecasting/dataloader_api/.ipynb_checkpoints/kuroshio_animation-checkpoint.gif ADDED

Git LFS Details

  • SHA256: 76f4227a35cb70c8ed629c37435beb5f9d437c935fe15ffb70f4c640d34d1675
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB