Spaces:
Sleeping
Sleeping
revert
Browse files- app.py +98 -1226
- src/config.py +0 -44
- src/model_manager.py +0 -62
- src/spatial_pipeline.py +0 -277
- src/ttm_pipeline.py +0 -303
- src/utils.py +0 -57
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import sys
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -8,109 +7,27 @@ import shutil
|
|
| 8 |
from pathlib import Path
|
| 9 |
from einops import rearrange
|
| 10 |
from typing import Union
|
| 11 |
-
|
| 12 |
-
# Force unbuffered output for HF Spaces logs
|
| 13 |
-
os.environ['PYTHONUNBUFFERED'] = '1'
|
| 14 |
-
|
| 15 |
-
# Configure logging FIRST before any other imports
|
| 16 |
-
import logging
|
| 17 |
-
logging.basicConfig(
|
| 18 |
-
level=logging.INFO,
|
| 19 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 20 |
-
handlers=[
|
| 21 |
-
logging.StreamHandler(sys.stdout)
|
| 22 |
-
]
|
| 23 |
-
)
|
| 24 |
-
logger = logging.getLogger(__name__)
|
| 25 |
-
logger.info("=" * 50)
|
| 26 |
-
logger.info("Starting application initialization...")
|
| 27 |
-
logger.info("=" * 50)
|
| 28 |
-
sys.stdout.flush()
|
| 29 |
-
|
| 30 |
try:
|
| 31 |
import spaces
|
| 32 |
-
logger.info("✅ HF Spaces module imported successfully")
|
| 33 |
except ImportError:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@staticmethod
|
| 37 |
-
def GPU(func=None, duration=None):
|
| 38 |
-
def decorator(f):
|
| 39 |
-
return f
|
| 40 |
-
return decorator if func is None else func
|
| 41 |
-
sys.stdout.flush()
|
| 42 |
-
|
| 43 |
-
logger.info("Importing torch...")
|
| 44 |
-
sys.stdout.flush()
|
| 45 |
import torch
|
| 46 |
-
logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
|
| 47 |
-
sys.stdout.flush()
|
| 48 |
-
|
| 49 |
-
import torch.nn.functional as F
|
| 50 |
import torchvision.transforms as T
|
|
|
|
| 51 |
from concurrent.futures import ThreadPoolExecutor
|
| 52 |
import atexit
|
| 53 |
import uuid
|
| 54 |
-
|
| 55 |
-
logger.info("Importing decord...")
|
| 56 |
-
sys.stdout.flush()
|
| 57 |
import decord
|
| 58 |
-
logger.info("✅ Decord imported successfully")
|
| 59 |
-
sys.stdout.flush()
|
| 60 |
|
| 61 |
-
from
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
| 67 |
-
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
| 68 |
-
from models.SpaTrackV2.models.predictor import Predictor
|
| 69 |
-
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 70 |
-
logger.info("✅ SpaTrack models imported successfully")
|
| 71 |
-
except Exception as e:
|
| 72 |
-
logger.error(f"❌ Failed to import SpaTrack models: {e}")
|
| 73 |
-
raise
|
| 74 |
-
sys.stdout.flush()
|
| 75 |
-
|
| 76 |
-
# TTM imports (optional - will be loaded on demand)
|
| 77 |
-
logger.info("Checking TTM (diffusers) availability...")
|
| 78 |
-
sys.stdout.flush()
|
| 79 |
-
TTM_COG_AVAILABLE = False
|
| 80 |
-
TTM_WAN_AVAILABLE = False
|
| 81 |
-
try:
|
| 82 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
| 83 |
-
from diffusers.utils import export_to_video, load_image
|
| 84 |
-
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 85 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 86 |
-
from diffusers.video_processor import VideoProcessor
|
| 87 |
-
TTM_COG_AVAILABLE = True
|
| 88 |
-
logger.info("✅ CogVideoX TTM available")
|
| 89 |
-
except ImportError as e:
|
| 90 |
-
logger.info(f"ℹ️ CogVideoX TTM not available: {e}")
|
| 91 |
-
sys.stdout.flush()
|
| 92 |
-
|
| 93 |
-
try:
|
| 94 |
-
from diffusers import AutoencoderKLWan, WanTransformer3DModel
|
| 95 |
-
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 96 |
-
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline, retrieve_latents
|
| 97 |
-
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
| 98 |
-
if not TTM_COG_AVAILABLE:
|
| 99 |
-
from diffusers.utils import export_to_video, load_image
|
| 100 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 101 |
-
from diffusers.video_processor import VideoProcessor
|
| 102 |
-
TTM_WAN_AVAILABLE = True
|
| 103 |
-
logger.info("✅ Wan TTM available")
|
| 104 |
-
except ImportError as e:
|
| 105 |
-
logger.info(f"ℹ️ Wan TTM not available: {e}")
|
| 106 |
-
sys.stdout.flush()
|
| 107 |
-
|
| 108 |
-
TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
|
| 109 |
-
if not TTM_AVAILABLE:
|
| 110 |
-
logger.warning("⚠️ Diffusers not available. TTM features will be disabled.")
|
| 111 |
-
else:
|
| 112 |
-
logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}")
|
| 113 |
-
sys.stdout.flush()
|
| 114 |
|
| 115 |
# Constants
|
| 116 |
MAX_FRAMES = 80
|
|
@@ -129,126 +46,9 @@ CAMERA_MOVEMENTS = [
|
|
| 129 |
"move_down"
|
| 130 |
]
|
| 131 |
|
| 132 |
-
# TTM Constants
|
| 133 |
-
TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V"
|
| 134 |
-
TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 135 |
-
TTM_DTYPE = torch.bfloat16
|
| 136 |
-
TTM_DEFAULT_NUM_FRAMES = 49
|
| 137 |
-
TTM_DEFAULT_NUM_INFERENCE_STEPS = 50
|
| 138 |
-
|
| 139 |
-
# TTM Model choices
|
| 140 |
-
TTM_MODELS = []
|
| 141 |
-
if TTM_COG_AVAILABLE:
|
| 142 |
-
TTM_MODELS.append("CogVideoX-5B")
|
| 143 |
-
if TTM_WAN_AVAILABLE:
|
| 144 |
-
TTM_MODELS.append("Wan2.2-14B (Recommended)")
|
| 145 |
-
|
| 146 |
-
# Global model instances (lazy loaded for HF Spaces GPU compatibility)
|
| 147 |
-
vggt4track_model = None
|
| 148 |
-
tracker_model = None
|
| 149 |
-
ttm_cog_pipeline = None
|
| 150 |
-
ttm_wan_pipeline = None
|
| 151 |
-
MODELS_LOADED = False
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def load_video_to_tensor(video_path: str) -> torch.Tensor:
|
| 155 |
-
"""Returns a video tensor from a video file. shape [1, C, T, H, W], [0, 1] range."""
|
| 156 |
-
cap = cv2.VideoCapture(video_path)
|
| 157 |
-
frames = []
|
| 158 |
-
while True:
|
| 159 |
-
ret, frame = cap.read()
|
| 160 |
-
if not ret:
|
| 161 |
-
break
|
| 162 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 163 |
-
frames.append(frame)
|
| 164 |
-
cap.release()
|
| 165 |
-
|
| 166 |
-
frames = np.array(frames)
|
| 167 |
-
video_tensor = torch.tensor(frames)
|
| 168 |
-
video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
|
| 169 |
-
video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
|
| 170 |
-
return video_tensor
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def get_ttm_cog_pipeline():
|
| 174 |
-
"""Lazy load CogVideoX TTM pipeline to save memory."""
|
| 175 |
-
global ttm_cog_pipeline
|
| 176 |
-
if ttm_cog_pipeline is None and TTM_COG_AVAILABLE:
|
| 177 |
-
logger.info("Loading TTM CogVideoX pipeline...")
|
| 178 |
-
ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained(
|
| 179 |
-
TTM_COG_MODEL_ID,
|
| 180 |
-
torch_dtype=TTM_DTYPE,
|
| 181 |
-
low_cpu_mem_usage=True,
|
| 182 |
-
)
|
| 183 |
-
ttm_cog_pipeline.vae.enable_tiling()
|
| 184 |
-
ttm_cog_pipeline.vae.enable_slicing()
|
| 185 |
-
logger.info("TTM CogVideoX pipeline loaded successfully!")
|
| 186 |
-
return ttm_cog_pipeline
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def get_ttm_wan_pipeline():
|
| 190 |
-
"""Lazy load Wan TTM pipeline to save memory."""
|
| 191 |
-
global ttm_wan_pipeline
|
| 192 |
-
if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE:
|
| 193 |
-
logger.info("Loading TTM Wan 2.2 pipeline...")
|
| 194 |
-
ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained(
|
| 195 |
-
TTM_WAN_MODEL_ID,
|
| 196 |
-
torch_dtype=TTM_DTYPE,
|
| 197 |
-
)
|
| 198 |
-
ttm_wan_pipeline.vae.enable_tiling()
|
| 199 |
-
ttm_wan_pipeline.vae.enable_slicing()
|
| 200 |
-
logger.info("TTM Wan 2.2 pipeline loaded successfully!")
|
| 201 |
-
return ttm_wan_pipeline
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
logger.info("Setting up thread pool and utility functions...")
|
| 205 |
-
sys.stdout.flush()
|
| 206 |
-
|
| 207 |
# Thread pool for delayed deletion
|
| 208 |
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 209 |
|
| 210 |
-
|
| 211 |
-
def load_models():
|
| 212 |
-
"""Load models lazily when GPU is available (inside @spaces.GPU decorated function)."""
|
| 213 |
-
global vggt4track_model, tracker_model, MODELS_LOADED
|
| 214 |
-
|
| 215 |
-
if MODELS_LOADED:
|
| 216 |
-
logger.info("Models already loaded, skipping...")
|
| 217 |
-
return
|
| 218 |
-
|
| 219 |
-
logger.info("🚀 Starting model loading...")
|
| 220 |
-
sys.stdout.flush()
|
| 221 |
-
|
| 222 |
-
try:
|
| 223 |
-
logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...")
|
| 224 |
-
sys.stdout.flush()
|
| 225 |
-
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
| 226 |
-
vggt4track_model.eval()
|
| 227 |
-
logger.info("✅ VGGT4Track model loaded, moving to CUDA...")
|
| 228 |
-
sys.stdout.flush()
|
| 229 |
-
vggt4track_model = vggt4track_model.to("cuda")
|
| 230 |
-
logger.info("✅ VGGT4Track model on CUDA")
|
| 231 |
-
sys.stdout.flush()
|
| 232 |
-
|
| 233 |
-
logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...")
|
| 234 |
-
sys.stdout.flush()
|
| 235 |
-
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
| 236 |
-
tracker_model.eval()
|
| 237 |
-
logger.info("✅ Predictor model loaded")
|
| 238 |
-
sys.stdout.flush()
|
| 239 |
-
|
| 240 |
-
MODELS_LOADED = True
|
| 241 |
-
logger.info("✅ All models loaded successfully!")
|
| 242 |
-
sys.stdout.flush()
|
| 243 |
-
|
| 244 |
-
except Exception as e:
|
| 245 |
-
logger.error(f"❌ Failed to load models: {e}")
|
| 246 |
-
import traceback
|
| 247 |
-
traceback.print_exc()
|
| 248 |
-
sys.stdout.flush()
|
| 249 |
-
raise
|
| 250 |
-
|
| 251 |
-
|
| 252 |
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 253 |
"""Delete file or directory after specified delay"""
|
| 254 |
def _delete():
|
|
@@ -267,7 +67,6 @@ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
|
| 267 |
thread_pool_executor.submit(_wait_and_delete)
|
| 268 |
atexit.register(_delete)
|
| 269 |
|
| 270 |
-
|
| 271 |
def create_user_temp_dir():
|
| 272 |
"""Create a unique temporary directory for each user session"""
|
| 273 |
session_id = str(uuid.uuid4())[:8]
|
|
@@ -276,16 +75,17 @@ def create_user_temp_dir():
|
|
| 276 |
delete_later(temp_dir, delay=600)
|
| 277 |
return temp_dir
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
sys.stdout.flush()
|
| 284 |
|
| 285 |
-
logger.info("Setting up Gradio static paths...")
|
| 286 |
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 287 |
-
logger.info("✅ Static paths configured")
|
| 288 |
-
sys.stdout.flush()
|
| 289 |
|
| 290 |
|
| 291 |
def generate_camera_trajectory(num_frames: int, movement_type: str,
|
|
@@ -311,8 +111,7 @@ def generate_camera_trajectory(num_frames: int, movement_type: str,
|
|
| 311 |
if movement_type == "static":
|
| 312 |
pass # Keep identity
|
| 313 |
elif movement_type == "move_forward":
|
| 314 |
-
# Move along -Z (forward in OpenGL convention)
|
| 315 |
-
ext[2, 3] = -speed * t
|
| 316 |
elif movement_type == "move_backward":
|
| 317 |
ext[2, 3] = speed * t # Move along +Z
|
| 318 |
elif movement_type == "move_left":
|
|
@@ -369,8 +168,7 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
|
|
| 369 |
base_dir = os.path.dirname(output_path)
|
| 370 |
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
|
| 371 |
mask_path = os.path.join(base_dir, "mask.mp4")
|
| 372 |
-
out_motion_signal = cv2.VideoWriter(
|
| 373 |
-
motion_signal_path, fourcc, fps, (W, H))
|
| 374 |
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
|
| 375 |
|
| 376 |
# Create meshgrid for pixel coordinates
|
|
@@ -450,21 +248,17 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
|
|
| 450 |
if hole_mask.sum() == 0:
|
| 451 |
break
|
| 452 |
dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
|
| 453 |
-
motion_signal_frame = np.where(
|
| 454 |
-
|
| 455 |
-
hole_mask = (motion_signal_frame.sum(
|
| 456 |
-
axis=-1) == 0).astype(np.uint8)
|
| 457 |
|
| 458 |
# Write TTM outputs if enabled
|
| 459 |
if generate_ttm_inputs:
|
| 460 |
# Motion signal: warped frame with NN inpainting
|
| 461 |
-
motion_signal_bgr = cv2.cvtColor(
|
| 462 |
-
motion_signal_frame, cv2.COLOR_RGB2BGR)
|
| 463 |
out_motion_signal.write(motion_signal_bgr)
|
| 464 |
|
| 465 |
# Mask: binary mask of valid (projected) pixels - white where valid, black where holes
|
| 466 |
-
mask_frame = np.stack(
|
| 467 |
-
[valid_mask, valid_mask, valid_mask], axis=-1)
|
| 468 |
out_mask.write(mask_frame)
|
| 469 |
|
| 470 |
# For the rendered output, use the same inpainted result
|
|
@@ -484,7 +278,7 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
|
|
| 484 |
}
|
| 485 |
|
| 486 |
|
| 487 |
-
@spaces.GPU
|
| 488 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 489 |
"""
|
| 490 |
GPU-intensive spatial tracking function.
|
|
@@ -495,23 +289,9 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 495 |
Returns:
|
| 496 |
Dictionary containing tracking results
|
| 497 |
"""
|
| 498 |
-
global vggt4track_model, tracker_model
|
| 499 |
-
|
| 500 |
-
logger.info("run_spatial_tracker: Starting GPU execution...")
|
| 501 |
-
sys.stdout.flush()
|
| 502 |
-
|
| 503 |
-
# Load models if not already loaded (lazy loading for HF Spaces)
|
| 504 |
-
load_models()
|
| 505 |
-
|
| 506 |
-
logger.info("run_spatial_tracker: Preprocessing video input...")
|
| 507 |
-
sys.stdout.flush()
|
| 508 |
-
|
| 509 |
# Run VGGT to get depth and camera poses
|
| 510 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 511 |
|
| 512 |
-
logger.info("run_spatial_tracker: Running VGGT inference...")
|
| 513 |
-
sys.stdout.flush()
|
| 514 |
-
|
| 515 |
with torch.no_grad():
|
| 516 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 517 |
predictions = vggt4track_model(video_input / 255)
|
|
@@ -520,9 +300,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 520 |
depth_map = predictions["points_map"][..., 2]
|
| 521 |
depth_conf = predictions["unc_metric"]
|
| 522 |
|
| 523 |
-
logger.info("run_spatial_tracker: VGGT inference complete")
|
| 524 |
-
sys.stdout.flush()
|
| 525 |
-
|
| 526 |
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 527 |
extrs = extrinsic.squeeze().cpu().numpy()
|
| 528 |
intrs = intrinsic.squeeze().cpu().numpy()
|
|
@@ -530,20 +307,13 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 530 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 531 |
|
| 532 |
# Setup tracker
|
| 533 |
-
logger.info("run_spatial_tracker: Setting up tracker...")
|
| 534 |
-
sys.stdout.flush()
|
| 535 |
-
|
| 536 |
tracker_model.spatrack.track_num = 512
|
| 537 |
tracker_model.to("cuda")
|
| 538 |
|
| 539 |
# Get grid points for tracking
|
| 540 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 541 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 542 |
-
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
| 543 |
-
0].numpy()
|
| 544 |
-
|
| 545 |
-
logger.info("run_spatial_tracker: Running 3D tracker...")
|
| 546 |
-
sys.stdout.flush()
|
| 547 |
|
| 548 |
# Run tracker
|
| 549 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
@@ -571,11 +341,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 571 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 572 |
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
|
| 573 |
|
| 574 |
-
logger.info("run_spatial_tracker: Moving results to CPU...")
|
| 575 |
-
sys.stdout.flush()
|
| 576 |
-
|
| 577 |
# Move results to CPU and return
|
| 578 |
-
|
| 579 |
'video_out': video_out.cpu(),
|
| 580 |
'point_map': point_map.cpu(),
|
| 581 |
'conf_depth': conf_depth.cpu(),
|
|
@@ -583,11 +350,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
|
|
| 583 |
'c2w_traj': c2w_traj.cpu(),
|
| 584 |
}
|
| 585 |
|
| 586 |
-
logger.info("run_spatial_tracker: Complete!")
|
| 587 |
-
sys.stdout.flush()
|
| 588 |
-
|
| 589 |
-
return result
|
| 590 |
-
|
| 591 |
|
| 592 |
def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
|
| 593 |
"""Main processing function
|
|
@@ -643,8 +405,7 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
|
|
| 643 |
c2w_traj = tracking_results['c2w_traj']
|
| 644 |
|
| 645 |
# Get RGB frames and depth
|
| 646 |
-
rgb_frames = rearrange(
|
| 647 |
-
video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
|
| 648 |
depth_frames = point_map[:, 2].numpy()
|
| 649 |
depth_conf_np = conf_depth.numpy()
|
| 650 |
|
|
@@ -655,8 +416,7 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
|
|
| 655 |
intrs_np = intrs_out.numpy()
|
| 656 |
extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
|
| 657 |
|
| 658 |
-
progress(
|
| 659 |
-
0.7, desc=f"Generating {camera_movement} camera trajectory...")
|
| 660 |
|
| 661 |
# Calculate scene scale from depth
|
| 662 |
valid_depth = depth_frames[depth_frames > 0]
|
|
@@ -711,995 +471,107 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
|
|
| 711 |
return None, None, None, None, f"❌ Error: {str(e)}"
|
| 712 |
|
| 713 |
|
| 714 |
-
# TTM CogVideoX Pipeline Helper Classes and Functions
|
| 715 |
-
class CogVideoXTTMHelper:
|
| 716 |
-
"""Helper class for TTM-style video generation using CogVideoX pipeline."""
|
| 717 |
-
|
| 718 |
-
def __init__(self, pipeline):
|
| 719 |
-
self.pipeline = pipeline
|
| 720 |
-
self.vae = pipeline.vae
|
| 721 |
-
self.transformer = pipeline.transformer
|
| 722 |
-
self.scheduler = pipeline.scheduler
|
| 723 |
-
self.vae_scale_factor_spatial = 2 ** (
|
| 724 |
-
len(self.vae.config.block_out_channels) - 1)
|
| 725 |
-
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
| 726 |
-
self.vae_scaling_factor_image = self.vae.config.scaling_factor
|
| 727 |
-
self.video_processor = pipeline.video_processor
|
| 728 |
-
|
| 729 |
-
@torch.no_grad()
|
| 730 |
-
def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
|
| 731 |
-
"""Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1]."""
|
| 732 |
-
latents = self.vae.encode(frames)[0].sample()
|
| 733 |
-
latents = latents * self.vae_scaling_factor_image
|
| 734 |
-
# (B, C, F, H, W) -> (B, F, C, H, W)
|
| 735 |
-
return latents.permute(0, 2, 1, 3, 4).contiguous()
|
| 736 |
-
|
| 737 |
-
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
| 738 |
-
"""Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
|
| 739 |
-
k = self.vae_scale_factor_temporal
|
| 740 |
-
|
| 741 |
-
mask0 = mask[0:1]
|
| 742 |
-
mask1 = mask[1::k]
|
| 743 |
-
sampled = torch.cat([mask0, mask1], dim=0)
|
| 744 |
-
pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
|
| 745 |
-
|
| 746 |
-
s = self.vae_scale_factor_spatial
|
| 747 |
-
H_latent = pooled.shape[-2] // s
|
| 748 |
-
W_latent = pooled.shape[-1] // s
|
| 749 |
-
pooled = F.interpolate(pooled, size=(
|
| 750 |
-
pooled.shape[2], H_latent, W_latent), mode="nearest")
|
| 751 |
-
|
| 752 |
-
latent_mask = pooled.permute(0, 2, 1, 3, 4)
|
| 753 |
-
return latent_mask
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
# TTM Wan Pipeline Helper Class
|
| 757 |
-
class WanTTMHelper:
|
| 758 |
-
"""Helper class for TTM-style video generation using Wan pipeline."""
|
| 759 |
-
|
| 760 |
-
def __init__(self, pipeline):
|
| 761 |
-
self.pipeline = pipeline
|
| 762 |
-
self.vae = pipeline.vae
|
| 763 |
-
self.transformer = pipeline.transformer
|
| 764 |
-
self.scheduler = pipeline.scheduler
|
| 765 |
-
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
| 766 |
-
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
| 767 |
-
self.video_processor = pipeline.video_processor
|
| 768 |
-
|
| 769 |
-
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
| 770 |
-
"""Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
|
| 771 |
-
k = self.vae_scale_factor_temporal
|
| 772 |
-
|
| 773 |
-
mask0 = mask[0:1]
|
| 774 |
-
mask1 = mask[1::k]
|
| 775 |
-
sampled = torch.cat([mask0, mask1], dim=0)
|
| 776 |
-
pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
|
| 777 |
-
|
| 778 |
-
s = self.vae_scale_factor_spatial
|
| 779 |
-
H_latent = pooled.shape[-2] // s
|
| 780 |
-
W_latent = pooled.shape[-1] // s
|
| 781 |
-
pooled = F.interpolate(pooled, size=(
|
| 782 |
-
pooled.shape[2], H_latent, W_latent), mode="nearest")
|
| 783 |
-
|
| 784 |
-
latent_mask = pooled.permute(0, 2, 1, 3, 4)
|
| 785 |
-
return latent_mask
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
def compute_hw_from_area(image_height: int, image_width: int, max_area: int, mod_value: int) -> tuple:
|
| 789 |
-
"""Compute (height, width) with proper aspect ratio and rounding."""
|
| 790 |
-
aspect_ratio = image_height / image_width
|
| 791 |
-
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
| 792 |
-
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
| 793 |
-
return int(height), int(width)
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
@spaces.GPU(duration=300)
|
| 797 |
-
def run_ttm_cog_generation(
|
| 798 |
-
first_frame_path: str,
|
| 799 |
-
motion_signal_path: str,
|
| 800 |
-
mask_path: str,
|
| 801 |
-
prompt: str,
|
| 802 |
-
tweak_index: int = 4,
|
| 803 |
-
tstrong_index: int = 9,
|
| 804 |
-
num_frames: int = 49,
|
| 805 |
-
num_inference_steps: int = 50,
|
| 806 |
-
guidance_scale: float = 6.0,
|
| 807 |
-
seed: int = 0,
|
| 808 |
-
progress=gr.Progress()
|
| 809 |
-
):
|
| 810 |
-
"""
|
| 811 |
-
Run TTM-style video generation using CogVideoX pipeline.
|
| 812 |
-
Uses the generated motion signal and mask to guide video generation.
|
| 813 |
-
"""
|
| 814 |
-
if not TTM_COG_AVAILABLE:
|
| 815 |
-
return None, "❌ CogVideoX TTM is not available. Please install diffusers package."
|
| 816 |
-
|
| 817 |
-
if first_frame_path is None or motion_signal_path is None or mask_path is None:
|
| 818 |
-
return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
|
| 819 |
-
|
| 820 |
-
progress(0, desc="Loading CogVideoX TTM pipeline...")
|
| 821 |
-
|
| 822 |
-
try:
|
| 823 |
-
# Get or load the pipeline
|
| 824 |
-
pipe = get_ttm_cog_pipeline()
|
| 825 |
-
if pipe is None:
|
| 826 |
-
return None, "❌ Failed to load CogVideoX TTM pipeline"
|
| 827 |
-
|
| 828 |
-
pipe = pipe.to("cuda")
|
| 829 |
-
|
| 830 |
-
# Create helper
|
| 831 |
-
ttm_helper = CogVideoXTTMHelper(pipe)
|
| 832 |
-
|
| 833 |
-
progress(0.1, desc="Loading inputs...")
|
| 834 |
-
|
| 835 |
-
# Load first frame
|
| 836 |
-
image = load_image(first_frame_path)
|
| 837 |
-
|
| 838 |
-
# Get dimensions
|
| 839 |
-
height = pipe.transformer.config.sample_height * \
|
| 840 |
-
ttm_helper.vae_scale_factor_spatial
|
| 841 |
-
width = pipe.transformer.config.sample_width * \
|
| 842 |
-
ttm_helper.vae_scale_factor_spatial
|
| 843 |
-
|
| 844 |
-
device = "cuda"
|
| 845 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 846 |
-
|
| 847 |
-
progress(0.15, desc="Encoding prompt...")
|
| 848 |
-
|
| 849 |
-
# Encode prompt
|
| 850 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 851 |
-
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
| 852 |
-
prompt=prompt,
|
| 853 |
-
negative_prompt="",
|
| 854 |
-
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 855 |
-
num_videos_per_prompt=1,
|
| 856 |
-
max_sequence_length=226,
|
| 857 |
-
device=device,
|
| 858 |
-
)
|
| 859 |
-
if do_classifier_free_guidance:
|
| 860 |
-
prompt_embeds = torch.cat(
|
| 861 |
-
[negative_prompt_embeds, prompt_embeds], dim=0)
|
| 862 |
-
|
| 863 |
-
progress(0.2, desc="Preparing latents...")
|
| 864 |
-
|
| 865 |
-
# Prepare timesteps
|
| 866 |
-
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 867 |
-
timesteps = pipe.scheduler.timesteps
|
| 868 |
-
|
| 869 |
-
# Prepare latents
|
| 870 |
-
latent_frames = (
|
| 871 |
-
num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
|
| 872 |
-
|
| 873 |
-
# Handle padding for CogVideoX 1.5
|
| 874 |
-
patch_size_t = pipe.transformer.config.patch_size_t
|
| 875 |
-
additional_frames = 0
|
| 876 |
-
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 877 |
-
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 878 |
-
num_frames += additional_frames * ttm_helper.vae_scale_factor_temporal
|
| 879 |
-
|
| 880 |
-
# Preprocess image
|
| 881 |
-
image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to(
|
| 882 |
-
device, dtype=prompt_embeds.dtype
|
| 883 |
-
)
|
| 884 |
-
|
| 885 |
-
latent_channels = pipe.transformer.config.in_channels // 2
|
| 886 |
-
latents, image_latents = pipe.prepare_latents(
|
| 887 |
-
image_tensor,
|
| 888 |
-
1, # batch_size
|
| 889 |
-
latent_channels,
|
| 890 |
-
num_frames,
|
| 891 |
-
height,
|
| 892 |
-
width,
|
| 893 |
-
prompt_embeds.dtype,
|
| 894 |
-
device,
|
| 895 |
-
generator,
|
| 896 |
-
None,
|
| 897 |
-
)
|
| 898 |
-
|
| 899 |
-
progress(0.3, desc="Loading motion signal and mask...")
|
| 900 |
-
|
| 901 |
-
# Load motion signal video
|
| 902 |
-
ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
|
| 903 |
-
refB, refC, refT, refH, refW = ref_vid.shape
|
| 904 |
-
ref_vid = F.interpolate(
|
| 905 |
-
ref_vid.permute(0, 2, 1, 3, 4).reshape(
|
| 906 |
-
refB*refT, refC, refH, refW),
|
| 907 |
-
size=(height, width), mode="bicubic", align_corners=True,
|
| 908 |
-
).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
|
| 909 |
-
|
| 910 |
-
ref_vid = ttm_helper.video_processor.normalize(
|
| 911 |
-
ref_vid.to(dtype=pipe.vae.dtype))
|
| 912 |
-
ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
|
| 913 |
-
|
| 914 |
-
# Load mask video
|
| 915 |
-
ref_mask = load_video_to_tensor(mask_path).to(device=device)
|
| 916 |
-
mB, mC, mT, mH, mW = ref_mask.shape
|
| 917 |
-
ref_mask = F.interpolate(
|
| 918 |
-
ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
|
| 919 |
-
size=(height, width), mode="nearest",
|
| 920 |
-
).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
|
| 921 |
-
ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous()
|
| 922 |
-
|
| 923 |
-
if len(ref_mask.shape) == 4:
|
| 924 |
-
ref_mask = ref_mask.unsqueeze(0)
|
| 925 |
-
|
| 926 |
-
ref_mask = ref_mask[0, :, :1].contiguous()
|
| 927 |
-
ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0]
|
| 928 |
-
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(ref_mask)
|
| 929 |
-
background_mask = 1.0 - motion_mask
|
| 930 |
-
|
| 931 |
-
progress(0.35, desc="Initializing TTM denoising...")
|
| 932 |
-
|
| 933 |
-
# Initialize with noisy reference latents at tweak timestep
|
| 934 |
-
if tweak_index >= 0:
|
| 935 |
-
tweak = timesteps[tweak_index]
|
| 936 |
-
fixed_noise = randn_tensor(
|
| 937 |
-
ref_latents.shape,
|
| 938 |
-
generator=generator,
|
| 939 |
-
device=ref_latents.device,
|
| 940 |
-
dtype=ref_latents.dtype,
|
| 941 |
-
)
|
| 942 |
-
noisy_latents = pipe.scheduler.add_noise(
|
| 943 |
-
ref_latents, fixed_noise, tweak.long())
|
| 944 |
-
latents = noisy_latents.to(
|
| 945 |
-
dtype=latents.dtype, device=latents.device)
|
| 946 |
-
else:
|
| 947 |
-
fixed_noise = randn_tensor(
|
| 948 |
-
ref_latents.shape,
|
| 949 |
-
generator=generator,
|
| 950 |
-
device=ref_latents.device,
|
| 951 |
-
dtype=ref_latents.dtype,
|
| 952 |
-
)
|
| 953 |
-
tweak_index = 0
|
| 954 |
-
|
| 955 |
-
# Prepare extra step kwargs
|
| 956 |
-
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
|
| 957 |
-
|
| 958 |
-
# Create rotary embeddings if required
|
| 959 |
-
image_rotary_emb = (
|
| 960 |
-
pipe._prepare_rotary_positional_embeddings(
|
| 961 |
-
height, width, latents.size(1), device)
|
| 962 |
-
if pipe.transformer.config.use_rotary_positional_embeddings
|
| 963 |
-
else None
|
| 964 |
-
)
|
| 965 |
-
|
| 966 |
-
# Create ofs embeddings if required
|
| 967 |
-
ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full(
|
| 968 |
-
(1,), fill_value=2.0)
|
| 969 |
-
|
| 970 |
-
progress(0.4, desc="Running TTM denoising loop...")
|
| 971 |
-
|
| 972 |
-
# Denoising loop
|
| 973 |
-
total_steps = len(timesteps) - tweak_index
|
| 974 |
-
old_pred_original_sample = None
|
| 975 |
-
|
| 976 |
-
for i, t in enumerate(timesteps[tweak_index:]):
|
| 977 |
-
step_progress = 0.4 + 0.5 * (i / total_steps)
|
| 978 |
-
progress(step_progress,
|
| 979 |
-
desc=f"Denoising step {i+1}/{total_steps}...")
|
| 980 |
-
|
| 981 |
-
latent_model_input = torch.cat(
|
| 982 |
-
[latents] * 2) if do_classifier_free_guidance else latents
|
| 983 |
-
latent_model_input = pipe.scheduler.scale_model_input(
|
| 984 |
-
latent_model_input, t)
|
| 985 |
-
|
| 986 |
-
latent_image_input = torch.cat(
|
| 987 |
-
[image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 988 |
-
latent_model_input = torch.cat(
|
| 989 |
-
[latent_model_input, latent_image_input], dim=2)
|
| 990 |
-
|
| 991 |
-
timestep = t.expand(latent_model_input.shape[0])
|
| 992 |
-
|
| 993 |
-
# Predict noise
|
| 994 |
-
noise_pred = pipe.transformer(
|
| 995 |
-
hidden_states=latent_model_input,
|
| 996 |
-
encoder_hidden_states=prompt_embeds,
|
| 997 |
-
timestep=timestep,
|
| 998 |
-
ofs=ofs_emb,
|
| 999 |
-
image_rotary_emb=image_rotary_emb,
|
| 1000 |
-
return_dict=False,
|
| 1001 |
-
)[0]
|
| 1002 |
-
noise_pred = noise_pred.float()
|
| 1003 |
-
|
| 1004 |
-
# Perform guidance
|
| 1005 |
-
if do_classifier_free_guidance:
|
| 1006 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1007 |
-
noise_pred = noise_pred_uncond + guidance_scale * \
|
| 1008 |
-
(noise_pred_text - noise_pred_uncond)
|
| 1009 |
-
|
| 1010 |
-
# Compute previous noisy sample
|
| 1011 |
-
if not isinstance(pipe.scheduler, CogVideoXDPMScheduler):
|
| 1012 |
-
latents, old_pred_original_sample = pipe.scheduler.step(
|
| 1013 |
-
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
| 1014 |
-
)
|
| 1015 |
-
else:
|
| 1016 |
-
latents, old_pred_original_sample = pipe.scheduler.step(
|
| 1017 |
-
noise_pred,
|
| 1018 |
-
old_pred_original_sample,
|
| 1019 |
-
t,
|
| 1020 |
-
timesteps[i - 1] if i > 0 else None,
|
| 1021 |
-
latents,
|
| 1022 |
-
**extra_step_kwargs,
|
| 1023 |
-
return_dict=False,
|
| 1024 |
-
)
|
| 1025 |
-
|
| 1026 |
-
# TTM: In between tweak and tstrong, replace mask with noisy reference latents
|
| 1027 |
-
in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
|
| 1028 |
-
|
| 1029 |
-
if in_between_tweak_tstrong:
|
| 1030 |
-
if i + tweak_index + 1 < len(timesteps):
|
| 1031 |
-
prev_t = timesteps[i + tweak_index + 1]
|
| 1032 |
-
noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
|
| 1033 |
-
dtype=latents.dtype, device=latents.device
|
| 1034 |
-
)
|
| 1035 |
-
latents = latents * background_mask + noisy_latents * motion_mask
|
| 1036 |
-
else:
|
| 1037 |
-
latents = latents * background_mask + ref_latents * motion_mask
|
| 1038 |
-
|
| 1039 |
-
latents = latents.to(prompt_embeds.dtype)
|
| 1040 |
-
|
| 1041 |
-
progress(0.9, desc="Decoding video...")
|
| 1042 |
-
|
| 1043 |
-
# Decode latents
|
| 1044 |
-
latents = latents[:, additional_frames:]
|
| 1045 |
-
frames = pipe.decode_latents(latents)
|
| 1046 |
-
video = ttm_helper.video_processor.postprocess_video(
|
| 1047 |
-
video=frames, output_type="pil")
|
| 1048 |
-
|
| 1049 |
-
progress(0.95, desc="Saving video...")
|
| 1050 |
-
|
| 1051 |
-
# Save video
|
| 1052 |
-
temp_dir = create_user_temp_dir()
|
| 1053 |
-
output_path = os.path.join(temp_dir, "ttm_output.mp4")
|
| 1054 |
-
export_to_video(video[0], output_path, fps=8)
|
| 1055 |
-
|
| 1056 |
-
progress(1.0, desc="Done!")
|
| 1057 |
-
|
| 1058 |
-
return output_path, f"✅ CogVideoX TTM video generated successfully!\n\n**Parameters:**\n- Model: CogVideoX-5B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
|
| 1059 |
-
|
| 1060 |
-
except Exception as e:
|
| 1061 |
-
logger.error(f"Error in CogVideoX TTM generation: {e}")
|
| 1062 |
-
import traceback
|
| 1063 |
-
traceback.print_exc()
|
| 1064 |
-
return None, f"❌ Error: {str(e)}"
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
@spaces.GPU(duration=300)
|
| 1068 |
-
def run_ttm_wan_generation(
|
| 1069 |
-
first_frame_path: str,
|
| 1070 |
-
motion_signal_path: str,
|
| 1071 |
-
mask_path: str,
|
| 1072 |
-
prompt: str,
|
| 1073 |
-
negative_prompt: str = "",
|
| 1074 |
-
tweak_index: int = 3,
|
| 1075 |
-
tstrong_index: int = 7,
|
| 1076 |
-
num_frames: int = 81,
|
| 1077 |
-
num_inference_steps: int = 50,
|
| 1078 |
-
guidance_scale: float = 3.5,
|
| 1079 |
-
seed: int = 0,
|
| 1080 |
-
progress=gr.Progress()
|
| 1081 |
-
):
|
| 1082 |
-
"""
|
| 1083 |
-
Run TTM-style video generation using Wan 2.2 pipeline.
|
| 1084 |
-
This is the recommended model for TTM as it produces higher-quality results.
|
| 1085 |
-
"""
|
| 1086 |
-
if not TTM_WAN_AVAILABLE:
|
| 1087 |
-
return None, "❌ Wan TTM is not available. Please install diffusers with Wan support."
|
| 1088 |
-
|
| 1089 |
-
if first_frame_path is None or motion_signal_path is None or mask_path is None:
|
| 1090 |
-
return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
|
| 1091 |
-
|
| 1092 |
-
progress(0, desc="Loading Wan 2.2 TTM pipeline...")
|
| 1093 |
-
|
| 1094 |
-
try:
|
| 1095 |
-
# Get or load the pipeline
|
| 1096 |
-
pipe = get_ttm_wan_pipeline()
|
| 1097 |
-
if pipe is None:
|
| 1098 |
-
return None, "❌ Failed to load Wan TTM pipeline"
|
| 1099 |
-
|
| 1100 |
-
pipe = pipe.to("cuda")
|
| 1101 |
-
|
| 1102 |
-
# Create helper
|
| 1103 |
-
ttm_helper = WanTTMHelper(pipe)
|
| 1104 |
-
|
| 1105 |
-
progress(0.1, desc="Loading inputs...")
|
| 1106 |
-
|
| 1107 |
-
# Load first frame
|
| 1108 |
-
image = load_image(first_frame_path)
|
| 1109 |
-
|
| 1110 |
-
# Get dimensions - compute based on image aspect ratio
|
| 1111 |
-
max_area = 480 * 832
|
| 1112 |
-
mod_value = ttm_helper.vae_scale_factor_spatial * \
|
| 1113 |
-
pipe.transformer.config.patch_size[1]
|
| 1114 |
-
height, width = compute_hw_from_area(
|
| 1115 |
-
image.height, image.width, max_area, mod_value)
|
| 1116 |
-
image = image.resize((width, height))
|
| 1117 |
-
|
| 1118 |
-
device = "cuda"
|
| 1119 |
-
gen_device = device if device.startswith("cuda") else "cpu"
|
| 1120 |
-
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
| 1121 |
-
|
| 1122 |
-
progress(0.15, desc="Encoding prompt...")
|
| 1123 |
-
|
| 1124 |
-
# Encode prompt
|
| 1125 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1126 |
-
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
| 1127 |
-
prompt=prompt,
|
| 1128 |
-
negative_prompt=negative_prompt if negative_prompt else None,
|
| 1129 |
-
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 1130 |
-
num_videos_per_prompt=1,
|
| 1131 |
-
max_sequence_length=512,
|
| 1132 |
-
device=device,
|
| 1133 |
-
)
|
| 1134 |
-
|
| 1135 |
-
# Get transformer dtype
|
| 1136 |
-
transformer_dtype = pipe.transformer.dtype
|
| 1137 |
-
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 1138 |
-
if negative_prompt_embeds is not None:
|
| 1139 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
| 1140 |
-
transformer_dtype)
|
| 1141 |
-
|
| 1142 |
-
# Encode image embedding if transformer supports it
|
| 1143 |
-
image_embeds = None
|
| 1144 |
-
if pipe.transformer.config.image_dim is not None:
|
| 1145 |
-
image_embeds = pipe.encode_image(image, device)
|
| 1146 |
-
image_embeds = image_embeds.repeat(1, 1, 1)
|
| 1147 |
-
image_embeds = image_embeds.to(transformer_dtype)
|
| 1148 |
-
|
| 1149 |
-
progress(0.2, desc="Preparing latents...")
|
| 1150 |
-
|
| 1151 |
-
# Prepare timesteps
|
| 1152 |
-
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 1153 |
-
timesteps = pipe.scheduler.timesteps
|
| 1154 |
-
|
| 1155 |
-
# Adjust num_frames to be valid for VAE
|
| 1156 |
-
if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
|
| 1157 |
-
num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
|
| 1158 |
-
ttm_helper.vae_scale_factor_temporal + 1
|
| 1159 |
-
num_frames = max(num_frames, 1)
|
| 1160 |
-
|
| 1161 |
-
# Prepare latent variables
|
| 1162 |
-
num_channels_latents = pipe.vae.config.z_dim
|
| 1163 |
-
image_tensor = ttm_helper.video_processor.preprocess(
|
| 1164 |
-
image, height=height, width=width).to(device, dtype=torch.float32)
|
| 1165 |
-
|
| 1166 |
-
latents_outputs = pipe.prepare_latents(
|
| 1167 |
-
image_tensor,
|
| 1168 |
-
1, # batch_size
|
| 1169 |
-
num_channels_latents,
|
| 1170 |
-
height,
|
| 1171 |
-
width,
|
| 1172 |
-
num_frames,
|
| 1173 |
-
torch.float32,
|
| 1174 |
-
device,
|
| 1175 |
-
generator,
|
| 1176 |
-
None,
|
| 1177 |
-
None, # last_image
|
| 1178 |
-
)
|
| 1179 |
-
|
| 1180 |
-
if hasattr(pipe, 'config') and pipe.config.expand_timesteps:
|
| 1181 |
-
latents, condition, first_frame_mask = latents_outputs
|
| 1182 |
-
else:
|
| 1183 |
-
latents, condition = latents_outputs
|
| 1184 |
-
first_frame_mask = None
|
| 1185 |
-
|
| 1186 |
-
progress(0.3, desc="Loading motion signal and mask...")
|
| 1187 |
-
|
| 1188 |
-
# Load motion signal video
|
| 1189 |
-
ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
|
| 1190 |
-
refB, refC, refT, refH, refW = ref_vid.shape
|
| 1191 |
-
ref_vid = F.interpolate(
|
| 1192 |
-
ref_vid.permute(0, 2, 1, 3, 4).reshape(
|
| 1193 |
-
refB*refT, refC, refH, refW),
|
| 1194 |
-
size=(height, width), mode="bicubic", align_corners=True,
|
| 1195 |
-
).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
|
| 1196 |
-
|
| 1197 |
-
ref_vid = ttm_helper.video_processor.normalize(
|
| 1198 |
-
ref_vid.to(dtype=pipe.vae.dtype))
|
| 1199 |
-
ref_latents = retrieve_latents(
|
| 1200 |
-
pipe.vae.encode(ref_vid), sample_mode="argmax")
|
| 1201 |
-
|
| 1202 |
-
# Normalize latents
|
| 1203 |
-
latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
|
| 1204 |
-
1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
|
| 1205 |
-
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
|
| 1206 |
-
1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
|
| 1207 |
-
ref_latents = (ref_latents - latents_mean) * latents_std
|
| 1208 |
-
|
| 1209 |
-
# Load mask video
|
| 1210 |
-
ref_mask = load_video_to_tensor(mask_path).to(device=device)
|
| 1211 |
-
mB, mC, mT, mH, mW = ref_mask.shape
|
| 1212 |
-
ref_mask = F.interpolate(
|
| 1213 |
-
ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
|
| 1214 |
-
size=(height, width), mode="nearest",
|
| 1215 |
-
).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
|
| 1216 |
-
mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()
|
| 1217 |
-
|
| 1218 |
-
# Align time dimension
|
| 1219 |
-
if mask_tc_hw.shape[0] > num_frames:
|
| 1220 |
-
mask_tc_hw = mask_tc_hw[:num_frames]
|
| 1221 |
-
elif mask_tc_hw.shape[0] < num_frames:
|
| 1222 |
-
return None, f"❌ num_frames ({num_frames}) > mask frames ({mask_tc_hw.shape[0]}). Please use more mask frames."
|
| 1223 |
-
|
| 1224 |
-
# Reduce channels if needed
|
| 1225 |
-
if mask_tc_hw.shape[1] > 1:
|
| 1226 |
-
mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float()
|
| 1227 |
-
else:
|
| 1228 |
-
mask_t1_hw = (mask_tc_hw > 0.5).float()
|
| 1229 |
-
|
| 1230 |
-
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
|
| 1231 |
-
mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
|
| 1232 |
-
background_mask = 1.0 - motion_mask
|
| 1233 |
-
|
| 1234 |
-
progress(0.35, desc="Initializing TTM denoising...")
|
| 1235 |
-
|
| 1236 |
-
# Initialize with noisy reference latents at tweak timestep
|
| 1237 |
-
if tweak_index >= 0 and tweak_index < len(timesteps):
|
| 1238 |
-
tweak = timesteps[tweak_index]
|
| 1239 |
-
fixed_noise = randn_tensor(
|
| 1240 |
-
ref_latents.shape,
|
| 1241 |
-
generator=generator,
|
| 1242 |
-
device=ref_latents.device,
|
| 1243 |
-
dtype=ref_latents.dtype,
|
| 1244 |
-
)
|
| 1245 |
-
tweak_t = torch.as_tensor(
|
| 1246 |
-
tweak, device=ref_latents.device, dtype=torch.long).view(1)
|
| 1247 |
-
noisy_latents = pipe.scheduler.add_noise(
|
| 1248 |
-
ref_latents, fixed_noise, tweak_t.long())
|
| 1249 |
-
latents = noisy_latents.to(
|
| 1250 |
-
dtype=latents.dtype, device=latents.device)
|
| 1251 |
-
else:
|
| 1252 |
-
fixed_noise = randn_tensor(
|
| 1253 |
-
ref_latents.shape,
|
| 1254 |
-
generator=generator,
|
| 1255 |
-
device=ref_latents.device,
|
| 1256 |
-
dtype=ref_latents.dtype,
|
| 1257 |
-
)
|
| 1258 |
-
tweak_index = 0
|
| 1259 |
-
|
| 1260 |
-
progress(0.4, desc="Running TTM denoising loop...")
|
| 1261 |
-
|
| 1262 |
-
# Denoising loop
|
| 1263 |
-
total_steps = len(timesteps) - tweak_index
|
| 1264 |
-
|
| 1265 |
-
for i, t in enumerate(timesteps[tweak_index:]):
|
| 1266 |
-
step_progress = 0.4 + 0.5 * (i / total_steps)
|
| 1267 |
-
progress(step_progress,
|
| 1268 |
-
desc=f"Denoising step {i+1}/{total_steps}...")
|
| 1269 |
-
|
| 1270 |
-
# Prepare model input
|
| 1271 |
-
if first_frame_mask is not None:
|
| 1272 |
-
latent_model_input = (1 - first_frame_mask) * \
|
| 1273 |
-
condition + first_frame_mask * latents
|
| 1274 |
-
latent_model_input = latent_model_input.to(transformer_dtype)
|
| 1275 |
-
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
|
| 1276 |
-
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
| 1277 |
-
else:
|
| 1278 |
-
latent_model_input = torch.cat(
|
| 1279 |
-
[latents, condition], dim=1).to(transformer_dtype)
|
| 1280 |
-
timestep = t.expand(latents.shape[0])
|
| 1281 |
-
|
| 1282 |
-
# Predict noise (conditional)
|
| 1283 |
-
noise_pred = pipe.transformer(
|
| 1284 |
-
hidden_states=latent_model_input,
|
| 1285 |
-
timestep=timestep,
|
| 1286 |
-
encoder_hidden_states=prompt_embeds,
|
| 1287 |
-
encoder_hidden_states_image=image_embeds,
|
| 1288 |
-
return_dict=False,
|
| 1289 |
-
)[0]
|
| 1290 |
-
|
| 1291 |
-
# CFG
|
| 1292 |
-
if do_classifier_free_guidance:
|
| 1293 |
-
noise_uncond = pipe.transformer(
|
| 1294 |
-
hidden_states=latent_model_input,
|
| 1295 |
-
timestep=timestep,
|
| 1296 |
-
encoder_hidden_states=negative_prompt_embeds,
|
| 1297 |
-
encoder_hidden_states_image=image_embeds,
|
| 1298 |
-
return_dict=False,
|
| 1299 |
-
)[0]
|
| 1300 |
-
noise_pred = noise_uncond + guidance_scale * \
|
| 1301 |
-
(noise_pred - noise_uncond)
|
| 1302 |
-
|
| 1303 |
-
# Scheduler step
|
| 1304 |
-
latents = pipe.scheduler.step(
|
| 1305 |
-
noise_pred, t, latents, return_dict=False)[0]
|
| 1306 |
-
|
| 1307 |
-
# TTM: In between tweak and tstrong, replace mask with noisy reference latents
|
| 1308 |
-
in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
|
| 1309 |
-
|
| 1310 |
-
if in_between_tweak_tstrong:
|
| 1311 |
-
if i + tweak_index + 1 < len(timesteps):
|
| 1312 |
-
prev_t = timesteps[i + tweak_index + 1]
|
| 1313 |
-
prev_t = torch.as_tensor(
|
| 1314 |
-
prev_t, device=ref_latents.device, dtype=torch.long).view(1)
|
| 1315 |
-
noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
|
| 1316 |
-
dtype=latents.dtype, device=latents.device
|
| 1317 |
-
)
|
| 1318 |
-
latents = latents * background_mask + noisy_latents * motion_mask
|
| 1319 |
-
else:
|
| 1320 |
-
latents = latents * background_mask + \
|
| 1321 |
-
ref_latents.to(dtype=latents.dtype,
|
| 1322 |
-
device=latents.device) * motion_mask
|
| 1323 |
-
|
| 1324 |
-
progress(0.9, desc="Decoding video...")
|
| 1325 |
-
|
| 1326 |
-
# Apply first frame mask if used
|
| 1327 |
-
if first_frame_mask is not None:
|
| 1328 |
-
latents = (1 - first_frame_mask) * condition + \
|
| 1329 |
-
first_frame_mask * latents
|
| 1330 |
-
|
| 1331 |
-
# Decode latents
|
| 1332 |
-
latents = latents.to(pipe.vae.dtype)
|
| 1333 |
-
latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
|
| 1334 |
-
1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
| 1335 |
-
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
|
| 1336 |
-
1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
| 1337 |
-
latents = latents / latents_std + latents_mean
|
| 1338 |
-
video = pipe.vae.decode(latents, return_dict=False)[0]
|
| 1339 |
-
video = ttm_helper.video_processor.postprocess_video(
|
| 1340 |
-
video, output_type="pil")
|
| 1341 |
-
|
| 1342 |
-
progress(0.95, desc="Saving video...")
|
| 1343 |
-
|
| 1344 |
-
# Save video
|
| 1345 |
-
temp_dir = create_user_temp_dir()
|
| 1346 |
-
output_path = os.path.join(temp_dir, "ttm_wan_output.mp4")
|
| 1347 |
-
export_to_video(video[0], output_path, fps=16)
|
| 1348 |
-
|
| 1349 |
-
progress(1.0, desc="Done!")
|
| 1350 |
-
|
| 1351 |
-
return output_path, f"✅ Wan 2.2 TTM video generated successfully!\n\n**Parameters:**\n- Model: Wan2.2-14B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
|
| 1352 |
-
|
| 1353 |
-
except Exception as e:
|
| 1354 |
-
logger.error(f"Error in Wan TTM generation: {e}")
|
| 1355 |
-
import traceback
|
| 1356 |
-
traceback.print_exc()
|
| 1357 |
-
return None, f"❌ Error: {str(e)}"
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
def run_ttm_generation(
|
| 1361 |
-
first_frame_path: str,
|
| 1362 |
-
motion_signal_path: str,
|
| 1363 |
-
mask_path: str,
|
| 1364 |
-
prompt: str,
|
| 1365 |
-
negative_prompt: str,
|
| 1366 |
-
model_choice: str,
|
| 1367 |
-
tweak_index: int,
|
| 1368 |
-
tstrong_index: int,
|
| 1369 |
-
num_frames: int,
|
| 1370 |
-
num_inference_steps: int,
|
| 1371 |
-
guidance_scale: float,
|
| 1372 |
-
seed: int,
|
| 1373 |
-
progress=gr.Progress()
|
| 1374 |
-
):
|
| 1375 |
-
"""
|
| 1376 |
-
Router function that calls the appropriate TTM generation based on model choice.
|
| 1377 |
-
"""
|
| 1378 |
-
if "Wan" in model_choice:
|
| 1379 |
-
return run_ttm_wan_generation(
|
| 1380 |
-
first_frame_path=first_frame_path,
|
| 1381 |
-
motion_signal_path=motion_signal_path,
|
| 1382 |
-
mask_path=mask_path,
|
| 1383 |
-
prompt=prompt,
|
| 1384 |
-
negative_prompt=negative_prompt,
|
| 1385 |
-
tweak_index=tweak_index,
|
| 1386 |
-
tstrong_index=tstrong_index,
|
| 1387 |
-
num_frames=num_frames,
|
| 1388 |
-
num_inference_steps=num_inference_steps,
|
| 1389 |
-
guidance_scale=guidance_scale,
|
| 1390 |
-
seed=seed,
|
| 1391 |
-
progress=progress,
|
| 1392 |
-
)
|
| 1393 |
-
else:
|
| 1394 |
-
return run_ttm_cog_generation(
|
| 1395 |
-
first_frame_path=first_frame_path,
|
| 1396 |
-
motion_signal_path=motion_signal_path,
|
| 1397 |
-
mask_path=mask_path,
|
| 1398 |
-
prompt=prompt,
|
| 1399 |
-
tweak_index=tweak_index,
|
| 1400 |
-
tstrong_index=tstrong_index,
|
| 1401 |
-
num_frames=num_frames,
|
| 1402 |
-
num_inference_steps=num_inference_steps,
|
| 1403 |
-
guidance_scale=guidance_scale,
|
| 1404 |
-
seed=seed,
|
| 1405 |
-
progress=progress,
|
| 1406 |
-
)
|
| 1407 |
-
|
| 1408 |
-
|
| 1409 |
# Create Gradio interface
|
| 1410 |
-
|
| 1411 |
-
sys.stdout.flush()
|
| 1412 |
|
| 1413 |
with gr.Blocks(
|
| 1414 |
theme=gr.themes.Soft(),
|
| 1415 |
title="🎬 Video to Point Cloud Renderer",
|
| 1416 |
css="""
|
| 1417 |
.gradio-container {
|
| 1418 |
-
max-width:
|
| 1419 |
margin: auto !important;
|
| 1420 |
}
|
| 1421 |
"""
|
| 1422 |
) as demo:
|
| 1423 |
gr.Markdown("""
|
| 1424 |
-
# 🎬 Video to Point Cloud Renderer
|
| 1425 |
|
| 1426 |
-
Upload a video to generate a 3D point cloud
|
| 1427 |
-
|
| 1428 |
|
| 1429 |
-
**
|
| 1430 |
-
1.
|
| 1431 |
-
2.
|
|
|
|
| 1432 |
|
| 1433 |
-
**TTM
|
| 1434 |
-
- `first_frame.png`:
|
| 1435 |
-
- `motion_signal.mp4`: Warped video
|
| 1436 |
-
- `mask.mp4`: Binary mask
|
| 1437 |
""")
|
| 1438 |
|
| 1439 |
-
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
|
| 1469 |
-
|
| 1470 |
-
|
| 1471 |
-
|
| 1472 |
-
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
with gr.Column(scale=1):
|
| 1490 |
-
gr.Markdown("### 🎭 TTM: Mask")
|
| 1491 |
-
mask_output = gr.Video(
|
| 1492 |
-
label="Mask Video (mask.mp4)",
|
| 1493 |
-
height=250
|
| 1494 |
-
)
|
| 1495 |
-
|
| 1496 |
-
status_text = gr.Markdown("Ready to process...")
|
| 1497 |
-
|
| 1498 |
-
with gr.Tab("🎬 Step 2: TTM Video Generation"):
|
| 1499 |
-
cog_available = "✅" if TTM_COG_AVAILABLE else "❌"
|
| 1500 |
-
wan_available = "✅" if TTM_WAN_AVAILABLE else "❌"
|
| 1501 |
-
gr.Markdown(f"""
|
| 1502 |
-
### 🎬 Time-to-Move (TTM) Video Generation
|
| 1503 |
-
|
| 1504 |
-
**Model Availability:**
|
| 1505 |
-
- {cog_available} CogVideoX-5B-I2V
|
| 1506 |
-
- {wan_available} Wan 2.2-14B (Recommended - higher quality)
|
| 1507 |
-
|
| 1508 |
-
**TTM Parameters:**
|
| 1509 |
-
- **tweak_index**: When denoising starts *outside* the mask (lower = more dynamic background)
|
| 1510 |
-
- **tstrong_index**: When denoising starts *inside* the mask (higher = more constrained motion)
|
| 1511 |
-
|
| 1512 |
-
**Recommended values:**
|
| 1513 |
-
- CogVideoX - Cut-and-Drag: `tweak_index=4`, `tstrong_index=9`
|
| 1514 |
-
- CogVideoX - Camera control: `tweak_index=3`, `tstrong_index=7`
|
| 1515 |
-
- **Wan 2.2 (Recommended)**: `tweak_index=3`, `tstrong_index=7`
|
| 1516 |
-
""")
|
| 1517 |
-
|
| 1518 |
-
with gr.Row():
|
| 1519 |
-
with gr.Column(scale=1):
|
| 1520 |
-
gr.Markdown("### ⚙️ TTM Settings")
|
| 1521 |
-
|
| 1522 |
-
ttm_model_choice = gr.Dropdown(
|
| 1523 |
-
choices=TTM_MODELS,
|
| 1524 |
-
value=TTM_MODELS[1] if TTM_WAN_AVAILABLE else TTM_MODELS[0],
|
| 1525 |
-
label="Model",
|
| 1526 |
-
info="Wan 2.2 is recommended for higher quality"
|
| 1527 |
-
)
|
| 1528 |
-
|
| 1529 |
-
ttm_prompt = gr.Textbox(
|
| 1530 |
-
label="Prompt",
|
| 1531 |
-
placeholder="Describe the video content...",
|
| 1532 |
-
value="A high quality video, smooth motion, natural lighting",
|
| 1533 |
-
lines=2
|
| 1534 |
-
)
|
| 1535 |
-
|
| 1536 |
-
ttm_negative_prompt = gr.Textbox(
|
| 1537 |
-
label="Negative Prompt (Wan only)",
|
| 1538 |
-
placeholder="Things to avoid in the video...",
|
| 1539 |
-
value="",
|
| 1540 |
-
lines=1,
|
| 1541 |
-
visible=True
|
| 1542 |
-
)
|
| 1543 |
-
|
| 1544 |
-
with gr.Row():
|
| 1545 |
-
ttm_tweak_index = gr.Slider(
|
| 1546 |
-
minimum=0,
|
| 1547 |
-
maximum=20,
|
| 1548 |
-
value=3,
|
| 1549 |
-
step=1,
|
| 1550 |
-
label="tweak_index",
|
| 1551 |
-
info="When background denoising starts"
|
| 1552 |
-
)
|
| 1553 |
-
ttm_tstrong_index = gr.Slider(
|
| 1554 |
-
minimum=0,
|
| 1555 |
-
maximum=30,
|
| 1556 |
-
value=7,
|
| 1557 |
-
step=1,
|
| 1558 |
-
label="tstrong_index",
|
| 1559 |
-
info="When mask region denoising starts"
|
| 1560 |
-
)
|
| 1561 |
-
|
| 1562 |
-
with gr.Row():
|
| 1563 |
-
ttm_num_frames = gr.Slider(
|
| 1564 |
-
minimum=17,
|
| 1565 |
-
maximum=81,
|
| 1566 |
-
value=49,
|
| 1567 |
-
step=4,
|
| 1568 |
-
label="Number of Frames"
|
| 1569 |
-
)
|
| 1570 |
-
ttm_guidance_scale = gr.Slider(
|
| 1571 |
-
minimum=1.0,
|
| 1572 |
-
maximum=15.0,
|
| 1573 |
-
value=3.5,
|
| 1574 |
-
step=0.5,
|
| 1575 |
-
label="Guidance Scale"
|
| 1576 |
-
)
|
| 1577 |
-
|
| 1578 |
-
with gr.Row():
|
| 1579 |
-
ttm_num_steps = gr.Slider(
|
| 1580 |
-
minimum=20,
|
| 1581 |
-
maximum=100,
|
| 1582 |
-
value=50,
|
| 1583 |
-
step=5,
|
| 1584 |
-
label="Inference Steps"
|
| 1585 |
-
)
|
| 1586 |
-
ttm_seed = gr.Number(
|
| 1587 |
-
value=0,
|
| 1588 |
-
label="Seed",
|
| 1589 |
-
precision=0
|
| 1590 |
-
)
|
| 1591 |
-
|
| 1592 |
-
ttm_generate_btn = gr.Button(
|
| 1593 |
-
"🎬 Generate TTM Video",
|
| 1594 |
-
variant="primary",
|
| 1595 |
-
size="lg",
|
| 1596 |
-
interactive=TTM_AVAILABLE
|
| 1597 |
-
)
|
| 1598 |
-
|
| 1599 |
-
with gr.Column(scale=1):
|
| 1600 |
-
gr.Markdown("### 📤 TTM Output")
|
| 1601 |
-
ttm_output_video = gr.Video(
|
| 1602 |
-
label="TTM Generated Video",
|
| 1603 |
-
height=400
|
| 1604 |
-
)
|
| 1605 |
-
ttm_status_text = gr.Markdown(
|
| 1606 |
-
"Upload a video in Step 1 first, then run TTM here.")
|
| 1607 |
-
|
| 1608 |
-
# TTM Input preview
|
| 1609 |
-
with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False):
|
| 1610 |
-
with gr.Row():
|
| 1611 |
-
ttm_preview_first_frame = gr.Image(
|
| 1612 |
-
label="First Frame",
|
| 1613 |
-
height=150
|
| 1614 |
-
)
|
| 1615 |
-
ttm_preview_motion = gr.Video(
|
| 1616 |
-
label="Motion Signal",
|
| 1617 |
-
height=150
|
| 1618 |
-
)
|
| 1619 |
-
ttm_preview_mask = gr.Video(
|
| 1620 |
-
label="Mask",
|
| 1621 |
-
height=150
|
| 1622 |
-
)
|
| 1623 |
-
|
| 1624 |
-
# Helper function to update states and preview
|
| 1625 |
-
def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()):
|
| 1626 |
-
result = process_video(video_path, camera_movement,
|
| 1627 |
-
generate_ttm_flag, progress)
|
| 1628 |
-
output_vid, motion_sig, mask_vid, first_frame, status = result
|
| 1629 |
-
|
| 1630 |
-
# Return all outputs including state updates and previews
|
| 1631 |
-
return (
|
| 1632 |
-
output_vid, # output_video
|
| 1633 |
-
motion_sig, # motion_signal_output
|
| 1634 |
-
mask_vid, # mask_output
|
| 1635 |
-
first_frame, # first_frame_output
|
| 1636 |
-
status, # status_text
|
| 1637 |
-
first_frame, # first_frame_state
|
| 1638 |
-
motion_sig, # motion_signal_state
|
| 1639 |
-
mask_vid, # mask_state
|
| 1640 |
-
first_frame, # ttm_preview_first_frame
|
| 1641 |
-
motion_sig, # ttm_preview_motion
|
| 1642 |
-
mask_vid, # ttm_preview_mask
|
| 1643 |
-
)
|
| 1644 |
|
| 1645 |
# Event handlers
|
| 1646 |
generate_btn.click(
|
| 1647 |
-
fn=
|
| 1648 |
inputs=[video_input, camera_movement, generate_ttm],
|
| 1649 |
-
outputs=[
|
| 1650 |
-
output_video, motion_signal_output, mask_output, first_frame_output, status_text,
|
| 1651 |
-
first_frame_state, motion_signal_state, mask_state,
|
| 1652 |
-
ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
|
| 1653 |
-
]
|
| 1654 |
-
)
|
| 1655 |
-
|
| 1656 |
-
# TTM generation event
|
| 1657 |
-
ttm_generate_btn.click(
|
| 1658 |
-
fn=run_ttm_generation,
|
| 1659 |
-
inputs=[
|
| 1660 |
-
first_frame_state,
|
| 1661 |
-
motion_signal_state,
|
| 1662 |
-
mask_state,
|
| 1663 |
-
ttm_prompt,
|
| 1664 |
-
ttm_negative_prompt,
|
| 1665 |
-
ttm_model_choice,
|
| 1666 |
-
ttm_tweak_index,
|
| 1667 |
-
ttm_tstrong_index,
|
| 1668 |
-
ttm_num_frames,
|
| 1669 |
-
ttm_num_steps,
|
| 1670 |
-
ttm_guidance_scale,
|
| 1671 |
-
ttm_seed
|
| 1672 |
-
],
|
| 1673 |
-
outputs=[ttm_output_video, ttm_status_text]
|
| 1674 |
)
|
| 1675 |
|
| 1676 |
# Examples
|
| 1677 |
gr.Markdown("### 📁 Examples")
|
| 1678 |
if os.path.exists("./examples"):
|
| 1679 |
-
example_videos = [f for f in os.listdir(
|
| 1680 |
-
"./examples") if f.endswith(".mp4")][:4]
|
| 1681 |
if example_videos:
|
| 1682 |
gr.Examples(
|
| 1683 |
-
examples=[[f"./examples/{v}", "move_forward", True]
|
| 1684 |
-
for v in example_videos],
|
| 1685 |
inputs=[video_input, camera_movement, generate_ttm],
|
| 1686 |
-
outputs=[
|
| 1687 |
-
|
| 1688 |
-
first_frame_state, motion_signal_state, mask_state,
|
| 1689 |
-
ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
|
| 1690 |
-
],
|
| 1691 |
-
fn=process_and_update_states,
|
| 1692 |
cache_examples=False
|
| 1693 |
)
|
| 1694 |
|
| 1695 |
# Launch
|
| 1696 |
-
logger.info("✅ Gradio interface created successfully!")
|
| 1697 |
-
logger.info("=" * 50)
|
| 1698 |
-
logger.info("Application ready to launch")
|
| 1699 |
-
logger.info("=" * 50)
|
| 1700 |
-
sys.stdout.flush()
|
| 1701 |
-
|
| 1702 |
if __name__ == "__main__":
|
| 1703 |
-
logger.info("Starting Gradio server...")
|
| 1704 |
-
sys.stdout.flush()
|
| 1705 |
demo.launch(share=False)
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import numpy as np
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from einops import rearrange
|
| 9 |
from typing import Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
import spaces
|
|
|
|
| 12 |
except ImportError:
|
| 13 |
+
def spaces(func):
|
| 14 |
+
return func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import torchvision.transforms as T
|
| 17 |
+
import logging
|
| 18 |
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
import atexit
|
| 20 |
import uuid
|
|
|
|
|
|
|
|
|
|
| 21 |
import decord
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
| 24 |
+
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
| 25 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
| 26 |
+
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 27 |
|
| 28 |
+
# Configure logging
|
| 29 |
+
logging.basicConfig(level=logging.INFO)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Constants
|
| 33 |
MAX_FRAMES = 80
|
|
|
|
| 46 |
"move_down"
|
| 47 |
]
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Thread pool for delayed deletion
|
| 50 |
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 53 |
"""Delete file or directory after specified delay"""
|
| 54 |
def _delete():
|
|
|
|
| 67 |
thread_pool_executor.submit(_wait_and_delete)
|
| 68 |
atexit.register(_delete)
|
| 69 |
|
|
|
|
| 70 |
def create_user_temp_dir():
|
| 71 |
"""Create a unique temporary directory for each user session"""
|
| 72 |
session_id = str(uuid.uuid4())[:8]
|
|
|
|
| 75 |
delete_later(temp_dir, delay=600)
|
| 76 |
return temp_dir
|
| 77 |
|
| 78 |
+
# Global model initialization
|
| 79 |
+
print("🚀 Initializing models...")
|
| 80 |
+
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
| 81 |
+
vggt4track_model.eval()
|
| 82 |
+
vggt4track_model = vggt4track_model.to("cuda")
|
| 83 |
|
| 84 |
+
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
| 85 |
+
tracker_model.eval()
|
| 86 |
+
print("✅ Models loaded successfully!")
|
|
|
|
| 87 |
|
|
|
|
| 88 |
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def generate_camera_trajectory(num_frames: int, movement_type: str,
|
|
|
|
| 111 |
if movement_type == "static":
|
| 112 |
pass # Keep identity
|
| 113 |
elif movement_type == "move_forward":
|
| 114 |
+
ext[2, 3] = -speed * t # Move along -Z (forward in OpenGL convention)
|
|
|
|
| 115 |
elif movement_type == "move_backward":
|
| 116 |
ext[2, 3] = speed * t # Move along +Z
|
| 117 |
elif movement_type == "move_left":
|
|
|
|
| 168 |
base_dir = os.path.dirname(output_path)
|
| 169 |
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
|
| 170 |
mask_path = os.path.join(base_dir, "mask.mp4")
|
| 171 |
+
out_motion_signal = cv2.VideoWriter(motion_signal_path, fourcc, fps, (W, H))
|
|
|
|
| 172 |
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
|
| 173 |
|
| 174 |
# Create meshgrid for pixel coordinates
|
|
|
|
| 248 |
if hole_mask.sum() == 0:
|
| 249 |
break
|
| 250 |
dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
|
| 251 |
+
motion_signal_frame = np.where(hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
|
| 252 |
+
hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
|
|
|
|
|
|
|
| 253 |
|
| 254 |
# Write TTM outputs if enabled
|
| 255 |
if generate_ttm_inputs:
|
| 256 |
# Motion signal: warped frame with NN inpainting
|
| 257 |
+
motion_signal_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR)
|
|
|
|
| 258 |
out_motion_signal.write(motion_signal_bgr)
|
| 259 |
|
| 260 |
# Mask: binary mask of valid (projected) pixels - white where valid, black where holes
|
| 261 |
+
mask_frame = np.stack([valid_mask, valid_mask, valid_mask], axis=-1)
|
|
|
|
| 262 |
out_mask.write(mask_frame)
|
| 263 |
|
| 264 |
# For the rendered output, use the same inpainted result
|
|
|
|
| 278 |
}
|
| 279 |
|
| 280 |
|
| 281 |
+
@spaces.GPU
|
| 282 |
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 283 |
"""
|
| 284 |
GPU-intensive spatial tracking function.
|
|
|
|
| 289 |
Returns:
|
| 290 |
Dictionary containing tracking results
|
| 291 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
# Run VGGT to get depth and camera poses
|
| 293 |
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 294 |
|
|
|
|
|
|
|
|
|
|
| 295 |
with torch.no_grad():
|
| 296 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 297 |
predictions = vggt4track_model(video_input / 255)
|
|
|
|
| 300 |
depth_map = predictions["points_map"][..., 2]
|
| 301 |
depth_conf = predictions["unc_metric"]
|
| 302 |
|
|
|
|
|
|
|
|
|
|
| 303 |
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 304 |
extrs = extrinsic.squeeze().cpu().numpy()
|
| 305 |
intrs = intrinsic.squeeze().cpu().numpy()
|
|
|
|
| 307 |
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 308 |
|
| 309 |
# Setup tracker
|
|
|
|
|
|
|
|
|
|
| 310 |
tracker_model.spatrack.track_num = 512
|
| 311 |
tracker_model.to("cuda")
|
| 312 |
|
| 313 |
# Get grid points for tracking
|
| 314 |
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 315 |
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 316 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
# Run tracker
|
| 319 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
|
| 341 |
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 342 |
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
|
| 343 |
|
|
|
|
|
|
|
|
|
|
| 344 |
# Move results to CPU and return
|
| 345 |
+
return {
|
| 346 |
'video_out': video_out.cpu(),
|
| 347 |
'point_map': point_map.cpu(),
|
| 348 |
'conf_depth': conf_depth.cpu(),
|
|
|
|
| 350 |
'c2w_traj': c2w_traj.cpu(),
|
| 351 |
}
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
|
| 355 |
"""Main processing function
|
|
|
|
| 405 |
c2w_traj = tracking_results['c2w_traj']
|
| 406 |
|
| 407 |
# Get RGB frames and depth
|
| 408 |
+
rgb_frames = rearrange(video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
|
|
|
|
| 409 |
depth_frames = point_map[:, 2].numpy()
|
| 410 |
depth_conf_np = conf_depth.numpy()
|
| 411 |
|
|
|
|
| 416 |
intrs_np = intrs_out.numpy()
|
| 417 |
extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
|
| 418 |
|
| 419 |
+
progress(0.7, desc=f"Generating {camera_movement} camera trajectory...")
|
|
|
|
| 420 |
|
| 421 |
# Calculate scene scale from depth
|
| 422 |
valid_depth = depth_frames[depth_frames > 0]
|
|
|
|
| 471 |
return None, None, None, None, f"❌ Error: {str(e)}"
|
| 472 |
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# Create Gradio interface
|
| 475 |
+
print("🎨 Creating Gradio interface...")
|
|
|
|
| 476 |
|
| 477 |
with gr.Blocks(
|
| 478 |
theme=gr.themes.Soft(),
|
| 479 |
title="🎬 Video to Point Cloud Renderer",
|
| 480 |
css="""
|
| 481 |
.gradio-container {
|
| 482 |
+
max-width: 1200px !important;
|
| 483 |
margin: auto !important;
|
| 484 |
}
|
| 485 |
"""
|
| 486 |
) as demo:
|
| 487 |
gr.Markdown("""
|
| 488 |
+
# 🎬 Video to Point Cloud Renderer (TTM Compatible)
|
| 489 |
|
| 490 |
+
Upload a video to generate a 3D point cloud and render it from a new camera perspective.
|
| 491 |
+
Generates outputs compatible with **Time-to-Move (TTM)** motion-controlled video generation.
|
| 492 |
|
| 493 |
+
**How it works:**
|
| 494 |
+
1. Upload a video
|
| 495 |
+
2. Select a camera movement type
|
| 496 |
+
3. Click "Generate" to create the rendered video and TTM inputs
|
| 497 |
|
| 498 |
+
**TTM Inputs:**
|
| 499 |
+
- `first_frame.png`: The first frame of the original video
|
| 500 |
+
- `motion_signal.mp4`: Warped video with nearest-neighbor inpainting
|
| 501 |
+
- `mask.mp4`: Binary mask showing valid projected pixels (white) vs holes (black)
|
| 502 |
""")
|
| 503 |
|
| 504 |
+
with gr.Row():
|
| 505 |
+
with gr.Column(scale=1):
|
| 506 |
+
gr.Markdown("### 📥 Input")
|
| 507 |
+
video_input = gr.Video(
|
| 508 |
+
label="Upload Video",
|
| 509 |
+
format="mp4",
|
| 510 |
+
height=300
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
camera_movement = gr.Dropdown(
|
| 514 |
+
choices=CAMERA_MOVEMENTS,
|
| 515 |
+
value="static",
|
| 516 |
+
label="🎥 Camera Movement",
|
| 517 |
+
info="Select how the camera should move in the rendered video"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
generate_ttm = gr.Checkbox(
|
| 521 |
+
label="🎯 Generate TTM Inputs",
|
| 522 |
+
value=True,
|
| 523 |
+
info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
|
| 527 |
+
|
| 528 |
+
with gr.Column(scale=1):
|
| 529 |
+
gr.Markdown("### 📤 Rendered Output")
|
| 530 |
+
output_video = gr.Video(
|
| 531 |
+
label="Rendered Video",
|
| 532 |
+
height=250
|
| 533 |
+
)
|
| 534 |
+
first_frame_output = gr.Image(
|
| 535 |
+
label="First Frame (first_frame.png)",
|
| 536 |
+
height=150
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
with gr.Row():
|
| 540 |
+
with gr.Column(scale=1):
|
| 541 |
+
gr.Markdown("### 🎯 TTM: Motion Signal")
|
| 542 |
+
motion_signal_output = gr.Video(
|
| 543 |
+
label="Motion Signal Video (motion_signal.mp4)",
|
| 544 |
+
height=250
|
| 545 |
+
)
|
| 546 |
+
with gr.Column(scale=1):
|
| 547 |
+
gr.Markdown("### 🎭 TTM: Mask")
|
| 548 |
+
mask_output = gr.Video(
|
| 549 |
+
label="Mask Video (mask.mp4)",
|
| 550 |
+
height=250
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
status_text = gr.Markdown("Ready to process...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
# Event handlers
|
| 556 |
generate_btn.click(
|
| 557 |
+
fn=process_video,
|
| 558 |
inputs=[video_input, camera_movement, generate_ttm],
|
| 559 |
+
outputs=[output_video, motion_signal_output, mask_output, first_frame_output, status_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
)
|
| 561 |
|
| 562 |
# Examples
|
| 563 |
gr.Markdown("### 📁 Examples")
|
| 564 |
if os.path.exists("./examples"):
|
| 565 |
+
example_videos = [f for f in os.listdir("./examples") if f.endswith(".mp4")][:4]
|
|
|
|
| 566 |
if example_videos:
|
| 567 |
gr.Examples(
|
| 568 |
+
examples=[[f"./examples/{v}", "move_forward", True] for v in example_videos],
|
|
|
|
| 569 |
inputs=[video_input, camera_movement, generate_ttm],
|
| 570 |
+
outputs=[output_video, motion_signal_output, mask_output, first_frame_output, status_text],
|
| 571 |
+
fn=process_video,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
cache_examples=False
|
| 573 |
)
|
| 574 |
|
| 575 |
# Launch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 577 |
demo.launch(share=False)
|
src/config.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
MAX_FRAMES = 80
|
| 4 |
-
OUTPUT_FPS = 24
|
| 5 |
-
RENDER_WIDTH = 512
|
| 6 |
-
RENDER_HEIGHT = 384
|
| 7 |
-
|
| 8 |
-
CAMERA_MOVEMENTS = [
|
| 9 |
-
"static",
|
| 10 |
-
"move_forward",
|
| 11 |
-
"move_backward",
|
| 12 |
-
"move_left",
|
| 13 |
-
"move_right",
|
| 14 |
-
"move_up",
|
| 15 |
-
"move_down"
|
| 16 |
-
]
|
| 17 |
-
|
| 18 |
-
TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V"
|
| 19 |
-
TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 20 |
-
TTM_DTYPE = torch.bfloat16
|
| 21 |
-
TTM_DEFAULT_NUM_FRAMES = 49
|
| 22 |
-
TTM_DEFAULT_NUM_INFERENCE_STEPS = 50
|
| 23 |
-
|
| 24 |
-
TTM_COG_AVAILABLE = False
|
| 25 |
-
TTM_WAN_AVAILABLE = False
|
| 26 |
-
try:
|
| 27 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
| 28 |
-
TTM_COG_AVAILABLE = True
|
| 29 |
-
except ImportError:
|
| 30 |
-
pass
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
from diffusers import AutoencoderKLWan, WanTransformer3DModel
|
| 34 |
-
TTM_WAN_AVAILABLE = True
|
| 35 |
-
except ImportError:
|
| 36 |
-
pass
|
| 37 |
-
|
| 38 |
-
TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
|
| 39 |
-
|
| 40 |
-
TTM_MODELS = []
|
| 41 |
-
if TTM_COG_AVAILABLE:
|
| 42 |
-
TTM_MODELS.append("CogVideoX-5B")
|
| 43 |
-
if TTM_WAN_AVAILABLE:
|
| 44 |
-
TTM_MODELS.append("Wan2.2-14B (Recommended)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model_manager.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
from models.SpaTrackV2.models.predictor import Predictor
|
| 2 |
-
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
| 3 |
-
import logging
|
| 4 |
-
from .config import (
|
| 5 |
-
TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE,
|
| 6 |
-
TTM_COG_MODEL_ID, TTM_WAN_MODEL_ID, TTM_DTYPE
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
vggt4track_model = None
|
| 12 |
-
tracker_model = None
|
| 13 |
-
ttm_cog_pipeline = None
|
| 14 |
-
ttm_wan_pipeline = None
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def init_spatial_models():
|
| 18 |
-
global vggt4track_model, tracker_model
|
| 19 |
-
print("🚀 Initializing models...")
|
| 20 |
-
vggt4track_model = VGGT4Track.from_pretrained(
|
| 21 |
-
"Yuxihenry/SpatialTrackerV2_Front")
|
| 22 |
-
vggt4track_model.eval()
|
| 23 |
-
vggt4track_model = vggt4track_model.to("cuda")
|
| 24 |
-
|
| 25 |
-
tracker_model = Predictor.from_pretrained(
|
| 26 |
-
"Yuxihenry/SpatialTrackerV2-Offline")
|
| 27 |
-
tracker_model.eval()
|
| 28 |
-
print("✅ Spatial Models loaded successfully!")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def get_ttm_cog_pipeline():
|
| 32 |
-
global ttm_cog_pipeline
|
| 33 |
-
if ttm_cog_pipeline is None and TTM_COG_AVAILABLE:
|
| 34 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
| 35 |
-
logger.info("Loading TTM CogVideoX pipeline...")
|
| 36 |
-
ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained(
|
| 37 |
-
TTM_COG_MODEL_ID,
|
| 38 |
-
torch_dtype=TTM_DTYPE,
|
| 39 |
-
low_cpu_mem_usage=True,
|
| 40 |
-
)
|
| 41 |
-
ttm_cog_pipeline.vae.enable_tiling()
|
| 42 |
-
ttm_cog_pipeline.vae.enable_slicing()
|
| 43 |
-
logger.info("TTM CogVideoX pipeline loaded successfully!")
|
| 44 |
-
return ttm_cog_pipeline
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_ttm_wan_pipeline():
|
| 48 |
-
global ttm_wan_pipeline
|
| 49 |
-
if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE:
|
| 50 |
-
from diffusers import WanImageToVideoPipeline
|
| 51 |
-
logger.info("Loading TTM Wan 2.2 pipeline...")
|
| 52 |
-
ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained(
|
| 53 |
-
TTM_WAN_MODEL_ID,
|
| 54 |
-
torch_dtype=TTM_DTYPE,
|
| 55 |
-
)
|
| 56 |
-
ttm_wan_pipeline.vae.enable_tiling()
|
| 57 |
-
ttm_wan_pipeline.vae.enable_slicing()
|
| 58 |
-
logger.info("TTM Wan 2.2 pipeline loaded successfully!")
|
| 59 |
-
return ttm_wan_pipeline
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
init_spatial_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/spatial_pipeline.py
DELETED
|
@@ -1,277 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import decord
|
| 6 |
-
import gradio as gr
|
| 7 |
-
import torchvision.transforms as T
|
| 8 |
-
from einops import rearrange
|
| 9 |
-
|
| 10 |
-
from .config import MAX_FRAMES, OUTPUT_FPS
|
| 11 |
-
from .utils import logger, create_user_temp_dir
|
| 12 |
-
from . import model_manager
|
| 13 |
-
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
| 14 |
-
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
import spaces
|
| 18 |
-
except ImportError:
|
| 19 |
-
class spaces:
|
| 20 |
-
@staticmethod
|
| 21 |
-
def GPU(func=None, duration=None):
|
| 22 |
-
def decorator(f):
|
| 23 |
-
return f
|
| 24 |
-
return decorator if func is None else func
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def generate_camera_trajectory(num_frames: int, movement_type: str,
|
| 28 |
-
base_intrinsics: np.ndarray,
|
| 29 |
-
scene_scale: float = 1.0) -> np.ndarray:
|
| 30 |
-
speed = scene_scale * 0.02
|
| 31 |
-
extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
|
| 32 |
-
|
| 33 |
-
for t in range(num_frames):
|
| 34 |
-
ext = np.eye(4, dtype=np.float32)
|
| 35 |
-
if movement_type == "static":
|
| 36 |
-
pass
|
| 37 |
-
elif movement_type == "move_forward":
|
| 38 |
-
ext[2, 3] = -speed * t
|
| 39 |
-
elif movement_type == "move_backward":
|
| 40 |
-
ext[2, 3] = speed * t
|
| 41 |
-
elif movement_type == "move_left":
|
| 42 |
-
ext[0, 3] = -speed * t
|
| 43 |
-
elif movement_type == "move_right":
|
| 44 |
-
ext[0, 3] = speed * t
|
| 45 |
-
elif movement_type == "move_up":
|
| 46 |
-
ext[1, 3] = -speed * t
|
| 47 |
-
elif movement_type == "move_down":
|
| 48 |
-
ext[1, 3] = speed * t
|
| 49 |
-
extrinsics[t] = ext
|
| 50 |
-
return extrinsics
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics,
|
| 54 |
-
new_extrinsics, output_path, fps=24, generate_ttm_inputs=False):
|
| 55 |
-
T, H, W, _ = rgb_frames.shape
|
| 56 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 57 |
-
out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
|
| 58 |
-
|
| 59 |
-
motion_signal_path = None
|
| 60 |
-
mask_path = None
|
| 61 |
-
out_motion_signal = None
|
| 62 |
-
out_mask = None
|
| 63 |
-
|
| 64 |
-
if generate_ttm_inputs:
|
| 65 |
-
base_dir = os.path.dirname(output_path)
|
| 66 |
-
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
|
| 67 |
-
mask_path = os.path.join(base_dir, "mask.mp4")
|
| 68 |
-
out_motion_signal = cv2.VideoWriter(
|
| 69 |
-
motion_signal_path, fourcc, fps, (W, H))
|
| 70 |
-
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
|
| 71 |
-
|
| 72 |
-
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 73 |
-
ones = np.ones_like(u)
|
| 74 |
-
|
| 75 |
-
for t in range(T):
|
| 76 |
-
rgb = rgb_frames[t]
|
| 77 |
-
depth = depth_frames[t]
|
| 78 |
-
K = intrinsics[t]
|
| 79 |
-
orig_c2w = np.linalg.inv(original_extrinsics[t])
|
| 80 |
-
|
| 81 |
-
if t == 0:
|
| 82 |
-
base_c2w = orig_c2w.copy()
|
| 83 |
-
new_c2w = base_c2w @ new_extrinsics[t]
|
| 84 |
-
new_w2c = np.linalg.inv(new_c2w)
|
| 85 |
-
K_inv = np.linalg.inv(K)
|
| 86 |
-
|
| 87 |
-
pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3)
|
| 88 |
-
rays_cam = (K_inv @ pixels.T).T
|
| 89 |
-
points_cam = rays_cam * depth.reshape(-1, 1)
|
| 90 |
-
points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
|
| 91 |
-
points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
|
| 92 |
-
points_proj = (K @ points_new_cam.T).T
|
| 93 |
-
|
| 94 |
-
z = np.clip(points_proj[:, 2:3], 1e-6, None)
|
| 95 |
-
uv_new = points_proj[:, :2] / z
|
| 96 |
-
|
| 97 |
-
rendered = np.zeros((H, W, 3), dtype=np.uint8)
|
| 98 |
-
z_buffer = np.full((H, W), np.inf, dtype=np.float32)
|
| 99 |
-
colors = rgb.reshape(-1, 3)
|
| 100 |
-
depths_new = points_new_cam[:, 2]
|
| 101 |
-
|
| 102 |
-
# Rasterization loop (simplified)
|
| 103 |
-
for i in range(len(uv_new)):
|
| 104 |
-
uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
|
| 105 |
-
if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
|
| 106 |
-
if depths_new[i] < z_buffer[vv, uu]:
|
| 107 |
-
z_buffer[vv, uu] = depths_new[i]
|
| 108 |
-
rendered[vv, uu] = colors[i]
|
| 109 |
-
|
| 110 |
-
# Inpainting for TTM
|
| 111 |
-
valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
|
| 112 |
-
motion_signal_frame = rendered.copy()
|
| 113 |
-
hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
|
| 114 |
-
|
| 115 |
-
if hole_mask.sum() > 0:
|
| 116 |
-
kernel = np.ones((3, 3), np.uint8)
|
| 117 |
-
for _ in range(max(H, W)):
|
| 118 |
-
if hole_mask.sum() == 0:
|
| 119 |
-
break
|
| 120 |
-
dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
|
| 121 |
-
motion_signal_frame = np.where(
|
| 122 |
-
hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
|
| 123 |
-
hole_mask = (motion_signal_frame.sum(
|
| 124 |
-
axis=-1) == 0).astype(np.uint8)
|
| 125 |
-
|
| 126 |
-
if generate_ttm_inputs:
|
| 127 |
-
out_motion_signal.write(cv2.cvtColor(
|
| 128 |
-
motion_signal_frame, cv2.COLOR_RGB2BGR))
|
| 129 |
-
out_mask.write(np.stack([valid_mask]*3, axis=-1))
|
| 130 |
-
|
| 131 |
-
out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR))
|
| 132 |
-
|
| 133 |
-
out.release()
|
| 134 |
-
if generate_ttm_inputs:
|
| 135 |
-
out_motion_signal.release()
|
| 136 |
-
out_mask.release()
|
| 137 |
-
|
| 138 |
-
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
@spaces.GPU
|
| 142 |
-
def run_spatial_tracker(video_tensor: torch.Tensor):
|
| 143 |
-
video_input = preprocess_image(video_tensor)[None].cuda()
|
| 144 |
-
|
| 145 |
-
# Use global models from model_manager
|
| 146 |
-
with torch.no_grad():
|
| 147 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 148 |
-
predictions = model_manager.vggt4track_model(video_input / 255)
|
| 149 |
-
extrinsic = predictions["poses_pred"]
|
| 150 |
-
intrinsic = predictions["intrs"]
|
| 151 |
-
depth_map = predictions["points_map"][..., 2]
|
| 152 |
-
depth_conf = predictions["unc_metric"]
|
| 153 |
-
|
| 154 |
-
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 155 |
-
extrs = extrinsic.squeeze().cpu().numpy()
|
| 156 |
-
intrs = intrinsic.squeeze().cpu().numpy()
|
| 157 |
-
video_tensor_gpu = video_input.squeeze()
|
| 158 |
-
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 159 |
-
|
| 160 |
-
model_manager.tracker_model.spatrack.track_num = 512
|
| 161 |
-
model_manager.tracker_model.to("cuda")
|
| 162 |
-
|
| 163 |
-
frame_H, frame_W = video_tensor_gpu.shape[2:]
|
| 164 |
-
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
|
| 165 |
-
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
|
| 166 |
-
0].numpy()
|
| 167 |
-
|
| 168 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 169 |
-
results = model_manager.tracker_model.forward(
|
| 170 |
-
video_tensor_gpu, depth=depth_tensor,
|
| 171 |
-
intrs=intrs, extrs=extrs,
|
| 172 |
-
queries=query_xyt,
|
| 173 |
-
fps=1, full_point=False, iters_track=4,
|
| 174 |
-
query_no_BA=True, fixed_cam=False, stage=1,
|
| 175 |
-
unc_metric=unc_metric,
|
| 176 |
-
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# Unpack tuple from tracker
|
| 180 |
-
c2w_traj, intrs_out, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video_out = results
|
| 181 |
-
|
| 182 |
-
# Resize logic (abbreviated)
|
| 183 |
-
max_size = 384
|
| 184 |
-
h, w = video_out.shape[2:]
|
| 185 |
-
scale = min(max_size / h, max_size / w)
|
| 186 |
-
if scale < 1:
|
| 187 |
-
new_h, new_w = int(h * scale), int(w * scale)
|
| 188 |
-
video_out = T.Resize((new_h, new_w))(video_out)
|
| 189 |
-
point_map = T.Resize((new_h, new_w))(point_map)
|
| 190 |
-
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 191 |
-
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
|
| 192 |
-
|
| 193 |
-
return {
|
| 194 |
-
'video_out': video_out.cpu(),
|
| 195 |
-
'point_map': point_map.cpu(),
|
| 196 |
-
'conf_depth': conf_depth.cpu(),
|
| 197 |
-
'intrs_out': intrs_out.cpu(),
|
| 198 |
-
'c2w_traj': c2w_traj.cpu(),
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
|
| 203 |
-
if video_path is None:
|
| 204 |
-
return None, None, None, None, "❌ Please upload a video first"
|
| 205 |
-
|
| 206 |
-
progress(0, desc="Initializing...")
|
| 207 |
-
temp_dir = create_user_temp_dir()
|
| 208 |
-
out_dir = os.path.join(temp_dir, "results")
|
| 209 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 210 |
-
|
| 211 |
-
try:
|
| 212 |
-
progress(0.1, desc="Loading video...")
|
| 213 |
-
video_reader = decord.VideoReader(video_path)
|
| 214 |
-
video_tensor = torch.from_numpy(
|
| 215 |
-
video_reader.get_batch(range(len(video_reader))).asnumpy()
|
| 216 |
-
).permute(0, 3, 1, 2).float()
|
| 217 |
-
|
| 218 |
-
fps_skip = max(1, len(video_tensor) // MAX_FRAMES)
|
| 219 |
-
video_tensor = video_tensor[::fps_skip][:MAX_FRAMES]
|
| 220 |
-
|
| 221 |
-
h, w = video_tensor.shape[2:]
|
| 222 |
-
scale = 336 / min(h, w)
|
| 223 |
-
if scale < 1:
|
| 224 |
-
new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2
|
| 225 |
-
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 226 |
-
|
| 227 |
-
progress(0.4, desc="Running 3D tracking...")
|
| 228 |
-
tracking_results = run_spatial_tracker(video_tensor)
|
| 229 |
-
|
| 230 |
-
video_out = tracking_results['video_out']
|
| 231 |
-
point_map = tracking_results['point_map']
|
| 232 |
-
conf_depth = tracking_results['conf_depth']
|
| 233 |
-
intrs_out = tracking_results['intrs_out']
|
| 234 |
-
c2w_traj = tracking_results['c2w_traj']
|
| 235 |
-
|
| 236 |
-
rgb_frames = rearrange(
|
| 237 |
-
video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
|
| 238 |
-
depth_frames = point_map[:, 2].numpy()
|
| 239 |
-
depth_frames[conf_depth.numpy() < 0.5] = 0
|
| 240 |
-
|
| 241 |
-
intrs_np = intrs_out.numpy()
|
| 242 |
-
extrs_np = torch.inverse(c2w_traj).numpy()
|
| 243 |
-
|
| 244 |
-
progress(
|
| 245 |
-
0.7, desc=f"Generating {camera_movement} camera trajectory...")
|
| 246 |
-
valid_depth = depth_frames[depth_frames > 0]
|
| 247 |
-
scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0
|
| 248 |
-
|
| 249 |
-
new_extrinsics = generate_camera_trajectory(
|
| 250 |
-
len(rgb_frames), camera_movement, intrs_np, scene_scale
|
| 251 |
-
)
|
| 252 |
-
|
| 253 |
-
progress(0.8, desc="Rendering video...")
|
| 254 |
-
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
|
| 255 |
-
render_results = render_from_pointcloud(
|
| 256 |
-
rgb_frames, depth_frames, intrs_np, extrs_np,
|
| 257 |
-
new_extrinsics, output_video_path, fps=OUTPUT_FPS,
|
| 258 |
-
generate_ttm_inputs=generate_ttm
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
first_frame_path = None
|
| 262 |
-
if generate_ttm:
|
| 263 |
-
first_frame_path = os.path.join(out_dir, "first_frame.png")
|
| 264 |
-
cv2.imwrite(first_frame_path, cv2.cvtColor(
|
| 265 |
-
rgb_frames[0], cv2.COLOR_RGB2BGR))
|
| 266 |
-
|
| 267 |
-
status_msg = f"✅ Video rendered successfully with '{camera_movement}'!"
|
| 268 |
-
if generate_ttm:
|
| 269 |
-
status_msg += "\n\n📁 **TTM outputs generated**"
|
| 270 |
-
|
| 271 |
-
return render_results['rendered'], render_results.get('motion_signal'), render_results.get('mask'), first_frame_path, status_msg
|
| 272 |
-
|
| 273 |
-
except Exception as e:
|
| 274 |
-
logger.error(f"Error processing video: {e}")
|
| 275 |
-
import traceback
|
| 276 |
-
traceback.print_exc()
|
| 277 |
-
return None, None, None, None, f"❌ Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ttm_pipeline.py
DELETED
|
@@ -1,303 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import gradio as gr
|
| 5 |
-
from diffusers.utils import export_to_video, load_image
|
| 6 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 7 |
-
from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
try:
|
| 11 |
-
import spaces
|
| 12 |
-
except ImportError:
|
| 13 |
-
class spaces:
|
| 14 |
-
@staticmethod
|
| 15 |
-
def GPU(func=None, duration=None):
|
| 16 |
-
def decorator(f):
|
| 17 |
-
return f
|
| 18 |
-
return decorator if func is None else func
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from .config import TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE
|
| 22 |
-
from .utils import create_user_temp_dir, load_video_to_tensor
|
| 23 |
-
from . import model_manager
|
| 24 |
-
|
| 25 |
-
# --- Helper Classes ---
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class CogVideoXTTMHelper:
|
| 29 |
-
def __init__(self, pipeline):
|
| 30 |
-
self.pipeline = pipeline
|
| 31 |
-
self.vae = pipeline.vae
|
| 32 |
-
self.vae_scale_factor_spatial = 2 ** (
|
| 33 |
-
len(self.vae.config.block_out_channels) - 1)
|
| 34 |
-
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
| 35 |
-
self.vae_scaling_factor_image = self.vae.config.scaling_factor
|
| 36 |
-
self.video_processor = pipeline.video_processor
|
| 37 |
-
|
| 38 |
-
@torch.no_grad()
|
| 39 |
-
def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
|
| 40 |
-
latents = self.vae.encode(
|
| 41 |
-
frames)[0].sample() * self.vae_scaling_factor_image
|
| 42 |
-
return latents.permute(0, 2, 1, 3, 4).contiguous()
|
| 43 |
-
|
| 44 |
-
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
| 45 |
-
k = self.vae_scale_factor_temporal
|
| 46 |
-
mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
|
| 47 |
-
pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
|
| 48 |
-
s = self.vae_scale_factor_spatial
|
| 49 |
-
H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
|
| 50 |
-
pooled = F.interpolate(pooled, size=(
|
| 51 |
-
pooled.shape[2], H_l, W_l), mode="nearest")
|
| 52 |
-
return pooled.permute(0, 2, 1, 3, 4)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class WanTTMHelper:
|
| 56 |
-
def __init__(self, pipeline):
|
| 57 |
-
self.pipeline = pipeline
|
| 58 |
-
self.vae = pipeline.vae
|
| 59 |
-
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
| 60 |
-
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
| 61 |
-
self.video_processor = pipeline.video_processor
|
| 62 |
-
|
| 63 |
-
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
| 64 |
-
k = self.vae_scale_factor_temporal
|
| 65 |
-
mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
|
| 66 |
-
pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
|
| 67 |
-
s = self.vae_scale_factor_spatial
|
| 68 |
-
H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
|
| 69 |
-
pooled = F.interpolate(pooled, size=(
|
| 70 |
-
pooled.shape[2], H_l, W_l), mode="nearest")
|
| 71 |
-
return pooled.permute(0, 2, 1, 3, 4)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def compute_hw_from_area(h, w, max_area, mod_value):
|
| 75 |
-
aspect = h / w
|
| 76 |
-
height = round(np.sqrt(max_area * aspect)) // mod_value * mod_value
|
| 77 |
-
width = round(np.sqrt(max_area / aspect)) // mod_value * mod_value
|
| 78 |
-
return int(height), int(width)
|
| 79 |
-
|
| 80 |
-
# --- Generation Functions ---
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
@spaces.GPU(duration=300)
|
| 84 |
-
def run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
|
| 85 |
-
tweak_index=4, tstrong_index=9, num_frames=49,
|
| 86 |
-
num_inference_steps=50, guidance_scale=6.0, seed=0, progress=gr.Progress()):
|
| 87 |
-
if not TTM_COG_AVAILABLE:
|
| 88 |
-
return None, "❌ CogVideoX TTM not available."
|
| 89 |
-
|
| 90 |
-
pipe = model_manager.get_ttm_cog_pipeline()
|
| 91 |
-
if not pipe:
|
| 92 |
-
return None, "❌ Failed to load pipeline"
|
| 93 |
-
pipe = pipe.to("cuda")
|
| 94 |
-
ttm_helper = CogVideoXTTMHelper(pipe)
|
| 95 |
-
|
| 96 |
-
device = "cuda"
|
| 97 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 98 |
-
|
| 99 |
-
image = load_image(first_frame_path)
|
| 100 |
-
height = pipe.transformer.config.sample_height * \
|
| 101 |
-
ttm_helper.vae_scale_factor_spatial
|
| 102 |
-
width = pipe.transformer.config.sample_width * \
|
| 103 |
-
ttm_helper.vae_scale_factor_spatial
|
| 104 |
-
|
| 105 |
-
do_cfg = guidance_scale > 1.0
|
| 106 |
-
prompt_embeds, neg_embeds = pipe.encode_prompt(
|
| 107 |
-
prompt, "", do_cfg, 1, 226, device)
|
| 108 |
-
if do_cfg:
|
| 109 |
-
prompt_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0)
|
| 110 |
-
|
| 111 |
-
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 112 |
-
timesteps = pipe.scheduler.timesteps
|
| 113 |
-
|
| 114 |
-
latent_frames = (
|
| 115 |
-
num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
|
| 116 |
-
image_tensor = ttm_helper.video_processor.preprocess(
|
| 117 |
-
image, height=height, width=width).to(device, dtype=prompt_embeds.dtype)
|
| 118 |
-
latent_channels = pipe.transformer.config.in_channels // 2
|
| 119 |
-
latents, image_latents = pipe.prepare_latents(
|
| 120 |
-
image_tensor, 1, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, None)
|
| 121 |
-
|
| 122 |
-
ref_vid = load_video_to_tensor(motion_signal_path).to(device)
|
| 123 |
-
ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
|
| 124 |
-
height, width), mode="bicubic").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
|
| 125 |
-
ref_vid = ttm_helper.video_processor.normalize(
|
| 126 |
-
ref_vid.to(dtype=pipe.vae.dtype))
|
| 127 |
-
ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
|
| 128 |
-
|
| 129 |
-
ref_mask = load_video_to_tensor(mask_path).to(device)
|
| 130 |
-
ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
|
| 131 |
-
height, width), mode="nearest").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
|
| 132 |
-
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
|
| 133 |
-
ref_mask[0, :, :1].permute(1, 0, 2, 3).contiguous())
|
| 134 |
-
background_mask = 1.0 - motion_mask
|
| 135 |
-
|
| 136 |
-
fixed_noise = randn_tensor(
|
| 137 |
-
ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
|
| 138 |
-
if tweak_index >= 0:
|
| 139 |
-
latents = pipe.scheduler.add_noise(
|
| 140 |
-
ref_latents, fixed_noise, timesteps[tweak_index].long()).to(dtype=latents.dtype)
|
| 141 |
-
|
| 142 |
-
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
|
| 143 |
-
|
| 144 |
-
for i, t in enumerate(timesteps[tweak_index:]):
|
| 145 |
-
progress(0.4 + 0.5 * (i / len(timesteps)), desc="Denoising...")
|
| 146 |
-
|
| 147 |
-
latent_input = torch.cat([latents] * 2) if do_cfg else latents
|
| 148 |
-
latent_input = pipe.scheduler.scale_model_input(latent_input, t)
|
| 149 |
-
latent_input = torch.cat([latent_input, torch.cat(
|
| 150 |
-
[image_latents]*2) if do_cfg else image_latents], dim=2)
|
| 151 |
-
|
| 152 |
-
noise_pred = pipe.transformer(hidden_states=latent_input, encoder_hidden_states=prompt_embeds, timestep=t.expand(
|
| 153 |
-
latent_input.shape[0]), return_dict=False)[0].float()
|
| 154 |
-
|
| 155 |
-
if do_cfg:
|
| 156 |
-
uncond, text = noise_pred.chunk(2)
|
| 157 |
-
noise_pred = uncond + guidance_scale * (text - uncond)
|
| 158 |
-
|
| 159 |
-
latents = pipe.scheduler.step(
|
| 160 |
-
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 161 |
-
|
| 162 |
-
if (i + tweak_index) < tstrong_index:
|
| 163 |
-
next_t = timesteps[i + tweak_index + 1] if i + \
|
| 164 |
-
tweak_index + 1 < len(timesteps) else None
|
| 165 |
-
if next_t is not None:
|
| 166 |
-
noisy_ref = pipe.scheduler.add_noise(
|
| 167 |
-
ref_latents, fixed_noise, next_t.long()).to(dtype=latents.dtype)
|
| 168 |
-
latents = latents * background_mask + noisy_ref * motion_mask
|
| 169 |
-
else:
|
| 170 |
-
latents = latents * background_mask + ref_latents * motion_mask
|
| 171 |
-
|
| 172 |
-
latents = latents.to(prompt_embeds.dtype)
|
| 173 |
-
|
| 174 |
-
frames = pipe.decode_latents(latents)
|
| 175 |
-
video = ttm_helper.video_processor.postprocess_video(
|
| 176 |
-
video=frames, output_type="pil")
|
| 177 |
-
|
| 178 |
-
out_path = os.path.join(create_user_temp_dir(), "ttm_cog_out.mp4")
|
| 179 |
-
export_to_video(video[0], out_path, fps=8)
|
| 180 |
-
return out_path, "✅ Video Generated"
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
@spaces.GPU(duration=300)
|
| 184 |
-
def run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt="",
|
| 185 |
-
tweak_index=3, tstrong_index=7, num_frames=81, num_inference_steps=50,
|
| 186 |
-
guidance_scale=3.5, seed=0, progress=gr.Progress()):
|
| 187 |
-
if not TTM_WAN_AVAILABLE:
|
| 188 |
-
return None, "❌ Wan TTM not available."
|
| 189 |
-
|
| 190 |
-
pipe = model_manager.get_ttm_wan_pipeline()
|
| 191 |
-
if not pipe:
|
| 192 |
-
return None, "❌ Failed to load pipeline"
|
| 193 |
-
pipe = pipe.to("cuda")
|
| 194 |
-
ttm_helper = WanTTMHelper(pipe)
|
| 195 |
-
|
| 196 |
-
device = "cuda"
|
| 197 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
| 198 |
-
|
| 199 |
-
image = load_image(first_frame_path)
|
| 200 |
-
h, w = compute_hw_from_area(image.height, image.width, 480*832,
|
| 201 |
-
ttm_helper.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1])
|
| 202 |
-
image = image.resize((w, h))
|
| 203 |
-
|
| 204 |
-
do_cfg = guidance_scale > 1.0
|
| 205 |
-
prompt_embeds, neg_embeds = pipe.encode_prompt(
|
| 206 |
-
prompt, negative_prompt, do_cfg, 1, 512, device)
|
| 207 |
-
prompt_embeds = prompt_embeds.to(pipe.transformer.dtype)
|
| 208 |
-
if neg_embeds is not None:
|
| 209 |
-
neg_embeds = neg_embeds.to(pipe.transformer.dtype)
|
| 210 |
-
|
| 211 |
-
image_embeds = pipe.encode_image(image, device).repeat(1, 1, 1).to(
|
| 212 |
-
pipe.transformer.dtype) if pipe.transformer.config.image_dim else None
|
| 213 |
-
|
| 214 |
-
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 215 |
-
timesteps = pipe.scheduler.timesteps
|
| 216 |
-
|
| 217 |
-
if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
|
| 218 |
-
num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
|
| 219 |
-
ttm_helper.vae_scale_factor_temporal + 1
|
| 220 |
-
|
| 221 |
-
image_tensor = ttm_helper.video_processor.preprocess(
|
| 222 |
-
image, height=h, width=w).to(device, dtype=torch.float32)
|
| 223 |
-
latents, condition = pipe.prepare_latents(
|
| 224 |
-
image_tensor, 1, pipe.vae.config.z_dim, h, w, num_frames, torch.float32, device, generator, None, None)
|
| 225 |
-
|
| 226 |
-
ref_vid = load_video_to_tensor(motion_signal_path).to(device)
|
| 227 |
-
ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
|
| 228 |
-
h, w), mode="bicubic").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
|
| 229 |
-
ref_vid = ttm_helper.video_processor.normalize(
|
| 230 |
-
ref_vid.to(dtype=pipe.vae.dtype))
|
| 231 |
-
ref_latents = retrieve_latents(
|
| 232 |
-
pipe.vae.encode(ref_vid), sample_mode="argmax")
|
| 233 |
-
|
| 234 |
-
mean = torch.tensor(pipe.vae.config.latents_mean).view(
|
| 235 |
-
1, -1, 1, 1, 1).to(device, ref_latents.dtype)
|
| 236 |
-
std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, -
|
| 237 |
-
1, 1, 1, 1).to(device, ref_latents.dtype)
|
| 238 |
-
ref_latents = (ref_latents - mean) * std
|
| 239 |
-
|
| 240 |
-
ref_mask = load_video_to_tensor(mask_path).to(device)
|
| 241 |
-
ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
|
| 242 |
-
h, w), mode="nearest").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
|
| 243 |
-
mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()[:num_frames]
|
| 244 |
-
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
|
| 245 |
-
(mask_tc_hw > 0.5).float()).permute(0, 2, 1, 3, 4).contiguous()
|
| 246 |
-
background_mask = 1.0 - motion_mask
|
| 247 |
-
|
| 248 |
-
fixed_noise = randn_tensor(
|
| 249 |
-
ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
|
| 250 |
-
if tweak_index >= 0:
|
| 251 |
-
latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, torch.as_tensor(
|
| 252 |
-
timesteps[tweak_index], device=device).long())
|
| 253 |
-
|
| 254 |
-
for i, t in enumerate(timesteps[tweak_index:]):
|
| 255 |
-
progress(0.4 + 0.5 * (i / len(timesteps)), desc=f"Step {i}")
|
| 256 |
-
|
| 257 |
-
latent_in = torch.cat([latents, condition], dim=1).to(
|
| 258 |
-
pipe.transformer.dtype)
|
| 259 |
-
ts = t.expand(latents.shape[0])
|
| 260 |
-
|
| 261 |
-
noise_pred = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=prompt_embeds,
|
| 262 |
-
encoder_hidden_states_image=image_embeds, return_dict=False)[0]
|
| 263 |
-
|
| 264 |
-
if do_cfg:
|
| 265 |
-
noise_uncond = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=neg_embeds,
|
| 266 |
-
encoder_hidden_states_image=image_embeds, return_dict=False)[0]
|
| 267 |
-
noise_pred = noise_uncond + guidance_scale * \
|
| 268 |
-
(noise_pred - noise_uncond)
|
| 269 |
-
|
| 270 |
-
latents = pipe.scheduler.step(
|
| 271 |
-
noise_pred, t, latents, return_dict=False)[0]
|
| 272 |
-
|
| 273 |
-
if (i + tweak_index) < tstrong_index:
|
| 274 |
-
next_t = timesteps[i + tweak_index + 1] if i + \
|
| 275 |
-
tweak_index + 1 < len(timesteps) else None
|
| 276 |
-
if next_t is not None:
|
| 277 |
-
noisy_ref = pipe.scheduler.add_noise(
|
| 278 |
-
ref_latents, fixed_noise, torch.as_tensor(next_t, device=device).long())
|
| 279 |
-
latents = latents * background_mask + noisy_ref * motion_mask
|
| 280 |
-
else:
|
| 281 |
-
latents = latents * background_mask + \
|
| 282 |
-
ref_latents.to(latents.dtype) * motion_mask
|
| 283 |
-
|
| 284 |
-
latents = latents.to(pipe.vae.dtype)
|
| 285 |
-
latents = latents / std + mean
|
| 286 |
-
video = pipe.vae.decode(latents, return_dict=False)[0]
|
| 287 |
-
video = ttm_helper.video_processor.postprocess_video(
|
| 288 |
-
video, output_type="pil")
|
| 289 |
-
|
| 290 |
-
out_path = os.path.join(create_user_temp_dir(), "ttm_wan_out.mp4")
|
| 291 |
-
export_to_video(video[0], out_path, fps=16)
|
| 292 |
-
return out_path, "✅ Video Generated"
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def run_ttm_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
|
| 296 |
-
model_choice, tweak_index, tstrong_index, num_frames, num_inference_steps,
|
| 297 |
-
guidance_scale, seed, progress=gr.Progress()):
|
| 298 |
-
if "Wan" in model_choice:
|
| 299 |
-
return run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
|
| 300 |
-
tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)
|
| 301 |
-
else:
|
| 302 |
-
return run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
|
| 303 |
-
tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import cv2
|
| 3 |
-
import time
|
| 4 |
-
import shutil
|
| 5 |
-
import logging
|
| 6 |
-
import uuid
|
| 7 |
-
import torch
|
| 8 |
-
import numpy as np
|
| 9 |
-
import atexit
|
| 10 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
-
from typing import Union
|
| 12 |
-
|
| 13 |
-
logging.basicConfig(level=logging.INFO)
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 17 |
-
|
| 18 |
-
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 19 |
-
def _delete():
|
| 20 |
-
try:
|
| 21 |
-
if os.path.isfile(path):
|
| 22 |
-
os.remove(path)
|
| 23 |
-
elif os.path.isdir(path):
|
| 24 |
-
shutil.rmtree(path)
|
| 25 |
-
except Exception as e:
|
| 26 |
-
logger.warning(f"Failed to delete {path}: {e}")
|
| 27 |
-
|
| 28 |
-
def _wait_and_delete():
|
| 29 |
-
time.sleep(delay)
|
| 30 |
-
_delete()
|
| 31 |
-
|
| 32 |
-
thread_pool_executor.submit(_wait_and_delete)
|
| 33 |
-
atexit.register(_delete)
|
| 34 |
-
|
| 35 |
-
def create_user_temp_dir():
|
| 36 |
-
session_id = str(uuid.uuid4())[:8]
|
| 37 |
-
temp_dir = os.path.join("temp_local", f"session_{session_id}")
|
| 38 |
-
os.makedirs(temp_dir, exist_ok=True)
|
| 39 |
-
delete_later(temp_dir, delay=600)
|
| 40 |
-
return temp_dir
|
| 41 |
-
|
| 42 |
-
def load_video_to_tensor(video_path: str) -> torch.Tensor:
|
| 43 |
-
cap = cv2.VideoCapture(video_path)
|
| 44 |
-
frames = []
|
| 45 |
-
while True:
|
| 46 |
-
ret, frame = cap.read()
|
| 47 |
-
if not ret:
|
| 48 |
-
break
|
| 49 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 50 |
-
frames.append(frame)
|
| 51 |
-
cap.release()
|
| 52 |
-
|
| 53 |
-
frames = np.array(frames)
|
| 54 |
-
video_tensor = torch.tensor(frames)
|
| 55 |
-
video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
|
| 56 |
-
video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
|
| 57 |
-
return video_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|