Spaces:
Sleeping
Sleeping
Free memory after inference to reduce potential memory overflow
Browse files
app.py
CHANGED
|
@@ -109,14 +109,18 @@ def generate_imgs(x: torch.Tensor,
|
|
| 109 |
# - torch.Tensor object -> Pil object
|
| 110 |
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
|
| 111 |
to_pil = transforms.ToPILImage()
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
print(torch.cuda.memory_allocated() / 1024**2)
|
| 118 |
-
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
| 119 |
|
|
|
|
| 120 |
|
| 121 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
| 122 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
|
@@ -175,10 +179,10 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
| 175 |
|
| 176 |
# Display images
|
| 177 |
with gr.Row():
|
| 178 |
-
gt_img = gr.Image(label="Ground-truth image", interactive=True, key=
|
| 179 |
-
observed_img = gr.Image(label="Observed image", interactive=False, key=
|
| 180 |
-
model_a_out = gr.Image(label="RAM output", interactive=False, key=
|
| 181 |
-
model_b_out = gr.Image(label="DPIR output", interactive=False, key=
|
| 182 |
|
| 183 |
# Manage datasets and display metric values
|
| 184 |
with gr.Row():
|
|
@@ -187,16 +191,16 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
| 187 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
| 188 |
label="Datasets",
|
| 189 |
value=dataset.name)
|
| 190 |
-
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=
|
| 191 |
with gr.Row():
|
| 192 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
| 193 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
| 194 |
with gr.Column(scale=1, min_width=160):
|
| 195 |
-
observed_metrics = gr.Textbox(label="Observed metric", lines=3, key=
|
| 196 |
with gr.Column(scale=1, min_width=160):
|
| 197 |
-
out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key=
|
| 198 |
with gr.Column(scale=1, min_width=160):
|
| 199 |
-
out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key=
|
| 200 |
|
| 201 |
# Manage physics
|
| 202 |
with gr.Row():
|
|
@@ -204,7 +208,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
|
|
| 204 |
choose_physics = gr.Radio(choices=available_physics,
|
| 205 |
label="Physics",
|
| 206 |
value=physics.name)
|
| 207 |
-
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=
|
| 208 |
with gr.Column(scale=1):
|
| 209 |
with gr.Row():
|
| 210 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
|
|
|
| 109 |
# - torch.Tensor object -> Pil object
|
| 110 |
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
|
| 111 |
to_pil = transforms.ToPILImage()
|
| 112 |
+
x_pil = to_pil(process_img(x)[0].to('cpu'))
|
| 113 |
+
y_pil = to_pil(process_img(y_plot)[0].to('cpu'))
|
| 114 |
+
out_pil = to_pil(process_img(out)[0].to('cpu'))
|
| 115 |
+
out_baseline_pil = to_pil(process_img(out_baseline)[0].to('cpu'))
|
| 116 |
|
| 117 |
+
|
| 118 |
+
# Free memory
|
| 119 |
+
del x, y, out, out_baseline, y_plot
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
print(torch.cuda.memory_allocated() / 1024**2)
|
|
|
|
| 122 |
|
| 123 |
+
return x_pil, y_pil, out_pil, out_baseline_pil, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
| 124 |
|
| 125 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
| 126 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
|
|
|
| 179 |
|
| 180 |
# Display images
|
| 181 |
with gr.Row():
|
| 182 |
+
gt_img = gr.Image(label="Ground-truth image", interactive=True, key='gt_img')
|
| 183 |
+
observed_img = gr.Image(label="Observed image", interactive=False, key='observed_img')
|
| 184 |
+
model_a_out = gr.Image(label="RAM output", interactive=False, key='ram_out')
|
| 185 |
+
model_b_out = gr.Image(label="DPIR output", interactive=False, key='dpir_out')
|
| 186 |
|
| 187 |
# Manage datasets and display metric values
|
| 188 |
with gr.Row():
|
|
|
|
| 191 |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
| 192 |
label="Datasets",
|
| 193 |
value=dataset.name)
|
| 194 |
+
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key='idx_slider')
|
| 195 |
with gr.Row():
|
| 196 |
load_button = gr.Button("Run on index image from dataset", size='md')
|
| 197 |
load_random_button = gr.Button("Run on random image from dataset", size='md')
|
| 198 |
with gr.Column(scale=1, min_width=160):
|
| 199 |
+
observed_metrics = gr.Textbox(label="Observed metric", lines=3, key='metrics')
|
| 200 |
with gr.Column(scale=1, min_width=160):
|
| 201 |
+
out_a_metric = gr.Textbox(label="RAM output metrics", lines=3, key='ram_metrics')
|
| 202 |
with gr.Column(scale=1, min_width=160):
|
| 203 |
+
out_b_metric = gr.Textbox(label="DPIR output metrics", lines=3, key='dpir_metrics')
|
| 204 |
|
| 205 |
# Manage physics
|
| 206 |
with gr.Row():
|
|
|
|
| 208 |
choose_physics = gr.Radio(choices=available_physics,
|
| 209 |
label="Physics",
|
| 210 |
value=physics.name)
|
| 211 |
+
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key='use_gen')
|
| 212 |
with gr.Column(scale=1):
|
| 213 |
with gr.Row():
|
| 214 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|