causvid / distillation_data /compute_vae_latent.py
lyttt's picture
Add files using upload-large-folder tool
5f5f46e verified
from causvid.models.wan.wan_wrapper import WanVAEWrapper
from causvid.util import launch_distributed_job
import torch.distributed as dist
import imageio.v3 as iio
from tqdm import tqdm
import argparse
import torch
import json
import math
import os
torch.set_grad_enabled(False)
def video_to_numpy(video_path):
"""
Reads a video file and returns a NumPy array containing all frames.
:param video_path: Path to the video file.
:return: NumPy array of shape (num_frames, height, width, channels)
"""
return iio.imread(video_path, plugin="pyav") # Reads the entire video as a NumPy array
def encode(self, videos: torch.Tensor) -> torch.Tensor:
device, dtype = videos[0].device, videos[0].dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
for u in videos
]
output = torch.stack(output, dim=0)
return output
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_video_folder", type=str,
help="Path to the folder containing input videos.")
parser.add_argument("--output_latent_folder", type=str,
help="Path to the folder where output latents will be saved.")
parser.add_argument("--info_path", type=str,
help="Path to the info file containing video metadata.")
args = parser.parse_args()
# Step 1: Setup the environment
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)
# Step 2: Create the generator
launch_distributed_job()
device = torch.cuda.current_device()
with open(args.info_path, "r") as f:
video_info = json.load(f)
model = WanVAEWrapper().to(device=device, dtype=torch.bfloat16)
video_paths = sorted(list(video_info.keys()))
os.makedirs(args.output_latent_folder, exist_ok=True)
for index in tqdm(range(int(math.ceil(len(video_paths) / dist.get_world_size()))), disable=dist.get_rank() != 0):
global_index = index * dist.get_world_size() + dist.get_rank()
if global_index >= len(video_paths):
break
video_path = video_paths[global_index]
prompt = video_info[video_path]
try:
array = video_to_numpy(os.path.join(
args.input_video_folder, video_path))
except:
print(f"Failed to read video: {video_path}")
continue
video_tensor = torch.tensor(array, dtype=torch.float32, device=device).unsqueeze(0).permute(
0, 4, 1, 2, 3
) / 255.0
video_tensor = video_tensor * 2 - 1
video_tensor = video_tensor.to(torch.bfloat16)
encoded_latents = encode(model, video_tensor).transpose(2, 1)
torch.save(
{prompt: encoded_latents.cpu().detach()},
os.path.join(args.output_latent_folder, f"{global_index:08d}.pt")
)
dist.barrier()
if __name__ == "__main__":
main()