JasonYinnnn commited on
Commit
a0fbf94
·
1 Parent(s): a08e831

use gr.State

Browse files
Files changed (1) hide show
  1. app.py +58 -53
app.py CHANGED
@@ -39,8 +39,7 @@ DTYPE = torch.float16
39
  DEVICE = "cpu"
40
  VALID_RATIO_THRESHOLD = 0.005
41
  CROP_SIZE = 518
42
- work_space = None
43
- dpt_pack = None
44
  generated_object_map = {}
45
 
46
  ############## 3D-Fixer model
@@ -148,7 +147,6 @@ def run_segmentation(
148
  ) -> Image.Image:
149
  rgb_image = image_prompts["image"].convert("RGB")
150
 
151
- global work_space
152
  global sam_segmentator
153
 
154
  device = "cpu"
@@ -176,16 +174,19 @@ def run_segmentation(
176
  seg_map_pil = plot_segmentation(rgb_image, detections)
177
 
178
  cleanup_tmp(TMP_DIR, expire_seconds=3600)
179
- work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
180
- os.makedirs(work_space, exist_ok=True)
181
- seg_map_pil.save(os.path.join(work_space, 'mask.png'))
 
 
182
 
183
- return seg_map_pil
184
 
185
  @spaces.GPU
186
  def run_depth_estimation(
187
  image_prompts: Any,
188
  seg_image: Union[str, Image.Image],
 
189
  ) -> Image.Image:
190
  rgb_image = image_prompts["image"].convert("RGB")
191
 
@@ -201,14 +202,11 @@ def run_depth_estimation(
201
  dtype = torch.float16 if device == 'cuda' else torch.float32
202
  moge_v2_dpt_model = moge_v2_dpt_model.to(device=device, dtype=dtype)
203
 
204
- global dpt_pack
205
- global work_space
206
  if work_space is None:
207
- work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
208
- os.makedirs(work_space, exist_ok=True)
209
- global generated_object_map
210
-
211
- generated_object_map = {}
212
 
213
  origin_W, origin_H = rgb_image.size
214
  if max(origin_H, origin_W) > 1024:
@@ -238,12 +236,12 @@ def run_depth_estimation(
238
  ])
239
  ).to(dtype=torch.float32, device=device)
240
 
241
- dpt_pack = {
242
- 'c2w': c2w,
243
- 'K': K,
244
- 'depth_mask': depth_mask,
245
- 'depth': depth
246
- }
247
 
248
  instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
249
  seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
@@ -260,7 +258,7 @@ def run_depth_estimation(
260
 
261
  scene_est_depth_pts, scene_est_depth_pts_colors = \
262
  project2ply(depth_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
263
- save_ply_path = os.path.join(work_space, "scene_pcd.glb")
264
 
265
  fg_depth_pts, _ = \
266
  project2ply(fg_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
@@ -269,22 +267,22 @@ def run_depth_estimation(
269
  if trans.shape[0] == 1:
270
  trans = trans[0]
271
 
272
- dpt_pack.update(
273
  {
274
  "trans": trans,
275
  "scale": scale,
276
  }
277
  )
278
 
279
- for k, v in dpt_pack.items():
280
  if isinstance(v, torch.Tensor):
281
- dpt_pack[k] = v.to('cpu')
282
 
283
  trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
284
  apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
285
  apply_transform(rot).export(save_ply_path)
286
 
287
- return save_ply_path
288
 
289
 
290
  def save_image(img, save_path):
@@ -307,7 +305,7 @@ def export_scene_glb(trimeshes, work_space, scene_name):
307
 
308
  def get_duration(rgb_image, seg_image, seed, randomize_seed,
309
  num_inference_steps, guidance_scale, cfg_interval_start,
310
- cfg_interval_end, t_rescale):
311
  instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
312
  step_duration = 15.0
313
  return instance_labels.shape[0] * step_duration + 60
@@ -323,8 +321,16 @@ def run_generation(
323
  cfg_interval_start: float = 0.5,
324
  cfg_interval_end: float = 1.0,
325
  t_rescale: float = 3.0,
 
326
  ):
327
 
 
 
 
 
 
 
 
328
  from threeDFixer.pipelines import ThreeDFixerPipeline
329
  from threeDFixer.datasets.utils import (
330
  edge_mask_morph_gradient,
@@ -377,9 +383,6 @@ def run_generation(
377
 
378
  return instance_glb_path, glb
379
 
380
- global dpt_pack
381
- global work_space
382
- global generated_object_map
383
  generated_object_map = {}
384
  run_id = str(uuid.uuid4())
385
 
@@ -397,11 +400,11 @@ def run_generation(
397
  seed = random.randint(0, MAX_SEED)
398
  set_random_seed(seed)
399
 
400
- H, W = dpt_pack['depth_mask'].shape
401
  rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
402
  seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
403
 
404
- depth_mask = dpt_pack['depth_mask'].detach().cpu().numpy() > 0
405
  seg_image = np.array(seg_image)
406
 
407
  mask_pack = []
@@ -416,8 +419,8 @@ def run_generation(
416
  results = []
417
  trimeshes = []
418
 
419
- trans = dpt_pack['trans']
420
- scale = dpt_pack['scale']
421
 
422
  current_scene_path = None
423
  pending_exports = []
@@ -534,7 +537,7 @@ def run_generation(
534
  trimeshes.append(glb)
535
  current_scene_path = export_scene_glb(
536
  trimeshes=trimeshes,
537
- work_space=work_space,
538
  scene_name=f"{run_id}_scene_step_{len(trimeshes)}.glb",
539
  )
540
  any_update = True
@@ -566,12 +569,12 @@ def run_generation(
566
  if flushed is not None:
567
  yield flushed
568
 
569
- est_depth = dpt_pack['depth'].to('cpu')
570
- c2w = dpt_pack['c2w'].to('cpu')
571
- K = dpt_pack['K'].to('cpu')
572
 
573
- intrinsics = dpt_pack['K'].float().to(DEVICE)
574
- extrinsics = copy.deepcopy(dpt_pack['c2w']).float().to(DEVICE)
575
  extrinsics[:3, 1:3] *= -1
576
 
577
  object_mask = object_mask > 0
@@ -590,12 +593,12 @@ def run_generation(
590
  instance_image, instance_mask, instance_rays_o, instance_rays_d, instance_rays_c, \
591
  instance_rays_t = process_instance_image(image, instance_mask, color_mask, est_depth, K, c2w, CROP_SIZE)
592
 
593
- save_image(scene_image, os.path.join(work_space, f'input_scene_image_{instance_name}.png'))
594
- save_image(scene_image_masked, os.path.join(work_space, f'input_scene_image_masked_{instance_name}.png'))
595
- save_image(instance_image, os.path.join(work_space, f'input_instance_image_{instance_name}.png'))
596
  save_image(
597
  torch.cat([instance_image, instance_mask]),
598
- os.path.join(work_space, f'input_instance_image_masked_{instance_name}.png')
599
  )
600
 
601
  pcd_points = (
@@ -607,7 +610,7 @@ def run_generation(
607
  save_projected_colored_pcd(
608
  pcd_points,
609
  repeat(pcd_colors, 'n -> n c', c=3),
610
- f"{work_space}/instance_est_depth_{instance_name}.ply"
611
  )
612
 
613
  with torch.no_grad():
@@ -634,10 +637,10 @@ def run_generation(
634
  )
635
 
636
  mp4_path = os.path.abspath(
637
- os.path.join(work_space, f"{run_id}_instance_gs_fine_{instance_name}.mp4")
638
  )
639
  poster_path = os.path.abspath(
640
- os.path.join(work_space, f"{run_id}_instance_gs_fine_{instance_name}.png")
641
  )
642
 
643
  video = render_utils.render_video(
@@ -678,7 +681,7 @@ def run_generation(
678
  trans=trans,
679
  scale=scale,
680
  rot=rot,
681
- work_space=work_space,
682
  instance_name=instance_name,
683
  run_id=run_id,
684
  )
@@ -708,7 +711,7 @@ def run_generation(
708
  if len(ready_items) > 0:
709
  final_scene_path = export_scene_glb(
710
  trimeshes=trimeshes,
711
- work_space=work_space,
712
  scene_name=f"{run_id}_scene_final.glb",
713
  )
714
 
@@ -740,6 +743,7 @@ def update_single_download(selected_name):
740
 
741
  # Demo
742
  with gr.Blocks() as demo:
 
743
  gr.Markdown(MARKDOWN)
744
 
745
  with gr.Column():
@@ -812,7 +816,6 @@ with gr.Blocks() as demo:
812
  with gr.Row():
813
  gr.Examples(
814
  examples=EXAMPLES,
815
- fn=run_generation,
816
  inputs=[image_prompts, seg_image, seed, randomize_seed, num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale],
817
  outputs=[model_output, download_glb, seed],
818
  cache_examples=False,
@@ -824,16 +827,17 @@ with gr.Blocks() as demo:
824
  image_prompts,
825
  polygon_refinement,
826
  ],
827
- outputs=[seg_image],
828
  ).then(lambda: gr.Button(interactive=True), outputs=[dpt_button])
829
 
830
  dpt_button.click(
831
  run_depth_estimation,
832
  inputs=[
833
  image_prompts,
834
- seg_image
 
835
  ],
836
- outputs=[dpt_model_output],
837
  ).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
838
 
839
  gen_button.click(
@@ -847,7 +851,8 @@ with gr.Blocks() as demo:
847
  guidance_scale,
848
  cfg_interval_start,
849
  cfg_interval_end,
850
- t_rescale
 
851
  ],
852
  outputs=[model_output,
853
  stream_output,
 
39
  DEVICE = "cpu"
40
  VALID_RATIO_THRESHOLD = 0.005
41
  CROP_SIZE = 518
42
+ work_space = None
 
43
  generated_object_map = {}
44
 
45
  ############## 3D-Fixer model
 
147
  ) -> Image.Image:
148
  rgb_image = image_prompts["image"].convert("RGB")
149
 
 
150
  global sam_segmentator
151
 
152
  device = "cpu"
 
174
  seg_map_pil = plot_segmentation(rgb_image, detections)
175
 
176
  cleanup_tmp(TMP_DIR, expire_seconds=3600)
177
+ work_space = {
178
+ "dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"),
179
+ }
180
+ os.makedirs(work_space["dir"], exist_ok=True)
181
+ seg_map_pil.save(os.path.join(work_space["dir"], "mask.png"))
182
 
183
+ return seg_map_pil, work_space
184
 
185
  @spaces.GPU
186
  def run_depth_estimation(
187
  image_prompts: Any,
188
  seg_image: Union[str, Image.Image],
189
+ work_space: dict,
190
  ) -> Image.Image:
191
  rgb_image = image_prompts["image"].convert("RGB")
192
 
 
202
  dtype = torch.float16 if device == 'cuda' else torch.float32
203
  moge_v2_dpt_model = moge_v2_dpt_model.to(device=device, dtype=dtype)
204
 
 
 
205
  if work_space is None:
206
+ work_space = {
207
+ "dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"),
208
+ }
209
+ os.makedirs(work_space["dir"], exist_ok=True)
 
210
 
211
  origin_W, origin_H = rgb_image.size
212
  if max(origin_H, origin_W) > 1024:
 
236
  ])
237
  ).to(dtype=torch.float32, device=device)
238
 
239
+ work_space.update({
240
+ "c2w": c2w,
241
+ "K": K,
242
+ "depth_mask": depth_mask,
243
+ "depth": depth,
244
+ })
245
 
246
  instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
247
  seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
 
258
 
259
  scene_est_depth_pts, scene_est_depth_pts_colors = \
260
  project2ply(depth_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
261
+ save_ply_path = os.path.join(work_space["dir"], "scene_pcd.glb")
262
 
263
  fg_depth_pts, _ = \
264
  project2ply(fg_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
 
267
  if trans.shape[0] == 1:
268
  trans = trans[0]
269
 
270
+ work_space.update(
271
  {
272
  "trans": trans,
273
  "scale": scale,
274
  }
275
  )
276
 
277
+ for k, v in work_space.items():
278
  if isinstance(v, torch.Tensor):
279
+ work_space[k] = v.to('cpu')
280
 
281
  trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
282
  apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
283
  apply_transform(rot).export(save_ply_path)
284
 
285
+ return save_ply_path, work_space
286
 
287
 
288
  def save_image(img, save_path):
 
305
 
306
  def get_duration(rgb_image, seg_image, seed, randomize_seed,
307
  num_inference_steps, guidance_scale, cfg_interval_start,
308
+ cfg_interval_end, t_rescale, work_space):
309
  instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
310
  step_duration = 15.0
311
  return instance_labels.shape[0] * step_duration + 60
 
321
  cfg_interval_start: float = 0.5,
322
  cfg_interval_end: float = 1.0,
323
  t_rescale: float = 3.0,
324
+ work_space: dict = None,
325
  ):
326
 
327
+ if work_space is None:
328
+ raise gr.Error("Please run step 1 and step 2 first.")
329
+ required_keys = ["dir", "depth_mask", "depth", "K", "c2w", "trans", "scale"]
330
+ missing = [k for k in required_keys if k not in work_space]
331
+ if missing:
332
+ raise gr.Error(f"Missing workspace fields: {missing}. Please run depth estimation (step 2) first.")
333
+
334
  from threeDFixer.pipelines import ThreeDFixerPipeline
335
  from threeDFixer.datasets.utils import (
336
  edge_mask_morph_gradient,
 
383
 
384
  return instance_glb_path, glb
385
 
 
 
 
386
  generated_object_map = {}
387
  run_id = str(uuid.uuid4())
388
 
 
400
  seed = random.randint(0, MAX_SEED)
401
  set_random_seed(seed)
402
 
403
+ H, W = work_space['depth_mask'].shape
404
  rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
405
  seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
406
 
407
+ depth_mask = work_space['depth_mask'].detach().cpu().numpy() > 0
408
  seg_image = np.array(seg_image)
409
 
410
  mask_pack = []
 
419
  results = []
420
  trimeshes = []
421
 
422
+ trans = work_space['trans']
423
+ scale = work_space['scale']
424
 
425
  current_scene_path = None
426
  pending_exports = []
 
537
  trimeshes.append(glb)
538
  current_scene_path = export_scene_glb(
539
  trimeshes=trimeshes,
540
+ work_space=work_space['dir'],
541
  scene_name=f"{run_id}_scene_step_{len(trimeshes)}.glb",
542
  )
543
  any_update = True
 
569
  if flushed is not None:
570
  yield flushed
571
 
572
+ est_depth = work_space['depth'].to('cpu')
573
+ c2w = work_space['c2w'].to('cpu')
574
+ K = work_space['K'].to('cpu')
575
 
576
+ intrinsics = work_space['K'].float().to(DEVICE)
577
+ extrinsics = copy.deepcopy(work_space['c2w']).float().to(DEVICE)
578
  extrinsics[:3, 1:3] *= -1
579
 
580
  object_mask = object_mask > 0
 
593
  instance_image, instance_mask, instance_rays_o, instance_rays_d, instance_rays_c, \
594
  instance_rays_t = process_instance_image(image, instance_mask, color_mask, est_depth, K, c2w, CROP_SIZE)
595
 
596
+ save_image(scene_image, os.path.join(work_space['dir'], f'input_scene_image_{instance_name}.png'))
597
+ save_image(scene_image_masked, os.path.join(work_space['dir'], f'input_scene_image_masked_{instance_name}.png'))
598
+ save_image(instance_image, os.path.join(work_space['dir'], f'input_instance_image_{instance_name}.png'))
599
  save_image(
600
  torch.cat([instance_image, instance_mask]),
601
+ os.path.join(work_space['dir'], f'input_instance_image_masked_{instance_name}.png')
602
  )
603
 
604
  pcd_points = (
 
610
  save_projected_colored_pcd(
611
  pcd_points,
612
  repeat(pcd_colors, 'n -> n c', c=3),
613
+ f"{work_space['dir']}/instance_est_depth_{instance_name}.ply"
614
  )
615
 
616
  with torch.no_grad():
 
637
  )
638
 
639
  mp4_path = os.path.abspath(
640
+ os.path.join(work_space['dir'], f"{run_id}_instance_gs_fine_{instance_name}.mp4")
641
  )
642
  poster_path = os.path.abspath(
643
+ os.path.join(work_space['dir'], f"{run_id}_instance_gs_fine_{instance_name}.png")
644
  )
645
 
646
  video = render_utils.render_video(
 
681
  trans=trans,
682
  scale=scale,
683
  rot=rot,
684
+ work_space=work_space['dir'],
685
  instance_name=instance_name,
686
  run_id=run_id,
687
  )
 
711
  if len(ready_items) > 0:
712
  final_scene_path = export_scene_glb(
713
  trimeshes=trimeshes,
714
+ work_space=work_space['dir'],
715
  scene_name=f"{run_id}_scene_final.glb",
716
  )
717
 
 
743
 
744
  # Demo
745
  with gr.Blocks() as demo:
746
+ gr_work_space = gr.State(value=None)
747
  gr.Markdown(MARKDOWN)
748
 
749
  with gr.Column():
 
816
  with gr.Row():
817
  gr.Examples(
818
  examples=EXAMPLES,
 
819
  inputs=[image_prompts, seg_image, seed, randomize_seed, num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale],
820
  outputs=[model_output, download_glb, seed],
821
  cache_examples=False,
 
827
  image_prompts,
828
  polygon_refinement,
829
  ],
830
+ outputs=[seg_image, gr_work_space],
831
  ).then(lambda: gr.Button(interactive=True), outputs=[dpt_button])
832
 
833
  dpt_button.click(
834
  run_depth_estimation,
835
  inputs=[
836
  image_prompts,
837
+ seg_image,
838
+ gr_work_space
839
  ],
840
+ outputs=[dpt_model_output, gr_work_space],
841
  ).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
842
 
843
  gen_button.click(
 
851
  guidance_scale,
852
  cfg_interval_start,
853
  cfg_interval_end,
854
+ t_rescale,
855
+ gr_work_space
856
  ],
857
  outputs=[model_output,
858
  stream_output,