|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
from PIL.ImageOps import exif_transpose |
|
|
from PIL import Image |
|
|
import io |
|
|
import pyarrow.parquet as pq |
|
|
import random |
|
|
import bisect |
|
|
import pyarrow.fs as fs |
|
|
import csv |
|
|
import numpy as np |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@torch.no_grad() |
|
|
def tokenize_prompt(tokenizer, prompt, text_encoder_architecture='open_clip'): |
|
|
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': |
|
|
return tokenizer( |
|
|
prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).input_ids |
|
|
elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']: |
|
|
|
|
|
return tokenizer( |
|
|
prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
return_tensors="pt", |
|
|
).input_ids |
|
|
elif text_encoder_architecture == 'CLIP_T5_base': |
|
|
input_ids = [] |
|
|
input_ids.append(tokenizer[0]( |
|
|
prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).input_ids) |
|
|
input_ids.append(tokenizer[1]( |
|
|
prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
return_tensors="pt", |
|
|
).input_ids) |
|
|
return input_ids |
|
|
else: |
|
|
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") |
|
|
|
|
|
def encode_prompt(text_encoder, input_ids, text_encoder_architecture='open_clip'): |
|
|
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': |
|
|
outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True) |
|
|
encoder_hidden_states = outputs.hidden_states[-2] |
|
|
cond_embeds = outputs[0] |
|
|
return encoder_hidden_states, cond_embeds |
|
|
elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']: |
|
|
|
|
|
outputs = text_encoder(input_ids=input_ids, return_dict=True) |
|
|
encoder_hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
cond_embeds = None |
|
|
return encoder_hidden_states, cond_embeds |
|
|
elif text_encoder_architecture == 'CLIP_T5_base': |
|
|
outputs_clip = text_encoder[0](input_ids=input_ids[0], return_dict=True, output_hidden_states=True) |
|
|
outputs_t5 = text_encoder[1](input_ids=input_ids[1], decoder_input_ids=torch.zeros_like(input_ids[1]), |
|
|
return_dict=True, output_hidden_states=True) |
|
|
encoder_hidden_states = outputs_t5.encoder_hidden_states[-2] |
|
|
cond_embeds = outputs_clip[0] |
|
|
return encoder_hidden_states, cond_embeds |
|
|
else: |
|
|
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") |
|
|
|
|
|
|
|
|
def process_image(image, size, Norm=False, hps_score = 6.0): |
|
|
image = exif_transpose(image) |
|
|
|
|
|
if not image.mode == "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
orig_height = image.height |
|
|
orig_width = image.width |
|
|
|
|
|
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) |
|
|
|
|
|
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) |
|
|
image = transforms.functional.crop(image, c_top, c_left, size, size) |
|
|
image = transforms.ToTensor()(image) |
|
|
|
|
|
if Norm: |
|
|
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) |
|
|
|
|
|
micro_conds = torch.tensor( |
|
|
[orig_width, orig_height, c_top, c_left, hps_score], |
|
|
) |
|
|
|
|
|
return {"image": image, "micro_conds": micro_conds} |
|
|
|
|
|
|
|
|
class MyParquetDataset(Dataset): |
|
|
def __init__(self, root_dir, tokenizer=None, size=512, |
|
|
text_encoder_architecture='CLIP', norm=False): |
|
|
random.seed(23) |
|
|
|
|
|
self.root_dir = root_dir |
|
|
self.dataset_receipt = {'MSCOCO_part1': {'total_num': 6212, 'ratio':1}, 'MSCOCO_part2': {'total_num': 6212, 'ratio':1}} |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.size = size |
|
|
self.text_encoder_architecture = text_encoder_architecture |
|
|
self.norm = norm |
|
|
|
|
|
self.hdfs = fs.HadoopFileSystem(host="", port=0000) |
|
|
self._init_mixed_parquet_dir_list() |
|
|
|
|
|
self.file_metadata = [] |
|
|
self.cumulative_sizes = [0] |
|
|
total = 0 |
|
|
for path in self.parquet_files: |
|
|
try: |
|
|
with pq.ParquetFile(path, filesystem=self.hdfs) as pf: |
|
|
num_rows = pf.metadata.num_rows |
|
|
self.file_metadata.append({ |
|
|
'path': path, |
|
|
'num_rows': num_rows, |
|
|
'global_offset': total |
|
|
}) |
|
|
total += num_rows |
|
|
self.cumulative_sizes.append(total) |
|
|
except Exception as e: |
|
|
print(f"Error processing {path}: {str(e)}") |
|
|
continue |
|
|
|
|
|
|
|
|
self.current_file = None |
|
|
self.cached_data = None |
|
|
self.cached_file_index = -1 |
|
|
|
|
|
def _init_mixed_parquet_dir_list(self): |
|
|
print('Loading parquet files, please be patient...') |
|
|
self.parquet_files = [] |
|
|
|
|
|
for key, value in self.dataset_receipt.items(): |
|
|
|
|
|
hdfs_path = os.path.join(self.root_dir, key) |
|
|
|
|
|
num = value['total_num'] |
|
|
sampled_list = random.sample( |
|
|
[f"{hdfs_path}/train-{idx:05d}-of-{num:05d}.parquet" for idx in range(num)], |
|
|
k=int(num * value['ratio']) |
|
|
) |
|
|
self.parquet_files += sampled_list |
|
|
|
|
|
def __len__(self): |
|
|
return self.cumulative_sizes[-1] |
|
|
|
|
|
def _locate_file(self, global_idx): |
|
|
|
|
|
file_index = bisect.bisect_right(self.cumulative_sizes, global_idx) - 1 |
|
|
if file_index < 0 or file_index >= len(self.file_metadata): |
|
|
raise IndexError(f"Index {global_idx} out of range") |
|
|
|
|
|
file_info = self.file_metadata[file_index] |
|
|
local_idx = global_idx - file_info['global_offset'] |
|
|
return file_index, local_idx |
|
|
|
|
|
def _load_file(self, file_index): |
|
|
"""Load Parquet files into cache on demand""" |
|
|
if self.cached_file_index != file_index: |
|
|
file_info = self.file_metadata[file_index] |
|
|
try: |
|
|
table = pq.read_table(file_info['path'], filesystem=self.hdfs) |
|
|
self.cached_data = table.to_pydict() |
|
|
self.cached_file_index = file_index |
|
|
except Exception as e: |
|
|
print(f"Error loading {file_info['path']}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
file_index, local_idx = self._locate_file(idx) |
|
|
self._load_file(file_index) |
|
|
sample = {k: v[local_idx] for k, v in self.cached_data.items()} |
|
|
|
|
|
|
|
|
generated_caption, image_path = sample['task2'], sample['image'] |
|
|
instance_image = Image.open(io.BytesIO(image_path['bytes'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rv = process_image(instance_image, self.size, self.norm) |
|
|
|
|
|
if isinstance(self.tokenizer, list): |
|
|
_tmp_ = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture) |
|
|
rv["prompt_input_ids"] = [_tmp_[0][0], _tmp_[1][0]] |
|
|
else: |
|
|
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture)[ |
|
|
0] |
|
|
|
|
|
return rv |
|
|
|
|
|
class HuggingFaceDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
hf_dataset, |
|
|
tokenizer, |
|
|
image_key, |
|
|
prompt_key, |
|
|
prompt_prefix=None, |
|
|
size=512, |
|
|
text_encoder_architecture='CLIP', |
|
|
): |
|
|
self.size = size |
|
|
self.image_key = image_key |
|
|
self.prompt_key = prompt_key |
|
|
self.tokenizer = tokenizer |
|
|
self.hf_dataset = hf_dataset |
|
|
self.prompt_prefix = prompt_prefix |
|
|
self.text_encoder_architecture = text_encoder_architecture |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.hf_dataset) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
item = self.hf_dataset[index] |
|
|
|
|
|
rv = process_image(item[self.image_key], self.size) |
|
|
|
|
|
prompt = item[self.prompt_key] |
|
|
|
|
|
if self.prompt_prefix is not None: |
|
|
prompt = self.prompt_prefix + prompt |
|
|
|
|
|
if isinstance(self.tokenizer, list): |
|
|
_tmp_ = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture) |
|
|
rv["prompt_input_ids"] = [_tmp_[0][0],_tmp_[1][0]] |
|
|
else: |
|
|
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] |
|
|
|
|
|
return rv |
|
|
|
|
|
|
|
|
def process_video(video_tensor, num_frames, height, width, use_random_crop=True): |
|
|
""" |
|
|
Process video tensor for training. |
|
|
|
|
|
Uses aspect-ratio preserving resize + crop to avoid distortion. |
|
|
|
|
|
Args: |
|
|
video_tensor: Video tensor of shape [C, F, H, W] or [F, H, W, C] |
|
|
num_frames: Target number of frames |
|
|
height: Target height |
|
|
width: Target width |
|
|
use_random_crop: If True, use random crop (for training). If False, use center crop (for validation/feature extraction) |
|
|
|
|
|
Returns: |
|
|
Processed video tensor of shape [C, F, H, W] in [0, 1] range |
|
|
""" |
|
|
|
|
|
if video_tensor.dim() == 4: |
|
|
if video_tensor.shape[0] == 3 or video_tensor.shape[0] == 1: |
|
|
|
|
|
pass |
|
|
elif video_tensor.shape[-1] == 3 or video_tensor.shape[-1] == 1: |
|
|
|
|
|
video_tensor = video_tensor.permute(3, 0, 1, 2) |
|
|
else: |
|
|
raise ValueError(f"Unexpected video tensor shape: {video_tensor.shape}") |
|
|
|
|
|
|
|
|
if video_tensor.max() > 1.0: |
|
|
video_tensor = video_tensor / 255.0 |
|
|
|
|
|
C, F, H, W = video_tensor.shape |
|
|
|
|
|
|
|
|
if F != num_frames: |
|
|
if F < num_frames: |
|
|
|
|
|
num_pad = num_frames - F |
|
|
last_frame = video_tensor[:, -1:, :, :] |
|
|
padding = last_frame.repeat(1, num_pad, 1, 1) |
|
|
video_tensor = torch.cat([video_tensor, padding], dim=1) |
|
|
F = num_frames |
|
|
else: |
|
|
|
|
|
max_start = F - num_frames |
|
|
start_idx = random.randint(0, max_start) |
|
|
indices = torch.arange(start_idx, start_idx + num_frames) |
|
|
video_tensor = video_tensor[:, indices, :, :] |
|
|
F = num_frames |
|
|
|
|
|
|
|
|
if H != height or W != width: |
|
|
|
|
|
|
|
|
scale_h = height / H |
|
|
scale_w = width / W |
|
|
|
|
|
|
|
|
|
|
|
scale = max(scale_h, scale_w) |
|
|
|
|
|
|
|
|
new_H = int(H * scale) |
|
|
new_W = int(W * scale) |
|
|
|
|
|
|
|
|
if new_H < height: |
|
|
new_H = height |
|
|
if new_W < width: |
|
|
new_W = width |
|
|
|
|
|
|
|
|
|
|
|
video_tensor = torch.nn.functional.interpolate( |
|
|
video_tensor.reshape(C * F, 1, H, W), |
|
|
size=(new_H, new_W), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).reshape(C, F, new_H, new_W) |
|
|
|
|
|
|
|
|
|
|
|
if use_random_crop: |
|
|
|
|
|
max_h = new_H - height |
|
|
max_w = new_W - width |
|
|
if max_h < 0 or max_w < 0: |
|
|
|
|
|
pad_h = max(0, height - new_H) |
|
|
pad_w = max(0, width - new_W) |
|
|
video_tensor = torch.nn.functional.pad( |
|
|
video_tensor, |
|
|
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), |
|
|
mode='constant', |
|
|
value=0 |
|
|
) |
|
|
|
|
|
if video_tensor.shape[2] != height or video_tensor.shape[3] != width: |
|
|
video_tensor = torch.nn.functional.interpolate( |
|
|
video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]), |
|
|
size=(height, width), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).reshape(C, F, height, width) |
|
|
else: |
|
|
crop_h = random.randint(0, max_h) |
|
|
crop_w = random.randint(0, max_w) |
|
|
video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width] |
|
|
else: |
|
|
|
|
|
crop_h = (new_H - height) // 2 |
|
|
crop_w = (new_W - width) // 2 |
|
|
if crop_h < 0 or crop_w < 0: |
|
|
|
|
|
pad_h = max(0, height - new_H) |
|
|
pad_w = max(0, width - new_W) |
|
|
video_tensor = torch.nn.functional.pad( |
|
|
video_tensor, |
|
|
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), |
|
|
mode='constant', |
|
|
value=0 |
|
|
) |
|
|
|
|
|
if video_tensor.shape[2] != height or video_tensor.shape[3] != width: |
|
|
video_tensor = torch.nn.functional.interpolate( |
|
|
video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]), |
|
|
size=(height, width), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).reshape(C, F, height, width) |
|
|
else: |
|
|
video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width] |
|
|
|
|
|
|
|
|
C, F, H, W = video_tensor.shape |
|
|
assert F == num_frames, f"Frame count mismatch: expected {num_frames}, got {F}" |
|
|
assert H == height, f"Height mismatch: expected {height}, got {H}" |
|
|
assert W == width, f"Width mismatch: expected {width}, got {W}" |
|
|
|
|
|
return video_tensor |
|
|
|
|
|
|
|
|
class VideoDataset(Dataset): |
|
|
""" |
|
|
Dataset for video training, compatible with HuggingFace datasets format. |
|
|
Supports OpenVid1M and similar video-text datasets. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
hf_dataset, |
|
|
tokenizer, |
|
|
video_key="video", |
|
|
prompt_key="caption", |
|
|
prompt_prefix=None, |
|
|
num_frames=16, |
|
|
height=480, |
|
|
width=848, |
|
|
text_encoder_architecture='umt5-base', |
|
|
use_random_crop=True, |
|
|
): |
|
|
self.hf_dataset = hf_dataset |
|
|
self.tokenizer = tokenizer |
|
|
self.video_key = video_key |
|
|
self.prompt_key = prompt_key |
|
|
self.prompt_prefix = prompt_prefix |
|
|
self.num_frames = num_frames |
|
|
self.height = height |
|
|
self.width = width |
|
|
self.text_encoder_architecture = text_encoder_architecture |
|
|
self.use_random_crop = use_random_crop |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.hf_dataset) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
item = self.hf_dataset[index] |
|
|
|
|
|
|
|
|
video = item[self.video_key] |
|
|
|
|
|
|
|
|
if isinstance(video, list): |
|
|
|
|
|
frames = [] |
|
|
for frame in video: |
|
|
if isinstance(frame, Image.Image): |
|
|
frame = transforms.ToTensor()(frame) |
|
|
frames.append(frame) |
|
|
video_tensor = torch.stack(frames, dim=1) |
|
|
elif isinstance(video, torch.Tensor): |
|
|
video_tensor = video |
|
|
else: |
|
|
raise ValueError(f"Unsupported video type: {type(video)}") |
|
|
|
|
|
|
|
|
video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width) |
|
|
|
|
|
|
|
|
C, F, H, W = video_tensor.shape |
|
|
if F != self.num_frames or H != self.height or W != self.width: |
|
|
|
|
|
video_tensor = torch.nn.functional.interpolate( |
|
|
video_tensor.reshape(C * F, 1, H, W), |
|
|
size=(self.height, self.width), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).reshape(C, F, self.height, self.width) |
|
|
|
|
|
if F < self.num_frames: |
|
|
|
|
|
num_pad = self.num_frames - F |
|
|
last_frame = video_tensor[:, -1:, :, :] |
|
|
padding = last_frame.repeat(1, num_pad, 1, 1) |
|
|
video_tensor = torch.cat([video_tensor, padding], dim=1) |
|
|
elif F > self.num_frames: |
|
|
|
|
|
video_tensor = video_tensor[:, :self.num_frames, :, :] |
|
|
|
|
|
|
|
|
video_tensor = video_tensor.contiguous().clone() |
|
|
|
|
|
|
|
|
prompt = item[self.prompt_key] |
|
|
if self.prompt_prefix is not None: |
|
|
prompt = self.prompt_prefix + prompt |
|
|
|
|
|
prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] |
|
|
|
|
|
prompt_input_ids = prompt_input_ids.clone() |
|
|
|
|
|
rv = { |
|
|
"video": video_tensor, |
|
|
"prompt_input_ids": prompt_input_ids |
|
|
} |
|
|
|
|
|
return rv |
|
|
|
|
|
|
|
|
class OpenVid1MDataset(Dataset): |
|
|
""" |
|
|
Dataset for OpenVid1M video-text pairs from CSV file. |
|
|
|
|
|
CSV format: |
|
|
video,caption,aesthetic score,motion score,temporal consistency score,camera motion,frame,fps,seconds,new_id |
|
|
|
|
|
Returns: |
|
|
dict with keys: |
|
|
- "video": torch.Tensor of shape [C, F, H, W] in [0, 1] range |
|
|
- "prompt_input_ids": torch.Tensor of tokenized prompt |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
csv_path, |
|
|
video_root_dir, |
|
|
tokenizer, |
|
|
num_frames=16, |
|
|
height=480, |
|
|
width=848, |
|
|
text_encoder_architecture='umt5-base', |
|
|
prompt_prefix=None, |
|
|
use_random_temporal_crop=True, |
|
|
use_random_crop=True, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
csv_path: Path to the CSV file containing video metadata |
|
|
video_root_dir: Root directory where video files are stored |
|
|
tokenizer: Text tokenizer |
|
|
num_frames: Target number of frames to extract |
|
|
height: Target height |
|
|
width: Target width |
|
|
text_encoder_architecture: Architecture of text encoder |
|
|
prompt_prefix: Optional prefix to add to prompts |
|
|
""" |
|
|
self.csv_path = csv_path |
|
|
self.video_root_dir = video_root_dir |
|
|
self.tokenizer = tokenizer |
|
|
self.num_frames = num_frames |
|
|
self.height = height |
|
|
self.width = width |
|
|
self.text_encoder_architecture = text_encoder_architecture |
|
|
self.prompt_prefix = prompt_prefix |
|
|
self.use_random_temporal_crop = use_random_temporal_crop |
|
|
self.use_random_crop = use_random_crop |
|
|
|
|
|
|
|
|
self.data = [] |
|
|
with open(csv_path, 'r', encoding='utf-8') as f: |
|
|
reader = csv.DictReader(f) |
|
|
for row in reader: |
|
|
self.data.append(row) |
|
|
|
|
|
logger.info(f"Loaded {len(self.data)} video entries from {csv_path}") |
|
|
|
|
|
|
|
|
self.video_loader = None |
|
|
try: |
|
|
import decord |
|
|
decord.bridge.set_bridge('torch') |
|
|
self.video_loader = 'decord' |
|
|
logger.info("Using decord for video loading") |
|
|
except ImportError: |
|
|
try: |
|
|
import av |
|
|
self.video_loader = 'av' |
|
|
logger.info("Using PyAV for video loading") |
|
|
except ImportError: |
|
|
try: |
|
|
import cv2 |
|
|
self.video_loader = 'cv2' |
|
|
logger.info("Using OpenCV for video loading") |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"No video loading library found. Please install one of: " |
|
|
"decord (pip install decord), PyAV (pip install av), or opencv-python (pip install opencv-python)" |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def _load_video_decord(self, video_path): |
|
|
"""Load video using decord""" |
|
|
import decord |
|
|
vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) |
|
|
total_frames = len(vr) |
|
|
|
|
|
|
|
|
if total_frames <= self.num_frames: |
|
|
indices = list(range(total_frames)) |
|
|
else: |
|
|
if self.use_random_temporal_crop: |
|
|
|
|
|
max_start = total_frames - self.num_frames |
|
|
start_idx = random.randint(0, max_start) |
|
|
else: |
|
|
|
|
|
start_idx = 0 |
|
|
indices = list(range(start_idx, start_idx + self.num_frames)) |
|
|
|
|
|
frames = vr.get_batch(indices) |
|
|
|
|
|
if isinstance(frames, torch.Tensor): |
|
|
frames = frames.float() |
|
|
else: |
|
|
|
|
|
|
|
|
frames = torch.tensor(frames, dtype=torch.float32) |
|
|
frames = frames.permute(3, 0, 1, 2) |
|
|
frames = frames / 255.0 |
|
|
|
|
|
return frames |
|
|
|
|
|
def _load_video_av(self, video_path): |
|
|
"""Load video using PyAV""" |
|
|
import av |
|
|
container = av.open(video_path) |
|
|
frames = [] |
|
|
|
|
|
|
|
|
video_stream = container.streams.video[0] |
|
|
total_frames = video_stream.frames if video_stream.frames > 0 else None |
|
|
|
|
|
|
|
|
if total_frames is None: |
|
|
|
|
|
frame_list = [] |
|
|
for frame in container.decode(video_stream): |
|
|
frame_list.append(frame) |
|
|
total_frames = len(frame_list) |
|
|
if total_frames <= self.num_frames: |
|
|
frame_indices = list(range(total_frames)) |
|
|
else: |
|
|
if self.use_random_temporal_crop: |
|
|
|
|
|
max_start = total_frames - self.num_frames |
|
|
start_idx = random.randint(0, max_start) |
|
|
else: |
|
|
|
|
|
start_idx = 0 |
|
|
frame_indices = list(range(start_idx, start_idx + self.num_frames)) |
|
|
frames = [transforms.ToTensor()(frame_list[i].to_image()) for i in frame_indices] |
|
|
else: |
|
|
if total_frames <= self.num_frames: |
|
|
frame_indices = list(range(total_frames)) |
|
|
else: |
|
|
if self.use_random_temporal_crop: |
|
|
|
|
|
max_start = total_frames - self.num_frames |
|
|
start_idx = random.randint(0, max_start) |
|
|
else: |
|
|
|
|
|
start_idx = 0 |
|
|
frame_indices = list(range(start_idx, start_idx + self.num_frames)) |
|
|
|
|
|
frame_idx = 0 |
|
|
for frame in container.decode(video_stream): |
|
|
if frame_idx in frame_indices: |
|
|
img = frame.to_image() |
|
|
img_tensor = transforms.ToTensor()(img) |
|
|
frames.append(img_tensor) |
|
|
if len(frames) >= self.num_frames: |
|
|
break |
|
|
frame_idx += 1 |
|
|
|
|
|
container.close() |
|
|
|
|
|
if len(frames) == 0: |
|
|
raise ValueError(f"No frames extracted from {video_path}") |
|
|
|
|
|
|
|
|
video_tensor = torch.stack(frames, dim=1) |
|
|
|
|
|
|
|
|
if video_tensor.shape[1] < self.num_frames: |
|
|
padding = torch.zeros( |
|
|
video_tensor.shape[0], |
|
|
self.num_frames - video_tensor.shape[1], |
|
|
video_tensor.shape[2], |
|
|
video_tensor.shape[3] |
|
|
) |
|
|
video_tensor = torch.cat([video_tensor, padding], dim=1) |
|
|
|
|
|
return video_tensor |
|
|
|
|
|
def _load_video_cv2(self, video_path): |
|
|
"""Load video using OpenCV""" |
|
|
import cv2 |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
|
|
|
if total_frames <= self.num_frames: |
|
|
frame_indices = list(range(total_frames)) |
|
|
else: |
|
|
if self.use_random_temporal_crop: |
|
|
|
|
|
max_start = total_frames - self.num_frames |
|
|
start_idx = random.randint(0, max_start) |
|
|
else: |
|
|
|
|
|
start_idx = 0 |
|
|
frame_indices = list(range(start_idx, start_idx + self.num_frames)) |
|
|
|
|
|
frame_idx = 0 |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if frame_idx in frame_indices: |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
frame_tensor = torch.tensor(frame_rgb, dtype=torch.float32).permute(2, 0, 1) / 255.0 |
|
|
frames.append(frame_tensor) |
|
|
if len(frames) >= self.num_frames: |
|
|
break |
|
|
frame_idx += 1 |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if len(frames) == 0: |
|
|
raise ValueError(f"No frames extracted from {video_path}") |
|
|
|
|
|
|
|
|
video_tensor = torch.stack(frames, dim=1) |
|
|
|
|
|
|
|
|
if video_tensor.shape[1] < self.num_frames: |
|
|
padding = torch.zeros( |
|
|
video_tensor.shape[0], |
|
|
self.num_frames - video_tensor.shape[1], |
|
|
video_tensor.shape[2], |
|
|
video_tensor.shape[3] |
|
|
) |
|
|
video_tensor = torch.cat([video_tensor, padding], dim=1) |
|
|
|
|
|
return video_tensor |
|
|
|
|
|
def _load_video(self, video_path): |
|
|
"""Load video from path using the available video loader""" |
|
|
full_path = os.path.join(self.video_root_dir, video_path) |
|
|
|
|
|
if not os.path.exists(full_path): |
|
|
raise FileNotFoundError(f"Video file not found: {full_path}") |
|
|
|
|
|
if self.video_loader == 'decord': |
|
|
return self._load_video_decord(full_path) |
|
|
elif self.video_loader == 'av': |
|
|
return self._load_video_av(full_path) |
|
|
elif self.video_loader == 'cv2': |
|
|
return self._load_video_cv2(full_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown video loader: {self.video_loader}") |
|
|
|
|
|
def __getitem__(self, index): |
|
|
row = self.data[index] |
|
|
|
|
|
|
|
|
video_path = row['video'] |
|
|
try: |
|
|
video_tensor = self._load_video(video_path) |
|
|
except Exception as e: |
|
|
|
|
|
logger.warning(f"Failed to load video {video_path}: {e}") |
|
|
video_tensor = torch.zeros(3, self.num_frames, self.height, self.width) |
|
|
|
|
|
|
|
|
video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width, use_random_crop=self.use_random_crop) |
|
|
|
|
|
|
|
|
C, F, H, W = video_tensor.shape |
|
|
if F != self.num_frames or H != self.height or W != self.width: |
|
|
|
|
|
video_tensor = torch.nn.functional.interpolate( |
|
|
video_tensor.reshape(C * F, 1, H, W), |
|
|
size=(self.height, self.width), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).reshape(C, F, self.height, self.width) |
|
|
|
|
|
if F < self.num_frames: |
|
|
|
|
|
num_pad = self.num_frames - F |
|
|
last_frame = video_tensor[:, -1:, :, :] |
|
|
padding = last_frame.repeat(1, num_pad, 1, 1) |
|
|
video_tensor = torch.cat([video_tensor, padding], dim=1) |
|
|
elif F > self.num_frames: |
|
|
|
|
|
video_tensor = video_tensor[:, :self.num_frames, :, :] |
|
|
|
|
|
|
|
|
video_tensor = video_tensor.contiguous().clone() |
|
|
|
|
|
|
|
|
prompt = row['caption'] |
|
|
if self.prompt_prefix is not None: |
|
|
prompt = self.prompt_prefix + prompt |
|
|
|
|
|
prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] |
|
|
|
|
|
prompt_input_ids = prompt_input_ids.clone() |
|
|
|
|
|
return { |
|
|
"video": video_tensor, |
|
|
"prompt_input_ids": prompt_input_ids |
|
|
} |
|
|
|
|
|
|
|
|
class TinyOpenVid1MDataset(OpenVid1MDataset): |
|
|
""" |
|
|
A tiny subset of OpenVid1MDataset for overfitting experiments. |
|
|
Only takes the first N samples from the full dataset. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
csv_path, |
|
|
video_root_dir=None, |
|
|
tokenizer=None, |
|
|
num_frames=16, |
|
|
height=480, |
|
|
width=848, |
|
|
text_encoder_architecture='umt5-base', |
|
|
prompt_prefix=None, |
|
|
max_samples=256, |
|
|
seed=42, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
max_samples: Maximum number of samples to use (default: 256) |
|
|
seed: Random seed for reproducibility (default: 42) |
|
|
""" |
|
|
|
|
|
super().__init__( |
|
|
csv_path=csv_path, |
|
|
video_root_dir=video_root_dir, |
|
|
tokenizer=tokenizer, |
|
|
num_frames=num_frames, |
|
|
height=height, |
|
|
width=width, |
|
|
text_encoder_architecture=text_encoder_architecture, |
|
|
prompt_prefix=prompt_prefix, |
|
|
) |
|
|
|
|
|
|
|
|
original_len = len(self.data) |
|
|
if original_len > max_samples: |
|
|
|
|
|
import random |
|
|
random.seed(seed) |
|
|
|
|
|
indices = list(range(original_len)) |
|
|
random.shuffle(indices) |
|
|
self.data = [self.data[i] for i in indices[:max_samples]] |
|
|
logger.info(f"Limited dataset to {max_samples} samples (from {original_len} total) for overfitting experiment") |
|
|
else: |
|
|
logger.info(f"Using all {len(self.data)} samples (less than max_samples={max_samples})") |
|
|
|
|
|
|
|
|
def get_hierarchical_path(base_dir, index): |
|
|
""" |
|
|
Get hierarchical path for loading features from 3-level directory structure. |
|
|
|
|
|
Structure: base_dir/level1/level2/level3/filename.npy |
|
|
- level1: index // 1000000 (0-999) |
|
|
- level2: (index // 1000) % 1000 (0-999) |
|
|
- level3: index % 1000 (0-999) |
|
|
|
|
|
Args: |
|
|
base_dir: Base directory for features |
|
|
index: Sample index |
|
|
|
|
|
Returns: |
|
|
Full path to the file |
|
|
""" |
|
|
level1 = index // 1000000 |
|
|
level2 = (index // 1000) % 1000 |
|
|
level3 = index % 1000 |
|
|
|
|
|
file_path = os.path.join( |
|
|
base_dir, |
|
|
f"{level1:03d}", |
|
|
f"{level2:03d}", |
|
|
f"{level3:03d}", |
|
|
f"{index:08d}.npy" |
|
|
) |
|
|
|
|
|
return file_path |
|
|
|
|
|
|
|
|
class PrecomputedFeatureDataset(Dataset): |
|
|
""" |
|
|
Dataset for loading pre-extracted video codes and text embeddings. |
|
|
|
|
|
This dataset loads features that were pre-extracted by extract_features.py, |
|
|
avoiding the need to encode videos and text during training. |
|
|
|
|
|
Features are stored in a 3-level hierarchical directory structure: |
|
|
- video_codes/level1/level2/level3/index.npy |
|
|
- text_embeddings/level1/level2/level3/index.npy |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
features_dir, |
|
|
num_samples=None, |
|
|
start_index=0, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
features_dir: Directory containing extracted features (should have video_codes/ and text_embeddings/ subdirs) |
|
|
num_samples: Number of samples to use. If None, use all available samples. |
|
|
start_index: Starting index for samples (for resuming or subset selection) |
|
|
""" |
|
|
self.features_dir = features_dir |
|
|
self.video_codes_dir = os.path.join(features_dir, "video_codes") |
|
|
self.text_embeddings_dir = os.path.join(features_dir, "text_embeddings") |
|
|
self.metadata_file = os.path.join(features_dir, "metadata.json") |
|
|
|
|
|
|
|
|
if os.path.exists(self.metadata_file): |
|
|
import json |
|
|
with open(self.metadata_file, 'r') as f: |
|
|
self.metadata = json.load(f) |
|
|
logger.info(f"Loaded metadata from {self.metadata_file}") |
|
|
logger.info(f" Total samples in metadata: {self.metadata.get('num_samples', 'unknown')}") |
|
|
|
|
|
|
|
|
if 'samples' in self.metadata and len(self.metadata['samples']) > 0: |
|
|
available_indices = sorted([s['index'] for s in self.metadata['samples']]) |
|
|
else: |
|
|
|
|
|
available_indices = self._scan_hierarchical_directory(self.video_codes_dir) |
|
|
else: |
|
|
|
|
|
logger.warning(f"Metadata file not found: {self.metadata_file}, scanning directory structure") |
|
|
self.metadata = {} |
|
|
available_indices = self._scan_hierarchical_directory(self.video_codes_dir) |
|
|
|
|
|
|
|
|
available_indices = [idx for idx in available_indices if idx >= start_index] |
|
|
if num_samples is not None: |
|
|
available_indices = available_indices[:num_samples] |
|
|
|
|
|
self.indices = available_indices |
|
|
logger.info(f"PrecomputedFeatureDataset: {len(self.indices)} samples available") |
|
|
if len(self.indices) > 0: |
|
|
logger.info(f" Index range: {min(self.indices)} to {max(self.indices)}") |
|
|
|
|
|
def _scan_hierarchical_directory(self, base_dir): |
|
|
""" |
|
|
Scan hierarchical directory structure to find all available indices. |
|
|
|
|
|
Args: |
|
|
base_dir: Base directory to scan |
|
|
|
|
|
Returns: |
|
|
List of available indices |
|
|
""" |
|
|
available_indices = [] |
|
|
|
|
|
if not os.path.exists(base_dir): |
|
|
raise FileNotFoundError(f"Directory not found: {base_dir}") |
|
|
|
|
|
|
|
|
for level1 in range(1000): |
|
|
level1_dir = os.path.join(base_dir, f"{level1:03d}") |
|
|
if not os.path.exists(level1_dir): |
|
|
continue |
|
|
|
|
|
|
|
|
for level2 in range(1000): |
|
|
level2_dir = os.path.join(level1_dir, f"{level2:03d}") |
|
|
if not os.path.exists(level2_dir): |
|
|
continue |
|
|
|
|
|
|
|
|
for level3 in range(1000): |
|
|
level3_dir = os.path.join(level2_dir, f"{level3:03d}") |
|
|
if not os.path.exists(level3_dir): |
|
|
continue |
|
|
|
|
|
|
|
|
for filename in os.listdir(level3_dir): |
|
|
if filename.endswith('.npy'): |
|
|
try: |
|
|
index = int(filename.replace('.npy', '')) |
|
|
available_indices.append(index) |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
return sorted(available_indices) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample_idx = self.indices[idx] |
|
|
|
|
|
|
|
|
video_code_path = get_hierarchical_path(self.video_codes_dir, sample_idx) |
|
|
text_embedding_path = get_hierarchical_path(self.text_embeddings_dir, sample_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(video_code_path): |
|
|
raise FileNotFoundError(f"Video code not found: {video_code_path}") |
|
|
video_codes_np = np.load(video_code_path) |
|
|
|
|
|
|
|
|
video_codes = torch.tensor(video_codes_np, dtype=torch.int32) |
|
|
del video_codes_np |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(text_embedding_path): |
|
|
raise FileNotFoundError(f"Text embedding not found: {text_embedding_path}") |
|
|
text_embedding_np = np.load(text_embedding_path) |
|
|
|
|
|
|
|
|
text_embedding_dtype = torch.float16 if text_embedding_np.dtype == np.float16 else torch.float32 |
|
|
text_embedding = torch.tensor(text_embedding_np, dtype=text_embedding_dtype) |
|
|
del text_embedding_np |
|
|
|
|
|
return { |
|
|
"video_codes": video_codes, |
|
|
"text_embedding": text_embedding, |
|
|
"sample_index": sample_idx, |
|
|
} |