Spaces:
Runtime error
Runtime error
hfspace gradio demo
Browse files- LICENSE +201 -0
- README.md +2 -13
- app.py +119 -0
- demo_examples/baby.png +0 -0
- demo_examples/bird.png +0 -0
- demo_examples/butterfly.png +0 -0
- demo_examples/head.png +0 -0
- demo_examples/woman.png +0 -0
- ds.py +485 -0
- losses.py +131 -0
- networks_SRGAN.py +347 -0
- networks_T1toT2.py +477 -0
- requirements.txt +334 -0
- src/.gitkeep +0 -0
- src/__pycache__/ds.cpython-310.pyc +0 -0
- src/__pycache__/losses.cpython-310.pyc +0 -0
- src/__pycache__/networks_SRGAN.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/app.py +115 -0
- src/ds.py +485 -0
- src/flagged/Alpha/0.png +0 -0
- src/flagged/Beta/0.png +0 -0
- src/flagged/Low-res/0.png +0 -0
- src/flagged/Orignal/0.png +0 -0
- src/flagged/Super-res/0.png +0 -0
- src/flagged/Uncertainty/0.png +0 -0
- src/flagged/log.csv +2 -0
- src/losses.py +131 -0
- src/networks_SRGAN.py +347 -0
- src/networks_T1toT2.py +477 -0
- src/utils.py +1273 -0
- utils.py +1304 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,13 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
emoji: 🦀
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 3.0.24
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# BayesCap
|
| 2 |
+
Bayesian Identity Cap for Calibrated Uncertainty in Pretrained Neural Networks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from matplotlib import cm
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.models as models
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from ds import *
|
| 17 |
+
from losses import *
|
| 18 |
+
from networks_SRGAN import *
|
| 19 |
+
from utils import *
|
| 20 |
+
|
| 21 |
+
device = 'cuda'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
NetG = Generator()
|
| 25 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
| 26 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 27 |
+
print("Number of Parameters:", params)
|
| 28 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
| 29 |
+
|
| 30 |
+
ensure_checkpoint_exists('BayesCap_SRGAN.pth')
|
| 31 |
+
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
|
| 32 |
+
NetG.to(device)
|
| 33 |
+
NetG.eval()
|
| 34 |
+
|
| 35 |
+
ensure_checkpoint_exists('BayesCap_ckpt.pth')
|
| 36 |
+
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
|
| 37 |
+
NetC.to(device)
|
| 38 |
+
NetC.eval()
|
| 39 |
+
|
| 40 |
+
def tensor01_to_pil(xt):
|
| 41 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
| 42 |
+
return r
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def predict(img):
|
| 46 |
+
"""
|
| 47 |
+
img: image
|
| 48 |
+
"""
|
| 49 |
+
image_size = (256,256)
|
| 50 |
+
upscale_factor = 4
|
| 51 |
+
lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
| 52 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
| 53 |
+
|
| 54 |
+
img = Image.fromarray(np.array(img))
|
| 55 |
+
img = lr_transforms(img)
|
| 56 |
+
lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
|
| 57 |
+
|
| 58 |
+
device = 'cuda'
|
| 59 |
+
dtype = torch.cuda.FloatTensor
|
| 60 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
| 61 |
+
xLR = xLR.type(dtype)
|
| 62 |
+
# pass them through the network
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
xSR = NetG(xLR)
|
| 65 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 66 |
+
|
| 67 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
| 68 |
+
b_map = xSRC_beta[0].to('cpu').data
|
| 69 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 73 |
+
|
| 74 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 75 |
+
|
| 76 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
| 77 |
+
|
| 78 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
| 79 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
| 80 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 81 |
+
|
| 82 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
| 83 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
| 84 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 85 |
+
|
| 86 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
| 87 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
| 88 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 89 |
+
|
| 90 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
| 91 |
+
|
| 92 |
+
import gradio as gr
|
| 93 |
+
|
| 94 |
+
title = "BayesCap"
|
| 95 |
+
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
|
| 96 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
gr.Interface(
|
| 100 |
+
fn=predict,
|
| 101 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
| 102 |
+
outputs=[
|
| 103 |
+
gr.outputs.Image(type='pil', label="Low-res"),
|
| 104 |
+
gr.outputs.Image(type='pil', label="Super-res"),
|
| 105 |
+
gr.outputs.Image(type='pil', label="Alpha"),
|
| 106 |
+
gr.outputs.Image(type='pil', label="Beta"),
|
| 107 |
+
gr.outputs.Image(type='pil', label="Uncertainty")
|
| 108 |
+
],
|
| 109 |
+
title=title,
|
| 110 |
+
description=description,
|
| 111 |
+
article=article,
|
| 112 |
+
examples=[
|
| 113 |
+
["./demo_examples/baby.png"],
|
| 114 |
+
["./demo_examples/bird.png"],
|
| 115 |
+
["./demo_examples/butterfly.png"],
|
| 116 |
+
["./demo_examples/head.png"],
|
| 117 |
+
["./demo_examples/woman.png"],
|
| 118 |
+
]
|
| 119 |
+
).launch(share=True)
|
demo_examples/baby.png
ADDED
|
demo_examples/bird.png
ADDED
|
demo_examples/butterfly.png
ADDED
|
demo_examples/head.png
ADDED
|
demo_examples/woman.png
ADDED
|
ds.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, division, print_function
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import copy
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import skimage.transform
|
| 10 |
+
from collections import Counter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.utils.data as data
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
| 19 |
+
|
| 20 |
+
import utils
|
| 21 |
+
|
| 22 |
+
class ImgDset(Dataset):
|
| 23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dataroot (str): Training data set address
|
| 27 |
+
image_size (int): High resolution image size
|
| 28 |
+
upscale_factor (int): Image magnification
|
| 29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
| 30 |
+
and the verification data set is not for data enhancement
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
| 35 |
+
super(ImgDset, self).__init__()
|
| 36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
| 37 |
+
|
| 38 |
+
if mode == "train":
|
| 39 |
+
self.hr_transforms = transforms.Compose([
|
| 40 |
+
transforms.RandomCrop(image_size),
|
| 41 |
+
transforms.RandomRotation(90),
|
| 42 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 43 |
+
])
|
| 44 |
+
else:
|
| 45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
| 46 |
+
|
| 47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
| 50 |
+
# Read a batch of image data
|
| 51 |
+
image = Image.open(self.filenames[batch_index])
|
| 52 |
+
|
| 53 |
+
# Transform image
|
| 54 |
+
hr_image = self.hr_transforms(image)
|
| 55 |
+
lr_image = self.lr_transforms(hr_image)
|
| 56 |
+
|
| 57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
| 58 |
+
# Note: The range of input and output is between [0, 1]
|
| 59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
| 60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
| 61 |
+
|
| 62 |
+
return lr_tensor, hr_tensor
|
| 63 |
+
|
| 64 |
+
def __len__(self) -> int:
|
| 65 |
+
return len(self.filenames)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PairedImages_w_nameList(Dataset):
|
| 69 |
+
'''
|
| 70 |
+
can act as supervised or un-supervised based on flists
|
| 71 |
+
'''
|
| 72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
| 73 |
+
self.flist1 = flist1
|
| 74 |
+
self.flist2 = flist2
|
| 75 |
+
self.transform1 = transform1
|
| 76 |
+
self.transform2 = transform2
|
| 77 |
+
self.do_aug = do_aug
|
| 78 |
+
def __getitem__(self, index):
|
| 79 |
+
impath1 = self.flist1[index]
|
| 80 |
+
img1 = Image.open(impath1).convert('RGB')
|
| 81 |
+
impath2 = self.flist2[index]
|
| 82 |
+
img2 = Image.open(impath2).convert('RGB')
|
| 83 |
+
|
| 84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
| 85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
| 86 |
+
|
| 87 |
+
if self.transform1 is not None:
|
| 88 |
+
img1 = self.transform1(img1)
|
| 89 |
+
if self.transform2 is not None:
|
| 90 |
+
img2 = self.transform2(img2)
|
| 91 |
+
|
| 92 |
+
return img1, img2
|
| 93 |
+
def __len__(self):
|
| 94 |
+
return len(self.flist1)
|
| 95 |
+
|
| 96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
| 97 |
+
'''
|
| 98 |
+
can act as supervised or un-supervised based on flists
|
| 99 |
+
'''
|
| 100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
| 101 |
+
self.flist1 = flist1
|
| 102 |
+
self.flist2 = flist2
|
| 103 |
+
self.transform1 = transform1
|
| 104 |
+
self.transform2 = transform2
|
| 105 |
+
self.do_aug = do_aug
|
| 106 |
+
def __getitem__(self, index):
|
| 107 |
+
impath1 = self.flist1[index]
|
| 108 |
+
img1 = np.load(impath1)
|
| 109 |
+
impath2 = self.flist2[index]
|
| 110 |
+
img2 = np.load(impath2)
|
| 111 |
+
|
| 112 |
+
if self.transform1 is not None:
|
| 113 |
+
img1 = self.transform1(img1)
|
| 114 |
+
if self.transform2 is not None:
|
| 115 |
+
img2 = self.transform2(img2)
|
| 116 |
+
|
| 117 |
+
return img1, img2
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.flist1)
|
| 120 |
+
|
| 121 |
+
# def call_paired():
|
| 122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
| 123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
| 124 |
+
|
| 125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
| 126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
| 127 |
+
|
| 128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
| 129 |
+
|
| 130 |
+
#### KITTI depth
|
| 131 |
+
|
| 132 |
+
def load_velodyne_points(filename):
|
| 133 |
+
"""Load 3D point cloud from KITTI file format
|
| 134 |
+
(adapted from https://github.com/hunse/kitti)
|
| 135 |
+
"""
|
| 136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
| 137 |
+
points[:, 3] = 1.0 # homogeneous
|
| 138 |
+
return points
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def read_calib_file(path):
|
| 142 |
+
"""Read KITTI calibration file
|
| 143 |
+
(from https://github.com/hunse/kitti)
|
| 144 |
+
"""
|
| 145 |
+
float_chars = set("0123456789.e+- ")
|
| 146 |
+
data = {}
|
| 147 |
+
with open(path, 'r') as f:
|
| 148 |
+
for line in f.readlines():
|
| 149 |
+
key, value = line.split(':', 1)
|
| 150 |
+
value = value.strip()
|
| 151 |
+
data[key] = value
|
| 152 |
+
if float_chars.issuperset(value):
|
| 153 |
+
# try to cast to float array
|
| 154 |
+
try:
|
| 155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
| 156 |
+
except ValueError:
|
| 157 |
+
# casting error: data[key] already eq. value, so pass
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
return data
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
| 164 |
+
"""Convert row, col matrix subscripts to linear indices
|
| 165 |
+
"""
|
| 166 |
+
m, n = matrixSize
|
| 167 |
+
return rowSub * (n-1) + colSub - 1
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
| 171 |
+
"""Generate a depth map from velodyne data
|
| 172 |
+
"""
|
| 173 |
+
# load calibration files
|
| 174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
| 175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
| 176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
| 177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
| 178 |
+
|
| 179 |
+
# get image shape
|
| 180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
| 181 |
+
|
| 182 |
+
# compute projection matrix velodyne->image plane
|
| 183 |
+
R_cam2rect = np.eye(4)
|
| 184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
| 185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
| 186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
| 187 |
+
|
| 188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
| 189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
| 190 |
+
velo = load_velodyne_points(velo_filename)
|
| 191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
| 192 |
+
|
| 193 |
+
# project the points to the camera
|
| 194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
| 195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
| 196 |
+
|
| 197 |
+
if vel_depth:
|
| 198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
| 199 |
+
|
| 200 |
+
# check if in bounds
|
| 201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
| 202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
| 203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
| 204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
| 205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
| 206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
| 207 |
+
|
| 208 |
+
# project to image
|
| 209 |
+
depth = np.zeros((im_shape[:2]))
|
| 210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
| 211 |
+
|
| 212 |
+
# find the duplicate points and choose the closest depth
|
| 213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
| 214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
| 215 |
+
for dd in dupe_inds:
|
| 216 |
+
pts = np.where(inds == dd)[0]
|
| 217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
| 218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
| 219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
| 220 |
+
depth[depth < 0] = 0
|
| 221 |
+
|
| 222 |
+
return depth
|
| 223 |
+
|
| 224 |
+
def pil_loader(path):
|
| 225 |
+
# open path as file to avoid ResourceWarning
|
| 226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
| 227 |
+
with open(path, 'rb') as f:
|
| 228 |
+
with Image.open(f) as img:
|
| 229 |
+
return img.convert('RGB')
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MonoDataset(data.Dataset):
|
| 233 |
+
"""Superclass for monocular dataloaders
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
data_path
|
| 237 |
+
filenames
|
| 238 |
+
height
|
| 239 |
+
width
|
| 240 |
+
frame_idxs
|
| 241 |
+
num_scales
|
| 242 |
+
is_train
|
| 243 |
+
img_ext
|
| 244 |
+
"""
|
| 245 |
+
def __init__(self,
|
| 246 |
+
data_path,
|
| 247 |
+
filenames,
|
| 248 |
+
height,
|
| 249 |
+
width,
|
| 250 |
+
frame_idxs,
|
| 251 |
+
num_scales,
|
| 252 |
+
is_train=False,
|
| 253 |
+
img_ext='.jpg'):
|
| 254 |
+
super(MonoDataset, self).__init__()
|
| 255 |
+
|
| 256 |
+
self.data_path = data_path
|
| 257 |
+
self.filenames = filenames
|
| 258 |
+
self.height = height
|
| 259 |
+
self.width = width
|
| 260 |
+
self.num_scales = num_scales
|
| 261 |
+
self.interp = Image.ANTIALIAS
|
| 262 |
+
|
| 263 |
+
self.frame_idxs = frame_idxs
|
| 264 |
+
|
| 265 |
+
self.is_train = is_train
|
| 266 |
+
self.img_ext = img_ext
|
| 267 |
+
|
| 268 |
+
self.loader = pil_loader
|
| 269 |
+
self.to_tensor = transforms.ToTensor()
|
| 270 |
+
|
| 271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
| 272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
| 273 |
+
try:
|
| 274 |
+
self.brightness = (0.8, 1.2)
|
| 275 |
+
self.contrast = (0.8, 1.2)
|
| 276 |
+
self.saturation = (0.8, 1.2)
|
| 277 |
+
self.hue = (-0.1, 0.1)
|
| 278 |
+
transforms.ColorJitter.get_params(
|
| 279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 280 |
+
except TypeError:
|
| 281 |
+
self.brightness = 0.2
|
| 282 |
+
self.contrast = 0.2
|
| 283 |
+
self.saturation = 0.2
|
| 284 |
+
self.hue = 0.1
|
| 285 |
+
|
| 286 |
+
self.resize = {}
|
| 287 |
+
for i in range(self.num_scales):
|
| 288 |
+
s = 2 ** i
|
| 289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
| 290 |
+
interpolation=self.interp)
|
| 291 |
+
|
| 292 |
+
self.load_depth = self.check_depth()
|
| 293 |
+
|
| 294 |
+
def preprocess(self, inputs, color_aug):
|
| 295 |
+
"""Resize colour images to the required scales and augment if required
|
| 296 |
+
|
| 297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
| 298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
| 299 |
+
same augmentation.
|
| 300 |
+
"""
|
| 301 |
+
for k in list(inputs):
|
| 302 |
+
frame = inputs[k]
|
| 303 |
+
if "color" in k:
|
| 304 |
+
n, im, i = k
|
| 305 |
+
for i in range(self.num_scales):
|
| 306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
| 307 |
+
|
| 308 |
+
for k in list(inputs):
|
| 309 |
+
f = inputs[k]
|
| 310 |
+
if "color" in k:
|
| 311 |
+
n, im, i = k
|
| 312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
| 313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
| 314 |
+
|
| 315 |
+
def __len__(self):
|
| 316 |
+
return len(self.filenames)
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, index):
|
| 319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
| 320 |
+
|
| 321 |
+
Values correspond to torch tensors.
|
| 322 |
+
Keys in the dictionary are either strings or tuples:
|
| 323 |
+
|
| 324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
| 325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
| 326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
| 327 |
+
"stereo_T" for camera extrinsics, and
|
| 328 |
+
"depth_gt" for ground truth depth maps.
|
| 329 |
+
|
| 330 |
+
<frame_id> is either:
|
| 331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
| 332 |
+
or
|
| 333 |
+
"s" for the opposite image in the stereo pair.
|
| 334 |
+
|
| 335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
| 336 |
+
-1 images at native resolution as loaded from disk
|
| 337 |
+
0 images resized to (self.width, self.height )
|
| 338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
| 339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
| 340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
| 341 |
+
"""
|
| 342 |
+
inputs = {}
|
| 343 |
+
|
| 344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
| 345 |
+
do_flip = self.is_train and random.random() > 0.5
|
| 346 |
+
|
| 347 |
+
line = self.filenames[index].split()
|
| 348 |
+
folder = line[0]
|
| 349 |
+
|
| 350 |
+
if len(line) == 3:
|
| 351 |
+
frame_index = int(line[1])
|
| 352 |
+
else:
|
| 353 |
+
frame_index = 0
|
| 354 |
+
|
| 355 |
+
if len(line) == 3:
|
| 356 |
+
side = line[2]
|
| 357 |
+
else:
|
| 358 |
+
side = None
|
| 359 |
+
|
| 360 |
+
for i in self.frame_idxs:
|
| 361 |
+
if i == "s":
|
| 362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
| 363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
| 364 |
+
else:
|
| 365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
| 366 |
+
|
| 367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
| 368 |
+
for scale in range(self.num_scales):
|
| 369 |
+
K = self.K.copy()
|
| 370 |
+
|
| 371 |
+
K[0, :] *= self.width // (2 ** scale)
|
| 372 |
+
K[1, :] *= self.height // (2 ** scale)
|
| 373 |
+
|
| 374 |
+
inv_K = np.linalg.pinv(K)
|
| 375 |
+
|
| 376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
| 377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
| 378 |
+
|
| 379 |
+
if do_color_aug:
|
| 380 |
+
color_aug = transforms.ColorJitter.get_params(
|
| 381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 382 |
+
else:
|
| 383 |
+
color_aug = (lambda x: x)
|
| 384 |
+
|
| 385 |
+
self.preprocess(inputs, color_aug)
|
| 386 |
+
|
| 387 |
+
for i in self.frame_idxs:
|
| 388 |
+
del inputs[("color", i, -1)]
|
| 389 |
+
del inputs[("color_aug", i, -1)]
|
| 390 |
+
|
| 391 |
+
if self.load_depth:
|
| 392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
| 393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
| 394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
| 395 |
+
|
| 396 |
+
if "s" in self.frame_idxs:
|
| 397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
| 398 |
+
baseline_sign = -1 if do_flip else 1
|
| 399 |
+
side_sign = -1 if side == "l" else 1
|
| 400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
| 401 |
+
|
| 402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
| 403 |
+
|
| 404 |
+
return inputs
|
| 405 |
+
|
| 406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
| 407 |
+
raise NotImplementedError
|
| 408 |
+
|
| 409 |
+
def check_depth(self):
|
| 410 |
+
raise NotImplementedError
|
| 411 |
+
|
| 412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
| 413 |
+
raise NotImplementedError
|
| 414 |
+
|
| 415 |
+
class KITTIDataset(MonoDataset):
|
| 416 |
+
"""Superclass for different types of KITTI dataset loaders
|
| 417 |
+
"""
|
| 418 |
+
def __init__(self, *args, **kwargs):
|
| 419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
| 420 |
+
|
| 421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
| 422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
| 423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
| 424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
| 425 |
+
# flip augmentation.
|
| 426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
| 427 |
+
[0, 1.92, 0.5, 0],
|
| 428 |
+
[0, 0, 1, 0],
|
| 429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
| 430 |
+
|
| 431 |
+
self.full_res_shape = (1242, 375)
|
| 432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
| 433 |
+
|
| 434 |
+
def check_depth(self):
|
| 435 |
+
line = self.filenames[0].split()
|
| 436 |
+
scene_name = line[0]
|
| 437 |
+
frame_index = int(line[1])
|
| 438 |
+
|
| 439 |
+
velo_filename = os.path.join(
|
| 440 |
+
self.data_path,
|
| 441 |
+
scene_name,
|
| 442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
| 443 |
+
|
| 444 |
+
return os.path.isfile(velo_filename)
|
| 445 |
+
|
| 446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
| 447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
| 448 |
+
|
| 449 |
+
if do_flip:
|
| 450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
| 451 |
+
|
| 452 |
+
return color
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class KITTIDepthDataset(KITTIDataset):
|
| 456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
| 457 |
+
"""
|
| 458 |
+
def __init__(self, *args, **kwargs):
|
| 459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
| 460 |
+
|
| 461 |
+
def get_image_path(self, folder, frame_index, side):
|
| 462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
| 463 |
+
image_path = os.path.join(
|
| 464 |
+
self.data_path,
|
| 465 |
+
folder,
|
| 466 |
+
"image_0{}/data".format(self.side_map[side]),
|
| 467 |
+
f_str)
|
| 468 |
+
return image_path
|
| 469 |
+
|
| 470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
| 471 |
+
f_str = "{:010d}.png".format(frame_index)
|
| 472 |
+
depth_path = os.path.join(
|
| 473 |
+
self.data_path,
|
| 474 |
+
folder,
|
| 475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
| 476 |
+
f_str)
|
| 477 |
+
|
| 478 |
+
depth_gt = Image.open(depth_path)
|
| 479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
| 480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
| 481 |
+
|
| 482 |
+
if do_flip:
|
| 483 |
+
depth_gt = np.fliplr(depth_gt)
|
| 484 |
+
|
| 485 |
+
return depth_gt
|
losses.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
class ContentLoss(nn.Module):
|
| 8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
| 9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
| 10 |
+
|
| 11 |
+
Paper reference list:
|
| 12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
| 13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
| 14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super(ContentLoss, self).__init__()
|
| 20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
| 21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
| 22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
| 23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
| 24 |
+
# Freeze model parameters.
|
| 25 |
+
for parameters in self.feature_extractor.parameters():
|
| 26 |
+
parameters.requires_grad = False
|
| 27 |
+
|
| 28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
| 29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 31 |
+
|
| 32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
| 33 |
+
# Standardized operations
|
| 34 |
+
sr = sr.sub(self.mean).div(self.std)
|
| 35 |
+
hr = hr.sub(self.mean).div(self.std)
|
| 36 |
+
|
| 37 |
+
# Find the feature map difference between the two images
|
| 38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
| 39 |
+
|
| 40 |
+
return loss
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GenGaussLoss(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self, reduction='mean',
|
| 46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
| 47 |
+
resi_min = 1e-4, resi_max=1e3
|
| 48 |
+
) -> None:
|
| 49 |
+
super(GenGaussLoss, self).__init__()
|
| 50 |
+
self.reduction = reduction
|
| 51 |
+
self.alpha_eps = alpha_eps
|
| 52 |
+
self.beta_eps = beta_eps
|
| 53 |
+
self.resi_min = resi_min
|
| 54 |
+
self.resi_max = resi_max
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
| 59 |
+
):
|
| 60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
| 61 |
+
beta1 = beta + self.beta_eps
|
| 62 |
+
|
| 63 |
+
resi = torch.abs(mean - target)
|
| 64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
| 65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
| 66 |
+
## check if resi has nans
|
| 67 |
+
if torch.sum(resi != resi) > 0:
|
| 68 |
+
print('resi has nans!!')
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
| 72 |
+
log_beta = torch.log(beta1)
|
| 73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
| 74 |
+
|
| 75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
| 76 |
+
print('log_one_over_alpha has nan')
|
| 77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
| 78 |
+
print('lgamma_beta has nan')
|
| 79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
| 80 |
+
print('log_beta has nan')
|
| 81 |
+
|
| 82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
| 83 |
+
|
| 84 |
+
if self.reduction == 'mean':
|
| 85 |
+
return l.mean()
|
| 86 |
+
elif self.reduction == 'sum':
|
| 87 |
+
return l.sum()
|
| 88 |
+
else:
|
| 89 |
+
print('Reduction not supported')
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
class TempCombLoss(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self, reduction='mean',
|
| 95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
| 96 |
+
resi_min = 1e-4, resi_max=1e3
|
| 97 |
+
) -> None:
|
| 98 |
+
super(TempCombLoss, self).__init__()
|
| 99 |
+
self.reduction = reduction
|
| 100 |
+
self.alpha_eps = alpha_eps
|
| 101 |
+
self.beta_eps = beta_eps
|
| 102 |
+
self.resi_min = resi_min
|
| 103 |
+
self.resi_max = resi_max
|
| 104 |
+
|
| 105 |
+
self.L_GenGauss = GenGaussLoss(
|
| 106 |
+
reduction=self.reduction,
|
| 107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
| 108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
| 109 |
+
)
|
| 110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
| 115 |
+
T1: float, T2: float
|
| 116 |
+
):
|
| 117 |
+
l1 = self.L_l1(mean, target)
|
| 118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
| 119 |
+
l = T1*l1 + T2*l2
|
| 120 |
+
|
| 121 |
+
return l
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# x1 = torch.randn(4,3,32,32)
|
| 125 |
+
# x2 = torch.rand(4,3,32,32)
|
| 126 |
+
# x3 = torch.rand(4,3,32,32)
|
| 127 |
+
# x4 = torch.randn(4,3,32,32)
|
| 128 |
+
|
| 129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
| 130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
| 131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
networks_SRGAN.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
# __all__ = [
|
| 8 |
+
# "ResidualConvBlock",
|
| 9 |
+
# "Discriminator", "Generator",
|
| 10 |
+
# ]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualConvBlock(nn.Module):
|
| 14 |
+
"""Implements residual conv function.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
channels (int): Number of channels in the input image.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, channels: int) -> None:
|
| 21 |
+
super(ResidualConvBlock, self).__init__()
|
| 22 |
+
self.rcb = nn.Sequential(
|
| 23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
| 24 |
+
nn.BatchNorm2d(channels),
|
| 25 |
+
nn.PReLU(),
|
| 26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
| 27 |
+
nn.BatchNorm2d(channels),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
identity = x
|
| 32 |
+
|
| 33 |
+
out = self.rcb(x)
|
| 34 |
+
out = torch.add(out, identity)
|
| 35 |
+
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Discriminator(nn.Module):
|
| 40 |
+
def __init__(self) -> None:
|
| 41 |
+
super(Discriminator, self).__init__()
|
| 42 |
+
self.features = nn.Sequential(
|
| 43 |
+
# input size. (3) x 96 x 96
|
| 44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
| 45 |
+
nn.LeakyReLU(0.2, True),
|
| 46 |
+
# state size. (64) x 48 x 48
|
| 47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
| 48 |
+
nn.BatchNorm2d(64),
|
| 49 |
+
nn.LeakyReLU(0.2, True),
|
| 50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
| 51 |
+
nn.BatchNorm2d(128),
|
| 52 |
+
nn.LeakyReLU(0.2, True),
|
| 53 |
+
# state size. (128) x 24 x 24
|
| 54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
| 55 |
+
nn.BatchNorm2d(128),
|
| 56 |
+
nn.LeakyReLU(0.2, True),
|
| 57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
| 58 |
+
nn.BatchNorm2d(256),
|
| 59 |
+
nn.LeakyReLU(0.2, True),
|
| 60 |
+
# state size. (256) x 12 x 12
|
| 61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
| 62 |
+
nn.BatchNorm2d(256),
|
| 63 |
+
nn.LeakyReLU(0.2, True),
|
| 64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
| 65 |
+
nn.BatchNorm2d(512),
|
| 66 |
+
nn.LeakyReLU(0.2, True),
|
| 67 |
+
# state size. (512) x 6 x 6
|
| 68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
| 69 |
+
nn.BatchNorm2d(512),
|
| 70 |
+
nn.LeakyReLU(0.2, True),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.classifier = nn.Sequential(
|
| 74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
| 75 |
+
nn.LeakyReLU(0.2, True),
|
| 76 |
+
nn.Linear(1024, 1),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 80 |
+
out = self.features(x)
|
| 81 |
+
out = torch.flatten(out, 1)
|
| 82 |
+
out = self.classifier(out)
|
| 83 |
+
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Generator(nn.Module):
|
| 88 |
+
def __init__(self) -> None:
|
| 89 |
+
super(Generator, self).__init__()
|
| 90 |
+
# First conv layer.
|
| 91 |
+
self.conv_block1 = nn.Sequential(
|
| 92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
| 93 |
+
nn.PReLU(),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Features trunk blocks.
|
| 97 |
+
trunk = []
|
| 98 |
+
for _ in range(16):
|
| 99 |
+
trunk.append(ResidualConvBlock(64))
|
| 100 |
+
self.trunk = nn.Sequential(*trunk)
|
| 101 |
+
|
| 102 |
+
# Second conv layer.
|
| 103 |
+
self.conv_block2 = nn.Sequential(
|
| 104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
| 105 |
+
nn.BatchNorm2d(64),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Upscale conv block.
|
| 109 |
+
self.upsampling = nn.Sequential(
|
| 110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
| 111 |
+
nn.PixelShuffle(2),
|
| 112 |
+
nn.PReLU(),
|
| 113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
| 114 |
+
nn.PixelShuffle(2),
|
| 115 |
+
nn.PReLU(),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Output layer.
|
| 119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
| 120 |
+
|
| 121 |
+
# Initialize neural network weights.
|
| 122 |
+
self._initialize_weights()
|
| 123 |
+
|
| 124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
| 125 |
+
if not dop:
|
| 126 |
+
return self._forward_impl(x)
|
| 127 |
+
else:
|
| 128 |
+
return self._forward_w_dop_impl(x, dop)
|
| 129 |
+
|
| 130 |
+
# Support torch.script function.
|
| 131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 132 |
+
out1 = self.conv_block1(x)
|
| 133 |
+
out = self.trunk(out1)
|
| 134 |
+
out2 = self.conv_block2(out)
|
| 135 |
+
out = torch.add(out1, out2)
|
| 136 |
+
out = self.upsampling(out)
|
| 137 |
+
out = self.conv_block3(out)
|
| 138 |
+
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
| 142 |
+
out1 = self.conv_block1(x)
|
| 143 |
+
out = self.trunk(out1)
|
| 144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
| 145 |
+
out = torch.add(out1, out2)
|
| 146 |
+
out = self.upsampling(out)
|
| 147 |
+
out = self.conv_block3(out)
|
| 148 |
+
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
def _initialize_weights(self) -> None:
|
| 152 |
+
for module in self.modules():
|
| 153 |
+
if isinstance(module, nn.Conv2d):
|
| 154 |
+
nn.init.kaiming_normal_(module.weight)
|
| 155 |
+
if module.bias is not None:
|
| 156 |
+
nn.init.constant_(module.bias, 0)
|
| 157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 158 |
+
nn.init.constant_(module.weight, 1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
#### BayesCap
|
| 162 |
+
class BayesCap(nn.Module):
|
| 163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
| 164 |
+
super(BayesCap, self).__init__()
|
| 165 |
+
# First conv layer.
|
| 166 |
+
self.conv_block1 = nn.Sequential(
|
| 167 |
+
nn.Conv2d(
|
| 168 |
+
in_channels, 64,
|
| 169 |
+
kernel_size=9, stride=1, padding=4
|
| 170 |
+
),
|
| 171 |
+
nn.PReLU(),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Features trunk blocks.
|
| 175 |
+
trunk = []
|
| 176 |
+
for _ in range(16):
|
| 177 |
+
trunk.append(ResidualConvBlock(64))
|
| 178 |
+
self.trunk = nn.Sequential(*trunk)
|
| 179 |
+
|
| 180 |
+
# Second conv layer.
|
| 181 |
+
self.conv_block2 = nn.Sequential(
|
| 182 |
+
nn.Conv2d(
|
| 183 |
+
64, 64,
|
| 184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
| 185 |
+
),
|
| 186 |
+
nn.BatchNorm2d(64),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Output layer.
|
| 190 |
+
self.conv_block3_mu = nn.Conv2d(
|
| 191 |
+
64, out_channels=out_channels,
|
| 192 |
+
kernel_size=9, stride=1, padding=4
|
| 193 |
+
)
|
| 194 |
+
self.conv_block3_alpha = nn.Sequential(
|
| 195 |
+
nn.Conv2d(
|
| 196 |
+
64, 64,
|
| 197 |
+
kernel_size=9, stride=1, padding=4
|
| 198 |
+
),
|
| 199 |
+
nn.PReLU(),
|
| 200 |
+
nn.Conv2d(
|
| 201 |
+
64, 64,
|
| 202 |
+
kernel_size=9, stride=1, padding=4
|
| 203 |
+
),
|
| 204 |
+
nn.PReLU(),
|
| 205 |
+
nn.Conv2d(
|
| 206 |
+
64, 1,
|
| 207 |
+
kernel_size=9, stride=1, padding=4
|
| 208 |
+
),
|
| 209 |
+
nn.ReLU(),
|
| 210 |
+
)
|
| 211 |
+
self.conv_block3_beta = nn.Sequential(
|
| 212 |
+
nn.Conv2d(
|
| 213 |
+
64, 64,
|
| 214 |
+
kernel_size=9, stride=1, padding=4
|
| 215 |
+
),
|
| 216 |
+
nn.PReLU(),
|
| 217 |
+
nn.Conv2d(
|
| 218 |
+
64, 64,
|
| 219 |
+
kernel_size=9, stride=1, padding=4
|
| 220 |
+
),
|
| 221 |
+
nn.PReLU(),
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
64, 1,
|
| 224 |
+
kernel_size=9, stride=1, padding=4
|
| 225 |
+
),
|
| 226 |
+
nn.ReLU(),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Initialize neural network weights.
|
| 230 |
+
self._initialize_weights()
|
| 231 |
+
|
| 232 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 233 |
+
return self._forward_impl(x)
|
| 234 |
+
|
| 235 |
+
# Support torch.script function.
|
| 236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 237 |
+
out1 = self.conv_block1(x)
|
| 238 |
+
out = self.trunk(out1)
|
| 239 |
+
out2 = self.conv_block2(out)
|
| 240 |
+
out = out1 + out2
|
| 241 |
+
out_mu = self.conv_block3_mu(out)
|
| 242 |
+
out_alpha = self.conv_block3_alpha(out)
|
| 243 |
+
out_beta = self.conv_block3_beta(out)
|
| 244 |
+
return out_mu, out_alpha, out_beta
|
| 245 |
+
|
| 246 |
+
def _initialize_weights(self) -> None:
|
| 247 |
+
for module in self.modules():
|
| 248 |
+
if isinstance(module, nn.Conv2d):
|
| 249 |
+
nn.init.kaiming_normal_(module.weight)
|
| 250 |
+
if module.bias is not None:
|
| 251 |
+
nn.init.constant_(module.bias, 0)
|
| 252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 253 |
+
nn.init.constant_(module.weight, 1)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class BayesCap_noID(nn.Module):
|
| 257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
| 258 |
+
super(BayesCap_noID, self).__init__()
|
| 259 |
+
# First conv layer.
|
| 260 |
+
self.conv_block1 = nn.Sequential(
|
| 261 |
+
nn.Conv2d(
|
| 262 |
+
in_channels, 64,
|
| 263 |
+
kernel_size=9, stride=1, padding=4
|
| 264 |
+
),
|
| 265 |
+
nn.PReLU(),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Features trunk blocks.
|
| 269 |
+
trunk = []
|
| 270 |
+
for _ in range(16):
|
| 271 |
+
trunk.append(ResidualConvBlock(64))
|
| 272 |
+
self.trunk = nn.Sequential(*trunk)
|
| 273 |
+
|
| 274 |
+
# Second conv layer.
|
| 275 |
+
self.conv_block2 = nn.Sequential(
|
| 276 |
+
nn.Conv2d(
|
| 277 |
+
64, 64,
|
| 278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
| 279 |
+
),
|
| 280 |
+
nn.BatchNorm2d(64),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Output layer.
|
| 284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
| 285 |
+
# 64, out_channels=out_channels,
|
| 286 |
+
# kernel_size=9, stride=1, padding=4
|
| 287 |
+
# )
|
| 288 |
+
self.conv_block3_alpha = nn.Sequential(
|
| 289 |
+
nn.Conv2d(
|
| 290 |
+
64, 64,
|
| 291 |
+
kernel_size=9, stride=1, padding=4
|
| 292 |
+
),
|
| 293 |
+
nn.PReLU(),
|
| 294 |
+
nn.Conv2d(
|
| 295 |
+
64, 64,
|
| 296 |
+
kernel_size=9, stride=1, padding=4
|
| 297 |
+
),
|
| 298 |
+
nn.PReLU(),
|
| 299 |
+
nn.Conv2d(
|
| 300 |
+
64, 1,
|
| 301 |
+
kernel_size=9, stride=1, padding=4
|
| 302 |
+
),
|
| 303 |
+
nn.ReLU(),
|
| 304 |
+
)
|
| 305 |
+
self.conv_block3_beta = nn.Sequential(
|
| 306 |
+
nn.Conv2d(
|
| 307 |
+
64, 64,
|
| 308 |
+
kernel_size=9, stride=1, padding=4
|
| 309 |
+
),
|
| 310 |
+
nn.PReLU(),
|
| 311 |
+
nn.Conv2d(
|
| 312 |
+
64, 64,
|
| 313 |
+
kernel_size=9, stride=1, padding=4
|
| 314 |
+
),
|
| 315 |
+
nn.PReLU(),
|
| 316 |
+
nn.Conv2d(
|
| 317 |
+
64, 1,
|
| 318 |
+
kernel_size=9, stride=1, padding=4
|
| 319 |
+
),
|
| 320 |
+
nn.ReLU(),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Initialize neural network weights.
|
| 324 |
+
self._initialize_weights()
|
| 325 |
+
|
| 326 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 327 |
+
return self._forward_impl(x)
|
| 328 |
+
|
| 329 |
+
# Support torch.script function.
|
| 330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 331 |
+
out1 = self.conv_block1(x)
|
| 332 |
+
out = self.trunk(out1)
|
| 333 |
+
out2 = self.conv_block2(out)
|
| 334 |
+
out = out1 + out2
|
| 335 |
+
# out_mu = self.conv_block3_mu(out)
|
| 336 |
+
out_alpha = self.conv_block3_alpha(out)
|
| 337 |
+
out_beta = self.conv_block3_beta(out)
|
| 338 |
+
return out_alpha, out_beta
|
| 339 |
+
|
| 340 |
+
def _initialize_weights(self) -> None:
|
| 341 |
+
for module in self.modules():
|
| 342 |
+
if isinstance(module, nn.Conv2d):
|
| 343 |
+
nn.init.kaiming_normal_(module.weight)
|
| 344 |
+
if module.bias is not None:
|
| 345 |
+
nn.init.constant_(module.bias, 0)
|
| 346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 347 |
+
nn.init.constant_(module.weight, 1)
|
networks_T1toT2.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
### components
|
| 7 |
+
class ResConv(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Residual convolutional block, where
|
| 10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
| 11 |
+
residual connection adds the input to the output
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
if not mid_channels:
|
| 16 |
+
mid_channels = out_channels
|
| 17 |
+
self.double_conv = nn.Sequential(
|
| 18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
| 19 |
+
nn.BatchNorm2d(mid_channels),
|
| 20 |
+
nn.ReLU(inplace=True),
|
| 21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
| 22 |
+
nn.BatchNorm2d(mid_channels),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
| 25 |
+
nn.BatchNorm2d(out_channels),
|
| 26 |
+
nn.ReLU(inplace=True)
|
| 27 |
+
)
|
| 28 |
+
self.double_conv1 = nn.Sequential(
|
| 29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 30 |
+
nn.BatchNorm2d(out_channels),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
)
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x_in = self.double_conv1(x)
|
| 35 |
+
x1 = self.double_conv(x)
|
| 36 |
+
return self.double_conv(x) + x_in
|
| 37 |
+
|
| 38 |
+
class Down(nn.Module):
|
| 39 |
+
"""Downscaling with maxpool then Resconv"""
|
| 40 |
+
def __init__(self, in_channels, out_channels):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.maxpool_conv = nn.Sequential(
|
| 43 |
+
nn.MaxPool2d(2),
|
| 44 |
+
ResConv(in_channels, out_channels)
|
| 45 |
+
)
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.maxpool_conv(x)
|
| 48 |
+
|
| 49 |
+
class Up(nn.Module):
|
| 50 |
+
"""Upscaling then double conv"""
|
| 51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 54 |
+
if bilinear:
|
| 55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
| 57 |
+
else:
|
| 58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
| 59 |
+
self.conv = ResConv(in_channels, out_channels)
|
| 60 |
+
def forward(self, x1, x2):
|
| 61 |
+
x1 = self.up(x1)
|
| 62 |
+
# input is CHW
|
| 63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 65 |
+
x1 = F.pad(
|
| 66 |
+
x1,
|
| 67 |
+
[
|
| 68 |
+
diffX // 2, diffX - diffX // 2,
|
| 69 |
+
diffY // 2, diffY - diffY // 2
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
# if you have padding issues, see
|
| 73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 75 |
+
x = torch.cat([x2, x1], dim=1)
|
| 76 |
+
return self.conv(x)
|
| 77 |
+
|
| 78 |
+
class OutConv(nn.Module):
|
| 79 |
+
def __init__(self, in_channels, out_channels):
|
| 80 |
+
super(OutConv, self).__init__()
|
| 81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
# return F.relu(self.conv(x))
|
| 84 |
+
return self.conv(x)
|
| 85 |
+
|
| 86 |
+
##### The composite networks
|
| 87 |
+
class UNet(nn.Module):
|
| 88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 89 |
+
super(UNet, self).__init__()
|
| 90 |
+
self.n_channels = n_channels
|
| 91 |
+
self.out_channels = out_channels
|
| 92 |
+
self.bilinear = bilinear
|
| 93 |
+
####
|
| 94 |
+
self.inc = ResConv(n_channels, 64)
|
| 95 |
+
self.down1 = Down(64, 128)
|
| 96 |
+
self.down2 = Down(128, 256)
|
| 97 |
+
self.down3 = Down(256, 512)
|
| 98 |
+
factor = 2 if bilinear else 1
|
| 99 |
+
self.down4 = Down(512, 1024 // factor)
|
| 100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 103 |
+
self.up4 = Up(128, 64, bilinear)
|
| 104 |
+
self.outc = OutConv(64, out_channels)
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x1 = self.inc(x)
|
| 107 |
+
x2 = self.down1(x1)
|
| 108 |
+
x3 = self.down2(x2)
|
| 109 |
+
x4 = self.down3(x3)
|
| 110 |
+
x5 = self.down4(x4)
|
| 111 |
+
x = self.up1(x5, x4)
|
| 112 |
+
x = self.up2(x, x3)
|
| 113 |
+
x = self.up3(x, x2)
|
| 114 |
+
x = self.up4(x, x1)
|
| 115 |
+
y = self.outc(x)
|
| 116 |
+
return y
|
| 117 |
+
|
| 118 |
+
class CasUNet(nn.Module):
|
| 119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 120 |
+
super(CasUNet, self).__init__()
|
| 121 |
+
self.n_unet = n_unet
|
| 122 |
+
self.io_channels = io_channels
|
| 123 |
+
self.bilinear = bilinear
|
| 124 |
+
####
|
| 125 |
+
self.unet_list = nn.ModuleList()
|
| 126 |
+
for i in range(self.n_unet):
|
| 127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 128 |
+
def forward(self, x, dop=None):
|
| 129 |
+
y = x
|
| 130 |
+
for i in range(self.n_unet):
|
| 131 |
+
if i==0:
|
| 132 |
+
if dop is not None:
|
| 133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
| 134 |
+
else:
|
| 135 |
+
y = self.unet_list[i](y)
|
| 136 |
+
else:
|
| 137 |
+
y = self.unet_list[i](y+x)
|
| 138 |
+
return y
|
| 139 |
+
|
| 140 |
+
class CasUNet_2head(nn.Module):
|
| 141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 142 |
+
super(CasUNet_2head, self).__init__()
|
| 143 |
+
self.n_unet = n_unet
|
| 144 |
+
self.io_channels = io_channels
|
| 145 |
+
self.bilinear = bilinear
|
| 146 |
+
####
|
| 147 |
+
self.unet_list = nn.ModuleList()
|
| 148 |
+
for i in range(self.n_unet):
|
| 149 |
+
if i != self.n_unet-1:
|
| 150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 151 |
+
else:
|
| 152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
y = x
|
| 155 |
+
for i in range(self.n_unet):
|
| 156 |
+
if i==0:
|
| 157 |
+
y = self.unet_list[i](y)
|
| 158 |
+
else:
|
| 159 |
+
y = self.unet_list[i](y+x)
|
| 160 |
+
y_mean, y_sigma = y[0], y[1]
|
| 161 |
+
return y_mean, y_sigma
|
| 162 |
+
|
| 163 |
+
class CasUNet_3head(nn.Module):
|
| 164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 165 |
+
super(CasUNet_3head, self).__init__()
|
| 166 |
+
self.n_unet = n_unet
|
| 167 |
+
self.io_channels = io_channels
|
| 168 |
+
self.bilinear = bilinear
|
| 169 |
+
####
|
| 170 |
+
self.unet_list = nn.ModuleList()
|
| 171 |
+
for i in range(self.n_unet):
|
| 172 |
+
if i != self.n_unet-1:
|
| 173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 174 |
+
else:
|
| 175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
y = x
|
| 178 |
+
for i in range(self.n_unet):
|
| 179 |
+
if i==0:
|
| 180 |
+
y = self.unet_list[i](y)
|
| 181 |
+
else:
|
| 182 |
+
y = self.unet_list[i](y+x)
|
| 183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
| 184 |
+
return y_mean, y_alpha, y_beta
|
| 185 |
+
|
| 186 |
+
class UNet_2head(nn.Module):
|
| 187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 188 |
+
super(UNet_2head, self).__init__()
|
| 189 |
+
self.n_channels = n_channels
|
| 190 |
+
self.out_channels = out_channels
|
| 191 |
+
self.bilinear = bilinear
|
| 192 |
+
####
|
| 193 |
+
self.inc = ResConv(n_channels, 64)
|
| 194 |
+
self.down1 = Down(64, 128)
|
| 195 |
+
self.down2 = Down(128, 256)
|
| 196 |
+
self.down3 = Down(256, 512)
|
| 197 |
+
factor = 2 if bilinear else 1
|
| 198 |
+
self.down4 = Down(512, 1024 // factor)
|
| 199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 202 |
+
self.up4 = Up(128, 64, bilinear)
|
| 203 |
+
#per pixel multiple channels may exist
|
| 204 |
+
self.out_mean = OutConv(64, out_channels)
|
| 205 |
+
#variance will always be a single number for a pixel
|
| 206 |
+
self.out_var = nn.Sequential(
|
| 207 |
+
OutConv(64, 128),
|
| 208 |
+
OutConv(128, 1),
|
| 209 |
+
)
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
x1 = self.inc(x)
|
| 212 |
+
x2 = self.down1(x1)
|
| 213 |
+
x3 = self.down2(x2)
|
| 214 |
+
x4 = self.down3(x3)
|
| 215 |
+
x5 = self.down4(x4)
|
| 216 |
+
x = self.up1(x5, x4)
|
| 217 |
+
x = self.up2(x, x3)
|
| 218 |
+
x = self.up3(x, x2)
|
| 219 |
+
x = self.up4(x, x1)
|
| 220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
| 221 |
+
return y_mean, y_var
|
| 222 |
+
|
| 223 |
+
class UNet_3head(nn.Module):
|
| 224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 225 |
+
super(UNet_3head, self).__init__()
|
| 226 |
+
self.n_channels = n_channels
|
| 227 |
+
self.out_channels = out_channels
|
| 228 |
+
self.bilinear = bilinear
|
| 229 |
+
####
|
| 230 |
+
self.inc = ResConv(n_channels, 64)
|
| 231 |
+
self.down1 = Down(64, 128)
|
| 232 |
+
self.down2 = Down(128, 256)
|
| 233 |
+
self.down3 = Down(256, 512)
|
| 234 |
+
factor = 2 if bilinear else 1
|
| 235 |
+
self.down4 = Down(512, 1024 // factor)
|
| 236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 239 |
+
self.up4 = Up(128, 64, bilinear)
|
| 240 |
+
#per pixel multiple channels may exist
|
| 241 |
+
self.out_mean = OutConv(64, out_channels)
|
| 242 |
+
#variance will always be a single number for a pixel
|
| 243 |
+
self.out_alpha = nn.Sequential(
|
| 244 |
+
OutConv(64, 128),
|
| 245 |
+
OutConv(128, 1),
|
| 246 |
+
nn.ReLU()
|
| 247 |
+
)
|
| 248 |
+
self.out_beta = nn.Sequential(
|
| 249 |
+
OutConv(64, 128),
|
| 250 |
+
OutConv(128, 1),
|
| 251 |
+
nn.ReLU()
|
| 252 |
+
)
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
x1 = self.inc(x)
|
| 255 |
+
x2 = self.down1(x1)
|
| 256 |
+
x3 = self.down2(x2)
|
| 257 |
+
x4 = self.down3(x3)
|
| 258 |
+
x5 = self.down4(x4)
|
| 259 |
+
x = self.up1(x5, x4)
|
| 260 |
+
x = self.up2(x, x3)
|
| 261 |
+
x = self.up3(x, x2)
|
| 262 |
+
x = self.up4(x, x1)
|
| 263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
| 264 |
+
self.out_alpha(x), self.out_beta(x)
|
| 265 |
+
return y_mean, y_alpha, y_beta
|
| 266 |
+
|
| 267 |
+
class ResidualBlock(nn.Module):
|
| 268 |
+
def __init__(self, in_features):
|
| 269 |
+
super(ResidualBlock, self).__init__()
|
| 270 |
+
conv_block = [
|
| 271 |
+
nn.ReflectionPad2d(1),
|
| 272 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 273 |
+
nn.InstanceNorm2d(in_features),
|
| 274 |
+
nn.ReLU(inplace=True),
|
| 275 |
+
nn.ReflectionPad2d(1),
|
| 276 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 277 |
+
nn.InstanceNorm2d(in_features)
|
| 278 |
+
]
|
| 279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 280 |
+
def forward(self, x):
|
| 281 |
+
return x + self.conv_block(x)
|
| 282 |
+
|
| 283 |
+
class Generator(nn.Module):
|
| 284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
| 285 |
+
super(Generator, self).__init__()
|
| 286 |
+
# Initial convolution block
|
| 287 |
+
model = [
|
| 288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
| 289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
| 290 |
+
]
|
| 291 |
+
# Downsampling
|
| 292 |
+
in_features = 64
|
| 293 |
+
out_features = in_features*2
|
| 294 |
+
for _ in range(2):
|
| 295 |
+
model += [
|
| 296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
| 297 |
+
nn.InstanceNorm2d(out_features),
|
| 298 |
+
nn.ReLU(inplace=True)
|
| 299 |
+
]
|
| 300 |
+
in_features = out_features
|
| 301 |
+
out_features = in_features*2
|
| 302 |
+
# Residual blocks
|
| 303 |
+
for _ in range(n_residual_blocks):
|
| 304 |
+
model += [ResidualBlock(in_features)]
|
| 305 |
+
# Upsampling
|
| 306 |
+
out_features = in_features//2
|
| 307 |
+
for _ in range(2):
|
| 308 |
+
model += [
|
| 309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
| 310 |
+
nn.InstanceNorm2d(out_features),
|
| 311 |
+
nn.ReLU(inplace=True)
|
| 312 |
+
]
|
| 313 |
+
in_features = out_features
|
| 314 |
+
out_features = in_features//2
|
| 315 |
+
# Output layer
|
| 316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
| 317 |
+
self.model = nn.Sequential(*model)
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
return self.model(x)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class ResnetGenerator(nn.Module):
|
| 323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
| 324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
| 328 |
+
"""Construct a Resnet-based generator
|
| 329 |
+
Parameters:
|
| 330 |
+
input_nc (int) -- the number of channels in input images
|
| 331 |
+
output_nc (int) -- the number of channels in output images
|
| 332 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 333 |
+
norm_layer -- normalization layer
|
| 334 |
+
use_dropout (bool) -- if use dropout layers
|
| 335 |
+
n_blocks (int) -- the number of ResNet blocks
|
| 336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
| 337 |
+
"""
|
| 338 |
+
assert(n_blocks >= 0)
|
| 339 |
+
super(ResnetGenerator, self).__init__()
|
| 340 |
+
if type(norm_layer) == functools.partial:
|
| 341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 342 |
+
else:
|
| 343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 344 |
+
|
| 345 |
+
model = [nn.ReflectionPad2d(3),
|
| 346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
| 347 |
+
norm_layer(ngf),
|
| 348 |
+
nn.ReLU(True)]
|
| 349 |
+
|
| 350 |
+
n_downsampling = 2
|
| 351 |
+
for i in range(n_downsampling): # add downsampling layers
|
| 352 |
+
mult = 2 ** i
|
| 353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
| 354 |
+
norm_layer(ngf * mult * 2),
|
| 355 |
+
nn.ReLU(True)]
|
| 356 |
+
|
| 357 |
+
mult = 2 ** n_downsampling
|
| 358 |
+
for i in range(n_blocks): # add ResNet blocks
|
| 359 |
+
|
| 360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
| 361 |
+
|
| 362 |
+
for i in range(n_downsampling): # add upsampling layers
|
| 363 |
+
mult = 2 ** (n_downsampling - i)
|
| 364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
| 365 |
+
kernel_size=3, stride=2,
|
| 366 |
+
padding=1, output_padding=1,
|
| 367 |
+
bias=use_bias),
|
| 368 |
+
norm_layer(int(ngf * mult / 2)),
|
| 369 |
+
nn.ReLU(True)]
|
| 370 |
+
model += [nn.ReflectionPad2d(3)]
|
| 371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
| 372 |
+
model += [nn.Tanh()]
|
| 373 |
+
|
| 374 |
+
self.model = nn.Sequential(*model)
|
| 375 |
+
|
| 376 |
+
def forward(self, input):
|
| 377 |
+
"""Standard forward"""
|
| 378 |
+
return self.model(input)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class ResnetBlock(nn.Module):
|
| 382 |
+
"""Define a Resnet block"""
|
| 383 |
+
|
| 384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 385 |
+
"""Initialize the Resnet block
|
| 386 |
+
A resnet block is a conv block with skip connections
|
| 387 |
+
We construct a conv block with build_conv_block function,
|
| 388 |
+
and implement skip connections in <forward> function.
|
| 389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
| 390 |
+
"""
|
| 391 |
+
super(ResnetBlock, self).__init__()
|
| 392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
| 393 |
+
|
| 394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 395 |
+
"""Construct a convolutional block.
|
| 396 |
+
Parameters:
|
| 397 |
+
dim (int) -- the number of channels in the conv layer.
|
| 398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
| 399 |
+
norm_layer -- normalization layer
|
| 400 |
+
use_dropout (bool) -- if use dropout layers.
|
| 401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
| 402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
| 403 |
+
"""
|
| 404 |
+
conv_block = []
|
| 405 |
+
p = 0
|
| 406 |
+
if padding_type == 'reflect':
|
| 407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 408 |
+
elif padding_type == 'replicate':
|
| 409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 410 |
+
elif padding_type == 'zero':
|
| 411 |
+
p = 1
|
| 412 |
+
else:
|
| 413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 414 |
+
|
| 415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
| 416 |
+
if use_dropout:
|
| 417 |
+
conv_block += [nn.Dropout(0.5)]
|
| 418 |
+
|
| 419 |
+
p = 0
|
| 420 |
+
if padding_type == 'reflect':
|
| 421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 422 |
+
elif padding_type == 'replicate':
|
| 423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 424 |
+
elif padding_type == 'zero':
|
| 425 |
+
p = 1
|
| 426 |
+
else:
|
| 427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
| 429 |
+
|
| 430 |
+
return nn.Sequential(*conv_block)
|
| 431 |
+
|
| 432 |
+
def forward(self, x):
|
| 433 |
+
"""Forward function (with skip connections)"""
|
| 434 |
+
out = x + self.conv_block(x) # add skip connections
|
| 435 |
+
return out
|
| 436 |
+
|
| 437 |
+
### discriminator
|
| 438 |
+
class NLayerDiscriminator(nn.Module):
|
| 439 |
+
"""Defines a PatchGAN discriminator"""
|
| 440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
| 441 |
+
"""Construct a PatchGAN discriminator
|
| 442 |
+
Parameters:
|
| 443 |
+
input_nc (int) -- the number of channels in input images
|
| 444 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
| 446 |
+
norm_layer -- normalization layer
|
| 447 |
+
"""
|
| 448 |
+
super(NLayerDiscriminator, self).__init__()
|
| 449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
| 450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 451 |
+
else:
|
| 452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 453 |
+
kw = 4
|
| 454 |
+
padw = 1
|
| 455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| 456 |
+
nf_mult = 1
|
| 457 |
+
nf_mult_prev = 1
|
| 458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 459 |
+
nf_mult_prev = nf_mult
|
| 460 |
+
nf_mult = min(2 ** n, 8)
|
| 461 |
+
sequence += [
|
| 462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| 463 |
+
norm_layer(ndf * nf_mult),
|
| 464 |
+
nn.LeakyReLU(0.2, True)
|
| 465 |
+
]
|
| 466 |
+
nf_mult_prev = nf_mult
|
| 467 |
+
nf_mult = min(2 ** n_layers, 8)
|
| 468 |
+
sequence += [
|
| 469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| 470 |
+
norm_layer(ndf * nf_mult),
|
| 471 |
+
nn.LeakyReLU(0.2, True)
|
| 472 |
+
]
|
| 473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
| 474 |
+
self.model = nn.Sequential(*sequence)
|
| 475 |
+
def forward(self, input):
|
| 476 |
+
"""Standard forward."""
|
| 477 |
+
return self.model(input)
|
requirements.txt
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file may be used to create an environment using:
|
| 2 |
+
# $ conda create --name <env> --file <this file>
|
| 3 |
+
# platform: linux-64
|
| 4 |
+
_libgcc_mutex=0.1=conda_forge
|
| 5 |
+
_openmp_mutex=4.5=2_kmp_llvm
|
| 6 |
+
aiohttp=3.8.1=pypi_0
|
| 7 |
+
aiosignal=1.2.0=pypi_0
|
| 8 |
+
albumentations=1.2.0=pyhd8ed1ab_0
|
| 9 |
+
alsa-lib=1.2.6.1=h7f98852_0
|
| 10 |
+
analytics-python=1.4.0=pypi_0
|
| 11 |
+
anyio=3.6.1=pypi_0
|
| 12 |
+
aom=3.3.0=h27087fc_1
|
| 13 |
+
argon2-cffi=21.3.0=pypi_0
|
| 14 |
+
argon2-cffi-bindings=21.2.0=pypi_0
|
| 15 |
+
asttokens=2.0.5=pypi_0
|
| 16 |
+
async-timeout=4.0.2=pypi_0
|
| 17 |
+
attr=2.5.1=h166bdaf_0
|
| 18 |
+
attrs=21.4.0=pypi_0
|
| 19 |
+
babel=2.10.1=pypi_0
|
| 20 |
+
backcall=0.2.0=pypi_0
|
| 21 |
+
backoff=1.10.0=pypi_0
|
| 22 |
+
bcrypt=3.2.2=pypi_0
|
| 23 |
+
beautifulsoup4=4.11.1=pypi_0
|
| 24 |
+
blas=1.0=mkl
|
| 25 |
+
bleach=5.0.0=pypi_0
|
| 26 |
+
blosc=1.21.1=h83bc5f7_3
|
| 27 |
+
brotli=1.0.9=h166bdaf_7
|
| 28 |
+
brotli-bin=1.0.9=h166bdaf_7
|
| 29 |
+
brotlipy=0.7.0=py310h7f8727e_1002
|
| 30 |
+
brunsli=0.1=h9c3ff4c_0
|
| 31 |
+
bzip2=1.0.8=h7b6447c_0
|
| 32 |
+
c-ares=1.18.1=h7f98852_0
|
| 33 |
+
c-blosc2=2.2.0=h7a311fb_0
|
| 34 |
+
ca-certificates=2022.6.15=ha878542_0
|
| 35 |
+
cairo=1.16.0=ha61ee94_1011
|
| 36 |
+
certifi=2022.6.15=py310hff52083_0
|
| 37 |
+
cffi=1.15.0=py310hd667e15_1
|
| 38 |
+
cfitsio=4.1.0=hd9d235c_0
|
| 39 |
+
charls=2.3.4=h9c3ff4c_0
|
| 40 |
+
charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 41 |
+
click=8.1.3=pypi_0
|
| 42 |
+
cloudpickle=2.1.0=pyhd8ed1ab_0
|
| 43 |
+
cryptography=37.0.1=py310h9ce1e76_0
|
| 44 |
+
cudatoolkit=10.2.89=hfd86e86_1
|
| 45 |
+
cycler=0.11.0=pypi_0
|
| 46 |
+
cytoolz=0.11.2=py310h5764c6d_2
|
| 47 |
+
dask-core=2022.7.0=pyhd8ed1ab_0
|
| 48 |
+
dbus=1.13.6=h5008d03_3
|
| 49 |
+
debugpy=1.6.0=pypi_0
|
| 50 |
+
decorator=5.1.1=pypi_0
|
| 51 |
+
defusedxml=0.7.1=pypi_0
|
| 52 |
+
entrypoints=0.4=pypi_0
|
| 53 |
+
executing=0.8.3=pypi_0
|
| 54 |
+
expat=2.4.8=h27087fc_0
|
| 55 |
+
fastapi=0.78.0=pypi_0
|
| 56 |
+
fastjsonschema=2.15.3=pypi_0
|
| 57 |
+
ffmpeg=4.4.2=habc3f16_0
|
| 58 |
+
ffmpy=0.3.0=pypi_0
|
| 59 |
+
fftw=3.3.10=nompi_h77c792f_102
|
| 60 |
+
fire=0.4.0=pypi_0
|
| 61 |
+
font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
| 62 |
+
font-ttf-inconsolata=3.000=h77eed37_0
|
| 63 |
+
font-ttf-source-code-pro=2.038=h77eed37_0
|
| 64 |
+
font-ttf-ubuntu=0.83=hab24e00_0
|
| 65 |
+
fontconfig=2.14.0=h8e229c2_0
|
| 66 |
+
fonts-conda-ecosystem=1=0
|
| 67 |
+
fonts-conda-forge=1=0
|
| 68 |
+
fonttools=4.33.3=pypi_0
|
| 69 |
+
freeglut=3.2.2=h9c3ff4c_1
|
| 70 |
+
freetype=2.11.0=h70c0345_0
|
| 71 |
+
frozenlist=1.3.0=pypi_0
|
| 72 |
+
fsspec=2022.5.0=pyhd8ed1ab_0
|
| 73 |
+
ftfy=6.1.1=pypi_0
|
| 74 |
+
gettext=0.19.8.1=h73d1719_1008
|
| 75 |
+
giflib=5.2.1=h7b6447c_0
|
| 76 |
+
glib=2.70.2=h780b84a_4
|
| 77 |
+
glib-tools=2.70.2=h780b84a_4
|
| 78 |
+
gmp=6.2.1=h295c915_3
|
| 79 |
+
gnutls=3.7.6=hb5d6004_1
|
| 80 |
+
gradio=3.0.24=pypi_0
|
| 81 |
+
graphite2=1.3.13=h58526e2_1001
|
| 82 |
+
gst-plugins-base=1.20.3=hf6a322e_0
|
| 83 |
+
gstreamer=1.20.3=hd4edc92_0
|
| 84 |
+
h11=0.12.0=pypi_0
|
| 85 |
+
harfbuzz=4.4.1=hf9f4e7c_0
|
| 86 |
+
hdf5=1.12.1=nompi_h2386368_104
|
| 87 |
+
httpcore=0.15.0=pypi_0
|
| 88 |
+
httpx=0.23.0=pypi_0
|
| 89 |
+
icu=70.1=h27087fc_0
|
| 90 |
+
idna=3.3=pyhd3eb1b0_0
|
| 91 |
+
imagecodecs=2022.2.22=py310h3ac3b6e_6
|
| 92 |
+
imageio=2.19.3=pyhcf75d05_0
|
| 93 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
| 94 |
+
ipykernel=6.13.0=pypi_0
|
| 95 |
+
ipython=8.4.0=pypi_0
|
| 96 |
+
ipython-genutils=0.2.0=pypi_0
|
| 97 |
+
jack=1.9.18=h8c3723f_1002
|
| 98 |
+
jasper=2.0.33=ha77e612_0
|
| 99 |
+
jedi=0.18.1=pypi_0
|
| 100 |
+
jinja2=3.1.2=pypi_0
|
| 101 |
+
joblib=1.1.0=pyhd8ed1ab_0
|
| 102 |
+
jpeg=9e=h7f8727e_0
|
| 103 |
+
json5=0.9.8=pypi_0
|
| 104 |
+
jsonschema=4.6.0=pypi_0
|
| 105 |
+
jupyter-client=7.3.1=pypi_0
|
| 106 |
+
jupyter-core=4.10.0=pypi_0
|
| 107 |
+
jupyter-server=1.17.0=pypi_0
|
| 108 |
+
jupyterlab=3.4.2=pypi_0
|
| 109 |
+
jupyterlab-pygments=0.2.2=pypi_0
|
| 110 |
+
jupyterlab-server=2.14.0=pypi_0
|
| 111 |
+
jxrlib=1.1=h7f98852_2
|
| 112 |
+
keyutils=1.6.1=h166bdaf_0
|
| 113 |
+
kiwisolver=1.4.2=pypi_0
|
| 114 |
+
kornia=0.6.5=pypi_0
|
| 115 |
+
krb5=1.19.3=h3790be6_0
|
| 116 |
+
lame=3.100=h7b6447c_0
|
| 117 |
+
lcms2=2.12=h3be6417_0
|
| 118 |
+
ld_impl_linux-64=2.38=h1181459_1
|
| 119 |
+
lerc=3.0=h9c3ff4c_0
|
| 120 |
+
libaec=1.0.6=h9c3ff4c_0
|
| 121 |
+
libavif=0.10.1=h166bdaf_0
|
| 122 |
+
libblas=3.9.0=12_linux64_mkl
|
| 123 |
+
libbrotlicommon=1.0.9=h166bdaf_7
|
| 124 |
+
libbrotlidec=1.0.9=h166bdaf_7
|
| 125 |
+
libbrotlienc=1.0.9=h166bdaf_7
|
| 126 |
+
libcap=2.64=ha37c62d_0
|
| 127 |
+
libcblas=3.9.0=12_linux64_mkl
|
| 128 |
+
libclang=14.0.6=default_h2e3cab8_0
|
| 129 |
+
libclang13=14.0.6=default_h3a83d3e_0
|
| 130 |
+
libcups=2.3.3=hf5a7f15_1
|
| 131 |
+
libcurl=7.83.1=h7bff187_0
|
| 132 |
+
libdb=6.2.32=h9c3ff4c_0
|
| 133 |
+
libdeflate=1.12=h166bdaf_0
|
| 134 |
+
libdrm=2.4.112=h166bdaf_0
|
| 135 |
+
libedit=3.1.20191231=he28a2e2_2
|
| 136 |
+
libev=4.33=h516909a_1
|
| 137 |
+
libevent=2.1.10=h9b69904_4
|
| 138 |
+
libffi=3.4.2=h7f98852_5
|
| 139 |
+
libflac=1.3.4=h27087fc_0
|
| 140 |
+
libgcc-ng=12.1.0=h8d9b700_16
|
| 141 |
+
libgfortran-ng=12.1.0=h69a702a_16
|
| 142 |
+
libgfortran5=12.1.0=hdcd56e2_16
|
| 143 |
+
libglib=2.70.2=h174f98d_4
|
| 144 |
+
libglu=9.0.0=he1b5a44_1001
|
| 145 |
+
libiconv=1.16=h7f8727e_2
|
| 146 |
+
libidn2=2.3.2=h7f8727e_0
|
| 147 |
+
liblapack=3.9.0=12_linux64_mkl
|
| 148 |
+
liblapacke=3.9.0=12_linux64_mkl
|
| 149 |
+
libllvm14=14.0.6=he0ac6c6_0
|
| 150 |
+
libnghttp2=1.47.0=h727a467_0
|
| 151 |
+
libnsl=2.0.0=h7f98852_0
|
| 152 |
+
libogg=1.3.4=h7f98852_1
|
| 153 |
+
libopencv=4.5.5=py310hcb97b83_13
|
| 154 |
+
libopus=1.3.1=h7f98852_1
|
| 155 |
+
libpciaccess=0.16=h516909a_0
|
| 156 |
+
libpng=1.6.37=hbc83047_0
|
| 157 |
+
libpq=14.4=hd77ab85_0
|
| 158 |
+
libprotobuf=3.20.1=h6239696_0
|
| 159 |
+
libsndfile=1.0.31=h9c3ff4c_1
|
| 160 |
+
libssh2=1.10.0=ha56f1ee_2
|
| 161 |
+
libstdcxx-ng=12.1.0=ha89aaad_16
|
| 162 |
+
libtasn1=4.16.0=h27cfd23_0
|
| 163 |
+
libtiff=4.4.0=hc85c160_1
|
| 164 |
+
libtool=2.4.6=h9c3ff4c_1008
|
| 165 |
+
libudev1=249=h166bdaf_4
|
| 166 |
+
libunistring=0.9.10=h27cfd23_0
|
| 167 |
+
libuuid=2.32.1=h7f98852_1000
|
| 168 |
+
libuv=1.40.0=h7b6447c_0
|
| 169 |
+
libva=2.15.0=h166bdaf_0
|
| 170 |
+
libvorbis=1.3.7=h9c3ff4c_0
|
| 171 |
+
libvpx=1.11.0=h9c3ff4c_3
|
| 172 |
+
libwebp=1.2.2=h55f646e_0
|
| 173 |
+
libwebp-base=1.2.2=h7f8727e_0
|
| 174 |
+
libxcb=1.13=h7f98852_1004
|
| 175 |
+
libxkbcommon=1.0.3=he3ba5ed_0
|
| 176 |
+
libxml2=2.9.14=h22db469_3
|
| 177 |
+
libzlib=1.2.12=h166bdaf_1
|
| 178 |
+
libzopfli=1.0.3=h9c3ff4c_0
|
| 179 |
+
linkify-it-py=1.0.3=pypi_0
|
| 180 |
+
llvm-openmp=14.0.4=he0ac6c6_0
|
| 181 |
+
locket=1.0.0=pyhd8ed1ab_0
|
| 182 |
+
lz4-c=1.9.3=h295c915_1
|
| 183 |
+
markdown-it-py=2.1.0=pypi_0
|
| 184 |
+
markupsafe=2.1.1=pypi_0
|
| 185 |
+
matplotlib=3.5.2=pypi_0
|
| 186 |
+
matplotlib-inline=0.1.3=pypi_0
|
| 187 |
+
mdit-py-plugins=0.3.0=pypi_0
|
| 188 |
+
mdurl=0.1.1=pypi_0
|
| 189 |
+
mistune=0.8.4=pypi_0
|
| 190 |
+
mkl=2021.4.0=h06a4308_640
|
| 191 |
+
mkl-service=2.4.0=py310h7f8727e_0
|
| 192 |
+
mkl_fft=1.3.1=py310hd6ae3a3_0
|
| 193 |
+
mkl_random=1.2.2=py310h00e6091_0
|
| 194 |
+
mltk=0.0.5=pypi_0
|
| 195 |
+
monotonic=1.6=pypi_0
|
| 196 |
+
multidict=6.0.2=pypi_0
|
| 197 |
+
munch=2.5.0=pypi_0
|
| 198 |
+
mysql-common=8.0.29=haf5c9bc_1
|
| 199 |
+
mysql-libs=8.0.29=h28c427c_1
|
| 200 |
+
nbclassic=0.3.7=pypi_0
|
| 201 |
+
nbclient=0.6.4=pypi_0
|
| 202 |
+
nbconvert=6.5.0=pypi_0
|
| 203 |
+
nbformat=5.4.0=pypi_0
|
| 204 |
+
ncurses=6.3=h7f8727e_2
|
| 205 |
+
nest-asyncio=1.5.5=pypi_0
|
| 206 |
+
nettle=3.7.3=hbbd107a_1
|
| 207 |
+
networkx=2.8.4=pyhd8ed1ab_0
|
| 208 |
+
nltk=3.7=pypi_0
|
| 209 |
+
notebook=6.4.11=pypi_0
|
| 210 |
+
notebook-shim=0.1.0=pypi_0
|
| 211 |
+
nspr=4.32=h9c3ff4c_1
|
| 212 |
+
nss=3.78=h2350873_0
|
| 213 |
+
ntk=1.1.3=pypi_0
|
| 214 |
+
numpy=1.22.3=py310hfa59a62_0
|
| 215 |
+
numpy-base=1.22.3=py310h9585f30_0
|
| 216 |
+
opencv=4.5.5=py310hff52083_13
|
| 217 |
+
opencv-python=4.6.0.66=pypi_0
|
| 218 |
+
openh264=2.1.1=h4ff587b_0
|
| 219 |
+
openjpeg=2.4.0=hb52868f_1
|
| 220 |
+
openssl=1.1.1q=h166bdaf_0
|
| 221 |
+
orjson=3.7.7=pypi_0
|
| 222 |
+
packaging=21.3=pyhd8ed1ab_0
|
| 223 |
+
pandas=1.4.2=pypi_0
|
| 224 |
+
pandocfilters=1.5.0=pypi_0
|
| 225 |
+
paramiko=2.11.0=pypi_0
|
| 226 |
+
parso=0.8.3=pypi_0
|
| 227 |
+
partd=1.2.0=pyhd8ed1ab_0
|
| 228 |
+
pcre=8.45=h9c3ff4c_0
|
| 229 |
+
pexpect=4.8.0=pypi_0
|
| 230 |
+
pickleshare=0.7.5=pypi_0
|
| 231 |
+
pillow=9.0.1=py310h22f2fdc_0
|
| 232 |
+
pip=21.2.4=py310h06a4308_0
|
| 233 |
+
pixman=0.40.0=h36c2ea0_0
|
| 234 |
+
portaudio=19.6.0=h57a0ea0_5
|
| 235 |
+
prometheus-client=0.14.1=pypi_0
|
| 236 |
+
prompt-toolkit=3.0.29=pypi_0
|
| 237 |
+
psutil=5.9.1=pypi_0
|
| 238 |
+
pthread-stubs=0.4=h36c2ea0_1001
|
| 239 |
+
ptyprocess=0.7.0=pypi_0
|
| 240 |
+
pulseaudio=14.0=h7f54b18_8
|
| 241 |
+
pure-eval=0.2.2=pypi_0
|
| 242 |
+
py-opencv=4.5.5=py310hfdc917e_13
|
| 243 |
+
pycocotools=2.0.4=pypi_0
|
| 244 |
+
pycparser=2.21=pyhd3eb1b0_0
|
| 245 |
+
pycryptodome=3.15.0=pypi_0
|
| 246 |
+
pydantic=1.9.1=pypi_0
|
| 247 |
+
pydub=0.25.1=pypi_0
|
| 248 |
+
pygments=2.12.0=pypi_0
|
| 249 |
+
pynacl=1.5.0=pypi_0
|
| 250 |
+
pyopenssl=22.0.0=pyhd3eb1b0_0
|
| 251 |
+
pyparsing=3.0.9=pyhd8ed1ab_0
|
| 252 |
+
pyrsistent=0.18.1=pypi_0
|
| 253 |
+
pysocks=1.7.1=py310h06a4308_0
|
| 254 |
+
python=3.10.5=h582c2e5_0_cpython
|
| 255 |
+
python-dateutil=2.8.2=pypi_0
|
| 256 |
+
python-multipart=0.0.5=pypi_0
|
| 257 |
+
python_abi=3.10=2_cp310
|
| 258 |
+
pytorch=1.11.0=py3.10_cuda10.2_cudnn7.6.5_0
|
| 259 |
+
pytorch-mutex=1.0=cuda
|
| 260 |
+
pytz=2022.1=pypi_0
|
| 261 |
+
pywavelets=1.3.0=py310hde88566_1
|
| 262 |
+
pyyaml=6.0=py310h5764c6d_4
|
| 263 |
+
pyzmq=23.1.0=pypi_0
|
| 264 |
+
qt-main=5.15.4=ha5833f6_2
|
| 265 |
+
qudida=0.0.4=pyhd8ed1ab_0
|
| 266 |
+
readline=8.1.2=h7f8727e_1
|
| 267 |
+
regex=2022.6.2=pypi_0
|
| 268 |
+
requests=2.27.1=pyhd3eb1b0_0
|
| 269 |
+
rfc3986=1.5.0=pypi_0
|
| 270 |
+
scikit-image=0.19.3=py310h769672d_0
|
| 271 |
+
scikit-learn=1.1.1=py310hffb9edd_0
|
| 272 |
+
scipy=1.8.1=py310h7612f91_0
|
| 273 |
+
seaborn=0.11.2=pypi_0
|
| 274 |
+
send2trash=1.8.0=pypi_0
|
| 275 |
+
setuptools=61.2.0=py310h06a4308_0
|
| 276 |
+
six=1.16.0=pyhd3eb1b0_1
|
| 277 |
+
snappy=1.1.9=hbd366e4_1
|
| 278 |
+
sniffio=1.2.0=pypi_0
|
| 279 |
+
soupsieve=2.3.2.post1=pypi_0
|
| 280 |
+
sqlite=3.39.0=h4ff8645_0
|
| 281 |
+
stack-data=0.2.0=pypi_0
|
| 282 |
+
starlette=0.19.1=pypi_0
|
| 283 |
+
svt-av1=1.1.0=h27087fc_1
|
| 284 |
+
termcolor=1.1.0=pypi_0
|
| 285 |
+
terminado=0.15.0=pypi_0
|
| 286 |
+
threadpoolctl=3.1.0=pyh8a188c0_0
|
| 287 |
+
tifffile=2022.5.4=pyhd8ed1ab_0
|
| 288 |
+
tinycss2=1.1.1=pypi_0
|
| 289 |
+
tk=8.6.12=h1ccaba5_0
|
| 290 |
+
toolz=0.11.2=pyhd8ed1ab_0
|
| 291 |
+
torchaudio=0.11.0=py310_cu102
|
| 292 |
+
torchvision=0.12.0=py310_cu102
|
| 293 |
+
tornado=6.1=pypi_0
|
| 294 |
+
tqdm=4.64.0=pypi_0
|
| 295 |
+
traitlets=5.2.2.post1=pypi_0
|
| 296 |
+
typing-extensions=4.1.1=hd3eb1b0_0
|
| 297 |
+
typing_extensions=4.1.1=pyh06a4308_0
|
| 298 |
+
tzdata=2022a=hda174b7_0
|
| 299 |
+
uc-micro-py=1.0.1=pypi_0
|
| 300 |
+
urllib3=1.26.9=py310h06a4308_0
|
| 301 |
+
uvicorn=0.18.2=pypi_0
|
| 302 |
+
wcwidth=0.2.5=pypi_0
|
| 303 |
+
webencodings=0.5.1=pypi_0
|
| 304 |
+
websocket-client=1.3.2=pypi_0
|
| 305 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
| 306 |
+
x264=1!161.3030=h7f98852_1
|
| 307 |
+
x265=3.5=h924138e_3
|
| 308 |
+
xcb-util=0.4.0=h166bdaf_0
|
| 309 |
+
xcb-util-image=0.4.0=h166bdaf_0
|
| 310 |
+
xcb-util-keysyms=0.4.0=h166bdaf_0
|
| 311 |
+
xcb-util-renderutil=0.3.9=h166bdaf_0
|
| 312 |
+
xcb-util-wm=0.4.1=h166bdaf_0
|
| 313 |
+
xorg-fixesproto=5.0=h7f98852_1002
|
| 314 |
+
xorg-inputproto=2.3.2=h7f98852_1002
|
| 315 |
+
xorg-kbproto=1.0.7=h7f98852_1002
|
| 316 |
+
xorg-libice=1.0.10=h7f98852_0
|
| 317 |
+
xorg-libsm=1.2.3=hd9c2040_1000
|
| 318 |
+
xorg-libx11=1.7.2=h7f98852_0
|
| 319 |
+
xorg-libxau=1.0.9=h7f98852_0
|
| 320 |
+
xorg-libxdmcp=1.1.3=h7f98852_0
|
| 321 |
+
xorg-libxext=1.3.4=h7f98852_1
|
| 322 |
+
xorg-libxfixes=5.0.3=h7f98852_1004
|
| 323 |
+
xorg-libxi=1.7.10=h7f98852_0
|
| 324 |
+
xorg-libxrender=0.9.10=h7f98852_1003
|
| 325 |
+
xorg-renderproto=0.11.1=h7f98852_1002
|
| 326 |
+
xorg-xextproto=7.3.0=h7f98852_1002
|
| 327 |
+
xorg-xproto=7.0.31=h7f98852_1007
|
| 328 |
+
xz=5.2.5=h7f8727e_1
|
| 329 |
+
yaml=0.2.5=h7f98852_2
|
| 330 |
+
yarl=1.7.2=pypi_0
|
| 331 |
+
zfp=0.5.5=h9c3ff4c_8
|
| 332 |
+
zlib=1.2.12=h166bdaf_1
|
| 333 |
+
zlib-ng=2.0.6=h166bdaf_0
|
| 334 |
+
zstd=1.5.2=ha4553b6_0
|
src/.gitkeep
ADDED
|
File without changes
|
src/__pycache__/ds.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
src/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
src/__pycache__/networks_SRGAN.cpython-310.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (34 kB). View file
|
|
|
src/app.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from matplotlib import cm
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.models as models
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from ds import *
|
| 17 |
+
from losses import *
|
| 18 |
+
from networks_SRGAN import *
|
| 19 |
+
from utils import *
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
NetG = Generator()
|
| 23 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
| 24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 25 |
+
print("Number of Parameters:",params)
|
| 26 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
NetG = Generator()
|
| 30 |
+
NetG.load_state_dict(torch.load('../ckpt/srgan-ImageNet-bc347d67.pth', map_location='cuda:0'))
|
| 31 |
+
NetG.to('cuda')
|
| 32 |
+
NetG.eval()
|
| 33 |
+
|
| 34 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
| 35 |
+
NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location='cuda:0'))
|
| 36 |
+
NetC.to('cuda')
|
| 37 |
+
NetC.eval()
|
| 38 |
+
|
| 39 |
+
def tensor01_to_pil(xt):
|
| 40 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
| 41 |
+
return r
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def predict(img):
|
| 45 |
+
"""
|
| 46 |
+
img: image
|
| 47 |
+
"""
|
| 48 |
+
image_size = (256,256)
|
| 49 |
+
upscale_factor = 4
|
| 50 |
+
lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
| 51 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
| 52 |
+
|
| 53 |
+
img = Image.fromarray(np.array(img))
|
| 54 |
+
img = lr_transforms(img)
|
| 55 |
+
lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
|
| 56 |
+
|
| 57 |
+
device = 'cuda'
|
| 58 |
+
dtype = torch.cuda.FloatTensor
|
| 59 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
| 60 |
+
xLR = xLR.type(dtype)
|
| 61 |
+
# pass them through the network
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
xSR = NetG(xLR)
|
| 64 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 65 |
+
|
| 66 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
| 67 |
+
b_map = xSRC_beta[0].to('cpu').data
|
| 68 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 72 |
+
|
| 73 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 74 |
+
|
| 75 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
| 76 |
+
|
| 77 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
| 78 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
| 79 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 80 |
+
|
| 81 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
| 82 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
| 83 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 84 |
+
|
| 85 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
| 86 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
| 87 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
| 88 |
+
|
| 89 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
| 90 |
+
|
| 91 |
+
import gradio as gr
|
| 92 |
+
|
| 93 |
+
title = "BayesCap"
|
| 94 |
+
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
|
| 95 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
gr.Interface(
|
| 99 |
+
fn=predict,
|
| 100 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
| 101 |
+
outputs=[
|
| 102 |
+
gr.outputs.Image(type='pil', label="Low-res"),
|
| 103 |
+
gr.outputs.Image(type='pil', label="Super-res"),
|
| 104 |
+
gr.outputs.Image(type='pil', label="Alpha"),
|
| 105 |
+
gr.outputs.Image(type='pil', label="Beta"),
|
| 106 |
+
gr.outputs.Image(type='pil', label="Uncertainty")
|
| 107 |
+
],
|
| 108 |
+
title=title,
|
| 109 |
+
description=description,
|
| 110 |
+
article=article,
|
| 111 |
+
examples=[
|
| 112 |
+
["../demo_examples/baby.png"],
|
| 113 |
+
["../demo_examples/bird.png"]
|
| 114 |
+
]
|
| 115 |
+
).launch(share=True)
|
src/ds.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, division, print_function
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import copy
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import skimage.transform
|
| 10 |
+
from collections import Counter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.utils.data as data
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
| 19 |
+
|
| 20 |
+
import utils
|
| 21 |
+
|
| 22 |
+
class ImgDset(Dataset):
|
| 23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dataroot (str): Training data set address
|
| 27 |
+
image_size (int): High resolution image size
|
| 28 |
+
upscale_factor (int): Image magnification
|
| 29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
| 30 |
+
and the verification data set is not for data enhancement
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
| 35 |
+
super(ImgDset, self).__init__()
|
| 36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
| 37 |
+
|
| 38 |
+
if mode == "train":
|
| 39 |
+
self.hr_transforms = transforms.Compose([
|
| 40 |
+
transforms.RandomCrop(image_size),
|
| 41 |
+
transforms.RandomRotation(90),
|
| 42 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 43 |
+
])
|
| 44 |
+
else:
|
| 45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
| 46 |
+
|
| 47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
| 50 |
+
# Read a batch of image data
|
| 51 |
+
image = Image.open(self.filenames[batch_index])
|
| 52 |
+
|
| 53 |
+
# Transform image
|
| 54 |
+
hr_image = self.hr_transforms(image)
|
| 55 |
+
lr_image = self.lr_transforms(hr_image)
|
| 56 |
+
|
| 57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
| 58 |
+
# Note: The range of input and output is between [0, 1]
|
| 59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
| 60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
| 61 |
+
|
| 62 |
+
return lr_tensor, hr_tensor
|
| 63 |
+
|
| 64 |
+
def __len__(self) -> int:
|
| 65 |
+
return len(self.filenames)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PairedImages_w_nameList(Dataset):
|
| 69 |
+
'''
|
| 70 |
+
can act as supervised or un-supervised based on flists
|
| 71 |
+
'''
|
| 72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
| 73 |
+
self.flist1 = flist1
|
| 74 |
+
self.flist2 = flist2
|
| 75 |
+
self.transform1 = transform1
|
| 76 |
+
self.transform2 = transform2
|
| 77 |
+
self.do_aug = do_aug
|
| 78 |
+
def __getitem__(self, index):
|
| 79 |
+
impath1 = self.flist1[index]
|
| 80 |
+
img1 = Image.open(impath1).convert('RGB')
|
| 81 |
+
impath2 = self.flist2[index]
|
| 82 |
+
img2 = Image.open(impath2).convert('RGB')
|
| 83 |
+
|
| 84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
| 85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
| 86 |
+
|
| 87 |
+
if self.transform1 is not None:
|
| 88 |
+
img1 = self.transform1(img1)
|
| 89 |
+
if self.transform2 is not None:
|
| 90 |
+
img2 = self.transform2(img2)
|
| 91 |
+
|
| 92 |
+
return img1, img2
|
| 93 |
+
def __len__(self):
|
| 94 |
+
return len(self.flist1)
|
| 95 |
+
|
| 96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
| 97 |
+
'''
|
| 98 |
+
can act as supervised or un-supervised based on flists
|
| 99 |
+
'''
|
| 100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
| 101 |
+
self.flist1 = flist1
|
| 102 |
+
self.flist2 = flist2
|
| 103 |
+
self.transform1 = transform1
|
| 104 |
+
self.transform2 = transform2
|
| 105 |
+
self.do_aug = do_aug
|
| 106 |
+
def __getitem__(self, index):
|
| 107 |
+
impath1 = self.flist1[index]
|
| 108 |
+
img1 = np.load(impath1)
|
| 109 |
+
impath2 = self.flist2[index]
|
| 110 |
+
img2 = np.load(impath2)
|
| 111 |
+
|
| 112 |
+
if self.transform1 is not None:
|
| 113 |
+
img1 = self.transform1(img1)
|
| 114 |
+
if self.transform2 is not None:
|
| 115 |
+
img2 = self.transform2(img2)
|
| 116 |
+
|
| 117 |
+
return img1, img2
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.flist1)
|
| 120 |
+
|
| 121 |
+
# def call_paired():
|
| 122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
| 123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
| 124 |
+
|
| 125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
| 126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
| 127 |
+
|
| 128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
| 129 |
+
|
| 130 |
+
#### KITTI depth
|
| 131 |
+
|
| 132 |
+
def load_velodyne_points(filename):
|
| 133 |
+
"""Load 3D point cloud from KITTI file format
|
| 134 |
+
(adapted from https://github.com/hunse/kitti)
|
| 135 |
+
"""
|
| 136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
| 137 |
+
points[:, 3] = 1.0 # homogeneous
|
| 138 |
+
return points
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def read_calib_file(path):
|
| 142 |
+
"""Read KITTI calibration file
|
| 143 |
+
(from https://github.com/hunse/kitti)
|
| 144 |
+
"""
|
| 145 |
+
float_chars = set("0123456789.e+- ")
|
| 146 |
+
data = {}
|
| 147 |
+
with open(path, 'r') as f:
|
| 148 |
+
for line in f.readlines():
|
| 149 |
+
key, value = line.split(':', 1)
|
| 150 |
+
value = value.strip()
|
| 151 |
+
data[key] = value
|
| 152 |
+
if float_chars.issuperset(value):
|
| 153 |
+
# try to cast to float array
|
| 154 |
+
try:
|
| 155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
| 156 |
+
except ValueError:
|
| 157 |
+
# casting error: data[key] already eq. value, so pass
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
return data
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
| 164 |
+
"""Convert row, col matrix subscripts to linear indices
|
| 165 |
+
"""
|
| 166 |
+
m, n = matrixSize
|
| 167 |
+
return rowSub * (n-1) + colSub - 1
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
| 171 |
+
"""Generate a depth map from velodyne data
|
| 172 |
+
"""
|
| 173 |
+
# load calibration files
|
| 174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
| 175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
| 176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
| 177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
| 178 |
+
|
| 179 |
+
# get image shape
|
| 180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
| 181 |
+
|
| 182 |
+
# compute projection matrix velodyne->image plane
|
| 183 |
+
R_cam2rect = np.eye(4)
|
| 184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
| 185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
| 186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
| 187 |
+
|
| 188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
| 189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
| 190 |
+
velo = load_velodyne_points(velo_filename)
|
| 191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
| 192 |
+
|
| 193 |
+
# project the points to the camera
|
| 194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
| 195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
| 196 |
+
|
| 197 |
+
if vel_depth:
|
| 198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
| 199 |
+
|
| 200 |
+
# check if in bounds
|
| 201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
| 202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
| 203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
| 204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
| 205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
| 206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
| 207 |
+
|
| 208 |
+
# project to image
|
| 209 |
+
depth = np.zeros((im_shape[:2]))
|
| 210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
| 211 |
+
|
| 212 |
+
# find the duplicate points and choose the closest depth
|
| 213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
| 214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
| 215 |
+
for dd in dupe_inds:
|
| 216 |
+
pts = np.where(inds == dd)[0]
|
| 217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
| 218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
| 219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
| 220 |
+
depth[depth < 0] = 0
|
| 221 |
+
|
| 222 |
+
return depth
|
| 223 |
+
|
| 224 |
+
def pil_loader(path):
|
| 225 |
+
# open path as file to avoid ResourceWarning
|
| 226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
| 227 |
+
with open(path, 'rb') as f:
|
| 228 |
+
with Image.open(f) as img:
|
| 229 |
+
return img.convert('RGB')
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MonoDataset(data.Dataset):
|
| 233 |
+
"""Superclass for monocular dataloaders
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
data_path
|
| 237 |
+
filenames
|
| 238 |
+
height
|
| 239 |
+
width
|
| 240 |
+
frame_idxs
|
| 241 |
+
num_scales
|
| 242 |
+
is_train
|
| 243 |
+
img_ext
|
| 244 |
+
"""
|
| 245 |
+
def __init__(self,
|
| 246 |
+
data_path,
|
| 247 |
+
filenames,
|
| 248 |
+
height,
|
| 249 |
+
width,
|
| 250 |
+
frame_idxs,
|
| 251 |
+
num_scales,
|
| 252 |
+
is_train=False,
|
| 253 |
+
img_ext='.jpg'):
|
| 254 |
+
super(MonoDataset, self).__init__()
|
| 255 |
+
|
| 256 |
+
self.data_path = data_path
|
| 257 |
+
self.filenames = filenames
|
| 258 |
+
self.height = height
|
| 259 |
+
self.width = width
|
| 260 |
+
self.num_scales = num_scales
|
| 261 |
+
self.interp = Image.ANTIALIAS
|
| 262 |
+
|
| 263 |
+
self.frame_idxs = frame_idxs
|
| 264 |
+
|
| 265 |
+
self.is_train = is_train
|
| 266 |
+
self.img_ext = img_ext
|
| 267 |
+
|
| 268 |
+
self.loader = pil_loader
|
| 269 |
+
self.to_tensor = transforms.ToTensor()
|
| 270 |
+
|
| 271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
| 272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
| 273 |
+
try:
|
| 274 |
+
self.brightness = (0.8, 1.2)
|
| 275 |
+
self.contrast = (0.8, 1.2)
|
| 276 |
+
self.saturation = (0.8, 1.2)
|
| 277 |
+
self.hue = (-0.1, 0.1)
|
| 278 |
+
transforms.ColorJitter.get_params(
|
| 279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 280 |
+
except TypeError:
|
| 281 |
+
self.brightness = 0.2
|
| 282 |
+
self.contrast = 0.2
|
| 283 |
+
self.saturation = 0.2
|
| 284 |
+
self.hue = 0.1
|
| 285 |
+
|
| 286 |
+
self.resize = {}
|
| 287 |
+
for i in range(self.num_scales):
|
| 288 |
+
s = 2 ** i
|
| 289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
| 290 |
+
interpolation=self.interp)
|
| 291 |
+
|
| 292 |
+
self.load_depth = self.check_depth()
|
| 293 |
+
|
| 294 |
+
def preprocess(self, inputs, color_aug):
|
| 295 |
+
"""Resize colour images to the required scales and augment if required
|
| 296 |
+
|
| 297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
| 298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
| 299 |
+
same augmentation.
|
| 300 |
+
"""
|
| 301 |
+
for k in list(inputs):
|
| 302 |
+
frame = inputs[k]
|
| 303 |
+
if "color" in k:
|
| 304 |
+
n, im, i = k
|
| 305 |
+
for i in range(self.num_scales):
|
| 306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
| 307 |
+
|
| 308 |
+
for k in list(inputs):
|
| 309 |
+
f = inputs[k]
|
| 310 |
+
if "color" in k:
|
| 311 |
+
n, im, i = k
|
| 312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
| 313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
| 314 |
+
|
| 315 |
+
def __len__(self):
|
| 316 |
+
return len(self.filenames)
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, index):
|
| 319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
| 320 |
+
|
| 321 |
+
Values correspond to torch tensors.
|
| 322 |
+
Keys in the dictionary are either strings or tuples:
|
| 323 |
+
|
| 324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
| 325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
| 326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
| 327 |
+
"stereo_T" for camera extrinsics, and
|
| 328 |
+
"depth_gt" for ground truth depth maps.
|
| 329 |
+
|
| 330 |
+
<frame_id> is either:
|
| 331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
| 332 |
+
or
|
| 333 |
+
"s" for the opposite image in the stereo pair.
|
| 334 |
+
|
| 335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
| 336 |
+
-1 images at native resolution as loaded from disk
|
| 337 |
+
0 images resized to (self.width, self.height )
|
| 338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
| 339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
| 340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
| 341 |
+
"""
|
| 342 |
+
inputs = {}
|
| 343 |
+
|
| 344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
| 345 |
+
do_flip = self.is_train and random.random() > 0.5
|
| 346 |
+
|
| 347 |
+
line = self.filenames[index].split()
|
| 348 |
+
folder = line[0]
|
| 349 |
+
|
| 350 |
+
if len(line) == 3:
|
| 351 |
+
frame_index = int(line[1])
|
| 352 |
+
else:
|
| 353 |
+
frame_index = 0
|
| 354 |
+
|
| 355 |
+
if len(line) == 3:
|
| 356 |
+
side = line[2]
|
| 357 |
+
else:
|
| 358 |
+
side = None
|
| 359 |
+
|
| 360 |
+
for i in self.frame_idxs:
|
| 361 |
+
if i == "s":
|
| 362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
| 363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
| 364 |
+
else:
|
| 365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
| 366 |
+
|
| 367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
| 368 |
+
for scale in range(self.num_scales):
|
| 369 |
+
K = self.K.copy()
|
| 370 |
+
|
| 371 |
+
K[0, :] *= self.width // (2 ** scale)
|
| 372 |
+
K[1, :] *= self.height // (2 ** scale)
|
| 373 |
+
|
| 374 |
+
inv_K = np.linalg.pinv(K)
|
| 375 |
+
|
| 376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
| 377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
| 378 |
+
|
| 379 |
+
if do_color_aug:
|
| 380 |
+
color_aug = transforms.ColorJitter.get_params(
|
| 381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 382 |
+
else:
|
| 383 |
+
color_aug = (lambda x: x)
|
| 384 |
+
|
| 385 |
+
self.preprocess(inputs, color_aug)
|
| 386 |
+
|
| 387 |
+
for i in self.frame_idxs:
|
| 388 |
+
del inputs[("color", i, -1)]
|
| 389 |
+
del inputs[("color_aug", i, -1)]
|
| 390 |
+
|
| 391 |
+
if self.load_depth:
|
| 392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
| 393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
| 394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
| 395 |
+
|
| 396 |
+
if "s" in self.frame_idxs:
|
| 397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
| 398 |
+
baseline_sign = -1 if do_flip else 1
|
| 399 |
+
side_sign = -1 if side == "l" else 1
|
| 400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
| 401 |
+
|
| 402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
| 403 |
+
|
| 404 |
+
return inputs
|
| 405 |
+
|
| 406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
| 407 |
+
raise NotImplementedError
|
| 408 |
+
|
| 409 |
+
def check_depth(self):
|
| 410 |
+
raise NotImplementedError
|
| 411 |
+
|
| 412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
| 413 |
+
raise NotImplementedError
|
| 414 |
+
|
| 415 |
+
class KITTIDataset(MonoDataset):
|
| 416 |
+
"""Superclass for different types of KITTI dataset loaders
|
| 417 |
+
"""
|
| 418 |
+
def __init__(self, *args, **kwargs):
|
| 419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
| 420 |
+
|
| 421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
| 422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
| 423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
| 424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
| 425 |
+
# flip augmentation.
|
| 426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
| 427 |
+
[0, 1.92, 0.5, 0],
|
| 428 |
+
[0, 0, 1, 0],
|
| 429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
| 430 |
+
|
| 431 |
+
self.full_res_shape = (1242, 375)
|
| 432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
| 433 |
+
|
| 434 |
+
def check_depth(self):
|
| 435 |
+
line = self.filenames[0].split()
|
| 436 |
+
scene_name = line[0]
|
| 437 |
+
frame_index = int(line[1])
|
| 438 |
+
|
| 439 |
+
velo_filename = os.path.join(
|
| 440 |
+
self.data_path,
|
| 441 |
+
scene_name,
|
| 442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
| 443 |
+
|
| 444 |
+
return os.path.isfile(velo_filename)
|
| 445 |
+
|
| 446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
| 447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
| 448 |
+
|
| 449 |
+
if do_flip:
|
| 450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
| 451 |
+
|
| 452 |
+
return color
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class KITTIDepthDataset(KITTIDataset):
|
| 456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
| 457 |
+
"""
|
| 458 |
+
def __init__(self, *args, **kwargs):
|
| 459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
| 460 |
+
|
| 461 |
+
def get_image_path(self, folder, frame_index, side):
|
| 462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
| 463 |
+
image_path = os.path.join(
|
| 464 |
+
self.data_path,
|
| 465 |
+
folder,
|
| 466 |
+
"image_0{}/data".format(self.side_map[side]),
|
| 467 |
+
f_str)
|
| 468 |
+
return image_path
|
| 469 |
+
|
| 470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
| 471 |
+
f_str = "{:010d}.png".format(frame_index)
|
| 472 |
+
depth_path = os.path.join(
|
| 473 |
+
self.data_path,
|
| 474 |
+
folder,
|
| 475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
| 476 |
+
f_str)
|
| 477 |
+
|
| 478 |
+
depth_gt = Image.open(depth_path)
|
| 479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
| 480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
| 481 |
+
|
| 482 |
+
if do_flip:
|
| 483 |
+
depth_gt = np.fliplr(depth_gt)
|
| 484 |
+
|
| 485 |
+
return depth_gt
|
src/flagged/Alpha/0.png
ADDED
|
src/flagged/Beta/0.png
ADDED
|
src/flagged/Low-res/0.png
ADDED
|
src/flagged/Orignal/0.png
ADDED
|
src/flagged/Super-res/0.png
ADDED
|
src/flagged/Uncertainty/0.png
ADDED
|
src/flagged/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'Orignal','Low-res','Super-res','Alpha','Beta','Uncertainty','flag','username','timestamp'
|
| 2 |
+
'Orignal/0.png','Low-res/0.png','Super-res/0.png','Alpha/0.png','Beta/0.png','Uncertainty/0.png','','','2022-07-09 14:01:12.964411'
|
src/losses.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
class ContentLoss(nn.Module):
|
| 8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
| 9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
| 10 |
+
|
| 11 |
+
Paper reference list:
|
| 12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
| 13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
| 14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super(ContentLoss, self).__init__()
|
| 20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
| 21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
| 22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
| 23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
| 24 |
+
# Freeze model parameters.
|
| 25 |
+
for parameters in self.feature_extractor.parameters():
|
| 26 |
+
parameters.requires_grad = False
|
| 27 |
+
|
| 28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
| 29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 31 |
+
|
| 32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
| 33 |
+
# Standardized operations
|
| 34 |
+
sr = sr.sub(self.mean).div(self.std)
|
| 35 |
+
hr = hr.sub(self.mean).div(self.std)
|
| 36 |
+
|
| 37 |
+
# Find the feature map difference between the two images
|
| 38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
| 39 |
+
|
| 40 |
+
return loss
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GenGaussLoss(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self, reduction='mean',
|
| 46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
| 47 |
+
resi_min = 1e-4, resi_max=1e3
|
| 48 |
+
) -> None:
|
| 49 |
+
super(GenGaussLoss, self).__init__()
|
| 50 |
+
self.reduction = reduction
|
| 51 |
+
self.alpha_eps = alpha_eps
|
| 52 |
+
self.beta_eps = beta_eps
|
| 53 |
+
self.resi_min = resi_min
|
| 54 |
+
self.resi_max = resi_max
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
| 59 |
+
):
|
| 60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
| 61 |
+
beta1 = beta + self.beta_eps
|
| 62 |
+
|
| 63 |
+
resi = torch.abs(mean - target)
|
| 64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
| 65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
| 66 |
+
## check if resi has nans
|
| 67 |
+
if torch.sum(resi != resi) > 0:
|
| 68 |
+
print('resi has nans!!')
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
| 72 |
+
log_beta = torch.log(beta1)
|
| 73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
| 74 |
+
|
| 75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
| 76 |
+
print('log_one_over_alpha has nan')
|
| 77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
| 78 |
+
print('lgamma_beta has nan')
|
| 79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
| 80 |
+
print('log_beta has nan')
|
| 81 |
+
|
| 82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
| 83 |
+
|
| 84 |
+
if self.reduction == 'mean':
|
| 85 |
+
return l.mean()
|
| 86 |
+
elif self.reduction == 'sum':
|
| 87 |
+
return l.sum()
|
| 88 |
+
else:
|
| 89 |
+
print('Reduction not supported')
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
class TempCombLoss(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self, reduction='mean',
|
| 95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
| 96 |
+
resi_min = 1e-4, resi_max=1e3
|
| 97 |
+
) -> None:
|
| 98 |
+
super(TempCombLoss, self).__init__()
|
| 99 |
+
self.reduction = reduction
|
| 100 |
+
self.alpha_eps = alpha_eps
|
| 101 |
+
self.beta_eps = beta_eps
|
| 102 |
+
self.resi_min = resi_min
|
| 103 |
+
self.resi_max = resi_max
|
| 104 |
+
|
| 105 |
+
self.L_GenGauss = GenGaussLoss(
|
| 106 |
+
reduction=self.reduction,
|
| 107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
| 108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
| 109 |
+
)
|
| 110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
| 115 |
+
T1: float, T2: float
|
| 116 |
+
):
|
| 117 |
+
l1 = self.L_l1(mean, target)
|
| 118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
| 119 |
+
l = T1*l1 + T2*l2
|
| 120 |
+
|
| 121 |
+
return l
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# x1 = torch.randn(4,3,32,32)
|
| 125 |
+
# x2 = torch.rand(4,3,32,32)
|
| 126 |
+
# x3 = torch.rand(4,3,32,32)
|
| 127 |
+
# x4 = torch.randn(4,3,32,32)
|
| 128 |
+
|
| 129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
| 130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
| 131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
src/networks_SRGAN.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
# __all__ = [
|
| 8 |
+
# "ResidualConvBlock",
|
| 9 |
+
# "Discriminator", "Generator",
|
| 10 |
+
# ]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualConvBlock(nn.Module):
|
| 14 |
+
"""Implements residual conv function.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
channels (int): Number of channels in the input image.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, channels: int) -> None:
|
| 21 |
+
super(ResidualConvBlock, self).__init__()
|
| 22 |
+
self.rcb = nn.Sequential(
|
| 23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
| 24 |
+
nn.BatchNorm2d(channels),
|
| 25 |
+
nn.PReLU(),
|
| 26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
| 27 |
+
nn.BatchNorm2d(channels),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
identity = x
|
| 32 |
+
|
| 33 |
+
out = self.rcb(x)
|
| 34 |
+
out = torch.add(out, identity)
|
| 35 |
+
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Discriminator(nn.Module):
|
| 40 |
+
def __init__(self) -> None:
|
| 41 |
+
super(Discriminator, self).__init__()
|
| 42 |
+
self.features = nn.Sequential(
|
| 43 |
+
# input size. (3) x 96 x 96
|
| 44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
| 45 |
+
nn.LeakyReLU(0.2, True),
|
| 46 |
+
# state size. (64) x 48 x 48
|
| 47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
| 48 |
+
nn.BatchNorm2d(64),
|
| 49 |
+
nn.LeakyReLU(0.2, True),
|
| 50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
| 51 |
+
nn.BatchNorm2d(128),
|
| 52 |
+
nn.LeakyReLU(0.2, True),
|
| 53 |
+
# state size. (128) x 24 x 24
|
| 54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
| 55 |
+
nn.BatchNorm2d(128),
|
| 56 |
+
nn.LeakyReLU(0.2, True),
|
| 57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
| 58 |
+
nn.BatchNorm2d(256),
|
| 59 |
+
nn.LeakyReLU(0.2, True),
|
| 60 |
+
# state size. (256) x 12 x 12
|
| 61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
| 62 |
+
nn.BatchNorm2d(256),
|
| 63 |
+
nn.LeakyReLU(0.2, True),
|
| 64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
| 65 |
+
nn.BatchNorm2d(512),
|
| 66 |
+
nn.LeakyReLU(0.2, True),
|
| 67 |
+
# state size. (512) x 6 x 6
|
| 68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
| 69 |
+
nn.BatchNorm2d(512),
|
| 70 |
+
nn.LeakyReLU(0.2, True),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.classifier = nn.Sequential(
|
| 74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
| 75 |
+
nn.LeakyReLU(0.2, True),
|
| 76 |
+
nn.Linear(1024, 1),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 80 |
+
out = self.features(x)
|
| 81 |
+
out = torch.flatten(out, 1)
|
| 82 |
+
out = self.classifier(out)
|
| 83 |
+
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Generator(nn.Module):
|
| 88 |
+
def __init__(self) -> None:
|
| 89 |
+
super(Generator, self).__init__()
|
| 90 |
+
# First conv layer.
|
| 91 |
+
self.conv_block1 = nn.Sequential(
|
| 92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
| 93 |
+
nn.PReLU(),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Features trunk blocks.
|
| 97 |
+
trunk = []
|
| 98 |
+
for _ in range(16):
|
| 99 |
+
trunk.append(ResidualConvBlock(64))
|
| 100 |
+
self.trunk = nn.Sequential(*trunk)
|
| 101 |
+
|
| 102 |
+
# Second conv layer.
|
| 103 |
+
self.conv_block2 = nn.Sequential(
|
| 104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
| 105 |
+
nn.BatchNorm2d(64),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Upscale conv block.
|
| 109 |
+
self.upsampling = nn.Sequential(
|
| 110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
| 111 |
+
nn.PixelShuffle(2),
|
| 112 |
+
nn.PReLU(),
|
| 113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
| 114 |
+
nn.PixelShuffle(2),
|
| 115 |
+
nn.PReLU(),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Output layer.
|
| 119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
| 120 |
+
|
| 121 |
+
# Initialize neural network weights.
|
| 122 |
+
self._initialize_weights()
|
| 123 |
+
|
| 124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
| 125 |
+
if not dop:
|
| 126 |
+
return self._forward_impl(x)
|
| 127 |
+
else:
|
| 128 |
+
return self._forward_w_dop_impl(x, dop)
|
| 129 |
+
|
| 130 |
+
# Support torch.script function.
|
| 131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 132 |
+
out1 = self.conv_block1(x)
|
| 133 |
+
out = self.trunk(out1)
|
| 134 |
+
out2 = self.conv_block2(out)
|
| 135 |
+
out = torch.add(out1, out2)
|
| 136 |
+
out = self.upsampling(out)
|
| 137 |
+
out = self.conv_block3(out)
|
| 138 |
+
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
| 142 |
+
out1 = self.conv_block1(x)
|
| 143 |
+
out = self.trunk(out1)
|
| 144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
| 145 |
+
out = torch.add(out1, out2)
|
| 146 |
+
out = self.upsampling(out)
|
| 147 |
+
out = self.conv_block3(out)
|
| 148 |
+
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
def _initialize_weights(self) -> None:
|
| 152 |
+
for module in self.modules():
|
| 153 |
+
if isinstance(module, nn.Conv2d):
|
| 154 |
+
nn.init.kaiming_normal_(module.weight)
|
| 155 |
+
if module.bias is not None:
|
| 156 |
+
nn.init.constant_(module.bias, 0)
|
| 157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 158 |
+
nn.init.constant_(module.weight, 1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
#### BayesCap
|
| 162 |
+
class BayesCap(nn.Module):
|
| 163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
| 164 |
+
super(BayesCap, self).__init__()
|
| 165 |
+
# First conv layer.
|
| 166 |
+
self.conv_block1 = nn.Sequential(
|
| 167 |
+
nn.Conv2d(
|
| 168 |
+
in_channels, 64,
|
| 169 |
+
kernel_size=9, stride=1, padding=4
|
| 170 |
+
),
|
| 171 |
+
nn.PReLU(),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Features trunk blocks.
|
| 175 |
+
trunk = []
|
| 176 |
+
for _ in range(16):
|
| 177 |
+
trunk.append(ResidualConvBlock(64))
|
| 178 |
+
self.trunk = nn.Sequential(*trunk)
|
| 179 |
+
|
| 180 |
+
# Second conv layer.
|
| 181 |
+
self.conv_block2 = nn.Sequential(
|
| 182 |
+
nn.Conv2d(
|
| 183 |
+
64, 64,
|
| 184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
| 185 |
+
),
|
| 186 |
+
nn.BatchNorm2d(64),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Output layer.
|
| 190 |
+
self.conv_block3_mu = nn.Conv2d(
|
| 191 |
+
64, out_channels=out_channels,
|
| 192 |
+
kernel_size=9, stride=1, padding=4
|
| 193 |
+
)
|
| 194 |
+
self.conv_block3_alpha = nn.Sequential(
|
| 195 |
+
nn.Conv2d(
|
| 196 |
+
64, 64,
|
| 197 |
+
kernel_size=9, stride=1, padding=4
|
| 198 |
+
),
|
| 199 |
+
nn.PReLU(),
|
| 200 |
+
nn.Conv2d(
|
| 201 |
+
64, 64,
|
| 202 |
+
kernel_size=9, stride=1, padding=4
|
| 203 |
+
),
|
| 204 |
+
nn.PReLU(),
|
| 205 |
+
nn.Conv2d(
|
| 206 |
+
64, 1,
|
| 207 |
+
kernel_size=9, stride=1, padding=4
|
| 208 |
+
),
|
| 209 |
+
nn.ReLU(),
|
| 210 |
+
)
|
| 211 |
+
self.conv_block3_beta = nn.Sequential(
|
| 212 |
+
nn.Conv2d(
|
| 213 |
+
64, 64,
|
| 214 |
+
kernel_size=9, stride=1, padding=4
|
| 215 |
+
),
|
| 216 |
+
nn.PReLU(),
|
| 217 |
+
nn.Conv2d(
|
| 218 |
+
64, 64,
|
| 219 |
+
kernel_size=9, stride=1, padding=4
|
| 220 |
+
),
|
| 221 |
+
nn.PReLU(),
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
64, 1,
|
| 224 |
+
kernel_size=9, stride=1, padding=4
|
| 225 |
+
),
|
| 226 |
+
nn.ReLU(),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Initialize neural network weights.
|
| 230 |
+
self._initialize_weights()
|
| 231 |
+
|
| 232 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 233 |
+
return self._forward_impl(x)
|
| 234 |
+
|
| 235 |
+
# Support torch.script function.
|
| 236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 237 |
+
out1 = self.conv_block1(x)
|
| 238 |
+
out = self.trunk(out1)
|
| 239 |
+
out2 = self.conv_block2(out)
|
| 240 |
+
out = out1 + out2
|
| 241 |
+
out_mu = self.conv_block3_mu(out)
|
| 242 |
+
out_alpha = self.conv_block3_alpha(out)
|
| 243 |
+
out_beta = self.conv_block3_beta(out)
|
| 244 |
+
return out_mu, out_alpha, out_beta
|
| 245 |
+
|
| 246 |
+
def _initialize_weights(self) -> None:
|
| 247 |
+
for module in self.modules():
|
| 248 |
+
if isinstance(module, nn.Conv2d):
|
| 249 |
+
nn.init.kaiming_normal_(module.weight)
|
| 250 |
+
if module.bias is not None:
|
| 251 |
+
nn.init.constant_(module.bias, 0)
|
| 252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 253 |
+
nn.init.constant_(module.weight, 1)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class BayesCap_noID(nn.Module):
|
| 257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
| 258 |
+
super(BayesCap_noID, self).__init__()
|
| 259 |
+
# First conv layer.
|
| 260 |
+
self.conv_block1 = nn.Sequential(
|
| 261 |
+
nn.Conv2d(
|
| 262 |
+
in_channels, 64,
|
| 263 |
+
kernel_size=9, stride=1, padding=4
|
| 264 |
+
),
|
| 265 |
+
nn.PReLU(),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Features trunk blocks.
|
| 269 |
+
trunk = []
|
| 270 |
+
for _ in range(16):
|
| 271 |
+
trunk.append(ResidualConvBlock(64))
|
| 272 |
+
self.trunk = nn.Sequential(*trunk)
|
| 273 |
+
|
| 274 |
+
# Second conv layer.
|
| 275 |
+
self.conv_block2 = nn.Sequential(
|
| 276 |
+
nn.Conv2d(
|
| 277 |
+
64, 64,
|
| 278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
| 279 |
+
),
|
| 280 |
+
nn.BatchNorm2d(64),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Output layer.
|
| 284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
| 285 |
+
# 64, out_channels=out_channels,
|
| 286 |
+
# kernel_size=9, stride=1, padding=4
|
| 287 |
+
# )
|
| 288 |
+
self.conv_block3_alpha = nn.Sequential(
|
| 289 |
+
nn.Conv2d(
|
| 290 |
+
64, 64,
|
| 291 |
+
kernel_size=9, stride=1, padding=4
|
| 292 |
+
),
|
| 293 |
+
nn.PReLU(),
|
| 294 |
+
nn.Conv2d(
|
| 295 |
+
64, 64,
|
| 296 |
+
kernel_size=9, stride=1, padding=4
|
| 297 |
+
),
|
| 298 |
+
nn.PReLU(),
|
| 299 |
+
nn.Conv2d(
|
| 300 |
+
64, 1,
|
| 301 |
+
kernel_size=9, stride=1, padding=4
|
| 302 |
+
),
|
| 303 |
+
nn.ReLU(),
|
| 304 |
+
)
|
| 305 |
+
self.conv_block3_beta = nn.Sequential(
|
| 306 |
+
nn.Conv2d(
|
| 307 |
+
64, 64,
|
| 308 |
+
kernel_size=9, stride=1, padding=4
|
| 309 |
+
),
|
| 310 |
+
nn.PReLU(),
|
| 311 |
+
nn.Conv2d(
|
| 312 |
+
64, 64,
|
| 313 |
+
kernel_size=9, stride=1, padding=4
|
| 314 |
+
),
|
| 315 |
+
nn.PReLU(),
|
| 316 |
+
nn.Conv2d(
|
| 317 |
+
64, 1,
|
| 318 |
+
kernel_size=9, stride=1, padding=4
|
| 319 |
+
),
|
| 320 |
+
nn.ReLU(),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Initialize neural network weights.
|
| 324 |
+
self._initialize_weights()
|
| 325 |
+
|
| 326 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 327 |
+
return self._forward_impl(x)
|
| 328 |
+
|
| 329 |
+
# Support torch.script function.
|
| 330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 331 |
+
out1 = self.conv_block1(x)
|
| 332 |
+
out = self.trunk(out1)
|
| 333 |
+
out2 = self.conv_block2(out)
|
| 334 |
+
out = out1 + out2
|
| 335 |
+
# out_mu = self.conv_block3_mu(out)
|
| 336 |
+
out_alpha = self.conv_block3_alpha(out)
|
| 337 |
+
out_beta = self.conv_block3_beta(out)
|
| 338 |
+
return out_alpha, out_beta
|
| 339 |
+
|
| 340 |
+
def _initialize_weights(self) -> None:
|
| 341 |
+
for module in self.modules():
|
| 342 |
+
if isinstance(module, nn.Conv2d):
|
| 343 |
+
nn.init.kaiming_normal_(module.weight)
|
| 344 |
+
if module.bias is not None:
|
| 345 |
+
nn.init.constant_(module.bias, 0)
|
| 346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 347 |
+
nn.init.constant_(module.weight, 1)
|
src/networks_T1toT2.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
### components
|
| 7 |
+
class ResConv(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Residual convolutional block, where
|
| 10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
| 11 |
+
residual connection adds the input to the output
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
if not mid_channels:
|
| 16 |
+
mid_channels = out_channels
|
| 17 |
+
self.double_conv = nn.Sequential(
|
| 18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
| 19 |
+
nn.BatchNorm2d(mid_channels),
|
| 20 |
+
nn.ReLU(inplace=True),
|
| 21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
| 22 |
+
nn.BatchNorm2d(mid_channels),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
| 25 |
+
nn.BatchNorm2d(out_channels),
|
| 26 |
+
nn.ReLU(inplace=True)
|
| 27 |
+
)
|
| 28 |
+
self.double_conv1 = nn.Sequential(
|
| 29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 30 |
+
nn.BatchNorm2d(out_channels),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
)
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x_in = self.double_conv1(x)
|
| 35 |
+
x1 = self.double_conv(x)
|
| 36 |
+
return self.double_conv(x) + x_in
|
| 37 |
+
|
| 38 |
+
class Down(nn.Module):
|
| 39 |
+
"""Downscaling with maxpool then Resconv"""
|
| 40 |
+
def __init__(self, in_channels, out_channels):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.maxpool_conv = nn.Sequential(
|
| 43 |
+
nn.MaxPool2d(2),
|
| 44 |
+
ResConv(in_channels, out_channels)
|
| 45 |
+
)
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.maxpool_conv(x)
|
| 48 |
+
|
| 49 |
+
class Up(nn.Module):
|
| 50 |
+
"""Upscaling then double conv"""
|
| 51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 54 |
+
if bilinear:
|
| 55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
| 57 |
+
else:
|
| 58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
| 59 |
+
self.conv = ResConv(in_channels, out_channels)
|
| 60 |
+
def forward(self, x1, x2):
|
| 61 |
+
x1 = self.up(x1)
|
| 62 |
+
# input is CHW
|
| 63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 65 |
+
x1 = F.pad(
|
| 66 |
+
x1,
|
| 67 |
+
[
|
| 68 |
+
diffX // 2, diffX - diffX // 2,
|
| 69 |
+
diffY // 2, diffY - diffY // 2
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
# if you have padding issues, see
|
| 73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 75 |
+
x = torch.cat([x2, x1], dim=1)
|
| 76 |
+
return self.conv(x)
|
| 77 |
+
|
| 78 |
+
class OutConv(nn.Module):
|
| 79 |
+
def __init__(self, in_channels, out_channels):
|
| 80 |
+
super(OutConv, self).__init__()
|
| 81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
# return F.relu(self.conv(x))
|
| 84 |
+
return self.conv(x)
|
| 85 |
+
|
| 86 |
+
##### The composite networks
|
| 87 |
+
class UNet(nn.Module):
|
| 88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 89 |
+
super(UNet, self).__init__()
|
| 90 |
+
self.n_channels = n_channels
|
| 91 |
+
self.out_channels = out_channels
|
| 92 |
+
self.bilinear = bilinear
|
| 93 |
+
####
|
| 94 |
+
self.inc = ResConv(n_channels, 64)
|
| 95 |
+
self.down1 = Down(64, 128)
|
| 96 |
+
self.down2 = Down(128, 256)
|
| 97 |
+
self.down3 = Down(256, 512)
|
| 98 |
+
factor = 2 if bilinear else 1
|
| 99 |
+
self.down4 = Down(512, 1024 // factor)
|
| 100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 103 |
+
self.up4 = Up(128, 64, bilinear)
|
| 104 |
+
self.outc = OutConv(64, out_channels)
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x1 = self.inc(x)
|
| 107 |
+
x2 = self.down1(x1)
|
| 108 |
+
x3 = self.down2(x2)
|
| 109 |
+
x4 = self.down3(x3)
|
| 110 |
+
x5 = self.down4(x4)
|
| 111 |
+
x = self.up1(x5, x4)
|
| 112 |
+
x = self.up2(x, x3)
|
| 113 |
+
x = self.up3(x, x2)
|
| 114 |
+
x = self.up4(x, x1)
|
| 115 |
+
y = self.outc(x)
|
| 116 |
+
return y
|
| 117 |
+
|
| 118 |
+
class CasUNet(nn.Module):
|
| 119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 120 |
+
super(CasUNet, self).__init__()
|
| 121 |
+
self.n_unet = n_unet
|
| 122 |
+
self.io_channels = io_channels
|
| 123 |
+
self.bilinear = bilinear
|
| 124 |
+
####
|
| 125 |
+
self.unet_list = nn.ModuleList()
|
| 126 |
+
for i in range(self.n_unet):
|
| 127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 128 |
+
def forward(self, x, dop=None):
|
| 129 |
+
y = x
|
| 130 |
+
for i in range(self.n_unet):
|
| 131 |
+
if i==0:
|
| 132 |
+
if dop is not None:
|
| 133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
| 134 |
+
else:
|
| 135 |
+
y = self.unet_list[i](y)
|
| 136 |
+
else:
|
| 137 |
+
y = self.unet_list[i](y+x)
|
| 138 |
+
return y
|
| 139 |
+
|
| 140 |
+
class CasUNet_2head(nn.Module):
|
| 141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 142 |
+
super(CasUNet_2head, self).__init__()
|
| 143 |
+
self.n_unet = n_unet
|
| 144 |
+
self.io_channels = io_channels
|
| 145 |
+
self.bilinear = bilinear
|
| 146 |
+
####
|
| 147 |
+
self.unet_list = nn.ModuleList()
|
| 148 |
+
for i in range(self.n_unet):
|
| 149 |
+
if i != self.n_unet-1:
|
| 150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 151 |
+
else:
|
| 152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
y = x
|
| 155 |
+
for i in range(self.n_unet):
|
| 156 |
+
if i==0:
|
| 157 |
+
y = self.unet_list[i](y)
|
| 158 |
+
else:
|
| 159 |
+
y = self.unet_list[i](y+x)
|
| 160 |
+
y_mean, y_sigma = y[0], y[1]
|
| 161 |
+
return y_mean, y_sigma
|
| 162 |
+
|
| 163 |
+
class CasUNet_3head(nn.Module):
|
| 164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
| 165 |
+
super(CasUNet_3head, self).__init__()
|
| 166 |
+
self.n_unet = n_unet
|
| 167 |
+
self.io_channels = io_channels
|
| 168 |
+
self.bilinear = bilinear
|
| 169 |
+
####
|
| 170 |
+
self.unet_list = nn.ModuleList()
|
| 171 |
+
for i in range(self.n_unet):
|
| 172 |
+
if i != self.n_unet-1:
|
| 173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
| 174 |
+
else:
|
| 175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
y = x
|
| 178 |
+
for i in range(self.n_unet):
|
| 179 |
+
if i==0:
|
| 180 |
+
y = self.unet_list[i](y)
|
| 181 |
+
else:
|
| 182 |
+
y = self.unet_list[i](y+x)
|
| 183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
| 184 |
+
return y_mean, y_alpha, y_beta
|
| 185 |
+
|
| 186 |
+
class UNet_2head(nn.Module):
|
| 187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 188 |
+
super(UNet_2head, self).__init__()
|
| 189 |
+
self.n_channels = n_channels
|
| 190 |
+
self.out_channels = out_channels
|
| 191 |
+
self.bilinear = bilinear
|
| 192 |
+
####
|
| 193 |
+
self.inc = ResConv(n_channels, 64)
|
| 194 |
+
self.down1 = Down(64, 128)
|
| 195 |
+
self.down2 = Down(128, 256)
|
| 196 |
+
self.down3 = Down(256, 512)
|
| 197 |
+
factor = 2 if bilinear else 1
|
| 198 |
+
self.down4 = Down(512, 1024 // factor)
|
| 199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 202 |
+
self.up4 = Up(128, 64, bilinear)
|
| 203 |
+
#per pixel multiple channels may exist
|
| 204 |
+
self.out_mean = OutConv(64, out_channels)
|
| 205 |
+
#variance will always be a single number for a pixel
|
| 206 |
+
self.out_var = nn.Sequential(
|
| 207 |
+
OutConv(64, 128),
|
| 208 |
+
OutConv(128, 1),
|
| 209 |
+
)
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
x1 = self.inc(x)
|
| 212 |
+
x2 = self.down1(x1)
|
| 213 |
+
x3 = self.down2(x2)
|
| 214 |
+
x4 = self.down3(x3)
|
| 215 |
+
x5 = self.down4(x4)
|
| 216 |
+
x = self.up1(x5, x4)
|
| 217 |
+
x = self.up2(x, x3)
|
| 218 |
+
x = self.up3(x, x2)
|
| 219 |
+
x = self.up4(x, x1)
|
| 220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
| 221 |
+
return y_mean, y_var
|
| 222 |
+
|
| 223 |
+
class UNet_3head(nn.Module):
|
| 224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
| 225 |
+
super(UNet_3head, self).__init__()
|
| 226 |
+
self.n_channels = n_channels
|
| 227 |
+
self.out_channels = out_channels
|
| 228 |
+
self.bilinear = bilinear
|
| 229 |
+
####
|
| 230 |
+
self.inc = ResConv(n_channels, 64)
|
| 231 |
+
self.down1 = Down(64, 128)
|
| 232 |
+
self.down2 = Down(128, 256)
|
| 233 |
+
self.down3 = Down(256, 512)
|
| 234 |
+
factor = 2 if bilinear else 1
|
| 235 |
+
self.down4 = Down(512, 1024 // factor)
|
| 236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
| 237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
| 238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
| 239 |
+
self.up4 = Up(128, 64, bilinear)
|
| 240 |
+
#per pixel multiple channels may exist
|
| 241 |
+
self.out_mean = OutConv(64, out_channels)
|
| 242 |
+
#variance will always be a single number for a pixel
|
| 243 |
+
self.out_alpha = nn.Sequential(
|
| 244 |
+
OutConv(64, 128),
|
| 245 |
+
OutConv(128, 1),
|
| 246 |
+
nn.ReLU()
|
| 247 |
+
)
|
| 248 |
+
self.out_beta = nn.Sequential(
|
| 249 |
+
OutConv(64, 128),
|
| 250 |
+
OutConv(128, 1),
|
| 251 |
+
nn.ReLU()
|
| 252 |
+
)
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
x1 = self.inc(x)
|
| 255 |
+
x2 = self.down1(x1)
|
| 256 |
+
x3 = self.down2(x2)
|
| 257 |
+
x4 = self.down3(x3)
|
| 258 |
+
x5 = self.down4(x4)
|
| 259 |
+
x = self.up1(x5, x4)
|
| 260 |
+
x = self.up2(x, x3)
|
| 261 |
+
x = self.up3(x, x2)
|
| 262 |
+
x = self.up4(x, x1)
|
| 263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
| 264 |
+
self.out_alpha(x), self.out_beta(x)
|
| 265 |
+
return y_mean, y_alpha, y_beta
|
| 266 |
+
|
| 267 |
+
class ResidualBlock(nn.Module):
|
| 268 |
+
def __init__(self, in_features):
|
| 269 |
+
super(ResidualBlock, self).__init__()
|
| 270 |
+
conv_block = [
|
| 271 |
+
nn.ReflectionPad2d(1),
|
| 272 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 273 |
+
nn.InstanceNorm2d(in_features),
|
| 274 |
+
nn.ReLU(inplace=True),
|
| 275 |
+
nn.ReflectionPad2d(1),
|
| 276 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 277 |
+
nn.InstanceNorm2d(in_features)
|
| 278 |
+
]
|
| 279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 280 |
+
def forward(self, x):
|
| 281 |
+
return x + self.conv_block(x)
|
| 282 |
+
|
| 283 |
+
class Generator(nn.Module):
|
| 284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
| 285 |
+
super(Generator, self).__init__()
|
| 286 |
+
# Initial convolution block
|
| 287 |
+
model = [
|
| 288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
| 289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
| 290 |
+
]
|
| 291 |
+
# Downsampling
|
| 292 |
+
in_features = 64
|
| 293 |
+
out_features = in_features*2
|
| 294 |
+
for _ in range(2):
|
| 295 |
+
model += [
|
| 296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
| 297 |
+
nn.InstanceNorm2d(out_features),
|
| 298 |
+
nn.ReLU(inplace=True)
|
| 299 |
+
]
|
| 300 |
+
in_features = out_features
|
| 301 |
+
out_features = in_features*2
|
| 302 |
+
# Residual blocks
|
| 303 |
+
for _ in range(n_residual_blocks):
|
| 304 |
+
model += [ResidualBlock(in_features)]
|
| 305 |
+
# Upsampling
|
| 306 |
+
out_features = in_features//2
|
| 307 |
+
for _ in range(2):
|
| 308 |
+
model += [
|
| 309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
| 310 |
+
nn.InstanceNorm2d(out_features),
|
| 311 |
+
nn.ReLU(inplace=True)
|
| 312 |
+
]
|
| 313 |
+
in_features = out_features
|
| 314 |
+
out_features = in_features//2
|
| 315 |
+
# Output layer
|
| 316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
| 317 |
+
self.model = nn.Sequential(*model)
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
return self.model(x)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class ResnetGenerator(nn.Module):
|
| 323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
| 324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
| 328 |
+
"""Construct a Resnet-based generator
|
| 329 |
+
Parameters:
|
| 330 |
+
input_nc (int) -- the number of channels in input images
|
| 331 |
+
output_nc (int) -- the number of channels in output images
|
| 332 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 333 |
+
norm_layer -- normalization layer
|
| 334 |
+
use_dropout (bool) -- if use dropout layers
|
| 335 |
+
n_blocks (int) -- the number of ResNet blocks
|
| 336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
| 337 |
+
"""
|
| 338 |
+
assert(n_blocks >= 0)
|
| 339 |
+
super(ResnetGenerator, self).__init__()
|
| 340 |
+
if type(norm_layer) == functools.partial:
|
| 341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 342 |
+
else:
|
| 343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 344 |
+
|
| 345 |
+
model = [nn.ReflectionPad2d(3),
|
| 346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
| 347 |
+
norm_layer(ngf),
|
| 348 |
+
nn.ReLU(True)]
|
| 349 |
+
|
| 350 |
+
n_downsampling = 2
|
| 351 |
+
for i in range(n_downsampling): # add downsampling layers
|
| 352 |
+
mult = 2 ** i
|
| 353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
| 354 |
+
norm_layer(ngf * mult * 2),
|
| 355 |
+
nn.ReLU(True)]
|
| 356 |
+
|
| 357 |
+
mult = 2 ** n_downsampling
|
| 358 |
+
for i in range(n_blocks): # add ResNet blocks
|
| 359 |
+
|
| 360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
| 361 |
+
|
| 362 |
+
for i in range(n_downsampling): # add upsampling layers
|
| 363 |
+
mult = 2 ** (n_downsampling - i)
|
| 364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
| 365 |
+
kernel_size=3, stride=2,
|
| 366 |
+
padding=1, output_padding=1,
|
| 367 |
+
bias=use_bias),
|
| 368 |
+
norm_layer(int(ngf * mult / 2)),
|
| 369 |
+
nn.ReLU(True)]
|
| 370 |
+
model += [nn.ReflectionPad2d(3)]
|
| 371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
| 372 |
+
model += [nn.Tanh()]
|
| 373 |
+
|
| 374 |
+
self.model = nn.Sequential(*model)
|
| 375 |
+
|
| 376 |
+
def forward(self, input):
|
| 377 |
+
"""Standard forward"""
|
| 378 |
+
return self.model(input)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class ResnetBlock(nn.Module):
|
| 382 |
+
"""Define a Resnet block"""
|
| 383 |
+
|
| 384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 385 |
+
"""Initialize the Resnet block
|
| 386 |
+
A resnet block is a conv block with skip connections
|
| 387 |
+
We construct a conv block with build_conv_block function,
|
| 388 |
+
and implement skip connections in <forward> function.
|
| 389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
| 390 |
+
"""
|
| 391 |
+
super(ResnetBlock, self).__init__()
|
| 392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
| 393 |
+
|
| 394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 395 |
+
"""Construct a convolutional block.
|
| 396 |
+
Parameters:
|
| 397 |
+
dim (int) -- the number of channels in the conv layer.
|
| 398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
| 399 |
+
norm_layer -- normalization layer
|
| 400 |
+
use_dropout (bool) -- if use dropout layers.
|
| 401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
| 402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
| 403 |
+
"""
|
| 404 |
+
conv_block = []
|
| 405 |
+
p = 0
|
| 406 |
+
if padding_type == 'reflect':
|
| 407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 408 |
+
elif padding_type == 'replicate':
|
| 409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 410 |
+
elif padding_type == 'zero':
|
| 411 |
+
p = 1
|
| 412 |
+
else:
|
| 413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 414 |
+
|
| 415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
| 416 |
+
if use_dropout:
|
| 417 |
+
conv_block += [nn.Dropout(0.5)]
|
| 418 |
+
|
| 419 |
+
p = 0
|
| 420 |
+
if padding_type == 'reflect':
|
| 421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 422 |
+
elif padding_type == 'replicate':
|
| 423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 424 |
+
elif padding_type == 'zero':
|
| 425 |
+
p = 1
|
| 426 |
+
else:
|
| 427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
| 429 |
+
|
| 430 |
+
return nn.Sequential(*conv_block)
|
| 431 |
+
|
| 432 |
+
def forward(self, x):
|
| 433 |
+
"""Forward function (with skip connections)"""
|
| 434 |
+
out = x + self.conv_block(x) # add skip connections
|
| 435 |
+
return out
|
| 436 |
+
|
| 437 |
+
### discriminator
|
| 438 |
+
class NLayerDiscriminator(nn.Module):
|
| 439 |
+
"""Defines a PatchGAN discriminator"""
|
| 440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
| 441 |
+
"""Construct a PatchGAN discriminator
|
| 442 |
+
Parameters:
|
| 443 |
+
input_nc (int) -- the number of channels in input images
|
| 444 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
| 446 |
+
norm_layer -- normalization layer
|
| 447 |
+
"""
|
| 448 |
+
super(NLayerDiscriminator, self).__init__()
|
| 449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
| 450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 451 |
+
else:
|
| 452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 453 |
+
kw = 4
|
| 454 |
+
padw = 1
|
| 455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| 456 |
+
nf_mult = 1
|
| 457 |
+
nf_mult_prev = 1
|
| 458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 459 |
+
nf_mult_prev = nf_mult
|
| 460 |
+
nf_mult = min(2 ** n, 8)
|
| 461 |
+
sequence += [
|
| 462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| 463 |
+
norm_layer(ndf * nf_mult),
|
| 464 |
+
nn.LeakyReLU(0.2, True)
|
| 465 |
+
]
|
| 466 |
+
nf_mult_prev = nf_mult
|
| 467 |
+
nf_mult = min(2 ** n_layers, 8)
|
| 468 |
+
sequence += [
|
| 469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| 470 |
+
norm_layer(ndf * nf_mult),
|
| 471 |
+
nn.LeakyReLU(0.2, True)
|
| 472 |
+
]
|
| 473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
| 474 |
+
self.model = nn.Sequential(*sequence)
|
| 475 |
+
def forward(self, input):
|
| 476 |
+
"""Standard forward."""
|
| 477 |
+
return self.model(input)
|
src/utils.py
ADDED
|
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
from glob import glob
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import kornia
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
import albumentations as albu
|
| 13 |
+
import functools
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
import torchvision as tv
|
| 20 |
+
import torchvision.models as models
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from torchvision.transforms import functional as F
|
| 23 |
+
from losses import TempCombLoss
|
| 24 |
+
|
| 25 |
+
########### DeblurGAN function
|
| 26 |
+
def get_norm_layer(norm_type='instance'):
|
| 27 |
+
if norm_type == 'batch':
|
| 28 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
| 29 |
+
elif norm_type == 'instance':
|
| 30 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 33 |
+
return norm_layer
|
| 34 |
+
|
| 35 |
+
def _array_to_batch(x):
|
| 36 |
+
x = np.transpose(x, (2, 0, 1))
|
| 37 |
+
x = np.expand_dims(x, 0)
|
| 38 |
+
return torch.from_numpy(x)
|
| 39 |
+
|
| 40 |
+
def get_normalize():
|
| 41 |
+
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 42 |
+
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
| 43 |
+
|
| 44 |
+
def process(a, b):
|
| 45 |
+
r = normalize(image=a, target=b)
|
| 46 |
+
return r['image'], r['target']
|
| 47 |
+
|
| 48 |
+
return process
|
| 49 |
+
|
| 50 |
+
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
|
| 51 |
+
x, _ = get_normalize()(x, x)
|
| 52 |
+
if mask is None:
|
| 53 |
+
mask = np.ones_like(x, dtype=np.float32)
|
| 54 |
+
else:
|
| 55 |
+
mask = np.round(mask.astype('float32') / 255)
|
| 56 |
+
|
| 57 |
+
h, w, _ = x.shape
|
| 58 |
+
block_size = 32
|
| 59 |
+
min_height = (h // block_size + 1) * block_size
|
| 60 |
+
min_width = (w // block_size + 1) * block_size
|
| 61 |
+
|
| 62 |
+
pad_params = {'mode': 'constant',
|
| 63 |
+
'constant_values': 0,
|
| 64 |
+
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
| 65 |
+
}
|
| 66 |
+
x = np.pad(x, **pad_params)
|
| 67 |
+
mask = np.pad(mask, **pad_params)
|
| 68 |
+
|
| 69 |
+
return map(_array_to_batch, (x, mask)), h, w
|
| 70 |
+
|
| 71 |
+
def postprocess(x: torch.Tensor) -> np.ndarray:
|
| 72 |
+
x, = x
|
| 73 |
+
x = x.detach().cpu().float().numpy()
|
| 74 |
+
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 75 |
+
return x.astype('uint8')
|
| 76 |
+
|
| 77 |
+
def sorted_glob(pattern):
|
| 78 |
+
return sorted(glob(pattern))
|
| 79 |
+
###########
|
| 80 |
+
|
| 81 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
| 82 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
| 83 |
+
Args:
|
| 84 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
| 85 |
+
Returns:
|
| 86 |
+
Normalized image data. Data range [0, 1].
|
| 87 |
+
"""
|
| 88 |
+
return image.astype(np.float64) / 255.0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
| 92 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
| 93 |
+
Args:
|
| 94 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
| 95 |
+
Returns:
|
| 96 |
+
Denormalized image data. Data range [0, 255].
|
| 97 |
+
"""
|
| 98 |
+
return image.astype(np.float64) * 255.0
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
| 102 |
+
"""Convert ``PIL.Image`` to Tensor.
|
| 103 |
+
Args:
|
| 104 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
| 105 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
| 106 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
| 107 |
+
Returns:
|
| 108 |
+
Normalized image data
|
| 109 |
+
Examples:
|
| 110 |
+
>>> image = Image.open("image.bmp")
|
| 111 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
| 112 |
+
"""
|
| 113 |
+
tensor = F.to_tensor(image)
|
| 114 |
+
|
| 115 |
+
if range_norm:
|
| 116 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
| 117 |
+
if half:
|
| 118 |
+
tensor = tensor.half()
|
| 119 |
+
|
| 120 |
+
return tensor
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
| 124 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
| 125 |
+
Args:
|
| 126 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
| 127 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
| 128 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
| 129 |
+
Returns:
|
| 130 |
+
Convert image data to support PIL library
|
| 131 |
+
Examples:
|
| 132 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
| 133 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
| 134 |
+
"""
|
| 135 |
+
if range_norm:
|
| 136 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
| 137 |
+
if half:
|
| 138 |
+
tensor = tensor.half()
|
| 139 |
+
|
| 140 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
| 141 |
+
|
| 142 |
+
return image
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def convert_rgb_to_y(image: Any) -> Any:
|
| 146 |
+
"""Convert RGB image or tensor image data to YCbCr(Y) format.
|
| 147 |
+
Args:
|
| 148 |
+
image: RGB image data read by ``PIL.Image''.
|
| 149 |
+
Returns:
|
| 150 |
+
Y image array data.
|
| 151 |
+
"""
|
| 152 |
+
if type(image) == np.ndarray:
|
| 153 |
+
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
| 154 |
+
elif type(image) == torch.Tensor:
|
| 155 |
+
if len(image.shape) == 4:
|
| 156 |
+
image = image.squeeze_(0)
|
| 157 |
+
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
| 158 |
+
else:
|
| 159 |
+
raise Exception("Unknown Type", type(image))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def convert_rgb_to_ycbcr(image: Any) -> Any:
|
| 163 |
+
"""Convert RGB image or tensor image data to YCbCr format.
|
| 164 |
+
Args:
|
| 165 |
+
image: RGB image data read by ``PIL.Image''.
|
| 166 |
+
Returns:
|
| 167 |
+
YCbCr image array data.
|
| 168 |
+
"""
|
| 169 |
+
if type(image) == np.ndarray:
|
| 170 |
+
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
| 171 |
+
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
|
| 172 |
+
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
|
| 173 |
+
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
| 174 |
+
elif type(image) == torch.Tensor:
|
| 175 |
+
if len(image.shape) == 4:
|
| 176 |
+
image = image.squeeze(0)
|
| 177 |
+
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
| 178 |
+
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
|
| 179 |
+
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
|
| 180 |
+
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
| 181 |
+
else:
|
| 182 |
+
raise Exception("Unknown Type", type(image))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def convert_ycbcr_to_rgb(image: Any) -> Any:
|
| 186 |
+
"""Convert YCbCr format image to RGB format.
|
| 187 |
+
Args:
|
| 188 |
+
image: YCbCr image data read by ``PIL.Image''.
|
| 189 |
+
Returns:
|
| 190 |
+
RGB image array data.
|
| 191 |
+
"""
|
| 192 |
+
if type(image) == np.ndarray:
|
| 193 |
+
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
|
| 194 |
+
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
|
| 195 |
+
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
|
| 196 |
+
return np.array([r, g, b]).transpose([1, 2, 0])
|
| 197 |
+
elif type(image) == torch.Tensor:
|
| 198 |
+
if len(image.shape) == 4:
|
| 199 |
+
image = image.squeeze(0)
|
| 200 |
+
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
|
| 201 |
+
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
|
| 202 |
+
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
|
| 203 |
+
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
| 204 |
+
else:
|
| 205 |
+
raise Exception("Unknown Type", type(image))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
| 209 |
+
"""Cut ``PIL.Image`` in the center area of the image.
|
| 210 |
+
Args:
|
| 211 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 212 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 213 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
| 214 |
+
upscale_factor (int): magnification factor.
|
| 215 |
+
Returns:
|
| 216 |
+
Randomly cropped low-resolution images and high-resolution images.
|
| 217 |
+
"""
|
| 218 |
+
w, h = hr.size
|
| 219 |
+
|
| 220 |
+
left = (w - image_size) // 2
|
| 221 |
+
top = (h - image_size) // 2
|
| 222 |
+
right = left + image_size
|
| 223 |
+
bottom = top + image_size
|
| 224 |
+
|
| 225 |
+
lr = lr.crop((left // upscale_factor,
|
| 226 |
+
top // upscale_factor,
|
| 227 |
+
right // upscale_factor,
|
| 228 |
+
bottom // upscale_factor))
|
| 229 |
+
hr = hr.crop((left, top, right, bottom))
|
| 230 |
+
|
| 231 |
+
return lr, hr
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
| 235 |
+
"""Will ``PIL.Image`` randomly capture the specified area of the image.
|
| 236 |
+
Args:
|
| 237 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 238 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 239 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
| 240 |
+
upscale_factor (int): magnification factor.
|
| 241 |
+
Returns:
|
| 242 |
+
Randomly cropped low-resolution images and high-resolution images.
|
| 243 |
+
"""
|
| 244 |
+
w, h = hr.size
|
| 245 |
+
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
|
| 246 |
+
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
|
| 247 |
+
right = left + image_size
|
| 248 |
+
bottom = top + image_size
|
| 249 |
+
|
| 250 |
+
lr = lr.crop((left // upscale_factor,
|
| 251 |
+
top // upscale_factor,
|
| 252 |
+
right // upscale_factor,
|
| 253 |
+
bottom // upscale_factor))
|
| 254 |
+
hr = hr.crop((left, top, right, bottom))
|
| 255 |
+
|
| 256 |
+
return lr, hr
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
|
| 260 |
+
"""Will ``PIL.Image`` randomly rotate the image.
|
| 261 |
+
Args:
|
| 262 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 263 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 264 |
+
angle (int): rotation angle, clockwise and counterclockwise rotation.
|
| 265 |
+
Returns:
|
| 266 |
+
Randomly rotated low-resolution images and high-resolution images.
|
| 267 |
+
"""
|
| 268 |
+
angle = random.choice((+angle, -angle))
|
| 269 |
+
lr = F.rotate(lr, angle)
|
| 270 |
+
hr = F.rotate(hr, angle)
|
| 271 |
+
|
| 272 |
+
return lr, hr
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
| 276 |
+
"""Flip the ``PIL.Image`` image horizontally randomly.
|
| 277 |
+
Args:
|
| 278 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 279 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 280 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
| 281 |
+
Returns:
|
| 282 |
+
Low-resolution image and high-resolution image after random horizontal flip.
|
| 283 |
+
"""
|
| 284 |
+
if torch.rand(1).item() > p:
|
| 285 |
+
lr = F.hflip(lr)
|
| 286 |
+
hr = F.hflip(hr)
|
| 287 |
+
|
| 288 |
+
return lr, hr
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
| 292 |
+
"""Turn the ``PIL.Image`` image upside down randomly.
|
| 293 |
+
Args:
|
| 294 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 295 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 296 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
| 297 |
+
Returns:
|
| 298 |
+
Randomly rotated up and down low-resolution images and high-resolution images.
|
| 299 |
+
"""
|
| 300 |
+
if torch.rand(1).item() > p:
|
| 301 |
+
lr = F.vflip(lr)
|
| 302 |
+
hr = F.vflip(hr)
|
| 303 |
+
|
| 304 |
+
return lr, hr
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
|
| 308 |
+
"""Set ``PIL.Image`` to randomly adjust the image brightness.
|
| 309 |
+
Args:
|
| 310 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 311 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 312 |
+
Returns:
|
| 313 |
+
Low-resolution image and high-resolution image with randomly adjusted brightness.
|
| 314 |
+
"""
|
| 315 |
+
# Randomly adjust the brightness gain range.
|
| 316 |
+
factor = random.uniform(0.5, 2)
|
| 317 |
+
lr = F.adjust_brightness(lr, factor)
|
| 318 |
+
hr = F.adjust_brightness(hr, factor)
|
| 319 |
+
|
| 320 |
+
return lr, hr
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
|
| 324 |
+
"""Set ``PIL.Image`` to randomly adjust the image contrast.
|
| 325 |
+
Args:
|
| 326 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 327 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 328 |
+
Returns:
|
| 329 |
+
Low-resolution image and high-resolution image with randomly adjusted contrast.
|
| 330 |
+
"""
|
| 331 |
+
# Randomly adjust the contrast gain range.
|
| 332 |
+
factor = random.uniform(0.5, 2)
|
| 333 |
+
lr = F.adjust_contrast(lr, factor)
|
| 334 |
+
hr = F.adjust_contrast(hr, factor)
|
| 335 |
+
|
| 336 |
+
return lr, hr
|
| 337 |
+
|
| 338 |
+
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
|
| 339 |
+
def img_mae(x1, x2):
|
| 340 |
+
m = torch.abs(x1-x2).mean()
|
| 341 |
+
return m
|
| 342 |
+
|
| 343 |
+
def img_mse(x1, x2):
|
| 344 |
+
m = torch.pow(torch.abs(x1-x2),2).mean()
|
| 345 |
+
return m
|
| 346 |
+
|
| 347 |
+
def img_psnr(x1, x2):
|
| 348 |
+
m = kornia.metrics.psnr(x1, x2, 1)
|
| 349 |
+
return m
|
| 350 |
+
|
| 351 |
+
def img_ssim(x1, x2):
|
| 352 |
+
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
|
| 353 |
+
m = m.mean()
|
| 354 |
+
return m
|
| 355 |
+
|
| 356 |
+
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
|
| 357 |
+
'''
|
| 358 |
+
xLR/SR/HR: 3xHxW
|
| 359 |
+
xSRvar: 1xHxW
|
| 360 |
+
'''
|
| 361 |
+
plt.figure(figsize=(30,10))
|
| 362 |
+
|
| 363 |
+
plt.subplot(1,5,1)
|
| 364 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 365 |
+
plt.axis('off')
|
| 366 |
+
|
| 367 |
+
plt.subplot(1,5,2)
|
| 368 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 369 |
+
plt.axis('off')
|
| 370 |
+
|
| 371 |
+
plt.subplot(1,5,3)
|
| 372 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 373 |
+
plt.axis('off')
|
| 374 |
+
|
| 375 |
+
plt.subplot(1,5,4)
|
| 376 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
| 377 |
+
print('error', error_map.min(), error_map.max())
|
| 378 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
| 379 |
+
plt.clim(elim[0], elim[1])
|
| 380 |
+
plt.axis('off')
|
| 381 |
+
|
| 382 |
+
plt.subplot(1,5,5)
|
| 383 |
+
print('uncer', xSRvar.min(), xSRvar.max())
|
| 384 |
+
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 385 |
+
plt.clim(ulim[0], ulim[1])
|
| 386 |
+
plt.axis('off')
|
| 387 |
+
|
| 388 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 389 |
+
plt.show()
|
| 390 |
+
|
| 391 |
+
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
|
| 392 |
+
'''
|
| 393 |
+
xLR/SR/HR: 3xHxW
|
| 394 |
+
'''
|
| 395 |
+
plt.figure(figsize=(30,10))
|
| 396 |
+
|
| 397 |
+
if task != 'm':
|
| 398 |
+
plt.subplot(1,4,1)
|
| 399 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 400 |
+
plt.axis('off')
|
| 401 |
+
|
| 402 |
+
plt.subplot(1,4,2)
|
| 403 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 404 |
+
plt.axis('off')
|
| 405 |
+
|
| 406 |
+
plt.subplot(1,4,3)
|
| 407 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 408 |
+
plt.axis('off')
|
| 409 |
+
else:
|
| 410 |
+
plt.subplot(1,4,1)
|
| 411 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 412 |
+
plt.clim(0,0.9)
|
| 413 |
+
plt.axis('off')
|
| 414 |
+
|
| 415 |
+
plt.subplot(1,4,2)
|
| 416 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 417 |
+
plt.clim(0,0.9)
|
| 418 |
+
plt.axis('off')
|
| 419 |
+
|
| 420 |
+
plt.subplot(1,4,3)
|
| 421 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 422 |
+
plt.clim(0,0.9)
|
| 423 |
+
plt.axis('off')
|
| 424 |
+
|
| 425 |
+
plt.subplot(1,4,4)
|
| 426 |
+
if task == 'inpainting':
|
| 427 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
|
| 428 |
+
else:
|
| 429 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
| 430 |
+
print('error', error_map.min(), error_map.max())
|
| 431 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
| 432 |
+
plt.clim(elim[0], elim[1])
|
| 433 |
+
plt.axis('off')
|
| 434 |
+
|
| 435 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 436 |
+
plt.show()
|
| 437 |
+
|
| 438 |
+
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
|
| 439 |
+
'''
|
| 440 |
+
xSRvar: 1xHxW
|
| 441 |
+
'''
|
| 442 |
+
plt.figure(figsize=(30,10))
|
| 443 |
+
|
| 444 |
+
plt.subplot(1,4,1)
|
| 445 |
+
print('uncer', xSRvar1.min(), xSRvar1.max())
|
| 446 |
+
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 447 |
+
plt.clim(ulim[0], ulim[1])
|
| 448 |
+
plt.axis('off')
|
| 449 |
+
|
| 450 |
+
plt.subplot(1,4,2)
|
| 451 |
+
print('uncer', xSRvar2.min(), xSRvar2.max())
|
| 452 |
+
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 453 |
+
plt.clim(ulim[0], ulim[1])
|
| 454 |
+
plt.axis('off')
|
| 455 |
+
|
| 456 |
+
plt.subplot(1,4,3)
|
| 457 |
+
print('uncer', xSRvar3.min(), xSRvar3.max())
|
| 458 |
+
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 459 |
+
plt.clim(ulim[0], ulim[1])
|
| 460 |
+
plt.axis('off')
|
| 461 |
+
|
| 462 |
+
plt.subplot(1,4,4)
|
| 463 |
+
print('uncer', xSRvar4.min(), xSRvar4.max())
|
| 464 |
+
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 465 |
+
plt.clim(ulim[0], ulim[1])
|
| 466 |
+
plt.axis('off')
|
| 467 |
+
|
| 468 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 469 |
+
plt.show()
|
| 470 |
+
|
| 471 |
+
def get_UCE(list_err, list_yout_var, num_bins=100):
|
| 472 |
+
err_min = np.min(list_err)
|
| 473 |
+
err_max = np.max(list_err)
|
| 474 |
+
err_len = (err_max-err_min)/num_bins
|
| 475 |
+
num_points = len(list_err)
|
| 476 |
+
|
| 477 |
+
bin_stats = {}
|
| 478 |
+
for i in range(num_bins):
|
| 479 |
+
bin_stats[i] = {
|
| 480 |
+
'start_idx': err_min + i*err_len,
|
| 481 |
+
'end_idx': err_min + (i+1)*err_len,
|
| 482 |
+
'num_points': 0,
|
| 483 |
+
'mean_err': 0,
|
| 484 |
+
'mean_var': 0,
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
for e,v in zip(list_err, list_yout_var):
|
| 488 |
+
for i in range(num_bins):
|
| 489 |
+
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
|
| 490 |
+
bin_stats[i]['num_points'] += 1
|
| 491 |
+
bin_stats[i]['mean_err'] += e
|
| 492 |
+
bin_stats[i]['mean_var'] += v
|
| 493 |
+
|
| 494 |
+
uce = 0
|
| 495 |
+
eps = 1e-8
|
| 496 |
+
for i in range(num_bins):
|
| 497 |
+
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
|
| 498 |
+
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
|
| 499 |
+
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
|
| 500 |
+
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
|
| 501 |
+
uce += bin_stats[i]['uce_bin']
|
| 502 |
+
|
| 503 |
+
list_x, list_y = [], []
|
| 504 |
+
for i in range(num_bins):
|
| 505 |
+
if bin_stats[i]['num_points']>0:
|
| 506 |
+
list_x.append(bin_stats[i]['mean_err'])
|
| 507 |
+
list_y.append(bin_stats[i]['mean_var'])
|
| 508 |
+
|
| 509 |
+
# sns.set_style('darkgrid')
|
| 510 |
+
# sns.scatterplot(x=list_x, y=list_y)
|
| 511 |
+
# sns.regplot(x=list_x, y=list_y, order=1)
|
| 512 |
+
# plt.xlabel('MSE', fontsize=34)
|
| 513 |
+
# plt.ylabel('Uncertainty', fontsize=34)
|
| 514 |
+
# plt.plot(list_x, list_x, color='r')
|
| 515 |
+
# plt.xlim(np.min(list_x), np.max(list_x))
|
| 516 |
+
# plt.ylim(np.min(list_err), np.max(list_x))
|
| 517 |
+
# plt.show()
|
| 518 |
+
|
| 519 |
+
return bin_stats, uce
|
| 520 |
+
|
| 521 |
+
##################### training BayesCap
|
| 522 |
+
def train_BayesCap(
|
| 523 |
+
NetC,
|
| 524 |
+
NetG,
|
| 525 |
+
train_loader,
|
| 526 |
+
eval_loader,
|
| 527 |
+
Cri = TempCombLoss(),
|
| 528 |
+
device='cuda',
|
| 529 |
+
dtype=torch.cuda.FloatTensor(),
|
| 530 |
+
init_lr=1e-4,
|
| 531 |
+
num_epochs=100,
|
| 532 |
+
eval_every=1,
|
| 533 |
+
ckpt_path='../ckpt/BayesCap',
|
| 534 |
+
T1=1e0,
|
| 535 |
+
T2=5e-2,
|
| 536 |
+
task=None,
|
| 537 |
+
):
|
| 538 |
+
NetC.to(device)
|
| 539 |
+
NetC.train()
|
| 540 |
+
NetG.to(device)
|
| 541 |
+
NetG.eval()
|
| 542 |
+
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
|
| 543 |
+
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
|
| 544 |
+
|
| 545 |
+
score = -1e8
|
| 546 |
+
all_loss = []
|
| 547 |
+
for eph in range(num_epochs):
|
| 548 |
+
eph_loss = 0
|
| 549 |
+
with tqdm(train_loader, unit='batch') as tepoch:
|
| 550 |
+
for (idx, batch) in enumerate(tepoch):
|
| 551 |
+
if idx>2000:
|
| 552 |
+
break
|
| 553 |
+
tepoch.set_description('Epoch {}'.format(eph))
|
| 554 |
+
##
|
| 555 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 556 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 557 |
+
if task == 'inpainting':
|
| 558 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 559 |
+
xMask = xMask.to(device).type(dtype)
|
| 560 |
+
# pass them through the network
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
if task == 'inpainting':
|
| 563 |
+
_, xSR1 = NetG(xLR, xMask)
|
| 564 |
+
elif task == 'depth':
|
| 565 |
+
xSR1 = NetG(xLR)[("disp", 0)]
|
| 566 |
+
else:
|
| 567 |
+
xSR1 = NetG(xLR)
|
| 568 |
+
# with torch.autograd.set_detect_anomaly(True):
|
| 569 |
+
xSR = xSR1.clone()
|
| 570 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 571 |
+
# print(xSRC_alpha)
|
| 572 |
+
optimizer.zero_grad()
|
| 573 |
+
if task == 'depth':
|
| 574 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
|
| 575 |
+
else:
|
| 576 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
|
| 577 |
+
# print(loss)
|
| 578 |
+
loss.backward()
|
| 579 |
+
optimizer.step()
|
| 580 |
+
##
|
| 581 |
+
eph_loss += loss.item()
|
| 582 |
+
tepoch.set_postfix(loss=loss.item())
|
| 583 |
+
eph_loss /= len(train_loader)
|
| 584 |
+
all_loss.append(eph_loss)
|
| 585 |
+
print('Avg. loss: {}'.format(eph_loss))
|
| 586 |
+
# evaluate and save the models
|
| 587 |
+
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
|
| 588 |
+
if eph%eval_every == 0:
|
| 589 |
+
curr_score = eval_BayesCap(
|
| 590 |
+
NetC,
|
| 591 |
+
NetG,
|
| 592 |
+
eval_loader,
|
| 593 |
+
device=device,
|
| 594 |
+
dtype=dtype,
|
| 595 |
+
task=task,
|
| 596 |
+
)
|
| 597 |
+
print('current score: {} | Last best score: {}'.format(curr_score, score))
|
| 598 |
+
if curr_score >= score:
|
| 599 |
+
score = curr_score
|
| 600 |
+
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
|
| 601 |
+
optim_scheduler.step()
|
| 602 |
+
|
| 603 |
+
#### get different uncertainty maps
|
| 604 |
+
def get_uncer_BayesCap(
|
| 605 |
+
NetC,
|
| 606 |
+
NetG,
|
| 607 |
+
xin,
|
| 608 |
+
task=None,
|
| 609 |
+
xMask=None,
|
| 610 |
+
):
|
| 611 |
+
with torch.no_grad():
|
| 612 |
+
if task == 'inpainting':
|
| 613 |
+
_, xSR = NetG(xin, xMask)
|
| 614 |
+
else:
|
| 615 |
+
xSR = NetG(xin)
|
| 616 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 617 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 618 |
+
b_map = xSRC_beta.to('cpu').data
|
| 619 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 620 |
+
|
| 621 |
+
return xSRvar
|
| 622 |
+
|
| 623 |
+
def get_uncer_TTDAp(
|
| 624 |
+
NetG,
|
| 625 |
+
xin,
|
| 626 |
+
p_mag=0.05,
|
| 627 |
+
num_runs=50,
|
| 628 |
+
task=None,
|
| 629 |
+
xMask=None,
|
| 630 |
+
):
|
| 631 |
+
list_xSR = []
|
| 632 |
+
with torch.no_grad():
|
| 633 |
+
for z in range(num_runs):
|
| 634 |
+
if task == 'inpainting':
|
| 635 |
+
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
|
| 636 |
+
else:
|
| 637 |
+
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
|
| 638 |
+
list_xSR.append(xSRz)
|
| 639 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 640 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 641 |
+
return xSRvar
|
| 642 |
+
|
| 643 |
+
def get_uncer_DO(
|
| 644 |
+
NetG,
|
| 645 |
+
xin,
|
| 646 |
+
dop=0.2,
|
| 647 |
+
num_runs=50,
|
| 648 |
+
task=None,
|
| 649 |
+
xMask=None,
|
| 650 |
+
):
|
| 651 |
+
list_xSR = []
|
| 652 |
+
with torch.no_grad():
|
| 653 |
+
for z in range(num_runs):
|
| 654 |
+
if task == 'inpainting':
|
| 655 |
+
_, xSRz = NetG(xin, xMask, dop=dop)
|
| 656 |
+
else:
|
| 657 |
+
xSRz = NetG(xin, dop=dop)
|
| 658 |
+
list_xSR.append(xSRz)
|
| 659 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 660 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 661 |
+
return xSRvar
|
| 662 |
+
|
| 663 |
+
################### Different eval functions
|
| 664 |
+
|
| 665 |
+
def eval_BayesCap(
|
| 666 |
+
NetC,
|
| 667 |
+
NetG,
|
| 668 |
+
eval_loader,
|
| 669 |
+
device='cuda',
|
| 670 |
+
dtype=torch.cuda.FloatTensor,
|
| 671 |
+
task=None,
|
| 672 |
+
xMask=None,
|
| 673 |
+
):
|
| 674 |
+
NetC.to(device)
|
| 675 |
+
NetC.eval()
|
| 676 |
+
NetG.to(device)
|
| 677 |
+
NetG.eval()
|
| 678 |
+
|
| 679 |
+
mean_ssim = 0
|
| 680 |
+
mean_psnr = 0
|
| 681 |
+
mean_mse = 0
|
| 682 |
+
mean_mae = 0
|
| 683 |
+
num_imgs = 0
|
| 684 |
+
list_error = []
|
| 685 |
+
list_var = []
|
| 686 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 687 |
+
for (idx, batch) in enumerate(tepoch):
|
| 688 |
+
tepoch.set_description('Validating ...')
|
| 689 |
+
##
|
| 690 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 691 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 692 |
+
if task == 'inpainting':
|
| 693 |
+
if xMask==None:
|
| 694 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 695 |
+
xMask = xMask.to(device).type(dtype)
|
| 696 |
+
else:
|
| 697 |
+
xMask = xMask.to(device).type(dtype)
|
| 698 |
+
# pass them through the network
|
| 699 |
+
with torch.no_grad():
|
| 700 |
+
if task == 'inpainting':
|
| 701 |
+
_, xSR = NetG(xLR, xMask)
|
| 702 |
+
elif task == 'depth':
|
| 703 |
+
xSR = NetG(xLR)[("disp", 0)]
|
| 704 |
+
else:
|
| 705 |
+
xSR = NetG(xLR)
|
| 706 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 707 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 708 |
+
b_map = xSRC_beta.to('cpu').data
|
| 709 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 710 |
+
n_batch = xSRC_mu.shape[0]
|
| 711 |
+
if task == 'depth':
|
| 712 |
+
xHR = xSR
|
| 713 |
+
for j in range(n_batch):
|
| 714 |
+
num_imgs += 1
|
| 715 |
+
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
|
| 716 |
+
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
|
| 717 |
+
mean_mse += img_mse(xSRC_mu[j], xHR[j])
|
| 718 |
+
mean_mae += img_mae(xSRC_mu[j], xHR[j])
|
| 719 |
+
|
| 720 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 721 |
+
|
| 722 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 723 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
| 724 |
+
list_error.extend(list(error_map.numpy()))
|
| 725 |
+
list_var.extend(list(var_map.numpy()))
|
| 726 |
+
##
|
| 727 |
+
mean_ssim /= num_imgs
|
| 728 |
+
mean_psnr /= num_imgs
|
| 729 |
+
mean_mse /= num_imgs
|
| 730 |
+
mean_mae /= num_imgs
|
| 731 |
+
print(
|
| 732 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 733 |
+
(
|
| 734 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 735 |
+
)
|
| 736 |
+
)
|
| 737 |
+
# print(len(list_error), len(list_var))
|
| 738 |
+
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
|
| 739 |
+
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
|
| 740 |
+
return mean_ssim
|
| 741 |
+
|
| 742 |
+
def eval_TTDA_p(
|
| 743 |
+
NetG,
|
| 744 |
+
eval_loader,
|
| 745 |
+
device='cuda',
|
| 746 |
+
dtype=torch.cuda.FloatTensor,
|
| 747 |
+
p_mag=0.05,
|
| 748 |
+
num_runs=50,
|
| 749 |
+
task = None,
|
| 750 |
+
xMask = None,
|
| 751 |
+
):
|
| 752 |
+
NetG.to(device)
|
| 753 |
+
NetG.eval()
|
| 754 |
+
|
| 755 |
+
mean_ssim = 0
|
| 756 |
+
mean_psnr = 0
|
| 757 |
+
mean_mse = 0
|
| 758 |
+
mean_mae = 0
|
| 759 |
+
num_imgs = 0
|
| 760 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 761 |
+
for (idx, batch) in enumerate(tepoch):
|
| 762 |
+
tepoch.set_description('Validating ...')
|
| 763 |
+
##
|
| 764 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 765 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 766 |
+
# pass them through the network
|
| 767 |
+
list_xSR = []
|
| 768 |
+
with torch.no_grad():
|
| 769 |
+
if task=='inpainting':
|
| 770 |
+
_, xSR = NetG(xLR, xMask)
|
| 771 |
+
else:
|
| 772 |
+
xSR = NetG(xLR)
|
| 773 |
+
for z in range(num_runs):
|
| 774 |
+
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
|
| 775 |
+
list_xSR.append(xSRz)
|
| 776 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 777 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 778 |
+
n_batch = xSR.shape[0]
|
| 779 |
+
for j in range(n_batch):
|
| 780 |
+
num_imgs += 1
|
| 781 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
| 782 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
| 783 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
| 784 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
| 785 |
+
|
| 786 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 787 |
+
|
| 788 |
+
mean_ssim /= num_imgs
|
| 789 |
+
mean_psnr /= num_imgs
|
| 790 |
+
mean_mse /= num_imgs
|
| 791 |
+
mean_mae /= num_imgs
|
| 792 |
+
print(
|
| 793 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 794 |
+
(
|
| 795 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 796 |
+
)
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
return mean_ssim
|
| 800 |
+
|
| 801 |
+
def eval_DO(
|
| 802 |
+
NetG,
|
| 803 |
+
eval_loader,
|
| 804 |
+
device='cuda',
|
| 805 |
+
dtype=torch.cuda.FloatTensor,
|
| 806 |
+
dop=0.2,
|
| 807 |
+
num_runs=50,
|
| 808 |
+
task=None,
|
| 809 |
+
xMask=None,
|
| 810 |
+
):
|
| 811 |
+
NetG.to(device)
|
| 812 |
+
NetG.eval()
|
| 813 |
+
|
| 814 |
+
mean_ssim = 0
|
| 815 |
+
mean_psnr = 0
|
| 816 |
+
mean_mse = 0
|
| 817 |
+
mean_mae = 0
|
| 818 |
+
num_imgs = 0
|
| 819 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 820 |
+
for (idx, batch) in enumerate(tepoch):
|
| 821 |
+
tepoch.set_description('Validating ...')
|
| 822 |
+
##
|
| 823 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 824 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 825 |
+
# pass them through the network
|
| 826 |
+
list_xSR = []
|
| 827 |
+
with torch.no_grad():
|
| 828 |
+
if task == 'inpainting':
|
| 829 |
+
_, xSR = NetG(xLR, xMask)
|
| 830 |
+
else:
|
| 831 |
+
xSR = NetG(xLR)
|
| 832 |
+
for z in range(num_runs):
|
| 833 |
+
xSRz = NetG(xLR, dop=dop)
|
| 834 |
+
list_xSR.append(xSRz)
|
| 835 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 836 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 837 |
+
n_batch = xSR.shape[0]
|
| 838 |
+
for j in range(n_batch):
|
| 839 |
+
num_imgs += 1
|
| 840 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
| 841 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
| 842 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
| 843 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
| 844 |
+
|
| 845 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 846 |
+
##
|
| 847 |
+
mean_ssim /= num_imgs
|
| 848 |
+
mean_psnr /= num_imgs
|
| 849 |
+
mean_mse /= num_imgs
|
| 850 |
+
mean_mae /= num_imgs
|
| 851 |
+
print(
|
| 852 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 853 |
+
(
|
| 854 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 855 |
+
)
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
return mean_ssim
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
############### compare all function
|
| 862 |
+
def compare_all(
|
| 863 |
+
NetC,
|
| 864 |
+
NetG,
|
| 865 |
+
eval_loader,
|
| 866 |
+
p_mag = 0.05,
|
| 867 |
+
dop = 0.2,
|
| 868 |
+
num_runs = 100,
|
| 869 |
+
device='cuda',
|
| 870 |
+
dtype=torch.cuda.FloatTensor,
|
| 871 |
+
task=None,
|
| 872 |
+
):
|
| 873 |
+
NetC.to(device)
|
| 874 |
+
NetC.eval()
|
| 875 |
+
NetG.to(device)
|
| 876 |
+
NetG.eval()
|
| 877 |
+
|
| 878 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 879 |
+
for (idx, batch) in enumerate(tepoch):
|
| 880 |
+
tepoch.set_description('Comparing ...')
|
| 881 |
+
##
|
| 882 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 883 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 884 |
+
if task == 'inpainting':
|
| 885 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 886 |
+
xMask = xMask.to(device).type(dtype)
|
| 887 |
+
# pass them through the network
|
| 888 |
+
with torch.no_grad():
|
| 889 |
+
if task == 'inpainting':
|
| 890 |
+
_, xSR = NetG(xLR, xMask)
|
| 891 |
+
else:
|
| 892 |
+
xSR = NetG(xLR)
|
| 893 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 894 |
+
|
| 895 |
+
if task == 'inpainting':
|
| 896 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
|
| 897 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
|
| 898 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
|
| 899 |
+
else:
|
| 900 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
|
| 901 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
|
| 902 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
|
| 903 |
+
|
| 904 |
+
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
|
| 905 |
+
|
| 906 |
+
n_batch = xSR.shape[0]
|
| 907 |
+
for j in range(n_batch):
|
| 908 |
+
if task=='s':
|
| 909 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j])
|
| 910 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
| 911 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
|
| 912 |
+
if task=='d':
|
| 913 |
+
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
|
| 914 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
| 915 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
| 916 |
+
if task=='inpainting':
|
| 917 |
+
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
|
| 918 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
|
| 919 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
| 920 |
+
if task=='m':
|
| 921 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
|
| 922 |
+
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
|
| 923 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
################# Degrading Identity
|
| 927 |
+
def degrage_BayesCap_p(
|
| 928 |
+
NetC,
|
| 929 |
+
NetG,
|
| 930 |
+
eval_loader,
|
| 931 |
+
device='cuda',
|
| 932 |
+
dtype=torch.cuda.FloatTensor,
|
| 933 |
+
num_runs=50,
|
| 934 |
+
):
|
| 935 |
+
NetC.to(device)
|
| 936 |
+
NetC.eval()
|
| 937 |
+
NetG.to(device)
|
| 938 |
+
NetG.eval()
|
| 939 |
+
|
| 940 |
+
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
|
| 941 |
+
list_s = []
|
| 942 |
+
list_p = []
|
| 943 |
+
list_u1 = []
|
| 944 |
+
list_u2 = []
|
| 945 |
+
list_c = []
|
| 946 |
+
for p_mag in p_mag_list:
|
| 947 |
+
mean_ssim = 0
|
| 948 |
+
mean_psnr = 0
|
| 949 |
+
mean_mse = 0
|
| 950 |
+
mean_mae = 0
|
| 951 |
+
num_imgs = 0
|
| 952 |
+
list_error = []
|
| 953 |
+
list_error2 = []
|
| 954 |
+
list_var = []
|
| 955 |
+
|
| 956 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 957 |
+
for (idx, batch) in enumerate(tepoch):
|
| 958 |
+
tepoch.set_description('Validating ...')
|
| 959 |
+
##
|
| 960 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 961 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 962 |
+
# pass them through the network
|
| 963 |
+
with torch.no_grad():
|
| 964 |
+
xSR = NetG(xLR)
|
| 965 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
|
| 966 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 967 |
+
b_map = xSRC_beta.to('cpu').data
|
| 968 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 969 |
+
n_batch = xSRC_mu.shape[0]
|
| 970 |
+
for j in range(n_batch):
|
| 971 |
+
num_imgs += 1
|
| 972 |
+
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
|
| 973 |
+
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
|
| 974 |
+
mean_mse += img_mse(xSRC_mu[j], xSR[j])
|
| 975 |
+
mean_mae += img_mae(xSRC_mu[j], xSR[j])
|
| 976 |
+
|
| 977 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 978 |
+
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 979 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
| 980 |
+
list_error.extend(list(error_map.numpy()))
|
| 981 |
+
list_error2.extend(list(error_map2.numpy()))
|
| 982 |
+
list_var.extend(list(var_map.numpy()))
|
| 983 |
+
##
|
| 984 |
+
mean_ssim /= num_imgs
|
| 985 |
+
mean_psnr /= num_imgs
|
| 986 |
+
mean_mse /= num_imgs
|
| 987 |
+
mean_mae /= num_imgs
|
| 988 |
+
print(
|
| 989 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 990 |
+
(
|
| 991 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 992 |
+
)
|
| 993 |
+
)
|
| 994 |
+
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
|
| 995 |
+
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
|
| 996 |
+
print('UCE1: ', uce1)
|
| 997 |
+
print('UCE2: ', uce2)
|
| 998 |
+
list_s.append(mean_ssim.item())
|
| 999 |
+
list_p.append(mean_psnr.item())
|
| 1000 |
+
list_u1.append(uce1)
|
| 1001 |
+
list_u2.append(uce2)
|
| 1002 |
+
|
| 1003 |
+
plt.plot(list_s)
|
| 1004 |
+
plt.show()
|
| 1005 |
+
plt.plot(list_p)
|
| 1006 |
+
plt.show()
|
| 1007 |
+
|
| 1008 |
+
plt.plot(list_u1, label='wrt SR output')
|
| 1009 |
+
plt.plot(list_u2, label='wrt BayesCap output')
|
| 1010 |
+
plt.legend()
|
| 1011 |
+
plt.show()
|
| 1012 |
+
|
| 1013 |
+
sns.set_style('darkgrid')
|
| 1014 |
+
fig,ax = plt.subplots()
|
| 1015 |
+
# make a plot
|
| 1016 |
+
ax.plot(p_mag_list, list_s, color="red", marker="o")
|
| 1017 |
+
# set x-axis label
|
| 1018 |
+
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
|
| 1019 |
+
# set y-axis label
|
| 1020 |
+
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
|
| 1021 |
+
|
| 1022 |
+
# twin object for two different y-axis on the sample plot
|
| 1023 |
+
ax2=ax.twinx()
|
| 1024 |
+
# make a plot with different y-axis using second axis object
|
| 1025 |
+
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
|
| 1026 |
+
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
|
| 1027 |
+
ax2.set_ylabel("UCE", color="green", fontsize=10)
|
| 1028 |
+
plt.legend(fontsize=10)
|
| 1029 |
+
plt.tight_layout()
|
| 1030 |
+
plt.show()
|
| 1031 |
+
|
| 1032 |
+
################# DeepFill_v2
|
| 1033 |
+
|
| 1034 |
+
# ----------------------------------------
|
| 1035 |
+
# PATH processing
|
| 1036 |
+
# ----------------------------------------
|
| 1037 |
+
def text_readlines(filename):
|
| 1038 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
| 1039 |
+
try:
|
| 1040 |
+
file = open(filename, 'r')
|
| 1041 |
+
except IOError:
|
| 1042 |
+
error = []
|
| 1043 |
+
return error
|
| 1044 |
+
content = file.readlines()
|
| 1045 |
+
# This for loop deletes the EOF (like \n)
|
| 1046 |
+
for i in range(len(content)):
|
| 1047 |
+
content[i] = content[i][:len(content[i])-1]
|
| 1048 |
+
file.close()
|
| 1049 |
+
return content
|
| 1050 |
+
|
| 1051 |
+
def savetxt(name, loss_log):
|
| 1052 |
+
np_loss_log = np.array(loss_log)
|
| 1053 |
+
np.savetxt(name, np_loss_log)
|
| 1054 |
+
|
| 1055 |
+
def get_files(path):
|
| 1056 |
+
# read a folder, return the complete path
|
| 1057 |
+
ret = []
|
| 1058 |
+
for root, dirs, files in os.walk(path):
|
| 1059 |
+
for filespath in files:
|
| 1060 |
+
ret.append(os.path.join(root, filespath))
|
| 1061 |
+
return ret
|
| 1062 |
+
|
| 1063 |
+
def get_names(path):
|
| 1064 |
+
# read a folder, return the image name
|
| 1065 |
+
ret = []
|
| 1066 |
+
for root, dirs, files in os.walk(path):
|
| 1067 |
+
for filespath in files:
|
| 1068 |
+
ret.append(filespath)
|
| 1069 |
+
return ret
|
| 1070 |
+
|
| 1071 |
+
def text_save(content, filename, mode = 'a'):
|
| 1072 |
+
# save a list to a txt
|
| 1073 |
+
# Try to save a list variable in txt file.
|
| 1074 |
+
file = open(filename, mode)
|
| 1075 |
+
for i in range(len(content)):
|
| 1076 |
+
file.write(str(content[i]) + '\n')
|
| 1077 |
+
file.close()
|
| 1078 |
+
|
| 1079 |
+
def check_path(path):
|
| 1080 |
+
if not os.path.exists(path):
|
| 1081 |
+
os.makedirs(path)
|
| 1082 |
+
|
| 1083 |
+
# ----------------------------------------
|
| 1084 |
+
# Validation and Sample at training
|
| 1085 |
+
# ----------------------------------------
|
| 1086 |
+
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
|
| 1087 |
+
# Save image one-by-one
|
| 1088 |
+
for i in range(len(img_list)):
|
| 1089 |
+
img = img_list[i]
|
| 1090 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
| 1091 |
+
img = img * 255
|
| 1092 |
+
# Process img_copy and do not destroy the data of img
|
| 1093 |
+
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
|
| 1094 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
| 1095 |
+
img_copy = img_copy.astype(np.uint8)
|
| 1096 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
| 1097 |
+
# Save to certain path
|
| 1098 |
+
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
|
| 1099 |
+
save_img_path = os.path.join(sample_folder, save_img_name)
|
| 1100 |
+
cv2.imwrite(save_img_path, img_copy)
|
| 1101 |
+
|
| 1102 |
+
def psnr(pred, target, pixel_max_cnt = 255):
|
| 1103 |
+
mse = torch.mul(target - pred, target - pred)
|
| 1104 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 1105 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
| 1106 |
+
return p
|
| 1107 |
+
|
| 1108 |
+
def grey_psnr(pred, target, pixel_max_cnt = 255):
|
| 1109 |
+
pred = torch.sum(pred, dim = 0)
|
| 1110 |
+
target = torch.sum(target, dim = 0)
|
| 1111 |
+
mse = torch.mul(target - pred, target - pred)
|
| 1112 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 1113 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
| 1114 |
+
return p
|
| 1115 |
+
|
| 1116 |
+
def ssim(pred, target):
|
| 1117 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
| 1118 |
+
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
| 1119 |
+
target = target[0]
|
| 1120 |
+
pred = pred[0]
|
| 1121 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
|
| 1122 |
+
return ssim
|
| 1123 |
+
|
| 1124 |
+
## for contextual attention
|
| 1125 |
+
|
| 1126 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
| 1127 |
+
"""
|
| 1128 |
+
Extract patches from images and put them in the C output dimension.
|
| 1129 |
+
:param padding:
|
| 1130 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
| 1131 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
| 1132 |
+
each dimension of images
|
| 1133 |
+
:param strides: [stride_rows, stride_cols]
|
| 1134 |
+
:param rates: [dilation_rows, dilation_cols]
|
| 1135 |
+
:return: A Tensor
|
| 1136 |
+
"""
|
| 1137 |
+
assert len(images.size()) == 4
|
| 1138 |
+
assert padding in ['same', 'valid']
|
| 1139 |
+
batch_size, channel, height, width = images.size()
|
| 1140 |
+
|
| 1141 |
+
if padding == 'same':
|
| 1142 |
+
images = same_padding(images, ksizes, strides, rates)
|
| 1143 |
+
elif padding == 'valid':
|
| 1144 |
+
pass
|
| 1145 |
+
else:
|
| 1146 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
| 1147 |
+
Only "same" or "valid" are supported.'.format(padding))
|
| 1148 |
+
|
| 1149 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
| 1150 |
+
dilation=rates,
|
| 1151 |
+
padding=0,
|
| 1152 |
+
stride=strides)
|
| 1153 |
+
patches = unfold(images)
|
| 1154 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
| 1155 |
+
|
| 1156 |
+
def same_padding(images, ksizes, strides, rates):
|
| 1157 |
+
assert len(images.size()) == 4
|
| 1158 |
+
batch_size, channel, rows, cols = images.size()
|
| 1159 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
| 1160 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
| 1161 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
| 1162 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
| 1163 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
| 1164 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
| 1165 |
+
# Pad the input
|
| 1166 |
+
padding_top = int(padding_rows / 2.)
|
| 1167 |
+
padding_left = int(padding_cols / 2.)
|
| 1168 |
+
padding_bottom = padding_rows - padding_top
|
| 1169 |
+
padding_right = padding_cols - padding_left
|
| 1170 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
| 1171 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
| 1172 |
+
return images
|
| 1173 |
+
|
| 1174 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
| 1175 |
+
if not axis:
|
| 1176 |
+
axis = range(len(x.shape))
|
| 1177 |
+
for i in sorted(axis, reverse=True):
|
| 1178 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
| 1179 |
+
return x
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
def reduce_std(x, axis=None, keepdim=False):
|
| 1183 |
+
if not axis:
|
| 1184 |
+
axis = range(len(x.shape))
|
| 1185 |
+
for i in sorted(axis, reverse=True):
|
| 1186 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
| 1187 |
+
return x
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
| 1191 |
+
if not axis:
|
| 1192 |
+
axis = range(len(x.shape))
|
| 1193 |
+
for i in sorted(axis, reverse=True):
|
| 1194 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
| 1195 |
+
return x
|
| 1196 |
+
|
| 1197 |
+
def random_mask(num_batch=1, mask_shape=(256,256)):
|
| 1198 |
+
list_mask = []
|
| 1199 |
+
for _ in range(num_batch):
|
| 1200 |
+
# rectangle mask
|
| 1201 |
+
image_height = mask_shape[0]
|
| 1202 |
+
image_width = mask_shape[1]
|
| 1203 |
+
max_delta_height = image_height//8
|
| 1204 |
+
max_delta_width = image_width//8
|
| 1205 |
+
height = image_height//4
|
| 1206 |
+
width = image_width//4
|
| 1207 |
+
max_t = image_height - height
|
| 1208 |
+
max_l = image_width - width
|
| 1209 |
+
t = random.randint(0, max_t)
|
| 1210 |
+
l = random.randint(0, max_l)
|
| 1211 |
+
# bbox = (t, l, height, width)
|
| 1212 |
+
h = random.randint(0, max_delta_height//2)
|
| 1213 |
+
w = random.randint(0, max_delta_width//2)
|
| 1214 |
+
mask = torch.zeros((1, 1, image_height, image_width))
|
| 1215 |
+
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
|
| 1216 |
+
rect_mask = mask
|
| 1217 |
+
|
| 1218 |
+
# brush mask
|
| 1219 |
+
min_num_vertex = 4
|
| 1220 |
+
max_num_vertex = 12
|
| 1221 |
+
mean_angle = 2 * math.pi / 5
|
| 1222 |
+
angle_range = 2 * math.pi / 15
|
| 1223 |
+
min_width = 12
|
| 1224 |
+
max_width = 40
|
| 1225 |
+
H, W = image_height, image_width
|
| 1226 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 1227 |
+
mask = Image.new('L', (W, H), 0)
|
| 1228 |
+
|
| 1229 |
+
for _ in range(np.random.randint(1, 4)):
|
| 1230 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 1231 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 1232 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 1233 |
+
angles = []
|
| 1234 |
+
vertex = []
|
| 1235 |
+
for i in range(num_vertex):
|
| 1236 |
+
if i % 2 == 0:
|
| 1237 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 1238 |
+
else:
|
| 1239 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 1240 |
+
|
| 1241 |
+
h, w = mask.size
|
| 1242 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 1243 |
+
for i in range(num_vertex):
|
| 1244 |
+
r = np.clip(
|
| 1245 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 1246 |
+
0, 2*average_radius)
|
| 1247 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 1248 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 1249 |
+
vertex.append((int(new_x), int(new_y)))
|
| 1250 |
+
|
| 1251 |
+
draw = ImageDraw.Draw(mask)
|
| 1252 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 1253 |
+
draw.line(vertex, fill=255, width=width)
|
| 1254 |
+
for v in vertex:
|
| 1255 |
+
draw.ellipse((v[0] - width//2,
|
| 1256 |
+
v[1] - width//2,
|
| 1257 |
+
v[0] + width//2,
|
| 1258 |
+
v[1] + width//2),
|
| 1259 |
+
fill=255)
|
| 1260 |
+
|
| 1261 |
+
if np.random.normal() > 0:
|
| 1262 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 1263 |
+
if np.random.normal() > 0:
|
| 1264 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
| 1265 |
+
|
| 1266 |
+
mask = transforms.ToTensor()(mask)
|
| 1267 |
+
mask = mask.reshape((1, 1, H, W))
|
| 1268 |
+
brush_mask = mask
|
| 1269 |
+
|
| 1270 |
+
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
|
| 1271 |
+
list_mask.append(mask)
|
| 1272 |
+
mask = torch.cat(list_mask, dim=0)
|
| 1273 |
+
return mask
|
utils.py
ADDED
|
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
from glob import glob
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import kornia
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
import albumentations as albu
|
| 13 |
+
import functools
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
import torchvision as tv
|
| 20 |
+
import torchvision.models as models
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from torchvision.transforms import functional as F
|
| 23 |
+
from losses import TempCombLoss
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
######## for loading checkpoint from googledrive
|
| 27 |
+
google_drive_paths = {
|
| 28 |
+
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
|
| 29 |
+
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def ensure_checkpoint_exists(model_weights_filename):
|
| 33 |
+
if not os.path.isfile(model_weights_filename) and (
|
| 34 |
+
model_weights_filename in google_drive_paths
|
| 35 |
+
):
|
| 36 |
+
gdrive_url = google_drive_paths[model_weights_filename]
|
| 37 |
+
try:
|
| 38 |
+
from gdown import download as drive_download
|
| 39 |
+
|
| 40 |
+
drive_download(gdrive_url, model_weights_filename, quiet=False)
|
| 41 |
+
except ModuleNotFoundError:
|
| 42 |
+
print(
|
| 43 |
+
"gdown module not found.",
|
| 44 |
+
"pip3 install gdown or, manually download the checkpoint file:",
|
| 45 |
+
gdrive_url
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if not os.path.isfile(model_weights_filename) and (
|
| 49 |
+
model_weights_filename not in google_drive_paths
|
| 50 |
+
):
|
| 51 |
+
print(
|
| 52 |
+
model_weights_filename,
|
| 53 |
+
" not found, you may need to manually download the model weights."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
########### DeblurGAN function
|
| 57 |
+
def get_norm_layer(norm_type='instance'):
|
| 58 |
+
if norm_type == 'batch':
|
| 59 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
| 60 |
+
elif norm_type == 'instance':
|
| 61 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 64 |
+
return norm_layer
|
| 65 |
+
|
| 66 |
+
def _array_to_batch(x):
|
| 67 |
+
x = np.transpose(x, (2, 0, 1))
|
| 68 |
+
x = np.expand_dims(x, 0)
|
| 69 |
+
return torch.from_numpy(x)
|
| 70 |
+
|
| 71 |
+
def get_normalize():
|
| 72 |
+
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 73 |
+
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
| 74 |
+
|
| 75 |
+
def process(a, b):
|
| 76 |
+
r = normalize(image=a, target=b)
|
| 77 |
+
return r['image'], r['target']
|
| 78 |
+
|
| 79 |
+
return process
|
| 80 |
+
|
| 81 |
+
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
|
| 82 |
+
x, _ = get_normalize()(x, x)
|
| 83 |
+
if mask is None:
|
| 84 |
+
mask = np.ones_like(x, dtype=np.float32)
|
| 85 |
+
else:
|
| 86 |
+
mask = np.round(mask.astype('float32') / 255)
|
| 87 |
+
|
| 88 |
+
h, w, _ = x.shape
|
| 89 |
+
block_size = 32
|
| 90 |
+
min_height = (h // block_size + 1) * block_size
|
| 91 |
+
min_width = (w // block_size + 1) * block_size
|
| 92 |
+
|
| 93 |
+
pad_params = {'mode': 'constant',
|
| 94 |
+
'constant_values': 0,
|
| 95 |
+
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
| 96 |
+
}
|
| 97 |
+
x = np.pad(x, **pad_params)
|
| 98 |
+
mask = np.pad(mask, **pad_params)
|
| 99 |
+
|
| 100 |
+
return map(_array_to_batch, (x, mask)), h, w
|
| 101 |
+
|
| 102 |
+
def postprocess(x: torch.Tensor) -> np.ndarray:
|
| 103 |
+
x, = x
|
| 104 |
+
x = x.detach().cpu().float().numpy()
|
| 105 |
+
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 106 |
+
return x.astype('uint8')
|
| 107 |
+
|
| 108 |
+
def sorted_glob(pattern):
|
| 109 |
+
return sorted(glob(pattern))
|
| 110 |
+
###########
|
| 111 |
+
|
| 112 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
| 113 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
| 114 |
+
Args:
|
| 115 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
| 116 |
+
Returns:
|
| 117 |
+
Normalized image data. Data range [0, 1].
|
| 118 |
+
"""
|
| 119 |
+
return image.astype(np.float64) / 255.0
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
| 123 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
| 124 |
+
Args:
|
| 125 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
| 126 |
+
Returns:
|
| 127 |
+
Denormalized image data. Data range [0, 255].
|
| 128 |
+
"""
|
| 129 |
+
return image.astype(np.float64) * 255.0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
| 133 |
+
"""Convert ``PIL.Image`` to Tensor.
|
| 134 |
+
Args:
|
| 135 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
| 136 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
| 137 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
| 138 |
+
Returns:
|
| 139 |
+
Normalized image data
|
| 140 |
+
Examples:
|
| 141 |
+
>>> image = Image.open("image.bmp")
|
| 142 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
| 143 |
+
"""
|
| 144 |
+
tensor = F.to_tensor(image)
|
| 145 |
+
|
| 146 |
+
if range_norm:
|
| 147 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
| 148 |
+
if half:
|
| 149 |
+
tensor = tensor.half()
|
| 150 |
+
|
| 151 |
+
return tensor
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
| 155 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
| 156 |
+
Args:
|
| 157 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
| 158 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
| 159 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
| 160 |
+
Returns:
|
| 161 |
+
Convert image data to support PIL library
|
| 162 |
+
Examples:
|
| 163 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
| 164 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
| 165 |
+
"""
|
| 166 |
+
if range_norm:
|
| 167 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
| 168 |
+
if half:
|
| 169 |
+
tensor = tensor.half()
|
| 170 |
+
|
| 171 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
| 172 |
+
|
| 173 |
+
return image
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def convert_rgb_to_y(image: Any) -> Any:
|
| 177 |
+
"""Convert RGB image or tensor image data to YCbCr(Y) format.
|
| 178 |
+
Args:
|
| 179 |
+
image: RGB image data read by ``PIL.Image''.
|
| 180 |
+
Returns:
|
| 181 |
+
Y image array data.
|
| 182 |
+
"""
|
| 183 |
+
if type(image) == np.ndarray:
|
| 184 |
+
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
| 185 |
+
elif type(image) == torch.Tensor:
|
| 186 |
+
if len(image.shape) == 4:
|
| 187 |
+
image = image.squeeze_(0)
|
| 188 |
+
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
| 189 |
+
else:
|
| 190 |
+
raise Exception("Unknown Type", type(image))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def convert_rgb_to_ycbcr(image: Any) -> Any:
|
| 194 |
+
"""Convert RGB image or tensor image data to YCbCr format.
|
| 195 |
+
Args:
|
| 196 |
+
image: RGB image data read by ``PIL.Image''.
|
| 197 |
+
Returns:
|
| 198 |
+
YCbCr image array data.
|
| 199 |
+
"""
|
| 200 |
+
if type(image) == np.ndarray:
|
| 201 |
+
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
| 202 |
+
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
|
| 203 |
+
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
|
| 204 |
+
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
| 205 |
+
elif type(image) == torch.Tensor:
|
| 206 |
+
if len(image.shape) == 4:
|
| 207 |
+
image = image.squeeze(0)
|
| 208 |
+
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
| 209 |
+
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
|
| 210 |
+
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
|
| 211 |
+
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
| 212 |
+
else:
|
| 213 |
+
raise Exception("Unknown Type", type(image))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def convert_ycbcr_to_rgb(image: Any) -> Any:
|
| 217 |
+
"""Convert YCbCr format image to RGB format.
|
| 218 |
+
Args:
|
| 219 |
+
image: YCbCr image data read by ``PIL.Image''.
|
| 220 |
+
Returns:
|
| 221 |
+
RGB image array data.
|
| 222 |
+
"""
|
| 223 |
+
if type(image) == np.ndarray:
|
| 224 |
+
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
|
| 225 |
+
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
|
| 226 |
+
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
|
| 227 |
+
return np.array([r, g, b]).transpose([1, 2, 0])
|
| 228 |
+
elif type(image) == torch.Tensor:
|
| 229 |
+
if len(image.shape) == 4:
|
| 230 |
+
image = image.squeeze(0)
|
| 231 |
+
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
|
| 232 |
+
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
|
| 233 |
+
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
|
| 234 |
+
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
| 235 |
+
else:
|
| 236 |
+
raise Exception("Unknown Type", type(image))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
| 240 |
+
"""Cut ``PIL.Image`` in the center area of the image.
|
| 241 |
+
Args:
|
| 242 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 243 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 244 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
| 245 |
+
upscale_factor (int): magnification factor.
|
| 246 |
+
Returns:
|
| 247 |
+
Randomly cropped low-resolution images and high-resolution images.
|
| 248 |
+
"""
|
| 249 |
+
w, h = hr.size
|
| 250 |
+
|
| 251 |
+
left = (w - image_size) // 2
|
| 252 |
+
top = (h - image_size) // 2
|
| 253 |
+
right = left + image_size
|
| 254 |
+
bottom = top + image_size
|
| 255 |
+
|
| 256 |
+
lr = lr.crop((left // upscale_factor,
|
| 257 |
+
top // upscale_factor,
|
| 258 |
+
right // upscale_factor,
|
| 259 |
+
bottom // upscale_factor))
|
| 260 |
+
hr = hr.crop((left, top, right, bottom))
|
| 261 |
+
|
| 262 |
+
return lr, hr
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
| 266 |
+
"""Will ``PIL.Image`` randomly capture the specified area of the image.
|
| 267 |
+
Args:
|
| 268 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 269 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 270 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
| 271 |
+
upscale_factor (int): magnification factor.
|
| 272 |
+
Returns:
|
| 273 |
+
Randomly cropped low-resolution images and high-resolution images.
|
| 274 |
+
"""
|
| 275 |
+
w, h = hr.size
|
| 276 |
+
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
|
| 277 |
+
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
|
| 278 |
+
right = left + image_size
|
| 279 |
+
bottom = top + image_size
|
| 280 |
+
|
| 281 |
+
lr = lr.crop((left // upscale_factor,
|
| 282 |
+
top // upscale_factor,
|
| 283 |
+
right // upscale_factor,
|
| 284 |
+
bottom // upscale_factor))
|
| 285 |
+
hr = hr.crop((left, top, right, bottom))
|
| 286 |
+
|
| 287 |
+
return lr, hr
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
|
| 291 |
+
"""Will ``PIL.Image`` randomly rotate the image.
|
| 292 |
+
Args:
|
| 293 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 294 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 295 |
+
angle (int): rotation angle, clockwise and counterclockwise rotation.
|
| 296 |
+
Returns:
|
| 297 |
+
Randomly rotated low-resolution images and high-resolution images.
|
| 298 |
+
"""
|
| 299 |
+
angle = random.choice((+angle, -angle))
|
| 300 |
+
lr = F.rotate(lr, angle)
|
| 301 |
+
hr = F.rotate(hr, angle)
|
| 302 |
+
|
| 303 |
+
return lr, hr
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
| 307 |
+
"""Flip the ``PIL.Image`` image horizontally randomly.
|
| 308 |
+
Args:
|
| 309 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 310 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 311 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
| 312 |
+
Returns:
|
| 313 |
+
Low-resolution image and high-resolution image after random horizontal flip.
|
| 314 |
+
"""
|
| 315 |
+
if torch.rand(1).item() > p:
|
| 316 |
+
lr = F.hflip(lr)
|
| 317 |
+
hr = F.hflip(hr)
|
| 318 |
+
|
| 319 |
+
return lr, hr
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
| 323 |
+
"""Turn the ``PIL.Image`` image upside down randomly.
|
| 324 |
+
Args:
|
| 325 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 326 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 327 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
| 328 |
+
Returns:
|
| 329 |
+
Randomly rotated up and down low-resolution images and high-resolution images.
|
| 330 |
+
"""
|
| 331 |
+
if torch.rand(1).item() > p:
|
| 332 |
+
lr = F.vflip(lr)
|
| 333 |
+
hr = F.vflip(hr)
|
| 334 |
+
|
| 335 |
+
return lr, hr
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
|
| 339 |
+
"""Set ``PIL.Image`` to randomly adjust the image brightness.
|
| 340 |
+
Args:
|
| 341 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 342 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 343 |
+
Returns:
|
| 344 |
+
Low-resolution image and high-resolution image with randomly adjusted brightness.
|
| 345 |
+
"""
|
| 346 |
+
# Randomly adjust the brightness gain range.
|
| 347 |
+
factor = random.uniform(0.5, 2)
|
| 348 |
+
lr = F.adjust_brightness(lr, factor)
|
| 349 |
+
hr = F.adjust_brightness(hr, factor)
|
| 350 |
+
|
| 351 |
+
return lr, hr
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
|
| 355 |
+
"""Set ``PIL.Image`` to randomly adjust the image contrast.
|
| 356 |
+
Args:
|
| 357 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
| 358 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
| 359 |
+
Returns:
|
| 360 |
+
Low-resolution image and high-resolution image with randomly adjusted contrast.
|
| 361 |
+
"""
|
| 362 |
+
# Randomly adjust the contrast gain range.
|
| 363 |
+
factor = random.uniform(0.5, 2)
|
| 364 |
+
lr = F.adjust_contrast(lr, factor)
|
| 365 |
+
hr = F.adjust_contrast(hr, factor)
|
| 366 |
+
|
| 367 |
+
return lr, hr
|
| 368 |
+
|
| 369 |
+
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
|
| 370 |
+
def img_mae(x1, x2):
|
| 371 |
+
m = torch.abs(x1-x2).mean()
|
| 372 |
+
return m
|
| 373 |
+
|
| 374 |
+
def img_mse(x1, x2):
|
| 375 |
+
m = torch.pow(torch.abs(x1-x2),2).mean()
|
| 376 |
+
return m
|
| 377 |
+
|
| 378 |
+
def img_psnr(x1, x2):
|
| 379 |
+
m = kornia.metrics.psnr(x1, x2, 1)
|
| 380 |
+
return m
|
| 381 |
+
|
| 382 |
+
def img_ssim(x1, x2):
|
| 383 |
+
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
|
| 384 |
+
m = m.mean()
|
| 385 |
+
return m
|
| 386 |
+
|
| 387 |
+
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
|
| 388 |
+
'''
|
| 389 |
+
xLR/SR/HR: 3xHxW
|
| 390 |
+
xSRvar: 1xHxW
|
| 391 |
+
'''
|
| 392 |
+
plt.figure(figsize=(30,10))
|
| 393 |
+
|
| 394 |
+
plt.subplot(1,5,1)
|
| 395 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 396 |
+
plt.axis('off')
|
| 397 |
+
|
| 398 |
+
plt.subplot(1,5,2)
|
| 399 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 400 |
+
plt.axis('off')
|
| 401 |
+
|
| 402 |
+
plt.subplot(1,5,3)
|
| 403 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 404 |
+
plt.axis('off')
|
| 405 |
+
|
| 406 |
+
plt.subplot(1,5,4)
|
| 407 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
| 408 |
+
print('error', error_map.min(), error_map.max())
|
| 409 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
| 410 |
+
plt.clim(elim[0], elim[1])
|
| 411 |
+
plt.axis('off')
|
| 412 |
+
|
| 413 |
+
plt.subplot(1,5,5)
|
| 414 |
+
print('uncer', xSRvar.min(), xSRvar.max())
|
| 415 |
+
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 416 |
+
plt.clim(ulim[0], ulim[1])
|
| 417 |
+
plt.axis('off')
|
| 418 |
+
|
| 419 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 420 |
+
plt.show()
|
| 421 |
+
|
| 422 |
+
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
|
| 423 |
+
'''
|
| 424 |
+
xLR/SR/HR: 3xHxW
|
| 425 |
+
'''
|
| 426 |
+
plt.figure(figsize=(30,10))
|
| 427 |
+
|
| 428 |
+
if task != 'm':
|
| 429 |
+
plt.subplot(1,4,1)
|
| 430 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 431 |
+
plt.axis('off')
|
| 432 |
+
|
| 433 |
+
plt.subplot(1,4,2)
|
| 434 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 435 |
+
plt.axis('off')
|
| 436 |
+
|
| 437 |
+
plt.subplot(1,4,3)
|
| 438 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
| 439 |
+
plt.axis('off')
|
| 440 |
+
else:
|
| 441 |
+
plt.subplot(1,4,1)
|
| 442 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 443 |
+
plt.clim(0,0.9)
|
| 444 |
+
plt.axis('off')
|
| 445 |
+
|
| 446 |
+
plt.subplot(1,4,2)
|
| 447 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 448 |
+
plt.clim(0,0.9)
|
| 449 |
+
plt.axis('off')
|
| 450 |
+
|
| 451 |
+
plt.subplot(1,4,3)
|
| 452 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
| 453 |
+
plt.clim(0,0.9)
|
| 454 |
+
plt.axis('off')
|
| 455 |
+
|
| 456 |
+
plt.subplot(1,4,4)
|
| 457 |
+
if task == 'inpainting':
|
| 458 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
|
| 459 |
+
else:
|
| 460 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
| 461 |
+
print('error', error_map.min(), error_map.max())
|
| 462 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
| 463 |
+
plt.clim(elim[0], elim[1])
|
| 464 |
+
plt.axis('off')
|
| 465 |
+
|
| 466 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 467 |
+
plt.show()
|
| 468 |
+
|
| 469 |
+
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
|
| 470 |
+
'''
|
| 471 |
+
xSRvar: 1xHxW
|
| 472 |
+
'''
|
| 473 |
+
plt.figure(figsize=(30,10))
|
| 474 |
+
|
| 475 |
+
plt.subplot(1,4,1)
|
| 476 |
+
print('uncer', xSRvar1.min(), xSRvar1.max())
|
| 477 |
+
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 478 |
+
plt.clim(ulim[0], ulim[1])
|
| 479 |
+
plt.axis('off')
|
| 480 |
+
|
| 481 |
+
plt.subplot(1,4,2)
|
| 482 |
+
print('uncer', xSRvar2.min(), xSRvar2.max())
|
| 483 |
+
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 484 |
+
plt.clim(ulim[0], ulim[1])
|
| 485 |
+
plt.axis('off')
|
| 486 |
+
|
| 487 |
+
plt.subplot(1,4,3)
|
| 488 |
+
print('uncer', xSRvar3.min(), xSRvar3.max())
|
| 489 |
+
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 490 |
+
plt.clim(ulim[0], ulim[1])
|
| 491 |
+
plt.axis('off')
|
| 492 |
+
|
| 493 |
+
plt.subplot(1,4,4)
|
| 494 |
+
print('uncer', xSRvar4.min(), xSRvar4.max())
|
| 495 |
+
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
| 496 |
+
plt.clim(ulim[0], ulim[1])
|
| 497 |
+
plt.axis('off')
|
| 498 |
+
|
| 499 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
| 500 |
+
plt.show()
|
| 501 |
+
|
| 502 |
+
def get_UCE(list_err, list_yout_var, num_bins=100):
|
| 503 |
+
err_min = np.min(list_err)
|
| 504 |
+
err_max = np.max(list_err)
|
| 505 |
+
err_len = (err_max-err_min)/num_bins
|
| 506 |
+
num_points = len(list_err)
|
| 507 |
+
|
| 508 |
+
bin_stats = {}
|
| 509 |
+
for i in range(num_bins):
|
| 510 |
+
bin_stats[i] = {
|
| 511 |
+
'start_idx': err_min + i*err_len,
|
| 512 |
+
'end_idx': err_min + (i+1)*err_len,
|
| 513 |
+
'num_points': 0,
|
| 514 |
+
'mean_err': 0,
|
| 515 |
+
'mean_var': 0,
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
for e,v in zip(list_err, list_yout_var):
|
| 519 |
+
for i in range(num_bins):
|
| 520 |
+
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
|
| 521 |
+
bin_stats[i]['num_points'] += 1
|
| 522 |
+
bin_stats[i]['mean_err'] += e
|
| 523 |
+
bin_stats[i]['mean_var'] += v
|
| 524 |
+
|
| 525 |
+
uce = 0
|
| 526 |
+
eps = 1e-8
|
| 527 |
+
for i in range(num_bins):
|
| 528 |
+
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
|
| 529 |
+
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
|
| 530 |
+
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
|
| 531 |
+
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
|
| 532 |
+
uce += bin_stats[i]['uce_bin']
|
| 533 |
+
|
| 534 |
+
list_x, list_y = [], []
|
| 535 |
+
for i in range(num_bins):
|
| 536 |
+
if bin_stats[i]['num_points']>0:
|
| 537 |
+
list_x.append(bin_stats[i]['mean_err'])
|
| 538 |
+
list_y.append(bin_stats[i]['mean_var'])
|
| 539 |
+
|
| 540 |
+
# sns.set_style('darkgrid')
|
| 541 |
+
# sns.scatterplot(x=list_x, y=list_y)
|
| 542 |
+
# sns.regplot(x=list_x, y=list_y, order=1)
|
| 543 |
+
# plt.xlabel('MSE', fontsize=34)
|
| 544 |
+
# plt.ylabel('Uncertainty', fontsize=34)
|
| 545 |
+
# plt.plot(list_x, list_x, color='r')
|
| 546 |
+
# plt.xlim(np.min(list_x), np.max(list_x))
|
| 547 |
+
# plt.ylim(np.min(list_err), np.max(list_x))
|
| 548 |
+
# plt.show()
|
| 549 |
+
|
| 550 |
+
return bin_stats, uce
|
| 551 |
+
|
| 552 |
+
##################### training BayesCap
|
| 553 |
+
def train_BayesCap(
|
| 554 |
+
NetC,
|
| 555 |
+
NetG,
|
| 556 |
+
train_loader,
|
| 557 |
+
eval_loader,
|
| 558 |
+
Cri = TempCombLoss(),
|
| 559 |
+
device='cuda',
|
| 560 |
+
dtype=torch.cuda.FloatTensor(),
|
| 561 |
+
init_lr=1e-4,
|
| 562 |
+
num_epochs=100,
|
| 563 |
+
eval_every=1,
|
| 564 |
+
ckpt_path='../ckpt/BayesCap',
|
| 565 |
+
T1=1e0,
|
| 566 |
+
T2=5e-2,
|
| 567 |
+
task=None,
|
| 568 |
+
):
|
| 569 |
+
NetC.to(device)
|
| 570 |
+
NetC.train()
|
| 571 |
+
NetG.to(device)
|
| 572 |
+
NetG.eval()
|
| 573 |
+
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
|
| 574 |
+
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
|
| 575 |
+
|
| 576 |
+
score = -1e8
|
| 577 |
+
all_loss = []
|
| 578 |
+
for eph in range(num_epochs):
|
| 579 |
+
eph_loss = 0
|
| 580 |
+
with tqdm(train_loader, unit='batch') as tepoch:
|
| 581 |
+
for (idx, batch) in enumerate(tepoch):
|
| 582 |
+
if idx>2000:
|
| 583 |
+
break
|
| 584 |
+
tepoch.set_description('Epoch {}'.format(eph))
|
| 585 |
+
##
|
| 586 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 587 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 588 |
+
if task == 'inpainting':
|
| 589 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 590 |
+
xMask = xMask.to(device).type(dtype)
|
| 591 |
+
# pass them through the network
|
| 592 |
+
with torch.no_grad():
|
| 593 |
+
if task == 'inpainting':
|
| 594 |
+
_, xSR1 = NetG(xLR, xMask)
|
| 595 |
+
elif task == 'depth':
|
| 596 |
+
xSR1 = NetG(xLR)[("disp", 0)]
|
| 597 |
+
else:
|
| 598 |
+
xSR1 = NetG(xLR)
|
| 599 |
+
# with torch.autograd.set_detect_anomaly(True):
|
| 600 |
+
xSR = xSR1.clone()
|
| 601 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 602 |
+
# print(xSRC_alpha)
|
| 603 |
+
optimizer.zero_grad()
|
| 604 |
+
if task == 'depth':
|
| 605 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
|
| 606 |
+
else:
|
| 607 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
|
| 608 |
+
# print(loss)
|
| 609 |
+
loss.backward()
|
| 610 |
+
optimizer.step()
|
| 611 |
+
##
|
| 612 |
+
eph_loss += loss.item()
|
| 613 |
+
tepoch.set_postfix(loss=loss.item())
|
| 614 |
+
eph_loss /= len(train_loader)
|
| 615 |
+
all_loss.append(eph_loss)
|
| 616 |
+
print('Avg. loss: {}'.format(eph_loss))
|
| 617 |
+
# evaluate and save the models
|
| 618 |
+
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
|
| 619 |
+
if eph%eval_every == 0:
|
| 620 |
+
curr_score = eval_BayesCap(
|
| 621 |
+
NetC,
|
| 622 |
+
NetG,
|
| 623 |
+
eval_loader,
|
| 624 |
+
device=device,
|
| 625 |
+
dtype=dtype,
|
| 626 |
+
task=task,
|
| 627 |
+
)
|
| 628 |
+
print('current score: {} | Last best score: {}'.format(curr_score, score))
|
| 629 |
+
if curr_score >= score:
|
| 630 |
+
score = curr_score
|
| 631 |
+
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
|
| 632 |
+
optim_scheduler.step()
|
| 633 |
+
|
| 634 |
+
#### get different uncertainty maps
|
| 635 |
+
def get_uncer_BayesCap(
|
| 636 |
+
NetC,
|
| 637 |
+
NetG,
|
| 638 |
+
xin,
|
| 639 |
+
task=None,
|
| 640 |
+
xMask=None,
|
| 641 |
+
):
|
| 642 |
+
with torch.no_grad():
|
| 643 |
+
if task == 'inpainting':
|
| 644 |
+
_, xSR = NetG(xin, xMask)
|
| 645 |
+
else:
|
| 646 |
+
xSR = NetG(xin)
|
| 647 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 648 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 649 |
+
b_map = xSRC_beta.to('cpu').data
|
| 650 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 651 |
+
|
| 652 |
+
return xSRvar
|
| 653 |
+
|
| 654 |
+
def get_uncer_TTDAp(
|
| 655 |
+
NetG,
|
| 656 |
+
xin,
|
| 657 |
+
p_mag=0.05,
|
| 658 |
+
num_runs=50,
|
| 659 |
+
task=None,
|
| 660 |
+
xMask=None,
|
| 661 |
+
):
|
| 662 |
+
list_xSR = []
|
| 663 |
+
with torch.no_grad():
|
| 664 |
+
for z in range(num_runs):
|
| 665 |
+
if task == 'inpainting':
|
| 666 |
+
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
|
| 667 |
+
else:
|
| 668 |
+
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
|
| 669 |
+
list_xSR.append(xSRz)
|
| 670 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 671 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 672 |
+
return xSRvar
|
| 673 |
+
|
| 674 |
+
def get_uncer_DO(
|
| 675 |
+
NetG,
|
| 676 |
+
xin,
|
| 677 |
+
dop=0.2,
|
| 678 |
+
num_runs=50,
|
| 679 |
+
task=None,
|
| 680 |
+
xMask=None,
|
| 681 |
+
):
|
| 682 |
+
list_xSR = []
|
| 683 |
+
with torch.no_grad():
|
| 684 |
+
for z in range(num_runs):
|
| 685 |
+
if task == 'inpainting':
|
| 686 |
+
_, xSRz = NetG(xin, xMask, dop=dop)
|
| 687 |
+
else:
|
| 688 |
+
xSRz = NetG(xin, dop=dop)
|
| 689 |
+
list_xSR.append(xSRz)
|
| 690 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 691 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 692 |
+
return xSRvar
|
| 693 |
+
|
| 694 |
+
################### Different eval functions
|
| 695 |
+
|
| 696 |
+
def eval_BayesCap(
|
| 697 |
+
NetC,
|
| 698 |
+
NetG,
|
| 699 |
+
eval_loader,
|
| 700 |
+
device='cuda',
|
| 701 |
+
dtype=torch.cuda.FloatTensor,
|
| 702 |
+
task=None,
|
| 703 |
+
xMask=None,
|
| 704 |
+
):
|
| 705 |
+
NetC.to(device)
|
| 706 |
+
NetC.eval()
|
| 707 |
+
NetG.to(device)
|
| 708 |
+
NetG.eval()
|
| 709 |
+
|
| 710 |
+
mean_ssim = 0
|
| 711 |
+
mean_psnr = 0
|
| 712 |
+
mean_mse = 0
|
| 713 |
+
mean_mae = 0
|
| 714 |
+
num_imgs = 0
|
| 715 |
+
list_error = []
|
| 716 |
+
list_var = []
|
| 717 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 718 |
+
for (idx, batch) in enumerate(tepoch):
|
| 719 |
+
tepoch.set_description('Validating ...')
|
| 720 |
+
##
|
| 721 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 722 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 723 |
+
if task == 'inpainting':
|
| 724 |
+
if xMask==None:
|
| 725 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 726 |
+
xMask = xMask.to(device).type(dtype)
|
| 727 |
+
else:
|
| 728 |
+
xMask = xMask.to(device).type(dtype)
|
| 729 |
+
# pass them through the network
|
| 730 |
+
with torch.no_grad():
|
| 731 |
+
if task == 'inpainting':
|
| 732 |
+
_, xSR = NetG(xLR, xMask)
|
| 733 |
+
elif task == 'depth':
|
| 734 |
+
xSR = NetG(xLR)[("disp", 0)]
|
| 735 |
+
else:
|
| 736 |
+
xSR = NetG(xLR)
|
| 737 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 738 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 739 |
+
b_map = xSRC_beta.to('cpu').data
|
| 740 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 741 |
+
n_batch = xSRC_mu.shape[0]
|
| 742 |
+
if task == 'depth':
|
| 743 |
+
xHR = xSR
|
| 744 |
+
for j in range(n_batch):
|
| 745 |
+
num_imgs += 1
|
| 746 |
+
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
|
| 747 |
+
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
|
| 748 |
+
mean_mse += img_mse(xSRC_mu[j], xHR[j])
|
| 749 |
+
mean_mae += img_mae(xSRC_mu[j], xHR[j])
|
| 750 |
+
|
| 751 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 752 |
+
|
| 753 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 754 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
| 755 |
+
list_error.extend(list(error_map.numpy()))
|
| 756 |
+
list_var.extend(list(var_map.numpy()))
|
| 757 |
+
##
|
| 758 |
+
mean_ssim /= num_imgs
|
| 759 |
+
mean_psnr /= num_imgs
|
| 760 |
+
mean_mse /= num_imgs
|
| 761 |
+
mean_mae /= num_imgs
|
| 762 |
+
print(
|
| 763 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 764 |
+
(
|
| 765 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 766 |
+
)
|
| 767 |
+
)
|
| 768 |
+
# print(len(list_error), len(list_var))
|
| 769 |
+
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
|
| 770 |
+
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
|
| 771 |
+
return mean_ssim
|
| 772 |
+
|
| 773 |
+
def eval_TTDA_p(
|
| 774 |
+
NetG,
|
| 775 |
+
eval_loader,
|
| 776 |
+
device='cuda',
|
| 777 |
+
dtype=torch.cuda.FloatTensor,
|
| 778 |
+
p_mag=0.05,
|
| 779 |
+
num_runs=50,
|
| 780 |
+
task = None,
|
| 781 |
+
xMask = None,
|
| 782 |
+
):
|
| 783 |
+
NetG.to(device)
|
| 784 |
+
NetG.eval()
|
| 785 |
+
|
| 786 |
+
mean_ssim = 0
|
| 787 |
+
mean_psnr = 0
|
| 788 |
+
mean_mse = 0
|
| 789 |
+
mean_mae = 0
|
| 790 |
+
num_imgs = 0
|
| 791 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 792 |
+
for (idx, batch) in enumerate(tepoch):
|
| 793 |
+
tepoch.set_description('Validating ...')
|
| 794 |
+
##
|
| 795 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 796 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 797 |
+
# pass them through the network
|
| 798 |
+
list_xSR = []
|
| 799 |
+
with torch.no_grad():
|
| 800 |
+
if task=='inpainting':
|
| 801 |
+
_, xSR = NetG(xLR, xMask)
|
| 802 |
+
else:
|
| 803 |
+
xSR = NetG(xLR)
|
| 804 |
+
for z in range(num_runs):
|
| 805 |
+
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
|
| 806 |
+
list_xSR.append(xSRz)
|
| 807 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 808 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 809 |
+
n_batch = xSR.shape[0]
|
| 810 |
+
for j in range(n_batch):
|
| 811 |
+
num_imgs += 1
|
| 812 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
| 813 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
| 814 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
| 815 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
| 816 |
+
|
| 817 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 818 |
+
|
| 819 |
+
mean_ssim /= num_imgs
|
| 820 |
+
mean_psnr /= num_imgs
|
| 821 |
+
mean_mse /= num_imgs
|
| 822 |
+
mean_mae /= num_imgs
|
| 823 |
+
print(
|
| 824 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 825 |
+
(
|
| 826 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 827 |
+
)
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
return mean_ssim
|
| 831 |
+
|
| 832 |
+
def eval_DO(
|
| 833 |
+
NetG,
|
| 834 |
+
eval_loader,
|
| 835 |
+
device='cuda',
|
| 836 |
+
dtype=torch.cuda.FloatTensor,
|
| 837 |
+
dop=0.2,
|
| 838 |
+
num_runs=50,
|
| 839 |
+
task=None,
|
| 840 |
+
xMask=None,
|
| 841 |
+
):
|
| 842 |
+
NetG.to(device)
|
| 843 |
+
NetG.eval()
|
| 844 |
+
|
| 845 |
+
mean_ssim = 0
|
| 846 |
+
mean_psnr = 0
|
| 847 |
+
mean_mse = 0
|
| 848 |
+
mean_mae = 0
|
| 849 |
+
num_imgs = 0
|
| 850 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 851 |
+
for (idx, batch) in enumerate(tepoch):
|
| 852 |
+
tepoch.set_description('Validating ...')
|
| 853 |
+
##
|
| 854 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 855 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 856 |
+
# pass them through the network
|
| 857 |
+
list_xSR = []
|
| 858 |
+
with torch.no_grad():
|
| 859 |
+
if task == 'inpainting':
|
| 860 |
+
_, xSR = NetG(xLR, xMask)
|
| 861 |
+
else:
|
| 862 |
+
xSR = NetG(xLR)
|
| 863 |
+
for z in range(num_runs):
|
| 864 |
+
xSRz = NetG(xLR, dop=dop)
|
| 865 |
+
list_xSR.append(xSRz)
|
| 866 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
| 867 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
| 868 |
+
n_batch = xSR.shape[0]
|
| 869 |
+
for j in range(n_batch):
|
| 870 |
+
num_imgs += 1
|
| 871 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
| 872 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
| 873 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
| 874 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
| 875 |
+
|
| 876 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
| 877 |
+
##
|
| 878 |
+
mean_ssim /= num_imgs
|
| 879 |
+
mean_psnr /= num_imgs
|
| 880 |
+
mean_mse /= num_imgs
|
| 881 |
+
mean_mae /= num_imgs
|
| 882 |
+
print(
|
| 883 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 884 |
+
(
|
| 885 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 886 |
+
)
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
return mean_ssim
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
############### compare all function
|
| 893 |
+
def compare_all(
|
| 894 |
+
NetC,
|
| 895 |
+
NetG,
|
| 896 |
+
eval_loader,
|
| 897 |
+
p_mag = 0.05,
|
| 898 |
+
dop = 0.2,
|
| 899 |
+
num_runs = 100,
|
| 900 |
+
device='cuda',
|
| 901 |
+
dtype=torch.cuda.FloatTensor,
|
| 902 |
+
task=None,
|
| 903 |
+
):
|
| 904 |
+
NetC.to(device)
|
| 905 |
+
NetC.eval()
|
| 906 |
+
NetG.to(device)
|
| 907 |
+
NetG.eval()
|
| 908 |
+
|
| 909 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 910 |
+
for (idx, batch) in enumerate(tepoch):
|
| 911 |
+
tepoch.set_description('Comparing ...')
|
| 912 |
+
##
|
| 913 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 914 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 915 |
+
if task == 'inpainting':
|
| 916 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
| 917 |
+
xMask = xMask.to(device).type(dtype)
|
| 918 |
+
# pass them through the network
|
| 919 |
+
with torch.no_grad():
|
| 920 |
+
if task == 'inpainting':
|
| 921 |
+
_, xSR = NetG(xLR, xMask)
|
| 922 |
+
else:
|
| 923 |
+
xSR = NetG(xLR)
|
| 924 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
| 925 |
+
|
| 926 |
+
if task == 'inpainting':
|
| 927 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
|
| 928 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
|
| 929 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
|
| 930 |
+
else:
|
| 931 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
|
| 932 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
|
| 933 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
|
| 934 |
+
|
| 935 |
+
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
|
| 936 |
+
|
| 937 |
+
n_batch = xSR.shape[0]
|
| 938 |
+
for j in range(n_batch):
|
| 939 |
+
if task=='s':
|
| 940 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j])
|
| 941 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
| 942 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
|
| 943 |
+
if task=='d':
|
| 944 |
+
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
|
| 945 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
| 946 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
| 947 |
+
if task=='inpainting':
|
| 948 |
+
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
|
| 949 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
|
| 950 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
| 951 |
+
if task=='m':
|
| 952 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
|
| 953 |
+
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
|
| 954 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
################# Degrading Identity
|
| 958 |
+
def degrage_BayesCap_p(
|
| 959 |
+
NetC,
|
| 960 |
+
NetG,
|
| 961 |
+
eval_loader,
|
| 962 |
+
device='cuda',
|
| 963 |
+
dtype=torch.cuda.FloatTensor,
|
| 964 |
+
num_runs=50,
|
| 965 |
+
):
|
| 966 |
+
NetC.to(device)
|
| 967 |
+
NetC.eval()
|
| 968 |
+
NetG.to(device)
|
| 969 |
+
NetG.eval()
|
| 970 |
+
|
| 971 |
+
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
|
| 972 |
+
list_s = []
|
| 973 |
+
list_p = []
|
| 974 |
+
list_u1 = []
|
| 975 |
+
list_u2 = []
|
| 976 |
+
list_c = []
|
| 977 |
+
for p_mag in p_mag_list:
|
| 978 |
+
mean_ssim = 0
|
| 979 |
+
mean_psnr = 0
|
| 980 |
+
mean_mse = 0
|
| 981 |
+
mean_mae = 0
|
| 982 |
+
num_imgs = 0
|
| 983 |
+
list_error = []
|
| 984 |
+
list_error2 = []
|
| 985 |
+
list_var = []
|
| 986 |
+
|
| 987 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
| 988 |
+
for (idx, batch) in enumerate(tepoch):
|
| 989 |
+
tepoch.set_description('Validating ...')
|
| 990 |
+
##
|
| 991 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
| 992 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
| 993 |
+
# pass them through the network
|
| 994 |
+
with torch.no_grad():
|
| 995 |
+
xSR = NetG(xLR)
|
| 996 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
|
| 997 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
| 998 |
+
b_map = xSRC_beta.to('cpu').data
|
| 999 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
| 1000 |
+
n_batch = xSRC_mu.shape[0]
|
| 1001 |
+
for j in range(n_batch):
|
| 1002 |
+
num_imgs += 1
|
| 1003 |
+
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
|
| 1004 |
+
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
|
| 1005 |
+
mean_mse += img_mse(xSRC_mu[j], xSR[j])
|
| 1006 |
+
mean_mae += img_mae(xSRC_mu[j], xSR[j])
|
| 1007 |
+
|
| 1008 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 1009 |
+
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
| 1010 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
| 1011 |
+
list_error.extend(list(error_map.numpy()))
|
| 1012 |
+
list_error2.extend(list(error_map2.numpy()))
|
| 1013 |
+
list_var.extend(list(var_map.numpy()))
|
| 1014 |
+
##
|
| 1015 |
+
mean_ssim /= num_imgs
|
| 1016 |
+
mean_psnr /= num_imgs
|
| 1017 |
+
mean_mse /= num_imgs
|
| 1018 |
+
mean_mae /= num_imgs
|
| 1019 |
+
print(
|
| 1020 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
| 1021 |
+
(
|
| 1022 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
| 1023 |
+
)
|
| 1024 |
+
)
|
| 1025 |
+
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
|
| 1026 |
+
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
|
| 1027 |
+
print('UCE1: ', uce1)
|
| 1028 |
+
print('UCE2: ', uce2)
|
| 1029 |
+
list_s.append(mean_ssim.item())
|
| 1030 |
+
list_p.append(mean_psnr.item())
|
| 1031 |
+
list_u1.append(uce1)
|
| 1032 |
+
list_u2.append(uce2)
|
| 1033 |
+
|
| 1034 |
+
plt.plot(list_s)
|
| 1035 |
+
plt.show()
|
| 1036 |
+
plt.plot(list_p)
|
| 1037 |
+
plt.show()
|
| 1038 |
+
|
| 1039 |
+
plt.plot(list_u1, label='wrt SR output')
|
| 1040 |
+
plt.plot(list_u2, label='wrt BayesCap output')
|
| 1041 |
+
plt.legend()
|
| 1042 |
+
plt.show()
|
| 1043 |
+
|
| 1044 |
+
sns.set_style('darkgrid')
|
| 1045 |
+
fig,ax = plt.subplots()
|
| 1046 |
+
# make a plot
|
| 1047 |
+
ax.plot(p_mag_list, list_s, color="red", marker="o")
|
| 1048 |
+
# set x-axis label
|
| 1049 |
+
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
|
| 1050 |
+
# set y-axis label
|
| 1051 |
+
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
|
| 1052 |
+
|
| 1053 |
+
# twin object for two different y-axis on the sample plot
|
| 1054 |
+
ax2=ax.twinx()
|
| 1055 |
+
# make a plot with different y-axis using second axis object
|
| 1056 |
+
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
|
| 1057 |
+
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
|
| 1058 |
+
ax2.set_ylabel("UCE", color="green", fontsize=10)
|
| 1059 |
+
plt.legend(fontsize=10)
|
| 1060 |
+
plt.tight_layout()
|
| 1061 |
+
plt.show()
|
| 1062 |
+
|
| 1063 |
+
################# DeepFill_v2
|
| 1064 |
+
|
| 1065 |
+
# ----------------------------------------
|
| 1066 |
+
# PATH processing
|
| 1067 |
+
# ----------------------------------------
|
| 1068 |
+
def text_readlines(filename):
|
| 1069 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
| 1070 |
+
try:
|
| 1071 |
+
file = open(filename, 'r')
|
| 1072 |
+
except IOError:
|
| 1073 |
+
error = []
|
| 1074 |
+
return error
|
| 1075 |
+
content = file.readlines()
|
| 1076 |
+
# This for loop deletes the EOF (like \n)
|
| 1077 |
+
for i in range(len(content)):
|
| 1078 |
+
content[i] = content[i][:len(content[i])-1]
|
| 1079 |
+
file.close()
|
| 1080 |
+
return content
|
| 1081 |
+
|
| 1082 |
+
def savetxt(name, loss_log):
|
| 1083 |
+
np_loss_log = np.array(loss_log)
|
| 1084 |
+
np.savetxt(name, np_loss_log)
|
| 1085 |
+
|
| 1086 |
+
def get_files(path):
|
| 1087 |
+
# read a folder, return the complete path
|
| 1088 |
+
ret = []
|
| 1089 |
+
for root, dirs, files in os.walk(path):
|
| 1090 |
+
for filespath in files:
|
| 1091 |
+
ret.append(os.path.join(root, filespath))
|
| 1092 |
+
return ret
|
| 1093 |
+
|
| 1094 |
+
def get_names(path):
|
| 1095 |
+
# read a folder, return the image name
|
| 1096 |
+
ret = []
|
| 1097 |
+
for root, dirs, files in os.walk(path):
|
| 1098 |
+
for filespath in files:
|
| 1099 |
+
ret.append(filespath)
|
| 1100 |
+
return ret
|
| 1101 |
+
|
| 1102 |
+
def text_save(content, filename, mode = 'a'):
|
| 1103 |
+
# save a list to a txt
|
| 1104 |
+
# Try to save a list variable in txt file.
|
| 1105 |
+
file = open(filename, mode)
|
| 1106 |
+
for i in range(len(content)):
|
| 1107 |
+
file.write(str(content[i]) + '\n')
|
| 1108 |
+
file.close()
|
| 1109 |
+
|
| 1110 |
+
def check_path(path):
|
| 1111 |
+
if not os.path.exists(path):
|
| 1112 |
+
os.makedirs(path)
|
| 1113 |
+
|
| 1114 |
+
# ----------------------------------------
|
| 1115 |
+
# Validation and Sample at training
|
| 1116 |
+
# ----------------------------------------
|
| 1117 |
+
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
|
| 1118 |
+
# Save image one-by-one
|
| 1119 |
+
for i in range(len(img_list)):
|
| 1120 |
+
img = img_list[i]
|
| 1121 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
| 1122 |
+
img = img * 255
|
| 1123 |
+
# Process img_copy and do not destroy the data of img
|
| 1124 |
+
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
|
| 1125 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
| 1126 |
+
img_copy = img_copy.astype(np.uint8)
|
| 1127 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
| 1128 |
+
# Save to certain path
|
| 1129 |
+
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
|
| 1130 |
+
save_img_path = os.path.join(sample_folder, save_img_name)
|
| 1131 |
+
cv2.imwrite(save_img_path, img_copy)
|
| 1132 |
+
|
| 1133 |
+
def psnr(pred, target, pixel_max_cnt = 255):
|
| 1134 |
+
mse = torch.mul(target - pred, target - pred)
|
| 1135 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 1136 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
| 1137 |
+
return p
|
| 1138 |
+
|
| 1139 |
+
def grey_psnr(pred, target, pixel_max_cnt = 255):
|
| 1140 |
+
pred = torch.sum(pred, dim = 0)
|
| 1141 |
+
target = torch.sum(target, dim = 0)
|
| 1142 |
+
mse = torch.mul(target - pred, target - pred)
|
| 1143 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 1144 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
| 1145 |
+
return p
|
| 1146 |
+
|
| 1147 |
+
def ssim(pred, target):
|
| 1148 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
| 1149 |
+
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
| 1150 |
+
target = target[0]
|
| 1151 |
+
pred = pred[0]
|
| 1152 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
|
| 1153 |
+
return ssim
|
| 1154 |
+
|
| 1155 |
+
## for contextual attention
|
| 1156 |
+
|
| 1157 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
| 1158 |
+
"""
|
| 1159 |
+
Extract patches from images and put them in the C output dimension.
|
| 1160 |
+
:param padding:
|
| 1161 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
| 1162 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
| 1163 |
+
each dimension of images
|
| 1164 |
+
:param strides: [stride_rows, stride_cols]
|
| 1165 |
+
:param rates: [dilation_rows, dilation_cols]
|
| 1166 |
+
:return: A Tensor
|
| 1167 |
+
"""
|
| 1168 |
+
assert len(images.size()) == 4
|
| 1169 |
+
assert padding in ['same', 'valid']
|
| 1170 |
+
batch_size, channel, height, width = images.size()
|
| 1171 |
+
|
| 1172 |
+
if padding == 'same':
|
| 1173 |
+
images = same_padding(images, ksizes, strides, rates)
|
| 1174 |
+
elif padding == 'valid':
|
| 1175 |
+
pass
|
| 1176 |
+
else:
|
| 1177 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
| 1178 |
+
Only "same" or "valid" are supported.'.format(padding))
|
| 1179 |
+
|
| 1180 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
| 1181 |
+
dilation=rates,
|
| 1182 |
+
padding=0,
|
| 1183 |
+
stride=strides)
|
| 1184 |
+
patches = unfold(images)
|
| 1185 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
| 1186 |
+
|
| 1187 |
+
def same_padding(images, ksizes, strides, rates):
|
| 1188 |
+
assert len(images.size()) == 4
|
| 1189 |
+
batch_size, channel, rows, cols = images.size()
|
| 1190 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
| 1191 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
| 1192 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
| 1193 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
| 1194 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
| 1195 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
| 1196 |
+
# Pad the input
|
| 1197 |
+
padding_top = int(padding_rows / 2.)
|
| 1198 |
+
padding_left = int(padding_cols / 2.)
|
| 1199 |
+
padding_bottom = padding_rows - padding_top
|
| 1200 |
+
padding_right = padding_cols - padding_left
|
| 1201 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
| 1202 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
| 1203 |
+
return images
|
| 1204 |
+
|
| 1205 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
| 1206 |
+
if not axis:
|
| 1207 |
+
axis = range(len(x.shape))
|
| 1208 |
+
for i in sorted(axis, reverse=True):
|
| 1209 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
| 1210 |
+
return x
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
def reduce_std(x, axis=None, keepdim=False):
|
| 1214 |
+
if not axis:
|
| 1215 |
+
axis = range(len(x.shape))
|
| 1216 |
+
for i in sorted(axis, reverse=True):
|
| 1217 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
| 1218 |
+
return x
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
| 1222 |
+
if not axis:
|
| 1223 |
+
axis = range(len(x.shape))
|
| 1224 |
+
for i in sorted(axis, reverse=True):
|
| 1225 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
| 1226 |
+
return x
|
| 1227 |
+
|
| 1228 |
+
def random_mask(num_batch=1, mask_shape=(256,256)):
|
| 1229 |
+
list_mask = []
|
| 1230 |
+
for _ in range(num_batch):
|
| 1231 |
+
# rectangle mask
|
| 1232 |
+
image_height = mask_shape[0]
|
| 1233 |
+
image_width = mask_shape[1]
|
| 1234 |
+
max_delta_height = image_height//8
|
| 1235 |
+
max_delta_width = image_width//8
|
| 1236 |
+
height = image_height//4
|
| 1237 |
+
width = image_width//4
|
| 1238 |
+
max_t = image_height - height
|
| 1239 |
+
max_l = image_width - width
|
| 1240 |
+
t = random.randint(0, max_t)
|
| 1241 |
+
l = random.randint(0, max_l)
|
| 1242 |
+
# bbox = (t, l, height, width)
|
| 1243 |
+
h = random.randint(0, max_delta_height//2)
|
| 1244 |
+
w = random.randint(0, max_delta_width//2)
|
| 1245 |
+
mask = torch.zeros((1, 1, image_height, image_width))
|
| 1246 |
+
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
|
| 1247 |
+
rect_mask = mask
|
| 1248 |
+
|
| 1249 |
+
# brush mask
|
| 1250 |
+
min_num_vertex = 4
|
| 1251 |
+
max_num_vertex = 12
|
| 1252 |
+
mean_angle = 2 * math.pi / 5
|
| 1253 |
+
angle_range = 2 * math.pi / 15
|
| 1254 |
+
min_width = 12
|
| 1255 |
+
max_width = 40
|
| 1256 |
+
H, W = image_height, image_width
|
| 1257 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 1258 |
+
mask = Image.new('L', (W, H), 0)
|
| 1259 |
+
|
| 1260 |
+
for _ in range(np.random.randint(1, 4)):
|
| 1261 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 1262 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 1263 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 1264 |
+
angles = []
|
| 1265 |
+
vertex = []
|
| 1266 |
+
for i in range(num_vertex):
|
| 1267 |
+
if i % 2 == 0:
|
| 1268 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 1269 |
+
else:
|
| 1270 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 1271 |
+
|
| 1272 |
+
h, w = mask.size
|
| 1273 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 1274 |
+
for i in range(num_vertex):
|
| 1275 |
+
r = np.clip(
|
| 1276 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 1277 |
+
0, 2*average_radius)
|
| 1278 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 1279 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 1280 |
+
vertex.append((int(new_x), int(new_y)))
|
| 1281 |
+
|
| 1282 |
+
draw = ImageDraw.Draw(mask)
|
| 1283 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 1284 |
+
draw.line(vertex, fill=255, width=width)
|
| 1285 |
+
for v in vertex:
|
| 1286 |
+
draw.ellipse((v[0] - width//2,
|
| 1287 |
+
v[1] - width//2,
|
| 1288 |
+
v[0] + width//2,
|
| 1289 |
+
v[1] + width//2),
|
| 1290 |
+
fill=255)
|
| 1291 |
+
|
| 1292 |
+
if np.random.normal() > 0:
|
| 1293 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 1294 |
+
if np.random.normal() > 0:
|
| 1295 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
| 1296 |
+
|
| 1297 |
+
mask = transforms.ToTensor()(mask)
|
| 1298 |
+
mask = mask.reshape((1, 1, H, W))
|
| 1299 |
+
brush_mask = mask
|
| 1300 |
+
|
| 1301 |
+
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
|
| 1302 |
+
list_mask.append(mask)
|
| 1303 |
+
mask = torch.cat(list_mask, dim=0)
|
| 1304 |
+
return mask
|