import os import sys import logging import torch import numpy as np import argparse import math import random import traceback import json import glob import io import urllib import requests import cv2 import time from decord import VideoReader, cpu from easydict import EasyDict from einops import rearrange from tqdm import tqdm from torchvision import transforms from transformers import AutoProcessor from diffusers_lite.arguments import args_init from diffusers_lite.constants import PRECISION_TO_TYPE from diffusers_lite.wan.modules.vae import WanVAE from diffusers_lite.wan.modules.t5 import T5EncoderModel from diffusers_lite.wan.modules.clip import CLIPModel from diffusers_lite.utils.data_utils import split_list, align_ceil_to, align_floor_to from diffusers_lite.utils.diffusion_utils import ( vae_encode, image_encode, prompt2states, ) from omegaconf import OmegaConf DEVICE = "cuda" DTYPE = torch.float16 def read_json(json_path): with open(json_path, 'r', encoding='utf-8') as file: data = json.load(file) return data def write_json(json_data,json_file, encoding='utf-8'): with open(json_file, 'w') as file: json.dump(json_data,file,indent=4) def seed_everything(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) logging.basicConfig(stream=sys.stdout, filemode='a', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S', format='%(asctime)s.%(msecs)03d %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') logger = logging.getLogger('default') logFormater = logging.Formatter("%(asctime)s.%(msecs)03d %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", datefmt='%Y-%m-%d %H:%M:%S') def load_and_analyze_video(video_path, args): if video_path.startswith('http'): req = urllib.request.Request(video_path) with urllib.request.urlopen(req, timeout=20) as resp: video_reader = VideoReader(io.BytesIO(resp.read()), ctx=cpu(0)) else: video_reader = VideoReader(video_path) video_fps = video_reader.get_avg_fps() total_frames = len(video_reader) frame_interval = video_fps / args.extract_fps extract_frames = min( int(math.ceil((total_frames * args.extract_fps) / video_fps)), args.num_frames ) return video_reader, video_fps, total_frames, frame_interval, extract_frames def get_common_video_params(win_video_path, lose_video_path, args): win_reader, win_fps, win_total, win_interval, win_frames = load_and_analyze_video(win_video_path, args) lose_reader, lose_fps, lose_total, lose_interval, lose_frames = load_and_analyze_video(lose_video_path, args) common_frames = min(win_frames, lose_frames) common_frames = align_floor_to(common_frames-1, alignment=4) + 1 print(f"Win video - fps:{win_fps}, total_frames:{win_total}, extract_frames:{win_frames}") print(f"Lose video - fps:{lose_fps}, total_frames:{lose_total}, extract_frames:{lose_frames}") print(f"Common extract_frames: {common_frames}") return win_reader, lose_reader, common_frames def extract_video_frames(video_reader, common_frames, args, video_path): total_frames = len(video_reader) video_fps = video_reader.get_avg_fps() frame_interval = video_fps / args.extract_fps frame_indices = [] current_position = args.start_idx while len(frame_indices) < common_frames and current_position < total_frames: frame_indices.append(int(current_position)) current_position += frame_interval frame_indices = np.array(frame_indices[:common_frames]) print(f"Frame indices: {frame_indices}, count: {len(frame_indices)}") frames = video_reader.get_batch(frame_indices).asnumpy() return frames def height_width_scale(frames, args): height, width = frames.shape[1], frames.shape[2] scale = args.resolution[0] / min(height, width) resize_height_scale = align_ceil_to(int(height * scale), 32) resize_width_scale = align_ceil_to(int(width * scale), 32) max_resolution = args.resolution[0] * args.aspect_ratio max_resolution = align_ceil_to(max_resolution, 32) height_scale = resize_height_scale width_scale = resize_width_scale if resize_height_scale > max_resolution: height_scale = max_resolution if resize_width_scale > max_resolution: width_scale = max_resolution if int(width * scale) < width_scale: scale_new = width_scale / width else: scale_new = scale if int(height * scale_new) < height_scale: scale_new = height_scale/height transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((int(height * scale_new), int(width * scale_new))), transforms.CenterCrop((height_scale, width_scale)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) return height_scale, width_scale, transform def process_video_frames(frames, args, save_first_frame_path, height_scale, width_scale,transform): processed_frames = [] for i, frame in enumerate(frames): processed_frame = transform(frame) processed_frames.append(processed_frame) if i == 0 and save_first_frame_path: denormalized_frame = processed_frame * 0.5 + 0.5 denormalized_frame = denormalized_frame.clamp(0, 1) first_frame = transforms.ToPILImage()(denormalized_frame) first_frame.save(save_first_frame_path) print(f"Processed video scale height {height_scale} width {width_scale}") return torch.stack(processed_frames) def encode_single_video(video_tensor, basic_kwargs, model_kwargs): vae = model_kwargs.vae image_encoder = model_kwargs.image_encoder video = video_tensor.unsqueeze(0).to(basic_kwargs.device) # (b, t, c, h, w) video = rearrange(video, "b t c h w -> b c t h w") batch_size, _, num_frames, height, width = video.shape image = video[:, :, 0:1, :, :] video_condition = torch.cat([ image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width) ], dim=2).to(basic_kwargs.device) with torch.autocast(device_type="cuda", dtype=basic_kwargs.dtype): latents = vae_encode(vae, video, vae_type="wanx") latents_condition = vae_encode(vae, video_condition, vae_type="wanx") image_embeds = image_encode(image_encoder, image, image_encoder_type="wanx") return { "latents": latents, "image_embeds": image_embeds, "latents_condition": latents_condition } def encode_video(args, video_path, basic_kwargs, model_kwargs, save_first_frame_path): video_reader, video_fps, total_frames, frame_interval, extract_frames = load_and_analyze_video(video_path, args) extract_frames = align_floor_to(extract_frames-1, alignment=4) + 1 frames = extract_video_frames(video_reader, extract_frames, args, video_path) height_scale, width_scale,transform = height_width_scale(frames, args) video_tensor = process_video_frames(frames, args, save_first_frame_path, height_scale, width_scale,transform) encode_kwargs = encode_single_video(video_tensor, basic_kwargs, model_kwargs) print(f"Encoded shapes -latents: {encode_kwargs['latents'].shape}, " f"Lose latents: {encode_kwargs['latents'].shape}") return encode_kwargs def basic_init(args): device = torch.device("cuda", 0) dtype = PRECISION_TO_TYPE[args.precision] basic_kwargs = EasyDict({ "device": device, "dtype": dtype, }) return basic_kwargs def model_init(args, basic_kwargs): vae = WanVAE( vae_pth=args.vae_path, device=basic_kwargs.device, ) image_encoder = CLIPModel( checkpoint_path=args.image_encoder_path, tokenizer_path=args.image_processor_path, dtype=basic_kwargs.dtype, device=basic_kwargs.device, ) text_encoder = T5EncoderModel( checkpoint_path=args.text_encoder_path, tokenizer_path=args.tokenizer_path, text_len=args.max_sequence_length, dtype=basic_kwargs.dtype, device=basic_kwargs.device, shard_fn=None, ) model_kwargs = EasyDict({ "vae": vae, "image_encoder": image_encoder, "text_encoder": text_encoder, }) return model_kwargs def encode_caption(args, caption, basic_kwargs, model_kwargs): text_encoder = model_kwargs.text_encoder text_states = prompt2states( caption, text_encoder, device=basic_kwargs.device, text_encoder_type=args.model_type, ) return text_states @torch.no_grad() def main_wan(config): seed_everything(config.seed) start = time.time() basic_kwargs = basic_init(config) model_kwargs = model_init(config, basic_kwargs) print(f"Load VAE: {time.time() - start:.2f}s") output_base_dir = config.save_dir save_latents_dir = os.path.join(output_base_dir, 'latents') save_first_frame_dir = os.path.join(output_base_dir, 'first_frame') save_clip_dir = os.path.join(output_base_dir, 'meta_v1') for dir_path in [save_latents_dir, save_clip_dir, save_first_frame_dir]: os.makedirs(dir_path, exist_ok=True) data = read_json(config.json_path) for clip_data in data: caption_short = clip_data['short_caption'] caption_long = clip_data['long_caption'] if "video_path" in clip_data and clip_data['video_path']: video_path = clip_data["video_path"] base_name = clip_data["source_id"] refl_metafile_path = os.path.join(save_clip_dir, base_name + '_meta_v1.json') if not os.path.isfile(refl_metafile_path): vae_latent_path = os.path.join(save_latents_dir, base_name + '.npy') f1_black_path = os.path.join(save_latents_dir, base_name + '_f1_black.npy') imgclip_path = os.path.join(save_latents_dir, base_name + '_img_clip.npy') first_frame_path = os.path.join(save_first_frame_dir, base_name + '.jpg') textshort_path = os.path.join(save_latents_dir, base_name + '_textshort.npy') textlong_path = os.path.join(save_latents_dir, base_name + '_textlong.npy') try: encode_kwargs = encode_video( config,video_path, basic_kwargs, model_kwargs, first_frame_path ) text_states_short = encode_caption(config, caption_short, basic_kwargs, model_kwargs) text_states_long = encode_caption(config, caption_long, basic_kwargs, model_kwargs) np.save(vae_latent_path, encode_kwargs["latents"].to(torch.float32).cpu().numpy()) np.save(f1_black_path, encode_kwargs["latents_condition"].to(torch.float32).cpu().numpy()) np.save(imgclip_path, encode_kwargs["image_embeds"].to(torch.float32).cpu().numpy()) np.save(textshort_path, text_states_short.to(torch.float32).cpu().numpy()) np.save(textlong_path, text_states_long.to(torch.float32).cpu().numpy()) dpo_meta_data = clip_data.copy() dpo_meta_data.update({ 'vae_latent_path': vae_latent_path, 'f1_black_path': f1_black_path, 'imgclip_path': imgclip_path, 'latent_shape': encode_kwargs["latents"].shape, 'textshort_path': textshort_path, 'text_states_short_shape': text_states_short.shape, 'textlong_path': textlong_path, 'text_states_long_shape': text_states_long.shape, }) with open(refl_metafile_path, 'w') as file: json.dump(dpo_meta_data, file, indent=4, ensure_ascii=False) print(f'Data processed successfully: {refl_metafile_path}') except Exception as e: print(f'Error processing DPO pair: {e}') traceback.print_exc() continue else: print(f'Data already processed: {refl_metafile_path}') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", default='', type=str) args = parser.parse_args() config = OmegaConf.load(args.config) main_wan(config)