Safetensors
English
llava
video-retrieval
text-to-video-search
multimodal-embedding
TARA / shared /scripts /create_webdataset.py
bpiyush's picture
Update TARA to latest Tarsier2 checkpoint and runnable demo.
7daf628
import os
import csv
import argparse
import multiprocessing as mp
from pathlib import Path
from typing import List, Dict
from functools import partial
import webdataset as wds
import torch
import numpy as np
from tqdm import tqdm
from decord import VideoReader
def parse_args():
parser = argparse.ArgumentParser(description="Convert CSV file to WebDataset format with video data")
parser.add_argument("--csv_path", type=str, required=True, help="Path to the CSV file")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for WebDataset shards")
parser.add_argument("--num_shards", type=int, default=128, help="Number of shards to create")
parser.add_argument("--samples_per_shard", type=int, default=None,
help="Max samples per shard (overrides num_shards if specified)")
parser.add_argument("--worker_count", type=int, default=mp.cpu_count(),
help="Number of worker processes")
parser.add_argument("--shard_prefix", type=str, default="shard",
help="Prefix for shard filenames")
parser.add_argument("--video_extension", type=str, default=".webm",
help="Extension of video files (default: .webm)")
parser.add_argument("--debug", action="store_true",
help="Debug mode: create a shard_debug.tar file with max 1000 videos")
parser.add_argument("--si", type=int, default=0, help="Start index")
parser.add_argument("--ei", type=int, default=None, help="End index")
return parser.parse_args()
def read_csv_data(csv_path: str, debug: bool = False) -> List[Dict]:
"""Read the CSV file and return a list of samples."""
samples = []
with open(csv_path, 'r') as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
samples.append(row)
# In debug mode, limit to 1000 samples
if debug and i >= 999:
break
# Select start and end index if specified
si = args.si
ei = args.ei if args.ei is not None else len(samples)
print("Selected samples from index", si, "to", ei)
samples = samples[si:ei]
return samples
def distribute_samples(samples: List[Dict], num_shards: int) -> List[List[Dict]]:
"""Distribute samples across shards."""
samples_per_shard = len(samples) // num_shards
remainder = len(samples) % num_shards
distributed_samples = []
start_idx = 0
for i in range(num_shards):
# Add one extra sample for the first 'remainder' shards
shard_size = samples_per_shard + (1 if i < remainder else 0)
end_idx = start_idx + shard_size
distributed_samples.append(samples[start_idx:end_idx])
start_idx = end_idx
return distributed_samples
def process_shard(shard_samples: List[Dict], shard_path: str, video_extension: str = ".webm"):
"""Process and write a single shard with actual video data."""
with wds.TarWriter(shard_path) as sink:
for sample in tqdm(shard_samples, desc=f"Processing {shard_path}"):
video_path = sample['video_path']
vr = VideoReader(video_path, num_threads=1)
n_frames = len(vr)
fps = vr.get_avg_fps()
H, W, _ = vr[0].shape
try:
# Read video file as binary data
with open(video_path, 'rb') as f:
video_data = f.read()
# Get filename without path for the key
filename = Path(video_path).stem
# Create sample with the actual video data
sample_dict = {
"__key__": filename,
"video": video_data, # Actual video binary data
"video.extension": video_extension.lstrip('.'), # Store extension without dot
"target": str(sample['target']), # Target/label
# "split": str(sample['split']), # Train/val/test split
"json": dict(n_frames=n_frames, fps=fps, H=H, W=W), # Additional metadata
}
sink.write(sample_dict)
except Exception as e:
print(f"Error processing {video_path}: {str(e)}")
import io
import torchvision
def encode_tensor(tensor):
"""
Convert tensor to bytes in memory.
"""
# Convert the tensor to bytes in memory
with io.BytesIO() as buf:
if isinstance(tensor, torch.Tensor):
torch.save(tensor, buf) # Save tensor to the buffer
return buf.getvalue() # Return the byte data
def process_shard_tensor(shard_samples: List[Dict], shard_path: str, video_extension: str = ".webm"):
"""Process and write a single shard with actual video data (actual tensor)."""
with wds.TarWriter(shard_path) as sink:
for sample in tqdm(shard_samples, desc=f"Processing {shard_path}"):
video_path = sample['video_path']
# Load the entire video as a tensor
video, audio, info = torchvision.io.read_video(video_path, pts_unit='sec')
video_data = encode_tensor(video)
n_frames = len(video)
fps = info['video_fps']
H, W = video.shape[1:-1]
try:
# # Read video file as binary data
# with open(video_path, 'rb') as f:
# video_data = f.read()
# Get filename without path for the key
filename = Path(video_path).stem
# Create sample with the actual video data
sample_dict = {
"__key__": filename,
"video": video_data, # Actual video binary data
"video.extension": video_extension.lstrip('.'), # Store extension without dot
"target": str(sample['target']), # Target/label
# "split": str(sample['split']), # Train/val/test split
"json": dict(n_frames=n_frames, fps=fps, H=H, W=W), # Additional metadata
}
sink.write(sample_dict)
except Exception as e:
print(f"Error processing {video_path}: {str(e)}")
def create_webdataset(csv_path: str, output_dir: str, num_shards: int,
samples_per_shard: int = None, worker_count: int = None,
shard_prefix: str = "shard", video_extension: str = ".webm",
debug: bool = False):
"""Convert CSV to WebDataset format with video data and parallel processing."""
os.makedirs(output_dir, exist_ok=True)
# Read all samples from the CSV
print(f"Reading samples from {csv_path}...")
samples = read_csv_data(csv_path, debug=debug)
total_samples = len(samples)
print(f"Found {total_samples} samples in the CSV file")
# Handle debug mode with a specific shard_debug.tar file
if debug:
print("Debug mode enabled: Creating shard_debug.tar with max 1000 videos")
debug_shard_path = os.path.join(output_dir, "shard_debug.tar")
# process_shard(samples, debug_shard_path, video_extension)
process_shard_tensor(samples, debug_shard_path, video_extension)
# Calculate and display file size
file_size = os.path.getsize(debug_shard_path)
print(f"Created debug shard: {debug_shard_path}")
print(f"Debug shard size: {file_size / (1024**2):.2f} MB")
# Test the debug shard
test_dataset(output_dir, debug_pattern="shard_debug.tar")
return
# Determine number of shards based on samples_per_shard if provided
if samples_per_shard is not None:
num_shards = (total_samples + samples_per_shard - 1) // samples_per_shard
print(f"Creating {num_shards} shards with max {samples_per_shard} samples per shard")
else:
print(f"Creating {num_shards} shards")
# Distribute samples across shards
shard_samples = distribute_samples(samples, num_shards)
# Prepare shard paths
shard_paths = [
os.path.join(output_dir, f"{shard_prefix}_{i:05d}.tar")
for i in range(num_shards)
]
# Use all available cores if worker_count is not specified
if worker_count is None:
worker_count = mp.cpu_count()
worker_count = min(worker_count, num_shards) # Don't use more workers than shards
print(f"Using {worker_count} worker processes")
# Process shards in parallel with video extension
# process_func = partial(process_shard, video_extension=video_extension)
process_func = partial(process_shard_tensor, video_extension=video_extension)
with mp.Pool(worker_count) as pool:
list(tqdm(
pool.starmap(process_func, zip(shard_samples, shard_paths)),
total=num_shards,
desc="Creating WebDataset shards with video data"
))
print(f"Successfully created {num_shards} WebDataset shards in {output_dir}")
# Calculate and display total dataset size
total_size = sum(os.path.getsize(path) for path in shard_paths)
print(f"Total dataset size: {total_size / (1024**2):.2f} MB")
def test_dataset(output_dir: str, shard_prefix: str = "shard", debug_pattern: str = None):
"""Test reading from the created WebDataset."""
# Find all shard files or use debug pattern
if debug_pattern:
shard_pattern = os.path.join(output_dir, debug_pattern)
else:
shard_pattern = os.path.join(output_dir, f"{shard_prefix}_*.tar")
# Create a dataset
dataset = wds.WebDataset(shard_pattern)
# Display sample info
print("\nTesting dataset:")
for i, sample in enumerate(dataset):
print(f"Sample {i}:")
for key, value in sample.items():
if key == "video":
print(f" {key}: <binary data of length {len(value)}>")
else:
print(f" {key}: {value}")
if i >= 2: # Just show a few samples
break
if __name__ == "__main__":
args = parse_args()
create_webdataset(
csv_path=args.csv_path,
output_dir=args.output_dir,
num_shards=args.num_shards,
samples_per_shard=args.samples_per_shard,
worker_count=args.worker_count,
shard_prefix=args.shard_prefix,
video_extension=args.video_extension,
debug=args.debug
)