|
|
import json |
|
|
import os |
|
|
import random |
|
|
import traceback |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from decord import VideoReader |
|
|
from easydict import EasyDict |
|
|
from einops import rearrange |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
|
|
|
from ..utils.data_utils import align_floor_to, align_ceil_to |
|
|
from ..constants import NULL_DIR |
|
|
|
|
|
|
|
|
class Image2VideoTrainDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
task="i2v-14b-480p", |
|
|
dataset_type="wanx", |
|
|
meta_file_list=[], |
|
|
meta_file_lose_list=[], |
|
|
uncond_prob=[0.0, 0.0], |
|
|
sp_size=1, |
|
|
patch_size=[1,2,2], |
|
|
): |
|
|
self.task = task |
|
|
self.dataset_type = dataset_type |
|
|
self.uncond_prompt_prob = uncond_prob[0] |
|
|
self.uncond_image_prob = uncond_prob[-1] |
|
|
self.sp_size = sp_size |
|
|
self.patch_size = patch_size |
|
|
self.meta_paths = [] |
|
|
|
|
|
for meta_file in meta_file_list: |
|
|
self.meta_paths.extend( |
|
|
[line.strip() for line in open(meta_file, "r").readlines()] |
|
|
) |
|
|
if len(meta_file_lose_list) > 0: |
|
|
self.meta_paths_lose = [] |
|
|
for meta_file in meta_file_lose_list: |
|
|
self.meta_paths_lose.extend( |
|
|
[line.strip() for line in open(meta_file, "r").readlines()] |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.meta_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
try_times = 100 |
|
|
for _ in range(try_times): |
|
|
try: |
|
|
if self.dataset_type in ["refl"]: |
|
|
return self.get_batch_lrm_refl(idx) |
|
|
elif self.dataset_type in ["lrm_ce"]: |
|
|
return self.get_batch_lrm_ce(idx) |
|
|
elif self.dataset_type in ["lrm_bt_online"]: |
|
|
return self.get_batch_lrm_bt_online(idx) |
|
|
except Exception as e: |
|
|
print( |
|
|
f"Error details: {str(e)}-{idx}-{self.meta_paths[idx]}-{traceback.format_exc()}\n" |
|
|
) |
|
|
idx = np.random.randint(len(self.meta_paths)) |
|
|
|
|
|
raise RuntimeError("Too many bad data.") |
|
|
|
|
|
def get_batch_lrm_refl(self, idx): |
|
|
data_json_path = self.meta_paths[idx] |
|
|
|
|
|
with open(data_json_path, "r") as f: |
|
|
data_dict = json.load(f) |
|
|
|
|
|
|
|
|
if 'video_vae_latent_path' in data_dict.keys(): |
|
|
latents_path = data_dict["video_vae_latent_path"] |
|
|
elif 'vae_latent_path' in data_dict.keys(): |
|
|
latents_path = data_dict["vae_latent_path"] |
|
|
else: |
|
|
latents_path = data_dict["latents_path"] |
|
|
latents = np.load(latents_path)[0] |
|
|
latents = torch.from_numpy(latents) |
|
|
frames = latents.shape[1] |
|
|
|
|
|
|
|
|
if 'textshort_path' in data_dict and 'textlong_path' in data_dict: |
|
|
text_states_path = data_dict["textshort_path"] |
|
|
text_states_path_long = data_dict["textlong_path"] |
|
|
prompt= data_dict["short_caption"] |
|
|
if random.random() <= 0.7: |
|
|
text_states_path = text_states_path_long |
|
|
prompt=data_dict["long_caption"] |
|
|
else: |
|
|
text_states_path = data_dict["text_en_path"] |
|
|
prompt= data_dict["prompt"] |
|
|
text_states = np.load(text_states_path)[0] |
|
|
text_states = torch.from_numpy(text_states) |
|
|
|
|
|
|
|
|
image_embeds_path = data_dict["imgclip_path"] |
|
|
image_embeds = torch.from_numpy(np.load(image_embeds_path)) |
|
|
image_embeds = rearrange(image_embeds, "b s d -> (b s) d") |
|
|
|
|
|
|
|
|
if "f1_black_path" in data_dict.keys(): |
|
|
latents_condition_path = data_dict["f1_black_path"] |
|
|
else: |
|
|
latents_condition_path = data_dict["latents_condition_path"] |
|
|
latents_condition = np.load(latents_condition_path)[0] |
|
|
latents_condition = torch.from_numpy(latents_condition) |
|
|
|
|
|
|
|
|
if "flf2v" in self.task: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, f"wanx/uncond_flf2v.npy") |
|
|
else: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, f"wanx/uncond.npy") |
|
|
uncond_text_states = np.load(uncond_text_states_path)[0] |
|
|
uncond_text_states = torch.from_numpy(uncond_text_states) |
|
|
|
|
|
|
|
|
random_number = random.random() |
|
|
if random_number < self.uncond_prompt_prob: |
|
|
null_text_states_path = os.path.join(NULL_DIR, f"wanx/null.npy") |
|
|
null_text_states = np.load(null_text_states_path)[0] |
|
|
text_states = torch.from_numpy(null_text_states) |
|
|
|
|
|
return latents, text_states, uncond_text_states, image_embeds, latents_condition,prompt |
|
|
|
|
|
def get_batch_refl(self, idx): |
|
|
data_json_path = self.meta_paths[idx] |
|
|
|
|
|
with open(data_json_path, "r") as f: |
|
|
data_dict = json.load(f) |
|
|
|
|
|
|
|
|
if 'video_vae_latent_path' in data_dict.keys(): |
|
|
latents_path = data_dict["video_vae_latent_path"] |
|
|
elif 'vae_latent_path' in data_dict.keys(): |
|
|
latents_path = data_dict["vae_latent_path"] |
|
|
else: |
|
|
latents_path = data_dict["latents_path"] |
|
|
latents = np.load(latents_path)[0] |
|
|
latents = torch.from_numpy(latents) |
|
|
frames = latents.shape[1] |
|
|
|
|
|
|
|
|
text_states_path = data_dict["textshort_path"] |
|
|
text_states_path_long = data_dict["textlong_path"] |
|
|
prompt= data_dict["short_caption"] |
|
|
if random.random() <= 0.7: |
|
|
text_states_path = text_states_path_long |
|
|
prompt=data_dict["long_caption"] |
|
|
text_states = np.load(text_states_path)[0] |
|
|
text_states = torch.from_numpy(text_states) |
|
|
|
|
|
image_embeds_path = data_dict["imgclip_path"] |
|
|
image_embeds = torch.from_numpy(np.load(image_embeds_path)) |
|
|
image_embeds = rearrange(image_embeds, "b s d -> (b s) d") |
|
|
|
|
|
if "f1_black_path" in data_dict.keys(): |
|
|
latents_condition_path = data_dict["f1_black_path"] |
|
|
else: |
|
|
latents_condition_path = data_dict["latents_condition_path"] |
|
|
latents_condition = np.load(latents_condition_path)[0] |
|
|
latents_condition = torch.from_numpy(latents_condition) |
|
|
|
|
|
if "flf2v" in self.task: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, f"wanx/uncond_flf2v.npy") |
|
|
else: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, f"wanx/uncond.npy") |
|
|
uncond_text_states = np.load(uncond_text_states_path)[0] |
|
|
uncond_text_states = torch.from_numpy(uncond_text_states) |
|
|
|
|
|
random_number = random.random() |
|
|
if random_number < self.uncond_prompt_prob: |
|
|
null_text_states_path = os.path.join(NULL_DIR, f"wanx/null.npy") |
|
|
null_text_states = np.load(null_text_states_path)[0] |
|
|
text_states = torch.from_numpy(null_text_states) |
|
|
|
|
|
return latents, text_states, uncond_text_states, image_embeds, latents_condition, prompt |
|
|
|
|
|
def get_batch_lrm_ce(self, idx): |
|
|
data_json_path = self.meta_paths[idx] |
|
|
|
|
|
with open(data_json_path, "r") as f: |
|
|
data_dict = json.load(f) |
|
|
|
|
|
source_id = data_dict["source_id"] |
|
|
|
|
|
if 'video_vae_latent_path' in data_dict: |
|
|
latents_path = data_dict["video_vae_latent_path"] |
|
|
else: |
|
|
latents_path = data_dict["vae_latent_path"] |
|
|
|
|
|
latents = np.load(latents_path)[0] |
|
|
latents = torch.from_numpy(latents) |
|
|
frames = latents.shape[1] |
|
|
|
|
|
if 'save_textshort_path' in data_dict: |
|
|
text_states_path = data_dict["save_textshort_path"] |
|
|
elif 'textshort_path' in data_dict: |
|
|
text_states_path = data_dict["textshort_path"] |
|
|
else: |
|
|
text_states_path = data_dict["text_en_path"] |
|
|
|
|
|
text_states = np.load(text_states_path)[0] |
|
|
text_states = torch.from_numpy(text_states) |
|
|
|
|
|
if "image_embeds" in data_dict: |
|
|
image_embeds_path = data_dict["image_embeds"] |
|
|
else: |
|
|
image_embeds_path = data_dict["imgclip_path"] |
|
|
|
|
|
image_embeds = torch.from_numpy(np.load(image_embeds_path)) |
|
|
image_embeds = rearrange(image_embeds, "b s d -> (b s) d") |
|
|
|
|
|
if "f1_black_path" in data_dict: |
|
|
latents_condition_path = data_dict["f1_black_path"] |
|
|
else: |
|
|
latents_condition_path = data_dict["latents_condition_path"] |
|
|
|
|
|
latents_condition = np.load(latents_condition_path)[0] |
|
|
latents_condition = torch.from_numpy(latents_condition) |
|
|
|
|
|
if "flf2v" in self.task: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, "wanx/uncond_flf2v.npy") |
|
|
else: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, "wanx/uncond.npy") |
|
|
|
|
|
uncond_text_states = np.load(uncond_text_states_path)[0] |
|
|
uncond_text_states = torch.from_numpy(uncond_text_states) |
|
|
|
|
|
if "model" in data_dict: |
|
|
data_from_model = data_dict["model"] |
|
|
else: |
|
|
data_from_model = "" |
|
|
if "text_alignment" in data_dict: |
|
|
text_alignment = data_dict["text_alignment"] |
|
|
else: |
|
|
text_alignment = 0 |
|
|
if "blur_quality" in data_dict: |
|
|
blur_quality = data_dict["blur_quality"] |
|
|
else: |
|
|
blur_quality = 0 |
|
|
if "physics_quality" in data_dict: |
|
|
physics_quality = data_dict["physics_quality"] |
|
|
else: |
|
|
physics_quality = 0 |
|
|
if "human_quality" in data_dict: |
|
|
human_quality = data_dict["human_quality"] |
|
|
else: |
|
|
human_quality = 0 |
|
|
|
|
|
if text_alignment == "poor" or text_alignment is None: text_alignment = 0 |
|
|
if blur_quality == "poor" or blur_quality is None: blur_quality = 0 |
|
|
if physics_quality == "poor" or physics_quality is None: physics_quality = 0 |
|
|
if human_quality == "poor" or human_quality is None: human_quality = 0 |
|
|
if text_alignment == "good": text_alignment = 1 |
|
|
if blur_quality == "good": blur_quality = 1 |
|
|
if physics_quality == "good": physics_quality = 1 |
|
|
if human_quality == "good": human_quality = 1 |
|
|
|
|
|
return (latents, text_states, uncond_text_states, image_embeds, latents_condition, |
|
|
data_from_model, text_alignment, blur_quality, physics_quality, human_quality) |
|
|
|
|
|
def get_batch_lrm_bt_online(self, idx): |
|
|
data_json_path = self.meta_paths[idx] |
|
|
|
|
|
if self.meta_paths_lose is None or len(self.meta_paths_lose) == 0: |
|
|
raise ValueError("meta_paths_lose is None or empty. Please ensure bt=True and meta_file_list_lose is provided.") |
|
|
|
|
|
data_json_path_lose = self.meta_paths_lose[random.randint(0, len(self.meta_paths_lose)-1)] |
|
|
|
|
|
with open(data_json_path, "r") as f: |
|
|
data_dict = json.load(f) |
|
|
with open(data_json_path_lose, "r") as f: |
|
|
data_dict_lose = json.load(f) |
|
|
|
|
|
if 'video_vae_latent_path' in data_dict: |
|
|
latents_path = data_dict["video_vae_latent_path"] |
|
|
latents_path_lose = data_dict_lose["video_vae_latent_path"] |
|
|
else: |
|
|
latents_path = data_dict["vae_latent_path"] |
|
|
latents_path_lose = data_dict_lose["vae_latent_path"] |
|
|
|
|
|
latents = np.load(latents_path)[0] |
|
|
latents = torch.from_numpy(latents) |
|
|
frames = latents.shape[1] |
|
|
latents_lose = np.load(latents_path_lose)[0] |
|
|
latents_lose = torch.from_numpy(latents_lose) |
|
|
frames_lose = latents_lose.shape[1] |
|
|
assert latents.shape == latents_lose.shape, f'latents.shape {latents.shape} != latents_lose.shape {latents_lose.shape}' |
|
|
|
|
|
if 'save_textshort_path' in data_dict: |
|
|
text_states_path = data_dict["save_textshort_path"] |
|
|
text_states_path_lose = data_dict_lose["save_textshort_path"] |
|
|
elif 'textshort_path' in data_dict: |
|
|
text_states_path_lose = data_dict_lose["textshort_path"] |
|
|
text_states_path = data_dict["textshort_path"] |
|
|
else: |
|
|
text_states_path = data_dict["text_en_path"] |
|
|
text_states_path_lose = data_dict_lose["text_en_path"] |
|
|
|
|
|
text_states = np.load(text_states_path)[0] |
|
|
text_states = torch.from_numpy(text_states) |
|
|
text_states_lose = np.load(text_states_path_lose)[0] |
|
|
text_states_lose = torch.from_numpy(text_states_lose) |
|
|
|
|
|
if "image_embeds" in data_dict: |
|
|
image_embeds_path = data_dict["image_embeds"] |
|
|
image_embeds_path_lose = data_dict_lose["image_embeds"] |
|
|
else: |
|
|
image_embeds_path = data_dict["imgclip_path"] |
|
|
image_embeds_path_lose = data_dict_lose["imgclip_path"] |
|
|
|
|
|
image_embeds = torch.from_numpy(np.load(image_embeds_path)) |
|
|
image_embeds = rearrange(image_embeds, "b s d -> (b s) d") |
|
|
image_embeds_lose = torch.from_numpy(np.load(image_embeds_path_lose)) |
|
|
image_embeds_lose = rearrange(image_embeds_lose, "b s d -> (b s) d") |
|
|
|
|
|
if "f1_black_path" in data_dict: |
|
|
latents_condition_path = data_dict["f1_black_path"] |
|
|
latents_condition_path_lose = data_dict_lose["f1_black_path"] |
|
|
else: |
|
|
latents_condition_path = data_dict["latents_condition_path"] |
|
|
latents_condition_path_lose = data_dict_lose["latents_condition_path"] |
|
|
|
|
|
latents_condition = np.load(latents_condition_path)[0] |
|
|
latents_condition = torch.from_numpy(latents_condition) |
|
|
latents_condition_lose = np.load(latents_condition_path_lose)[0] |
|
|
latents_condition_lose = torch.from_numpy(latents_condition_lose) |
|
|
|
|
|
if "flf2v" in self.task: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, "wanx/uncond_flf2v.npy") |
|
|
uncond_text_states_path_lose = os.path.join(NULL_DIR, "wanx/uncond_flf2v.npy") |
|
|
else: |
|
|
uncond_text_states_path = os.path.join(NULL_DIR, "wanx/uncond.npy") |
|
|
uncond_text_states_path_lose = os.path.join(NULL_DIR, "wanx/uncond.npy") |
|
|
|
|
|
uncond_text_states = np.load(uncond_text_states_path)[0] |
|
|
uncond_text_states = torch.from_numpy(uncond_text_states) |
|
|
uncond_text_states_lose = np.load(uncond_text_states_path_lose)[0] |
|
|
uncond_text_states_lose = torch.from_numpy(uncond_text_states_lose) |
|
|
|
|
|
return (latents, text_states, uncond_text_states, image_embeds, latents_condition, |
|
|
latents_lose, text_states_lose, uncond_text_states_lose, image_embeds_lose, latents_condition_lose) |
|
|
|
|
|
|
|
|
class Image2VideoEvalDataset(Dataset): |
|
|
def __init__(self, file_path, resolution=(512,512), alignment=16, do_scale=True): |
|
|
self.prompts = [] |
|
|
self.image_ids = [] |
|
|
self.image_paths = [] |
|
|
self.last_image_paths = [] |
|
|
self.seeds = [] |
|
|
|
|
|
if file_path.endswith(".txt"): |
|
|
with open(file_path, "r") as file: |
|
|
for line in file: |
|
|
prompt = line.strip() |
|
|
self.prompts.append(prompt) |
|
|
|
|
|
elif file_path.endswith(".json"): |
|
|
with open(file_path, "r") as f: |
|
|
datas = json.load(f) |
|
|
for data in datas: |
|
|
self.prompts.append(data["caption"].strip()) |
|
|
if "image_id" in data.keys(): |
|
|
self.image_ids.append(data["image_id"]) |
|
|
if "image_path" in data.keys(): |
|
|
self.image_paths.append(data["image_path"]) |
|
|
if "last_image_path" in data.keys(): |
|
|
self.last_image_paths.append(data["last_image_path"]) |
|
|
if "seed" in data.keys(): |
|
|
self.seeds.append(data["seed"]) |
|
|
|
|
|
self.resolution = resolution |
|
|
self.alignment = alignment |
|
|
self.do_scale = do_scale |
|
|
|
|
|
print(f"[INFO] Load text and image dataset done, total len {len(self.prompts)}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.prompts) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
prompt = self.prompts[index] |
|
|
|
|
|
if len(self.image_paths) > 0: |
|
|
image_path = self.image_paths[index] |
|
|
image_id = image_path.split("/")[-1].split(".")[0] |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
width, height = image.size |
|
|
scale = min(min(self.resolution) / min(width, height), max(self.resolution) / max(width, height)) |
|
|
|
|
|
width_scale = align_ceil_to(int(width * scale), self.alignment) |
|
|
height_scale = align_ceil_to(int(height * scale), self.alignment) |
|
|
|
|
|
if not self.do_scale: |
|
|
width_scale = width |
|
|
height_scale = height |
|
|
|
|
|
transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize((height_scale, width_scale)), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
) |
|
|
|
|
|
image = transform(image) |
|
|
else: |
|
|
image_path = "" |
|
|
image = "" |
|
|
image_id = str(index) |
|
|
|
|
|
if len(self.image_ids) > 0: |
|
|
image_id = self.image_ids[index] |
|
|
|
|
|
|
|
|
if len(self.last_image_paths) > 0: |
|
|
last_image_path = self.last_image_paths[index] |
|
|
last_image = Image.open(last_image_path).convert("RGB") |
|
|
last_image = transform(last_image) |
|
|
else: |
|
|
last_image = "" |
|
|
|
|
|
if len(self.seeds) > 0: |
|
|
seed = self.seeds[index] |
|
|
image_id += f'_seed_{seed}' |
|
|
else: |
|
|
seed = 42 |
|
|
|
|
|
return { |
|
|
"prompt": prompt, |
|
|
"image": image, |
|
|
"last_image": last_image, |
|
|
"image_id": image_id, |
|
|
"image_path": image_path, |
|
|
"seed": seed, |
|
|
} |
|
|
|
|
|
|