abreza commited on
Commit
0d0c2b2
·
1 Parent(s): aa4c733
Files changed (1) hide show
  1. app.py +64 -23
app.py CHANGED
@@ -87,11 +87,13 @@ def create_user_temp_dir():
87
  # Global model initialization for Spatial Tracker
88
  print("🚀 Initializing tracking models...")
89
 
90
- vggt4track_model = VGGT4Track.from_pretrained(
91
- "Yuxihenry/SpatialTrackerV2_Front")
92
  vggt4track_model.eval()
93
  vggt4track_model = vggt4track_model.to("cuda")
94
 
 
 
 
95
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
96
  tracker_model.eval()
97
 
@@ -208,39 +210,78 @@ def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrin
208
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
209
 
210
 
 
211
  @spaces.GPU
212
- def run_spatial_tracker(video_tensor):
213
- if not hasattr(vggt4track_model, "infer"):
214
- vggt4track_model.infer = lambda x: vggt4track_model(x)
215
- if tracker_model.spatrack.base_model is None:
216
- tracker_model.spatrack.base_model = vggt4track_model
217
 
 
 
 
 
 
 
 
218
  video_input = preprocess_image(video_tensor)[None].cuda()
 
219
  with torch.no_grad():
220
- with torch.amp.autocast('cuda', dtype=torch.bfloat16):
221
  predictions = vggt4track_model(video_input / 255)
222
- extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
223
- depth_map, depth_conf = predictions["points_map"][...,
224
- 2], predictions["unc_metric"]
225
-
226
- depth_tensor, extrs, intrs = depth_map.squeeze().cpu().numpy(
227
- ), extrinsic.squeeze().cpu().numpy(), intrinsic.squeeze().cpu().numpy()
 
 
 
228
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
 
 
229
  tracker_model.spatrack.track_num = 512
230
  tracker_model.to("cuda")
231
- grid_pts = get_points_on_a_grid(
232
- 30, (video_input.shape[3], video_input.shape[4]), device="cpu")
233
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
234
- 0].numpy()
235
 
 
 
 
 
 
 
236
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
237
- c2w_traj, intrs_out, point_map, conf_depth, _, _, _, _, video_out = tracker_model.forward(
238
- video_input.squeeze(), depth=depth_tensor, intrs=intrs, extrs=extrs, queries=query_xyt,
239
- fps=1, unc_metric=unc_metric, support_frame=len(video_input.squeeze())-1
 
 
 
 
 
 
 
 
240
  )
241
- return {'video_out': video_out.cpu(), 'point_map': point_map.cpu(), 'conf_depth': conf_depth.cpu(), 'intrs_out': intrs_out.cpu(), 'c2w_traj': c2w_traj.cpu()}
242
 
243
- # --- TTM WAN INFERENCE FUNCTION ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
  @spaces.GPU
 
87
  # Global model initialization for Spatial Tracker
88
  print("🚀 Initializing tracking models...")
89
 
90
+ vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
 
91
  vggt4track_model.eval()
92
  vggt4track_model = vggt4track_model.to("cuda")
93
 
94
+ if not hasattr(vggt4track_model, 'infer'):
95
+ vggt4track_model.infer = vggt4track_model.forward
96
+
97
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
98
  tracker_model.eval()
99
 
 
210
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
211
 
212
 
213
+
214
  @spaces.GPU
215
+ def run_spatial_tracker(video_tensor: torch.Tensor):
216
+ """
217
+ GPU-intensive spatial tracking function.
 
 
218
 
219
+ Args:
220
+ video_tensor: Preprocessed video tensor (T, C, H, W)
221
+
222
+ Returns:
223
+ Dictionary containing tracking results
224
+ """
225
+ # Run VGGT to get depth and camera poses
226
  video_input = preprocess_image(video_tensor)[None].cuda()
227
+
228
  with torch.no_grad():
229
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
230
  predictions = vggt4track_model(video_input / 255)
231
+ extrinsic = predictions["poses_pred"]
232
+ intrinsic = predictions["intrs"]
233
+ depth_map = predictions["points_map"][..., 2]
234
+ depth_conf = predictions["unc_metric"]
235
+
236
+ depth_tensor = depth_map.squeeze().cpu().numpy()
237
+ extrs = extrinsic.squeeze().cpu().numpy()
238
+ intrs = intrinsic.squeeze().cpu().numpy()
239
+ video_tensor_gpu = video_input.squeeze()
240
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
241
+
242
+ # Setup tracker
243
  tracker_model.spatrack.track_num = 512
244
  tracker_model.to("cuda")
 
 
 
 
245
 
246
+ # Get grid points for tracking
247
+ frame_H, frame_W = video_tensor_gpu.shape[2:]
248
+ grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
249
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
250
+
251
+ # Run tracker
252
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
253
+ (
254
+ c2w_traj, intrs_out, point_map, conf_depth,
255
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video_out
256
+ ) = tracker_model.forward(
257
+ video_tensor_gpu, depth=depth_tensor,
258
+ intrs=intrs, extrs=extrs,
259
+ queries=query_xyt,
260
+ fps=1, full_point=False, iters_track=4,
261
+ query_no_BA=True, fixed_cam=False, stage=1,
262
+ unc_metric=unc_metric,
263
+ support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
264
  )
 
265
 
266
+ # Resize outputs for rendering
267
+ max_size = 384
268
+ h, w = video_out.shape[2:]
269
+ scale = min(max_size / h, max_size / w)
270
+ if scale < 1:
271
+ new_h, new_w = int(h * scale), int(w * scale)
272
+ video_out = T.Resize((new_h, new_w))(video_out)
273
+ point_map = T.Resize((new_h, new_w))(point_map)
274
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
275
+ intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
276
+
277
+ # Move results to CPU and return
278
+ return {
279
+ 'video_out': video_out.cpu(),
280
+ 'point_map': point_map.cpu(),
281
+ 'conf_depth': conf_depth.cpu(),
282
+ 'intrs_out': intrs_out.cpu(),
283
+ 'c2w_traj': c2w_traj.cpu(),
284
+ }
285
 
286
 
287
  @spaces.GPU