Upload 32 files
Browse files- .gitattributes +16 -0
- GANFilling/.gitattributes +1 -0
- GANFilling/.gitignore +3 -0
- GANFilling/GANFilling_percept_results.png +3 -0
- GANFilling/README.md +62 -0
- GANFilling/data/landcover_types.csv +15 -0
- GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz +3 -0
- GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz +3 -0
- GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz +3 -0
- GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz +3 -0
- GANFilling/data/test_data/landcover/29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz +0 -0
- GANFilling/data/test_data/landcover/30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz +0 -0
- GANFilling/data/test_data/landcover/33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz +0 -0
- GANFilling/data/test_data/landcover/34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz +0 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz +3 -0
- GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz +3 -0
- GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png +3 -0
- GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png +3 -0
- GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png +3 -0
- GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png +3 -0
- GANFilling/src/iterator.py +111 -0
- GANFilling/src/models/convlstm.py +200 -0
- GANFilling/src/models/discriminator.py +60 -0
- GANFilling/src/models/generator.py +123 -0
- GANFilling/src/test.py +138 -0
- GANFilling/src/train.py +258 -0
- GANFilling/src/utils/example_clean_data.json +83 -0
- GANFilling/src/utils/generate_cleanData_file.py +84 -0
.gitattributes
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
GANFilling/GANFilling_percept_results.png filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png filter=lfs diff=lfs merge=lfs -text
|
GANFilling/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
GANFilling/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pycache
|
| 2 |
+
__pycache__
|
| 3 |
+
*/__pycache__
|
GANFilling/GANFilling_percept_results.png
ADDED
|
Git LFS Details
|
GANFilling/README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generative Networks for Spatio-Temporal Gap Filling of Sentinel-2 Reflectances
|
| 2 |
+
---------------
|
| 3 |
+
| [Journal ISPRS paper](https://doi.org/10.1016/j.isprsjprs.2025.01.016) |
|
| 4 |
+
|
| 5 |
+
# Abstract
|
| 6 |
+
|
| 7 |
+
Earth observation from satellite sensors offers the possibility to monitor natural ecosystems by deriving spatially explicit and temporally resolved biogeophysical parameters. Optical remote sensing, however, suffers from missing data mainly due to the presence of clouds, sensor malfunctioning, and atmospheric conditions. This study proposes a novel deep learning architecture to address gap filling of satellite reflectances, more precisely the visible and near-infrared bands, and illustrates its performance at high-resolution Sentinel-2 data. We introduce GANFilling, a generative adversarial network capable of sequence-to-sequence translation, which comprises convolutional long short-term memory layers to effectively exploit complete dependencies in space-time series data. We focus on Europe and evaluate the method's performance quantitatively (through distortion and perceptual metrics) and qualitatively (via visual inspection and visual quality metrics). Quantitatively, our model offers the best trade-off between denoising corrupted data and preserving noise-free information, underscoring the importance of considering multiple metrics jointly when assessing gap filling tasks. Qualitatively, it successfully deals with various noise sources, such as clouds and missing data, constituting a robust solution to multiple scenarios and settings. We also illustrate and quantify the quality of the generated product in the relevant downstream application of vegetation greenness forecasting, where using GANFilling enhances forecasting in approximately 70% of the considered regions in Europe. This research contributes to underlining the utility of deep learning for Earth observation data, which allows for improved spatially and temporally resolved monitoring of the Earth surface.
|
| 8 |
+
|
| 9 |
+
# Usage and Requirements
|
| 10 |
+
|
| 11 |
+
## Train
|
| 12 |
+
|
| 13 |
+
`
|
| 14 |
+
python src/train.py --dataPath data/train_data --cleanDataPath src/utils/example_clean_data.json --name model_1
|
| 15 |
+
`
|
| 16 |
+
|
| 17 |
+
Models are saved to ./trained_models/ (can be changed by passing --modelsPath=your_dir in train.py).
|
| 18 |
+
|
| 19 |
+
For training the discriminator, a json file with the noiseless(real) data has to be generated. This can be done with the ./utils/generate_cleanData_file.py. An example of the structure of this file can be found in ./utils/generate_cleanData_file.py.
|
| 20 |
+
|
| 21 |
+
The folder with the training data is specified with the `--dataPath` argument. This repository just includes few samples in data/train_data/ as a reference for structure and behaviour checking.
|
| 22 |
+
|
| 23 |
+
## Test
|
| 24 |
+
|
| 25 |
+
`
|
| 26 |
+
python src/test.py --model_path trained_models/GANFilling.pt --data_path data --results_path results
|
| 27 |
+
`
|
| 28 |
+
|
| 29 |
+
This will run the GANFilling trained model on a small set of examples and generate the corresponding gap filled time series.
|
| 30 |
+
|
| 31 |
+
To test your own model modify the parameter `--model_path` in test.py.
|
| 32 |
+
To test on your own data modify the parameter `--data_path` in test.py.
|
| 33 |
+
|
| 34 |
+
# Results
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
<b>Example images showing the GANFilling reconstruction on different land covers.</b> For each example, the first row shows the land cover map followed by ten original time steps of a visible (RGB) sequence, while the second row corresponds to its noise-free version. All images are noted with the day-of-year (DOY). The third row illustrates the NDVI maps for the noise-free images. Different types of noise are outlined in red. (A) Complexe scene with predominant herbaceous vegetation and multiple frames with complete loss of information. (B) Sequence with mostly cultivated areas to show the performance on fast changes in the Earth’s surface with heavily occluded frames. (C) Sequence with widespread vines characterized by a rapid evolution of the land cover. (D) Predominant coniferous tree cover with a water body nearby. (E) Sequence with predominant broadleaf tree cover and several consecutive time steps with dense occlusions. Land cover’s legend: <span style="color:rgb(255,255,255)">◻</span> Cloud or No data, <span style="color:rgb(210,0,0)">◼</span> Artificial surfaces and constructions, <span style="color:rgb(253,211,39)">◼</span> Cultivated areas, <span style="color:rgb(176,91,16)">◼</span> Vineyards, <span style="color:rgb(35,152,0)">◼</span> Broadleaf tree cover, <span style="color:rgb(8,98,0)">◼</span> Coniferous tree cover, <span style="color:rgb(249,150,39)">◼</span> Herbaceous vegetation, <span style="color:rgb(141,139,0)">◼</span> Moors and Heathland, <span style="color:rgb(95,53,6)">◼</span> Sclerophyllous vegetation, <span style="color:rgb(149,107,196)">◼</span> Marshes, <span style="color:rgb(77,37,106)">◼</span> Peatbogs, <span style="color:rgb(154,154,154)">◼</span> Natural material surfaces, <span style="color:rgb(106,255,255)">◼</span> Permanent snow covered surfaces, <span style="color:rgb(20,69,249)">◼</span> Water bodies.
|
| 39 |
+
|
| 40 |
+
# How to cite
|
| 41 |
+
|
| 42 |
+
If you use this code for your research, please cite our paper Generative Networks for Spatio-Temporal Gap Filling of Sentinel-2 Reflectances:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
@article{GonzalezCalabuig2025,
|
| 46 |
+
title = {Generative networks for spatio-temporal gap filling of Sentinel-2 reflectances},
|
| 47 |
+
journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
|
| 48 |
+
volume = {220},
|
| 49 |
+
pages = {637-648},
|
| 50 |
+
year = {2025},
|
| 51 |
+
issn = {0924-2716},
|
| 52 |
+
doi = {https://doi.org/10.1016/j.isprsjprs.2025.01.016},
|
| 53 |
+
author = {Maria Gonzalez-Calabuig and Miguel-Ángel Fernández-Torres and Gustau Camps-Valls}}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
# Aknowledgments
|
| 57 |
+
|
| 58 |
+
Authors acknowledge the support from the European Research Council (ERC) under the ERC SynergyGrant USMILE (grant agreement 855187), the European Union’s Horizon 2020 research and innovation program within the projects ‘XAIDA: Extreme Events- Artificial Intelligence for Detection and Attribution,’(grant agreement 101003469), ‘DeepCube: Explainable AI pipelines for big Copernicus data’ (grant agreement 101004188), the ESA AI4Science project ”MultiHazards, Compounds and Cascade events: DeepExtremes”, 2022-2024, the computer resources provided by the J¨ulich Supercomputing Centre (JSC) (Project No.PRACE-DEV-2022D01-048), the computer resources provided by Artemisa (funded by the European Union ERDF and Comunitat Valenciana), as well as the technical support provided by the Instituto de Física Corpuscular, IFIC (CSIC-UV).
|
| 59 |
+
|
| 60 |
+
# License
|
| 61 |
+
|
| 62 |
+
[MIT]()
|
GANFilling/data/landcover_types.csv
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0,255,255,255,255,Clouds or No data
|
| 2 |
+
62,210,0,0,255,Artificial surfaces and constructions
|
| 3 |
+
73,253,211,39,255,Cultivated areas
|
| 4 |
+
75,176,91,16,255,Vineyards
|
| 5 |
+
82,35,152,0,255,Broadleaf tree cover
|
| 6 |
+
83,8,98,0,255,Coniferous tree cover
|
| 7 |
+
102,249,150,39,255,Herbaceous vegetation
|
| 8 |
+
103,141,139,0,255,Moors and Heathland
|
| 9 |
+
104,95,53,6,255,Sclerophyllous vegetation
|
| 10 |
+
105,149,107,196,255,Marshes
|
| 11 |
+
106,77,37,106,255,Peatbogs
|
| 12 |
+
121,154,154,154,255,Natural material surfaces
|
| 13 |
+
123,106,255,255,255,Permanent snow covered surfaces
|
| 14 |
+
162,20,69,249,255,Water bodies
|
| 15 |
+
255,255,255,255,255,No data
|
GANFilling/data/test_data/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ff96d1e04dc8f23e1caa09dfb3b65adc4111393729ed3b963a87b1b5185c048
|
| 3 |
+
size 4735175
|
GANFilling/data/test_data/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:603f455b64f8044910fa78b6dee26f335b924f2fb6f9cbf03539bfbae6c4f037
|
| 3 |
+
size 4198884
|
GANFilling/data/test_data/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e15904f0feee4a86d1ef931d97d34a3b5b7d4e855aed218fad0bfd8dd50ca839
|
| 3 |
+
size 3328672
|
GANFilling/data/test_data/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:716193f3fea3641ce5c94b4169fb047f6ac224ae5d13720cf0a413b2520322fe
|
| 3 |
+
size 5594634
|
GANFilling/data/test_data/landcover/29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz
ADDED
|
Binary file (4.23 kB). View file
|
|
|
GANFilling/data/test_data/landcover/30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz
ADDED
|
Binary file (2.7 kB). View file
|
|
|
GANFilling/data/test_data/landcover/33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz
ADDED
|
Binary file (3.6 kB). View file
|
|
|
GANFilling/data/test_data/landcover/34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz
ADDED
|
Binary file (3.32 kB). View file
|
|
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-03-30_2018-08-26_2873_3001_3513_3641_44_124_54_134.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:372e0bf759d6856b57dd9a5295b9cb03d1ad1507940d4d8b6be582d6d09f4939
|
| 3 |
+
size 8178516
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7127004d5dd01befad7f786fea024f522a75fb8942a49a220b7843a36e1fb7f
|
| 3 |
+
size 8292551
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9bcebacc479854a4715c8096a7ff6d90de208f2e6df0259a3b594390902b28d
|
| 3 |
+
size 8150040
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9ff886ee2e13a9201a3ec2341bea317ba6eb28bc944dd458e5c51b2c94603d2
|
| 3 |
+
size 8096535
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd8ce22607591b366ba2b4ac61a4df19b0fd58c2f7cfaa94504f6d6599134e2f
|
| 3 |
+
size 8291712
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d522b678b06e37d7ce62b4d898bed0fec187a72711efe3c7fc4d7dada43b083
|
| 3 |
+
size 8400326
|
GANFilling/data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:143075cf0da3467a2d25ff4a69f290cf497f43b846147bb9b13167b5c4710068
|
| 3 |
+
size 7970874
|
GANFilling/results/context_29TQF_2017-08-16_2018-01-12_5049_5177_441_569_78_158_6_86.npz.png
ADDED
|
Git LFS Details
|
GANFilling/results/context_30UYV_2018-06-19_2018-11-15_1977_2105_2617_2745_30_110_40_120.npz.png
ADDED
|
Git LFS Details
|
GANFilling/results/context_33SXD_2019-05-16_2019-10-12_4025_4153_3129_3257_62_142_48_128.npz.png
ADDED
|
Git LFS Details
|
GANFilling/results/context_34TDP_2018-04-28_2018-09-24_2233_2361_5049_5177_34_114_78_158.npz.png
ADDED
|
Git LFS Details
|
GANFilling/src/iterator.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from kornia.augmentation import RandomHorizontalFlip, RandomVerticalFlip
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Iterator():
|
| 11 |
+
def __init__(self, dataPath, cleanDataPath, mode):
|
| 12 |
+
self.mode = mode
|
| 13 |
+
self.random_seed_data = 42
|
| 14 |
+
self.__initData(dataPath, cleanDataPath)
|
| 15 |
+
|
| 16 |
+
def __initData(self, dataPath, cleanDataPath):
|
| 17 |
+
self.data = {}
|
| 18 |
+
|
| 19 |
+
if self.mode == 'train':
|
| 20 |
+
self.__loadTrainData(dataPath)
|
| 21 |
+
self.cleanData = json.load(open(cleanDataPath))
|
| 22 |
+
|
| 23 |
+
def __loadTrainData(self, dataPath):
|
| 24 |
+
tiles = os.listdir(dataPath)
|
| 25 |
+
if 'LICENSE' in tiles:
|
| 26 |
+
tiles.remove('LICENSE')
|
| 27 |
+
tiles.sort()
|
| 28 |
+
|
| 29 |
+
loadedData = []
|
| 30 |
+
for tile in tiles:
|
| 31 |
+
in_tile_path = dataPath / tile
|
| 32 |
+
files = os.listdir(in_tile_path)
|
| 33 |
+
files.sort()
|
| 34 |
+
in_files = []
|
| 35 |
+
for file in files:
|
| 36 |
+
in_files.append(os.path.join(in_tile_path, file))
|
| 37 |
+
random.Random(self.random_seed_data).shuffle(in_files)
|
| 38 |
+
for f in in_files[:int(len(in_files)*0.8)]:
|
| 39 |
+
loadedData.append(f)
|
| 40 |
+
self.data = loadedData
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data)
|
| 44 |
+
|
| 45 |
+
def __getMasks(self, sample, index):
|
| 46 |
+
return sample["highresdynamic"][:,:,-1,:index+1]
|
| 47 |
+
|
| 48 |
+
def __find_closest_area(self, condition_1):
|
| 49 |
+
self.cleanData.keys()
|
| 50 |
+
if condition_1 in self.cleanData.keys():
|
| 51 |
+
return condition_1
|
| 52 |
+
else:
|
| 53 |
+
if str(int(condition_1[:2])+1) + condition_1[-1] in self.cleanData.keys():
|
| 54 |
+
return str(int(condition_1[:2])+1) + condition_1[-1]
|
| 55 |
+
elif str(int(condition_1[:2])-1) + condition_1[-1] in self.cleanData.keys():
|
| 56 |
+
return str(int(condition_1[:2])-1) + condition_1[-1]
|
| 57 |
+
|
| 58 |
+
def __getCleanSequence(self, condition_1, condition_2):
|
| 59 |
+
time_lenght = 4
|
| 60 |
+
count_stuck = 0
|
| 61 |
+
condition_1 = self.__find_closest_area(condition_1)
|
| 62 |
+
data = self.cleanData[condition_1]
|
| 63 |
+
rand_number = torch.randint(len(data), (1,)).item()
|
| 64 |
+
attributes = data[rand_number]
|
| 65 |
+
while len(attributes['time steps']) < time_lenght:
|
| 66 |
+
rand_number = torch.randint(len(data), (1,)).item()
|
| 67 |
+
attributes = data[rand_number]
|
| 68 |
+
# Check end of the loop:
|
| 69 |
+
'''count_stuck += 1
|
| 70 |
+
if count_stuck >= 2:
|
| 71 |
+
print('Stuck ', count_stuck)'''
|
| 72 |
+
|
| 73 |
+
kernelSize = attributes['kernel size']
|
| 74 |
+
sample = np.load(attributes['path'])
|
| 75 |
+
x_min = attributes['bbox'][1]
|
| 76 |
+
x_max = attributes['bbox'][1]+kernelSize
|
| 77 |
+
y_min = attributes['bbox'][0]
|
| 78 |
+
y_max = attributes['bbox'][0]+kernelSize
|
| 79 |
+
|
| 80 |
+
discriminator_sample = sample["highresdynamic"][y_min:y_max,x_min:x_max, 0:4, attributes['time steps'][0]:attributes['time steps'][0]+time_lenght]
|
| 81 |
+
discriminator_sample = torch.from_numpy(discriminator_sample)
|
| 82 |
+
|
| 83 |
+
## Data augmentation: Horitzontal Flip
|
| 84 |
+
transformation_1 = RandomHorizontalFlip(p=0.4)
|
| 85 |
+
discriminator_sample[:,:,:,0] = transformation_1(discriminator_sample[:,:,:,0])
|
| 86 |
+
for t in range(1, discriminator_sample.shape[3]):
|
| 87 |
+
discriminator_sample[:,:,:,t] = transformation_1(discriminator_sample[:,:,:,t], params=transformation_1._params)
|
| 88 |
+
|
| 89 |
+
## Data augmentation: Vertical Flip
|
| 90 |
+
transformation_2 = RandomVerticalFlip(p=0.4)
|
| 91 |
+
discriminator_sample[:,:,:,0] = transformation_2(discriminator_sample[:,:,:,0])
|
| 92 |
+
for t in range(1, discriminator_sample.shape[3]):
|
| 93 |
+
discriminator_sample[:,:,:,t] = transformation_2(discriminator_sample[:,:,:,t], params=transformation_2._params)
|
| 94 |
+
|
| 95 |
+
return discriminator_sample.numpy()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, index):
|
| 99 |
+
self.index = index
|
| 100 |
+
sample = np.load(self.data[index])
|
| 101 |
+
|
| 102 |
+
context = 10
|
| 103 |
+
noisyImg = sample["highresdynamic"][:,:,0:4,:context]
|
| 104 |
+
|
| 105 |
+
masks = self.__getMasks(sample, context-1)
|
| 106 |
+
cleanImg = self.__getCleanSequence(self.data[index].split('/')[3][:3], None)
|
| 107 |
+
|
| 108 |
+
noisyImg = np.nan_to_num(np.clip(noisyImg, 0, 1), nan=1.0)
|
| 109 |
+
cleanImg = np.nan_to_num(np.clip(cleanImg, 0, 1), nan=1.0)
|
| 110 |
+
|
| 111 |
+
return np.transpose(noisyImg, (2,0,1,3)), np.transpose(cleanImg, (2,0,1,3)), masks
|
GANFilling/src/models/convlstm.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConvLSTMCell(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias, device):
|
| 9 |
+
"""
|
| 10 |
+
Initialize ConvLSTM cell.
|
| 11 |
+
|
| 12 |
+
Parameters
|
| 13 |
+
----------
|
| 14 |
+
input_dim: int
|
| 15 |
+
Number of channels of input tensor.
|
| 16 |
+
hidden_dim: int
|
| 17 |
+
Number of channels of hidden state.
|
| 18 |
+
kernel_size: (int, int)
|
| 19 |
+
Size of the convolutional kernel.
|
| 20 |
+
bias: bool
|
| 21 |
+
Whether or not to add the bias.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
super(ConvLSTMCell, self).__init__()
|
| 25 |
+
|
| 26 |
+
self.input_dim = input_dim
|
| 27 |
+
self.hidden_dim = hidden_dim
|
| 28 |
+
|
| 29 |
+
self.kernel_size = kernel_size
|
| 30 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
| 31 |
+
self.bias = bias
|
| 32 |
+
self.device = device
|
| 33 |
+
|
| 34 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 35 |
+
out_channels=4*self.hidden_dim,
|
| 36 |
+
kernel_size=self.kernel_size,
|
| 37 |
+
padding=self.padding,
|
| 38 |
+
bias=self.bias)
|
| 39 |
+
|
| 40 |
+
def __initStates(self, size):
|
| 41 |
+
return torch.zeros(size).to(self.device), torch.zeros(size).to(self.device)
|
| 42 |
+
#return torch.zeros(size).cuda(), torch.zeros(size).cuda()
|
| 43 |
+
|
| 44 |
+
def forward(self, input_tensor, cur_state):
|
| 45 |
+
if cur_state == None:
|
| 46 |
+
h_cur, c_cur = self.__initStates([input_tensor.shape[0], self.hidden_dim, input_tensor.shape[2], input_tensor.shape[3]])
|
| 47 |
+
else:
|
| 48 |
+
h_cur, c_cur = cur_state
|
| 49 |
+
|
| 50 |
+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
|
| 51 |
+
combined_conv = self.conv(combined)
|
| 52 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 53 |
+
|
| 54 |
+
i = torch.sigmoid(cc_i)
|
| 55 |
+
f = torch.sigmoid(cc_f)
|
| 56 |
+
o = torch.sigmoid(cc_o)
|
| 57 |
+
g = torch.tanh(cc_g)
|
| 58 |
+
|
| 59 |
+
c_next = f * c_cur + i * g
|
| 60 |
+
h_next = o * torch.tanh(c_next)
|
| 61 |
+
|
| 62 |
+
return h_next, c_next
|
| 63 |
+
|
| 64 |
+
def init_hidden(self, batch_size, image_size):
|
| 65 |
+
height, width = image_size
|
| 66 |
+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
|
| 67 |
+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ConvLSTM(nn.Module):
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
Parameters:
|
| 75 |
+
input_dim: Number of channels in input
|
| 76 |
+
hidden_dim: Number of hidden channels
|
| 77 |
+
kernel_size: Size of kernel in convolutions
|
| 78 |
+
num_layers: Number of LSTM layers stacked on each other
|
| 79 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
| 80 |
+
bias: Bias or no bias in Convolution
|
| 81 |
+
return_all_layers: Return the list of computations for all layers
|
| 82 |
+
Note: Will do same padding.
|
| 83 |
+
|
| 84 |
+
Input:
|
| 85 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
| 86 |
+
Output:
|
| 87 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
| 88 |
+
0 - layer_output_list is the list of lists of length T of each output
|
| 89 |
+
1 - last_state_list is the list of last states
|
| 90 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
| 91 |
+
Example:
|
| 92 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
| 93 |
+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
| 94 |
+
>> _, last_states = convlstm(x)
|
| 95 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
|
| 99 |
+
batch_first=False, bias=True, return_all_layers=False):
|
| 100 |
+
super(ConvLSTM, self).__init__()
|
| 101 |
+
|
| 102 |
+
self._check_kernel_size_consistency(kernel_size)
|
| 103 |
+
|
| 104 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
| 105 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
| 106 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
| 107 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
| 108 |
+
raise ValueError('Inconsistent list length.')
|
| 109 |
+
|
| 110 |
+
self.input_dim = input_dim
|
| 111 |
+
self.hidden_dim = hidden_dim
|
| 112 |
+
self.kernel_size = kernel_size
|
| 113 |
+
self.num_layers = num_layers
|
| 114 |
+
self.batch_first = batch_first
|
| 115 |
+
self.bias = bias
|
| 116 |
+
self.return_all_layers = return_all_layers
|
| 117 |
+
|
| 118 |
+
cell_list = []
|
| 119 |
+
for i in range(0, self.num_layers):
|
| 120 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
| 121 |
+
|
| 122 |
+
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
|
| 123 |
+
hidden_dim=self.hidden_dim[i],
|
| 124 |
+
kernel_size=self.kernel_size[i],
|
| 125 |
+
bias=self.bias))
|
| 126 |
+
|
| 127 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 128 |
+
|
| 129 |
+
def forward(self, input_tensor, hidden_state=None):
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
input_tensor: todo
|
| 135 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
| 136 |
+
hidden_state: todo
|
| 137 |
+
None. todo implement stateful
|
| 138 |
+
|
| 139 |
+
Returns
|
| 140 |
+
-------
|
| 141 |
+
last_state_list, layer_output
|
| 142 |
+
"""
|
| 143 |
+
if not self.batch_first:
|
| 144 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
| 145 |
+
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
|
| 146 |
+
|
| 147 |
+
b, _, _, h, w = input_tensor.size()
|
| 148 |
+
|
| 149 |
+
# Implement stateful ConvLSTM
|
| 150 |
+
if hidden_state is not None:
|
| 151 |
+
raise NotImplementedError()
|
| 152 |
+
else:
|
| 153 |
+
# Since the init is done in forward. Can send image size here
|
| 154 |
+
hidden_state = self._init_hidden(batch_size=b,
|
| 155 |
+
image_size=(h, w))
|
| 156 |
+
|
| 157 |
+
layer_output_list = []
|
| 158 |
+
last_state_list = []
|
| 159 |
+
|
| 160 |
+
seq_len = input_tensor.size(1)
|
| 161 |
+
cur_layer_input = input_tensor
|
| 162 |
+
|
| 163 |
+
for layer_idx in range(self.num_layers):
|
| 164 |
+
|
| 165 |
+
h, c = hidden_state[layer_idx]
|
| 166 |
+
output_inner = []
|
| 167 |
+
for t in range(seq_len):
|
| 168 |
+
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
|
| 169 |
+
cur_state=[h, c])
|
| 170 |
+
output_inner.append(h)
|
| 171 |
+
|
| 172 |
+
layer_output = torch.stack(output_inner, dim=1)
|
| 173 |
+
cur_layer_input = layer_output
|
| 174 |
+
|
| 175 |
+
layer_output_list.append(layer_output)
|
| 176 |
+
last_state_list.append([h, c])
|
| 177 |
+
|
| 178 |
+
if not self.return_all_layers:
|
| 179 |
+
layer_output_list = layer_output_list[-1:]
|
| 180 |
+
last_state_list = last_state_list[-1:]
|
| 181 |
+
|
| 182 |
+
return layer_output_list, last_state_list
|
| 183 |
+
|
| 184 |
+
def _init_hidden(self, batch_size, image_size):
|
| 185 |
+
init_states = []
|
| 186 |
+
for i in range(self.num_layers):
|
| 187 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
| 188 |
+
return init_states
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def _check_kernel_size_consistency(kernel_size):
|
| 192 |
+
if not (isinstance(kernel_size, tuple) or
|
| 193 |
+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
|
| 194 |
+
raise ValueError('`kernel_size` must be tuple or list of tuples')
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _extend_for_multilayer(param, num_layers):
|
| 198 |
+
if not isinstance(param, list):
|
| 199 |
+
param = [param] * num_layers
|
| 200 |
+
return param
|
GANFilling/src/models/discriminator.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils.parametrizations import spectral_norm
|
| 5 |
+
|
| 6 |
+
from convlstm import ConvLSTMCell
|
| 7 |
+
|
| 8 |
+
class Discriminator(nn.Module): #PatchGAN
|
| 9 |
+
def __init__(self, device, inputChannels=6, d=64):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.device = device
|
| 12 |
+
self.conv1 = nn.Conv2d(inputChannels, d, 4, 2, 1)
|
| 13 |
+
self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
|
| 14 |
+
self.conv2_bn = nn.BatchNorm2d(d * 2, track_running_stats=False)
|
| 15 |
+
self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
|
| 16 |
+
self.conv3_bn = nn.BatchNorm2d(d * 4, track_running_stats=False)
|
| 17 |
+
self.conv4_lstm = ConvLSTMCell(d * 4, d * 4, (3,3), False,self.device)
|
| 18 |
+
self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 1, 1)
|
| 19 |
+
self.conv4_bn = nn.BatchNorm2d(d * 8, track_running_stats=False)
|
| 20 |
+
self.conv5_lstm = ConvLSTMCell(d * 8, d * 8, (3,3), False,self.device)
|
| 21 |
+
self.conv5 = nn.Conv2d(d * 8, 1, 4, 1, 1)
|
| 22 |
+
|
| 23 |
+
self.conv_fusion = nn.Conv2d(4, 1, 1)
|
| 24 |
+
|
| 25 |
+
torch.backends.cudnn.deterministic = True
|
| 26 |
+
|
| 27 |
+
def weight_init(self, mean, std):
|
| 28 |
+
for m in self._modules:
|
| 29 |
+
normal_init(self._modules[m], mean, std)
|
| 30 |
+
|
| 31 |
+
def forward_step(self, input, states):
|
| 32 |
+
x = F.leaky_relu(self.conv1(input), 0.2)
|
| 33 |
+
x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
|
| 34 |
+
x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
|
| 35 |
+
states1 = self.conv4_lstm(x, states[0])
|
| 36 |
+
x = F.leaky_relu(self.conv4_bn(self.conv4(states1[0])), 0.2)
|
| 37 |
+
states2 = self.conv5_lstm(x, states[1])
|
| 38 |
+
x = F.leaky_relu(self.conv5(states2[0]), 0.2)
|
| 39 |
+
return x.squeeze(dim=1), [states1, states2]
|
| 40 |
+
|
| 41 |
+
def forward(self, tensor):
|
| 42 |
+
output = torch.empty((tensor.shape[0], int(tensor.shape[2]/8)-2,int(tensor.shape[3]/8)-2,tensor.shape[4])).to(self.device)
|
| 43 |
+
for patch in range(tensor.shape[0]):
|
| 44 |
+
states = (None,None,None,None)
|
| 45 |
+
for timeStep in range(tensor.shape[4]):
|
| 46 |
+
output[patch,:,:,timeStep], states = self.forward_step(tensor[patch,:,:,:,timeStep].unsqueeze(dim=0), states)
|
| 47 |
+
|
| 48 |
+
return F.sigmoid(output), states
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normal_init(m, mean, std):
|
| 52 |
+
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
|
| 53 |
+
m.weight.data.normal_(mean, std)
|
| 54 |
+
m.bias.data.zero_()
|
| 55 |
+
|
| 56 |
+
if __name__=='__main__':
|
| 57 |
+
x = torch.zeros((16, 3, 32, 32, 4), dtype=torch.float32)
|
| 58 |
+
model = Discriminator('cpu', inputChannels=3)
|
| 59 |
+
y = model(x)
|
| 60 |
+
print(y[0].size())
|
GANFilling/src/models/generator.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
sys.path.append('src/models')
|
| 7 |
+
from convlstm import ConvLSTMCell
|
| 8 |
+
|
| 9 |
+
class Generator(nn.Module):
|
| 10 |
+
def __init__(self, device, inputChannels = 4, outputChannels=3, d=64):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.d = d
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
self.conv1 = nn.Conv2d(inputChannels, d, 3, 2, 1)
|
| 16 |
+
self.conv2 = nn.Conv2d(d, d * 2, 3, 2, 1)
|
| 17 |
+
self.conv3 = nn.Conv2d(d * 2, d * 4, 3, 2, 1)
|
| 18 |
+
self.conv4 = nn.Conv2d(d * 4, d * 8, 3, 2, 1)
|
| 19 |
+
self.conv5 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
|
| 20 |
+
self.conv6 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
|
| 21 |
+
self.conv7 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
|
| 22 |
+
|
| 23 |
+
self.conv_lstm_d1 = ConvLSTMCell(d * 8, d * 8, (3,3), False, device)
|
| 24 |
+
self.conv_lstm_d2 = ConvLSTMCell(d * 8 * 2 , d * 8, (3,3), False, device)
|
| 25 |
+
self.conv_lstm_d3 = ConvLSTMCell(d * 8 * 2 , d * 8, (3,3), False, device)
|
| 26 |
+
self.conv_lstm_d4 = ConvLSTMCell(d * 8 * 2 , d * 4, (3,3), False, device)
|
| 27 |
+
self.conv_lstm_d5 = ConvLSTMCell(d * 4 * 2 , d * 2, (3,3), False, device)
|
| 28 |
+
self.conv_lstm_d6 = ConvLSTMCell(d * 2 * 2 , d, (3,3), False, device)
|
| 29 |
+
self.conv_lstm_d7 = ConvLSTMCell(d * 2 , d, (3,3), False, device)
|
| 30 |
+
|
| 31 |
+
self.conv_lstm_e1 = ConvLSTMCell(d, d, (3,3), False, device)
|
| 32 |
+
self.conv_lstm_e2 = ConvLSTMCell(d * 2 , d * 2, (3,3), False, device)
|
| 33 |
+
self.conv_lstm_e3 = ConvLSTMCell(d * 4 , d * 4, (3,3), False, device)
|
| 34 |
+
self.conv_lstm_e4 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
|
| 35 |
+
self.conv_lstm_e5 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
|
| 36 |
+
self.conv_lstm_e6 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
|
| 37 |
+
self.conv_lstm_e7 = ConvLSTMCell(d * 8 , d * 8, (3,3), False, device)
|
| 38 |
+
|
| 39 |
+
self.up = nn.Upsample(scale_factor=2)
|
| 40 |
+
self.conv_out = nn.Conv2d(d, outputChannels, 3, 1, 1)
|
| 41 |
+
|
| 42 |
+
self.slope = 0.2
|
| 43 |
+
|
| 44 |
+
def weight_init(self, mean, std):
|
| 45 |
+
for m in self._modules:
|
| 46 |
+
normal_init(self._modules[m], mean, std)
|
| 47 |
+
|
| 48 |
+
def forward_step(self, input, states_encoder, states_decoder):
|
| 49 |
+
|
| 50 |
+
e1 = self.conv1(input)
|
| 51 |
+
states_e1 = self.conv_lstm_e1(e1, states_encoder[0])
|
| 52 |
+
e2 = self.conv2(F.leaky_relu(states_e1[0], self.slope))
|
| 53 |
+
states_e2 = self.conv_lstm_e2(e2, states_encoder[1])
|
| 54 |
+
e3 = self.conv3(F.leaky_relu(states_e2[0], self.slope))
|
| 55 |
+
states_e3 = self.conv_lstm_e3(e3, states_encoder[2])
|
| 56 |
+
e4 = self.conv4(F.leaky_relu(states_e3[0], self.slope))
|
| 57 |
+
states_e4 = self.conv_lstm_e4(e4, states_encoder[3])
|
| 58 |
+
e5 = self.conv5(F.leaky_relu(states_e4[0], self.slope))
|
| 59 |
+
states_e5 = self.conv_lstm_e5(e5, states_encoder[4])
|
| 60 |
+
e6 = self.conv6(F.leaky_relu(states_e5[0], self.slope))
|
| 61 |
+
states_e6 = self.conv_lstm_e6(e6, states_encoder[5])
|
| 62 |
+
e7 = self.conv7(F.leaky_relu(states_e6[0], self.slope))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
states1 = self.conv_lstm_d1(F.relu(e7), states_decoder[0])
|
| 66 |
+
d1 = self.up(states1[0])
|
| 67 |
+
d1 = torch.cat([d1, e6], 1)
|
| 68 |
+
|
| 69 |
+
states2 = self.conv_lstm_d2(F.relu(d1), states_decoder[1])
|
| 70 |
+
d2 = self.up(states2[0])
|
| 71 |
+
d2 = torch.cat([d2, e5], 1)
|
| 72 |
+
|
| 73 |
+
states3 = self.conv_lstm_d3(F.relu(d2), states_decoder[2])
|
| 74 |
+
d3 = self.up(states3[0])
|
| 75 |
+
d3 = torch.cat([d3, e4], 1)
|
| 76 |
+
|
| 77 |
+
states4 = self.conv_lstm_d4(F.relu(d3), states_decoder[3])
|
| 78 |
+
d4 = self.up(states4[0])
|
| 79 |
+
d4 = torch.cat([d4, e3], 1)
|
| 80 |
+
|
| 81 |
+
states5 = self.conv_lstm_d5(F.relu(d4), states_decoder[4])
|
| 82 |
+
d5 = self.up(states5[0])
|
| 83 |
+
d5 = torch.cat([d5, e2], 1)
|
| 84 |
+
|
| 85 |
+
states6 = self.conv_lstm_d6(F.relu(d5), states_decoder[5])
|
| 86 |
+
d6 = self.up(states6[0])
|
| 87 |
+
d6 = torch.cat([d6, e1], 1)
|
| 88 |
+
|
| 89 |
+
states7 = self.conv_lstm_d7(F.relu(d6), states_decoder[6])
|
| 90 |
+
d7 = self.up(states7[0])
|
| 91 |
+
|
| 92 |
+
o = torch.clip(torch.tanh(self.conv_out(d7)), min=-0.0, max = 1)
|
| 93 |
+
|
| 94 |
+
states_e = [states_e1, states_e2, states_e3,states_e4, states_e5, states_e6]
|
| 95 |
+
states_d = [states1, states2, states3,states4, states5, states6, states7]
|
| 96 |
+
|
| 97 |
+
return o, (states_e, states_d)
|
| 98 |
+
|
| 99 |
+
def forward(self, tensor):
|
| 100 |
+
states_encoder = (None,None,None,None,None,None,None)
|
| 101 |
+
states_decoder = (None,None,None,None,None,None,None)
|
| 102 |
+
output = torch.empty_like(tensor)
|
| 103 |
+
for timeStep in range(tensor.shape[4]):
|
| 104 |
+
output[:,:,:,:,timeStep], states = self.forward_step(tensor[:,:,:,:,timeStep], states_encoder, states_decoder)
|
| 105 |
+
states_encoder, states_decoder = states[0], states[1]
|
| 106 |
+
return output, states
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def normal_init(m, mean, std):
|
| 110 |
+
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
|
| 111 |
+
m.weight.data.normal_(mean, std)
|
| 112 |
+
m.bias.data.zero_()
|
| 113 |
+
|
| 114 |
+
if __name__=='__main__':
|
| 115 |
+
# batch_size = number of 3D patches
|
| 116 |
+
# num_channles = BGR+NIR
|
| 117 |
+
# h,w = spatial resolution
|
| 118 |
+
states_encoder = (None,None,None,None,None,None,None)
|
| 119 |
+
states_decoder = (None,None,None,None,None,None,None)
|
| 120 |
+
x = torch.zeros((2, 4, 128, 128, 10), dtype=torch.float32)
|
| 121 |
+
model = Generator('cpu', inputChannels=4)
|
| 122 |
+
y, states = model(x)
|
| 123 |
+
print(y.size())
|
GANFilling/src/test.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import csv
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import matplotlib as mpl
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.patches as mpatches
|
| 14 |
+
|
| 15 |
+
sys.path.append('src/models')
|
| 16 |
+
from generator import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main(config):
|
| 20 |
+
|
| 21 |
+
gap_filling_model_path = config.model_path
|
| 22 |
+
results_path = Path(config.results_path)
|
| 23 |
+
data_path = Path(config.data_path)
|
| 24 |
+
list_samples = [x for x in os.listdir(config.data_path) if x.endswith('.npz')]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
landcover_cmap = {}
|
| 28 |
+
with open('data/landcover_types.csv', 'r') as fd:
|
| 29 |
+
reader = csv.reader(fd)
|
| 30 |
+
for row in reader:
|
| 31 |
+
landcover_cmap[row[0]] = [row[1], row[2], row[3]]
|
| 32 |
+
|
| 33 |
+
for sample in list_samples:
|
| 34 |
+
|
| 35 |
+
sample_lc = data_path / "landcover" / sample.replace("context_", "")
|
| 36 |
+
|
| 37 |
+
landcover = np.load(sample_lc)['landcover'][0]
|
| 38 |
+
context = np.load(data_path / sample)['highresdynamic'][:,:,:4, :10]
|
| 39 |
+
|
| 40 |
+
model = __load_gapFill_model(gap_filling_model_path)
|
| 41 |
+
gap_filled_sample = __gapFill_context(context, model)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
n = gap_filled_sample[0,-1]
|
| 45 |
+
r = gap_filled_sample[0,2]
|
| 46 |
+
ndvi = (n-r)/(n+r)
|
| 47 |
+
|
| 48 |
+
fig, axs = plt.subplots(3, 11, figsize=(15, 4))
|
| 49 |
+
|
| 50 |
+
landcover_rgb = np.empty((128,128, 3))
|
| 51 |
+
for c in np.unique(landcover):
|
| 52 |
+
for ch in range(3):
|
| 53 |
+
landcover_rgb[landcover==c, ch] = landcover_cmap[str(c)][ch]
|
| 54 |
+
|
| 55 |
+
landcover_rgb = Image.fromarray(landcover_rgb.astype('uint8'), 'RGB')
|
| 56 |
+
axs[0, 0].imshow(landcover_rgb)
|
| 57 |
+
axs[0, 0].axis('off')
|
| 58 |
+
axs[1, 0].axis('off')
|
| 59 |
+
axs[2, 0].axis('off')
|
| 60 |
+
|
| 61 |
+
import datetime
|
| 62 |
+
date = datetime.datetime.strptime(sample.split('_')[2], '%Y-%m-%d')
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
for i in range(10):
|
| 66 |
+
img_rgb = np.clip(context[:,:,:3, i]*3,0,1)[:,:,[2,1,0]]
|
| 67 |
+
img_rgb = (img_rgb * 255).astype('uint8')
|
| 68 |
+
img_rgb = Image.fromarray(img_rgb)
|
| 69 |
+
axs[0, i+1].imshow(img_rgb)
|
| 70 |
+
axs[0, i+1].text(94, 122, (date+datetime.timedelta(days=5*i)).strftime('%j'), color='black', fontsize=10, bbox=dict(boxstyle='square,pad=0.1', fc='white', ec='none'))
|
| 71 |
+
axs[0, i+1].axis('off')
|
| 72 |
+
|
| 73 |
+
for i in range(10):
|
| 74 |
+
img_rgb = np.transpose(gap_filled_sample.cpu().detach().numpy()[0, [2,1,0],:,:, i], (1, 2, 0))
|
| 75 |
+
img_rgb = np.clip(img_rgb*3, 0, 1)
|
| 76 |
+
img_rgb = (img_rgb * 255).astype('uint8')
|
| 77 |
+
img_rgb = Image.fromarray(img_rgb)
|
| 78 |
+
axs[1, i+1].imshow(img_rgb)
|
| 79 |
+
axs[1, i+1].axis('off')
|
| 80 |
+
|
| 81 |
+
cmap = mpl.colormaps.get_cmap('jet')
|
| 82 |
+
cmap.set_bad(color='black')
|
| 83 |
+
|
| 84 |
+
for i in range(10):
|
| 85 |
+
ndvi_img = Image.fromarray(ndvi[:,:,i].cpu().detach().numpy())
|
| 86 |
+
im = axs[2, i+1].imshow(ndvi[:,:,i].cpu().detach().numpy(), cmap=cmap, vmin=0., vmax=1.)
|
| 87 |
+
axs[2, i+1].axis('off')
|
| 88 |
+
|
| 89 |
+
cbar_ax = fig.add_axes([0.2, 0.95, 0.7, 0.03])
|
| 90 |
+
cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
|
| 91 |
+
cbar.outline.set_linewidth(0)
|
| 92 |
+
|
| 93 |
+
plt.subplots_adjust(wspace=0., hspace=0.05)
|
| 94 |
+
|
| 95 |
+
landcover_cmap_legend = {}
|
| 96 |
+
with open('data/landcover_types.csv', 'r') as fd:
|
| 97 |
+
reader = csv.reader(fd)
|
| 98 |
+
for row in reader:
|
| 99 |
+
landcover_cmap_legend[row[5]] = [int(row[1])/255, int(row[2])/255, int(row[3])/255, 1]
|
| 100 |
+
|
| 101 |
+
landcover_cmap_legend.pop('No data')
|
| 102 |
+
legend_items =[mpatches.Patch(color=color,label=lc) for lc, color in landcover_cmap_legend.items()]
|
| 103 |
+
|
| 104 |
+
path = results_path / sample
|
| 105 |
+
plt.savefig(f"{path}.png", bbox_inches='tight')
|
| 106 |
+
plt.close()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
## GAP FILLING FUNCTIONS
|
| 110 |
+
|
| 111 |
+
def __load_gapFill_model(gapFilling_model):
|
| 112 |
+
modelConfiguration = torch.load(gapFilling_model, map_location='cpu')
|
| 113 |
+
trainConfiguration = modelConfiguration['configuration']
|
| 114 |
+
model = Generator('cpu' , trainConfiguration.inputChannels, trainConfiguration.outputChannels)
|
| 115 |
+
model.load_state_dict(modelConfiguration['model'])
|
| 116 |
+
model.eval()
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
def __prepare_data(context):
|
| 120 |
+
context = np.nan_to_num(np.clip(context, 0, 1), nan=1.0)
|
| 121 |
+
context = np.transpose(context, (2,0,1,3))
|
| 122 |
+
return context
|
| 123 |
+
|
| 124 |
+
def __gapFill_context(context, model):
|
| 125 |
+
context = __prepare_data(context)
|
| 126 |
+
tensor = torch.from_numpy(context).unsqueeze(dim=0).float()
|
| 127 |
+
output, _ = model(tensor)
|
| 128 |
+
return output
|
| 129 |
+
|
| 130 |
+
if __name__=='__main__':
|
| 131 |
+
parser = argparse.ArgumentParser(description = "ArgParse")
|
| 132 |
+
parser.add_argument('--model_path', type=str)
|
| 133 |
+
parser.add_argument('--data_path', type=str)
|
| 134 |
+
parser.add_argument('--results_path', type=str)
|
| 135 |
+
|
| 136 |
+
config = parser.parse_args()
|
| 137 |
+
|
| 138 |
+
main(config)
|
GANFilling/src/train.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from skimage.metrics import structural_similarity as ssim
|
| 14 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
| 15 |
+
|
| 16 |
+
from iterator import Iterator
|
| 17 |
+
from models.generator import Generator
|
| 18 |
+
from models.discriminator import Discriminator
|
| 19 |
+
|
| 20 |
+
class Trainer():
|
| 21 |
+
def __init__(self, configuration):
|
| 22 |
+
self.configuration = configuration
|
| 23 |
+
self.__createDirectories()
|
| 24 |
+
self.__initCriterions()
|
| 25 |
+
self.__setLogger()
|
| 26 |
+
self.__logParser()
|
| 27 |
+
self.__resetAccumulatedMetrics()
|
| 28 |
+
|
| 29 |
+
self.generatorHistLoss = []
|
| 30 |
+
self.generatorHistLoss_adv = []
|
| 31 |
+
self.generatorHistLoss_rec = []
|
| 32 |
+
self.discriminatorHistLoss = []
|
| 33 |
+
|
| 34 |
+
self.metricSSIM = []
|
| 35 |
+
self.metricPSNR = []
|
| 36 |
+
|
| 37 |
+
torch.manual_seed(100)
|
| 38 |
+
# Seed python RNG
|
| 39 |
+
random.seed(100)
|
| 40 |
+
# Seed numpy RNG
|
| 41 |
+
np.random.seed(100)
|
| 42 |
+
|
| 43 |
+
def __logParser(self):
|
| 44 |
+
for arg, value in vars(self.configuration).items():
|
| 45 |
+
logging.info("Argument %s: %r", arg, value)
|
| 46 |
+
|
| 47 |
+
def __createDirectories(self):
|
| 48 |
+
self.modelPath = Path(self.configuration.modelsPath + self.configuration.name )
|
| 49 |
+
Path.mkdir(self.modelPath / 'Images/', parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
def __initCriterions(self):
|
| 52 |
+
self.discriminatorCriterion = nn.BCELoss(reduction='none')
|
| 53 |
+
self.generatorCriterion = [nn.BCELoss(reduction='none'), nn.L1Loss(reduction='none')]
|
| 54 |
+
|
| 55 |
+
def train(self):
|
| 56 |
+
logging.info('Starting training')
|
| 57 |
+
self.step = 0
|
| 58 |
+
self.mode = 'train'
|
| 59 |
+
for self.epoch in range(self.configuration.maxEpochs):
|
| 60 |
+
for noisyTensor, cleanTensor, masksTensor in self.trainDataLoader:
|
| 61 |
+
|
| 62 |
+
noisyTensor = noisyTensor.float().to(self.device)
|
| 63 |
+
cleanTensor = cleanTensor.float().to(self.device)
|
| 64 |
+
masksTensor = masksTensor.float().to(self.device)
|
| 65 |
+
|
| 66 |
+
### Update Discriminator
|
| 67 |
+
self.__setRequiresGrad(self.discriminator, True)
|
| 68 |
+
self.discriminator.zero_grad()
|
| 69 |
+
|
| 70 |
+
outputDiscriminator,_ = self.discriminator(cleanTensor)
|
| 71 |
+
realLossDiscriminator = self.discriminatorCriterion(outputDiscriminator, Variable(torch.ones(outputDiscriminator.size()).to(self.device)))
|
| 72 |
+
|
| 73 |
+
outputGenerator, _ = self.generator(noisyTensor)
|
| 74 |
+
outputGenerator_patch, _ = self.__divideInPatches(outputGenerator, cleanTensor.shape[4])
|
| 75 |
+
outputDiscriminator,_ = self.discriminator(outputGenerator_patch)
|
| 76 |
+
fakeLossDiscriminator = self.discriminatorCriterion(outputDiscriminator, Variable(torch.zeros(outputDiscriminator.size()).to(self.device)))
|
| 77 |
+
|
| 78 |
+
trainLossDiscriminator = (torch.mean(realLossDiscriminator) + torch.mean(fakeLossDiscriminator))*0.5
|
| 79 |
+
trainLossDiscriminator.backward()
|
| 80 |
+
self.discriminatorOptimizer.step()
|
| 81 |
+
self.discriminator.zero_grad()
|
| 82 |
+
self.discriminatorHistLoss.append(trainLossDiscriminator)
|
| 83 |
+
|
| 84 |
+
### Update Generator
|
| 85 |
+
self.__setRequiresGrad(self.discriminator, False)
|
| 86 |
+
self.generator.zero_grad()
|
| 87 |
+
|
| 88 |
+
outputGenerator, _ = self.generator(noisyTensor)
|
| 89 |
+
outputGenerator_patch, startTime = self.__divideInPatches(outputGenerator, 4)
|
| 90 |
+
outputDiscriminator,_ = self.discriminator(outputGenerator_patch)
|
| 91 |
+
|
| 92 |
+
generatorAdversarialLoss = torch.mean(self.generatorCriterion[0](outputDiscriminator, Variable(torch.ones(outputDiscriminator.size()).to(self.device))))
|
| 93 |
+
generatorReconstructionLoss = self.generatorCriterion[1](outputGenerator, noisyTensor) * (1-masksTensor).unsqueeze(dim=1)
|
| 94 |
+
generatorReconstructionLoss = self.configuration.lambdaL1 * torch.mean(generatorReconstructionLoss)
|
| 95 |
+
|
| 96 |
+
trainLossGenerator = torch.mean(generatorAdversarialLoss) + generatorReconstructionLoss
|
| 97 |
+
trainLossGenerator.backward()
|
| 98 |
+
self.generatorOptimizer.step()
|
| 99 |
+
self.generator.zero_grad()
|
| 100 |
+
self.generatorHistLoss.append(trainLossGenerator)
|
| 101 |
+
self.generatorHistLoss_adv.append(torch.mean(generatorAdversarialLoss))
|
| 102 |
+
self.generatorHistLoss_rec.append(torch.mean(generatorReconstructionLoss))
|
| 103 |
+
self.__evaluate(outputGenerator, noisyTensor, masksTensor)
|
| 104 |
+
self.__logMetrics()
|
| 105 |
+
self.__plotGeneratorOutput(outputGenerator, noisyTensor, self.modelPath)
|
| 106 |
+
|
| 107 |
+
if self.step % int(self.configuration.validateEvery) == 0:
|
| 108 |
+
self.__saveModelFreq()
|
| 109 |
+
|
| 110 |
+
def __saveModelFreq(self):
|
| 111 |
+
self.__saveModel(self.generator, 'step_generator_'+str(self.step))
|
| 112 |
+
logging.info('Saved model')
|
| 113 |
+
|
| 114 |
+
def __divideInPatches(self, tensor, timeSteps, kernelSize = 64):
|
| 115 |
+
new_tensor = torch.zeros((1,tensor.shape[1], kernelSize, kernelSize, timeSteps)).float().to(self.device)
|
| 116 |
+
startTime = random.sample([x+1 for x in range(-1,tensor.shape[4]-timeSteps)],1)[0]
|
| 117 |
+
x = random.sample([x+1 for x in range(-1,64)],1)[0]
|
| 118 |
+
y = random.sample([x+1 for x in range(-1,64)],1)[0]
|
| 119 |
+
|
| 120 |
+
new_tensor[0] = tensor[0, :, x:x+kernelSize, y:y+kernelSize, startTime:startTime+timeSteps]
|
| 121 |
+
|
| 122 |
+
return new_tensor, startTime
|
| 123 |
+
|
| 124 |
+
def __unnormalize(self, tensor):
|
| 125 |
+
if self.configuration.normalize:
|
| 126 |
+
return tensor/2 + .5
|
| 127 |
+
else:
|
| 128 |
+
return tensor
|
| 129 |
+
|
| 130 |
+
def __setRequiresGrad(self, net, requires_grad=False):
|
| 131 |
+
for param in net.parameters():
|
| 132 |
+
param.requires_grad = requires_grad
|
| 133 |
+
|
| 134 |
+
def __evaluate(self, outputTensor, noisyTensor, masksTensor):
|
| 135 |
+
outputImages = outputTensor.detach().permute(0,2,3,1,4).cpu().numpy()
|
| 136 |
+
noisyImages = noisyTensor.detach().permute(0,2,3,1,4).cpu().numpy()
|
| 137 |
+
maskImages = masksTensor.detach().cpu().numpy()
|
| 138 |
+
|
| 139 |
+
for imgBatch in range(outputTensor.shape[0]):
|
| 140 |
+
maskImg = np.expand_dims( (1-maskImages[imgBatch]), 2)
|
| 141 |
+
outputImg = outputImages[imgBatch]*maskImg
|
| 142 |
+
noisyImg = noisyImages[imgBatch]*maskImg
|
| 143 |
+
|
| 144 |
+
ssim_aux = []
|
| 145 |
+
psnr_aux = []
|
| 146 |
+
for timeStep in range(noisyImg.shape[3]):
|
| 147 |
+
metric_1 = ssim(outputImg[:,:,:,timeStep], noisyImg[:,:,:,timeStep], multichannel=True)
|
| 148 |
+
metric_2 = psnr(outputImg[:,:,:,timeStep], noisyImg[:,:,:,timeStep])
|
| 149 |
+
if np.isfinite(metric_1):
|
| 150 |
+
ssim_aux.append(metric_1)
|
| 151 |
+
if np.isfinite(metric_2):
|
| 152 |
+
psnr_aux.append(metric_2)
|
| 153 |
+
|
| 154 |
+
self.metricSSIM.append(sum(ssim_aux)/len(ssim_aux))
|
| 155 |
+
self.metricPSNR.append(sum(psnr_aux)/len(psnr_aux))
|
| 156 |
+
self.metricSSIM = sum(self.metricSSIM)/self.configuration.batchSize
|
| 157 |
+
self.metricPSNR = sum(self.metricPSNR)/self.configuration.batchSize
|
| 158 |
+
self.accumulatedMetricSSIM.append(self.metricSSIM)
|
| 159 |
+
self.accumulatedMetricPSNR.append(self.metricPSNR)
|
| 160 |
+
|
| 161 |
+
def __plotGeneratorOutput(self, outputGenerator, noisyImg, path):
|
| 162 |
+
if self.step % int(self.configuration.plotEvery) == 0:
|
| 163 |
+
for sampleNumber in range(outputGenerator.shape[0]):
|
| 164 |
+
for timeStep in range(outputGenerator.shape[4]):
|
| 165 |
+
self.__saveImage(self.__unnormalize(outputGenerator[sampleNumber,:,:,:,timeStep])[[2,1,0],:,:], sampleNumber, path, '_'+str(timeStep))
|
| 166 |
+
self.__saveImage(self.__unnormalize(noisyImg[sampleNumber,:,:,:,timeStep])[[2,1,0],:,:], sampleNumber, path, 'noisy_'+str(timeStep))
|
| 167 |
+
|
| 168 |
+
def __saveImage(self, tensor, sampleNumber, path, label):
|
| 169 |
+
tensor = torch.clamp(3*tensor, -1, 1)
|
| 170 |
+
img = transforms.ToPILImage()(tensor).convert("RGB")
|
| 171 |
+
img.save(str(path /'Images') + '/e'+str(self.epoch)+'_s'+str(self.step)+'_img'+str(sampleNumber)+'_'+label+'.png')
|
| 172 |
+
|
| 173 |
+
def __logMetrics(self):
|
| 174 |
+
if self.step % int(self.configuration.printEvery) == 0:
|
| 175 |
+
logging.info('Epoch {epoch} - Step {step} ---> GLoss: {gloss} - adv: {gloss_adv} - rec: {gloss_rec}, DLoss: {dloss}, SSIM: {ssim} PSNR: {psnr}'.format(\
|
| 176 |
+
epoch=self.epoch, step=self.step, gloss=torch.mean(torch.FloatTensor(self.generatorHistLoss[-self.configuration.printEvery:])), gloss_adv=torch.mean(torch.FloatTensor(self.generatorHistLoss_adv[-self.configuration.printEvery:])), gloss_rec=torch.mean(torch.FloatTensor(self.generatorHistLoss_rec[-self.configuration.printEvery:])),\
|
| 177 |
+
dloss= torch.mean(torch.FloatTensor(self.discriminatorHistLoss[-self.configuration.printEvery:])),\
|
| 178 |
+
ssim=torch.mean(torch.FloatTensor(self.accumulatedMetricSSIM)), psnr=torch.mean(torch.FloatTensor(self.accumulatedMetricPSNR))))
|
| 179 |
+
self.__resetAccumulatedMetrics()
|
| 180 |
+
self.step += 1
|
| 181 |
+
self.__resetMetrics()
|
| 182 |
+
|
| 183 |
+
def __resetMetrics(self):
|
| 184 |
+
self.metricSSIM = []
|
| 185 |
+
self.metricPSNR = []
|
| 186 |
+
|
| 187 |
+
def __resetAccumulatedMetrics(self):
|
| 188 |
+
self.accumulatedMetricSSIM = []
|
| 189 |
+
self.accumulatedMetricPSNR = []
|
| 190 |
+
|
| 191 |
+
def __saveModel(self, model, tag):
|
| 192 |
+
modelDictionary = {'model':model.state_dict(), 'configuration':self.configuration}
|
| 193 |
+
torch.save(modelDictionary, str(self.modelPath / tag)+'.pt')
|
| 194 |
+
|
| 195 |
+
def setData(self):
|
| 196 |
+
self.trainIterator = Iterator(Path(configuration.dataPath), configuration.cleanDataPath, 'train')
|
| 197 |
+
self.trainDataLoader = DataLoader(self.trainIterator, batch_size=self.configuration.batchSize, shuffle=True)
|
| 198 |
+
|
| 199 |
+
self.validIterator = Iterator(Path(configuration.dataPath), configuration.cleanDataPath,'valid')
|
| 200 |
+
self.validDataLoader = DataLoader(self.validIterator, batch_size=1)
|
| 201 |
+
|
| 202 |
+
def setModels(self):
|
| 203 |
+
self.__setDevice()
|
| 204 |
+
self.__setGenerator()
|
| 205 |
+
self.__setDiscriminator()
|
| 206 |
+
|
| 207 |
+
def __setGenerator(self):
|
| 208 |
+
self.generator = Generator(self.device, self.configuration.inputChannels, self.configuration.outputChannels)
|
| 209 |
+
self.generator.weight_init(mean=0.0, std=0.02)
|
| 210 |
+
self.generator.to(self.device)
|
| 211 |
+
self.generatorOptimizer = torch.optim.Adam(self.generator.parameters(),lr=self.configuration.lrG, betas=(0.5,0.999))
|
| 212 |
+
|
| 213 |
+
def __setDiscriminator(self):
|
| 214 |
+
self.discriminator = Discriminator(self.device, self.configuration.inputChannels)
|
| 215 |
+
self.discriminator.weight_init(mean=0.0, std=0.02)
|
| 216 |
+
self.discriminator.to(self.device)
|
| 217 |
+
self.discriminatorOptimizer = torch.optim.Adam(self.discriminator.parameters(),lr=self.configuration.lrD, betas=(0.5,0.999))
|
| 218 |
+
|
| 219 |
+
def __setDevice(self):
|
| 220 |
+
print('TORCH AVAILABLE: ', torch.cuda.is_available())
|
| 221 |
+
if torch.cuda.is_available():
|
| 222 |
+
self.device = 'cuda:0'
|
| 223 |
+
else:
|
| 224 |
+
self.device = 'cpu'
|
| 225 |
+
|
| 226 |
+
def __setLogger(self):
|
| 227 |
+
Path.mkdir(Path(self.configuration.logsPath), exist_ok=True)
|
| 228 |
+
logging.basicConfig(filename=self.configuration.logsPath+self.configuration.name+'.log', format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')
|
| 229 |
+
|
| 230 |
+
def main(configuration):
|
| 231 |
+
trainer = Trainer(configuration)
|
| 232 |
+
trainer.setData()
|
| 233 |
+
trainer.setModels()
|
| 234 |
+
trainer.train()
|
| 235 |
+
|
| 236 |
+
if __name__=='__main__':
|
| 237 |
+
parser = argparse.ArgumentParser()
|
| 238 |
+
parser.add_argument('--dataPath', type=str, default="train_data/", help='Path to the .npz files')
|
| 239 |
+
parser.add_argument('--cleanDataPath', type=str, default="utils/example_clean_data.json", help='file with the noiseless data paths for the discriminator')
|
| 240 |
+
parser.add_argument('--logsPath', type=str, default='logs/')
|
| 241 |
+
parser.add_argument('--modelsPath', type=str, default='trained_models/', help='Path to save the trained models')
|
| 242 |
+
parser.add_argument('--name', type=str, help='Model name')
|
| 243 |
+
|
| 244 |
+
parser.add_argument('--maxEpochs', type=int, default=1)
|
| 245 |
+
parser.add_argument('--batchSize', type=int, default=1)
|
| 246 |
+
parser.add_argument('--printEvery',type=int, default=1)
|
| 247 |
+
parser.add_argument('--plotEvery', type=int, default=5)
|
| 248 |
+
parser.add_argument('--validateEvery', type=int, default=5)
|
| 249 |
+
|
| 250 |
+
parser.add_argument('--inputChannels', type=int, default=4)
|
| 251 |
+
parser.add_argument('--outputChannels', type=int, default=4)
|
| 252 |
+
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
|
| 253 |
+
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
|
| 254 |
+
parser.add_argument('--lambdaL1', type=float, default=100, help='lambda for L1 loss')
|
| 255 |
+
parser.add_argument('--normalize', action='store_true')
|
| 256 |
+
|
| 257 |
+
configuration = parser.parse_args()
|
| 258 |
+
main(configuration)
|
GANFilling/src/utils/example_clean_data.json
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"33U": [
|
| 3 |
+
{
|
| 4 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2105_2233_3769_3897_32_112_58_138.npz",
|
| 5 |
+
"kernel size": 64,
|
| 6 |
+
"bbox": [
|
| 7 |
+
11,
|
| 8 |
+
60
|
| 9 |
+
],
|
| 10 |
+
"time steps": [
|
| 11 |
+
7,
|
| 12 |
+
8,
|
| 13 |
+
9
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-07-28_2018-12-24_4025_4153_2105_2233_62_142_32_112.npz",
|
| 18 |
+
"kernel size": 64,
|
| 19 |
+
"bbox": [
|
| 20 |
+
0,
|
| 21 |
+
44
|
| 22 |
+
],
|
| 23 |
+
"time steps": [
|
| 24 |
+
6,
|
| 25 |
+
7,
|
| 26 |
+
8,
|
| 27 |
+
9
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_2617_2745_3513_3641_40_120_54_134.npz",
|
| 32 |
+
"kernel size": 64,
|
| 33 |
+
"bbox": [
|
| 34 |
+
0,
|
| 35 |
+
29
|
| 36 |
+
],
|
| 37 |
+
"time steps": [
|
| 38 |
+
3,
|
| 39 |
+
4,
|
| 40 |
+
5
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-07-18_2018-12-14_825_953_4537_4665_12_92_70_150.npz",
|
| 45 |
+
"kernel size": 64,
|
| 46 |
+
"bbox": [
|
| 47 |
+
0,
|
| 48 |
+
35
|
| 49 |
+
],
|
| 50 |
+
"time steps": [
|
| 51 |
+
1,
|
| 52 |
+
2,
|
| 53 |
+
3
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-06-18_2018-11-14_2745_2873_3897_4025_42_122_60_140.npz",
|
| 58 |
+
"kernel size": 64,
|
| 59 |
+
"bbox": [
|
| 60 |
+
0,
|
| 61 |
+
0
|
| 62 |
+
],
|
| 63 |
+
"time steps": [
|
| 64 |
+
7,
|
| 65 |
+
8,
|
| 66 |
+
9
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"path": "data/train_data/33UXQ/33UXQ_2018-07-08_2018-12-04_697_825_4793_4921_10_90_74_154.npz",
|
| 71 |
+
"kernel size": 64,
|
| 72 |
+
"bbox": [
|
| 73 |
+
0,
|
| 74 |
+
49
|
| 75 |
+
],
|
| 76 |
+
"time steps": [
|
| 77 |
+
3,
|
| 78 |
+
4,
|
| 79 |
+
5
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
]
|
| 83 |
+
}
|
GANFilling/src/utils/generate_cleanData_file.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from itertools import groupby
|
| 8 |
+
from operator import itemgetter
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
folderName = '/train_data/context/'
|
| 14 |
+
split = 'IID'
|
| 15 |
+
data = {split:[]}
|
| 16 |
+
maxSizeKernel = 128
|
| 17 |
+
minSizeKernel = 127
|
| 18 |
+
|
| 19 |
+
i = 0
|
| 20 |
+
for fileName in os.listdir(folderName):
|
| 21 |
+
if fileName == 'LICENSE':
|
| 22 |
+
pass
|
| 23 |
+
else:
|
| 24 |
+
print(fileName)
|
| 25 |
+
for dataCube in os.listdir(folderName+'/'+fileName):
|
| 26 |
+
cubePath = folderName + '/' + fileName + '/' +dataCube
|
| 27 |
+
print('Evaluating data cube ', cubePath)
|
| 28 |
+
sample = np.load(cubePath)
|
| 29 |
+
masks = sample['highresdynamic'][:,:,-1,:10]
|
| 30 |
+
|
| 31 |
+
clean = {}
|
| 32 |
+
total_kernels = {}
|
| 33 |
+
max_lenght_kernels = {}
|
| 34 |
+
for kernelSize in range(maxSizeKernel,minSizeKernel, -1):
|
| 35 |
+
clean[kernelSize] = {}
|
| 36 |
+
total_kernels[kernelSize] = {}
|
| 37 |
+
max_lenght_kernels[kernelSize] = {}
|
| 38 |
+
|
| 39 |
+
########################################################################
|
| 40 |
+
masks = np.expand_dims(masks, axis=(0, 1))
|
| 41 |
+
validKernel = False
|
| 42 |
+
|
| 43 |
+
for kernelSize in range(maxSizeKernel,minSizeKernel, -1):
|
| 44 |
+
kernel = torch.ones((1,1,kernelSize, kernelSize))
|
| 45 |
+
for timeStep in range(masks.shape[4]):
|
| 46 |
+
conv = F.conv2d(torch.from_numpy(masks[:,:,:,:,timeStep]).float(), kernel)
|
| 47 |
+
clean[kernelSize][timeStep]= (conv == 0).nonzero()
|
| 48 |
+
|
| 49 |
+
if len(clean[kernelSize][timeStep]) != 0:
|
| 50 |
+
for clean_kernel in clean[kernelSize][timeStep]:
|
| 51 |
+
if (clean_kernel[2].item(),clean_kernel[3].item()) not in total_kernels[kernelSize].keys():
|
| 52 |
+
total_kernels[kernelSize][(clean_kernel[2].item(),clean_kernel[3].item())] = [timeStep]
|
| 53 |
+
else:
|
| 54 |
+
total_kernels[kernelSize][(clean_kernel[2].item(),clean_kernel[3].item())].append(timeStep)
|
| 55 |
+
#print('KS: ', kernelSize)
|
| 56 |
+
#print(total_kernels[kernelSize])
|
| 57 |
+
lenght = 3
|
| 58 |
+
kernel_to_save = []
|
| 59 |
+
for key, value in total_kernels[kernelSize].items():
|
| 60 |
+
for k, g in groupby(enumerate(value), lambda ix : ix[0] - ix[1]):
|
| 61 |
+
consecutives = list(map(itemgetter(1), g))
|
| 62 |
+
#print(consecutives)
|
| 63 |
+
#print(len(consecutives) >= lenght)
|
| 64 |
+
if len(consecutives) >= lenght:
|
| 65 |
+
kernel_to_save.append(key)
|
| 66 |
+
kernel_to_save.append(consecutives)
|
| 67 |
+
lenght = len(consecutives)
|
| 68 |
+
validKernel = True
|
| 69 |
+
if validKernel:
|
| 70 |
+
print('Found valid kernel ', kernelSize, ' with ', len(kernel_to_save[1]), ' consecutive time steps')
|
| 71 |
+
data[split].append({
|
| 72 |
+
'path':str(cubePath),
|
| 73 |
+
'kernel size': kernelSize,
|
| 74 |
+
'bbox': kernel_to_save[0],
|
| 75 |
+
'time steps': kernel_to_save[1]
|
| 76 |
+
})
|
| 77 |
+
validKernel = False
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
print('Saving data ...')
|
| 81 |
+
with open('cleanData.json', 'w') as outfile:
|
| 82 |
+
json.dump(data, outfile)
|
| 83 |
+
|
| 84 |
+
|