causvid / tests /wan /test_wan_ode_pairs.py
lyttt's picture
Add files using upload-large-folder tool
8a70e8e verified
from causvid.models.wan.wan_wrapper import WanVAEWrapper
from diffusers.utils import export_to_video
import argparse
import torch
import glob
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
args = parser.parse_args()
torch.set_grad_enabled(False)
model = WanVAEWrapper().to(device="cuda:0", dtype=torch.bfloat16)
for index, video_path in enumerate(sorted(glob.glob(args.data_path + "/*.pt"))):
data = torch.load(video_path)
prompt = list(data.keys())[0]
video_latent = data[prompt][:, -1].cuda().to(torch.bfloat16)
video = model.decode_to_pixel(video_latent)
video = (video * 0.5 + 0.5).clamp(0, 1)[0].permute(0, 2, 3, 1).cpu().numpy()
print(index, prompt)
export_to_video(video, f"ode_output_{index:03d}.mp4", fps=16)