sqfoo commited on
Commit
6021dd1
·
verified ·
1 Parent(s): ecc0bbb

Upload 99 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .ipynb_checkpoints/demo-checkpoint.ipynb +37 -0
  3. LICENSE +21 -0
  4. README.md +155 -3
  5. __pycache__/utilspp.cpython-38.pyc +0 -0
  6. __pycache__/utilspp.cpython-39.pyc +0 -0
  7. ckpts/hko-7/.DS_Store +0 -0
  8. ckpts/hko-7/stldm/n2n_timesteps-50_sampling_timesteps-20_objective-pred_v_final.pt +3 -0
  9. data/.DS_Store +0 -0
  10. data/HKO-7/samplers/hko7_cloudy_days_t20_test.txt +199 -0
  11. data/HKO-7/samplers/hko7_cloudy_days_t20_test.txt.pkl +3 -0
  12. data/HKO-7/samplers/hko7_cloudy_days_t20_train.txt +1224 -0
  13. data/HKO-7/samplers/hko7_cloudy_days_t20_train.txt.pkl +3 -0
  14. data/SEVIR/CATALOG.csv +3 -0
  15. data/__pycache__/config.cpython-39.pyc +0 -0
  16. data/__pycache__/dutils.cpython-38.pyc +0 -0
  17. data/__pycache__/dutils.cpython-39.pyc +0 -0
  18. data/__pycache__/loader.cpython-38.pyc +0 -0
  19. data/__pycache__/loader.cpython-39.pyc +0 -0
  20. data/config.py +62 -0
  21. data/dutils.py +1212 -0
  22. data/loader.py +49 -0
  23. data/sample_data.npy +3 -0
  24. demo.ipynb +0 -0
  25. ens_eval.py +164 -0
  26. ens_gen.py +124 -0
  27. nowcasting/__init__.py +0 -0
  28. nowcasting/__pycache__/__init__.cpython-310.pyc +0 -0
  29. nowcasting/__pycache__/__init__.cpython-38.pyc +0 -0
  30. nowcasting/__pycache__/__init__.cpython-39.pyc +0 -0
  31. nowcasting/__pycache__/config.cpython-310.pyc +0 -0
  32. nowcasting/__pycache__/config.cpython-38.pyc +0 -0
  33. nowcasting/__pycache__/config.cpython-39.pyc +0 -0
  34. nowcasting/__pycache__/hko_iterator.cpython-310.pyc +0 -0
  35. nowcasting/__pycache__/hko_iterator.cpython-38.pyc +0 -0
  36. nowcasting/__pycache__/hko_iterator.cpython-39.pyc +0 -0
  37. nowcasting/__pycache__/image.cpython-310.pyc +0 -0
  38. nowcasting/__pycache__/image.cpython-38.pyc +0 -0
  39. nowcasting/__pycache__/image.cpython-39.pyc +0 -0
  40. nowcasting/__pycache__/mask.cpython-310.pyc +0 -0
  41. nowcasting/__pycache__/mask.cpython-38.pyc +0 -0
  42. nowcasting/__pycache__/mask.cpython-39.pyc +0 -0
  43. nowcasting/__pycache__/utils.cpython-310.pyc +0 -0
  44. nowcasting/__pycache__/utils.cpython-38.pyc +0 -0
  45. nowcasting/__pycache__/utils.cpython-39.pyc +0 -0
  46. nowcasting/config.py +302 -0
  47. nowcasting/encoder_forecaster.py +556 -0
  48. nowcasting/helpers/__init__.py +0 -0
  49. nowcasting/helpers/__pycache__/__init__.cpython-310.pyc +0 -0
  50. nowcasting/helpers/__pycache__/__init__.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/SEVIR/CATALOG.csv filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/demo-checkpoint.ipynb ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6e716a85-c1aa-4b14-90e5-181840d54cc6",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
12
+ "import numpy as np\n"
13
+ ]
14
+ }
15
+ ],
16
+ "metadata": {
17
+ "kernelspec": {
18
+ "display_name": "Python 3 (ipykernel)",
19
+ "language": "python",
20
+ "name": "python3"
21
+ },
22
+ "language_info": {
23
+ "codemirror_mode": {
24
+ "name": "ipython",
25
+ "version": 3
26
+ },
27
+ "file_extension": ".py",
28
+ "mimetype": "text/x-python",
29
+ "name": "python",
30
+ "nbconvert_exporter": "python",
31
+ "pygments_lexer": "ipython3",
32
+ "version": "3.8.18"
33
+ }
34
+ },
35
+ "nbformat": 4,
36
+ "nbformat_minor": 5
37
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 sqfoo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,155 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STLDM: Spatio-Temporal Latent Diffusion Model for Precipitation Nowcasting
2
+
3
+ This is the official code implementation of the paper [**STLDM: Spatio-Temporal Latent Diffusion Model for Precipitation Nowcasting**](https://openreview.net/forum?id=f4oJwXn3qg) submitted to TMLR.
4
+
5
+ ## Setup Environment
6
+
7
+ Create a new conda environment:
8
+
9
+ ```bash
10
+ conda create -n stldm python=3.9
11
+ conda activate stldm
12
+ ```
13
+
14
+ Install related packages:
15
+
16
+ ```bash
17
+ pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ ## Data Preparation
22
+
23
+ There are three radar reflectivity datasets being evaluated with **STLDM** and other baselines: SEVIR, HKO-7, and MeteoNet.
24
+
25
+ ### SEVIR
26
+
27
+ For the SEVIR dataset, please refer to [https://github.com/amazon-science/earth-forecasting-transformer](https://github.com/amazon-science/earth-forecasting-transformer) for downloading the SEVIR dataset. Please make sure the downloaded files are stored in the following way:
28
+
29
+ ```
30
+ data/
31
+ ├─ SEVIR/
32
+ │ ├─ data/
33
+ │ │ └─ vil/
34
+ | | ├─ 2017/
35
+ | | ├─ 2018/
36
+ | | └─ 2019/
37
+ │ └─ CATALOG.csv
38
+ ├─ ...
39
+ ```
40
+
41
+ ### HKO-7
42
+
43
+ For the HKO-7 dataset, please refer to [https://github.com/sxjscience/HKO-7](https://github.com/sxjscience/HKO-7) for downloading the HKO-7 dataset. Please make sure the downloaded files are stored in the following way:
44
+
45
+ ```
46
+ data/
47
+ ├─ HKO-7/
48
+ │ ├─ hko_data/
49
+ │ │ └─ mask_dat.npz
50
+ │ ├─ radarPNG/
51
+ │ │ ├─ 2009/
52
+ │ │ └─ ...
53
+ │ ├─ radarPNG_mask/
54
+ │ │ ├─ 2009/
55
+ │ │ └─ ...
56
+ │ └─ samplers/
57
+ ├─ ...
58
+ ```
59
+
60
+ The samplers files have been saved inside ```data/HKO-7/samplers/``` for you.
61
+
62
+ ### MeteoNet
63
+
64
+ For the MeteoNet dataset, please refer to [https://github.com/DeminYu98/DiffCast](https://github.com/DeminYu98/DiffCast) to download the pre-processed h5 file, or you can follow the provided instruction to pre-process the raw dataset found on [the official MeteoNet website](https://meteofrance.github.io/meteonet/english/data/rain-radar/). Please make sure the file is named as ```meteo.h5``` and stored in the following way:
65
+
66
+ ```
67
+ data/
68
+ ├─ meteonet/
69
+ │ ├─ meteo.h5
70
+ ├─ ...
71
+ ```
72
+
73
+ ### FYI 💡
74
+
75
+ If you really want to change the suggested way to save those data files above, remember to update the corresponding file directories as well in the following ways:
76
+
77
+ - The SEVIR dataset: ```SEVIR_ROOT_DIR``` in ```data/dutils.py```
78
+ - The HKO-7 dataset: ```__C.ROOT_DIR```, ```possible_hko_png_paths``` and ```possible_hko_mask_paths``` in ```nowcasting/config.py```
79
+ - The MeteoNet dataset: ```METEO_FILE_DIR``` in ```data/dutils.py```
80
+
81
+ ## Training
82
+
83
+ You can train the **STLDM** with the script ```train.py``` with the following command:
84
+
85
+ ``` bash
86
+ python train.py -d HKO7_5_20 --seq_len 5 --out_len 20 -m STLDM_HKO --type "3D"
87
+ ```
88
+
89
+ In particular, there are a few arguments to set:
90
+
91
+ - ```-d``` / ```--dataset``` : The dataset config found in ```data/config.py```. Please set the corresponding ```--seq_len``` (input sequence length) and ```--out_len``` (output sequence length) as well.
92
+ - ```-m``` / ```--model``` : The STLDM config found in ```stldm/__init__.py```
93
+ - ```--type```: is to specify whether is ```"3D"``` (**Spatiotemporal**) or ```"2S"``` (**Spatial**) Visual Enhancement.
94
+
95
+ ## Evaluation and Sampling
96
+
97
+ ### Evaluation
98
+
99
+ For the evaluation of **STLDM**, we generate ten ensemble predictions of **STLDM** and evaluate them.
100
+
101
+ First, let's run the ensemble generation script, ```ens_gen.py``` to generate the ensemble prediction and save it as an npy file, with the following command:
102
+
103
+ ```bash
104
+ python ens_gen.py -d HKO7_5_20 -m STLDM_HKO --type "3D" -f "model_checkpoint" --c_str 1.0 --e_id 0
105
+ ```
106
+
107
+ Other than the arguments above, there are stll a few parameters to set:
108
+
109
+ - ```-f```: the relative/absolute path to **STLDM** checkpoint
110
+ - ```--c_str```: Classifier-Free Guidance strength, it is disabled when set to 0.0
111
+ - ```--e_id```: Represent the $e\_id$ th ensemble prediction, starting from 0
112
+
113
+ Then, we run the evaluation script, ```ens_eval.py``` to evaluate the generated ensemble predictions (labeled starting from 0) with the following command:
114
+
115
+ ```bash
116
+ python ens_eval.py -d HKO7_5_20 --out_len 20 --e_file "filepath_{}.npy" --ens_no 10
117
+ ```
118
+
119
+ Again, other than the arguments specified above, there are still a few parameters to set:
120
+
121
+ - ```--e_file```: The format of the ensemble predictions, replace the $e\_id$ by $\{ \}$. Make sure the predictions are labelled, starting from 0.
122
+ - ```--ens_no```: Total number of ensemble predictions, i.e., 10
123
+
124
+ ### Sampling
125
+
126
+ Other than the evaluation process, we also provide a demo file, ```demo.ipynb``` to show you how to set up and call the **STLDM$** to generate samples for your side implementation. In this demo, we include three different configurations:
127
+
128
+ - **SpatioTemporal** Visual Enhancement with image size of *128*
129
+ - **Spatial** Visual Enhancement with image size of *128*
130
+ - **SpatioTemporal** Visual Enhancement with image size of *256*
131
+
132
+ You can download their corresponding modek checkpoints from [this link](https://hkustconnect-my.sharepoint.com/:f:/g/personal/sqfoo_connect_ust_hk/IgATefXlByydRaKlqYnC3hIyAUNk5ftNZBXJz0yKa7d89yE?e=BLk0V3) ([Alternative link](https://drive.google.com/drive/folders/1bCQBt5JPQ-JzHSy8ruYhj6p32Q5uEECM?usp=sharing)).
133
+
134
+ ## Credits and Acknowledgment
135
+
136
+ We would like to thank these developers and credit their code.
137
+
138
+ - [FACL](https://github.com/argenycw/FACL)
139
+ - [OpenSTL](https://github.com/chengtan9907/OpenSTL/blob/OpenSTL-Lightning/README.md)
140
+ - [DiffCast](https://github.com/DeminYu98/DiffCast)
141
+ - [denoising_diffusion_pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch)
142
+
143
+ ## Citation
144
+
145
+ If you find this work helpful, please cite the following:
146
+
147
+ ```bib
148
+ @article{foo2025stldm,
149
+ author = {Foo, Shi Quan and Wong, Chi-Ho and Gao, Zhihan and Yeung, Dit-Yan and Wong, Ka-Hing and Wong, Wai-Kin},
150
+ title = {STLDM: Spatio-Temporal Latent Diffusion Model for Precipitation Nowcasting},
151
+ journal = {Transactions on Machine Learning Research},
152
+ year = {2025},
153
+ url = {https://openreview.net/forum?id=f4oJwXn3qg},
154
+ }
155
+ ```
__pycache__/utilspp.cpython-38.pyc ADDED
Binary file (12.7 kB). View file
 
__pycache__/utilspp.cpython-39.pyc ADDED
Binary file (12.8 kB). View file
 
ckpts/hko-7/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ckpts/hko-7/stldm/n2n_timesteps-50_sampling_timesteps-20_objective-pred_v_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b258368931d2bf66843e3e0d763e6437e68ce50152b05ccc9a12cf3b0ecac6ff
3
+ size 324258726
data/.DS_Store ADDED
Binary file (8.2 kB). View file
 
data/HKO-7/samplers/hko7_cloudy_days_t20_test.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20150110, 43.75619097948626
2
+ 20150111, 149.95275616718968
3
+ 20150112, 374.587865563982
4
+ 20150113, 216.4265648608205
5
+ 20150215, 23.97378254300325
6
+ 20150221, 47.148593604108314
7
+ 20150222, 109.70054189911669
8
+ 20150223, 61.28945820984426
9
+ 20150312, 98.79946209611806
10
+ 20150319, 31.994431187529052
11
+ 20150320, 60.65151818921431
12
+ 20150323, 28.129440031220827
13
+ 20150325, 22.872230575894932
14
+ 20150331, 21.820746346176193
15
+ 20150403, 34.57228669078336
16
+ 20150404, 24.82365106345886
17
+ 20150405, 31.665026804393307
18
+ 20150407, 33.26660849235528
19
+ 20150408, 94.84768094781496
20
+ 20150409, 22.203001174428096
21
+ 20150410, 82.69853265922829
22
+ 20150411, 353.58096869915147
23
+ 20150412, 31.1800565144119
24
+ 20150417, 44.918818645397494
25
+ 20150418, 24.080770753428627
26
+ 20150419, 64.81862887174572
27
+ 20150420, 134.63497737244305
28
+ 20150421, 21.996475658124133
29
+ 20150423, 32.6925866602743
30
+ 20150428, 23.7152796014239
31
+ 20150429, 54.57041110965831
32
+ 20150430, 44.70511117648767
33
+ 20150501, 46.1183677286727
34
+ 20150502, 44.90144663819153
35
+ 20150503, 49.670905610762425
36
+ 20150504, 67.67321630927475
37
+ 20150505, 67.95217779820261
38
+ 20150506, 92.13506781003022
39
+ 20150507, 66.49604560574231
40
+ 20150508, 41.54096078132263
41
+ 20150509, 141.6098750437675
42
+ 20150510, 118.69222291230821
43
+ 20150511, 123.33292957808122
44
+ 20150512, 22.11332483437936
45
+ 20150515, 31.402990251627152
46
+ 20150516, 190.73493669368895
47
+ 20150517, 123.26633275706115
48
+ 20150518, 55.93378915330079
49
+ 20150519, 75.29142622617388
50
+ 20150520, 286.1066578771502
51
+ 20150521, 327.0232328350187
52
+ 20150522, 197.9260324049861
53
+ 20150523, 492.82381178085785
54
+ 20150524, 266.1958120859484
55
+ 20150525, 172.35522310486695
56
+ 20150526, 170.2591961224431
57
+ 20150527, 55.43129521879359
58
+ 20150530, 87.61113998576245
59
+ 20150531, 101.77555679044627
60
+ 20150601, 90.33209807938168
61
+ 20150602, 38.3629325688609
62
+ 20150603, 26.60613086936309
63
+ 20150604, 26.28305530131335
64
+ 20150605, 76.74834212330765
65
+ 20150606, 61.74615059129474
66
+ 20150607, 32.53982813226406
67
+ 20150609, 41.16127709495584
68
+ 20150610, 41.9677021588796
69
+ 20150611, 51.02534211601307
70
+ 20150612, 47.988766671025104
71
+ 20150613, 38.858003544862854
72
+ 20150614, 26.969955794817928
73
+ 20150615, 22.181291719049337
74
+ 20150621, 53.58603615925537
75
+ 20150622, 98.10795233321713
76
+ 20150623, 178.18817356026273
77
+ 20150624, 129.40208224372384
78
+ 20150625, 38.546225592747554
79
+ 20150702, 28.222513147954437
80
+ 20150703, 21.610705049976755
81
+ 20150704, 26.45230561366806
82
+ 20150705, 55.110333856345875
83
+ 20150706, 25.856551631069088
84
+ 20150709, 235.03884457428808
85
+ 20150710, 102.79465580503033
86
+ 20150714, 24.60319509821014
87
+ 20150715, 46.689866363211955
88
+ 20150716, 63.028152240237105
89
+ 20150717, 108.734602365179
90
+ 20150718, 73.07054603963272
91
+ 20150719, 70.11114648984594
92
+ 20150720, 287.5367683272315
93
+ 20150721, 322.20042222367505
94
+ 20150722, 186.14272948762843
95
+ 20150723, 118.66465379474661
96
+ 20150724, 127.66547082025804
97
+ 20150725, 48.57482203045095
98
+ 20150726, 44.40593104429272
99
+ 20150727, 32.61881900860065
100
+ 20150728, 32.448406627731295
101
+ 20150729, 32.201990898128784
102
+ 20150730, 30.077929415097636
103
+ 20150809, 59.961966454556034
104
+ 20150810, 126.86668727386788
105
+ 20150811, 76.57628991602742
106
+ 20150812, 96.5194674223856
107
+ 20150813, 252.00741877480158
108
+ 20150814, 179.07121160942586
109
+ 20150815, 152.54204565793245
110
+ 20150816, 46.21115141369047
111
+ 20150817, 36.35543464781746
112
+ 20150820, 48.975007445664815
113
+ 20150825, 29.65178480246265
114
+ 20150826, 102.09460952028128
115
+ 20150827, 120.86394627498836
116
+ 20150828, 104.83123928550674
117
+ 20150829, 276.51425990091815
118
+ 20150830, 289.6720114087301
119
+ 20150831, 225.63438807531384
120
+ 20150901, 231.3144142622618
121
+ 20150902, 319.48056572524405
122
+ 20150903, 135.36383892098507
123
+ 20150904, 21.66859854428173
124
+ 20150905, 24.239027087691777
125
+ 20150906, 23.229757345646593
126
+ 20150907, 159.19597334815205
127
+ 20150908, 91.84391777544349
128
+ 20150909, 48.96262528448881
129
+ 20150910, 43.52720809361923
130
+ 20150911, 41.92730688487913
131
+ 20150912, 43.19016518479776
132
+ 20150913, 34.703902617968396
133
+ 20150914, 29.62716214841933
134
+ 20150915, 51.113737251569034
135
+ 20150916, 100.1013191538819
136
+ 20150917, 51.71855559932305
137
+ 20150918, 59.26835556136683
138
+ 20150919, 40.77983042044398
139
+ 20150920, 78.79494620961182
140
+ 20150921, 161.67126445548584
141
+ 20150922, 49.63578440986749
142
+ 20150923, 36.10682937237395
143
+ 20150924, 50.429000532504666
144
+ 20150925, 59.50495138159431
145
+ 20150926, 59.904413770716616
146
+ 20150927, 58.42647587604603
147
+ 20150928, 78.84617601551604
148
+ 20150929, 41.22325771443514
149
+ 20150930, 44.00618280741514
150
+ 20151001, 50.48162457691409
151
+ 20151002, 25.243249226774044
152
+ 20151003, 184.0588648811599
153
+ 20151004, 518.6718274203859
154
+ 20151005, 258.02547016649237
155
+ 20151006, 107.7635876118666
156
+ 20151007, 84.55949267782428
157
+ 20151008, 53.157816495234776
158
+ 20151009, 26.70240894496746
159
+ 20151010, 83.35799482798699
160
+ 20151011, 110.94956088737794
161
+ 20151012, 60.71514775104601
162
+ 20151013, 88.16609571129707
163
+ 20151014, 87.18272267512077
164
+ 20151015, 52.31939377760342
165
+ 20151016, 49.53965161552765
166
+ 20151017, 35.443420211529514
167
+ 20151018, 31.41580968880754
168
+ 20151019, 35.210431376394695
169
+ 20151023, 23.18438644089958
170
+ 20151024, 28.495018850244072
171
+ 20151025, 23.631041521385402
172
+ 20151027, 32.73702674628079
173
+ 20151028, 20.834829817343607
174
+ 20151030, 23.063971500058358
175
+ 20151102, 21.893503748256617
176
+ 20151109, 40.70959491951417
177
+ 20151110, 23.001540157641813
178
+ 20151113, 121.34067551499768
179
+ 20151114, 30.08454733990005
180
+ 20151116, 22.963587075455177
181
+ 20151118, 65.64878399581589
182
+ 20151119, 86.3667786001462
183
+ 20151120, 48.5797208727241
184
+ 20151121, 40.62472124157369
185
+ 20151122, 31.929468670095304
186
+ 20151123, 46.070741987738245
187
+ 20151124, 43.33614906584146
188
+ 20151125, 38.5326454265458
189
+ 20151203, 44.5021005854835
190
+ 20151205, 156.30654074413056
191
+ 20151206, 31.48730096466759
192
+ 20151208, 33.42894220711297
193
+ 20151209, 473.5921466396443
194
+ 20151220, 49.35841541724779
195
+ 20151221, 30.52428394496746
196
+ 20151226, 55.58330554829149
197
+ 20151227, 28.596221597512784
198
+ 20151229, 42.04237763685495
199
+ 20151230, 49.52976704149231
data/HKO-7/samplers/hko7_cloudy_days_t20_test.txt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6427753cb9aef289f621cd97ffed4cccae10996faebafada5f4819232d45f246
3
+ size 477407
data/HKO-7/samplers/hko7_cloudy_days_t20_train.txt ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20090305, 136.62595858667146
2
+ 20090306, 228.80309496601032
3
+ 20090307, 61.28797688575081
4
+ 20090313, 50.83967177330312
5
+ 20090323, 27.882427867852165
6
+ 20090324, 98.90460069444444
7
+ 20090325, 186.85773640893774
8
+ 20090327, 83.00352533110642
9
+ 20090328, 85.88496538673874
10
+ 20090329, 95.76494835251044
11
+ 20090404, 55.43760260489309
12
+ 20090405, 51.40799718880753
13
+ 20090406, 51.97289596408646
14
+ 20090407, 41.94900518654117
15
+ 20090408, 22.849034969200364
16
+ 20090409, 96.84607150151723
17
+ 20090410, 34.88548858089261
18
+ 20090411, 26.66028119188749
19
+ 20090412, 73.90951701243608
20
+ 20090413, 83.42903845595072
21
+ 20090414, 57.75939424976757
22
+ 20090415, 163.3683099793701
23
+ 20090416, 180.4777461791027
24
+ 20090417, 38.150535906264544
25
+ 20090418, 136.80413615760114
26
+ 20090419, 88.94150068282197
27
+ 20090420, 37.34591705166202
28
+ 20090421, 35.38847592689446
29
+ 20090422, 46.74223053957462
30
+ 20090424, 21.330452405857738
31
+ 20090425, 251.52358005724082
32
+ 20090426, 30.056307386099487
33
+ 20090429, 41.453953426932365
34
+ 20090430, 27.32378381421432
35
+ 20090505, 21.583478977801022
36
+ 20090509, 22.712410470420732
37
+ 20090513, 32.6148927469136
38
+ 20090516, 85.55672561453974
39
+ 20090517, 109.60553775860065
40
+ 20090518, 116.95029219694324
41
+ 20090519, 89.23390664952346
42
+ 20090520, 106.5634751889944
43
+ 20090521, 82.6532027254765
44
+ 20090522, 181.06651902458165
45
+ 20090523, 501.6702046286611
46
+ 20090524, 400.99202769060906
47
+ 20090525, 173.28997614670865
48
+ 20090526, 101.19955307851002
49
+ 20090527, 53.12012853759879
50
+ 20090528, 60.35818260402138
51
+ 20090529, 25.776541979021392
52
+ 20090603, 160.16725414778008
53
+ 20090604, 97.28280759675732
54
+ 20090607, 71.03364441684101
55
+ 20090608, 99.24742143915621
56
+ 20090609, 116.8017359295095
57
+ 20090610, 82.46225544905461
58
+ 20090611, 152.62945577638308
59
+ 20090612, 90.96381378506264
60
+ 20090613, 88.4151001714319
61
+ 20090614, 118.5259904550209
62
+ 20090615, 130.086386600709
63
+ 20090616, 87.38530679771037
64
+ 20090617, 25.00878032302665
65
+ 20090620, 30.45724553986518
66
+ 20090621, 40.03228839783821
67
+ 20090622, 106.4602314330544
68
+ 20090623, 83.79152683344954
69
+ 20090624, 70.1403104660623
70
+ 20090625, 85.02154085309161
71
+ 20090626, 75.30984571129706
72
+ 20090627, 335.9534569676896
73
+ 20090628, 186.45501920803784
74
+ 20090629, 42.00070443253138
75
+ 20090703, 63.63226896646908
76
+ 20090704, 237.3798625639238
77
+ 20090705, 193.21849521298233
78
+ 20090706, 48.60273110617154
79
+ 20090710, 32.801129743433286
80
+ 20090711, 73.77359440376569
81
+ 20090712, 27.58495957548814
82
+ 20090714, 44.6260489307299
83
+ 20090715, 54.747222222222206
84
+ 20090716, 22.874410779936973
85
+ 20090718, 235.99509966294747
86
+ 20090719, 341.8592338955717
87
+ 20090720, 65.83892666201766
88
+ 20090723, 22.81097400191771
89
+ 20090724, 34.50459739803094
90
+ 20090725, 35.20270513714551
91
+ 20090726, 51.135496026557405
92
+ 20090727, 37.10169459541063
93
+ 20090728, 42.00904757235007
94
+ 20090729, 44.38810836529521
95
+ 20090730, 63.88379169572293
96
+ 20090731, 33.26692928094232
97
+ 20090803, 57.20850566250586
98
+ 20090804, 195.11130405287113
99
+ 20090805, 320.39658806950257
100
+ 20090806, 177.06629202260575
101
+ 20090807, 23.55537268276384
102
+ 20090808, 30.765311555671776
103
+ 20090810, 35.96150210866714
104
+ 20090811, 253.41603796199436
105
+ 20090812, 341.82729526237796
106
+ 20090813, 191.62245008680554
107
+ 20090814, 53.51327997344771
108
+ 20090823, 35.0751107769642
109
+ 20090824, 51.215284315434666
110
+ 20090825, 45.89880397198978
111
+ 20090828, 24.214795262377965
112
+ 20090829, 26.32466514861783
113
+ 20090830, 54.47782354137608
114
+ 20090831, 32.23210352742911
115
+ 20090901, 62.516496505985586
116
+ 20090902, 26.58246055519574
117
+ 20090905, 20.926650939969775
118
+ 20090907, 22.240835448762837
119
+ 20090908, 34.00065395950047
120
+ 20090909, 57.947631552185015
121
+ 20090910, 91.215858820922
122
+ 20090911, 82.86404510226207
123
+ 20090912, 70.03017207278481
124
+ 20090913, 56.37304160855416
125
+ 20090914, 272.6651468369081
126
+ 20090915, 285.3876847353525
127
+ 20090916, 49.86654958858544
128
+ 20090918, 151.93731350368924
129
+ 20090919, 69.07707661407484
130
+ 20090920, 36.43623096815435
131
+ 20090921, 43.216834880541164
132
+ 20090922, 46.76212733477011
133
+ 20090923, 49.67186882554626
134
+ 20090924, 52.48864079419726
135
+ 20090925, 61.88516623808692
136
+ 20090926, 37.474039327638295
137
+ 20090927, 69.27637835599721
138
+ 20090928, 463.9436457606926
139
+ 20090929, 332.59933334059735
140
+ 20090930, 66.80611588150651
141
+ 20091001, 54.379151775337064
142
+ 20091002, 28.19407234280568
143
+ 20091003, 44.98317751917714
144
+ 20091004, 37.289169463331014
145
+ 20091007, 37.08707105706647
146
+ 20091008, 61.00016235181311
147
+ 20091009, 44.652421475476515
148
+ 20091010, 23.698430417538354
149
+ 20091011, 112.88366802504649
150
+ 20091012, 35.83612582083915
151
+ 20091013, 36.11428387232682
152
+ 20091016, 60.58141761825895
153
+ 20091017, 51.14802181398186
154
+ 20091018, 71.10693136913063
155
+ 20091019, 101.7541401528359
156
+ 20091020, 201.2627237331474
157
+ 20091021, 64.73917236983361
158
+ 20091022, 63.29760327994324
159
+ 20091023, 39.95842286291259
160
+ 20091024, 35.976269213447246
161
+ 20091025, 31.75665043148535
162
+ 20091026, 55.861363900511385
163
+ 20091027, 46.790897765574144
164
+ 20091028, 30.22591944880288
165
+ 20091029, 21.87970993084734
166
+ 20091030, 23.724973122966055
167
+ 20091031, 40.22858427039748
168
+ 20091101, 33.49137192730126
169
+ 20091102, 24.3667207839377
170
+ 20091109, 22.49475020151289
171
+ 20091110, 28.50735213999301
172
+ 20091111, 66.11944244682707
173
+ 20091112, 154.01058955137148
174
+ 20091113, 59.36137043816828
175
+ 20091115, 66.11488043351932
176
+ 20091116, 102.36392502760343
177
+ 20091123, 24.5346406831125
178
+ 20091124, 34.07892132293119
179
+ 20091125, 46.277792245370364
180
+ 20091126, 39.02615244363087
181
+ 20091128, 35.47976431746862
182
+ 20091206, 24.131424337517426
183
+ 20091207, 98.9516879866922
184
+ 20091208, 66.62441923814505
185
+ 20091209, 34.3988533676197
186
+ 20091215, 56.45255150220826
187
+ 20091216, 24.40855307124593
188
+ 20091227, 81.33299646240121
189
+ 20091228, 23.943436737273355
190
+ 20091229, 88.94060556863089
191
+ 20091230, 96.75108851987449
192
+ 20100102, 134.41120754155045
193
+ 20100103, 26.636634668468155
194
+ 20100105, 37.79852993520455
195
+ 20100106, 40.25822146675964
196
+ 20100107, 57.98259385169688
197
+ 20100110, 49.90426454991865
198
+ 20100111, 95.87596648361225
199
+ 20100121, 52.24734371367966
200
+ 20100122, 48.799663673872615
201
+ 20100206, 32.67162693224081
202
+ 20100207, 209.8687973980125
203
+ 20100208, 60.44115473181079
204
+ 20100212, 27.867008077638303
205
+ 20100213, 27.533020433809863
206
+ 20100214, 66.34328727917247
207
+ 20100216, 32.234679545850774
208
+ 20100217, 25.200238987680148
209
+ 20100218, 51.988591970304505
210
+ 20100219, 88.0366902385518
211
+ 20100222, 78.82401481142492
212
+ 20100306, 26.108932618549513
213
+ 20100307, 35.84544125552069
214
+ 20100308, 20.593918344665273
215
+ 20100309, 50.63448759298
216
+ 20100324, 55.8357261157601
217
+ 20100325, 39.645025751104136
218
+ 20100401, 20.778464050151097
219
+ 20100402, 102.23865208042773
220
+ 20100405, 24.484351391794515
221
+ 20100406, 59.879344636215706
222
+ 20100407, 75.98970573678116
223
+ 20100408, 141.23548440405628
224
+ 20100409, 25.15233466992097
225
+ 20100410, 34.2181490803696
226
+ 20100411, 78.74429970798465
227
+ 20100412, 73.08766689185262
228
+ 20100413, 53.82095809361923
229
+ 20100415, 155.0404605071195
230
+ 20100417, 127.06235825421942
231
+ 20100418, 102.04714594955833
232
+ 20100420, 44.37650511390051
233
+ 20100421, 40.503602975360295
234
+ 20100422, 189.74148597309394
235
+ 20100426, 85.08088334640868
236
+ 20100427, 25.063565274872154
237
+ 20100428, 59.77439016350212
238
+ 20100429, 324.0794886633429
239
+ 20100430, 34.49387021427267
240
+ 20100503, 29.15010260489308
241
+ 20100506, 124.59078898622735
242
+ 20100507, 103.47258360936776
243
+ 20100508, 24.618598544281735
244
+ 20100509, 21.73658647486772
245
+ 20100510, 157.08738735945764
246
+ 20100512, 22.094342621894043
247
+ 20100514, 84.33014786000699
248
+ 20100515, 88.7563665882148
249
+ 20100516, 20.627662460774058
250
+ 20100519, 204.0931331834312
251
+ 20100520, 148.64851395426544
252
+ 20100521, 37.800374825662495
253
+ 20100522, 117.86880774930266
254
+ 20100523, 181.03243712953275
255
+ 20100526, 41.56335881635037
256
+ 20100527, 27.768272204715217
257
+ 20100528, 75.1496062748016
258
+ 20100529, 206.38358703074147
259
+ 20100530, 175.6673069938401
260
+ 20100531, 118.4293410041841
261
+ 20100601, 237.94262170937938
262
+ 20100602, 334.39367899378186
263
+ 20100608, 55.2173480357973
264
+ 20100609, 207.89246825604368
265
+ 20100610, 319.07378980706653
266
+ 20100611, 168.00780432792885
267
+ 20100612, 72.57029488464669
268
+ 20100613, 126.20548273332172
269
+ 20100614, 118.64496182734774
270
+ 20100615, 210.03491617271035
271
+ 20100616, 181.45767230357976
272
+ 20100617, 73.54427446536496
273
+ 20100618, 25.24199954236402
274
+ 20100621, 120.53997976958391
275
+ 20100622, 82.61656079294514
276
+ 20100623, 104.60782877604167
277
+ 20100624, 120.09928467137378
278
+ 20100625, 218.53039047971882
279
+ 20100626, 407.84667342224554
280
+ 20100627, 376.6243358830195
281
+ 20100628, 250.77930849750115
282
+ 20100629, 35.55333202580195
283
+ 20100713, 20.72497693659926
284
+ 20100714, 24.57621055613668
285
+ 20100715, 64.12604257467457
286
+ 20100716, 94.36938324471177
287
+ 20100717, 118.22245939388658
288
+ 20100718, 25.131903584088786
289
+ 20100720, 24.465838200255696
290
+ 20100721, 213.62310389789633
291
+ 20100722, 206.2271132873631
292
+ 20100723, 127.32937133164812
293
+ 20100724, 41.27161966817759
294
+ 20100725, 41.94855426981636
295
+ 20100726, 114.66956012465128
296
+ 20100727, 358.45817969839607
297
+ 20100728, 234.21029375291786
298
+ 20100729, 115.44730793816831
299
+ 20100730, 55.26047169194562
300
+ 20100801, 23.52279190638075
301
+ 20100802, 51.18088879445607
302
+ 20100803, 22.69003860849605
303
+ 20100804, 25.242931712962964
304
+ 20100805, 131.22104725999537
305
+ 20100806, 47.94726707781264
306
+ 20100807, 92.85841741486517
307
+ 20100808, 55.812995772315205
308
+ 20100810, 21.316553347280333
309
+ 20100811, 79.84416296199443
310
+ 20100812, 43.735309703335666
311
+ 20100813, 20.47297060233613
312
+ 20100814, 39.483855801080885
313
+ 20100815, 72.57462335832172
314
+ 20100816, 119.99333086355185
315
+ 20100817, 79.66040751804195
316
+ 20100818, 52.19738548204323
317
+ 20100819, 91.94117597919576
318
+ 20100820, 67.82326788412367
319
+ 20100822, 59.08391282397721
320
+ 20100823, 107.90557934536264
321
+ 20100824, 271.26208825231475
322
+ 20100825, 158.71470170454546
323
+ 20100826, 52.67590156909284
324
+ 20100827, 45.05917614626917
325
+ 20100828, 21.632718430381214
326
+ 20100829, 111.65726660564856
327
+ 20100830, 124.24277933955138
328
+ 20100831, 43.97353320403301
329
+ 20100901, 40.16028300790331
330
+ 20100902, 142.00870325575315
331
+ 20100903, 498.5411797928289
332
+ 20100904, 266.90315673144255
333
+ 20100905, 51.12990215306834
334
+ 20100906, 32.760605169107386
335
+ 20100907, 31.888638823512327
336
+ 20100908, 55.89353366427952
337
+ 20100909, 152.4942677767997
338
+ 20100910, 142.45224347461686
339
+ 20100911, 421.70778871019286
340
+ 20100912, 234.28636390051145
341
+ 20100913, 52.64439450400976
342
+ 20100914, 84.66343270571826
343
+ 20100915, 76.15386468114241
344
+ 20100916, 46.37933737215248
345
+ 20100917, 61.118001983089236
346
+ 20100918, 68.72911563662251
347
+ 20100919, 97.44578212168761
348
+ 20100920, 332.5900812848675
349
+ 20100921, 579.4747775380636
350
+ 20100922, 378.79682342515105
351
+ 20100923, 124.51707036698046
352
+ 20100924, 85.51370801371456
353
+ 20100925, 78.88520437441889
354
+ 20100926, 62.14001699790795
355
+ 20100927, 59.77074071652721
356
+ 20100928, 56.54303221450616
357
+ 20100929, 68.64558376736112
358
+ 20100930, 38.8137911872385
359
+ 20101001, 26.032749665853093
360
+ 20101002, 24.983342413412363
361
+ 20101003, 74.97003192555788
362
+ 20101004, 147.84034751278477
363
+ 20101005, 320.81967598645986
364
+ 20101006, 282.0738953175848
365
+ 20101007, 191.48288677504644
366
+ 20101008, 72.47399828568108
367
+ 20101009, 144.1631623009647
368
+ 20101010, 155.19021930206875
369
+ 20101011, 75.0252271835774
370
+ 20101012, 58.9400700255695
371
+ 20101013, 58.17412940202231
372
+ 20101014, 46.33378225244073
373
+ 20101015, 75.42381559449093
374
+ 20101016, 32.24196231403998
375
+ 20101017, 39.06391013627381
376
+ 20101018, 41.11644093590189
377
+ 20101020, 27.779980742296928
378
+ 20101021, 89.352832258252
379
+ 20101022, 169.78000384995352
380
+ 20101023, 69.3806056775918
381
+ 20101027, 40.52576958038521
382
+ 20101028, 29.311661182298934
383
+ 20101104, 167.72977993520453
384
+ 20101105, 279.8310407949792
385
+ 20101106, 158.45435916434218
386
+ 20101107, 40.77464769293352
387
+ 20101108, 30.161046824151555
388
+ 20101109, 44.06837121212121
389
+ 20101110, 65.25919721205254
390
+ 20101111, 42.84965041695723
391
+ 20101112, 25.77128070049694
392
+ 20101113, 86.23978596090822
393
+ 20101114, 61.09089703916783
394
+ 20101116, 27.438181732333796
395
+ 20101215, 83.16066727684799
396
+ 20101216, 76.18976547971873
397
+ 20101225, 52.48511448163644
398
+ 20110103, 35.99297964609484
399
+ 20110104, 91.52568191393539
400
+ 20110111, 71.14335537976523
401
+ 20110112, 99.04646240120874
402
+ 20110120, 22.689244283182234
403
+ 20110213, 221.73691360849605
404
+ 20110214, 40.80066320897257
405
+ 20110215, 149.97527676080898
406
+ 20110216, 22.64409540620641
407
+ 20110219, 33.252693393836445
408
+ 20110309, 91.90697621641252
409
+ 20110317, 55.72166615818224
410
+ 20110318, 142.97228550274278
411
+ 20110319, 94.70543569851233
412
+ 20110327, 96.05199209515366
413
+ 20110329, 91.72684224731557
414
+ 20110330, 259.44745757787075
415
+ 20110331, 54.02866871513249
416
+ 20110415, 22.540139760576473
417
+ 20110416, 24.26677653562297
418
+ 20110417, 125.34206364772196
419
+ 20110418, 44.91978966904928
420
+ 20110419, 22.963660288665253
421
+ 20110420, 23.317585633497426
422
+ 20110422, 57.15267880491631
423
+ 20110423, 61.766474895397494
424
+ 20110425, 21.202812645281266
425
+ 20110426, 24.233773898768014
426
+ 20110427, 41.24985853236867
427
+ 20110428, 137.31416801051836
428
+ 20110429, 295.45362149872153
429
+ 20110430, 95.75766830834493
430
+ 20110501, 152.57700361023944
431
+ 20110502, 117.76413259821015
432
+ 20110503, 177.39995587081592
433
+ 20110504, 75.2182224474082
434
+ 20110505, 73.69405418264763
435
+ 20110506, 127.82177475592748
436
+ 20110507, 104.55865658414693
437
+ 20110508, 84.0641158908647
438
+ 20110509, 20.580801902090975
439
+ 20110511, 21.399591941248257
440
+ 20110512, 79.2632547361692
441
+ 20110513, 191.61483376191308
442
+ 20110514, 146.87530164022547
443
+ 20110515, 91.66247693659926
444
+ 20110516, 203.88988951359835
445
+ 20110517, 48.00782158007903
446
+ 20110518, 76.23233597745235
447
+ 20110520, 25.131326999070197
448
+ 20110521, 161.98638423988842
449
+ 20110522, 306.4063389847746
450
+ 20110523, 131.7929106374942
451
+ 20110524, 55.43122512058344
452
+ 20110531, 32.16935289740896
453
+ 20110605, 23.848160194386328
454
+ 20110607, 33.29129329381683
455
+ 20110608, 38.98019979805906
456
+ 20110609, 21.67068333042771
457
+ 20110611, 128.9752891097164
458
+ 20110612, 167.42113388394932
459
+ 20110613, 41.01110094142259
460
+ 20110614, 27.571970340829843
461
+ 20110615, 47.03335839435147
462
+ 20110616, 177.92250715510227
463
+ 20110617, 81.67146349081827
464
+ 20110618, 34.76633651499303
465
+ 20110621, 63.823974495874005
466
+ 20110622, 182.5606137043817
467
+ 20110623, 218.6694170226058
468
+ 20110624, 34.418169964551375
469
+ 20110625, 39.781262893712224
470
+ 20110626, 123.53867786785216
471
+ 20110627, 125.84756871803812
472
+ 20110628, 210.20931743229892
473
+ 20110629, 323.0898435683983
474
+ 20110630, 188.5001558141562
475
+ 20110701, 37.033898477452354
476
+ 20110709, 60.17592090161552
477
+ 20110710, 43.41258789516504
478
+ 20110711, 161.83613054248025
479
+ 20110712, 183.02187009675737
480
+ 20110713, 92.7071393029405
481
+ 20110714, 136.99119668177593
482
+ 20110715, 383.15455819967457
483
+ 20110716, 381.922978411204
484
+ 20110717, 198.67304197175736
485
+ 20110718, 164.5868838258368
486
+ 20110719, 144.3219961282543
487
+ 20110720, 203.32798353236865
488
+ 20110721, 49.25801589377032
489
+ 20110722, 25.731379481926997
490
+ 20110728, 107.40300260053466
491
+ 20110729, 142.66062387407015
492
+ 20110730, 100.56940757932357
493
+ 20110731, 26.589137501452807
494
+ 20110805, 29.346624396135265
495
+ 20110806, 20.18941389905857
496
+ 20110807, 43.96820283443747
497
+ 20110808, 139.8291853716295
498
+ 20110809, 208.90599285216177
499
+ 20110810, 175.34678056717806
500
+ 20110811, 56.27505284605997
501
+ 20110812, 20.006154114365422
502
+ 20110817, 25.62793449994189
503
+ 20110818, 21.587661080602047
504
+ 20110819, 30.14290101261041
505
+ 20110820, 35.011025395165035
506
+ 20110821, 20.12696020746165
507
+ 20110824, 40.05301749186426
508
+ 20110825, 22.119369624593205
509
+ 20110826, 30.547850018886557
510
+ 20110830, 34.09250966120409
511
+ 20110831, 52.65844156787542
512
+ 20110901, 130.54113203161316
513
+ 20110902, 128.1669594084147
514
+ 20110903, 116.22911418380986
515
+ 20110904, 37.470217522373325
516
+ 20110905, 47.890565979486276
517
+ 20110906, 30.946702841701526
518
+ 20110907, 34.402428194735016
519
+ 20110908, 31.603840510227805
520
+ 20110909, 78.73658291201767
521
+ 20110910, 100.50139034170155
522
+ 20110911, 156.5110798756392
523
+ 20110912, 66.90991326708507
524
+ 20110913, 46.516962314039965
525
+ 20110914, 41.780873234710555
526
+ 20110915, 64.64849579410738
527
+ 20110916, 77.78759879125988
528
+ 20110917, 58.228567563052074
529
+ 20110918, 112.29172586878197
530
+ 20110919, 60.13979453597164
531
+ 20110920, 27.00396381770107
532
+ 20110922, 29.113824601929327
533
+ 20110923, 84.00498296577173
534
+ 20110924, 141.03301316974665
535
+ 20110925, 230.6655854108554
536
+ 20110926, 78.97429810989075
537
+ 20110927, 35.68688818427476
538
+ 20110928, 136.85599775540447
539
+ 20110929, 259.23085592456994
540
+ 20110930, 75.60321144235238
541
+ 20111001, 61.53155200342862
542
+ 20111002, 214.85392967660394
543
+ 20111003, 176.42546925848444
544
+ 20111004, 61.928290802243126
545
+ 20111005, 42.20670182473268
546
+ 20111006, 90.13077529346815
547
+ 20111007, 68.09544016591119
548
+ 20111008, 29.100231360413762
549
+ 20111009, 51.94596844490933
550
+ 20111010, 137.91833086355183
551
+ 20111011, 127.01731498430965
552
+ 20111012, 170.94246189998842
553
+ 20111013, 286.50326265399815
554
+ 20111014, 114.70196547390748
555
+ 20111015, 47.84846256101814
556
+ 20111016, 39.638181913935384
557
+ 20111017, 28.59887061976987
558
+ 20111018, 32.294380157484895
559
+ 20111019, 31.440034613261272
560
+ 20111020, 39.14295803550673
561
+ 20111024, 28.146944553405397
562
+ 20111104, 35.26615709263133
563
+ 20111107, 71.93888816248257
564
+ 20111108, 331.386634305265
565
+ 20111109, 379.73886037744074
566
+ 20111112, 39.01522039167829
567
+ 20111113, 34.4868455079033
568
+ 20111114, 35.933232726057646
569
+ 20111115, 26.571033639876806
570
+ 20111116, 26.743759080079037
571
+ 20111117, 198.26452394961643
572
+ 20111118, 133.26355946362156
573
+ 20111119, 86.71600527370987
574
+ 20111120, 35.035476050383544
575
+ 20111121, 57.03569506188985
576
+ 20111122, 28.57397849110879
577
+ 20111123, 35.00330605677591
578
+ 20111125, 23.649023346699206
579
+ 20111126, 67.68473711355185
580
+ 20111127, 30.021795095304512
581
+ 20111205, 48.25002124738492
582
+ 20120105, 104.29902189388655
583
+ 20120106, 21.676583565783353
584
+ 20120111, 46.484685720304505
585
+ 20120112, 155.01304698396095
586
+ 20120113, 148.11024614278244
587
+ 20120115, 265.70077870757785
588
+ 20120116, 157.90245852219903
589
+ 20120122, 34.969506188981875
590
+ 20120123, 64.36847218559541
591
+ 20120124, 35.67474159780824
592
+ 20120125, 30.020862571187816
593
+ 20120207, 66.90006538120565
594
+ 20120208, 23.766170894351465
595
+ 20120222, 21.957772184449095
596
+ 20120224, 56.05528164390757
597
+ 20120225, 29.70407423146211
598
+ 20120227, 43.75812358350767
599
+ 20120228, 146.18389430061598
600
+ 20120304, 26.300959764353788
601
+ 20120305, 34.86704457955603
602
+ 20120306, 59.563980416085535
603
+ 20120307, 38.043944676894476
604
+ 20120308, 54.86796003312412
605
+ 20120311, 73.48692305177823
606
+ 20120312, 42.601446274988376
607
+ 20120318, 20.28321583711065
608
+ 20120327, 20.15923425877499
609
+ 20120403, 33.63068885111576
610
+ 20120405, 199.76914952347744
611
+ 20120406, 181.56111873837747
612
+ 20120407, 22.05863079672245
613
+ 20120408, 243.07268748547187
614
+ 20120409, 141.99248950342863
615
+ 20120410, 24.95706648070665
616
+ 20120411, 34.93628381421433
617
+ 20120412, 46.378410114481625
618
+ 20120413, 73.36980219414893
619
+ 20120415, 27.01078804189911
620
+ 20120416, 86.76931859454905
621
+ 20120417, 253.06885242329153
622
+ 20120418, 158.07843917073458
623
+ 20120419, 216.44427755259184
624
+ 20120420, 356.705582614191
625
+ 20120421, 69.50133889328411
626
+ 20120422, 25.895080049976745
627
+ 20120423, 62.498505418991165
628
+ 20120424, 35.07947938458856
629
+ 20120425, 79.38964326185493
630
+ 20120426, 46.88026808025337
631
+ 20120427, 246.83180188720362
632
+ 20120428, 177.62698944531613
633
+ 20120429, 98.48469970362622
634
+ 20120430, 35.45415776818921
635
+ 20120501, 38.946295146152956
636
+ 20120502, 62.265384922710346
637
+ 20120503, 98.89222781555094
638
+ 20120504, 59.85654473936542
639
+ 20120505, 115.06919256305207
640
+ 20120506, 69.89122809884935
641
+ 20120510, 52.98434067730125
642
+ 20120511, 82.38079926487681
643
+ 20120512, 33.287890080195254
644
+ 20120513, 100.77997188807528
645
+ 20120514, 39.74556837662715
646
+ 20120515, 93.45692682909112
647
+ 20120516, 179.34509784693168
648
+ 20120517, 150.1514281148303
649
+ 20120518, 283.1472258542538
650
+ 20120519, 162.8450588592945
651
+ 20120520, 84.38526884298003
652
+ 20120521, 55.78281831125058
653
+ 20120523, 30.494975629067877
654
+ 20120525, 42.9129234948861
655
+ 20120526, 87.15840252353556
656
+ 20120527, 111.9839186352278
657
+ 20120528, 155.79253399581592
658
+ 20120529, 81.49536661727102
659
+ 20120530, 73.16466087720828
660
+ 20120531, 66.30515494246862
661
+ 20120601, 65.95614049424687
662
+ 20120607, 20.85536124186425
663
+ 20120608, 30.526771886622516
664
+ 20120609, 125.87894311512086
665
+ 20120610, 60.65825760547421
666
+ 20120611, 112.87685033850535
667
+ 20120612, 176.38053267375636
668
+ 20120613, 225.37360402864948
669
+ 20120614, 95.9233399799512
670
+ 20120615, 112.87394634762902
671
+ 20120616, 199.91418707868434
672
+ 20120617, 123.39643516097166
673
+ 20120618, 82.30777291085542
674
+ 20120619, 52.6865451388889
675
+ 20120620, 73.52798008193864
676
+ 20120621, 227.8351153896444
677
+ 20120622, 200.99316033966764
678
+ 20120623, 130.71126093241514
679
+ 20120624, 145.42582483437937
680
+ 20120625, 45.958252520629955
681
+ 20120626, 45.63854874912831
682
+ 20120627, 22.53596165301022
683
+ 20120629, 76.21593517549977
684
+ 20120630, 67.28973624186425
685
+ 20120701, 30.23978636390051
686
+ 20120705, 97.22105180003486
687
+ 20120706, 67.02140174628079
688
+ 20120707, 36.87964391562063
689
+ 20120708, 22.167915359135282
690
+ 20120713, 31.47329403475128
691
+ 20120714, 24.112315311192468
692
+ 20120715, 23.03197931194793
693
+ 20120716, 29.393501205834497
694
+ 20120717, 31.363913223500706
695
+ 20120718, 42.11399966585309
696
+ 20120719, 30.423471459495584
697
+ 20120720, 32.19373801429568
698
+ 20120721, 38.418141271501625
699
+ 20120722, 111.534527000523
700
+ 20120723, 305.4445363348443
701
+ 20120725, 425.3922848384472
702
+ 20120726, 270.14713341904934
703
+ 20120727, 143.0784675005811
704
+ 20120728, 39.630659794862844
705
+ 20120729, 23.80491940521851
706
+ 20120803, 65.87163310669456
707
+ 20120804, 174.61665994740815
708
+ 20120805, 127.87644718299626
709
+ 20120806, 52.96096837226872
710
+ 20120808, 24.279802998605298
711
+ 20120809, 21.723397729253836
712
+ 20120810, 56.92519594810553
713
+ 20120811, 148.85347222222217
714
+ 20120812, 204.3983773898768
715
+ 20120813, 160.65132732595308
716
+ 20120814, 28.701588832229195
717
+ 20120816, 98.85505270077871
718
+ 20120817, 141.65484167974196
719
+ 20120818, 73.6276789865179
720
+ 20120820, 22.215848369944215
721
+ 20120821, 48.31004220420735
722
+ 20120822, 114.96252433461177
723
+ 20120823, 21.592981280509065
724
+ 20120825, 58.30701000261506
725
+ 20120826, 92.99211268014876
726
+ 20120829, 26.10597381736705
727
+ 20120830, 30.602471597512785
728
+ 20120831, 48.01793242968387
729
+ 20120901, 41.164819778591344
730
+ 20120902, 36.74191092079264
731
+ 20120903, 52.148103534693156
732
+ 20120904, 77.8733388903417
733
+ 20120905, 33.99965459379359
734
+ 20120906, 60.53588192991633
735
+ 20120907, 74.14039636360997
736
+ 20120909, 33.5636689693747
737
+ 20120910, 45.419486394409574
738
+ 20120911, 41.957370844955825
739
+ 20120912, 34.730700836820084
740
+ 20120913, 61.21184986634123
741
+ 20120914, 33.420998045966996
742
+ 20120915, 26.844679618491394
743
+ 20120916, 54.41844618055556
744
+ 20120917, 71.693800666841
745
+ 20120919, 24.288151404869826
746
+ 20120920, 30.290509319793124
747
+ 20120921, 44.52357624360763
748
+ 20120922, 46.52463007758021
749
+ 20120923, 60.30840070751975
750
+ 20120924, 102.84257503777312
751
+ 20120925, 116.06859727307068
752
+ 20120926, 77.39527763249652
753
+ 20120927, 46.295262196362145
754
+ 20120930, 39.05014368603801
755
+ 20121001, 43.8218346844491
756
+ 20121002, 45.0548095362622
757
+ 20121003, 55.270770135983255
758
+ 20121004, 64.44542382176896
759
+ 20121005, 43.024875602917255
760
+ 20121006, 21.134233532368672
761
+ 20121008, 32.69247788092748
762
+ 20121010, 30.112397758310088
763
+ 20121011, 29.853273731694564
764
+ 20121012, 64.561951744828
765
+ 20121013, 100.94356912482566
766
+ 20121014, 71.15270459234077
767
+ 20121015, 29.164262988145044
768
+ 20121016, 50.05106436686424
769
+ 20121017, 32.32784406235472
770
+ 20121018, 35.83185237244305
771
+ 20121019, 23.56877088418177
772
+ 20121020, 41.49348014005114
773
+ 20121021, 40.56814653794745
774
+ 20121022, 34.88908138656439
775
+ 20121023, 24.15241021617852
776
+ 20121025, 25.634765988203167
777
+ 20121026, 42.98020288528591
778
+ 20121027, 51.68258350040679
779
+ 20121028, 34.085144990702005
780
+ 20121029, 70.69760158792423
781
+ 20121030, 414.6291510489307
782
+ 20121031, 246.88588501714315
783
+ 20121101, 88.46325292015341
784
+ 20121102, 60.19221546664342
785
+ 20121103, 58.46229333740121
786
+ 20121104, 28.681218946129714
787
+ 20121105, 41.65760765341701
788
+ 20121106, 22.620909787598798
789
+ 20121110, 52.8720482463819
790
+ 20121111, 47.08687238493724
791
+ 20121113, 44.9120341919456
792
+ 20121114, 39.10897965335891
793
+ 20121116, 90.44041960861227
794
+ 20121117, 238.69832726784048
795
+ 20121118, 22.83449976028591
796
+ 20121119, 23.235088875813577
797
+ 20121120, 23.16098144758252
798
+ 20121121, 34.50703597164109
799
+ 20121123, 104.08733910099954
800
+ 20121124, 36.144737912598785
801
+ 20121125, 58.245344643479775
802
+ 20121126, 138.49525910913528
803
+ 20121127, 116.1277260213273
804
+ 20121128, 153.1060002978266
805
+ 20121129, 62.34876438284517
806
+ 20121130, 106.80491123314735
807
+ 20121201, 29.832474357856817
808
+ 20121202, 51.42532379561831
809
+ 20121203, 43.87085330950721
810
+ 20121204, 37.574779354079496
811
+ 20121205, 59.6601429930846
812
+ 20121207, 24.548445490469543
813
+ 20121208, 38.5434260227801
814
+ 20121218, 33.644407034518835
815
+ 20121226, 64.43589990847282
816
+ 20121227, 77.76107515399815
817
+ 20121229, 55.823466919456074
818
+ 20121230, 22.691671025104604
819
+ 20130126, 68.32953658908647
820
+ 20130203, 40.21736147431425
821
+ 20130204, 69.91238250377732
822
+ 20130205, 57.8372849837285
823
+ 20130313, 37.18495957548814
824
+ 20130319, 46.00838563458857
825
+ 20130320, 61.9663423262436
826
+ 20130323, 33.65348184710601
827
+ 20130324, 64.3982873154928
828
+ 20130326, 168.771519424105
829
+ 20130327, 41.56376340219664
830
+ 20130328, 265.6667685451534
831
+ 20130329, 79.51368476871221
832
+ 20130330, 355.0431574340425
833
+ 20130331, 67.06829345362623
834
+ 20130402, 149.4377527894003
835
+ 20130403, 110.74030156758481
836
+ 20130404, 66.24243357014181
837
+ 20130405, 156.6795004503719
838
+ 20130406, 91.90465898855184
839
+ 20130408, 23.98437699761739
840
+ 20130409, 136.47248481810786
841
+ 20130410, 99.25893407136216
842
+ 20130411, 145.10497624651325
843
+ 20130412, 38.984963752324504
844
+ 20130413, 25.025353759879135
845
+ 20130414, 44.8889865905393
846
+ 20130415, 43.28978545589259
847
+ 20130416, 27.412897707461646
848
+ 20130417, 118.70350817933519
849
+ 20130418, 98.94480619479309
850
+ 20130419, 78.98868458711065
851
+ 20130420, 81.90855979050443
852
+ 20130421, 22.120508448105525
853
+ 20130425, 121.91539109716413
854
+ 20130426, 252.95637327847808
855
+ 20130428, 25.334676458623893
856
+ 20130429, 62.346051824763585
857
+ 20130430, 142.26964329817525
858
+ 20130501, 33.8277409126569
859
+ 20130502, 119.36171929480476
860
+ 20130503, 68.33329737622034
861
+ 20130504, 54.23608768450721
862
+ 20130506, 21.33145303056718
863
+ 20130507, 22.99018280015109
864
+ 20130508, 178.57049848552012
865
+ 20130509, 137.87810556863087
866
+ 20130510, 156.389030901325
867
+ 20130511, 46.95578782397722
868
+ 20130512, 48.34755909315434
869
+ 20130514, 31.646399022257093
870
+ 20130515, 141.49384643043936
871
+ 20130516, 231.89553605154583
872
+ 20130517, 108.04281911029754
873
+ 20130518, 32.8831909576941
874
+ 20130519, 113.01127255491636
875
+ 20130520, 148.83095816625988
876
+ 20130521, 243.68406282688287
877
+ 20130522, 447.63819553405403
878
+ 20130523, 96.4477764702464
879
+ 20130524, 42.805950902196656
880
+ 20130525, 163.88438226406328
881
+ 20130526, 80.59039618200836
882
+ 20130527, 89.96674293933053
883
+ 20130528, 52.900444197466285
884
+ 20130529, 49.58047056601582
885
+ 20130530, 26.962475302185027
886
+ 20130603, 54.61715099808228
887
+ 20130604, 153.63807785622967
888
+ 20130605, 186.80010859774524
889
+ 20130606, 61.38567726493492
890
+ 20130608, 35.266644511273824
891
+ 20130609, 58.68488003399581
892
+ 20130610, 139.65129427446539
893
+ 20130611, 309.9337880636913
894
+ 20130612, 120.38993564039978
895
+ 20130613, 116.76242281932822
896
+ 20130614, 233.6716854079498
897
+ 20130615, 320.22559765080206
898
+ 20130616, 139.76226337314037
899
+ 20130617, 30.093121840132504
900
+ 20130621, 48.08229983873781
901
+ 20130622, 82.57217791143655
902
+ 20130623, 146.94995877644118
903
+ 20130624, 282.88758952957926
904
+ 20130625, 95.0396178376337
905
+ 20130701, 115.02560437006045
906
+ 20130702, 61.928661814272424
907
+ 20130706, 37.02248790533473
908
+ 20130707, 97.14188077493026
909
+ 20130708, 33.69977971728266
910
+ 20130709, 55.034494857043235
911
+ 20130710, 35.16722363871455
912
+ 20130714, 127.9919087415737
913
+ 20130715, 180.73405374680374
914
+ 20130716, 199.5831769743724
915
+ 20130717, 191.53194281003022
916
+ 20130718, 65.19927994973267
917
+ 20130719, 133.9287567918991
918
+ 20130720, 240.20309303812172
919
+ 20130721, 76.10018487040911
920
+ 20130722, 50.93545008135751
921
+ 20130723, 55.84513471205254
922
+ 20130724, 179.46736546954904
923
+ 20130725, 261.0974208943515
924
+ 20130726, 271.54983855619474
925
+ 20130727, 136.99305464754767
926
+ 20130728, 92.09601166608553
927
+ 20130731, 24.211531881973492
928
+ 20130801, 64.45616210483497
929
+ 20130802, 185.41478363987684
930
+ 20130803, 120.62649948425155
931
+ 20130807, 37.23730023826127
932
+ 20130810, 33.29708638423989
933
+ 20130811, 21.927062630753134
934
+ 20130813, 322.93491290388187
935
+ 20130814, 307.2372183359484
936
+ 20130815, 220.49118614888434
937
+ 20130816, 297.62268022140853
938
+ 20130817, 267.12960251046024
939
+ 20130818, 111.35105564998838
940
+ 20130819, 45.49491134210832
941
+ 20130820, 57.66521966527196
942
+ 20130821, 71.53396149320083
943
+ 20130822, 101.2599978934217
944
+ 20130823, 140.78798832665038
945
+ 20130824, 89.85586246222688
946
+ 20130825, 40.11625461268016
947
+ 20130829, 51.15046163121804
948
+ 20130830, 147.82179708672084
949
+ 20130831, 246.754305955079
950
+ 20130901, 113.76087139702464
951
+ 20130902, 79.31063386215712
952
+ 20130903, 168.32009875493958
953
+ 20130904, 272.31384984454905
954
+ 20130905, 91.41031605939098
955
+ 20130906, 69.47553917509299
956
+ 20130907, 39.205069952928866
957
+ 20130908, 21.31527142172246
958
+ 20130909, 31.588420720013954
959
+ 20130910, 44.12261992968386
960
+ 20130911, 52.30735867764992
961
+ 20130912, 79.7153144612971
962
+ 20130913, 103.39279368607623
963
+ 20130914, 72.49962335832173
964
+ 20130915, 48.97323882787076
965
+ 20130916, 76.44151357653416
966
+ 20130917, 53.45604969345652
967
+ 20130918, 54.43395295792656
968
+ 20130919, 29.099646603324036
969
+ 20130922, 367.8701579570549
970
+ 20130923, 249.235258854893
971
+ 20130924, 32.61869243229893
972
+ 20130925, 35.36432091904928
973
+ 20130928, 73.08063255462574
974
+ 20130929, 243.00749633164807
975
+ 20130930, 306.4345184652487
976
+ 20131001, 40.19865033705253
977
+ 20131002, 34.479447059507216
978
+ 20131004, 21.415825669746635
979
+ 20131009, 25.32022790136535
980
+ 20131010, 32.18095126539982
981
+ 20131011, 31.99367899378197
982
+ 20131012, 30.35447956619015
983
+ 20131013, 30.706253450430033
984
+ 20131014, 142.5235437369247
985
+ 20131015, 43.649596299686195
986
+ 20131016, 25.725661756159933
987
+ 20131017, 21.21159925615992
988
+ 20131018, 28.128024392724317
989
+ 20131022, 32.32290086732914
990
+ 20131023, 35.4515408894119
991
+ 20131029, 38.405682676662
992
+ 20131030, 31.816779986053007
993
+ 20131031, 31.90093615614831
994
+ 20131101, 27.61662816713157
995
+ 20131102, 115.37988308490236
996
+ 20131103, 149.52505230125522
997
+ 20131104, 78.23612200720596
998
+ 20131105, 41.96016533007903
999
+ 20131106, 26.114996295327753
1000
+ 20131107, 42.38168293816829
1001
+ 20131110, 136.3468579294514
1002
+ 20131112, 94.41665213854021
1003
+ 20131115, 39.82262229050443
1004
+ 20131116, 78.61473497065319
1005
+ 20131117, 48.18344810553231
1006
+ 20131118, 29.973857907659227
1007
+ 20131120, 20.604020114191066
1008
+ 20131121, 51.136010503835415
1009
+ 20131122, 66.37020844229428
1010
+ 20131124, 76.96372798988841
1011
+ 20131127, 50.039894416841015
1012
+ 20131128, 75.84964823773826
1013
+ 20131213, 79.16618905450954
1014
+ 20131214, 211.67638525685723
1015
+ 20131215, 420.3117846713737
1016
+ 20131216, 342.96863105096475
1017
+ 20131217, 478.720224241632
1018
+ 20140208, 23.630335454439784
1019
+ 20140209, 75.14158494595539
1020
+ 20140210, 27.744407034518833
1021
+ 20140213, 94.30788241660855
1022
+ 20140219, 172.0230234483961
1023
+ 20140222, 69.79637305323105
1024
+ 20140309, 65.87043417305905
1025
+ 20140310, 64.25986641387725
1026
+ 20140311, 24.917366922361698
1027
+ 20140312, 29.19463512610414
1028
+ 20140313, 25.458290293758722
1029
+ 20140314, 31.69204403475128
1030
+ 20140318, 23.423604755055795
1031
+ 20140319, 39.21903130084844
1032
+ 20140327, 23.41754380230125
1033
+ 20140329, 91.91434470885638
1034
+ 20140330, 292.9196107914923
1035
+ 20140331, 253.80143901092512
1036
+ 20140401, 106.97886139440962
1037
+ 20140402, 181.600617990179
1038
+ 20140403, 252.41761190289398
1039
+ 20140404, 20.795628849953516
1040
+ 20140405, 21.720432357043236
1041
+ 20140406, 300.43859778155513
1042
+ 20140408, 78.17030487273361
1043
+ 20140409, 32.890573061947926
1044
+ 20140410, 22.470932850999535
1045
+ 20140413, 25.260595907426776
1046
+ 20140418, 20.41067762813808
1047
+ 20140426, 113.97868723849369
1048
+ 20140427, 44.45263358612272
1049
+ 20140429, 52.85484658298466
1050
+ 20140430, 126.49033480067413
1051
+ 20140501, 128.88813433432125
1052
+ 20140502, 35.50418646123896
1053
+ 20140503, 68.03873616195956
1054
+ 20140504, 124.95661828800556
1055
+ 20140505, 174.92232119508373
1056
+ 20140506, 88.36784852975359
1057
+ 20140507, 253.31558577405852
1058
+ 20140508, 281.3919763217787
1059
+ 20140509, 480.06120300151093
1060
+ 20140510, 230.49119141533015
1061
+ 20140511, 205.4820327028126
1062
+ 20140512, 156.76944699754904
1063
+ 20140513, 81.30908298465832
1064
+ 20140514, 21.89773052504649
1065
+ 20140515, 39.947487178928405
1066
+ 20140516, 88.09127694967457
1067
+ 20140517, 183.43590390370755
1068
+ 20140518, 105.40262759327061
1069
+ 20140519, 112.52293954846583
1070
+ 20140520, 143.1907039966876
1071
+ 20140521, 123.768075713331
1072
+ 20140522, 146.64557309826824
1073
+ 20140523, 194.98976075807767
1074
+ 20140524, 40.74365030073222
1075
+ 20140526, 36.89792792596466
1076
+ 20140527, 25.41117921170386
1077
+ 20140528, 32.128269373256636
1078
+ 20140529, 42.571904237854476
1079
+ 20140530, 30.894538332461654
1080
+ 20140603, 119.3601629692585
1081
+ 20140604, 28.660918250232445
1082
+ 20140605, 106.6490854544398
1083
+ 20140606, 169.37256871803808
1084
+ 20140607, 154.40959328509996
1085
+ 20140608, 53.36430057967224
1086
+ 20140609, 67.63162573367039
1087
+ 20140610, 89.99786563662249
1088
+ 20140611, 37.91046715190608
1089
+ 20140612, 36.04271668700604
1090
+ 20140613, 22.212055620932123
1091
+ 20140614, 60.21817922623198
1092
+ 20140615, 160.63765908298467
1093
+ 20140616, 54.423535564853566
1094
+ 20140617, 58.11623173088098
1095
+ 20140618, 38.36407049046953
1096
+ 20140619, 74.36847196798001
1097
+ 20140620, 77.7738283066016
1098
+ 20140621, 119.29349666579498
1099
+ 20140622, 187.1327440362041
1100
+ 20140623, 111.1517991268596
1101
+ 20140624, 65.32679131799162
1102
+ 20140625, 71.67983804771039
1103
+ 20140626, 23.243867677824277
1104
+ 20140629, 50.502375711878194
1105
+ 20140630, 44.06397496803812
1106
+ 20140701, 70.78839747501164
1107
+ 20140702, 21.4579365338796
1108
+ 20140705, 20.564100636331936
1109
+ 20140706, 59.28443547332637
1110
+ 20140707, 55.36132866980474
1111
+ 20140708, 94.93222883251977
1112
+ 20140709, 24.205018559681545
1113
+ 20140710, 65.32519667451186
1114
+ 20140711, 57.94388838040446
1115
+ 20140712, 61.87451693979545
1116
+ 20140713, 36.246024196594604
1117
+ 20140714, 24.11177268566946
1118
+ 20140717, 86.05393040301021
1119
+ 20140718, 170.54987887174568
1120
+ 20140719, 70.18380277341934
1121
+ 20140720, 21.389537206531852
1122
+ 20140721, 24.904192817294284
1123
+ 20140723, 33.28208209844259
1124
+ 20140724, 98.2812189461297
1125
+ 20140725, 66.55174991283123
1126
+ 20140726, 89.153669804742
1127
+ 20140727, 55.76670643741284
1128
+ 20140801, 77.50097955892608
1129
+ 20140802, 47.23127941945607
1130
+ 20140803, 24.28849445025569
1131
+ 20140804, 27.641382823396093
1132
+ 20140805, 59.530193986808456
1133
+ 20140806, 107.75092253602975
1134
+ 20140807, 94.35307978120642
1135
+ 20140808, 57.08875014528127
1136
+ 20140809, 76.8537945650279
1137
+ 20140810, 34.28849717427941
1138
+ 20140811, 51.66818957752208
1139
+ 20140812, 197.83517677097862
1140
+ 20140813, 357.7546878632031
1141
+ 20140814, 122.80910241602739
1142
+ 20140818, 40.27715887959089
1143
+ 20140819, 327.54424032426783
1144
+ 20140820, 286.2331065129592
1145
+ 20140821, 63.01878777312877
1146
+ 20140822, 32.8901639135867
1147
+ 20140825, 21.239476006799162
1148
+ 20140826, 28.87913172375236
1149
+ 20140827, 124.34361852045556
1150
+ 20140828, 78.05018559681542
1151
+ 20140829, 23.721888438516967
1152
+ 20140830, 34.31858492416317
1153
+ 20140831, 62.38237102655741
1154
+ 20140903, 25.001808933344957
1155
+ 20140904, 34.299766278765695
1156
+ 20140905, 64.1351253777313
1157
+ 20140906, 40.08666572233845
1158
+ 20140907, 86.35310429741982
1159
+ 20140908, 49.52347545473036
1160
+ 20140909, 29.29712361256393
1161
+ 20140910, 37.02630753138075
1162
+ 20140911, 40.57987055439331
1163
+ 20140912, 80.21033512755693
1164
+ 20140913, 45.57955347803347
1165
+ 20140914, 56.27964918206646
1166
+ 20140915, 131.16330304218968
1167
+ 20140916, 303.1881555817061
1168
+ 20140917, 76.85736303608788
1169
+ 20140918, 51.148375755462574
1170
+ 20140919, 67.16068652661554
1171
+ 20140920, 24.000558606462118
1172
+ 20140923, 21.921376866864247
1173
+ 20140924, 25.080225948686653
1174
+ 20140925, 35.09464892782425
1175
+ 20140926, 39.96193594549049
1176
+ 20140927, 43.513512792015334
1177
+ 20140928, 48.600577129823336
1178
+ 20140929, 65.18053576098326
1179
+ 20140930, 78.84402938258636
1180
+ 20141001, 72.51337659082985
1181
+ 20141002, 50.72223438952814
1182
+ 20141003, 76.47042702497669
1183
+ 20141004, 126.25225493843371
1184
+ 20141005, 64.90787206531846
1185
+ 20141011, 20.986212263191543
1186
+ 20141012, 20.71796657078103
1187
+ 20141016, 23.756614474372387
1188
+ 20141017, 35.27394997966062
1189
+ 20141018, 23.492493317061832
1190
+ 20141019, 20.595106107887954
1191
+ 20141021, 40.95922790271968
1192
+ 20141022, 23.438181187529064
1193
+ 20141023, 43.47170483931891
1194
+ 20141024, 46.727642645599914
1195
+ 20141025, 51.37816004910506
1196
+ 20141026, 94.35662713001871
1197
+ 20141027, 68.03675961035566
1198
+ 20141028, 21.424196957810324
1199
+ 20141030, 24.947406410480855
1200
+ 20141031, 34.73852623053231
1201
+ 20141101, 42.339237927126916
1202
+ 20141104, 21.830416267143192
1203
+ 20141106, 43.89685048378662
1204
+ 20141107, 69.33576352568573
1205
+ 20141108, 87.40360715219663
1206
+ 20141110, 29.08102771966528
1207
+ 20141111, 37.43220394928804
1208
+ 20141112, 29.107394453161323
1209
+ 20141113, 20.995226239249185
1210
+ 20141114, 23.49454341730591
1211
+ 20141115, 26.82312696129707
1212
+ 20141116, 29.49105357973036
1213
+ 20141201, 33.785750450371914
1214
+ 20141202, 43.8162371980042
1215
+ 20141203, 21.638188269990703
1216
+ 20141204, 52.66889273884239
1217
+ 20141205, 20.34541219926778
1218
+ 20141207, 22.521633833100882
1219
+ 20141218, 25.065987839958158
1220
+ 20141219, 117.93202650852008
1221
+ 20141225, 87.132361220072
1222
+ 20141226, 20.54271959263134
1223
+ 20141227, 48.39155815972224
1224
+ 20141228, 34.44478494740817
data/HKO-7/samplers/hko7_cloudy_days_t20_train.txt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cdd5bf19545b3138be01509547b0360457d3a4a9273e16a333fe62a4a212d16
3
+ size 2916002
data/SEVIR/CATALOG.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3209386cde96ffa80ccec3c1919090ffc601333cc651116f9fc48f224a5c2f57
3
+ size 33838047
data/__pycache__/config.cpython-39.pyc ADDED
Binary file (1.42 kB). View file
 
data/__pycache__/dutils.cpython-38.pyc ADDED
Binary file (35.2 kB). View file
 
data/__pycache__/dutils.cpython-39.pyc ADDED
Binary file (35.3 kB). View file
 
data/__pycache__/loader.cpython-38.pyc ADDED
Binary file (1.59 kB). View file
 
data/__pycache__/loader.cpython-39.pyc ADDED
Binary file (1.61 kB). View file
 
data/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import dutils
2
+
3
+ SEVIR_13_12 = {
4
+ 'meta': {
5
+ 'dataset': 'SEVIR',
6
+ 'seq_len': 13,
7
+ 'out_len': 12,
8
+ 'metrics': ['mae', 'mse', 'ssim', 'psnr', 'lpips64',
9
+ 'csi-16', 'csi-74', 'csi-133', 'csi-160', 'csi-181', 'csi-219',
10
+ 'csi_4-16', 'csi_4-74', 'csi_4-133', 'csi_4-160', 'csi_4-181', 'csi_4-219',
11
+ 'csi_16-16', 'csi_16-74', 'csi_16-133', 'csi_16-160', 'csi_16-181', 'csi_16-219',
12
+ 'hss-16', 'hss-74', 'hss-133', 'hss-160', 'hss-181', 'hss-219' ],
13
+ },
14
+ 'param': {
15
+ 'seq_len': 25,
16
+ 'data_types': ['vil'],
17
+ 'sample_mode': 'sequent',
18
+ 'layout': 'NTCHW',
19
+ 'raw_seq_len': 25,
20
+ 'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE,
21
+ 'end_date': None,
22
+ },
23
+ 'savedir': 'sevir'
24
+ }
25
+
26
+ HKO7_5_20 = {
27
+ 'meta': {
28
+ 'dataset': 'HKO-7',
29
+ 'seq_len': 5,
30
+ 'out_len': 20,
31
+ 'metrics': ['mae', 'mse', 'ssim', 'psnr', 'lpips64',
32
+ 'csi-84', 'csi-117', 'csi-140', 'csi-158', 'csi-185',
33
+ 'csi_4-84', 'csi_4-117', 'csi_4-140', 'csi_4-158', 'csi_4-185',
34
+ 'csi_16-84', 'csi_16-117', 'csi_16-140', 'csi_16-158', 'csi_16-185',
35
+ 'hss-84', 'hss-117', 'hss-140', 'hss-158', 'hss-185'],
36
+ },
37
+ 'param': {
38
+ 'pd_path': 'data/HKO-7/samplers/hko7_cloudy_days_t20_test.txt.pkl',
39
+ 'sample_mode': 'sequent',
40
+ 'seq_len': 25,
41
+ 'stride': 13,
42
+ },
43
+ 'savedir': 'hko-7'
44
+ }
45
+
46
+ METEONET_5_20 = {
47
+ 'meta': {
48
+ 'dataset': 'meteonet',
49
+ 'seq_len': 5,
50
+ 'out_len': 20,
51
+ 'metrics': ['mae', 'mse', 'ssim', 'psnr', 'lpips64',
52
+ 'csi-44', 'csi-64', 'csi-87', 'csi-117',
53
+ 'csi_4-44', 'csi_4-64', 'csi_4-87', 'csi_4-117',
54
+ 'csi_16-44', 'csi_16-64', 'csi_16-87', 'csi_16-117',
55
+ 'hss-44', 'hss-64', 'hss-87', 'hss-117']
56
+ },
57
+ 'param': {
58
+ 'img_size': 128,
59
+ 'in_len': 5,
60
+ },
61
+ 'savedir': 'meteonet'
62
+ }
data/dutils.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ dutils.py
3
+ A utility library for customized data loading functions
4
+ '''
5
+ import os
6
+ import gzip
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ import os
11
+ import cv2
12
+ from typing import List, Union, Dict, Sequence
13
+ import numpy as np
14
+ import numpy.random as nprand
15
+ import datetime
16
+ import pandas as pd
17
+ import h5py
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch.nn.functional import avg_pool2d
21
+ import random
22
+ from torchvision import transforms as T
23
+ from torchvision import datasets
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from PIL import Image
26
+
27
+ SEVIR_ROOT_DIR = "data/SEVIR"
28
+ METEO_FILE_DIR = "data/meteonet"
29
+
30
+ def resize(seq, size):
31
+ # seq shape : (B, T, 1, H, W)
32
+ seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) # (B, T, H, W)
33
+ seq = seq.clamp(0,1)
34
+ return seq.unsqueeze(2) # (B, T, 1, H, W)
35
+
36
+ # =====================================================================================
37
+ # HKO-7 data
38
+ # =====================================================================================
39
+ def pixel_to_dBZ_nonlinear(img):
40
+ '''
41
+ [0, 255] OR [0, 1] pixel => [0, 80] dBZ
42
+ '''
43
+ if img.mean() > 1.0:
44
+ img = img / 255.0
45
+ ashift = 31.0
46
+ afact = 4.0
47
+ atan_dBZ_min = -1.482
48
+ atan_dBZ_max = 1.412
49
+ tan_pix = np.tan(img * (atan_dBZ_max - atan_dBZ_min) + atan_dBZ_min)
50
+ return tan_pix * afact + ashift
51
+
52
+ def dbZ_to_pixel_nonlinear(dbZ):
53
+ '''
54
+ [0, 80] dBZ => [0, 255] OR [0, 1] pixel
55
+ '''
56
+ ashift = 31.0
57
+ afact = 4.0
58
+ atan_dBZ_min = -1.482
59
+ atan_dBZ_max = 1.412
60
+ dbZ_adjusted = (dbZ - ashift) / afact
61
+ return (np.arctan(dbZ_adjusted) - atan_dBZ_min) / (atan_dBZ_max - atan_dBZ_min)
62
+
63
+ def dbZ_to_pixel(dbZ):
64
+ '''
65
+ [0, 80] dbZ => [0, 1] pixel
66
+ '''
67
+ return np.floor((dbZ + 10) * 255 / 70 + 0.5) / 255.0
68
+
69
+ def pixel_to_dBZ(pixel):
70
+ '''
71
+ [0, 255] (or [0, 1]) pixel => [0, 80] dBZ
72
+ '''
73
+ if pixel.mean() > 1.0:
74
+ pixel = pixel / 255.0
75
+ return (70 * pixel) - 10
76
+
77
+ def nonlinear_to_linear(im):
78
+ return dbZ_to_pixel(pixel_to_dBZ_nonlinear(im))
79
+
80
+ def nonlinear_to_linear_batched(seq, datetime):
81
+ seq_linear = np.zeros_like(seq)
82
+ for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)):
83
+ if dt_b[0].year >= 2016:
84
+ seq_linear[i] = nonlinear_to_linear(seq_b)
85
+ else:
86
+ seq_linear[i] = seq_b
87
+ seq_linear = np.clip(seq_linear, 0.0, 1.0)
88
+ return seq_linear
89
+
90
+ def linear_to_nonlinear(im):
91
+ return dbZ_to_pixel_nonlinear(pixel_to_dBZ(im))
92
+
93
+ def linear_to_nonlinear_batched(seq, datetime):
94
+ seq_nonlinear = np.zeros_like(seq)
95
+ for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)):
96
+ if dt_b[0].year < 2016:
97
+ seq_nonlinear[i] = linear_to_nonlinear(seq_b)
98
+ else:
99
+ seq_nonlinear[i] = seq_b
100
+ seq_nonlinear = np.clip(seq_nonlinear, 0.0, 1.0)
101
+ return seq_nonlinear
102
+
103
+
104
+ # =====================================================================================
105
+ # SEVIR data
106
+ # Code is adapted from https://github.com/MIT-AI-Accelerator/neurips-2020-sevir. Their license is MIT License.
107
+ # (From Earthformer's implementation)
108
+ # =====================================================================================
109
+
110
+
111
+ # SEVIR Dataset constants
112
+ SEVIR_DATA_TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght']
113
+ SEVIR_RAW_DTYPES = {'vis': np.int16,
114
+ 'ir069': np.int16,
115
+ 'ir107': np.int16,
116
+ 'vil': np.uint8,
117
+ 'lght': np.int16}
118
+ LIGHTING_FRAME_TIMES = np.arange(- 120.0, 125.0, 5) * 60
119
+ SEVIR_DATA_SHAPE = {'lght': (48, 48), }
120
+ PREPROCESS_SCALE_SEVIR = {'vis': 1, # Not utilized in original paper
121
+ 'ir069': 1 / 1174.68,
122
+ 'ir107': 1 / 2562.43,
123
+ 'vil': 1 / 47.54,
124
+ 'lght': 1 / 0.60517}
125
+ PREPROCESS_OFFSET_SEVIR = {'vis': 0, # Not utilized in original paper
126
+ 'ir069': 3683.58,
127
+ 'ir107': 1552.80,
128
+ 'vil': - 33.44,
129
+ 'lght': - 0.02990}
130
+ PREPROCESS_SCALE_01 = {'vis': 1,
131
+ 'ir069': 1,
132
+ 'ir107': 1,
133
+ 'vil': 1 / 255, # currently the only one implemented
134
+ 'lght': 1}
135
+ PREPROCESS_OFFSET_01 = {'vis': 0,
136
+ 'ir069': 0,
137
+ 'ir107': 0,
138
+ 'vil': 0, # currently the only one implemented
139
+ 'lght': 0}
140
+
141
+ # sevir
142
+ SEVIR_CATALOG = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv")
143
+ SEVIR_DATA_DIR = os.path.join(SEVIR_ROOT_DIR, "data")
144
+ SEVIR_RAW_SEQ_LEN = 49
145
+
146
+ SEVIR_TRAIN_VAL_SPLIT_DATE = datetime.datetime(2019, 1, 1)
147
+ SEVIR_TRAIN_TEST_SPLIT_DATE = datetime.datetime(2019, 6, 1)
148
+
149
+ def change_layout_np(data,
150
+ in_layout='NHWT', out_layout='NHWT',
151
+ ret_contiguous=False):
152
+ # first convert to 'NHWT'
153
+ if in_layout == 'NHWT':
154
+ pass
155
+ elif in_layout == 'NTHW':
156
+ data = np.transpose(data,
157
+ axes=(0, 2, 3, 1))
158
+ elif in_layout == 'NWHT':
159
+ data = np.transpose(data,
160
+ axes=(0, 2, 1, 3))
161
+ elif in_layout == 'NTCHW':
162
+ data = data[:, :, 0, :, :]
163
+ data = np.transpose(data,
164
+ axes=(0, 2, 3, 1))
165
+ elif in_layout == 'NTHWC':
166
+ data = data[:, :, :, :, 0]
167
+ data = np.transpose(data,
168
+ axes=(0, 2, 3, 1))
169
+ elif in_layout == 'NTWHC':
170
+ data = data[:, :, :, :, 0]
171
+ data = np.transpose(data,
172
+ axes=(0, 3, 2, 1))
173
+ elif in_layout == 'TNHW':
174
+ data = np.transpose(data,
175
+ axes=(1, 2, 3, 0))
176
+ elif in_layout == 'TNCHW':
177
+ data = data[:, :, 0, :, :]
178
+ data = np.transpose(data,
179
+ axes=(1, 2, 3, 0))
180
+ else:
181
+ raise NotImplementedError
182
+
183
+ if out_layout == 'NHWT':
184
+ pass
185
+ elif out_layout == 'NTHW':
186
+ data = np.transpose(data,
187
+ axes=(0, 3, 1, 2))
188
+ elif out_layout == 'NWHT':
189
+ data = np.transpose(data,
190
+ axes=(0, 2, 1, 3))
191
+ elif out_layout == 'NTCHW':
192
+ data = np.transpose(data,
193
+ axes=(0, 3, 1, 2))
194
+ data = np.expand_dims(data, axis=2)
195
+ elif out_layout == 'NTHWC':
196
+ data = np.transpose(data,
197
+ axes=(0, 3, 1, 2))
198
+ data = np.expand_dims(data, axis=-1)
199
+ elif out_layout == 'NTWHC':
200
+ data = np.transpose(data,
201
+ axes=(0, 3, 2, 1))
202
+ data = np.expand_dims(data, axis=-1)
203
+ elif out_layout == 'TNHW':
204
+ data = np.transpose(data,
205
+ axes=(3, 0, 1, 2))
206
+ elif out_layout == 'TNCHW':
207
+ data = np.transpose(data,
208
+ axes=(3, 0, 1, 2))
209
+ data = np.expand_dims(data, axis=2)
210
+ else:
211
+ raise NotImplementedError
212
+ if ret_contiguous:
213
+ data = data.ascontiguousarray()
214
+ return data
215
+
216
+ def change_layout_torch(data,
217
+ in_layout='NHWT', out_layout='NHWT',
218
+ ret_contiguous=False):
219
+ # first convert to 'NHWT'
220
+ if in_layout == 'NHWT':
221
+ pass
222
+ elif in_layout == 'NTHW':
223
+ data = data.permute(0, 2, 3, 1)
224
+ elif in_layout == 'NTCHW':
225
+ data = data[:, :, 0, :, :]
226
+ data = data.permute(0, 2, 3, 1)
227
+ elif in_layout == 'NTHWC':
228
+ data = data[:, :, :, :, 0]
229
+ data = data.permute(0, 2, 3, 1)
230
+ elif in_layout == 'TNHW':
231
+ data = data.permute(1, 2, 3, 0)
232
+ elif in_layout == 'TNCHW':
233
+ data = data[:, :, 0, :, :]
234
+ data = data.permute(1, 2, 3, 0)
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ if out_layout == 'NHWT':
239
+ pass
240
+ elif out_layout == 'NTHW':
241
+ data = data.permute(0, 3, 1, 2)
242
+ elif out_layout == 'NTCHW':
243
+ data = data.permute(0, 3, 1, 2)
244
+ data = torch.unsqueeze(data, dim=2)
245
+ elif out_layout == 'NTHWC':
246
+ data = data.permute(0, 3, 1, 2)
247
+ data = torch.unsqueeze(data, dim=-1)
248
+ elif out_layout == 'TNHW':
249
+ data = data.permute(3, 0, 1, 2)
250
+ elif out_layout == 'TNCHW':
251
+ data = data.permute(3, 0, 1, 2)
252
+ data = torch.unsqueeze(data, dim=2)
253
+ else:
254
+ raise NotImplementedError
255
+ if ret_contiguous:
256
+ data = data.contiguous()
257
+ return data
258
+
259
+ class SEVIRDataLoader:
260
+ r"""
261
+ DataLoader that loads SEVIR sequences, and spilts each event
262
+ into segments according to specified sequence length.
263
+
264
+ Event Frames:
265
+ [-----------------------raw_seq_len----------------------]
266
+ [-----seq_len-----]
267
+ <--stride-->[-----seq_len-----]
268
+ <--stride-->[-----seq_len-----]
269
+ ...
270
+ """
271
+ def __init__(self,
272
+ data_types: Sequence[str] = None,
273
+ seq_len: int = 49,
274
+ raw_seq_len: int = 49,
275
+ sample_mode: str = 'sequent',
276
+ stride: int = 12,
277
+ batch_size: int = 1,
278
+ layout: str = 'NHWT',
279
+ num_shard: int = 1,
280
+ rank: int = 0,
281
+ split_mode: str = "uneven",
282
+ sevir_catalog: Union[str, pd.DataFrame] = None,
283
+ sevir_data_dir: str = None,
284
+ start_date: datetime.datetime = None,
285
+ end_date: datetime.datetime = None,
286
+ datetime_filter=None,
287
+ catalog_filter='default',
288
+ shuffle: bool = False,
289
+ shuffle_seed: int = 1,
290
+ output_type=np.float32,
291
+ preprocess: bool = True,
292
+ rescale_method: str = '01',
293
+ downsample_dict: Dict[str, Sequence[int]] = None,
294
+ verbose: bool = False):
295
+ r"""
296
+ Parameters
297
+ ----------
298
+ data_types
299
+ A subset of SEVIR_DATA_TYPES.
300
+ seq_len
301
+ The length of the data sequences. Should be smaller than the max length raw_seq_len.
302
+ raw_seq_len
303
+ The length of the raw data sequences.
304
+ sample_mode
305
+ 'random' or 'sequent'
306
+ stride
307
+ Useful when sample_mode == 'sequent'
308
+ stride must not be smaller than out_len to prevent data leakage in testing.
309
+ batch_size
310
+ Number of sequences in one batch.
311
+ layout
312
+ str: consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W'
313
+ The layout of sampled data. Raw data layout is 'NHWT'.
314
+ valid layout: 'NHWT', 'NTHW', 'NTCHW', 'TNHW', 'TNCHW'.
315
+ num_shard
316
+ Split the whole dataset into num_shard parts for distributed training.
317
+ rank
318
+ Rank of the current process within num_shard.
319
+ split_mode: str
320
+ if 'ceil', all `num_shard` dataloaders have the same length = ceil(total_len / num_shard).
321
+ Different dataloaders may have some duplicated data batches, if the total size of datasets is not divided by num_shard.
322
+ if 'floor', all `num_shard` dataloaders have the same length = floor(total_len / num_shard).
323
+ The last several data batches may be wasted, if the total size of datasets is not divided by num_shard.
324
+ if 'uneven', the last datasets has larger length when the total length is not divided by num_shard.
325
+ The uneven split leads to synchronization error in dist.all_reduce() or dist.barrier().
326
+ See related issue: https://github.com/pytorch/pytorch/issues/33148
327
+ Notice: this also affects the behavior of `self.use_up`.
328
+ sevir_catalog
329
+ Name of SEVIR catalog CSV file.
330
+ sevir_data_dir
331
+ Directory path to SEVIR data.
332
+ start_date
333
+ Start time of SEVIR samples to generate.
334
+ end_date
335
+ End time of SEVIR samples to generate.
336
+ datetime_filter
337
+ function
338
+ Mask function applied to time_utc column of catalog (return true to keep the row).
339
+ Pass function of the form lambda t : COND(t)
340
+ Example: lambda t: np.logical_and(t.dt.hour>=13,t.dt.hour<=21) # Generate only day-time events
341
+ catalog_filter
342
+ function or None or 'default'
343
+ Mask function applied to entire catalog dataframe (return true to keep row).
344
+ Pass function of the form lambda catalog: COND(catalog)
345
+ Example: lambda c: [s[0]=='S' for s in c.id] # Generate only the 'S' events
346
+ shuffle
347
+ bool, If True, data samples are shuffled before each epoch.
348
+ shuffle_seed
349
+ int, Seed to use for shuffling.
350
+ output_type
351
+ np.dtype, dtype of generated tensors
352
+ preprocess
353
+ bool, If True, self.preprocess_data_dict(data_dict) is called before each sample generated
354
+ downsample_dict:
355
+ dict, downsample_dict.keys() == data_types. downsample_dict[key] is a Sequence of (t_factor, h_factor, w_factor),
356
+ representing the downsampling factors of all dimensions.
357
+ verbose
358
+ bool, verbose when opening raw data files
359
+ """
360
+ super(SEVIRDataLoader, self).__init__()
361
+ if sevir_catalog is None:
362
+ sevir_catalog = SEVIR_CATALOG
363
+ if sevir_data_dir is None:
364
+ sevir_data_dir = SEVIR_DATA_DIR
365
+ if data_types is None:
366
+ data_types = SEVIR_DATA_TYPES
367
+ else:
368
+ assert set(data_types).issubset(SEVIR_DATA_TYPES)
369
+
370
+ # configs which should not be modified
371
+ self._dtypes = SEVIR_RAW_DTYPES
372
+ self.lght_frame_times = LIGHTING_FRAME_TIMES
373
+ self.data_shape = SEVIR_DATA_SHAPE
374
+
375
+ self.raw_seq_len = raw_seq_len
376
+ assert seq_len <= self.raw_seq_len, f'seq_len must not be larger than raw_seq_len = {raw_seq_len}, got {seq_len}.'
377
+ self.seq_len = seq_len
378
+ assert sample_mode in ['random', 'sequent'], f'Invalid sample_mode = {sample_mode}, must be \'random\' or \'sequent\'.'
379
+ self.sample_mode = sample_mode
380
+ self.stride = stride
381
+ self.batch_size = batch_size
382
+ valid_layout = ('NHWT', 'NTHW', 'NTCHW', 'NTHWC', 'TNHW', 'TNCHW')
383
+ if layout not in valid_layout:
384
+ raise ValueError(f'Invalid layout = {layout}! Must be one of {valid_layout}.')
385
+ self.layout = layout
386
+ self.num_shard = num_shard
387
+ self.rank = rank
388
+ valid_split_mode = ('ceil', 'floor', 'uneven')
389
+ if split_mode not in valid_split_mode:
390
+ raise ValueError(f'Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}.')
391
+ self.split_mode = split_mode
392
+ self._samples = None
393
+ self._hdf_files = {}
394
+ self.data_types = data_types
395
+ if isinstance(sevir_catalog, str):
396
+ self.catalog = pd.read_csv(sevir_catalog, parse_dates=['time_utc'], low_memory=False)
397
+ else:
398
+ self.catalog = sevir_catalog
399
+ self.sevir_data_dir = sevir_data_dir
400
+ self.datetime_filter = datetime_filter
401
+ self.catalog_filter = catalog_filter
402
+ self.start_date = start_date
403
+ self.end_date = end_date
404
+ self.shuffle = shuffle
405
+ self.shuffle_seed = int(shuffle_seed)
406
+ self.output_type = output_type
407
+ self.preprocess = preprocess
408
+ self.downsample_dict = downsample_dict
409
+ self.rescale_method = rescale_method
410
+ self.verbose = verbose
411
+
412
+ if self.start_date is not None:
413
+ self.catalog = self.catalog[self.catalog.time_utc > self.start_date]
414
+ if self.end_date is not None:
415
+ self.catalog = self.catalog[self.catalog.time_utc <= self.end_date]
416
+ if self.datetime_filter:
417
+ self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)]
418
+
419
+ if self.catalog_filter is not None:
420
+ if self.catalog_filter == 'default':
421
+ self.catalog_filter = lambda c: c.pct_missing == 0
422
+ self.catalog = self.catalog[self.catalog_filter(self.catalog)]
423
+
424
+ self._compute_samples()
425
+ self._open_files(verbose=self.verbose)
426
+ self.reset()
427
+
428
+ def _compute_samples(self):
429
+ """
430
+ Computes the list of samples in catalog to be used. This sets self._samples
431
+ """
432
+ # locate all events containing colocated data_types
433
+ imgt = self.data_types
434
+ imgts = set(imgt)
435
+ filtcat = self.catalog[ np.logical_or.reduce([self.catalog.img_type==i for i in imgt]) ]
436
+ # remove rows missing one or more requested img_types
437
+ filtcat = filtcat.groupby('id').filter(lambda x: imgts.issubset(set(x['img_type'])))
438
+ # If there are repeated IDs, remove them (this is a bug in SEVIR)
439
+ # TODO: is it necessary to keep one of them instead of deleting them all
440
+ filtcat = filtcat.groupby('id').filter(lambda x: x.shape[0]==len(imgt))
441
+ self._samples = filtcat.groupby('id').apply(lambda df: self._df_to_series(df,imgt) )
442
+ if self.shuffle:
443
+ self.shuffle_samples()
444
+
445
+ def shuffle_samples(self):
446
+ self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed)
447
+
448
+ def _df_to_series(self, df, imgt):
449
+ d = {}
450
+ df = df.set_index('img_type')
451
+ for i in imgt:
452
+ s = df.loc[i]
453
+ idx = s.file_index if i != 'lght' else s.id
454
+ d.update({f'{i}_filename': [s.file_name],
455
+ f'{i}_index': [idx]})
456
+
457
+ return pd.DataFrame(d)
458
+
459
+ def _open_files(self, verbose=True):
460
+ """
461
+ Opens HDF files
462
+ """
463
+ imgt = self.data_types
464
+ hdf_filenames = []
465
+ for t in imgt:
466
+ hdf_filenames += list(np.unique( self._samples[f'{t}_filename'].values ))
467
+ self._hdf_files = {}
468
+ for f in hdf_filenames:
469
+ if verbose:
470
+ print('Opening HDF5 file for reading', f)
471
+ self._hdf_files[f] = h5py.File(self.sevir_data_dir + '/' + f, 'r')
472
+
473
+ def close(self):
474
+ """
475
+ Closes all open file handles
476
+ """
477
+ for f in self._hdf_files:
478
+ self._hdf_files[f].close()
479
+ self._hdf_files = {}
480
+
481
+ @property
482
+ def num_seq_per_event(self):
483
+ return 1 + (self.raw_seq_len - self.seq_len) // self.stride
484
+
485
+ @property
486
+ def total_num_seq(self):
487
+ """
488
+ The total number of sequences within each shard.
489
+ Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`.
490
+ """
491
+ return int(self.num_seq_per_event * self.num_event)
492
+
493
+ @property
494
+ def total_num_event(self):
495
+ """
496
+ The total number of events in the whole dataset, before split into different shards.
497
+ """
498
+ return int(self._samples.shape[0])
499
+
500
+ @property
501
+ def start_event_idx(self):
502
+ """
503
+ The event idx used in certain rank should satisfy event_idx >= start_event_idx
504
+ """
505
+ return self.total_num_event // self.num_shard * self.rank
506
+
507
+ @property
508
+ def end_event_idx(self):
509
+ """
510
+ The event idx used in certain rank should satisfy event_idx < end_event_idx
511
+
512
+ """
513
+ if self.split_mode == 'ceil':
514
+ _last_start_event_idx = self.total_num_event // self.num_shard * (self.num_shard - 1)
515
+ _num_event = self.total_num_event - _last_start_event_idx
516
+ return self.start_event_idx + _num_event
517
+ elif self.split_mode == 'floor':
518
+ return self.total_num_event // self.num_shard * (self.rank + 1)
519
+ else: # self.split_mode == 'uneven':
520
+ if self.rank == self.num_shard - 1: # the last process
521
+ return self.total_num_event
522
+ else:
523
+ return self.total_num_event // self.num_shard * (self.rank + 1)
524
+
525
+ @property
526
+ def num_event(self):
527
+ """
528
+ The number of events split into each rank
529
+ """
530
+ return self.end_event_idx - self.start_event_idx
531
+
532
+ def _read_data(self, row, data):
533
+ """
534
+ Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_len).
535
+
536
+ Parameters
537
+ ----------
538
+ row
539
+ A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index.
540
+ data
541
+ Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_len).
542
+
543
+ Returns
544
+ -------
545
+ data
546
+ Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_len).
547
+ """
548
+ imgtyps = np.unique([x.split('_')[0] for x in list(row.keys())])
549
+ for t in imgtyps:
550
+ fname = row[f'{t}_filename']
551
+ idx = row[f'{t}_index']
552
+ t_slice = slice(0, None)
553
+ # Need to bin lght counts into grid
554
+ if t == 'lght':
555
+ lght_data = self._hdf_files[fname][idx][:]
556
+ data_i = self._lght_to_grid(lght_data, t_slice)
557
+ else:
558
+ data_i = self._hdf_files[fname][t][idx:idx + 1, :, :, t_slice]
559
+ data[t] = np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i
560
+
561
+ return data
562
+
563
+ def _lght_to_grid(self, data, t_slice=slice(0, None)):
564
+ """
565
+ Converts Nx5 lightning data matrix into a 2D grid of pixel counts
566
+ """
567
+ # out_size = (48,48,len(self.lght_frame_times)-1) if isinstance(t_slice,(slice,)) else (48,48)
568
+ out_size = (*self.data_shape['lght'], len(self.lght_frame_times)) if t_slice.stop is None else (*self.data_shape['lght'], 1)
569
+ if data.shape[0] == 0:
570
+ return np.zeros((1,) + out_size, dtype=np.float32)
571
+
572
+ # filter out points outside the grid
573
+ x, y = data[:, 3], data[:, 4]
574
+ m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]])
575
+ data = data[m, :]
576
+ if data.shape[0] == 0:
577
+ return np.zeros((1,) + out_size, dtype=np.float32)
578
+
579
+ # Filter/separate times
580
+ t = data[:, 0]
581
+ if t_slice.stop is not None: # select only one time bin
582
+ if t_slice.stop > 0:
583
+ if t_slice.stop < len(self.lght_frame_times):
584
+ tm = np.logical_and(t >= self.lght_frame_times[t_slice.stop - 1],
585
+ t < self.lght_frame_times[t_slice.stop])
586
+ else:
587
+ tm = t >= self.lght_frame_times[-1]
588
+ else: # special case: frame 0 uses lght from frame 1
589
+ tm = np.logical_and(t >= self.lght_frame_times[0], t < self.lght_frame_times[1])
590
+ # tm=np.logical_and( (t>=FRAME_TIMES[t_slice],t<FRAME_TIMES[t_slice+1]) )
591
+
592
+ data = data[tm, :]
593
+ z = np.zeros(data.shape[0], dtype=np.int64)
594
+ else: # compute z coordinate based on bin location times
595
+ z = np.digitize(t, self.lght_frame_times) - 1
596
+ z[z == -1] = 0 # special case: frame 0 uses lght from frame 1
597
+
598
+ x = data[:, 3].astype(np.int64)
599
+ y = data[:, 4].astype(np.int64)
600
+
601
+ k = np.ravel_multi_index(np.array([y, x, z]), out_size)
602
+ n = np.bincount(k, minlength=np.prod(out_size))
603
+ return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :]
604
+
605
+ def _old_save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True):
606
+ """
607
+ This method does not save .h5 dataset correctly. There are some batches missed due to unknown error.
608
+ E.g., the first converted .h5 file `SEVIR_VIL_RANDOMEVENTS_2017_0501_0831.h5` only has batch_dim = 1414,
609
+ while it should be 1440 in the original .h5 file.
610
+ """
611
+ import os
612
+ from skimage.measure import block_reduce
613
+ assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!"
614
+ os.makedirs(save_dir)
615
+ sample_counter = 0
616
+ for index, row in self._samples.iterrows():
617
+ if verbose:
618
+ print(f"Downsampling {sample_counter}-th data item.", end='\r')
619
+ for data_type in self.data_types:
620
+ fname = row[f'{data_type}_filename']
621
+ idx = row[f'{data_type}_index']
622
+ t_slice = slice(0, None)
623
+ if data_type == 'lght':
624
+ lght_data = self._hdf_files[fname][idx][:]
625
+ data_i = self._lght_to_grid(lght_data, t_slice)
626
+ else:
627
+ data_i = self._hdf_files[fname][data_type][idx:idx + 1, :, :, t_slice]
628
+ # Downsample t
629
+ t_slice = [slice(None, None), ] * 4
630
+ t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) # layout = 'NHWT'
631
+ data_i = data_i[tuple(t_slice)]
632
+ # Downsample h, w
633
+ data_i = block_reduce(data_i,
634
+ block_size=(1, *downsample_dict[data_type][1:], 1),
635
+ func=np.max)
636
+ # Save as new .h5 file
637
+ new_file_path = os.path.join(save_dir, fname)
638
+ if not os.path.exists(new_file_path):
639
+ if not os.path.exists(os.path.dirname(new_file_path)):
640
+ os.makedirs(os.path.dirname(new_file_path))
641
+ # Create dataset
642
+ with h5py.File(new_file_path, 'w') as hf:
643
+ hf.create_dataset(
644
+ data_type, data=data_i,
645
+ maxshape=(None, *data_i.shape[1:]))
646
+ else:
647
+ # Append
648
+ with h5py.File(new_file_path, 'a') as hf:
649
+ hf[data_type].resize((hf[data_type].shape[0] + data_i.shape[0]), axis=0)
650
+ hf[data_type][-data_i.shape[0]:] = data_i
651
+
652
+ sample_counter += 1
653
+
654
+ def save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True):
655
+ """
656
+ Parameters
657
+ ----------
658
+ save_dir
659
+ downsample_dict: Dict[Sequence[int]]
660
+ Notice that this is different from `self.downsample_dict`, which is used during runtime.
661
+ """
662
+ import os
663
+ from skimage.measure import block_reduce
664
+ from ...utils.utils import path_splitall
665
+ assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!"
666
+ os.makedirs(save_dir)
667
+ for fname, hdf_file in self._hdf_files.items():
668
+ if verbose:
669
+ print(f"Downsampling data in {fname}.")
670
+ data_type = path_splitall(fname)[0]
671
+ if data_type == 'lght':
672
+ # TODO: how to get idx?
673
+ raise NotImplementedError
674
+ # lght_data = self._hdf_files[fname][idx][:]
675
+ # t_slice = slice(0, None)
676
+ # data_i = self._lght_to_grid(lght_data, t_slice)
677
+ else:
678
+ data_i = self._hdf_files[fname][data_type]
679
+ # Downsample t
680
+ t_slice = [slice(None, None), ] * 4
681
+ t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) # layout = 'NHWT'
682
+ data_i = data_i[tuple(t_slice)]
683
+ # Downsample h, w
684
+ data_i = block_reduce(data_i,
685
+ block_size=(1, *downsample_dict[data_type][1:], 1),
686
+ func=np.max)
687
+ # Save as new .h5 file
688
+ new_file_path = os.path.join(save_dir, fname)
689
+ if not os.path.exists(os.path.dirname(new_file_path)):
690
+ os.makedirs(os.path.dirname(new_file_path))
691
+ # Create dataset
692
+ with h5py.File(new_file_path, 'w') as hf:
693
+ hf.create_dataset(
694
+ data_type, data=data_i,
695
+ maxshape=(None, *data_i.shape[1:]))
696
+
697
+ @property
698
+ def sample_count(self):
699
+ """
700
+ Record how many times self.__next__() is called.
701
+ """
702
+ return self._sample_count
703
+
704
+ def inc_sample_count(self):
705
+ self._sample_count += 1
706
+
707
+ @property
708
+ def curr_event_idx(self):
709
+ return self._curr_event_idx
710
+
711
+ @property
712
+ def curr_seq_idx(self):
713
+ """
714
+ Used only when self.sample_mode == 'sequent'
715
+ """
716
+ return self._curr_seq_idx
717
+
718
+ def set_curr_event_idx(self, val):
719
+ self._curr_event_idx = val
720
+
721
+ def set_curr_seq_idx(self, val):
722
+ """
723
+ Used only when self.sample_mode == 'sequent'
724
+ """
725
+ self._curr_seq_idx = val
726
+
727
+ def reset(self, shuffle: bool = None):
728
+ self.set_curr_event_idx(val=self.start_event_idx)
729
+ self.set_curr_seq_idx(0)
730
+ self._sample_count = 0
731
+ if shuffle is None:
732
+ shuffle = self.shuffle
733
+ if shuffle:
734
+ self.shuffle_samples()
735
+
736
+ def __len__(self):
737
+ """
738
+ Used only when self.sample_mode == 'sequent'
739
+ """
740
+ return self.total_num_seq // self.batch_size
741
+
742
+ @property
743
+ def use_up(self):
744
+ """
745
+ Check if dataset is used up in 'sequent' mode.
746
+ """
747
+ if self.sample_mode == 'random':
748
+ return False
749
+ else: # self.sample_mode == 'sequent'
750
+ # compute the remaining number of sequences in current event
751
+ curr_event_remain_seq = self.num_seq_per_event - self.curr_seq_idx
752
+ all_remain_seq = curr_event_remain_seq + (
753
+ self.end_event_idx - self.curr_event_idx - 1) * self.num_seq_per_event
754
+ if self.split_mode == "floor":
755
+ # This approach does not cover all available data, but avoid dealing with masks
756
+ return all_remain_seq < self.batch_size
757
+ else:
758
+ return all_remain_seq <= 0
759
+
760
+ def _load_event_batch(self, event_idx, event_batch_size):
761
+ """
762
+ Loads a selected batch of events (not batch of sequences) into memory.
763
+
764
+ Parameters
765
+ ----------
766
+ idx
767
+ event_batch_size
768
+ event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size]
769
+ Returns
770
+ -------
771
+ event_batch
772
+ list of event batches.
773
+ event_batch[i] is the event batch of the i-th data type.
774
+ Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_len)
775
+ """
776
+ event_idx_slice_end = event_idx + event_batch_size
777
+ pad_size = 0
778
+ if event_idx_slice_end > self.end_event_idx:
779
+ pad_size = event_idx_slice_end - self.end_event_idx
780
+ event_idx_slice_end = self.end_event_idx
781
+ pd_batch = self._samples.iloc[event_idx:event_idx_slice_end]
782
+ data = {}
783
+ for index, row in pd_batch.iterrows():
784
+ data = self._read_data(row, data)
785
+ if pad_size > 0:
786
+ event_batch = []
787
+ for t in self.data_types:
788
+ pad_shape = [pad_size, ] + list(data[t].shape[1:])
789
+ data_pad = np.concatenate((data[t].astype(self.output_type),
790
+ np.zeros(pad_shape, dtype=self.output_type)),
791
+ axis=0)
792
+ event_batch.append(data_pad)
793
+ else:
794
+ event_batch = [data[t].astype(self.output_type) for t in self.data_types]
795
+ return event_batch
796
+
797
+ def __iter__(self):
798
+ return self
799
+
800
+ def __next__(self):
801
+ if self.sample_mode == 'random':
802
+ self.inc_sample_count()
803
+ ret_dict = self._random_sample()
804
+ else:
805
+ if self.use_up:
806
+ raise StopIteration
807
+ else:
808
+ self.inc_sample_count()
809
+ ret_dict = self._sequent_sample()
810
+ ret_dict = self.data_dict_to_tensor(data_dict=ret_dict,
811
+ data_types=self.data_types)
812
+ if self.preprocess:
813
+ ret_dict = self.preprocess_data_dict(data_dict=ret_dict,
814
+ data_types=self.data_types,
815
+ layout=self.layout,
816
+ rescale=self.rescale_method)
817
+ if self.downsample_dict is not None:
818
+ ret_dict = self.downsample_data_dict(data_dict=ret_dict,
819
+ data_types=self.data_types,
820
+ factors_dict=self.downsample_dict,
821
+ layout=self.layout)
822
+ return ret_dict
823
+
824
+ def __getitem__(self, index):
825
+ data_dict = self._idx_sample(index=index)
826
+ return data_dict
827
+
828
+ @staticmethod
829
+ def preprocess_data_dict(data_dict, data_types=None, layout='NHWT', rescale='01'):
830
+ """
831
+ Parameters
832
+ ----------
833
+ data_dict: Dict[str, Union[np.ndarray, torch.Tensor]]
834
+ data_types: Sequence[str]
835
+ The data types that we want to rescale. This mainly excludes "mask" from preprocessing.
836
+ layout: str
837
+ consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W'
838
+ rescale: str
839
+ 'sevir': use the offsets and scale factors in original implementation.
840
+ '01': scale all values to range 0 to 1, currently only supports 'vil'
841
+ Returns
842
+ -------
843
+ data_dict: Dict[str, Union[np.ndarray, torch.Tensor]]
844
+ preprocessed data
845
+ """
846
+ if rescale == 'sevir':
847
+ scale_dict = PREPROCESS_SCALE_SEVIR
848
+ offset_dict = PREPROCESS_OFFSET_SEVIR
849
+ elif rescale == '01':
850
+ scale_dict = PREPROCESS_SCALE_01
851
+ offset_dict = PREPROCESS_OFFSET_01
852
+ else:
853
+ raise ValueError(f'Invalid rescale option: {rescale}.')
854
+ if data_types is None:
855
+ data_types = data_dict.keys()
856
+ for key, data in data_dict.items():
857
+ if key in data_types:
858
+ if isinstance(data, np.ndarray):
859
+ data = scale_dict[key] * (
860
+ data.astype(np.float32) +
861
+ offset_dict[key])
862
+ data = change_layout_np(data=data,
863
+ in_layout='NHWT',
864
+ out_layout=layout)
865
+ elif isinstance(data, torch.Tensor):
866
+ data = scale_dict[key] * (
867
+ data.float() +
868
+ offset_dict[key])
869
+ data = change_layout_torch(data=data,
870
+ in_layout='NHWT',
871
+ out_layout=layout)
872
+ data_dict[key] = data
873
+ return data_dict
874
+
875
+ @staticmethod
876
+ def process_data_dict_back(data_dict, data_types=None, rescale='01'):
877
+ """
878
+ Parameters
879
+ ----------
880
+ data_dict
881
+ each data_dict[key] is a torch.Tensor.
882
+ rescale
883
+ str:
884
+ 'sevir': data are scaled using the offsets and scale factors in original implementation.
885
+ '01': data are all scaled to range 0 to 1, currently only supports 'vil'
886
+ Returns
887
+ -------
888
+ data_dict
889
+ each data_dict[key] is the data processed back in torch.Tensor.
890
+ """
891
+ if rescale == 'sevir':
892
+ scale_dict = PREPROCESS_SCALE_SEVIR
893
+ offset_dict = PREPROCESS_OFFSET_SEVIR
894
+ elif rescale == '01':
895
+ scale_dict = PREPROCESS_SCALE_01
896
+ offset_dict = PREPROCESS_OFFSET_01
897
+ else:
898
+ raise ValueError(f'Invalid rescale option: {rescale}.')
899
+ if data_types is None:
900
+ data_types = data_dict.keys()
901
+ for key in data_types:
902
+ data = data_dict[key]
903
+ data = data.float() / scale_dict[key] - offset_dict[key]
904
+ data_dict[key] = data
905
+ return data_dict
906
+
907
+ @staticmethod
908
+ def data_dict_to_tensor(data_dict, data_types=None):
909
+ """
910
+ Convert each element in data_dict to torch.Tensor (copy without grad).
911
+ """
912
+ ret_dict = {}
913
+ if data_types is None:
914
+ data_types = data_dict.keys()
915
+ for key, data in data_dict.items():
916
+ if key in data_types:
917
+ if isinstance(data, torch.Tensor):
918
+ ret_dict[key] = data.detach().clone()
919
+ elif isinstance(data, np.ndarray):
920
+ ret_dict[key] = torch.from_numpy(data)
921
+ else:
922
+ raise ValueError(f"Invalid data type: {type(data)}. Should be torch.Tensor or np.ndarray")
923
+ else: # key == "mask"
924
+ ret_dict[key] = data
925
+ return ret_dict
926
+
927
+ @staticmethod
928
+ def downsample_data_dict(data_dict, data_types=None, factors_dict=None, layout='NHWT'):
929
+ """
930
+ Parameters
931
+ ----------
932
+ data_dict: Dict[str, Union[np.array, torch.Tensor]]
933
+ factors_dict: Optional[Dict[str, Sequence[int]]]
934
+ each element `factors` is a Sequence of int, representing (t_factor, h_factor, w_factor)
935
+
936
+ Returns
937
+ -------
938
+ downsampled_data_dict: Dict[str, torch.Tensor]
939
+ Modify on a deep copy of data_dict instead of directly modifying the original data_dict
940
+ """
941
+ if factors_dict is None:
942
+ factors_dict = {}
943
+ if data_types is None:
944
+ data_types = data_dict.keys()
945
+ downsampled_data_dict = SEVIRDataLoader.data_dict_to_tensor(
946
+ data_dict=data_dict,
947
+ data_types=data_types) # make a copy
948
+ for key, data in data_dict.items():
949
+ factors = factors_dict.get(key, None)
950
+ if factors is not None:
951
+ downsampled_data_dict[key] = change_layout_torch(
952
+ data=downsampled_data_dict[key],
953
+ in_layout=layout,
954
+ out_layout='NTHW')
955
+ # downsample t dimension
956
+ t_slice = [slice(None, None), ] * 4
957
+ t_slice[1] = slice(None, None, factors[0])
958
+ downsampled_data_dict[key] = downsampled_data_dict[key][tuple(t_slice)]
959
+ # downsample spatial dimensions
960
+ downsampled_data_dict[key] = avg_pool2d(
961
+ input=downsampled_data_dict[key],
962
+ kernel_size=(factors[1], factors[2]))
963
+
964
+ downsampled_data_dict[key] = change_layout_torch(
965
+ data=downsampled_data_dict[key],
966
+ in_layout='NTHW',
967
+ out_layout=layout)
968
+
969
+ return downsampled_data_dict
970
+
971
+ def _random_sample(self):
972
+ """
973
+ Returns
974
+ -------
975
+ ret_dict
976
+ dict. ret_dict.keys() == self.data_types.
977
+ If self.preprocess == False:
978
+ ret_dict[imgt].shape == (batch_size, height, width, seq_len)
979
+ """
980
+ num_sampled = 0
981
+ event_idx_list = nprand.randint(low=self.start_event_idx,
982
+ high=self.end_event_idx,
983
+ size=self.batch_size)
984
+ seq_idx_list = nprand.randint(low=0,
985
+ high=self.num_seq_per_event,
986
+ size=self.batch_size)
987
+ seq_slice_list = [slice(seq_idx * self.stride,
988
+ seq_idx * self.stride + self.seq_len)
989
+ for seq_idx in seq_idx_list]
990
+ ret_dict = {}
991
+ while num_sampled < self.batch_size:
992
+ event = self._load_event_batch(event_idx=event_idx_list[num_sampled],
993
+ event_batch_size=1)
994
+ for imgt_idx, imgt in enumerate(self.data_types):
995
+ sampled_seq = event[imgt_idx][[0, ], :, :, seq_slice_list[num_sampled]] # keep the dim of batch_size for concatenation
996
+ if imgt in ret_dict:
997
+ ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq),
998
+ axis=0)
999
+ else:
1000
+ ret_dict.update({imgt: sampled_seq})
1001
+ return ret_dict
1002
+
1003
+ def _sequent_sample(self):
1004
+ """
1005
+ Returns
1006
+ -------
1007
+ ret_dict: Dict
1008
+ `ret_dict.keys()` contains `self.data_types`.
1009
+ `ret_dict["mask"]` is a list of bool, indicating if the data entry is real or padded.
1010
+ If self.preprocess == False:
1011
+ ret_dict[imgt].shape == (batch_size, height, width, seq_len)
1012
+ """
1013
+ assert not self.use_up, 'Data loader used up! Reset it to reuse.'
1014
+ event_idx = self.curr_event_idx
1015
+ seq_idx = self.curr_seq_idx
1016
+ num_sampled = 0
1017
+ sampled_idx_list = [] # list of (event_idx, seq_idx) records
1018
+ while num_sampled < self.batch_size:
1019
+ sampled_idx_list.append({'event_idx': event_idx,
1020
+ 'seq_idx': seq_idx})
1021
+ seq_idx += 1
1022
+ if seq_idx >= self.num_seq_per_event:
1023
+ event_idx += 1
1024
+ seq_idx = 0
1025
+ num_sampled += 1
1026
+
1027
+ start_event_idx = sampled_idx_list[0]['event_idx']
1028
+ event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1
1029
+
1030
+ event_batch = self._load_event_batch(event_idx=start_event_idx,
1031
+ event_batch_size=event_batch_size)
1032
+ ret_dict = {"mask": []}
1033
+ all_no_pad_flag = True
1034
+ for sampled_idx in sampled_idx_list:
1035
+ batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] # use [] to keepdim
1036
+ seq_slice = slice(sampled_idx['seq_idx'] * self.stride,
1037
+ sampled_idx['seq_idx'] * self.stride + self.seq_len)
1038
+ for imgt_idx, imgt in enumerate(self.data_types):
1039
+ sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice]
1040
+ if imgt in ret_dict:
1041
+ ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq),
1042
+ axis=0)
1043
+ else:
1044
+ ret_dict.update({imgt: sampled_seq})
1045
+ # add mask
1046
+ no_pad_flag = sampled_idx['event_idx'] < self.end_event_idx
1047
+ if not no_pad_flag:
1048
+ all_no_pad_flag = False
1049
+ ret_dict["mask"].append(no_pad_flag)
1050
+ if all_no_pad_flag:
1051
+ # if there is no padded data items at all, set `ret_dict["mask"] = None` for convenience.
1052
+ ret_dict["mask"] = None
1053
+ # update current idx
1054
+ self.set_curr_event_idx(event_idx)
1055
+ self.set_curr_seq_idx(seq_idx)
1056
+ return ret_dict
1057
+
1058
+ def _idx_sample(self, index):
1059
+ """
1060
+ Parameters
1061
+ ----------
1062
+ index
1063
+ The index of the batch to sample.
1064
+ Returns
1065
+ -------
1066
+ ret_dict
1067
+ dict. ret_dict.keys() == self.data_types.
1068
+ If self.preprocess == False:
1069
+ ret_dict[imgt].shape == (batch_size, height, width, seq_len)
1070
+ """
1071
+ event_idx = (index * self.batch_size) // self.num_seq_per_event
1072
+ seq_idx = (index * self.batch_size) % self.num_seq_per_event
1073
+ num_sampled = 0
1074
+ sampled_idx_list = [] # list of (event_idx, seq_idx) records
1075
+ while num_sampled < self.batch_size:
1076
+ sampled_idx_list.append({'event_idx': event_idx,
1077
+ 'seq_idx': seq_idx})
1078
+ seq_idx += 1
1079
+ if seq_idx >= self.num_seq_per_event:
1080
+ event_idx += 1
1081
+ seq_idx = 0
1082
+ num_sampled += 1
1083
+
1084
+ start_event_idx = sampled_idx_list[0]['event_idx']
1085
+ event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1
1086
+
1087
+ event_batch = self._load_event_batch(event_idx=start_event_idx,
1088
+ event_batch_size=event_batch_size)
1089
+ ret_dict = {}
1090
+ for sampled_idx in sampled_idx_list:
1091
+ batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] # use [] to keepdim
1092
+ seq_slice = slice(sampled_idx['seq_idx'] * self.stride,
1093
+ sampled_idx['seq_idx'] * self.stride + self.seq_len)
1094
+ for imgt_idx, imgt in enumerate(self.data_types):
1095
+ sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice]
1096
+ if imgt in ret_dict:
1097
+ ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq),
1098
+ axis=0)
1099
+ else:
1100
+ ret_dict.update({imgt: sampled_seq})
1101
+
1102
+ ret_dict = self.data_dict_to_tensor(data_dict=ret_dict,
1103
+ data_types=self.data_types)
1104
+ if self.preprocess:
1105
+ ret_dict = self.preprocess_data_dict(data_dict=ret_dict,
1106
+ data_types=self.data_types,
1107
+ layout=self.layout,
1108
+ rescale=self.rescale_method)
1109
+
1110
+ if self.downsample_dict is not None:
1111
+ ret_dict = self.downsample_data_dict(data_dict=ret_dict,
1112
+ data_types=self.data_types,
1113
+ factors_dict=self.downsample_dict,
1114
+ layout=self.layout)
1115
+ return ret_dict
1116
+
1117
+
1118
+ class SEVIRDataIterator():
1119
+ '''
1120
+ A wrapper s.t. it implements the function sample().
1121
+ Every arguments in this class will be redirected to the inner SEVIRDataLoader object.
1122
+ If you expect a pythonic iterator, use SEVIRDataLoader instead.
1123
+ '''
1124
+ def __init__(self, **kwargs):
1125
+ self.loader = SEVIRDataLoader(**kwargs)
1126
+ self.sample_mode = kwargs['sample_mode'] if 'sample_mode' in kwargs else 'random'
1127
+
1128
+ def reset(self):
1129
+ self.loader.reset()
1130
+
1131
+ def sample(self, batch_size=None):
1132
+ '''
1133
+ The input param batch_size here is not used
1134
+ '''
1135
+ out = next(self.loader, None)
1136
+ if out is None and self.sample_mode == 'random':
1137
+ self.loader.reset()
1138
+ out = next(self.loader, None)
1139
+ return out
1140
+
1141
+ def __len__(self):
1142
+ """
1143
+ Used only when self.sample_mode == 'sequent'
1144
+ """
1145
+ return len(self.loader)
1146
+
1147
+ # =====================================================================================
1148
+ # MeteoNet data
1149
+ # Reshape it to 256x256, with in_len=4, out_len=10
1150
+ # https://meteofrance.github.io/meteonet/
1151
+ # dwonload from https://meteonet.umr-cnrm.fr/dataset/data/NW/radar/reflectivity_old_product/
1152
+ # =====================================================================================
1153
+
1154
+ class Meteo(Dataset):
1155
+ def __init__(self, data_path, img_size, type='train', trans=None, in_len=-1):
1156
+ super().__init__()
1157
+
1158
+ self.pixel_scale = 70.0
1159
+
1160
+ self.data_path = data_path
1161
+ self.img_size = img_size
1162
+ self.in_len = in_len
1163
+
1164
+ assert type in ['train', 'test', 'val']
1165
+ self.type = type if type!='val' else 'test'
1166
+ with h5py.File(data_path,'r') as f:
1167
+ self.all_len = int(f[f'{self.type}_len'][()]) # 10000-3000 for train, 2000 for test, 1000 for val
1168
+ if trans is not None:
1169
+ self.transform = trans
1170
+ else:
1171
+ self.transform = T.Compose([
1172
+ T.Resize((img_size, img_size)),
1173
+ # transforms.ToTensor(),
1174
+ # trans.Lambda(lambda x: x/255.0),
1175
+ # transforms.Normalize(mean=[0.5], std=[0.5]),
1176
+ # trans.RandomCrop(data_config["img_size"]),
1177
+
1178
+ ])
1179
+
1180
+ def __len__(self):
1181
+ return self.all_len
1182
+
1183
+ def sample(self):
1184
+ index = np.random.randint(0, self.all_len)
1185
+ return self.__getitem__(index)
1186
+
1187
+
1188
+ def __getitem__(self, index):
1189
+
1190
+ with h5py.File(self.data_path,'r') as f:
1191
+ imgs = f[self.type][str(index)][()] # numpy array: (25, 565, 784), dtype=uint8, range(0,70)
1192
+
1193
+ frames = torch.from_numpy(imgs).float().squeeze()
1194
+ frames = frames / self.pixel_scale
1195
+ frames = self.transform(frames).unsqueeze(1)
1196
+
1197
+ # return frames.unsqueeze(1) # (25,1,128,128
1198
+ return frames[:self.in_len], frames[self.in_len:]
1199
+
1200
+
1201
+ def load_meteonet(batch_size, val_batch_size, in_len, train=False, num_workers=0, img_size=128):
1202
+ meteo_filepath = os.path.join(METEO_FILE_DIR, "meteo.h5")
1203
+ if train:
1204
+ train_set = Meteo(meteo_filepath, img_size, 'train', in_len=in_len)
1205
+ valid_set = Meteo(meteo_filepath, img_size, 'val', in_len=in_len)
1206
+ dataloader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
1207
+ dataloader_valid = torch.utils.data.DataLoader(valid_set, batch_size=val_batch_size, shuffle=False, drop_last=True, num_workers=num_workers)
1208
+ return dataloader_train, dataloader_valid
1209
+ else:
1210
+ test_set = Meteo(meteo_filepath, img_size, 'test', in_len=in_len)
1211
+ dataloader_test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
1212
+ return None, dataloader_test
data/loader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ from data import dutils
4
+ from nowcasting.hko_iterator import HKOIterator
5
+
6
+ def GET_TrainLoader(meta, param, batch_size, in_len, out_len):
7
+ if meta['dataset'] == 'SEVIR':
8
+ total_seq_len = in_len + out_len
9
+ train_config = {
10
+ 'data_types': ['vil'],
11
+ 'layout': 'NTCHW',
12
+ 'seq_len': total_seq_len,
13
+ 'raw_seq_len': total_seq_len,
14
+ 'end_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE,
15
+ 'start_date': None
16
+ }
17
+ test_config = {
18
+ 'data_types': ['vil'],
19
+ 'layout': 'NTCHW',
20
+ 'seq_len': total_seq_len,
21
+ 'raw_seq_len': total_seq_len,
22
+ 'end_date': None,
23
+ 'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE
24
+ }
25
+ train_loader = dutils.SEVIRDataIterator(**train_config, batch_size=batch_size)
26
+ test_loader = dutils.SEVIRDataIterator(**test_config, batch_size=8 if batch_size > 8 else batch_size)
27
+ return train_loader, test_loader
28
+ elif meta['dataset'].startswith('HKO'):
29
+ total_seq_len = in_len + out_len
30
+ pkl_path = param['pd_path']
31
+ train_loader = HKOIterator(pd_path=pkl_path.replace('test', 'train'), sample_mode="random", seq_len=total_seq_len, stride=1)
32
+ test_loader = HKOIterator(pd_path=pkl_path, sample_mode="sequent", seq_len=total_seq_len, stride=in_len)
33
+ return train_loader, test_loader
34
+ elif meta['dataset'] == 'meteonet':
35
+ train_loader, test_loader = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8 if batch_size > 8 else batch_size, train=True, **param)
36
+ return train_loader, test_loader
37
+ else:
38
+ raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')
39
+
40
+ def GET_TestLoader(meta, param, batch_size):
41
+ if meta['dataset'] == 'SEVIR':
42
+ return dutils.SEVIRDataIterator(**param, batch_size=batch_size)
43
+ elif meta['dataset'].startswith('HKO'):
44
+ return HKOIterator(**param)
45
+ elif meta['dataset'] == 'meteonet':
46
+ _, test_iter = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8, train=False, **param)
47
+ return iter(test_iter)
48
+ else:
49
+ raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')
data/sample_data.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b298f6d4adfd084653888891b8d32e8bbc84fd1226a6d206e92191abe57ba8f
3
+ size 46080128
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ens_eval.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, logging, argparse
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ import utilspp as utpp
8
+ from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss
9
+ from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
10
+ from data.loader import GET_TestLoader
11
+ from data.dutils import resize
12
+
13
+ class MetricListEvaluator():
14
+ '''
15
+ To evaluate a list of metrics. Supported metrics:
16
+ - CSI, HSS (Eg. `csi-84, hss-84`)
17
+ - CSI-pooled (Eg. `csi_4-84`)
18
+ - MAE
19
+ - MSE
20
+ - SSIM
21
+ - PSNR
22
+ '''
23
+ def __init__(self, metric_list):
24
+ self.metric_holder = {}
25
+ self.batch_count = 0
26
+ for metric_name in metric_list:
27
+ threshold = ''
28
+ radius = ''
29
+ if '-' in metric_name:
30
+ metric, threshold = metric_name.split('-')
31
+ if '_' in metric:
32
+ metric, radius = metric.split('_')
33
+ radius = int(radius)
34
+ # initialize metrics
35
+ threshold = float(threshold) / 255 if threshold.isdigit() else threshold
36
+ self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius)
37
+
38
+ def init_metric(self, metric_name, **kwarg):
39
+ '''
40
+ return a tuple of three items in order:
41
+ - the function to call during eval
42
+ - the value(s) to keep track of
43
+ - a dict of any additional item to pass into the function
44
+ '''
45
+ if metric_name.split('-')[0] in ['csi', 'hss']:
46
+ # use tfpn instead
47
+ return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] # tp,
48
+ elif '_' in metric_name.split('-')[0]: # Indicate Pooling
49
+ return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}]
50
+ else:
51
+ # directly convert the string name into function call
52
+ return [eval(metric_name), 0, {}]
53
+
54
+ def eval(self, y_pred, y):
55
+ self.batch_count += 1
56
+ for _, metric in self.metric_holder.items():
57
+ temp = metric[0](y_pred, y, **metric[-1])
58
+ if temp is list:
59
+ temp = np.array(temp)
60
+ elif type(temp) == torch.Tensor:
61
+ temp = temp.detach().cpu().numpy()
62
+ metric[1] += temp
63
+
64
+ def get_results(self):
65
+ output_holder = {}
66
+ for key, metric in self.metric_holder.items():
67
+ val = metric[1]
68
+ # special handle of tfpn => compute the final score now
69
+ if metric[0] is utpp.tfpn:
70
+ metric_name, threshold = key.split('-')
71
+ val = eval(metric_name)(*list(metric[1]))
72
+ elif metric[0] is utpp.tfpn_pool:
73
+ metric_name, info = key.split('_')
74
+ val = eval(metric_name)(*list(metric[1]))
75
+ else:
76
+ val /= self.batch_count
77
+ output_holder[key] = val
78
+ return output_holder
79
+
80
+
81
+ if __name__ == '__main__':
82
+ parser = argparse.ArgumentParser()
83
+ # Dataset related
84
+ parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
85
+ parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length')
86
+ # ensemble npy filename with {}
87
+ parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}')
88
+ parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number')
89
+ # hyperparams
90
+ parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
91
+ parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
92
+ # config override
93
+ parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /')
94
+ # logging related
95
+ parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
96
+ args = parser.parse_args()
97
+
98
+ # Prepare logger
99
+ path_list = args.e_file.split("/")
100
+ logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log')
101
+ logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
102
+ logging.info(f'Steps: {args.step}')
103
+
104
+ dataset_config = globals()[args.dataset]
105
+ dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
106
+ loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)
107
+
108
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
109
+ # prepare metrics
110
+ metric_list = dataset_meta['metrics']
111
+ if args.metrics is not None:
112
+ metric_list = args.metrics.lower().split('/')
113
+ logging.info(f'Overwriting metrics list with: {metric_list}')
114
+ evaluator = MetricListEvaluator(metric_list)
115
+
116
+ for e in range(args.ens_no):
117
+ prediction = np.load(args.e_file.format(str(e)))
118
+ prediction = torch.tensor(prediction, device=device)
119
+
120
+ step = 1
121
+ if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']:
122
+ loader.reset() # Reset it, otherwise alignment error
123
+ else:
124
+ pass
125
+
126
+ while args.step < 0 or step <= args.step:
127
+ if dataset_meta['dataset'] == 'SEVIR':
128
+ data = loader.sample(batch_size=args.batch_size)
129
+ if data is None:
130
+ break
131
+ y = data['vil'][:, -args.out_len:] # Expected to be same as prediction
132
+ elif dataset_meta['dataset'] == 'HKO-7':
133
+ setattr(args, 'seq_len', dataset_meta['seq_len'])
134
+ try:
135
+ data = loader.sample(batch_size=args.batch_size)
136
+ except Exception as e:
137
+ logging.error(e)
138
+ break
139
+ x_seq, x_mask, dt_clip, _ = data
140
+ x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
141
+ elif dataset_meta['dataset'].startswith('meteo'):
142
+ try:
143
+ x, y = next(loader)
144
+ except Exception as e:
145
+ logging.error(e)
146
+ break
147
+
148
+ with torch.no_grad():
149
+ y = y.to(device)
150
+ y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size]
151
+
152
+ if y.shape[-1] != y_pred.shape[-1]:
153
+ y = resize(y, y_pred.shape[-1])
154
+
155
+ y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) # B T C H W
156
+ evaluator.eval(y_pred, y)
157
+ # log/print every
158
+ if step == 1 or step % args.print_every == 0:
159
+ logging.info(f'E_ID:{e} -> {step} Steps evaluated')
160
+ step += 1
161
+ # log the final scores
162
+ final_results = evaluator.get_results()
163
+ for k, v in final_results.items():
164
+ logging.info(f'{k}: {v}')
ens_gen.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample Command
3
+ """
4
+ import os, sys, logging, argparse
5
+ import torch
6
+ from torch import nn
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+
10
+ from stldm import *
11
+ import utilspp as utpp
12
+ from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
13
+ from data.loader import GET_TestLoader
14
+ from data.dutils import resize
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ # Dataset related
19
+ parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
20
+ # model related
21
+ parser.add_argument('-f', dest='checkpt', type=str, default='', help='model checkpoint to be loaded from (Empty = not loading)')
22
+ parser.add_argument('-m', '--model', type=str, default='', help='the model definition to be created')
23
+ parser.add_argument('--type', type=str, default='3D', help='Determine which kind of model to use, 2D or 3D')
24
+ parser.add_argument('--c_str', type=float, default=0.0, help='CFG strength')
25
+ parser.add_argument('--e_id', type=int, default=0, help='Ensemble ID')
26
+ # hyperparams
27
+ parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
28
+ parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
29
+ # logging related
30
+ parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
31
+ parser.add_argument('-o', '--output', default=None, help='The path to save the log files')
32
+ args = parser.parse_args()
33
+
34
+ # prepare logger
35
+ if args.output is None:
36
+ path_list = args.checkpt.split("/")
37
+ logfile_name = os.path.join(*path_list[:-1], 'logs', f'{path_list[-1][:-3]}.log')
38
+ else:
39
+ logfile_name = f'{args.output}.log'
40
+ logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
41
+ logging.info(f'Model checkpoint: {args.checkpt}')
42
+ logging.info(f'Steps: {args.step}')
43
+
44
+ sampler_dir = os.path.join(*logfile_name.split("/")[:-2], f'CFG={args.c_str}_samples')
45
+ os.makedirs(sampler_dir, exist_ok=True)
46
+
47
+ # Prepare Dataloader
48
+ dataset_config = globals()[args.dataset]
49
+ dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
50
+ loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)
51
+
52
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
+
54
+ # Prepare Model
55
+ assert args.type in ['2D', '3D'], 'Please specify either 2D or 3D'
56
+ model_config = globals()[args.model]
57
+ model = n2n_setup[args.type](model_config, print_info=True, cfg_str=args.c_str if args.c_str != 0.0 else None).to(device)
58
+ logging.info(f'CFG Scheduler: Const-{args.c_str}')
59
+
60
+ data = torch.load(args.checkpt, map_location=device)
61
+ if 'model' in data.keys():
62
+ model.load_state_dict(data['model'])
63
+ else:
64
+ model.load_state_dict(data)
65
+
66
+
67
+ in_len, out_len = model_config['vp_param']['shape_in'][0], model_config['vp_param']['shape_out'][0]
68
+ img_size = model_config['vp_param']['shape_in'][-1]
69
+
70
+ step = 0
71
+ out = []
72
+ while args.step < 0 or step <=args.step:
73
+ model.eval()
74
+
75
+ if dataset_meta['dataset'] == 'HKO-7':
76
+ setattr(args, 'seq_len', in_len)
77
+ try:
78
+ data = loader.sample(batch_size=args.batch_size)
79
+ except Exception as e:
80
+ logging.error(e)
81
+ break
82
+ x_seq, x_mask, dt_clip, _ = data
83
+ x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
84
+ elif dataset_meta['dataset'] == 'SEVIR':
85
+ data = loader.sample(batch_size=args.batch_size)
86
+ if data is None:
87
+ break
88
+ x, y = data['vil'][:, :in_len], data['vil'][:, in_len:]
89
+ elif dataset_meta['dataset'].startswith('meteo'):
90
+ try:
91
+ x, y = next(loader)
92
+ except Exception as e:
93
+ logging.error(e)
94
+ break
95
+
96
+ x, y = x.to(device), y.to(device)
97
+
98
+ with torch.no_grad():
99
+ if x.shape[-1] != img_size:
100
+ x = resize(x, img_size)
101
+ y = resize(y, img_size) # TO compare with DiffCast paper
102
+ if model_config['pre'] is not None:
103
+ x = model_config['pre'](x)
104
+
105
+ y_pred = model(x)
106
+
107
+ if model_config['post'] is not None:
108
+ x = model_config['post'](x)
109
+ y_pred = model_config['post'](y_pred)
110
+ y_pred = y_pred.clamp(0,1)
111
+
112
+ out.append(y_pred.detach().cpu())
113
+
114
+ step += 1
115
+ # log/print every
116
+ if step == 1 or step % args.print_every == 0:
117
+ logging.info(f'{step} Steps Generated, {len(out)} in out array')
118
+
119
+ logging.info(f'{step} Steps Generated, {len(out)} in out array')
120
+ out = torch.cat(out, dim=0)
121
+ out = out.numpy()
122
+ save_path = os.path.join(sampler_dir, f'BTCHW_total-no:{len(out)}_e={args.e_id}.npy')
123
+ np.save(save_path, out)
124
+ print('Output saved in', save_path)
nowcasting/__init__.py ADDED
File without changes
nowcasting/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
nowcasting/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (159 Bytes). View file
 
nowcasting/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
nowcasting/__pycache__/config.cpython-310.pyc ADDED
Binary file (8.53 kB). View file
 
nowcasting/__pycache__/config.cpython-38.pyc ADDED
Binary file (8.34 kB). View file
 
nowcasting/__pycache__/config.cpython-39.pyc ADDED
Binary file (8.6 kB). View file
 
nowcasting/__pycache__/hko_iterator.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
nowcasting/__pycache__/hko_iterator.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
nowcasting/__pycache__/hko_iterator.cpython-39.pyc ADDED
Binary file (15.6 kB). View file
 
nowcasting/__pycache__/image.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
nowcasting/__pycache__/image.cpython-38.pyc ADDED
Binary file (3.75 kB). View file
 
nowcasting/__pycache__/image.cpython-39.pyc ADDED
Binary file (3.76 kB). View file
 
nowcasting/__pycache__/mask.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
nowcasting/__pycache__/mask.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
nowcasting/__pycache__/mask.cpython-39.pyc ADDED
Binary file (1.33 kB). View file
 
nowcasting/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
nowcasting/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.63 kB). View file
 
nowcasting/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.62 kB). View file
 
nowcasting/config.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import yaml
4
+ import logging
5
+ from collections import OrderedDict
6
+ from .helpers.ordered_easydict import OrderedEasyDict as edict
7
+
8
+ __C = edict()
9
+ cfg = __C # type: edict()
10
+
11
+ # Random seed
12
+ __C.SEED = None
13
+
14
+ # Dataset name
15
+ # Used by symbols factories who need to adjust for different
16
+ # inputs based on dataset used. Should be set by the script.
17
+ __C.DATASET = None
18
+
19
+ # Project directory, since config.py is supposed to be in $ROOT_DIR/nowcasting
20
+ #__C.ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
21
+ # ====== Changed the root dir to a fixed path =======
22
+ __C.ROOT_DIR = 'data/HKO-7'
23
+
24
+ __C.MNIST_PATH = os.path.join(__C.ROOT_DIR, 'mnist_data')
25
+ if not os.path.exists(__C.MNIST_PATH):
26
+ os.makedirs(__C.MNIST_PATH)
27
+ __C.HKO_DATA_BASE_PATH = os.path.join(__C.ROOT_DIR, 'hko_data')
28
+
29
+ # # Append your path to the possible paths
30
+ possible_hko_png_paths = [os.path.join('E:\\datasets\\HKO-data\\radarPNG\\radarPNG'),
31
+ os.path.join(__C.HKO_DATA_BASE_PATH, 'radarPNG'),
32
+ 'data/HKO-7/radarPNG']
33
+ possible_hko_mask_paths = [os.path.join('E:\\datasets\\HKO-data\\radarPNG\\radarPNG_mask'),
34
+ os.path.join(__C.HKO_DATA_BASE_PATH, 'radarPNG_mask'),
35
+ 'data/HKO-7/radarPNG_mask']
36
+
37
+ # Search for the radarPNG
38
+ find_hko_png_path = False
39
+ for ele in possible_hko_png_paths:
40
+ if os.path.exists(ele):
41
+ find_hko_png_path = True
42
+ __C.HKO_PNG_PATH = ele
43
+ break
44
+ if not find_hko_png_path:
45
+ raise RuntimeError("radarPNG is not found! You can download the radarPNG using"
46
+ " `bash download_radar_png.bash`")
47
+ # Search for the radarPNG_mask
48
+ find_hko_mask_path = False
49
+ for ele in possible_hko_mask_paths:
50
+ if os.path.exists(ele):
51
+ find_hko_mask_path = True
52
+ __C.HKO_MASK_PATH = ele
53
+ break
54
+ if not find_hko_mask_path:
55
+ raise RuntimeError("radarPNG_mask is not found! You can download the radarPNG_mask using"
56
+ " `bash download_radar_png.bash`")
57
+ if not os.path.exists(__C.HKO_DATA_BASE_PATH):
58
+ os.makedirs(__C.HKO_DATA_BASE_PATH)
59
+ __C.HKO_PD_BASE_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'pd')
60
+ if not os.path.exists(__C.HKO_PD_BASE_PATH):
61
+ os.makedirs(__C.HKO_PD_BASE_PATH)
62
+ __C.HKO_VALID_DATETIME_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'valid_datetime.pkl')
63
+ __C.HKO_SORTED_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'sorted_day.pkl')
64
+ __C.HKO_RAINY_TRAIN_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_train_days.txt')
65
+ __C.HKO_RAINY_VALID_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_valid_days.txt')
66
+ __C.HKO_RAINY_TEST_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_test_days.txt')
67
+
68
+ __C.HKO_PD = edict()
69
+ __C.HKO_PD.ALL = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all.pkl')
70
+ __C.HKO_PD.ALL_09_14 = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all_09_14.pkl')
71
+ __C.HKO_PD.ALL_15 = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all_15.pkl')
72
+ __C.HKO_PD.RAINY_TRAIN = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_train.pkl')
73
+ __C.HKO_PD.RAINY_VALID = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_valid.pkl')
74
+ __C.HKO_PD.RAINY_TEST = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_test.pkl')
75
+
76
+ __C.HKO = edict()
77
+ __C.HKO.ITERATOR = edict()
78
+ __C.HKO.ITERATOR.WIDTH = 480
79
+ __C.HKO.ITERATOR.HEIGHT = 480
80
+ __C.HKO.ITERATOR.FILTER_RAINFALL = True # Whether to discard part of the rainfall, has a denoising effect
81
+ __C.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD = 0.28 # All the pixel values that are smaller than round(threshold * 255) will be discarded
82
+
83
+
84
+ # The Benchmark parameters
85
+ __C.HKO.BENCHMARK = edict()
86
+ __C.HKO.BENCHMARK.STAT_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'benchmark_stat')
87
+ if not os.path.exists(__C.HKO.BENCHMARK.STAT_PATH):
88
+ os.makedirs(__C.HKO.BENCHMARK.STAT_PATH)
89
+ __C.HKO.BENCHMARK.VISUALIZE_SEQ_NUM = 10 # Number of sequences that will be plotted and saved to the benchmark directory
90
+ __C.HKO.BENCHMARK.IN_LEN = 5 # The maximum input length to ensure that all models are tested on the same set of input data
91
+ __C.HKO.BENCHMARK.OUT_LEN = 20 # The maximum output length to ensure that all models are tested on the same set of input data
92
+ __C.HKO.BENCHMARK.STRIDE = 5 # The stride
93
+
94
+
95
+ __C.HKO.EVALUATION = edict()
96
+ __C.HKO.EVALUATION.ZR = edict()
97
+ __C.HKO.EVALUATION.ZR.a = 58.53 # The a factor in the Z-R relationship
98
+ __C.HKO.EVALUATION.ZR.b = 1.56 # The b factor in the Z-R relationship
99
+ __C.HKO.EVALUATION.THRESHOLDS = (0.5, 2, 5, 10, 30)
100
+ __C.HKO.EVALUATION.BALANCING_WEIGHTS = (1, 1, 2, 5, 10, 30) # The corresponding balancing weights
101
+ __C.HKO.EVALUATION.CENTRAL_REGION = (120, 120, 360, 360)
102
+
103
+ __C.MOVINGMNIST = edict()
104
+ __C.MOVINGMNIST.DISTRACTOR_NUM = 0
105
+ __C.MOVINGMNIST.VELOCITY_LOWER = 0.0
106
+ __C.MOVINGMNIST.VELOCITY_UPPER = 3.6
107
+ __C.MOVINGMNIST.SCALE_VARIATION_LOWER = 1/1.1
108
+ __C.MOVINGMNIST.SCALE_VARIATION_UPPER = 1.1
109
+ __C.MOVINGMNIST.ROTATION_LOWER = -30
110
+ __C.MOVINGMNIST.ROTATION_UPPER = 30
111
+ __C.MOVINGMNIST.ILLUMINATION_LOWER = 0.6
112
+ __C.MOVINGMNIST.ILLUMINATION_UPPER = 1.0
113
+ __C.MOVINGMNIST.DIGIT_NUM = 3
114
+ __C.MOVINGMNIST.IN_LEN = 10
115
+ __C.MOVINGMNIST.OUT_LEN = 10
116
+ __C.MOVINGMNIST.TESTING_LEN = 20
117
+ __C.MOVINGMNIST.IMG_SIZE = 64
118
+ __C.MOVINGMNIST.TEST_FILE = os.path.join(__C.MNIST_PATH, "movingmnist_10000_nodistr.npz")
119
+
120
+ __C.MODEL = edict()
121
+ __C.MODEL.RESUME = False # If True, load LOAD_ITER parameters from LOAD_DIR
122
+ __C.MODEL.TESTING = False # If True, run in Testing mode
123
+ __C.MODEL.LOAD_DIR = "" # The directory to load the pre-trained parameters
124
+ # Could be like `D:\\HKUST\\3-2\\NIPS2017\\hko_0502\\bal_loss_direct`
125
+ __C.MODEL.LOAD_ITER = 79999 # Only applicable when LOAD_DIR is non-empty
126
+ __C.MODEL.SAVE_DIR = ""
127
+ __C.MODEL.CNN_ACT_TYPE = "leaky"
128
+ __C.MODEL.RNN_ACT_TYPE = "leaky"
129
+ __C.MODEL.FRAME_STACK = 1 # Stack multiple frames as the input
130
+ __C.MODEL.FRAME_SKIP = 1 # The frame skip size
131
+ __C.MODEL.IN_LEN = 5 # Size of the input
132
+ __C.MODEL.OUT_LEN = 20 # Size of the output
133
+ __C.MODEL.OUT_TYPE = "direct" # Can be "direct", or "DFN"
134
+ __C.MODEL.NORMAL_LOSS_GLOBAL_SCALE = 0.00005
135
+ __C.MODEL.USE_BALANCED_LOSS = True
136
+ __C.MODEL.TEMPORAL_WEIGHT_TYPE = "same" # Can be "same", "linear" or "exponential"
137
+ __C.MODEL.TEMPORAL_WEIGHT_UPPER = 5 # Only applicable when temporal_weights_type is "linear" or "exponential"
138
+ # If linear
139
+ # the weights will be increased following (1 + i * (upper - 1) / (T - 1))
140
+ # If exponential
141
+ # the weights will be increased following exp^{i * \ln(upper) / (T-1)}
142
+ __C.MODEL.L1_LAMBDA = 1.0
143
+ __C.MODEL.L2_LAMBDA = 1.0
144
+ __C.MODEL.GDL_LAMBDA = 0.0
145
+ __C.MODEL.USE_SEASONALITY = False # Whether to use seasonality
146
+
147
+ __C.MODEL.TRAJRNN = edict()
148
+ __C.MODEL.TRAJRNN.INIT_GRID = True
149
+ __C.MODEL.TRAJRNN.FLOW_LR_MULT = 1.0
150
+ __C.MODEL.TRAJRNN.SAVE_MID_RESULTS = False
151
+
152
+ __C.MODEL.ENCODER_FORECASTER = edict()
153
+ __C.MODEL.ENCODER_FORECASTER.HAS_MASK = True
154
+ __C.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE = [96, 32, 16]
155
+ __C.MODEL.ENCODER_FORECASTER.FIRST_CONV = (8, 7, 5, 1) # Num filter, kernel, stride, pad
156
+ __C.MODEL.ENCODER_FORECASTER.LAST_DECONV = (8, 7, 5, 1) # Num filter, kernel, stride, pad
157
+ __C.MODEL.ENCODER_FORECASTER.DOWNSAMPLE = [(5, 3, 1),
158
+ (3, 2, 1)] # (kernel, stride, pad) for conv2d
159
+ __C.MODEL.ENCODER_FORECASTER.UPSAMPLE = [(5, 3, 1),
160
+ (4, 2, 1)] # (kernel, stride, pad) for deconv2d
161
+
162
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS = edict() # Define the RNN blocks for the encoder RNN
163
+ # In our network, the forecaster RNN will always have the reverse structure of encoder RNN
164
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.RES_CONNECTION = True
165
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.LAYER_TYPE = ["ConvGRU", "ConvGRU", "ConvGRU"]
166
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.STACK_NUM = [2, 3, 3]
167
+ # These features are used for both ConvGRU
168
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.NUM_FILTER = [32, 64, 64]
169
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.H2H_KERNEL = [(5, 5), (5, 5), (3, 3)]
170
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.H2H_DILATE = [(1, 1), (1, 1), (1, 1)]
171
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.I2H_KERNEL = [(3, 3), (3, 3), (3, 3)]
172
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.I2H_PAD = [(1, 1), (1, 1), (1, 1)]
173
+ # These features are only used in TrajGRU
174
+ __C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.L = [5, 5, 5]
175
+
176
+ __C.MODEL.DECONVBASELINE = edict()
177
+ __C.MODEL.DECONVBASELINE.BASE_NUM_FILTER = 16
178
+ __C.MODEL.DECONVBASELINE.USE_3D = True
179
+ __C.MODEL.DECONVBASELINE.ENCODER = "separate"
180
+ __C.MODEL.DECONVBASELINE.BN = True
181
+ __C.MODEL.DECONVBASELINE.BN_GLOBAL_STATS = False
182
+ __C.MODEL.DECONVBASELINE.COMPAT = edict() # Compatibility flags to recover behavior of previous versions
183
+ __C.MODEL.DECONVBASELINE.COMPAT.CONV_INSTEADOF_FC_IN_ENCODER = False # Until 6th May 2017
184
+ __C.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC = 0
185
+
186
+ __C.MODEL.TRAIN = edict()
187
+ __C.MODEL.TRAIN.BATCH_SIZE = 3
188
+ __C.MODEL.TRAIN.TBPTT = False
189
+ __C.MODEL.TRAIN.OPTIMIZER = "adam"
190
+ __C.MODEL.TRAIN.LR = 1E-4
191
+ __C.MODEL.TRAIN.GAMMA1 = 0.9 # Used in RMSProp
192
+ __C.MODEL.TRAIN.BETA1 = 0.5 # When using ADAM, momentum is called beta1
193
+ __C.MODEL.TRAIN.EPS = 1E-8
194
+ __C.MODEL.TRAIN.MIN_LR = 1E-6
195
+ __C.MODEL.TRAIN.GRAD_CLIP = 50.0
196
+ __C.MODEL.TRAIN.WD = 0
197
+ __C.MODEL.TRAIN.MAX_ITER = 180000
198
+ __C.MODEL.VALID_ITER = 5000
199
+ __C.MODEL.SAVE_ITER = 15000
200
+ __C.MODEL.TRAIN.LR_DECAY_ITER = 10000
201
+ __C.MODEL.TRAIN.LR_DECAY_FACTOR = 0.7
202
+
203
+ __C.MODEL.TEST = edict()
204
+ __C.MODEL.TEST.FINETUNE = True
205
+ __C.MODEL.TEST.MAX_ITER = 1 # Number of samples to generate in testing mode
206
+ __C.MODEL.TEST.MODE = "online" # Can be `online` or `fixed`
207
+ __C.MODEL.TEST.DISABLE_TBPTT = True
208
+ __C.MODEL.TEST.ONLINE = edict()
209
+ __C.MODEL.TEST.ONLINE.OPTIMIZER = "adagrad"
210
+ __C.MODEL.TEST.ONLINE.LR = 1E-4
211
+ __C.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE = 0.0
212
+ __C.MODEL.TEST.ONLINE.GAMMA1 = 0.9 # Used in RMSProp
213
+ __C.MODEL.TEST.ONLINE.BETA1 = 0.5 # Used in ADAM!
214
+ __C.MODEL.TEST.ONLINE.EPS = 1E-6
215
+ __C.MODEL.TEST.ONLINE.GRAD_CLIP = 50.0
216
+ __C.MODEL.TEST.ONLINE.WD = 0
217
+
218
+
219
+ def _merge_two_config(user_cfg, default_cfg):
220
+ """ Merge user's config into default config dictionary, clobbering the
221
+ options in b whenever they are also specified in a.
222
+ Need to ensure the type of two val under same key are the same
223
+ Do recursive merge when encounter hierarchical dictionary
224
+ """
225
+ if type(user_cfg) is not edict:
226
+ return
227
+ for key, val in user_cfg.items():
228
+ # Since user_cfg is a sub-file of default_cfg
229
+ if not key in default_cfg:
230
+ raise KeyError('{} is not a valid config key'.format(key))
231
+
232
+ if (type(default_cfg[key]) is not type(val) and
233
+ default_cfg[key] is not None):
234
+ if isinstance(default_cfg[key], np.ndarray):
235
+ val = np.array(val, dtype=default_cfg[key].dtype)
236
+ else:
237
+ raise ValueError(
238
+ 'Type mismatch ({} vs. {}) '
239
+ 'for config key: {}'.format(type(default_cfg[key]),
240
+ type(val), key))
241
+ # Recursive merge config
242
+ if type(val) is edict:
243
+ try:
244
+ _merge_two_config(user_cfg[key], default_cfg[key])
245
+ except:
246
+ print('Error under config key: {}'.format(key))
247
+ raise
248
+ else:
249
+ default_cfg[key] = val
250
+
251
+
252
+ def cfg_from_file(file_name, target=__C):
253
+ """ Load a config file and merge it into the default options.
254
+ """
255
+ import yaml
256
+ with open(file_name, 'r') as f:
257
+ print('Loading YAML config file from %s' %f)
258
+ yaml_cfg = edict(yaml.load(f))
259
+
260
+ _merge_two_config(yaml_cfg, target)
261
+
262
+
263
+ def ordered_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
264
+ class OrderedDumper(Dumper):
265
+ pass
266
+
267
+ def _dict_representer(dumper, data):
268
+ return dumper.represent_mapping(
269
+ yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
270
+ data.items(), flow_style=False)
271
+
272
+ def _ndarray_representer(dumper, data):
273
+ return dumper.represent_list(data.tolist())
274
+
275
+ OrderedDumper.add_representer(OrderedDict, _dict_representer)
276
+ OrderedDumper.add_representer(edict, _dict_representer)
277
+ OrderedDumper.add_representer(np.ndarray, _ndarray_representer)
278
+ return yaml.dump(data, stream, OrderedDumper, **kwds)
279
+
280
+
281
+ def save_cfg(dir_path, source=__C):
282
+ cfg_count = 0
283
+ file_path = os.path.join(dir_path, 'cfg%d.yml' %cfg_count)
284
+ while os.path.exists(file_path):
285
+ cfg_count += 1
286
+ file_path = os.path.join(dir_path, 'cfg%d.yml' % cfg_count)
287
+ with open(file_path, 'w') as f:
288
+ logging.info("Save YAML config file to %s" %file_path)
289
+ ordered_dump(source, f, yaml.SafeDumper, default_flow_style=None)
290
+
291
+
292
+ def load_latest_cfg(dir_path, target=__C):
293
+ import re
294
+ cfg_count = None
295
+ source_cfg_path = None
296
+ for fname in os.listdir(dir_path):
297
+ ret = re.search('cfg(\d+)\.yml', fname)
298
+ if ret != None:
299
+ if cfg_count is None or (int(re.group(1)) > cfg_count):
300
+ cfg_count = int(re.group(1))
301
+ source_cfg_path = os.path.join(dir_path, ret.group(0))
302
+ cfg_from_file(file_name=source_cfg_path, target=target)
nowcasting/encoder_forecaster.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mxnet as mx
2
+ import mxnet.ndarray as nd
3
+ import nowcasting.config as cfg
4
+ from nowcasting.ops import reset_regs
5
+ from nowcasting.operators.common import grid_generator
6
+ from nowcasting.operators import *
7
+ from nowcasting.ops import *
8
+ from nowcasting.prediction_base_factory import PredictionBaseFactory
9
+ from nowcasting.operators.transformations import DFN
10
+ from nowcasting.my_module import MyModule
11
+
12
+
13
+ def get_encoder_forecaster_rnn_blocks(batch_size):
14
+ encoder_rnn_blocks = []
15
+ forecaster_rnn_blocks = []
16
+ gan_rnn_blocks = []
17
+ CONFIG = cfg.MODEL.ENCODER_FORECASTER.RNN_BLOCKS
18
+ for vec, block_prefix in [(encoder_rnn_blocks, "ebrnn"),
19
+ (forecaster_rnn_blocks, "fbrnn"),
20
+ (gan_rnn_blocks, "dbrnn")]:
21
+ for i in range(len(CONFIG.NUM_FILTER)):
22
+ name = "%s%d" % (block_prefix, i + 1)
23
+ if CONFIG.LAYER_TYPE[i] == "ConvGRU":
24
+ rnn_block = BaseStackRNN(base_rnn_class=ConvGRU,
25
+ stack_num=CONFIG.STACK_NUM[i],
26
+ name=name,
27
+ residual_connection=CONFIG.RES_CONNECTION,
28
+ num_filter=CONFIG.NUM_FILTER[i],
29
+ b_h_w=(batch_size,
30
+ cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],
31
+ cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),
32
+ h2h_kernel=CONFIG.H2H_KERNEL[i],
33
+ h2h_dilate=CONFIG.H2H_DILATE[i],
34
+ i2h_kernel=CONFIG.I2H_KERNEL[i],
35
+ i2h_pad=CONFIG.I2H_PAD[i],
36
+ act_type=cfg.MODEL.RNN_ACT_TYPE)
37
+ elif CONFIG.LAYER_TYPE[i] == "TrajGRU":
38
+ rnn_block = BaseStackRNN(base_rnn_class=TrajGRU,
39
+ stack_num=CONFIG.STACK_NUM[i],
40
+ name=name,
41
+ L=CONFIG.L[i],
42
+ residual_connection=CONFIG.RES_CONNECTION,
43
+ num_filter=CONFIG.NUM_FILTER[i],
44
+ b_h_w=(batch_size,
45
+ cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],
46
+ cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),
47
+ h2h_kernel=CONFIG.H2H_KERNEL[i],
48
+ h2h_dilate=CONFIG.H2H_DILATE[i],
49
+ i2h_kernel=CONFIG.I2H_KERNEL[i],
50
+ i2h_pad=CONFIG.I2H_PAD[i],
51
+ act_type=cfg.MODEL.RNN_ACT_TYPE)
52
+ else:
53
+ raise NotImplementedError
54
+ vec.append(rnn_block)
55
+ return encoder_rnn_blocks, forecaster_rnn_blocks, gan_rnn_blocks
56
+
57
+ class EncoderForecasterBaseFactory(PredictionBaseFactory):
58
+ def __init__(self,
59
+ batch_size,
60
+ in_seq_len,
61
+ out_seq_len,
62
+ height,
63
+ width,
64
+ ctx_num=1,
65
+ name="encoder_forecaster"):
66
+ super(EncoderForecasterBaseFactory, self).__init__(batch_size=batch_size,
67
+ in_seq_len=in_seq_len,
68
+ out_seq_len=out_seq_len,
69
+ height=height,
70
+ width=width,
71
+ name=name)
72
+ self._ctx_num = ctx_num
73
+
74
+ def _init_rnn(self):
75
+ self._encoder_rnn_blocks, self._forecaster_rnn_blocks, self._gan_rnn_blocks =\
76
+ get_encoder_forecaster_rnn_blocks(batch_size=self._batch_size)
77
+ return self._encoder_rnn_blocks + self._forecaster_rnn_blocks + self._gan_rnn_blocks
78
+
79
+ @property
80
+ def init_encoder_state_info(self):
81
+ init_state_info = []
82
+ for block in self._encoder_rnn_blocks:
83
+ for state in block.init_state_vars():
84
+ init_state_info.append({'name': state.name,
85
+ 'shape': state.attr('__shape__'),
86
+ '__layout__': state.list_attr()['__layout__']})
87
+ return init_state_info
88
+
89
+ @property
90
+ def init_forecaster_state_info(self):
91
+ init_state_info = []
92
+ for block in self._forecaster_rnn_blocks:
93
+ for state in block.init_state_vars():
94
+ init_state_info.append({'name': state.name,
95
+ 'shape': state.attr('__shape__'),
96
+ '__layout__': state.list_attr()['__layout__']})
97
+ return init_state_info
98
+
99
+ @property
100
+ def init_gan_state_info(self):
101
+ init_gan_state_info = []
102
+ for block in self._gan_rnn_blocks:
103
+ for state in block.init_state_vars():
104
+ init_gan_state_info.append({'name': state.name,
105
+ 'shape': state.attr('__shape__'),
106
+ '__layout__': state.list_attr()['__layout__']})
107
+ return init_gan_state_info
108
+
109
+ def stack_rnn_encode(self, data):
110
+ CONFIG = cfg.MODEL.ENCODER_FORECASTER
111
+ pre_encoded_data = self._pre_encode_frame(frame_data=data, seqlen=self._in_seq_len)
112
+ reshape_data = mx.sym.Reshape(pre_encoded_data, shape=(-1, 0, 0, 0), reverse=True)
113
+
114
+ # Encoder Part
115
+ conv1 = conv2d_act(data=reshape_data,
116
+ num_filter=CONFIG.FIRST_CONV[0],
117
+ kernel=(CONFIG.FIRST_CONV[1], CONFIG.FIRST_CONV[1]),
118
+ stride=(CONFIG.FIRST_CONV[2], CONFIG.FIRST_CONV[2]),
119
+ pad=(CONFIG.FIRST_CONV[3], CONFIG.FIRST_CONV[3]),
120
+ act_type=cfg.MODEL.CNN_ACT_TYPE,
121
+ name="econv1")
122
+ rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)
123
+ encoder_rnn_block_states = []
124
+ for i in range(rnn_block_num):
125
+ if i == 0:
126
+ inputs = conv1
127
+ else:
128
+ inputs = downsample
129
+ rnn_out, states = self._encoder_rnn_blocks[i].unroll(
130
+ length=self._in_seq_len,
131
+ inputs=inputs,
132
+ begin_states=None,
133
+ ret_mid=False)
134
+ encoder_rnn_block_states.append(states)
135
+ if i < rnn_block_num - 1:
136
+ downsample = downsample_module(data=rnn_out[-1],
137
+ num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i + 1],
138
+ kernel=(CONFIG.DOWNSAMPLE[i][0],
139
+ CONFIG.DOWNSAMPLE[i][0]),
140
+ stride=(CONFIG.DOWNSAMPLE[i][1],
141
+ CONFIG.DOWNSAMPLE[i][1]),
142
+ pad=(CONFIG.DOWNSAMPLE[i][2],
143
+ CONFIG.DOWNSAMPLE[i][2]),
144
+ b_h_w=(self._batch_size,
145
+ CONFIG.FEATMAP_SIZE[i + 1],
146
+ CONFIG.FEATMAP_SIZE[i + 1]),
147
+ name="edown%d" %(i + 1))
148
+ return encoder_rnn_block_states
149
+
150
+ def stack_rnn_forecast(self, block_state_list, last_frame):
151
+ CONFIG = cfg.MODEL.ENCODER_FORECASTER
152
+ block_state_list = [self._forecaster_rnn_blocks[i].to_split(block_state_list[i])
153
+ for i in range(len(self._forecaster_rnn_blocks))]
154
+ rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)
155
+ rnn_block_outputs = []
156
+ # RNN Forecaster Part
157
+ curr_inputs = None
158
+ for i in range(rnn_block_num - 1, -1, -1):
159
+ rnn_out, rnn_state = self._forecaster_rnn_blocks[i].unroll(
160
+ length=self._out_seq_len, inputs=curr_inputs,
161
+ begin_states=block_state_list[i][::-1], # Reverse the order of states for the forecaster
162
+ ret_mid=False)
163
+ rnn_block_outputs.append(rnn_out)
164
+ if i > 0:
165
+ upsample = upsample_module(data=rnn_out[-1],
166
+ num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i],
167
+ kernel=(CONFIG.UPSAMPLE[i - 1][0],
168
+ CONFIG.UPSAMPLE[i - 1][0]),
169
+ stride=(CONFIG.UPSAMPLE[i - 1][1],
170
+ CONFIG.UPSAMPLE[i - 1][1]),
171
+ pad=(CONFIG.UPSAMPLE[i - 1][2],
172
+ CONFIG.UPSAMPLE[i - 1][2]),
173
+ b_h_w=(self._batch_size, CONFIG.FEATMAP_SIZE[i - 1]),
174
+ name="fup%d" %i)
175
+ curr_inputs = upsample
176
+ # Output
177
+ if cfg.MODEL.OUT_TYPE == "DFN":
178
+ concat_fbrnn1_out = mx.sym.concat(*rnn_out[-1], dim=0)
179
+ dynamic_filter = deconv2d(data=concat_fbrnn1_out,
180
+ num_filter=121,
181
+ kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),
182
+ stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),
183
+ pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]))
184
+ flow = dynamic_filter
185
+ dynamic_filter = mx.sym.SliceChannel(dynamic_filter, axis=0, num_outputs=self._out_seq_len)
186
+ prev_frame = last_frame
187
+ preds = []
188
+ for i in range(self._out_seq_len):
189
+ pred_ele = DFN(data=prev_frame, local_kernels=dynamic_filter[i], K=11, batch_size=self._batch_size)
190
+ preds.append(pred_ele)
191
+ prev_frame = pred_ele
192
+ pred = mx.sym.concat(*preds, dim=0)
193
+ elif cfg.MODEL.OUT_TYPE == "direct":
194
+ flow = None
195
+ deconv1 = deconv2d_act(data=mx.sym.concat(*rnn_out[-1], dim=0),
196
+ num_filter=CONFIG.LAST_DECONV[0],
197
+ kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),
198
+ stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),
199
+ pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]),
200
+ act_type=cfg.MODEL.CNN_ACT_TYPE,
201
+ name="fdeconv1")
202
+ conv_final = conv2d_act(data=deconv1,
203
+ num_filter=CONFIG.LAST_DECONV[0],
204
+ kernel=(3, 3), stride=(1, 1), pad=(1, 1),
205
+ act_type=cfg.MODEL.CNN_ACT_TYPE, name="conv_final")
206
+ pred = conv2d(data=conv_final,
207
+ num_filter=1, kernel=(1, 1), name="out")
208
+ else:
209
+ raise NotImplementedError
210
+ pred = mx.sym.Reshape(pred,
211
+ shape=(self._out_seq_len, self._batch_size,
212
+ 1, self._height, self._width),
213
+ __layout__="TNCHW")
214
+ return pred, flow
215
+
216
+ def encoder_sym(self):
217
+ self.reset_all()
218
+ data = mx.sym.Variable('data') # Shape: (in_seq_len, batch_size, C, H, W)
219
+ block_state_list = self.stack_rnn_encode(data=data)
220
+ states = []
221
+ for i, rnn_block in enumerate(self._encoder_rnn_blocks):
222
+ states.extend(rnn_block.flatten_add_layout(block_state_list[i]))
223
+ return mx.sym.Group(states)
224
+
225
+ def encoder_data_desc(self):
226
+ ret = list()
227
+ ret.append(mx.io.DataDesc(name='data',
228
+ shape=(self._in_seq_len,
229
+ self._batch_size * self._ctx_num,
230
+ 1,
231
+ self._height,
232
+ self._width),
233
+ layout="TNCHW"))
234
+ for info in self.init_encoder_state_info:
235
+ state_shape = safe_eval(info['shape'])
236
+ assert info['__layout__'].find('N') == 0,\
237
+ "Layout=%s is not supported!" %info["__layout__"]
238
+ state_shape = (state_shape[0] * self._ctx_num, ) + state_shape[1:]
239
+ ret.append(mx.io.DataDesc(name=info['name'],
240
+ shape=state_shape,
241
+ layout=info['__layout__']))
242
+ return ret
243
+
244
+ def forecaster_sym(self):
245
+ self.reset_all()
246
+ block_state_list = []
247
+ for block in self._forecaster_rnn_blocks:
248
+ block_state_list.append(block.init_state_vars())
249
+
250
+ if cfg.MODEL.OUT_TYPE == "direct":
251
+ pred, _ = self.stack_rnn_forecast(block_state_list=block_state_list,
252
+ last_frame=None)
253
+ return mx.sym.Group([pred])
254
+ else:
255
+ last_frame = mx.sym.Variable('last_frame') # Shape: (batch_size, C, H, W)
256
+ pred, flow = self.stack_rnn_forecast(block_state_list=block_state_list,
257
+ last_frame=last_frame)
258
+ return mx.sym.Group([pred, mx.sym.BlockGrad(flow)])
259
+
260
+ def forecaster_data_desc(self):
261
+ ret = list()
262
+ for info in self.init_forecaster_state_info:
263
+ state_shape = safe_eval(info['shape'])
264
+ assert info['__layout__'].find('N') == 0, \
265
+ "Layout=%s is not supported!" % info["__layout__"]
266
+ state_shape = (state_shape[0] * self._ctx_num,) + state_shape[1:]
267
+ ret.append(mx.io.DataDesc(name=info['name'],
268
+ shape=state_shape,
269
+ layout=info['__layout__']))
270
+ if cfg.MODEL.OUT_TYPE != "direct":
271
+ ret.append(mx.io.DataDesc(name="last_frame",
272
+ shape=(self._ctx_num * self._batch_size,
273
+ 1, self._height, self._width),
274
+ layout="NCHW"))
275
+ return ret
276
+
277
+ def loss_sym(self):
278
+ raise NotImplementedError
279
+
280
+ def loss_data_desc(self):
281
+ ret = list()
282
+ ret.append(mx.io.DataDesc(name='pred',
283
+ shape=(self._out_seq_len,
284
+ self._ctx_num * self._batch_size,
285
+ 1,
286
+ self._height,
287
+ self._width),
288
+ layout="TNCHW"))
289
+ return ret
290
+
291
+ def loss_label_desc(self):
292
+ ret = list()
293
+ ret.append(mx.io.DataDesc(name='target',
294
+ shape=(self._out_seq_len,
295
+ self._ctx_num * self._batch_size,
296
+ 1,
297
+ self._height,
298
+ self._width),
299
+ layout="TNCHW"))
300
+ if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
301
+ ret.append(mx.io.DataDesc(name='mask',
302
+ shape=(self._out_seq_len,
303
+ self._ctx_num * self._batch_size,
304
+ 1,
305
+ self._height,
306
+ self._width),
307
+ layout="TNCHW"))
308
+ return ret
309
+
310
+
311
+
312
+ def init_optimizer_using_cfg(net, for_finetune):
313
+ if not for_finetune:
314
+ lr_scheduler = mx.lr_scheduler.FactorScheduler(step=cfg.MODEL.TRAIN.LR_DECAY_ITER,
315
+ factor=cfg.MODEL.TRAIN.LR_DECAY_FACTOR,
316
+ stop_factor_lr=cfg.MODEL.TRAIN.MIN_LR)
317
+ if cfg.MODEL.TRAIN.OPTIMIZER.lower() == "adam":
318
+ net.init_optimizer(optimizer="adam",
319
+ optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR,
320
+ 'beta1': cfg.MODEL.TRAIN.BETA1,
321
+ 'rescale_grad': 1.0,
322
+ 'epsilon': cfg.MODEL.TRAIN.EPS,
323
+ 'lr_scheduler': lr_scheduler,
324
+ 'wd': cfg.MODEL.TRAIN.WD})
325
+ elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "rmsprop":
326
+ net.init_optimizer(optimizer="rmsprop",
327
+ optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR,
328
+ 'gamma1': cfg.MODEL.TRAIN.GAMMA1,
329
+ 'rescale_grad': 1.0,
330
+ 'epsilon': cfg.MODEL.TRAIN.EPS,
331
+ 'lr_scheduler': lr_scheduler,
332
+ 'wd': cfg.MODEL.TRAIN.WD})
333
+ elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "sgd":
334
+ net.init_optimizer(optimizer="sgd",
335
+ optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR,
336
+ 'momentum': 0.0,
337
+ 'rescale_grad': 1.0,
338
+ 'lr_scheduler': lr_scheduler,
339
+ 'wd': cfg.MODEL.TRAIN.WD})
340
+ elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "adagrad":
341
+ net.init_optimizer(optimizer="adagrad",
342
+ optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR,
343
+ 'eps': cfg.MODEL.TRAIN.EPS,
344
+ 'rescale_grad': 1.0,
345
+ 'wd': cfg.MODEL.TRAIN.WD})
346
+ else:
347
+ raise NotImplementedError
348
+ else:
349
+ if cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "adam":
350
+ net.init_optimizer(optimizer="adam",
351
+ optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR,
352
+ 'beta1': cfg.MODEL.TEST.ONLINE.BETA1,
353
+ 'rescale_grad': 1.0,
354
+ 'epsilon': cfg.MODEL.TEST.ONLINE.EPS,
355
+ 'wd': cfg.MODEL.TEST.ONLINE.WD})
356
+ elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "rmsprop":
357
+ net.init_optimizer(optimizer="rmsprop",
358
+ optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR,
359
+ 'gamma1': cfg.MODEL.TEST.ONLINE.GAMMA1,
360
+ 'rescale_grad': 1.0,
361
+ 'epsilon': cfg.MODEL.TEST.ONLINE.EPS,
362
+ 'wd': cfg.MODEL.TEST.ONLINE.WD})
363
+ elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "sgd":
364
+ net.init_optimizer(optimizer="sgd",
365
+ optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR,
366
+ 'momentum': 0.0,
367
+ 'rescale_grad': 1.0,
368
+ 'wd': cfg.MODEL.TEST.ONLINE.WD})
369
+ elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "adagrad":
370
+ net.init_optimizer(optimizer="adagrad",
371
+ optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR,
372
+ 'eps': cfg.MODEL.TRAIN.EPS,
373
+ 'rescale_grad': 1.0,
374
+ 'wd': cfg.MODEL.TEST.ONLINE.WD})
375
+ return net
376
+
377
+
378
+ def encoder_forecaster_build_networks(factory, context,
379
+ shared_encoder_net=None,
380
+ shared_forecaster_net=None,
381
+ shared_loss_net=None,
382
+ for_finetune=False):
383
+ """
384
+
385
+ Parameters
386
+ ----------
387
+ factory : EncoderForecasterBaseFactory
388
+ context : list
389
+ shared_encoder_net : MyModule or None
390
+ shared_forecaster_net : MyModule or None
391
+ shared_loss_net : MyModule or None
392
+ for_finetune : bool
393
+
394
+ Returns
395
+ -------
396
+
397
+ """
398
+ encoder_net = MyModule(factory.encoder_sym(),
399
+ data_names=[ele.name for ele in factory.encoder_data_desc()],
400
+ label_names=[],
401
+ context=context,
402
+ name="encoder_net")
403
+ encoder_net.bind(data_shapes=factory.encoder_data_desc(),
404
+ label_shapes=None,
405
+ inputs_need_grad=True,
406
+ shared_module=shared_encoder_net)
407
+ if shared_encoder_net is None:
408
+ encoder_net.init_params(mx.init.MSRAPrelu(slope=0.2))
409
+ init_optimizer_using_cfg(encoder_net, for_finetune=for_finetune)
410
+ forecaster_net = MyModule(factory.forecaster_sym(),
411
+ data_names=[ele.name for ele in
412
+ factory.forecaster_data_desc()],
413
+ label_names=[],
414
+ context=context,
415
+ name="forecaster_net")
416
+ forecaster_net.bind(data_shapes=factory.forecaster_data_desc(),
417
+ label_shapes=None,
418
+ inputs_need_grad=True,
419
+ shared_module=shared_forecaster_net)
420
+ if shared_forecaster_net is None:
421
+ forecaster_net.init_params(mx.init.MSRAPrelu(slope=0.2))
422
+ init_optimizer_using_cfg(forecaster_net, for_finetune=for_finetune)
423
+
424
+ loss_net = MyModule(factory.loss_sym(),
425
+ data_names=[ele.name for ele in
426
+ factory.loss_data_desc()],
427
+ label_names=[ele.name for ele in
428
+ factory.loss_label_desc()],
429
+ context=context,
430
+ name="loss_net")
431
+ loss_net.bind(data_shapes=factory.loss_data_desc(),
432
+ label_shapes=factory.loss_label_desc(),
433
+ inputs_need_grad=True,
434
+ shared_module=shared_loss_net)
435
+ if shared_loss_net is None:
436
+ loss_net.init_params()
437
+ return encoder_net, forecaster_net, loss_net
438
+
439
+
440
+ class EncoderForecasterStates(object):
441
+ def __init__(self, factory, ctx):
442
+ self._factory = factory
443
+ self._ctx = ctx
444
+ self._encoder_state_info = factory.init_encoder_state_info
445
+ self._forecaster_state_info = factory.init_forecaster_state_info
446
+ self._states_nd = []
447
+ for info in self._encoder_state_info:
448
+ state_shape = safe_eval(info['shape'])
449
+ state_shape = (state_shape[0] * factory._ctx_num, ) + state_shape[1:]
450
+ self._states_nd.append(mx.nd.zeros(shape=state_shape, ctx=ctx))
451
+
452
+ def reset_all(self):
453
+ for ele, info in zip(self._states_nd, self._encoder_state_info):
454
+ ele[:] = 0
455
+
456
+ def reset_batch(self, batch_id):
457
+ for ele, info in zip(self._states_nd, self._encoder_state_info):
458
+ ele[batch_id][:] = 0
459
+
460
+ def update(self, states_nd):
461
+ for target, src in zip(self._states_nd, states_nd):
462
+ target[:] = src
463
+
464
+ def get_encoder_states(self):
465
+ return self._states_nd
466
+
467
+ def get_forecaster_state(self):
468
+ return self._states_nd
469
+
470
+
471
+ def train_step(batch_size, encoder_net, forecaster_net,
472
+ loss_net, init_states,
473
+ data_nd, gt_nd, mask_nd, iter_id=None):
474
+ """Finetune the encoder, forecaster and GAN for one step
475
+
476
+ Parameters
477
+ ----------
478
+ batch_size : int
479
+ encoder_net : MyModule
480
+ forecaster_net : MyModule
481
+ loss_net : MyModule
482
+ init_states : EncoderForecasterStates
483
+ data_nd : mx.nd.ndarray
484
+ gt_nd : mx.nd.ndarray
485
+ mask_nd : mx.nd.ndarray
486
+ iter_id : int
487
+
488
+ Returns
489
+ -------
490
+ init_states: EncoderForecasterStates
491
+ loss_dict: dict
492
+ """
493
+ # Forward Encoder
494
+ encoder_net.forward(is_train=True,
495
+ data_batch=mx.io.DataBatch(data=[data_nd] + init_states.get_encoder_states()))
496
+ encoder_states_nd = encoder_net.get_outputs()
497
+ init_states.update(encoder_states_nd)
498
+ # Forward Forecaster
499
+ if cfg.MODEL.OUT_TYPE == "direct":
500
+ forecaster_net.forward(is_train=True,
501
+ data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state()))
502
+ else:
503
+ last_frame_nd = data_nd[data_nd.shape[0] - 1]
504
+ forecaster_net.forward(is_train=True,
505
+ data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state() +
506
+ [last_frame_nd]))
507
+ forecaster_outputs = forecaster_net.get_outputs()
508
+ pred_nd = forecaster_outputs[0]
509
+
510
+ # Calculate the gradient of the loss functions
511
+ if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
512
+ loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],
513
+ label=[gt_nd, mask_nd]))
514
+ else:
515
+ loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],
516
+ label=[gt_nd]))
517
+ pred_grad = loss_net.get_input_grads()[0]
518
+ loss_dict = loss_net.get_output_dict()
519
+ for k in loss_dict:
520
+ loss_dict[k] = nd.mean(loss_dict[k]).asscalar()
521
+ # Backward Forecaster
522
+ forecaster_net.backward(out_grads=[pred_grad])
523
+ if cfg.MODEL.OUT_TYPE == "direct":
524
+ encoder_states_grad_nd = forecaster_net.get_input_grads()
525
+ else:
526
+ encoder_states_grad_nd = forecaster_net.get_input_grads()[:-1]
527
+ # Backward Encoder
528
+ encoder_net.backward(encoder_states_grad_nd)
529
+ # Update forecaster and encoder
530
+ forecaster_grad_norm = forecaster_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)
531
+ encoder_grad_norm = encoder_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)
532
+ forecaster_net.update()
533
+ encoder_net.update()
534
+ loss_str = ", ".join(["%s=%g" %(k, v) for k, v in loss_dict.items()])
535
+ if iter_id is not None:
536
+ logging.info("Iter:%d, %s, e_gnorm=%g, f_gnorm=%g"
537
+ % (iter_id, loss_str, encoder_grad_norm, forecaster_grad_norm))
538
+ return init_states, loss_dict
539
+
540
+
541
+ def load_encoder_forecaster_params(load_dir, load_iter, encoder_net, forecaster_net):
542
+ logging.info("Loading parameters from {}, Iter = {}"
543
+ .format(os.path.realpath(load_dir), load_iter))
544
+ encoder_arg_params, encoder_aux_params = load_params(prefix=os.path.join(load_dir,
545
+ "encoder_net"),
546
+ epoch=load_iter)
547
+ encoder_net.init_params(arg_params=encoder_arg_params, aux_params=encoder_aux_params,
548
+ allow_missing=False, force_init=True)
549
+ forecaster_arg_params, forecaster_aux_params = load_params(prefix=os.path.join(load_dir,
550
+ "forecaster_net"),
551
+ epoch=load_iter)
552
+ forecaster_net.init_params(arg_params=forecaster_arg_params,
553
+ aux_params=forecaster_aux_params,
554
+ allow_missing=False,
555
+ force_init=True)
556
+ logging.info("Loading Complete!")
nowcasting/helpers/__init__.py ADDED
File without changes
nowcasting/helpers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
nowcasting/helpers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (162 Bytes). View file