dgfx commited on
Commit
eeb8299
·
verified ·
1 Parent(s): c45ea0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -78
app.py CHANGED
@@ -15,20 +15,21 @@ subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86
15
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
  os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
17
  # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
 
18
  from dust3r.inference import inference
19
  from dust3r.model import AsymmetricCroCo3DStereo
20
  from dust3r.utils.device import to_numpy
21
  from dust3r.image_pairs import make_pairs
22
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
23
  from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
24
-
25
  from argparse import ArgumentParser, Namespace
26
  from arguments import ModelParams, PipelineParams, OptimizationParams
27
  from train_joint import training
28
  from render_by_interp import render_sets
 
29
  GRADIO_CACHE_FOLDER = './gradio_cache_folder'
30
- #############################################################################################################################################
31
 
 
32
 
33
  def get_dust3r_args_parser():
34
  parser = argparse.ArgumentParser()
@@ -41,17 +42,15 @@ def get_dust3r_args_parser():
41
  parser.add_argument("--niter", type=int, default=300)
42
  parser.add_argument("--focal_avg", type=bool, default=True)
43
  parser.add_argument("--n_views", type=int, default=3)
44
- parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
45
  return parser
46
 
47
-
48
- @spaces.GPU(duration=120)
49
  def process(inputfiles, input_path=None):
50
-
51
  if input_path is not None:
52
  imgs_path = './assets/example/' + input_path
53
  imgs_names = sorted(os.listdir(imgs_path))
54
-
55
  inputfiles = []
56
  for imgs_name in imgs_names:
57
  file_path = os.path.join(imgs_path, imgs_name)
@@ -60,45 +59,48 @@ def process(inputfiles, input_path=None):
60
  print(inputfiles)
61
 
62
  # ------ (1) Coarse Geometric Initialization ------
63
- # os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
64
  parser = get_dust3r_args_parser()
65
  opt = parser.parse_args()
66
-
67
  tmp_user_folder = str(uuid.uuid4()).replace("-", "")
68
  opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
69
- img_folder_path = os.path.join(opt.img_base_path, "images")
70
-
71
- img_folder_path = os.path.join(opt.img_base_path, "images")
72
  model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
73
  os.makedirs(img_folder_path, exist_ok=True)
74
-
75
- opt.n_views = len(inputfiles)
76
  if opt.n_views == 1:
77
  raise gr.Error("The number of input images should be greater than 1.")
 
78
  print("Multiple images: ", inputfiles)
79
  for image_path in inputfiles:
80
  if input_path is not None:
81
  shutil.copy(image_path, img_folder_path)
82
  else:
83
  shutil.move(image_path, img_folder_path)
 
84
  train_img_list = sorted(os.listdir(img_folder_path))
85
  assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
86
- images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
 
87
  resolutions_are_equal = len(set(imgs_resolution)) == 1
88
- if resolutions_are_equal == False:
 
89
  raise gr.Error("The resolution of the input image should be the same.")
 
90
  print("ori_size", ori_size)
91
  start_time = time.time()
 
92
  ######################################################
93
  pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
94
  output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
95
- output_colmap_path=img_folder_path.replace("images", "sparse/0")
96
  os.makedirs(output_colmap_path, exist_ok=True)
97
-
98
  scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
99
  loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
100
- scene = scene.clean_pointcloud()
101
-
102
  imgs = to_numpy(scene.imgs)
103
  focals = scene.get_focals()
104
  poses = to_numpy(scene.get_im_poses())
@@ -107,29 +109,34 @@ def process(inputfiles, input_path=None):
107
  confidence_masks = to_numpy(scene.get_masks())
108
  intrinsics = to_numpy(scene.get_intrinsics())
109
  ######################################################
 
110
  end_time = time.time()
111
  print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
 
112
  save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
113
  save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
 
114
  pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
115
  color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
116
  color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
 
117
  storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
118
  pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
119
  np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
120
  np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
121
-
122
  ### save VRAM
123
  del scene
124
  torch.cuda.empty_cache()
125
  gc.collect()
 
126
  ##################################################################################################################################################
127
-
128
  # ------ (2) Fast 3D-Gaussian Optimization ------
129
  parser = ArgumentParser(description="Training script parameters")
130
  lp = ModelParams(parser)
131
  op = OptimizationParams(parser)
132
  pp = PipelineParams(parser)
 
133
  parser.add_argument('--debug_from', type=int, default=-1)
134
  parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
135
  parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
@@ -141,88 +148,49 @@ def process(inputfiles, input_path=None):
141
  parser.add_argument("--optim_pose", type=bool, default=True)
142
  parser.add_argument("--skip_train", action="store_true")
143
  parser.add_argument("--skip_test", action="store_true")
 
144
  args = parser.parse_args(sys.argv[1:])
145
  args.save_iterations.append(args.iterations)
146
- args.model_path = opt.img_base_path + '/output/'
147
  args.source_path = opt.img_base_path
148
- # args.model_path = GRADIO_CACHE_FOLDER + '/output/'
149
- # args.source_path = GRADIO_CACHE_FOLDER
150
  args.iteration = 10000
 
151
  os.makedirs(args.model_path, exist_ok=True)
152
  training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
 
153
  ##################################################################################################################################################
154
-
155
  # ------ (3) Render video by interpolation ------
156
  parser = ArgumentParser(description="Testing script parameters")
157
  model = ModelParams(parser, sentinel=True)
158
  pipeline = PipelineParams(parser)
 
159
  args.eval = True
160
  args.get_video = True
161
  args.n_views = opt.n_views
162
- # render_sets(
163
- # model.extract(args),
164
- # args.iteration,
165
- # pipeline.extract(args),
166
- # args.skip_train,
167
- # args.skip_test,
168
- # args,
169
- # )
170
  output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
171
  output_video_path = ""
172
- # output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
173
- # output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
174
- # output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
175
-
176
- return output_video_path, output_ply_path, output_ply_path
177
- ##################################################################################################################################################
178
-
179
 
 
180
 
181
  _TITLE = '''InstantSplat'''
182
- _DESCRIPTION = '''
183
- <div style="display: flex; justify-content: center; align-items: center;">
184
- <div style="width: 100%; text-align: center; font-size: 30px;">
185
- <strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
186
- </div>
187
- </div>
188
- <p></p>
189
-
190
- <div align="center">
191
- <a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>&nbsp;
192
- <a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>&nbsp;
193
- <a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
194
- <a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
195
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
196
- </a>
197
- </div>
198
- <p></p>
199
-
200
- * Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
201
- * Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
202
- * Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
203
- '''
204
-
205
-
206
- # <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a>&nbsp;
207
- # &nbsp;
208
- # <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
209
- # * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
210
 
211
-
212
- # block = gr.Blocks(title=_TITLE).queue()
213
  block = gr.Blocks().queue()
214
  with block:
215
  with gr.Row():
216
  with gr.Column(scale=1):
217
- # gr.Markdown('# ' + _TITLE)
218
  gr.Markdown(_DESCRIPTION)
219
-
220
  with gr.Row(variant='panel'):
221
  with gr.Tab("Input"):
222
  inputfiles = gr.File(file_count="multiple", label="images")
223
  input_path = gr.Textbox(visible=False, label="example_path")
224
  button_gen = gr.Button("RUN")
225
-
226
  with gr.Row(variant='panel'):
227
  with gr.Tab("Output"):
228
  with gr.Column(scale=2):
@@ -230,7 +198,7 @@ with block:
230
  output_model = gr.Model3D(
231
  label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
232
  interactive=False,
233
- camera_position=[0.5, 0.5, 1], # 稍微偏移一点,以便更好地查看模型
234
  )
235
  gr.Markdown(
236
  """
@@ -238,11 +206,11 @@ with block:
238
  &nbsp;&nbsp;Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move.
239
  </div>
240
  """
241
- )
242
- output_file = gr.File(label="ply")
243
  with gr.Column(scale=1):
244
  output_video = gr.Video(label="video")
245
 
246
  button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_model])
247
-
248
  block.launch(server_name="0.0.0.0", share=False)
 
15
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
  os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
17
  # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
18
+
19
  from dust3r.inference import inference
20
  from dust3r.model import AsymmetricCroCo3DStereo
21
  from dust3r.utils.device import to_numpy
22
  from dust3r.image_pairs import make_pairs
23
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
24
  from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
 
25
  from argparse import ArgumentParser, Namespace
26
  from arguments import ModelParams, PipelineParams, OptimizationParams
27
  from train_joint import training
28
  from render_by_interp import render_sets
29
+
30
  GRADIO_CACHE_FOLDER = './gradio_cache_folder'
 
31
 
32
+ #############################################################################################################################################
33
 
34
  def get_dust3r_args_parser():
35
  parser = argparse.ArgumentParser()
 
42
  parser.add_argument("--niter", type=int, default=300)
43
  parser.add_argument("--focal_avg", type=bool, default=True)
44
  parser.add_argument("--n_views", type=int, default=3)
45
+ parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
46
  return parser
47
 
48
+ # FIXED: Explicitly set duration to 60 seconds to pass free-tier ZeroGPU checks
49
+ @spaces.GPU(duration=60)
50
  def process(inputfiles, input_path=None):
 
51
  if input_path is not None:
52
  imgs_path = './assets/example/' + input_path
53
  imgs_names = sorted(os.listdir(imgs_path))
 
54
  inputfiles = []
55
  for imgs_name in imgs_names:
56
  file_path = os.path.join(imgs_path, imgs_name)
 
59
  print(inputfiles)
60
 
61
  # ------ (1) Coarse Geometric Initialization ------
 
62
  parser = get_dust3r_args_parser()
63
  opt = parser.parse_args()
 
64
  tmp_user_folder = str(uuid.uuid4()).replace("-", "")
65
  opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
66
+ img_folder_path = os.path.join(opt.img_base_path, "images")
67
+
 
68
  model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
69
  os.makedirs(img_folder_path, exist_ok=True)
70
+ opt.n_views = len(inputfiles)
71
+
72
  if opt.n_views == 1:
73
  raise gr.Error("The number of input images should be greater than 1.")
74
+
75
  print("Multiple images: ", inputfiles)
76
  for image_path in inputfiles:
77
  if input_path is not None:
78
  shutil.copy(image_path, img_folder_path)
79
  else:
80
  shutil.move(image_path, img_folder_path)
81
+
82
  train_img_list = sorted(os.listdir(img_folder_path))
83
  assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
84
+
85
+ images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
86
  resolutions_are_equal = len(set(imgs_resolution)) == 1
87
+
88
+ if not resolutions_are_equal:
89
  raise gr.Error("The resolution of the input image should be the same.")
90
+
91
  print("ori_size", ori_size)
92
  start_time = time.time()
93
+
94
  ######################################################
95
  pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
96
  output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
97
+ output_colmap_path = img_folder_path.replace("images", "sparse/0")
98
  os.makedirs(output_colmap_path, exist_ok=True)
99
+
100
  scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
101
  loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
102
+ scene = scene.clean_pointcloud()
103
+
104
  imgs = to_numpy(scene.imgs)
105
  focals = scene.get_focals()
106
  poses = to_numpy(scene.get_im_poses())
 
109
  confidence_masks = to_numpy(scene.get_masks())
110
  intrinsics = to_numpy(scene.get_intrinsics())
111
  ######################################################
112
+
113
  end_time = time.time()
114
  print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
115
+
116
  save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
117
  save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
118
+
119
  pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
120
  color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
121
  color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
122
+
123
  storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
124
  pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
125
  np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
126
  np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
127
+
128
  ### save VRAM
129
  del scene
130
  torch.cuda.empty_cache()
131
  gc.collect()
132
+
133
  ##################################################################################################################################################
 
134
  # ------ (2) Fast 3D-Gaussian Optimization ------
135
  parser = ArgumentParser(description="Training script parameters")
136
  lp = ModelParams(parser)
137
  op = OptimizationParams(parser)
138
  pp = PipelineParams(parser)
139
+
140
  parser.add_argument('--debug_from', type=int, default=-1)
141
  parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
142
  parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
 
148
  parser.add_argument("--optim_pose", type=bool, default=True)
149
  parser.add_argument("--skip_train", action="store_true")
150
  parser.add_argument("--skip_test", action="store_true")
151
+
152
  args = parser.parse_args(sys.argv[1:])
153
  args.save_iterations.append(args.iterations)
154
+ args.model_path = opt.img_base_path + '/output/'
155
  args.source_path = opt.img_base_path
 
 
156
  args.iteration = 10000
157
+
158
  os.makedirs(args.model_path, exist_ok=True)
159
  training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
160
+
161
  ##################################################################################################################################################
 
162
  # ------ (3) Render video by interpolation ------
163
  parser = ArgumentParser(description="Testing script parameters")
164
  model = ModelParams(parser, sentinel=True)
165
  pipeline = PipelineParams(parser)
166
+
167
  args.eval = True
168
  args.get_video = True
169
  args.n_views = opt.n_views
170
+
 
 
 
 
 
 
 
171
  output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
172
  output_video_path = ""
173
+
174
+ return output_video_path, output_ply_path, output_ply_path
 
 
 
 
 
175
 
176
+ ##################################################################################################################################################
177
 
178
  _TITLE = '''InstantSplat'''
179
+ _DESCRIPTION = '''<div style="display: flex; justify-content: center; align-items: center;"> <div style="width: 100%; text-align: center; font-size: 30px;"> <strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong> </div></div> <p></p><div align="center"> <a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>&nbsp; <a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>&nbsp; <a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a> <a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> </a></div><p></p>* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # FIXED: Re-aligned the Gradio frontend so the UI mounts correctly
 
182
  block = gr.Blocks().queue()
183
  with block:
184
  with gr.Row():
185
  with gr.Column(scale=1):
 
186
  gr.Markdown(_DESCRIPTION)
187
+
188
  with gr.Row(variant='panel'):
189
  with gr.Tab("Input"):
190
  inputfiles = gr.File(file_count="multiple", label="images")
191
  input_path = gr.Textbox(visible=False, label="example_path")
192
  button_gen = gr.Button("RUN")
193
+
194
  with gr.Row(variant='panel'):
195
  with gr.Tab("Output"):
196
  with gr.Column(scale=2):
 
198
  output_model = gr.Model3D(
199
  label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
200
  interactive=False,
201
+ camera_position=[0.5, 0.5, 1],
202
  )
203
  gr.Markdown(
204
  """
 
206
  &nbsp;&nbsp;Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move.
207
  </div>
208
  """
209
+ )
210
+ output_file = gr.File(label="ply")
211
  with gr.Column(scale=1):
212
  output_video = gr.Video(label="video")
213
 
214
  button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_model])
215
+
216
  block.launch(server_name="0.0.0.0", share=False)