Yujia-Zhang0913 commited on
Commit
abfd173
·
1 Parent(s): 27fcb94
Files changed (1) hide show
  1. app.py +39 -36
app.py CHANGED
@@ -31,6 +31,32 @@ from vggt.utils.load_fn import load_and_preprocess_images
31
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
32
  from vggt.utils.geometry import unproject_depth_map_to_point_map
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU
35
  def _gpu_run_vggt_inference(images_tensor):
36
  """
@@ -568,27 +594,23 @@ def get_pca_color(feat, start = 0, brightness=1.25, center=True):
568
  return color
569
 
570
  @spaces.GPU
571
- def _gpu_concerto_forward_pca(point, model_type, pca_slider, bright_slider):
572
  """
573
  GPU-only function: Run Concerto/Sonata model forward pass and PCA.
574
  """
575
- global concerto_model, sonata_model
576
  device = "cuda" if torch.cuda.is_available() else "cpu"
577
 
578
  for key in point.keys():
579
  if isinstance(point[key], torch.Tensor):
580
  point[key] = point[key].to(device, non_blocking=True)
581
 
582
- if model_type == "Concerto":
583
- model = concerto_model.to(device)
584
- elif model_type == "Sonata":
585
- model = sonata_model.to(device)
586
- model.eval()
587
 
588
  with torch.inference_mode():
589
  concerto_start_time = time.time()
590
  with torch.inference_mode(False):
591
- point = model(point)
592
  concerto_end_time = time.time()
593
 
594
  # upcast point feature
@@ -628,9 +650,17 @@ def Concerto_process(target_dir, original_points, original_colors, original_norm
628
  original_coord = point["coord"].copy()
629
  point = transform(point)
630
 
 
 
 
 
 
 
 
 
631
  # GPU: Run model forward + PCA
632
  processed_colors, point_feat, point_inverse, concerto_time, pca_time = _gpu_concerto_forward_pca(
633
- point, model_type, slider_value, bright_value
634
  )
635
 
636
  # CPU: Save features
@@ -705,33 +735,6 @@ def concerto_slider_update(target_dir,pca_slider,bright_slider,is_example,log_ou
705
  log_output = "No representations saved, please click PCA generate first."
706
  return processed_temp, log_output
707
 
708
- # set random seed
709
- # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
710
- # (the pca prevent in paper is with another version of cuda and pytorch environment)
711
- concerto.utils.set_seed(53124)
712
- # Load model (to CPU; moved to GPU on-demand via @spaces.GPU)
713
- if flash_attn is not None:
714
- print("Loading model with Flash Attention.")
715
- concerto_model = concerto.load("concerto_large", repo_id="Pointcept/Concerto")
716
- sonata_model = concerto.model.load("sonata", repo_id="facebook/sonata")
717
- else:
718
- print("Loading model without Flash Attention.")
719
- custom_config = dict(
720
- # enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
721
- enable_flash=False,
722
- )
723
- concerto_model = concerto.load(
724
- "concerto_large", repo_id="Pointcept/Concerto", custom_config=custom_config
725
- )
726
- sonata_model = concerto.load("sonata", repo_id="facebook/sonata", custom_config=custom_config)
727
-
728
- transform = concerto.transform.default()
729
-
730
- VGGT_model = VGGT()
731
- _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
732
- VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
733
- # VGGT_model.load_state_dict(torch.load("vggt/ckpt/model.pt",weights_only=True))
734
-
735
 
736
  BASE_URL = "https://huggingface.co/datasets/pointcept-bot/concerto_huggingface_demo/resolve/main/"
737
  def get_url(path):
 
31
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
32
  from vggt.utils.geometry import unproject_depth_map_to_point_map
33
 
34
+ # set random seed
35
+ # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
36
+ # (the pca prevent in paper is with another version of cuda and pytorch environment)
37
+ concerto.utils.set_seed(53124)
38
+ # Load model (to CPU; moved to GPU on-demand via @spaces.GPU)
39
+ if flash_attn is not None:
40
+ print("Loading model with Flash Attention.")
41
+ concerto_model = concerto.load("concerto_large", repo_id="Pointcept/Concerto")
42
+ sonata_model = concerto.model.load("sonata", repo_id="facebook/sonata")
43
+ else:
44
+ print("Loading model without Flash Attention.")
45
+ custom_config = dict(
46
+ # enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
47
+ enable_flash=False,
48
+ )
49
+ concerto_model = concerto.load(
50
+ "concerto_large", repo_id="Pointcept/Concerto", custom_config=custom_config
51
+ )
52
+ sonata_model = concerto.load("sonata", repo_id="facebook/sonata", custom_config=custom_config)
53
+
54
+ transform = concerto.transform.default()
55
+
56
+ VGGT_model = VGGT()
57
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
58
+ VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
59
+
60
  @spaces.GPU
61
  def _gpu_run_vggt_inference(images_tensor):
62
  """
 
594
  return color
595
 
596
  @spaces.GPU
597
+ def _gpu_concerto_forward_pca(point, concerto_model_, pca_slider, bright_slider):
598
  """
599
  GPU-only function: Run Concerto/Sonata model forward pass and PCA.
600
  """
 
601
  device = "cuda" if torch.cuda.is_available() else "cpu"
602
 
603
  for key in point.keys():
604
  if isinstance(point[key], torch.Tensor):
605
  point[key] = point[key].to(device, non_blocking=True)
606
 
607
+ concerto_model_ = concerto_model_.to(device)
608
+ concerto_model_.eval()
 
 
 
609
 
610
  with torch.inference_mode():
611
  concerto_start_time = time.time()
612
  with torch.inference_mode(False):
613
+ point = concerto_model_(point)
614
  concerto_end_time = time.time()
615
 
616
  # upcast point feature
 
650
  original_coord = point["coord"].copy()
651
  point = transform(point)
652
 
653
+ # Select model based on type
654
+ if model_type == "Concerto":
655
+ selected_model = concerto_model
656
+ elif model_type == "Sonata":
657
+ selected_model = sonata_model
658
+ else:
659
+ selected_model = concerto_model
660
+
661
  # GPU: Run model forward + PCA
662
  processed_colors, point_feat, point_inverse, concerto_time, pca_time = _gpu_concerto_forward_pca(
663
+ point, selected_model, slider_value, bright_value
664
  )
665
 
666
  # CPU: Save features
 
735
  log_output = "No representations saved, please click PCA generate first."
736
  return processed_temp, log_output
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
 
739
  BASE_URL = "https://huggingface.co/datasets/pointcept-bot/concerto_huggingface_demo/resolve/main/"
740
  def get_url(path):