#!/usr/bin/env python3 from __future__ import annotations import argparse import math from concurrent.futures import Future from concurrent.futures import ProcessPoolExecutor from concurrent.futures import as_completed from pathlib import Path import numpy as np import pandas as pd import torchaudio from rich.progress import BarColumn from rich.progress import MofNCompleteColumn from rich.progress import Progress from rich.progress import TaskProgressColumn from rich.progress import TextColumn from rich.progress import TimeElapsedColumn from rich.progress import TimeRemainingColumn def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Estimate storage required for a WAV-only WebDataset generated from a parquet manifest." ) ) parser.add_argument( "parquet_path", type=Path, help="Path to parquet file (must include audio path column).", ) parser.add_argument( "--audio-column", type=str, default="file_path", help="Column name containing source audio file paths.", ) parser.add_argument( "--duration-column", type=str, default="duration_sec", help="Column name containing durations in seconds.", ) parser.add_argument( "--workers", type=int, default=24, help="Parallel worker count used to probe missing durations.", ) parser.add_argument( "--sample-rate", type=int, default=16000, help="Target WAV sample rate.", ) parser.add_argument( "--channels", type=int, default=1, help="Target WAV channel count.", ) parser.add_argument( "--bits-per-sample", type=int, default=16, choices=[8, 16, 24, 32], help="Target WAV PCM bit depth.", ) parser.add_argument( "--shard-size-gb", type=float, default=1.0, help="Assumed max size per .tar shard in GB (decimal) for trailer overhead estimate.", ) parser.add_argument( "--no-probe-missing", action="store_true", help="Fail if duration is missing/invalid instead of probing audio headers.", ) return parser.parse_args() def format_bytes(n_bytes: float) -> str: units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] value = float(n_bytes) for unit in units: if value < 1024.0 or unit == units[-1]: return f"{value:,.2f} {unit}" value /= 1024.0 return f"{value:,.2f} PiB" def estimate_shard_count(total_bytes: int, shard_size_gb: float) -> int: shard_size_bytes = max(1, int(shard_size_gb * 1_000_000_000)) return max(1, math.ceil(total_bytes / shard_size_bytes)) def parse_duration_value(value: object) -> float: if value is None: return float("nan") if isinstance(value, (int, float, np.integer, np.floating)): return float(value) if isinstance(value, str): try: return float(value) except ValueError: return float("nan") try: return float(str(value)) except (TypeError, ValueError): return float("nan") def probe_duration_sec(audio_path: str) -> float: try: info_fn = getattr(torchaudio, "info", None) if info_fn is None: return float("nan") info = info_fn(audio_path) if info.sample_rate <= 0: return float("nan") return float(info.num_frames) / float(info.sample_rate) except Exception: return float("nan") def probe_file_size_bytes(audio_path: str) -> int: try: return Path(audio_path).stat().st_size except Exception: return -1 def probe_durations_parallel(paths: list[str], workers: int) -> list[float]: if not paths: return [] results: list[float] = [float("nan")] * len(paths) with ProcessPoolExecutor(max_workers=workers) as executor: future_to_idx: dict[Future[float], int] = { executor.submit(probe_duration_sec, path): idx for idx, path in enumerate(paths) } with Progress( TextColumn("[bold cyan]{task.description}"), BarColumn(), MofNCompleteColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), ) as progress: task_id = progress.add_task("Probing missing durations", total=len(paths)) for future in as_completed(future_to_idx): idx = future_to_idx[future] try: results[idx] = future.result() except Exception: results[idx] = float("nan") progress.advance(task_id, 1) return results def probe_file_sizes_parallel(paths: list[str], workers: int) -> list[int]: if not paths: return [] results: list[int] = [0] * len(paths) with ProcessPoolExecutor(max_workers=workers) as executor: future_to_idx: dict[Future[int], int] = { executor.submit(probe_file_size_bytes, path): idx for idx, path in enumerate(paths) } with Progress( TextColumn("[bold cyan]{task.description}"), BarColumn(), MofNCompleteColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), ) as progress: task_id = progress.add_task("Probing current file sizes", total=len(paths)) for future in as_completed(future_to_idx): idx = future_to_idx[future] try: results[idx] = future.result() except Exception: results[idx] = -1 progress.advance(task_id, 1) return results def main() -> None: args = parse_args() if not args.parquet_path.exists(): raise FileNotFoundError(f"Parquet not found: {args.parquet_path}") if args.workers < 1: raise ValueError("--workers must be >= 1") df = pd.read_parquet(args.parquet_path) if args.audio_column not in df.columns: raise ValueError(f"Missing audio column '{args.audio_column}' in parquet.") if args.duration_column not in df.columns: raise ValueError( f"Missing duration column '{args.duration_column}' in parquet. " "Either add it or adapt the script." ) durations = np.array( [parse_duration_value(value) for value in df[args.duration_column].tolist()], dtype=np.float64, ) paths = [str(path) for path in df[args.audio_column].tolist()] n_rows = len(paths) unique_paths = list(dict.fromkeys(paths)) unique_sizes = probe_file_sizes_parallel(unique_paths, workers=args.workers) size_by_path = dict(zip(unique_paths, unique_sizes)) row_sizes = np.array([size_by_path[path] for path in paths], dtype=np.int64) current_row_known_mask = row_sizes >= 0 current_unique_known_mask = np.array(unique_sizes, dtype=np.int64) >= 0 current_rows_total_bytes = int(row_sizes[current_row_known_mask].sum()) current_unique_total_bytes = int( np.array(unique_sizes, dtype=np.int64)[current_unique_known_mask].sum() ) current_rows_missing = int((~current_row_known_mask).sum()) current_unique_missing = int((~current_unique_known_mask).sum()) invalid_mask = ~np.isfinite(durations) | (durations <= 0.0) n_missing = int(invalid_mask.sum()) if n_missing > 0: if args.no_probe_missing: raise ValueError( f"Found {n_missing} rows with missing/invalid durations and --no-probe-missing was set." ) probe_indices = np.where(invalid_mask)[0].tolist() probe_paths = [paths[i] for i in probe_indices] probed = probe_durations_parallel(probe_paths, workers=args.workers) for i, duration in zip(probe_indices, probed, strict=True): durations[i] = duration unresolved_mask = ~np.isfinite(durations) | (durations <= 0.0) n_unresolved = int(unresolved_mask.sum()) valid_durations = durations[~unresolved_mask] n_valid = int(valid_durations.shape[0]) if n_valid == 0: raise RuntimeError("No valid durations available; cannot estimate size.") bytes_per_sample = args.bits_per_sample // 8 bytes_per_frame = args.channels * bytes_per_sample frames = np.rint(valid_durations * args.sample_rate).astype(np.int64) wav_file_bytes = 44 + frames * bytes_per_frame wav_total_bytes = int(wav_file_bytes.sum()) padded_wav_data = ((wav_file_bytes + 511) // 512) * 512 tar_entry_overhead = 512 * n_valid tar_data_bytes = int(padded_wav_data.sum()) estimated_shards = estimate_shard_count(wav_total_bytes, args.shard_size_gb) tar_trailer_overhead = 1024 * estimated_shards wds_tar_total_bytes = tar_entry_overhead + tar_data_bytes + tar_trailer_overhead total_duration_sec = float(valid_durations.sum()) total_hours = total_duration_sec / 3600.0 print("=== WebDataset WAV Size Estimate ===") print(f"Parquet: {args.parquet_path}") print(f"Rows total: {n_rows:,}") print(f"Rows valid: {n_valid:,}") print(f"Rows missing/invalid duration initially: {n_missing:,}") print(f"Rows unresolved after probe: {n_unresolved:,}") print() print("Current dataset volume (before conversion):") print( "- Referenced rows total bytes: " f"{format_bytes(current_rows_total_bytes)} " f"(missing rows: {current_rows_missing:,})" ) print( "- Unique source files total bytes: " f"{format_bytes(current_unique_total_bytes)} " f"(missing files: {current_unique_missing:,})" ) print() print( f"Target WAV format: {args.sample_rate} Hz, {args.channels} ch, " f"PCM {args.bits_per_sample}-bit" ) print(f"Total audio duration: {total_hours:,.2f} h") print() print(f"Estimated WAV bytes (sum of .wav files): {format_bytes(wav_total_bytes)}") print(f"Estimated WDS TAR bytes (wav-only): {format_bytes(wds_tar_total_bytes)}") print( f"Average WAV per sample: {format_bytes(wav_total_bytes / n_valid)}" ) print( f"Average WDS per sample: {format_bytes(wds_tar_total_bytes / n_valid)}" ) print() print(f"Assumed shard size: {args.shard_size_gb} GB") print(f"Estimated shard count: {estimated_shards:,}") print() print("Shard count quick table (for WAV payload):") for shard_size_gb in [1.0, 2.0, 5.0, 10.0]: shard_count = estimate_shard_count(wav_total_bytes, shard_size_gb) print(f"- {shard_size_gb:>4.1f} GB -> {shard_count:,} shards") if __name__ == "__main__": main()