import gradio as gr import os import numpy as np import cv2 import time import shutil from pathlib import Path from einops import rearrange from typing import Union try: import spaces except ImportError: def spaces(func): return func import torch import torchvision.transforms as T import logging from concurrent.futures import ThreadPoolExecutor import atexit import uuid import decord from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image from models.SpaTrackV2.models.predictor import Predictor from models.SpaTrackV2.models.utils import get_points_on_a_grid from diffusers.utils import export_to_video, load_image from pipelines.wan_pipeline import WanImageToVideoTTMPipeline from pipelines.utils import compute_hw_from_area logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MAX_FRAMES = 81 OUTPUT_FPS = 24 RENDER_WIDTH = 512 RENDER_HEIGHT = 384 WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" CAMERA_MOVEMENTS = [ "static", "move_forward", "move_backward", "move_left", "move_right", "move_up", "move_down" ] thread_pool_executor = ThreadPoolExecutor(max_workers=2) def delete_later(path: Union[str, os.PathLike], delay: int = 600): def _delete(): try: if os.path.isfile(path): os.remove(path) elif os.path.isdir(path): shutil.rmtree(path) except Exception as e: logger.warning(f"Failed to delete {path}: {e}") def _wait_and_delete(): time.sleep(delay) _delete() thread_pool_executor.submit(_wait_and_delete) atexit.register(_delete) def create_user_temp_dir(): session_id = str(uuid.uuid4())[:8] temp_dir = os.path.join("temp_local", f"session_{session_id}") os.makedirs(temp_dir, exist_ok=True) delete_later(temp_dir, delay=600) return temp_dir print("🚀 Initializing tracking models...") vggt4track_model = VGGT4Track.from_pretrained( "Yuxihenry/SpatialTrackerV2_Front") vggt4track_model.eval() if not hasattr(vggt4track_model, 'infer'): vggt4track_model.infer = vggt4track_model.forward tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") tracker_model.eval() wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained( WAN_MODEL_ID, torch_dtype=torch.bfloat16 ) wan_pipeline.vae.enable_tiling() wan_pipeline.vae.enable_slicing() print("✅ Tracking models loaded successfully!") gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"]) def generate_camera_trajectory(num_frames: int, movement_type: str, base_intrinsics: np.ndarray, scene_scale: float = 1.0) -> tuple: speed = scene_scale * 0.02 extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32) for t in range(num_frames): ext = np.eye(4, dtype=np.float32) if movement_type == "move_forward": ext[2, 3] = -speed * t elif movement_type == "move_backward": ext[2, 3] = speed * t elif movement_type == "move_left": ext[0, 3] = -speed * t elif movement_type == "move_right": ext[0, 3] = speed * t elif movement_type == "move_up": ext[1, 3] = -speed * t elif movement_type == "move_down": ext[1, 3] = speed * t extrinsics[t] = ext return extrinsics def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics, new_extrinsics, output_path, fps=24, generate_ttm_inputs=False): T, H, W, _ = rgb_frames.shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) motion_signal_path = mask_path = out_motion_signal = out_mask = None if generate_ttm_inputs: base_dir = os.path.dirname(output_path) motion_signal_path = os.path.join(base_dir, "motion_signal.mp4") mask_path = os.path.join(base_dir, "mask.mp4") out_motion_signal = cv2.VideoWriter( motion_signal_path, fourcc, fps, (W, H)) out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H)) u, v = np.meshgrid(np.arange(W), np.arange(H)) for t in range(T): rgb, depth, K = rgb_frames[t], depth_frames[t], intrinsics[t] orig_c2w = np.linalg.inv(original_extrinsics[t]) if t == 0: base_c2w = orig_c2w.copy() new_c2w = base_c2w @ new_extrinsics[t] new_w2c = np.linalg.inv(new_c2w) K_inv = np.linalg.inv(K) pixels = np.stack([u, v, np.ones_like(u)], axis=-1).reshape(-1, 3) rays_cam = (K_inv @ pixels.T).T points_cam = rays_cam * depth.reshape(-1, 1) points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3] points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3] points_proj = (K @ points_new_cam.T).T z = np.clip(points_proj[:, 2:3], 1e-6, None) uv_new = points_proj[:, :2] / z rendered = np.zeros((H, W, 3), dtype=np.uint8) z_buffer = np.full((H, W), np.inf, dtype=np.float32) colors, depths_new = rgb.reshape(-1, 3), points_new_cam[:, 2] for i in range(len(uv_new)): uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1])) if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0: if depths_new[i] < z_buffer[vv, uu]: z_buffer[vv, uu] = depths_new[i] rendered[vv, uu] = colors[i] valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255 motion_signal_frame = rendered.copy() hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8) if hole_mask.sum() > 0: kernel = np.ones((3, 3), np.uint8) for _ in range(10): # Iterative fill if hole_mask.sum() == 0: break dilated = cv2.dilate(motion_signal_frame, kernel) motion_signal_frame = np.where( hole_mask[:, :, None] > 0, dilated, motion_signal_frame) hole_mask = (motion_signal_frame.sum( axis=-1) == 0).astype(np.uint8) if generate_ttm_inputs: out_motion_signal.write(cv2.cvtColor( motion_signal_frame, cv2.COLOR_RGB2BGR)) out_mask.write(np.stack([valid_mask]*3, axis=-1)) out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR)) out.release() if generate_ttm_inputs: out_motion_signal.release() out_mask.release() return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path} @spaces.GPU def run_spatial_tracker(video_tensor: torch.Tensor): """ GPU-intensive spatial tracking function. Args: video_tensor: Preprocessed video tensor (T, C, H, W) Returns: Dictionary containing tracking results """ global vggt4track_model global tracker_model global wan_pipeline video_input = preprocess_image(video_tensor)[None].cuda() vggt4track_model = vggt4track_model.to("cuda") with torch.no_grad(): with torch.amp.autocast('cuda', dtype=torch.bfloat16): predictions = vggt4track_model(video_input / 255) extrinsic = predictions["poses_pred"] intrinsic = predictions["intrs"] depth_map = predictions["points_map"][..., 2] depth_conf = predictions["unc_metric"] depth_tensor = depth_map.squeeze().cpu().numpy() extrs = extrinsic.squeeze().cpu().numpy() intrs = intrinsic.squeeze().cpu().numpy() video_tensor_gpu = video_input.squeeze() unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 tracker_model.spatrack.track_num = 512 tracker_model.to("cuda") frame_H, frame_W = video_tensor_gpu.shape[2:] grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu") query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[ 0].numpy() with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): ( c2w_traj, intrs_out, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video_out ) = tracker_model.forward( video_tensor_gpu, depth=depth_tensor, intrs=intrs, extrs=extrs, queries=query_xyt, fps=1, full_point=False, iters_track=4, query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric, support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2 ) max_size = 384 h, w = video_out.shape[2:] scale = min(max_size / h, max_size / w) if scale < 1: new_h, new_w = int(h * scale), int(w * scale) video_out = T.Resize((new_h, new_w))(video_out) point_map = T.Resize((new_h, new_w))(point_map) conf_depth = T.Resize((new_h, new_w))(conf_depth) intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale return { 'video_out': video_out.cpu(), 'point_map': point_map.cpu(), 'conf_depth': conf_depth.cpu(), 'intrs_out': intrs_out.cpu(), 'c2w_traj': c2w_traj.cpu(), } @spaces.GPU def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path, motion_video_path, mask_video_path, progress=gr.Progress()): if not first_frame_path or not motion_video_path or not mask_video_path: return None, "❌ TTM Inputs missing. Please run 3D tracking first." progress(0, desc="Loading Wan TTM Pipeline...") import decord vr = decord.VideoReader(motion_video_path) actual_frame_count = len(vr) target_num_frames = ((actual_frame_count - 1) // 4) * 4 + 1 if target_num_frames < 5: return None, f"❌ Video too short. Only {actual_frame_count} frames tracked." logger.info(f"Setting Wan num_frames to {target_num_frames} based on tracking output.") progress(0.2, desc="Preparing inputs...") image = load_image(first_frame_path) negative_prompt = ( "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量," "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的," "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" ) wan_pipeline.to("cuda") max_area = 480 * 832 mod_value = wan_pipeline.vae_scale_factor_spatial * \ wan_pipeline.transformer.config.patch_size[1] height, width = compute_hw_from_area( image.height, image.width, max_area, mod_value) image = image.resize((width, height)) progress(0.4, desc=f"Generating {target_num_frames} frames (this may take a few minutes)...") generator = torch.Generator(device="cuda").manual_seed(0) with torch.inference_mode(): result = wan_pipeline( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=target_num_frames, guidance_scale=3.5, num_inference_steps=50, generator=generator, motion_signal_video_path=motion_video_path, motion_signal_mask_path=mask_video_path, tweak_index=int(tweak_index), tstrong_index=int(tstrong_index), ) output_path = os.path.join(os.path.dirname( first_frame_path), "wan_ttm_output.mp4") export_to_video(result.frames[0], output_path, fps=16) return output_path, f"✅ TTM Video ({target_num_frames} frames) generated successfully!" # --- MODIFIED PROCESS VIDEO TO RETURN FILE PATHS --- def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Progress()): if video_path is None: return None, None, None, None, "❌ Please upload a video first" progress(0, desc="Initializing...") temp_dir = create_user_temp_dir() out_dir = os.path.join(temp_dir, "results") os.makedirs(out_dir, exist_ok=True) try: progress(0.1, desc="Loading video...") video_reader = decord.VideoReader(video_path) video_tensor = torch.from_numpy(video_reader.get_batch( range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2).float() video_tensor = video_tensor[::max( 1, len(video_tensor)//MAX_FRAMES)][:MAX_FRAMES] h, w = video_tensor.shape[2:] scale = 336 / min(h, w) if scale < 1: video_tensor = T.Resize( (int(h*scale)//2*2, int(w*scale)//2*2))(video_tensor) progress(0.4, desc="Running 3D tracking...") tracking_results = run_spatial_tracker(video_tensor) rgb_frames = rearrange( tracking_results['video_out'].numpy(), "T C H W -> T H W C").astype(np.uint8) depth_frames = tracking_results['point_map'][:, 2].numpy() depth_frames[tracking_results['conf_depth'].numpy() < 0.5] = 0 scene_scale = np.median(depth_frames[depth_frames > 0]) if np.any( depth_frames > 0) else 1.0 new_exts = generate_camera_trajectory(len( rgb_frames), camera_movement, tracking_results['intrs_out'].numpy(), scene_scale) progress(0.8, desc="Rendering viewpoint...") output_video_path = os.path.join(out_dir, "rendered_video.mp4") render_results = render_from_pointcloud(rgb_frames, depth_frames, tracking_results['intrs_out'].numpy(), torch.inverse( tracking_results['c2w_traj']).numpy(), new_exts, output_video_path, fps=OUTPUT_FPS, generate_ttm_inputs=generate_ttm) first_frame_path = os.path.join(out_dir, "first_frame.png") cv2.imwrite(first_frame_path, cv2.cvtColor( rgb_frames[0], cv2.COLOR_RGB2BGR)) status_msg = f"✅ 3D results ready! You can now use the prompt below to generate a high-quality TTM video." return render_results['rendered'], render_results['motion_signal'], render_results['mask'], first_frame_path, status_msg except Exception as e: logger.error(f"Error: {e}") return None, None, None, None, f"❌ Error: {str(e)}" # --- GRADIO INTERFACE --- with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo: gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator") gr.Markdown( "Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.") first_frame_file = gr.State("") motion_signal_file = gr.State("") mask_file = gr.State("") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Tracking & Viewpoint") video_input = gr.Video(label="Upload Video") camera_movement = gr.Dropdown( choices=CAMERA_MOVEMENTS, value="static", label="Camera Movement" ) generate_btn = gr.Button( "🚀 1. Run Spatial Tracker", variant="primary") output_video = gr.Video(label="Point Cloud Render (Draft)") status_text = gr.Markdown("Ready...") with gr.Column(scale=1): gr.Markdown("### 2. Time-to-Move (Wan 2.2)") ttm_prompt = gr.Textbox( label="Prompt", placeholder="Describe the scene (e.g., 'A monkey walking in the forest, high quality')" ) with gr.Row(): tweak_idx = gr.Number( label="Tweak Index", value=3, precision=0) tstrong_idx = gr.Number( label="Tstrong Index", value=6, precision=0) wan_generate_btn = gr.Button( "✨ 2. Generate TTM Video (Wan)", variant="secondary") wan_output_video = gr.Video(label="Final High-Quality TTM Video") wan_status = gr.Markdown("Awaiting 3D inputs...") with gr.Accordion("Debug: TTM Intermediate Inputs", open=False): with gr.Row(): motion_signal_output = gr.Video(label="motion_signal.mp4") mask_output = gr.Video(label="mask.mp4") first_frame_output = gr.Image( label="first_frame.png", type="filepath") generate_btn.click( fn=process_video, inputs=[video_input, camera_movement], outputs=[ output_video, motion_signal_output, mask_output, first_frame_output, status_text ] ).then( fn=lambda a, b, c, d, e: (b, c, d), inputs=[ output_video, motion_signal_output, mask_output, first_frame_output, status_text ], outputs=[motion_signal_file, mask_file, first_frame_file] ) wan_generate_btn.click( fn=run_wan_ttm_generation, inputs=[ ttm_prompt, tweak_idx, tstrong_idx, first_frame_file, motion_signal_file, mask_file ], outputs=[wan_output_video, wan_status] ) if __name__ == "__main__": demo.launch(share=False)