Spaces:
Sleeping
Sleeping
refactor: enhance track visualization functionality
Browse files
app.py
CHANGED
|
@@ -24,26 +24,22 @@ from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
|
| 24 |
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
| 25 |
from models.SpaTrackV2.models.predictor import Predictor
|
| 26 |
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
# --- TTM SPECIFIC IMPORTS ---
|
| 30 |
from diffusers.utils import export_to_video, load_image
|
| 31 |
-
|
| 32 |
from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
|
| 33 |
-
from pipelines.utils import compute_hw_from_area
|
| 34 |
|
| 35 |
-
# Configure logging
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
| 37 |
logger = logging.getLogger(__name__)
|
| 38 |
|
| 39 |
-
# Constants
|
| 40 |
MAX_FRAMES = 81
|
| 41 |
OUTPUT_FPS = 24
|
| 42 |
RENDER_WIDTH = 512
|
| 43 |
RENDER_HEIGHT = 384
|
| 44 |
WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 45 |
|
| 46 |
-
# Camera movement types
|
| 47 |
CAMERA_MOVEMENTS = [
|
| 48 |
"static",
|
| 49 |
"move_forward",
|
|
@@ -54,7 +50,6 @@ CAMERA_MOVEMENTS = [
|
|
| 54 |
"move_down"
|
| 55 |
]
|
| 56 |
|
| 57 |
-
# Thread pool for delayed deletion
|
| 58 |
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 59 |
|
| 60 |
|
|
@@ -84,7 +79,6 @@ def create_user_temp_dir():
|
|
| 84 |
return temp_dir
|
| 85 |
|
| 86 |
|
| 87 |
-
# Global model initialization for Spatial Tracker
|
| 88 |
print("🚀 Initializing tracking models...")
|
| 89 |
|
| 90 |
vggt4track_model = VGGT4Track.from_pretrained(
|
|
@@ -200,6 +194,145 @@ def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrin
|
|
| 200 |
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
|
| 201 |
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
@spaces.GPU
|
| 204 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 205 |
"""
|
|
@@ -211,7 +344,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 211 |
Returns:
|
| 212 |
Dictionary containing tracking results
|
| 213 |
"""
|
| 214 |
-
# Run VGGT to get depth and camera poses
|
| 215 |
global vggt4track_model
|
| 216 |
global tracker_model
|
| 217 |
global wan_pipeline
|
|
@@ -234,17 +366,14 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 234 |
video_tensor_gpu = video_input.squeeze()
|
| 235 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 236 |
|
| 237 |
-
# Setup tracker
|
| 238 |
tracker_model.spatrack.track_num = 512
|
| 239 |
tracker_model.to("cuda")
|
| 240 |
|
| 241 |
-
# Get grid points for tracking
|
| 242 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 243 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 244 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
| 245 |
0].numpy()
|
| 246 |
|
| 247 |
-
# Run tracker
|
| 248 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 249 |
(
|
| 250 |
c2w_traj, intrs_out, point_map, conf_depth,
|
|
@@ -259,7 +388,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 259 |
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
|
| 260 |
)
|
| 261 |
|
| 262 |
-
# Resize outputs for rendering
|
| 263 |
max_size = 384
|
| 264 |
h, w = video_out.shape[2:]
|
| 265 |
scale = min(max_size / h, max_size / w)
|
|
@@ -269,14 +397,19 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 269 |
point_map = T.Resize((new_h, new_w))(point_map)
|
| 270 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 271 |
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
# Move results to CPU and return
|
| 274 |
return {
|
| 275 |
'video_out': video_out.cpu(),
|
| 276 |
'point_map': point_map.cpu(),
|
| 277 |
'conf_depth': conf_depth.cpu(),
|
| 278 |
'intrs_out': intrs_out.cpu(),
|
| 279 |
'c2w_traj': c2w_traj.cpu(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
}
|
| 281 |
|
| 282 |
|
|
@@ -301,7 +434,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 301 |
progress(0.2, desc="Preparing inputs...")
|
| 302 |
image = load_image(first_frame_path)
|
| 303 |
|
| 304 |
-
# Standard Wan Negative Prompt
|
| 305 |
negative_prompt = (
|
| 306 |
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
|
| 307 |
"低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
|
|
@@ -310,7 +442,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 310 |
|
| 311 |
wan_pipeline.to("cuda")
|
| 312 |
|
| 313 |
-
# Match resolution logic from run_wan.py
|
| 314 |
max_area = 480 * 832
|
| 315 |
mod_value = wan_pipeline.vae_scale_factor_spatial * \
|
| 316 |
wan_pipeline.transformer.config.patch_size[1]
|
|
@@ -349,7 +480,7 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 349 |
|
| 350 |
def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Progress()):
|
| 351 |
if video_path is None:
|
| 352 |
-
return None, None, None, None, "❌ Please upload a video first"
|
| 353 |
|
| 354 |
progress(0, desc="Initializing...")
|
| 355 |
temp_dir = create_user_temp_dir()
|
|
@@ -383,6 +514,33 @@ def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Pr
|
|
| 383 |
new_exts = generate_camera_trajectory(len(
|
| 384 |
rgb_frames), camera_movement, tracking_results['intrs_out'].numpy(), scene_scale)
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
progress(0.8, desc="Rendering viewpoint...")
|
| 387 |
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
|
| 388 |
render_results = render_from_pointcloud(rgb_frames, depth_frames, tracking_results['intrs_out'].numpy(),
|
|
@@ -395,11 +553,13 @@ def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Pr
|
|
| 395 |
rgb_frames[0], cv2.COLOR_RGB2BGR))
|
| 396 |
|
| 397 |
status_msg = f"✅ 3D results ready! You can now use the prompt below to generate a high-quality TTM video."
|
| 398 |
-
return render_results['rendered'], render_results['motion_signal'], render_results['mask'], first_frame_path, status_msg
|
| 399 |
|
| 400 |
except Exception as e:
|
| 401 |
logger.error(f"Error: {e}")
|
| 402 |
-
|
|
|
|
|
|
|
| 403 |
|
| 404 |
|
| 405 |
# --- GRADIO INTERFACE ---
|
|
@@ -408,7 +568,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
|
|
| 408 |
gr.Markdown(
|
| 409 |
"Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
|
| 410 |
|
| 411 |
-
# Shared state for TTM files - initialized as empty strings
|
| 412 |
first_frame_file = gr.State("")
|
| 413 |
motion_signal_file = gr.State("")
|
| 414 |
mask_file = gr.State("")
|
|
@@ -426,6 +585,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
|
|
| 426 |
"🚀 1. Run Spatial Tracker", variant="primary")
|
| 427 |
|
| 428 |
output_video = gr.Video(label="Point Cloud Render (Draft)")
|
|
|
|
| 429 |
status_text = gr.Markdown("Ready...")
|
| 430 |
|
| 431 |
with gr.Column(scale=1):
|
|
@@ -446,19 +606,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
|
|
| 446 |
wan_output_video = gr.Video(label="Final High-Quality TTM Video")
|
| 447 |
wan_status = gr.Markdown("Awaiting 3D inputs...")
|
| 448 |
|
| 449 |
-
# The Accordion provides a visual check of what TTM is using
|
| 450 |
with gr.Accordion("Debug: TTM Intermediate Inputs", open=False):
|
| 451 |
with gr.Row():
|
| 452 |
-
# IMPORTANT: type="filepath" prevents the ValueError by passing
|
| 453 |
-
# the path string instead of the raw pixel array.
|
| 454 |
motion_signal_output = gr.Video(label="motion_signal.mp4")
|
| 455 |
mask_output = gr.Video(label="mask.mp4")
|
| 456 |
first_frame_output = gr.Image(
|
| 457 |
label="first_frame.png", type="filepath")
|
| 458 |
|
| 459 |
-
# --- Event Handlers ---
|
| 460 |
|
| 461 |
-
# 1. Process 3D Tracking and save results to temporary local files
|
| 462 |
generate_btn.click(
|
| 463 |
fn=process_video,
|
| 464 |
inputs=[video_input, camera_movement],
|
|
@@ -467,23 +622,22 @@ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as
|
|
| 467 |
motion_signal_output,
|
| 468 |
mask_output,
|
| 469 |
first_frame_output,
|
|
|
|
| 470 |
status_text
|
| 471 |
]
|
| 472 |
).then(
|
| 473 |
-
|
| 474 |
-
# We ignore the 'output_video' (index 0) and 'status_text' (index 4).
|
| 475 |
-
fn=lambda a, b, c, d, e: (b, c, d),
|
| 476 |
inputs=[
|
| 477 |
output_video,
|
| 478 |
motion_signal_output,
|
| 479 |
mask_output,
|
| 480 |
first_frame_output,
|
|
|
|
| 481 |
status_text
|
| 482 |
],
|
| 483 |
outputs=[motion_signal_file, mask_file, first_frame_file]
|
| 484 |
)
|
| 485 |
|
| 486 |
-
# 3. Use the stored paths to run the Wan 2.2 TTM Dual-Clock Denoising loop
|
| 487 |
wan_generate_btn.click(
|
| 488 |
fn=run_wan_ttm_generation,
|
| 489 |
inputs=[
|
|
|
|
| 24 |
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
| 25 |
from models.SpaTrackV2.models.predictor import Predictor
|
| 26 |
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 27 |
+
from matplotlib import cm
|
| 28 |
|
|
|
|
|
|
|
| 29 |
from diffusers.utils import export_to_video, load_image
|
| 30 |
+
|
| 31 |
from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
|
| 32 |
+
from pipelines.utils import compute_hw_from_area
|
| 33 |
|
|
|
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
|
|
|
| 37 |
MAX_FRAMES = 81
|
| 38 |
OUTPUT_FPS = 24
|
| 39 |
RENDER_WIDTH = 512
|
| 40 |
RENDER_HEIGHT = 384
|
| 41 |
WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 42 |
|
|
|
|
| 43 |
CAMERA_MOVEMENTS = [
|
| 44 |
"static",
|
| 45 |
"move_forward",
|
|
|
|
| 50 |
"move_down"
|
| 51 |
]
|
| 52 |
|
|
|
|
| 53 |
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 54 |
|
| 55 |
|
|
|
|
| 79 |
return temp_dir
|
| 80 |
|
| 81 |
|
|
|
|
| 82 |
print("🚀 Initializing tracking models...")
|
| 83 |
|
| 84 |
vggt4track_model = VGGT4Track.from_pretrained(
|
|
|
|
| 194 |
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
|
| 195 |
|
| 196 |
|
| 197 |
+
def visualize_tracks_with_selective_history(
|
| 198 |
+
video: np.ndarray, # (T, H, W, C) uint8
|
| 199 |
+
tracks: np.ndarray, # (T, N, 2) float - 2D track coordinates
|
| 200 |
+
track_depths: np.ndarray, # (T, N) float - depth at each track point
|
| 201 |
+
visibility: np.ndarray = None, # (T, N) float
|
| 202 |
+
output_path: str = None,
|
| 203 |
+
fps: int = 24,
|
| 204 |
+
depth_threshold: float = None, # Threshold to separate foreground/background
|
| 205 |
+
bg_history_length: int = -1, # History length for background (-1 = infinite)
|
| 206 |
+
fg_history_length: int = 0, # History length for foreground (0 = no history)
|
| 207 |
+
linewidth: int = 2,
|
| 208 |
+
point_radius: int = 4,
|
| 209 |
+
):
|
| 210 |
+
"""
|
| 211 |
+
Visualize tracked points with selective history trails.
|
| 212 |
+
|
| 213 |
+
Background points (depth > threshold): Show history trails
|
| 214 |
+
Foreground points (depth <= threshold): No history trails (only current position)
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
video: Input video frames (T, H, W, C)
|
| 218 |
+
tracks: 2D track positions (T, N, 2) - (x, y) coordinates
|
| 219 |
+
track_depths: Depth values at each track point (T, N)
|
| 220 |
+
visibility: Visibility mask for each track point (T, N)
|
| 221 |
+
output_path: Path to save the output video
|
| 222 |
+
fps: Output video frame rate
|
| 223 |
+
depth_threshold: Depth threshold to separate FG/BG. If None, uses median.
|
| 224 |
+
bg_history_length: History length for background tracks (-1 = infinite)
|
| 225 |
+
fg_history_length: History length for foreground tracks (0 = no history)
|
| 226 |
+
linewidth: Width of the history trail lines
|
| 227 |
+
point_radius: Radius of the current position marker
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Annotated video frames (T, H, W, C)
|
| 231 |
+
"""
|
| 232 |
+
T, H, W, C = video.shape
|
| 233 |
+
N = tracks.shape[1]
|
| 234 |
+
|
| 235 |
+
# Compute depth threshold if not provided
|
| 236 |
+
# Use median depth at frame 0 as the threshold
|
| 237 |
+
if depth_threshold is None:
|
| 238 |
+
valid_depths = track_depths[0][track_depths[0] > 0]
|
| 239 |
+
if len(valid_depths) > 0:
|
| 240 |
+
depth_threshold = np.median(valid_depths)
|
| 241 |
+
else:
|
| 242 |
+
depth_threshold = 1.0
|
| 243 |
+
|
| 244 |
+
# Classify points as foreground or background based on initial depth
|
| 245 |
+
# Foreground = closer (smaller depth), Background = farther (larger depth)
|
| 246 |
+
initial_depths = track_depths[0]
|
| 247 |
+
is_background = initial_depths > depth_threshold
|
| 248 |
+
|
| 249 |
+
# Create color palette - rainbow based on y-coordinate
|
| 250 |
+
color_map = cm.get_cmap("gist_rainbow")
|
| 251 |
+
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
|
| 252 |
+
if y_max - y_min < 1e-6:
|
| 253 |
+
y_max = y_min + 1
|
| 254 |
+
|
| 255 |
+
colors = np.zeros((N, 3), dtype=np.uint8)
|
| 256 |
+
for i in range(N):
|
| 257 |
+
norm_y = (tracks[0, i, 1] - y_min) / (y_max - y_min)
|
| 258 |
+
color = np.array(color_map(norm_y)[:3]) * 255
|
| 259 |
+
colors[i] = color.astype(np.uint8)
|
| 260 |
+
|
| 261 |
+
# Process each frame
|
| 262 |
+
res_video = video.copy()
|
| 263 |
+
|
| 264 |
+
for t in range(T):
|
| 265 |
+
frame = res_video[t].copy()
|
| 266 |
+
|
| 267 |
+
# Draw history trails
|
| 268 |
+
# Background points: draw full history
|
| 269 |
+
# Foreground points: draw limited or no history
|
| 270 |
+
for i in range(N):
|
| 271 |
+
# Determine history length based on foreground/background
|
| 272 |
+
if is_background[i]:
|
| 273 |
+
history_len = bg_history_length
|
| 274 |
+
else:
|
| 275 |
+
history_len = fg_history_length
|
| 276 |
+
|
| 277 |
+
# Calculate start frame for history
|
| 278 |
+
if history_len < 0: # Infinite history
|
| 279 |
+
start_frame = 0
|
| 280 |
+
elif history_len == 0: # No history
|
| 281 |
+
start_frame = t
|
| 282 |
+
else:
|
| 283 |
+
start_frame = max(0, t - history_len)
|
| 284 |
+
|
| 285 |
+
# Draw history trail (lines connecting past positions)
|
| 286 |
+
if start_frame < t:
|
| 287 |
+
for j in range(start_frame, t):
|
| 288 |
+
# Check visibility
|
| 289 |
+
if visibility is not None:
|
| 290 |
+
if visibility[j, i] < 0.5 or visibility[j + 1, i] < 0.5:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
pt1 = (int(tracks[j, i, 0]), int(tracks[j, i, 1]))
|
| 294 |
+
pt2 = (int(tracks[j + 1, i, 0]), int(tracks[j + 1, i, 1]))
|
| 295 |
+
|
| 296 |
+
# Skip invalid coordinates
|
| 297 |
+
if pt1[0] <= 0 or pt1[1] <= 0 or pt2[0] <= 0 or pt2[1] <= 0:
|
| 298 |
+
continue
|
| 299 |
+
if pt1[0] >= W or pt1[1] >= H or pt2[0] >= W or pt2[1] >= H:
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
# Draw with fading alpha (older = more transparent)
|
| 303 |
+
alpha = (j - start_frame + 1) / (t - start_frame + 1)
|
| 304 |
+
color = colors[i].tolist()
|
| 305 |
+
cv2.line(frame, pt1, pt2, color, linewidth, cv2.LINE_AA)
|
| 306 |
+
|
| 307 |
+
# Draw current positions (dots)
|
| 308 |
+
for i in range(N):
|
| 309 |
+
if visibility is not None and visibility[t, i] < 0.5:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
coord = (int(tracks[t, i, 0]), int(tracks[t, i, 1]))
|
| 313 |
+
|
| 314 |
+
# Skip invalid coordinates
|
| 315 |
+
if coord[0] <= 0 or coord[1] <= 0:
|
| 316 |
+
continue
|
| 317 |
+
if coord[0] >= W or coord[1] >= H:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
color = colors[i].tolist()
|
| 321 |
+
cv2.circle(frame, coord, point_radius, color, -1)
|
| 322 |
+
|
| 323 |
+
res_video[t] = frame
|
| 324 |
+
|
| 325 |
+
# Save video if output path provided
|
| 326 |
+
if output_path:
|
| 327 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 328 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
|
| 329 |
+
for frame in res_video:
|
| 330 |
+
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
| 331 |
+
out.release()
|
| 332 |
+
|
| 333 |
+
return res_video
|
| 334 |
+
|
| 335 |
+
|
| 336 |
@spaces.GPU
|
| 337 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 338 |
"""
|
|
|
|
| 344 |
Returns:
|
| 345 |
Dictionary containing tracking results
|
| 346 |
"""
|
|
|
|
| 347 |
global vggt4track_model
|
| 348 |
global tracker_model
|
| 349 |
global wan_pipeline
|
|
|
|
| 366 |
video_tensor_gpu = video_input.squeeze()
|
| 367 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 368 |
|
|
|
|
| 369 |
tracker_model.spatrack.track_num = 512
|
| 370 |
tracker_model.to("cuda")
|
| 371 |
|
|
|
|
| 372 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 373 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 374 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
| 375 |
0].numpy()
|
| 376 |
|
|
|
|
| 377 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 378 |
(
|
| 379 |
c2w_traj, intrs_out, point_map, conf_depth,
|
|
|
|
| 388 |
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
|
| 389 |
)
|
| 390 |
|
|
|
|
| 391 |
max_size = 384
|
| 392 |
h, w = video_out.shape[2:]
|
| 393 |
scale = min(max_size / h, max_size / w)
|
|
|
|
| 397 |
point_map = T.Resize((new_h, new_w))(point_map)
|
| 398 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 399 |
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
|
| 400 |
+
# Scale 2D track coordinates
|
| 401 |
+
track2d_pred[..., :2] = track2d_pred[..., :2] * scale
|
| 402 |
|
|
|
|
| 403 |
return {
|
| 404 |
'video_out': video_out.cpu(),
|
| 405 |
'point_map': point_map.cpu(),
|
| 406 |
'conf_depth': conf_depth.cpu(),
|
| 407 |
'intrs_out': intrs_out.cpu(),
|
| 408 |
'c2w_traj': c2w_traj.cpu(),
|
| 409 |
+
'track2d_pred': track2d_pred.cpu(), # 2D track positions (T, N, 2)
|
| 410 |
+
'track3d_pred': track3d_pred.cpu(), # 3D track positions (T, N, 3)
|
| 411 |
+
'vis_pred': vis_pred.cpu(), # Visibility mask (T, N)
|
| 412 |
+
'conf_pred': conf_pred.cpu(), # Confidence scores (T, N)
|
| 413 |
}
|
| 414 |
|
| 415 |
|
|
|
|
| 434 |
progress(0.2, desc="Preparing inputs...")
|
| 435 |
image = load_image(first_frame_path)
|
| 436 |
|
|
|
|
| 437 |
negative_prompt = (
|
| 438 |
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
|
| 439 |
"低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
|
|
|
|
| 442 |
|
| 443 |
wan_pipeline.to("cuda")
|
| 444 |
|
|
|
|
| 445 |
max_area = 480 * 832
|
| 446 |
mod_value = wan_pipeline.vae_scale_factor_spatial * \
|
| 447 |
wan_pipeline.transformer.config.patch_size[1]
|
|
|
|
| 480 |
|
| 481 |
def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Progress()):
|
| 482 |
if video_path is None:
|
| 483 |
+
return None, None, None, None, None, "❌ Please upload a video first"
|
| 484 |
|
| 485 |
progress(0, desc="Initializing...")
|
| 486 |
temp_dir = create_user_temp_dir()
|
|
|
|
| 514 |
new_exts = generate_camera_trajectory(len(
|
| 515 |
rgb_frames), camera_movement, tracking_results['intrs_out'].numpy(), scene_scale)
|
| 516 |
|
| 517 |
+
progress(0.7, desc="Visualizing tracks...")
|
| 518 |
+
# Get track data for visualization
|
| 519 |
+
track2d = tracking_results['track2d_pred'].numpy() # (T, N, 2+)
|
| 520 |
+
track3d = tracking_results['track3d_pred'].numpy() # (T, N, 3)
|
| 521 |
+
vis_pred = tracking_results['vis_pred'].numpy() # (T, N)
|
| 522 |
+
|
| 523 |
+
# Get depth at each track point (use Z coordinate from 3D tracks)
|
| 524 |
+
track_depths = track3d[..., 2] # (T, N) - depth is the Z coordinate
|
| 525 |
+
|
| 526 |
+
# Create track visualization with selective history:
|
| 527 |
+
# - Background points (farther): show history trails
|
| 528 |
+
# - Foreground points (closer): no history trails
|
| 529 |
+
track_viz_path = os.path.join(out_dir, "track_visualization.mp4")
|
| 530 |
+
visualize_tracks_with_selective_history(
|
| 531 |
+
video=rgb_frames.copy(),
|
| 532 |
+
tracks=track2d[..., :2], # Use only x, y coordinates
|
| 533 |
+
track_depths=track_depths,
|
| 534 |
+
visibility=vis_pred,
|
| 535 |
+
output_path=track_viz_path,
|
| 536 |
+
fps=OUTPUT_FPS,
|
| 537 |
+
depth_threshold=None, # Auto-compute based on median depth
|
| 538 |
+
bg_history_length=-1, # Infinite history for background
|
| 539 |
+
fg_history_length=0, # No history for foreground
|
| 540 |
+
linewidth=2,
|
| 541 |
+
point_radius=4
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
progress(0.8, desc="Rendering viewpoint...")
|
| 545 |
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
|
| 546 |
render_results = render_from_pointcloud(rgb_frames, depth_frames, tracking_results['intrs_out'].numpy(),
|
|
|
|
| 553 |
rgb_frames[0], cv2.COLOR_RGB2BGR))
|
| 554 |
|
| 555 |
status_msg = f"✅ 3D results ready! You can now use the prompt below to generate a high-quality TTM video."
|
| 556 |
+
return render_results['rendered'], render_results['motion_signal'], render_results['mask'], first_frame_path, track_viz_path, status_msg
|
| 557 |
|
| 558 |
except Exception as e:
|
| 559 |
logger.error(f"Error: {e}")
|
| 560 |
+
import traceback
|
| 561 |
+
traceback.print_exc()
|
| 562 |
+
return None, None, None, None, None, f"❌ Error: {str(e)}"
|
| 563 |
|
| 564 |
|
| 565 |
# --- GRADIO INTERFACE ---
|
|
|
|
| 568 |
gr.Markdown(
|
| 569 |
"Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
|
| 570 |
|
|
|
|
| 571 |
first_frame_file = gr.State("")
|
| 572 |
motion_signal_file = gr.State("")
|
| 573 |
mask_file = gr.State("")
|
|
|
|
| 585 |
"🚀 1. Run Spatial Tracker", variant="primary")
|
| 586 |
|
| 587 |
output_video = gr.Video(label="Point Cloud Render (Draft)")
|
| 588 |
+
track_viz_output = gr.Video(label="Track Visualization (BG history, FG no history)")
|
| 589 |
status_text = gr.Markdown("Ready...")
|
| 590 |
|
| 591 |
with gr.Column(scale=1):
|
|
|
|
| 606 |
wan_output_video = gr.Video(label="Final High-Quality TTM Video")
|
| 607 |
wan_status = gr.Markdown("Awaiting 3D inputs...")
|
| 608 |
|
|
|
|
| 609 |
with gr.Accordion("Debug: TTM Intermediate Inputs", open=False):
|
| 610 |
with gr.Row():
|
|
|
|
|
|
|
| 611 |
motion_signal_output = gr.Video(label="motion_signal.mp4")
|
| 612 |
mask_output = gr.Video(label="mask.mp4")
|
| 613 |
first_frame_output = gr.Image(
|
| 614 |
label="first_frame.png", type="filepath")
|
| 615 |
|
|
|
|
| 616 |
|
|
|
|
| 617 |
generate_btn.click(
|
| 618 |
fn=process_video,
|
| 619 |
inputs=[video_input, camera_movement],
|
|
|
|
| 622 |
motion_signal_output,
|
| 623 |
mask_output,
|
| 624 |
first_frame_output,
|
| 625 |
+
track_viz_output,
|
| 626 |
status_text
|
| 627 |
]
|
| 628 |
).then(
|
| 629 |
+
fn=lambda a, b, c, d, e, f: (b, c, d),
|
|
|
|
|
|
|
| 630 |
inputs=[
|
| 631 |
output_video,
|
| 632 |
motion_signal_output,
|
| 633 |
mask_output,
|
| 634 |
first_frame_output,
|
| 635 |
+
track_viz_output,
|
| 636 |
status_text
|
| 637 |
],
|
| 638 |
outputs=[motion_signal_file, mask_file, first_frame_file]
|
| 639 |
)
|
| 640 |
|
|
|
|
| 641 |
wan_generate_btn.click(
|
| 642 |
fn=run_wan_ttm_generation,
|
| 643 |
inputs=[
|