Spaces:
Sleeping
Sleeping
Load automatically the right baseline and default physics when choosing a dataset
Browse files
app.py
CHANGED
|
@@ -112,25 +112,23 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
|
| 112 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
| 113 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
| 114 |
|
| 115 |
-
def get_physics(physics_name):
|
| 116 |
-
if physics_name == 'MRI':
|
| 117 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR_MRI')
|
| 118 |
-
elif physics_name == 'CT':
|
| 119 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR_CT')
|
| 120 |
-
else:
|
| 121 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR')
|
| 122 |
-
return get_physics_on_DEVICE_STR(physics_name), baseline
|
| 123 |
-
|
| 124 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
| 125 |
def get_dataset(dataset_name):
|
| 126 |
global AVAILABLE_PHYSICS
|
| 127 |
if dataset_name == 'MRI':
|
| 128 |
AVAILABLE_PHYSICS = ['MRI']
|
|
|
|
|
|
|
| 129 |
elif dataset_name == 'CT':
|
| 130 |
AVAILABLE_PHYSICS = ['CT']
|
|
|
|
|
|
|
| 131 |
else:
|
| 132 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
### Gradio Blocks interface
|
| 136 |
|
|
@@ -212,10 +210,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
| 212 |
### Event listeners
|
| 213 |
choose_dataset.change(fn=get_dataset,
|
| 214 |
inputs=choose_dataset,
|
| 215 |
-
outputs=dataset_placeholder)
|
| 216 |
-
choose_physics.change(fn=
|
| 217 |
inputs=choose_physics,
|
| 218 |
-
outputs=[physics_placeholder
|
| 219 |
update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
|
| 220 |
choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
|
| 221 |
inputs=choose_metrics,
|
|
|
|
| 112 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
| 113 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
| 116 |
def get_dataset(dataset_name):
|
| 117 |
global AVAILABLE_PHYSICS
|
| 118 |
if dataset_name == 'MRI':
|
| 119 |
AVAILABLE_PHYSICS = ['MRI']
|
| 120 |
+
baseline_name = 'DPIR_MRI'
|
| 121 |
+
physics_name = 'MRI'
|
| 122 |
elif dataset_name == 'CT':
|
| 123 |
AVAILABLE_PHYSICS = ['CT']
|
| 124 |
+
baseline_name = 'DPIR_CT'
|
| 125 |
+
physics_name = 'CT'
|
| 126 |
else:
|
| 127 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
| 128 |
+
baseline_name = 'DPIR'
|
| 129 |
+
physics_name = 'MotionBlur_easy'
|
| 130 |
+
return get_dataset_on_DEVICE_STR(dataset_name), get_physics_on_DEVICE_STR(physics_name), get_baseline_model_on_DEVICE_STR(baseline_name)
|
| 131 |
+
|
| 132 |
|
| 133 |
### Gradio Blocks interface
|
| 134 |
|
|
|
|
| 210 |
### Event listeners
|
| 211 |
choose_dataset.change(fn=get_dataset,
|
| 212 |
inputs=choose_dataset,
|
| 213 |
+
outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
|
| 214 |
+
choose_physics.change(fn=get_physics_on_DEVICE_STR,
|
| 215 |
inputs=choose_physics,
|
| 216 |
+
outputs=[physics_placeholder])
|
| 217 |
update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
|
| 218 |
choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
|
| 219 |
inputs=choose_metrics,
|