nifty-lab / data /datasets_custom /validation_dataset.py
IgorCSIS
Initial Space deploy, ZeroGPU adapter for Lance
e18ede0
Raw
History Blame Contribute Delete
49.5 kB
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
import json
import os
from typing import Any, Dict, List
import numpy as np
import torch
from torch.utils.data import Dataset
import decord
from decord import VideoReader
from PIL import Image
from data.video.sampler.utils import FRAME_SAMPLER_TYPES
from data.video.sampler.frames import FrameSamplerOutput
from data.transforms import VideoTransform
from data.data_utils import (
get_flattened_position_ids_extrapolate_video,
len2weight,
patchify_video_with_merge,
)
from data.system_prompt_render import render_qwenvl_prompt, expand_and_index_by_token_ids_new
from data.common import generate_system_prompt
from modeling.qwen2 import Qwen2Tokenizer
from config.config_factory import ModelArguments, DataArguments, TrainingArguments
sample_task_map = {
't2v': 0,
'idip': 1,
'edit': 2,
'refedit': 3,
}
modality_map = {
'system_prompt': -1,
'text': 0,
'noise': 1,
'ref_source': 2,
'ref_image': 3,
'ref_vit': 4
}
class ValidationDataset(Dataset):
def __init__(
self,
jsonl_path: str,
tokenizer: Qwen2Tokenizer,
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
new_token_ids: Dict[str, int],
dataset_config: None,
local_rank: int = 0,
world_size: int = 1,
):
"""
初始化验证数据集
Args:
jsonl_path: JSONL文件路径
tokenizer: 分词器
"""
self.jsonl_path = jsonl_path
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
try:
full_data = self._read_jsonl()
except:
with open(jsonl_path, 'r', encoding='utf-8') as f:
full_data = json.load(f)
if isinstance(full_data, dict):
full_data = [{"index": self.pro_index(index), "data": prompt} for index, prompt in full_data.items()]
if world_size > 1:
self.data = full_data[local_rank::world_size]
print(f"Rank {local_rank}/{world_size} will process {len(self.data)} samples")
else:
self.data = full_data
self.data_config = dataset_config
self.bos_token_id = self.new_token_ids["bos_token_id"]
self.eos_token_id = self.new_token_ids["eos_token_id"]
self.start_of_image = self.new_token_ids["start_of_image"]
self.end_of_image = self.new_token_ids["end_of_image"]
self.image_token_id = self.new_token_ids["image_token_id"]
try:
max_duration = self.data_config.max_duration
except:
max_duration = 6.0
video_frame_sampler_params = {"temporal": 4, "sample_fps": 12, "max_duration": max_duration, "assert_seconds": False, "truncate": False}
self.frame_sampler = FRAME_SAMPLER_TYPES["multi_clips"](**video_frame_sampler_params)
self.cpu_count = os.cpu_count() or 1
if self.data_config.resolution in ["video_192p", "image_256res"]:
resolution_vae = 256
resolution_vit = 224
elif self.data_config.resolution == "image_512res":
resolution_vae = 512
resolution_vit = 448
elif self.data_config.resolution == "image_768res":
resolution_vae = 768
resolution_vit = 672
elif self.data_config.resolution == "video_360p":
resolution_vae = 480
resolution_vit = 476
elif self.data_config.resolution == "video_480p":
resolution_vae = 640
resolution_vit = 616
else:
raise ValueError(f"Unknown resolution: {self.data_config.resolution}")
video_transform_args = {
"resolution": resolution_vae,
"mode": "bucket",
"divisible_crop_size": 16,
"stride_spatial": 16,
"stride_temporal": 4,
"aspect_ratios": ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16"],
"mean": 0.5,
"std": 0.5,
}
self.transform = VideoTransform(**video_transform_args)
vit_video_transform_args = {
"resolution": resolution_vit,
"mode": "bucket",
"divisible_crop_size": 28,
"aspect_ratios": ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16"],
"mean": [0.48145466, 0.4578275, 0.40821073],
"std": [0.26862954, 0.26130258, 0.27577711],
}
self.vit_transform = VideoTransform(**vit_video_transform_args)
self.sample = self.set_sequence_status()
self.frame_condition_idx = []
if hasattr(self.data_config, 'system_prompt_type'):
self.system_prompt_type = self.data_config.system_prompt_type
else:
self.system_prompt_type = 'SP0'
def pro_index(self, index: int):
if isinstance(index, str):
for x in ['.mp4', '.jpg', '.png', '.jpeg']:
index = index.replace(x, "")
return int(index)
def set_sequence_status(self):
sequence_status = dict(
curr=0,
sample_lens=[],
sample_type=[],
sample_N_target=[],
packed_position_ids=[],
nested_attention_masks=[],
split_lens=[],
attn_modes=[],
packed_text_ids=[],
packed_text_indexes=[],
packed_label_ids=[],
ce_loss_indexes=[],
ce_loss_weights=[],
vae_image_tensors=[],
vae_video_tensors=[],
packed_latent_position_ids=[],
vae_latent_shapes=[],
packed_vae_token_indexes=[],
packed_timesteps=[],
mse_loss_indexes=[],
packed_vit_tokens=[],
vit_token_seqlens=[],
packed_vit_position_ids=[],
packed_vit_token_indexes=[],
vit_video_grid_thw=[],
vae_video_grid_thw=[],
video_grid_thw=[],
vit_video_tensors=[],
vae_video_latent=[],
vae_data_mode=[],
vit_data_mode=[],
sample_task=[],
sample_modality=[],
)
return sequence_status
def _read_jsonl(self) -> List[Dict[str, Any]]:
"""读取JSONL文件"""
data = []
with open(self.jsonl_path, "r", encoding="utf-8") as f:
for line in f:
data.append(json.loads(line.strip()))
return data
def __len__(self) -> int:
return len(self.data)
@staticmethod
def _read_decord(video: VideoReader, frame_idx: List[int]) -> List[Image.Image]:
frames_np = video.get_batch(frame_idx).asnumpy()
return [Image.fromarray(frame) for frame in frames_np]
def get_video_tensor_online(self, media_url, vision_stream, worker_id=0, element_dtype="image") -> torch.Tensor:
self.vision_stream = vision_stream
video_stream = media_url
if element_dtype == "image":
image = Image.open(video_stream)
if image.mode == "P":
image = image.convert("RGBA")
if image.mode == "RGBA":
bg = Image.new("RGB", image.size, (255, 255, 255))
bg.paste(image, mask=image.split()[3])
image = bg
else:
image = image.convert("RGB")
video_frames = [image]
else:
video_reader = VideoReader(video_stream, ctx=decord.cpu(worker_id % self.cpu_count))
total_frames = len(video_reader)
frames_info = {
"clip_indices": [(0, total_frames)],
"fps": 24,
}
frames_sampler_output: FrameSamplerOutput = self.frame_sampler(frames_info)
video_frames = self._read_decord(video_reader, frames_sampler_output.indices)
if vision_stream == "vae_video":
video_tensor = self.transform(video_frames)
elif vision_stream == "vit_video":
video_tensor = self.vit_transform(video_frames)
if element_dtype == "image":
video_tensor = video_tensor.repeat(1, 2, 1, 1)
if video_tensor.shape[1] % 2 == 1:
last_frame = video_tensor[:, -1:, :, :]
video_tensor = torch.cat([video_tensor, last_frame], dim=1)
else:
raise ValueError(f"Unknown vision_stream: {vision_stream}")
return video_tensor
def process_vit_video(self, video_tensor, curr: int, curr_rope_id: int, curr_split_len: int, curr_video_grid_thw: None, item_loss=0):
if not self.data_config.text_template:
self.sample["packed_text_ids"].append(self.start_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
if isinstance(video_tensor, torch.Tensor):
self.sample["vit_video_tensors"].append(video_tensor)
vit_tokens = patchify_video_with_merge(
video_tensor, self.data_config.vit_patch_size, self.data_config.vit_patch_size_temporal
)
num_video_tokens = vit_tokens.shape[0] // 4
t, h, w = video_tensor.size(1), video_tensor.size(2), video_tensor.size(3)
self.sample["packed_vit_tokens"].append(vit_tokens)
self.sample["vit_data_mode"].append("online")
if t is not None:
vit_video_grid_thw = [
t // self.data_config.vit_patch_size_temporal,
h // self.data_config.vit_patch_size,
w // self.data_config.vit_patch_size,
]
self.sample["vit_video_grid_thw"].append(vit_video_grid_thw)
curr_video_grid_thw.append(vit_video_grid_thw)
self.sample["vit_token_seqlens"].append(num_video_tokens)
self.sample["packed_vit_position_ids"].append(
torch.zeros(num_video_tokens)
)
if not self.data_config.text_template:
self.sample["packed_vit_token_indexes"].extend(range(curr, curr + num_video_tokens))
curr += num_video_tokens
curr_split_len += num_video_tokens
self.sample["packed_text_ids"].extend([self.image_token_id] * num_video_tokens)
self.sample["packed_text_ids"].append(self.end_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
self.sample["packed_position_ids"].extend([curr_rope_id] * curr_split_len)
curr_rope_id += 1
self.sample["attn_modes"].append("full")
self.sample["split_lens"].append(curr_split_len)
return self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_video_tokens
def process_text(self, caption: str, curr: int, curr_rope_id: int, curr_split_len: int, item_loss=0):
"""处理文本,添加特殊token"""
text_ids = self.tokenizer.encode(caption)
shifted_text_ids = [self.bos_token_id] + text_ids
self.sample["packed_text_ids"].extend(shifted_text_ids)
self.sample["packed_text_indexes"].extend(range(curr, curr + len(shifted_text_ids)))
if item_loss == 1:
loss_token_shift = 0
self.sample["ce_loss_indexes"].extend(range(curr - loss_token_shift, curr + len(shifted_text_ids)))
self.sample["ce_loss_weights"].extend([len2weight(len(shifted_text_ids) + loss_token_shift)] * (len(shifted_text_ids) + loss_token_shift))
self.sample["packed_label_ids"].extend(text_ids + [self.eos_token_id])
curr += len(shifted_text_ids)
curr_split_len += len(shifted_text_ids)
# add a <|im_end|> token
self.sample["packed_text_ids"].append(self.eos_token_id)
self.sample["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
self.sample["attn_modes"].append("causal")
self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
curr_rope_id += curr_split_len
self.sample["split_lens"].append(curr_split_len)
return self.sample, curr, curr_rope_id, curr_split_len
def process_vae_video(self, video_tensor, curr: int, curr_rope_id: int, curr_split_len: int, curr_video_grid_thw: None, video_sizes: list, item_loss=0):
if not self.data_config.text_template:
num_special_tokens = 0
self.sample["packed_text_ids"].append(self.start_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
num_special_tokens += 1
if isinstance(video_tensor, torch.Tensor):
self.sample["vae_video_tensors"].append(video_tensor)
_, T, H, W = video_tensor.shape
_T, _H, _W = self.data_config.vae_downsample
t = (T - 1) // _T + 1
h = H // _H
w = W // _W
self.sample["vae_data_mode"].append("online")
spatial_merge_size = 2
vae_video_grid_thw = [
t,
h * spatial_merge_size,
w * spatial_merge_size,
]
self.sample["vae_video_grid_thw"].append(vae_video_grid_thw)
curr_video_grid_thw.append(vae_video_grid_thw)
self.sample["vae_latent_shapes"].append((t, h, w))
packed_latent_position_ids = get_flattened_position_ids_extrapolate_video(t, h, w, max_latent_size=self.data_config.max_latent_size)
self.sample["packed_latent_position_ids"].append(packed_latent_position_ids)
num_vid_tokens = t * h * w
if not self.data_config.text_template:
self.sample["packed_vae_token_indexes"].extend(range(curr, curr + num_vid_tokens))
if item_loss == 1:
timestep = np.random.randn()
frame_condition_idx = self.frame_condition_idx
packed_timesteps = [timestep] * num_vid_tokens
mse_loss_indexes = list(range(curr, curr + num_vid_tokens))
frame_condition_indexes = []
for idx in frame_condition_idx:
if idx == -1:
idx = t - 1
if idx == 1:
continue
frame_condition_indexes.extend(mse_loss_indexes[idx * h * w : (idx + 1) * h * w])
packed_timesteps[idx * h * w : (idx + 1) * h * w] = [-sys.float_info.max] * (h * w)
if frame_condition_idx:
mse_loss_indexes = sorted(list(set(mse_loss_indexes) - set(frame_condition_indexes)))
if not self.data_config.text_template:
self.sample["mse_loss_indexes"].extend(mse_loss_indexes)
else:
timestep = float("-inf")
packed_timesteps = [timestep] * num_vid_tokens
self.sample["packed_timesteps"].extend(packed_timesteps)
if not self.data_config.text_template:
curr += num_vid_tokens
curr_split_len += num_vid_tokens
self.sample["packed_text_ids"].extend([self.image_token_id] * num_vid_tokens)
# add <|endofimage|> token
self.sample["packed_text_ids"].append(self.end_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
num_special_tokens += 1
# update sequence status
if item_loss == 1:
self.sample["attn_modes"].append("noise")
else:
self.sample["attn_modes"].append("full_noise")
self.sample["packed_position_ids"].extend([curr_rope_id] * (num_vid_tokens + num_special_tokens))
curr_rope_id += 1
self.sample["split_lens"].append(curr_split_len)
video_sizes.append([T, H, W])
return self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, num_vid_tokens
def process_text_template(
self,
text_ids,
spans_index,
tgt_index,
caption_index,
video_types: list[str],
curr: int,
curr_rope_id: int,
curr_split_len: int,
item_loss=0,
):
self.sample["packed_text_ids"].extend(text_ids)
self.sample["sample_lens"] = len(text_ids)
curr_split_idx = curr
for video_id, span_index in enumerate(spans_index):
vision_start, vision_end = curr_split_idx + span_index[0], curr_split_idx + span_index[-1]
self.sample["packed_text_indexes"].extend(range(curr, vision_start))
if (vision_start - 1) - curr != 0:
curr_split_len = (vision_start - 1) - curr
self.sample["packed_position_ids"].extend(
range(curr_rope_id, curr_rope_id + curr_split_len)
)
curr_rope_id += curr_split_len
self.sample["sample_modality"].extend([modality_map["system_prompt"]] * curr_split_len)
if caption_index != [] and caption_index[0] in range(curr, curr + curr_split_len):
split_len_1 = caption_index[0] - curr
split_len_2 = len(caption_index)
split_len_3 = curr_split_len - split_len_1 - split_len_2
split_len_text = [split_len_1, split_len_2, split_len_3]
split_len_text = [x for x in split_len_text if x != 0]
self.sample["attn_modes"].extend(["causal"] * len(split_len_text))
self.sample["split_lens"].extend(split_len_text)
else:
self.sample["attn_modes"].append("causal")
self.sample["split_lens"].append(curr_split_len)
curr_split_len = len(span_index) + 2
if video_types[video_id] == "vit_video":
self.sample["packed_vit_token_indexes"].extend(range(vision_start, vision_end + 1))
self.sample["attn_modes"].append("full")
self.sample["sample_modality"].extend([modality_map["ref_vit"]] * curr_split_len)
elif "vae_video" in video_types[video_id]:
self.sample["packed_vae_token_indexes"].extend(range(vision_start, vision_end + 1))
if "cond" in video_types[video_id]:
self.sample["attn_modes"].append("full_noise")
if self.sample_task == "edit":
self.sample["sample_modality"].extend([modality_map["ref_source"]] * curr_split_len)
elif self.sample_task == "idip":
self.sample["sample_modality"].extend([modality_map["ref_image"]] * curr_split_len)
elif "target" in video_types[video_id]:
self.sample["mse_loss_indexes"].extend(range(vision_start, vision_end + 1))
self.sample["attn_modes"].append("noise")
self.sample["sample_modality"].extend([modality_map["noise"]] * curr_split_len)
else:
raise ValueError(f"video_types {video_types[video_id]} not supported")
self.sample["packed_position_ids"].extend([curr_rope_id] * curr_split_len)
self.sample["split_lens"].append(len(span_index) + 2)
curr = vision_end + 1
curr_rope_id += 1
self.sample["packed_text_indexes"].append(curr)
curr += 1
len_split_last = self.sample["sample_lens"] - (curr - curr_split_idx) if spans_index != [] else len(text_ids)
if len_split_last != 0:
self.sample["split_lens"].append(len_split_last)
self.sample["packed_text_indexes"].extend(range(curr, curr + len_split_last))
self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + len_split_last))
self.sample["attn_modes"].append("causal")
self.sample["sample_modality"].extend([modality_map["system_prompt"]] * len_split_last)
if item_loss == 1:
packed_label_index = tgt_index
self.sample["packed_label_ids"].extend(text_ids[packed_label_index[0] :])
packed_label_index = np.asarray(packed_label_index, dtype=np.int64) + curr_split_idx
ce_loss_indexes = (packed_label_index - 1).tolist()
self.sample["ce_loss_indexes"].extend(ce_loss_indexes)
self.sample["ce_loss_weights"].extend([len2weight(len(packed_label_index))] * (len(packed_label_index)))
if caption_index != []:
self.sample["sample_modality"][caption_index[0] : caption_index[-1] + 1] = [modality_map["text"]] * (caption_index[-1] - caption_index[0] + 1)
curr_split_idx += len(text_ids)
curr = curr_split_idx
return self.sample, curr, curr_rope_id, curr_split_len
def process_und_template(self, system_prompt, user_prompt, answer, vit_video_tensor):
curr = 0
sample_lens = 0
curr_rope_id = 0
curr_video_grid_thw = []
prompt_prefix = "<|im_start|>" + "system\n" + system_prompt + "<|im_end|>" + "\n" + "<|im_start|>" + "user\n"
text_ids_prompt_prefix = self.tokenizer.encode(prompt_prefix)
self.sample["packed_text_ids"].extend(text_ids_prompt_prefix)
self.sample["packed_text_indexes"].extend(range(curr, curr + len(text_ids_prompt_prefix)))
curr += len(text_ids_prompt_prefix)
split_len_prefix = len(text_ids_prompt_prefix)
# update sequence status
self.sample["attn_modes"].append("causal")
self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_prefix))
self.sample["split_lens"].append(split_len_prefix)
curr_rope_id += split_len_prefix
self.sample["packed_text_ids"].append(self.start_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
split_len_vision_token = 1
if isinstance(vit_video_tensor, torch.Tensor):
self.sample["vit_video_tensors"].append(vit_video_tensor)
# preprocess video
vit_tokens = patchify_video_with_merge(
vit_video_tensor, self.data_config.vit_patch_size, self.data_config.vit_patch_size_temporal
)
num_video_tokens = vit_tokens.shape[0] // 4
t, h, w = vit_video_tensor.size(1), vit_video_tensor.size(2), vit_video_tensor.size(3)
self.sample["packed_vit_tokens"].append(vit_tokens)
self.sample["vit_data_mode"].append("online")
if t is not None:
vit_video_grid_thw = [
t // self.data_config.vit_patch_size_temporal,
h // self.data_config.vit_patch_size,
w // self.data_config.vit_patch_size,
]
self.sample["vit_video_grid_thw"].append(vit_video_grid_thw)
curr_video_grid_thw.append(vit_video_grid_thw)
self.sample["vit_token_seqlens"].append(num_video_tokens)
self.sample["packed_vit_position_ids"].append(
torch.zeros(num_video_tokens)
)
self.sample["packed_vit_token_indexes"].extend(range(curr, curr + num_video_tokens))
curr += num_video_tokens
split_len_vision_token += num_video_tokens
# dummy position_ids
self.sample["packed_text_ids"].extend([self.image_token_id] * num_video_tokens)
# add a <|endofimage|> token
self.sample["packed_text_ids"].append(self.end_of_image)
self.sample["packed_text_indexes"].append(curr)
curr += 1
split_len_vision_token += 1
# update sequence status
self.sample["attn_modes"].append("full")
self.sample["packed_position_ids"].extend([curr_rope_id] * split_len_vision_token)
self.sample["split_lens"].append(split_len_vision_token)
curr_rope_id += 1
prompt_postfix = user_prompt + "<|im_end|>" + "\n" + "<|im_start|>" + "assistant"
text_ids_prompt_postfix = self.tokenizer.encode(prompt_postfix)
self.sample["packed_text_ids"].extend(text_ids_prompt_postfix)
self.sample["packed_text_indexes"].extend(range(curr, curr + len(text_ids_prompt_postfix)))
curr += len(text_ids_prompt_postfix)
split_len_postfix = len(text_ids_prompt_postfix)
self.sample["attn_modes"].append("causal")
self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_postfix))
self.sample["split_lens"].append(split_len_postfix)
curr_rope_id += split_len_postfix
answer = "\n" + answer
answer_ids = self.tokenizer.encode(answer)
shifted_text_ids_answer = answer_ids + [self.eos_token_id]
self.sample["packed_text_ids"].extend(shifted_text_ids_answer)
self.sample["packed_text_indexes"].extend(range(curr, curr + len(shifted_text_ids_answer)))
self.sample["ce_loss_indexes"].extend(range(curr, curr + len(shifted_text_ids_answer)))
self.sample["ce_loss_weights"].extend([len2weight(len(shifted_text_ids_answer))] * (len(shifted_text_ids_answer)))
self.sample["packed_label_ids"].extend(shifted_text_ids_answer)
curr += len(shifted_text_ids_answer)
split_len_answer = len(shifted_text_ids_answer)
self.sample["attn_modes"].append("causal")
self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_answer))
self.sample["split_lens"].append(split_len_answer)
curr_rope_id += split_len_answer
sample_lens = len(self.sample["packed_text_ids"])
return sample_lens, curr_video_grid_thw
def _finalize_sample(self, sample_lens, curr_video_grid_thw, sample_type, sample=None, additional_fields=None, video_sizes=None):
self.sample["sample_lens"] = [sample_lens]
self.sample["video_grid_thw"] = torch.tensor([curr_video_grid_thw])
self.sample["packed_text_ids"] = torch.tensor(self.sample["packed_text_ids"])
self.sample["packed_text_indexes"] = torch.tensor(self.sample["packed_text_indexes"])
self.sample["packed_vae_token_indexes"] = torch.tensor(self.sample["packed_vae_token_indexes"])
self.sample["packed_position_ids"] = torch.tensor(self.sample["packed_position_ids"])
self.sample["vae_video_grid_thw"] = torch.tensor(self.sample["vae_video_grid_thw"])
self.sample["vit_video_grid_thw"] = torch.tensor(self.sample["vit_video_grid_thw"])
self.sample["packed_vit_token_indexes"] = torch.tensor(self.sample["packed_vit_token_indexes"])
self.sample["sample_N_target"] = torch.tensor([[1]])
self.sample["sample_type"] = [sample_type]
self.sample["padded_videos"] = self.sample["vae_video_tensors"]
if "ce_loss_indexes" in self.sample and len(self.sample["ce_loss_indexes"]) > 0:
self.sample["ce_loss_indexes"] = torch.tensor(self.sample["ce_loss_indexes"])
self.sample["mse_loss_indexes"] = torch.tensor(self.sample["mse_loss_indexes"])
if video_sizes is not None:
self.sample["video_sizes"] = torch.tensor(video_sizes)
elif "video_sizes" in self.sample:
self.sample["video_sizes"] = torch.tensor(self.sample["video_sizes"])
if "sample_modality" in self.sample and len(self.sample["sample_modality"]) > 0:
self.sample["sample_modality"] = torch.tensor(self.sample["sample_modality"])
if sample is not None:
for key in ["index", "category", "question", "gt"]:
if key in sample:
self.sample[key] = sample[key]
if additional_fields is not None:
for key, value in additional_fields.items():
self.sample[key] = value
return self.sample
def ti2t_sample(self, idx: int) -> Dict[str, Any]:
self.sample = self.set_sequence_status()
sample = self.data[idx]
system_prompt = sample["system_prompt"]
user_prompt = sample["user_prompt"]
answer = sample["gt"]
image_path = sample["image_path"]
vit_image_tensor = self.get_video_tensor_online(image_path, vision_stream="vit_video", element_dtype="image")
sample_lens, curr_video_grid_thw = self.process_und_template(
system_prompt=system_prompt,
user_prompt=user_prompt,
answer=answer,
vit_video_tensor=vit_image_tensor,
)
self.sample["system_prompt"] = system_prompt
self.sample["user_prompt"] = user_prompt
self.sample["image_path"] = image_path
self.sample["instruction"] = user_prompt
return self._finalize_sample(
sample_lens, curr_video_grid_thw,
sample_type="und",
sample=sample
)
def t2v_sample(self, idx: int) -> Dict[str, Any]:
"""获取单个样本"""
_T, _H, _W = self.data_config.vae_downsample
if self.data_config.task == "t2i":
t = 1
t_ = 1
element_dtype = 'image'
else:
t = (self.data_config.num_frames - 1) // _T + 1
t_ = self.data_config.num_frames
element_dtype = 'video'
self.sample = self.set_sequence_status()
packed_text_indexes, packed_position_ids, sample_modality = [], [], []
sample = self.data[idx]
if "prompt_en" in sample.keys():
user_prompt = "".join(sample["prompt_en"][0])
else:
user_prompt = sample["data"]
if self.data_config.text_template:
caption_instruction = generate_system_prompt(system_prompt_type=self.data_config.task, vision_type=element_dtype)
text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], []
if self.system_prompt_type == 'SP2':
user_prompt = caption_instruction + " " + user_prompt
caption_instruction = "You are a helpful assistant. "
elif self.system_prompt_type == 'SP1':
caption_instruction = "You are a helpful assistant. " + caption_instruction
text_template_user.append({"type": "text", "text": user_prompt})
else:
text_ids = self.tokenizer.encode(user_prompt)
text_ids = [self.new_token_ids["bos_token_id"]] + text_ids + [self.new_token_ids["eos_token_id"]]
text_split_len = len(text_ids)
packed_text_indexes.extend(range(0, text_split_len))
packed_position_ids.extend(range(0, text_split_len))
sample_modality.extend([modality_map['text']] * text_split_len)
h = self.data_config.H // _H
w = self.data_config.W // _W
spatial_merge_size = 2
num_vid_tokens = t * h * w
if self.data_config.text_template:
text_template_assistant.append({"type":element_dtype})
else:
text_ids.append(self.new_token_ids["start_of_image"])
packed_text_indexes.append(text_split_len)
packed_vae_token_indexes = torch.tensor(range(len(text_ids), len(text_ids) + num_vid_tokens))
text_ids.extend([self.image_token_id] * num_vid_tokens)
text_ids.append(self.new_token_ids["end_of_image"])
packed_text_indexes.append(len(text_ids) - 1)
video_split_len = num_vid_tokens + 2
packed_position_ids.extend([text_split_len] * video_split_len)
sample_modality.extend([modality_map['noise']] * video_split_len)
if self.data_config.text_template:
all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_instruction, text_template_assistant, text_template_user, [num_vid_tokens], search_text=user_prompt)
self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template(
all_token_id,
spans_index,
tgt_index,
search_index,
video_types=['target_vae_video'],
curr=0,
curr_rope_id=0,
curr_split_len=0,
item_loss=0,
)
return {
"packed_text_ids": torch.tensor(text_ids) if not self.data_config.text_template else torch.tensor(self.sample["packed_text_ids"]),
"packed_text_indexes": torch.tensor(packed_text_indexes) if not self.data_config.text_template else torch.tensor(self.sample["packed_text_indexes"]),
"packed_vae_token_indexes": packed_vae_token_indexes if not self.data_config.text_template else torch.tensor(self.sample["packed_vae_token_indexes"]),
"vae_video_grid_thw": torch.tensor([[t, h * spatial_merge_size, w * spatial_merge_size]]),
"video_grid_thw": torch.tensor([[[t, h * spatial_merge_size, w * spatial_merge_size]]]),
"sample_N_target": torch.tensor([[1]]),
"split_lens": [text_split_len, video_split_len] if not self.data_config.text_template else self.sample["split_lens"],
"attn_modes": ["causal", "noise"] if not self.data_config.text_template else self.sample["attn_modes"],
"sample_lens": [text_split_len + video_split_len] if not self.data_config.text_template else [self.sample["sample_lens"]],
"val_sample_type": ["gen"],
"padded_latent": None,
"mse_loss_indexes": packed_vae_token_indexes if not self.data_config.text_template else torch.tensor(self.sample["mse_loss_indexes"]),
"video_sizes": torch.tensor([[t_, self.data_config.H, self.data_config.W]]),
"packed_position_ids": torch.tensor(packed_position_ids) if not self.data_config.text_template else torch.tensor(self.sample["packed_position_ids"]),
"caption": user_prompt,
"sample_type": ["gen"],
"index": sample["index"],
"caption_cn": user_prompt,
"original_prompt_en": sample["original_prompt_en"] if "original_prompt_en" in sample.keys() else user_prompt,
"sample_task": torch.zeros(text_split_len + video_split_len) if not self.data_config.text_template else torch.zeros(self.sample["sample_lens"]),
"sample_modality": torch.tensor(sample_modality) if not self.data_config.text_template else torch.tensor(self.sample["sample_modality"]),
"additional_info": sample["additional_info"] if "additional_info" in sample.keys() else None,
}
def tv2v_sample(self, idx: int) -> Dict[str, Any]:
sample = self.data[idx]
user_prompt = "Create a 2D animation based on the provided image of a maze. The blue star slides smoothly along the white path, stopping perfectly on the red flag and then acquiring a trophy. The blue star never slides or crosses into the black segments of the maze. The camera is a static, top-down view showing the entire maze."
sample["data"] = {
"interleave_array": [user_prompt, sample["image_path"], sample["image_path"], sample["video_path"]],
"element_dtype_array": ["text", "image", "image", "video"],
"istarget_in_interleave": [0, 0, 0, 1]
}
self.sample_task = 'edit'
result = self.tiv2v_sample(idx)
result["caption"] = user_prompt
result["caption_cn"] = user_prompt
return result
def tiv2v_sample(self, idx: int) -> Dict[str, Any]:
sample_modality, text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], [], []
self.sample = self.set_sequence_status()
sample_lens = 0
sample = self.data[idx]
index = sample["index"]
data_sample = sample["data"]
additional_info = sample["data"]["additional_info"] if "additional_info" in sample["data"] else []
interleave_array, element_dtype_array, istarget_in_interleave = data_sample["interleave_array"], data_sample["element_dtype_array"], data_sample["istarget_in_interleave"]
curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, caption_all = 0, 0, 0, [], [], ''
for element, element_dtype, is_target in zip(interleave_array, element_dtype_array, istarget_in_interleave):
if element_dtype == "text":
caption_all += element
if self.data_config.text_template:
text_template_user.append({"type": "text", "text": element})
search_text = element
else:
self.sample, curr, curr_rope_id, curr_split_len = self.process_text(element, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, item_loss=is_target)
sample_lens += curr_split_len
sample_modality.extend([modality_map['text']] * curr_split_len)
elif element_dtype in ["image", "video"]:
if is_target == 0:
vit_image_tensor = self.get_video_tensor_online(element, vision_stream="vit_video", element_dtype=element_dtype)
self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_tokens_ = self.process_vit_video(
vit_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, item_loss=0
)
if self.data_config.text_template:
text_template_user.append({"type": element_dtype})
vit_num_tokens.append(num_tokens_)
video_types.append("vit_video")
else:
sample_lens += curr_split_len
sample_modality.extend([modality_map['ref_vit']] * curr_split_len)
# vae condition/target processing
vae_image_tensor = self.get_video_tensor_online(element, vision_stream="vae_video", element_dtype=element_dtype)
self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, num_tokens_ = self.process_vae_video(
vae_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, video_sizes=video_sizes, item_loss=is_target
)
if self.data_config.text_template:
vit_num_tokens.append(num_tokens_)
if is_target == 0:
text_template_user.append({"type": element_dtype})
video_types.append("cond_vae_video")
else:
text_template_assistant.append({"type": element_dtype})
video_types.append("target_vae_video")
else:
sample_lens += curr_split_len
if is_target == 0:
sample_modality.extend([modality_map[f'ref_{element_dtype}']] * curr_split_len)
else:
sample_modality.extend([modality_map[f'noise']] * curr_split_len)
if self.data_config.text_template:
if text_template_user[0]['type']=='text':
text_template_user = text_template_user[1:] + text_template_user[:1]
caption_instruction = generate_system_prompt(system_prompt_type=self.data_config.task, vision_type=element_dtype)
all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_instruction, text_template_assistant, text_template_user, vit_num_tokens, search_text=search_text)
self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template(
all_token_id,
spans_index,
tgt_index,
search_index,
video_types=video_types,
curr=0,
curr_rope_id=0,
curr_split_len=0,
item_loss=0,
)
sample_lens = len(all_token_id)
sample_modality = self.sample["sample_modality"]
additional_fields = {
"caption": caption_all,
"caption_cn": caption_all,
"index": sample["index"],
"additional_info": additional_info
}
if self.sample_task == 'edit':
self.sample["sample_task"] = torch.ones(sample_lens) * sample_task_map['edit']
elif self.sample_task == 'idip':
self.sample["sample_task"] = torch.ones(sample_lens) * sample_task_map['idip']
return self._finalize_sample(
sample_lens, curr_video_grid_thw,
sample_type="gen",
sample=sample,
additional_fields=additional_fields,
video_sizes=video_sizes
)
def render_template(self, instruction, text_template_assistant, text_template_user, vit_num_tokens, search_text=""):
messages = [
{
"role": "user",
"content": text_template_user,
},
{
"role": "assistant",
"content": text_template_assistant,
},
]
caption_all = render_qwenvl_prompt(messages, default_system=instruction, include_assistant_content=True)
all_token_id, spans_index, tgt_index, search_index = expand_and_index_by_token_ids_new(
rendered_text=caption_all.strip(), tokens=vit_num_tokens, target_text=f"assistant\n", tokenizer=self.tokenizer, search_text=search_text
)
assert len(all_token_id[tgt_index[0] :]) == len(tgt_index)
return all_token_id, spans_index, tgt_index, search_index
def x2t_sample(self, idx: int) -> Dict[str, Any]:
sample_modality = []
self.sample = self.set_sequence_status()
sample_lens = 0
sample = self.data[idx]
index = sample["index"]
data_sample = sample["data"]
interleave_array, element_dtype_array, istarget_in_interleave = data_sample["interleave_array"], data_sample["element_dtype_array"], data_sample["istarget_in_interleave"]
curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, caption_all = 0, 0, 0, [], [], ""
if self.data_config.text_template:
text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], []
for element, element_dtype, is_target in zip(interleave_array, element_dtype_array, istarget_in_interleave):
if element_dtype == "text":
if is_target == 1:
if self.data_config.text_template:
if isinstance(element, str):
caption_a = element
caption_i = generate_system_prompt(system_prompt_type="caption", vision_type=element_dtype_array[0])
caption_q = ""
element = [caption_i, caption_q, caption_a]
caption_i, caption_q, caption_a = element[0], element[1], element[2]
if self.system_prompt_type == 'SP2':
caption_q = caption_i + " " + caption_q
caption_i = "You are a helpful assistant. "
elif self.system_prompt_type == 'SP1':
caption_i = "You are a helpful assistant. " + caption_i
element = [caption_i, caption_q, caption_a]
caption_i, caption_q, caption_a = element[0], element[1], element[2]
text_template_assistant.append({"type": "text", "text": caption_a})
if caption_q != "":
text_template_user.append({"type": "text", "text": caption_q})
all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_i, text_template_assistant, text_template_user, vit_num_tokens)
self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template(
all_token_id,
spans_index,
tgt_index,
search_index,
video_types,
curr=curr,
curr_rope_id=curr_rope_id,
curr_split_len=0,
item_loss=is_target,
)
sample_lens += curr_split_len
caption_all += "\n".join(element)
caption_answer = element[-1]
else:
if isinstance(element, list):
element = element[-1]
self.sample, curr, curr_rope_id, curr_split_len = self.process_text(
element, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, item_loss=is_target
)
sample_lens += curr_split_len
sample_modality.extend([modality_map["text"]] * curr_split_len)
caption_all += element
caption_answer = element
elif element_dtype in ["image", "video"]:
vit_image_tensor = self.get_video_tensor_online(element, vision_stream="vit_video", element_dtype=element_dtype)
self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_tokens_ = self.process_vit_video(
vit_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, item_loss=0
)
sample_lens += curr_split_len
sample_modality.extend([modality_map["ref_vit"]] * curr_split_len)
index_video_path_name = element.split("/")[-1]
if self.data_config.text_template:
text_template_user.append({"type": element_dtype})
vit_num_tokens.append(num_tokens_)
video_types.append("vit_video")
if self.sample["sample_lens"] != []:
sample_lens = self.sample["sample_lens"]
if self.sample["sample_modality"] != []:
sample_modality = self.sample["sample_modality"]
self.sample["sample_modality"] = sample_modality
self.sample["sample_task"] = torch.ones(self.sample["sample_lens"]) * sample_task_map["t2v"]
additional_fields = {
"caption": caption_all,
"caption_cn": caption_all,
"caption_answer": caption_answer,
"index_item": index,
"index": index_video_path_name,
"additional_information": data_sample["additional_information"] if "additional_information" in data_sample.keys() else {},
"visual_path": data_sample["interleave_array"][0],
"question": data_sample["interleave_array"][1][1] if isinstance(data_sample["interleave_array"][1], list) and len(data_sample["interleave_array"][1]) > 1 else None,
"answer": data_sample["interleave_array"][1][2] if isinstance(data_sample["interleave_array"][1], list) and len(data_sample["interleave_array"][1]) > 2 else None
}
return self._finalize_sample(
sample_lens, curr_video_grid_thw,
sample_type="und",
additional_fields=additional_fields
)
def __getitem__(self, idx: int) -> Dict[str, Any]:
if self.data_config.task == "tv2v":
return self.tv2v_sample(idx)
elif self.data_config.task in ["t2i","t2v"]:
return self.t2v_sample(idx)
elif self.data_config.task == "ti2t":
return self.ti2t_sample(idx)
elif "tiv2v" in self.data_config.task:
if 'edit' in self.data_config.task:
self.sample_task = 'edit'
elif 'idip' in self.data_config.task:
self.sample_task = 'idip'
return self.tiv2v_sample(idx)
elif self.data_config.task == "video_edit":
self.sample_task = 'edit'
return self.tiv2v_sample(idx)
elif self.data_config.task == "video_idip" or self.data_config.task == "video_idip_multiref":
self.sample_task = 'idip'
return self.tiv2v_sample(idx)
elif self.data_config.task == "image_edit":
self.sample_task = 'edit'
return self.tiv2v_sample(idx)
elif self.data_config.task == "image_idip":
self.sample_task = 'idip'
return self.tiv2v_sample(idx)
elif self.data_config.task in ["x2t", "x2t_image", "x2t_video"]:
return self.x2t_sample(idx)
else:
raise ValueError(f"Unknown task: {self.data_config.task}")