#!/usr/bin/env python3 """Distributed EDA with chunk-slice architecture for billion-scale datasets. Each dataset is sliced into small row-ranges. Each slice is a Dask task that processes a bounded amount of memory (chunk_size * n_vars). Slice results are merged per-dataset on the scheduler side with O(1) memory. No quantiles - only non-zero min, max, mean, sparsity, gene-level stats, and metadata summaries. This handles datasets from 2 GB to 500 GB. """ from __future__ import annotations import argparse import concurrent.futures import gc import hashlib import json import math import time from dataclasses import dataclass from pathlib import Path from typing import Any import anndata as ad import dask import numpy as np import pandas as pd import yaml from dask.distributed import Client, LocalCluster from scipy import sparse from tqdm import tqdm # --------------------------------------------------------------------------- # Slice result: the only thing returned from each Dask task # --------------------------------------------------------------------------- @dataclass class SliceResult: """Mergeable statistics from one row-slice of a dataset. Every field is O(1) memory except gene arrays which are O(n_vars). All fields are JSON-serialisable after to_dict(). """ # identity dataset_path: str = "" slice_start: int = 0 slice_end: int = 0 # matrix global n_obs_slice: int = 0 n_vars: int = 0 nnz: int = 0 x_sum: float = 0.0 x_sum_sq: float = 0.0 # cell-level running stats (non-zero counts per cell) cell_total_counts_sum: float = 0.0 cell_total_counts_min: float = math.inf cell_total_counts_max: float = -math.inf cell_n_genes_sum: int = 0 cell_n_genes_min: int = 2**63 - 1 cell_n_genes_max: int = 0 # gene-level accumulators (length = n_vars, stored as list for serialisation) gene_n_cells: list | None = None gene_total_counts: list | None = None # status status: str = "ok" error: str = "" elapsed_sec: float = 0.0 def merge_slice_results(slices: list[SliceResult], n_obs: int, n_vars: int) -> dict: """Merge many SliceResults into one per-dataset summary dict. Uses O(n_vars) memory for gene arrays, everything else O(1). """ nnz_total = 0 x_sum = 0.0 x_sum_sq = 0.0 n_obs_seen = 0 cell_total_counts_sum = 0.0 cell_total_counts_min = math.inf cell_total_counts_max = -math.inf cell_n_genes_sum = 0 cell_n_genes_min = 2**63 - 1 cell_n_genes_max = 0 gene_n_cells = np.zeros(n_vars, dtype=np.int64) gene_total_counts = np.zeros(n_vars, dtype=np.float64) for s in slices: if s.status != "ok": continue n_obs_seen += s.n_obs_slice nnz_total += s.nnz x_sum += s.x_sum x_sum_sq += s.x_sum_sq cell_total_counts_sum += s.cell_total_counts_sum cell_total_counts_min = min(cell_total_counts_min, s.cell_total_counts_min) cell_total_counts_max = max(cell_total_counts_max, s.cell_total_counts_max) cell_n_genes_sum += s.cell_n_genes_sum cell_n_genes_min = min(cell_n_genes_min, s.cell_n_genes_min) cell_n_genes_max = max(cell_n_genes_max, s.cell_n_genes_max) if s.gene_n_cells is not None: gene_n_cells += np.asarray(s.gene_n_cells, dtype=np.int64) if s.gene_total_counts is not None: gene_total_counts += np.asarray(s.gene_total_counts, dtype=np.float64) total_entries = n_obs * n_vars row: dict[str, Any] = { "n_obs": n_obs, "n_vars": n_vars, "n_obs_processed": n_obs_seen, "nnz": int(nnz_total), "sparsity": float(1.0 - nnz_total / total_entries) if total_entries else None, "x_mean": float(x_sum / total_entries) if total_entries else None, } if total_entries: var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2) row["x_std"] = float(math.sqrt(var)) else: row["x_std"] = None # Cell-level summaries if n_obs_seen > 0: row["cell_total_counts_min"] = float(cell_total_counts_min) row["cell_total_counts_max"] = float(cell_total_counts_max) row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs_seen) row["cell_n_genes_detected_min"] = int(cell_n_genes_min) row["cell_n_genes_detected_max"] = int(cell_n_genes_max) row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs_seen) else: row["cell_total_counts_min"] = None row["cell_total_counts_max"] = None row["cell_total_counts_mean"] = None row["cell_n_genes_detected_min"] = None row["cell_n_genes_detected_max"] = None row["cell_n_genes_detected_mean"] = None # Gene-level summaries genes_detected = int(np.count_nonzero(gene_n_cells)) row["genes_detected_in_any_cell"] = genes_detected row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0 if genes_detected > 0: mask = gene_n_cells > 0 row["gene_n_cells_min"] = int(gene_n_cells[mask].min()) row["gene_n_cells_max"] = int(gene_n_cells[mask].max()) row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean()) row["gene_total_counts_min"] = float(gene_total_counts[mask].min()) row["gene_total_counts_max"] = float(gene_total_counts[mask].max()) row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean()) else: for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean", "gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"): row[k] = 0 # Clean up del gene_n_cells, gene_total_counts return row # --------------------------------------------------------------------------- # Simple worker function for small datasets (no Dask overhead) # --------------------------------------------------------------------------- def process_dataset_simple( path_str: str, n_obs: int, n_vars: int, chunk_size: int, max_meta_cols: int, max_categories: int, ) -> dict: """Process entire small dataset in one worker (no slicing, no Dask).""" t0 = time.time() path = Path(path_str) row: dict[str, Any] = { "dataset_path": path_str, "dataset_file": path.name, "n_obs": n_obs, "n_vars": n_vars, } try: adata = ad.read_h5ad(path, backed="r") total_entries = n_obs * n_vars nnz_total = 0 x_sum = 0.0 x_sum_sq = 0.0 # Cell-level accumulators cell_total_counts_sum = 0.0 cell_total_counts_min = math.inf cell_total_counts_max = -math.inf cell_n_genes_sum = 0 cell_n_genes_min = 2**63 - 1 cell_n_genes_max = 0 # Gene-level accumulators gene_n_cells = np.zeros(n_vars, dtype=np.int64) gene_total_counts = np.zeros(n_vars, dtype=np.float64) # Process in chunks for start in range(0, n_obs, chunk_size): end = min(start + chunk_size, n_obs) chunk = adata.X[start:end, :] if sparse.issparse(chunk): csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk data = csr.data.astype(np.float64, copy=False) nnz_total += int(csr.nnz) x_sum += float(data.sum()) x_sum_sq += float(np.square(data).sum()) # Cell stats cell_counts = np.asarray(csr.sum(axis=1)).ravel() cell_genes = np.diff(csr.indptr).astype(np.int64) cell_total_counts_sum += float(cell_counts.sum()) cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min())) cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max())) cell_n_genes_sum += int(cell_genes.sum()) cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min())) cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max())) # Gene stats csc = csr.tocsc() gene_n_cells += np.diff(csc.indptr).astype(np.int64) gene_total_counts += np.asarray(csc.sum(axis=0)).ravel() del csr, csc, data else: arr = np.asarray(chunk, dtype=np.float64) nz = arr != 0 nnz_total += int(nz.sum()) x_sum += float(arr.sum()) x_sum_sq += float(np.square(arr).sum()) # Cell stats cell_counts = arr.sum(axis=1) cell_genes = nz.sum(axis=1).astype(np.int64) cell_total_counts_sum += float(cell_counts.sum()) cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min())) cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max())) cell_n_genes_sum += int(cell_genes.sum()) cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min())) cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max())) # Gene stats gene_n_cells += nz.sum(axis=0).astype(np.int64) gene_total_counts += arr.sum(axis=0) del arr, nz del chunk gc.collect() # Matrix-level stats row["nnz"] = int(nnz_total) row["sparsity"] = float(1.0 - nnz_total / total_entries) if total_entries else None row["x_mean"] = float(x_sum / total_entries) if total_entries else None if total_entries: var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2) row["x_std"] = float(math.sqrt(var)) else: row["x_std"] = None # Cell-level stats if n_obs > 0: row["cell_total_counts_min"] = float(cell_total_counts_min) row["cell_total_counts_max"] = float(cell_total_counts_max) row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs) row["cell_n_genes_detected_min"] = int(cell_n_genes_min) row["cell_n_genes_detected_max"] = int(cell_n_genes_max) row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs) else: row["cell_total_counts_min"] = None row["cell_total_counts_max"] = None row["cell_total_counts_mean"] = None row["cell_n_genes_detected_min"] = None row["cell_n_genes_detected_max"] = None row["cell_n_genes_detected_mean"] = None # Gene-level stats genes_detected = int(np.count_nonzero(gene_n_cells)) row["genes_detected_in_any_cell"] = genes_detected row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0 if genes_detected > 0: mask = gene_n_cells > 0 row["gene_n_cells_min"] = int(gene_n_cells[mask].min()) row["gene_n_cells_max"] = int(gene_n_cells[mask].max()) row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean()) row["gene_total_counts_min"] = float(gene_total_counts[mask].min()) row["gene_total_counts_max"] = float(gene_total_counts[mask].max()) row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean()) else: for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean", "gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"): row[k] = 0 # Metadata row["obs_columns"] = int(len(adata.obs.columns)) row["var_columns"] = int(len(adata.var.columns)) row["metadata_obs_summary"] = summarize_metadata( adata.obs, max_cols=max_meta_cols, max_categories=max_categories ) row["metadata_var_summary"] = summarize_metadata( adata.var, max_cols=max_meta_cols, max_categories=max_categories ) row["obs_schema"] = extract_schema(adata.obs) row["var_schema"] = extract_schema(adata.var) # Clean up del gene_n_cells, gene_total_counts try: if hasattr(adata, "file") and adata.file is not None: adata.file.close() except Exception: pass del adata row["status"] = "ok" row["n_slices_total"] = 1 row["n_slices_ok"] = 1 row["n_slices_failed"] = 0 except Exception as exc: row["status"] = "failed" row["error"] = str(exc) gc.collect() row["elapsed_sec"] = round(time.time() - t0, 2) return row # --------------------------------------------------------------------------- # Core worker function: process ONE slice of ONE dataset (Dask) # --------------------------------------------------------------------------- def process_slice( path_str: str, obs_start: int, obs_end: int, chunk_size: int, ) -> SliceResult: """Process rows [obs_start, obs_end) of a dataset. Memory usage bounded by: chunk_size * n_vars * ~12 bytes * 3x overhead. """ t0 = time.time() path = Path(path_str) result = SliceResult(dataset_path=path_str, slice_start=obs_start, slice_end=obs_end) try: adata = ad.read_h5ad(path, backed="r") n_vars = int(adata.n_vars) result.n_vars = n_vars result.n_obs_slice = obs_end - obs_start # Gene-level accumulators for this slice gene_n_cells = np.zeros(n_vars, dtype=np.int64) gene_total_counts = np.zeros(n_vars, dtype=np.float64) # Process in sub-chunks within this slice for start in range(obs_start, obs_end, chunk_size): end = min(start + chunk_size, obs_end) chunk = adata.X[start:end, :] if sparse.issparse(chunk): csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk data = csr.data.astype(np.float64, copy=False) result.nnz += int(csr.nnz) result.x_sum += float(data.sum()) result.x_sum_sq += float(np.square(data).sum()) # Cell stats cell_counts = np.asarray(csr.sum(axis=1)).ravel() cell_genes = np.diff(csr.indptr).astype(np.int64) # Gene stats (optimized: use bincount instead of CSC conversion) # Accumulate counts directly from CSR indices/data gene_total_counts += np.bincount( csr.indices, weights=data, minlength=n_vars ) gene_n_cells += np.bincount( csr.indices, minlength=n_vars ) del csr, data else: arr = np.asarray(chunk, dtype=np.float64) nz = arr != 0 result.nnz += int(nz.sum()) result.x_sum += float(arr.sum()) result.x_sum_sq += float(np.square(arr).sum()) # Cell stats cell_counts = arr.sum(axis=1) cell_genes = nz.sum(axis=1).astype(np.int64) # Gene stats gene_n_cells += nz.sum(axis=0).astype(np.int64) gene_total_counts += arr.sum(axis=0) del arr, nz # Update cell-level running stats result.cell_total_counts_sum += float(cell_counts.sum()) result.cell_total_counts_min = min(result.cell_total_counts_min, float(cell_counts.min())) result.cell_total_counts_max = max(result.cell_total_counts_max, float(cell_counts.max())) result.cell_n_genes_sum += int(cell_genes.sum()) result.cell_n_genes_min = min(result.cell_n_genes_min, int(cell_genes.min())) result.cell_n_genes_max = max(result.cell_n_genes_max, int(cell_genes.max())) del chunk, cell_counts, cell_genes gc.collect() # Store gene arrays as lists for serialisation result.gene_n_cells = gene_n_cells.tolist() result.gene_total_counts = gene_total_counts.tolist() del gene_n_cells, gene_total_counts # Close file try: if hasattr(adata, "file") and adata.file is not None: adata.file.close() except Exception: pass del adata except Exception as exc: result.status = "failed" result.error = str(exc) gc.collect() result.elapsed_sec = round(time.time() - t0, 2) return result # --------------------------------------------------------------------------- # Metadata helpers (run on scheduler, not workers) # --------------------------------------------------------------------------- def safe_name(path: Path) -> str: """Generate safe filename from path.""" digest = hashlib.md5(str(path).encode("utf-8"), usedforsecurity=False).hexdigest()[:10] stem = path.stem.replace(" ", "_") if len(stem) > 80: stem = stem[:80] return f"{stem}_{digest}" def summarize_metadata(df: pd.DataFrame, max_cols: int, max_categories: int) -> dict[str, dict]: """Summarize DataFrame metadata with top categories.""" if df.empty: return {} preferred = ["cell_type", "assay", "tissue", "disease", "sex", "donor_id"] selected: list[str] = [c for c in preferred if c in df.columns] for col in df.columns: if col not in selected: selected.append(col) if len(selected) >= max_cols: break out: dict[str, dict] = {} n_rows = max(1, len(df)) for col in selected: s = df[col] summary: dict[str, Any] = { "dtype": str(s.dtype), "missing_fraction": float(s.isna().sum()) / n_rows, } if isinstance(s.dtype, pd.CategoricalDtype): summary["n_unique"] = int(len(s.cat.categories)) vc = s.value_counts(dropna=False).head(max_categories) summary["top_values"] = {str(k): int(v) for k, v in vc.items()} elif pd.api.types.is_string_dtype(s.dtype) or s.dtype == object: s_str = s.dropna().astype(str) summary["n_unique"] = int(s_str.nunique()) vc = s_str.value_counts(dropna=False).head(max_categories) summary["top_values"] = {str(k): int(v) for k, v in vc.items()} out[col] = summary return out def extract_schema(df: pd.DataFrame) -> dict[str, object]: """Extract DataFrame schema.""" return { "n_columns": int(len(df.columns)), "columns": [str(c) for c in df.columns], "dtypes": {str(c): str(df[c].dtype) for c in df.columns}, } def extract_metadata_on_scheduler( path: Path, max_meta_cols: int, max_categories: int, ) -> dict: """Extract obs/var metadata. Runs on scheduler (lightweight, no X access).""" try: adata = ad.read_h5ad(path, backed="r") result = { "obs_columns": int(len(adata.obs.columns)), "var_columns": int(len(adata.var.columns)), "metadata_obs_summary": summarize_metadata( adata.obs, max_cols=max_meta_cols, max_categories=max_categories ), "metadata_var_summary": summarize_metadata( adata.var, max_cols=max_meta_cols, max_categories=max_categories ), "obs_schema": extract_schema(adata.obs), "var_schema": extract_schema(adata.var), } try: if hasattr(adata, "file") and adata.file is not None: adata.file.close() except Exception: pass del adata gc.collect() return result except Exception as exc: return {"metadata_error": str(exc)} # --------------------------------------------------------------------------- # Dask configuration # --------------------------------------------------------------------------- def configure_dask_for_hpc() -> None: """Configure Dask for HPC with aggressive memory management.""" dask.config.set({ "distributed.worker.memory.target": 0.60, "distributed.worker.memory.spill": 0.70, "distributed.worker.memory.pause": 0.80, "distributed.worker.memory.terminate": 0.95, "distributed.worker.daemon": False, "distributed.worker.use-file-locking": False, "distributed.scheduler.allowed-failures": 10, "distributed.scheduler.work-stealing": True, "distributed.scheduler.work-stealing-interval": "100ms", "distributed.comm.timeouts.connect": "120s", "distributed.comm.timeouts.tcp": "120s", "distributed.admin.tick.interval": "2s", "distributed.admin.log-length": 500, }) # --------------------------------------------------------------------------- # Config / metadata helpers # --------------------------------------------------------------------------- def load_config(config_path: Path) -> dict: with open(config_path) as f: return yaml.safe_load(f) def load_enhanced_metadata(cache_path: Path) -> pd.DataFrame: if not cache_path.exists(): raise FileNotFoundError( f"Enhanced metadata cache not found: {cache_path}\n" "Run: uv run python scripts/build_metadata_cache.py --config " ) return pd.read_parquet(cache_path) def get_datasets_for_shard( metadata_df: pd.DataFrame, config: dict, num_shards: int, shard_index: int, ) -> list[dict]: """Get dataset info for this shard. Returns list of dicts with keys: dataset_path, n_obs, n_vars, total_entries. """ if num_shards > 1: metadata_df = metadata_df.sort_values("total_entries", ascending=False).reset_index(drop=True) shard_df = metadata_df[metadata_df.index % num_shards == shard_index].copy() else: shard_df = metadata_df.copy() shard_df = shard_df[shard_df["status"].str.startswith("ok", na=False)].copy() max_entries = config["dataset_thresholds"]["max_entries"] shard_df = shard_df[shard_df["total_entries"] <= max_entries].copy() datasets = [] for _, row in shard_df.iterrows(): # Normalize path to absolute to avoid duplicates from relative/absolute mixups dataset_path = Path(str(row["dataset_path"])).resolve() datasets.append({ "dataset_path": str(dataset_path), "n_obs": int(row.get("n_obs", 0)), "n_vars": int(row.get("n_vars", 0)), "total_entries": int(row.get("total_entries", 0)), "size_category": str(row.get("size_category", "large")), # Include size category }) return datasets # --------------------------------------------------------------------------- # Main processing pipeline # --------------------------------------------------------------------------- def create_slice_tasks( dataset: dict, obs_slice_size: int, small_dataset_threshold: int, ) -> list[tuple[str, int, int]]: """Create (path, start, end) slice tasks for a dataset. Small datasets (< threshold): Single task for entire dataset (faster, no slicing overhead) Medium/Large datasets: Sliced into obs_slice_size chunks (memory-safe) """ path = dataset["dataset_path"] n_obs = dataset["n_obs"] total_entries = dataset.get("total_entries", n_obs * dataset.get("n_vars", 0)) if n_obs <= 0: return [(path, 0, 0)] # For small datasets, process entire dataset in one task (no slicing overhead) if total_entries < small_dataset_threshold: return [(path, 0, n_obs)] # For medium/large datasets, slice to manage memory tasks = [] for start in range(0, n_obs, obs_slice_size): end = min(start + obs_slice_size, n_obs) tasks.append((path, start, end)) return tasks def process_all_datasets( datasets: list[dict], config: dict, per_dataset_dir: Path, client: Client | None, max_retries: int = 3, ) -> tuple[list[dict], list[dict]]: """Process all datasets: small ones with ProcessPoolExecutor, large ones with Dask.""" base_chunk_size = config["resources"]["chunk_size"] base_obs_slice_size = config["slicing"].get("obs_slice_size", 75_000) obs_slice_size_xlarge = config["slicing"].get("obs_slice_size_xlarge", 150_000) small_threshold = config["dataset_thresholds"]["small"] max_meta_cols = config["metadata"]["max_meta_cols"] max_categories = config["metadata"]["max_categories"] max_workers_base = config["resources"]["max_workers"] # Helper function to get adjusted parameters based on size category def get_dataset_params(dataset): size_cat = dataset.get("size_category", "large") strategy = config.get("strategy", {}).get(size_cat, config["strategy"]["large"]) chunk_mult = strategy.get("chunk_size_multiplier", 1.0) chunk_size = int(base_chunk_size * chunk_mult) # Use smaller slice size for xlarge datasets if size_cat == "xlarge": obs_slice = obs_slice_size_xlarge else: obs_slice = base_obs_slice_size return chunk_size, obs_slice, size_cat successes = [] failures = [] # Categorize datasets: small, dask-ready (medium/large), xlarge (skip Dask) small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold] non_small = [d for d in datasets if d.get("total_entries", 0) >= small_threshold] # Split non-small into Dask-compatible and xlarge (which skip Dask due to failures) dask_datasets = [d for d in non_small if d.get("size_category", "large") != "xlarge"] xlarge_datasets = [d for d in non_small if d.get("size_category", "large") == "xlarge"] small_datasets.sort(key=lambda d: d["total_entries"]) dask_datasets.sort(key=lambda d: d["total_entries"]) xlarge_datasets.sort(key=lambda d: d["total_entries"]) datasets_sorted = small_datasets + dask_datasets + xlarge_datasets small_count = len(small_datasets) dask_count = len(dask_datasets) xlarge_count = len(xlarge_datasets) datasets_sorted = small_datasets + dask_datasets + xlarge_datasets small_count = len(small_datasets) dask_count = len(dask_datasets) xlarge_count = len(xlarge_datasets) print(f"\n{'=' * 80}") print(f"Processing {len(datasets_sorted)} datasets") print(f" Small datasets (ProcessPoolExecutor): {small_count}") print(f" Medium/Large (Dask + slicing): {dask_count}") print(f" XLarge (Direct, skip Dask): {xlarge_count}") print(f"Slice size: {base_obs_slice_size:,} rows (medium/large), {obs_slice_size_xlarge:,} rows (xlarge)") print(f"Small threshold: {small_threshold:,} entries") print(f"Base chunk size: {base_chunk_size:,} rows (adjusted per dataset size)") print(f"{'=' * 80}\n") total_datasets = len(datasets_sorted) # ======================================================================== # Phase 1: Process small datasets with ProcessPoolExecutor (batched) # ======================================================================== if small_count > 0: print(f"{'='*80}") print(f"PHASE 1: Small datasets ({small_count}) - ProcessPoolExecutor") print(f"{'='*80}\n") # Adaptive worker management current_workers = max_workers_base min_workers_ratio = config["resources"].get("min_workers_ratio", 0.25) min_workers = max(1, int(max_workers_base * min_workers_ratio)) batch_size = max(30, min(100, small_count // 4)) # Throughput monitoring check_interval = 50 baseline_throughput = None slowdown_threshold = config["resources"].get("slowdown_threshold", 0.5) last_check_idx = 0 batch_start_time = time.time() print(f"Workers: {current_workers} (adaptive: {min_workers}-{max_workers_base})") print(f"Batch size: {batch_size} (recycled between batches)\n") with tqdm(total=small_count, desc="Small datasets", position=0) as pbar: for batch_start in range(0, small_count, batch_size): batch_end = min(batch_start + batch_size, small_count) batch = small_datasets[batch_start:batch_end] # Check throughput and adjust workers processed = len(successes) + len(failures) if processed >= last_check_idx + check_interval and processed > check_interval: elapsed = time.time() - batch_start_time current_throughput = processed / elapsed if elapsed > 0 else 0 if baseline_throughput is None and processed >= check_interval * 2: baseline_throughput = current_throughput tqdm.write(f"Baseline: {baseline_throughput:.2f} ds/sec") if baseline_throughput and current_throughput < baseline_throughput * slowdown_threshold: if current_workers > min_workers: old_workers = current_workers current_workers = max(min_workers, current_workers // 2) tqdm.write(f"⚠️ Slowdown detected. Workers: {old_workers} → {current_workers}") baseline_throughput = None last_check_idx = processed # Process batch executor = concurrent.futures.ProcessPoolExecutor(max_workers=current_workers) futures = {} try: for dataset in batch: # Get chunk size for this dataset chunk_size, _, _ = get_dataset_params(dataset) future = executor.submit( process_dataset_simple, dataset["dataset_path"], dataset["n_obs"], dataset["n_vars"], chunk_size, max_meta_cols, max_categories, ) futures[future] = dataset for future in concurrent.futures.as_completed(futures): dataset = futures[future] ds_path = dataset["dataset_path"] ds_name = Path(ds_path).name try: row = future.result(timeout=3600) # File size try: row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) except Exception: pass # Save JSON try: payload_name = safe_name(Path(ds_path)) + ".json" (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) except Exception as exc: row["save_error"] = str(exc) if row.get("status") == "ok": successes.append(row) elapsed = row.get("elapsed_sec", "?") tqdm.write(f" [{len(successes)}/{total_datasets}] ✓ {ds_name[:50]} | {elapsed}s") else: failures.append(row) error = row.get("error", "Unknown")[:60] tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {error}") except concurrent.futures.TimeoutError: failures.append({ "dataset_path": ds_path, "dataset_file": ds_name, "status": "failed", "error": "Timeout", }) tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | Timeout") except Exception as exc: failures.append({ "dataset_path": ds_path, "dataset_file": ds_name, "status": "failed", "error": str(exc), }) tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {exc}") finally: pbar.update(1) finally: executor.shutdown(wait=True) gc.collect() time.sleep(1) print(f"\nPhase 1 complete: {len([s for s in successes if s in successes[-small_count:]])} ok, " + f"{len([f for f in failures if f in failures[-small_count:]])} failed\n") # ======================================================================== # Phase 2: Process medium/large datasets with Dask # ======================================================================== if dask_count > 0 and client: print(f"{'='*80}") print(f"PHASE 2: Medium/Large datasets ({dask_count}) - Dask + slicing") print(f"{'='*80}\n") with tqdm( total=dask_count, desc="Med/Large datasets", position=0, leave=True, ncols=100 ) as dataset_pbar: for ds_local_idx, dataset in enumerate(dask_datasets): dataset_idx = small_count + ds_local_idx ds_path = dataset["dataset_path"] ds_name = Path(ds_path).name n_obs = dataset["n_obs"] n_vars = dataset["n_vars"] total_entries = dataset["total_entries"] # Get size-specific parameters chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset) t0 = time.time() # Create slice tasks with adjusted slice size slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold) n_slices = len(slice_tasks) dataset_pbar.set_description(f"Med/Large [{ds_local_idx + 1}/{dask_count}] ({size_cat})") # Submit all slices for this dataset slice_results: list[SliceResult] = [] failed_slices: list[tuple[str, int, int]] = [] # Submit slice tasks to Dask futures = client.map( lambda t: process_slice(t[0], t[1], t[2], chunk_size), slice_tasks, pure=False, ) # Collect results with progress bar (show for sliced datasets) show_slice_bar = n_slices > 1 slice_pbar = tqdm( total=n_slices, desc=f" \u2514\u2500 Slices", position=1, leave=False, ncols=100, disable=not show_slice_bar ) if show_slice_bar else None if slice_pbar: slice_pbar.set_postfix(ok=0, fail=0) for task, future in zip(slice_tasks, futures): try: sr = future.result(timeout=3600) if sr.status == "ok": slice_results.append(sr) else: failed_slices.append(task) except Exception: failed_slices.append(task) finally: if slice_pbar: slice_pbar.set_postfix(ok=len(slice_results), fail=len(failed_slices)) slice_pbar.update(1) if slice_pbar: slice_pbar.close() # Retry failed slices for retry in range(max_retries): if not failed_slices: break tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Retry {retry + 1}/{max_retries}: {len(failed_slices)} failed slices") time.sleep(1) retry_futures = client.map( lambda t: process_slice(t[0], t[1], t[2], chunk_size), failed_slices, pure=False, ) next_failed = [] for task, future in zip(failed_slices, retry_futures): try: sr = future.result(timeout=3600) if sr.status == "ok": slice_results.append(sr) else: next_failed.append(task) except Exception: next_failed.append(task) failed_slices = next_failed # EMERGENCY MODE: If still failing after all retries, use extreme settings if failed_slices and len(failed_slices) > 0: emergency_chunk = max(10000, chunk_size // 10) # Use 10% of original or 10K min tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ⚠️ EMERGENCY MODE: {len(failed_slices)} slices with extreme settings (chunk={emergency_chunk:,})") time.sleep(2) # Process failed slices one at a time with minimal chunk size emergency_ok = 0 for task in failed_slices: try: future = client.submit( process_slice, task[0], task[1], task[2], emergency_chunk, pure=False, ) sr = future.result(timeout=7200) # 2 hour timeout for extreme cases if sr.status == "ok": slice_results.append(sr) emergency_ok += 1 except Exception as e: tqdm.write(f" Emergency failed for slice {task[1]}-{task[2]}: {str(e)[:100]}") continue if emergency_ok > 0: tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✓ Emergency mode recovered {emergency_ok}/{len(failed_slices)} slices") # Update failed_slices to only those that still failed failed_slices = [t for t in failed_slices if not any( sr.slice_start == t[1] and sr.slice_end == t[2] for sr in slice_results )] # Check if we got enough data ok_count = len(slice_results) fail_count = len(failed_slices) elapsed = round(time.time() - t0, 1) if ok_count == 0: tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all {n_slices} slices failed | {elapsed}s") failures.append({ "dataset_path": ds_path, "dataset_file": ds_name, "status": "failed", "error": f"All {n_slices} slices failed", "elapsed_sec": elapsed, }) dataset_pbar.update(1) continue # Merge slice results into dataset summary row = merge_slice_results(slice_results, n_obs, n_vars) row["dataset_path"] = ds_path row["dataset_file"] = ds_name row["n_slices_total"] = n_slices row["n_slices_ok"] = ok_count row["n_slices_failed"] = fail_count # File size try: row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) except Exception: pass # Extract metadata (lightweight, on scheduler) meta = extract_metadata_on_scheduler( Path(ds_path), max_meta_cols, max_categories ) row.update(meta) row["status"] = "ok" if fail_count == 0 else "partial" row["elapsed_sec"] = elapsed # Save per-dataset JSON try: payload_name = safe_name(Path(ds_path)) + ".json" (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) except Exception as exc: row["save_error"] = str(exc) successes.append(row) status = "✓" if fail_count == 0 else "⚠" tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s") # Free memory del slice_results gc.collect() # Update dataset progress dataset_pbar.update(1) print(f"\nPhase 2 complete\n") # Close Dask cluster before Phase 3 (xlarge direct processing doesn't use Dask) if xlarge_count > 0 and client: print("Closing Dask cluster before Phase 3 (xlarge datasets process directly)...") try: client.close() del client gc.collect() time.sleep(2) except Exception as e: print(f"Warning: Error closing Dask client: {e}") # ======================================================================== # Phase 3: Process xlarge datasets DIRECTLY (skip Dask - causes failures) # ======================================================================== if xlarge_count > 0: print(f"{'='*80}") print(f"PHASE 3: XLarge datasets ({xlarge_count}) - Direct processing (no Dask)") print(f"{'='*80}\n") with tqdm( total=xlarge_count, desc="XLarge datasets", position=0, leave=True, ncols=100 ) as dataset_pbar: for ds_local_idx, dataset in enumerate(xlarge_datasets): dataset_idx = small_count + dask_count + ds_local_idx ds_path = dataset["dataset_path"] ds_name = Path(ds_path).name n_obs = dataset["n_obs"] n_vars = dataset["n_vars"] # Get xlarge-specific parameters chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset) t0 = time.time() # Create slice tasks slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold) n_slices = len(slice_tasks) tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Processing {ds_name[:50]} | {n_slices} slices | chunk={chunk_size:,}") # Process slices DIRECTLY without Dask (one at a time) slice_results: list[SliceResult] = [] for slice_idx, (path, start, end) in enumerate(slice_tasks): try: sr = process_slice(path, start, end, chunk_size) if sr.status == "ok": slice_results.append(sr) else: tqdm.write(f" Slice {slice_idx+1}/{n_slices} failed: {sr.error}") except Exception as e: tqdm.write(f" Slice {slice_idx+1}/{n_slices} error: {str(e)[:100]}") ok_count = len(slice_results) fail_count = n_slices - ok_count elapsed = round(time.time() - t0, 1) if ok_count == 0: tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all slices failed | {elapsed}s") failures.append({ "dataset_path": ds_path, "dataset_file": ds_name, "status": "failed", "error": f"All {n_slices} slices failed (xlarge direct mode)", "elapsed_sec": elapsed, }) dataset_pbar.update(1) continue # Merge results row = merge_slice_results(slice_results, n_obs, n_vars) row["dataset_path"] = ds_path row["dataset_file"] = ds_name row["n_slices_total"] = n_slices row["n_slices_ok"] = ok_count row["n_slices_failed"] = fail_count row["processing_mode"] = "xlarge_direct" # File size try: row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) except Exception: pass # Extract metadata meta = extract_metadata_on_scheduler( Path(ds_path), max_meta_cols, max_categories ) row.update(meta) row["status"] = "ok" if fail_count == 0 else "partial" row["elapsed_sec"] = elapsed # Save per-dataset JSON try: payload_name = safe_name(Path(ds_path)) + ".json" (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) except Exception as exc: row["save_error"] = str(exc) successes.append(row) status = "✓" if fail_count == 0 else "⚠" tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s") dataset_pbar.update(1) gc.collect() print(f"\nPhase 3 complete\n") return successes, failures # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--config", type=Path, required=True, help="YAML config") parser.add_argument("--num-shards", type=int, help="Override num_shards") parser.add_argument("--shard-index", type=int, help="Override shard_index") parser.add_argument("--max-retries", type=int, default=3, help="Max retries per slice") args = parser.parse_args() config = load_config(args.config) if args.num_shards is not None: config["sharding"]["num_shards"] = args.num_shards config["sharding"]["enabled"] = args.num_shards > 1 if args.shard_index is not None: config["sharding"]["shard_index"] = args.shard_index num_shards = config["sharding"]["num_shards"] shard_index = config["sharding"]["shard_index"] configure_dask_for_hpc() # Load metadata cache_path = Path(config["paths"]["enhanced_metadata_cache"]) if not cache_path.is_absolute(): cache_path = Path(args.config).parent.parent / cache_path print(f"Loading metadata from: {cache_path}") metadata_df = load_enhanced_metadata(cache_path) datasets = get_datasets_for_shard(metadata_df, config, num_shards, shard_index) if not datasets: print("No datasets scheduled for this shard.") return # Output dirs output_dir = Path(config["paths"]["output_dir"]) if not output_dir.is_absolute(): output_dir = Path(args.config).parent.parent / output_dir output_dir.mkdir(parents=True, exist_ok=True) per_dataset_dir = output_dir / "per_dataset" per_dataset_dir.mkdir(parents=True, exist_ok=True) # Filter out already-completed datasets (resume capability) def is_dataset_done(ds_path: str) -> bool: """Check if dataset already has a successful result.""" try: payload_name = safe_name(Path(ds_path)) + ".json" result_file = per_dataset_dir / payload_name if result_file.exists(): result_data = json.loads(result_file.read_text()) return result_data.get("status") == "ok" except Exception: pass return False original_count = len(datasets) datasets = [d for d in datasets if not is_dataset_done(d["dataset_path"])] skipped_count = original_count - len(datasets) if skipped_count > 0: print(f"\n{'='*80}") print(f"RESUME MODE: Skipping {skipped_count} already-completed datasets") print(f"Remaining to process: {len(datasets)}") print(f"{'='*80}\n") if not datasets: print("All datasets already completed. Nothing to do.") return # Check if we need Dask cluster (for medium/large datasets) small_threshold = config["dataset_thresholds"]["small"] # Count datasets by processing type small_count_init = sum(1 for d in datasets if d.get("total_entries", 0) < small_threshold) dask_count_init = sum(1 for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge") xlarge_count_init = sum(1 for d in datasets if d.get("size_category", "large") == "xlarge") client = None cluster = None if dask_count_init > 0: # Cluster setup for large datasets max_memory_gib = config["resources"]["max_memory_gib"] max_workers = config["resources"]["max_workers"] min_workers = config["resources"].get("min_workers", min(4, max_workers)) threads_per_worker = config["resources"].get("threads_per_worker", 1) # Adaptive scaling config adaptive_config = config["resources"].get("adaptive_scaling", {}) target_duration = adaptive_config.get("target_duration", "30s") wait_count = adaptive_config.get("wait_count", 3) interval = adaptive_config.get("interval", "2s") memory_per_worker_gib = max(2.0, max_memory_gib / max_workers) total_entries = sum(d["total_entries"] for d in datasets) total_slices = sum( max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 75_000))) for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge" ) print(json.dumps({ "total_datasets": len(datasets), "small_datasets": small_count_init, "large_datasets": dask_count_init + xlarge_count_init, "total_slices": total_slices, "total_entries": total_entries, "shard_index": shard_index, "num_shards": num_shards, "memory_per_worker_gib": round(memory_per_worker_gib, 1), "max_workers": max_workers, }, indent=2)) print(f"\nStarting Dask LocalCluster (for {dask_count_init} medium/large datasets):") print(f" Workers: {min_workers} -> {max_workers} (adaptive)") print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB") print(f" Total memory budget: {max_memory_gib} GiB\n") cluster = LocalCluster( n_workers=min_workers, threads_per_worker=threads_per_worker, processes=True, memory_limit=f"{memory_per_worker_gib}GiB", silence_logs=True, dashboard_address=None, lifetime="180 minutes", lifetime_stagger="20 minutes", ) cluster.adapt( minimum=min_workers, maximum=max_workers, target_duration=target_duration, wait_count=wait_count, interval=interval, ) client = Client(cluster) print(f"Dask cluster ready: {client}\n") else: print(f"No Dask-compatible datasets (all small or xlarge)\n") if xlarge_count_init > 0: print(f"Note: {xlarge_count_init} xlarge datasets will be processed directly (Phase 3, no Dask)\n") try: successes, failures = process_all_datasets( datasets, config, per_dataset_dir, client, max_retries=args.max_retries, ) # Include previously completed datasets in final summary if skipped_count > 0: print(f"\nLoading {skipped_count} previously completed results...") for json_file in per_dataset_dir.glob("*.json"): try: result = json.loads(json_file.read_text()) if result.get("status") == "ok": # Check if not already in successes ds_path = result.get("dataset_path", "") if not any(s.get("dataset_path") == ds_path for s in successes): successes.append(result) except Exception: pass print(f"Total results (new + previous): {len(successes)}") print(f"\n{'=' * 80}") print(f"PROCESSING COMPLETE") print(f" Succeeded: {len(successes)}") print(f" Failed: {len(failures)}") print(f" Success rate: {len(successes) / max(1, original_count) * 100:.1f}%") print(f"{'=' * 80}\n") if failures: print("WARNING: Some datasets failed permanently:") for fail in failures[:10]: print(f" - {fail['dataset_file']}: {fail.get('error', 'Unknown')[:80]}") if len(failures) > 10: print(f" ... and {len(failures) - 10} more") except KeyboardInterrupt: print("\n\n{'=' * 80}") print("INTERRUPTED - Saving partial results...") print(f"{'=' * 80}\n") successes = [] failures = [] seen_paths = set() # Load all completed results from disk (deduplicate by dataset_path) for json_file in per_dataset_dir.glob("*.json"): try: result = json.loads(json_file.read_text()) ds_path = result.get("dataset_path", "") if ds_path and ds_path in seen_paths: continue # Skip duplicate seen_paths.add(ds_path) if result.get("status") == "ok": successes.append(result) else: failures.append(result) except Exception: pass except Exception as exc: print(f"\n\nERROR during processing: {exc}") print("Saving partial results...") successes = [] failures = [] seen_paths = set() # Load all completed results from disk (deduplicate by dataset_path) for json_file in per_dataset_dir.glob("*.json"): try: result = json.loads(json_file.read_text()) ds_path = result.get("dataset_path", "") if ds_path and ds_path in seen_paths: continue # Skip duplicate seen_paths.add(ds_path) if result.get("status") == "ok": successes.append(result) else: failures.append(result) except Exception: pass # Always save results, even on error/interrupt try: summary_df = pd.DataFrame(successes) # Deduplicate by dataset_path (keep first/most recent) if not summary_df.empty and 'dataset_path' in summary_df.columns: original_count = len(summary_df) summary_df = summary_df.drop_duplicates(subset=['dataset_path'], keep='first') if len(summary_df) < original_count: print(f"\nRemoved {original_count - len(summary_df)} duplicate entries from results") summary_csv = output_dir / f"eda_summary_shard_{shard_index:03d}_of_{num_shards:03d}.csv" summary_df.to_csv(summary_csv, index=False) failures_path = output_dir / f"eda_failures_shard_{shard_index:03d}_of_{num_shards:03d}.json" failures_path.write_text(json.dumps(failures, indent=2)) print(f"\n{'=' * 80}") print("RESULTS SAVED") print(f" Summary CSV: {summary_csv}") print(f" Failures JSON: {failures_path}") print(json.dumps({ "ok_count": len(successes), "failed_count": len(failures), }, indent=2)) print(f"{'=' * 80}\n") except Exception as save_exc: print(f"ERROR saving results: {save_exc}") finally: if client: client.close() if cluster: cluster.close() if __name__ == "__main__": main()