abreza commited on
Commit
d29eac9
·
1 Parent(s): e05bac0
Files changed (1) hide show
  1. app.py +326 -150
app.py CHANGED
@@ -19,225 +19,401 @@ from concurrent.futures import ThreadPoolExecutor
19
  import atexit
20
  import uuid
21
  import decord
22
- from PIL import Image
23
-
24
- try:
25
- from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
26
- from pipelines.utils import compute_hw_from_area, validate_inputs
27
- from diffusers.utils import export_to_video, load_image
28
- except ImportError:
29
- print("Warning: TTM pipelines not found. Ensure the /pipelines folder is in your path.")
30
 
31
  from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
32
  from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
33
  from models.SpaTrackV2.models.predictor import Predictor
34
  from models.SpaTrackV2.models.utils import get_points_on_a_grid
35
 
 
 
 
 
 
 
 
36
  # Configure logging
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
 
40
  # Constants
41
- MAX_FRAMES = 81
42
- OUTPUT_FPS = 16
 
 
43
  WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
44
- DTYPE = torch.bfloat16
45
 
46
- # --- Global Model Initialization ---
47
- print("🚀 Initializing models...")
48
- vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
49
- vggt4track_model.eval().to("cuda")
 
 
 
 
 
 
50
 
51
- tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
52
- tracker_model.eval()
53
 
54
- # Lazy loading for Wan to save VRAM initially
55
- wan_pipe = None
56
 
57
- def get_wan_pipeline():
58
- global wan_pipe
59
- if wan_pipe is None:
60
- print("🚀 Initializing Wan 2.2 TTM Pipeline...")
61
- wan_pipe = WanImageToVideoTTMPipeline.from_pretrained(WAN_MODEL_ID, torch_dtype=DTYPE)
62
- wan_pipe.vae.enable_tiling()
63
- wan_pipe.vae.enable_slicing()
64
- wan_pipe.to("cuda")
65
- return wan_pipe
66
-
67
- # --- Utility Functions ---
68
- def delete_later(path, delay=600):
69
  def _wait_and_delete():
70
  time.sleep(delay)
71
- try:
72
- if os.path.isfile(path): os.remove(path)
73
- elif os.path.isdir(path): shutil.rmtree(path)
74
- except: pass
75
- ThreadPoolExecutor(max_workers=1).submit(_wait_and_delete)
76
 
77
  def create_user_temp_dir():
78
  session_id = str(uuid.uuid4())[:8]
79
  temp_dir = os.path.join("temp_local", f"session_{session_id}")
80
  os.makedirs(temp_dir, exist_ok=True)
81
- delete_later(temp_dir)
82
  return temp_dir
83
 
84
- def generate_camera_trajectory(num_frames, movement_type, base_intrinsics, scene_scale=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  speed = scene_scale * 0.02
86
  extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
87
  for t in range(num_frames):
88
  ext = np.eye(4, dtype=np.float32)
89
- if movement_type == "move_forward": ext[2, 3] = -speed * t
90
- elif movement_type == "move_backward": ext[2, 3] = speed * t
91
- elif movement_type == "move_left": ext[0, 3] = -speed * t
92
- elif movement_type == "move_right": ext[0, 3] = speed * t
93
- elif movement_type == "move_up": ext[1, 3] = -speed * t
94
- elif movement_type == "move_down": ext[1, 3] = speed * t
 
 
 
 
 
 
95
  extrinsics[t] = ext
96
  return extrinsics
97
 
98
- def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics, new_extrinsics, output_path, generate_ttm_inputs=True):
99
- T, H, W, _ = rgb_frames.shape
100
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), OUTPUT_FPS, (W, H))
101
 
102
- motion_signal_path = os.path.join(os.path.dirname(output_path), "motion_signal.mp4")
103
- mask_path = os.path.join(os.path.dirname(output_path), "mask.mp4")
104
- out_motion = cv2.VideoWriter(motion_signal_path, cv2.VideoWriter_fourcc(*'mp4v'), OUTPUT_FPS, (W, H))
105
- out_mask = cv2.VideoWriter(mask_path, cv2.VideoWriter_fourcc(*'mp4v'), OUTPUT_FPS, (W, H))
 
 
 
 
 
 
 
 
 
106
 
107
  u, v = np.meshgrid(np.arange(W), np.arange(H))
108
  for t in range(T):
 
109
  orig_c2w = np.linalg.inv(original_extrinsics[t])
110
- if t == 0: base_c2w = orig_c2w.copy()
 
111
  new_c2w = base_c2w @ new_extrinsics[t]
112
  new_w2c = np.linalg.inv(new_c2w)
113
-
114
- K_inv = np.linalg.inv(intrinsics[t])
115
  pixels = np.stack([u, v, np.ones_like(u)], axis=-1).reshape(-1, 3)
116
  rays_cam = (K_inv @ pixels.T).T
117
- points_cam = rays_cam * depth_frames[t].reshape(-1, 1)
118
  points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
119
  points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
120
- points_proj = (intrinsics[t] @ points_new_cam.T).T
121
-
122
- uv_new = points_proj[:, :2] / np.clip(points_proj[:, 2:3], 1e-6, None)
123
  rendered = np.zeros((H, W, 3), dtype=np.uint8)
124
- z_buf = np.full((H, W), np.inf)
 
125
 
126
  for i in range(len(uv_new)):
127
  uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
128
- if 0 <= uu < W and 0 <= vv < H and points_new_cam[i, 2] > 0:
129
- if points_new_cam[i, 2] < z_buf[vv, uu]:
130
- z_buf[vv, uu] = points_new_cam[i, 2]
131
- rendered[vv, uu] = rgb_frames[t].reshape(-1, 3)[i]
132
 
133
  valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
134
-
135
- # Hole filling for motion signal
136
- motion_frame = rendered.copy()
137
- hole_mask = (motion_frame.sum(axis=-1) == 0).astype(np.uint8)
138
  if hole_mask.sum() > 0:
139
- for _ in range(10): # Iterative dilation for NN inpainting
140
- dilated = cv2.dilate(motion_frame, np.ones((3,3), np.uint8))
141
- motion_frame = np.where(hole_mask[:, :, None] > 0, dilated, motion_frame)
142
- hole_mask = (motion_frame.sum(axis=-1) == 0).astype(np.uint8)
143
- if hole_mask.sum() == 0: break
144
-
145
- out_motion.write(cv2.cvtColor(motion_frame, cv2.COLOR_RGB2BGR))
146
- out_mask.write(cv2.merge([valid_mask, valid_mask, valid_mask]))
147
- out.write(cv2.cvtColor(motion_frame, cv2.COLOR_RGB2BGR))
148
-
149
- out.release(); out_motion.release(); out_mask.release()
 
 
 
 
 
 
 
 
 
150
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
151
 
152
- # --- Main Processing Logic ---
153
- def run_ttm_wan_inference(image_path, motion_path, mask_path, prompt, tweak_idx, tstrong_idx, guidance_scale, seed=0):
154
- pipe = get_wan_pipeline()
155
- image = load_image(image_path)
156
- max_area = 480 * 832
157
- mod_val = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
158
- h, w = compute_hw_from_area(image.height, image.width, max_area, mod_val)
159
- image = image.resize((w, h))
160
 
161
- generator = torch.Generator(device="cuda").manual_seed(seed)
162
- with torch.inference_mode():
163
- result = pipe(
164
- image=image, prompt=prompt, height=h, width=w, num_frames=81,
165
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
166
- motion_signal_video_path=motion_path, motion_signal_mask_path=mask_path,
167
- tweak_index=tweak_idx, tstrong_index=tstrong_idx, negative_prompt="blurry, static, low quality"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  )
169
- return result.frames[0]
170
 
171
- def process_video_full_pipeline(video_path, camera_movement, prompt, tweak_idx, tstrong_idx, guidance_scale, progress=gr.Progress()):
172
- if not video_path or not prompt: return [None]*5 + ["❌ Missing video or prompt"]
173
 
174
- temp_dir = create_user_temp_dir()
175
- res_dir = os.path.join(temp_dir, "results"); os.makedirs(res_dir, exist_ok=True)
176
 
177
- # 1. Spatial Tracking
178
- progress(0.1, desc="3D Analysis...")
179
- vr = decord.VideoReader(video_path)
180
- vt = torch.from_numpy(vr.get_batch(range(len(vr))).asnumpy()).permute(0,3,1,2).float()
181
- vt = vt[::max(1, len(vt)//MAX_FRAMES)][:MAX_FRAMES]
182
 
183
- # Preprocess for VGGT
184
- v_in = preprocess_image(vt)[None].cuda()
185
- with torch.no_grad():
186
- preds = vggt4track_model(v_in / 255)
187
 
188
- # Tracker
189
- tracker_model.to("cuda")
190
- grid = get_points_on_a_grid(30, v_in.shape[3:], device="cpu")
191
- queries = torch.cat([torch.zeros_like(grid[:,:,:1]), grid], dim=2)[0].numpy()
192
 
193
- c2w, intrs, p_map, c_depth, _, _, _, _, v_out = tracker_model.forward(
194
- v_in.squeeze(), depth=preds["points_map"][...,2].squeeze().cpu().numpy(),
195
- intrs=preds["intrs"].squeeze().cpu().numpy(), extrs=preds["poses_pred"].squeeze().cpu().numpy(),
196
- queries=queries, fps=1, iters_track=4, fixed_cam=False
 
197
  )
198
 
199
- # 2. Rendering
200
- progress(0.6, desc="Rendering Point Cloud...")
201
- rgb = rearrange(v_out.cpu().numpy(), "T C H W -> T H W C").astype(np.uint8)
202
- depth = p_map[0, 2].cpu().numpy() # Simplified for single view context
203
- new_ext = generate_camera_trajectory(len(rgb), camera_movement, intrs.cpu().numpy(), np.median(depth[depth>0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- rend_path = os.path.join(res_dir, "warp.mp4")
206
- rend_res = render_from_pointcloud(rgb, p_map[:,2].cpu().numpy(), intrs.cpu().numpy(), torch.inverse(c2w).cpu().numpy(), new_ext, rend_path)
207
 
208
- first_frame_path = os.path.join(res_dir, "first.png")
209
- cv2.imwrite(first_frame_path, cv2.cvtColor(rgb[0], cv2.COLOR_RGB2BGR))
210
 
211
- # 3. Wan TTM Inference
212
- progress(0.8, desc="Wan 2.2 Realistic Generation...")
213
- wan_video_path = os.path.join(res_dir, "final_wan.mp4")
214
- wan_frames = run_ttm_wan_inference(first_frame_path, rend_res['motion_signal'], rend_res['mask'], prompt, tweak_idx, tstrong_idx, guidance_scale)
215
- export_to_video(wan_frames, wan_video_path, fps=16)
216
 
217
- return rend_path, wan_video_path, rend_res['motion_signal'], rend_res['mask'], first_frame_path, "✅ Generated successfully!"
 
 
218
 
219
- # --- Gradio UI ---
220
- with gr.Blocks(theme=gr.themes.Soft(), title="Wan 2.2 TTM Video Generator") as demo:
221
- gr.Markdown("# 🎬 Time-to-Move (TTM) with Wan 2.2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  with gr.Row():
224
- with gr.Column():
225
- v_in = gr.Video(label="Source Video")
226
- p_in = gr.Textbox(label="Prompt", placeholder="Describe the action...")
227
- c_in = gr.Dropdown(choices=["move_forward", "move_backward", "move_left", "move_right", "move_up", "move_down", "static"], value="move_forward", label="Camera Movement")
228
- with gr.Accordion("TTM Settings", open=False):
229
- twk = gr.Slider(0, 15, value=3, label="Tweak Index")
230
- strng = gr.Slider(0, 20, value=7, label="Tstrong Index")
231
- cfg = gr.Slider(1, 10, value=5.0, label="CFG Scale")
232
- btn = gr.Button("Generate Realistic Video", variant="primary")
233
-
234
- with gr.Column():
235
- v_final = gr.Video(label="Final Realistic Result")
236
- v_warp = gr.Video(label="Point Cloud Warp (Guide)")
 
 
 
237
  with gr.Row():
238
- v_msig = gr.Video(label="Motion Signal")
239
- v_mask = gr.Video(label="Mask")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- btn.click(process_video_full_pipeline, [v_in, c_in, p_in, twk, strng, cfg], [v_warp, v_final, v_msig, v_mask, gr.Image(visible=False), gr.Markdown()])
 
 
 
 
 
242
 
243
- demo.launch()
 
 
19
  import atexit
20
  import uuid
21
  import decord
 
 
 
 
 
 
 
 
22
 
23
  from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
24
  from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
25
  from models.SpaTrackV2.models.predictor import Predictor
26
  from models.SpaTrackV2.models.utils import get_points_on_a_grid
27
 
28
+
29
+ # --- TTM SPECIFIC IMPORTS ---
30
+ from diffusers.utils import export_to_video, load_image
31
+ # Note: Ensure pipelines/wan_pipeline.py and pipelines/utils.py are in your working directory
32
+ from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
33
+ from pipelines.utils import compute_hw_from_area, validate_inputs
34
+
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
38
 
39
  # Constants
40
+ MAX_FRAMES = 80
41
+ OUTPUT_FPS = 24
42
+ RENDER_WIDTH = 512
43
+ RENDER_HEIGHT = 384
44
  WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
45
 
46
+ # Camera movement types
47
+ CAMERA_MOVEMENTS = [
48
+ "static",
49
+ "move_forward",
50
+ "move_backward",
51
+ "move_left",
52
+ "move_right",
53
+ "move_up",
54
+ "move_down"
55
+ ]
56
 
57
+ # Thread pool for delayed deletion
58
+ thread_pool_executor = ThreadPoolExecutor(max_workers=2)
59
 
 
 
60
 
61
+ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
62
+ def _delete():
63
+ try:
64
+ if os.path.isfile(path):
65
+ os.remove(path)
66
+ elif os.path.isdir(path):
67
+ shutil.rmtree(path)
68
+ except Exception as e:
69
+ logger.warning(f"Failed to delete {path}: {e}")
70
+
 
 
71
  def _wait_and_delete():
72
  time.sleep(delay)
73
+ _delete()
74
+
75
+ thread_pool_executor.submit(_wait_and_delete)
76
+ atexit.register(_delete)
77
+
78
 
79
  def create_user_temp_dir():
80
  session_id = str(uuid.uuid4())[:8]
81
  temp_dir = os.path.join("temp_local", f"session_{session_id}")
82
  os.makedirs(temp_dir, exist_ok=True)
83
+ delete_later(temp_dir, delay=600)
84
  return temp_dir
85
 
86
+
87
+ # Global model initialization for Spatial Tracker
88
+ print("🚀 Initializing tracking models...")
89
+
90
+ vggt4track_model = VGGT4Track.from_pretrained(
91
+ "Yuxihenry/SpatialTrackerV2_Front")
92
+ vggt4track_model.eval()
93
+ vggt4track_model = vggt4track_model.to("cuda")
94
+
95
+ tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
96
+ tracker_model.eval()
97
+
98
+ # Lazy loading for Wan to save memory until needed
99
+ wan_pipeline = None
100
+
101
+
102
+ def get_wan_pipeline():
103
+ global wan_pipeline
104
+ if wan_pipeline is None:
105
+ print("🚀 Loading Wan TTM Pipeline (14B)...")
106
+ wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
107
+ WAN_MODEL_ID,
108
+ torch_dtype=torch.bfloat16
109
+ )
110
+ wan_pipeline.vae.enable_tiling()
111
+ wan_pipeline.vae.enable_slicing()
112
+ wan_pipeline.to("cuda")
113
+ return wan_pipeline
114
+
115
+
116
+ print("✅ Tracking models loaded successfully!")
117
+
118
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
119
+
120
+ # --- YOUR ORIGINAL FUNCTIONS (generate_camera_trajectory, render_from_pointcloud, run_spatial_tracker) ---
121
+ # [Keeping these exactly as provided in your snippet]
122
+
123
+
124
+ def generate_camera_trajectory(num_frames: int, movement_type: str, base_intrinsics: np.ndarray, scene_scale: float = 1.0) -> tuple:
125
  speed = scene_scale * 0.02
126
  extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
127
  for t in range(num_frames):
128
  ext = np.eye(4, dtype=np.float32)
129
+ if movement_type == "move_forward":
130
+ ext[2, 3] = -speed * t
131
+ elif movement_type == "move_backward":
132
+ ext[2, 3] = speed * t
133
+ elif movement_type == "move_left":
134
+ ext[0, 3] = -speed * t
135
+ elif movement_type == "move_right":
136
+ ext[0, 3] = speed * t
137
+ elif movement_type == "move_up":
138
+ ext[1, 3] = -speed * t
139
+ elif movement_type == "move_down":
140
+ ext[1, 3] = speed * t
141
  extrinsics[t] = ext
142
  return extrinsics
143
 
 
 
 
144
 
145
+ def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics, new_extrinsics, output_path, fps=24, generate_ttm_inputs=False):
146
+ T, H, W, _ = rgb_frames.shape
147
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
148
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
149
+
150
+ motion_signal_path = mask_path = out_motion_signal = out_mask = None
151
+ if generate_ttm_inputs:
152
+ base_dir = os.path.dirname(output_path)
153
+ motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
154
+ mask_path = os.path.join(base_dir, "mask.mp4")
155
+ out_motion_signal = cv2.VideoWriter(
156
+ motion_signal_path, fourcc, fps, (W, H))
157
+ out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
158
 
159
  u, v = np.meshgrid(np.arange(W), np.arange(H))
160
  for t in range(T):
161
+ rgb, depth, K = rgb_frames[t], depth_frames[t], intrinsics[t]
162
  orig_c2w = np.linalg.inv(original_extrinsics[t])
163
+ if t == 0:
164
+ base_c2w = orig_c2w.copy()
165
  new_c2w = base_c2w @ new_extrinsics[t]
166
  new_w2c = np.linalg.inv(new_c2w)
167
+ K_inv = np.linalg.inv(K)
 
168
  pixels = np.stack([u, v, np.ones_like(u)], axis=-1).reshape(-1, 3)
169
  rays_cam = (K_inv @ pixels.T).T
170
+ points_cam = rays_cam * depth.reshape(-1, 1)
171
  points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
172
  points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
173
+ points_proj = (K @ points_new_cam.T).T
174
+ z = np.clip(points_proj[:, 2:3], 1e-6, None)
175
+ uv_new = points_proj[:, :2] / z
176
  rendered = np.zeros((H, W, 3), dtype=np.uint8)
177
+ z_buffer = np.full((H, W), np.inf, dtype=np.float32)
178
+ colors, depths_new = rgb.reshape(-1, 3), points_new_cam[:, 2]
179
 
180
  for i in range(len(uv_new)):
181
  uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
182
+ if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
183
+ if depths_new[i] < z_buffer[vv, uu]:
184
+ z_buffer[vv, uu] = depths_new[i]
185
+ rendered[vv, uu] = colors[i]
186
 
187
  valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
188
+ motion_signal_frame = rendered.copy()
189
+ hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
 
 
190
  if hole_mask.sum() > 0:
191
+ kernel = np.ones((3, 3), np.uint8)
192
+ for _ in range(10): # Iterative fill
193
+ if hole_mask.sum() == 0:
194
+ break
195
+ dilated = cv2.dilate(motion_signal_frame, kernel)
196
+ motion_signal_frame = np.where(
197
+ hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
198
+ hole_mask = (motion_signal_frame.sum(
199
+ axis=-1) == 0).astype(np.uint8)
200
+
201
+ if generate_ttm_inputs:
202
+ out_motion_signal.write(cv2.cvtColor(
203
+ motion_signal_frame, cv2.COLOR_RGB2BGR))
204
+ out_mask.write(np.stack([valid_mask]*3, axis=-1))
205
+ out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR))
206
+
207
+ out.release()
208
+ if generate_ttm_inputs:
209
+ out_motion_signal.release()
210
+ out_mask.release()
211
  return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
212
 
 
 
 
 
 
 
 
 
213
 
214
+ @spaces.GPU
215
+ def run_spatial_tracker(video_tensor):
216
+ video_input = preprocess_image(video_tensor)[None].cuda()
217
+ with torch.no_grad():
218
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
219
+ predictions = vggt4track_model(video_input / 255)
220
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
221
+ depth_map, depth_conf = predictions["points_map"][...,
222
+ 2], predictions["unc_metric"]
223
+
224
+ depth_tensor, extrs, intrs = depth_map.squeeze().cpu().numpy(
225
+ ), extrinsic.squeeze().cpu().numpy(), intrinsic.squeeze().cpu().numpy()
226
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
227
+ tracker_model.spatrack.track_num = 512
228
+ tracker_model.to("cuda")
229
+ grid_pts = get_points_on_a_grid(
230
+ 30, (video_input.shape[3], video_input.shape[4]), device="cpu")
231
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
232
+ 0].numpy()
233
+
234
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
235
+ c2w_traj, intrs_out, point_map, conf_depth, _, _, _, _, video_out = tracker_model.forward(
236
+ video_input.squeeze(), depth=depth_tensor, intrs=intrs, extrs=extrs, queries=query_xyt,
237
+ fps=1, unc_metric=unc_metric, support_frame=len(video_input.squeeze())-1
238
  )
239
+ return {'video_out': video_out.cpu(), 'point_map': point_map.cpu(), 'conf_depth': conf_depth.cpu(), 'intrs_out': intrs_out.cpu(), 'c2w_traj': c2w_traj.cpu()}
240
 
241
+ # --- TTM WAN INFERENCE FUNCTION ---
 
242
 
 
 
243
 
244
+ @spaces.GPU
245
+ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path, motion_video_path, mask_video_path, progress=gr.Progress()):
246
+ if not first_frame_path or not motion_video_path or not mask_video_path:
247
+ return None, "❌ TTM Inputs missing. Please run 3D tracking first."
 
248
 
249
+ progress(0, desc="Loading Wan TTM Pipeline...")
250
+ pipe = get_wan_pipeline()
 
 
251
 
252
+ progress(0.2, desc="Preparing inputs...")
253
+ image = load_image(first_frame_path)
 
 
254
 
255
+ # Standard Wan Negative Prompt
256
+ negative_prompt = (
257
+ "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
258
+ "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
259
+ "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
260
  )
261
 
262
+ # Match resolution logic from run_wan.py
263
+ max_area = 480 * 832
264
+ mod_value = pipe.vae_scale_factor_spatial * \
265
+ pipe.transformer.config.patch_size[1]
266
+ height, width = compute_hw_from_area(
267
+ image.height, image.width, max_area, mod_value)
268
+ image = image.resize((width, height))
269
+
270
+ progress(0.4, desc="Generating Video (this may take a few minutes)...")
271
+ generator = torch.Generator(device="cuda").manual_seed(0)
272
+
273
+ with torch.inference_mode():
274
+ result = pipe(
275
+ image=image,
276
+ prompt=prompt,
277
+ negative_prompt=negative_prompt,
278
+ height=height,
279
+ width=width,
280
+ num_frames=81, # Wan default
281
+ guidance_scale=3.5,
282
+ num_inference_steps=50,
283
+ generator=generator,
284
+ motion_signal_video_path=motion_video_path,
285
+ motion_signal_mask_path=mask_video_path,
286
+ tweak_index=int(tweak_index),
287
+ tstrong_index=int(tstrong_index),
288
+ )
289
+
290
+ output_path = os.path.join(os.path.dirname(
291
+ first_frame_path), "wan_ttm_output.mp4")
292
+ export_to_video(result.frames[0], output_path, fps=16)
293
 
294
+ return output_path, "✅ TTM Video generated successfully with Wan 2.2!"
 
295
 
296
+ # --- MODIFIED PROCESS VIDEO TO RETURN FILE PATHS ---
 
297
 
 
 
 
 
 
298
 
299
+ def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Progress()):
300
+ if video_path is None:
301
+ return None, None, None, None, "❌ Please upload a video first"
302
 
303
+ progress(0, desc="Initializing...")
304
+ temp_dir = create_user_temp_dir()
305
+ out_dir = os.path.join(temp_dir, "results")
306
+ os.makedirs(out_dir, exist_ok=True)
307
+
308
+ try:
309
+ progress(0.1, desc="Loading video...")
310
+ video_reader = decord.VideoReader(video_path)
311
+ video_tensor = torch.from_numpy(video_reader.get_batch(
312
+ range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2).float()
313
+ video_tensor = video_tensor[::max(
314
+ 1, len(video_tensor)//MAX_FRAMES)][:MAX_FRAMES]
315
+
316
+ h, w = video_tensor.shape[2:]
317
+ scale = 336 / min(h, w)
318
+ if scale < 1:
319
+ video_tensor = T.Resize(
320
+ (int(h*scale)//2*2, int(w*scale)//2*2))(video_tensor)
321
+
322
+ progress(0.4, desc="Running 3D tracking...")
323
+ tracking_results = run_spatial_tracker(video_tensor)
324
+
325
+ rgb_frames = rearrange(
326
+ tracking_results['video_out'].numpy(), "T C H W -> T H W C").astype(np.uint8)
327
+ depth_frames = tracking_results['point_map'][:, 2].numpy()
328
+ depth_frames[tracking_results['conf_depth'].numpy() < 0.5] = 0
329
+
330
+ scene_scale = np.median(depth_frames[depth_frames > 0]) if np.any(
331
+ depth_frames > 0) else 1.0
332
+ new_exts = generate_camera_trajectory(len(
333
+ rgb_frames), camera_movement, tracking_results['intrs_out'].numpy(), scene_scale)
334
+
335
+ progress(0.8, desc="Rendering viewpoint...")
336
+ output_video_path = os.path.join(out_dir, "rendered_video.mp4")
337
+ render_results = render_from_pointcloud(rgb_frames, depth_frames, tracking_results['intrs_out'].numpy(),
338
+ torch.inverse(
339
+ tracking_results['c2w_traj']).numpy(),
340
+ new_exts, output_video_path, fps=OUTPUT_FPS, generate_ttm_inputs=generate_ttm)
341
+
342
+ first_frame_path = os.path.join(out_dir, "first_frame.png")
343
+ cv2.imwrite(first_frame_path, cv2.cvtColor(
344
+ rgb_frames[0], cv2.COLOR_RGB2BGR))
345
+
346
+ status_msg = f"✅ 3D results ready! You can now use the prompt below to generate a high-quality TTM video."
347
+ return render_results['rendered'], render_results['motion_signal'], render_results['mask'], first_frame_path, status_msg
348
+
349
+ except Exception as e:
350
+ logger.error(f"Error: {e}")
351
+ return None, None, None, None, f"❌ Error: {str(e)}"
352
+
353
+
354
+ # --- GRADIO INTERFACE ---
355
+ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
356
+ gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
357
+
358
+ # Shared state for TTM files
359
+ first_frame_file = gr.State()
360
+ motion_signal_file = gr.State()
361
+ mask_file = gr.State()
362
 
363
  with gr.Row():
364
+ with gr.Column(scale=1):
365
+ gr.Markdown("### 1. Tracking & Viewpoint")
366
+ video_input = gr.Video(label="Upload Video")
367
+ camera_movement = gr.Dropdown(
368
+ choices=CAMERA_MOVEMENTS, value="static", label="Camera Movement")
369
+ generate_btn = gr.Button(
370
+ "🚀 1. Run Spatial Tracker", variant="primary")
371
+
372
+ output_video = gr.Video(label="Point Cloud Render (Draft)")
373
+ status_text = gr.Markdown("Ready...")
374
+
375
+ with gr.Column(scale=1):
376
+ gr.Markdown("### 2. Time-to-Move (Wan 2.2)")
377
+ ttm_prompt = gr.Textbox(
378
+ label="Prompt", placeholder="Describe the scene (e.g., 'A monkey walking in the forest, high quality')")
379
+
380
  with gr.Row():
381
+ tweak_idx = gr.Number(
382
+ label="Tweak Index", value=3, precision=0)
383
+ tstrong_idx = gr.Number(
384
+ label="Tstrong Index", value=6, precision=0)
385
+
386
+ wan_generate_btn = gr.Button(
387
+ "✨ 2. Generate TTM Video (Wan)", variant="secondary")
388
+ wan_output_video = gr.Video(label="Final High-Quality TTM Video")
389
+ wan_status = gr.Markdown("Awaiting 3D inputs...")
390
+
391
+ with gr.Accordion("Debug: TTM Intermediate Inputs", open=False):
392
+ with gr.Row():
393
+ motion_signal_output = gr.Video(label="motion_signal.mp4")
394
+ mask_output = gr.Video(label="mask.mp4")
395
+ first_frame_output = gr.Image(label="first_frame.png")
396
+
397
+ # Event Handlers
398
+ generate_btn.click(
399
+ fn=process_video,
400
+ inputs=[video_input, camera_movement],
401
+ outputs=[output_video, motion_signal_output,
402
+ mask_output, first_frame_output, status_text]
403
+ ).then(
404
+ # Link output files to state for the next step
405
+ fn=lambda a, b, c, d: (b, c, d),
406
+ inputs=[output_video, motion_signal_output,
407
+ mask_output, first_frame_output],
408
+ outputs=[motion_signal_file, mask_file, first_frame_file]
409
+ )
410
 
411
+ wan_generate_btn.click(
412
+ fn=run_wan_ttm_generation,
413
+ inputs=[ttm_prompt, tweak_idx, tstrong_idx,
414
+ first_frame_file, motion_signal_file, mask_file],
415
+ outputs=[wan_output_video, wan_status]
416
+ )
417
 
418
+ if __name__ == "__main__":
419
+ demo.launch(share=False)