Spaces:
Sleeping
Sleeping
added inpainting
Browse files- app.py +3 -3
- factories.py +21 -5
app.py
CHANGED
|
@@ -159,7 +159,7 @@ def get_dataset(dataset_name):
|
|
| 159 |
physics_name = 'CT'
|
| 160 |
baseline_name = 'DPIR_CT'
|
| 161 |
else:
|
| 162 |
-
available_physics = ['MotionBlur_medium', 'MotionBlur_hard',
|
| 163 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
| 164 |
physics_name = 'MotionBlur_hard'
|
| 165 |
baseline_name = 'DPIR'
|
|
@@ -192,11 +192,11 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
| 192 |
|
| 193 |
### USER-SPECIFIC VARIABLES
|
| 194 |
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
| 195 |
-
available_physics_placeholder = gr.State(['MotionBlur_medium', 'MotionBlur_hard',
|
| 196 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
|
| 197 |
# Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
|
| 198 |
# Solution: using lambda expression
|
| 199 |
-
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("
|
| 200 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
|
| 201 |
|
| 202 |
print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
|
|
|
| 159 |
physics_name = 'CT'
|
| 160 |
baseline_name = 'DPIR_CT'
|
| 161 |
else:
|
| 162 |
+
available_physics = ['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
|
| 163 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
| 164 |
physics_name = 'MotionBlur_hard'
|
| 165 |
baseline_name = 'DPIR'
|
|
|
|
| 192 |
|
| 193 |
### USER-SPECIFIC VARIABLES
|
| 194 |
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
| 195 |
+
available_physics_placeholder = gr.State(['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
|
| 196 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
|
| 197 |
# Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
|
| 198 |
# Solution: using lambda expression
|
| 199 |
+
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_hard"))
|
| 200 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
|
| 201 |
|
| 202 |
print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
factories.py
CHANGED
|
@@ -13,7 +13,7 @@ from physics.blur_generator import GaussianBlurGenerator
|
|
| 13 |
|
| 14 |
class PhysicsWithGenerator(torch.nn.Module):
|
| 15 |
"""Interface between Physics, Generator and Gradio."""
|
| 16 |
-
all_physics = ["MotionBlur_medium", "MotionBlur_hard",
|
| 17 |
"GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
|
| 18 |
"MRI", "CT"]
|
| 19 |
|
|
@@ -83,6 +83,22 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
| 83 |
"updatable_params_converter": {"sigma": float},
|
| 84 |
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
|
| 85 |
"blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
elif self.name == "MRI":
|
| 87 |
self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
|
| 88 |
img_size=(640, 320), device=device_str)
|
|
@@ -101,14 +117,14 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
| 101 |
circle=False,
|
| 102 |
normalize=True,
|
| 103 |
device=device_str,
|
| 104 |
-
noise_model=dinv.physics.GaussianNoise(sigma=1e-
|
| 105 |
max_iter=10,
|
| 106 |
)
|
| 107 |
-
self.physics_generator = SigmaGenerator(sigma_min=1e-
|
| 108 |
-
self.generator = SigmaGenerator(sigma_min=1e-
|
| 109 |
self.saved_params = {"updatable_params": {"sigma": 0.1},
|
| 110 |
"updatable_params_converter": {"sigma": float},
|
| 111 |
-
"fixed_params": {"noise_sigma_min": 1e-
|
| 112 |
"angles": angles, "max_iter": 10}}
|
| 113 |
|
| 114 |
def display_saved_params(self) -> str:
|
|
|
|
| 13 |
|
| 14 |
class PhysicsWithGenerator(torch.nn.Module):
|
| 15 |
"""Interface between Physics, Generator and Gradio."""
|
| 16 |
+
all_physics = ["Inpainting", "MotionBlur_medium", "MotionBlur_hard",
|
| 17 |
"GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
|
| 18 |
"MRI", "CT"]
|
| 19 |
|
|
|
|
| 83 |
"updatable_params_converter": {"sigma": float},
|
| 84 |
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
|
| 85 |
"blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
|
| 86 |
+
elif self.name == "Inpainting":
|
| 87 |
+
self.physics = dinv.physics.Inpainting(tensor_size=(256, 256), mask=split_ratio,
|
| 88 |
+
noise_model=dinv.physics.GaussianNoise(sigma=sigma),
|
| 89 |
+
device=device_str)
|
| 90 |
+
self.physics_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
|
| 91 |
+
split_ratio=split_ratio, pixelwise=pixelwise,
|
| 92 |
+
random_split_ratio=True, min_split_ratio=split_ratio,
|
| 93 |
+
max_split_ratio=split_ratio, device=device_str)
|
| 94 |
+
self.generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
|
| 95 |
+
split_ratio=split_ratio, pixelwise=pixelwise,
|
| 96 |
+
random_split_ratio=True, min_split_ratio=split_ratio,
|
| 97 |
+
max_split_ratio=split_ratio, device=device_str)
|
| 98 |
+
|
| 99 |
+
self.saved_params = {"updatable_params": {},
|
| 100 |
+
"updatable_params_converter": {"sigma": float},
|
| 101 |
+
"fixed_params": {"sigma": sigma}}
|
| 102 |
elif self.name == "MRI":
|
| 103 |
self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
|
| 104 |
img_size=(640, 320), device=device_str)
|
|
|
|
| 117 |
circle=False,
|
| 118 |
normalize=True,
|
| 119 |
device=device_str,
|
| 120 |
+
noise_model=dinv.physics.GaussianNoise(sigma=1e-3).to(device_str),
|
| 121 |
max_iter=10,
|
| 122 |
)
|
| 123 |
+
self.physics_generator = SigmaGenerator(sigma_min=1e-3, sigma_max=1e-3, device=device_str)
|
| 124 |
+
self.generator = SigmaGenerator(sigma_min=1e-3, sigma_max=1e-3, device=device_str)
|
| 125 |
self.saved_params = {"updatable_params": {"sigma": 0.1},
|
| 126 |
"updatable_params_converter": {"sigma": float},
|
| 127 |
+
"fixed_params": {"noise_sigma_min": 1e-3, "noise_sigma_max": 1e-3,
|
| 128 |
"angles": angles, "max_iter": 10}}
|
| 129 |
|
| 130 |
def display_saved_params(self) -> str:
|