whats2000's picture
feat(eda): normalize dataset paths and deduplicate results in summary
95969f7
#!/usr/bin/env python3
"""Distributed EDA with chunk-slice architecture for billion-scale datasets.
Each dataset is sliced into small row-ranges. Each slice is a Dask task that
processes a bounded amount of memory (chunk_size * n_vars). Slice results are
merged per-dataset on the scheduler side with O(1) memory.
No quantiles - only non-zero min, max, mean, sparsity, gene-level stats,
and metadata summaries. This handles datasets from 2 GB to 500 GB.
"""
from __future__ import annotations
import argparse
import concurrent.futures
import gc
import hashlib
import json
import math
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import anndata as ad
import dask
import numpy as np
import pandas as pd
import yaml
from dask.distributed import Client, LocalCluster
from scipy import sparse
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Slice result: the only thing returned from each Dask task
# ---------------------------------------------------------------------------
@dataclass
class SliceResult:
"""Mergeable statistics from one row-slice of a dataset.
Every field is O(1) memory except gene arrays which are O(n_vars).
All fields are JSON-serialisable after to_dict().
"""
# identity
dataset_path: str = ""
slice_start: int = 0
slice_end: int = 0
# matrix global
n_obs_slice: int = 0
n_vars: int = 0
nnz: int = 0
x_sum: float = 0.0
x_sum_sq: float = 0.0
# cell-level running stats (non-zero counts per cell)
cell_total_counts_sum: float = 0.0
cell_total_counts_min: float = math.inf
cell_total_counts_max: float = -math.inf
cell_n_genes_sum: int = 0
cell_n_genes_min: int = 2**63 - 1
cell_n_genes_max: int = 0
# gene-level accumulators (length = n_vars, stored as list for serialisation)
gene_n_cells: list | None = None
gene_total_counts: list | None = None
# status
status: str = "ok"
error: str = ""
elapsed_sec: float = 0.0
def merge_slice_results(slices: list[SliceResult], n_obs: int, n_vars: int) -> dict:
"""Merge many SliceResults into one per-dataset summary dict.
Uses O(n_vars) memory for gene arrays, everything else O(1).
"""
nnz_total = 0
x_sum = 0.0
x_sum_sq = 0.0
n_obs_seen = 0
cell_total_counts_sum = 0.0
cell_total_counts_min = math.inf
cell_total_counts_max = -math.inf
cell_n_genes_sum = 0
cell_n_genes_min = 2**63 - 1
cell_n_genes_max = 0
gene_n_cells = np.zeros(n_vars, dtype=np.int64)
gene_total_counts = np.zeros(n_vars, dtype=np.float64)
for s in slices:
if s.status != "ok":
continue
n_obs_seen += s.n_obs_slice
nnz_total += s.nnz
x_sum += s.x_sum
x_sum_sq += s.x_sum_sq
cell_total_counts_sum += s.cell_total_counts_sum
cell_total_counts_min = min(cell_total_counts_min, s.cell_total_counts_min)
cell_total_counts_max = max(cell_total_counts_max, s.cell_total_counts_max)
cell_n_genes_sum += s.cell_n_genes_sum
cell_n_genes_min = min(cell_n_genes_min, s.cell_n_genes_min)
cell_n_genes_max = max(cell_n_genes_max, s.cell_n_genes_max)
if s.gene_n_cells is not None:
gene_n_cells += np.asarray(s.gene_n_cells, dtype=np.int64)
if s.gene_total_counts is not None:
gene_total_counts += np.asarray(s.gene_total_counts, dtype=np.float64)
total_entries = n_obs * n_vars
row: dict[str, Any] = {
"n_obs": n_obs,
"n_vars": n_vars,
"n_obs_processed": n_obs_seen,
"nnz": int(nnz_total),
"sparsity": float(1.0 - nnz_total / total_entries) if total_entries else None,
"x_mean": float(x_sum / total_entries) if total_entries else None,
}
if total_entries:
var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2)
row["x_std"] = float(math.sqrt(var))
else:
row["x_std"] = None
# Cell-level summaries
if n_obs_seen > 0:
row["cell_total_counts_min"] = float(cell_total_counts_min)
row["cell_total_counts_max"] = float(cell_total_counts_max)
row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs_seen)
row["cell_n_genes_detected_min"] = int(cell_n_genes_min)
row["cell_n_genes_detected_max"] = int(cell_n_genes_max)
row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs_seen)
else:
row["cell_total_counts_min"] = None
row["cell_total_counts_max"] = None
row["cell_total_counts_mean"] = None
row["cell_n_genes_detected_min"] = None
row["cell_n_genes_detected_max"] = None
row["cell_n_genes_detected_mean"] = None
# Gene-level summaries
genes_detected = int(np.count_nonzero(gene_n_cells))
row["genes_detected_in_any_cell"] = genes_detected
row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0
if genes_detected > 0:
mask = gene_n_cells > 0
row["gene_n_cells_min"] = int(gene_n_cells[mask].min())
row["gene_n_cells_max"] = int(gene_n_cells[mask].max())
row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean())
row["gene_total_counts_min"] = float(gene_total_counts[mask].min())
row["gene_total_counts_max"] = float(gene_total_counts[mask].max())
row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean())
else:
for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean",
"gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"):
row[k] = 0
# Clean up
del gene_n_cells, gene_total_counts
return row
# ---------------------------------------------------------------------------
# Simple worker function for small datasets (no Dask overhead)
# ---------------------------------------------------------------------------
def process_dataset_simple(
path_str: str,
n_obs: int,
n_vars: int,
chunk_size: int,
max_meta_cols: int,
max_categories: int,
) -> dict:
"""Process entire small dataset in one worker (no slicing, no Dask)."""
t0 = time.time()
path = Path(path_str)
row: dict[str, Any] = {
"dataset_path": path_str,
"dataset_file": path.name,
"n_obs": n_obs,
"n_vars": n_vars,
}
try:
adata = ad.read_h5ad(path, backed="r")
total_entries = n_obs * n_vars
nnz_total = 0
x_sum = 0.0
x_sum_sq = 0.0
# Cell-level accumulators
cell_total_counts_sum = 0.0
cell_total_counts_min = math.inf
cell_total_counts_max = -math.inf
cell_n_genes_sum = 0
cell_n_genes_min = 2**63 - 1
cell_n_genes_max = 0
# Gene-level accumulators
gene_n_cells = np.zeros(n_vars, dtype=np.int64)
gene_total_counts = np.zeros(n_vars, dtype=np.float64)
# Process in chunks
for start in range(0, n_obs, chunk_size):
end = min(start + chunk_size, n_obs)
chunk = adata.X[start:end, :]
if sparse.issparse(chunk):
csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk
data = csr.data.astype(np.float64, copy=False)
nnz_total += int(csr.nnz)
x_sum += float(data.sum())
x_sum_sq += float(np.square(data).sum())
# Cell stats
cell_counts = np.asarray(csr.sum(axis=1)).ravel()
cell_genes = np.diff(csr.indptr).astype(np.int64)
cell_total_counts_sum += float(cell_counts.sum())
cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
cell_n_genes_sum += int(cell_genes.sum())
cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
# Gene stats
csc = csr.tocsc()
gene_n_cells += np.diff(csc.indptr).astype(np.int64)
gene_total_counts += np.asarray(csc.sum(axis=0)).ravel()
del csr, csc, data
else:
arr = np.asarray(chunk, dtype=np.float64)
nz = arr != 0
nnz_total += int(nz.sum())
x_sum += float(arr.sum())
x_sum_sq += float(np.square(arr).sum())
# Cell stats
cell_counts = arr.sum(axis=1)
cell_genes = nz.sum(axis=1).astype(np.int64)
cell_total_counts_sum += float(cell_counts.sum())
cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
cell_n_genes_sum += int(cell_genes.sum())
cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
# Gene stats
gene_n_cells += nz.sum(axis=0).astype(np.int64)
gene_total_counts += arr.sum(axis=0)
del arr, nz
del chunk
gc.collect()
# Matrix-level stats
row["nnz"] = int(nnz_total)
row["sparsity"] = float(1.0 - nnz_total / total_entries) if total_entries else None
row["x_mean"] = float(x_sum / total_entries) if total_entries else None
if total_entries:
var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2)
row["x_std"] = float(math.sqrt(var))
else:
row["x_std"] = None
# Cell-level stats
if n_obs > 0:
row["cell_total_counts_min"] = float(cell_total_counts_min)
row["cell_total_counts_max"] = float(cell_total_counts_max)
row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs)
row["cell_n_genes_detected_min"] = int(cell_n_genes_min)
row["cell_n_genes_detected_max"] = int(cell_n_genes_max)
row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs)
else:
row["cell_total_counts_min"] = None
row["cell_total_counts_max"] = None
row["cell_total_counts_mean"] = None
row["cell_n_genes_detected_min"] = None
row["cell_n_genes_detected_max"] = None
row["cell_n_genes_detected_mean"] = None
# Gene-level stats
genes_detected = int(np.count_nonzero(gene_n_cells))
row["genes_detected_in_any_cell"] = genes_detected
row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0
if genes_detected > 0:
mask = gene_n_cells > 0
row["gene_n_cells_min"] = int(gene_n_cells[mask].min())
row["gene_n_cells_max"] = int(gene_n_cells[mask].max())
row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean())
row["gene_total_counts_min"] = float(gene_total_counts[mask].min())
row["gene_total_counts_max"] = float(gene_total_counts[mask].max())
row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean())
else:
for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean",
"gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"):
row[k] = 0
# Metadata
row["obs_columns"] = int(len(adata.obs.columns))
row["var_columns"] = int(len(adata.var.columns))
row["metadata_obs_summary"] = summarize_metadata(
adata.obs, max_cols=max_meta_cols, max_categories=max_categories
)
row["metadata_var_summary"] = summarize_metadata(
adata.var, max_cols=max_meta_cols, max_categories=max_categories
)
row["obs_schema"] = extract_schema(adata.obs)
row["var_schema"] = extract_schema(adata.var)
# Clean up
del gene_n_cells, gene_total_counts
try:
if hasattr(adata, "file") and adata.file is not None:
adata.file.close()
except Exception:
pass
del adata
row["status"] = "ok"
row["n_slices_total"] = 1
row["n_slices_ok"] = 1
row["n_slices_failed"] = 0
except Exception as exc:
row["status"] = "failed"
row["error"] = str(exc)
gc.collect()
row["elapsed_sec"] = round(time.time() - t0, 2)
return row
# ---------------------------------------------------------------------------
# Core worker function: process ONE slice of ONE dataset (Dask)
# ---------------------------------------------------------------------------
def process_slice(
path_str: str,
obs_start: int,
obs_end: int,
chunk_size: int,
) -> SliceResult:
"""Process rows [obs_start, obs_end) of a dataset.
Memory usage bounded by: chunk_size * n_vars * ~12 bytes * 3x overhead.
"""
t0 = time.time()
path = Path(path_str)
result = SliceResult(dataset_path=path_str, slice_start=obs_start, slice_end=obs_end)
try:
adata = ad.read_h5ad(path, backed="r")
n_vars = int(adata.n_vars)
result.n_vars = n_vars
result.n_obs_slice = obs_end - obs_start
# Gene-level accumulators for this slice
gene_n_cells = np.zeros(n_vars, dtype=np.int64)
gene_total_counts = np.zeros(n_vars, dtype=np.float64)
# Process in sub-chunks within this slice
for start in range(obs_start, obs_end, chunk_size):
end = min(start + chunk_size, obs_end)
chunk = adata.X[start:end, :]
if sparse.issparse(chunk):
csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk
data = csr.data.astype(np.float64, copy=False)
result.nnz += int(csr.nnz)
result.x_sum += float(data.sum())
result.x_sum_sq += float(np.square(data).sum())
# Cell stats
cell_counts = np.asarray(csr.sum(axis=1)).ravel()
cell_genes = np.diff(csr.indptr).astype(np.int64)
# Gene stats (optimized: use bincount instead of CSC conversion)
# Accumulate counts directly from CSR indices/data
gene_total_counts += np.bincount(
csr.indices,
weights=data,
minlength=n_vars
)
gene_n_cells += np.bincount(
csr.indices,
minlength=n_vars
)
del csr, data
else:
arr = np.asarray(chunk, dtype=np.float64)
nz = arr != 0
result.nnz += int(nz.sum())
result.x_sum += float(arr.sum())
result.x_sum_sq += float(np.square(arr).sum())
# Cell stats
cell_counts = arr.sum(axis=1)
cell_genes = nz.sum(axis=1).astype(np.int64)
# Gene stats
gene_n_cells += nz.sum(axis=0).astype(np.int64)
gene_total_counts += arr.sum(axis=0)
del arr, nz
# Update cell-level running stats
result.cell_total_counts_sum += float(cell_counts.sum())
result.cell_total_counts_min = min(result.cell_total_counts_min, float(cell_counts.min()))
result.cell_total_counts_max = max(result.cell_total_counts_max, float(cell_counts.max()))
result.cell_n_genes_sum += int(cell_genes.sum())
result.cell_n_genes_min = min(result.cell_n_genes_min, int(cell_genes.min()))
result.cell_n_genes_max = max(result.cell_n_genes_max, int(cell_genes.max()))
del chunk, cell_counts, cell_genes
gc.collect()
# Store gene arrays as lists for serialisation
result.gene_n_cells = gene_n_cells.tolist()
result.gene_total_counts = gene_total_counts.tolist()
del gene_n_cells, gene_total_counts
# Close file
try:
if hasattr(adata, "file") and adata.file is not None:
adata.file.close()
except Exception:
pass
del adata
except Exception as exc:
result.status = "failed"
result.error = str(exc)
gc.collect()
result.elapsed_sec = round(time.time() - t0, 2)
return result
# ---------------------------------------------------------------------------
# Metadata helpers (run on scheduler, not workers)
# ---------------------------------------------------------------------------
def safe_name(path: Path) -> str:
"""Generate safe filename from path."""
digest = hashlib.md5(str(path).encode("utf-8"), usedforsecurity=False).hexdigest()[:10]
stem = path.stem.replace(" ", "_")
if len(stem) > 80:
stem = stem[:80]
return f"{stem}_{digest}"
def summarize_metadata(df: pd.DataFrame, max_cols: int, max_categories: int) -> dict[str, dict]:
"""Summarize DataFrame metadata with top categories."""
if df.empty:
return {}
preferred = ["cell_type", "assay", "tissue", "disease", "sex", "donor_id"]
selected: list[str] = [c for c in preferred if c in df.columns]
for col in df.columns:
if col not in selected:
selected.append(col)
if len(selected) >= max_cols:
break
out: dict[str, dict] = {}
n_rows = max(1, len(df))
for col in selected:
s = df[col]
summary: dict[str, Any] = {
"dtype": str(s.dtype),
"missing_fraction": float(s.isna().sum()) / n_rows,
}
if isinstance(s.dtype, pd.CategoricalDtype):
summary["n_unique"] = int(len(s.cat.categories))
vc = s.value_counts(dropna=False).head(max_categories)
summary["top_values"] = {str(k): int(v) for k, v in vc.items()}
elif pd.api.types.is_string_dtype(s.dtype) or s.dtype == object:
s_str = s.dropna().astype(str)
summary["n_unique"] = int(s_str.nunique())
vc = s_str.value_counts(dropna=False).head(max_categories)
summary["top_values"] = {str(k): int(v) for k, v in vc.items()}
out[col] = summary
return out
def extract_schema(df: pd.DataFrame) -> dict[str, object]:
"""Extract DataFrame schema."""
return {
"n_columns": int(len(df.columns)),
"columns": [str(c) for c in df.columns],
"dtypes": {str(c): str(df[c].dtype) for c in df.columns},
}
def extract_metadata_on_scheduler(
path: Path,
max_meta_cols: int,
max_categories: int,
) -> dict:
"""Extract obs/var metadata. Runs on scheduler (lightweight, no X access)."""
try:
adata = ad.read_h5ad(path, backed="r")
result = {
"obs_columns": int(len(adata.obs.columns)),
"var_columns": int(len(adata.var.columns)),
"metadata_obs_summary": summarize_metadata(
adata.obs, max_cols=max_meta_cols, max_categories=max_categories
),
"metadata_var_summary": summarize_metadata(
adata.var, max_cols=max_meta_cols, max_categories=max_categories
),
"obs_schema": extract_schema(adata.obs),
"var_schema": extract_schema(adata.var),
}
try:
if hasattr(adata, "file") and adata.file is not None:
adata.file.close()
except Exception:
pass
del adata
gc.collect()
return result
except Exception as exc:
return {"metadata_error": str(exc)}
# ---------------------------------------------------------------------------
# Dask configuration
# ---------------------------------------------------------------------------
def configure_dask_for_hpc() -> None:
"""Configure Dask for HPC with aggressive memory management."""
dask.config.set({
"distributed.worker.memory.target": 0.60,
"distributed.worker.memory.spill": 0.70,
"distributed.worker.memory.pause": 0.80,
"distributed.worker.memory.terminate": 0.95,
"distributed.worker.daemon": False,
"distributed.worker.use-file-locking": False,
"distributed.scheduler.allowed-failures": 10,
"distributed.scheduler.work-stealing": True,
"distributed.scheduler.work-stealing-interval": "100ms",
"distributed.comm.timeouts.connect": "120s",
"distributed.comm.timeouts.tcp": "120s",
"distributed.admin.tick.interval": "2s",
"distributed.admin.log-length": 500,
})
# ---------------------------------------------------------------------------
# Config / metadata helpers
# ---------------------------------------------------------------------------
def load_config(config_path: Path) -> dict:
with open(config_path) as f:
return yaml.safe_load(f)
def load_enhanced_metadata(cache_path: Path) -> pd.DataFrame:
if not cache_path.exists():
raise FileNotFoundError(
f"Enhanced metadata cache not found: {cache_path}\n"
"Run: uv run python scripts/build_metadata_cache.py --config <config.yaml>"
)
return pd.read_parquet(cache_path)
def get_datasets_for_shard(
metadata_df: pd.DataFrame,
config: dict,
num_shards: int,
shard_index: int,
) -> list[dict]:
"""Get dataset info for this shard.
Returns list of dicts with keys: dataset_path, n_obs, n_vars, total_entries.
"""
if num_shards > 1:
metadata_df = metadata_df.sort_values("total_entries", ascending=False).reset_index(drop=True)
shard_df = metadata_df[metadata_df.index % num_shards == shard_index].copy()
else:
shard_df = metadata_df.copy()
shard_df = shard_df[shard_df["status"].str.startswith("ok", na=False)].copy()
max_entries = config["dataset_thresholds"]["max_entries"]
shard_df = shard_df[shard_df["total_entries"] <= max_entries].copy()
datasets = []
for _, row in shard_df.iterrows():
# Normalize path to absolute to avoid duplicates from relative/absolute mixups
dataset_path = Path(str(row["dataset_path"])).resolve()
datasets.append({
"dataset_path": str(dataset_path),
"n_obs": int(row.get("n_obs", 0)),
"n_vars": int(row.get("n_vars", 0)),
"total_entries": int(row.get("total_entries", 0)),
"size_category": str(row.get("size_category", "large")), # Include size category
})
return datasets
# ---------------------------------------------------------------------------
# Main processing pipeline
# ---------------------------------------------------------------------------
def create_slice_tasks(
dataset: dict,
obs_slice_size: int,
small_dataset_threshold: int,
) -> list[tuple[str, int, int]]:
"""Create (path, start, end) slice tasks for a dataset.
Small datasets (< threshold): Single task for entire dataset (faster, no slicing overhead)
Medium/Large datasets: Sliced into obs_slice_size chunks (memory-safe)
"""
path = dataset["dataset_path"]
n_obs = dataset["n_obs"]
total_entries = dataset.get("total_entries", n_obs * dataset.get("n_vars", 0))
if n_obs <= 0:
return [(path, 0, 0)]
# For small datasets, process entire dataset in one task (no slicing overhead)
if total_entries < small_dataset_threshold:
return [(path, 0, n_obs)]
# For medium/large datasets, slice to manage memory
tasks = []
for start in range(0, n_obs, obs_slice_size):
end = min(start + obs_slice_size, n_obs)
tasks.append((path, start, end))
return tasks
def process_all_datasets(
datasets: list[dict],
config: dict,
per_dataset_dir: Path,
client: Client | None,
max_retries: int = 3,
) -> tuple[list[dict], list[dict]]:
"""Process all datasets: small ones with ProcessPoolExecutor, large ones with Dask."""
base_chunk_size = config["resources"]["chunk_size"]
base_obs_slice_size = config["slicing"].get("obs_slice_size", 75_000)
obs_slice_size_xlarge = config["slicing"].get("obs_slice_size_xlarge", 150_000)
small_threshold = config["dataset_thresholds"]["small"]
max_meta_cols = config["metadata"]["max_meta_cols"]
max_categories = config["metadata"]["max_categories"]
max_workers_base = config["resources"]["max_workers"]
# Helper function to get adjusted parameters based on size category
def get_dataset_params(dataset):
size_cat = dataset.get("size_category", "large")
strategy = config.get("strategy", {}).get(size_cat, config["strategy"]["large"])
chunk_mult = strategy.get("chunk_size_multiplier", 1.0)
chunk_size = int(base_chunk_size * chunk_mult)
# Use smaller slice size for xlarge datasets
if size_cat == "xlarge":
obs_slice = obs_slice_size_xlarge
else:
obs_slice = base_obs_slice_size
return chunk_size, obs_slice, size_cat
successes = []
failures = []
# Categorize datasets: small, dask-ready (medium/large), xlarge (skip Dask)
small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold]
non_small = [d for d in datasets if d.get("total_entries", 0) >= small_threshold]
# Split non-small into Dask-compatible and xlarge (which skip Dask due to failures)
dask_datasets = [d for d in non_small if d.get("size_category", "large") != "xlarge"]
xlarge_datasets = [d for d in non_small if d.get("size_category", "large") == "xlarge"]
small_datasets.sort(key=lambda d: d["total_entries"])
dask_datasets.sort(key=lambda d: d["total_entries"])
xlarge_datasets.sort(key=lambda d: d["total_entries"])
datasets_sorted = small_datasets + dask_datasets + xlarge_datasets
small_count = len(small_datasets)
dask_count = len(dask_datasets)
xlarge_count = len(xlarge_datasets)
datasets_sorted = small_datasets + dask_datasets + xlarge_datasets
small_count = len(small_datasets)
dask_count = len(dask_datasets)
xlarge_count = len(xlarge_datasets)
print(f"\n{'=' * 80}")
print(f"Processing {len(datasets_sorted)} datasets")
print(f" Small datasets (ProcessPoolExecutor): {small_count}")
print(f" Medium/Large (Dask + slicing): {dask_count}")
print(f" XLarge (Direct, skip Dask): {xlarge_count}")
print(f"Slice size: {base_obs_slice_size:,} rows (medium/large), {obs_slice_size_xlarge:,} rows (xlarge)")
print(f"Small threshold: {small_threshold:,} entries")
print(f"Base chunk size: {base_chunk_size:,} rows (adjusted per dataset size)")
print(f"{'=' * 80}\n")
total_datasets = len(datasets_sorted)
# ========================================================================
# Phase 1: Process small datasets with ProcessPoolExecutor (batched)
# ========================================================================
if small_count > 0:
print(f"{'='*80}")
print(f"PHASE 1: Small datasets ({small_count}) - ProcessPoolExecutor")
print(f"{'='*80}\n")
# Adaptive worker management
current_workers = max_workers_base
min_workers_ratio = config["resources"].get("min_workers_ratio", 0.25)
min_workers = max(1, int(max_workers_base * min_workers_ratio))
batch_size = max(30, min(100, small_count // 4))
# Throughput monitoring
check_interval = 50
baseline_throughput = None
slowdown_threshold = config["resources"].get("slowdown_threshold", 0.5)
last_check_idx = 0
batch_start_time = time.time()
print(f"Workers: {current_workers} (adaptive: {min_workers}-{max_workers_base})")
print(f"Batch size: {batch_size} (recycled between batches)\n")
with tqdm(total=small_count, desc="Small datasets", position=0) as pbar:
for batch_start in range(0, small_count, batch_size):
batch_end = min(batch_start + batch_size, small_count)
batch = small_datasets[batch_start:batch_end]
# Check throughput and adjust workers
processed = len(successes) + len(failures)
if processed >= last_check_idx + check_interval and processed > check_interval:
elapsed = time.time() - batch_start_time
current_throughput = processed / elapsed if elapsed > 0 else 0
if baseline_throughput is None and processed >= check_interval * 2:
baseline_throughput = current_throughput
tqdm.write(f"Baseline: {baseline_throughput:.2f} ds/sec")
if baseline_throughput and current_throughput < baseline_throughput * slowdown_threshold:
if current_workers > min_workers:
old_workers = current_workers
current_workers = max(min_workers, current_workers // 2)
tqdm.write(f"⚠️ Slowdown detected. Workers: {old_workers}{current_workers}")
baseline_throughput = None
last_check_idx = processed
# Process batch
executor = concurrent.futures.ProcessPoolExecutor(max_workers=current_workers)
futures = {}
try:
for dataset in batch:
# Get chunk size for this dataset
chunk_size, _, _ = get_dataset_params(dataset)
future = executor.submit(
process_dataset_simple,
dataset["dataset_path"],
dataset["n_obs"],
dataset["n_vars"],
chunk_size,
max_meta_cols,
max_categories,
)
futures[future] = dataset
for future in concurrent.futures.as_completed(futures):
dataset = futures[future]
ds_path = dataset["dataset_path"]
ds_name = Path(ds_path).name
try:
row = future.result(timeout=3600)
# File size
try:
row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4)
except Exception:
pass
# Save JSON
try:
payload_name = safe_name(Path(ds_path)) + ".json"
(per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2))
except Exception as exc:
row["save_error"] = str(exc)
if row.get("status") == "ok":
successes.append(row)
elapsed = row.get("elapsed_sec", "?")
tqdm.write(f" [{len(successes)}/{total_datasets}] ✓ {ds_name[:50]} | {elapsed}s")
else:
failures.append(row)
error = row.get("error", "Unknown")[:60]
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {error}")
except concurrent.futures.TimeoutError:
failures.append({
"dataset_path": ds_path,
"dataset_file": ds_name,
"status": "failed",
"error": "Timeout",
})
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | Timeout")
except Exception as exc:
failures.append({
"dataset_path": ds_path,
"dataset_file": ds_name,
"status": "failed",
"error": str(exc),
})
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {exc}")
finally:
pbar.update(1)
finally:
executor.shutdown(wait=True)
gc.collect()
time.sleep(1)
print(f"\nPhase 1 complete: {len([s for s in successes if s in successes[-small_count:]])} ok, " +
f"{len([f for f in failures if f in failures[-small_count:]])} failed\n")
# ========================================================================
# Phase 2: Process medium/large datasets with Dask
# ========================================================================
if dask_count > 0 and client:
print(f"{'='*80}")
print(f"PHASE 2: Medium/Large datasets ({dask_count}) - Dask + slicing")
print(f"{'='*80}\n")
with tqdm(
total=dask_count,
desc="Med/Large datasets",
position=0,
leave=True,
ncols=100
) as dataset_pbar:
for ds_local_idx, dataset in enumerate(dask_datasets):
dataset_idx = small_count + ds_local_idx
ds_path = dataset["dataset_path"]
ds_name = Path(ds_path).name
n_obs = dataset["n_obs"]
n_vars = dataset["n_vars"]
total_entries = dataset["total_entries"]
# Get size-specific parameters
chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset)
t0 = time.time()
# Create slice tasks with adjusted slice size
slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
n_slices = len(slice_tasks)
dataset_pbar.set_description(f"Med/Large [{ds_local_idx + 1}/{dask_count}] ({size_cat})")
# Submit all slices for this dataset
slice_results: list[SliceResult] = []
failed_slices: list[tuple[str, int, int]] = []
# Submit slice tasks to Dask
futures = client.map(
lambda t: process_slice(t[0], t[1], t[2], chunk_size),
slice_tasks,
pure=False,
)
# Collect results with progress bar (show for sliced datasets)
show_slice_bar = n_slices > 1
slice_pbar = tqdm(
total=n_slices,
desc=f" \u2514\u2500 Slices",
position=1,
leave=False,
ncols=100,
disable=not show_slice_bar
) if show_slice_bar else None
if slice_pbar:
slice_pbar.set_postfix(ok=0, fail=0)
for task, future in zip(slice_tasks, futures):
try:
sr = future.result(timeout=3600)
if sr.status == "ok":
slice_results.append(sr)
else:
failed_slices.append(task)
except Exception:
failed_slices.append(task)
finally:
if slice_pbar:
slice_pbar.set_postfix(ok=len(slice_results), fail=len(failed_slices))
slice_pbar.update(1)
if slice_pbar:
slice_pbar.close()
# Retry failed slices
for retry in range(max_retries):
if not failed_slices:
break
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Retry {retry + 1}/{max_retries}: {len(failed_slices)} failed slices")
time.sleep(1)
retry_futures = client.map(
lambda t: process_slice(t[0], t[1], t[2], chunk_size),
failed_slices,
pure=False,
)
next_failed = []
for task, future in zip(failed_slices, retry_futures):
try:
sr = future.result(timeout=3600)
if sr.status == "ok":
slice_results.append(sr)
else:
next_failed.append(task)
except Exception:
next_failed.append(task)
failed_slices = next_failed
# EMERGENCY MODE: If still failing after all retries, use extreme settings
if failed_slices and len(failed_slices) > 0:
emergency_chunk = max(10000, chunk_size // 10) # Use 10% of original or 10K min
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ⚠️ EMERGENCY MODE: {len(failed_slices)} slices with extreme settings (chunk={emergency_chunk:,})")
time.sleep(2)
# Process failed slices one at a time with minimal chunk size
emergency_ok = 0
for task in failed_slices:
try:
future = client.submit(
process_slice, task[0], task[1], task[2], emergency_chunk,
pure=False,
)
sr = future.result(timeout=7200) # 2 hour timeout for extreme cases
if sr.status == "ok":
slice_results.append(sr)
emergency_ok += 1
except Exception as e:
tqdm.write(f" Emergency failed for slice {task[1]}-{task[2]}: {str(e)[:100]}")
continue
if emergency_ok > 0:
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✓ Emergency mode recovered {emergency_ok}/{len(failed_slices)} slices")
# Update failed_slices to only those that still failed
failed_slices = [t for t in failed_slices if not any(
sr.slice_start == t[1] and sr.slice_end == t[2] for sr in slice_results
)]
# Check if we got enough data
ok_count = len(slice_results)
fail_count = len(failed_slices)
elapsed = round(time.time() - t0, 1)
if ok_count == 0:
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all {n_slices} slices failed | {elapsed}s")
failures.append({
"dataset_path": ds_path,
"dataset_file": ds_name,
"status": "failed",
"error": f"All {n_slices} slices failed",
"elapsed_sec": elapsed,
})
dataset_pbar.update(1)
continue
# Merge slice results into dataset summary
row = merge_slice_results(slice_results, n_obs, n_vars)
row["dataset_path"] = ds_path
row["dataset_file"] = ds_name
row["n_slices_total"] = n_slices
row["n_slices_ok"] = ok_count
row["n_slices_failed"] = fail_count
# File size
try:
row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4)
except Exception:
pass
# Extract metadata (lightweight, on scheduler)
meta = extract_metadata_on_scheduler(
Path(ds_path), max_meta_cols, max_categories
)
row.update(meta)
row["status"] = "ok" if fail_count == 0 else "partial"
row["elapsed_sec"] = elapsed
# Save per-dataset JSON
try:
payload_name = safe_name(Path(ds_path)) + ".json"
(per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2))
except Exception as exc:
row["save_error"] = str(exc)
successes.append(row)
status = "✓" if fail_count == 0 else "⚠"
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s")
# Free memory
del slice_results
gc.collect()
# Update dataset progress
dataset_pbar.update(1)
print(f"\nPhase 2 complete\n")
# Close Dask cluster before Phase 3 (xlarge direct processing doesn't use Dask)
if xlarge_count > 0 and client:
print("Closing Dask cluster before Phase 3 (xlarge datasets process directly)...")
try:
client.close()
del client
gc.collect()
time.sleep(2)
except Exception as e:
print(f"Warning: Error closing Dask client: {e}")
# ========================================================================
# Phase 3: Process xlarge datasets DIRECTLY (skip Dask - causes failures)
# ========================================================================
if xlarge_count > 0:
print(f"{'='*80}")
print(f"PHASE 3: XLarge datasets ({xlarge_count}) - Direct processing (no Dask)")
print(f"{'='*80}\n")
with tqdm(
total=xlarge_count,
desc="XLarge datasets",
position=0,
leave=True,
ncols=100
) as dataset_pbar:
for ds_local_idx, dataset in enumerate(xlarge_datasets):
dataset_idx = small_count + dask_count + ds_local_idx
ds_path = dataset["dataset_path"]
ds_name = Path(ds_path).name
n_obs = dataset["n_obs"]
n_vars = dataset["n_vars"]
# Get xlarge-specific parameters
chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset)
t0 = time.time()
# Create slice tasks
slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
n_slices = len(slice_tasks)
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Processing {ds_name[:50]} | {n_slices} slices | chunk={chunk_size:,}")
# Process slices DIRECTLY without Dask (one at a time)
slice_results: list[SliceResult] = []
for slice_idx, (path, start, end) in enumerate(slice_tasks):
try:
sr = process_slice(path, start, end, chunk_size)
if sr.status == "ok":
slice_results.append(sr)
else:
tqdm.write(f" Slice {slice_idx+1}/{n_slices} failed: {sr.error}")
except Exception as e:
tqdm.write(f" Slice {slice_idx+1}/{n_slices} error: {str(e)[:100]}")
ok_count = len(slice_results)
fail_count = n_slices - ok_count
elapsed = round(time.time() - t0, 1)
if ok_count == 0:
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all slices failed | {elapsed}s")
failures.append({
"dataset_path": ds_path,
"dataset_file": ds_name,
"status": "failed",
"error": f"All {n_slices} slices failed (xlarge direct mode)",
"elapsed_sec": elapsed,
})
dataset_pbar.update(1)
continue
# Merge results
row = merge_slice_results(slice_results, n_obs, n_vars)
row["dataset_path"] = ds_path
row["dataset_file"] = ds_name
row["n_slices_total"] = n_slices
row["n_slices_ok"] = ok_count
row["n_slices_failed"] = fail_count
row["processing_mode"] = "xlarge_direct"
# File size
try:
row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4)
except Exception:
pass
# Extract metadata
meta = extract_metadata_on_scheduler(
Path(ds_path), max_meta_cols, max_categories
)
row.update(meta)
row["status"] = "ok" if fail_count == 0 else "partial"
row["elapsed_sec"] = elapsed
# Save per-dataset JSON
try:
payload_name = safe_name(Path(ds_path)) + ".json"
(per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2))
except Exception as exc:
row["save_error"] = str(exc)
successes.append(row)
status = "✓" if fail_count == 0 else "⚠"
tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s")
dataset_pbar.update(1)
gc.collect()
print(f"\nPhase 3 complete\n")
return successes, failures
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--config", type=Path, required=True, help="YAML config")
parser.add_argument("--num-shards", type=int, help="Override num_shards")
parser.add_argument("--shard-index", type=int, help="Override shard_index")
parser.add_argument("--max-retries", type=int, default=3, help="Max retries per slice")
args = parser.parse_args()
config = load_config(args.config)
if args.num_shards is not None:
config["sharding"]["num_shards"] = args.num_shards
config["sharding"]["enabled"] = args.num_shards > 1
if args.shard_index is not None:
config["sharding"]["shard_index"] = args.shard_index
num_shards = config["sharding"]["num_shards"]
shard_index = config["sharding"]["shard_index"]
configure_dask_for_hpc()
# Load metadata
cache_path = Path(config["paths"]["enhanced_metadata_cache"])
if not cache_path.is_absolute():
cache_path = Path(args.config).parent.parent / cache_path
print(f"Loading metadata from: {cache_path}")
metadata_df = load_enhanced_metadata(cache_path)
datasets = get_datasets_for_shard(metadata_df, config, num_shards, shard_index)
if not datasets:
print("No datasets scheduled for this shard.")
return
# Output dirs
output_dir = Path(config["paths"]["output_dir"])
if not output_dir.is_absolute():
output_dir = Path(args.config).parent.parent / output_dir
output_dir.mkdir(parents=True, exist_ok=True)
per_dataset_dir = output_dir / "per_dataset"
per_dataset_dir.mkdir(parents=True, exist_ok=True)
# Filter out already-completed datasets (resume capability)
def is_dataset_done(ds_path: str) -> bool:
"""Check if dataset already has a successful result."""
try:
payload_name = safe_name(Path(ds_path)) + ".json"
result_file = per_dataset_dir / payload_name
if result_file.exists():
result_data = json.loads(result_file.read_text())
return result_data.get("status") == "ok"
except Exception:
pass
return False
original_count = len(datasets)
datasets = [d for d in datasets if not is_dataset_done(d["dataset_path"])]
skipped_count = original_count - len(datasets)
if skipped_count > 0:
print(f"\n{'='*80}")
print(f"RESUME MODE: Skipping {skipped_count} already-completed datasets")
print(f"Remaining to process: {len(datasets)}")
print(f"{'='*80}\n")
if not datasets:
print("All datasets already completed. Nothing to do.")
return
# Check if we need Dask cluster (for medium/large datasets)
small_threshold = config["dataset_thresholds"]["small"]
# Count datasets by processing type
small_count_init = sum(1 for d in datasets if d.get("total_entries", 0) < small_threshold)
dask_count_init = sum(1 for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge")
xlarge_count_init = sum(1 for d in datasets if d.get("size_category", "large") == "xlarge")
client = None
cluster = None
if dask_count_init > 0:
# Cluster setup for large datasets
max_memory_gib = config["resources"]["max_memory_gib"]
max_workers = config["resources"]["max_workers"]
min_workers = config["resources"].get("min_workers", min(4, max_workers))
threads_per_worker = config["resources"].get("threads_per_worker", 1)
# Adaptive scaling config
adaptive_config = config["resources"].get("adaptive_scaling", {})
target_duration = adaptive_config.get("target_duration", "30s")
wait_count = adaptive_config.get("wait_count", 3)
interval = adaptive_config.get("interval", "2s")
memory_per_worker_gib = max(2.0, max_memory_gib / max_workers)
total_entries = sum(d["total_entries"] for d in datasets)
total_slices = sum(
max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 75_000)))
for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge"
)
print(json.dumps({
"total_datasets": len(datasets),
"small_datasets": small_count_init,
"large_datasets": dask_count_init + xlarge_count_init,
"total_slices": total_slices,
"total_entries": total_entries,
"shard_index": shard_index,
"num_shards": num_shards,
"memory_per_worker_gib": round(memory_per_worker_gib, 1),
"max_workers": max_workers,
}, indent=2))
print(f"\nStarting Dask LocalCluster (for {dask_count_init} medium/large datasets):")
print(f" Workers: {min_workers} -> {max_workers} (adaptive)")
print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB")
print(f" Total memory budget: {max_memory_gib} GiB\n")
cluster = LocalCluster(
n_workers=min_workers,
threads_per_worker=threads_per_worker,
processes=True,
memory_limit=f"{memory_per_worker_gib}GiB",
silence_logs=True,
dashboard_address=None,
lifetime="180 minutes",
lifetime_stagger="20 minutes",
)
cluster.adapt(
minimum=min_workers,
maximum=max_workers,
target_duration=target_duration,
wait_count=wait_count,
interval=interval,
)
client = Client(cluster)
print(f"Dask cluster ready: {client}\n")
else:
print(f"No Dask-compatible datasets (all small or xlarge)\n")
if xlarge_count_init > 0:
print(f"Note: {xlarge_count_init} xlarge datasets will be processed directly (Phase 3, no Dask)\n")
try:
successes, failures = process_all_datasets(
datasets, config, per_dataset_dir, client,
max_retries=args.max_retries,
)
# Include previously completed datasets in final summary
if skipped_count > 0:
print(f"\nLoading {skipped_count} previously completed results...")
for json_file in per_dataset_dir.glob("*.json"):
try:
result = json.loads(json_file.read_text())
if result.get("status") == "ok":
# Check if not already in successes
ds_path = result.get("dataset_path", "")
if not any(s.get("dataset_path") == ds_path for s in successes):
successes.append(result)
except Exception:
pass
print(f"Total results (new + previous): {len(successes)}")
print(f"\n{'=' * 80}")
print(f"PROCESSING COMPLETE")
print(f" Succeeded: {len(successes)}")
print(f" Failed: {len(failures)}")
print(f" Success rate: {len(successes) / max(1, original_count) * 100:.1f}%")
print(f"{'=' * 80}\n")
if failures:
print("WARNING: Some datasets failed permanently:")
for fail in failures[:10]:
print(f" - {fail['dataset_file']}: {fail.get('error', 'Unknown')[:80]}")
if len(failures) > 10:
print(f" ... and {len(failures) - 10} more")
except KeyboardInterrupt:
print("\n\n{'=' * 80}")
print("INTERRUPTED - Saving partial results...")
print(f"{'=' * 80}\n")
successes = []
failures = []
seen_paths = set()
# Load all completed results from disk (deduplicate by dataset_path)
for json_file in per_dataset_dir.glob("*.json"):
try:
result = json.loads(json_file.read_text())
ds_path = result.get("dataset_path", "")
if ds_path and ds_path in seen_paths:
continue # Skip duplicate
seen_paths.add(ds_path)
if result.get("status") == "ok":
successes.append(result)
else:
failures.append(result)
except Exception:
pass
except Exception as exc:
print(f"\n\nERROR during processing: {exc}")
print("Saving partial results...")
successes = []
failures = []
seen_paths = set()
# Load all completed results from disk (deduplicate by dataset_path)
for json_file in per_dataset_dir.glob("*.json"):
try:
result = json.loads(json_file.read_text())
ds_path = result.get("dataset_path", "")
if ds_path and ds_path in seen_paths:
continue # Skip duplicate
seen_paths.add(ds_path)
if result.get("status") == "ok":
successes.append(result)
else:
failures.append(result)
except Exception:
pass
# Always save results, even on error/interrupt
try:
summary_df = pd.DataFrame(successes)
# Deduplicate by dataset_path (keep first/most recent)
if not summary_df.empty and 'dataset_path' in summary_df.columns:
original_count = len(summary_df)
summary_df = summary_df.drop_duplicates(subset=['dataset_path'], keep='first')
if len(summary_df) < original_count:
print(f"\nRemoved {original_count - len(summary_df)} duplicate entries from results")
summary_csv = output_dir / f"eda_summary_shard_{shard_index:03d}_of_{num_shards:03d}.csv"
summary_df.to_csv(summary_csv, index=False)
failures_path = output_dir / f"eda_failures_shard_{shard_index:03d}_of_{num_shards:03d}.json"
failures_path.write_text(json.dumps(failures, indent=2))
print(f"\n{'=' * 80}")
print("RESULTS SAVED")
print(f" Summary CSV: {summary_csv}")
print(f" Failures JSON: {failures_path}")
print(json.dumps({
"ok_count": len(successes),
"failed_count": len(failures),
}, indent=2))
print(f"{'=' * 80}\n")
except Exception as save_exc:
print(f"ERROR saving results: {save_exc}")
finally:
if client:
client.close()
if cluster:
cluster.close()
if __name__ == "__main__":
main()