DVD / test_script /test_single_video.py
haodongli's picture
init-1
4b35c4e
import argparse
import os
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from omegaconf import OmegaConf
from safetensors.torch import load_file
from tqdm import tqdm
from diffsynth import save_video
from examples.wanvideo.model_training.WanTrainingModule import \
WanTrainingModule
# =============================
# Helper: Math & Alignment
# =============================
def compute_scale_and_shift(curr_frames, ref_frames, mask=None):
"""Computes scale and shift for overlap alignment."""
if mask is None:
mask = np.ones_like(ref_frames)
a_00 = np.sum(mask * curr_frames * curr_frames)
a_01 = np.sum(mask * curr_frames)
a_11 = np.sum(mask)
b_0 = np.sum(mask * curr_frames * ref_frames)
b_1 = np.sum(mask * ref_frames)
det = a_00 * a_11 - a_01 * a_01
if det != 0:
scale = (a_11 * b_0 - a_01 * b_1) / det
shift = (-a_01 * b_0 + a_00 * b_1) / det
else:
scale, shift = 1.0, 0.0
return scale, shift
# =============================
# Helper: Video Processing
# =============================
def read_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
video_np = np.stack(frames)
video_tensor = torch.from_numpy(
video_np).permute(0, 3, 1, 2).float() / 255.0
return video_tensor.unsqueeze(0), fps # [1, T, C, H, W], fps
def resize_for_training_scale(video_tensor, target_h=480, target_w=640):
B, T, C, H, W = video_tensor.shape
ratio = max(target_h / H, target_w / W)
new_H = int(np.ceil(H * ratio))
new_W = int(np.ceil(W * ratio))
# Align to 16
new_H = (new_H + 15) // 16 * 16
new_W = (new_W + 15) // 16 * 16
if new_H == H and new_W == W:
return video_tensor, (H, W)
video_reshape = video_tensor.view(B * T, C, H, W)
resized = F.interpolate(video_reshape, size=(
new_H, new_W), mode="bilinear", align_corners=False)
resized = resized.view(B, T, C, new_H, new_W)
return resized, (H, W)
def resize_depth_back(depth_np, orig_size):
orig_H, orig_W = orig_size
depth_tensor = torch.from_numpy(depth_np).permute(0, 3, 1, 2).float()
depth_tensor = F.interpolate(depth_tensor, size=(
orig_H, orig_W), mode='bilinear', align_corners=False)
return depth_tensor.permute(0, 2, 3, 1).cpu().numpy()
def pad_time_mod4(video_tensor):
"""Pads the temporal dimension to satisfy 4n+1 requirement."""
B, T, C, H, W = video_tensor.shape
remainder = T % 4
if remainder != 1:
pad_len = (4 - remainder + 1) % 4
pad_frames = video_tensor[:, -1:, :, :, :].repeat(1, pad_len, 1, 1, 1)
video_tensor = torch.cat([video_tensor, pad_frames], dim=1)
return video_tensor, T
def get_window_index(T, window_size, overlap):
if T <= window_size:
return [(0, T)]
res = [(0, window_size)]
start = window_size - overlap
while start < T:
end = start + window_size
if end < T:
res.append((start, end))
start += window_size - overlap
else:
# Last window ensures full window_size length if possible
start = max(0, T - window_size)
res.append((start, T))
break
return res
# =============================
# Core Inference
# =============================
def generate_depth_sliced(model, input_rgb, window_size=45, overlap=9, scale_only=False):
B, T, C, H, W = input_rgb.shape
depth_windows = get_window_index(T, window_size, overlap)
print(f"depth_windows {depth_windows}")
depth_res_list = []
# 1. Inference per window
for start, end in tqdm(depth_windows, desc="Inferencing Slices"):
_input_rgb_slice = input_rgb[:, start:end]
# Ensure 4n+1 padding
_input_rgb_slice, origin_T = pad_time_mod4(_input_rgb_slice)
_input_frame = _input_rgb_slice.shape[1]
_input_height, _input_width = _input_rgb_slice.shape[-2:]
outputs = model.pipe(
prompt=[""] * B,
negative_prompt=[""] * B,
mode=model.args.mode,
height=_input_height,
width=_input_width,
num_frames=_input_frame,
batch_size=B,
input_image=_input_rgb_slice[:, 0],
extra_images=_input_rgb_slice,
extra_image_frame_index=torch.ones(
[B, _input_frame]).to(model.pipe.device),
input_video=_input_rgb_slice,
cfg_scale=1,
seed=0,
tiled=False,
denoise_step=model.args.denoise_step,
)
# Drop the padded frames
depth_res_list.append(outputs['depth'][:, :origin_T])
# 2. Overlap Alignment
depth_list_aligned = None
prev_end = None
for i, (t, (start, end)) in enumerate(zip(depth_res_list, depth_windows)):
print(f"Handling window {i} start: {start}, end: {end}")
if i == 0:
depth_list_aligned = t
prev_end = end
continue
curr_start = start
real_overlap = prev_end - curr_start
if real_overlap > 0:
ref_frames = depth_list_aligned[:, -real_overlap:]
curr_frames = t[:, :real_overlap]
if scale_only:
scale = np.sum(curr_frames * ref_frames) / \
(np.sum(curr_frames * curr_frames) + 1e-6)
shift = 0.0
else:
scale, shift = compute_scale_and_shift(curr_frames, ref_frames)
scale = np.clip(scale, 0.7, 1.5)
aligned_t = t * scale + shift
aligned_t[aligned_t < 0] = 0
# Debugging Output
curr_overlap_aligned = aligned_t[:, :real_overlap]
diff = np.abs(curr_overlap_aligned - ref_frames)
mae_scalar = float(
diff.mean(axis=tuple(range(1, diff.ndim))).mean())
print(f"\n[Overlap {i}]")
print(f"real_overlap = {real_overlap}")
print(f"scale = {scale:.8f}, shift = {shift:.8f}")
print(
f"aligned curr range = {aligned_t.min():.6f} ~ {aligned_t.max():.6f}")
print(f"overlap MAE(after align) = {mae_scalar:.6f}")
# Smooth blending
alpha = np.linspace(0, 1, real_overlap, dtype=np.float32).reshape(
1, real_overlap, 1, 1, 1)
smooth_overlap = (1 - alpha) * ref_frames + \
alpha * aligned_t[:, :real_overlap]
depth_list_aligned = np.concatenate(
[depth_list_aligned[:, :-real_overlap], smooth_overlap,
aligned_t[:, real_overlap:]], axis=1
)
else:
# Fallback if no overlap exists
depth_list_aligned = np.concatenate(
[depth_list_aligned, t], axis=1)
print(
f"Total depth range after concat = {depth_list_aligned.min():.6f} ~ {depth_list_aligned.max():.6f}")
prev_end = end
# Crop to original length
return depth_list_aligned[:, :T]
# =============================
# Pipeline Components
# =============================
def load_model(ckpt_dir, yaml_args):
"""Initializes and loads the model checkpoint."""
accelerator = Accelerator()
model = WanTrainingModule(
accelerator=accelerator,
model_id_with_origin_paths=yaml_args.model_id_with_origin_paths,
trainable_models=None,
use_gradient_checkpointing=False,
lora_rank=yaml_args.lora_rank,
lora_base_model=yaml_args.lora_base_model,
args=yaml_args,
)
ckpt_path = os.path.join(ckpt_dir, "model.safetensors")
state_dict = load_file(ckpt_path, device="cpu")
dit_state_dict = {k.replace("pipe.dit.", ""): v for k,
v in state_dict.items() if "pipe.dit." in k}
model.pipe.dit.load_state_dict(dit_state_dict, strict=True)
model.merge_lora_layer()
model = model.to("cuda")
return model
def load_video_data(args):
"""Loads and resizes the input video."""
input_tensor, origin_fps = read_video(args.input_video)
print("Original shape:", input_tensor.shape)
input_tensor, orig_size = resize_for_training_scale(
input_tensor, args.height, args.width)
print("Resized shape:", input_tensor.shape)
print(f"input range {input_tensor.min()} - {input_tensor.max()}")
return input_tensor, orig_size, origin_fps
def predict_depth(model, input_tensor, orig_size, args):
"""Runs depth prediction and post-processes the output to original size."""
depth = generate_depth_sliced(
model, input_tensor, args.window_size, args.overlap)[0]
print(f"depth range shape {depth.min()} - {depth.max()}, shape {depth.shape}")
# Post Process: resize back to original
depth = resize_depth_back(depth, orig_size)
print(f"after resizing {depth.min()} - {depth.max()}, {depth.shape}")
return depth
def save_results(depth, origin_fps, args):
"""Normalizes and saves the depth video to disk."""
os.makedirs(args.output_dir, exist_ok=True)
base_name = os.path.basename(args.input_video).split('.')[0]
gray_scale = 'gray' if args.grayscale else 'color'
out_prefix = os.path.join(
args.output_dir, f"{base_name}_{gray_scale}")
output_path = f"{out_prefix}_depth_vis.mp4"
print(f"Saving to {output_path}")
d_min, d_max = depth.min(), depth.max()
vis_depth = (depth - d_min) / (d_max - d_min + 1e-8)
save_video(vis_depth, output_path,
fps=origin_fps, quality=6, grayscale=args.grayscale)
return output_path
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--input_video", type=str, required=True)
parser.add_argument("--output_dir", type=str,
default="./inference_results")
parser.add_argument('--model_config', default='ckpt/model_config.yaml')
parser.add_argument("--window_size", type=int, default=81)
parser.add_argument('--height', type=int, default=480)
parser.add_argument('--width', type=int, default=640)
parser.add_argument("--overlap", type=int, default=9)
parser.add_argument('--grayscale', action='store_true')
return parser.parse_args()
# =============================
# Main Script
# =============================
def main():
args = parse_args()
yaml_args = OmegaConf.load(args.model_config)
# 1. Load Model
model = load_model(args.ckpt, yaml_args)
# 2. Load Video
input_tensor, orig_size, origin_fps = load_video_data(args)
# 3. Predict Depth
depth = predict_depth(model, input_tensor, orig_size, args)
# 4. Save Results
save_results(depth, origin_fps, args)
print("Inference completed successfully!")
if __name__ == "__main__":
main()