Spaces:
Sleeping
Sleeping
move
Browse filesvggt4track_model to cuda inside hugging face zero gpu inference
app.py
CHANGED
|
@@ -87,9 +87,9 @@ def create_user_temp_dir():
|
|
| 87 |
# Global model initialization for Spatial Tracker
|
| 88 |
print("🚀 Initializing tracking models...")
|
| 89 |
|
| 90 |
-
vggt4track_model = VGGT4Track.from_pretrained(
|
|
|
|
| 91 |
vggt4track_model.eval()
|
| 92 |
-
vggt4track_model = vggt4track_model.to("cuda")
|
| 93 |
|
| 94 |
if not hasattr(vggt4track_model, 'infer'):
|
| 95 |
vggt4track_model.infer = vggt4track_model.forward
|
|
@@ -105,7 +105,6 @@ wan_pipeline.vae.enable_tiling()
|
|
| 105 |
wan_pipeline.vae.enable_slicing()
|
| 106 |
|
| 107 |
|
| 108 |
-
|
| 109 |
print("✅ Tracking models loaded successfully!")
|
| 110 |
|
| 111 |
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
|
@@ -201,7 +200,6 @@ def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrin
|
|
| 201 |
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
|
| 202 |
|
| 203 |
|
| 204 |
-
|
| 205 |
@spaces.GPU
|
| 206 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 207 |
"""
|
|
@@ -216,6 +214,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 216 |
# Run VGGT to get depth and camera poses
|
| 217 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 218 |
|
|
|
|
|
|
|
| 219 |
with torch.no_grad():
|
| 220 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 221 |
predictions = vggt4track_model(video_input / 255)
|
|
@@ -237,7 +237,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 237 |
# Get grid points for tracking
|
| 238 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 239 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 240 |
-
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
|
|
|
| 241 |
|
| 242 |
# Run tracker
|
| 243 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
@@ -292,7 +293,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 292 |
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 293 |
)
|
| 294 |
|
| 295 |
-
|
| 296 |
wan_pipeline.to("cuda")
|
| 297 |
|
| 298 |
# Match resolution logic from run_wan.py
|
|
@@ -390,7 +390,8 @@ def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Pr
|
|
| 390 |
# --- GRADIO INTERFACE ---
|
| 391 |
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
|
| 392 |
gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
|
| 393 |
-
gr.Markdown(
|
|
|
|
| 394 |
|
| 395 |
# Shared state for TTM files - initialized as empty strings
|
| 396 |
first_frame_file = gr.State("")
|
|
@@ -437,7 +438,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
|
|
| 437 |
# the path string instead of the raw pixel array.
|
| 438 |
motion_signal_output = gr.Video(label="motion_signal.mp4")
|
| 439 |
mask_output = gr.Video(label="mask.mp4")
|
| 440 |
-
first_frame_output = gr.Image(
|
|
|
|
| 441 |
|
| 442 |
# --- Event Handlers ---
|
| 443 |
|
|
|
|
| 87 |
# Global model initialization for Spatial Tracker
|
| 88 |
print("🚀 Initializing tracking models...")
|
| 89 |
|
| 90 |
+
vggt4track_model = VGGT4Track.from_pretrained(
|
| 91 |
+
"Yuxihenry/SpatialTrackerV2_Front")
|
| 92 |
vggt4track_model.eval()
|
|
|
|
| 93 |
|
| 94 |
if not hasattr(vggt4track_model, 'infer'):
|
| 95 |
vggt4track_model.infer = vggt4track_model.forward
|
|
|
|
| 105 |
wan_pipeline.vae.enable_slicing()
|
| 106 |
|
| 107 |
|
|
|
|
| 108 |
print("✅ Tracking models loaded successfully!")
|
| 109 |
|
| 110 |
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
|
|
|
| 200 |
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
|
| 201 |
|
| 202 |
|
|
|
|
| 203 |
@spaces.GPU
|
| 204 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 205 |
"""
|
|
|
|
| 214 |
# Run VGGT to get depth and camera poses
|
| 215 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 216 |
|
| 217 |
+
vggt4track_model = vggt4track_model.to("cuda")
|
| 218 |
+
|
| 219 |
with torch.no_grad():
|
| 220 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 221 |
predictions = vggt4track_model(video_input / 255)
|
|
|
|
| 237 |
# Get grid points for tracking
|
| 238 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 239 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 240 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
| 241 |
+
0].numpy()
|
| 242 |
|
| 243 |
# Run tracker
|
| 244 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
|
| 293 |
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 294 |
)
|
| 295 |
|
|
|
|
| 296 |
wan_pipeline.to("cuda")
|
| 297 |
|
| 298 |
# Match resolution logic from run_wan.py
|
|
|
|
| 390 |
# --- GRADIO INTERFACE ---
|
| 391 |
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
|
| 392 |
gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
|
| 393 |
+
gr.Markdown(
|
| 394 |
+
"Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
|
| 395 |
|
| 396 |
# Shared state for TTM files - initialized as empty strings
|
| 397 |
first_frame_file = gr.State("")
|
|
|
|
| 438 |
# the path string instead of the raw pixel array.
|
| 439 |
motion_signal_output = gr.Video(label="motion_signal.mp4")
|
| 440 |
mask_output = gr.Video(label="mask.mp4")
|
| 441 |
+
first_frame_output = gr.Image(
|
| 442 |
+
label="first_frame.png", type="filepath")
|
| 443 |
|
| 444 |
# --- Event Handlers ---
|
| 445 |
|