abreza commited on
Commit
c8dc4de
·
1 Parent(s): 530e09a

add some logs to check run issue

Browse files
Files changed (1) hide show
  1. app.py +284 -89
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import os
3
  import numpy as np
@@ -6,32 +7,75 @@ import time
6
  import shutil
7
  from pathlib import Path
8
  from einops import rearrange
9
- from typing import Union, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
  import spaces
 
12
  except ImportError:
 
13
  class spaces:
14
  @staticmethod
15
  def GPU(func=None, duration=None):
16
  def decorator(f):
17
  return f
18
  return decorator if func is None else func
 
 
 
 
19
  import torch
 
 
 
20
  import torch.nn.functional as F
21
  import torchvision.transforms as T
22
- import logging
23
  from concurrent.futures import ThreadPoolExecutor
24
  import atexit
25
  import uuid
 
 
 
26
  import decord
 
 
 
27
  from PIL import Image
28
 
29
- from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
30
- from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
31
- from models.SpaTrackV2.models.predictor import Predictor
32
- from models.SpaTrackV2.models.utils import get_points_on_a_grid
 
 
 
 
 
 
 
 
33
 
34
  # TTM imports (optional - will be loaded on demand)
 
 
35
  TTM_COG_AVAILABLE = False
36
  TTM_WAN_AVAILABLE = False
37
  try:
@@ -41,8 +85,10 @@ try:
41
  from diffusers.utils.torch_utils import randn_tensor
42
  from diffusers.video_processor import VideoProcessor
43
  TTM_COG_AVAILABLE = True
44
- except ImportError:
45
- pass
 
 
46
 
47
  try:
48
  from diffusers import AutoencoderKLWan, WanTransformer3DModel
@@ -54,17 +100,17 @@ try:
54
  from diffusers.utils.torch_utils import randn_tensor
55
  from diffusers.video_processor import VideoProcessor
56
  TTM_WAN_AVAILABLE = True
57
- except ImportError:
58
- pass
 
 
59
 
60
  TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
61
  if not TTM_AVAILABLE:
62
- logger_init = logging.getLogger(__name__)
63
- logger_init.warning("Diffusers not available. TTM features will be disabled.")
64
-
65
- # Configure logging
66
- logging.basicConfig(level=logging.INFO)
67
- logger = logging.getLogger(__name__)
68
 
69
  # Constants
70
  MAX_FRAMES = 80
@@ -97,9 +143,12 @@ if TTM_COG_AVAILABLE:
97
  if TTM_WAN_AVAILABLE:
98
  TTM_MODELS.append("Wan2.2-14B (Recommended)")
99
 
100
- # Global TTM pipelines (lazy loaded)
 
 
101
  ttm_cog_pipeline = None
102
  ttm_wan_pipeline = None
 
103
 
104
 
105
  def load_video_to_tensor(video_path: str) -> torch.Tensor:
@@ -150,11 +199,56 @@ def get_ttm_wan_pipeline():
150
  ttm_wan_pipeline.vae.enable_slicing()
151
  logger.info("TTM Wan 2.2 pipeline loaded successfully!")
152
  return ttm_wan_pipeline
153
- return ttm_pipeline
 
 
 
154
 
155
  # Thread pool for delayed deletion
156
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
159
  """Delete file or directory after specified delay"""
160
  def _delete():
@@ -173,6 +267,7 @@ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
173
  thread_pool_executor.submit(_wait_and_delete)
174
  atexit.register(_delete)
175
 
 
176
  def create_user_temp_dir():
177
  """Create a unique temporary directory for each user session"""
178
  session_id = str(uuid.uuid4())[:8]
@@ -181,17 +276,16 @@ def create_user_temp_dir():
181
  delete_later(temp_dir, delay=600)
182
  return temp_dir
183
 
184
- # Global model initialization
185
- print("🚀 Initializing models...")
186
- vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
187
- vggt4track_model.eval()
188
- vggt4track_model = vggt4track_model.to("cuda")
189
 
190
- tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
191
- tracker_model.eval()
192
- print("Models loaded successfully!")
 
193
 
 
194
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
 
 
195
 
196
 
197
  def generate_camera_trajectory(num_frames: int, movement_type: str,
@@ -217,7 +311,8 @@ def generate_camera_trajectory(num_frames: int, movement_type: str,
217
  if movement_type == "static":
218
  pass # Keep identity
219
  elif movement_type == "move_forward":
220
- ext[2, 3] = -speed * t # Move along -Z (forward in OpenGL convention)
 
221
  elif movement_type == "move_backward":
222
  ext[2, 3] = speed * t # Move along +Z
223
  elif movement_type == "move_left":
@@ -274,7 +369,8 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
274
  base_dir = os.path.dirname(output_path)
275
  motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
276
  mask_path = os.path.join(base_dir, "mask.mp4")
277
- out_motion_signal = cv2.VideoWriter(motion_signal_path, fourcc, fps, (W, H))
 
278
  out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
279
 
280
  # Create meshgrid for pixel coordinates
@@ -354,17 +450,21 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
354
  if hole_mask.sum() == 0:
355
  break
356
  dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
357
- motion_signal_frame = np.where(hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
358
- hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
 
 
359
 
360
  # Write TTM outputs if enabled
361
  if generate_ttm_inputs:
362
  # Motion signal: warped frame with NN inpainting
363
- motion_signal_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR)
 
364
  out_motion_signal.write(motion_signal_bgr)
365
 
366
  # Mask: binary mask of valid (projected) pixels - white where valid, black where holes
367
- mask_frame = np.stack([valid_mask, valid_mask, valid_mask], axis=-1)
 
368
  out_mask.write(mask_frame)
369
 
370
  # For the rendered output, use the same inpainted result
@@ -384,7 +484,7 @@ def render_from_pointcloud(rgb_frames: np.ndarray,
384
  }
385
 
386
 
387
- @spaces.GPU
388
  def run_spatial_tracker(video_tensor: torch.Tensor):
389
  """
390
  GPU-intensive spatial tracking function.
@@ -395,9 +495,23 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
395
  Returns:
396
  Dictionary containing tracking results
397
  """
 
 
 
 
 
 
 
 
 
 
 
398
  # Run VGGT to get depth and camera poses
399
  video_input = preprocess_image(video_tensor)[None].cuda()
400
 
 
 
 
401
  with torch.no_grad():
402
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
403
  predictions = vggt4track_model(video_input / 255)
@@ -406,6 +520,9 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
406
  depth_map = predictions["points_map"][..., 2]
407
  depth_conf = predictions["unc_metric"]
408
 
 
 
 
409
  depth_tensor = depth_map.squeeze().cpu().numpy()
410
  extrs = extrinsic.squeeze().cpu().numpy()
411
  intrs = intrinsic.squeeze().cpu().numpy()
@@ -413,13 +530,20 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
413
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
414
 
415
  # Setup tracker
 
 
 
416
  tracker_model.spatrack.track_num = 512
417
  tracker_model.to("cuda")
418
 
419
  # Get grid points for tracking
420
  frame_H, frame_W = video_tensor_gpu.shape[2:]
421
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
422
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
 
 
 
 
423
 
424
  # Run tracker
425
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
@@ -447,8 +571,11 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
447
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
448
  intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
449
 
 
 
 
450
  # Move results to CPU and return
451
- return {
452
  'video_out': video_out.cpu(),
453
  'point_map': point_map.cpu(),
454
  'conf_depth': conf_depth.cpu(),
@@ -456,6 +583,11 @@ def run_spatial_tracker(video_tensor: torch.Tensor):
456
  'c2w_traj': c2w_traj.cpu(),
457
  }
458
 
 
 
 
 
 
459
 
460
  def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
461
  """Main processing function
@@ -511,7 +643,8 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
511
  c2w_traj = tracking_results['c2w_traj']
512
 
513
  # Get RGB frames and depth
514
- rgb_frames = rearrange(video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
 
515
  depth_frames = point_map[:, 2].numpy()
516
  depth_conf_np = conf_depth.numpy()
517
 
@@ -522,7 +655,8 @@ def process_video(video_path: str, camera_movement: str, generate_ttm: bool = Tr
522
  intrs_np = intrs_out.numpy()
523
  extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
524
 
525
- progress(0.7, desc=f"Generating {camera_movement} camera trajectory...")
 
526
 
527
  # Calculate scene scale from depth
528
  valid_depth = depth_frames[depth_frames > 0]
@@ -586,7 +720,8 @@ class CogVideoXTTMHelper:
586
  self.vae = pipeline.vae
587
  self.transformer = pipeline.transformer
588
  self.scheduler = pipeline.scheduler
589
- self.vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
590
  self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
591
  self.vae_scaling_factor_image = self.vae.config.scaling_factor
592
  self.video_processor = pipeline.video_processor
@@ -596,7 +731,8 @@ class CogVideoXTTMHelper:
596
  """Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1]."""
597
  latents = self.vae.encode(frames)[0].sample()
598
  latents = latents * self.vae_scaling_factor_image
599
- return latents.permute(0, 2, 1, 3, 4).contiguous() # (B, C, F, H, W) -> (B, F, C, H, W)
 
600
 
601
  def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
602
  """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
@@ -610,7 +746,8 @@ class CogVideoXTTMHelper:
610
  s = self.vae_scale_factor_spatial
611
  H_latent = pooled.shape[-2] // s
612
  W_latent = pooled.shape[-1] // s
613
- pooled = F.interpolate(pooled, size=(pooled.shape[2], H_latent, W_latent), mode="nearest")
 
614
 
615
  latent_mask = pooled.permute(0, 2, 1, 3, 4)
616
  return latent_mask
@@ -641,7 +778,8 @@ class WanTTMHelper:
641
  s = self.vae_scale_factor_spatial
642
  H_latent = pooled.shape[-2] // s
643
  W_latent = pooled.shape[-1] // s
644
- pooled = F.interpolate(pooled, size=(pooled.shape[2], H_latent, W_latent), mode="nearest")
 
645
 
646
  latent_mask = pooled.permute(0, 2, 1, 3, 4)
647
  return latent_mask
@@ -698,8 +836,10 @@ def run_ttm_cog_generation(
698
  image = load_image(first_frame_path)
699
 
700
  # Get dimensions
701
- height = pipe.transformer.config.sample_height * ttm_helper.vae_scale_factor_spatial
702
- width = pipe.transformer.config.sample_width * ttm_helper.vae_scale_factor_spatial
 
 
703
 
704
  device = "cuda"
705
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -717,7 +857,8 @@ def run_ttm_cog_generation(
717
  device=device,
718
  )
719
  if do_classifier_free_guidance:
720
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
 
721
 
722
  progress(0.2, desc="Preparing latents...")
723
 
@@ -726,7 +867,8 @@ def run_ttm_cog_generation(
726
  timesteps = pipe.scheduler.timesteps
727
 
728
  # Prepare latents
729
- latent_frames = (num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
 
730
 
731
  # Handle padding for CogVideoX 1.5
732
  patch_size_t = pipe.transformer.config.patch_size_t
@@ -760,11 +902,13 @@ def run_ttm_cog_generation(
760
  ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
761
  refB, refC, refT, refH, refW = ref_vid.shape
762
  ref_vid = F.interpolate(
763
- ref_vid.permute(0, 2, 1, 3, 4).reshape(refB*refT, refC, refH, refW),
 
764
  size=(height, width), mode="bicubic", align_corners=True,
765
  ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
766
 
767
- ref_vid = ttm_helper.video_processor.normalize(ref_vid.to(dtype=pipe.vae.dtype))
 
768
  ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
769
 
770
  # Load mask video
@@ -795,8 +939,10 @@ def run_ttm_cog_generation(
795
  device=ref_latents.device,
796
  dtype=ref_latents.dtype,
797
  )
798
- noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, tweak.long())
799
- latents = noisy_latents.to(dtype=latents.dtype, device=latents.device)
 
 
800
  else:
801
  fixed_noise = randn_tensor(
802
  ref_latents.shape,
@@ -811,13 +957,15 @@ def run_ttm_cog_generation(
811
 
812
  # Create rotary embeddings if required
813
  image_rotary_emb = (
814
- pipe._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
 
815
  if pipe.transformer.config.use_rotary_positional_embeddings
816
  else None
817
  )
818
 
819
  # Create ofs embeddings if required
820
- ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
 
821
 
822
  progress(0.4, desc="Running TTM denoising loop...")
823
 
@@ -827,13 +975,18 @@ def run_ttm_cog_generation(
827
 
828
  for i, t in enumerate(timesteps[tweak_index:]):
829
  step_progress = 0.4 + 0.5 * (i / total_steps)
830
- progress(step_progress, desc=f"Denoising step {i+1}/{total_steps}...")
 
831
 
832
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
833
- latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
 
 
834
 
835
- latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
836
- latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
 
 
837
 
838
  timestep = t.expand(latent_model_input.shape[0])
839
 
@@ -851,7 +1004,8 @@ def run_ttm_cog_generation(
851
  # Perform guidance
852
  if do_classifier_free_guidance:
853
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
854
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
855
 
856
  # Compute previous noisy sample
857
  if not isinstance(pipe.scheduler, CogVideoXDPMScheduler):
@@ -889,7 +1043,8 @@ def run_ttm_cog_generation(
889
  # Decode latents
890
  latents = latents[:, additional_frames:]
891
  frames = pipe.decode_latents(latents)
892
- video = ttm_helper.video_processor.postprocess_video(video=frames, output_type="pil")
 
893
 
894
  progress(0.95, desc="Saving video...")
895
 
@@ -954,8 +1109,10 @@ def run_ttm_wan_generation(
954
 
955
  # Get dimensions - compute based on image aspect ratio
956
  max_area = 480 * 832
957
- mod_value = ttm_helper.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
958
- height, width = compute_hw_from_area(image.height, image.width, max_area, mod_value)
 
 
959
  image = image.resize((width, height))
960
 
961
  device = "cuda"
@@ -979,7 +1136,8 @@ def run_ttm_wan_generation(
979
  transformer_dtype = pipe.transformer.dtype
980
  prompt_embeds = prompt_embeds.to(transformer_dtype)
981
  if negative_prompt_embeds is not None:
982
- negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
 
983
 
984
  # Encode image embedding if transformer supports it
985
  image_embeds = None
@@ -996,12 +1154,14 @@ def run_ttm_wan_generation(
996
 
997
  # Adjust num_frames to be valid for VAE
998
  if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
999
- num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * ttm_helper.vae_scale_factor_temporal + 1
 
1000
  num_frames = max(num_frames, 1)
1001
 
1002
  # Prepare latent variables
1003
  num_channels_latents = pipe.vae.config.z_dim
1004
- image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
 
1005
 
1006
  latents_outputs = pipe.prepare_latents(
1007
  image_tensor,
@@ -1029,16 +1189,21 @@ def run_ttm_wan_generation(
1029
  ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
1030
  refB, refC, refT, refH, refW = ref_vid.shape
1031
  ref_vid = F.interpolate(
1032
- ref_vid.permute(0, 2, 1, 3, 4).reshape(refB*refT, refC, refH, refW),
 
1033
  size=(height, width), mode="bicubic", align_corners=True,
1034
  ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
1035
 
1036
- ref_vid = ttm_helper.video_processor.normalize(ref_vid.to(dtype=pipe.vae.dtype))
1037
- ref_latents = retrieve_latents(pipe.vae.encode(ref_vid), sample_mode="argmax")
 
 
1038
 
1039
  # Normalize latents
1040
- latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
1041
- latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
 
 
1042
  ref_latents = (ref_latents - latents_mean) * latents_std
1043
 
1044
  # Load mask video
@@ -1062,7 +1227,8 @@ def run_ttm_wan_generation(
1062
  else:
1063
  mask_t1_hw = (mask_tc_hw > 0.5).float()
1064
 
1065
- motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
 
1066
  background_mask = 1.0 - motion_mask
1067
 
1068
  progress(0.35, desc="Initializing TTM denoising...")
@@ -1076,9 +1242,12 @@ def run_ttm_wan_generation(
1076
  device=ref_latents.device,
1077
  dtype=ref_latents.dtype,
1078
  )
1079
- tweak_t = torch.as_tensor(tweak, device=ref_latents.device, dtype=torch.long).view(1)
1080
- noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, tweak_t.long())
1081
- latents = noisy_latents.to(dtype=latents.dtype, device=latents.device)
 
 
 
1082
  else:
1083
  fixed_noise = randn_tensor(
1084
  ref_latents.shape,
@@ -1095,16 +1264,19 @@ def run_ttm_wan_generation(
1095
 
1096
  for i, t in enumerate(timesteps[tweak_index:]):
1097
  step_progress = 0.4 + 0.5 * (i / total_steps)
1098
- progress(step_progress, desc=f"Denoising step {i+1}/{total_steps}...")
 
1099
 
1100
  # Prepare model input
1101
  if first_frame_mask is not None:
1102
- latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
 
1103
  latent_model_input = latent_model_input.to(transformer_dtype)
1104
  temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
1105
  timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
1106
  else:
1107
- latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
 
1108
  timestep = t.expand(latents.shape[0])
1109
 
1110
  # Predict noise (conditional)
@@ -1125,10 +1297,12 @@ def run_ttm_wan_generation(
1125
  encoder_hidden_states_image=image_embeds,
1126
  return_dict=False,
1127
  )[0]
1128
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
 
1129
 
1130
  # Scheduler step
1131
- latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
1132
 
1133
  # TTM: In between tweak and tstrong, replace mask with noisy reference latents
1134
  in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
@@ -1136,27 +1310,34 @@ def run_ttm_wan_generation(
1136
  if in_between_tweak_tstrong:
1137
  if i + tweak_index + 1 < len(timesteps):
1138
  prev_t = timesteps[i + tweak_index + 1]
1139
- prev_t = torch.as_tensor(prev_t, device=ref_latents.device, dtype=torch.long).view(1)
 
1140
  noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
1141
  dtype=latents.dtype, device=latents.device
1142
  )
1143
  latents = latents * background_mask + noisy_latents * motion_mask
1144
  else:
1145
- latents = latents * background_mask + ref_latents.to(dtype=latents.dtype, device=latents.device) * motion_mask
 
 
1146
 
1147
  progress(0.9, desc="Decoding video...")
1148
 
1149
  # Apply first frame mask if used
1150
  if first_frame_mask is not None:
1151
- latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
 
1152
 
1153
  # Decode latents
1154
  latents = latents.to(pipe.vae.dtype)
1155
- latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1156
- latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
 
 
1157
  latents = latents / latents_std + latents_mean
1158
  video = pipe.vae.decode(latents, return_dict=False)[0]
1159
- video = ttm_helper.video_processor.postprocess_video(video, output_type="pil")
 
1160
 
1161
  progress(0.95, desc="Saving video...")
1162
 
@@ -1226,7 +1407,8 @@ def run_ttm_generation(
1226
 
1227
 
1228
  # Create Gradio interface
1229
- print("🎨 Creating Gradio interface...")
 
1230
 
1231
  with gr.Blocks(
1232
  theme=gr.themes.Soft(),
@@ -1283,7 +1465,8 @@ with gr.Blocks(
1283
  info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
1284
  )
1285
 
1286
- generate_btn = gr.Button("🚀 Generate Motion Signal", variant="primary", size="lg")
 
1287
 
1288
  with gr.Column(scale=1):
1289
  gr.Markdown("### 📤 Rendered Output")
@@ -1419,7 +1602,8 @@ with gr.Blocks(
1419
  label="TTM Generated Video",
1420
  height=400
1421
  )
1422
- ttm_status_text = gr.Markdown("Upload a video in Step 1 first, then run TTM here.")
 
1423
 
1424
  # TTM Input preview
1425
  with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False):
@@ -1439,7 +1623,8 @@ with gr.Blocks(
1439
 
1440
  # Helper function to update states and preview
1441
  def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()):
1442
- result = process_video(video_path, camera_movement, generate_ttm_flag, progress)
 
1443
  output_vid, motion_sig, mask_vid, first_frame, status = result
1444
 
1445
  # Return all outputs including state updates and previews
@@ -1491,10 +1676,12 @@ with gr.Blocks(
1491
  # Examples
1492
  gr.Markdown("### 📁 Examples")
1493
  if os.path.exists("./examples"):
1494
- example_videos = [f for f in os.listdir("./examples") if f.endswith(".mp4")][:4]
 
1495
  if example_videos:
1496
  gr.Examples(
1497
- examples=[[f"./examples/{v}", "move_forward", True] for v in example_videos],
 
1498
  inputs=[video_input, camera_movement, generate_ttm],
1499
  outputs=[
1500
  output_video, motion_signal_output, mask_output, first_frame_output, status_text,
@@ -1506,5 +1693,13 @@ with gr.Blocks(
1506
  )
1507
 
1508
  # Launch
 
 
 
 
 
 
1509
  if __name__ == "__main__":
 
 
1510
  demo.launch(share=False)
 
1
+ import sys
2
  import gradio as gr
3
  import os
4
  import numpy as np
 
7
  import shutil
8
  from pathlib import Path
9
  from einops import rearrange
10
+ from typing import Union
11
+
12
+ # Force unbuffered output for HF Spaces logs
13
+ os.environ['PYTHONUNBUFFERED'] = '1'
14
+
15
+ # Configure logging FIRST before any other imports
16
+ import logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.StreamHandler(sys.stdout)
22
+ ]
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+ logger.info("=" * 50)
26
+ logger.info("Starting application initialization...")
27
+ logger.info("=" * 50)
28
+ sys.stdout.flush()
29
+
30
  try:
31
  import spaces
32
+ logger.info("✅ HF Spaces module imported successfully")
33
  except ImportError:
34
+ logger.warning("⚠️ HF Spaces module not available, using mock")
35
  class spaces:
36
  @staticmethod
37
  def GPU(func=None, duration=None):
38
  def decorator(f):
39
  return f
40
  return decorator if func is None else func
41
+ sys.stdout.flush()
42
+
43
+ logger.info("Importing torch...")
44
+ sys.stdout.flush()
45
  import torch
46
+ logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
47
+ sys.stdout.flush()
48
+
49
  import torch.nn.functional as F
50
  import torchvision.transforms as T
 
51
  from concurrent.futures import ThreadPoolExecutor
52
  import atexit
53
  import uuid
54
+
55
+ logger.info("Importing decord...")
56
+ sys.stdout.flush()
57
  import decord
58
+ logger.info("✅ Decord imported successfully")
59
+ sys.stdout.flush()
60
+
61
  from PIL import Image
62
 
63
+ logger.info("Importing SpaTrack models...")
64
+ sys.stdout.flush()
65
+ try:
66
+ from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
67
+ from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
68
+ from models.SpaTrackV2.models.predictor import Predictor
69
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
70
+ logger.info("✅ SpaTrack models imported successfully")
71
+ except Exception as e:
72
+ logger.error(f"❌ Failed to import SpaTrack models: {e}")
73
+ raise
74
+ sys.stdout.flush()
75
 
76
  # TTM imports (optional - will be loaded on demand)
77
+ logger.info("Checking TTM (diffusers) availability...")
78
+ sys.stdout.flush()
79
  TTM_COG_AVAILABLE = False
80
  TTM_WAN_AVAILABLE = False
81
  try:
 
85
  from diffusers.utils.torch_utils import randn_tensor
86
  from diffusers.video_processor import VideoProcessor
87
  TTM_COG_AVAILABLE = True
88
+ logger.info("✅ CogVideoX TTM available")
89
+ except ImportError as e:
90
+ logger.info(f"ℹ️ CogVideoX TTM not available: {e}")
91
+ sys.stdout.flush()
92
 
93
  try:
94
  from diffusers import AutoencoderKLWan, WanTransformer3DModel
 
100
  from diffusers.utils.torch_utils import randn_tensor
101
  from diffusers.video_processor import VideoProcessor
102
  TTM_WAN_AVAILABLE = True
103
+ logger.info("✅ Wan TTM available")
104
+ except ImportError as e:
105
+ logger.info(f"ℹ️ Wan TTM not available: {e}")
106
+ sys.stdout.flush()
107
 
108
  TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
109
  if not TTM_AVAILABLE:
110
+ logger.warning("⚠️ Diffusers not available. TTM features will be disabled.")
111
+ else:
112
+ logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}")
113
+ sys.stdout.flush()
 
 
114
 
115
  # Constants
116
  MAX_FRAMES = 80
 
143
  if TTM_WAN_AVAILABLE:
144
  TTM_MODELS.append("Wan2.2-14B (Recommended)")
145
 
146
+ # Global model instances (lazy loaded for HF Spaces GPU compatibility)
147
+ vggt4track_model = None
148
+ tracker_model = None
149
  ttm_cog_pipeline = None
150
  ttm_wan_pipeline = None
151
+ MODELS_LOADED = False
152
 
153
 
154
  def load_video_to_tensor(video_path: str) -> torch.Tensor:
 
199
  ttm_wan_pipeline.vae.enable_slicing()
200
  logger.info("TTM Wan 2.2 pipeline loaded successfully!")
201
  return ttm_wan_pipeline
202
+
203
+
204
+ logger.info("Setting up thread pool and utility functions...")
205
+ sys.stdout.flush()
206
 
207
  # Thread pool for delayed deletion
208
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
209
 
210
+
211
+ def load_models():
212
+ """Load models lazily when GPU is available (inside @spaces.GPU decorated function)."""
213
+ global vggt4track_model, tracker_model, MODELS_LOADED
214
+
215
+ if MODELS_LOADED:
216
+ logger.info("Models already loaded, skipping...")
217
+ return
218
+
219
+ logger.info("🚀 Starting model loading...")
220
+ sys.stdout.flush()
221
+
222
+ try:
223
+ logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...")
224
+ sys.stdout.flush()
225
+ vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
226
+ vggt4track_model.eval()
227
+ logger.info("✅ VGGT4Track model loaded, moving to CUDA...")
228
+ sys.stdout.flush()
229
+ vggt4track_model = vggt4track_model.to("cuda")
230
+ logger.info("✅ VGGT4Track model on CUDA")
231
+ sys.stdout.flush()
232
+
233
+ logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...")
234
+ sys.stdout.flush()
235
+ tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
236
+ tracker_model.eval()
237
+ logger.info("✅ Predictor model loaded")
238
+ sys.stdout.flush()
239
+
240
+ MODELS_LOADED = True
241
+ logger.info("✅ All models loaded successfully!")
242
+ sys.stdout.flush()
243
+
244
+ except Exception as e:
245
+ logger.error(f"❌ Failed to load models: {e}")
246
+ import traceback
247
+ traceback.print_exc()
248
+ sys.stdout.flush()
249
+ raise
250
+
251
+
252
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
253
  """Delete file or directory after specified delay"""
254
  def _delete():
 
267
  thread_pool_executor.submit(_wait_and_delete)
268
  atexit.register(_delete)
269
 
270
+
271
  def create_user_temp_dir():
272
  """Create a unique temporary directory for each user session"""
273
  session_id = str(uuid.uuid4())[:8]
 
276
  delete_later(temp_dir, delay=600)
277
  return temp_dir
278
 
 
 
 
 
 
279
 
280
+ # Note: Models are loaded lazily inside @spaces.GPU decorated functions
281
+ # This is required for HF Spaces ZeroGPU compatibility
282
+ logger.info("Models will be loaded lazily when GPU is available")
283
+ sys.stdout.flush()
284
 
285
+ logger.info("Setting up Gradio static paths...")
286
  gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
287
+ logger.info("✅ Static paths configured")
288
+ sys.stdout.flush()
289
 
290
 
291
  def generate_camera_trajectory(num_frames: int, movement_type: str,
 
311
  if movement_type == "static":
312
  pass # Keep identity
313
  elif movement_type == "move_forward":
314
+ # Move along -Z (forward in OpenGL convention)
315
+ ext[2, 3] = -speed * t
316
  elif movement_type == "move_backward":
317
  ext[2, 3] = speed * t # Move along +Z
318
  elif movement_type == "move_left":
 
369
  base_dir = os.path.dirname(output_path)
370
  motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
371
  mask_path = os.path.join(base_dir, "mask.mp4")
372
+ out_motion_signal = cv2.VideoWriter(
373
+ motion_signal_path, fourcc, fps, (W, H))
374
  out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
375
 
376
  # Create meshgrid for pixel coordinates
 
450
  if hole_mask.sum() == 0:
451
  break
452
  dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
453
+ motion_signal_frame = np.where(
454
+ hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
455
+ hole_mask = (motion_signal_frame.sum(
456
+ axis=-1) == 0).astype(np.uint8)
457
 
458
  # Write TTM outputs if enabled
459
  if generate_ttm_inputs:
460
  # Motion signal: warped frame with NN inpainting
461
+ motion_signal_bgr = cv2.cvtColor(
462
+ motion_signal_frame, cv2.COLOR_RGB2BGR)
463
  out_motion_signal.write(motion_signal_bgr)
464
 
465
  # Mask: binary mask of valid (projected) pixels - white where valid, black where holes
466
+ mask_frame = np.stack(
467
+ [valid_mask, valid_mask, valid_mask], axis=-1)
468
  out_mask.write(mask_frame)
469
 
470
  # For the rendered output, use the same inpainted result
 
484
  }
485
 
486
 
487
+ @spaces.GPU(duration=180)
488
  def run_spatial_tracker(video_tensor: torch.Tensor):
489
  """
490
  GPU-intensive spatial tracking function.
 
495
  Returns:
496
  Dictionary containing tracking results
497
  """
498
+ global vggt4track_model, tracker_model
499
+
500
+ logger.info("run_spatial_tracker: Starting GPU execution...")
501
+ sys.stdout.flush()
502
+
503
+ # Load models if not already loaded (lazy loading for HF Spaces)
504
+ load_models()
505
+
506
+ logger.info("run_spatial_tracker: Preprocessing video input...")
507
+ sys.stdout.flush()
508
+
509
  # Run VGGT to get depth and camera poses
510
  video_input = preprocess_image(video_tensor)[None].cuda()
511
 
512
+ logger.info("run_spatial_tracker: Running VGGT inference...")
513
+ sys.stdout.flush()
514
+
515
  with torch.no_grad():
516
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
517
  predictions = vggt4track_model(video_input / 255)
 
520
  depth_map = predictions["points_map"][..., 2]
521
  depth_conf = predictions["unc_metric"]
522
 
523
+ logger.info("run_spatial_tracker: VGGT inference complete")
524
+ sys.stdout.flush()
525
+
526
  depth_tensor = depth_map.squeeze().cpu().numpy()
527
  extrs = extrinsic.squeeze().cpu().numpy()
528
  intrs = intrinsic.squeeze().cpu().numpy()
 
530
  unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
531
 
532
  # Setup tracker
533
+ logger.info("run_spatial_tracker: Setting up tracker...")
534
+ sys.stdout.flush()
535
+
536
  tracker_model.spatrack.track_num = 512
537
  tracker_model.to("cuda")
538
 
539
  # Get grid points for tracking
540
  frame_H, frame_W = video_tensor_gpu.shape[2:]
541
  grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
542
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
543
+ 0].numpy()
544
+
545
+ logger.info("run_spatial_tracker: Running 3D tracker...")
546
+ sys.stdout.flush()
547
 
548
  # Run tracker
549
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
 
571
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
572
  intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
573
 
574
+ logger.info("run_spatial_tracker: Moving results to CPU...")
575
+ sys.stdout.flush()
576
+
577
  # Move results to CPU and return
578
+ result = {
579
  'video_out': video_out.cpu(),
580
  'point_map': point_map.cpu(),
581
  'conf_depth': conf_depth.cpu(),
 
583
  'c2w_traj': c2w_traj.cpu(),
584
  }
585
 
586
+ logger.info("run_spatial_tracker: Complete!")
587
+ sys.stdout.flush()
588
+
589
+ return result
590
+
591
 
592
  def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
593
  """Main processing function
 
643
  c2w_traj = tracking_results['c2w_traj']
644
 
645
  # Get RGB frames and depth
646
+ rgb_frames = rearrange(
647
+ video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
648
  depth_frames = point_map[:, 2].numpy()
649
  depth_conf_np = conf_depth.numpy()
650
 
 
655
  intrs_np = intrs_out.numpy()
656
  extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
657
 
658
+ progress(
659
+ 0.7, desc=f"Generating {camera_movement} camera trajectory...")
660
 
661
  # Calculate scene scale from depth
662
  valid_depth = depth_frames[depth_frames > 0]
 
720
  self.vae = pipeline.vae
721
  self.transformer = pipeline.transformer
722
  self.scheduler = pipeline.scheduler
723
+ self.vae_scale_factor_spatial = 2 ** (
724
+ len(self.vae.config.block_out_channels) - 1)
725
  self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
726
  self.vae_scaling_factor_image = self.vae.config.scaling_factor
727
  self.video_processor = pipeline.video_processor
 
731
  """Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1]."""
732
  latents = self.vae.encode(frames)[0].sample()
733
  latents = latents * self.vae_scaling_factor_image
734
+ # (B, C, F, H, W) -> (B, F, C, H, W)
735
+ return latents.permute(0, 2, 1, 3, 4).contiguous()
736
 
737
  def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
738
  """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
 
746
  s = self.vae_scale_factor_spatial
747
  H_latent = pooled.shape[-2] // s
748
  W_latent = pooled.shape[-1] // s
749
+ pooled = F.interpolate(pooled, size=(
750
+ pooled.shape[2], H_latent, W_latent), mode="nearest")
751
 
752
  latent_mask = pooled.permute(0, 2, 1, 3, 4)
753
  return latent_mask
 
778
  s = self.vae_scale_factor_spatial
779
  H_latent = pooled.shape[-2] // s
780
  W_latent = pooled.shape[-1] // s
781
+ pooled = F.interpolate(pooled, size=(
782
+ pooled.shape[2], H_latent, W_latent), mode="nearest")
783
 
784
  latent_mask = pooled.permute(0, 2, 1, 3, 4)
785
  return latent_mask
 
836
  image = load_image(first_frame_path)
837
 
838
  # Get dimensions
839
+ height = pipe.transformer.config.sample_height * \
840
+ ttm_helper.vae_scale_factor_spatial
841
+ width = pipe.transformer.config.sample_width * \
842
+ ttm_helper.vae_scale_factor_spatial
843
 
844
  device = "cuda"
845
  generator = torch.Generator(device=device).manual_seed(seed)
 
857
  device=device,
858
  )
859
  if do_classifier_free_guidance:
860
+ prompt_embeds = torch.cat(
861
+ [negative_prompt_embeds, prompt_embeds], dim=0)
862
 
863
  progress(0.2, desc="Preparing latents...")
864
 
 
867
  timesteps = pipe.scheduler.timesteps
868
 
869
  # Prepare latents
870
+ latent_frames = (
871
+ num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
872
 
873
  # Handle padding for CogVideoX 1.5
874
  patch_size_t = pipe.transformer.config.patch_size_t
 
902
  ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
903
  refB, refC, refT, refH, refW = ref_vid.shape
904
  ref_vid = F.interpolate(
905
+ ref_vid.permute(0, 2, 1, 3, 4).reshape(
906
+ refB*refT, refC, refH, refW),
907
  size=(height, width), mode="bicubic", align_corners=True,
908
  ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
909
 
910
+ ref_vid = ttm_helper.video_processor.normalize(
911
+ ref_vid.to(dtype=pipe.vae.dtype))
912
  ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
913
 
914
  # Load mask video
 
939
  device=ref_latents.device,
940
  dtype=ref_latents.dtype,
941
  )
942
+ noisy_latents = pipe.scheduler.add_noise(
943
+ ref_latents, fixed_noise, tweak.long())
944
+ latents = noisy_latents.to(
945
+ dtype=latents.dtype, device=latents.device)
946
  else:
947
  fixed_noise = randn_tensor(
948
  ref_latents.shape,
 
957
 
958
  # Create rotary embeddings if required
959
  image_rotary_emb = (
960
+ pipe._prepare_rotary_positional_embeddings(
961
+ height, width, latents.size(1), device)
962
  if pipe.transformer.config.use_rotary_positional_embeddings
963
  else None
964
  )
965
 
966
  # Create ofs embeddings if required
967
+ ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full(
968
+ (1,), fill_value=2.0)
969
 
970
  progress(0.4, desc="Running TTM denoising loop...")
971
 
 
975
 
976
  for i, t in enumerate(timesteps[tweak_index:]):
977
  step_progress = 0.4 + 0.5 * (i / total_steps)
978
+ progress(step_progress,
979
+ desc=f"Denoising step {i+1}/{total_steps}...")
980
 
981
+ latent_model_input = torch.cat(
982
+ [latents] * 2) if do_classifier_free_guidance else latents
983
+ latent_model_input = pipe.scheduler.scale_model_input(
984
+ latent_model_input, t)
985
 
986
+ latent_image_input = torch.cat(
987
+ [image_latents] * 2) if do_classifier_free_guidance else image_latents
988
+ latent_model_input = torch.cat(
989
+ [latent_model_input, latent_image_input], dim=2)
990
 
991
  timestep = t.expand(latent_model_input.shape[0])
992
 
 
1004
  # Perform guidance
1005
  if do_classifier_free_guidance:
1006
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1007
+ noise_pred = noise_pred_uncond + guidance_scale * \
1008
+ (noise_pred_text - noise_pred_uncond)
1009
 
1010
  # Compute previous noisy sample
1011
  if not isinstance(pipe.scheduler, CogVideoXDPMScheduler):
 
1043
  # Decode latents
1044
  latents = latents[:, additional_frames:]
1045
  frames = pipe.decode_latents(latents)
1046
+ video = ttm_helper.video_processor.postprocess_video(
1047
+ video=frames, output_type="pil")
1048
 
1049
  progress(0.95, desc="Saving video...")
1050
 
 
1109
 
1110
  # Get dimensions - compute based on image aspect ratio
1111
  max_area = 480 * 832
1112
+ mod_value = ttm_helper.vae_scale_factor_spatial * \
1113
+ pipe.transformer.config.patch_size[1]
1114
+ height, width = compute_hw_from_area(
1115
+ image.height, image.width, max_area, mod_value)
1116
  image = image.resize((width, height))
1117
 
1118
  device = "cuda"
 
1136
  transformer_dtype = pipe.transformer.dtype
1137
  prompt_embeds = prompt_embeds.to(transformer_dtype)
1138
  if negative_prompt_embeds is not None:
1139
+ negative_prompt_embeds = negative_prompt_embeds.to(
1140
+ transformer_dtype)
1141
 
1142
  # Encode image embedding if transformer supports it
1143
  image_embeds = None
 
1154
 
1155
  # Adjust num_frames to be valid for VAE
1156
  if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
1157
+ num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
1158
+ ttm_helper.vae_scale_factor_temporal + 1
1159
  num_frames = max(num_frames, 1)
1160
 
1161
  # Prepare latent variables
1162
  num_channels_latents = pipe.vae.config.z_dim
1163
+ image_tensor = ttm_helper.video_processor.preprocess(
1164
+ image, height=height, width=width).to(device, dtype=torch.float32)
1165
 
1166
  latents_outputs = pipe.prepare_latents(
1167
  image_tensor,
 
1189
  ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
1190
  refB, refC, refT, refH, refW = ref_vid.shape
1191
  ref_vid = F.interpolate(
1192
+ ref_vid.permute(0, 2, 1, 3, 4).reshape(
1193
+ refB*refT, refC, refH, refW),
1194
  size=(height, width), mode="bicubic", align_corners=True,
1195
  ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
1196
 
1197
+ ref_vid = ttm_helper.video_processor.normalize(
1198
+ ref_vid.to(dtype=pipe.vae.dtype))
1199
+ ref_latents = retrieve_latents(
1200
+ pipe.vae.encode(ref_vid), sample_mode="argmax")
1201
 
1202
  # Normalize latents
1203
+ latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1204
+ 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
1205
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1206
+ 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
1207
  ref_latents = (ref_latents - latents_mean) * latents_std
1208
 
1209
  # Load mask video
 
1227
  else:
1228
  mask_t1_hw = (mask_tc_hw > 0.5).float()
1229
 
1230
+ motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
1231
+ mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
1232
  background_mask = 1.0 - motion_mask
1233
 
1234
  progress(0.35, desc="Initializing TTM denoising...")
 
1242
  device=ref_latents.device,
1243
  dtype=ref_latents.dtype,
1244
  )
1245
+ tweak_t = torch.as_tensor(
1246
+ tweak, device=ref_latents.device, dtype=torch.long).view(1)
1247
+ noisy_latents = pipe.scheduler.add_noise(
1248
+ ref_latents, fixed_noise, tweak_t.long())
1249
+ latents = noisy_latents.to(
1250
+ dtype=latents.dtype, device=latents.device)
1251
  else:
1252
  fixed_noise = randn_tensor(
1253
  ref_latents.shape,
 
1264
 
1265
  for i, t in enumerate(timesteps[tweak_index:]):
1266
  step_progress = 0.4 + 0.5 * (i / total_steps)
1267
+ progress(step_progress,
1268
+ desc=f"Denoising step {i+1}/{total_steps}...")
1269
 
1270
  # Prepare model input
1271
  if first_frame_mask is not None:
1272
+ latent_model_input = (1 - first_frame_mask) * \
1273
+ condition + first_frame_mask * latents
1274
  latent_model_input = latent_model_input.to(transformer_dtype)
1275
  temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
1276
  timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
1277
  else:
1278
+ latent_model_input = torch.cat(
1279
+ [latents, condition], dim=1).to(transformer_dtype)
1280
  timestep = t.expand(latents.shape[0])
1281
 
1282
  # Predict noise (conditional)
 
1297
  encoder_hidden_states_image=image_embeds,
1298
  return_dict=False,
1299
  )[0]
1300
+ noise_pred = noise_uncond + guidance_scale * \
1301
+ (noise_pred - noise_uncond)
1302
 
1303
  # Scheduler step
1304
+ latents = pipe.scheduler.step(
1305
+ noise_pred, t, latents, return_dict=False)[0]
1306
 
1307
  # TTM: In between tweak and tstrong, replace mask with noisy reference latents
1308
  in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
 
1310
  if in_between_tweak_tstrong:
1311
  if i + tweak_index + 1 < len(timesteps):
1312
  prev_t = timesteps[i + tweak_index + 1]
1313
+ prev_t = torch.as_tensor(
1314
+ prev_t, device=ref_latents.device, dtype=torch.long).view(1)
1315
  noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
1316
  dtype=latents.dtype, device=latents.device
1317
  )
1318
  latents = latents * background_mask + noisy_latents * motion_mask
1319
  else:
1320
+ latents = latents * background_mask + \
1321
+ ref_latents.to(dtype=latents.dtype,
1322
+ device=latents.device) * motion_mask
1323
 
1324
  progress(0.9, desc="Decoding video...")
1325
 
1326
  # Apply first frame mask if used
1327
  if first_frame_mask is not None:
1328
+ latents = (1 - first_frame_mask) * condition + \
1329
+ first_frame_mask * latents
1330
 
1331
  # Decode latents
1332
  latents = latents.to(pipe.vae.dtype)
1333
+ latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1334
+ 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1335
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1336
+ 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
1337
  latents = latents / latents_std + latents_mean
1338
  video = pipe.vae.decode(latents, return_dict=False)[0]
1339
+ video = ttm_helper.video_processor.postprocess_video(
1340
+ video, output_type="pil")
1341
 
1342
  progress(0.95, desc="Saving video...")
1343
 
 
1407
 
1408
 
1409
  # Create Gradio interface
1410
+ logger.info("🎨 Creating Gradio interface...")
1411
+ sys.stdout.flush()
1412
 
1413
  with gr.Blocks(
1414
  theme=gr.themes.Soft(),
 
1465
  info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
1466
  )
1467
 
1468
+ generate_btn = gr.Button(
1469
+ "🚀 Generate Motion Signal", variant="primary", size="lg")
1470
 
1471
  with gr.Column(scale=1):
1472
  gr.Markdown("### 📤 Rendered Output")
 
1602
  label="TTM Generated Video",
1603
  height=400
1604
  )
1605
+ ttm_status_text = gr.Markdown(
1606
+ "Upload a video in Step 1 first, then run TTM here.")
1607
 
1608
  # TTM Input preview
1609
  with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False):
 
1623
 
1624
  # Helper function to update states and preview
1625
  def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()):
1626
+ result = process_video(video_path, camera_movement,
1627
+ generate_ttm_flag, progress)
1628
  output_vid, motion_sig, mask_vid, first_frame, status = result
1629
 
1630
  # Return all outputs including state updates and previews
 
1676
  # Examples
1677
  gr.Markdown("### 📁 Examples")
1678
  if os.path.exists("./examples"):
1679
+ example_videos = [f for f in os.listdir(
1680
+ "./examples") if f.endswith(".mp4")][:4]
1681
  if example_videos:
1682
  gr.Examples(
1683
+ examples=[[f"./examples/{v}", "move_forward", True]
1684
+ for v in example_videos],
1685
  inputs=[video_input, camera_movement, generate_ttm],
1686
  outputs=[
1687
  output_video, motion_signal_output, mask_output, first_frame_output, status_text,
 
1693
  )
1694
 
1695
  # Launch
1696
+ logger.info("✅ Gradio interface created successfully!")
1697
+ logger.info("=" * 50)
1698
+ logger.info("Application ready to launch")
1699
+ logger.info("=" * 50)
1700
+ sys.stdout.flush()
1701
+
1702
  if __name__ == "__main__":
1703
+ logger.info("Starting Gradio server...")
1704
+ sys.stdout.flush()
1705
  demo.launch(share=False)