mgonzc commited on
Commit
ee69210
·
verified ·
1 Parent(s): b6ae318

Upload 32 files

Browse files
Files changed (33) hide show
  1. .gitattributes +16 -0
  2. GANFilling/.gitattributes +1 -0
  3. GANFilling/.gitignore +3 -0
  4. GANFilling/GANFilling_percept_results.png +3 -0
  5. GANFilling/README.md +62 -0
  6. GANFilling/data/landcover_types.csv +15 -0
  7. GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz +3 -0
  8. GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz +3 -0
  9. GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz +3 -0
  10. GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz +3 -0
  11. GANFilling/data/test_data/landcover/29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz +0 -0
  12. GANFilling/data/test_data/landcover/30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz +0 -0
  13. GANFilling/data/test_data/landcover/33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz +0 -0
  14. GANFilling/data/test_data/landcover/34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz +0 -0
  15. GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz +3 -0
  16. GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz +3 -0
  17. GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz +3 -0
  18. GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz +3 -0
  19. GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz +3 -0
  20. GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz +3 -0
  21. GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz +3 -0
  22. GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png +3 -0
  23. GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png +3 -0
  24. GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png +3 -0
  25. GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png +3 -0
  26. GANFilling/src/iterator.py +111 -0
  27. GANFilling/src/models/convlstm.py +200 -0
  28. GANFilling/src/models/discriminator.py +60 -0
  29. GANFilling/src/models/generator.py +123 -0
  30. GANFilling/src/test.py +138 -0
  31. GANFilling/src/train.py +258 -0
  32. GANFilling/src/utils/example_clean_data.json +83 -0
  33. GANFilling/src/utils/generate_cleanData_file.py +84 -0
.gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz filter=lfs diff=lfs merge=lfs -text
2
+ GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz filter=lfs diff=lfs merge=lfs -text
3
+ GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz filter=lfs diff=lfs merge=lfs -text
4
+ GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz filter=lfs diff=lfs merge=lfs -text
5
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz filter=lfs diff=lfs merge=lfs -text
6
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz filter=lfs diff=lfs merge=lfs -text
7
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz filter=lfs diff=lfs merge=lfs -text
8
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz filter=lfs diff=lfs merge=lfs -text
9
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz filter=lfs diff=lfs merge=lfs -text
10
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz filter=lfs diff=lfs merge=lfs -text
11
+ GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz filter=lfs diff=lfs merge=lfs -text
12
+ GANFilling/GANFilling_percept_results.png filter=lfs diff=lfs merge=lfs -text
13
+ GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png filter=lfs diff=lfs merge=lfs -text
14
+ GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png filter=lfs diff=lfs merge=lfs -text
15
+ GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png filter=lfs diff=lfs merge=lfs -text
16
+ GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png filter=lfs diff=lfs merge=lfs -text
GANFilling/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
GANFilling/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # pycache
2
+ __pycache__
3
+ */__pycache__
GANFilling/GANFilling_percept_results.png ADDED

Git LFS Details

  • SHA256: 9c4ba1119d2ec1f9e06701f4bdb5dce36fa09246ff9e09eb6aa9cfd4e84220e5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.8 MB
GANFilling/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Networks for Spatio-Temporal Gap Filling of Sentinel-2 Reflectances
2
+ ---------------
3
+ | [Journal ISPRS paper](https://doi.org/10.1016/j.isprsjprs.2025.01.016) |
4
+
5
+ # Abstract
6
+
7
+ Earth observation from satellite sensors offers the possibility to monitor natural ecosystems by deriving spatially explicit and temporally resolved biogeophysical parameters. Optical remote sensing, however, suffers from missing data mainly due to the presence of clouds, sensor malfunctioning, and atmospheric conditions. This study proposes a novel deep learning architecture to address gap filling of satellite reflectances, more precisely the visible and near-infrared bands, and illustrates its performance at high-resolution Sentinel-2 data. We introduce GANFilling, a generative adversarial network capable of sequence-to-sequence translation, which comprises convolutional long short-term memory layers to effectively exploit complete dependencies in space-time series data. We focus on Europe and evaluate the method's performance quantitatively (through distortion and perceptual metrics) and qualitatively (via visual inspection and visual quality metrics). Quantitatively, our model offers the best trade-off between denoising corrupted data and preserving noise-free information, underscoring the importance of considering multiple metrics jointly when assessing gap filling tasks. Qualitatively, it successfully deals with various noise sources, such as clouds and missing data, constituting a robust solution to multiple scenarios and settings. We also illustrate and quantify the quality of the generated product in the relevant downstream application of vegetation greenness forecasting, where using GANFilling enhances forecasting in approximately 70% of the considered regions in Europe. This research contributes to underlining the utility of deep learning for Earth observation data, which allows for improved spatially and temporally resolved monitoring of the Earth surface.
8
+
9
+ # Usage and Requirements
10
+
11
+ ## Train
12
+
13
+ `
14
+ python src/train.py --dataPath data/train_data --cleanDataPath src/utils/example_clean_data.json --name model_1
15
+ `
16
+
17
+ Models are saved to ./trained_models/ (can be changed by passing --modelsPath=your_dir in train.py).
18
+
19
+ For training the discriminator, a json file with the noiseless(real) data has to be generated. This can be done with the ./utils/generate_cleanData_file.py. An example of the structure of this file can be found in ./utils/generate_cleanData_file.py.
20
+
21
+ The folder with the training data is specified with the `--dataPath` argument. This repository just includes few samples in data/train_data/ as a reference for structure and behaviour checking.
22
+
23
+ ## Test
24
+
25
+ `
26
+ python src/test.py --model_path trained_models/GANFilling.pt --data_path data --results_path results
27
+ `
28
+
29
+ This will run the GANFilling trained model on a small set of examples and generate the corresponding gap filled time series.
30
+
31
+ To test your own model modify the parameter `--model_path` in test.py.
32
+ To test on your own data modify the parameter `--data_path` in test.py.
33
+
34
+ # Results
35
+
36
+ ![image](GANFilling_percept_results.png)
37
+
38
+ <b>Example images showing the GANFilling reconstruction on different land covers.</b> For each example, the first row shows the land cover map followed by ten original time steps of a visible (RGB) sequence, while the second row corresponds to its noise-free version. All images are noted with the day-of-year (DOY). The third row illustrates the NDVI maps for the noise-free images. Different types of noise are outlined in red. (A) Complexe scene with predominant herbaceous vegetation and multiple frames with complete loss of information. (B) Sequence with mostly cultivated areas to show the performance on fast changes in the Earth’s surface with heavily occluded frames. (C) Sequence with widespread vines characterized by a rapid evolution of the land cover. (D) Predominant coniferous tree cover with a water body nearby. (E) Sequence with predominant broadleaf tree cover and several consecutive time steps with dense occlusions. Land cover’s legend: <span style="color:rgb(255,255,255)">&#9723;</span> Cloud or No data, <span style="color:rgb(210,0,0)">&#9724;</span> Artificial surfaces and constructions, <span style="color:rgb(253,211,39)">&#9724;</span> Cultivated areas, <span style="color:rgb(176,91,16)">&#9724;</span> Vineyards, <span style="color:rgb(35,152,0)">&#9724;</span> Broadleaf tree cover, <span style="color:rgb(8,98,0)">&#9724;</span> Coniferous tree cover, <span style="color:rgb(249,150,39)">&#9724;</span> Herbaceous vegetation, <span style="color:rgb(141,139,0)">&#9724;</span> Moors and Heathland, <span style="color:rgb(95,53,6)">&#9724;</span> Sclerophyllous vegetation, <span style="color:rgb(149,107,196)">&#9724;</span> Marshes, <span style="color:rgb(77,37,106)">&#9724;</span> Peatbogs, <span style="color:rgb(154,154,154)">&#9724;</span> Natural material surfaces, <span style="color:rgb(106,255,255)">&#9724;</span> Permanent snow covered surfaces, <span style="color:rgb(20,69,249)">&#9724;</span> Water bodies.
39
+
40
+ # How to cite
41
+
42
+ If you use this code for your research, please cite our paper Generative Networks for Spatio-Temporal Gap Filling of Sentinel-2 Reflectances:
43
+
44
+ ```
45
+ @article{GonzalezCalabuig2025,
46
+ title = {Generative networks for spatio-temporal gap filling of Sentinel-2 reflectances},
47
+ journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
48
+ volume = {220},
49
+ pages = {637-648},
50
+ year = {2025},
51
+ issn = {0924-2716},
52
+ doi = {https://doi.org/10.1016/j.isprsjprs.2025.01.016},
53
+ author = {Maria Gonzalez-Calabuig and Miguel-Ángel Fernández-Torres and Gustau Camps-Valls}}
54
+ ```
55
+
56
+ # Aknowledgments
57
+
58
+ Authors acknowledge the support from the European Research Council (ERC) under the ERC SynergyGrant USMILE (grant agreement 855187), the European Union’s Horizon 2020 research and innovation program within the projects ‘XAIDA: Extreme Events- Artificial Intelligence for Detection and Attribution,’(grant agreement 101003469), ‘DeepCube: Explainable AI pipelines for big Copernicus data’ (grant agreement 101004188), the ESA AI4Science project ”MultiHazards, Compounds and Cascade events: DeepExtremes”, 2022-2024, the computer resources provided by the J¨ulich Supercomputing Centre (JSC) (Project No.PRACE-DEV-2022D01-048), the computer resources provided by Artemisa (funded by the European Union ERDF and Comunitat Valenciana), as well as the technical support provided by the Instituto de Física Corpuscular, IFIC (CSIC-UV).
59
+
60
+ # License
61
+
62
+ [MIT]()
GANFilling/data/landcover_types.csv ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0,255,255,255,255,Clouds or No data
2
+ 62,210,0,0,255,Artificial surfaces and constructions
3
+ 73,253,211,39,255,Cultivated areas
4
+ 75,176,91,16,255,Vineyards
5
+ 82,35,152,0,255,Broadleaf tree cover
6
+ 83,8,98,0,255,Coniferous tree cover
7
+ 102,249,150,39,255,Herbaceous vegetation
8
+ 103,141,139,0,255,Moors and Heathland
9
+ 104,95,53,6,255,Sclerophyllous vegetation
10
+ 105,149,107,196,255,Marshes
11
+ 106,77,37,106,255,Peatbogs
12
+ 121,154,154,154,255,Natural material surfaces
13
+ 123,106,255,255,255,Permanent snow covered surfaces
14
+ 162,20,69,249,255,Water bodies
15
+ 255,255,255,255,255,No data
GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ff96d1e04dc8f23e1caa09dfb3b65adc4111393729ed3b963a87b1b5185c048
3
+ size 4735175
GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:603f455b64f8044910fa78b6dee26f335b924f2fb6f9cbf03539bfbae6c4f037
3
+ size 4198884
GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e15904f0feee4a86d1ef931d97d34a3b5b7d4e855aed218fad0bfd8dd50ca839
3
+ size 3328672
GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716193f3fea3641ce5c94b4169fb047f6ac224ae5d13720cf0a413b2520322fe
3
+ size 5594634
GANFilling/data/test_data/landcover/29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz ADDED
Binary file (4.23 kB). View file
 
GANFilling/data/test_data/landcover/30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz ADDED
Binary file (2.7 kB). View file
 
GANFilling/data/test_data/landcover/33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz ADDED
Binary file (3.6 kB). View file
 
GANFilling/data/test_data/landcover/34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz ADDED
Binary file (3.32 kB). View file
 
GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:372e0bf759d6856b57dd9a5295b9cb03d1ad1507940d4d8b6be582d6d09f4939
3
+ size 8178516
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7127004d5dd01befad7f786fea024f522a75fb8942a49a220b7843a36e1fb7f
3
+ size 8292551
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9bcebacc479854a4715c8096a7ff6d90de208f2e6df0259a3b594390902b28d
3
+ size 8150040
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9ff886ee2e13a9201a3ec2341bea317ba6eb28bc944dd458e5c51b2c94603d2
3
+ size 8096535
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd8ce22607591b366ba2b4ac61a4df19b0fd58c2f7cfaa94504f6d6599134e2f
3
+ size 8291712
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d522b678b06e37d7ce62b4d898bed0fec187a72711efe3c7fc4d7dada43b083
3
+ size 8400326
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143075cf0da3467a2d25ff4a69f290cf497f43b846147bb9b13167b5c4710068
3
+ size 7970874
GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png ADDED

Git LFS Details

  • SHA256: 09fcd016fd5d10d900be0b6d5dbc2d31bce9e98899e1ff3b78f11e8141fab8f4
  • Pointer size: 131 Bytes
  • Size of remote file: 641 kB
GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png ADDED

Git LFS Details

  • SHA256: fc099683575f4ea0b6945d3adec761be8168127b0a57ebdbeb452eae03398ee9
  • Pointer size: 131 Bytes
  • Size of remote file: 617 kB
GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png ADDED

Git LFS Details

  • SHA256: cf35299a6eba5567134813f810ccafd038b39e7b9a2e67098dd96d27976b69f3
  • Pointer size: 131 Bytes
  • Size of remote file: 670 kB
GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png ADDED

Git LFS Details

  • SHA256: cbf824604b5c5b5c0aeed9efd8743167f4f0fb9d5bacd9db481c0a0e0d7c0bd0
  • Pointer size: 131 Bytes
  • Size of remote file: 541 kB
GANFilling/src/iterator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import torch
4
+ import random
5
+ import os
6
+
7
+ from kornia.augmentation import RandomHorizontalFlip, RandomVerticalFlip
8
+
9
+
10
+ class Iterator():
11
+ def __init__(self, dataPath, cleanDataPath, mode):
12
+ self.mode = mode
13
+ self.random_seed_data = 42
14
+ self.__initData(dataPath, cleanDataPath)
15
+
16
+ def __initData(self, dataPath, cleanDataPath):
17
+ self.data = {}
18
+
19
+ if self.mode == 'train':
20
+ self.__loadTrainData(dataPath)
21
+ self.cleanData = json.load(open(cleanDataPath))
22
+
23
+ def __loadTrainData(self, dataPath):
24
+ tiles = os.listdir(dataPath)
25
+ if 'LICENSE' in tiles:
26
+ tiles.remove('LICENSE')
27
+ tiles.sort()
28
+
29
+ loadedData = []
30
+ for tile in tiles:
31
+ in_tile_path = dataPath / tile
32
+ files = os.listdir(in_tile_path)
33
+ files.sort()
34
+ in_files = []
35
+ for file in files:
36
+ in_files.append(os.path.join(in_tile_path, file))
37
+ random.Random(self.random_seed_data).shuffle(in_files)
38
+ for f in in_files[:int(len(in_files)*0.8)]:
39
+ loadedData.append(f)
40
+ self.data = loadedData
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+ def __getMasks(self, sample, index):
46
+ return sample["highresdynamic"][:,:,-1,:index+1]
47
+
48
+ def __find_closest_area(self, condition_1):
49
+ self.cleanData.keys()
50
+ if condition_1 in self.cleanData.keys():
51
+ return condition_1
52
+ else:
53
+ if str(int(condition_1[:2])+1) + condition_1[-1] in self.cleanData.keys():
54
+ return str(int(condition_1[:2])+1) + condition_1[-1]
55
+ elif str(int(condition_1[:2])-1) + condition_1[-1] in self.cleanData.keys():
56
+ return str(int(condition_1[:2])-1) + condition_1[-1]
57
+
58
+ def __getCleanSequence(self, condition_1, condition_2):
59
+ time_lenght = 4
60
+ count_stuck = 0
61
+ condition_1 = self.__find_closest_area(condition_1)
62
+ data = self.cleanData[condition_1]
63
+ rand_number = torch.randint(len(data), (1,)).item()
64
+ attributes = data[rand_number]
65
+ while len(attributes['time steps']) < time_lenght:
66
+ rand_number = torch.randint(len(data), (1,)).item()
67
+ attributes = data[rand_number]
68
+ # Check end of the loop:
69
+ '''count_stuck += 1
70
+ if count_stuck >= 2:
71
+ print('Stuck ', count_stuck)'''
72
+
73
+ kernelSize = attributes['kernel size']
74
+ sample = np.load(attributes['path'])
75
+ x_min = attributes['bbox'][1]
76
+ x_max = attributes['bbox'][1]+kernelSize
77
+ y_min = attributes['bbox'][0]
78
+ y_max = attributes['bbox'][0]+kernelSize
79
+
80
+ discriminator_sample = sample["highresdynamic"][y_min:y_max,x_min:x_max, 0:4, attributes['time steps'][0]:attributes['time steps'][0]+time_lenght]
81
+ discriminator_sample = torch.from_numpy(discriminator_sample)
82
+
83
+ ## Data augmentation: Horitzontal Flip
84
+ transformation_1 = RandomHorizontalFlip(p=0.4)
85
+ discriminator_sample[:,:,:,0] = transformation_1(discriminator_sample[:,:,:,0])
86
+ for t in range(1, discriminator_sample.shape[3]):
87
+ discriminator_sample[:,:,:,t] = transformation_1(discriminator_sample[:,:,:,t], params=transformation_1._params)
88
+
89
+ ## Data augmentation: Vertical Flip
90
+ transformation_2 = RandomVerticalFlip(p=0.4)
91
+ discriminator_sample[:,:,:,0] = transformation_2(discriminator_sample[:,:,:,0])
92
+ for t in range(1, discriminator_sample.shape[3]):
93
+ discriminator_sample[:,:,:,t] = transformation_2(discriminator_sample[:,:,:,t], params=transformation_2._params)
94
+
95
+ return discriminator_sample.numpy()
96
+
97
+
98
+ def __getitem__(self, index):
99
+ self.index = index
100
+ sample = np.load(self.data[index])
101
+
102
+ context = 10
103
+ noisyImg = sample["highresdynamic"][:,:,0:4,:context]
104
+
105
+ masks = self.__getMasks(sample, context-1)
106
+ cleanImg = self.__getCleanSequence(self.data[index].split('/')[3][:3], None)
107
+
108
+ noisyImg = np.nan_to_num(np.clip(noisyImg, 0, 1), nan=1.0)
109
+ cleanImg = np.nan_to_num(np.clip(cleanImg, 0, 1), nan=1.0)
110
+
111
+ return np.transpose(noisyImg, (2,0,1,3)), np.transpose(cleanImg, (2,0,1,3)), masks
GANFilling/src/models/convlstm.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+
6
+ class ConvLSTMCell(nn.Module):
7
+
8
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias, device):
9
+ """
10
+ Initialize ConvLSTM cell.
11
+
12
+ Parameters
13
+ ----------
14
+ input_dim: int
15
+ Number of channels of input tensor.
16
+ hidden_dim: int
17
+ Number of channels of hidden state.
18
+ kernel_size: (int, int)
19
+ Size of the convolutional kernel.
20
+ bias: bool
21
+ Whether or not to add the bias.
22
+ """
23
+
24
+ super(ConvLSTMCell, self).__init__()
25
+
26
+ self.input_dim = input_dim
27
+ self.hidden_dim = hidden_dim
28
+
29
+ self.kernel_size = kernel_size
30
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
31
+ self.bias = bias
32
+ self.device = device
33
+
34
+ self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
35
+ out_channels=4*self.hidden_dim,
36
+ kernel_size=self.kernel_size,
37
+ padding=self.padding,
38
+ bias=self.bias)
39
+
40
+ def __initStates(self, size):
41
+ return torch.zeros(size).to(self.device), torch.zeros(size).to(self.device)
42
+ #return torch.zeros(size).cuda(), torch.zeros(size).cuda()
43
+
44
+ def forward(self, input_tensor, cur_state):
45
+ if cur_state == None:
46
+ h_cur, c_cur = self.__initStates([input_tensor.shape[0], self.hidden_dim, input_tensor.shape[2], input_tensor.shape[3]])
47
+ else:
48
+ h_cur, c_cur = cur_state
49
+
50
+ combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
51
+ combined_conv = self.conv(combined)
52
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
53
+
54
+ i = torch.sigmoid(cc_i)
55
+ f = torch.sigmoid(cc_f)
56
+ o = torch.sigmoid(cc_o)
57
+ g = torch.tanh(cc_g)
58
+
59
+ c_next = f * c_cur + i * g
60
+ h_next = o * torch.tanh(c_next)
61
+
62
+ return h_next, c_next
63
+
64
+ def init_hidden(self, batch_size, image_size):
65
+ height, width = image_size
66
+ return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
67
+ torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
68
+
69
+
70
+ class ConvLSTM(nn.Module):
71
+
72
+ """
73
+
74
+ Parameters:
75
+ input_dim: Number of channels in input
76
+ hidden_dim: Number of hidden channels
77
+ kernel_size: Size of kernel in convolutions
78
+ num_layers: Number of LSTM layers stacked on each other
79
+ batch_first: Whether or not dimension 0 is the batch or not
80
+ bias: Bias or no bias in Convolution
81
+ return_all_layers: Return the list of computations for all layers
82
+ Note: Will do same padding.
83
+
84
+ Input:
85
+ A tensor of size B, T, C, H, W or T, B, C, H, W
86
+ Output:
87
+ A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
88
+ 0 - layer_output_list is the list of lists of length T of each output
89
+ 1 - last_state_list is the list of last states
90
+ each element of the list is a tuple (h, c) for hidden state and memory
91
+ Example:
92
+ >> x = torch.rand((32, 10, 64, 128, 128))
93
+ >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
94
+ >> _, last_states = convlstm(x)
95
+ >> h = last_states[0][0] # 0 for layer index, 0 for h index
96
+ """
97
+
98
+ def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
99
+ batch_first=False, bias=True, return_all_layers=False):
100
+ super(ConvLSTM, self).__init__()
101
+
102
+ self._check_kernel_size_consistency(kernel_size)
103
+
104
+ # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
105
+ kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
106
+ hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
107
+ if not len(kernel_size) == len(hidden_dim) == num_layers:
108
+ raise ValueError('Inconsistent list length.')
109
+
110
+ self.input_dim = input_dim
111
+ self.hidden_dim = hidden_dim
112
+ self.kernel_size = kernel_size
113
+ self.num_layers = num_layers
114
+ self.batch_first = batch_first
115
+ self.bias = bias
116
+ self.return_all_layers = return_all_layers
117
+
118
+ cell_list = []
119
+ for i in range(0, self.num_layers):
120
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
121
+
122
+ cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
123
+ hidden_dim=self.hidden_dim[i],
124
+ kernel_size=self.kernel_size[i],
125
+ bias=self.bias))
126
+
127
+ self.cell_list = nn.ModuleList(cell_list)
128
+
129
+ def forward(self, input_tensor, hidden_state=None):
130
+ """
131
+
132
+ Parameters
133
+ ----------
134
+ input_tensor: todo
135
+ 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
136
+ hidden_state: todo
137
+ None. todo implement stateful
138
+
139
+ Returns
140
+ -------
141
+ last_state_list, layer_output
142
+ """
143
+ if not self.batch_first:
144
+ # (t, b, c, h, w) -> (b, t, c, h, w)
145
+ input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
146
+
147
+ b, _, _, h, w = input_tensor.size()
148
+
149
+ # Implement stateful ConvLSTM
150
+ if hidden_state is not None:
151
+ raise NotImplementedError()
152
+ else:
153
+ # Since the init is done in forward. Can send image size here
154
+ hidden_state = self._init_hidden(batch_size=b,
155
+ image_size=(h, w))
156
+
157
+ layer_output_list = []
158
+ last_state_list = []
159
+
160
+ seq_len = input_tensor.size(1)
161
+ cur_layer_input = input_tensor
162
+
163
+ for layer_idx in range(self.num_layers):
164
+
165
+ h, c = hidden_state[layer_idx]
166
+ output_inner = []
167
+ for t in range(seq_len):
168
+ h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
169
+ cur_state=[h, c])
170
+ output_inner.append(h)
171
+
172
+ layer_output = torch.stack(output_inner, dim=1)
173
+ cur_layer_input = layer_output
174
+
175
+ layer_output_list.append(layer_output)
176
+ last_state_list.append([h, c])
177
+
178
+ if not self.return_all_layers:
179
+ layer_output_list = layer_output_list[-1:]
180
+ last_state_list = last_state_list[-1:]
181
+
182
+ return layer_output_list, last_state_list
183
+
184
+ def _init_hidden(self, batch_size, image_size):
185
+ init_states = []
186
+ for i in range(self.num_layers):
187
+ init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
188
+ return init_states
189
+
190
+ @staticmethod
191
+ def _check_kernel_size_consistency(kernel_size):
192
+ if not (isinstance(kernel_size, tuple) or
193
+ (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
194
+ raise ValueError('`kernel_size` must be tuple or list of tuples')
195
+
196
+ @staticmethod
197
+ def _extend_for_multilayer(param, num_layers):
198
+ if not isinstance(param, list):
199
+ param = [param] * num_layers
200
+ return param
GANFilling/src/models/discriminator.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.parametrizations import spectral_norm
5
+
6
+ from convlstm import ConvLSTMCell
7
+
8
+ class Discriminator(nn.Module): #PatchGAN
9
+ def __init__(self, device, inputChannels=6, d=64):
10
+ super().__init__()
11
+ self.device = device
12
+ self.conv1 = nn.Conv2d(inputChannels, d, 4, 2, 1)
13
+ self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
14
+ self.conv2_bn = nn.BatchNorm2d(d * 2, track_running_stats=False)
15
+ self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
16
+ self.conv3_bn = nn.BatchNorm2d(d * 4, track_running_stats=False)
17
+ self.conv4_lstm = ConvLSTMCell(d * 4, d * 4, (3,3), False,self.device)
18
+ self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 1, 1)
19
+ self.conv4_bn = nn.BatchNorm2d(d * 8, track_running_stats=False)
20
+ self.conv5_lstm = ConvLSTMCell(d * 8, d * 8, (3,3), False,self.device)
21
+ self.conv5 = nn.Conv2d(d * 8, 1, 4, 1, 1)
22
+
23
+ self.conv_fusion = nn.Conv2d(4, 1, 1)
24
+
25
+ torch.backends.cudnn.deterministic = True
26
+
27
+ def weight_init(self, mean, std):
28
+ for m in self._modules:
29
+ normal_init(self._modules[m], mean, std)
30
+
31
+ def forward_step(self, input, states):
32
+ x = F.leaky_relu(self.conv1(input), 0.2)
33
+ x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
34
+ x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
35
+ states1 = self.conv4_lstm(x, states[0])
36
+ x = F.leaky_relu(self.conv4_bn(self.conv4(states1[0])), 0.2)
37
+ states2 = self.conv5_lstm(x, states[1])
38
+ x = F.leaky_relu(self.conv5(states2[0]), 0.2)
39
+ return x.squeeze(dim=1), [states1, states2]
40
+
41
+ def forward(self, tensor):
42
+ output = torch.empty((tensor.shape[0], int(tensor.shape[2]/8)-2,int(tensor.shape[3]/8)-2,tensor.shape[4])).to(self.device)
43
+ for patch in range(tensor.shape[0]):
44
+ states = (None,None,None,None)
45
+ for timeStep in range(tensor.shape[4]):
46
+ output[patch,:,:,timeStep], states = self.forward_step(tensor[patch,:,:,:,timeStep].unsqueeze(dim=0), states)
47
+
48
+ return F.sigmoid(output), states
49
+
50
+
51
+ def normal_init(m, mean, std):
52
+ if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
53
+ m.weight.data.normal_(mean, std)
54
+ m.bias.data.zero_()
55
+
56
+ if __name__=='__main__':
57
+ x = torch.zeros((16, 3, 32, 32, 4), dtype=torch.float32)
58
+ model = Discriminator('cpu', inputChannels=3)
59
+ y = model(x)
60
+ print(y[0].size())
GANFilling/src/models/generator.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ sys.path.append('src/models')
7
+ from convlstm import ConvLSTMCell
8
+
9
+ class Generator(nn.Module):
10
+ def __init__(self, device, inputChannels = 4, outputChannels=3, d=64):
11
+ super().__init__()
12
+ self.d = d
13
+ self.device = device
14
+
15
+ self.conv1 = nn.Conv2d(inputChannels, d, 3, 2, 1)
16
+ self.conv2 = nn.Conv2d(d, d * 2, 3, 2, 1)
17
+ self.conv3 = nn.Conv2d(d * 2, d * 4, 3, 2, 1)
18
+ self.conv4 = nn.Conv2d(d * 4, d * 8, 3, 2, 1)
19
+ self.conv5 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
20
+ self.conv6 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
21
+ self.conv7 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
22
+
23
+ self.conv_lstm_d1 = ConvLSTMCell(d * 8, d * 8, (3,3), False, device)
24
+ self.conv_lstm_d2 = ConvLSTMCell(d * 8 * 2 , d * 8, (3,3), False, device)
25
+ self.conv_lstm_d3 = ConvLSTMCell(d * 8 * 2 , d * 8, (3,3), False, device)
26
+ self.conv_lstm_d4 = ConvLSTMCell(d * 8 * 2 , d * 4, (3,3), False, device)
27
+ self.conv_lstm_d5 = ConvLSTMCell(d * 4 * 2 , d * 2, (3,3), False, device)
28
+ self.conv_lstm_d6 = ConvLSTMCell(d * 2 * 2 , d, (3,3), False, device)
29
+ self.conv_lstm_d7 = ConvLSTMCell(d * 2 , d, (3,3), False, device)
30
+
31
+ self.conv_lstm_e1 = ConvLSTMCell(d, d, (3,3), False, device)
32
+ self.conv_lstm_e2 = ConvLSTMCell(d * 2 , d * 2, (3,3), False, device)
33
+ self.conv_lstm_e3 = ConvLSTMCell(d * 4 , d * 4, (3,3), False, device)
34
+ self.conv_lstm_e4 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
35
+ self.conv_lstm_e5 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
36
+ self.conv_lstm_e6 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
37
+ self.conv_lstm_e7 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
38
+
39
+ self.up = nn.Upsample(scale_factor=2)
40
+ self.conv_out = nn.Conv2d(d, outputChannels, 3, 1, 1)
41
+
42
+ self.slope = 0.2
43
+
44
+ def weight_init(self, mean, std):
45
+ for m in self._modules:
46
+ normal_init(self._modules[m], mean, std)
47
+
48
+ def forward_step(self, input, states_encoder, states_decoder):
49
+
50
+ e1 = self.conv1(input)
51
+ states_e1 = self.conv_lstm_e1(e1, states_encoder[0])
52
+ e2 = self.conv2(F.leaky_relu(states_e1[0], self.slope))
53
+ states_e2 = self.conv_lstm_e2(e2, states_encoder[1])
54
+ e3 = self.conv3(F.leaky_relu(states_e2[0], self.slope))
55
+ states_e3 = self.conv_lstm_e3(e3, states_encoder[2])
56
+ e4 = self.conv4(F.leaky_relu(states_e3[0], self.slope))
57
+ states_e4 = self.conv_lstm_e4(e4, states_encoder[3])
58
+ e5 = self.conv5(F.leaky_relu(states_e4[0], self.slope))
59
+ states_e5 = self.conv_lstm_e5(e5, states_encoder[4])
60
+ e6 = self.conv6(F.leaky_relu(states_e5[0], self.slope))
61
+ states_e6 = self.conv_lstm_e6(e6, states_encoder[5])
62
+ e7 = self.conv7(F.leaky_relu(states_e6[0], self.slope))
63
+
64
+
65
+ states1 = self.conv_lstm_d1(F.relu(e7), states_decoder[0])
66
+ d1 = self.up(states1[0])
67
+ d1 = torch.cat([d1, e6], 1)
68
+
69
+ states2 = self.conv_lstm_d2(F.relu(d1), states_decoder[1])
70
+ d2 = self.up(states2[0])
71
+ d2 = torch.cat([d2, e5], 1)
72
+
73
+ states3 = self.conv_lstm_d3(F.relu(d2), states_decoder[2])
74
+ d3 = self.up(states3[0])
75
+ d3 = torch.cat([d3, e4], 1)
76
+
77
+ states4 = self.conv_lstm_d4(F.relu(d3), states_decoder[3])
78
+ d4 = self.up(states4[0])
79
+ d4 = torch.cat([d4, e3], 1)
80
+
81
+ states5 = self.conv_lstm_d5(F.relu(d4), states_decoder[4])
82
+ d5 = self.up(states5[0])
83
+ d5 = torch.cat([d5, e2], 1)
84
+
85
+ states6 = self.conv_lstm_d6(F.relu(d5), states_decoder[5])
86
+ d6 = self.up(states6[0])
87
+ d6 = torch.cat([d6, e1], 1)
88
+
89
+ states7 = self.conv_lstm_d7(F.relu(d6), states_decoder[6])
90
+ d7 = self.up(states7[0])
91
+
92
+ o = torch.clip(torch.tanh(self.conv_out(d7)), min=-0.0, max = 1)
93
+
94
+ states_e = [states_e1, states_e2, states_e3,states_e4, states_e5, states_e6]
95
+ states_d = [states1, states2, states3,states4, states5, states6, states7]
96
+
97
+ return o, (states_e, states_d)
98
+
99
+ def forward(self, tensor):
100
+ states_encoder = (None,None,None,None,None,None,None)
101
+ states_decoder = (None,None,None,None,None,None,None)
102
+ output = torch.empty_like(tensor)
103
+ for timeStep in range(tensor.shape[4]):
104
+ output[:,:,:,:,timeStep], states = self.forward_step(tensor[:,:,:,:,timeStep], states_encoder, states_decoder)
105
+ states_encoder, states_decoder = states[0], states[1]
106
+ return output, states
107
+
108
+
109
+ def normal_init(m, mean, std):
110
+ if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
111
+ m.weight.data.normal_(mean, std)
112
+ m.bias.data.zero_()
113
+
114
+ if __name__=='__main__':
115
+ # batch_size = number of 3D patches
116
+ # num_channles = BGR+NIR
117
+ # h,w = spatial resolution
118
+ states_encoder = (None,None,None,None,None,None,None)
119
+ states_decoder = (None,None,None,None,None,None,None)
120
+ x = torch.zeros((2, 4, 128, 128, 10), dtype=torch.float32)
121
+ model = Generator('cpu', inputChannels=4)
122
+ y, states = model(x)
123
+ print(y.size())
GANFilling/src/test.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import csv
5
+ import numpy as np
6
+
7
+ import argparse
8
+ from pathlib import Path
9
+ from PIL import Image
10
+
11
+ import matplotlib as mpl
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+
15
+ sys.path.append('src/models')
16
+ from generator import *
17
+
18
+
19
+ def main(config):
20
+
21
+ gap_filling_model_path = config.model_path
22
+ results_path = Path(config.results_path)
23
+ data_path = Path(config.data_path)
24
+ list_samples = [x for x in os.listdir(config.data_path) if x.endswith('.npz')]
25
+
26
+
27
+ landcover_cmap = {}
28
+ with open('data/landcover_types.csv', 'r') as fd:
29
+ reader = csv.reader(fd)
30
+ for row in reader:
31
+ landcover_cmap[row[0]] = [row[1], row[2], row[3]]
32
+
33
+ for sample in list_samples:
34
+
35
+ sample_lc = data_path / "landcover" / sample.replace("context_", "")
36
+
37
+ landcover = np.load(sample_lc)['landcover'][0]
38
+ context = np.load(data_path / sample)['highresdynamic'][:,:,:4, :10]
39
+
40
+ model = __load_gapFill_model(gap_filling_model_path)
41
+ gap_filled_sample = __gapFill_context(context, model)
42
+
43
+
44
+ n = gap_filled_sample[0,-1]
45
+ r = gap_filled_sample[0,2]
46
+ ndvi = (n-r)/(n+r)
47
+
48
+ fig, axs = plt.subplots(3, 11, figsize=(15, 4))
49
+
50
+ landcover_rgb = np.empty((128,128, 3))
51
+ for c in np.unique(landcover):
52
+ for ch in range(3):
53
+ landcover_rgb[landcover==c, ch] = landcover_cmap[str(c)][ch]
54
+
55
+ landcover_rgb = Image.fromarray(landcover_rgb.astype('uint8'), 'RGB')
56
+ axs[0, 0].imshow(landcover_rgb)
57
+ axs[0, 0].axis('off')
58
+ axs[1, 0].axis('off')
59
+ axs[2, 0].axis('off')
60
+
61
+ import datetime
62
+ date = datetime.datetime.strptime(sample.split('_')[2], '%Y-%m-%d')
63
+
64
+
65
+ for i in range(10):
66
+ img_rgb = np.clip(context[:,:,:3, i]*3,0,1)[:,:,[2,1,0]]
67
+ img_rgb = (img_rgb * 255).astype('uint8')
68
+ img_rgb = Image.fromarray(img_rgb)
69
+ axs[0, i+1].imshow(img_rgb)
70
+ axs[0, i+1].text(94, 122, (date+datetime.timedelta(days=5*i)).strftime('%j'), color='black', fontsize=10, bbox=dict(boxstyle='square,pad=0.1', fc='white', ec='none'))
71
+ axs[0, i+1].axis('off')
72
+
73
+ for i in range(10):
74
+ img_rgb = np.transpose(gap_filled_sample.cpu().detach().numpy()[0, [2,1,0],:,:, i], (1, 2, 0))
75
+ img_rgb = np.clip(img_rgb*3, 0, 1)
76
+ img_rgb = (img_rgb * 255).astype('uint8')
77
+ img_rgb = Image.fromarray(img_rgb)
78
+ axs[1, i+1].imshow(img_rgb)
79
+ axs[1, i+1].axis('off')
80
+
81
+ cmap = mpl.colormaps.get_cmap('jet')
82
+ cmap.set_bad(color='black')
83
+
84
+ for i in range(10):
85
+ ndvi_img = Image.fromarray(ndvi[:,:,i].cpu().detach().numpy())
86
+ im = axs[2, i+1].imshow(ndvi[:,:,i].cpu().detach().numpy(), cmap=cmap, vmin=0., vmax=1.)
87
+ axs[2, i+1].axis('off')
88
+
89
+ cbar_ax = fig.add_axes([0.2, 0.95, 0.7, 0.03])
90
+ cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
91
+ cbar.outline.set_linewidth(0)
92
+
93
+ plt.subplots_adjust(wspace=0., hspace=0.05)
94
+
95
+ landcover_cmap_legend = {}
96
+ with open('data/landcover_types.csv', 'r') as fd:
97
+ reader = csv.reader(fd)
98
+ for row in reader:
99
+ landcover_cmap_legend[row[5]] = [int(row[1])/255, int(row[2])/255, int(row[3])/255, 1]
100
+
101
+ landcover_cmap_legend.pop('No data')
102
+ legend_items =[mpatches.Patch(color=color,label=lc) for lc, color in landcover_cmap_legend.items()]
103
+
104
+ path = results_path / sample
105
+ plt.savefig(f"{path}.png", bbox_inches='tight')
106
+ plt.close()
107
+
108
+
109
+ ## GAP FILLING FUNCTIONS
110
+
111
+ def __load_gapFill_model(gapFilling_model):
112
+ modelConfiguration = torch.load(gapFilling_model, map_location='cpu')
113
+ trainConfiguration = modelConfiguration['configuration']
114
+ model = Generator('cpu' , trainConfiguration.inputChannels, trainConfiguration.outputChannels)
115
+ model.load_state_dict(modelConfiguration['model'])
116
+ model.eval()
117
+ return model
118
+
119
+ def __prepare_data(context):
120
+ context = np.nan_to_num(np.clip(context, 0, 1), nan=1.0)
121
+ context = np.transpose(context, (2,0,1,3))
122
+ return context
123
+
124
+ def __gapFill_context(context, model):
125
+ context = __prepare_data(context)
126
+ tensor = torch.from_numpy(context).unsqueeze(dim=0).float()
127
+ output, _ = model(tensor)
128
+ return output
129
+
130
+ if __name__=='__main__':
131
+ parser = argparse.ArgumentParser(description = "ArgParse")
132
+ parser.add_argument('--model_path', type=str)
133
+ parser.add_argument('--data_path', type=str)
134
+ parser.add_argument('--results_path', type=str)
135
+
136
+ config = parser.parse_args()
137
+
138
+ main(config)
GANFilling/src/train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import logging
6
+ import numpy as np
7
+ import random
8
+ import torch.nn as nn
9
+ from pathlib import Path
10
+
11
+ from torch.autograd import Variable
12
+ from torchvision import transforms
13
+ from skimage.metrics import structural_similarity as ssim
14
+ from skimage.metrics import peak_signal_noise_ratio as psnr
15
+
16
+ from iterator import Iterator
17
+ from models.generator import Generator
18
+ from models.discriminator import Discriminator
19
+
20
+ class Trainer():
21
+ def __init__(self, configuration):
22
+ self.configuration = configuration
23
+ self.__createDirectories()
24
+ self.__initCriterions()
25
+ self.__setLogger()
26
+ self.__logParser()
27
+ self.__resetAccumulatedMetrics()
28
+
29
+ self.generatorHistLoss = []
30
+ self.generatorHistLoss_adv = []
31
+ self.generatorHistLoss_rec = []
32
+ self.discriminatorHistLoss = []
33
+
34
+ self.metricSSIM = []
35
+ self.metricPSNR = []
36
+
37
+ torch.manual_seed(100)
38
+ # Seed python RNG
39
+ random.seed(100)
40
+ # Seed numpy RNG
41
+ np.random.seed(100)
42
+
43
+ def __logParser(self):
44
+ for arg, value in vars(self.configuration).items():
45
+ logging.info("Argument %s: %r", arg, value)
46
+
47
+ def __createDirectories(self):
48
+ self.modelPath = Path(self.configuration.modelsPath + self.configuration.name )
49
+ Path.mkdir(self.modelPath / 'Images/', parents=True, exist_ok=True)
50
+
51
+ def __initCriterions(self):
52
+ self.discriminatorCriterion = nn.BCELoss(reduction='none')
53
+ self.generatorCriterion = [nn.BCELoss(reduction='none'), nn.L1Loss(reduction='none')]
54
+
55
+ def train(self):
56
+ logging.info('Starting training')
57
+ self.step = 0
58
+ self.mode = 'train'
59
+ for self.epoch in range(self.configuration.maxEpochs):
60
+ for noisyTensor, cleanTensor, masksTensor in self.trainDataLoader:
61
+
62
+ noisyTensor = noisyTensor.float().to(self.device)
63
+ cleanTensor = cleanTensor.float().to(self.device)
64
+ masksTensor = masksTensor.float().to(self.device)
65
+
66
+ ### Update Discriminator
67
+ self.__setRequiresGrad(self.discriminator, True)
68
+ self.discriminator.zero_grad()
69
+
70
+ outputDiscriminator,_ = self.discriminator(cleanTensor)
71
+ realLossDiscriminator = self.discriminatorCriterion(outputDiscriminator, Variable(torch.ones(outputDiscriminator.size()).to(self.device)))
72
+
73
+ outputGenerator, _ = self.generator(noisyTensor)
74
+ outputGenerator_patch, _ = self.__divideInPatches(outputGenerator, cleanTensor.shape[4])
75
+ outputDiscriminator,_ = self.discriminator(outputGenerator_patch)
76
+ fakeLossDiscriminator = self.discriminatorCriterion(outputDiscriminator, Variable(torch.zeros(outputDiscriminator.size()).to(self.device)))
77
+
78
+ trainLossDiscriminator = (torch.mean(realLossDiscriminator) + torch.mean(fakeLossDiscriminator))*0.5
79
+ trainLossDiscriminator.backward()
80
+ self.discriminatorOptimizer.step()
81
+ self.discriminator.zero_grad()
82
+ self.discriminatorHistLoss.append(trainLossDiscriminator)
83
+
84
+ ### Update Generator
85
+ self.__setRequiresGrad(self.discriminator, False)
86
+ self.generator.zero_grad()
87
+
88
+ outputGenerator, _ = self.generator(noisyTensor)
89
+ outputGenerator_patch, startTime = self.__divideInPatches(outputGenerator, 4)
90
+ outputDiscriminator,_ = self.discriminator(outputGenerator_patch)
91
+
92
+ generatorAdversarialLoss = torch.mean(self.generatorCriterion[0](outputDiscriminator, Variable(torch.ones(outputDiscriminator.size()).to(self.device))))
93
+ generatorReconstructionLoss = self.generatorCriterion[1](outputGenerator, noisyTensor) * (1-masksTensor).unsqueeze(dim=1)
94
+ generatorReconstructionLoss = self.configuration.lambdaL1 * torch.mean(generatorReconstructionLoss)
95
+
96
+ trainLossGenerator = torch.mean(generatorAdversarialLoss) + generatorReconstructionLoss
97
+ trainLossGenerator.backward()
98
+ self.generatorOptimizer.step()
99
+ self.generator.zero_grad()
100
+ self.generatorHistLoss.append(trainLossGenerator)
101
+ self.generatorHistLoss_adv.append(torch.mean(generatorAdversarialLoss))
102
+ self.generatorHistLoss_rec.append(torch.mean(generatorReconstructionLoss))
103
+ self.__evaluate(outputGenerator, noisyTensor, masksTensor)
104
+ self.__logMetrics()
105
+ self.__plotGeneratorOutput(outputGenerator, noisyTensor, self.modelPath)
106
+
107
+ if self.step % int(self.configuration.validateEvery) == 0:
108
+ self.__saveModelFreq()
109
+
110
+ def __saveModelFreq(self):
111
+ self.__saveModel(self.generator, 'step_generator_'+str(self.step))
112
+ logging.info('Saved model')
113
+
114
+ def __divideInPatches(self, tensor, timeSteps, kernelSize = 64):
115
+ new_tensor = torch.zeros((1,tensor.shape[1], kernelSize, kernelSize, timeSteps)).float().to(self.device)
116
+ startTime = random.sample([x+1 for x in range(-1,tensor.shape[4]-timeSteps)],1)[0]
117
+ x = random.sample([x+1 for x in range(-1,64)],1)[0]
118
+ y = random.sample([x+1 for x in range(-1,64)],1)[0]
119
+
120
+ new_tensor[0] = tensor[0, :, x:x+kernelSize, y:y+kernelSize, startTime:startTime+timeSteps]
121
+
122
+ return new_tensor, startTime
123
+
124
+ def __unnormalize(self, tensor):
125
+ if self.configuration.normalize:
126
+ return tensor/2 + .5
127
+ else:
128
+ return tensor
129
+
130
+ def __setRequiresGrad(self, net, requires_grad=False):
131
+ for param in net.parameters():
132
+ param.requires_grad = requires_grad
133
+
134
+ def __evaluate(self, outputTensor, noisyTensor, masksTensor):
135
+ outputImages = outputTensor.detach().permute(0,2,3,1,4).cpu().numpy()
136
+ noisyImages = noisyTensor.detach().permute(0,2,3,1,4).cpu().numpy()
137
+ maskImages = masksTensor.detach().cpu().numpy()
138
+
139
+ for imgBatch in range(outputTensor.shape[0]):
140
+ maskImg = np.expand_dims( (1-maskImages[imgBatch]), 2)
141
+ outputImg = outputImages[imgBatch]*maskImg
142
+ noisyImg = noisyImages[imgBatch]*maskImg
143
+
144
+ ssim_aux = []
145
+ psnr_aux = []
146
+ for timeStep in range(noisyImg.shape[3]):
147
+ metric_1 = ssim(outputImg[:,:,:,timeStep], noisyImg[:,:,:,timeStep], multichannel=True)
148
+ metric_2 = psnr(outputImg[:,:,:,timeStep], noisyImg[:,:,:,timeStep])
149
+ if np.isfinite(metric_1):
150
+ ssim_aux.append(metric_1)
151
+ if np.isfinite(metric_2):
152
+ psnr_aux.append(metric_2)
153
+
154
+ self.metricSSIM.append(sum(ssim_aux)/len(ssim_aux))
155
+ self.metricPSNR.append(sum(psnr_aux)/len(psnr_aux))
156
+ self.metricSSIM = sum(self.metricSSIM)/self.configuration.batchSize
157
+ self.metricPSNR = sum(self.metricPSNR)/self.configuration.batchSize
158
+ self.accumulatedMetricSSIM.append(self.metricSSIM)
159
+ self.accumulatedMetricPSNR.append(self.metricPSNR)
160
+
161
+ def __plotGeneratorOutput(self, outputGenerator, noisyImg, path):
162
+ if self.step % int(self.configuration.plotEvery) == 0:
163
+ for sampleNumber in range(outputGenerator.shape[0]):
164
+ for timeStep in range(outputGenerator.shape[4]):
165
+ self.__saveImage(self.__unnormalize(outputGenerator[sampleNumber,:,:,:,timeStep])[[2,1,0],:,:], sampleNumber, path, '_'+str(timeStep))
166
+ self.__saveImage(self.__unnormalize(noisyImg[sampleNumber,:,:,:,timeStep])[[2,1,0],:,:], sampleNumber, path, 'noisy_'+str(timeStep))
167
+
168
+ def __saveImage(self, tensor, sampleNumber, path, label):
169
+ tensor = torch.clamp(3*tensor, -1, 1)
170
+ img = transforms.ToPILImage()(tensor).convert("RGB")
171
+ img.save(str(path /'Images') + '/e'+str(self.epoch)+'_s'+str(self.step)+'_img'+str(sampleNumber)+'_'+label+'.png')
172
+
173
+ def __logMetrics(self):
174
+ if self.step % int(self.configuration.printEvery) == 0:
175
+ logging.info('Epoch {epoch} - Step {step} ---> GLoss: {gloss} - adv: {gloss_adv} - rec: {gloss_rec}, DLoss: {dloss}, SSIM: {ssim} PSNR: {psnr}'.format(\
176
+ epoch=self.epoch, step=self.step, gloss=torch.mean(torch.FloatTensor(self.generatorHistLoss[-self.configuration.printEvery:])), gloss_adv=torch.mean(torch.FloatTensor(self.generatorHistLoss_adv[-self.configuration.printEvery:])), gloss_rec=torch.mean(torch.FloatTensor(self.generatorHistLoss_rec[-self.configuration.printEvery:])),\
177
+ dloss= torch.mean(torch.FloatTensor(self.discriminatorHistLoss[-self.configuration.printEvery:])),\
178
+ ssim=torch.mean(torch.FloatTensor(self.accumulatedMetricSSIM)), psnr=torch.mean(torch.FloatTensor(self.accumulatedMetricPSNR))))
179
+ self.__resetAccumulatedMetrics()
180
+ self.step += 1
181
+ self.__resetMetrics()
182
+
183
+ def __resetMetrics(self):
184
+ self.metricSSIM = []
185
+ self.metricPSNR = []
186
+
187
+ def __resetAccumulatedMetrics(self):
188
+ self.accumulatedMetricSSIM = []
189
+ self.accumulatedMetricPSNR = []
190
+
191
+ def __saveModel(self, model, tag):
192
+ modelDictionary = {'model':model.state_dict(), 'configuration':self.configuration}
193
+ torch.save(modelDictionary, str(self.modelPath / tag)+'.pt')
194
+
195
+ def setData(self):
196
+ self.trainIterator = Iterator(Path(configuration.dataPath), configuration.cleanDataPath, 'train')
197
+ self.trainDataLoader = DataLoader(self.trainIterator, batch_size=self.configuration.batchSize, shuffle=True)
198
+
199
+ self.validIterator = Iterator(Path(configuration.dataPath), configuration.cleanDataPath,'valid')
200
+ self.validDataLoader = DataLoader(self.validIterator, batch_size=1)
201
+
202
+ def setModels(self):
203
+ self.__setDevice()
204
+ self.__setGenerator()
205
+ self.__setDiscriminator()
206
+
207
+ def __setGenerator(self):
208
+ self.generator = Generator(self.device, self.configuration.inputChannels, self.configuration.outputChannels)
209
+ self.generator.weight_init(mean=0.0, std=0.02)
210
+ self.generator.to(self.device)
211
+ self.generatorOptimizer = torch.optim.Adam(self.generator.parameters(),lr=self.configuration.lrG, betas=(0.5,0.999))
212
+
213
+ def __setDiscriminator(self):
214
+ self.discriminator = Discriminator(self.device, self.configuration.inputChannels)
215
+ self.discriminator.weight_init(mean=0.0, std=0.02)
216
+ self.discriminator.to(self.device)
217
+ self.discriminatorOptimizer = torch.optim.Adam(self.discriminator.parameters(),lr=self.configuration.lrD, betas=(0.5,0.999))
218
+
219
+ def __setDevice(self):
220
+ print('TORCH AVAILABLE: ', torch.cuda.is_available())
221
+ if torch.cuda.is_available():
222
+ self.device = 'cuda:0'
223
+ else:
224
+ self.device = 'cpu'
225
+
226
+ def __setLogger(self):
227
+ Path.mkdir(Path(self.configuration.logsPath), exist_ok=True)
228
+ logging.basicConfig(filename=self.configuration.logsPath+self.configuration.name+'.log', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')
229
+
230
+ def main(configuration):
231
+ trainer = Trainer(configuration)
232
+ trainer.setData()
233
+ trainer.setModels()
234
+ trainer.train()
235
+
236
+ if __name__=='__main__':
237
+ parser = argparse.ArgumentParser()
238
+ parser.add_argument('--dataPath', type=str, default="train_data/", help='Path to the .npz files')
239
+ parser.add_argument('--cleanDataPath', type=str, default="utils/example_clean_data.json", help='file with the noiseless data paths for the discriminator')
240
+ parser.add_argument('--logsPath', type=str, default='logs/')
241
+ parser.add_argument('--modelsPath', type=str, default='trained_models/', help='Path to save the trained models')
242
+ parser.add_argument('--name', type=str, help='Model name')
243
+
244
+ parser.add_argument('--maxEpochs', type=int, default=1)
245
+ parser.add_argument('--batchSize', type=int, default=1)
246
+ parser.add_argument('--printEvery',type=int, default=1)
247
+ parser.add_argument('--plotEvery', type=int, default=5)
248
+ parser.add_argument('--validateEvery', type=int, default=5)
249
+
250
+ parser.add_argument('--inputChannels', type=int, default=4)
251
+ parser.add_argument('--outputChannels', type=int, default=4)
252
+ parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
253
+ parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
254
+ parser.add_argument('--lambdaL1', type=float, default=100, help='lambda for L1 loss')
255
+ parser.add_argument('--normalize', action='store_true')
256
+
257
+ configuration = parser.parse_args()
258
+ main(configuration)
GANFilling/src/utils/example_clean_data.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "33U": [
3
+ {
4
+ "path": "data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz",
5
+ "kernel size": 64,
6
+ "bbox": [
7
+ 11,
8
+ 60
9
+ ],
10
+ "time steps": [
11
+ 7,
12
+ 8,
13
+ 9
14
+ ]
15
+ },
16
+ {
17
+ "path": "data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz",
18
+ "kernel size": 64,
19
+ "bbox": [
20
+ 0,
21
+ 44
22
+ ],
23
+ "time steps": [
24
+ 6,
25
+ 7,
26
+ 8,
27
+ 9
28
+ ]
29
+ },
30
+ {
31
+ "path": "data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz",
32
+ "kernel size": 64,
33
+ "bbox": [
34
+ 0,
35
+ 29
36
+ ],
37
+ "time steps": [
38
+ 3,
39
+ 4,
40
+ 5
41
+ ]
42
+ },
43
+ {
44
+ "path": "data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz",
45
+ "kernel size": 64,
46
+ "bbox": [
47
+ 0,
48
+ 35
49
+ ],
50
+ "time steps": [
51
+ 1,
52
+ 2,
53
+ 3
54
+ ]
55
+ },
56
+ {
57
+ "path": "data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz",
58
+ "kernel size": 64,
59
+ "bbox": [
60
+ 0,
61
+ 0
62
+ ],
63
+ "time steps": [
64
+ 7,
65
+ 8,
66
+ 9
67
+ ]
68
+ },
69
+ {
70
+ "path": "data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz",
71
+ "kernel size": 64,
72
+ "bbox": [
73
+ 0,
74
+ 49
75
+ ],
76
+ "time steps": [
77
+ 3,
78
+ 4,
79
+ 5
80
+ ]
81
+ }
82
+ ]
83
+ }
GANFilling/src/utils/generate_cleanData_file.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from itertools import groupby
8
+ from operator import itemgetter
9
+
10
+ import matplotlib.pyplot as plt
11
+ import json
12
+
13
+ folderName = '/train_data/context/'
14
+ split = 'IID'
15
+ data = {split:[]}
16
+ maxSizeKernel = 128
17
+ minSizeKernel = 127
18
+
19
+ i = 0
20
+ for fileName in os.listdir(folderName):
21
+ if fileName == 'LICENSE':
22
+ pass
23
+ else:
24
+ print(fileName)
25
+ for dataCube in os.listdir(folderName+'/'+fileName):
26
+ cubePath = folderName + '/' + fileName + '/' +dataCube
27
+ print('Evaluating data cube ', cubePath)
28
+ sample = np.load(cubePath)
29
+ masks = sample['highresdynamic'][:,:,-1,:10]
30
+
31
+ clean = {}
32
+ total_kernels = {}
33
+ max_lenght_kernels = {}
34
+ for kernelSize in range(maxSizeKernel,minSizeKernel, -1):
35
+ clean[kernelSize] = {}
36
+ total_kernels[kernelSize] = {}
37
+ max_lenght_kernels[kernelSize] = {}
38
+
39
+ ########################################################################
40
+ masks = np.expand_dims(masks, axis=(0, 1))
41
+ validKernel = False
42
+
43
+ for kernelSize in range(maxSizeKernel,minSizeKernel, -1):
44
+ kernel = torch.ones((1,1,kernelSize, kernelSize))
45
+ for timeStep in range(masks.shape[4]):
46
+ conv = F.conv2d(torch.from_numpy(masks[:,:,:,:,timeStep]).float(), kernel)
47
+ clean[kernelSize][timeStep]= (conv == 0).nonzero()
48
+
49
+ if len(clean[kernelSize][timeStep]) != 0:
50
+ for clean_kernel in clean[kernelSize][timeStep]:
51
+ if (clean_kernel[2].item(),clean_kernel[3].item()) not in total_kernels[kernelSize].keys():
52
+ total_kernels[kernelSize][(clean_kernel[2].item(),clean_kernel[3].item())] = [timeStep]
53
+ else:
54
+ total_kernels[kernelSize][(clean_kernel[2].item(),clean_kernel[3].item())].append(timeStep)
55
+ #print('KS: ', kernelSize)
56
+ #print(total_kernels[kernelSize])
57
+ lenght = 3
58
+ kernel_to_save = []
59
+ for key, value in total_kernels[kernelSize].items():
60
+ for k, g in groupby(enumerate(value), lambda ix : ix[0] - ix[1]):
61
+ consecutives = list(map(itemgetter(1), g))
62
+ #print(consecutives)
63
+ #print(len(consecutives) >= lenght)
64
+ if len(consecutives) >= lenght:
65
+ kernel_to_save.append(key)
66
+ kernel_to_save.append(consecutives)
67
+ lenght = len(consecutives)
68
+ validKernel = True
69
+ if validKernel:
70
+ print('Found valid kernel ', kernelSize, ' with ', len(kernel_to_save[1]), ' consecutive time steps')
71
+ data[split].append({
72
+ 'path':str(cubePath),
73
+ 'kernel size': kernelSize,
74
+ 'bbox': kernel_to_save[0],
75
+ 'time steps': kernel_to_save[1]
76
+ })
77
+ validKernel = False
78
+ break
79
+
80
+ print('Saving data ...')
81
+ with open('cleanData.json', 'w') as outfile:
82
+ json.dump(data, outfile)
83
+
84
+