| import os |
| import csv |
| import argparse |
| import multiprocessing as mp |
| from pathlib import Path |
| from typing import List, Dict |
| from functools import partial |
| import webdataset as wds |
| import torch |
| import numpy as np |
| from tqdm import tqdm |
| from decord import VideoReader |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Convert CSV file to WebDataset format with video data") |
| parser.add_argument("--csv_path", type=str, required=True, help="Path to the CSV file") |
| parser.add_argument("--output_dir", type=str, required=True, help="Output directory for WebDataset shards") |
| parser.add_argument("--num_shards", type=int, default=128, help="Number of shards to create") |
| parser.add_argument("--samples_per_shard", type=int, default=None, |
| help="Max samples per shard (overrides num_shards if specified)") |
| parser.add_argument("--worker_count", type=int, default=mp.cpu_count(), |
| help="Number of worker processes") |
| parser.add_argument("--shard_prefix", type=str, default="shard", |
| help="Prefix for shard filenames") |
| parser.add_argument("--video_extension", type=str, default=".webm", |
| help="Extension of video files (default: .webm)") |
| parser.add_argument("--debug", action="store_true", |
| help="Debug mode: create a shard_debug.tar file with max 1000 videos") |
| parser.add_argument("--si", type=int, default=0, help="Start index") |
| parser.add_argument("--ei", type=int, default=None, help="End index") |
| return parser.parse_args() |
|
|
|
|
| def read_csv_data(csv_path: str, debug: bool = False) -> List[Dict]: |
| """Read the CSV file and return a list of samples.""" |
| samples = [] |
| with open(csv_path, 'r') as f: |
| reader = csv.DictReader(f) |
| for i, row in enumerate(reader): |
| samples.append(row) |
| |
| if debug and i >= 999: |
| break |
| |
| |
| si = args.si |
| ei = args.ei if args.ei is not None else len(samples) |
| print("Selected samples from index", si, "to", ei) |
| samples = samples[si:ei] |
|
|
| return samples |
|
|
|
|
| def distribute_samples(samples: List[Dict], num_shards: int) -> List[List[Dict]]: |
| """Distribute samples across shards.""" |
| samples_per_shard = len(samples) // num_shards |
| remainder = len(samples) % num_shards |
| |
| distributed_samples = [] |
| start_idx = 0 |
| |
| for i in range(num_shards): |
| |
| shard_size = samples_per_shard + (1 if i < remainder else 0) |
| end_idx = start_idx + shard_size |
| |
| distributed_samples.append(samples[start_idx:end_idx]) |
| start_idx = end_idx |
| |
| return distributed_samples |
|
|
|
|
| def process_shard(shard_samples: List[Dict], shard_path: str, video_extension: str = ".webm"): |
| """Process and write a single shard with actual video data.""" |
| with wds.TarWriter(shard_path) as sink: |
| for sample in tqdm(shard_samples, desc=f"Processing {shard_path}"): |
| video_path = sample['video_path'] |
|
|
| vr = VideoReader(video_path, num_threads=1) |
| n_frames = len(vr) |
| fps = vr.get_avg_fps() |
| H, W, _ = vr[0].shape |
| |
| try: |
| |
| with open(video_path, 'rb') as f: |
| video_data = f.read() |
| |
| |
| filename = Path(video_path).stem |
|
|
| |
| sample_dict = { |
| "__key__": filename, |
| "video": video_data, |
| "video.extension": video_extension.lstrip('.'), |
| "target": str(sample['target']), |
| |
| "json": dict(n_frames=n_frames, fps=fps, H=H, W=W), |
| } |
| |
| sink.write(sample_dict) |
| except Exception as e: |
| print(f"Error processing {video_path}: {str(e)}") |
|
|
|
|
| import io |
| import torchvision |
|
|
| def encode_tensor(tensor): |
| """ |
| Convert tensor to bytes in memory. |
| """ |
| |
| with io.BytesIO() as buf: |
| if isinstance(tensor, torch.Tensor): |
| torch.save(tensor, buf) |
| return buf.getvalue() |
|
|
|
|
| def process_shard_tensor(shard_samples: List[Dict], shard_path: str, video_extension: str = ".webm"): |
| """Process and write a single shard with actual video data (actual tensor).""" |
| with wds.TarWriter(shard_path) as sink: |
| for sample in tqdm(shard_samples, desc=f"Processing {shard_path}"): |
| video_path = sample['video_path'] |
|
|
| |
| video, audio, info = torchvision.io.read_video(video_path, pts_unit='sec') |
| video_data = encode_tensor(video) |
| n_frames = len(video) |
| fps = info['video_fps'] |
| H, W = video.shape[1:-1] |
| |
| try: |
| |
| |
| |
| |
| |
| filename = Path(video_path).stem |
|
|
| |
| sample_dict = { |
| "__key__": filename, |
| "video": video_data, |
| "video.extension": video_extension.lstrip('.'), |
| "target": str(sample['target']), |
| |
| "json": dict(n_frames=n_frames, fps=fps, H=H, W=W), |
| } |
|
|
| sink.write(sample_dict) |
| except Exception as e: |
| print(f"Error processing {video_path}: {str(e)}") |
|
|
|
|
| def create_webdataset(csv_path: str, output_dir: str, num_shards: int, |
| samples_per_shard: int = None, worker_count: int = None, |
| shard_prefix: str = "shard", video_extension: str = ".webm", |
| debug: bool = False): |
| """Convert CSV to WebDataset format with video data and parallel processing.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print(f"Reading samples from {csv_path}...") |
| samples = read_csv_data(csv_path, debug=debug) |
| total_samples = len(samples) |
| print(f"Found {total_samples} samples in the CSV file") |
| |
| |
| if debug: |
| print("Debug mode enabled: Creating shard_debug.tar with max 1000 videos") |
| debug_shard_path = os.path.join(output_dir, "shard_debug.tar") |
| |
| process_shard_tensor(samples, debug_shard_path, video_extension) |
| |
| |
| file_size = os.path.getsize(debug_shard_path) |
| print(f"Created debug shard: {debug_shard_path}") |
| print(f"Debug shard size: {file_size / (1024**2):.2f} MB") |
| |
| |
| test_dataset(output_dir, debug_pattern="shard_debug.tar") |
| return |
| |
| |
| if samples_per_shard is not None: |
| num_shards = (total_samples + samples_per_shard - 1) // samples_per_shard |
| print(f"Creating {num_shards} shards with max {samples_per_shard} samples per shard") |
| else: |
| print(f"Creating {num_shards} shards") |
| |
| |
| shard_samples = distribute_samples(samples, num_shards) |
| |
| |
| shard_paths = [ |
| os.path.join(output_dir, f"{shard_prefix}_{i:05d}.tar") |
| for i in range(num_shards) |
| ] |
| |
| |
| if worker_count is None: |
| worker_count = mp.cpu_count() |
| |
| worker_count = min(worker_count, num_shards) |
| |
| print(f"Using {worker_count} worker processes") |
| |
| |
| |
| process_func = partial(process_shard_tensor, video_extension=video_extension) |
| |
| with mp.Pool(worker_count) as pool: |
| list(tqdm( |
| pool.starmap(process_func, zip(shard_samples, shard_paths)), |
| total=num_shards, |
| desc="Creating WebDataset shards with video data" |
| )) |
| |
| print(f"Successfully created {num_shards} WebDataset shards in {output_dir}") |
| |
| |
| total_size = sum(os.path.getsize(path) for path in shard_paths) |
| print(f"Total dataset size: {total_size / (1024**2):.2f} MB") |
|
|
|
|
| def test_dataset(output_dir: str, shard_prefix: str = "shard", debug_pattern: str = None): |
| """Test reading from the created WebDataset.""" |
| |
| if debug_pattern: |
| shard_pattern = os.path.join(output_dir, debug_pattern) |
| else: |
| shard_pattern = os.path.join(output_dir, f"{shard_prefix}_*.tar") |
| |
| |
| dataset = wds.WebDataset(shard_pattern) |
| |
| |
| print("\nTesting dataset:") |
| for i, sample in enumerate(dataset): |
| print(f"Sample {i}:") |
| for key, value in sample.items(): |
| if key == "video": |
| print(f" {key}: <binary data of length {len(value)}>") |
| else: |
| print(f" {key}: {value}") |
| |
| if i >= 2: |
| break |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| create_webdataset( |
| csv_path=args.csv_path, |
| output_dir=args.output_dir, |
| num_shards=args.num_shards, |
| samples_per_shard=args.samples_per_shard, |
| worker_count=args.worker_count, |
| shard_prefix=args.shard_prefix, |
| video_extension=args.video_extension, |
| debug=args.debug |
| ) |