Spaces:
Sleeping
Sleeping
unfinished work
Browse files
app.py
CHANGED
|
@@ -74,18 +74,18 @@ def generate_imgs(dataset: EvalDataset, idx: int,
|
|
| 74 |
|
| 75 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
| 76 |
|
| 77 |
-
def update_random_idx_and_generate_imgs(dataset: EvalDataset,
|
| 78 |
-
model: EvalModel,
|
| 79 |
-
baseline: BaselineModel,
|
| 80 |
physics: PhysicsWithGenerator,
|
| 81 |
use_gen: bool,
|
| 82 |
metrics: List[Metric]):
|
| 83 |
idx = random.randint(0, len(dataset)-1)
|
| 84 |
-
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
|
| 85 |
-
idx,
|
| 86 |
-
model,
|
| 87 |
-
baseline,
|
| 88 |
-
physics,
|
| 89 |
use_gen,
|
| 90 |
metrics)
|
| 91 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
|
@@ -145,6 +145,17 @@ def get_model(model_name, ckpt_pth):
|
|
| 145 |
else:
|
| 146 |
return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
### Gradio Blocks interface
|
| 150 |
|
|
|
|
| 74 |
|
| 75 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
| 76 |
|
| 77 |
+
def update_random_idx_and_generate_imgs(dataset: EvalDataset,
|
| 78 |
+
model: EvalModel,
|
| 79 |
+
baseline: BaselineModel,
|
| 80 |
physics: PhysicsWithGenerator,
|
| 81 |
use_gen: bool,
|
| 82 |
metrics: List[Metric]):
|
| 83 |
idx = random.randint(0, len(dataset)-1)
|
| 84 |
+
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
|
| 85 |
+
idx,
|
| 86 |
+
model,
|
| 87 |
+
baseline,
|
| 88 |
+
physics,
|
| 89 |
use_gen,
|
| 90 |
metrics)
|
| 91 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
|
|
|
| 145 |
else:
|
| 146 |
return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
|
| 147 |
|
| 148 |
+
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
| 149 |
+
|
| 150 |
+
def get_dataset(dataset_name):
|
| 151 |
+
global AVAILABLE_PHYSICS
|
| 152 |
+
if dataset_name = 'MRI':
|
| 153 |
+
AVAILABLE_PHYSICS = ['MRI']
|
| 154 |
+
elif dataset_name = 'CT':
|
| 155 |
+
AVAILABLE_PHYSICS = ['CT']
|
| 156 |
+
else:
|
| 157 |
+
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
| 158 |
+
return get_dataset_on_DEVICE_STR(dataset_name)
|
| 159 |
|
| 160 |
### Gradio Blocks interface
|
| 161 |
|