abreza commited on
Commit
ae7b7e0
·
1 Parent(s): c8dc4de
Files changed (6) hide show
  1. app.py +98 -1226
  2. src/config.py +0 -44
  3. src/model_manager.py +0 -62
  4. src/spatial_pipeline.py +0 -277
  5. src/ttm_pipeline.py +0 -303
  6. src/utils.py +0 -57
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import sys
2
  import gradio as gr
3
  import os
4
  import numpy as np
@@ -8,109 +7,27 @@ import shutil
8
  from pathlib import Path
9
  from einops import rearrange
10
  from typing import Union
11
-
12
- # Force unbuffered output for HF Spaces logs
13
- os.environ['PYTHONUNBUFFERED'] = '1'
14
-
15
- # Configure logging FIRST before any other imports
16
- import logging
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
- handlers=[
21
- logging.StreamHandler(sys.stdout)
22
- ]
23
- )
24
- logger = logging.getLogger(__name__)
25
- logger.info("=" * 50)
26
- logger.info("Starting application initialization...")
27
- logger.info("=" * 50)
28
- sys.stdout.flush()
29
-
30
  try:
31
  import spaces
32
- logger.info("✅ HF Spaces module imported successfully")
33
  except ImportError:
34
- logger.warning("⚠️ HF Spaces module not available, using mock")
35
- class spaces:
36
- @staticmethod
37
- def GPU(func=None, duration=None):
38
- def decorator(f):
39
- return f
40
- return decorator if func is None else func
41
- sys.stdout.flush()
42
-
43
- logger.info("Importing torch...")
44
- sys.stdout.flush()
45
  import torch
46
- logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
47
- sys.stdout.flush()
48
-
49
- import torch.nn.functional as F
50
  import torchvision.transforms as T
 
51
  from concurrent.futures import ThreadPoolExecutor
52
  import atexit
53
  import uuid
54
-
55
- logger.info("Importing decord...")
56
- sys.stdout.flush()
57
  import decord
58
- logger.info("✅ Decord imported successfully")
59
- sys.stdout.flush()
60
 
61
- from PIL import Image
 
 
 
62
 
63
- logger.info("Importing SpaTrack models...")
64
- sys.stdout.flush()
65
- try:
66
- from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
67
- from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
68
- from models.SpaTrackV2.models.predictor import Predictor
69
- from models.SpaTrackV2.models.utils import get_points_on_a_grid
70
- logger.info("✅ SpaTrack models imported successfully")
71
- except Exception as e:
72
- logger.error(f"❌ Failed to import SpaTrack models: {e}")
73
- raise
74
- sys.stdout.flush()
75
-
76
- # TTM imports (optional - will be loaded on demand)
77
- logger.info("Checking TTM (diffusers) availability...")
78
- sys.stdout.flush()
79
- TTM_COG_AVAILABLE = False
80
- TTM_WAN_AVAILABLE = False
81
- try:
82
- from diffusers import CogVideoXImageToVideoPipeline
83
- from diffusers.utils import export_to_video, load_image
84
- from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
85
- from diffusers.utils.torch_utils import randn_tensor
86
- from diffusers.video_processor import VideoProcessor
87
- TTM_COG_AVAILABLE = True
88
- logger.info("✅ CogVideoX TTM available")
89
- except ImportError as e:
90
- logger.info(f"ℹ️ CogVideoX TTM not available: {e}")
91
- sys.stdout.flush()
92
-
93
- try:
94
- from diffusers import AutoencoderKLWan, WanTransformer3DModel
95
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
96
- from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline, retrieve_latents
97
- from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
98
- if not TTM_COG_AVAILABLE:
99
- from diffusers.utils import export_to_video, load_image
100
- from diffusers.utils.torch_utils import randn_tensor
101
- from diffusers.video_processor import VideoProcessor
102
- TTM_WAN_AVAILABLE = True
103
- logger.info("✅ Wan TTM available")
104
- except ImportError as e:
105
- logger.info(f"ℹ️ Wan TTM not available: {e}")
106
- sys.stdout.flush()
107
-
108
- TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
109
- if not TTM_AVAILABLE:
110
- logger.warning("⚠️ Diffusers not available. TTM features will be disabled.")
111
- else:
112
- logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}")
113
- sys.stdout.flush()
114
 
115
  # Constants
116
  MAX_FRAMES = 80
@@ -129,126 +46,9 @@ CAMERA_MOVEMENTS = [
129
  "move_down"
130
  ]
131
 
132
- # TTM Constants
133
- TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V"
134
- TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
135
- TTM_DTYPE = torch.bfloat16
136
- TTM_DEFAULT_NUM_FRAMES = 49
137
- TTM_DEFAULT_NUM_INFERENCE_STEPS = 50
138
-
139
- # TTM Model choices
140
- TTM_MODELS = []
141
- if TTM_COG_AVAILABLE:
142
- TTM_MODELS.append("CogVideoX-5B")
143
- if TTM_WAN_AVAILABLE:
144
- TTM_MODELS.append("Wan2.2-14B (Recommended)")
145
-
146
- # Global model instances (lazy loaded for HF Spaces GPU compatibility)
147
- vggt4track_model = None
148
- tracker_model = None
149
- ttm_cog_pipeline = None
150
- ttm_wan_pipeline = None
151
- MODELS_LOADED = False
152
-
153
-
154
- def load_video_to_tensor(video_path: str) -> torch.Tensor:
155
- """Returns a video tensor from a video file. shape [1, C, T, H, W], [0, 1] range."""
156
- cap = cv2.VideoCapture(video_path)
157
- frames = []
158
- while True:
159
- ret, frame = cap.read()
160
- if not ret:
161
- break
162
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
163
- frames.append(frame)
164
- cap.release()
165
-
166
- frames = np.array(frames)
167
- video_tensor = torch.tensor(frames)
168
- video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
169
- video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
170
- return video_tensor
171
-
172
-
173
- def get_ttm_cog_pipeline():
174
- """Lazy load CogVideoX TTM pipeline to save memory."""
175
- global ttm_cog_pipeline
176
- if ttm_cog_pipeline is None and TTM_COG_AVAILABLE:
177
- logger.info("Loading TTM CogVideoX pipeline...")
178
- ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained(
179
- TTM_COG_MODEL_ID,
180
- torch_dtype=TTM_DTYPE,
181
- low_cpu_mem_usage=True,
182
- )
183
- ttm_cog_pipeline.vae.enable_tiling()
184
- ttm_cog_pipeline.vae.enable_slicing()
185
- logger.info("TTM CogVideoX pipeline loaded successfully!")
186
- return ttm_cog_pipeline
187
-
188
-
189
- def get_ttm_wan_pipeline():
190
- """Lazy load Wan TTM pipeline to save memory."""
191
- global ttm_wan_pipeline
192
- if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE:
193
- logger.info("Loading TTM Wan 2.2 pipeline...")
194
- ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained(
195
- TTM_WAN_MODEL_ID,
196
- torch_dtype=TTM_DTYPE,
197
- )
198
- ttm_wan_pipeline.vae.enable_tiling()
199
- ttm_wan_pipeline.vae.enable_slicing()
200
- logger.info("TTM Wan 2.2 pipeline loaded successfully!")
201
- return ttm_wan_pipeline
202
-
203
-
204
- logger.info("Setting up thread pool and utility functions...")
205
- sys.stdout.flush()
206
-
207
  # Thread pool for delayed deletion
208
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
209
 
210
-
211
- def load_models():
212
- """Load models lazily when GPU is available (inside @spaces.GPU decorated function)."""
213
- global vggt4track_model, tracker_model, MODELS_LOADED
214
-
215
- if MODELS_LOADED:
216
- logger.info("Models already loaded, skipping...")
217
- return
218
-
219
- logger.info("🚀 Starting model loading...")
220
- sys.stdout.flush()
221
-
222
- try:
223
- logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...")
224
- sys.stdout.flush()
225
- vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
226
- vggt4track_model.eval()
227
- logger.info("✅ VGGT4Track model loaded, moving to CUDA...")
228
- sys.stdout.flush()
229
- vggt4track_model = vggt4track_model.to("cuda")
230
- logger.info("✅ VGGT4Track model on CUDA")
231
- sys.stdout.flush()
232
-
233
- logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...")
234
- sys.stdout.flush()
235
- tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
236
- tracker_model.eval()
237
- logger.info("✅ Predictor model loaded")
238
- sys.stdout.flush()
239
-
240
- MODELS_LOADED = True
241
- logger.info("✅ All models loaded successfully!")
242
- sys.stdout.flush()
243
-
244
- except Exception as e:
245
- logger.error(f"❌ Failed to load models: {e}")
246
- import traceback
247
- traceback.print_exc()
248
- sys.stdout.flush()
249
- raise
250
-
251
-
252
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
253
  """Delete file or directory after specified delay"""
254
  def _delete():
@@ -267,7 +67,6 @@ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
267
  thread_pool_executor.submit(_wait_and_delete)
268
  atexit.register(_delete)
269
 
270
-
271
  def create_user_temp_dir():
272
  """Create a unique temporary directory for each user session"""
273
  session_id = str(uuid.uuid4())[:8]
@@ -276,16 +75,17 @@ def create_user_temp_dir():
276
  delete_later(temp_dir, delay=600)
277
  return temp_dir
278
 
 
 
 
 
 
279
 
280
- # Note: Models are loaded lazily inside @spaces.GPU decorated functions
281
- # This is required for HF Spaces ZeroGPU compatibility
282
- logger.info("Models will be loaded lazily when GPU is available")
283
- sys.stdout.flush()
284
 
285
- logger.info("Setting up Gradio static paths...")
286
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
287
- logger.info("✅ Static paths configured")
288
- sys.stdout.flush()
289
 
290
 
291
  def generate_camera_trajectory(num_frames: int, movement_type: str,
@@ -311,8 +111,7 @@ def generate_camera_trajectory(num_frames: int, movement_type: str,
311
  if movement_type == "static":
312
  pass # Keep identity
313
  elif movement_type == "move_forward":
314
- # Move along -Z (forward in OpenGL convention)
315
- ext[2, 3] = -speed * t
316
  elif movement_type == "move_backward":
317
  ext[2, 3] = speed * t # Move along +Z
318
  elif movement_type == "move_left":
@@ -369,8 +168,7 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
369
  base_dir = os.path.dirname(output_path)
370
  motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
371
  mask_path = os.path.join(base_dir, "mask.mp4")
372
- out_motion_signal = cv2.VideoWriter(
373
- motion_signal_path, fourcc, fps, (W, H))
374
  out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
375
 
376
  # Create meshgrid for pixel coordinates
@@ -450,21 +248,17 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
450
  if hole_mask.sum() == 0:
451
  break
452
  dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
453
- motion_signal_frame = np.where(
454
- hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
455
- hole_mask = (motion_signal_frame.sum(
456
- axis=-1) == 0).astype(np.uint8)
457
 
458
  # Write TTM outputs if enabled
459
  if generate_ttm_inputs:
460
  # Motion signal: warped frame with NN inpainting
461
- motion_signal_bgr = cv2.cvtColor(
462
- motion_signal_frame, cv2.COLOR_RGB2BGR)
463
  out_motion_signal.write(motion_signal_bgr)
464
 
465
  # Mask: binary mask of valid (projected) pixels - white where valid, black where holes
466
- mask_frame = np.stack(
467
- [valid_mask, valid_mask, valid_mask], axis=-1)
468
  out_mask.write(mask_frame)
469
 
470
  # For the rendered output, use the same inpainted result
@@ -484,7 +278,7 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
484
  }
485
 
486
 
487
- @spaces.GPU(duration=180)
488
  def run_spatial_tracker(video_tensor: torch.Tensor):
489
  """
490
  GPU-intensive spatial tracking function.
@@ -495,23 +289,9 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
495
  Returns:
496
  Dictionary containing tracking results
497
  """
498
- global vggt4track_model, tracker_model
499
-
500
- logger.info("run_spatial_tracker: Starting GPU execution...")
501
- sys.stdout.flush()
502
-
503
- # Load models if not already loaded (lazy loading for HF Spaces)
504
- load_models()
505
-
506
- logger.info("run_spatial_tracker: Preprocessing video input...")
507
- sys.stdout.flush()
508
-
509
  # Run VGGT to get depth and camera poses
510
  video_input = preprocess_image(video_tensor)[None].cuda()
511
 
512
- logger.info("run_spatial_tracker: Running VGGT inference...")
513
- sys.stdout.flush()
514
-
515
  with torch.no_grad():
516
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
517
  predictions = vggt4track_model(video_input / 255)
@@ -520,9 +300,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
520
  depth_map = predictions["points_map"][..., 2]
521
  depth_conf = predictions["unc_metric"]
522
 
523
- logger.info("run_spatial_tracker: VGGT inference complete")
524
- sys.stdout.flush()
525
-
526
  depth_tensor = depth_map.squeeze().cpu().numpy()
527
  extrs = extrinsic.squeeze().cpu().numpy()
528
  intrs = intrinsic.squeeze().cpu().numpy()
@@ -530,20 +307,13 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
530
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
531
 
532
  # Setup tracker
533
- logger.info("run_spatial_tracker: Setting up tracker...")
534
- sys.stdout.flush()
535
-
536
  tracker_model.spatrack.track_num = 512
537
  tracker_model.to("cuda")
538
 
539
  # Get grid points for tracking
540
  frame_H, frame_W = video_tensor_gpu.shape[2:]
541
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
542
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
543
- 0].numpy()
544
-
545
- logger.info("run_spatial_tracker: Running 3D tracker...")
546
- sys.stdout.flush()
547
 
548
  # Run tracker
549
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
@@ -571,11 +341,8 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
571
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
572
  intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
573
 
574
- logger.info("run_spatial_tracker: Moving results to CPU...")
575
- sys.stdout.flush()
576
-
577
  # Move results to CPU and return
578
- result = {
579
  'video_out': video_out.cpu(),
580
  'point_map': point_map.cpu(),
581
  'conf_depth': conf_depth.cpu(),
@@ -583,11 +350,6 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
583
  'c2w_traj': c2w_traj.cpu(),
584
  }
585
 
586
- logger.info("run_spatial_tracker: Complete!")
587
- sys.stdout.flush()
588
-
589
- return result
590
-
591
 
592
  def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
593
  """Main processing function
@@ -643,8 +405,7 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
643
  c2w_traj = tracking_results['c2w_traj']
644
 
645
  # Get RGB frames and depth
646
- rgb_frames = rearrange(
647
- video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
648
  depth_frames = point_map[:, 2].numpy()
649
  depth_conf_np = conf_depth.numpy()
650
 
@@ -655,8 +416,7 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
655
  intrs_np = intrs_out.numpy()
656
  extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
657
 
658
- progress(
659
- 0.7, desc=f"Generating {camera_movement} camera trajectory...")
660
 
661
  # Calculate scene scale from depth
662
  valid_depth = depth_frames[depth_frames > 0]
@@ -711,995 +471,107 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
711
  return None, None, None, None, f"❌ Error: {str(e)}"
712
 
713
 
714
- # TTM CogVideoX Pipeline Helper Classes and Functions
715
- class CogVideoXTTMHelper:
716
- """Helper class for TTM-style video generation using CogVideoX pipeline."""
717
-
718
- def __init__(self, pipeline):
719
- self.pipeline = pipeline
720
- self.vae = pipeline.vae
721
- self.transformer = pipeline.transformer
722
- self.scheduler = pipeline.scheduler
723
- self.vae_scale_factor_spatial = 2 ** (
724
- len(self.vae.config.block_out_channels) - 1)
725
- self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
726
- self.vae_scaling_factor_image = self.vae.config.scaling_factor
727
- self.video_processor = pipeline.video_processor
728
-
729
- @torch.no_grad()
730
- def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
731
- """Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1]."""
732
- latents = self.vae.encode(frames)[0].sample()
733
- latents = latents * self.vae_scaling_factor_image
734
- # (B, C, F, H, W) -> (B, F, C, H, W)
735
- return latents.permute(0, 2, 1, 3, 4).contiguous()
736
-
737
- def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
738
- """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
739
- k = self.vae_scale_factor_temporal
740
-
741
- mask0 = mask[0:1]
742
- mask1 = mask[1::k]
743
- sampled = torch.cat([mask0, mask1], dim=0)
744
- pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
745
-
746
- s = self.vae_scale_factor_spatial
747
- H_latent = pooled.shape[-2] // s
748
- W_latent = pooled.shape[-1] // s
749
- pooled = F.interpolate(pooled, size=(
750
- pooled.shape[2], H_latent, W_latent), mode="nearest")
751
-
752
- latent_mask = pooled.permute(0, 2, 1, 3, 4)
753
- return latent_mask
754
-
755
-
756
- # TTM Wan Pipeline Helper Class
757
- class WanTTMHelper:
758
- """Helper class for TTM-style video generation using Wan pipeline."""
759
-
760
- def __init__(self, pipeline):
761
- self.pipeline = pipeline
762
- self.vae = pipeline.vae
763
- self.transformer = pipeline.transformer
764
- self.scheduler = pipeline.scheduler
765
- self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
766
- self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
767
- self.video_processor = pipeline.video_processor
768
-
769
- def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
770
- """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
771
- k = self.vae_scale_factor_temporal
772
-
773
- mask0 = mask[0:1]
774
- mask1 = mask[1::k]
775
- sampled = torch.cat([mask0, mask1], dim=0)
776
- pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
777
-
778
- s = self.vae_scale_factor_spatial
779
- H_latent = pooled.shape[-2] // s
780
- W_latent = pooled.shape[-1] // s
781
- pooled = F.interpolate(pooled, size=(
782
- pooled.shape[2], H_latent, W_latent), mode="nearest")
783
-
784
- latent_mask = pooled.permute(0, 2, 1, 3, 4)
785
- return latent_mask
786
-
787
-
788
- def compute_hw_from_area(image_height: int, image_width: int, max_area: int, mod_value: int) -> tuple:
789
- """Compute (height, width) with proper aspect ratio and rounding."""
790
- aspect_ratio = image_height / image_width
791
- height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
792
- width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
793
- return int(height), int(width)
794
-
795
-
796
- @spaces.GPU(duration=300)
797
- def run_ttm_cog_generation(
798
- first_frame_path: str,
799
- motion_signal_path: str,
800
- mask_path: str,
801
- prompt: str,
802
- tweak_index: int = 4,
803
- tstrong_index: int = 9,
804
- num_frames: int = 49,
805
- num_inference_steps: int = 50,
806
- guidance_scale: float = 6.0,
807
- seed: int = 0,
808
- progress=gr.Progress()
809
- ):
810
- """
811
- Run TTM-style video generation using CogVideoX pipeline.
812
- Uses the generated motion signal and mask to guide video generation.
813
- """
814
- if not TTM_COG_AVAILABLE:
815
- return None, "❌ CogVideoX TTM is not available. Please install diffusers package."
816
-
817
- if first_frame_path is None or motion_signal_path is None or mask_path is None:
818
- return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
819
-
820
- progress(0, desc="Loading CogVideoX TTM pipeline...")
821
-
822
- try:
823
- # Get or load the pipeline
824
- pipe = get_ttm_cog_pipeline()
825
- if pipe is None:
826
- return None, "❌ Failed to load CogVideoX TTM pipeline"
827
-
828
- pipe = pipe.to("cuda")
829
-
830
- # Create helper
831
- ttm_helper = CogVideoXTTMHelper(pipe)
832
-
833
- progress(0.1, desc="Loading inputs...")
834
-
835
- # Load first frame
836
- image = load_image(first_frame_path)
837
-
838
- # Get dimensions
839
- height = pipe.transformer.config.sample_height * \
840
- ttm_helper.vae_scale_factor_spatial
841
- width = pipe.transformer.config.sample_width * \
842
- ttm_helper.vae_scale_factor_spatial
843
-
844
- device = "cuda"
845
- generator = torch.Generator(device=device).manual_seed(seed)
846
-
847
- progress(0.15, desc="Encoding prompt...")
848
-
849
- # Encode prompt
850
- do_classifier_free_guidance = guidance_scale > 1.0
851
- prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
852
- prompt=prompt,
853
- negative_prompt="",
854
- do_classifier_free_guidance=do_classifier_free_guidance,
855
- num_videos_per_prompt=1,
856
- max_sequence_length=226,
857
- device=device,
858
- )
859
- if do_classifier_free_guidance:
860
- prompt_embeds = torch.cat(
861
- [negative_prompt_embeds, prompt_embeds], dim=0)
862
-
863
- progress(0.2, desc="Preparing latents...")
864
-
865
- # Prepare timesteps
866
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
867
- timesteps = pipe.scheduler.timesteps
868
-
869
- # Prepare latents
870
- latent_frames = (
871
- num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
872
-
873
- # Handle padding for CogVideoX 1.5
874
- patch_size_t = pipe.transformer.config.patch_size_t
875
- additional_frames = 0
876
- if patch_size_t is not None and latent_frames % patch_size_t != 0:
877
- additional_frames = patch_size_t - latent_frames % patch_size_t
878
- num_frames += additional_frames * ttm_helper.vae_scale_factor_temporal
879
-
880
- # Preprocess image
881
- image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to(
882
- device, dtype=prompt_embeds.dtype
883
- )
884
-
885
- latent_channels = pipe.transformer.config.in_channels // 2
886
- latents, image_latents = pipe.prepare_latents(
887
- image_tensor,
888
- 1, # batch_size
889
- latent_channels,
890
- num_frames,
891
- height,
892
- width,
893
- prompt_embeds.dtype,
894
- device,
895
- generator,
896
- None,
897
- )
898
-
899
- progress(0.3, desc="Loading motion signal and mask...")
900
-
901
- # Load motion signal video
902
- ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
903
- refB, refC, refT, refH, refW = ref_vid.shape
904
- ref_vid = F.interpolate(
905
- ref_vid.permute(0, 2, 1, 3, 4).reshape(
906
- refB*refT, refC, refH, refW),
907
- size=(height, width), mode="bicubic", align_corners=True,
908
- ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
909
-
910
- ref_vid = ttm_helper.video_processor.normalize(
911
- ref_vid.to(dtype=pipe.vae.dtype))
912
- ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
913
-
914
- # Load mask video
915
- ref_mask = load_video_to_tensor(mask_path).to(device=device)
916
- mB, mC, mT, mH, mW = ref_mask.shape
917
- ref_mask = F.interpolate(
918
- ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
919
- size=(height, width), mode="nearest",
920
- ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
921
- ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous()
922
-
923
- if len(ref_mask.shape) == 4:
924
- ref_mask = ref_mask.unsqueeze(0)
925
-
926
- ref_mask = ref_mask[0, :, :1].contiguous()
927
- ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0]
928
- motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(ref_mask)
929
- background_mask = 1.0 - motion_mask
930
-
931
- progress(0.35, desc="Initializing TTM denoising...")
932
-
933
- # Initialize with noisy reference latents at tweak timestep
934
- if tweak_index >= 0:
935
- tweak = timesteps[tweak_index]
936
- fixed_noise = randn_tensor(
937
- ref_latents.shape,
938
- generator=generator,
939
- device=ref_latents.device,
940
- dtype=ref_latents.dtype,
941
- )
942
- noisy_latents = pipe.scheduler.add_noise(
943
- ref_latents, fixed_noise, tweak.long())
944
- latents = noisy_latents.to(
945
- dtype=latents.dtype, device=latents.device)
946
- else:
947
- fixed_noise = randn_tensor(
948
- ref_latents.shape,
949
- generator=generator,
950
- device=ref_latents.device,
951
- dtype=ref_latents.dtype,
952
- )
953
- tweak_index = 0
954
-
955
- # Prepare extra step kwargs
956
- extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
957
-
958
- # Create rotary embeddings if required
959
- image_rotary_emb = (
960
- pipe._prepare_rotary_positional_embeddings(
961
- height, width, latents.size(1), device)
962
- if pipe.transformer.config.use_rotary_positional_embeddings
963
- else None
964
- )
965
-
966
- # Create ofs embeddings if required
967
- ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full(
968
- (1,), fill_value=2.0)
969
-
970
- progress(0.4, desc="Running TTM denoising loop...")
971
-
972
- # Denoising loop
973
- total_steps = len(timesteps) - tweak_index
974
- old_pred_original_sample = None
975
-
976
- for i, t in enumerate(timesteps[tweak_index:]):
977
- step_progress = 0.4 + 0.5 * (i / total_steps)
978
- progress(step_progress,
979
- desc=f"Denoising step {i+1}/{total_steps}...")
980
-
981
- latent_model_input = torch.cat(
982
- [latents] * 2) if do_classifier_free_guidance else latents
983
- latent_model_input = pipe.scheduler.scale_model_input(
984
- latent_model_input, t)
985
-
986
- latent_image_input = torch.cat(
987
- [image_latents] * 2) if do_classifier_free_guidance else image_latents
988
- latent_model_input = torch.cat(
989
- [latent_model_input, latent_image_input], dim=2)
990
-
991
- timestep = t.expand(latent_model_input.shape[0])
992
-
993
- # Predict noise
994
- noise_pred = pipe.transformer(
995
- hidden_states=latent_model_input,
996
- encoder_hidden_states=prompt_embeds,
997
- timestep=timestep,
998
- ofs=ofs_emb,
999
- image_rotary_emb=image_rotary_emb,
1000
- return_dict=False,
1001
- )[0]
1002
- noise_pred = noise_pred.float()
1003
-
1004
- # Perform guidance
1005
- if do_classifier_free_guidance:
1006
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1007
- noise_pred = noise_pred_uncond + guidance_scale * \
1008
- (noise_pred_text - noise_pred_uncond)
1009
-
1010
- # Compute previous noisy sample
1011
- if not isinstance(pipe.scheduler, CogVideoXDPMScheduler):
1012
- latents, old_pred_original_sample = pipe.scheduler.step(
1013
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1014
- )
1015
- else:
1016
- latents, old_pred_original_sample = pipe.scheduler.step(
1017
- noise_pred,
1018
- old_pred_original_sample,
1019
- t,
1020
- timesteps[i - 1] if i > 0 else None,
1021
- latents,
1022
- **extra_step_kwargs,
1023
- return_dict=False,
1024
- )
1025
-
1026
- # TTM: In between tweak and tstrong, replace mask with noisy reference latents
1027
- in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
1028
-
1029
- if in_between_tweak_tstrong:
1030
- if i + tweak_index + 1 < len(timesteps):
1031
- prev_t = timesteps[i + tweak_index + 1]
1032
- noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
1033
- dtype=latents.dtype, device=latents.device
1034
- )
1035
- latents = latents * background_mask + noisy_latents * motion_mask
1036
- else:
1037
- latents = latents * background_mask + ref_latents * motion_mask
1038
-
1039
- latents = latents.to(prompt_embeds.dtype)
1040
-
1041
- progress(0.9, desc="Decoding video...")
1042
-
1043
- # Decode latents
1044
- latents = latents[:, additional_frames:]
1045
- frames = pipe.decode_latents(latents)
1046
- video = ttm_helper.video_processor.postprocess_video(
1047
- video=frames, output_type="pil")
1048
-
1049
- progress(0.95, desc="Saving video...")
1050
-
1051
- # Save video
1052
- temp_dir = create_user_temp_dir()
1053
- output_path = os.path.join(temp_dir, "ttm_output.mp4")
1054
- export_to_video(video[0], output_path, fps=8)
1055
-
1056
- progress(1.0, desc="Done!")
1057
-
1058
- return output_path, f"✅ CogVideoX TTM video generated successfully!\n\n**Parameters:**\n- Model: CogVideoX-5B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
1059
-
1060
- except Exception as e:
1061
- logger.error(f"Error in CogVideoX TTM generation: {e}")
1062
- import traceback
1063
- traceback.print_exc()
1064
- return None, f"❌ Error: {str(e)}"
1065
-
1066
-
1067
- @spaces.GPU(duration=300)
1068
- def run_ttm_wan_generation(
1069
- first_frame_path: str,
1070
- motion_signal_path: str,
1071
- mask_path: str,
1072
- prompt: str,
1073
- negative_prompt: str = "",
1074
- tweak_index: int = 3,
1075
- tstrong_index: int = 7,
1076
- num_frames: int = 81,
1077
- num_inference_steps: int = 50,
1078
- guidance_scale: float = 3.5,
1079
- seed: int = 0,
1080
- progress=gr.Progress()
1081
- ):
1082
- """
1083
- Run TTM-style video generation using Wan 2.2 pipeline.
1084
- This is the recommended model for TTM as it produces higher-quality results.
1085
- """
1086
- if not TTM_WAN_AVAILABLE:
1087
- return None, "❌ Wan TTM is not available. Please install diffusers with Wan support."
1088
-
1089
- if first_frame_path is None or motion_signal_path is None or mask_path is None:
1090
- return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
1091
-
1092
- progress(0, desc="Loading Wan 2.2 TTM pipeline...")
1093
-
1094
- try:
1095
- # Get or load the pipeline
1096
- pipe = get_ttm_wan_pipeline()
1097
- if pipe is None:
1098
- return None, "❌ Failed to load Wan TTM pipeline"
1099
-
1100
- pipe = pipe.to("cuda")
1101
-
1102
- # Create helper
1103
- ttm_helper = WanTTMHelper(pipe)
1104
-
1105
- progress(0.1, desc="Loading inputs...")
1106
-
1107
- # Load first frame
1108
- image = load_image(first_frame_path)
1109
-
1110
- # Get dimensions - compute based on image aspect ratio
1111
- max_area = 480 * 832
1112
- mod_value = ttm_helper.vae_scale_factor_spatial * \
1113
- pipe.transformer.config.patch_size[1]
1114
- height, width = compute_hw_from_area(
1115
- image.height, image.width, max_area, mod_value)
1116
- image = image.resize((width, height))
1117
-
1118
- device = "cuda"
1119
- gen_device = device if device.startswith("cuda") else "cpu"
1120
- generator = torch.Generator(device=gen_device).manual_seed(seed)
1121
-
1122
- progress(0.15, desc="Encoding prompt...")
1123
-
1124
- # Encode prompt
1125
- do_classifier_free_guidance = guidance_scale > 1.0
1126
- prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
1127
- prompt=prompt,
1128
- negative_prompt=negative_prompt if negative_prompt else None,
1129
- do_classifier_free_guidance=do_classifier_free_guidance,
1130
- num_videos_per_prompt=1,
1131
- max_sequence_length=512,
1132
- device=device,
1133
- )
1134
-
1135
- # Get transformer dtype
1136
- transformer_dtype = pipe.transformer.dtype
1137
- prompt_embeds = prompt_embeds.to(transformer_dtype)
1138
- if negative_prompt_embeds is not None:
1139
- negative_prompt_embeds = negative_prompt_embeds.to(
1140
- transformer_dtype)
1141
-
1142
- # Encode image embedding if transformer supports it
1143
- image_embeds = None
1144
- if pipe.transformer.config.image_dim is not None:
1145
- image_embeds = pipe.encode_image(image, device)
1146
- image_embeds = image_embeds.repeat(1, 1, 1)
1147
- image_embeds = image_embeds.to(transformer_dtype)
1148
-
1149
- progress(0.2, desc="Preparing latents...")
1150
-
1151
- # Prepare timesteps
1152
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
1153
- timesteps = pipe.scheduler.timesteps
1154
-
1155
- # Adjust num_frames to be valid for VAE
1156
- if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
1157
- num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
1158
- ttm_helper.vae_scale_factor_temporal + 1
1159
- num_frames = max(num_frames, 1)
1160
-
1161
- # Prepare latent variables
1162
- num_channels_latents = pipe.vae.config.z_dim
1163
- image_tensor = ttm_helper.video_processor.preprocess(
1164
- image, height=height, width=width).to(device, dtype=torch.float32)
1165
-
1166
- latents_outputs = pipe.prepare_latents(
1167
- image_tensor,
1168
- 1, # batch_size
1169
- num_channels_latents,
1170
- height,
1171
- width,
1172
- num_frames,
1173
- torch.float32,
1174
- device,
1175
- generator,
1176
- None,
1177
- None, # last_image
1178
- )
1179
-
1180
- if hasattr(pipe, 'config') and pipe.config.expand_timesteps:
1181
- latents, condition, first_frame_mask = latents_outputs
1182
- else:
1183
- latents, condition = latents_outputs
1184
- first_frame_mask = None
1185
-
1186
- progress(0.3, desc="Loading motion signal and mask...")
1187
-
1188
- # Load motion signal video
1189
- ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
1190
- refB, refC, refT, refH, refW = ref_vid.shape
1191
- ref_vid = F.interpolate(
1192
- ref_vid.permute(0, 2, 1, 3, 4).reshape(
1193
- refB*refT, refC, refH, refW),
1194
- size=(height, width), mode="bicubic", align_corners=True,
1195
- ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
1196
-
1197
- ref_vid = ttm_helper.video_processor.normalize(
1198
- ref_vid.to(dtype=pipe.vae.dtype))
1199
- ref_latents = retrieve_latents(
1200
- pipe.vae.encode(ref_vid), sample_mode="argmax")
1201
-
1202
- # Normalize latents
1203
- latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1204
- 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
1205
- latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1206
- 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
1207
- ref_latents = (ref_latents - latents_mean) * latents_std
1208
-
1209
- # Load mask video
1210
- ref_mask = load_video_to_tensor(mask_path).to(device=device)
1211
- mB, mC, mT, mH, mW = ref_mask.shape
1212
- ref_mask = F.interpolate(
1213
- ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
1214
- size=(height, width), mode="nearest",
1215
- ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
1216
- mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()
1217
-
1218
- # Align time dimension
1219
- if mask_tc_hw.shape[0] > num_frames:
1220
- mask_tc_hw = mask_tc_hw[:num_frames]
1221
- elif mask_tc_hw.shape[0] < num_frames:
1222
- return None, f"❌ num_frames ({num_frames}) > mask frames ({mask_tc_hw.shape[0]}). Please use more mask frames."
1223
-
1224
- # Reduce channels if needed
1225
- if mask_tc_hw.shape[1] > 1:
1226
- mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float()
1227
- else:
1228
- mask_t1_hw = (mask_tc_hw > 0.5).float()
1229
-
1230
- motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
1231
- mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
1232
- background_mask = 1.0 - motion_mask
1233
-
1234
- progress(0.35, desc="Initializing TTM denoising...")
1235
-
1236
- # Initialize with noisy reference latents at tweak timestep
1237
- if tweak_index >= 0 and tweak_index < len(timesteps):
1238
- tweak = timesteps[tweak_index]
1239
- fixed_noise = randn_tensor(
1240
- ref_latents.shape,
1241
- generator=generator,
1242
- device=ref_latents.device,
1243
- dtype=ref_latents.dtype,
1244
- )
1245
- tweak_t = torch.as_tensor(
1246
- tweak, device=ref_latents.device, dtype=torch.long).view(1)
1247
- noisy_latents = pipe.scheduler.add_noise(
1248
- ref_latents, fixed_noise, tweak_t.long())
1249
- latents = noisy_latents.to(
1250
- dtype=latents.dtype, device=latents.device)
1251
- else:
1252
- fixed_noise = randn_tensor(
1253
- ref_latents.shape,
1254
- generator=generator,
1255
- device=ref_latents.device,
1256
- dtype=ref_latents.dtype,
1257
- )
1258
- tweak_index = 0
1259
-
1260
- progress(0.4, desc="Running TTM denoising loop...")
1261
-
1262
- # Denoising loop
1263
- total_steps = len(timesteps) - tweak_index
1264
-
1265
- for i, t in enumerate(timesteps[tweak_index:]):
1266
- step_progress = 0.4 + 0.5 * (i / total_steps)
1267
- progress(step_progress,
1268
- desc=f"Denoising step {i+1}/{total_steps}...")
1269
-
1270
- # Prepare model input
1271
- if first_frame_mask is not None:
1272
- latent_model_input = (1 - first_frame_mask) * \
1273
- condition + first_frame_mask * latents
1274
- latent_model_input = latent_model_input.to(transformer_dtype)
1275
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
1276
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
1277
- else:
1278
- latent_model_input = torch.cat(
1279
- [latents, condition], dim=1).to(transformer_dtype)
1280
- timestep = t.expand(latents.shape[0])
1281
-
1282
- # Predict noise (conditional)
1283
- noise_pred = pipe.transformer(
1284
- hidden_states=latent_model_input,
1285
- timestep=timestep,
1286
- encoder_hidden_states=prompt_embeds,
1287
- encoder_hidden_states_image=image_embeds,
1288
- return_dict=False,
1289
- )[0]
1290
-
1291
- # CFG
1292
- if do_classifier_free_guidance:
1293
- noise_uncond = pipe.transformer(
1294
- hidden_states=latent_model_input,
1295
- timestep=timestep,
1296
- encoder_hidden_states=negative_prompt_embeds,
1297
- encoder_hidden_states_image=image_embeds,
1298
- return_dict=False,
1299
- )[0]
1300
- noise_pred = noise_uncond + guidance_scale * \
1301
- (noise_pred - noise_uncond)
1302
-
1303
- # Scheduler step
1304
- latents = pipe.scheduler.step(
1305
- noise_pred, t, latents, return_dict=False)[0]
1306
-
1307
- # TTM: In between tweak and tstrong, replace mask with noisy reference latents
1308
- in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
1309
-
1310
- if in_between_tweak_tstrong:
1311
- if i + tweak_index + 1 < len(timesteps):
1312
- prev_t = timesteps[i + tweak_index + 1]
1313
- prev_t = torch.as_tensor(
1314
- prev_t, device=ref_latents.device, dtype=torch.long).view(1)
1315
- noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
1316
- dtype=latents.dtype, device=latents.device
1317
- )
1318
- latents = latents * background_mask + noisy_latents * motion_mask
1319
- else:
1320
- latents = latents * background_mask + \
1321
- ref_latents.to(dtype=latents.dtype,
1322
- device=latents.device) * motion_mask
1323
-
1324
- progress(0.9, desc="Decoding video...")
1325
-
1326
- # Apply first frame mask if used
1327
- if first_frame_mask is not None:
1328
- latents = (1 - first_frame_mask) * condition + \
1329
- first_frame_mask * latents
1330
-
1331
- # Decode latents
1332
- latents = latents.to(pipe.vae.dtype)
1333
- latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1334
- 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1335
- latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1336
- 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1337
- latents = latents / latents_std + latents_mean
1338
- video = pipe.vae.decode(latents, return_dict=False)[0]
1339
- video = ttm_helper.video_processor.postprocess_video(
1340
- video, output_type="pil")
1341
-
1342
- progress(0.95, desc="Saving video...")
1343
-
1344
- # Save video
1345
- temp_dir = create_user_temp_dir()
1346
- output_path = os.path.join(temp_dir, "ttm_wan_output.mp4")
1347
- export_to_video(video[0], output_path, fps=16)
1348
-
1349
- progress(1.0, desc="Done!")
1350
-
1351
- return output_path, f"✅ Wan 2.2 TTM video generated successfully!\n\n**Parameters:**\n- Model: Wan2.2-14B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
1352
-
1353
- except Exception as e:
1354
- logger.error(f"Error in Wan TTM generation: {e}")
1355
- import traceback
1356
- traceback.print_exc()
1357
- return None, f"❌ Error: {str(e)}"
1358
-
1359
-
1360
- def run_ttm_generation(
1361
- first_frame_path: str,
1362
- motion_signal_path: str,
1363
- mask_path: str,
1364
- prompt: str,
1365
- negative_prompt: str,
1366
- model_choice: str,
1367
- tweak_index: int,
1368
- tstrong_index: int,
1369
- num_frames: int,
1370
- num_inference_steps: int,
1371
- guidance_scale: float,
1372
- seed: int,
1373
- progress=gr.Progress()
1374
- ):
1375
- """
1376
- Router function that calls the appropriate TTM generation based on model choice.
1377
- """
1378
- if "Wan" in model_choice:
1379
- return run_ttm_wan_generation(
1380
- first_frame_path=first_frame_path,
1381
- motion_signal_path=motion_signal_path,
1382
- mask_path=mask_path,
1383
- prompt=prompt,
1384
- negative_prompt=negative_prompt,
1385
- tweak_index=tweak_index,
1386
- tstrong_index=tstrong_index,
1387
- num_frames=num_frames,
1388
- num_inference_steps=num_inference_steps,
1389
- guidance_scale=guidance_scale,
1390
- seed=seed,
1391
- progress=progress,
1392
- )
1393
- else:
1394
- return run_ttm_cog_generation(
1395
- first_frame_path=first_frame_path,
1396
- motion_signal_path=motion_signal_path,
1397
- mask_path=mask_path,
1398
- prompt=prompt,
1399
- tweak_index=tweak_index,
1400
- tstrong_index=tstrong_index,
1401
- num_frames=num_frames,
1402
- num_inference_steps=num_inference_steps,
1403
- guidance_scale=guidance_scale,
1404
- seed=seed,
1405
- progress=progress,
1406
- )
1407
-
1408
-
1409
  # Create Gradio interface
1410
- logger.info("🎨 Creating Gradio interface...")
1411
- sys.stdout.flush()
1412
 
1413
  with gr.Blocks(
1414
  theme=gr.themes.Soft(),
1415
  title="🎬 Video to Point Cloud Renderer",
1416
  css="""
1417
  .gradio-container {
1418
- max-width: 1400px !important;
1419
  margin: auto !important;
1420
  }
1421
  """
1422
  ) as demo:
1423
  gr.Markdown("""
1424
- # 🎬 Video to Point Cloud Renderer + TTM Video Generation
1425
 
1426
- Upload a video to generate a 3D point cloud, render it from a new camera perspective,
1427
- and optionally run **Time-to-Move (TTM)** for motion-controlled video generation.
1428
 
1429
- **Workflow:**
1430
- 1. **Step 1**: Upload a video and select camera movement → Generate motion signal & mask
1431
- 2. **Step 2**: (Optional) Run TTM to generate a high-quality video with the motion signal
 
1432
 
1433
- **TTM (Time-to-Move)** uses dual-clock denoising to guide video generation using:
1434
- - `first_frame.png`: Starting frame
1435
- - `motion_signal.mp4`: Warped video showing desired motion
1436
- - `mask.mp4`: Binary mask for motion regions
1437
  """)
1438
 
1439
- # State to store paths for TTM
1440
- first_frame_state = gr.State(None)
1441
- motion_signal_state = gr.State(None)
1442
- mask_state = gr.State(None)
1443
-
1444
- with gr.Tabs():
1445
- with gr.Tab("📥 Step 1: Generate Motion Signal"):
1446
- with gr.Row():
1447
- with gr.Column(scale=1):
1448
- gr.Markdown("### 📥 Input")
1449
- video_input = gr.Video(
1450
- label="Upload Video",
1451
- format="mp4",
1452
- height=300
1453
- )
1454
-
1455
- camera_movement = gr.Dropdown(
1456
- choices=CAMERA_MOVEMENTS,
1457
- value="static",
1458
- label="🎥 Camera Movement",
1459
- info="Select how the camera should move in the rendered video"
1460
- )
1461
-
1462
- generate_ttm = gr.Checkbox(
1463
- label="🎯 Generate TTM Inputs",
1464
- value=True,
1465
- info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
1466
- )
1467
-
1468
- generate_btn = gr.Button(
1469
- "🚀 Generate Motion Signal", variant="primary", size="lg")
1470
-
1471
- with gr.Column(scale=1):
1472
- gr.Markdown("### 📤 Rendered Output")
1473
- output_video = gr.Video(
1474
- label="Rendered Video",
1475
- height=250
1476
- )
1477
- first_frame_output = gr.Image(
1478
- label="First Frame (first_frame.png)",
1479
- height=150
1480
- )
1481
-
1482
- with gr.Row():
1483
- with gr.Column(scale=1):
1484
- gr.Markdown("### 🎯 TTM: Motion Signal")
1485
- motion_signal_output = gr.Video(
1486
- label="Motion Signal Video (motion_signal.mp4)",
1487
- height=250
1488
- )
1489
- with gr.Column(scale=1):
1490
- gr.Markdown("### 🎭 TTM: Mask")
1491
- mask_output = gr.Video(
1492
- label="Mask Video (mask.mp4)",
1493
- height=250
1494
- )
1495
-
1496
- status_text = gr.Markdown("Ready to process...")
1497
-
1498
- with gr.Tab("🎬 Step 2: TTM Video Generation"):
1499
- cog_available = "✅" if TTM_COG_AVAILABLE else "❌"
1500
- wan_available = "✅" if TTM_WAN_AVAILABLE else "❌"
1501
- gr.Markdown(f"""
1502
- ### 🎬 Time-to-Move (TTM) Video Generation
1503
-
1504
- **Model Availability:**
1505
- - {cog_available} CogVideoX-5B-I2V
1506
- - {wan_available} Wan 2.2-14B (Recommended - higher quality)
1507
-
1508
- **TTM Parameters:**
1509
- - **tweak_index**: When denoising starts *outside* the mask (lower = more dynamic background)
1510
- - **tstrong_index**: When denoising starts *inside* the mask (higher = more constrained motion)
1511
-
1512
- **Recommended values:**
1513
- - CogVideoX - Cut-and-Drag: `tweak_index=4`, `tstrong_index=9`
1514
- - CogVideoX - Camera control: `tweak_index=3`, `tstrong_index=7`
1515
- - **Wan 2.2 (Recommended)**: `tweak_index=3`, `tstrong_index=7`
1516
- """)
1517
-
1518
- with gr.Row():
1519
- with gr.Column(scale=1):
1520
- gr.Markdown("### ⚙️ TTM Settings")
1521
-
1522
- ttm_model_choice = gr.Dropdown(
1523
- choices=TTM_MODELS,
1524
- value=TTM_MODELS[1] if TTM_WAN_AVAILABLE else TTM_MODELS[0],
1525
- label="Model",
1526
- info="Wan 2.2 is recommended for higher quality"
1527
- )
1528
-
1529
- ttm_prompt = gr.Textbox(
1530
- label="Prompt",
1531
- placeholder="Describe the video content...",
1532
- value="A high quality video, smooth motion, natural lighting",
1533
- lines=2
1534
- )
1535
-
1536
- ttm_negative_prompt = gr.Textbox(
1537
- label="Negative Prompt (Wan only)",
1538
- placeholder="Things to avoid in the video...",
1539
- value="",
1540
- lines=1,
1541
- visible=True
1542
- )
1543
-
1544
- with gr.Row():
1545
- ttm_tweak_index = gr.Slider(
1546
- minimum=0,
1547
- maximum=20,
1548
- value=3,
1549
- step=1,
1550
- label="tweak_index",
1551
- info="When background denoising starts"
1552
- )
1553
- ttm_tstrong_index = gr.Slider(
1554
- minimum=0,
1555
- maximum=30,
1556
- value=7,
1557
- step=1,
1558
- label="tstrong_index",
1559
- info="When mask region denoising starts"
1560
- )
1561
-
1562
- with gr.Row():
1563
- ttm_num_frames = gr.Slider(
1564
- minimum=17,
1565
- maximum=81,
1566
- value=49,
1567
- step=4,
1568
- label="Number of Frames"
1569
- )
1570
- ttm_guidance_scale = gr.Slider(
1571
- minimum=1.0,
1572
- maximum=15.0,
1573
- value=3.5,
1574
- step=0.5,
1575
- label="Guidance Scale"
1576
- )
1577
-
1578
- with gr.Row():
1579
- ttm_num_steps = gr.Slider(
1580
- minimum=20,
1581
- maximum=100,
1582
- value=50,
1583
- step=5,
1584
- label="Inference Steps"
1585
- )
1586
- ttm_seed = gr.Number(
1587
- value=0,
1588
- label="Seed",
1589
- precision=0
1590
- )
1591
-
1592
- ttm_generate_btn = gr.Button(
1593
- "🎬 Generate TTM Video",
1594
- variant="primary",
1595
- size="lg",
1596
- interactive=TTM_AVAILABLE
1597
- )
1598
-
1599
- with gr.Column(scale=1):
1600
- gr.Markdown("### 📤 TTM Output")
1601
- ttm_output_video = gr.Video(
1602
- label="TTM Generated Video",
1603
- height=400
1604
- )
1605
- ttm_status_text = gr.Markdown(
1606
- "Upload a video in Step 1 first, then run TTM here.")
1607
-
1608
- # TTM Input preview
1609
- with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False):
1610
- with gr.Row():
1611
- ttm_preview_first_frame = gr.Image(
1612
- label="First Frame",
1613
- height=150
1614
- )
1615
- ttm_preview_motion = gr.Video(
1616
- label="Motion Signal",
1617
- height=150
1618
- )
1619
- ttm_preview_mask = gr.Video(
1620
- label="Mask",
1621
- height=150
1622
- )
1623
-
1624
- # Helper function to update states and preview
1625
- def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()):
1626
- result = process_video(video_path, camera_movement,
1627
- generate_ttm_flag, progress)
1628
- output_vid, motion_sig, mask_vid, first_frame, status = result
1629
-
1630
- # Return all outputs including state updates and previews
1631
- return (
1632
- output_vid, # output_video
1633
- motion_sig, # motion_signal_output
1634
- mask_vid, # mask_output
1635
- first_frame, # first_frame_output
1636
- status, # status_text
1637
- first_frame, # first_frame_state
1638
- motion_sig, # motion_signal_state
1639
- mask_vid, # mask_state
1640
- first_frame, # ttm_preview_first_frame
1641
- motion_sig, # ttm_preview_motion
1642
- mask_vid, # ttm_preview_mask
1643
- )
1644
 
1645
  # Event handlers
1646
  generate_btn.click(
1647
- fn=process_and_update_states,
1648
  inputs=[video_input, camera_movement, generate_ttm],
1649
- outputs=[
1650
- output_video, motion_signal_output, mask_output, first_frame_output, status_text,
1651
- first_frame_state, motion_signal_state, mask_state,
1652
- ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
1653
- ]
1654
- )
1655
-
1656
- # TTM generation event
1657
- ttm_generate_btn.click(
1658
- fn=run_ttm_generation,
1659
- inputs=[
1660
- first_frame_state,
1661
- motion_signal_state,
1662
- mask_state,
1663
- ttm_prompt,
1664
- ttm_negative_prompt,
1665
- ttm_model_choice,
1666
- ttm_tweak_index,
1667
- ttm_tstrong_index,
1668
- ttm_num_frames,
1669
- ttm_num_steps,
1670
- ttm_guidance_scale,
1671
- ttm_seed
1672
- ],
1673
- outputs=[ttm_output_video, ttm_status_text]
1674
  )
1675
 
1676
  # Examples
1677
  gr.Markdown("### 📁 Examples")
1678
  if os.path.exists("./examples"):
1679
- example_videos = [f for f in os.listdir(
1680
- "./examples") if f.endswith(".mp4")][:4]
1681
  if example_videos:
1682
  gr.Examples(
1683
- examples=[[f"./examples/{v}", "move_forward", True]
1684
- for v in example_videos],
1685
  inputs=[video_input, camera_movement, generate_ttm],
1686
- outputs=[
1687
- output_video, motion_signal_output, mask_output, first_frame_output, status_text,
1688
- first_frame_state, motion_signal_state, mask_state,
1689
- ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
1690
- ],
1691
- fn=process_and_update_states,
1692
  cache_examples=False
1693
  )
1694
 
1695
  # Launch
1696
- logger.info("✅ Gradio interface created successfully!")
1697
- logger.info("=" * 50)
1698
- logger.info("Application ready to launch")
1699
- logger.info("=" * 50)
1700
- sys.stdout.flush()
1701
-
1702
  if __name__ == "__main__":
1703
- logger.info("Starting Gradio server...")
1704
- sys.stdout.flush()
1705
  demo.launch(share=False)
 
 
1
  import gradio as gr
2
  import os
3
  import numpy as np
 
7
  from pathlib import Path
8
  from einops import rearrange
9
  from typing import Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
  import spaces
 
12
  except ImportError:
13
+ def spaces(func):
14
+ return func
 
 
 
 
 
 
 
 
 
15
  import torch
 
 
 
 
16
  import torchvision.transforms as T
17
+ import logging
18
  from concurrent.futures import ThreadPoolExecutor
19
  import atexit
20
  import uuid
 
 
 
21
  import decord
 
 
22
 
23
+ from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
24
+ from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
25
+ from models.SpaTrackV2.models.predictor import Predictor
26
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
27
 
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Constants
33
  MAX_FRAMES = 80
 
46
  "move_down"
47
  ]
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Thread pool for delayed deletion
50
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
53
  """Delete file or directory after specified delay"""
54
  def _delete():
 
67
  thread_pool_executor.submit(_wait_and_delete)
68
  atexit.register(_delete)
69
 
 
70
  def create_user_temp_dir():
71
  """Create a unique temporary directory for each user session"""
72
  session_id = str(uuid.uuid4())[:8]
 
75
  delete_later(temp_dir, delay=600)
76
  return temp_dir
77
 
78
+ # Global model initialization
79
+ print("🚀 Initializing models...")
80
+ vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
81
+ vggt4track_model.eval()
82
+ vggt4track_model = vggt4track_model.to("cuda")
83
 
84
+ tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
85
+ tracker_model.eval()
86
+ print("Models loaded successfully!")
 
87
 
 
88
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
 
 
89
 
90
 
91
  def generate_camera_trajectory(num_frames: int, movement_type: str,
 
111
  if movement_type == "static":
112
  pass # Keep identity
113
  elif movement_type == "move_forward":
114
+ ext[2, 3] = -speed * t # Move along -Z (forward in OpenGL convention)
 
115
  elif movement_type == "move_backward":
116
  ext[2, 3] = speed * t # Move along +Z
117
  elif movement_type == "move_left":
 
168
  base_dir = os.path.dirname(output_path)
169
  motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
170
  mask_path = os.path.join(base_dir, "mask.mp4")
171
+ out_motion_signal = cv2.VideoWriter(motion_signal_path, fourcc, fps, (W, H))
 
172
  out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
173
 
174
  # Create meshgrid for pixel coordinates
 
248
  if hole_mask.sum() == 0:
249
  break
250
  dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
251
+ motion_signal_frame = np.where(hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
252
+ hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
 
 
253
 
254
  # Write TTM outputs if enabled
255
  if generate_ttm_inputs:
256
  # Motion signal: warped frame with NN inpainting
257
+ motion_signal_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR)
 
258
  out_motion_signal.write(motion_signal_bgr)
259
 
260
  # Mask: binary mask of valid (projected) pixels - white where valid, black where holes
261
+ mask_frame = np.stack([valid_mask, valid_mask, valid_mask], axis=-1)
 
262
  out_mask.write(mask_frame)
263
 
264
  # For the rendered output, use the same inpainted result
 
278
  }
279
 
280
 
281
+ @spaces.GPU
282
  def run_spatial_tracker(video_tensor: torch.Tensor):
283
  """
284
  GPU-intensive spatial tracking function.
 
289
  Returns:
290
  Dictionary containing tracking results
291
  """
 
 
 
 
 
 
 
 
 
 
 
292
  # Run VGGT to get depth and camera poses
293
  video_input = preprocess_image(video_tensor)[None].cuda()
294
 
 
 
 
295
  with torch.no_grad():
296
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
297
  predictions = vggt4track_model(video_input / 255)
 
300
  depth_map = predictions["points_map"][..., 2]
301
  depth_conf = predictions["unc_metric"]
302
 
 
 
 
303
  depth_tensor = depth_map.squeeze().cpu().numpy()
304
  extrs = extrinsic.squeeze().cpu().numpy()
305
  intrs = intrinsic.squeeze().cpu().numpy()
 
307
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
308
 
309
  # Setup tracker
 
 
 
310
  tracker_model.spatrack.track_num = 512
311
  tracker_model.to("cuda")
312
 
313
  # Get grid points for tracking
314
  frame_H, frame_W = video_tensor_gpu.shape[2:]
315
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
316
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
 
 
 
 
317
 
318
  # Run tracker
319
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
 
341
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
342
  intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
343
 
 
 
 
344
  # Move results to CPU and return
345
+ return {
346
  'video_out': video_out.cpu(),
347
  'point_map': point_map.cpu(),
348
  'conf_depth': conf_depth.cpu(),
 
350
  'c2w_traj': c2w_traj.cpu(),
351
  }
352
 
 
 
 
 
 
353
 
354
  def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
355
  """Main processing function
 
405
  c2w_traj = tracking_results['c2w_traj']
406
 
407
  # Get RGB frames and depth
408
+ rgb_frames = rearrange(video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
 
409
  depth_frames = point_map[:, 2].numpy()
410
  depth_conf_np = conf_depth.numpy()
411
 
 
416
  intrs_np = intrs_out.numpy()
417
  extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
418
 
419
+ progress(0.7, desc=f"Generating {camera_movement} camera trajectory...")
 
420
 
421
  # Calculate scene scale from depth
422
  valid_depth = depth_frames[depth_frames > 0]
 
471
  return None, None, None, None, f"❌ Error: {str(e)}"
472
 
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # Create Gradio interface
475
+ print("🎨 Creating Gradio interface...")
 
476
 
477
  with gr.Blocks(
478
  theme=gr.themes.Soft(),
479
  title="🎬 Video to Point Cloud Renderer",
480
  css="""
481
  .gradio-container {
482
+ max-width: 1200px !important;
483
  margin: auto !important;
484
  }
485
  """
486
  ) as demo:
487
  gr.Markdown("""
488
+ # 🎬 Video to Point Cloud Renderer (TTM Compatible)
489
 
490
+ Upload a video to generate a 3D point cloud and render it from a new camera perspective.
491
+ Generates outputs compatible with **Time-to-Move (TTM)** motion-controlled video generation.
492
 
493
+ **How it works:**
494
+ 1. Upload a video
495
+ 2. Select a camera movement type
496
+ 3. Click "Generate" to create the rendered video and TTM inputs
497
 
498
+ **TTM Inputs:**
499
+ - `first_frame.png`: The first frame of the original video
500
+ - `motion_signal.mp4`: Warped video with nearest-neighbor inpainting
501
+ - `mask.mp4`: Binary mask showing valid projected pixels (white) vs holes (black)
502
  """)
503
 
504
+ with gr.Row():
505
+ with gr.Column(scale=1):
506
+ gr.Markdown("### 📥 Input")
507
+ video_input = gr.Video(
508
+ label="Upload Video",
509
+ format="mp4",
510
+ height=300
511
+ )
512
+
513
+ camera_movement = gr.Dropdown(
514
+ choices=CAMERA_MOVEMENTS,
515
+ value="static",
516
+ label="🎥 Camera Movement",
517
+ info="Select how the camera should move in the rendered video"
518
+ )
519
+
520
+ generate_ttm = gr.Checkbox(
521
+ label="🎯 Generate TTM Inputs",
522
+ value=True,
523
+ info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
524
+ )
525
+
526
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
527
+
528
+ with gr.Column(scale=1):
529
+ gr.Markdown("### 📤 Rendered Output")
530
+ output_video = gr.Video(
531
+ label="Rendered Video",
532
+ height=250
533
+ )
534
+ first_frame_output = gr.Image(
535
+ label="First Frame (first_frame.png)",
536
+ height=150
537
+ )
538
+
539
+ with gr.Row():
540
+ with gr.Column(scale=1):
541
+ gr.Markdown("### 🎯 TTM: Motion Signal")
542
+ motion_signal_output = gr.Video(
543
+ label="Motion Signal Video (motion_signal.mp4)",
544
+ height=250
545
+ )
546
+ with gr.Column(scale=1):
547
+ gr.Markdown("### 🎭 TTM: Mask")
548
+ mask_output = gr.Video(
549
+ label="Mask Video (mask.mp4)",
550
+ height=250
551
+ )
552
+
553
+ status_text = gr.Markdown("Ready to process...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
  # Event handlers
556
  generate_btn.click(
557
+ fn=process_video,
558
  inputs=[video_input, camera_movement, generate_ttm],
559
+ outputs=[output_video, motion_signal_output, mask_output, first_frame_output, status_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  )
561
 
562
  # Examples
563
  gr.Markdown("### 📁 Examples")
564
  if os.path.exists("./examples"):
565
+ example_videos = [f for f in os.listdir("./examples") if f.endswith(".mp4")][:4]
 
566
  if example_videos:
567
  gr.Examples(
568
+ examples=[[f"./examples/{v}", "move_forward", True] for v in example_videos],
 
569
  inputs=[video_input, camera_movement, generate_ttm],
570
+ outputs=[output_video, motion_signal_output, mask_output, first_frame_output, status_text],
571
+ fn=process_video,
 
 
 
 
572
  cache_examples=False
573
  )
574
 
575
  # Launch
 
 
 
 
 
 
576
  if __name__ == "__main__":
 
 
577
  demo.launch(share=False)
src/config.py DELETED
@@ -1,44 +0,0 @@
1
- import torch
2
-
3
- MAX_FRAMES = 80
4
- OUTPUT_FPS = 24
5
- RENDER_WIDTH = 512
6
- RENDER_HEIGHT = 384
7
-
8
- CAMERA_MOVEMENTS = [
9
- "static",
10
- "move_forward",
11
- "move_backward",
12
- "move_left",
13
- "move_right",
14
- "move_up",
15
- "move_down"
16
- ]
17
-
18
- TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V"
19
- TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
20
- TTM_DTYPE = torch.bfloat16
21
- TTM_DEFAULT_NUM_FRAMES = 49
22
- TTM_DEFAULT_NUM_INFERENCE_STEPS = 50
23
-
24
- TTM_COG_AVAILABLE = False
25
- TTM_WAN_AVAILABLE = False
26
- try:
27
- from diffusers import CogVideoXImageToVideoPipeline
28
- TTM_COG_AVAILABLE = True
29
- except ImportError:
30
- pass
31
-
32
- try:
33
- from diffusers import AutoencoderKLWan, WanTransformer3DModel
34
- TTM_WAN_AVAILABLE = True
35
- except ImportError:
36
- pass
37
-
38
- TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
39
-
40
- TTM_MODELS = []
41
- if TTM_COG_AVAILABLE:
42
- TTM_MODELS.append("CogVideoX-5B")
43
- if TTM_WAN_AVAILABLE:
44
- TTM_MODELS.append("Wan2.2-14B (Recommended)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model_manager.py DELETED
@@ -1,62 +0,0 @@
1
- from models.SpaTrackV2.models.predictor import Predictor
2
- from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
3
- import logging
4
- from .config import (
5
- TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE,
6
- TTM_COG_MODEL_ID, TTM_WAN_MODEL_ID, TTM_DTYPE
7
- )
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- vggt4track_model = None
12
- tracker_model = None
13
- ttm_cog_pipeline = None
14
- ttm_wan_pipeline = None
15
-
16
-
17
- def init_spatial_models():
18
- global vggt4track_model, tracker_model
19
- print("🚀 Initializing models...")
20
- vggt4track_model = VGGT4Track.from_pretrained(
21
- "Yuxihenry/SpatialTrackerV2_Front")
22
- vggt4track_model.eval()
23
- vggt4track_model = vggt4track_model.to("cuda")
24
-
25
- tracker_model = Predictor.from_pretrained(
26
- "Yuxihenry/SpatialTrackerV2-Offline")
27
- tracker_model.eval()
28
- print("✅ Spatial Models loaded successfully!")
29
-
30
-
31
- def get_ttm_cog_pipeline():
32
- global ttm_cog_pipeline
33
- if ttm_cog_pipeline is None and TTM_COG_AVAILABLE:
34
- from diffusers import CogVideoXImageToVideoPipeline
35
- logger.info("Loading TTM CogVideoX pipeline...")
36
- ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained(
37
- TTM_COG_MODEL_ID,
38
- torch_dtype=TTM_DTYPE,
39
- low_cpu_mem_usage=True,
40
- )
41
- ttm_cog_pipeline.vae.enable_tiling()
42
- ttm_cog_pipeline.vae.enable_slicing()
43
- logger.info("TTM CogVideoX pipeline loaded successfully!")
44
- return ttm_cog_pipeline
45
-
46
-
47
- def get_ttm_wan_pipeline():
48
- global ttm_wan_pipeline
49
- if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE:
50
- from diffusers import WanImageToVideoPipeline
51
- logger.info("Loading TTM Wan 2.2 pipeline...")
52
- ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained(
53
- TTM_WAN_MODEL_ID,
54
- torch_dtype=TTM_DTYPE,
55
- )
56
- ttm_wan_pipeline.vae.enable_tiling()
57
- ttm_wan_pipeline.vae.enable_slicing()
58
- logger.info("TTM Wan 2.2 pipeline loaded successfully!")
59
- return ttm_wan_pipeline
60
-
61
-
62
- init_spatial_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/spatial_pipeline.py DELETED
@@ -1,277 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
- import decord
6
- import gradio as gr
7
- import torchvision.transforms as T
8
- from einops import rearrange
9
-
10
- from .config import MAX_FRAMES, OUTPUT_FPS
11
- from .utils import logger, create_user_temp_dir
12
- from . import model_manager
13
- from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
14
- from models.SpaTrackV2.models.utils import get_points_on_a_grid
15
-
16
- try:
17
- import spaces
18
- except ImportError:
19
- class spaces:
20
- @staticmethod
21
- def GPU(func=None, duration=None):
22
- def decorator(f):
23
- return f
24
- return decorator if func is None else func
25
-
26
-
27
- def generate_camera_trajectory(num_frames: int, movement_type: str,
28
- base_intrinsics: np.ndarray,
29
- scene_scale: float = 1.0) -> np.ndarray:
30
- speed = scene_scale * 0.02
31
- extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
32
-
33
- for t in range(num_frames):
34
- ext = np.eye(4, dtype=np.float32)
35
- if movement_type == "static":
36
- pass
37
- elif movement_type == "move_forward":
38
- ext[2, 3] = -speed * t
39
- elif movement_type == "move_backward":
40
- ext[2, 3] = speed * t
41
- elif movement_type == "move_left":
42
- ext[0, 3] = -speed * t
43
- elif movement_type == "move_right":
44
- ext[0, 3] = speed * t
45
- elif movement_type == "move_up":
46
- ext[1, 3] = -speed * t
47
- elif movement_type == "move_down":
48
- ext[1, 3] = speed * t
49
- extrinsics[t] = ext
50
- return extrinsics
51
-
52
-
53
- def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics,
54
- new_extrinsics, output_path, fps=24, generate_ttm_inputs=False):
55
- T, H, W, _ = rgb_frames.shape
56
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
57
- out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
58
-
59
- motion_signal_path = None
60
- mask_path = None
61
- out_motion_signal = None
62
- out_mask = None
63
-
64
- if generate_ttm_inputs:
65
- base_dir = os.path.dirname(output_path)
66
- motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
67
- mask_path = os.path.join(base_dir, "mask.mp4")
68
- out_motion_signal = cv2.VideoWriter(
69
- motion_signal_path, fourcc, fps, (W, H))
70
- out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
71
-
72
- u, v = np.meshgrid(np.arange(W), np.arange(H))
73
- ones = np.ones_like(u)
74
-
75
- for t in range(T):
76
- rgb = rgb_frames[t]
77
- depth = depth_frames[t]
78
- K = intrinsics[t]
79
- orig_c2w = np.linalg.inv(original_extrinsics[t])
80
-
81
- if t == 0:
82
- base_c2w = orig_c2w.copy()
83
- new_c2w = base_c2w @ new_extrinsics[t]
84
- new_w2c = np.linalg.inv(new_c2w)
85
- K_inv = np.linalg.inv(K)
86
-
87
- pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3)
88
- rays_cam = (K_inv @ pixels.T).T
89
- points_cam = rays_cam * depth.reshape(-1, 1)
90
- points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
91
- points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
92
- points_proj = (K @ points_new_cam.T).T
93
-
94
- z = np.clip(points_proj[:, 2:3], 1e-6, None)
95
- uv_new = points_proj[:, :2] / z
96
-
97
- rendered = np.zeros((H, W, 3), dtype=np.uint8)
98
- z_buffer = np.full((H, W), np.inf, dtype=np.float32)
99
- colors = rgb.reshape(-1, 3)
100
- depths_new = points_new_cam[:, 2]
101
-
102
- # Rasterization loop (simplified)
103
- for i in range(len(uv_new)):
104
- uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
105
- if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
106
- if depths_new[i] < z_buffer[vv, uu]:
107
- z_buffer[vv, uu] = depths_new[i]
108
- rendered[vv, uu] = colors[i]
109
-
110
- # Inpainting for TTM
111
- valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
112
- motion_signal_frame = rendered.copy()
113
- hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
114
-
115
- if hole_mask.sum() > 0:
116
- kernel = np.ones((3, 3), np.uint8)
117
- for _ in range(max(H, W)):
118
- if hole_mask.sum() == 0:
119
- break
120
- dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
121
- motion_signal_frame = np.where(
122
- hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
123
- hole_mask = (motion_signal_frame.sum(
124
- axis=-1) == 0).astype(np.uint8)
125
-
126
- if generate_ttm_inputs:
127
- out_motion_signal.write(cv2.cvtColor(
128
- motion_signal_frame, cv2.COLOR_RGB2BGR))
129
- out_mask.write(np.stack([valid_mask]*3, axis=-1))
130
-
131
- out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR))
132
-
133
- out.release()
134
- if generate_ttm_inputs:
135
- out_motion_signal.release()
136
- out_mask.release()
137
-
138
- return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
139
-
140
-
141
- @spaces.GPU
142
- def run_spatial_tracker(video_tensor: torch.Tensor):
143
- video_input = preprocess_image(video_tensor)[None].cuda()
144
-
145
- # Use global models from model_manager
146
- with torch.no_grad():
147
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
148
- predictions = model_manager.vggt4track_model(video_input / 255)
149
- extrinsic = predictions["poses_pred"]
150
- intrinsic = predictions["intrs"]
151
- depth_map = predictions["points_map"][..., 2]
152
- depth_conf = predictions["unc_metric"]
153
-
154
- depth_tensor = depth_map.squeeze().cpu().numpy()
155
- extrs = extrinsic.squeeze().cpu().numpy()
156
- intrs = intrinsic.squeeze().cpu().numpy()
157
- video_tensor_gpu = video_input.squeeze()
158
- unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
159
-
160
- model_manager.tracker_model.spatrack.track_num = 512
161
- model_manager.tracker_model.to("cuda")
162
-
163
- frame_H, frame_W = video_tensor_gpu.shape[2:]
164
- grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
165
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
166
- 0].numpy()
167
-
168
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
169
- results = model_manager.tracker_model.forward(
170
- video_tensor_gpu, depth=depth_tensor,
171
- intrs=intrs, extrs=extrs,
172
- queries=query_xyt,
173
- fps=1, full_point=False, iters_track=4,
174
- query_no_BA=True, fixed_cam=False, stage=1,
175
- unc_metric=unc_metric,
176
- support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
177
- )
178
-
179
- # Unpack tuple from tracker
180
- c2w_traj, intrs_out, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video_out = results
181
-
182
- # Resize logic (abbreviated)
183
- max_size = 384
184
- h, w = video_out.shape[2:]
185
- scale = min(max_size / h, max_size / w)
186
- if scale < 1:
187
- new_h, new_w = int(h * scale), int(w * scale)
188
- video_out = T.Resize((new_h, new_w))(video_out)
189
- point_map = T.Resize((new_h, new_w))(point_map)
190
- conf_depth = T.Resize((new_h, new_w))(conf_depth)
191
- intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
192
-
193
- return {
194
- 'video_out': video_out.cpu(),
195
- 'point_map': point_map.cpu(),
196
- 'conf_depth': conf_depth.cpu(),
197
- 'intrs_out': intrs_out.cpu(),
198
- 'c2w_traj': c2w_traj.cpu(),
199
- }
200
-
201
-
202
- def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
203
- if video_path is None:
204
- return None, None, None, None, "❌ Please upload a video first"
205
-
206
- progress(0, desc="Initializing...")
207
- temp_dir = create_user_temp_dir()
208
- out_dir = os.path.join(temp_dir, "results")
209
- os.makedirs(out_dir, exist_ok=True)
210
-
211
- try:
212
- progress(0.1, desc="Loading video...")
213
- video_reader = decord.VideoReader(video_path)
214
- video_tensor = torch.from_numpy(
215
- video_reader.get_batch(range(len(video_reader))).asnumpy()
216
- ).permute(0, 3, 1, 2).float()
217
-
218
- fps_skip = max(1, len(video_tensor) // MAX_FRAMES)
219
- video_tensor = video_tensor[::fps_skip][:MAX_FRAMES]
220
-
221
- h, w = video_tensor.shape[2:]
222
- scale = 336 / min(h, w)
223
- if scale < 1:
224
- new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2
225
- video_tensor = T.Resize((new_h, new_w))(video_tensor)
226
-
227
- progress(0.4, desc="Running 3D tracking...")
228
- tracking_results = run_spatial_tracker(video_tensor)
229
-
230
- video_out = tracking_results['video_out']
231
- point_map = tracking_results['point_map']
232
- conf_depth = tracking_results['conf_depth']
233
- intrs_out = tracking_results['intrs_out']
234
- c2w_traj = tracking_results['c2w_traj']
235
-
236
- rgb_frames = rearrange(
237
- video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
238
- depth_frames = point_map[:, 2].numpy()
239
- depth_frames[conf_depth.numpy() < 0.5] = 0
240
-
241
- intrs_np = intrs_out.numpy()
242
- extrs_np = torch.inverse(c2w_traj).numpy()
243
-
244
- progress(
245
- 0.7, desc=f"Generating {camera_movement} camera trajectory...")
246
- valid_depth = depth_frames[depth_frames > 0]
247
- scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0
248
-
249
- new_extrinsics = generate_camera_trajectory(
250
- len(rgb_frames), camera_movement, intrs_np, scene_scale
251
- )
252
-
253
- progress(0.8, desc="Rendering video...")
254
- output_video_path = os.path.join(out_dir, "rendered_video.mp4")
255
- render_results = render_from_pointcloud(
256
- rgb_frames, depth_frames, intrs_np, extrs_np,
257
- new_extrinsics, output_video_path, fps=OUTPUT_FPS,
258
- generate_ttm_inputs=generate_ttm
259
- )
260
-
261
- first_frame_path = None
262
- if generate_ttm:
263
- first_frame_path = os.path.join(out_dir, "first_frame.png")
264
- cv2.imwrite(first_frame_path, cv2.cvtColor(
265
- rgb_frames[0], cv2.COLOR_RGB2BGR))
266
-
267
- status_msg = f"✅ Video rendered successfully with '{camera_movement}'!"
268
- if generate_ttm:
269
- status_msg += "\n\n📁 **TTM outputs generated**"
270
-
271
- return render_results['rendered'], render_results.get('motion_signal'), render_results.get('mask'), first_frame_path, status_msg
272
-
273
- except Exception as e:
274
- logger.error(f"Error processing video: {e}")
275
- import traceback
276
- traceback.print_exc()
277
- return None, None, None, None, f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ttm_pipeline.py DELETED
@@ -1,303 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn.functional as F
4
- import gradio as gr
5
- from diffusers.utils import export_to_video, load_image
6
- from diffusers.utils.torch_utils import randn_tensor
7
- from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents
8
-
9
-
10
- try:
11
- import spaces
12
- except ImportError:
13
- class spaces:
14
- @staticmethod
15
- def GPU(func=None, duration=None):
16
- def decorator(f):
17
- return f
18
- return decorator if func is None else func
19
-
20
-
21
- from .config import TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE
22
- from .utils import create_user_temp_dir, load_video_to_tensor
23
- from . import model_manager
24
-
25
- # --- Helper Classes ---
26
-
27
-
28
- class CogVideoXTTMHelper:
29
- def __init__(self, pipeline):
30
- self.pipeline = pipeline
31
- self.vae = pipeline.vae
32
- self.vae_scale_factor_spatial = 2 ** (
33
- len(self.vae.config.block_out_channels) - 1)
34
- self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
35
- self.vae_scaling_factor_image = self.vae.config.scaling_factor
36
- self.video_processor = pipeline.video_processor
37
-
38
- @torch.no_grad()
39
- def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
40
- latents = self.vae.encode(
41
- frames)[0].sample() * self.vae_scaling_factor_image
42
- return latents.permute(0, 2, 1, 3, 4).contiguous()
43
-
44
- def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
45
- k = self.vae_scale_factor_temporal
46
- mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
47
- pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
48
- s = self.vae_scale_factor_spatial
49
- H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
50
- pooled = F.interpolate(pooled, size=(
51
- pooled.shape[2], H_l, W_l), mode="nearest")
52
- return pooled.permute(0, 2, 1, 3, 4)
53
-
54
-
55
- class WanTTMHelper:
56
- def __init__(self, pipeline):
57
- self.pipeline = pipeline
58
- self.vae = pipeline.vae
59
- self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
60
- self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
61
- self.video_processor = pipeline.video_processor
62
-
63
- def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
64
- k = self.vae_scale_factor_temporal
65
- mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
66
- pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
67
- s = self.vae_scale_factor_spatial
68
- H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
69
- pooled = F.interpolate(pooled, size=(
70
- pooled.shape[2], H_l, W_l), mode="nearest")
71
- return pooled.permute(0, 2, 1, 3, 4)
72
-
73
-
74
- def compute_hw_from_area(h, w, max_area, mod_value):
75
- aspect = h / w
76
- height = round(np.sqrt(max_area * aspect)) // mod_value * mod_value
77
- width = round(np.sqrt(max_area / aspect)) // mod_value * mod_value
78
- return int(height), int(width)
79
-
80
- # --- Generation Functions ---
81
-
82
-
83
- @spaces.GPU(duration=300)
84
- def run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
85
- tweak_index=4, tstrong_index=9, num_frames=49,
86
- num_inference_steps=50, guidance_scale=6.0, seed=0, progress=gr.Progress()):
87
- if not TTM_COG_AVAILABLE:
88
- return None, "❌ CogVideoX TTM not available."
89
-
90
- pipe = model_manager.get_ttm_cog_pipeline()
91
- if not pipe:
92
- return None, "❌ Failed to load pipeline"
93
- pipe = pipe.to("cuda")
94
- ttm_helper = CogVideoXTTMHelper(pipe)
95
-
96
- device = "cuda"
97
- generator = torch.Generator(device=device).manual_seed(seed)
98
-
99
- image = load_image(first_frame_path)
100
- height = pipe.transformer.config.sample_height * \
101
- ttm_helper.vae_scale_factor_spatial
102
- width = pipe.transformer.config.sample_width * \
103
- ttm_helper.vae_scale_factor_spatial
104
-
105
- do_cfg = guidance_scale > 1.0
106
- prompt_embeds, neg_embeds = pipe.encode_prompt(
107
- prompt, "", do_cfg, 1, 226, device)
108
- if do_cfg:
109
- prompt_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0)
110
-
111
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
112
- timesteps = pipe.scheduler.timesteps
113
-
114
- latent_frames = (
115
- num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
116
- image_tensor = ttm_helper.video_processor.preprocess(
117
- image, height=height, width=width).to(device, dtype=prompt_embeds.dtype)
118
- latent_channels = pipe.transformer.config.in_channels // 2
119
- latents, image_latents = pipe.prepare_latents(
120
- image_tensor, 1, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, None)
121
-
122
- ref_vid = load_video_to_tensor(motion_signal_path).to(device)
123
- ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
124
- height, width), mode="bicubic").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
125
- ref_vid = ttm_helper.video_processor.normalize(
126
- ref_vid.to(dtype=pipe.vae.dtype))
127
- ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
128
-
129
- ref_mask = load_video_to_tensor(mask_path).to(device)
130
- ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
131
- height, width), mode="nearest").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
132
- motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
133
- ref_mask[0, :, :1].permute(1, 0, 2, 3).contiguous())
134
- background_mask = 1.0 - motion_mask
135
-
136
- fixed_noise = randn_tensor(
137
- ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
138
- if tweak_index >= 0:
139
- latents = pipe.scheduler.add_noise(
140
- ref_latents, fixed_noise, timesteps[tweak_index].long()).to(dtype=latents.dtype)
141
-
142
- extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
143
-
144
- for i, t in enumerate(timesteps[tweak_index:]):
145
- progress(0.4 + 0.5 * (i / len(timesteps)), desc="Denoising...")
146
-
147
- latent_input = torch.cat([latents] * 2) if do_cfg else latents
148
- latent_input = pipe.scheduler.scale_model_input(latent_input, t)
149
- latent_input = torch.cat([latent_input, torch.cat(
150
- [image_latents]*2) if do_cfg else image_latents], dim=2)
151
-
152
- noise_pred = pipe.transformer(hidden_states=latent_input, encoder_hidden_states=prompt_embeds, timestep=t.expand(
153
- latent_input.shape[0]), return_dict=False)[0].float()
154
-
155
- if do_cfg:
156
- uncond, text = noise_pred.chunk(2)
157
- noise_pred = uncond + guidance_scale * (text - uncond)
158
-
159
- latents = pipe.scheduler.step(
160
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
161
-
162
- if (i + tweak_index) < tstrong_index:
163
- next_t = timesteps[i + tweak_index + 1] if i + \
164
- tweak_index + 1 < len(timesteps) else None
165
- if next_t is not None:
166
- noisy_ref = pipe.scheduler.add_noise(
167
- ref_latents, fixed_noise, next_t.long()).to(dtype=latents.dtype)
168
- latents = latents * background_mask + noisy_ref * motion_mask
169
- else:
170
- latents = latents * background_mask + ref_latents * motion_mask
171
-
172
- latents = latents.to(prompt_embeds.dtype)
173
-
174
- frames = pipe.decode_latents(latents)
175
- video = ttm_helper.video_processor.postprocess_video(
176
- video=frames, output_type="pil")
177
-
178
- out_path = os.path.join(create_user_temp_dir(), "ttm_cog_out.mp4")
179
- export_to_video(video[0], out_path, fps=8)
180
- return out_path, "✅ Video Generated"
181
-
182
-
183
- @spaces.GPU(duration=300)
184
- def run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt="",
185
- tweak_index=3, tstrong_index=7, num_frames=81, num_inference_steps=50,
186
- guidance_scale=3.5, seed=0, progress=gr.Progress()):
187
- if not TTM_WAN_AVAILABLE:
188
- return None, "❌ Wan TTM not available."
189
-
190
- pipe = model_manager.get_ttm_wan_pipeline()
191
- if not pipe:
192
- return None, "❌ Failed to load pipeline"
193
- pipe = pipe.to("cuda")
194
- ttm_helper = WanTTMHelper(pipe)
195
-
196
- device = "cuda"
197
- generator = torch.Generator(device=device).manual_seed(seed)
198
-
199
- image = load_image(first_frame_path)
200
- h, w = compute_hw_from_area(image.height, image.width, 480*832,
201
- ttm_helper.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1])
202
- image = image.resize((w, h))
203
-
204
- do_cfg = guidance_scale > 1.0
205
- prompt_embeds, neg_embeds = pipe.encode_prompt(
206
- prompt, negative_prompt, do_cfg, 1, 512, device)
207
- prompt_embeds = prompt_embeds.to(pipe.transformer.dtype)
208
- if neg_embeds is not None:
209
- neg_embeds = neg_embeds.to(pipe.transformer.dtype)
210
-
211
- image_embeds = pipe.encode_image(image, device).repeat(1, 1, 1).to(
212
- pipe.transformer.dtype) if pipe.transformer.config.image_dim else None
213
-
214
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
215
- timesteps = pipe.scheduler.timesteps
216
-
217
- if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
218
- num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
219
- ttm_helper.vae_scale_factor_temporal + 1
220
-
221
- image_tensor = ttm_helper.video_processor.preprocess(
222
- image, height=h, width=w).to(device, dtype=torch.float32)
223
- latents, condition = pipe.prepare_latents(
224
- image_tensor, 1, pipe.vae.config.z_dim, h, w, num_frames, torch.float32, device, generator, None, None)
225
-
226
- ref_vid = load_video_to_tensor(motion_signal_path).to(device)
227
- ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
228
- h, w), mode="bicubic").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
229
- ref_vid = ttm_helper.video_processor.normalize(
230
- ref_vid.to(dtype=pipe.vae.dtype))
231
- ref_latents = retrieve_latents(
232
- pipe.vae.encode(ref_vid), sample_mode="argmax")
233
-
234
- mean = torch.tensor(pipe.vae.config.latents_mean).view(
235
- 1, -1, 1, 1, 1).to(device, ref_latents.dtype)
236
- std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, -
237
- 1, 1, 1, 1).to(device, ref_latents.dtype)
238
- ref_latents = (ref_latents - mean) * std
239
-
240
- ref_mask = load_video_to_tensor(mask_path).to(device)
241
- ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
242
- h, w), mode="nearest").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
243
- mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()[:num_frames]
244
- motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
245
- (mask_tc_hw > 0.5).float()).permute(0, 2, 1, 3, 4).contiguous()
246
- background_mask = 1.0 - motion_mask
247
-
248
- fixed_noise = randn_tensor(
249
- ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
250
- if tweak_index >= 0:
251
- latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, torch.as_tensor(
252
- timesteps[tweak_index], device=device).long())
253
-
254
- for i, t in enumerate(timesteps[tweak_index:]):
255
- progress(0.4 + 0.5 * (i / len(timesteps)), desc=f"Step {i}")
256
-
257
- latent_in = torch.cat([latents, condition], dim=1).to(
258
- pipe.transformer.dtype)
259
- ts = t.expand(latents.shape[0])
260
-
261
- noise_pred = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=prompt_embeds,
262
- encoder_hidden_states_image=image_embeds, return_dict=False)[0]
263
-
264
- if do_cfg:
265
- noise_uncond = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=neg_embeds,
266
- encoder_hidden_states_image=image_embeds, return_dict=False)[0]
267
- noise_pred = noise_uncond + guidance_scale * \
268
- (noise_pred - noise_uncond)
269
-
270
- latents = pipe.scheduler.step(
271
- noise_pred, t, latents, return_dict=False)[0]
272
-
273
- if (i + tweak_index) < tstrong_index:
274
- next_t = timesteps[i + tweak_index + 1] if i + \
275
- tweak_index + 1 < len(timesteps) else None
276
- if next_t is not None:
277
- noisy_ref = pipe.scheduler.add_noise(
278
- ref_latents, fixed_noise, torch.as_tensor(next_t, device=device).long())
279
- latents = latents * background_mask + noisy_ref * motion_mask
280
- else:
281
- latents = latents * background_mask + \
282
- ref_latents.to(latents.dtype) * motion_mask
283
-
284
- latents = latents.to(pipe.vae.dtype)
285
- latents = latents / std + mean
286
- video = pipe.vae.decode(latents, return_dict=False)[0]
287
- video = ttm_helper.video_processor.postprocess_video(
288
- video, output_type="pil")
289
-
290
- out_path = os.path.join(create_user_temp_dir(), "ttm_wan_out.mp4")
291
- export_to_video(video[0], out_path, fps=16)
292
- return out_path, "✅ Video Generated"
293
-
294
-
295
- def run_ttm_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
296
- model_choice, tweak_index, tstrong_index, num_frames, num_inference_steps,
297
- guidance_scale, seed, progress=gr.Progress()):
298
- if "Wan" in model_choice:
299
- return run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
300
- tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)
301
- else:
302
- return run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
303
- tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,57 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import shutil
5
- import logging
6
- import uuid
7
- import torch
8
- import numpy as np
9
- import atexit
10
- from concurrent.futures import ThreadPoolExecutor
11
- from typing import Union
12
-
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- thread_pool_executor = ThreadPoolExecutor(max_workers=2)
17
-
18
- def delete_later(path: Union[str, os.PathLike], delay: int = 600):
19
- def _delete():
20
- try:
21
- if os.path.isfile(path):
22
- os.remove(path)
23
- elif os.path.isdir(path):
24
- shutil.rmtree(path)
25
- except Exception as e:
26
- logger.warning(f"Failed to delete {path}: {e}")
27
-
28
- def _wait_and_delete():
29
- time.sleep(delay)
30
- _delete()
31
-
32
- thread_pool_executor.submit(_wait_and_delete)
33
- atexit.register(_delete)
34
-
35
- def create_user_temp_dir():
36
- session_id = str(uuid.uuid4())[:8]
37
- temp_dir = os.path.join("temp_local", f"session_{session_id}")
38
- os.makedirs(temp_dir, exist_ok=True)
39
- delete_later(temp_dir, delay=600)
40
- return temp_dir
41
-
42
- def load_video_to_tensor(video_path: str) -> torch.Tensor:
43
- cap = cv2.VideoCapture(video_path)
44
- frames = []
45
- while True:
46
- ret, frame = cap.read()
47
- if not ret:
48
- break
49
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
50
- frames.append(frame)
51
- cap.release()
52
-
53
- frames = np.array(frames)
54
- video_tensor = torch.tensor(frames)
55
- video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
56
- video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
57
- return video_tensor