Spaces:
Sleeping
Sleeping
Bugfix
Browse files- .gitignore +2 -2
- app.py +6 -4
- evals.py +3 -5
- models/blocks.py +1 -1
- utils.py +47 -30
.gitignore
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
.
|
| 2 |
-
__pycache__
|
|
|
|
| 1 |
+
.ipynb_checkpoints/
|
| 2 |
+
__pycache__/
|
app.py
CHANGED
|
@@ -14,6 +14,9 @@ from torchvision import transforms
|
|
| 14 |
from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
### Gradio Utils
|
| 18 |
def generate_imgs(dataset: EvalDataset, idx: int,
|
| 19 |
model: EvalModel, baseline: BaselineModel,
|
|
@@ -152,8 +155,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
| 152 |
# Loading things
|
| 153 |
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
|
| 154 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
|
| 155 |
-
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("
|
| 156 |
-
physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("
|
| 157 |
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
| 158 |
|
| 159 |
@gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
|
|
@@ -265,5 +268,4 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
| 265 |
metrics_placeholder],
|
| 266 |
outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
| 267 |
|
| 268 |
-
|
| 269 |
-
interface.launch()
|
|
|
|
| 14 |
from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
| 15 |
|
| 16 |
|
| 17 |
+
DEVICE_STR = 'cuda'
|
| 18 |
+
|
| 19 |
+
|
| 20 |
### Gradio Utils
|
| 21 |
def generate_imgs(dataset: EvalDataset, idx: int,
|
| 22 |
model: EvalModel, baseline: BaselineModel,
|
|
|
|
| 155 |
# Loading things
|
| 156 |
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
|
| 157 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
|
| 158 |
+
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
| 159 |
+
physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
|
| 160 |
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
| 161 |
|
| 162 |
@gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
|
|
|
|
| 268 |
metrics_placeholder],
|
| 269 |
outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
| 270 |
|
| 271 |
+
interface.launch()
|
|
|
evals.py
CHANGED
|
@@ -486,8 +486,6 @@ class BaselineModel(torch.nn.Module):
|
|
| 486 |
x_adj = physics.A_adjoint(y)
|
| 487 |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
| 488 |
return output
|
| 489 |
-
elif 'UNROLLED_DPIR' in self.name:
|
| 490 |
-
return self.model(y, physics=physics)
|
| 491 |
else:
|
| 492 |
return self.model(y)
|
| 493 |
|
|
@@ -504,19 +502,19 @@ class EvalDataset(torch.utils.data.Dataset):
|
|
| 504 |
if self.name not in self.all_datasets:
|
| 505 |
raise ValueError(f"{self.name} is unavailable.")
|
| 506 |
if self.name == 'Natural':
|
| 507 |
-
self.root = '
|
| 508 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
| 509 |
self.dataset = dinv.datasets.LsdirHR(root=self.root,
|
| 510 |
download=False,
|
| 511 |
transform=self.transform)
|
| 512 |
elif self.name == 'MRI':
|
| 513 |
-
self.root = '
|
| 514 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
| 515 |
self.dataset = Preprocessed_fastMRI(root=self.root,
|
| 516 |
transform=self.transform,
|
| 517 |
preprocess=False)
|
| 518 |
elif self.name == "CT":
|
| 519 |
-
self.root = '
|
| 520 |
self.transform = None
|
| 521 |
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
|
| 522 |
transform=self.transform)
|
|
|
|
| 486 |
x_adj = physics.A_adjoint(y)
|
| 487 |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
| 488 |
return output
|
|
|
|
|
|
|
| 489 |
else:
|
| 490 |
return self.model(y)
|
| 491 |
|
|
|
|
| 502 |
if self.name not in self.all_datasets:
|
| 503 |
raise ValueError(f"{self.name} is unavailable.")
|
| 504 |
if self.name == 'Natural':
|
| 505 |
+
self.root = 'img_samples/LSDIR_samples'
|
| 506 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
| 507 |
self.dataset = dinv.datasets.LsdirHR(root=self.root,
|
| 508 |
download=False,
|
| 509 |
transform=self.transform)
|
| 510 |
elif self.name == 'MRI':
|
| 511 |
+
self.root = 'img_samples/FastMRI_samples'
|
| 512 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
| 513 |
self.dataset = Preprocessed_fastMRI(root=self.root,
|
| 514 |
transform=self.transform,
|
| 515 |
preprocess=False)
|
| 516 |
elif self.name == "CT":
|
| 517 |
+
self.root = 'img_samples/LIDC_IDRI_samples'
|
| 518 |
self.transform = None
|
| 519 |
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
|
| 520 |
transform=self.transform)
|
models/blocks.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|
| 7 |
from deepinv.models.unet import BFBatchNorm2d
|
| 8 |
from deepinv.physics.blur import gaussian_blur
|
| 9 |
from deepinv.physics.functional import conv2d
|
| 10 |
-
from deepinv.utils
|
| 11 |
|
| 12 |
from timm.models.layers import trunc_normal_, DropPath
|
| 13 |
|
|
|
|
| 7 |
from deepinv.models.unet import BFBatchNorm2d
|
| 8 |
from deepinv.physics.blur import gaussian_blur
|
| 9 |
from deepinv.physics.functional import conv2d
|
| 10 |
+
from deepinv.utils import TensorList
|
| 11 |
|
| 12 |
from timm.models.layers import trunc_normal_, DropPath
|
| 13 |
|
utils.py
CHANGED
|
@@ -9,8 +9,16 @@ from physics.multiscale import Pad
|
|
| 9 |
|
| 10 |
|
| 11 |
class ArtifactRemoval(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
|
| 13 |
-
super().__init__()
|
| 14 |
self.pinv = pinv
|
| 15 |
self.backbone_net = backbone_net
|
| 16 |
self.fm_mode = fm_mode
|
|
@@ -24,7 +32,14 @@ class ArtifactRemoval(nn.Module):
|
|
| 24 |
v.requires_grad = False
|
| 25 |
self.backbone_net = self.backbone_net.to(device)
|
| 26 |
|
|
|
|
| 27 |
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if physics is None:
|
| 29 |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
|
| 30 |
|
|
@@ -35,8 +50,15 @@ class ArtifactRemoval(nn.Module):
|
|
| 35 |
|
| 36 |
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
|
| 42 |
|
|
@@ -45,14 +67,18 @@ class ArtifactRemoval(nn.Module):
|
|
| 45 |
|
| 46 |
return out
|
| 47 |
|
| 48 |
-
def forward(self,
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def get_model(
|
| 53 |
model_name="unext_emb_physics_config_C",
|
| 54 |
device="cpu",
|
| 55 |
in_channels=[1, 2, 3],
|
|
|
|
| 56 |
conv_type="base",
|
| 57 |
pool_type="base",
|
| 58 |
layer_scale_init_value=1e-6,
|
|
@@ -65,6 +91,7 @@ def get_model(
|
|
| 65 |
antialias="gaussian",
|
| 66 |
nc_base=64,
|
| 67 |
cond_type="base",
|
|
|
|
| 68 |
pretrained_pth=None,
|
| 69 |
weight_tied=True,
|
| 70 |
N=4,
|
|
@@ -73,41 +100,31 @@ def get_model(
|
|
| 73 |
relu_in_encoding=False,
|
| 74 |
skip_in_encoding=True,
|
| 75 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
model_name = model_name.lower()
|
| 77 |
-
nc = [nc_base * 2**i for i in range(4)]
|
| 78 |
|
| 79 |
if model_name == "pdnet":
|
| 80 |
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
|
| 81 |
|
| 82 |
-
elif model_name == "unrolled_dpir":
|
| 83 |
-
model = UNeXt(
|
| 84 |
-
in_channels=in_channels,
|
| 85 |
-
out_channels=in_channels,
|
| 86 |
-
device=device,
|
| 87 |
-
conv_type=conv_type,
|
| 88 |
-
pool_type=pool_type,
|
| 89 |
-
layer_scale_init_value=layer_scale_init_value,
|
| 90 |
-
init_type=init_type,
|
| 91 |
-
gain_init_conv=gain_init_conv,
|
| 92 |
-
gain_init_linear=gain_init_linear,
|
| 93 |
-
drop_prob=drop_prob,
|
| 94 |
-
replk=replk,
|
| 95 |
-
mult_fact=mult_fact,
|
| 96 |
-
antialias=antialias,
|
| 97 |
-
nc=nc,
|
| 98 |
-
cond_type=cond_type,
|
| 99 |
-
emb_physics=False,
|
| 100 |
-
config=None,
|
| 101 |
-
pretrained_pth=pretrained_pth,
|
| 102 |
-
).to(device)
|
| 103 |
-
model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
|
| 104 |
-
return ArtifactRemoval(model, pinv=True, device=device)
|
| 105 |
-
|
| 106 |
elif model_name == "unext_emb_physics_config_c":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
model = UNeXt(
|
| 108 |
in_channels=in_channels,
|
| 109 |
out_channels=in_channels,
|
| 110 |
device=device,
|
|
|
|
| 111 |
conv_type=conv_type,
|
| 112 |
pool_type=pool_type,
|
| 113 |
layer_scale_init_value=layer_scale_init_value,
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class ArtifactRemoval(nn.Module):
|
| 12 |
+
r"""
|
| 13 |
+
Artifact removal architecture :math:`\phi(A^{\top}y)`.
|
| 14 |
+
|
| 15 |
+
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
|
| 16 |
+
|
| 17 |
+
In the end we should not use this for unext !!
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
|
| 21 |
+
super(ArtifactRemoval, self).__init__()
|
| 22 |
self.pinv = pinv
|
| 23 |
self.backbone_net = backbone_net
|
| 24 |
self.fm_mode = fm_mode
|
|
|
|
| 32 |
v.requires_grad = False
|
| 33 |
self.backbone_net = self.backbone_net.to(device)
|
| 34 |
|
| 35 |
+
|
| 36 |
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
|
| 37 |
+
r"""
|
| 38 |
+
Reconstructs a signal estimate from measurements y
|
| 39 |
+
|
| 40 |
+
:param torch.tensor y: measurements
|
| 41 |
+
:param deepinv.physics.Physics physics: forward operator
|
| 42 |
+
"""
|
| 43 |
if physics is None:
|
| 44 |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
|
| 45 |
|
|
|
|
| 50 |
|
| 51 |
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
|
| 52 |
|
| 53 |
+
if hasattr(physics.noise_model, "sigma"):
|
| 54 |
+
sigma = physics.noise_model.sigma
|
| 55 |
+
else:
|
| 56 |
+
sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
|
| 57 |
+
|
| 58 |
+
if hasattr(physics.noise_model, "gain"):
|
| 59 |
+
gamma = physics.noise_model.gain
|
| 60 |
+
else:
|
| 61 |
+
gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
|
| 62 |
|
| 63 |
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
|
| 64 |
|
|
|
|
| 67 |
|
| 68 |
return out
|
| 69 |
|
| 70 |
+
def forward(self, y=None, physics=None, x_in=None, **kwargs):
|
| 71 |
+
if 'unext' in type(self.backbone_net).__name__.lower():
|
| 72 |
+
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
|
| 73 |
+
else:
|
| 74 |
+
return self.backbone_net(physics=physics, y=y, **kwargs)
|
| 75 |
|
| 76 |
|
| 77 |
def get_model(
|
| 78 |
model_name="unext_emb_physics_config_C",
|
| 79 |
device="cpu",
|
| 80 |
in_channels=[1, 2, 3],
|
| 81 |
+
grayscale=False,
|
| 82 |
conv_type="base",
|
| 83 |
pool_type="base",
|
| 84 |
layer_scale_init_value=1e-6,
|
|
|
|
| 91 |
antialias="gaussian",
|
| 92 |
nc_base=64,
|
| 93 |
cond_type="base",
|
| 94 |
+
blind=False,
|
| 95 |
pretrained_pth=None,
|
| 96 |
weight_tied=True,
|
| 97 |
N=4,
|
|
|
|
| 100 |
relu_in_encoding=False,
|
| 101 |
skip_in_encoding=True,
|
| 102 |
):
|
| 103 |
+
"""
|
| 104 |
+
Load the model.
|
| 105 |
+
|
| 106 |
+
:param str model_name: name of the model
|
| 107 |
+
:param str device: device
|
| 108 |
+
:param bool grayscale: if True, the model is trained on grayscale images
|
| 109 |
+
:param bool train: if True, the model is trained
|
| 110 |
+
:return: model
|
| 111 |
+
"""
|
| 112 |
model_name = model_name.lower()
|
|
|
|
| 113 |
|
| 114 |
if model_name == "pdnet":
|
| 115 |
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
elif model_name == "unext_emb_physics_config_c":
|
| 118 |
+
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
|
| 119 |
+
residual = True if "residual" in model_name else False
|
| 120 |
+
nc = [nc_base * 2**i for i in range(4)]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
model = UNeXt(
|
| 124 |
in_channels=in_channels,
|
| 125 |
out_channels=in_channels,
|
| 126 |
device=device,
|
| 127 |
+
residual=residual,
|
| 128 |
conv_type=conv_type,
|
| 129 |
pool_type=pool_type,
|
| 130 |
layer_scale_init_value=layer_scale_init_value,
|