Spaces:
Sleeping
Sleeping
Cleanup
Browse files- evals.py +1 -1
- utils.py → model_factory.py +0 -1
- models/PDNet.py +0 -322
- physics/inpainting_generator.py +0 -107
evals.py
CHANGED
|
@@ -7,7 +7,7 @@ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
|
|
| 7 |
from torchvision import transforms
|
| 8 |
|
| 9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
| 10 |
-
from
|
| 11 |
|
| 12 |
DEFAULT_MODEL_PARAMS = {
|
| 13 |
"in_channels": [1, 2, 3],
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
|
| 9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
| 10 |
+
from model_factory import get_model
|
| 11 |
|
| 12 |
DEFAULT_MODEL_PARAMS = {
|
| 13 |
"in_channels": [1, 2, 3],
|
utils.py → model_factory.py
RENAMED
|
@@ -3,7 +3,6 @@ import torch.nn as nn
|
|
| 3 |
import deepinv as dinv
|
| 4 |
|
| 5 |
from models.unext_wip import UNeXt
|
| 6 |
-
from models.unrolled_dpir import get_unrolled_architecture
|
| 7 |
from physics.multiscale import Pad
|
| 8 |
|
| 9 |
|
|
|
|
| 3 |
import deepinv as dinv
|
| 4 |
|
| 5 |
from models.unext_wip import UNeXt
|
|
|
|
| 6 |
from physics.multiscale import Pad
|
| 7 |
|
| 8 |
|
models/PDNet.py
DELETED
|
@@ -1,322 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch.func import vmap
|
| 5 |
-
from torch.utils.data import DataLoader
|
| 6 |
-
import deepinv as dinv
|
| 7 |
-
from deepinv.unfolded import unfolded_builder
|
| 8 |
-
from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset
|
| 9 |
-
from deepinv.optim.optim_iterators import CPIteration, fStep, gStep
|
| 10 |
-
from deepinv.optim import Prior, DataFidelity
|
| 11 |
-
from deepinv.utils import TensorList
|
| 12 |
-
|
| 13 |
-
from physics.multiscale import MultiScaleLinearPhysics
|
| 14 |
-
from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def get_PDNet_architecture(in_channels=[1, 2, 3], out_channels=[1, 2, 3], n_primal=3, n_dual=3, device='cuda'):
|
| 18 |
-
class PDNetIteration(CPIteration):
|
| 19 |
-
r"""Single iteration of learned primal dual.
|
| 20 |
-
We only redefine the fStep and gStep classes.
|
| 21 |
-
The forward method is inherited from the CPIteration class.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
def __init__(self, **kwargs):
|
| 25 |
-
super().__init__(**kwargs)
|
| 26 |
-
self.g_step = gStepPDNet(**kwargs)
|
| 27 |
-
self.f_step = fStepPDNet(**kwargs)
|
| 28 |
-
|
| 29 |
-
def forward(
|
| 30 |
-
self, X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs
|
| 31 |
-
):
|
| 32 |
-
r"""
|
| 33 |
-
Single iteration of the Chambolle-Pock algorithm.
|
| 34 |
-
|
| 35 |
-
:param dict X: Dictionary containing the current iterate and the estimated cost.
|
| 36 |
-
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
| 37 |
-
:param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior.
|
| 38 |
-
:param dict cur_params: dictionary containing the current parameters of the algorithm.
|
| 39 |
-
:param torch.Tensor y: Input data.
|
| 40 |
-
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
|
| 41 |
-
:return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost.
|
| 42 |
-
"""
|
| 43 |
-
x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual
|
| 44 |
-
BS, C_primal, H_primal, W_primal = x_prev.shape
|
| 45 |
-
_, C_dual, H_dual, W_dual = u_prev.shape
|
| 46 |
-
n_channels = C_primal // n_primal
|
| 47 |
-
K = lambda x: torch.cat(
|
| 48 |
-
[physics.A(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_primal)], dim=1)
|
| 49 |
-
K_adjoint = lambda x: torch.cat(
|
| 50 |
-
[physics.A_adjoint(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_dual)], dim=1)
|
| 51 |
-
u = self.f_step(u_prev, K(z_prev), cur_data_fidelity, y, physics, n_channels,
|
| 52 |
-
cur_params) # dual update (data_fid)
|
| 53 |
-
x = self.g_step(x_prev, K_adjoint(u), cur_prior, n_channels, cur_params) # primal update (prior)
|
| 54 |
-
z = x + cur_params["beta"] * (x - x_prev)
|
| 55 |
-
F = (
|
| 56 |
-
self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
|
| 57 |
-
if self.has_cost
|
| 58 |
-
else None
|
| 59 |
-
)
|
| 60 |
-
return {"est": (x, z, u), "cost": F}
|
| 61 |
-
|
| 62 |
-
class fStepPDNet(fStep):
|
| 63 |
-
r"""
|
| 64 |
-
Dual update of the PDNet algorithm.
|
| 65 |
-
We write it as a proximal operator of the data fidelity term.
|
| 66 |
-
This proximal mapping is to be replaced by a trainable model.
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
-
def __init__(self, **kwargs):
|
| 70 |
-
super().__init__(**kwargs)
|
| 71 |
-
|
| 72 |
-
def forward(self, x, w, cur_data_fidelity, y, physics, n_channels, *args):
|
| 73 |
-
r"""
|
| 74 |
-
:param torch.Tensor x: Current first variable :math:`u`.
|
| 75 |
-
:param torch.Tensor w: Current second variable :math:`A z`.
|
| 76 |
-
:param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term.
|
| 77 |
-
:param torch.Tensor y: Input data.
|
| 78 |
-
"""
|
| 79 |
-
return cur_data_fidelity.prox(x, w, y, n_channels)
|
| 80 |
-
|
| 81 |
-
class gStepPDNet(gStep):
|
| 82 |
-
r"""
|
| 83 |
-
Primal update of the PDNet algorithm.
|
| 84 |
-
We write it as a proximal operator of the prior term.
|
| 85 |
-
This proximal mapping is to be replaced by a trainable model.
|
| 86 |
-
"""
|
| 87 |
-
|
| 88 |
-
def __init__(self, **kwargs):
|
| 89 |
-
super().__init__(**kwargs)
|
| 90 |
-
|
| 91 |
-
def forward(self, x, w, cur_prior, n_channels, *args):
|
| 92 |
-
r"""
|
| 93 |
-
:param torch.Tensor x: Current first variable :math:`x`.
|
| 94 |
-
:param torch.Tensor w: Current second variable :math:`A^\top u`.
|
| 95 |
-
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
| 96 |
-
"""
|
| 97 |
-
return cur_prior.prox(x, w, n_channels)
|
| 98 |
-
|
| 99 |
-
# %%
|
| 100 |
-
# Define the trainable prior and data fidelity terms.
|
| 101 |
-
# ---------------------------------------------------
|
| 102 |
-
# Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`.
|
| 103 |
-
# Their proximal operators are replaced by trainable models.
|
| 104 |
-
|
| 105 |
-
class PDNetPrior(Prior):
|
| 106 |
-
def __init__(self, model, *args, **kwargs):
|
| 107 |
-
super().__init__(*args, **kwargs)
|
| 108 |
-
self.model = model
|
| 109 |
-
|
| 110 |
-
def prox(self, x, w, n_channels):
|
| 111 |
-
# give to the model : full primal + premier de dual
|
| 112 |
-
dual_cond = w[:, 0:n_channels, :, :]
|
| 113 |
-
return self.model(x, dual_cond)
|
| 114 |
-
|
| 115 |
-
class PDNetDataFid(DataFidelity):
|
| 116 |
-
def __init__(self, model, *args, **kwargs):
|
| 117 |
-
super().__init__(*args, **kwargs)
|
| 118 |
-
self.model = model
|
| 119 |
-
|
| 120 |
-
def prox(self, x, w, y, n_channels):
|
| 121 |
-
# give to the model : full dual + deuxieme de primal + y = n_channel*n_dual + n_channel + n_channel
|
| 122 |
-
if n_primal > 1:
|
| 123 |
-
primal_cond = w[:, n_channels:(2 * n_channels), :, :]
|
| 124 |
-
else:
|
| 125 |
-
primal_cond = w[:, 0:n_channels, :, :]
|
| 126 |
-
return self.model(x, primal_cond, y)
|
| 127 |
-
|
| 128 |
-
# Unrolled optimization algorithm parameters
|
| 129 |
-
max_iter = 10
|
| 130 |
-
|
| 131 |
-
# Set up the data fidelity term. Each layer has its own data fidelity module.
|
| 132 |
-
in_channels_dual = [in_channel * n_dual + in_channel + in_channel for in_channel in in_channels]
|
| 133 |
-
out_channels_dual = [in_channel * n_dual for in_channel in in_channels]
|
| 134 |
-
in_channels_primal = [in_channel * n_primal + in_channel for in_channel in in_channels]
|
| 135 |
-
out_channels_primal = [in_channel * n_primal for in_channel in in_channels]
|
| 136 |
-
|
| 137 |
-
data_fidelity = [
|
| 138 |
-
PDNetDataFid(model=PDNet_DualBlock(in_channels=in_channels_dual, out_channels=out_channels_dual).to(device)) for
|
| 139 |
-
i in range(max_iter)
|
| 140 |
-
]
|
| 141 |
-
|
| 142 |
-
# Set up the trainable prior. Each layer has its own prior module.
|
| 143 |
-
prior = [
|
| 144 |
-
PDNetPrior(model=PDNet_PrimalBlock(in_channels=in_channels_primal, out_channels=out_channels_primal).to(device))
|
| 145 |
-
for i in range(max_iter)]
|
| 146 |
-
|
| 147 |
-
# %%
|
| 148 |
-
# Define the model.
|
| 149 |
-
# -------------------------------
|
| 150 |
-
|
| 151 |
-
def custom_init(y, physics):
|
| 152 |
-
x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1)
|
| 153 |
-
u0 = (0 * y).repeat(1, n_dual, 1, 1)
|
| 154 |
-
return {"est": (x0, x0, u0)}
|
| 155 |
-
|
| 156 |
-
def custom_output(X):
|
| 157 |
-
x = X["est"][0]
|
| 158 |
-
n_channels = x.shape[1] // n_primal
|
| 159 |
-
if n_primal > 1:
|
| 160 |
-
return X["est"][0][:, n_channels:(2 * n_channels), :, :]
|
| 161 |
-
else:
|
| 162 |
-
return X["est"][0][:, 0:n_channels, :, :]
|
| 163 |
-
|
| 164 |
-
# %%
|
| 165 |
-
# Define the unfolded trainable model.
|
| 166 |
-
# -------------------------------------
|
| 167 |
-
# The original paper of the learned primal dual algorithm the authors used the adjoint operator
|
| 168 |
-
# in the primal update. However, the same authors (among others) find in the paper
|
| 169 |
-
#
|
| 170 |
-
# A. Hauptmann, J. Adler, S. Arridge, O. Öktem,
|
| 171 |
-
# Multi-scale learned iterative reconstruction,
|
| 172 |
-
# IEEE Transactions on Computational Imaging 6, 843-856, 2020.
|
| 173 |
-
#
|
| 174 |
-
# that using a filtered gradient can improve both the training speed and reconstruction quality significantly.
|
| 175 |
-
# Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step.
|
| 176 |
-
|
| 177 |
-
model = unfolded_builder(
|
| 178 |
-
iteration=PDNetIteration(),
|
| 179 |
-
params_algo={"beta": 0.0},
|
| 180 |
-
data_fidelity=data_fidelity,
|
| 181 |
-
prior=prior,
|
| 182 |
-
max_iter=max_iter,
|
| 183 |
-
custom_init=custom_init,
|
| 184 |
-
get_output=custom_output,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
return model.to(device)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def init_weights(m):
|
| 191 |
-
if isinstance(m, torch.nn.Linear):
|
| 192 |
-
torch.torch.nn.init.xavier_uniform(m.weight)
|
| 193 |
-
m.bias.data.fill_(0.0)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
class PDNet_PrimalBlock(torch.nn.Module):
|
| 197 |
-
r"""
|
| 198 |
-
Primal block for the Primal-Dual unfolding model.
|
| 199 |
-
|
| 200 |
-
From https://arxiv.org/abs/1707.06474.
|
| 201 |
-
|
| 202 |
-
Primal variables are images of shape (batch_size, in_channels, height, width). The input of each
|
| 203 |
-
primal block is the concatenation of the current primal variable and the backprojected dual variable along
|
| 204 |
-
the channel dimension. The output of each primal block is the current primal variable.
|
| 205 |
-
|
| 206 |
-
:param int in_channels: number of input channels. Default: 6.
|
| 207 |
-
:param int out_channels: number of output channels. Default: 5.
|
| 208 |
-
:param int depth: number of convolutional layers in the block. Default: 3.
|
| 209 |
-
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
| 210 |
-
:param int nf: number of features in the convolutional layers. Default: 32.
|
| 211 |
-
"""
|
| 212 |
-
|
| 213 |
-
def __init__(self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], depth=3, bias=True, nf=32):
|
| 214 |
-
super(PDNet_PrimalBlock, self).__init__()
|
| 215 |
-
|
| 216 |
-
self.separate_head = isinstance(in_channels, list)
|
| 217 |
-
self.depth = depth
|
| 218 |
-
|
| 219 |
-
self.in_conv = InHead(in_channels, nf, bias=bias)
|
| 220 |
-
# self.m_head.apply(init_weights)
|
| 221 |
-
|
| 222 |
-
# self.in_conv = torch.nn.Conv2d(
|
| 223 |
-
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
| 224 |
-
# )
|
| 225 |
-
|
| 226 |
-
self.in_conv.apply(init_weights)
|
| 227 |
-
self.conv_list = torch.nn.ModuleList(
|
| 228 |
-
[
|
| 229 |
-
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 230 |
-
for _ in range(self.depth - 2)
|
| 231 |
-
]
|
| 232 |
-
)
|
| 233 |
-
self.conv_list.apply(init_weights)
|
| 234 |
-
# self.out_conv = torch.nn.Conv2d(
|
| 235 |
-
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 236 |
-
# )
|
| 237 |
-
self.out_conv = OutTail(nf, out_channels, bias=bias)
|
| 238 |
-
self.out_conv.apply(init_weights)
|
| 239 |
-
|
| 240 |
-
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
|
| 241 |
-
|
| 242 |
-
def forward(self, x, Atu):
|
| 243 |
-
r"""
|
| 244 |
-
Forward pass of the primal block.
|
| 245 |
-
|
| 246 |
-
:param torch.Tensor x: current primal variable.
|
| 247 |
-
:param torch.Tensor Atu: backprojected dual variable.
|
| 248 |
-
:return: (:class:`torch.Tensor`) the current primal variable.
|
| 249 |
-
"""
|
| 250 |
-
primal_channels = x.shape[1]
|
| 251 |
-
x_in = torch.cat((x, Atu), dim=1)
|
| 252 |
-
|
| 253 |
-
x_ = self.in_conv(x_in)
|
| 254 |
-
x_ = self.nl_list[0](x_)
|
| 255 |
-
|
| 256 |
-
for i in range(self.depth - 2):
|
| 257 |
-
x_l = self.conv_list[i](x_)
|
| 258 |
-
x_ = self.nl_list[i + 1](x_l)
|
| 259 |
-
|
| 260 |
-
return self.out_conv(x_, primal_channels) + x
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
class PDNet_DualBlock(torch.nn.Module):
|
| 264 |
-
r"""
|
| 265 |
-
Dual block for the Primal-Dual unfolding model.
|
| 266 |
-
|
| 267 |
-
From https://arxiv.org/abs/1707.06474.
|
| 268 |
-
|
| 269 |
-
Dual variables are images of shape (batch_size, in_channels, height, width). The input of each
|
| 270 |
-
primal block is the concatenation of the current dual variable with the projected primal variable and
|
| 271 |
-
the measurements. The output of each dual block is the current primal variable.
|
| 272 |
-
|
| 273 |
-
:param int in_channels: number of input channels. Default: 7.
|
| 274 |
-
:param int out_channels: number of output channels. Default: 5.
|
| 275 |
-
:param int depth: number of convolutional layers in the block. Default: 3.
|
| 276 |
-
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
| 277 |
-
:param int nf: number of features in the convolutional layers. Default: 32.
|
| 278 |
-
"""
|
| 279 |
-
|
| 280 |
-
def __init__(self, in_channels=[1, 2, 3], out_channels=[6, 2, 3], depth=3, bias=True, nf=32):
|
| 281 |
-
super(PDNet_DualBlock, self).__init__()
|
| 282 |
-
|
| 283 |
-
self.depth = depth
|
| 284 |
-
self.in_conv = InHead(in_channels, nf, bias=bias)
|
| 285 |
-
# self.in_conv = torch.nn.Conv2d(
|
| 286 |
-
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
| 287 |
-
# )
|
| 288 |
-
self.in_conv.apply(init_weights)
|
| 289 |
-
self.conv_list = torch.nn.ModuleList(
|
| 290 |
-
[
|
| 291 |
-
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 292 |
-
for _ in range(self.depth - 2)
|
| 293 |
-
]
|
| 294 |
-
)
|
| 295 |
-
self.conv_list.apply(init_weights)
|
| 296 |
-
self.out_conv = OutTail(nf, out_channels, bias=bias)
|
| 297 |
-
# self.out_conv = torch.nn.Conv2d(
|
| 298 |
-
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 299 |
-
# )
|
| 300 |
-
self.out_conv.apply(init_weights)
|
| 301 |
-
|
| 302 |
-
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
|
| 303 |
-
|
| 304 |
-
def forward(self, u, Ax_cur, y):
|
| 305 |
-
r"""
|
| 306 |
-
Forward pass of the dual block.
|
| 307 |
-
|
| 308 |
-
:param torch.Tensor u: current dual variable.
|
| 309 |
-
:param torch.Tensor Ax_cur: projection of the primal variable.
|
| 310 |
-
:param torch.Tensor y: measurements.
|
| 311 |
-
"""
|
| 312 |
-
dual_channels = u.shape[1]
|
| 313 |
-
x_in = torch.cat((u, Ax_cur, y), dim=1)
|
| 314 |
-
|
| 315 |
-
x_ = self.in_conv(x_in)
|
| 316 |
-
x_ = self.nl_list[0](x_)
|
| 317 |
-
|
| 318 |
-
for i in range(self.depth - 2):
|
| 319 |
-
x_l = self.conv_list[i](x_)
|
| 320 |
-
x_ = self.nl_list[i + 1](x_l)
|
| 321 |
-
|
| 322 |
-
return self.out_conv(x_, dual_channels) + u
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
physics/inpainting_generator.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from deepinv.physics.generator import PhysicsGenerator
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class InpaintingMaskGenerator(PhysicsGenerator):
|
| 6 |
-
|
| 7 |
-
def __init__(
|
| 8 |
-
self,
|
| 9 |
-
mask_shape: tuple,
|
| 10 |
-
num_channels: int = 1,
|
| 11 |
-
device: str = "cpu",
|
| 12 |
-
dtype: type = torch.float32,
|
| 13 |
-
block_size_ratio=0.1,
|
| 14 |
-
num_blocks=5,
|
| 15 |
-
) -> None:
|
| 16 |
-
kwargs = {
|
| 17 |
-
"mask_shape": mask_shape,
|
| 18 |
-
"block_size_ratio": block_size_ratio,
|
| 19 |
-
"num_blocks": num_blocks,
|
| 20 |
-
}
|
| 21 |
-
if len(mask_shape) != 2:
|
| 22 |
-
raise ValueError(
|
| 23 |
-
"mask_shape must 2D. Add channels via num_channels parameter"
|
| 24 |
-
)
|
| 25 |
-
super().__init__(
|
| 26 |
-
num_channels=num_channels,
|
| 27 |
-
device=device,
|
| 28 |
-
dtype=dtype,
|
| 29 |
-
**kwargs,
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
def generate_mask(self, image_shape, block_size_ratio, num_blocks):
|
| 33 |
-
# Create an all-ones tensor which will serve as the initial mask
|
| 34 |
-
mask = torch.ones(image_shape)
|
| 35 |
-
batch_size = mask.shape[0]
|
| 36 |
-
|
| 37 |
-
# Calculate block size based on the image dimensions and block_size_ratio
|
| 38 |
-
block_width = int(image_shape[-2] * block_size_ratio)
|
| 39 |
-
block_height = int(image_shape[-1] * block_size_ratio)
|
| 40 |
-
|
| 41 |
-
# Generate random coordinates for each block in each batch
|
| 42 |
-
x_coords = torch.randint(
|
| 43 |
-
0, image_shape[-1] - block_width, (batch_size, num_blocks)
|
| 44 |
-
)
|
| 45 |
-
y_coords = torch.randint(
|
| 46 |
-
0, image_shape[-2] - block_height, (batch_size, num_blocks)
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
# Create grids of indices for the block dimensions
|
| 50 |
-
x_range = torch.arange(block_width).view(1, 1, -1)
|
| 51 |
-
y_range = torch.arange(block_height).view(1, 1, -1)
|
| 52 |
-
|
| 53 |
-
# Expand ranges to match the batch and num_blocks dimensions
|
| 54 |
-
x_indices = x_coords.unsqueeze(-1) + x_range
|
| 55 |
-
y_indices = y_coords.unsqueeze(-1) + y_range
|
| 56 |
-
|
| 57 |
-
# Expand and flatten the indices for advanced indexing
|
| 58 |
-
x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1)
|
| 59 |
-
y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1)
|
| 60 |
-
|
| 61 |
-
# Create batch indices for advanced indexing
|
| 62 |
-
batch_indices = (
|
| 63 |
-
torch.arange(batch_size)
|
| 64 |
-
.view(-1, 1, 1)
|
| 65 |
-
.expand(-1, num_blocks, block_width * block_height)
|
| 66 |
-
.reshape(-1)
|
| 67 |
-
)
|
| 68 |
-
channel_indices = (
|
| 69 |
-
torch.arange(3)
|
| 70 |
-
.view(1, 1, 1, -1)
|
| 71 |
-
.expand(batch_size, num_blocks, block_width * block_height, -1)
|
| 72 |
-
.reshape(-1)
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
# Apply the blocks using advanced indexing
|
| 76 |
-
mask[batch_indices, :, y_indices, x_indices] = 0
|
| 77 |
-
|
| 78 |
-
return mask
|
| 79 |
-
|
| 80 |
-
def step(
|
| 81 |
-
self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None
|
| 82 |
-
):
|
| 83 |
-
r"""
|
| 84 |
-
Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l`
|
| 85 |
-
|
| 86 |
-
:param int batch_size: batch_size.
|
| 87 |
-
:param float sigma: the standard deviation of the Gaussian Process
|
| 88 |
-
:param float l: the length scale of the trajectory
|
| 89 |
-
|
| 90 |
-
:return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])`
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
# TODO: add randomness
|
| 94 |
-
block_size_ratio = (
|
| 95 |
-
self.block_size_ratio if block_size_ratio is None else block_size_ratio
|
| 96 |
-
)
|
| 97 |
-
num_blocks = self.num_blocks if num_blocks is None else num_blocks
|
| 98 |
-
batch_shape = (
|
| 99 |
-
batch_size,
|
| 100 |
-
self.num_channels,
|
| 101 |
-
self.mask_shape[-2],
|
| 102 |
-
self.mask_shape[-1],
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks)
|
| 106 |
-
|
| 107 |
-
return {"mask": mask.to(self.factory_kwargs["device"])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|