| |
| """Distributed EDA with chunk-slice architecture for billion-scale datasets. |
| |
| Each dataset is sliced into small row-ranges. Each slice is a Dask task that |
| processes a bounded amount of memory (chunk_size * n_vars). Slice results are |
| merged per-dataset on the scheduler side with O(1) memory. |
| |
| No quantiles - only non-zero min, max, mean, sparsity, gene-level stats, |
| and metadata summaries. This handles datasets from 2 GB to 500 GB. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import concurrent.futures |
| import gc |
| import hashlib |
| import json |
| import math |
| import time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import anndata as ad |
| import dask |
| import numpy as np |
| import pandas as pd |
| import yaml |
| from dask.distributed import Client, LocalCluster |
| from scipy import sparse |
| from tqdm import tqdm |
|
|
|
|
| |
| |
| |
| @dataclass |
| class SliceResult: |
| """Mergeable statistics from one row-slice of a dataset. |
| |
| Every field is O(1) memory except gene arrays which are O(n_vars). |
| All fields are JSON-serialisable after to_dict(). |
| """ |
| |
| dataset_path: str = "" |
| slice_start: int = 0 |
| slice_end: int = 0 |
|
|
| |
| n_obs_slice: int = 0 |
| n_vars: int = 0 |
| nnz: int = 0 |
| x_sum: float = 0.0 |
| x_sum_sq: float = 0.0 |
|
|
| |
| cell_total_counts_sum: float = 0.0 |
| cell_total_counts_min: float = math.inf |
| cell_total_counts_max: float = -math.inf |
| cell_n_genes_sum: int = 0 |
| cell_n_genes_min: int = 2**63 - 1 |
| cell_n_genes_max: int = 0 |
|
|
| |
| gene_n_cells: list | None = None |
| gene_total_counts: list | None = None |
|
|
| |
| status: str = "ok" |
| error: str = "" |
| elapsed_sec: float = 0.0 |
|
|
|
|
| def merge_slice_results(slices: list[SliceResult], n_obs: int, n_vars: int) -> dict: |
| """Merge many SliceResults into one per-dataset summary dict. |
| |
| Uses O(n_vars) memory for gene arrays, everything else O(1). |
| """ |
| nnz_total = 0 |
| x_sum = 0.0 |
| x_sum_sq = 0.0 |
| n_obs_seen = 0 |
|
|
| cell_total_counts_sum = 0.0 |
| cell_total_counts_min = math.inf |
| cell_total_counts_max = -math.inf |
| cell_n_genes_sum = 0 |
| cell_n_genes_min = 2**63 - 1 |
| cell_n_genes_max = 0 |
|
|
| gene_n_cells = np.zeros(n_vars, dtype=np.int64) |
| gene_total_counts = np.zeros(n_vars, dtype=np.float64) |
|
|
| for s in slices: |
| if s.status != "ok": |
| continue |
| n_obs_seen += s.n_obs_slice |
| nnz_total += s.nnz |
| x_sum += s.x_sum |
| x_sum_sq += s.x_sum_sq |
|
|
| cell_total_counts_sum += s.cell_total_counts_sum |
| cell_total_counts_min = min(cell_total_counts_min, s.cell_total_counts_min) |
| cell_total_counts_max = max(cell_total_counts_max, s.cell_total_counts_max) |
| cell_n_genes_sum += s.cell_n_genes_sum |
| cell_n_genes_min = min(cell_n_genes_min, s.cell_n_genes_min) |
| cell_n_genes_max = max(cell_n_genes_max, s.cell_n_genes_max) |
|
|
| if s.gene_n_cells is not None: |
| gene_n_cells += np.asarray(s.gene_n_cells, dtype=np.int64) |
| if s.gene_total_counts is not None: |
| gene_total_counts += np.asarray(s.gene_total_counts, dtype=np.float64) |
|
|
| total_entries = n_obs * n_vars |
| row: dict[str, Any] = { |
| "n_obs": n_obs, |
| "n_vars": n_vars, |
| "n_obs_processed": n_obs_seen, |
| "nnz": int(nnz_total), |
| "sparsity": float(1.0 - nnz_total / total_entries) if total_entries else None, |
| "x_mean": float(x_sum / total_entries) if total_entries else None, |
| } |
| if total_entries: |
| var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2) |
| row["x_std"] = float(math.sqrt(var)) |
| else: |
| row["x_std"] = None |
|
|
| |
| if n_obs_seen > 0: |
| row["cell_total_counts_min"] = float(cell_total_counts_min) |
| row["cell_total_counts_max"] = float(cell_total_counts_max) |
| row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs_seen) |
| row["cell_n_genes_detected_min"] = int(cell_n_genes_min) |
| row["cell_n_genes_detected_max"] = int(cell_n_genes_max) |
| row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs_seen) |
| else: |
| row["cell_total_counts_min"] = None |
| row["cell_total_counts_max"] = None |
| row["cell_total_counts_mean"] = None |
| row["cell_n_genes_detected_min"] = None |
| row["cell_n_genes_detected_max"] = None |
| row["cell_n_genes_detected_mean"] = None |
|
|
| |
| genes_detected = int(np.count_nonzero(gene_n_cells)) |
| row["genes_detected_in_any_cell"] = genes_detected |
| row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0 |
| if genes_detected > 0: |
| mask = gene_n_cells > 0 |
| row["gene_n_cells_min"] = int(gene_n_cells[mask].min()) |
| row["gene_n_cells_max"] = int(gene_n_cells[mask].max()) |
| row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean()) |
| row["gene_total_counts_min"] = float(gene_total_counts[mask].min()) |
| row["gene_total_counts_max"] = float(gene_total_counts[mask].max()) |
| row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean()) |
| else: |
| for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean", |
| "gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"): |
| row[k] = 0 |
|
|
| |
| del gene_n_cells, gene_total_counts |
| return row |
|
|
|
|
| |
| |
| |
| def process_dataset_simple( |
| path_str: str, |
| n_obs: int, |
| n_vars: int, |
| chunk_size: int, |
| max_meta_cols: int, |
| max_categories: int, |
| ) -> dict: |
| """Process entire small dataset in one worker (no slicing, no Dask).""" |
| t0 = time.time() |
| path = Path(path_str) |
| row: dict[str, Any] = { |
| "dataset_path": path_str, |
| "dataset_file": path.name, |
| "n_obs": n_obs, |
| "n_vars": n_vars, |
| } |
|
|
| try: |
| adata = ad.read_h5ad(path, backed="r") |
| total_entries = n_obs * n_vars |
|
|
| nnz_total = 0 |
| x_sum = 0.0 |
| x_sum_sq = 0.0 |
|
|
| |
| cell_total_counts_sum = 0.0 |
| cell_total_counts_min = math.inf |
| cell_total_counts_max = -math.inf |
| cell_n_genes_sum = 0 |
| cell_n_genes_min = 2**63 - 1 |
| cell_n_genes_max = 0 |
|
|
| |
| gene_n_cells = np.zeros(n_vars, dtype=np.int64) |
| gene_total_counts = np.zeros(n_vars, dtype=np.float64) |
|
|
| |
| for start in range(0, n_obs, chunk_size): |
| end = min(start + chunk_size, n_obs) |
| chunk = adata.X[start:end, :] |
|
|
| if sparse.issparse(chunk): |
| csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk |
| data = csr.data.astype(np.float64, copy=False) |
|
|
| nnz_total += int(csr.nnz) |
| x_sum += float(data.sum()) |
| x_sum_sq += float(np.square(data).sum()) |
|
|
| |
| cell_counts = np.asarray(csr.sum(axis=1)).ravel() |
| cell_genes = np.diff(csr.indptr).astype(np.int64) |
|
|
| cell_total_counts_sum += float(cell_counts.sum()) |
| cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min())) |
| cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max())) |
| cell_n_genes_sum += int(cell_genes.sum()) |
| cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min())) |
| cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max())) |
|
|
| |
| csc = csr.tocsc() |
| gene_n_cells += np.diff(csc.indptr).astype(np.int64) |
| gene_total_counts += np.asarray(csc.sum(axis=0)).ravel() |
|
|
| del csr, csc, data |
| else: |
| arr = np.asarray(chunk, dtype=np.float64) |
| nz = arr != 0 |
|
|
| nnz_total += int(nz.sum()) |
| x_sum += float(arr.sum()) |
| x_sum_sq += float(np.square(arr).sum()) |
|
|
| |
| cell_counts = arr.sum(axis=1) |
| cell_genes = nz.sum(axis=1).astype(np.int64) |
|
|
| cell_total_counts_sum += float(cell_counts.sum()) |
| cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min())) |
| cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max())) |
| cell_n_genes_sum += int(cell_genes.sum()) |
| cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min())) |
| cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max())) |
|
|
| |
| gene_n_cells += nz.sum(axis=0).astype(np.int64) |
| gene_total_counts += arr.sum(axis=0) |
|
|
| del arr, nz |
|
|
| del chunk |
| gc.collect() |
|
|
| |
| row["nnz"] = int(nnz_total) |
| row["sparsity"] = float(1.0 - nnz_total / total_entries) if total_entries else None |
| row["x_mean"] = float(x_sum / total_entries) if total_entries else None |
| if total_entries: |
| var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2) |
| row["x_std"] = float(math.sqrt(var)) |
| else: |
| row["x_std"] = None |
|
|
| |
| if n_obs > 0: |
| row["cell_total_counts_min"] = float(cell_total_counts_min) |
| row["cell_total_counts_max"] = float(cell_total_counts_max) |
| row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs) |
| row["cell_n_genes_detected_min"] = int(cell_n_genes_min) |
| row["cell_n_genes_detected_max"] = int(cell_n_genes_max) |
| row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs) |
| else: |
| row["cell_total_counts_min"] = None |
| row["cell_total_counts_max"] = None |
| row["cell_total_counts_mean"] = None |
| row["cell_n_genes_detected_min"] = None |
| row["cell_n_genes_detected_max"] = None |
| row["cell_n_genes_detected_mean"] = None |
|
|
| |
| genes_detected = int(np.count_nonzero(gene_n_cells)) |
| row["genes_detected_in_any_cell"] = genes_detected |
| row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0 |
| if genes_detected > 0: |
| mask = gene_n_cells > 0 |
| row["gene_n_cells_min"] = int(gene_n_cells[mask].min()) |
| row["gene_n_cells_max"] = int(gene_n_cells[mask].max()) |
| row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean()) |
| row["gene_total_counts_min"] = float(gene_total_counts[mask].min()) |
| row["gene_total_counts_max"] = float(gene_total_counts[mask].max()) |
| row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean()) |
| else: |
| for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean", |
| "gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"): |
| row[k] = 0 |
|
|
| |
| row["obs_columns"] = int(len(adata.obs.columns)) |
| row["var_columns"] = int(len(adata.var.columns)) |
| row["metadata_obs_summary"] = summarize_metadata( |
| adata.obs, max_cols=max_meta_cols, max_categories=max_categories |
| ) |
| row["metadata_var_summary"] = summarize_metadata( |
| adata.var, max_cols=max_meta_cols, max_categories=max_categories |
| ) |
| row["obs_schema"] = extract_schema(adata.obs) |
| row["var_schema"] = extract_schema(adata.var) |
|
|
| |
| del gene_n_cells, gene_total_counts |
| try: |
| if hasattr(adata, "file") and adata.file is not None: |
| adata.file.close() |
| except Exception: |
| pass |
| del adata |
|
|
| row["status"] = "ok" |
| row["n_slices_total"] = 1 |
| row["n_slices_ok"] = 1 |
| row["n_slices_failed"] = 0 |
|
|
| except Exception as exc: |
| row["status"] = "failed" |
| row["error"] = str(exc) |
|
|
| gc.collect() |
| row["elapsed_sec"] = round(time.time() - t0, 2) |
| return row |
|
|
|
|
| |
| |
| |
| def process_slice( |
| path_str: str, |
| obs_start: int, |
| obs_end: int, |
| chunk_size: int, |
| ) -> SliceResult: |
| """Process rows [obs_start, obs_end) of a dataset. |
| |
| Memory usage bounded by: chunk_size * n_vars * ~12 bytes * 3x overhead. |
| """ |
| t0 = time.time() |
| path = Path(path_str) |
| result = SliceResult(dataset_path=path_str, slice_start=obs_start, slice_end=obs_end) |
|
|
| try: |
| adata = ad.read_h5ad(path, backed="r") |
| n_vars = int(adata.n_vars) |
| result.n_vars = n_vars |
| result.n_obs_slice = obs_end - obs_start |
|
|
| |
| gene_n_cells = np.zeros(n_vars, dtype=np.int64) |
| gene_total_counts = np.zeros(n_vars, dtype=np.float64) |
|
|
| |
| for start in range(obs_start, obs_end, chunk_size): |
| end = min(start + chunk_size, obs_end) |
| chunk = adata.X[start:end, :] |
|
|
| if sparse.issparse(chunk): |
| csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk |
| data = csr.data.astype(np.float64, copy=False) |
|
|
| result.nnz += int(csr.nnz) |
| result.x_sum += float(data.sum()) |
| result.x_sum_sq += float(np.square(data).sum()) |
|
|
| |
| cell_counts = np.asarray(csr.sum(axis=1)).ravel() |
| cell_genes = np.diff(csr.indptr).astype(np.int64) |
|
|
| |
| |
| gene_total_counts += np.bincount( |
| csr.indices, |
| weights=data, |
| minlength=n_vars |
| ) |
| gene_n_cells += np.bincount( |
| csr.indices, |
| minlength=n_vars |
| ) |
|
|
| del csr, data |
| else: |
| arr = np.asarray(chunk, dtype=np.float64) |
| nz = arr != 0 |
|
|
| result.nnz += int(nz.sum()) |
| result.x_sum += float(arr.sum()) |
| result.x_sum_sq += float(np.square(arr).sum()) |
|
|
| |
| cell_counts = arr.sum(axis=1) |
| cell_genes = nz.sum(axis=1).astype(np.int64) |
|
|
| |
| gene_n_cells += nz.sum(axis=0).astype(np.int64) |
| gene_total_counts += arr.sum(axis=0) |
|
|
| del arr, nz |
|
|
| |
| result.cell_total_counts_sum += float(cell_counts.sum()) |
| result.cell_total_counts_min = min(result.cell_total_counts_min, float(cell_counts.min())) |
| result.cell_total_counts_max = max(result.cell_total_counts_max, float(cell_counts.max())) |
| result.cell_n_genes_sum += int(cell_genes.sum()) |
| result.cell_n_genes_min = min(result.cell_n_genes_min, int(cell_genes.min())) |
| result.cell_n_genes_max = max(result.cell_n_genes_max, int(cell_genes.max())) |
|
|
| del chunk, cell_counts, cell_genes |
| gc.collect() |
|
|
| |
| result.gene_n_cells = gene_n_cells.tolist() |
| result.gene_total_counts = gene_total_counts.tolist() |
| del gene_n_cells, gene_total_counts |
|
|
| |
| try: |
| if hasattr(adata, "file") and adata.file is not None: |
| adata.file.close() |
| except Exception: |
| pass |
| del adata |
|
|
| except Exception as exc: |
| result.status = "failed" |
| result.error = str(exc) |
|
|
| gc.collect() |
| result.elapsed_sec = round(time.time() - t0, 2) |
| return result |
|
|
|
|
| |
| |
| |
| def safe_name(path: Path) -> str: |
| """Generate safe filename from path.""" |
| digest = hashlib.md5(str(path).encode("utf-8"), usedforsecurity=False).hexdigest()[:10] |
| stem = path.stem.replace(" ", "_") |
| if len(stem) > 80: |
| stem = stem[:80] |
| return f"{stem}_{digest}" |
|
|
|
|
| def summarize_metadata(df: pd.DataFrame, max_cols: int, max_categories: int) -> dict[str, dict]: |
| """Summarize DataFrame metadata with top categories.""" |
| if df.empty: |
| return {} |
|
|
| preferred = ["cell_type", "assay", "tissue", "disease", "sex", "donor_id"] |
| selected: list[str] = [c for c in preferred if c in df.columns] |
| for col in df.columns: |
| if col not in selected: |
| selected.append(col) |
| if len(selected) >= max_cols: |
| break |
|
|
| out: dict[str, dict] = {} |
| n_rows = max(1, len(df)) |
| for col in selected: |
| s = df[col] |
| summary: dict[str, Any] = { |
| "dtype": str(s.dtype), |
| "missing_fraction": float(s.isna().sum()) / n_rows, |
| } |
| if isinstance(s.dtype, pd.CategoricalDtype): |
| summary["n_unique"] = int(len(s.cat.categories)) |
| vc = s.value_counts(dropna=False).head(max_categories) |
| summary["top_values"] = {str(k): int(v) for k, v in vc.items()} |
| elif pd.api.types.is_string_dtype(s.dtype) or s.dtype == object: |
| s_str = s.dropna().astype(str) |
| summary["n_unique"] = int(s_str.nunique()) |
| vc = s_str.value_counts(dropna=False).head(max_categories) |
| summary["top_values"] = {str(k): int(v) for k, v in vc.items()} |
| out[col] = summary |
| return out |
|
|
|
|
| def extract_schema(df: pd.DataFrame) -> dict[str, object]: |
| """Extract DataFrame schema.""" |
| return { |
| "n_columns": int(len(df.columns)), |
| "columns": [str(c) for c in df.columns], |
| "dtypes": {str(c): str(df[c].dtype) for c in df.columns}, |
| } |
|
|
|
|
| def extract_metadata_on_scheduler( |
| path: Path, |
| max_meta_cols: int, |
| max_categories: int, |
| ) -> dict: |
| """Extract obs/var metadata. Runs on scheduler (lightweight, no X access).""" |
| try: |
| adata = ad.read_h5ad(path, backed="r") |
| result = { |
| "obs_columns": int(len(adata.obs.columns)), |
| "var_columns": int(len(adata.var.columns)), |
| "metadata_obs_summary": summarize_metadata( |
| adata.obs, max_cols=max_meta_cols, max_categories=max_categories |
| ), |
| "metadata_var_summary": summarize_metadata( |
| adata.var, max_cols=max_meta_cols, max_categories=max_categories |
| ), |
| "obs_schema": extract_schema(adata.obs), |
| "var_schema": extract_schema(adata.var), |
| } |
| try: |
| if hasattr(adata, "file") and adata.file is not None: |
| adata.file.close() |
| except Exception: |
| pass |
| del adata |
| gc.collect() |
| return result |
| except Exception as exc: |
| return {"metadata_error": str(exc)} |
|
|
|
|
| |
| |
| |
| def configure_dask_for_hpc() -> None: |
| """Configure Dask for HPC with aggressive memory management.""" |
| dask.config.set({ |
| "distributed.worker.memory.target": 0.60, |
| "distributed.worker.memory.spill": 0.70, |
| "distributed.worker.memory.pause": 0.80, |
| "distributed.worker.memory.terminate": 0.95, |
| "distributed.worker.daemon": False, |
| "distributed.worker.use-file-locking": False, |
| "distributed.scheduler.allowed-failures": 10, |
| "distributed.scheduler.work-stealing": True, |
| "distributed.scheduler.work-stealing-interval": "100ms", |
| "distributed.comm.timeouts.connect": "120s", |
| "distributed.comm.timeouts.tcp": "120s", |
| "distributed.admin.tick.interval": "2s", |
| "distributed.admin.log-length": 500, |
| }) |
|
|
|
|
| |
| |
| |
| def load_config(config_path: Path) -> dict: |
| with open(config_path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def load_enhanced_metadata(cache_path: Path) -> pd.DataFrame: |
| if not cache_path.exists(): |
| raise FileNotFoundError( |
| f"Enhanced metadata cache not found: {cache_path}\n" |
| "Run: uv run python scripts/build_metadata_cache.py --config <config.yaml>" |
| ) |
| return pd.read_parquet(cache_path) |
|
|
|
|
| def get_datasets_for_shard( |
| metadata_df: pd.DataFrame, |
| config: dict, |
| num_shards: int, |
| shard_index: int, |
| ) -> list[dict]: |
| """Get dataset info for this shard. |
| |
| Returns list of dicts with keys: dataset_path, n_obs, n_vars, total_entries. |
| """ |
| if num_shards > 1: |
| metadata_df = metadata_df.sort_values("total_entries", ascending=False).reset_index(drop=True) |
| shard_df = metadata_df[metadata_df.index % num_shards == shard_index].copy() |
| else: |
| shard_df = metadata_df.copy() |
|
|
| shard_df = shard_df[shard_df["status"].str.startswith("ok", na=False)].copy() |
| max_entries = config["dataset_thresholds"]["max_entries"] |
| shard_df = shard_df[shard_df["total_entries"] <= max_entries].copy() |
|
|
| datasets = [] |
| for _, row in shard_df.iterrows(): |
| |
| dataset_path = Path(str(row["dataset_path"])).resolve() |
| datasets.append({ |
| "dataset_path": str(dataset_path), |
| "n_obs": int(row.get("n_obs", 0)), |
| "n_vars": int(row.get("n_vars", 0)), |
| "total_entries": int(row.get("total_entries", 0)), |
| "size_category": str(row.get("size_category", "large")), |
| }) |
| return datasets |
|
|
|
|
| |
| |
| |
| def create_slice_tasks( |
| dataset: dict, |
| obs_slice_size: int, |
| small_dataset_threshold: int, |
| ) -> list[tuple[str, int, int]]: |
| """Create (path, start, end) slice tasks for a dataset. |
| |
| Small datasets (< threshold): Single task for entire dataset (faster, no slicing overhead) |
| Medium/Large datasets: Sliced into obs_slice_size chunks (memory-safe) |
| """ |
| path = dataset["dataset_path"] |
| n_obs = dataset["n_obs"] |
| total_entries = dataset.get("total_entries", n_obs * dataset.get("n_vars", 0)) |
| |
| if n_obs <= 0: |
| return [(path, 0, 0)] |
| |
| |
| if total_entries < small_dataset_threshold: |
| return [(path, 0, n_obs)] |
| |
| |
| tasks = [] |
| for start in range(0, n_obs, obs_slice_size): |
| end = min(start + obs_slice_size, n_obs) |
| tasks.append((path, start, end)) |
| return tasks |
|
|
|
|
| def process_all_datasets( |
| datasets: list[dict], |
| config: dict, |
| per_dataset_dir: Path, |
| client: Client | None, |
| max_retries: int = 3, |
| ) -> tuple[list[dict], list[dict]]: |
| """Process all datasets: small ones with ProcessPoolExecutor, large ones with Dask.""" |
| base_chunk_size = config["resources"]["chunk_size"] |
| base_obs_slice_size = config["slicing"].get("obs_slice_size", 75_000) |
| obs_slice_size_xlarge = config["slicing"].get("obs_slice_size_xlarge", 150_000) |
| small_threshold = config["dataset_thresholds"]["small"] |
| max_meta_cols = config["metadata"]["max_meta_cols"] |
| max_categories = config["metadata"]["max_categories"] |
| max_workers_base = config["resources"]["max_workers"] |
| |
| |
| def get_dataset_params(dataset): |
| size_cat = dataset.get("size_category", "large") |
| strategy = config.get("strategy", {}).get(size_cat, config["strategy"]["large"]) |
| |
| chunk_mult = strategy.get("chunk_size_multiplier", 1.0) |
| chunk_size = int(base_chunk_size * chunk_mult) |
| |
| |
| if size_cat == "xlarge": |
| obs_slice = obs_slice_size_xlarge |
| else: |
| obs_slice = base_obs_slice_size |
| |
| return chunk_size, obs_slice, size_cat |
|
|
| successes = [] |
| failures = [] |
|
|
| |
| small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold] |
| non_small = [d for d in datasets if d.get("total_entries", 0) >= small_threshold] |
| |
| |
| dask_datasets = [d for d in non_small if d.get("size_category", "large") != "xlarge"] |
| xlarge_datasets = [d for d in non_small if d.get("size_category", "large") == "xlarge"] |
| |
| small_datasets.sort(key=lambda d: d["total_entries"]) |
| dask_datasets.sort(key=lambda d: d["total_entries"]) |
| xlarge_datasets.sort(key=lambda d: d["total_entries"]) |
| |
| datasets_sorted = small_datasets + dask_datasets + xlarge_datasets |
| small_count = len(small_datasets) |
| dask_count = len(dask_datasets) |
| xlarge_count = len(xlarge_datasets) |
| |
| datasets_sorted = small_datasets + dask_datasets + xlarge_datasets |
| small_count = len(small_datasets) |
| dask_count = len(dask_datasets) |
| xlarge_count = len(xlarge_datasets) |
| |
| print(f"\n{'=' * 80}") |
| print(f"Processing {len(datasets_sorted)} datasets") |
| print(f" Small datasets (ProcessPoolExecutor): {small_count}") |
| print(f" Medium/Large (Dask + slicing): {dask_count}") |
| print(f" XLarge (Direct, skip Dask): {xlarge_count}") |
| print(f"Slice size: {base_obs_slice_size:,} rows (medium/large), {obs_slice_size_xlarge:,} rows (xlarge)") |
| print(f"Small threshold: {small_threshold:,} entries") |
| print(f"Base chunk size: {base_chunk_size:,} rows (adjusted per dataset size)") |
| print(f"{'=' * 80}\n") |
|
|
| total_datasets = len(datasets_sorted) |
|
|
| |
| |
| |
| if small_count > 0: |
| print(f"{'='*80}") |
| print(f"PHASE 1: Small datasets ({small_count}) - ProcessPoolExecutor") |
| print(f"{'='*80}\n") |
| |
| |
| current_workers = max_workers_base |
| min_workers_ratio = config["resources"].get("min_workers_ratio", 0.25) |
| min_workers = max(1, int(max_workers_base * min_workers_ratio)) |
| batch_size = max(30, min(100, small_count // 4)) |
| |
| |
| check_interval = 50 |
| baseline_throughput = None |
| slowdown_threshold = config["resources"].get("slowdown_threshold", 0.5) |
| last_check_idx = 0 |
| batch_start_time = time.time() |
| |
| print(f"Workers: {current_workers} (adaptive: {min_workers}-{max_workers_base})") |
| print(f"Batch size: {batch_size} (recycled between batches)\n") |
| |
| with tqdm(total=small_count, desc="Small datasets", position=0) as pbar: |
| for batch_start in range(0, small_count, batch_size): |
| batch_end = min(batch_start + batch_size, small_count) |
| batch = small_datasets[batch_start:batch_end] |
| |
| |
| processed = len(successes) + len(failures) |
| if processed >= last_check_idx + check_interval and processed > check_interval: |
| elapsed = time.time() - batch_start_time |
| current_throughput = processed / elapsed if elapsed > 0 else 0 |
| |
| if baseline_throughput is None and processed >= check_interval * 2: |
| baseline_throughput = current_throughput |
| tqdm.write(f"Baseline: {baseline_throughput:.2f} ds/sec") |
| |
| if baseline_throughput and current_throughput < baseline_throughput * slowdown_threshold: |
| if current_workers > min_workers: |
| old_workers = current_workers |
| current_workers = max(min_workers, current_workers // 2) |
| tqdm.write(f"⚠️ Slowdown detected. Workers: {old_workers} → {current_workers}") |
| baseline_throughput = None |
| |
| last_check_idx = processed |
| |
| |
| executor = concurrent.futures.ProcessPoolExecutor(max_workers=current_workers) |
| futures = {} |
| |
| try: |
| for dataset in batch: |
| |
| chunk_size, _, _ = get_dataset_params(dataset) |
| future = executor.submit( |
| process_dataset_simple, |
| dataset["dataset_path"], |
| dataset["n_obs"], |
| dataset["n_vars"], |
| chunk_size, |
| max_meta_cols, |
| max_categories, |
| ) |
| futures[future] = dataset |
| |
| for future in concurrent.futures.as_completed(futures): |
| dataset = futures[future] |
| ds_path = dataset["dataset_path"] |
| ds_name = Path(ds_path).name |
| |
| try: |
| row = future.result(timeout=3600) |
| |
| |
| try: |
| row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) |
| except Exception: |
| pass |
| |
| |
| try: |
| payload_name = safe_name(Path(ds_path)) + ".json" |
| (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) |
| except Exception as exc: |
| row["save_error"] = str(exc) |
| |
| if row.get("status") == "ok": |
| successes.append(row) |
| elapsed = row.get("elapsed_sec", "?") |
| tqdm.write(f" [{len(successes)}/{total_datasets}] ✓ {ds_name[:50]} | {elapsed}s") |
| else: |
| failures.append(row) |
| error = row.get("error", "Unknown")[:60] |
| tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {error}") |
| |
| except concurrent.futures.TimeoutError: |
| failures.append({ |
| "dataset_path": ds_path, |
| "dataset_file": ds_name, |
| "status": "failed", |
| "error": "Timeout", |
| }) |
| tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | Timeout") |
| except Exception as exc: |
| failures.append({ |
| "dataset_path": ds_path, |
| "dataset_file": ds_name, |
| "status": "failed", |
| "error": str(exc), |
| }) |
| tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {exc}") |
| finally: |
| pbar.update(1) |
| finally: |
| executor.shutdown(wait=True) |
| gc.collect() |
| time.sleep(1) |
| |
| print(f"\nPhase 1 complete: {len([s for s in successes if s in successes[-small_count:]])} ok, " + |
| f"{len([f for f in failures if f in failures[-small_count:]])} failed\n") |
|
|
| |
| |
| |
| if dask_count > 0 and client: |
| print(f"{'='*80}") |
| print(f"PHASE 2: Medium/Large datasets ({dask_count}) - Dask + slicing") |
| print(f"{'='*80}\n") |
| |
| with tqdm( |
| total=dask_count, |
| desc="Med/Large datasets", |
| position=0, |
| leave=True, |
| ncols=100 |
| ) as dataset_pbar: |
| for ds_local_idx, dataset in enumerate(dask_datasets): |
| dataset_idx = small_count + ds_local_idx |
| ds_path = dataset["dataset_path"] |
| ds_name = Path(ds_path).name |
| n_obs = dataset["n_obs"] |
| n_vars = dataset["n_vars"] |
| total_entries = dataset["total_entries"] |
| |
| |
| chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset) |
|
|
| t0 = time.time() |
|
|
| |
| slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold) |
| n_slices = len(slice_tasks) |
|
|
| dataset_pbar.set_description(f"Med/Large [{ds_local_idx + 1}/{dask_count}] ({size_cat})") |
|
|
| |
| slice_results: list[SliceResult] = [] |
| failed_slices: list[tuple[str, int, int]] = [] |
|
|
| |
| futures = client.map( |
| lambda t: process_slice(t[0], t[1], t[2], chunk_size), |
| slice_tasks, |
| pure=False, |
| ) |
| |
| |
| show_slice_bar = n_slices > 1 |
| slice_pbar = tqdm( |
| total=n_slices, |
| desc=f" \u2514\u2500 Slices", |
| position=1, |
| leave=False, |
| ncols=100, |
| disable=not show_slice_bar |
| ) if show_slice_bar else None |
| |
| if slice_pbar: |
| slice_pbar.set_postfix(ok=0, fail=0) |
| |
| for task, future in zip(slice_tasks, futures): |
| try: |
| sr = future.result(timeout=3600) |
| if sr.status == "ok": |
| slice_results.append(sr) |
| else: |
| failed_slices.append(task) |
| except Exception: |
| failed_slices.append(task) |
| finally: |
| if slice_pbar: |
| slice_pbar.set_postfix(ok=len(slice_results), fail=len(failed_slices)) |
| slice_pbar.update(1) |
| |
| if slice_pbar: |
| slice_pbar.close() |
|
|
| |
| for retry in range(max_retries): |
| if not failed_slices: |
| break |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Retry {retry + 1}/{max_retries}: {len(failed_slices)} failed slices") |
| time.sleep(1) |
|
|
| retry_futures = client.map( |
| lambda t: process_slice(t[0], t[1], t[2], chunk_size), |
| failed_slices, |
| pure=False, |
| ) |
| next_failed = [] |
| for task, future in zip(failed_slices, retry_futures): |
| try: |
| sr = future.result(timeout=3600) |
| if sr.status == "ok": |
| slice_results.append(sr) |
| else: |
| next_failed.append(task) |
| except Exception: |
| next_failed.append(task) |
| failed_slices = next_failed |
| |
| |
| if failed_slices and len(failed_slices) > 0: |
| emergency_chunk = max(10000, chunk_size // 10) |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ⚠️ EMERGENCY MODE: {len(failed_slices)} slices with extreme settings (chunk={emergency_chunk:,})") |
| time.sleep(2) |
| |
| |
| emergency_ok = 0 |
| for task in failed_slices: |
| try: |
| future = client.submit( |
| process_slice, task[0], task[1], task[2], emergency_chunk, |
| pure=False, |
| ) |
| sr = future.result(timeout=7200) |
| if sr.status == "ok": |
| slice_results.append(sr) |
| emergency_ok += 1 |
| except Exception as e: |
| tqdm.write(f" Emergency failed for slice {task[1]}-{task[2]}: {str(e)[:100]}") |
| continue |
| |
| if emergency_ok > 0: |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✓ Emergency mode recovered {emergency_ok}/{len(failed_slices)} slices") |
| |
| |
| failed_slices = [t for t in failed_slices if not any( |
| sr.slice_start == t[1] and sr.slice_end == t[2] for sr in slice_results |
| )] |
|
|
| |
| ok_count = len(slice_results) |
| fail_count = len(failed_slices) |
| elapsed = round(time.time() - t0, 1) |
|
|
| if ok_count == 0: |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all {n_slices} slices failed | {elapsed}s") |
| failures.append({ |
| "dataset_path": ds_path, |
| "dataset_file": ds_name, |
| "status": "failed", |
| "error": f"All {n_slices} slices failed", |
| "elapsed_sec": elapsed, |
| }) |
| dataset_pbar.update(1) |
| continue |
|
|
| |
| row = merge_slice_results(slice_results, n_obs, n_vars) |
| row["dataset_path"] = ds_path |
| row["dataset_file"] = ds_name |
| row["n_slices_total"] = n_slices |
| row["n_slices_ok"] = ok_count |
| row["n_slices_failed"] = fail_count |
|
|
| |
| try: |
| row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) |
| except Exception: |
| pass |
|
|
| |
| meta = extract_metadata_on_scheduler( |
| Path(ds_path), max_meta_cols, max_categories |
| ) |
| row.update(meta) |
|
|
| row["status"] = "ok" if fail_count == 0 else "partial" |
| row["elapsed_sec"] = elapsed |
|
|
| |
| try: |
| payload_name = safe_name(Path(ds_path)) + ".json" |
| (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) |
| except Exception as exc: |
| row["save_error"] = str(exc) |
|
|
| successes.append(row) |
| status = "✓" if fail_count == 0 else "⚠" |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s") |
|
|
| |
| del slice_results |
| gc.collect() |
| |
| |
| dataset_pbar.update(1) |
| |
| print(f"\nPhase 2 complete\n") |
| |
| |
| if xlarge_count > 0 and client: |
| print("Closing Dask cluster before Phase 3 (xlarge datasets process directly)...") |
| try: |
| client.close() |
| del client |
| gc.collect() |
| time.sleep(2) |
| except Exception as e: |
| print(f"Warning: Error closing Dask client: {e}") |
| |
| |
| |
| |
| if xlarge_count > 0: |
| print(f"{'='*80}") |
| print(f"PHASE 3: XLarge datasets ({xlarge_count}) - Direct processing (no Dask)") |
| print(f"{'='*80}\n") |
| |
| with tqdm( |
| total=xlarge_count, |
| desc="XLarge datasets", |
| position=0, |
| leave=True, |
| ncols=100 |
| ) as dataset_pbar: |
| for ds_local_idx, dataset in enumerate(xlarge_datasets): |
| dataset_idx = small_count + dask_count + ds_local_idx |
| ds_path = dataset["dataset_path"] |
| ds_name = Path(ds_path).name |
| n_obs = dataset["n_obs"] |
| n_vars = dataset["n_vars"] |
| |
| |
| chunk_size, obs_slice_size, size_cat = get_dataset_params(dataset) |
| |
| t0 = time.time() |
| |
| |
| slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold) |
| n_slices = len(slice_tasks) |
| |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] Processing {ds_name[:50]} | {n_slices} slices | chunk={chunk_size:,}") |
| |
| |
| slice_results: list[SliceResult] = [] |
| for slice_idx, (path, start, end) in enumerate(slice_tasks): |
| try: |
| sr = process_slice(path, start, end, chunk_size) |
| if sr.status == "ok": |
| slice_results.append(sr) |
| else: |
| tqdm.write(f" Slice {slice_idx+1}/{n_slices} failed: {sr.error}") |
| except Exception as e: |
| tqdm.write(f" Slice {slice_idx+1}/{n_slices} error: {str(e)[:100]}") |
| |
| ok_count = len(slice_results) |
| fail_count = n_slices - ok_count |
| elapsed = round(time.time() - t0, 1) |
| |
| if ok_count == 0: |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] ✗ FAILED: {ds_name[:50]} | all slices failed | {elapsed}s") |
| failures.append({ |
| "dataset_path": ds_path, |
| "dataset_file": ds_name, |
| "status": "failed", |
| "error": f"All {n_slices} slices failed (xlarge direct mode)", |
| "elapsed_sec": elapsed, |
| }) |
| dataset_pbar.update(1) |
| continue |
| |
| |
| row = merge_slice_results(slice_results, n_obs, n_vars) |
| row["dataset_path"] = ds_path |
| row["dataset_file"] = ds_name |
| row["n_slices_total"] = n_slices |
| row["n_slices_ok"] = ok_count |
| row["n_slices_failed"] = fail_count |
| row["processing_mode"] = "xlarge_direct" |
| |
| |
| try: |
| row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4) |
| except Exception: |
| pass |
| |
| |
| meta = extract_metadata_on_scheduler( |
| Path(ds_path), max_meta_cols, max_categories |
| ) |
| row.update(meta) |
| |
| row["status"] = "ok" if fail_count == 0 else "partial" |
| row["elapsed_sec"] = elapsed |
| |
| |
| try: |
| payload_name = safe_name(Path(ds_path)) + ".json" |
| (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2)) |
| except Exception as exc: |
| row["save_error"] = str(exc) |
| |
| successes.append(row) |
| status = "✓" if fail_count == 0 else "⚠" |
| tqdm.write(f" [{dataset_idx + 1}/{total_datasets}] {status} {ds_name[:50]} | {ok_count}/{n_slices} slices | {elapsed}s") |
| |
| dataset_pbar.update(1) |
| gc.collect() |
| |
| print(f"\nPhase 3 complete\n") |
|
|
| return successes, failures |
|
|
|
|
| |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--config", type=Path, required=True, help="YAML config") |
| parser.add_argument("--num-shards", type=int, help="Override num_shards") |
| parser.add_argument("--shard-index", type=int, help="Override shard_index") |
| parser.add_argument("--max-retries", type=int, default=3, help="Max retries per slice") |
| args = parser.parse_args() |
|
|
| config = load_config(args.config) |
|
|
| if args.num_shards is not None: |
| config["sharding"]["num_shards"] = args.num_shards |
| config["sharding"]["enabled"] = args.num_shards > 1 |
| if args.shard_index is not None: |
| config["sharding"]["shard_index"] = args.shard_index |
|
|
| num_shards = config["sharding"]["num_shards"] |
| shard_index = config["sharding"]["shard_index"] |
|
|
| configure_dask_for_hpc() |
|
|
| |
| cache_path = Path(config["paths"]["enhanced_metadata_cache"]) |
| if not cache_path.is_absolute(): |
| cache_path = Path(args.config).parent.parent / cache_path |
| print(f"Loading metadata from: {cache_path}") |
| metadata_df = load_enhanced_metadata(cache_path) |
|
|
| datasets = get_datasets_for_shard(metadata_df, config, num_shards, shard_index) |
| if not datasets: |
| print("No datasets scheduled for this shard.") |
| return |
|
|
| |
| output_dir = Path(config["paths"]["output_dir"]) |
| if not output_dir.is_absolute(): |
| output_dir = Path(args.config).parent.parent / output_dir |
| output_dir.mkdir(parents=True, exist_ok=True) |
| per_dataset_dir = output_dir / "per_dataset" |
| per_dataset_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| def is_dataset_done(ds_path: str) -> bool: |
| """Check if dataset already has a successful result.""" |
| try: |
| payload_name = safe_name(Path(ds_path)) + ".json" |
| result_file = per_dataset_dir / payload_name |
| if result_file.exists(): |
| result_data = json.loads(result_file.read_text()) |
| return result_data.get("status") == "ok" |
| except Exception: |
| pass |
| return False |
| |
| original_count = len(datasets) |
| datasets = [d for d in datasets if not is_dataset_done(d["dataset_path"])] |
| skipped_count = original_count - len(datasets) |
| |
| if skipped_count > 0: |
| print(f"\n{'='*80}") |
| print(f"RESUME MODE: Skipping {skipped_count} already-completed datasets") |
| print(f"Remaining to process: {len(datasets)}") |
| print(f"{'='*80}\n") |
| |
| if not datasets: |
| print("All datasets already completed. Nothing to do.") |
| return |
|
|
| |
| small_threshold = config["dataset_thresholds"]["small"] |
| |
| small_count_init = sum(1 for d in datasets if d.get("total_entries", 0) < small_threshold) |
| dask_count_init = sum(1 for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge") |
| xlarge_count_init = sum(1 for d in datasets if d.get("size_category", "large") == "xlarge") |
| |
| client = None |
| cluster = None |
| |
| if dask_count_init > 0: |
| |
| max_memory_gib = config["resources"]["max_memory_gib"] |
| max_workers = config["resources"]["max_workers"] |
| min_workers = config["resources"].get("min_workers", min(4, max_workers)) |
| threads_per_worker = config["resources"].get("threads_per_worker", 1) |
| |
| |
| adaptive_config = config["resources"].get("adaptive_scaling", {}) |
| target_duration = adaptive_config.get("target_duration", "30s") |
| wait_count = adaptive_config.get("wait_count", 3) |
| interval = adaptive_config.get("interval", "2s") |
|
|
| memory_per_worker_gib = max(2.0, max_memory_gib / max_workers) |
|
|
| total_entries = sum(d["total_entries"] for d in datasets) |
| total_slices = sum( |
| max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 75_000))) |
| for d in datasets if d.get("total_entries", 0) >= small_threshold and d.get("size_category", "large") != "xlarge" |
| ) |
|
|
| print(json.dumps({ |
| "total_datasets": len(datasets), |
| "small_datasets": small_count_init, |
| "large_datasets": dask_count_init + xlarge_count_init, |
| "total_slices": total_slices, |
| "total_entries": total_entries, |
| "shard_index": shard_index, |
| "num_shards": num_shards, |
| "memory_per_worker_gib": round(memory_per_worker_gib, 1), |
| "max_workers": max_workers, |
| }, indent=2)) |
|
|
| print(f"\nStarting Dask LocalCluster (for {dask_count_init} medium/large datasets):") |
| print(f" Workers: {min_workers} -> {max_workers} (adaptive)") |
| print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB") |
| print(f" Total memory budget: {max_memory_gib} GiB\n") |
|
|
| cluster = LocalCluster( |
| n_workers=min_workers, |
| threads_per_worker=threads_per_worker, |
| processes=True, |
| memory_limit=f"{memory_per_worker_gib}GiB", |
| silence_logs=True, |
| dashboard_address=None, |
| lifetime="180 minutes", |
| lifetime_stagger="20 minutes", |
| ) |
|
|
| cluster.adapt( |
| minimum=min_workers, |
| maximum=max_workers, |
| target_duration=target_duration, |
| wait_count=wait_count, |
| interval=interval, |
| ) |
|
|
| client = Client(cluster) |
| print(f"Dask cluster ready: {client}\n") |
| else: |
| print(f"No Dask-compatible datasets (all small or xlarge)\n") |
| |
| if xlarge_count_init > 0: |
| print(f"Note: {xlarge_count_init} xlarge datasets will be processed directly (Phase 3, no Dask)\n") |
|
|
| try: |
| successes, failures = process_all_datasets( |
| datasets, config, per_dataset_dir, client, |
| max_retries=args.max_retries, |
| ) |
| |
| |
| if skipped_count > 0: |
| print(f"\nLoading {skipped_count} previously completed results...") |
| for json_file in per_dataset_dir.glob("*.json"): |
| try: |
| result = json.loads(json_file.read_text()) |
| if result.get("status") == "ok": |
| |
| ds_path = result.get("dataset_path", "") |
| if not any(s.get("dataset_path") == ds_path for s in successes): |
| successes.append(result) |
| except Exception: |
| pass |
| print(f"Total results (new + previous): {len(successes)}") |
|
|
| print(f"\n{'=' * 80}") |
| print(f"PROCESSING COMPLETE") |
| print(f" Succeeded: {len(successes)}") |
| print(f" Failed: {len(failures)}") |
| print(f" Success rate: {len(successes) / max(1, original_count) * 100:.1f}%") |
| print(f"{'=' * 80}\n") |
|
|
| if failures: |
| print("WARNING: Some datasets failed permanently:") |
| for fail in failures[:10]: |
| print(f" - {fail['dataset_file']}: {fail.get('error', 'Unknown')[:80]}") |
| if len(failures) > 10: |
| print(f" ... and {len(failures) - 10} more") |
|
|
| except KeyboardInterrupt: |
| print("\n\n{'=' * 80}") |
| print("INTERRUPTED - Saving partial results...") |
| print(f"{'=' * 80}\n") |
| successes = [] |
| failures = [] |
| seen_paths = set() |
| |
| for json_file in per_dataset_dir.glob("*.json"): |
| try: |
| result = json.loads(json_file.read_text()) |
| ds_path = result.get("dataset_path", "") |
| if ds_path and ds_path in seen_paths: |
| continue |
| seen_paths.add(ds_path) |
| if result.get("status") == "ok": |
| successes.append(result) |
| else: |
| failures.append(result) |
| except Exception: |
| pass |
| except Exception as exc: |
| print(f"\n\nERROR during processing: {exc}") |
| print("Saving partial results...") |
| successes = [] |
| failures = [] |
| seen_paths = set() |
| |
| for json_file in per_dataset_dir.glob("*.json"): |
| try: |
| result = json.loads(json_file.read_text()) |
| ds_path = result.get("dataset_path", "") |
| if ds_path and ds_path in seen_paths: |
| continue |
| seen_paths.add(ds_path) |
| if result.get("status") == "ok": |
| successes.append(result) |
| else: |
| failures.append(result) |
| except Exception: |
| pass |
| |
| |
| try: |
| summary_df = pd.DataFrame(successes) |
| |
| if not summary_df.empty and 'dataset_path' in summary_df.columns: |
| original_count = len(summary_df) |
| summary_df = summary_df.drop_duplicates(subset=['dataset_path'], keep='first') |
| if len(summary_df) < original_count: |
| print(f"\nRemoved {original_count - len(summary_df)} duplicate entries from results") |
| |
| summary_csv = output_dir / f"eda_summary_shard_{shard_index:03d}_of_{num_shards:03d}.csv" |
| summary_df.to_csv(summary_csv, index=False) |
|
|
| failures_path = output_dir / f"eda_failures_shard_{shard_index:03d}_of_{num_shards:03d}.json" |
| failures_path.write_text(json.dumps(failures, indent=2)) |
|
|
| print(f"\n{'=' * 80}") |
| print("RESULTS SAVED") |
| print(f" Summary CSV: {summary_csv}") |
| print(f" Failures JSON: {failures_path}") |
| print(json.dumps({ |
| "ok_count": len(successes), |
| "failed_count": len(failures), |
| }, indent=2)) |
| print(f"{'=' * 80}\n") |
| except Exception as save_exc: |
| print(f"ERROR saving results: {save_exc}") |
|
|
| finally: |
| if client: |
| client.close() |
| if cluster: |
| cluster.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|