abreza commited on
Commit
9462c17
·
1 Parent(s): 85cb605

vggt4track_model to cuda inside hugging face zero gpu inference

Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -87,9 +87,9 @@ def create_user_temp_dir():
87
  # Global model initialization for Spatial Tracker
88
  print("🚀 Initializing tracking models...")
89
 
90
- vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
 
91
  vggt4track_model.eval()
92
- vggt4track_model = vggt4track_model.to("cuda")
93
 
94
  if not hasattr(vggt4track_model, 'infer'):
95
  vggt4track_model.infer = vggt4track_model.forward
@@ -105,7 +105,6 @@ wan_pipeline.vae.enable_tiling()
105
  wan_pipeline.vae.enable_slicing()
106
 
107
 
108
-
109
  print("✅ Tracking models loaded successfully!")
110
 
111
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
@@ -201,7 +200,6 @@ def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrin
201
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
202
 
203
 
204
-
205
  @spaces.GPU
206
  def run_spatial_tracker(video_tensor: torch.Tensor):
207
  """
@@ -216,6 +214,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
216
  # Run VGGT to get depth and camera poses
217
  video_input = preprocess_image(video_tensor)[None].cuda()
218
 
 
 
219
  with torch.no_grad():
220
  with torch.amp.autocast('cuda', dtype=torch.bfloat16):
221
  predictions = vggt4track_model(video_input / 255)
@@ -237,7 +237,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
237
  # Get grid points for tracking
238
  frame_H, frame_W = video_tensor_gpu.shape[2:]
239
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
240
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
 
241
 
242
  # Run tracker
243
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
@@ -292,7 +293,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
292
  "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
293
  )
294
 
295
-
296
  wan_pipeline.to("cuda")
297
 
298
  # Match resolution logic from run_wan.py
@@ -390,7 +390,8 @@ def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Pr
390
  # --- GRADIO INTERFACE ---
391
  with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
392
  gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
393
- gr.Markdown("Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
 
394
 
395
  # Shared state for TTM files - initialized as empty strings
396
  first_frame_file = gr.State("")
@@ -437,7 +438,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
437
  # the path string instead of the raw pixel array.
438
  motion_signal_output = gr.Video(label="motion_signal.mp4")
439
  mask_output = gr.Video(label="mask.mp4")
440
- first_frame_output = gr.Image(label="first_frame.png", type="filepath")
 
441
 
442
  # --- Event Handlers ---
443
 
 
87
  # Global model initialization for Spatial Tracker
88
  print("🚀 Initializing tracking models...")
89
 
90
+ vggt4track_model = VGGT4Track.from_pretrained(
91
+ "Yuxihenry/SpatialTrackerV2_Front")
92
  vggt4track_model.eval()
 
93
 
94
  if not hasattr(vggt4track_model, 'infer'):
95
  vggt4track_model.infer = vggt4track_model.forward
 
105
  wan_pipeline.vae.enable_slicing()
106
 
107
 
 
108
  print("✅ Tracking models loaded successfully!")
109
 
110
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
 
200
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
201
 
202
 
 
203
  @spaces.GPU
204
  def run_spatial_tracker(video_tensor: torch.Tensor):
205
  """
 
214
  # Run VGGT to get depth and camera poses
215
  video_input = preprocess_image(video_tensor)[None].cuda()
216
 
217
+ vggt4track_model = vggt4track_model.to("cuda")
218
+
219
  with torch.no_grad():
220
  with torch.amp.autocast('cuda', dtype=torch.bfloat16):
221
  predictions = vggt4track_model(video_input / 255)
 
237
  # Get grid points for tracking
238
  frame_H, frame_W = video_tensor_gpu.shape[2:]
239
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
240
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
241
+ 0].numpy()
242
 
243
  # Run tracker
244
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
 
293
  "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
294
  )
295
 
 
296
  wan_pipeline.to("cuda")
297
 
298
  # Match resolution logic from run_wan.py
 
390
  # --- GRADIO INTERFACE ---
391
  with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
392
  gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
393
+ gr.Markdown(
394
+ "Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
395
 
396
  # Shared state for TTM files - initialized as empty strings
397
  first_frame_file = gr.State("")
 
438
  # the path string instead of the raw pixel array.
439
  motion_signal_output = gr.Video(label="motion_signal.mp4")
440
  mask_output = gr.Video(label="mask.mp4")
441
+ first_frame_output = gr.Image(
442
+ label="first_frame.png", type="filepath")
443
 
444
  # --- Event Handlers ---
445