notenoughram commited on
Commit
8f5dd0d
ยท
verified ยท
1 Parent(s): 9faeb08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -94
app.py CHANGED
@@ -1,68 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import spaces
3
- # ์œ ๋ฃŒ ํ™˜๊ฒฝ์—์„œ๋Š” spaces import๊ฐ€ ์žˆ์–ด๋„ @spaces.GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋งŒ ์•ˆ ์“ฐ๋ฉด ๋ฉ๋‹ˆ๋‹ค.
4
  from gradio_litmodel3d import LitModel3D
5
 
6
- import os
7
- import shutil
8
  os.environ['SPCONV_ALGO'] = 'native'
9
  from typing import *
10
  import torch
11
- import numpy as np
12
  import imageio
 
13
  from easydict import EasyDict as edict
14
  from PIL import Image
 
 
 
 
15
  from trellis.pipelines import TrellisVGGTTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
 
18
 
 
 
19
  from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
20
  from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
21
- import open3d as o3d
22
- from torchvision import transforms as TF
23
- from PIL import Image
24
- import sys
25
- # sys.path.append("wheels") # ํ•„์š”์‹œ ๊ฒฝ๋กœ ์ˆ˜์ •
26
- import cv2 # cv2๊ฐ€ ๋ˆ„๋ฝ๋˜์–ด ์žˆ์„ ์ˆ˜ ์žˆ์–ด ์ถ”๊ฐ€ํ–ˆ์Šต๋‹ˆ๋‹ค.
27
-
28
  from wheels.mast3r.model import AsymmetricMASt3R
29
  from wheels.mast3r.fast_nn import fast_reciprocal_NNs
30
  from wheels.dust3r.dust3r.inference import inference
31
  from wheels.dust3r.dust3r.utils.image import load_images_new
32
- from trellis.utils.general_utils import *
33
- import copy
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
37
- # TMP_DIR = "tmp/Trellis-demo"
38
- # os.environ['GRADIO_TEMP_DIR'] = 'tmp'
39
  os.makedirs(TMP_DIR, exist_ok=True)
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
 
42
  def start_session(req: gr.Request):
43
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
44
  os.makedirs(user_dir, exist_ok=True)
45
-
46
-
47
  def end_session(req: gr.Request):
48
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
49
- # [์ˆ˜์ •] ํด๋”๊ฐ€ ์—†์œผ๋ฉด ์‚ญ์ œํ•˜์ง€ ์•Š๋„๋ก ์˜ˆ์™ธ ์ฒ˜๋ฆฌ ์ถ”๊ฐ€
50
  if os.path.exists(user_dir):
51
  shutil.rmtree(user_dir)
52
 
53
- # [์ˆ˜์ •] ์œ ๋ฃŒ 4 GPU ์‚ฌ์šฉ ์‹œ ์ถฉ๋Œ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด @spaces.GPU ์ œ๊ฑฐ
54
  def preprocess_image(image: Image.Image) -> Image.Image:
55
- """
56
- Preprocess the input image for 3D generation.
57
- """
58
  processed_image = pipeline.preprocess_image(image)
59
  return processed_image
60
 
61
- # [์ˆ˜์ •] @spaces.GPU ์ œ๊ฑฐ
62
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
63
- """
64
- Preprocess the input video for multi-image 3D generation.
65
- """
66
  vid = imageio.get_reader(video, 'ffmpeg')
67
  fps = vid.get_meta_data()['fps']
68
  images = []
@@ -76,16 +76,12 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
76
  processed_images = [pipeline.preprocess_image(image) for image in images]
77
  return processed_images
78
 
79
- # [์ˆ˜์ •] @spaces.GPU ์ œ๊ฑฐ
80
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
81
- """
82
- Preprocess a list of input images for multi-image 3D generation.
83
- """
84
  images = [image[0] for image in images]
85
  processed_images = [pipeline.preprocess_image(image) for image in images]
86
  return processed_images
87
 
88
-
89
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
90
  return {
91
  'gaussian': {
@@ -101,8 +97,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
101
  'faces': mesh.faces.cpu().numpy(),
102
  },
103
  }
104
-
105
-
106
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
107
  gs = Gaussian(
108
  aabb=state['gaussian']['aabb'],
@@ -112,7 +107,7 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
112
  opacity_bias=state['gaussian']['opacity_bias'],
113
  scaling_activation=state['gaussian']['scaling_activation'],
114
  )
115
- # [์ˆ˜์ •] ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•  ๋•Œ ๋ฉ”์ธ GPU(cuda:0)์œผ๋กœ ๋ณด๋ƒ„
116
  gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
117
  gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
118
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
@@ -123,18 +118,17 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
123
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
124
  faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
125
  )
126
-
127
  return gs, mesh
128
 
129
-
130
  def get_seed(randomize_seed: bool, seed: int) -> int:
131
- """
132
- Get the random seed for generation.
133
- """
134
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
135
 
136
- def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics):
 
 
 
137
 
 
138
  extrinsic_tmp = extrinsic.clone()
139
  camera_relative = torch.matmul(extrinsic_tmp[:num_frames,:3,:3].permute(0,2,1), extrinsic_tmp[num_frames:,:3,:3])
140
  camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
@@ -155,20 +149,17 @@ def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrins
155
  def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth):
156
  images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
157
  with torch.no_grad():
158
- # [์ˆ˜์ •] mast3r_model ์ถ”๋ก  ์‹œ cuda:0 ๋ช…์‹œ (๋˜๋Š” ํ• ๋‹น๋œ device)
159
  output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
160
  view1, pred1 = output['view1'], output['pred1']
161
  view2, pred2 = output['view2'], output['pred2']
162
  del output
163
  desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
164
 
165
- # find 2D-2D matches between the two images
166
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
167
  device="cuda:0", dist='dot', block_size=2**13)
168
 
169
- # ignore small border around the edge
170
  H0, W0 = view1['true_shape'][0]
171
-
172
  valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
173
  matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
174
 
@@ -187,7 +178,7 @@ def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, tar
187
  pixel[0] *= scale_x
188
  pixel[1] *= scale_y
189
  depth_map = rend_depth[0]
190
- fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
191
  K = np.array([
192
  [fx, 0, cx],
193
  [0, fy, cy],
@@ -242,13 +233,10 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
242
  del output
243
  desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
244
 
245
- # find 2D-2D matches between the two images
246
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
247
  device="cuda:0", dist='dot', block_size=2**13)
248
 
249
- # ignore small border around the edge
250
  H0, W0 = view1['true_shape'][0]
251
-
252
  valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
253
  matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
254
 
@@ -267,7 +255,7 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
267
  pixel[0] *= scale_x
268
  pixel[1] *= scale_y
269
  depth_map = rend_depth[0]
270
- fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
271
  K = np.array([
272
  [fx, 0, cx],
273
  [0, fy, cy],
@@ -308,8 +296,7 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
308
  scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
309
  dist_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1)
310
  scale_2 = dist_2[dist_2 < np.percentile(dist_2, 99)].mean()
311
- # scale_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1).mean()
312
- # scale_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1).mean()
313
  points_3D_at_pixels_2 = points_3D_at_pixels_2 * (scale_1 / scale_2)
314
  pcd_1 = o3d.geometry.PointCloud()
315
  pcd_1.points = o3d.utility.Vector3dVector(points_3D_at_pixels)
@@ -334,7 +321,7 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
334
  )
335
  return transformation_matrix, evaluation.fitness
336
 
337
- # [์ˆ˜์ •] @spaces.GPU ์ œ๊ฑฐ
338
  def generate_and_extract_glb(
339
  multiimages: List[Tuple[Image.Image, str]],
340
  seed: int,
@@ -354,9 +341,7 @@ def generate_and_extract_glb(
354
  trellis_stage2_start_t: float,
355
  req: gr.Request,
356
  ) -> Tuple[dict, str, str, str]:
357
- """
358
- Convert an image to a 3D model and extract GLB file.
359
- """
360
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
361
  image_files = [image[0] for image in multiimages]
362
 
@@ -376,23 +361,41 @@ def generate_and_extract_glb(
376
  },
377
  mode=multiimage_algo,
378
  )
 
379
  if refine == "Yes":
380
  try:
381
  images, alphas = load_and_preprocess_images(multiimages)
382
- images, alphas = images.to(device), alphas.to(device)
 
 
383
  with torch.no_grad():
384
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
385
  images = images[None]
386
- # [์ˆ˜์ •] ๋ถ„์‚ฐ ๋ฐฐ์น˜๋œ VGGT_model ์ ‘๊ทผ
387
- vggt = pipeline.VGGT_model if not hasattr(pipeline.VGGT_model, 'module') else pipeline.VGGT_model.module
388
- aggregated_tokens_list, ps_idx = vggt.aggregator(images)
 
 
 
 
 
 
 
 
 
389
  # Predict Cameras
390
  pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
391
- # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
392
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
393
  # Predict Point Cloud
394
- point_map, point_conf = vggt.point_head(aggregated_tokens_list, images, ps_idx)
395
- del aggregated_tokens_list
 
 
 
 
 
 
 
396
  mask = (alphas[:,0,...][...,None] > 0.8)
397
  conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
398
  confidence_mask = (point_conf[0] > conf_threshold) & (point_conf[0] > 1e-5)
@@ -442,21 +445,24 @@ def generate_and_extract_glb(
442
  scale = np.linalg.norm(distance, axis=1).max()
443
  voxel_size = 1/64*scale*2
444
  pcd = pcd.voxel_down_sample(voxel_size)
445
- # pcd.points = o3d.utility.Vector3dVector((coords[:,1:].cpu().numpy() + 0.5) / 64 - 0.5)
446
 
447
  for k in range(len(image_files)):
448
  images = torch.stack([TF.ToTensor()(render_image) for render_image in video['color']] + [TF.ToTensor()(image_files[k].convert("RGB"))], dim=0)
449
  # if len(images) == 0:
450
  with torch.no_grad():
451
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
452
- # predictions = vggt_model(images.cuda())
453
- vggt = pipeline.VGGT_model if not hasattr(pipeline.VGGT_model, 'module') else pipeline.VGGT_model.module
454
- aggregated_tokens_list, ps_idx = vggt.aggregator(images[None].cuda())
 
455
  pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
 
 
 
456
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
457
  extrinsic, intrinsic = extrinsic[0], intrinsic[0]
458
  extrinsic = torch.cat([extrinsic, torch.tensor([0,0,0,1])[None,None].repeat(extrinsic.shape[0], 1, 1).to(extrinsic.device)], dim=1)
459
- del aggregated_tokens_list, ps_idx
460
 
461
  target_extrinsic, target_intrinsic = align_camera(registration_num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics)
462
  fxy = target_intrinsic[:,0,0]
@@ -483,9 +489,9 @@ def generate_and_extract_glb(
483
  target_image = images[registration_num_frames:].to(target_extrinsic.device)[j]
484
  original_size = (rend_image.shape[1], rend_image.shape[2])
485
 
486
- import torchvision
487
- torchvision.utils.save_image(rend_image, 'rend_image_{}.png'.format(k))
488
- torchvision.utils.save_image(target_image, 'target_image_{}.png'.format(k))
489
 
490
  mask_rend = (rend_image.detach().cpu() > 0).any(dim=0)
491
  mask_target = (target_image.detach().cpu() > 0).any(dim=0)
@@ -516,7 +522,10 @@ def generate_and_extract_glb(
516
  target_intrinsics = torch.cat(target_intrinsics, dim=0)
517
 
518
  target_fitnesses_filtered = [x for x in target_fitnesses if x <= 1]
519
- idx = target_fitnesses.index(max(target_fitnesses_filtered))
 
 
 
520
  target_transform = target_transforms[idx]
521
  down_pcd_align = copy.deepcopy(down_pcd).transform(target_transform)
522
  # pcd = o3d.geometry.PointCloud()
@@ -526,7 +535,7 @@ def generate_and_extract_glb(
526
  o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True),
527
  o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration = 10000))
528
  down_pcd_align_2 = copy.deepcopy(down_pcd_align).transform(reg_p2p.transformation)
529
- input_points = torch.tensor(np.asarray(down_pcd_align_2.points)).to(extrinsic.device).float()
530
  input_points = ((input_points + 0.5).clip(0, 1) * 64 - 0.5).to(torch.int32)
531
 
532
  outputs = pipeline.run_refine(
@@ -555,12 +564,10 @@ def generate_and_extract_glb(
555
  )
556
  except Exception as e:
557
  print(f"Error during refinement: {e}")
 
 
 
558
  # Render video
559
- # import uuid
560
- # output_id = str(uuid.uuid4())
561
- # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
562
- # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
563
- # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
564
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
565
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
566
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -581,11 +588,7 @@ def generate_and_extract_glb(
581
  return state, video_path, glb_path, glb_path
582
 
583
 
584
- # [์ˆ˜์ •] @spaces.GPU ์ œ๊ฑฐ
585
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
586
- """
587
- Extract a Gaussian splatting file from the generated 3D model.
588
- """
589
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
590
  gs, _ = unpack_state(state)
591
  gaussian_path = os.path.join(user_dir, 'sample.ply')
@@ -595,6 +598,7 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
595
 
596
 
597
  def prepare_multi_example() -> List[Image.Image]:
 
598
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
599
  images = []
600
  for case in multi_case:
@@ -611,14 +615,12 @@ def prepare_multi_example() -> List[Image.Image]:
611
 
612
 
613
  def split_image(image: Image.Image) -> List[Image.Image]:
614
- """
615
- Split a multi-view image into separate view images.
616
- """
617
  image_np = np.array(image)
618
- # [์•ˆ์ •์„ฑ ์ถ”๊ฐ€] ์ฑ„๋„ 3๊ฐœ์งœ๋ฆฌ(RGB) ์ด๋ฏธ์ง€๊ฐ€ ๋“ค์–ด์˜ฌ ๊ฒฝ์šฐ ์—๋Ÿฌ ๋ฐฉ์ง€
 
619
  if image_np.shape[-1] < 4:
620
  return [preprocess_image(image)]
621
-
622
  alpha = image_np[..., 3]
623
  alpha = np.any(alpha>0, axis=0)
624
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
@@ -775,29 +777,36 @@ with demo:
775
  )
776
 
777
 
778
- # Launch the Gradio app - VRAM 4๊ฐœ ๋ถ„์‚ฐ ์ตœ์ ํ™” ์ ์šฉ
779
  if __name__ == "__main__":
 
780
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
781
 
782
  num_gpus = torch.cuda.device_count()
 
 
783
  if num_gpus >= 4:
784
- # [VRAM ๋ถ„์‚ฐ ํ•ต์‹ฌ] ๊ฐ ๋ชจ๋ธ์„ ๋ฌผ๋ฆฌ์ ์œผ๋กœ ๋‹ค๋ฅธ GPU์— ๋กœ๋“œํ•˜์—ฌ OOM ๋ฐฉ์ง€
785
- pipeline.to("cuda:0") # ๋ฉ”์ธ ํŒŒ์ดํ”„๋ผ์ธ
 
 
786
  if hasattr(pipeline, 'VGGT_model'):
787
  pipeline.VGGT_model.to("cuda:1")
788
  if hasattr(pipeline, 'birefnet_model'):
789
  pipeline.birefnet_model.to("cuda:2")
790
- # ๊ฐ€์žฅ ๋ฌด๊ฑฐ์šด ๋””์ฝ”๋”๋“ค์„ 3๋ฒˆ GPU๋กœ ๊ฒฉ๋ฆฌ
 
791
  if hasattr(pipeline, 'slat_decoder'):
792
  pipeline.slat_decoder.to("cuda:3")
793
  if hasattr(pipeline, 'sparse_structure_decoder'):
794
  pipeline.sparse_structure_decoder.to("cuda:3")
795
 
796
- # Mast3r ๋ชจ๋ธ ๋กœ๋“œ (cuda:0 ์‚ฌ์šฉ)
797
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
798
- print("Success: 4 GPU VRAM Sharding Activated.")
799
  else:
 
800
  pipeline.cuda()
801
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
802
-
803
  demo.launch()
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import shutil
5
+ import numpy as np
6
+
7
+ # [1] Open3D ์—†์œผ๋ฉด ์ž๋™ ์„ค์น˜ (ModuleNotFoundError ํ•ด๊ฒฐ)
8
+ try:
9
+ import open3d as o3d
10
+ except ImportError:
11
+ print("Open3D not found. Installing...")
12
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"])
13
+ import open3d as o3d
14
+
15
  import gradio as gr
16
+ # @spaces.GPU ์ œ๊ฑฐ๋ฅผ ์œ„ํ•ด spaces๋Š” import ํ•˜๋˜ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋Š” ์•ˆ ์”๋‹ˆ๋‹ค.
17
+ import spaces
18
  from gradio_litmodel3d import LitModel3D
19
 
 
 
20
  os.environ['SPCONV_ALGO'] = 'native'
21
  from typing import *
22
  import torch
 
23
  import imageio
24
+ import cv2
25
  from easydict import EasyDict as edict
26
  from PIL import Image
27
+ from torchvision import transforms as TF
28
+ import copy
29
+
30
+ # Trellis ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
31
  from trellis.pipelines import TrellisVGGTTo3DPipeline
32
  from trellis.representations import Gaussian, MeshExtractResult
33
  from trellis.utils import render_utils, postprocessing_utils
34
+ from trellis.utils.general_utils import *
35
 
36
+ # ์ปค์Šคํ…€ ํœ  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ (์›๋ณธ ๋กœ์ง์šฉ)
37
+ sys.path.append("wheels")
38
  from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
39
  from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
 
 
 
 
 
 
 
40
  from wheels.mast3r.model import AsymmetricMASt3R
41
  from wheels.mast3r.fast_nn import fast_reciprocal_NNs
42
  from wheels.dust3r.dust3r.inference import inference
43
  from wheels.dust3r.dust3r.utils.image import load_images_new
 
 
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
 
47
  os.makedirs(TMP_DIR, exist_ok=True)
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
 
50
+ # --- ์„ธ์…˜ ๊ด€๋ฆฌ ---
51
  def start_session(req: gr.Request):
52
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
53
  os.makedirs(user_dir, exist_ok=True)
54
+
 
55
  def end_session(req: gr.Request):
56
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
57
  if os.path.exists(user_dir):
58
  shutil.rmtree(user_dir)
59
 
60
+ # --- ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค (@spaces.GPU ์ œ๊ฑฐ๋จ) ---
61
  def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
62
  processed_image = pipeline.preprocess_image(image)
63
  return processed_image
64
 
 
65
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
66
  vid = imageio.get_reader(video, 'ffmpeg')
67
  fps = vid.get_meta_data()['fps']
68
  images = []
 
76
  processed_images = [pipeline.preprocess_image(image) for image in images]
77
  return processed_images
78
 
 
79
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
80
  images = [image[0] for image in images]
81
  processed_images = [pipeline.preprocess_image(image) for image in images]
82
  return processed_images
83
 
84
+ # --- State ๊ด€๋ฆฌ ---
85
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
86
  return {
87
  'gaussian': {
 
97
  'faces': mesh.faces.cpu().numpy(),
98
  },
99
  }
100
+
 
101
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
102
  gs = Gaussian(
103
  aabb=state['gaussian']['aabb'],
 
107
  opacity_bias=state['gaussian']['opacity_bias'],
108
  scaling_activation=state['gaussian']['scaling_activation'],
109
  )
110
+ # ๋กœ๋“œ ์‹œ ๋ฉ”์ธ GPU(cuda:0)๋กœ ๋ณต๊ท€
111
  gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
112
  gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
113
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
 
118
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
119
  faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
120
  )
 
121
  return gs, mesh
122
 
 
123
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
124
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
 
126
+ # --- [์›๋ณธ ๋ณต๊ตฌ] ์ •๋ฐ€ ์ˆ˜ํ•™/ํฌ์ฆˆ ํ•จ์ˆ˜๋“ค ---
127
+ def perform_rodrigues_transformation(rvec):
128
+ R, _ = cv2.Rodrigues(rvec)
129
+ return R
130
 
131
+ def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics):
132
  extrinsic_tmp = extrinsic.clone()
133
  camera_relative = torch.matmul(extrinsic_tmp[:num_frames,:3,:3].permute(0,2,1), extrinsic_tmp[num_frames:,:3,:3])
134
  camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
 
149
  def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth):
150
  images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
151
  with torch.no_grad():
152
+ # [GPU ์ˆ˜์ •] mast3r ๋ชจ๋ธ ์ถ”๋ก  ์‹œ cuda:0 ๋ช…์‹œ
153
  output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
154
  view1, pred1 = output['view1'], output['pred1']
155
  view2, pred2 = output['view2'], output['pred2']
156
  del output
157
  desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
158
 
 
159
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
160
  device="cuda:0", dist='dot', block_size=2**13)
161
 
 
162
  H0, W0 = view1['true_shape'][0]
 
163
  valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
164
  matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
165
 
 
178
  pixel[0] *= scale_x
179
  pixel[1] *= scale_y
180
  depth_map = rend_depth[0]
181
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
182
  K = np.array([
183
  [fx, 0, cx],
184
  [0, fy, cy],
 
233
  del output
234
  desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
235
 
 
236
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
237
  device="cuda:0", dist='dot', block_size=2**13)
238
 
 
239
  H0, W0 = view1['true_shape'][0]
 
240
  valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
241
  matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
242
 
 
255
  pixel[0] *= scale_x
256
  pixel[1] *= scale_y
257
  depth_map = rend_depth[0]
258
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
259
  K = np.array([
260
  [fx, 0, cx],
261
  [0, fy, cy],
 
296
  scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
297
  dist_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1)
298
  scale_2 = dist_2[dist_2 < np.percentile(dist_2, 99)].mean()
299
+
 
300
  points_3D_at_pixels_2 = points_3D_at_pixels_2 * (scale_1 / scale_2)
301
  pcd_1 = o3d.geometry.PointCloud()
302
  pcd_1.points = o3d.utility.Vector3dVector(points_3D_at_pixels)
 
321
  )
322
  return transformation_matrix, evaluation.fitness
323
 
324
+ # [์ˆ˜์ •] ๋ฉ”์ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (Refine ๋กœ์ง 100% ๋ณต๊ตฌ + VRAM ๋ถ„์‚ฐ ์ ‘๊ทผ)
325
  def generate_and_extract_glb(
326
  multiimages: List[Tuple[Image.Image, str]],
327
  seed: int,
 
341
  trellis_stage2_start_t: float,
342
  req: gr.Request,
343
  ) -> Tuple[dict, str, str, str]:
344
+
 
 
345
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
346
  image_files = [image[0] for image in multiimages]
347
 
 
361
  },
362
  mode=multiimage_algo,
363
  )
364
+
365
  if refine == "Yes":
366
  try:
367
  images, alphas = load_and_preprocess_images(multiimages)
368
+ # ์ด๋ฏธ์ง€๋ฅผ cuda:0 (๋˜๋Š” ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU)์œผ๋กœ ์ด๋™
369
+ images, alphas = images.to("cuda:0"), alphas.to("cuda:0")
370
+
371
  with torch.no_grad():
372
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
373
  images = images[None]
374
+ # [VRAM ๋ถ„์‚ฐ ๋Œ€์‘] VGGT_model์ด ๋‹ค๋ฅธ GPU์— ์žˆ์–ด๋„ ํ˜ธ์ถœ ๊ฐ€๋Šฅํ•˜๋„๋ก ์ฒ˜๋ฆฌ
375
+ if hasattr(pipeline.VGGT_model, 'module'):
376
+ vggt = pipeline.VGGT_model.module
377
+ else:
378
+ vggt = pipeline.VGGT_model
379
+
380
+ # ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ VGGT ๋ชจ๋ธ์ด ์žˆ๋Š” GPU๋กœ ์ž„์‹œ ์ด๋™
381
+ target_device = next(vggt.parameters()).device
382
+ images_in = images.to(target_device)
383
+
384
+ aggregated_tokens_list, ps_idx = vggt.aggregator(images_in)
385
+
386
  # Predict Cameras
387
  pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
 
388
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
389
  # Predict Point Cloud
390
+ point_map, point_conf = vggt.point_head(aggregated_tokens_list, images_in, ps_idx)
391
+
392
+ # ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ค์‹œ CPU/๋ฉ”์ธ GPU๋กœ ๊ฐ€์ ธ์™€์„œ ์ฒ˜๋ฆฌ
393
+ point_map = point_map.to("cuda:0")
394
+ point_conf = point_conf.to("cuda:0")
395
+ extrinsic = extrinsic.to("cuda:0")
396
+ intrinsic = intrinsic.to("cuda:0")
397
+ del aggregated_tokens_list, images_in
398
+
399
  mask = (alphas[:,0,...][...,None] > 0.8)
400
  conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
401
  confidence_mask = (point_conf[0] > conf_threshold) & (point_conf[0] > 1e-5)
 
445
  scale = np.linalg.norm(distance, axis=1).max()
446
  voxel_size = 1/64*scale*2
447
  pcd = pcd.voxel_down_sample(voxel_size)
 
448
 
449
  for k in range(len(image_files)):
450
  images = torch.stack([TF.ToTensor()(render_image) for render_image in video['color']] + [TF.ToTensor()(image_files[k].convert("RGB"))], dim=0)
451
  # if len(images) == 0:
452
  with torch.no_grad():
453
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
454
+ # [VRAM ๋ถ„์‚ฐ ๋Œ€์‘] VGGT_model ํ˜ธ์ถœ
455
+ target_device = next(vggt.parameters()).device
456
+ images_in = images[None].to(target_device)
457
+ aggregated_tokens_list, ps_idx = vggt.aggregator(images_in)
458
  pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
459
+
460
+ # ๊ฒฐ๊ณผ ํšŒ์ˆ˜
461
+ pose_enc = pose_enc.to("cuda:0")
462
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
463
  extrinsic, intrinsic = extrinsic[0], intrinsic[0]
464
  extrinsic = torch.cat([extrinsic, torch.tensor([0,0,0,1])[None,None].repeat(extrinsic.shape[0], 1, 1).to(extrinsic.device)], dim=1)
465
+ del aggregated_tokens_list, ps_idx, images_in
466
 
467
  target_extrinsic, target_intrinsic = align_camera(registration_num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics)
468
  fxy = target_intrinsic[:,0,0]
 
489
  target_image = images[registration_num_frames:].to(target_extrinsic.device)[j]
490
  original_size = (rend_image.shape[1], rend_image.shape[2])
491
 
492
+ # import torchvision
493
+ # torchvision.utils.save_image(rend_image, 'rend_image_{}.png'.format(k))
494
+ # torchvision.utils.save_image(target_image, 'target_image_{}.png'.format(k))
495
 
496
  mask_rend = (rend_image.detach().cpu() > 0).any(dim=0)
497
  mask_target = (target_image.detach().cpu() > 0).any(dim=0)
 
522
  target_intrinsics = torch.cat(target_intrinsics, dim=0)
523
 
524
  target_fitnesses_filtered = [x for x in target_fitnesses if x <= 1]
525
+ if len(target_fitnesses_filtered) > 0:
526
+ idx = target_fitnesses.index(max(target_fitnesses_filtered))
527
+ else:
528
+ idx = 0
529
  target_transform = target_transforms[idx]
530
  down_pcd_align = copy.deepcopy(down_pcd).transform(target_transform)
531
  # pcd = o3d.geometry.PointCloud()
 
535
  o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True),
536
  o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration = 10000))
537
  down_pcd_align_2 = copy.deepcopy(down_pcd_align).transform(reg_p2p.transformation)
538
+ input_points = torch.tensor(np.asarray(down_pcd_align_2.points)).to("cuda:0").float()
539
  input_points = ((input_points + 0.5).clip(0, 1) * 64 - 0.5).to(torch.int32)
540
 
541
  outputs = pipeline.run_refine(
 
564
  )
565
  except Exception as e:
566
  print(f"Error during refinement: {e}")
567
+ import traceback
568
+ traceback.print_exc()
569
+
570
  # Render video
 
 
 
 
 
571
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
572
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
573
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
588
  return state, video_path, glb_path, glb_path
589
 
590
 
 
591
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
592
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
593
  gs, _ = unpack_state(state)
594
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
598
 
599
 
600
  def prepare_multi_example() -> List[Image.Image]:
601
+ if not os.path.exists("assets/example_multi_image"): return []
602
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
603
  images = []
604
  for case in multi_case:
 
615
 
616
 
617
  def split_image(image: Image.Image) -> List[Image.Image]:
 
 
 
618
  image_np = np.array(image)
619
+
620
+ # [์ˆ˜์ •] ์ฑ„๋„ ์ฒดํฌ: RGBA(4)๊ฐ€ ์•„๋‹ ๊ฒฝ์šฐ ๋‹จ์ผ ์ฒ˜๋ฆฌ
621
  if image_np.shape[-1] < 4:
622
  return [preprocess_image(image)]
623
+
624
  alpha = image_np[..., 3]
625
  alpha = np.any(alpha>0, axis=0)
626
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
 
777
  )
778
 
779
 
780
+ # [์ˆ˜์ •] 4 GPU VRAM ๋ถ„์‚ฐ ๋ฐฐ์น˜ ๋กœ์ง (๋ฉ”์ธ ์‹คํ–‰๋ถ€)
781
  if __name__ == "__main__":
782
+ # 1. ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ (๋ฉ”์ธ GPU)
783
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
784
 
785
  num_gpus = torch.cuda.device_count()
786
+ print(f"์‹œ์Šคํ…œ์—์„œ ๊ฐ์ง€๋œ GPU ๊ฐœ์ˆ˜: {num_gpus}")
787
+
788
  if num_gpus >= 4:
789
+ # [ํ•ต์‹ฌ] VRAM OOM ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ๋ชจ๋ธ์„ 4๊ฐœ GPU์— ์ˆ˜๋™์œผ๋กœ ๋ถ„์‚ฐ
790
+ pipeline.to("cuda:0") # ์‰˜์€ 0๋ฒˆ
791
+
792
+ # ๋ชจ๋ธ ์ด๋™
793
  if hasattr(pipeline, 'VGGT_model'):
794
  pipeline.VGGT_model.to("cuda:1")
795
  if hasattr(pipeline, 'birefnet_model'):
796
  pipeline.birefnet_model.to("cuda:2")
797
+
798
+ # ๊ฐ€์žฅ ๋ฌด๊ฑฐ์šด ๋””์ฝ”๋”๋“ค์„ 3๋ฒˆ์œผ๋กœ ๊ฒฉ๋ฆฌ
799
  if hasattr(pipeline, 'slat_decoder'):
800
  pipeline.slat_decoder.to("cuda:3")
801
  if hasattr(pipeline, 'sparse_structure_decoder'):
802
  pipeline.sparse_structure_decoder.to("cuda:3")
803
 
804
+ # Refine์šฉ Mast3r ๋ชจ๋ธ์€ 0๋ฒˆ (ํ˜น์€ ๋ฉ”๋ชจ๋ฆฌ ์—ฌ์œ  ์žˆ๋Š” ๊ณณ)
805
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
806
+ print("--- 4 GPU VRAM ๋ถ„์‚ฐ ๋ฐฐ์น˜ ์™„๋ฃŒ ---")
807
  else:
808
+ # GPU๊ฐ€ ๋ถ€์กฑํ•˜๋ฉด ์ผ๋ฐ˜ ๋กœ๋“œ
809
  pipeline.cuda()
810
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
811
+
812
  demo.launch()