# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from torch.utils.data import Dataset from torchvision import transforms from PIL.ImageOps import exif_transpose from PIL import Image import io import pyarrow.parquet as pq import random import bisect import pyarrow.fs as fs import csv import numpy as np import logging logger = logging.getLogger(__name__) @torch.no_grad() def tokenize_prompt(tokenizer, prompt, text_encoder_architecture='open_clip'): # support open_clip, CLIP, T5/UMT5 if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': return tokenizer( prompt, truncation=True, padding="max_length", max_length=77, return_tensors="pt", ).input_ids elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']: # T5/UMT5 tokenizer return tokenizer( prompt, truncation=True, padding="max_length", max_length=512, return_tensors="pt", ).input_ids elif text_encoder_architecture == 'CLIP_T5_base': # we have two tokenizers, 1st for CLIP, 2nd for T5 input_ids = [] input_ids.append(tokenizer[0]( prompt, truncation=True, padding="max_length", max_length=77, return_tensors="pt", ).input_ids) input_ids.append(tokenizer[1]( prompt, truncation=True, padding="max_length", max_length=512, return_tensors="pt", ).input_ids) return input_ids else: raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") def encode_prompt(text_encoder, input_ids, text_encoder_architecture='open_clip'): # support open_clip, CLIP, T5/UMT5 if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True) encoder_hidden_states = outputs.hidden_states[-2] cond_embeds = outputs[0] return encoder_hidden_states, cond_embeds elif text_encoder_architecture in ['umt5-base', 'umt5-xxl', 't5']: # T5/UMT5 encoder - only returns encoder_hidden_states, no pooled projection outputs = text_encoder(input_ids=input_ids, return_dict=True) encoder_hidden_states = outputs.last_hidden_state # For T5, we don't have a pooled projection, so return None or a dummy tensor # The video pipeline doesn't use cond_embeds, so we can return None cond_embeds = None return encoder_hidden_states, cond_embeds elif text_encoder_architecture == 'CLIP_T5_base': outputs_clip = text_encoder[0](input_ids=input_ids[0], return_dict=True, output_hidden_states=True) outputs_t5 = text_encoder[1](input_ids=input_ids[1], decoder_input_ids=torch.zeros_like(input_ids[1]), return_dict=True, output_hidden_states=True) encoder_hidden_states = outputs_t5.encoder_hidden_states[-2] cond_embeds = outputs_clip[0] return encoder_hidden_states, cond_embeds else: raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") def process_image(image, size, Norm=False, hps_score = 6.0): image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") orig_height = image.height orig_width = image.width image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) image = transforms.functional.crop(image, c_top, c_left, size, size) image = transforms.ToTensor()(image) if Norm: image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) micro_conds = torch.tensor( [orig_width, orig_height, c_top, c_left, hps_score], ) return {"image": image, "micro_conds": micro_conds} class MyParquetDataset(Dataset): def __init__(self, root_dir, tokenizer=None, size=512, text_encoder_architecture='CLIP', norm=False): random.seed(23) self.root_dir = root_dir self.dataset_receipt = {'MSCOCO_part1': {'total_num': 6212, 'ratio':1}, 'MSCOCO_part2': {'total_num': 6212, 'ratio':1}} self.tokenizer = tokenizer self.size = size self.text_encoder_architecture = text_encoder_architecture self.norm = norm self.hdfs = fs.HadoopFileSystem(host="", port=0000) # TODO: change to your own HDFS host and port self._init_mixed_parquet_dir_list() self.file_metadata = [] self.cumulative_sizes = [0] total = 0 for path in self.parquet_files: try: with pq.ParquetFile(path, filesystem=self.hdfs) as pf: num_rows = pf.metadata.num_rows self.file_metadata.append({ 'path': path, 'num_rows': num_rows, 'global_offset': total }) total += num_rows self.cumulative_sizes.append(total) except Exception as e: print(f"Error processing {path}: {str(e)}") continue # init cache self.current_file = None self.cached_data = None self.cached_file_index = -1 def _init_mixed_parquet_dir_list(self): print('Loading parquet files, please be patient...') self.parquet_files = [] for key, value in self.dataset_receipt.items(): # Generate a list of standard Parquet file paths, lazy load hdfs_path = os.path.join(self.root_dir, key) num = value['total_num'] sampled_list = random.sample( [f"{hdfs_path}/train-{idx:05d}-of-{num:05d}.parquet" for idx in range(num)], k=int(num * value['ratio']) ) self.parquet_files += sampled_list def __len__(self): return self.cumulative_sizes[-1] def _locate_file(self, global_idx): # Use binary search to quickly locate files file_index = bisect.bisect_right(self.cumulative_sizes, global_idx) - 1 if file_index < 0 or file_index >= len(self.file_metadata): raise IndexError(f"Index {global_idx} out of range") file_info = self.file_metadata[file_index] local_idx = global_idx - file_info['global_offset'] return file_index, local_idx def _load_file(self, file_index): """Load Parquet files into cache on demand""" if self.cached_file_index != file_index: file_info = self.file_metadata[file_index] try: table = pq.read_table(file_info['path'], filesystem=self.hdfs) self.cached_data = table.to_pydict() self.cached_file_index = file_index except Exception as e: print(f"Error loading {file_info['path']}: {str(e)}") raise def __getitem__(self, idx): file_index, local_idx = self._locate_file(idx) self._load_file(file_index) sample = {k: v[local_idx] for k, v in self.cached_data.items()} # cprint(sample.keys(), 'red') generated_caption, image_path = sample['task2'], sample['image'] # only suitable for my data instance_image = Image.open(io.BytesIO(image_path['bytes'])) # if instance_image.width < self.size or instance_image.height < self.size: # raise ValueError(f"Image at {image_path} is too small") rv = process_image(instance_image, self.size, self.norm) if isinstance(self.tokenizer, list): _tmp_ = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture) rv["prompt_input_ids"] = [_tmp_[0][0], _tmp_[1][0]] else: rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, generated_caption, self.text_encoder_architecture)[ 0] return rv class HuggingFaceDataset(Dataset): def __init__( self, hf_dataset, tokenizer, image_key, prompt_key, prompt_prefix=None, size=512, text_encoder_architecture='CLIP', ): self.size = size self.image_key = image_key self.prompt_key = prompt_key self.tokenizer = tokenizer self.hf_dataset = hf_dataset self.prompt_prefix = prompt_prefix self.text_encoder_architecture = text_encoder_architecture def __len__(self): return len(self.hf_dataset) def __getitem__(self, index): item = self.hf_dataset[index] rv = process_image(item[self.image_key], self.size) prompt = item[self.prompt_key] if self.prompt_prefix is not None: prompt = self.prompt_prefix + prompt if isinstance(self.tokenizer, list): _tmp_ = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture) rv["prompt_input_ids"] = [_tmp_[0][0],_tmp_[1][0]] else: rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] return rv def process_video(video_tensor, num_frames, height, width, use_random_crop=True): """ Process video tensor for training. Uses aspect-ratio preserving resize + crop to avoid distortion. Args: video_tensor: Video tensor of shape [C, F, H, W] or [F, H, W, C] num_frames: Target number of frames height: Target height width: Target width use_random_crop: If True, use random crop (for training). If False, use center crop (for validation/feature extraction) Returns: Processed video tensor of shape [C, F, H, W] in [0, 1] range """ # Ensure video is in [C, F, H, W] format if video_tensor.dim() == 4: if video_tensor.shape[0] == 3 or video_tensor.shape[0] == 1: # Already in [C, F, H, W] format pass elif video_tensor.shape[-1] == 3 or video_tensor.shape[-1] == 1: # [F, H, W, C] -> [C, F, H, W] video_tensor = video_tensor.permute(3, 0, 1, 2) else: raise ValueError(f"Unexpected video tensor shape: {video_tensor.shape}") # Normalize to [0, 1] if needed if video_tensor.max() > 1.0: video_tensor = video_tensor / 255.0 C, F, H, W = video_tensor.shape # Temporal resampling: ensure exactly num_frames frames if F != num_frames: if F < num_frames: # If video is shorter, pad by repeating the last frame num_pad = num_frames - F last_frame = video_tensor[:, -1:, :, :] # [C, 1, H, W] padding = last_frame.repeat(1, num_pad, 1, 1) # [C, num_pad, H, W] video_tensor = torch.cat([video_tensor, padding], dim=1) # [C, num_frames, H, W] F = num_frames else: # If video is longer, randomly select a continuous segment of num_frames max_start = F - num_frames start_idx = random.randint(0, max_start) indices = torch.arange(start_idx, start_idx + num_frames) video_tensor = video_tensor[:, indices, :, :] F = num_frames # Update F after temporal resampling # Spatial resizing: aspect-ratio preserving resize + crop if H != height or W != width: # Step 1: Aspect-ratio preserving resize # Calculate scale factors for both dimensions scale_h = height / H scale_w = width / W # Use the larger scale to ensure both dimensions are at least as large as target # This way, after resize, we can crop to exact target size scale = max(scale_h, scale_w) # Calculate new dimensions maintaining aspect ratio new_H = int(H * scale) new_W = int(W * scale) # Ensure we have at least the target size (handle rounding) if new_H < height: new_H = height if new_W < width: new_W = width # Resize maintaining aspect ratio # Process each frame: [C, F, H, W] -> reshape to [C*F, 1, H, W] for interpolation video_tensor = torch.nn.functional.interpolate( video_tensor.reshape(C * F, 1, H, W), size=(new_H, new_W), mode='bilinear', align_corners=False ).reshape(C, F, new_H, new_W) # Step 2: Crop to target size (height, width) # Calculate crop coordinates if use_random_crop: # Random crop for training (data augmentation) max_h = new_H - height max_w = new_W - width if max_h < 0 or max_w < 0: # If resized image is smaller than target, pad instead pad_h = max(0, height - new_H) pad_w = max(0, width - new_W) video_tensor = torch.nn.functional.pad( video_tensor, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), mode='constant', value=0 ) # If still not exact size, crop or pad if video_tensor.shape[2] != height or video_tensor.shape[3] != width: video_tensor = torch.nn.functional.interpolate( video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]), size=(height, width), mode='bilinear', align_corners=False ).reshape(C, F, height, width) else: crop_h = random.randint(0, max_h) crop_w = random.randint(0, max_w) video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width] else: # Center crop for validation/feature extraction (deterministic) crop_h = (new_H - height) // 2 crop_w = (new_W - width) // 2 if crop_h < 0 or crop_w < 0: # If resized image is smaller than target, pad instead pad_h = max(0, height - new_H) pad_w = max(0, width - new_W) video_tensor = torch.nn.functional.pad( video_tensor, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), mode='constant', value=0 ) # If still not exact size, crop or pad if video_tensor.shape[2] != height or video_tensor.shape[3] != width: video_tensor = torch.nn.functional.interpolate( video_tensor.reshape(C * F, 1, video_tensor.shape[2], video_tensor.shape[3]), size=(height, width), mode='bilinear', align_corners=False ).reshape(C, F, height, width) else: video_tensor = video_tensor[:, :, crop_h:crop_h + height, crop_w:crop_w + width] # Final verification: ensure output has exactly the expected shape C, F, H, W = video_tensor.shape assert F == num_frames, f"Frame count mismatch: expected {num_frames}, got {F}" assert H == height, f"Height mismatch: expected {height}, got {H}" assert W == width, f"Width mismatch: expected {width}, got {W}" return video_tensor class VideoDataset(Dataset): """ Dataset for video training, compatible with HuggingFace datasets format. Supports OpenVid1M and similar video-text datasets. """ def __init__( self, hf_dataset, tokenizer, video_key="video", prompt_key="caption", prompt_prefix=None, num_frames=16, height=480, width=848, text_encoder_architecture='umt5-base', use_random_crop=True, # Random crop for training, center crop for validation ): self.hf_dataset = hf_dataset self.tokenizer = tokenizer self.video_key = video_key self.prompt_key = prompt_key self.prompt_prefix = prompt_prefix self.num_frames = num_frames self.height = height self.width = width self.text_encoder_architecture = text_encoder_architecture self.use_random_crop = use_random_crop def __len__(self): return len(self.hf_dataset) def __getitem__(self, index): item = self.hf_dataset[index] # Load video video = item[self.video_key] # Convert to tensor if needed (handle different formats) if isinstance(video, list): # List of PIL Images or tensors frames = [] for frame in video: if isinstance(frame, Image.Image): frame = transforms.ToTensor()(frame) frames.append(frame) video_tensor = torch.stack(frames, dim=1) # [C, F, H, W] elif isinstance(video, torch.Tensor): video_tensor = video else: raise ValueError(f"Unsupported video type: {type(video)}") # Process video video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width) # Ensure video tensor has exactly the expected shape C, F, H, W = video_tensor.shape if F != self.num_frames or H != self.height or W != self.width: # If shape doesn't match, create a properly sized tensor video_tensor = torch.nn.functional.interpolate( video_tensor.reshape(C * F, 1, H, W), size=(self.height, self.width), mode='bilinear', align_corners=False ).reshape(C, F, self.height, self.width) # Ensure exactly num_frames if F < self.num_frames: # Pad by repeating last frame num_pad = self.num_frames - F last_frame = video_tensor[:, -1:, :, :] padding = last_frame.repeat(1, num_pad, 1, 1) video_tensor = torch.cat([video_tensor, padding], dim=1) elif F > self.num_frames: # Crop to num_frames video_tensor = video_tensor[:, :self.num_frames, :, :] # Clone to ensure storage is resizable (required for DataLoader collate) video_tensor = video_tensor.contiguous().clone() # Process prompt prompt = item[self.prompt_key] if self.prompt_prefix is not None: prompt = self.prompt_prefix + prompt prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] # Clone to ensure storage is resizable prompt_input_ids = prompt_input_ids.clone() rv = { "video": video_tensor, # [C, num_frames, height, width], guaranteed shape "prompt_input_ids": prompt_input_ids } return rv class OpenVid1MDataset(Dataset): """ Dataset for OpenVid1M video-text pairs from CSV file. CSV format: video,caption,aesthetic score,motion score,temporal consistency score,camera motion,frame,fps,seconds,new_id Returns: dict with keys: - "video": torch.Tensor of shape [C, F, H, W] in [0, 1] range - "prompt_input_ids": torch.Tensor of tokenized prompt """ def __init__( self, csv_path, video_root_dir, tokenizer, num_frames=16, height=480, width=848, text_encoder_architecture='umt5-base', prompt_prefix=None, use_random_temporal_crop=True, # If False, always sample from the beginning use_random_crop=True, # Random crop for training, center crop for validation/feature extraction ): """ Args: csv_path: Path to the CSV file containing video metadata video_root_dir: Root directory where video files are stored tokenizer: Text tokenizer num_frames: Target number of frames to extract height: Target height width: Target width text_encoder_architecture: Architecture of text encoder prompt_prefix: Optional prefix to add to prompts """ self.csv_path = csv_path self.video_root_dir = video_root_dir self.tokenizer = tokenizer self.num_frames = num_frames self.height = height self.width = width self.text_encoder_architecture = text_encoder_architecture self.prompt_prefix = prompt_prefix self.use_random_temporal_crop = use_random_temporal_crop self.use_random_crop = use_random_crop # Load CSV data self.data = [] with open(csv_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: self.data.append(row) logger.info(f"Loaded {len(self.data)} video entries from {csv_path}") # Try to import video loading library self.video_loader = None try: import decord decord.bridge.set_bridge('torch') self.video_loader = 'decord' logger.info("Using decord for video loading") except ImportError: try: import av self.video_loader = 'av' logger.info("Using PyAV for video loading") except ImportError: try: import cv2 self.video_loader = 'cv2' logger.info("Using OpenCV for video loading") except ImportError: raise ImportError( "No video loading library found. Please install one of: " "decord (pip install decord), PyAV (pip install av), or opencv-python (pip install opencv-python)" ) def __len__(self): return len(self.data) def _load_video_decord(self, video_path): """Load video using decord""" import decord vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) total_frames = len(vr) # Sample frames: random temporal crop (continuous segment) for better temporal coherence if total_frames <= self.num_frames: indices = list(range(total_frames)) else: if self.use_random_temporal_crop: # Randomly select a continuous segment of num_frames max_start = total_frames - self.num_frames start_idx = random.randint(0, max_start) else: # Fixed sampling: always start from the beginning start_idx = 0 indices = list(range(start_idx, start_idx + self.num_frames)) frames = vr.get_batch(indices) # [F, H, W, C] in uint8 # If using torch bridge, frames is already a torch Tensor if isinstance(frames, torch.Tensor): frames = frames.float() # [F, H, W, C] else: # Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy # This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate frames = torch.tensor(frames, dtype=torch.float32) # [F, H, W, C], fully copied frames = frames.permute(3, 0, 1, 2) # [C, F, H, W] frames = frames / 255.0 # Normalize to [0, 1] return frames def _load_video_av(self, video_path): """Load video using PyAV""" import av container = av.open(video_path) frames = [] # Get video stream video_stream = container.streams.video[0] total_frames = video_stream.frames if video_stream.frames > 0 else None # Sample frames: random temporal crop (continuous segment) for better temporal coherence if total_frames is None: # If we can't get frame count, decode all frames and sample frame_list = [] for frame in container.decode(video_stream): frame_list.append(frame) total_frames = len(frame_list) if total_frames <= self.num_frames: frame_indices = list(range(total_frames)) else: if self.use_random_temporal_crop: # Randomly select a continuous segment of num_frames max_start = total_frames - self.num_frames start_idx = random.randint(0, max_start) else: # Fixed sampling: always start from the beginning start_idx = 0 frame_indices = list(range(start_idx, start_idx + self.num_frames)) frames = [transforms.ToTensor()(frame_list[i].to_image()) for i in frame_indices] else: if total_frames <= self.num_frames: frame_indices = list(range(total_frames)) else: if self.use_random_temporal_crop: # Randomly select a continuous segment of num_frames max_start = total_frames - self.num_frames start_idx = random.randint(0, max_start) else: # Fixed sampling: always start from the beginning start_idx = 0 frame_indices = list(range(start_idx, start_idx + self.num_frames)) frame_idx = 0 for frame in container.decode(video_stream): if frame_idx in frame_indices: img = frame.to_image() # PIL Image img_tensor = transforms.ToTensor()(img) # [C, H, W] frames.append(img_tensor) if len(frames) >= self.num_frames: break frame_idx += 1 container.close() if len(frames) == 0: raise ValueError(f"No frames extracted from {video_path}") # Stack frames: [C, F, H, W] video_tensor = torch.stack(frames, dim=1) # Pad if needed if video_tensor.shape[1] < self.num_frames: padding = torch.zeros( video_tensor.shape[0], self.num_frames - video_tensor.shape[1], video_tensor.shape[2], video_tensor.shape[3] ) video_tensor = torch.cat([video_tensor, padding], dim=1) return video_tensor def _load_video_cv2(self, video_path): """Load video using OpenCV""" import cv2 cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Sample frames: random temporal crop (continuous segment) for better temporal coherence if total_frames <= self.num_frames: frame_indices = list(range(total_frames)) else: if self.use_random_temporal_crop: # Randomly select a continuous segment of num_frames max_start = total_frames - self.num_frames start_idx = random.randint(0, max_start) else: # Fixed sampling: always start from the beginning start_idx = 0 frame_indices = list(range(start_idx, start_idx + self.num_frames)) frame_idx = 0 while True: ret, frame = cap.read() if not ret: break if frame_idx in frame_indices: # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert to tensor [C, H, W] and normalize to [0, 1] # Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy # This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate frame_tensor = torch.tensor(frame_rgb, dtype=torch.float32).permute(2, 0, 1) / 255.0 frames.append(frame_tensor) if len(frames) >= self.num_frames: break frame_idx += 1 cap.release() if len(frames) == 0: raise ValueError(f"No frames extracted from {video_path}") # Stack frames: [C, F, H, W] video_tensor = torch.stack(frames, dim=1) # Pad if needed if video_tensor.shape[1] < self.num_frames: padding = torch.zeros( video_tensor.shape[0], self.num_frames - video_tensor.shape[1], video_tensor.shape[2], video_tensor.shape[3] ) video_tensor = torch.cat([video_tensor, padding], dim=1) return video_tensor def _load_video(self, video_path): """Load video from path using the available video loader""" full_path = os.path.join(self.video_root_dir, video_path) if not os.path.exists(full_path): raise FileNotFoundError(f"Video file not found: {full_path}") if self.video_loader == 'decord': return self._load_video_decord(full_path) elif self.video_loader == 'av': return self._load_video_av(full_path) elif self.video_loader == 'cv2': return self._load_video_cv2(full_path) else: raise ValueError(f"Unknown video loader: {self.video_loader}") def __getitem__(self, index): row = self.data[index] # Load video video_path = row['video'] try: video_tensor = self._load_video(video_path) except Exception as e: # If video loading fails, return a zero tensor and log error logger.warning(f"Failed to load video {video_path}: {e}") video_tensor = torch.zeros(3, self.num_frames, self.height, self.width) # Process video: aspect-ratio preserving resize + crop to target dimensions video_tensor = process_video(video_tensor, self.num_frames, self.height, self.width, use_random_crop=self.use_random_crop) # Ensure video tensor has exactly the expected shape C, F, H, W = video_tensor.shape if F != self.num_frames or H != self.height or W != self.width: # If shape doesn't match, create a properly sized tensor video_tensor = torch.nn.functional.interpolate( video_tensor.reshape(C * F, 1, H, W), size=(self.height, self.width), mode='bilinear', align_corners=False ).reshape(C, F, self.height, self.width) # Ensure exactly num_frames if F < self.num_frames: # Pad by repeating last frame num_pad = self.num_frames - F last_frame = video_tensor[:, -1:, :, :] padding = last_frame.repeat(1, num_pad, 1, 1) video_tensor = torch.cat([video_tensor, padding], dim=1) elif F > self.num_frames: # Crop to num_frames video_tensor = video_tensor[:, :self.num_frames, :, :] # Clone to ensure storage is resizable (required for DataLoader collate) video_tensor = video_tensor.contiguous().clone() # Process prompt prompt = row['caption'] if self.prompt_prefix is not None: prompt = self.prompt_prefix + prompt prompt_input_ids = tokenize_prompt(self.tokenizer, prompt, self.text_encoder_architecture)[0] # Clone to ensure storage is resizable prompt_input_ids = prompt_input_ids.clone() return { "video": video_tensor, # [C, num_frames, height, width], guaranteed shape "prompt_input_ids": prompt_input_ids } class TinyOpenVid1MDataset(OpenVid1MDataset): """ A tiny subset of OpenVid1MDataset for overfitting experiments. Only takes the first N samples from the full dataset. """ def __init__( self, csv_path, video_root_dir=None, tokenizer=None, num_frames=16, height=480, width=848, text_encoder_architecture='umt5-base', prompt_prefix=None, max_samples=256, # Only use first N samples seed=42, # Fixed seed for reproducibility ): """ Args: max_samples: Maximum number of samples to use (default: 256) seed: Random seed for reproducibility (default: 42) """ # Initialize parent class super().__init__( csv_path=csv_path, video_root_dir=video_root_dir, tokenizer=tokenizer, num_frames=num_frames, height=height, width=width, text_encoder_architecture=text_encoder_architecture, prompt_prefix=prompt_prefix, ) # Limit to first max_samples original_len = len(self.data) if original_len > max_samples: # Use fixed seed to ensure reproducibility import random random.seed(seed) # Shuffle with fixed seed, then take first max_samples indices = list(range(original_len)) random.shuffle(indices) self.data = [self.data[i] for i in indices[:max_samples]] logger.info(f"Limited dataset to {max_samples} samples (from {original_len} total) for overfitting experiment") else: logger.info(f"Using all {len(self.data)} samples (less than max_samples={max_samples})") def get_hierarchical_path(base_dir, index): """ Get hierarchical path for loading features from 3-level directory structure. Structure: base_dir/level1/level2/level3/filename.npy - level1: index // 1000000 (0-999) - level2: (index // 1000) % 1000 (0-999) - level3: index % 1000 (0-999) Args: base_dir: Base directory for features index: Sample index Returns: Full path to the file """ level1 = index // 1000000 level2 = (index // 1000) % 1000 level3 = index % 1000 file_path = os.path.join( base_dir, f"{level1:03d}", f"{level2:03d}", f"{level3:03d}", f"{index:08d}.npy" ) return file_path class PrecomputedFeatureDataset(Dataset): """ Dataset for loading pre-extracted video codes and text embeddings. This dataset loads features that were pre-extracted by extract_features.py, avoiding the need to encode videos and text during training. Features are stored in a 3-level hierarchical directory structure: - video_codes/level1/level2/level3/index.npy - text_embeddings/level1/level2/level3/index.npy """ def __init__( self, features_dir, num_samples=None, start_index=0, ): """ Args: features_dir: Directory containing extracted features (should have video_codes/ and text_embeddings/ subdirs) num_samples: Number of samples to use. If None, use all available samples. start_index: Starting index for samples (for resuming or subset selection) """ self.features_dir = features_dir self.video_codes_dir = os.path.join(features_dir, "video_codes") self.text_embeddings_dir = os.path.join(features_dir, "text_embeddings") self.metadata_file = os.path.join(features_dir, "metadata.json") # Load metadata if os.path.exists(self.metadata_file): import json with open(self.metadata_file, 'r') as f: self.metadata = json.load(f) logger.info(f"Loaded metadata from {self.metadata_file}") logger.info(f" Total samples in metadata: {self.metadata.get('num_samples', 'unknown')}") # Get available indices from metadata if 'samples' in self.metadata and len(self.metadata['samples']) > 0: available_indices = sorted([s['index'] for s in self.metadata['samples']]) else: # Fallback: infer from directory structure available_indices = self._scan_hierarchical_directory(self.video_codes_dir) else: # If no metadata, scan directory structure logger.warning(f"Metadata file not found: {self.metadata_file}, scanning directory structure") self.metadata = {} available_indices = self._scan_hierarchical_directory(self.video_codes_dir) # Filter by start_index and num_samples available_indices = [idx for idx in available_indices if idx >= start_index] if num_samples is not None: available_indices = available_indices[:num_samples] self.indices = available_indices logger.info(f"PrecomputedFeatureDataset: {len(self.indices)} samples available") if len(self.indices) > 0: logger.info(f" Index range: {min(self.indices)} to {max(self.indices)}") def _scan_hierarchical_directory(self, base_dir): """ Scan hierarchical directory structure to find all available indices. Args: base_dir: Base directory to scan Returns: List of available indices """ available_indices = [] if not os.path.exists(base_dir): raise FileNotFoundError(f"Directory not found: {base_dir}") # Scan level1 directories (000-999) for level1 in range(1000): level1_dir = os.path.join(base_dir, f"{level1:03d}") if not os.path.exists(level1_dir): continue # Scan level2 directories (000-999) for level2 in range(1000): level2_dir = os.path.join(level1_dir, f"{level2:03d}") if not os.path.exists(level2_dir): continue # Scan level3 directories (000-999) for level3 in range(1000): level3_dir = os.path.join(level2_dir, f"{level3:03d}") if not os.path.exists(level3_dir): continue # List all .npy files in level3 directory for filename in os.listdir(level3_dir): if filename.endswith('.npy'): try: index = int(filename.replace('.npy', '')) available_indices.append(index) except ValueError: continue return sorted(available_indices) def __len__(self): return len(self.indices) def __getitem__(self, idx): sample_idx = self.indices[idx] # Get hierarchical paths video_code_path = get_hierarchical_path(self.video_codes_dir, sample_idx) text_embedding_path = get_hierarchical_path(self.text_embeddings_dir, sample_idx) # Load video codes # Note: We load directly (not mmap) to avoid storage sharing issues with torch # The files are small enough (video codes are int32, typically < 1MB per sample) if not os.path.exists(video_code_path): raise FileNotFoundError(f"Video code not found: {video_code_path}") video_codes_np = np.load(video_code_path) # [F', H', W'] # Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy # This avoids "Trying to resize storage that is not resizable" errors in DataLoader collate video_codes = torch.tensor(video_codes_np, dtype=torch.int32) # CPU tensor, int32, fully copied del video_codes_np # Release numpy array reference # Load text embedding # Note: We load directly (not mmap) to avoid storage sharing issues with torch if not os.path.exists(text_embedding_path): raise FileNotFoundError(f"Text embedding not found: {text_embedding_path}") text_embedding_np = np.load(text_embedding_path) # [L, D] # Use torch.tensor() instead of torch.from_numpy() to ensure a complete copy # Preserve original dtype (should be float16 from extraction) text_embedding_dtype = torch.float16 if text_embedding_np.dtype == np.float16 else torch.float32 text_embedding = torch.tensor(text_embedding_np, dtype=text_embedding_dtype) # CPU tensor, fully copied del text_embedding_np # Release numpy array reference return { "video_codes": video_codes, # [F', H', W'], CPU tensor, int32 "text_embedding": text_embedding, # [L, D], CPU tensor, float16/bfloat16 "sample_index": sample_idx, }