43 / Meissonic /train /dataset_utils.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL.ImageOps import exif_transpose
from PIL import Image
import io
import pyarrow.parquet as pq
import random
import bisect
import pyarrow.fs as fs
import csv
import numpy as np
import logging
logger = logging.getLogger(__name__)
@torch.no_grad()
def tokenize_prompt(tokenizer, prompt, text_encoder_architecture='open_clip'): # support open_clip, CLIP, T5/UMT5
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
return tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=77,
return_tensors="pt",
).input_ids
elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']:
# T5/UMT5 tokenizer
return tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt",
).input_ids
elif text_encoder_architecture == 'CLIP_T5_base': # we have two tokenizers, 1st for CLIP, 2nd for T5
input_ids = []
input_ids.append(tokenizer[0](
prompt,
truncation=True,
padding="max_length",
max_length=77,
return_tensors="pt",
).input_ids)
input_ids.append(tokenizer[1](
prompt,
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt",
).input_ids)
return input_ids
else:
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
def encode_prompt(text_encoder, input_ids, text_encoder_architecture='open_clip'): # support open_clip, CLIP, T5/UMT5
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True)
encoder_hidden_states = outputs.hidden_states[-2]
cond_embeds = outputs[0]
return encoder_hidden_states, cond_embeds
elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']:
# T5/UMT5 encoder - only returns encoder_hidden_states, no pooled projection
outputs = text_encoder(input_ids=input_ids, return_dict=True)
encoder_hidden_states = outputs.last_hidden_state
# For T5, we don't have a pooled projection, so return None or a dummy tensor
# The video pipeline doesn't use cond_embeds, so we can return None
cond_embeds = None
return encoder_hidden_states, cond_embeds
elif text_encoder_architecture == 'CLIP_T5_base':
outputs_clip = text_encoder[0](input_ids=input_ids[0], return_dict=True, output_hidden_states=True)
outputs_t5 = text_encoder[1](input_ids=input_ids[1], decoder_input_ids=torch.zeros_like(input_ids[1]),
return_dict=True, output_hidden_states=True)
encoder_hidden_states = outputs_t5.encoder_hidden_states[-2]
cond_embeds = outputs_clip[0]
return encoder_hidden_states, cond_embeds
else:
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
def process_image(image, size, Norm=False, hps_score = 6.0):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
orig_height = image.height
orig_width = image.width
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
image = transforms.functional.crop(image, c_top, c_left, size, size)
image = transforms.ToTensor()(image)
if Norm:
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
micro_conds = torch.tensor(
[orig_width, orig_height, c_top, c_left, hps_score],
)
return {"image": image, "micro_conds": micro_conds}
class MyParquetDataset(Dataset):
def __init__(self, root_dir, tokenizer=None, size=512,
text_encoder_architecture='CLIP', norm=False):
random.seed(23)
self.root_dir = root_dir
self.dataset_receipt = {'MSCOCO_part1': {'total_num': 6212, 'ratio':1}, 'MSCOCO_part2': {'total_num': 6212, 'ratio':1}}
self.tokenizer = tokenizer
self.size = size
self.text_encoder_architecture = text_encoder_architecture
self.norm = norm
self.hdfs = fs.HadoopFileSystem(host="", port=0000) # TODO: change to your own HDFS host and port
self._init_mixed_parquet_dir_list()
self.file_metadata = []
self.cumulative_sizes = [0]
total = 0
for path in self.parquet_files:
try:
with pq.ParquetFile(path, filesystem=self.hdfs) as pf:
num_rows = pf.metadata.num_rows
self.file_metadata.append({
'path': path,
'num_rows': num_rows,
'global_offset': total
})
total += num_rows
self.cumulative_sizes.append(total)
except Exception as e:
print(f"Error processing {path}: {str(e)}")
continue
# init cache
self.current_file = None
self.cached_data = None
self.cached_file_index = -1
def _init_mixed_parquet_dir_list(self):
print('Loading parquet files, please be patient...')
self.parquet_files = []
for key, value in self.dataset_receipt.items():
# Generate a list of standard Parquet file paths, lazy load
hdfs_path = os.path.join(self.root_dir, key)
num = value['total_num']
sampled_list = random.sample(
[f"{hdfs_path}/train-{idx:05d}-of-{num:05d}.parquet" for idx in range(num)],
k=int(num * value['ratio'])
)
self.parquet_files += sampled_list
def __len__(self):
return self.cumulative_sizes[-1]
def _locate_file(self, global_idx):
# Use binary search to quickly locate files
file_index = bisect.bisect_right(self.cumulative_sizes, global_idx) - 1
if file_index < 0 or file_index >= len(self.file_metadata):
raise IndexError(f"Index {global_idx} out of range")
file_info = self.file_metadata[file_index]
local_idx = global_idx - file_info['global_offset']
return file_index, local_idx
def _load_file(self, file_index):
"""Load Parquet files into cache on demand"""
if self.cached_file_index != file_index:
file_info = self.file_metadata[file_index]
try:
table = pq.read_table(file_info['path'], filesystem=self.hdfs)
self.cached_data = table.to_pydict()
self.cached_file_index = file_index
except Exception as e:
print(f"Error loading {file_info['path']}: {str(e)}")
raise
def __getitem__(self, idx):
file_index, local_idx = self._locate_file(idx)
self._load_file(file_index)
sample = {k: v[local_idx] for k, v in self.cached_data.items()}
# cprint(sample.keys(), 'red')
generated_caption, image_path = sample['task2'], sample['image'] # only suitable for my data
instance_image = Image.open(io.BytesIO(image_path['bytes']))
# if instance_image.width < self.size or instance_image.height < self.size:
# raise ValueError(f"Image at {image_path} is too small")
rv = process_image(instance_image, self.size, self.norm)
if isinstance(self.tokenizer, list):
_tmp_ = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture)
rv["prompt_input_ids"] = [_tmp_[0][0], _tmp_[1][0]]
else:
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture)[
0]
return rv
class HuggingFaceDataset(Dataset):
def __init__(
self,
hf_dataset,
tokenizer,
image_key,
prompt_key,
prompt_prefix=None,
size=512,
text_encoder_architecture='CLIP',
):
self.size = size
self.image_key = image_key
self.prompt_key = prompt_key
self.tokenizer = tokenizer
self.hf_dataset = hf_dataset
self.prompt_prefix = prompt_prefix
self.text_encoder_architecture = text_encoder_architecture
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, index):
item = self.hf_dataset[index]
rv = process_image(item[self.image_key], self.size)
prompt = item[self.prompt_key]
if self.prompt_prefix is not None:
prompt = self.prompt_prefix + prompt
if isinstance(self.tokenizer, list):
_tmp_ = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)
rv["prompt_input_ids"] = [_tmp_[0][0],_tmp_[1][0]]
else:
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0]
return rv
def process_video(video_tensor, num_frames, height, width, use_random_crop=True):
"""
Process video tensor for training.
Uses aspect-ratio preserving resize + crop to avoid distortion.
Args:
video_tensor: Video tensor of shape [C, F, H, W] or [F, H, W, C]
num_frames: Target number of frames
height: Target height
width: Target width
use_random_crop: If True, use random crop (for training). If False, use center crop (for validation/feature extraction)
Returns:
Processed video tensor of shape [C, F, H, W] in [0, 1] range
"""
# Ensure video is in [C, F, H, W] format
if video_tensor.dim() == 4:
if video_tensor.shape[0] == 3 or video_tensor.shape[0] == 1:
# Already in [C, F, H, W] format
pass
elif video_tensor.shape[-1] == 3 or video_tensor.shape[-1] == 1:
# [F, H, W, C] -> [C, F, H, W]
video_tensor = video_tensor.permute(3, 0, 1, 2)
else:
raise ValueError(f"Unexpected video tensor shape: {video_tensor.shape}")
# Normalize to [0, 1] if needed
if video_tensor.max() > 1.0:
video_tensor = video_tensor / 255.0
C, F, H, W = video_tensor.shape
# Temporal resampling: ensure exactly num_frames frames
if F != num_frames:
if F < num_frames:
# If video is shorter, pad by repeating the last frame
num_pad = num_frames - F
last_frame = video_tensor[:, -1:, :, :] # [C, 1, H, W]
padding = last_frame.repeat(1, num_pad, 1, 1) # [C, num_pad, H, W]
video_tensor = torch.cat([video_tensor, padding], dim=1) # [C, num_frames, H, W]
F = num_frames
else:
# If video is longer, randomly select a continuous segment of num_frames
max_start = F - num_frames
start_idx = random.randint(0, max_start)
indices = torch.arange(start_idx, start_idx + num_frames)
video_tensor = video_tensor[:, indices, :, :]
F = num_frames # Update F after temporal resampling
# Spatial resizing: aspect-ratio preserving resize + crop
if H != height or W != width:
# Step 1: Aspect-ratio preserving resize
# Calculate scale factors for both dimensions
scale_h = height / H
scale_w = width / W
# Use the larger scale to ensure both dimensions are at least as large as target
# This way, after resize, we can crop to exact target size
scale = max(scale_h, scale_w)
# Calculate new dimensions maintaining aspect ratio
new_H = int(H * scale)
new_W = int(W * scale)
# Ensure we have at least the target size (handle rounding)
if new_H < height:
new_H = height
if new_W < width:
new_W = width
# Resize maintaining aspect ratio
# Process each frame: [C, F, H, W] -> reshape to [C*F, 1, H, W] for interpolation
video_tensor = torch.nn.functional.interpolate(
video_tensor.reshape(C * F, 1, H, W),
size=(new_H, new_W),
mode='bilinear',
align_corners=False
).reshape(C, F, new_H, new_W)
# Step 2: Crop to target size (height, width)
# Calculate crop coordinates
if use_random_crop:
# Random crop for training (data augmentation)
max_h = new_H - height
max_w = new_W - width
if max_h < 0 or max_w < 0:
# If resized image is smaller than target, pad instead
pad_h = max(0, height - new_H)
pad_w = max(0, width - new_W)
video_tensor = torch.nn.functional.pad(
video_tensor,
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
mode='constant',
value=0
)
# If still not exact size, crop or pad
if video_tensor.shape[2] != height or video_tensor.shape[3] != width:
video_tensor = torch.nn.functional.interpolate(
video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]),
size=(height, width),
mode='bilinear',
align_corners=False
).reshape(C, F, height, width)
else:
crop_h = random.randint(0, max_h)
crop_w = random.randint(0, max_w)
video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width]
else:
# Center crop for validation/feature extraction (deterministic)
crop_h = (new_H - height) // 2
crop_w = (new_W - width) // 2
if crop_h < 0 or crop_w < 0:
# If resized image is smaller than target, pad instead
pad_h = max(0, height - new_H)
pad_w = max(0, width - new_W)
video_tensor = torch.nn.functional.pad(
video_tensor,
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
mode='constant',
value=0
)
# If still not exact size, crop or pad
if video_tensor.shape[2] != height or video_tensor.shape[3] != width:
video_tensor = torch.nn.functional.interpolate(
video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]),
size=(height, width),
mode='bilinear',
align_corners=False
).reshape(C, F, height, width)
else:
video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width]
# Final verification: ensure output has exactly the expected shape
C, F, H, W = video_tensor.shape
assert F == num_frames, f"Frame count mismatch: expected {num_frames}, got {F}"
assert H == height, f"Height mismatch: expected {height}, got {H}"
assert W == width, f"Width mismatch: expected {width}, got {W}"
return video_tensor
class VideoDataset(Dataset):
"""
Dataset for video training, compatible with HuggingFace datasets format.
Supports OpenVid1M and similar video-text datasets.
"""
def __init__(
self,
hf_dataset,
tokenizer,
video_key="video",
prompt_key="caption",
prompt_prefix=None,
num_frames=16,
height=480,
width=848,
text_encoder_architecture='umt5-base',
use_random_crop=True, # Random crop for training, center crop for validation
):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
self.video_key = video_key
self.prompt_key = prompt_key
self.prompt_prefix = prompt_prefix
self.num_frames = num_frames
self.height = height
self.width = width
self.text_encoder_architecture = text_encoder_architecture
self.use_random_crop = use_random_crop
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, index):
item = self.hf_dataset[index]
# Load video
video = item[self.video_key]
# Convert to tensor if needed (handle different formats)
if isinstance(video, list):
# List of PIL Images or tensors
frames = []
for frame in video:
if isinstance(frame, Image.Image):
frame = transforms.ToTensor()(frame)
frames.append(frame)
video_tensor = torch.stack(frames, dim=1) # [C, F, H, W]
elif isinstance(video, torch.Tensor):
video_tensor = video
else:
raise ValueError(f"Unsupported video type: {type(video)}")
# Process video
video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width)
# Ensure video tensor has exactly the expected shape
C, F, H, W = video_tensor.shape
if F != self.num_frames or H != self.height or W != self.width:
# If shape doesn't match, create a properly sized tensor
video_tensor = torch.nn.functional.interpolate(
video_tensor.reshape(C * F, 1, H, W),
size=(self.height, self.width),
mode='bilinear',
align_corners=False
).reshape(C, F, self.height, self.width)
# Ensure exactly num_frames
if F < self.num_frames:
# Pad by repeating last frame
num_pad = self.num_frames - F
last_frame = video_tensor[:, -1:, :, :]
padding = last_frame.repeat(1, num_pad, 1, 1)
video_tensor = torch.cat([video_tensor, padding], dim=1)
elif F > self.num_frames:
# Crop to num_frames
video_tensor = video_tensor[:, :self.num_frames, :, :]
# Clone to ensure storage is resizable (required for DataLoader collate)
video_tensor = video_tensor.contiguous().clone()
# Process prompt
prompt = item[self.prompt_key]
if self.prompt_prefix is not None:
prompt = self.prompt_prefix + prompt
prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0]
# Clone to ensure storage is resizable
prompt_input_ids = prompt_input_ids.clone()
rv = {
"video": video_tensor, # [C, num_frames, height, width], guaranteed shape
"prompt_input_ids": prompt_input_ids
}
return rv
class OpenVid1MDataset(Dataset):
"""
Dataset for OpenVid1M video-text pairs from CSV file.
CSV format:
video,caption,aesthetic score,motion score,temporal consistency score,camera motion,frame,fps,seconds,new_id
Returns:
dict with keys:
- "video": torch.Tensor of shape [C, F, H, W] in [0, 1] range
- "prompt_input_ids": torch.Tensor of tokenized prompt
"""
def __init__(
self,
csv_path,
video_root_dir,
tokenizer,
num_frames=16,
height=480,
width=848,
text_encoder_architecture='umt5-base',
prompt_prefix=None,
use_random_temporal_crop=True, # If False, always sample from the beginning
use_random_crop=True, # Random crop for training, center crop for validation/feature extraction
):
"""
Args:
csv_path: Path to the CSV file containing video metadata
video_root_dir: Root directory where video files are stored
tokenizer: Text tokenizer
num_frames: Target number of frames to extract
height: Target height
width: Target width
text_encoder_architecture: Architecture of text encoder
prompt_prefix: Optional prefix to add to prompts
"""
self.csv_path = csv_path
self.video_root_dir = video_root_dir
self.tokenizer = tokenizer
self.num_frames = num_frames
self.height = height
self.width = width
self.text_encoder_architecture = text_encoder_architecture
self.prompt_prefix = prompt_prefix
self.use_random_temporal_crop = use_random_temporal_crop
self.use_random_crop = use_random_crop
# Load CSV data
self.data = []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
self.data.append(row)
logger.info(f"Loaded {len(self.data)} video entries from {csv_path}")
# Try to import video loading library
self.video_loader = None
try:
import decord
decord.bridge.set_bridge('torch')
self.video_loader = 'decord'
logger.info("Using decord for video loading")
except ImportError:
try:
import av
self.video_loader = 'av'
logger.info("Using PyAV for video loading")
except ImportError:
try:
import cv2
self.video_loader = 'cv2'
logger.info("Using OpenCV for video loading")
except ImportError:
raise ImportError(
"No video loading library found. Please install one of: "
"decord (pip install decord), PyAV (pip install av), or opencv-python (pip install opencv-python)"
)
def __len__(self):
return len(self.data)
def _load_video_decord(self, video_path):
"""Load video using decord"""
import decord
vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
total_frames = len(vr)
# Sample frames: random temporal crop (continuous segment) for better temporal coherence
if total_frames <= self.num_frames:
indices = list(range(total_frames))
else:
if self.use_random_temporal_crop:
# Randomly select a continuous segment of num_frames
max_start = total_frames - self.num_frames
start_idx = random.randint(0, max_start)
else:
# Fixed sampling: always start from the beginning
start_idx = 0
indices = list(range(start_idx, start_idx + self.num_frames))
frames = vr.get_batch(indices) # [F, H, W, C] in uint8
# If using torch bridge, frames is already a torch Tensor
if isinstance(frames, torch.Tensor):
frames = frames.float() # [F, H, W, C]
else:
# Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy
# This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate
frames = torch.tensor(frames, dtype=torch.float32) # [F, H, W, C], fully copied
frames = frames.permute(3, 0, 1, 2) # [C, F, H, W]
frames = frames / 255.0 # Normalize to [0, 1]
return frames
def _load_video_av(self, video_path):
"""Load video using PyAV"""
import av
container = av.open(video_path)
frames = []
# Get video stream
video_stream = container.streams.video[0]
total_frames = video_stream.frames if video_stream.frames > 0 else None
# Sample frames: random temporal crop (continuous segment) for better temporal coherence
if total_frames is None:
# If we can't get frame count, decode all frames and sample
frame_list = []
for frame in container.decode(video_stream):
frame_list.append(frame)
total_frames = len(frame_list)
if total_frames <= self.num_frames:
frame_indices = list(range(total_frames))
else:
if self.use_random_temporal_crop:
# Randomly select a continuous segment of num_frames
max_start = total_frames - self.num_frames
start_idx = random.randint(0, max_start)
else:
# Fixed sampling: always start from the beginning
start_idx = 0
frame_indices = list(range(start_idx, start_idx + self.num_frames))
frames = [transforms.ToTensor()(frame_list[i].to_image()) for i in frame_indices]
else:
if total_frames <= self.num_frames:
frame_indices = list(range(total_frames))
else:
if self.use_random_temporal_crop:
# Randomly select a continuous segment of num_frames
max_start = total_frames - self.num_frames
start_idx = random.randint(0, max_start)
else:
# Fixed sampling: always start from the beginning
start_idx = 0
frame_indices = list(range(start_idx, start_idx + self.num_frames))
frame_idx = 0
for frame in container.decode(video_stream):
if frame_idx in frame_indices:
img = frame.to_image() # PIL Image
img_tensor = transforms.ToTensor()(img) # [C, H, W]
frames.append(img_tensor)
if len(frames) >= self.num_frames:
break
frame_idx += 1
container.close()
if len(frames) == 0:
raise ValueError(f"No frames extracted from {video_path}")
# Stack frames: [C, F, H, W]
video_tensor = torch.stack(frames, dim=1)
# Pad if needed
if video_tensor.shape[1] < self.num_frames:
padding = torch.zeros(
video_tensor.shape[0],
self.num_frames - video_tensor.shape[1],
video_tensor.shape[2],
video_tensor.shape[3]
)
video_tensor = torch.cat([video_tensor, padding], dim=1)
return video_tensor
def _load_video_cv2(self, video_path):
"""Load video using OpenCV"""
import cv2
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Sample frames: random temporal crop (continuous segment) for better temporal coherence
if total_frames <= self.num_frames:
frame_indices = list(range(total_frames))
else:
if self.use_random_temporal_crop:
# Randomly select a continuous segment of num_frames
max_start = total_frames - self.num_frames
start_idx = random.randint(0, max_start)
else:
# Fixed sampling: always start from the beginning
start_idx = 0
frame_indices = list(range(start_idx, start_idx + self.num_frames))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_idx in frame_indices:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert to tensor [C, H, W] and normalize to [0, 1]
# Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy
# This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate
frame_tensor = torch.tensor(frame_rgb, dtype=torch.float32).permute(2, 0, 1) / 255.0
frames.append(frame_tensor)
if len(frames) >= self.num_frames:
break
frame_idx += 1
cap.release()
if len(frames) == 0:
raise ValueError(f"No frames extracted from {video_path}")
# Stack frames: [C, F, H, W]
video_tensor = torch.stack(frames, dim=1)
# Pad if needed
if video_tensor.shape[1] < self.num_frames:
padding = torch.zeros(
video_tensor.shape[0],
self.num_frames - video_tensor.shape[1],
video_tensor.shape[2],
video_tensor.shape[3]
)
video_tensor = torch.cat([video_tensor, padding], dim=1)
return video_tensor
def _load_video(self, video_path):
"""Load video from path using the available video loader"""
full_path = os.path.join(self.video_root_dir, video_path)
if not os.path.exists(full_path):
raise FileNotFoundError(f"Video file not found: {full_path}")
if self.video_loader == 'decord':
return self._load_video_decord(full_path)
elif self.video_loader == 'av':
return self._load_video_av(full_path)
elif self.video_loader == 'cv2':
return self._load_video_cv2(full_path)
else:
raise ValueError(f"Unknown video loader: {self.video_loader}")
def __getitem__(self, index):
row = self.data[index]
# Load video
video_path = row['video']
try:
video_tensor = self._load_video(video_path)
except Exception as e:
# If video loading fails, return a zero tensor and log error
logger.warning(f"Failed to load video {video_path}: {e}")
video_tensor = torch.zeros(3, self.num_frames, self.height, self.width)
# Process video: aspect-ratio preserving resize + crop to target dimensions
video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width, use_random_crop=self.use_random_crop)
# Ensure video tensor has exactly the expected shape
C, F, H, W = video_tensor.shape
if F != self.num_frames or H != self.height or W != self.width:
# If shape doesn't match, create a properly sized tensor
video_tensor = torch.nn.functional.interpolate(
video_tensor.reshape(C * F, 1, H, W),
size=(self.height, self.width),
mode='bilinear',
align_corners=False
).reshape(C, F, self.height, self.width)
# Ensure exactly num_frames
if F < self.num_frames:
# Pad by repeating last frame
num_pad = self.num_frames - F
last_frame = video_tensor[:, -1:, :, :]
padding = last_frame.repeat(1, num_pad, 1, 1)
video_tensor = torch.cat([video_tensor, padding], dim=1)
elif F > self.num_frames:
# Crop to num_frames
video_tensor = video_tensor[:, :self.num_frames, :, :]
# Clone to ensure storage is resizable (required for DataLoader collate)
video_tensor = video_tensor.contiguous().clone()
# Process prompt
prompt = row['caption']
if self.prompt_prefix is not None:
prompt = self.prompt_prefix + prompt
prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0]
# Clone to ensure storage is resizable
prompt_input_ids = prompt_input_ids.clone()
return {
"video": video_tensor, # [C, num_frames, height, width], guaranteed shape
"prompt_input_ids": prompt_input_ids
}
class TinyOpenVid1MDataset(OpenVid1MDataset):
"""
A tiny subset of OpenVid1MDataset for overfitting experiments.
Only takes the first N samples from the full dataset.
"""
def __init__(
self,
csv_path,
video_root_dir=None,
tokenizer=None,
num_frames=16,
height=480,
width=848,
text_encoder_architecture='umt5-base',
prompt_prefix=None,
max_samples=256, # Only use first N samples
seed=42, # Fixed seed for reproducibility
):
"""
Args:
max_samples: Maximum number of samples to use (default: 256)
seed: Random seed for reproducibility (default: 42)
"""
# Initialize parent class
super().__init__(
csv_path=csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=num_frames,
height=height,
width=width,
text_encoder_architecture=text_encoder_architecture,
prompt_prefix=prompt_prefix,
)
# Limit to first max_samples
original_len = len(self.data)
if original_len > max_samples:
# Use fixed seed to ensure reproducibility
import random
random.seed(seed)
# Shuffle with fixed seed, then take first max_samples
indices = list(range(original_len))
random.shuffle(indices)
self.data = [self.data[i] for i in indices[:max_samples]]
logger.info(f"Limited dataset to {max_samples} samples (from {original_len} total) for overfitting experiment")
else:
logger.info(f"Using all {len(self.data)} samples (less than max_samples={max_samples})")
def get_hierarchical_path(base_dir, index):
"""
Get hierarchical path for loading features from 3-level directory structure.
Structure: base_dir/level1/level2/level3/filename.npy
- level1: index // 1000000 (0-999)
- level2: (index // 1000) % 1000 (0-999)
- level3: index % 1000 (0-999)
Args:
base_dir: Base directory for features
index: Sample index
Returns:
Full path to the file
"""
level1 = index // 1000000
level2 = (index // 1000) % 1000
level3 = index % 1000
file_path = os.path.join(
base_dir,
f"{level1:03d}",
f"{level2:03d}",
f"{level3:03d}",
f"{index:08d}.npy"
)
return file_path
class PrecomputedFeatureDataset(Dataset):
"""
Dataset for loading pre-extracted video codes and text embeddings.
This dataset loads features that were pre-extracted by extract_features.py,
avoiding the need to encode videos and text during training.
Features are stored in a 3-level hierarchical directory structure:
- video_codes/level1/level2/level3/index.npy
- text_embeddings/level1/level2/level3/index.npy
"""
def __init__(
self,
features_dir,
num_samples=None,
start_index=0,
):
"""
Args:
features_dir: Directory containing extracted features (should have video_codes/ and text_embeddings/ subdirs)
num_samples: Number of samples to use. If None, use all available samples.
start_index: Starting index for samples (for resuming or subset selection)
"""
self.features_dir = features_dir
self.video_codes_dir = os.path.join(features_dir, "video_codes")
self.text_embeddings_dir = os.path.join(features_dir, "text_embeddings")
self.metadata_file = os.path.join(features_dir, "metadata.json")
# Load metadata
if os.path.exists(self.metadata_file):
import json
with open(self.metadata_file, 'r') as f:
self.metadata = json.load(f)
logger.info(f"Loaded metadata from {self.metadata_file}")
logger.info(f" Total samples in metadata: {self.metadata.get('num_samples', 'unknown')}")
# Get available indices from metadata
if 'samples' in self.metadata and len(self.metadata['samples']) > 0:
available_indices = sorted([s['index'] for s in self.metadata['samples']])
else:
# Fallback: infer from directory structure
available_indices = self._scan_hierarchical_directory(self.video_codes_dir)
else:
# If no metadata, scan directory structure
logger.warning(f"Metadata file not found: {self.metadata_file}, scanning directory structure")
self.metadata = {}
available_indices = self._scan_hierarchical_directory(self.video_codes_dir)
# Filter by start_index and num_samples
available_indices = [idx for idx in available_indices if idx >= start_index]
if num_samples is not None:
available_indices = available_indices[:num_samples]
self.indices = available_indices
logger.info(f"PrecomputedFeatureDataset: {len(self.indices)} samples available")
if len(self.indices) > 0:
logger.info(f" Index range: {min(self.indices)} to {max(self.indices)}")
def _scan_hierarchical_directory(self, base_dir):
"""
Scan hierarchical directory structure to find all available indices.
Args:
base_dir: Base directory to scan
Returns:
List of available indices
"""
available_indices = []
if not os.path.exists(base_dir):
raise FileNotFoundError(f"Directory not found: {base_dir}")
# Scan level1 directories (000-999)
for level1 in range(1000):
level1_dir = os.path.join(base_dir, f"{level1:03d}")
if not os.path.exists(level1_dir):
continue
# Scan level2 directories (000-999)
for level2 in range(1000):
level2_dir = os.path.join(level1_dir, f"{level2:03d}")
if not os.path.exists(level2_dir):
continue
# Scan level3 directories (000-999)
for level3 in range(1000):
level3_dir = os.path.join(level2_dir, f"{level3:03d}")
if not os.path.exists(level3_dir):
continue
# List all .npy files in level3 directory
for filename in os.listdir(level3_dir):
if filename.endswith('.npy'):
try:
index = int(filename.replace('.npy', ''))
available_indices.append(index)
except ValueError:
continue
return sorted(available_indices)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
sample_idx = self.indices[idx]
# Get hierarchical paths
video_code_path = get_hierarchical_path(self.video_codes_dir, sample_idx)
text_embedding_path = get_hierarchical_path(self.text_embeddings_dir, sample_idx)
# Load video codes
# Note: We load directly (not mmap) to avoid storage sharing issues with torch
# The files are small enough (video codes are int32, typically < 1MB per sample)
if not os.path.exists(video_code_path):
raise FileNotFoundError(f"Video code not found: {video_code_path}")
video_codes_np = np.load(video_code_path) # [F', H', W']
# Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy
# This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate
video_codes = torch.tensor(video_codes_np, dtype=torch.int32) # CPU tensor, int32, fully copied
del video_codes_np # Release numpy array reference
# Load text embedding
# Note: We load directly (not mmap) to avoid storage sharing issues with torch
if not os.path.exists(text_embedding_path):
raise FileNotFoundError(f"Text embedding not found: {text_embedding_path}")
text_embedding_np = np.load(text_embedding_path) # [L, D]
# Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy
# Preserve original dtype (should be float16 from extraction)
text_embedding_dtype = torch.float16 if text_embedding_np.dtype == np.float16 else torch.float32
text_embedding = torch.tensor(text_embedding_np, dtype=text_embedding_dtype) # CPU tensor, fully copied
del text_embedding_np # Release numpy array reference
return {
"video_codes": video_codes, # [F', H', W'], CPU tensor, int32
"text_embedding": text_embedding, # [L, D], CPU tensor, float16/bfloat16
"sample_index": sample_idx,
}