Yujia-Zhang0913 commited on
Commit
d45351d
·
1 Parent(s): 9e30547

space gpu fuc

Browse files
Files changed (1) hide show
  1. app.py +110 -87
app.py CHANGED
@@ -5,6 +5,7 @@ import shutil
5
  from datetime import datetime
6
  import glob
7
  import gc
 
8
  import gradio as gr
9
  import numpy as np
10
  import open3d as o3d
@@ -30,59 +31,59 @@ from vggt.utils.load_fn import load_and_preprocess_images
30
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
31
  from vggt.utils.geometry import unproject_depth_map_to_point_map
32
 
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
-
35
- def run_model(target_dir, model) -> dict:
36
  """
37
- Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
38
  """
39
- print(f"Processing images from {target_dir}")
40
-
41
- # if not torch.cuda.is_available():
42
- # raise ValueError("CUDA is not available. Check your environment.")
43
-
44
- # Move model to device
45
- model = model.to(device)
46
  model.eval()
47
 
48
- # Load and preprocess images
49
- image_names = glob.glob(os.path.join(target_dir, "images", "*"))
50
- image_names = sorted(image_names)
51
- print(f"Found {len(image_names)} images")
52
- if len(image_names) == 0:
53
- raise ValueError("No images found. Check your upload.")
54
-
55
- images = load_and_preprocess_images(image_names).to(device)
56
- print(f"Preprocessed images shape: {images.shape}")
57
-
58
- # Run inference
59
  print("Running inference...")
60
  with torch.no_grad():
61
  if device == "cuda":
62
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
63
- predictions = model(images)
64
  else:
65
- predictions = model(images)
66
 
67
- # Convert pose encoding to extrinsic and intrinsic matrices
68
  print("Converting pose encoding to extrinsic and intrinsic matrices...")
69
- extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
70
  predictions["extrinsic"] = extrinsic
71
  predictions["intrinsic"] = intrinsic
72
 
73
- # Convert tensors to numpy
74
  for key in predictions.keys():
75
  if isinstance(predictions[key], torch.Tensor):
76
- predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Generate world points from depth map
79
  print("Computing world points from depth map...")
80
  depth_map = predictions["depth"] # (S, H, W, 1)
81
  world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
82
  predictions["world_points_from_depth"] = world_points
83
 
84
- # Clean up
85
- torch.cuda.empty_cache()
86
  return predictions
87
 
88
  def handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode,if_TSDF):
@@ -92,7 +93,6 @@ def handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mod
92
  """
93
  start_time = time.time()
94
  gc.collect()
95
- torch.cuda.empty_cache()
96
 
97
  # Create a unique folder name
98
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
@@ -203,7 +203,6 @@ def parse_frames(
203
 
204
  start_time = time.time()
205
  gc.collect()
206
- torch.cuda.empty_cache()
207
 
208
  # Prepare frame_filter dropdown
209
  target_dir_images = os.path.join(target_dir, "images")
@@ -213,8 +212,7 @@ def parse_frames(
213
  frame_filter_choices = ["All"] + all_files
214
 
215
  print("Running run_model...")
216
- with torch.no_grad():
217
- predictions = run_model(target_dir, VGGT_model)
218
 
219
  # Save predictions
220
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
@@ -363,7 +361,6 @@ def parse_frames(
363
  # Cleanup
364
  del predictions
365
  gc.collect()
366
- torch.cuda.empty_cache()
367
  end_time = time.time()
368
  print(f"Total time: {end_time - start_time:.2f} seconds")
369
  return original_points, original_colors, original_normals
@@ -570,29 +567,31 @@ def get_pca_color(feat, start = 0, brightness=1.25, center=True):
570
  color = color.clamp(0.0, 1.0)
571
  return color
572
 
573
- def Concerto_process(target_dir, original_points, original_colors, original_normals, slider_value, bright_value, model_type):
574
- gc.collect()
575
- torch.cuda.empty_cache()
576
- target_dir_pcds = os.path.join(target_dir, "pcds")
 
 
 
577
 
578
- point = {"coord": original_points, "color": original_colors, "normal":original_normals}
579
- original_coord = point["coord"].copy()
580
- original_color = point["color"].copy()
581
- point = transform(point)
 
 
 
 
 
582
 
583
  with torch.inference_mode():
584
- for key in point.keys():
585
- if isinstance(point[key], torch.Tensor) and device=="cuda":
586
- point[key] = point[key].cuda(non_blocking=True)
587
- # model forward:
588
  concerto_start_time = time.time()
589
- if model_type =="Concerto":
590
- point = concerto_model(point)
591
- elif model_type == "Sonata":
592
- point = sonata_model(point)
593
  concerto_end_time = time.time()
 
594
  # upcast point feature
595
- # Point is a structure contains all the information during forward
596
  for _ in range(2):
597
  assert "pooling_parent" in point.keys()
598
  assert "pooling_inverse" in point.keys()
@@ -607,27 +606,38 @@ def Concerto_process(target_dir, original_points, original_colors, original_norm
607
  parent.feat = point.feat[inverse]
608
  point = parent
609
 
610
- # here point is down-sampled by GridSampling in default transform pipeline
611
- # feature of point cloud in original scale can be acquired by:
612
- _ = point.feat[point.inverse]
613
-
614
- # PCA
615
- point_feat = point.feat.cpu().detach().numpy()
616
- np.save(os.path.join(target_dir_pcds,"feat.npy"),point_feat)
617
  pca_start_time = time.time()
618
- pca_color = get_pca_color(point.feat,start = slider_value, brightness=bright_value, center=True)
619
  pca_end_time = time.time()
620
 
621
- # inverse back to original scale before grid sampling
622
- # point.inverse is acquired from the GirdSampling transform
 
 
623
  point_inverse = point.inverse.cpu().detach().numpy()
624
- np.save(os.path.join(target_dir_pcds,"inverse.npy"),point_inverse)
625
- original_pca_color = pca_color[point.inverse]
626
- points = original_coord
627
- colors = original_pca_color.cpu().detach().numpy()
628
 
629
- end_time = time.time()
630
- return points, colors, concerto_end_time - concerto_start_time, pca_end_time - pca_start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
  def gradio_demo(target_dir,pca_slider,bright_slider, model_type, if_color=True, if_normal=True):
633
  target_dir_pcds = os.path.join(target_dir, "pcds")
@@ -651,21 +661,35 @@ def gradio_demo(target_dir,pca_slider,bright_slider, model_type, if_color=True,
651
 
652
  return processed_temp, f"Feature visualization process finished with {concerto_time:.3f} seconds using Concerto inference and {pca_time:.3f} seconds using PCA. Updating visualization."
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  def concerto_slider_update(target_dir,pca_slider,bright_slider,is_example,log_output):
655
  if is_example == "True":
656
  return None, log_output
657
  else:
658
  target_dir_pcds = os.path.join(target_dir, "pcds")
659
  if os.path.isfile(os.path.join(target_dir_pcds,"feat.npy")):
 
660
  feat = np.load(os.path.join(target_dir_pcds,"feat.npy"))
661
  inverse = np.load(os.path.join(target_dir_pcds,"inverse.npy"))
662
- feat = torch.tensor(feat, device = device)
663
- inverse = torch.tensor(inverse, device = device)
664
- pca_start_time = time.time()
665
- pca_colors = get_pca_color(feat,start = pca_slider, brightness=bright_slider, center=True)
666
- processed_colors = pca_colors[inverse].cpu().detach().numpy()
667
- pca_end_time = time.time()
668
- pca_time = pca_end_time - pca_start_time
669
  processed_points = np.load(os.path.join(target_dir_pcds,"points.npy"))
670
  processed_normals = np.load(os.path.join(target_dir_pcds,"normals.npy"))
671
  processed_temp = (os.path.join(target_dir_pcds,"processed.glb"))
@@ -673,36 +697,35 @@ def concerto_slider_update(target_dir,pca_slider,bright_slider,is_example,log_ou
673
  feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=processed_normals)
674
  feat_3d.add_geometry(feat_data)
675
  feat_3d.export(processed_temp)
676
- log_output = f"Feature visualization process finished with{pca_time:.3f} seconds using PCA. Updating visualization."
677
  else:
678
  processed_temp = None
679
  log_output = "No representations saved, please click PCA generate first."
680
- # processed_temp, log_output = gradio_demo(target_dir,pca_slider,bright_slider)
681
  return processed_temp, log_output
682
 
683
  # set random seed
684
  # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
685
  # (the pca prevent in paper is with another version of cuda and pytorch environment)
686
  concerto.utils.set_seed(53124)
687
- # Load model
688
- if device == 'cuda' and flash_attn is not None:
689
- print("Loading model with Flash Attention on GPU.")
690
- concerto_model = concerto.load("concerto_large", repo_id="Pointcept/Concerto").to(device)
691
- sonata_model = concerto.model.load("sonata", repo_id="facebook/sonata").to(device)
692
  else:
693
- print("Loading model on CPU or without Flash Attention.")
694
  custom_config = dict(
695
  # enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
696
  enable_flash=False,
697
  )
698
  concerto_model = concerto.load(
699
  "concerto_large", repo_id="Pointcept/Concerto", custom_config=custom_config
700
- ).to(device)
701
- sonata_model = concerto.load("sonata", repo_id="facebook/sonata", custom_config=custom_config).to(device)
702
 
703
  transform = concerto.transform.default()
704
 
705
- VGGT_model = VGGT().to(device)
706
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
707
  VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
708
  # VGGT_model.load_state_dict(torch.load("vggt/ckpt/model.pt",weights_only=True))
 
5
  from datetime import datetime
6
  import glob
7
  import gc
8
+ import spaces
9
  import gradio as gr
10
  import numpy as np
11
  import open3d as o3d
 
31
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
32
  from vggt.utils.geometry import unproject_depth_map_to_point_map
33
 
34
+ @spaces.GPU
35
+ def _gpu_run_vggt_inference(images_tensor):
 
36
  """
37
+ GPU-only function: Run VGGT model inference on preprocessed images.
38
  """
39
+ global VGGT_model
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ images_tensor = images_tensor.to(device)
42
+ model = VGGT_model.to(device)
 
 
 
43
  model.eval()
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  print("Running inference...")
46
  with torch.no_grad():
47
  if device == "cuda":
48
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
49
+ predictions = model(images_tensor)
50
  else:
51
+ predictions = model(images_tensor)
52
 
 
53
  print("Converting pose encoding to extrinsic and intrinsic matrices...")
54
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images_tensor.shape[-2:])
55
  predictions["extrinsic"] = extrinsic
56
  predictions["intrinsic"] = intrinsic
57
 
 
58
  for key in predictions.keys():
59
  if isinstance(predictions[key], torch.Tensor):
60
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0)
61
+
62
+ torch.cuda.empty_cache()
63
+ return predictions
64
+
65
+ def run_model(target_dir) -> dict:
66
+ """
67
+ CPU-GPU hybrid: Handle CPU-intensive file I/O and call GPU function for inference.
68
+ """
69
+ print(f"Processing images from {target_dir}")
70
+
71
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
72
+ image_names = sorted(image_names)
73
+ print(f"Found {len(image_names)} images")
74
+ if len(image_names) == 0:
75
+ raise ValueError("No images found. Check your upload.")
76
+
77
+ images = load_and_preprocess_images(image_names)
78
+ print(f"Preprocessed images shape: {images.shape}")
79
+
80
+ predictions = _gpu_run_vggt_inference(images)
81
 
 
82
  print("Computing world points from depth map...")
83
  depth_map = predictions["depth"] # (S, H, W, 1)
84
  world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
85
  predictions["world_points_from_depth"] = world_points
86
 
 
 
87
  return predictions
88
 
89
  def handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode,if_TSDF):
 
93
  """
94
  start_time = time.time()
95
  gc.collect()
 
96
 
97
  # Create a unique folder name
98
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
 
203
 
204
  start_time = time.time()
205
  gc.collect()
 
206
 
207
  # Prepare frame_filter dropdown
208
  target_dir_images = os.path.join(target_dir, "images")
 
212
  frame_filter_choices = ["All"] + all_files
213
 
214
  print("Running run_model...")
215
+ predictions = run_model(target_dir)
 
216
 
217
  # Save predictions
218
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
 
361
  # Cleanup
362
  del predictions
363
  gc.collect()
 
364
  end_time = time.time()
365
  print(f"Total time: {end_time - start_time:.2f} seconds")
366
  return original_points, original_colors, original_normals
 
567
  color = color.clamp(0.0, 1.0)
568
  return color
569
 
570
+ @spaces.GPU
571
+ def _gpu_concerto_forward_pca(point, model_type, pca_slider, bright_slider):
572
+ """
573
+ GPU-only function: Run Concerto/Sonata model forward pass and PCA.
574
+ """
575
+ global concerto_model, sonata_model
576
+ device = "cuda" if torch.cuda.is_available() else "cpu"
577
 
578
+ for key in point.keys():
579
+ if isinstance(point[key], torch.Tensor):
580
+ point[key] = point[key].to(device, non_blocking=True)
581
+
582
+ if model_type == "Concerto":
583
+ model = concerto_model.to(device)
584
+ elif model_type == "Sonata":
585
+ model = sonata_model.to(device)
586
+ model.eval()
587
 
588
  with torch.inference_mode():
 
 
 
 
589
  concerto_start_time = time.time()
590
+ with torch.inference_mode(False):
591
+ point = model(point)
 
 
592
  concerto_end_time = time.time()
593
+
594
  # upcast point feature
 
595
  for _ in range(2):
596
  assert "pooling_parent" in point.keys()
597
  assert "pooling_inverse" in point.keys()
 
606
  parent.feat = point.feat[inverse]
607
  point = parent
608
 
 
 
 
 
 
 
 
609
  pca_start_time = time.time()
610
+ pca_color = get_pca_color(point.feat, start=pca_slider, brightness=bright_slider, center=True)
611
  pca_end_time = time.time()
612
 
613
+ original_pca_color = pca_color[point.inverse]
614
+
615
+ processed_colors = original_pca_color.cpu().detach().numpy()
616
+ point_feat = point.feat.cpu().detach().numpy()
617
  point_inverse = point.inverse.cpu().detach().numpy()
618
+ concerto_time = concerto_end_time - concerto_start_time
619
+ pca_time = pca_end_time - pca_start_time
 
 
620
 
621
+ torch.cuda.empty_cache()
622
+ return processed_colors, point_feat, point_inverse, concerto_time, pca_time
623
+
624
+ def Concerto_process(target_dir, original_points, original_colors, original_normals, slider_value, bright_value, model_type):
625
+ target_dir_pcds = os.path.join(target_dir, "pcds")
626
+
627
+ point = {"coord": original_points, "color": original_colors, "normal": original_normals}
628
+ original_coord = point["coord"].copy()
629
+ point = transform(point)
630
+
631
+ # GPU: Run model forward + PCA
632
+ processed_colors, point_feat, point_inverse, concerto_time, pca_time = _gpu_concerto_forward_pca(
633
+ point, model_type, slider_value, bright_value
634
+ )
635
+
636
+ # CPU: Save features
637
+ np.save(os.path.join(target_dir_pcds, "feat.npy"), point_feat)
638
+ np.save(os.path.join(target_dir_pcds, "inverse.npy"), point_inverse)
639
+
640
+ return original_coord, processed_colors, concerto_time, pca_time
641
 
642
  def gradio_demo(target_dir,pca_slider,bright_slider, model_type, if_color=True, if_normal=True):
643
  target_dir_pcds = os.path.join(target_dir, "pcds")
 
661
 
662
  return processed_temp, f"Feature visualization process finished with {concerto_time:.3f} seconds using Concerto inference and {pca_time:.3f} seconds using PCA. Updating visualization."
663
 
664
+ @spaces.GPU
665
+ def _gpu_pca_slider_compute(feat_array, inverse_array, pca_slider, bright_slider):
666
+ """
667
+ GPU-only function: Compute PCA colors for slider updates.
668
+ """
669
+ device = "cuda" if torch.cuda.is_available() else "cpu"
670
+ feat_tensor = torch.tensor(feat_array, device=device)
671
+ inverse_tensor = torch.tensor(inverse_array, device=device)
672
+
673
+ pca_start_time = time.time()
674
+ pca_colors = get_pca_color(feat_tensor, start=pca_slider, brightness=bright_slider, center=True)
675
+ processed_colors = pca_colors[inverse_tensor].cpu().detach().numpy()
676
+ pca_end_time = time.time()
677
+ return processed_colors, (pca_end_time - pca_start_time)
678
+
679
  def concerto_slider_update(target_dir,pca_slider,bright_slider,is_example,log_output):
680
  if is_example == "True":
681
  return None, log_output
682
  else:
683
  target_dir_pcds = os.path.join(target_dir, "pcds")
684
  if os.path.isfile(os.path.join(target_dir_pcds,"feat.npy")):
685
+ # CPU: Load data from disk
686
  feat = np.load(os.path.join(target_dir_pcds,"feat.npy"))
687
  inverse = np.load(os.path.join(target_dir_pcds,"inverse.npy"))
688
+
689
+ # GPU: Compute PCA colors
690
+ processed_colors, pca_time = _gpu_pca_slider_compute(feat, inverse, pca_slider, bright_slider)
691
+
692
+ # CPU: Build mesh
 
 
693
  processed_points = np.load(os.path.join(target_dir_pcds,"points.npy"))
694
  processed_normals = np.load(os.path.join(target_dir_pcds,"normals.npy"))
695
  processed_temp = (os.path.join(target_dir_pcds,"processed.glb"))
 
697
  feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=processed_normals)
698
  feat_3d.add_geometry(feat_data)
699
  feat_3d.export(processed_temp)
700
+ log_output = f"Feature visualization process finished with {pca_time:.3f} seconds using PCA. Updating visualization."
701
  else:
702
  processed_temp = None
703
  log_output = "No representations saved, please click PCA generate first."
 
704
  return processed_temp, log_output
705
 
706
  # set random seed
707
  # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
708
  # (the pca prevent in paper is with another version of cuda and pytorch environment)
709
  concerto.utils.set_seed(53124)
710
+ # Load model (to CPU; moved to GPU on-demand via @spaces.GPU)
711
+ if flash_attn is not None:
712
+ print("Loading model with Flash Attention.")
713
+ concerto_model = concerto.load("concerto_large", repo_id="Pointcept/Concerto")
714
+ sonata_model = concerto.model.load("sonata", repo_id="facebook/sonata")
715
  else:
716
+ print("Loading model without Flash Attention.")
717
  custom_config = dict(
718
  # enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
719
  enable_flash=False,
720
  )
721
  concerto_model = concerto.load(
722
  "concerto_large", repo_id="Pointcept/Concerto", custom_config=custom_config
723
+ )
724
+ sonata_model = concerto.load("sonata", repo_id="facebook/sonata", custom_config=custom_config)
725
 
726
  transform = concerto.transform.default()
727
 
728
+ VGGT_model = VGGT()
729
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
730
  VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
731
  # VGGT_model.load_state_dict(torch.load("vggt/ckpt/model.pt",weights_only=True))