import argparse import os from datetime import datetime import cv2 import numpy as np import torch import torch.nn.functional as F from accelerate import Accelerator from omegaconf import OmegaConf from safetensors.torch import load_file from tqdm import tqdm from diffsynth import save_video from examples.wanvideo.model_training.WanTrainingModule import \ WanTrainingModule # ============================= # Helper: Math & Alignment # ============================= def compute_scale_and_shift(curr_frames, ref_frames, mask=None): """Computes scale and shift for overlap alignment.""" if mask is None: mask = np.ones_like(ref_frames) a_00 = np.sum(mask * curr_frames * curr_frames) a_01 = np.sum(mask * curr_frames) a_11 = np.sum(mask) b_0 = np.sum(mask * curr_frames * ref_frames) b_1 = np.sum(mask * ref_frames) det = a_00 * a_11 - a_01 * a_01 if det != 0: scale = (a_11 * b_0 - a_01 * b_1) / det shift = (-a_01 * b_0 + a_00 * b_1) / det else: scale, shift = 1.0, 0.0 return scale, shift # ============================= # Helper: Video Processing # ============================= def read_video(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Cannot open video: {video_path}") fps = cap.get(cv2.CAP_PROP_FPS) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() video_np = np.stack(frames) video_tensor = torch.from_numpy( video_np).permute(0, 3, 1, 2).float() / 255.0 return video_tensor.unsqueeze(0), fps # [1, T, C, H, W], fps def resize_for_training_scale(video_tensor, target_h=480, target_w=640): B, T, C, H, W = video_tensor.shape ratio = max(target_h / H, target_w / W) new_H = int(np.ceil(H * ratio)) new_W = int(np.ceil(W * ratio)) # Align to 16 new_H = (new_H + 15) // 16 * 16 new_W = (new_W + 15) // 16 * 16 if new_H == H and new_W == W: return video_tensor, (H, W) video_reshape = video_tensor.view(B * T, C, H, W) resized = F.interpolate(video_reshape, size=( new_H, new_W), mode="bilinear", align_corners=False) resized = resized.view(B, T, C, new_H, new_W) return resized, (H, W) def resize_depth_back(depth_np, orig_size): orig_H, orig_W = orig_size depth_tensor = torch.from_numpy(depth_np).permute(0, 3, 1, 2).float() depth_tensor = F.interpolate(depth_tensor, size=( orig_H, orig_W), mode='bilinear', align_corners=False) return depth_tensor.permute(0, 2, 3, 1).cpu().numpy() def pad_time_mod4(video_tensor): """Pads the temporal dimension to satisfy 4n+1 requirement.""" B, T, C, H, W = video_tensor.shape remainder = T % 4 if remainder != 1: pad_len = (4 - remainder + 1) % 4 pad_frames = video_tensor[:, -1:, :, :, :].repeat(1, pad_len, 1, 1, 1) video_tensor = torch.cat([video_tensor, pad_frames], dim=1) return video_tensor, T def get_window_index(T, window_size, overlap): if T <= window_size: return [(0, T)] res = [(0, window_size)] start = window_size - overlap while start < T: end = start + window_size if end < T: res.append((start, end)) start += window_size - overlap else: # Last window ensures full window_size length if possible start = max(0, T - window_size) res.append((start, T)) break return res # ============================= # Core Inference # ============================= def generate_depth_sliced(model, input_rgb, window_size=45, overlap=9, scale_only=False): B, T, C, H, W = input_rgb.shape depth_windows = get_window_index(T, window_size, overlap) print(f"depth_windows {depth_windows}") depth_res_list = [] # 1. Inference per window for start, end in tqdm(depth_windows, desc="Inferencing Slices"): _input_rgb_slice = input_rgb[:, start:end] # Ensure 4n+1 padding _input_rgb_slice, origin_T = pad_time_mod4(_input_rgb_slice) _input_frame = _input_rgb_slice.shape[1] _input_height, _input_width = _input_rgb_slice.shape[-2:] outputs = model.pipe( prompt=[""] * B, negative_prompt=[""] * B, mode=model.args.mode, height=_input_height, width=_input_width, num_frames=_input_frame, batch_size=B, input_image=_input_rgb_slice[:, 0], extra_images=_input_rgb_slice, extra_image_frame_index=torch.ones( [B, _input_frame]).to(model.pipe.device), input_video=_input_rgb_slice, cfg_scale=1, seed=0, tiled=False, denoise_step=model.args.denoise_step, ) # Drop the padded frames depth_res_list.append(outputs['depth'][:, :origin_T]) # 2. Overlap Alignment depth_list_aligned = None prev_end = None for i, (t, (start, end)) in enumerate(zip(depth_res_list, depth_windows)): print(f"Handling window {i} start: {start}, end: {end}") if i == 0: depth_list_aligned = t prev_end = end continue curr_start = start real_overlap = prev_end - curr_start if real_overlap > 0: ref_frames = depth_list_aligned[:, -real_overlap:] curr_frames = t[:, :real_overlap] if scale_only: scale = np.sum(curr_frames * ref_frames) / \ (np.sum(curr_frames * curr_frames) + 1e-6) shift = 0.0 else: scale, shift = compute_scale_and_shift(curr_frames, ref_frames) scale = np.clip(scale, 0.7, 1.5) aligned_t = t * scale + shift aligned_t[aligned_t < 0] = 0 # Debugging Output curr_overlap_aligned = aligned_t[:, :real_overlap] diff = np.abs(curr_overlap_aligned - ref_frames) mae_scalar = float( diff.mean(axis=tuple(range(1, diff.ndim))).mean()) print(f"\n[Overlap {i}]") print(f"real_overlap = {real_overlap}") print(f"scale = {scale:.8f}, shift = {shift:.8f}") print( f"aligned curr range = {aligned_t.min():.6f} ~ {aligned_t.max():.6f}") print(f"overlap MAE(after align) = {mae_scalar:.6f}") # Smooth blending alpha = np.linspace(0, 1, real_overlap, dtype=np.float32).reshape( 1, real_overlap, 1, 1, 1) smooth_overlap = (1 - alpha) * ref_frames + \ alpha * aligned_t[:, :real_overlap] depth_list_aligned = np.concatenate( [depth_list_aligned[:, :-real_overlap], smooth_overlap, aligned_t[:, real_overlap:]], axis=1 ) else: # Fallback if no overlap exists depth_list_aligned = np.concatenate( [depth_list_aligned, t], axis=1) print( f"Total depth range after concat = {depth_list_aligned.min():.6f} ~ {depth_list_aligned.max():.6f}") prev_end = end # Crop to original length return depth_list_aligned[:, :T] # ============================= # Pipeline Components # ============================= def load_model(ckpt_dir, yaml_args): """Initializes and loads the model checkpoint.""" accelerator = Accelerator() model = WanTrainingModule( accelerator=accelerator, model_id_with_origin_paths=yaml_args.model_id_with_origin_paths, trainable_models=None, use_gradient_checkpointing=False, lora_rank=yaml_args.lora_rank, lora_base_model=yaml_args.lora_base_model, args=yaml_args, ) ckpt_path = os.path.join(ckpt_dir, "model.safetensors") state_dict = load_file(ckpt_path, device="cpu") dit_state_dict = {k.replace("pipe.dit.", ""): v for k, v in state_dict.items() if "pipe.dit." in k} model.pipe.dit.load_state_dict(dit_state_dict, strict=True) model.merge_lora_layer() model = model.to("cuda") return model def load_video_data(args): """Loads and resizes the input video.""" input_tensor, origin_fps = read_video(args.input_video) print("Original shape:", input_tensor.shape) input_tensor, orig_size = resize_for_training_scale( input_tensor, args.height, args.width) print("Resized shape:", input_tensor.shape) print(f"input range {input_tensor.min()} - {input_tensor.max()}") return input_tensor, orig_size, origin_fps def predict_depth(model, input_tensor, orig_size, args): """Runs depth prediction and post-processes the output to original size.""" depth = generate_depth_sliced( model, input_tensor, args.window_size, args.overlap)[0] print(f"depth range shape {depth.min()} - {depth.max()}, shape {depth.shape}") # Post Process: resize back to original depth = resize_depth_back(depth, orig_size) print(f"after resizing {depth.min()} - {depth.max()}, {depth.shape}") return depth def save_results(depth, origin_fps, args): """Normalizes and saves the depth video to disk.""" os.makedirs(args.output_dir, exist_ok=True) base_name = os.path.basename(args.input_video).split('.')[0] gray_scale = 'gray' if args.grayscale else 'color' out_prefix = os.path.join( args.output_dir, f"{base_name}_{gray_scale}") output_path = f"{out_prefix}_depth_vis.mp4" print(f"Saving to {output_path}") d_min, d_max = depth.min(), depth.max() vis_depth = (depth - d_min) / (d_max - d_min + 1e-8) save_video(vis_depth, output_path, fps=origin_fps, quality=6, grayscale=args.grayscale) return output_path def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--input_video", type=str, required=True) parser.add_argument("--output_dir", type=str, default="./inference_results") parser.add_argument('--model_config', default='ckpt/model_config.yaml') parser.add_argument("--window_size", type=int, default=81) parser.add_argument('--height', type=int, default=480) parser.add_argument('--width', type=int, default=640) parser.add_argument("--overlap", type=int, default=9) parser.add_argument('--grayscale', action='store_true') return parser.parse_args() # ============================= # Main Script # ============================= def main(): args = parse_args() yaml_args = OmegaConf.load(args.model_config) # 1. Load Model model = load_model(args.ckpt, yaml_args) # 2. Load Video input_tensor, orig_size, origin_fps = load_video_data(args) # 3. Predict Depth depth = predict_depth(model, input_tensor, orig_size, args) # 4. Save Results save_results(depth, origin_fps, args) print("Inference completed successfully!") if __name__ == "__main__": main()