VideoCoF / videox_fun /data /dataset_image_video.py
XiangpengYang's picture
first commit
42a2bfa
import csv
import gc
import io
import json
import math
import os
import random
import re
from contextlib import contextmanager
from random import shuffle
from threading import Thread
import albumentations
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from decord import VideoReader
from einops import rearrange
from func_timeout import FunctionTimedOut, func_timeout
from packaging import version as pver
from PIL import Image
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data.dataset import Dataset
VIDEO_READER_TIMEOUT = 20
def get_random_mask(shape, image_start_only=False):
f, c, h, w = shape
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
if not image_start_only:
if f != 1:
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
else:
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
if mask_index == 0:
center_x = torch.randint(0, w, (1,)).item()
center_y = torch.randint(0, h, (1,)).item()
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
start_x = max(center_x - block_size_x // 2, 0)
end_x = min(center_x + block_size_x // 2, w)
start_y = max(center_y - block_size_y // 2, 0)
end_y = min(center_y + block_size_y // 2, h)
mask[:, :, start_y:end_y, start_x:end_x] = 1
elif mask_index == 1:
mask[:, :, :, :] = 1
elif mask_index == 2:
mask_frame_index = np.random.randint(1, 5)
mask[mask_frame_index:, :, :, :] = 1
elif mask_index == 3:
mask_frame_index = np.random.randint(1, 5)
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
elif mask_index == 4:
center_x = torch.randint(0, w, (1,)).item()
center_y = torch.randint(0, h, (1,)).item()
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
start_x = max(center_x - block_size_x // 2, 0)
end_x = min(center_x + block_size_x // 2, w)
start_y = max(center_y - block_size_y // 2, 0)
end_y = min(center_y + block_size_y // 2, h)
mask_frame_before = np.random.randint(0, f // 2)
mask_frame_after = np.random.randint(f // 2, f)
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
elif mask_index == 5:
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
elif mask_index == 6:
num_frames_to_mask = random.randint(1, max(f // 2, 1))
frames_to_mask = random.sample(range(f), num_frames_to_mask)
for i in frames_to_mask:
block_height = random.randint(1, h // 4)
block_width = random.randint(1, w // 4)
top_left_y = random.randint(0, h - block_height)
top_left_x = random.randint(0, w - block_width)
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
elif mask_index == 7:
center_x = torch.randint(0, w, (1,)).item()
center_y = torch.randint(0, h, (1,)).item()
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
for i in range(h):
for j in range(w):
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
mask[:, :, i, j] = 1
elif mask_index == 8:
center_x = torch.randint(0, w, (1,)).item()
center_y = torch.randint(0, h, (1,)).item()
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
for i in range(h):
for j in range(w):
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
mask[:, :, i, j] = 1
elif mask_index == 9:
for idx in range(f):
if np.random.rand() > 0.5:
mask[idx, :, :, :] = 1
else:
raise ValueError(f"The mask_index {mask_index} is not define")
else:
if f != 1:
mask[1:, :, :, :] = 1
else:
mask[:, :, :, :] = 1
return mask
class Camera(object):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
def custom_meshgrid(*args):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def get_relative_pose(cam_params):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def ray_condition(K, c2w, H, W, device):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
with open(pose_file_path, 'r') as f:
poses = f.readlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
cam_params = [[float(x) for x in pose] for pose in poses]
if return_poses:
return cam_params
else:
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
def derive_ground_object_from_instruction(instruction: str) -> str:
s = (instruction or '').strip()
if not s:
return 'the target area'
s = s.rstrip('.').strip()
# swap/replace: capture phrase between "replace/swap" and "with/by"
swap_patterns = [
r"\breplace\s+(.*?)\s+(?:with|by)\b",
r"\bswap\s+(.*?)\s+with\b",
]
for pat in swap_patterns:
m = re.search(pat, s, flags=re.IGNORECASE)
if m:
phrase = m.group(1).strip(' .,:;')
if phrase:
return phrase
# removal: capture object after remove/delete/erase/eliminate up to a preposition or punctuation
m = re.search(r"\b(?:remove|delete|erase|eliminate)\s+(.*?)(?:\s+(?:from|in|at|on|over|under|near|by)\b|[.,;]|$)", s, flags=re.IGNORECASE)
if m:
phrase = m.group(1).strip(' .,:;')
if phrase:
return phrase
# add/insert: generic target area
if re.search(r"^\s*(?:add|insert)\b", s, flags=re.IGNORECASE):
return 'the target area'
# local style (change/make ...): take the immediate noun after determiner
m = re.search(r"\b(?:change|make)\s+(?:(the|a|an)\s+)?([A-Za-z][A-Za-z0-9\-]*)", s, flags=re.IGNORECASE)
if m:
det = m.group(1) or ''
noun = m.group(2)
phrase = (det + ' ' + noun).strip()
return phrase
return 'the target area'
class ImageVideoSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
drop_last: bool = False
) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
# buckets for each aspect ratio
self.bucket = {'image':[], 'video':[]}
def __iter__(self):
for idx in self.sampler:
content_type = self.dataset.dataset[idx].get('type', 'image')
self.bucket[content_type].append(idx)
# yield a batch of indices in the same aspect ratio group
if len(self.bucket['video']) == self.batch_size:
bucket = self.bucket['video']
yield bucket[:]
del bucket[:]
elif len(self.bucket['image']) == self.batch_size:
bucket = self.bucket['image']
yield bucket[:]
del bucket[:]
@contextmanager
def VideoReader_contextmanager(*args, **kwargs):
vr = VideoReader(*args, **kwargs)
try:
yield vr
finally:
del vr
gc.collect()
def get_video_reader_batch(video_reader, batch_index):
frames = video_reader.get_batch(batch_index).asnumpy()
return frames
def resize_frame(frame, target_short_side):
h, w, _ = frame.shape
if h < w:
if target_short_side > h:
return frame
new_h = target_short_side
new_w = int(target_short_side * w / h)
else:
if target_short_side > w:
return frame
new_w = target_short_side
new_h = int(target_short_side * h / w)
resized_frame = cv2.resize(frame, (new_w, new_h))
return resized_frame
class VideoEditDataset(Dataset):
def __init__(
self,
ann_path,
data_root=None,
video_sample_height: int = None, # 改为None以支持动态分辨率
video_sample_width: int = None,
video_sample_stride=1,
video_sample_n_frames=65, # 9+8=17 for your case
source_frames=33,
edit_frames=32,
text_drop_ratio=0.1,
enable_bucket=False,
enable_inpaint=False,
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
):
dataset = json.load(open(ann_path))
if isinstance(dataset, dict):
new_dataset = []
for vid_id, info in dataset.items():
text_content = info["edit_instruction"]
new_dataset.append({
"original_video": info["original_video"],
"edited_video": info["edited_video"],
"text": text_content,
"type": info.get("type", "video"),
# 添加分辨率信息到metadata
"resolution": info.get("resolution", None)
})
dataset = new_dataset
self.data_root = data_root
self.dataset = dataset
self.length = len(self.dataset)
self.source_frames = source_frames
self.edit_frames = edit_frames
self.video_sample_n_frames = video_sample_n_frames
self.instruction_template = instruction_template
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.video_sample_stride = video_sample_stride
# 如果启用bucket,不固定分辨率
if enable_bucket:
self.video_sample_height = None
self.video_sample_width = None
else:
self.video_sample_height = video_sample_height
self.video_sample_width = video_sample_width
def load_video_pair(self, original_path, edited_path):
"""加载视频对,保持原始分辨率用于bucket training"""
if self.data_root is not None:
original_path = os.path.join(self.data_root, original_path)
edited_path = os.path.join(self.data_root, edited_path)
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
# 获取视频信息
orig_length = len(orig_reader)
edit_length = len(edit_reader)
min_length = min(orig_length, edit_length)
# 统一采样策略
start_idx = 0 # 从头开始
orig_indices = np.linspace(
start_idx,
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, orig_length - 1),
self.source_frames,
dtype=int
)
edit_indices = np.linspace(
start_idx,
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, edit_length - 1),
self.edit_frames,
dtype=int
)
# 加载帧
orig_frames = get_video_reader_batch(orig_reader, orig_indices)
edit_frames = get_video_reader_batch(edit_reader, edit_indices)
# 在拼接前对齐两段视频到相同 HxW(缩放后中心裁剪到 min(H1,H2) x min(W1,W2))
def resize_and_center_crop_batch(frames_np, target_h, target_w):
resized = []
for i in range(frames_np.shape[0]):
frame = frames_np[i]
h, w = frame.shape[0], frame.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
frame_resized = cv2.resize(frame, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
resized.append(frame_cropped)
return np.stack(resized, axis=0)
oh, ow = orig_frames.shape[1], orig_frames.shape[2]
eh, ew = edit_frames.shape[1], edit_frames.shape[2]
target_h = min(oh, eh)
target_w = min(ow, ew)
if (oh != target_h or ow != target_w):
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
if (eh != target_h or ew != target_w):
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
# 如果启用bucket,返回numpy数组
if self.enable_bucket:
return np.concatenate([orig_frames, edit_frames], axis=0)
else:
# 转换为tensor并归一化
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
return torch.cat([orig_frames, edit_frames], dim=0)
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
while True:
try:
# 加载视频对
pixel_values = self.load_video_pair(
data_info['original_video'],
data_info['edited_video']
)
# 准备文本
text = data_info['text']
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
text = self.instruction_template.format(edit_instruction=text)
if random.random() < self.text_drop_ratio:
text = ''
sample = {
"pixel_values": pixel_values,
"text": text,
"data_type": "video",
"idx": idx,
}
# 如果需要inpainting
if self.enable_inpaint and not self.enable_bucket:
# 这里添加inpaint逻辑
pass
return sample
except Exception as e:
try:
print(
f"Error loading video pair: {e}\n"
f" original={os.path.join(self.data_root, data_info.get('original_video','')) if self.data_root else data_info.get('original_video','')}\n"
f" edited ={os.path.join(self.data_root, data_info.get('edited_video','')) if self.data_root else data_info.get('edited_video','')}"
)
except Exception:
print(f"Error loading video pair: {e}")
idx = random.randint(0, self.length-1)
class VideoEditReasoningDataset(Dataset):
def __init__(
self,
ann_path,
data_root=None,
video_sample_height: int = None,
video_sample_width: int = None,
video_sample_stride=1,
video_sample_n_frames=65,
source_frames=33,
reasoning_frames=4,
edit_frames=32,
text_drop_ratio=0.1,
enable_bucket=False,
enable_inpaint=False,
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instrction}, and finally the same scene but {edit_instruction}",
):
dataset = json.load(open(ann_path))
if isinstance(dataset, dict):
new_dataset = []
for vid_id, info in dataset.items():
text_content = info.get("edit_instruction", info.get("text", ""))
# support both 'grounded_video' and 'ground_video'
grounded_key = "grounded_video" if "grounded_video" in info else "ground_video"
new_dataset.append({
"original_video": info["original_video"],
"grounded_video": info[grounded_key],
"edited_video": info["edited_video"],
"text": text_content,
"edit_instruction": text_content,
"type": info.get("type", "video"),
"resolution": info.get("resolution", None),
})
dataset = new_dataset
self.data_root = data_root
self.dataset = dataset
self.length = len(self.dataset)
self.source_frames = source_frames
self.reasoning_frames = reasoning_frames
self.edit_frames = edit_frames
self.video_sample_n_frames = video_sample_n_frames
self.instruction_template = instruction_template
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.video_sample_stride = video_sample_stride
if enable_bucket:
self.video_sample_height = None
self.video_sample_width = None
else:
self.video_sample_height = video_sample_height
self.video_sample_width = video_sample_width
def load_video_pair(self, original_path, grounded_path, edited_path):
if self.data_root is not None:
original_path = os.path.join(self.data_root, original_path)
grounded_path = os.path.join(self.data_root, grounded_path)
edited_path = os.path.join(self.data_root, edited_path)
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
VideoReader_contextmanager(grounded_path, num_threads=2) as ground_reader, \
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
orig_length = len(orig_reader)
ground_length = len(ground_reader)
edit_length = len(edit_reader)
start_idx = 0
orig_indices = np.linspace(
start_idx,
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, max(orig_length - 1, 0)),
self.source_frames,
dtype=int
)
# reasoning/grounded indices at 8-frame interval (example: 0,7,14,21, ...)
interval = 8
ground_indices_full = np.arange(0, max(ground_length, 1), interval, dtype=int)
if len(ground_indices_full) == 0:
ground_indices = np.array([0] * self.reasoning_frames, dtype=int)
else:
ground_indices = ground_indices_full[: self.reasoning_frames]
if len(ground_indices) < self.reasoning_frames:
pad_value = ground_indices[-1] if len(ground_indices) > 0 else 0
ground_indices = np.pad(
ground_indices, (0, self.reasoning_frames - len(ground_indices)), constant_values=pad_value
)
edit_indices = np.linspace(
start_idx,
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, max(edit_length - 1, 0)),
self.edit_frames,
dtype=int
)
orig_frames = get_video_reader_batch(orig_reader, orig_indices)
ground_frames = get_video_reader_batch(ground_reader, ground_indices)
edit_frames = get_video_reader_batch(edit_reader, edit_indices)
def resize_and_center_crop_batch(frames_np, target_h, target_w):
resized = []
for i in range(frames_np.shape[0]):
frame = frames_np[i]
h, w = frame.shape[0], frame.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
frame_resized = cv2.resize(frame, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
resized.append(frame_cropped)
return np.stack(resized, axis=0)
oh, ow = orig_frames.shape[1], orig_frames.shape[2]
gh, gw = ground_frames.shape[1], ground_frames.shape[2]
eh, ew = edit_frames.shape[1], edit_frames.shape[2]
target_h = min(oh, gh, eh)
target_w = min(ow, gw, ew)
if (oh != target_h or ow != target_w):
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
if (gh != target_h or gw != target_w):
ground_frames = resize_and_center_crop_batch(ground_frames, target_h, target_w)
if (eh != target_h or ew != target_w):
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
if self.enable_bucket:
return np.concatenate([orig_frames, ground_frames, edit_frames], axis=0)
else:
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
ground_frames = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
return torch.cat([orig_frames, ground_frames, edit_frames], dim=0)
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
while True:
try:
pixel_values = self.load_video_pair(
data_info['original_video'],
data_info.get('grounded_video', data_info.get('ground_video')),
data_info['edited_video'],
)
# Prepare instructions
edit_text = data_info.get('edit_instruction', data_info.get('text', ''))
ground_instr = derive_ground_object_from_instruction(edit_text)
text = edit_text
if self.instruction_template:
text = self.instruction_template.format(edit_instruction=edit_text, ground_instrction=ground_instr)
if random.random() < self.text_drop_ratio:
text = ''
sample = {
"pixel_values": pixel_values,
"text": text,
"data_type": "video",
"idx": idx,
}
if self.enable_inpaint and not self.enable_bucket:
pass
return sample
except Exception as e:
print(f"Error loading video triplet: {e}")
idx = random.randint(0, self.length-1)
class ImageVideoDataset(Dataset):
def __init__(
self,
ann_path, data_root=None,
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
image_sample_size=512,
video_repeat=0,
text_drop_ratio=0.1,
enable_bucket=False,
video_length_drop_start=0.0,
video_length_drop_end=1.0,
enable_inpaint=False,
return_file_name=False,
):
# Loading annotations from files
print(f"loading annotations from {ann_path} ...")
if ann_path.endswith('.csv'):
with open(ann_path, 'r') as csvfile:
dataset = list(csv.DictReader(csvfile))
elif ann_path.endswith('.json'):
dataset = json.load(open(ann_path))
self.data_root = data_root
# It's used to balance num of images and videos.
if video_repeat > 0:
self.dataset = []
for data in dataset:
if data.get('type', 'image') != 'video':
self.dataset.append(data)
for _ in range(video_repeat):
for data in dataset:
if data.get('type', 'image') == 'video':
self.dataset.append(data)
else:
self.dataset = dataset
del dataset
self.length = len(self.dataset)
print(f"data scale: {self.length}")
# TODO: enable bucket training
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.return_file_name = return_file_name
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# Video params
self.video_sample_stride = video_sample_stride
self.video_sample_n_frames = video_sample_n_frames
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
# Image params
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.image_sample_size)),
transforms.CenterCrop(self.image_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
if data_info.get('type', 'image')=='video':
video_id, text = data_info['file_path'], data_info['text']
if self.data_root is None:
video_dir = video_id
else:
video_dir = os.path.join(self.data_root, video_id)
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
if min_sample_n_frames == 0:
raise ValueError(f"No Frames in video.")
video_length = int(self.video_length_drop_end * len(video_reader))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
try:
sample_args = (video_reader, batch_index)
pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(pixel_values)):
frame = pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
else:
pixel_values = pixel_values
if not self.enable_bucket:
pixel_values = self.video_transforms(pixel_values)
# Random use no text generation
if random.random() < self.text_drop_ratio:
text = ''
return pixel_values, text, 'video', video_dir
else:
image_path, text = data_info['file_path'], data_info['text']
if self.data_root is not None:
image_path = os.path.join(self.data_root, image_path)
image = Image.open(image_path).convert('RGB')
if not self.enable_bucket:
image = self.image_transforms(image).unsqueeze(0)
else:
image = np.expand_dims(np.array(image), 0)
if random.random() < self.text_drop_ratio:
text = ''
return image, text, 'image', image_path
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'image')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'image')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
pixel_values, name, data_type, file_path = self.get_batch(idx)
sample["pixel_values"] = pixel_values
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if self.return_file_name:
sample["file_name"] = os.path.basename(file_path)
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
class ImageVideoEditDataset(Dataset):
def __init__(
self,
ann_path,
data_root=None,
video_sample_size=512,
video_sample_stride=1,
source_frames=33,
target_frames=32,
text_drop_ratio=0.1,
enable_bucket=False,
enable_inpaint=False,
video_length_drop_start=0.0,
video_length_drop_end=1.0,
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
):
dataset = json.load(open(ann_path))
if isinstance(dataset, dict):
new_dataset = []
for _, info in dataset.items():
# Keep original keys, just standardize text field
data_type = info.get("type", "video")
entry = dict(info) # Copy original entry
# Standardize text field name and handle None/empty values
if "edit_instruction" in entry:
entry["text"] = entry["edit_instruction"]
elif "instruction" in entry:
entry["text"] = entry["instruction"]
elif "text" not in entry:
entry["text"] = ""
# Ensure text is not None (convert None to empty string)
if entry["text"] is None:
entry["text"] = ""
# Add file_path for bucket sampler compatibility
# Bucket sampler expects 'file_path' to get dimensions
if data_type == "video":
entry["file_path"] = entry.get("original_video", "")
else: # image
entry["file_path"] = entry.get("original_image", "")
new_dataset.append(entry)
dataset = new_dataset
self.data_root = data_root
self.dataset = dataset
self.length = len(self.dataset)
# sampling params
self.video_sample_stride = video_sample_stride
self.source_frames = source_frames
self.target_frames = target_frames
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# transforms params (match ImageVideoDataset)
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
# Image transforms for non-bucket mode
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.instruction_template = instruction_template
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
# For pre-resize like ImageVideoDataset
self.larger_side_of_image_and_video = min(self.video_sample_size)
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
resized = []
for i in range(frames_np.shape[0]):
frame = frames_np[i]
h, w = frame.shape[0], frame.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
frame_resized = cv2.resize(frame, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
resized.append(frame_cropped)
return np.stack(resized, axis=0)
def _resize_and_center_crop_image(self, image_np, target_h, target_w):
h, w = image_np.shape[0], image_np.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
image_resized = cv2.resize(image_np, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
return image_cropped
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'video')
# Handle None or empty instruction with safety fallback
raw_text = data_info.get('text', '')
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
# Use a generic fallback description if instruction is missing
raw_text = "the content has been modified"
# Apply instruction template if available
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
text = self.instruction_template.format(edit_instruction=raw_text)
else:
text = raw_text
if data_type == 'video':
# video pair branch (default)
src_rel, tgt_rel = data_info['original_video'], data_info['edited_video']
if self.data_root is not None:
src_path = os.path.join(self.data_root, src_rel)
tgt_path = os.path.join(self.data_root, tgt_rel)
else:
src_path = src_rel
tgt_path = tgt_rel
# Force use CPU decoder to read all frames instead of just keyframes
from decord import cpu
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
# Get video lengths
src_length = len(src_reader)
tgt_length = len(tgt_reader)
# Check if video has enough frames
if src_length < self.source_frames:
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
if tgt_length < self.target_frames:
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
# Unified sampling strategy: start from beginning (same as VideoEditDataset)
start_idx = 0
src_indices = np.linspace(
start_idx,
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
self.source_frames,
dtype=int
)
tgt_indices = np.linspace(
start_idx,
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
self.target_frames,
dtype=int
)
# read batches with timeout
try:
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from pair. Error is {e}.")
# align HxW between source and target to enable concat
sh, sw = src_frames.shape[1], src_frames.shape[2]
th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
target_h = min(sh, th)
target_w = min(sw, tw)
if (sh != target_h or sw != target_w):
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
if (th != target_h or tw != target_w):
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
if not self.enable_bucket:
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
src_tensor = self.video_transforms(src_tensor)
tgt_tensor = self.video_transforms(tgt_tensor)
else:
src_tensor = src_frames
tgt_tensor = tgt_frames
# Random text drop
if random.random() < self.text_drop_ratio:
text = ''
return src_tensor, tgt_tensor, text, 'video'
else:
# image pair branch (simple like ImageVideoDataset image path)
src_img_rel = data_info.get('original_image')
tgt_img_rel = data_info.get('edited_image')
if src_img_rel is None or tgt_img_rel is None:
raise ValueError('Missing original_image/edited_image for image sample')
if self.data_root is not None:
src_img_path = os.path.join(self.data_root, src_img_rel)
tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
else:
src_img_path = src_img_rel
tgt_img_path = tgt_img_rel
src_img = Image.open(src_img_path).convert('RGB')
tgt_img = Image.open(tgt_img_path).convert('RGB')
if not self.enable_bucket:
# Apply transforms and add frame dimension
src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
else:
# For bucket mode, keep as numpy and add frame dimension
src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
if random.random() < self.text_drop_ratio:
text = ''
return src_tensor, tgt_tensor, text, 'image'
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'video')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'video')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
src_vals, tgt_vals, name, data_type = self.get_batch(idx)
if data_type == 'video':
sample["pixel_values_src_video"] = src_vals
sample["pixel_values_tgt_video"] = tgt_vals
else:
sample["pixel_values_src_image"] = src_vals
sample["pixel_values_tgt_image"] = tgt_vals
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
# Inpaint not applied here to avoid ambiguity across src/tgt branches
return sample
class ImageVideoCoTDataset(Dataset):
"""
Dataset for Chain-of-Thought (CoT) style image/video editing.
- For videos: loads original_video, grounded_video, and edited_video (3-part)
- For images: loads original_image and edited_image (2-part, same as ImageVideoEditDataset)
"""
def __init__(
self,
ann_path,
data_root=None,
video_sample_size=512,
video_sample_stride=1,
source_frames=33,
reasoning_frames=4,
target_frames=33,
text_drop_ratio=0.1,
enable_bucket=False,
enable_inpaint=False,
video_length_drop_start=0.0,
video_length_drop_end=1.0,
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instruction}, and finally the same scene but {edit_instruction}",
enable_gradual_ground=False,
enable_gray_red_mask=False,
enable_gray_black_background=False,
enable_gray_alpha_overlay=False,
gray_alpha=0.5,
gray_intensity_range=(96, 160),
gray_tolerance=12,
):
dataset = json.load(open(ann_path))
if isinstance(dataset, dict):
new_dataset = []
for _, info in dataset.items():
data_type = info.get("type", "video")
entry = dict(info) # Copy original entry
# Standardize text field name and handle None/empty values
if "edit_instruction" in entry:
entry["text"] = entry["edit_instruction"]
elif "instruction" in entry:
entry["text"] = entry["instruction"]
elif "text" not in entry:
entry["text"] = ""
# Ensure text is not None
if entry["text"] is None:
entry["text"] = ""
# Add file_path for bucket sampler compatibility
if data_type == "video":
entry["file_path"] = entry.get("original_video", "")
else: # image
entry["file_path"] = entry.get("original_image", "")
new_dataset.append(entry)
dataset = new_dataset
self.data_root = data_root
self.dataset = dataset
self.length = len(self.dataset)
# sampling params
self.video_sample_stride = video_sample_stride
self.source_frames = source_frames
self.reasoning_frames = reasoning_frames
self.target_frames = target_frames
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# transforms params
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
# Image transforms for non-bucket mode
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.instruction_template = instruction_template
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.enable_gradual_ground = enable_gradual_ground
# only one visualization mode at a time
enabled_modes = int(bool(enable_gray_red_mask)) + int(bool(enable_gray_black_background)) + int(bool(enable_gray_alpha_overlay))
if enabled_modes > 1:
raise ValueError("enable_gray_red_mask, enable_gray_black_background and enable_gray_alpha_overlay cannot be enabled simultaneously.")
self.enable_gray_red_mask = enable_gray_red_mask
self.enable_gray_black_background = enable_gray_black_background
self.enable_gray_alpha_overlay = enable_gray_alpha_overlay
self.gray_alpha = float(gray_alpha)
if not (0.0 <= self.gray_alpha <= 1.0):
raise ValueError("gray_alpha must be in [0,1].")
if not isinstance(gray_intensity_range, (list, tuple)) or len(gray_intensity_range) != 2:
raise ValueError("gray_intensity_range must contain exactly two values (min and max intensity).")
self.gray_intensity_range = (int(gray_intensity_range[0]), int(gray_intensity_range[1]))
if self.gray_intensity_range[0] > self.gray_intensity_range[1]:
raise ValueError("gray_intensity_range min value cannot be greater than max value.")
self.gray_tolerance = int(gray_tolerance)
# For pre-resize like ImageVideoDataset
self.larger_side_of_image_and_video = min(self.video_sample_size)
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
resized = []
for i in range(frames_np.shape[0]):
frame = frames_np[i]
h, w = frame.shape[0], frame.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
frame_resized = cv2.resize(frame, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
resized.append(frame_cropped)
return np.stack(resized, axis=0)
def _resize_and_center_crop_image(self, image_np, target_h, target_w):
h, w = image_np.shape[0], image_np.shape[1]
scale = max(target_h / h, target_w / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
image_resized = cv2.resize(image_np, (new_w, new_h))
y0 = max((new_h - target_h) // 2, 0)
x0 = max((new_w - target_w) // 2, 0)
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
return image_cropped
def _derive_ground_instruction(self, edit_instruction_text: str) -> str:
"""Derive grounded object phrase from instruction using shared rules."""
return derive_ground_object_from_instruction(edit_instruction_text)
def _ensure_same_size_pair(self, img_a: np.ndarray, img_b: np.ndarray) -> tuple:
"""Resize img_b to img_a's size if needed to enable per-pixel interpolation."""
ha, wa = img_a.shape[:2]
hb, wb = img_b.shape[:2]
if (ha, wa) == (hb, wb):
return img_a, img_b
resized_b = cv2.resize(img_b, (wa, ha), interpolation=cv2.INTER_LINEAR)
return img_a, resized_b
def _interpolate_ground_frames(self, ground_first: np.ndarray, target_first: np.ndarray,
total_steps: int = 16,
pick_indices: tuple = (0, 4, 8, 12)) -> np.ndarray:
"""
Create grounding frames by linearly interpolating between the first frame of
the grounding video and the first frame of the edited video, then picking
specific indices.
Returns array of shape (len(pick_indices), H, W, 3) in uint8.
"""
a_np, b_np = self._ensure_same_size_pair(ground_first, target_first)
a_t = torch.from_numpy(a_np).float() / 255.0 # H, W, C
b_t = torch.from_numpy(b_np).float() / 255.0 # H, W, C
a_t = a_t.permute(2, 0, 1).contiguous() # C, H, W
b_t = b_t.permute(2, 0, 1).contiguous() # C, H, W
c, h, w = a_t.shape
pair = torch.stack([a_t, b_t], dim=0) # 2, C, H, W
pair_chw_t = pair.permute(1, 2, 3, 0).contiguous() # C, H, W, 2
seq = pair_chw_t.view(1, c * h * w, 2) # 1, (C*H*W), 2
with torch.no_grad():
seq_interp = F.interpolate(seq, size=int(total_steps), mode="linear", align_corners=True)
seq_interp = seq_interp.view(c, h, w, int(total_steps)).permute(3, 0, 1, 2).contiguous() # T, C, H, W
out_frames = []
t_steps = int(total_steps)
for idx in pick_indices:
safe_idx = max(0, min(int(idx), t_steps - 1))
img = (seq_interp[safe_idx].clamp(0.0, 1.0) * 255.0).byte().permute(1, 2, 0).cpu().numpy()
out_frames.append(img)
return np.stack(out_frames, axis=0)
def _build_gray_mask(self, frame: np.ndarray) -> np.ndarray:
"""Detect gray regions in a frame using intensity range and tolerance."""
frame_float = frame.astype(np.float32)
if frame_float.max() <= 1.0:
frame_float = frame_float * 255.0
channel_max = frame_float.max(axis=2)
channel_min = frame_float.min(axis=2)
min_intensity, max_intensity = self.gray_intensity_range
tone_flatness = channel_max - channel_min
mask = tone_flatness <= float(self.gray_tolerance)
mask &= channel_max >= float(min_intensity)
mask &= channel_max <= float(max_intensity)
return mask
def _apply_gray_region_effect(self, frames_np: np.ndarray, mode: str) -> np.ndarray:
"""Apply requested effect on detected gray regions for a batch of frames."""
processed_frames = []
for frame in frames_np:
mask = self._build_gray_mask(frame)
if not np.any(mask):
processed_frames.append(frame)
continue
frame_out = frame.copy()
if np.issubdtype(frame_out.dtype, np.floating) and frame_out.max() <= 1.0:
red_value = np.array([1.0, 0.0, 0.0], dtype=frame_out.dtype)
else:
red_value = np.array([255, 0, 0], dtype=frame_out.dtype)
if mode == "red":
frame_out[mask] = red_value
else:
frame_out[:] = 0
frame_out[mask] = frame[mask]
processed_frames.append(frame_out)
return np.stack(processed_frames, axis=0)
def _apply_gray_overlay_from_reference(self, src_frames_np: np.ndarray, ref_frames_np: np.ndarray,
alpha: float = 0.5, gray_value: float = 0.5, num_frames: int = 4) -> np.ndarray:
"""
Detect gray regions on ref frames, and overlay gray with alpha onto the
first `num_frames` frames of src frames at the same positions.
"""
n = min(int(num_frames), int(src_frames_np.shape[0]), int(ref_frames_np.shape[0]))
if n <= 0:
return src_frames_np
out = src_frames_np.copy()
a = float(alpha)
a = 0.0 if a < 0.0 else (1.0 if a > 1.0 else a)
gv = float(gray_value)
gv = 0.0 if gv < 0.0 else (1.0 if gv > 1.0 else gv)
for i in range(n):
mask = self._build_gray_mask(ref_frames_np[i])
if not np.any(mask):
continue
src = out[i]
# normalize to 0..1 float
if np.issubdtype(src.dtype, np.floating):
f = src.astype(np.float32)
if f.max() > 1.0:
f = np.clip(f / 255.0, 0.0, 1.0)
back_to_uint8 = False
else:
f = src.astype(np.float32) / 255.0
back_to_uint8 = True
gray_color = np.array([gv, gv, gv], dtype=np.float32)
# boolean mask is (H,W); f[mask] -> (K,3), broadcast with gray_color (3,)
f[mask] = (1.0 - a) * f[mask] + a * gray_color
if back_to_uint8:
out[i] = (f * 255.0).clip(0, 255).astype(src.dtype)
else:
out[i] = f.astype(src.dtype)
return out
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'video')
# Handle None or empty instruction with safety fallback
raw_text = data_info.get('text', '')
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
raw_text = "the content has been modified"
if data_type == 'video':
# Video triplet branch: original + grounded + edited
src_rel = data_info['original_video']
# Support both 'grounded_video' and 'ground_video' keys
ground_rel = data_info.get('grounded_video', data_info.get('ground_video'))
tgt_rel = data_info['edited_video']
if self.data_root is not None:
src_path = os.path.join(self.data_root, src_rel)
ground_path = os.path.join(self.data_root, ground_rel)
tgt_path = os.path.join(self.data_root, tgt_rel)
else:
src_path = src_rel
ground_path = ground_rel
tgt_path = tgt_rel
# Force use CPU decoder to read all frames
from decord import cpu
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
VideoReader_contextmanager(ground_path, num_threads=2, ctx=cpu(0)) as ground_reader, \
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
# Get video lengths
src_length = len(src_reader)
ground_length = len(ground_reader)
tgt_length = len(tgt_reader)
# Check if video has enough frames
if src_length < self.source_frames:
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
if tgt_length < self.target_frames:
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
# Unified sampling strategy: start from beginning
start_idx = 0
# Sample source frames
src_indices = np.linspace(
start_idx,
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
self.source_frames,
dtype=int
)
# Sample target frames
tgt_indices = np.linspace(
start_idx,
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
self.target_frames,
dtype=int
)
# Read batches with timeout
try:
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
if self.enable_gradual_ground:
# Interpolate between first frame of grounded and edited videos
ground_first = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, [0]))
# Use the first decoded edited frame if available to avoid double decode
tgt_first_frame = tgt_frames[0]
# steps: 0..15, pick 0,3,6,9,12 -> 5 grounding frames
ground_frames = self._interpolate_ground_frames(
ground_first=ground_first[0],
target_first=tgt_first_frame,
total_steps=16,
pick_indices=(0, 3, 6, 9, 12),
)
else:
# # Original behavior: sample grounding frames evenly by stride
# ground_indices = np.linspace(
# start_idx,
# min(start_idx + (self.reasoning_frames - 1) * self.video_sample_stride, ground_length - 1),
# self.reasoning_frames,
# dtype=int
# )
#==============================================================
# New behavior: ground_indices are the first 'reasoning_frames' from src_indices
ground_indices = src_indices[:self.reasoning_frames]
# --- 增加这个重要的安全检查 ---
# 确保我们想采样的最后一帧 (ground_indices[-1])
# 没有超出 ground_video 的总长度 (ground_length)
if len(ground_indices) > 0 and ground_indices[-1] >= ground_length:
raise ValueError(
f"Data inconsistency error: Ground video has only {ground_length} frames, "
f"but the source-based sampling (stride={self.video_sample_stride}) "
f"requires reading up to frame {ground_indices[-1]}. "
f"File: {ground_path}"
)
ground_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, ground_indices))
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from triplet. Error is {e}.")
# Align HxW among source, ground, and target to enable concat
sh, sw = src_frames.shape[1], src_frames.shape[2]
gh, gw = ground_frames.shape[1], ground_frames.shape[2]
th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
target_h = min(sh, gh, th)
target_w = min(sw, gw, tw)
if (sh != target_h or sw != target_w):
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
if (gh != target_h or gw != target_w):
ground_frames = self._resize_and_center_crop_batch(ground_frames, target_h, target_w)
if (th != target_h or tw != target_w):
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
if self.enable_gray_red_mask or self.enable_gray_black_background:
effect_mode = "red" if self.enable_gray_red_mask else "black"
ground_frames = self._apply_gray_region_effect(ground_frames, effect_mode)
elif self.enable_gray_alpha_overlay:
# Use gray regions detected on grounding frames to overlay 50% gray on the
# first 4 frames of the original video.
ground_frames = self._apply_gray_overlay_from_reference(
src_frames, ground_frames, alpha=self.gray_alpha, gray_value=0.5, num_frames=4
)
if not self.enable_bucket:
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
ground_tensor = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
src_tensor = self.video_transforms(src_tensor)
ground_tensor = self.video_transforms(ground_tensor)
tgt_tensor = self.video_transforms(tgt_tensor)
else:
src_tensor = src_frames
ground_tensor = ground_frames
tgt_tensor = tgt_frames
# Prepare text with template
ground_instr = self._derive_ground_instruction(raw_text)
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
text = self.instruction_template.format(
edit_instruction=raw_text,
ground_instruction=ground_instr
)
else:
text = raw_text
# Random text drop
if random.random() < self.text_drop_ratio:
text = ''
return src_tensor, ground_tensor, tgt_tensor, text, 'video'
else:
# Image pair branch (simple like ImageVideoEditDataset)
src_img_rel = data_info.get('original_image')
tgt_img_rel = data_info.get('edited_image')
if src_img_rel is None or tgt_img_rel is None:
raise ValueError('Missing original_image/edited_image for image sample')
if self.data_root is not None:
src_img_path = os.path.join(self.data_root, src_img_rel)
tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
else:
src_img_path = src_img_rel
tgt_img_path = tgt_img_rel
src_img = Image.open(src_img_path).convert('RGB')
tgt_img = Image.open(tgt_img_path).convert('RGB')
if not self.enable_bucket:
# Apply transforms and add frame dimension
src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
else:
# For bucket mode, keep as numpy and add frame dimension
src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
# Apply instruction template if available
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
text = self.instruction_template.format(edit_instruction=raw_text, ground_instruction="")
else:
text = raw_text
if random.random() < self.text_drop_ratio:
text = ''
# For images, ground_tensor is None
return src_tensor, None, tgt_tensor, text, 'image'
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'video')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'video')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
result = self.get_batch(idx)
if data_type == 'video':
src_vals, ground_vals, tgt_vals, name, data_type = result
sample["pixel_values_src_video"] = src_vals
sample["pixel_values_ground_video"] = ground_vals
sample["pixel_values_tgt_video"] = tgt_vals
else:
src_vals, _, tgt_vals, name, data_type = result
sample["pixel_values_src_image"] = src_vals
sample["pixel_values_tgt_image"] = tgt_vals
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
return sample
def padding_image(images, new_width, new_height):
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
aspect_ratio = images.width / images.height
if new_width / new_height > 1:
if aspect_ratio > new_width / new_height:
new_img_width = new_width
new_img_height = int(new_img_width / aspect_ratio)
else:
new_img_height = new_height
new_img_width = int(new_img_height * aspect_ratio)
else:
if aspect_ratio > new_width / new_height:
new_img_width = new_width
new_img_height = int(new_img_width / aspect_ratio)
else:
new_img_height = new_height
new_img_width = int(new_img_height * aspect_ratio)
resized_img = images.resize((new_img_width, new_img_height))
paste_x = (new_width - new_img_width) // 2
paste_y = (new_height - new_img_height) // 2
new_image.paste(resized_img, (paste_x, paste_y))
return new_image
class ImageVideoControlDataset(Dataset):
def __init__(
self,
ann_path, data_root=None,
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
image_sample_size=512,
video_repeat=0,
text_drop_ratio=0.1,
enable_bucket=False,
video_length_drop_start=0.1,
video_length_drop_end=0.9,
enable_inpaint=False,
enable_camera_info=False,
):
# Loading annotations from files
if ann_path.endswith('.csv'):
with open(ann_path, 'r') as csvfile:
dataset = list(csv.DictReader(csvfile))
elif ann_path.endswith('.json'):
dataset = json.load(open(ann_path))
self.data_root = data_root
# It's used to balance num of images and videos.
if video_repeat > 0:
self.dataset = []
for data in dataset:
if data.get('type', 'image') != 'video':
self.dataset.append(data)
for _ in range(video_repeat):
for data in dataset:
if data.get('type', 'image') == 'video':
self.dataset.append(data)
else:
self.dataset = dataset
del dataset
self.length = len(self.dataset)
print(f"data scale: {self.length}")
# TODO: enable bucket training
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.enable_camera_info = enable_camera_info
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# Video params
self.video_sample_stride = video_sample_stride
self.video_sample_n_frames = video_sample_n_frames
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
if self.enable_camera_info:
self.video_transforms_camera = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size)
]
)
# Image params
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.image_sample_size)),
transforms.CenterCrop(self.image_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
video_id, text = data_info['file_path'], data_info['text']
if data_info.get('type', 'image')=='video':
if self.data_root is None:
video_dir = video_id
else:
video_dir = os.path.join(self.data_root, video_id)
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
if min_sample_n_frames == 0:
raise ValueError(f"No Frames in video.")
video_length = int(self.video_length_drop_end * len(video_reader))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
try:
sample_args = (video_reader, batch_index)
pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(pixel_values)):
frame = pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
else:
pixel_values = pixel_values
if not self.enable_bucket:
pixel_values = self.video_transforms(pixel_values)
# Random use no text generation
if random.random() < self.text_drop_ratio:
text = ''
control_video_id = data_info['control_file_path']
if self.data_root is None:
control_video_id = control_video_id
else:
control_video_id = os.path.join(self.data_root, control_video_id)
if self.enable_camera_info:
if control_video_id.lower().endswith('.txt'):
if not self.enable_bucket:
control_pixel_values = torch.zeros_like(pixel_values)
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
control_camera_values = self.video_transforms_camera(control_camera_values)
else:
control_pixel_values = np.zeros_like(pixel_values)
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
else:
if not self.enable_bucket:
control_pixel_values = torch.zeros_like(pixel_values)
control_camera_values = None
else:
control_pixel_values = np.zeros_like(pixel_values)
control_camera_values = None
else:
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
try:
sample_args = (control_video_reader, batch_index)
control_pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(control_pixel_values)):
frame = control_pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
control_pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
control_pixel_values = control_pixel_values / 255.
del control_video_reader
else:
control_pixel_values = control_pixel_values
if not self.enable_bucket:
control_pixel_values = self.video_transforms(control_pixel_values)
control_camera_values = None
return pixel_values, control_pixel_values, control_camera_values, text, "video"
else:
image_path, text = data_info['file_path'], data_info['text']
if self.data_root is not None:
image_path = os.path.join(self.data_root, image_path)
image = Image.open(image_path).convert('RGB')
if not self.enable_bucket:
image = self.image_transforms(image).unsqueeze(0)
else:
image = np.expand_dims(np.array(image), 0)
if random.random() < self.text_drop_ratio:
text = ''
control_image_id = data_info['control_file_path']
if self.image_root is None:
control_image_id = control_image_id
else:
control_image_id = os.path.join(self.image_root, control_image_id)
control_image = Image.open(control_image_id).convert('RGB')
if not self.enable_bucket:
control_image = self.image_transforms(control_image).unsqueeze(0)
else:
control_image = np.expand_dims(np.array(control_image), 0)
return image, control_image, None, text, 'image'
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'image')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'image')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
sample["pixel_values"] = pixel_values
sample["control_pixel_values"] = control_pixel_values
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if self.enable_camera_info:
sample["control_camera_values"] = control_camera_values
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
if self.enable_inpaint and not self.enable_bucket:
mask = get_random_mask(pixel_values.size())
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
sample["mask_pixel_values"] = mask_pixel_values
sample["mask"] = mask
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
sample["clip_pixel_values"] = clip_pixel_values
return sample