behretj commited on
Commit
8062ce3
·
1 Parent(s): 00a8267

caching-version + variable baseline control

Browse files
Files changed (1) hide show
  1. app.py +82 -49
app.py CHANGED
@@ -11,6 +11,8 @@ import sys
11
  import tempfile
12
  from PIL import Image
13
  from gradio_patches.radio import Radio
 
 
14
 
15
  REPO_URL = "https://github.com/prs-eth/stereospace.git"
16
  COMMIT_SHA = "d7bbae6"
@@ -80,9 +82,35 @@ def find_all_output_files(output_dir, base_name):
80
  pass
81
  return outputs
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  @spaces.GPU
85
- def process_all_modes(input_image):
 
 
 
 
 
 
 
86
  if input_image is None:
87
  raise gr.Error("Please upload an image or select an example.")
88
 
@@ -96,27 +124,16 @@ def process_all_modes(input_image):
96
  try:
97
  with tempfile.TemporaryDirectory() as tmp_output:
98
  base_name = os.path.splitext(os.path.basename(input_path))[0]
99
-
100
- inference_script = "inference.py"
101
- cmd = [
102
- "python", inference_script,
103
- "--input", input_path,
104
- "--output", tmp_output
105
- ]
106
-
107
- print(f"Running command: {' '.join(cmd)}")
108
- print(f"Working directory: {REPO_DIR}")
109
- result = subprocess.run(
110
- cmd,
111
- cwd=REPO_DIR,
112
- capture_output=True,
113
- text=True
114
  )
 
 
115
 
116
- if result.returncode != 0:
117
- error_msg = f"Inference failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
118
- print(error_msg)
119
- raise gr.Error(error_msg)
120
 
121
  try:
122
  output_files = find_all_output_files(tmp_output, base_name)
@@ -198,12 +215,24 @@ with gr.Blocks(
198
  )
199
 
200
  with gr.Row():
201
- image = gr.Image(
 
 
 
 
 
202
  type="filepath",
203
  label="Input/Output Image",
204
  elem_classes="result-image",
205
  height=480,
206
  )
 
 
 
 
 
 
 
207
 
208
  outputs_gallery = gr.Gallery(
209
  visible=False,
@@ -243,12 +272,12 @@ with gr.Blocks(
243
 
244
  return result
245
 
246
- def process_new_image(img, current_mode):
247
  if img is None:
248
  placeholder = create_placeholder_image()
249
  return [], placeholder, gr.update()
250
 
251
- all_outputs = process_all_modes(img)
252
  gallery_data = []
253
  available_modes = []
254
  for mode_name, mode_image in all_outputs.items():
@@ -259,12 +288,12 @@ with gr.Blocks(
259
  radio_value = current_mode if current_mode in available_modes else (available_modes[0] if available_modes else "Anaglyph")
260
  return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
261
 
262
- def process_example_simple(img):
263
  if img is None:
264
  placeholder = create_placeholder_image()
265
  return [], placeholder, gr.update()
266
 
267
- all_outputs = process_all_modes(img)
268
 
269
  gallery_data = []
270
  available_modes = []
@@ -277,45 +306,43 @@ with gr.Blocks(
277
  return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
278
 
279
  def clear_image():
280
- return None, []
281
 
282
  examples_list = get_example_images()
283
  if examples_list:
284
- def process_example_wrapper(img):
285
- gallery_data, placeholder_image, radio_update = process_example_simple(img)
286
- return gallery_data, placeholder_image, radio_update
287
-
288
  examples_component = gr.Examples(
289
  examples=examples_list,
290
- inputs=[image],
291
- outputs=[outputs_gallery, image, output_mode],
292
- fn=process_example_wrapper,
293
- cache_examples=True,
294
- cache_mode="lazy",
295
  label="Example Images",
296
  elem_id="example-images-gallery",
297
  )
298
 
299
- def process_upload_wrapper(img, current_mode):
300
- gallery_data, blocked_image, radio_update = process_new_image(img, current_mode)
 
 
 
 
 
 
301
  return gallery_data, blocked_image, radio_update
302
 
303
- upload_event = image.upload(
304
- fn=process_upload_wrapper,
305
- inputs=[image, output_mode],
306
- outputs=[outputs_gallery, image, output_mode]
307
  )
308
 
309
  def on_gallery_change(gallery_data, current_mode, current_image):
310
  if not gallery_data or len(gallery_data) == 0:
311
- return current_image, gr.update(interactive=True)
312
  updated_image = update_image_from_gallery(gallery_data, current_mode, current_image)
313
- return updated_image, gr.update(interactive=True)
314
 
315
  gallery_change_event = outputs_gallery.change(
316
  fn=on_gallery_change,
317
- inputs=[outputs_gallery, output_mode, image],
318
- outputs=[image, image]
319
  )
320
 
321
  def switch_mode_handler(current_mode, gallery_data, current_image):
@@ -324,13 +351,19 @@ with gr.Blocks(
324
 
325
  output_mode.change(
326
  fn=switch_mode_handler,
327
- inputs=[output_mode, outputs_gallery, image],
328
- outputs=image
 
 
 
 
 
 
329
  )
330
 
331
- image.clear(
332
  fn=clear_image,
333
- outputs=[image, outputs_gallery]
334
  )
335
 
336
  if __name__ == "__main__":
 
11
  import tempfile
12
  from PIL import Image
13
  from gradio_patches.radio import Radio
14
+ import torch
15
+ from omegaconf import OmegaConf
16
 
17
  REPO_URL = "https://github.com/prs-eth/stereospace.git"
18
  COMMIT_SHA = "d7bbae6"
 
82
  pass
83
  return outputs
84
 
85
+ class Args:
86
+ def __init__(
87
+ self,
88
+ input,
89
+ output,
90
+ baseline=0.15,
91
+ batch_size=1,
92
+ src_intrinsics=None,
93
+ tgt_intrinsics=None,
94
+ ):
95
+ self.input = input
96
+ self.output = output
97
+ self.baseline = baseline
98
+ self.batch_size = batch_size
99
+ self.src_intrinsics = src_intrinsics
100
+ self.tgt_intrinsics = tgt_intrinsics
101
+
102
+ stereo_nvs = None
103
+ config = OmegaConf.load(os.path.join(REPO_DIR, "configs/stereospace.yaml"))
104
 
105
  @spaces.GPU
106
+ def process_all_modes(input_image, baseline=0.15):
107
+ global stereo_nvs
108
+
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ if stereo_nvs is None:
112
+ stereo_nvs = StereoSpace(config, device=device)
113
+
114
  if input_image is None:
115
  raise gr.Error("Please upload an image or select an example.")
116
 
 
124
  try:
125
  with tempfile.TemporaryDirectory() as tmp_output:
126
  base_name = os.path.splitext(os.path.basename(input_path))[0]
127
+
128
+ args = Args(
129
+ input=input_path,
130
+ output=tmp_output,
131
+ baseline=float(baseline),
132
+ batch_size=1,
 
 
 
 
 
 
 
 
 
133
  )
134
+
135
+ generate_novel_view(args=args, config=config, stereo_nvs=stereo_nvs)
136
 
 
 
 
 
137
 
138
  try:
139
  output_files = find_all_output_files(tmp_output, base_name)
 
215
  )
216
 
217
  with gr.Row():
218
+ image_in = gr.Image(
219
+ type="filepath",
220
+ height=480,
221
+ visible=False,
222
+ )
223
+ image_out = gr.Image(
224
  type="filepath",
225
  label="Input/Output Image",
226
  elem_classes="result-image",
227
  height=480,
228
  )
229
+
230
+ with gr.Row():
231
+ baseline_slider = gr.Slider(
232
+ minimum=0.025, maximum=0.4, value=0.15, step=0.005,
233
+ label="Baseline length (meters)"
234
+ )
235
+ run_btn = gr.Button("Run")
236
 
237
  outputs_gallery = gr.Gallery(
238
  visible=False,
 
272
 
273
  return result
274
 
275
+ def process_new_image(img, current_mode, baseline=0.15):
276
  if img is None:
277
  placeholder = create_placeholder_image()
278
  return [], placeholder, gr.update()
279
 
280
+ all_outputs = process_all_modes(img, baseline=baseline)
281
  gallery_data = []
282
  available_modes = []
283
  for mode_name, mode_image in all_outputs.items():
 
288
  radio_value = current_mode if current_mode in available_modes else (available_modes[0] if available_modes else "Anaglyph")
289
  return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
290
 
291
+ def process_example_simple(img, baseline=0.15):
292
  if img is None:
293
  placeholder = create_placeholder_image()
294
  return [], placeholder, gr.update()
295
 
296
+ all_outputs = process_all_modes(img, baseline=baseline)
297
 
298
  gallery_data = []
299
  available_modes = []
 
306
  return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
307
 
308
  def clear_image():
309
+ return None, None, []
310
 
311
  examples_list = get_example_images()
312
  if examples_list:
 
 
 
 
313
  examples_component = gr.Examples(
314
  examples=examples_list,
315
+ inputs=[image_in],
 
 
 
 
316
  label="Example Images",
317
  elem_id="example-images-gallery",
318
  )
319
 
320
+ image_in.change(
321
+ fn=lambda img: (img, []),
322
+ inputs=[image_in],
323
+ outputs=[image_out, outputs_gallery],
324
+ )
325
+
326
+ def run_wrapper(img, current_mode, baseline=0.15):
327
+ gallery_data, blocked_image, radio_update = process_new_image(img, current_mode, baseline=baseline)
328
  return gallery_data, blocked_image, radio_update
329
 
330
+ run_btn.click(
331
+ fn=run_wrapper,
332
+ inputs=[image_in, output_mode, baseline_slider],
333
+ outputs=[outputs_gallery, image_out, output_mode],
334
  )
335
 
336
  def on_gallery_change(gallery_data, current_mode, current_image):
337
  if not gallery_data or len(gallery_data) == 0:
338
+ return current_image
339
  updated_image = update_image_from_gallery(gallery_data, current_mode, current_image)
340
+ return updated_image
341
 
342
  gallery_change_event = outputs_gallery.change(
343
  fn=on_gallery_change,
344
+ inputs=[outputs_gallery, output_mode, image_in],
345
+ outputs=image_out
346
  )
347
 
348
  def switch_mode_handler(current_mode, gallery_data, current_image):
 
351
 
352
  output_mode.change(
353
  fn=switch_mode_handler,
354
+ inputs=[output_mode, outputs_gallery, image_out],
355
+ outputs=image_out
356
+ )
357
+
358
+ image_out.upload(
359
+ fn=lambda img: [img, img],
360
+ inputs=[image_out],
361
+ outputs=[image_in, image_out]
362
  )
363
 
364
+ image_out.clear(
365
  fn=clear_image,
366
+ outputs=[image_in, image_out, outputs_gallery]
367
  )
368
 
369
  if __name__ == "__main__":