Spaces:
Sleeping
Sleeping
Commit
·
4dc3e99
0
Parent(s):
gradio demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- .gitignore +2 -0
- README.md +13 -0
- app.py +269 -0
- datasets.py +84 -0
- evals.py +564 -0
- img_samples/LSDIR_samples/0001000/0000007_s005.png +0 -0
- img_samples/LSDIR_samples/0001000/0000030_s003.png +0 -0
- img_samples/LSDIR_samples/0001000/0000067_s005.png +0 -0
- img_samples/LSDIR_samples/0001000/0000082_s003.png +0 -0
- img_samples/LSDIR_samples/0001000/0000110_s002.png +0 -0
- img_samples/LSDIR_samples/0001000/0000125_s003.png +0 -0
- img_samples/LSDIR_samples/0001000/0000154_s007.png +0 -0
- img_samples/LSDIR_samples/0001000/0000247_s007.png +0 -0
- img_samples/LSDIR_samples/0001000/0000259_s003.png +0 -0
- img_samples/LSDIR_samples/0001000/0000405_s008.png +0 -0
- img_samples/LSDIR_samples/0001000/0000578_s002.png +0 -0
- img_samples/LSDIR_samples/0001000/0000669_s010.png +0 -0
- img_samples/LSDIR_samples/0001000/0000689_s006.png +0 -0
- img_samples/LSDIR_samples/0001000/0000715_s011.png +0 -0
- img_samples/LSDIR_samples/0001000/0000752_s010.png +0 -0
- img_samples/LSDIR_samples/0001000/0000803_s012.png +0 -0
- img_samples/LSDIR_samples/0001000/0000825_s012.png +0 -0
- img_samples/LSDIR_samples/0001000/0000921_s012.png +0 -0
- img_samples/LSDIR_samples/0001000/0000958_s004.png +0 -0
- img_samples/LSDIR_samples/0001000/0000994_s021.png +0 -0
- img_samples/LSDIR_samples/0009000/0008033_s006.png +0 -0
- img_samples/LSDIR_samples/0009000/0008068_s005.png +0 -0
- img_samples/LSDIR_samples/0009000/0008115_s004.png +0 -0
- img_samples/LSDIR_samples/0009000/0008217_s002.png +0 -0
- img_samples/LSDIR_samples/0009000/0008294_s010.png +0 -0
- img_samples/LSDIR_samples/0009000/0008315_s053.png +0 -0
- img_samples/LSDIR_samples/0009000/0008340_s015.png +0 -0
- img_samples/LSDIR_samples/0009000/0008361_s009.png +0 -0
- img_samples/LSDIR_samples/0009000/0008386_s007.png +0 -0
- img_samples/LSDIR_samples/0009000/0008491_s006.png +0 -0
- img_samples/LSDIR_samples/0009000/0008528_s007.png +0 -0
- img_samples/LSDIR_samples/0009000/0008571_s007.png +0 -0
- img_samples/LSDIR_samples/0009000/0008573_s012.png +0 -0
- img_samples/LSDIR_samples/0009000/0008605_s007.png +0 -0
- img_samples/LSDIR_samples/0009000/0008611_s002.png +0 -0
- img_samples/LSDIR_samples/0009000/0008631_s005.png +0 -0
- img_samples/LSDIR_samples/0009000/0008681_s008.png +0 -0
- img_samples/LSDIR_samples/0009000/0008703_s013.png +0 -0
- img_samples/LSDIR_samples/0009000/0008714_s010.png +0 -0
- img_samples/LSDIR_samples/0009000/0008774_s004.png +0 -0
- img_samples/LSDIR_samples/0023000/0022020_s005.png +0 -0
- img_samples/LSDIR_samples/0023000/0022037_s011.png +0 -0
- img_samples/LSDIR_samples/0023000/0022059_s008.png +0 -0
- img_samples/LSDIR_samples/0023000/0022135_s002.png +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.ipynb
|
| 2 |
+
__pycache__
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Denoising
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.19.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: bsd-3-clause
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from functools import partial
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import deepinv as dinv
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
|
| 14 |
+
from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Gradio Utils
|
| 18 |
+
def generate_imgs(dataset: EvalDataset, idx: int,
|
| 19 |
+
model: EvalModel, baseline: BaselineModel,
|
| 20 |
+
physics: PhysicsWithGenerator, use_gen: bool,
|
| 21 |
+
metrics: List[Metric]):
|
| 22 |
+
### Load 1 image
|
| 23 |
+
x = dataset[idx] # shape : (3, 256, 256)
|
| 24 |
+
x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
### Compute y
|
| 28 |
+
y = physics(x, use_gen) # possible reduction in img shape due to Blurring
|
| 29 |
+
|
| 30 |
+
### Compute x_hat
|
| 31 |
+
out = model(y=y, physics=physics.physics)
|
| 32 |
+
out_baseline = baseline(y=y, physics=physics.physics)
|
| 33 |
+
|
| 34 |
+
### Process tensors before metric computation
|
| 35 |
+
if "Blur" in physics.name:
|
| 36 |
+
w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
|
| 37 |
+
h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2
|
| 38 |
+
|
| 39 |
+
x = x[..., w_1:w_2, h_1:h_2]
|
| 40 |
+
out = out[..., w_1:w_2, h_1:h_2]
|
| 41 |
+
if out_baseline.shape != out.shape:
|
| 42 |
+
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
|
| 43 |
+
|
| 44 |
+
### Metrics
|
| 45 |
+
metrics_y = ""
|
| 46 |
+
metrics_out = ""
|
| 47 |
+
metrics_out_baseline = ""
|
| 48 |
+
for metric in metrics:
|
| 49 |
+
if y.shape == x.shape:
|
| 50 |
+
metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
|
| 51 |
+
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
|
| 52 |
+
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
| 53 |
+
|
| 54 |
+
### Process y when y shape is different from x shape
|
| 55 |
+
if physics.name == "MRI" or "SR" in physics.name:
|
| 56 |
+
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
| 57 |
+
else:
|
| 58 |
+
y_plot = y.clone()
|
| 59 |
+
|
| 60 |
+
### Processing images for plotting :
|
| 61 |
+
# - clip value outside of [0,1]
|
| 62 |
+
# - shape (1, C, H, W) -> (C, H, W)
|
| 63 |
+
# - torch.Tensor object -> Pil object
|
| 64 |
+
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
|
| 65 |
+
to_pil = transforms.ToPILImage()
|
| 66 |
+
x = to_pil(process_img(x)[0].to('cpu'))
|
| 67 |
+
y = to_pil(process_img(y_plot)[0].to('cpu'))
|
| 68 |
+
out = to_pil(process_img(out)[0].to('cpu'))
|
| 69 |
+
out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
| 73 |
+
|
| 74 |
+
def update_random_idx_and_generate_imgs(dataset: EvalDataset,
|
| 75 |
+
model: EvalModel,
|
| 76 |
+
baseline: BaselineModel,
|
| 77 |
+
physics: PhysicsWithGenerator,
|
| 78 |
+
use_gen: bool,
|
| 79 |
+
metrics: List[Metric]):
|
| 80 |
+
idx = random.randint(0, len(dataset))
|
| 81 |
+
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
|
| 82 |
+
idx,
|
| 83 |
+
model,
|
| 84 |
+
baseline,
|
| 85 |
+
physics,
|
| 86 |
+
use_gen,
|
| 87 |
+
metrics)
|
| 88 |
+
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
| 89 |
+
|
| 90 |
+
def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
|
| 91 |
+
model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel,
|
| 92 |
+
x: Image.Image, y: Image.Image,
|
| 93 |
+
out_a: Image.Image, out_b: Image.Image,
|
| 94 |
+
y_metrics_str: str,
|
| 95 |
+
out_a_metric_str : str, out_b_metric_str: str) -> None:
|
| 96 |
+
|
| 97 |
+
### PROCESSES STR
|
| 98 |
+
physics_params_str = ""
|
| 99 |
+
for param_name, param_value in physics.saved_params["updatable_params"].items():
|
| 100 |
+
physics_params_str += f"{param_name}_{param_value}-"
|
| 101 |
+
physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str
|
| 102 |
+
y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-")
|
| 103 |
+
y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str
|
| 104 |
+
out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-")
|
| 105 |
+
out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str
|
| 106 |
+
out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-")
|
| 107 |
+
out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str
|
| 108 |
+
|
| 109 |
+
save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png"
|
| 110 |
+
titles = [f"{dataset.name}[{idx}]",
|
| 111 |
+
f"y = {physics.name}(x)",
|
| 112 |
+
f"{model_a.name}",
|
| 113 |
+
f"{model_b.name}"]
|
| 114 |
+
|
| 115 |
+
# Pil object -> torch.Tensor
|
| 116 |
+
to_tensor = transforms.ToTensor()
|
| 117 |
+
x = to_tensor(x)
|
| 118 |
+
y = to_tensor(y)
|
| 119 |
+
out_a = to_tensor(out_a)
|
| 120 |
+
out_b = to_tensor(out_b)
|
| 121 |
+
|
| 122 |
+
dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
|
| 123 |
+
|
| 124 |
+
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
| 125 |
+
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
| 126 |
+
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
| 127 |
+
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
| 128 |
+
get_physics_generator_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
| 129 |
+
|
| 130 |
+
def get_model(model_name, ckpt_pth):
|
| 131 |
+
if model_name in BaselineModel.all_baselines:
|
| 132 |
+
return get_baseline_model_on_DEVICE_STR(model_name)
|
| 133 |
+
else:
|
| 134 |
+
return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
### Gradio Blocks interface
|
| 138 |
+
|
| 139 |
+
# Define custom CSS
|
| 140 |
+
custom_css = """
|
| 141 |
+
.fixed-textbox textarea {
|
| 142 |
+
height: 90px !important; /* Adjust height to fit exactly 4 lines */
|
| 143 |
+
overflow: scroll; /* Add a scroll bar if necessary */
|
| 144 |
+
resize: none; /* User can resize vertically the textbox */
|
| 145 |
+
}
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
|
| 149 |
+
with gr.Blocks(title=title, css=custom_css) as interface:
|
| 150 |
+
gr.Markdown("## " + title)
|
| 151 |
+
|
| 152 |
+
# Loading things
|
| 153 |
+
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
|
| 154 |
+
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
|
| 155 |
+
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("DIV2K_valid_HR"))
|
| 156 |
+
physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("Denoising")) # lambda expression to instanciate a callable in a gr.State
|
| 157 |
+
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
| 158 |
+
|
| 159 |
+
@gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
|
| 160 |
+
def dynamic_layout(model_a, model_b, dataset, physics, metrics):
|
| 161 |
+
### LAYOUT
|
| 162 |
+
model_a_name = model_a.base_name
|
| 163 |
+
model_a_full_name = model_a.name
|
| 164 |
+
model_b_name = model_b.base_name
|
| 165 |
+
model_b_full_name = model_b.name
|
| 166 |
+
dataset_name = dataset.name
|
| 167 |
+
physics_name = physics.name
|
| 168 |
+
metric_names = [metric.name for metric in metrics]
|
| 169 |
+
|
| 170 |
+
# Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel
|
| 171 |
+
with gr.Row():
|
| 172 |
+
with gr.Column():
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column():
|
| 175 |
+
clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False)
|
| 176 |
+
physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
|
| 177 |
+
with gr.Column():
|
| 178 |
+
y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
|
| 179 |
+
y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
|
| 180 |
+
with gr.Row():
|
| 181 |
+
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
| 182 |
+
label="List of EvalDataset",
|
| 183 |
+
value=dataset_name,
|
| 184 |
+
scale=2)
|
| 185 |
+
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
|
| 186 |
+
|
| 187 |
+
choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
|
| 188 |
+
label="List of PhysicsWithGenerator",
|
| 189 |
+
value=physics_name)
|
| 190 |
+
with gr.Row():
|
| 191 |
+
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
| 192 |
+
label="Updatable Parameter Key",
|
| 193 |
+
scale=2)
|
| 194 |
+
value_text = gr.Textbox(label="Update Value", scale=2)
|
| 195 |
+
with gr.Column(scale=1):
|
| 196 |
+
update_button = gr.Button("Update Param")
|
| 197 |
+
use_generator_button = gr.Checkbox(label="Use param generator")
|
| 198 |
+
|
| 199 |
+
with gr.Column():
|
| 200 |
+
with gr.Row():
|
| 201 |
+
with gr.Column():
|
| 202 |
+
model_a_out = gr.Image(label=f"{model_a_full_name} OUTPUT", interactive=False)
|
| 203 |
+
out_a_metric = gr.Textbox(label="Metrics(model_a(y), x)", elem_classes=["fixed-textbox"])
|
| 204 |
+
load_model_a = gr.Button("Load model A...", scale=1)
|
| 205 |
+
with gr.Column():
|
| 206 |
+
model_b_out = gr.Image(label=f"{model_b_full_name} OUTPUT", interactive=False)
|
| 207 |
+
out_b_metric = gr.Textbox(label="Metrics(model_b(y), x)", elem_classes=["fixed-textbox"])
|
| 208 |
+
load_model_b = gr.Button("Load model B...", scale=1)
|
| 209 |
+
with gr.Row():
|
| 210 |
+
choose_model_a = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
|
| 211 |
+
label="List of Model A",
|
| 212 |
+
value=model_a_name,
|
| 213 |
+
scale=2)
|
| 214 |
+
path_a_str = gr.Textbox(value=model_a.ckpt_pth, label="Checkpoint path", scale=3)
|
| 215 |
+
with gr.Row():
|
| 216 |
+
choose_model_b = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
|
| 217 |
+
label="List of Model B",
|
| 218 |
+
value=model_b_name,
|
| 219 |
+
scale=2)
|
| 220 |
+
path_b_str = gr.Textbox(value=model_b.ckpt_pth, label="Checkpoint path", scale=3)
|
| 221 |
+
|
| 222 |
+
# Components: Load Metric + Load/Save Buttons
|
| 223 |
+
with gr.Row():
|
| 224 |
+
with gr.Column():
|
| 225 |
+
choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
|
| 226 |
+
value=metric_names,
|
| 227 |
+
label="Choose metrics you are interested")
|
| 228 |
+
with gr.Column():
|
| 229 |
+
load_button = gr.Button("Load images...")
|
| 230 |
+
load_random_button = gr.Button("Load randomly...")
|
| 231 |
+
save_button = gr.Button("Save images...")
|
| 232 |
+
|
| 233 |
+
### Event listeners
|
| 234 |
+
choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
|
| 235 |
+
inputs=choose_dataset,
|
| 236 |
+
outputs=dataset_placeholder)
|
| 237 |
+
choose_physics.change(fn=get_physics_generator_on_DEVICE_STR,
|
| 238 |
+
inputs=choose_physics,
|
| 239 |
+
outputs=physics_placeholder)
|
| 240 |
+
update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
|
| 241 |
+
load_model_a.click(fn=get_model,
|
| 242 |
+
inputs=[choose_model_a, path_a_str],
|
| 243 |
+
outputs=model_a_placeholder)
|
| 244 |
+
load_model_b.click(fn=get_model,
|
| 245 |
+
inputs=[choose_model_b, path_b_str],
|
| 246 |
+
outputs=model_b_placeholder)
|
| 247 |
+
choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
|
| 248 |
+
inputs=choose_metrics,
|
| 249 |
+
outputs=metrics_placeholder)
|
| 250 |
+
load_button.click(fn=generate_imgs,
|
| 251 |
+
inputs=[dataset_placeholder,
|
| 252 |
+
idx_slider,
|
| 253 |
+
model_a_placeholder,
|
| 254 |
+
model_b_placeholder,
|
| 255 |
+
physics_placeholder,
|
| 256 |
+
use_generator_button,
|
| 257 |
+
metrics_placeholder],
|
| 258 |
+
outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
| 259 |
+
load_random_button.click(fn=update_random_idx_and_generate_imgs,
|
| 260 |
+
inputs=[dataset_placeholder,
|
| 261 |
+
model_a_placeholder,
|
| 262 |
+
model_b_placeholder,
|
| 263 |
+
physics_placeholder,
|
| 264 |
+
use_generator_button,
|
| 265 |
+
metrics_placeholder],
|
| 266 |
+
outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
interface.launch()
|
datasets.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Callable, Optional
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Preprocessed_fastMRI(torch.utils.data.Dataset):
|
| 10 |
+
"""FastMRI from preprocessed data for faster lading."""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
root: str,
|
| 15 |
+
transform: Optional[Callable] = None,
|
| 16 |
+
preprocess: bool = False,
|
| 17 |
+
) -> None:
|
| 18 |
+
self.root = root
|
| 19 |
+
self.transform = transform
|
| 20 |
+
self.preprocess = preprocess
|
| 21 |
+
|
| 22 |
+
# should contain all the information to load a data sample from the storage
|
| 23 |
+
self.sample_identifiers = []
|
| 24 |
+
|
| 25 |
+
# append all filenames in self.root ending with .pt
|
| 26 |
+
for root, _, files in os.walk(self.root):
|
| 27 |
+
for file in files:
|
| 28 |
+
if file.endswith(".pt"):
|
| 29 |
+
self.sample_identifiers.append(file)
|
| 30 |
+
|
| 31 |
+
def __len__(self) -> int:
|
| 32 |
+
return len(self.sample_identifiers)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, idx: int):
|
| 35 |
+
fname = self.sample_identifiers[idx]
|
| 36 |
+
|
| 37 |
+
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
|
| 38 |
+
img = tensor['data'].float()
|
| 39 |
+
|
| 40 |
+
if self.transform is not None:
|
| 41 |
+
img = self.transform(img)
|
| 42 |
+
|
| 43 |
+
if not self.preprocess:
|
| 44 |
+
return img
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
# remove extension and prefix from filename
|
| 48 |
+
fname = Path(fname).stem
|
| 49 |
+
return img, fname
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
|
| 53 |
+
"""FastMRI from preprocessed data for faster lading."""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
root: str,
|
| 58 |
+
transform: Optional[Callable] = None,
|
| 59 |
+
) -> None:
|
| 60 |
+
self.root = root
|
| 61 |
+
self.transform = transform
|
| 62 |
+
|
| 63 |
+
# should contain all the information to load a data sample from the storage
|
| 64 |
+
self.sample_identifiers = []
|
| 65 |
+
|
| 66 |
+
# append all filenames in self.root ending with .pt
|
| 67 |
+
for root, _, files in os.walk(self.root):
|
| 68 |
+
for file in files:
|
| 69 |
+
if file.endswith(".pt"):
|
| 70 |
+
self.sample_identifiers.append(file)
|
| 71 |
+
|
| 72 |
+
def __len__(self) -> int:
|
| 73 |
+
return len(self.sample_identifiers)
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, idx: int):
|
| 76 |
+
fname = self.sample_identifiers[idx]
|
| 77 |
+
|
| 78 |
+
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
|
| 79 |
+
img = tensor['data'].float()
|
| 80 |
+
|
| 81 |
+
if self.transform is not None:
|
| 82 |
+
img = self.transform(img)
|
| 83 |
+
|
| 84 |
+
img = img.unsqueeze(0) # add channel dim
|
evals.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List
|
| 2 |
+
|
| 3 |
+
import deepinv as dinv
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
|
| 10 |
+
from utils import get_model
|
| 11 |
+
|
| 12 |
+
DEFAULT_MODEL_PARAMS = {
|
| 13 |
+
"in_channels": [1, 2, 3],
|
| 14 |
+
"grayscale": False,
|
| 15 |
+
"conv_type": "base",
|
| 16 |
+
"pool_type": "base",
|
| 17 |
+
"layer_scale_init_value": 1e-6,
|
| 18 |
+
"init_type": "ortho",
|
| 19 |
+
"gain_init_conv": 1.0,
|
| 20 |
+
"gain_init_linear": 1.0,
|
| 21 |
+
"drop_prob": 0.0,
|
| 22 |
+
"replk": False,
|
| 23 |
+
"mult_fact": 4,
|
| 24 |
+
"antialias": "gaussian",
|
| 25 |
+
"nc_base": 64,
|
| 26 |
+
"cond_type": "base",
|
| 27 |
+
"blind": False,
|
| 28 |
+
"pretrained_pth": None,
|
| 29 |
+
"N": 2,
|
| 30 |
+
"c_mult": 2,
|
| 31 |
+
"depth_encoding": 2,
|
| 32 |
+
"relu_in_encoding": False,
|
| 33 |
+
"skip_in_encoding": True
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PhysicsWithGenerator(torch.nn.Module):
|
| 38 |
+
"""Interface between Physics, Generator and Gradio."""
|
| 39 |
+
all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur",
|
| 40 |
+
"MRI", "CT"]
|
| 41 |
+
|
| 42 |
+
def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.name = physics_name
|
| 46 |
+
if self.name not in self.all_physics:
|
| 47 |
+
raise ValueError(f"{self.name} is unavailable.")
|
| 48 |
+
|
| 49 |
+
self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
|
| 50 |
+
if self.name == "MotionBlur_easy":
|
| 51 |
+
psf_size = 31
|
| 52 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid",
|
| 53 |
+
device=device_str)
|
| 54 |
+
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
|
| 55 |
+
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
|
| 56 |
+
self.saved_params = {"updatable_params": {"sigma": 0.05},
|
| 57 |
+
"updatable_params_converter": {"sigma": float},
|
| 58 |
+
"fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
|
| 59 |
+
"psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
|
| 60 |
+
elif self.name == "MotionBlur_medium":
|
| 61 |
+
psf_size = 31
|
| 62 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid",
|
| 63 |
+
device=device_str)
|
| 64 |
+
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
|
| 65 |
+
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
|
| 66 |
+
self.saved_params = {"updatable_params": {"sigma": 0.05},
|
| 67 |
+
"updatable_params_converter": {"sigma": float},
|
| 68 |
+
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
|
| 69 |
+
"psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
|
| 70 |
+
elif self.name == "MotionBlur_hard":
|
| 71 |
+
psf_size = 31
|
| 72 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid",
|
| 73 |
+
device=device_str)
|
| 74 |
+
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
|
| 75 |
+
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
|
| 76 |
+
self.saved_params = {"updatable_params": {"sigma": 0.05},
|
| 77 |
+
"updatable_params_converter": {"sigma": float},
|
| 78 |
+
"fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
|
| 79 |
+
"psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
|
| 80 |
+
elif self.name == "GaussianBlur":
|
| 81 |
+
psf_size = 31
|
| 82 |
+
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
|
| 83 |
+
device=device_str)
|
| 84 |
+
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1,
|
| 85 |
+
device=device_str)
|
| 86 |
+
self.generator = self.physics_generator + self.sigma_generator
|
| 87 |
+
self.saved_params = {"updatable_params": {"sigma": 0.05},
|
| 88 |
+
"updatable_params_converter": {"sigma": float},
|
| 89 |
+
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|
| 90 |
+
"psf_size": 31, "num_channels": 1}}
|
| 91 |
+
elif self.name == "MRI":
|
| 92 |
+
self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01),
|
| 93 |
+
device=device_str)
|
| 94 |
+
self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
|
| 95 |
+
self.generator = self.physics_generator # + self.sigma_generator
|
| 96 |
+
self.saved_params = {"updatable_params": {"sigma": 0.05},
|
| 97 |
+
"updatable_params_converter": {"sigma": float},
|
| 98 |
+
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|
| 99 |
+
"acceleration_factor": 4}}
|
| 100 |
+
elif self.name == "CT":
|
| 101 |
+
acceleration_factor = 10
|
| 102 |
+
img_h = 480
|
| 103 |
+
angles = int(img_h / acceleration_factor)
|
| 104 |
+
# angles = torch.linspace(0, 180, steps=10)
|
| 105 |
+
self.physics = dinv.physics.Tomography(
|
| 106 |
+
img_width=img_h,
|
| 107 |
+
angles=angles,
|
| 108 |
+
circle=False,
|
| 109 |
+
normalize=True,
|
| 110 |
+
device=device_str,
|
| 111 |
+
noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
|
| 112 |
+
max_iter=10,
|
| 113 |
+
)
|
| 114 |
+
self.physics_generator = None
|
| 115 |
+
self.generator = self.sigma_generator
|
| 116 |
+
self.saved_params = {"updatable_params": {"sigma": 0.1},
|
| 117 |
+
"updatable_params_converter": {"sigma": float},
|
| 118 |
+
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.,
|
| 119 |
+
"angles": angles, "max_iter": 10}}
|
| 120 |
+
|
| 121 |
+
def display_saved_params(self) -> str:
|
| 122 |
+
"""Printable version of saved_params."""
|
| 123 |
+
updatable_params_str = "Updatable parameters:\n"
|
| 124 |
+
for param_name, param_value in self.saved_params["updatable_params"].items():
|
| 125 |
+
updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n"
|
| 126 |
+
|
| 127 |
+
fixed_params_str = "Fixed parameters:\n"
|
| 128 |
+
for param_name, param_value in self.saved_params["fixed_params"].items():
|
| 129 |
+
fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n"
|
| 130 |
+
|
| 131 |
+
return updatable_params_str + fixed_params_str
|
| 132 |
+
|
| 133 |
+
def _update_save_params(self, key: str, value: Any) -> None:
|
| 134 |
+
"""Update value of an existing key in save_params."""
|
| 135 |
+
if key in list(self.saved_params["updatable_params"].keys()):
|
| 136 |
+
if type(value) == str: # it may be only a str representation
|
| 137 |
+
# type: str -> ???
|
| 138 |
+
value = self.saved_params["updatable_params_converter"][key](value)
|
| 139 |
+
elif isinstance(value, torch.Tensor):
|
| 140 |
+
value = value.item() # type: torch.Tensor -> float
|
| 141 |
+
value = float(f"{value:.4f}") # keeps only 4 significant digits
|
| 142 |
+
self.saved_params["updatable_params"][key] = value
|
| 143 |
+
|
| 144 |
+
def update_and_display_params(self, key, value) -> str:
|
| 145 |
+
"""_update_save_params + update physics with saved_params + display_saved_params"""
|
| 146 |
+
self._update_save_params(key, value)
|
| 147 |
+
|
| 148 |
+
if self.name == "Denoising":
|
| 149 |
+
self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"])
|
| 150 |
+
else:
|
| 151 |
+
self.physics.update_parameters(**self.saved_params["updatable_params"])
|
| 152 |
+
|
| 153 |
+
return self.display_saved_params()
|
| 154 |
+
|
| 155 |
+
def update_saved_params_and_physics(self, **kwargs) -> None:
|
| 156 |
+
"""Update save_params and update physics."""
|
| 157 |
+
for key, value in kwargs.items():
|
| 158 |
+
self._update_save_params(key, value)
|
| 159 |
+
|
| 160 |
+
self.physics.update(**kwargs)
|
| 161 |
+
|
| 162 |
+
def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
|
| 163 |
+
if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"):
|
| 164 |
+
use_gen = True
|
| 165 |
+
elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
|
| 166 |
+
use_gen = True
|
| 167 |
+
|
| 168 |
+
if use_gen:
|
| 169 |
+
kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
|
| 170 |
+
self.update_saved_params_and_physics(**kwargs)
|
| 171 |
+
|
| 172 |
+
return self.physics(x)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class EvalModel(torch.nn.Module):
|
| 176 |
+
"""Eval model.
|
| 177 |
+
|
| 178 |
+
Is there a difference with BaselineModel ?
|
| 179 |
+
-> BaselineModel should be models that are already trained and will have fixed weights.
|
| 180 |
+
-> Eval model will change depending on differents checkpoints.
|
| 181 |
+
"""
|
| 182 |
+
all_models = ["unext_emb_physics_config_C"]
|
| 183 |
+
|
| 184 |
+
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
|
| 185 |
+
"""Load the model we want to evaluate."""
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.base_name = model_name
|
| 188 |
+
self.ckpt_pth = ckpt_pth
|
| 189 |
+
self.name = self.base_name
|
| 190 |
+
if self.base_name not in self.all_models:
|
| 191 |
+
raise ValueError(f"{self.base_name} is unavailable.")
|
| 192 |
+
if self.base_name == "unext_emb_physics_config_C":
|
| 193 |
+
if self.ckpt_pth == "":
|
| 194 |
+
self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar"
|
| 195 |
+
self.model = get_model(model_name=self.base_name,
|
| 196 |
+
device='cpu',
|
| 197 |
+
**DEFAULT_MODEL_PARAMS)
|
| 198 |
+
|
| 199 |
+
# load model checkpoint
|
| 200 |
+
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[
|
| 201 |
+
'state_dict'] # load on cpu
|
| 202 |
+
self.model.load_state_dict(state_dict)
|
| 203 |
+
self.model.to(device_str)
|
| 204 |
+
self.model.eval()
|
| 205 |
+
|
| 206 |
+
# add epoch in the model name
|
| 207 |
+
epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
|
| 208 |
+
self.name = self.name + f"+{epoch}"
|
| 209 |
+
|
| 210 |
+
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
| 211 |
+
return self.model(y, physics=physics)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class BaselineModel(torch.nn.Module):
|
| 215 |
+
"""Baseline model.
|
| 216 |
+
|
| 217 |
+
Is there a difference with EvalModel ?
|
| 218 |
+
-> BaselineModel should be models that are already trained and will have fixed weights.
|
| 219 |
+
-> Eval model will change depending on differents checkpoints.
|
| 220 |
+
"""
|
| 221 |
+
all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR",
|
| 222 |
+
"DPIR_MRI", "DPIR_CT", "PDNET"]
|
| 223 |
+
|
| 224 |
+
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.base_name = model_name
|
| 227 |
+
self.ckpt_pth = ""
|
| 228 |
+
self.name = self.base_name
|
| 229 |
+
if self.name not in self.all_baselines:
|
| 230 |
+
raise ValueError(f"{self.name} is unavailable.")
|
| 231 |
+
elif self.name == "DRUNET":
|
| 232 |
+
n_channels = 3
|
| 233 |
+
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
|
| 234 |
+
self.model = dinv.models.DRUNet(in_channels=n_channels,
|
| 235 |
+
out_channels=n_channels,
|
| 236 |
+
device=device_str,
|
| 237 |
+
pretrained=ckpt_pth)
|
| 238 |
+
self.model.eval() # Set the model to evaluation mode
|
| 239 |
+
elif self.name == 'PDNET':
|
| 240 |
+
ckpt_pth = "ckpt/pdnet.pth.tar"
|
| 241 |
+
self.model = get_model(model_name='pdnet',
|
| 242 |
+
device=device_str)
|
| 243 |
+
self.model.eval()
|
| 244 |
+
self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict'])
|
| 245 |
+
elif self.name == "SWINIRx2":
|
| 246 |
+
n_channels = 3
|
| 247 |
+
scale = 2
|
| 248 |
+
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth"
|
| 249 |
+
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
|
| 250 |
+
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
|
| 251 |
+
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
|
| 252 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
| 253 |
+
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
|
| 254 |
+
pretrained=ckpt_pth)
|
| 255 |
+
self.model.to(device_str)
|
| 256 |
+
self.model.eval() # Set the model to evaluation mode
|
| 257 |
+
elif self.name == "SWINIRx4":
|
| 258 |
+
n_channels = 3
|
| 259 |
+
scale = 4
|
| 260 |
+
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"
|
| 261 |
+
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
|
| 262 |
+
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
|
| 263 |
+
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
|
| 264 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
| 265 |
+
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
|
| 266 |
+
pretrained=ckpt_pth)
|
| 267 |
+
self.model.to(device_str)
|
| 268 |
+
self.model.eval() # Set the model to evaluation mode
|
| 269 |
+
|
| 270 |
+
elif self.name == "PnP-PGD-DRUNET":
|
| 271 |
+
n_channels = 3
|
| 272 |
+
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
|
| 273 |
+
drunet = dinv.models.DRUNet(in_channels=n_channels,
|
| 274 |
+
out_channels=n_channels,
|
| 275 |
+
device=device_str,
|
| 276 |
+
pretrained=ckpt_pth)
|
| 277 |
+
drunet.eval() # Set the model to evaluation mode
|
| 278 |
+
self.model = dinv.optim.optim_builder(iteration="PGD",
|
| 279 |
+
prior=dinv.optim.PnP(drunet).to(device_str),
|
| 280 |
+
data_fidelity=dinv.optim.L2(),
|
| 281 |
+
max_iter=20,
|
| 282 |
+
params_algo={'stepsize': 1., 'g_param': .05})
|
| 283 |
+
elif self.name == "DPIR":
|
| 284 |
+
n_channels = 3
|
| 285 |
+
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
|
| 286 |
+
drunet = dinv.models.DRUNet(in_channels=n_channels,
|
| 287 |
+
out_channels=n_channels,
|
| 288 |
+
device=device_str,
|
| 289 |
+
pretrained=ckpt_pth)
|
| 290 |
+
drunet.eval() # Set the model to evaluation mode
|
| 291 |
+
|
| 292 |
+
# Specify the denoising prior
|
| 293 |
+
self.prior = dinv.optim.prior.PnP(denoiser=drunet)
|
| 294 |
+
elif self.name == "DPIR_MRI":
|
| 295 |
+
class ComplexDenoiser(torch.nn.Module):
|
| 296 |
+
def __init__(self, denoiser):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.denoiser = denoiser
|
| 299 |
+
|
| 300 |
+
def forward(self, x, sigma):
|
| 301 |
+
noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0)
|
| 302 |
+
input_min = noisy_batch.min()
|
| 303 |
+
denoised_batch = self.denoiser(noisy_batch - input_min, sigma)
|
| 304 |
+
denoised_batch = denoised_batch + input_min
|
| 305 |
+
denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1)
|
| 306 |
+
return denoised
|
| 307 |
+
|
| 308 |
+
# Load PnP denoiser backbone
|
| 309 |
+
n_channels = 1
|
| 310 |
+
ckpt_pth = "ckpt/drunet_gray.pth"
|
| 311 |
+
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
|
| 312 |
+
pretrained=ckpt_pth)
|
| 313 |
+
complex_drunet = ComplexDenoiser(drunet)
|
| 314 |
+
complex_drunet.eval()
|
| 315 |
+
|
| 316 |
+
# Specify the denoising prior
|
| 317 |
+
self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet)
|
| 318 |
+
elif self.name == "DPIR_CT":
|
| 319 |
+
class CTDenoiser(torch.nn.Module):
|
| 320 |
+
def __init__(self, denoiser):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.denoiser = denoiser
|
| 323 |
+
|
| 324 |
+
def forward(self, x, sigma):
|
| 325 |
+
x = x - x.min()
|
| 326 |
+
denoised = self.denoiser(x, sigma)
|
| 327 |
+
denoised = denoised + x.min()
|
| 328 |
+
return denoised
|
| 329 |
+
|
| 330 |
+
# Load PnP denoiser backbone
|
| 331 |
+
n_channels = 1
|
| 332 |
+
ckpt_pth = "ckpt/drunet_gray.pth"
|
| 333 |
+
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
|
| 334 |
+
pretrained=ckpt_pth)
|
| 335 |
+
ct_drunet = CTDenoiser(drunet)
|
| 336 |
+
ct_drunet.eval()
|
| 337 |
+
|
| 338 |
+
# Specify the denoising prior
|
| 339 |
+
self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet)
|
| 340 |
+
|
| 341 |
+
def circular_roll(self, tensor, p_h, p_w):
|
| 342 |
+
return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1))
|
| 343 |
+
|
| 344 |
+
def get_DPIR_params(self, noise_level_img, max_iter=8):
|
| 345 |
+
r"""
|
| 346 |
+
Default parameters for the DPIR Plug-and-Play algorithm.
|
| 347 |
+
|
| 348 |
+
:param float noise_level_img: Noise level of the input image.
|
| 349 |
+
:return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations).
|
| 350 |
+
"""
|
| 351 |
+
max_iter = 8
|
| 352 |
+
s1 = 49.0 / 255.0
|
| 353 |
+
s2 = max(noise_level_img, 0.01)
|
| 354 |
+
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
|
| 355 |
+
np.float32
|
| 356 |
+
)
|
| 357 |
+
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
|
| 358 |
+
lamb = 1 / 0.23
|
| 359 |
+
return list(sigma_denoiser), list(lamb * stepsize)
|
| 360 |
+
|
| 361 |
+
def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8):
|
| 362 |
+
r"""
|
| 363 |
+
Default parameters for the DPIR Plug-and-Play algorithm.
|
| 364 |
+
|
| 365 |
+
:param float noise_level_img: Noise level of the input image.
|
| 366 |
+
"""
|
| 367 |
+
s1 = 49.0 / 255.0
|
| 368 |
+
s2 = noise_level_img
|
| 369 |
+
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
|
| 370 |
+
np.float32
|
| 371 |
+
)
|
| 372 |
+
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
|
| 373 |
+
lamb = 1.
|
| 374 |
+
return lamb, list(sigma_denoiser), list(stepsize), max_iter
|
| 375 |
+
|
| 376 |
+
def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0):
|
| 377 |
+
r"""
|
| 378 |
+
Default parameters for the DPIR Plug-and-Play algorithm.
|
| 379 |
+
|
| 380 |
+
:param float noise_level_img: Noise level of the input image.
|
| 381 |
+
"""
|
| 382 |
+
s1 = 49.0 / 255.0 * lip_cons
|
| 383 |
+
s2 = noise_level_img
|
| 384 |
+
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
|
| 385 |
+
np.float32
|
| 386 |
+
)
|
| 387 |
+
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 #
|
| 388 |
+
lamb = 1.
|
| 389 |
+
return lamb, list(sigma_denoiser), list(stepsize), max_iter
|
| 390 |
+
|
| 391 |
+
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
| 392 |
+
if self.name == "DRUNET":
|
| 393 |
+
return self.model(y, sigma=physics.noise_model.sigma)
|
| 394 |
+
elif self.name == "PnP-PGD-DRUNET":
|
| 395 |
+
return self.model(y, physics=physics)
|
| 396 |
+
elif self.name == "DPIR":
|
| 397 |
+
# Set the DPIR algorithm parameters
|
| 398 |
+
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
|
| 399 |
+
max_iter = 8
|
| 400 |
+
|
| 401 |
+
sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter)
|
| 402 |
+
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser}
|
| 403 |
+
early_stop = False # Do not stop algorithm with convergence criteria
|
| 404 |
+
|
| 405 |
+
# instantiate DPIR
|
| 406 |
+
model = dinv.optim.optim_builder(
|
| 407 |
+
iteration="HQS",
|
| 408 |
+
prior=self.prior,
|
| 409 |
+
data_fidelity=dinv.optim.data_fidelity.L2(),
|
| 410 |
+
early_stop=early_stop,
|
| 411 |
+
max_iter=max_iter,
|
| 412 |
+
verbose=True,
|
| 413 |
+
params_algo=params_algo,
|
| 414 |
+
)
|
| 415 |
+
return model(y, physics=physics)
|
| 416 |
+
elif self.name == "DPIR_MRI":
|
| 417 |
+
sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value
|
| 418 |
+
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16)
|
| 419 |
+
stepsize = [stepsize[0]] * max_iter
|
| 420 |
+
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
|
| 421 |
+
early_stop = False # Do not stop algorithm with convergence criteria
|
| 422 |
+
|
| 423 |
+
# Instantiate the algorithm class to solve the IP
|
| 424 |
+
model = dinv.optim.optim_builder(
|
| 425 |
+
iteration="HQS",
|
| 426 |
+
prior=self.prior,
|
| 427 |
+
data_fidelity=dinv.optim.data_fidelity.L2(),
|
| 428 |
+
early_stop=early_stop,
|
| 429 |
+
max_iter=max_iter,
|
| 430 |
+
verbose=True,
|
| 431 |
+
params_algo=params_algo,
|
| 432 |
+
)
|
| 433 |
+
return model(y, physics=physics)
|
| 434 |
+
elif self.name == "DPIR_CT":
|
| 435 |
+
# Set the DPIR algorithm parameters
|
| 436 |
+
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
|
| 437 |
+
lip_const = physics.compute_norm(physics.A_adjoint(y))
|
| 438 |
+
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8,
|
| 439 |
+
lip_cons=lip_const.item())
|
| 440 |
+
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
|
| 441 |
+
early_stop = False # Do not stop algorithm with convergence criteria
|
| 442 |
+
|
| 443 |
+
def custom_init(y, physic_op):
|
| 444 |
+
x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4)
|
| 445 |
+
return {"est": (x_init, x_init)}
|
| 446 |
+
|
| 447 |
+
# Instantiate the algorithm class to solve the IP
|
| 448 |
+
algo = dinv.optim.optim_builder(
|
| 449 |
+
iteration="HQS",
|
| 450 |
+
prior=self.prior,
|
| 451 |
+
data_fidelity=dinv.optim.data_fidelity.L2(),
|
| 452 |
+
early_stop=early_stop,
|
| 453 |
+
max_iter=max_iter,
|
| 454 |
+
verbose=True,
|
| 455 |
+
params_algo=params_algo,
|
| 456 |
+
custom_init=custom_init
|
| 457 |
+
)
|
| 458 |
+
return algo(y, physics=physics)
|
| 459 |
+
elif self.name == 'SWINIRx4':
|
| 460 |
+
window_size = 8
|
| 461 |
+
scale = 4
|
| 462 |
+
_, _, h_old, w_old = y.size()
|
| 463 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
| 464 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
| 465 |
+
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
|
| 466 |
+
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
|
| 467 |
+
output = self.model(img_lq)
|
| 468 |
+
output = output[..., :h_old * scale, :w_old * scale]
|
| 469 |
+
output = self.circular_roll(output, -2, -2)
|
| 470 |
+
# check shape of adjoint
|
| 471 |
+
x_adj = physics.A_adjoint(y)
|
| 472 |
+
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
| 473 |
+
return output
|
| 474 |
+
elif self.name == 'SWINIRx2':
|
| 475 |
+
window_size = 8
|
| 476 |
+
scale = 2
|
| 477 |
+
_, _, h_old, w_old = y.size()
|
| 478 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
| 479 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
| 480 |
+
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
|
| 481 |
+
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
|
| 482 |
+
output = self.model(img_lq)
|
| 483 |
+
output = output[..., :h_old * scale, :w_old * scale]
|
| 484 |
+
output = self.circular_roll(output, -1, -1)
|
| 485 |
+
# check shape of adjoint
|
| 486 |
+
x_adj = physics.A_adjoint(y)
|
| 487 |
+
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
| 488 |
+
return output
|
| 489 |
+
elif 'UNROLLED_DPIR' in self.name:
|
| 490 |
+
return self.model(y, physics=physics)
|
| 491 |
+
else:
|
| 492 |
+
return self.model(y)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class EvalDataset(torch.utils.data.Dataset):
|
| 496 |
+
"""
|
| 497 |
+
We expect that images are 480x480.
|
| 498 |
+
"""
|
| 499 |
+
all_datasets = ["Natural", "MRI", "CT"]
|
| 500 |
+
|
| 501 |
+
def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
|
| 502 |
+
self.name = dataset_name
|
| 503 |
+
self.device_str = device_str
|
| 504 |
+
if self.name not in self.all_datasets:
|
| 505 |
+
raise ValueError(f"{self.name} is unavailable.")
|
| 506 |
+
if self.name == 'Natural':
|
| 507 |
+
self.root = 'datasets/LSDIR_samples'
|
| 508 |
+
self.transform = transforms.Compose([transforms.ToTensor()])
|
| 509 |
+
self.dataset = dinv.datasets.LsdirHR(root=self.root,
|
| 510 |
+
download=False,
|
| 511 |
+
transform=self.transform)
|
| 512 |
+
elif self.name == 'MRI':
|
| 513 |
+
self.root = 'datasets/FastMRI_samples'
|
| 514 |
+
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
| 515 |
+
self.dataset = Preprocessed_fastMRI(root=self.root,
|
| 516 |
+
transform=self.transform,
|
| 517 |
+
preprocess=False)
|
| 518 |
+
elif self.name == "CT":
|
| 519 |
+
self.root = 'datasets/LIDC_IDRI_samples'
|
| 520 |
+
self.transform = None
|
| 521 |
+
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
|
| 522 |
+
transform=self.transform)
|
| 523 |
+
|
| 524 |
+
def __len__(self) -> int:
|
| 525 |
+
return len(self.dataset)
|
| 526 |
+
|
| 527 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
| 528 |
+
return self.dataset[idx].to(self.device_str)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class Metric():
|
| 532 |
+
"""Metrics and utilities."""
|
| 533 |
+
all_metrics = ["PSNR", "SSIM", "LPIPS"]
|
| 534 |
+
|
| 535 |
+
def __init__(self, metric_name: str, device_str: str = "cpu") -> None:
|
| 536 |
+
self.name = metric_name
|
| 537 |
+
if self.name not in self.all_metrics:
|
| 538 |
+
raise ValueError(f"{self.name} is unavailable.")
|
| 539 |
+
elif self.name == "PSNR":
|
| 540 |
+
self.metric = dinv.loss.metric.PSNR()
|
| 541 |
+
elif self.name == "SSIM":
|
| 542 |
+
self.metric = dinv.loss.metric.SSIM()
|
| 543 |
+
elif self.name == "LPIPS":
|
| 544 |
+
self.metric = dinv.loss.metric.LPIPS(device=device_str)
|
| 545 |
+
|
| 546 |
+
def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 547 |
+
# it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
|
| 548 |
+
if x_net.shape[-1] != x.shape[-1]:
|
| 549 |
+
min_size = min(x_net.shape[-1], x.shape[-1])
|
| 550 |
+
x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
|
| 551 |
+
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
|
| 552 |
+
x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
|
| 553 |
+
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
|
| 554 |
+
else:
|
| 555 |
+
x_net_crop = x_net
|
| 556 |
+
x_crop = x
|
| 557 |
+
return self.metric(x_net_crop, x_crop)
|
| 558 |
+
|
| 559 |
+
@classmethod
|
| 560 |
+
def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]:
|
| 561 |
+
l = []
|
| 562 |
+
for metric_name in metric_names:
|
| 563 |
+
l.append(cls(metric_name, device_str=device_str))
|
| 564 |
+
return l
|
img_samples/LSDIR_samples/0001000/0000007_s005.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000030_s003.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000067_s005.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000082_s003.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000110_s002.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000125_s003.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000154_s007.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000247_s007.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000259_s003.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000405_s008.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000578_s002.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000669_s010.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000689_s006.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000715_s011.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000752_s010.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000803_s012.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000825_s012.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000921_s012.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000958_s004.png
ADDED
|
img_samples/LSDIR_samples/0001000/0000994_s021.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008033_s006.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008068_s005.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008115_s004.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008217_s002.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008294_s010.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008315_s053.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008340_s015.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008361_s009.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008386_s007.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008491_s006.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008528_s007.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008571_s007.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008573_s012.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008605_s007.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008611_s002.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008631_s005.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008681_s008.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008703_s013.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008714_s010.png
ADDED
|
img_samples/LSDIR_samples/0009000/0008774_s004.png
ADDED
|
img_samples/LSDIR_samples/0023000/0022020_s005.png
ADDED
|
img_samples/LSDIR_samples/0023000/0022037_s011.png
ADDED
|
img_samples/LSDIR_samples/0023000/0022059_s008.png
ADDED
|
img_samples/LSDIR_samples/0023000/0022135_s002.png
ADDED
|