feat(eda): implement hybrid processing strategy for small and large datasets
Browse files- scripts/distributed_eda.py +400 -131
scripts/distributed_eda.py
CHANGED
|
@@ -12,6 +12,7 @@ and metadata summaries. This handles datasets from 2 GB to 500 GB.
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import argparse
|
|
|
|
| 15 |
import gc
|
| 16 |
import hashlib
|
| 17 |
import json
|
|
@@ -165,7 +166,184 @@ def merge_slice_results(slices: list[SliceResult], n_obs: int, n_vars: int) -> d
|
|
| 165 |
|
| 166 |
|
| 167 |
# ---------------------------------------------------------------------------
|
| 168 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
# ---------------------------------------------------------------------------
|
| 170 |
def process_slice(
|
| 171 |
path_str: str,
|
|
@@ -458,122 +636,203 @@ def process_all_datasets(
|
|
| 458 |
datasets: list[dict],
|
| 459 |
config: dict,
|
| 460 |
per_dataset_dir: Path,
|
| 461 |
-
client: Client,
|
| 462 |
max_retries: int = 3,
|
| 463 |
) -> tuple[list[dict], list[dict]]:
|
| 464 |
-
"""Process all datasets
|
| 465 |
-
|
| 466 |
-
Each task processes at most obs_slice_size rows. Results are merged
|
| 467 |
-
per-dataset with O(n_vars) memory on the scheduler side.
|
| 468 |
-
"""
|
| 469 |
chunk_size = config["resources"]["chunk_size"]
|
| 470 |
obs_slice_size = config["slicing"].get("obs_slice_size", 75_000)
|
| 471 |
small_threshold = config["dataset_thresholds"]["small"]
|
| 472 |
max_meta_cols = config["metadata"]["max_meta_cols"]
|
| 473 |
max_categories = config["metadata"]["max_categories"]
|
|
|
|
| 474 |
|
| 475 |
successes = []
|
| 476 |
failures = []
|
| 477 |
|
| 478 |
-
# Categorize
|
| 479 |
small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold]
|
| 480 |
large_datasets = [d for d in datasets if d.get("total_entries", 0) >= small_threshold]
|
| 481 |
|
| 482 |
-
# Sort each category
|
| 483 |
small_datasets.sort(key=lambda d: d["total_entries"])
|
| 484 |
large_datasets.sort(key=lambda d: d["total_entries"])
|
| 485 |
|
| 486 |
-
# Process small first in batches for speed, then large one-by-one
|
| 487 |
datasets_sorted = small_datasets + large_datasets
|
| 488 |
small_count = len(small_datasets)
|
| 489 |
sliced_count = len(large_datasets)
|
| 490 |
|
| 491 |
-
# Calculate batch size for small datasets (maximize worker utilization)
|
| 492 |
-
max_workers = client.cluster.maximum if hasattr(client.cluster, 'maximum') else 48
|
| 493 |
-
small_batch_size = max(1, min(max_workers, small_count)) # Process up to max_workers at once
|
| 494 |
-
|
| 495 |
print(f"\n{'=' * 80}")
|
| 496 |
print(f"Processing {len(datasets_sorted)} datasets")
|
| 497 |
-
print(f" Small datasets (
|
| 498 |
-
print(f" Medium/Large (
|
| 499 |
print(f"Slice size: {obs_slice_size:,} rows per task (for medium/large)")
|
| 500 |
print(f"Small threshold: {small_threshold:,} entries")
|
| 501 |
print(f"Chunk size: {chunk_size:,} rows per sub-chunk")
|
| 502 |
print(f"{'=' * 80}\n")
|
| 503 |
|
| 504 |
total_datasets = len(datasets_sorted)
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
#
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
ds_path = dataset["dataset_path"]
|
| 537 |
ds_name = Path(ds_path).name
|
| 538 |
n_obs = dataset["n_obs"]
|
| 539 |
n_vars = dataset["n_vars"]
|
| 540 |
total_entries = dataset["total_entries"]
|
| 541 |
|
| 542 |
-
|
|
|
|
|
|
|
| 543 |
slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
|
| 544 |
n_slices = len(slice_tasks)
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
# Submit slice tasks to Dask
|
| 547 |
futures = client.map(
|
| 548 |
lambda t: process_slice(t[0], t[1], t[2], chunk_size),
|
| 549 |
slice_tasks,
|
| 550 |
pure=False,
|
| 551 |
)
|
| 552 |
-
|
| 553 |
-
batch_futures.append(futures)
|
| 554 |
-
batch_info.append({
|
| 555 |
-
'dataset': dataset,
|
| 556 |
-
'ds_idx': dataset_idx,
|
| 557 |
-
'ds_path': ds_path,
|
| 558 |
-
'ds_name': ds_name,
|
| 559 |
-
'n_obs': n_obs,
|
| 560 |
-
'n_vars': n_vars,
|
| 561 |
-
'total_entries': total_entries,
|
| 562 |
-
'slice_tasks': slice_tasks,
|
| 563 |
-
'n_slices': n_slices,
|
| 564 |
-
't0': time.time()
|
| 565 |
-
})
|
| 566 |
-
|
| 567 |
-
# Process results for the batch
|
| 568 |
-
dataset_pbar.set_description(f"Datasets [{ds_idx + 1}-{batch_end}/{total_datasets}]" if len(batch) > 1 else f"Datasets [{ds_idx + 1}/{total_datasets}]")
|
| 569 |
-
|
| 570 |
-
for info_idx, (futures, info) in enumerate(zip(batch_futures, batch_info)):
|
| 571 |
-
dataset = info['dataset']
|
| 572 |
-
dataset_idx = info['ds_idx']
|
| 573 |
-
ds_path = info['ds_path']
|
| 574 |
-
ds_name = info['ds_name']
|
| 575 |
-
n_obs = info['n_obs']
|
| 576 |
-
n_vars = info['n_vars']
|
| 577 |
total_entries = info['total_entries']
|
| 578 |
slice_tasks = info['slice_tasks']
|
| 579 |
n_slices = info['n_slices']
|
|
@@ -583,11 +842,11 @@ def process_all_datasets(
|
|
| 583 |
slice_results: list[SliceResult] = []
|
| 584 |
failed_slices: list[tuple[str, int, int]] = []
|
| 585 |
|
| 586 |
-
# Collect results with progress bar (show
|
| 587 |
-
show_slice_bar = n_slices > 1
|
| 588 |
slice_pbar = tqdm(
|
| 589 |
total=n_slices,
|
| 590 |
-
desc=f"
|
| 591 |
position=1,
|
| 592 |
leave=False,
|
| 593 |
ncols=100,
|
|
@@ -696,8 +955,7 @@ def process_all_datasets(
|
|
| 696 |
# Update dataset progress
|
| 697 |
dataset_pbar.update(1)
|
| 698 |
|
| 699 |
-
|
| 700 |
-
ds_idx = batch_end
|
| 701 |
|
| 702 |
return successes, failures
|
| 703 |
|
|
@@ -746,60 +1004,69 @@ def main() -> None:
|
|
| 746 |
per_dataset_dir = output_dir / "per_dataset"
|
| 747 |
per_dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 748 |
|
| 749 |
-
#
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
"num_shards": num_shards,
|
| 770 |
-
"memory_per_worker_gib": round(memory_per_worker_gib, 1),
|
| 771 |
-
"max_workers": max_workers,
|
| 772 |
-
}, indent=2))
|
| 773 |
-
|
| 774 |
-
print(f"\nStarting Dask LocalCluster:")
|
| 775 |
-
print(f" Workers: {min_workers} -> {max_workers} (adaptive)")
|
| 776 |
-
print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB")
|
| 777 |
-
print(f" Total memory budget: {max_memory_gib} GiB\n")
|
| 778 |
-
|
| 779 |
-
cluster = LocalCluster(
|
| 780 |
-
n_workers=min_workers,
|
| 781 |
-
threads_per_worker=1,
|
| 782 |
-
processes=True,
|
| 783 |
-
memory_limit=f"{memory_per_worker_gib}GiB",
|
| 784 |
-
silence_logs=True,
|
| 785 |
-
dashboard_address=None,
|
| 786 |
-
lifetime="120 minutes",
|
| 787 |
-
lifetime_stagger="15 minutes",
|
| 788 |
-
)
|
| 789 |
-
|
| 790 |
-
cluster.adapt(
|
| 791 |
-
minimum=min_workers,
|
| 792 |
-
maximum=max_workers,
|
| 793 |
-
target_duration="30s",
|
| 794 |
-
wait_count=3,
|
| 795 |
-
interval="2s",
|
| 796 |
-
)
|
| 797 |
-
|
| 798 |
-
client = Client(cluster)
|
| 799 |
|
| 800 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
print(f"Dask cluster ready: {client}\n")
|
|
|
|
|
|
|
| 802 |
|
|
|
|
| 803 |
successes, failures = process_all_datasets(
|
| 804 |
datasets, config, per_dataset_dir, client,
|
| 805 |
max_retries=args.max_retries,
|
|
@@ -835,8 +1102,10 @@ def main() -> None:
|
|
| 835 |
}, indent=2))
|
| 836 |
|
| 837 |
finally:
|
| 838 |
-
client
|
| 839 |
-
|
|
|
|
|
|
|
| 840 |
|
| 841 |
|
| 842 |
if __name__ == "__main__":
|
|
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import argparse
|
| 15 |
+
import concurrent.futures
|
| 16 |
import gc
|
| 17 |
import hashlib
|
| 18 |
import json
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
# ---------------------------------------------------------------------------
|
| 169 |
+
# Simple worker function for small datasets (no Dask overhead)
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
def process_dataset_simple(
|
| 172 |
+
path_str: str,
|
| 173 |
+
n_obs: int,
|
| 174 |
+
n_vars: int,
|
| 175 |
+
chunk_size: int,
|
| 176 |
+
max_meta_cols: int,
|
| 177 |
+
max_categories: int,
|
| 178 |
+
) -> dict:
|
| 179 |
+
"""Process entire small dataset in one worker (no slicing, no Dask)."""
|
| 180 |
+
t0 = time.time()
|
| 181 |
+
path = Path(path_str)
|
| 182 |
+
row: dict[str, Any] = {
|
| 183 |
+
"dataset_path": path_str,
|
| 184 |
+
"dataset_file": path.name,
|
| 185 |
+
"n_obs": n_obs,
|
| 186 |
+
"n_vars": n_vars,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
adata = ad.read_h5ad(path, backed="r")
|
| 191 |
+
total_entries = n_obs * n_vars
|
| 192 |
+
|
| 193 |
+
nnz_total = 0
|
| 194 |
+
x_sum = 0.0
|
| 195 |
+
x_sum_sq = 0.0
|
| 196 |
+
|
| 197 |
+
# Cell-level accumulators
|
| 198 |
+
cell_total_counts_sum = 0.0
|
| 199 |
+
cell_total_counts_min = math.inf
|
| 200 |
+
cell_total_counts_max = -math.inf
|
| 201 |
+
cell_n_genes_sum = 0
|
| 202 |
+
cell_n_genes_min = 2**63 - 1
|
| 203 |
+
cell_n_genes_max = 0
|
| 204 |
+
|
| 205 |
+
# Gene-level accumulators
|
| 206 |
+
gene_n_cells = np.zeros(n_vars, dtype=np.int64)
|
| 207 |
+
gene_total_counts = np.zeros(n_vars, dtype=np.float64)
|
| 208 |
+
|
| 209 |
+
# Process in chunks
|
| 210 |
+
for start in range(0, n_obs, chunk_size):
|
| 211 |
+
end = min(start + chunk_size, n_obs)
|
| 212 |
+
chunk = adata.X[start:end, :]
|
| 213 |
+
|
| 214 |
+
if sparse.issparse(chunk):
|
| 215 |
+
csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk
|
| 216 |
+
data = csr.data.astype(np.float64, copy=False)
|
| 217 |
+
|
| 218 |
+
nnz_total += int(csr.nnz)
|
| 219 |
+
x_sum += float(data.sum())
|
| 220 |
+
x_sum_sq += float(np.square(data).sum())
|
| 221 |
+
|
| 222 |
+
# Cell stats
|
| 223 |
+
cell_counts = np.asarray(csr.sum(axis=1)).ravel()
|
| 224 |
+
cell_genes = np.diff(csr.indptr).astype(np.int64)
|
| 225 |
+
|
| 226 |
+
cell_total_counts_sum += float(cell_counts.sum())
|
| 227 |
+
cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
|
| 228 |
+
cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
|
| 229 |
+
cell_n_genes_sum += int(cell_genes.sum())
|
| 230 |
+
cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
|
| 231 |
+
cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
|
| 232 |
+
|
| 233 |
+
# Gene stats
|
| 234 |
+
csc = csr.tocsc()
|
| 235 |
+
gene_n_cells += np.diff(csc.indptr).astype(np.int64)
|
| 236 |
+
gene_total_counts += np.asarray(csc.sum(axis=0)).ravel()
|
| 237 |
+
|
| 238 |
+
del csr, csc, data
|
| 239 |
+
else:
|
| 240 |
+
arr = np.asarray(chunk, dtype=np.float64)
|
| 241 |
+
nz = arr != 0
|
| 242 |
+
|
| 243 |
+
nnz_total += int(nz.sum())
|
| 244 |
+
x_sum += float(arr.sum())
|
| 245 |
+
x_sum_sq += float(np.square(arr).sum())
|
| 246 |
+
|
| 247 |
+
# Cell stats
|
| 248 |
+
cell_counts = arr.sum(axis=1)
|
| 249 |
+
cell_genes = nz.sum(axis=1).astype(np.int64)
|
| 250 |
+
|
| 251 |
+
cell_total_counts_sum += float(cell_counts.sum())
|
| 252 |
+
cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
|
| 253 |
+
cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
|
| 254 |
+
cell_n_genes_sum += int(cell_genes.sum())
|
| 255 |
+
cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
|
| 256 |
+
cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
|
| 257 |
+
|
| 258 |
+
# Gene stats
|
| 259 |
+
gene_n_cells += nz.sum(axis=0).astype(np.int64)
|
| 260 |
+
gene_total_counts += arr.sum(axis=0)
|
| 261 |
+
|
| 262 |
+
del arr, nz
|
| 263 |
+
|
| 264 |
+
del chunk
|
| 265 |
+
gc.collect()
|
| 266 |
+
|
| 267 |
+
# Matrix-level stats
|
| 268 |
+
row["nnz"] = int(nnz_total)
|
| 269 |
+
row["sparsity"] = float(1.0 - nnz_total / total_entries) if total_entries else None
|
| 270 |
+
row["x_mean"] = float(x_sum / total_entries) if total_entries else None
|
| 271 |
+
if total_entries:
|
| 272 |
+
var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2)
|
| 273 |
+
row["x_std"] = float(math.sqrt(var))
|
| 274 |
+
else:
|
| 275 |
+
row["x_std"] = None
|
| 276 |
+
|
| 277 |
+
# Cell-level stats
|
| 278 |
+
if n_obs > 0:
|
| 279 |
+
row["cell_total_counts_min"] = float(cell_total_counts_min)
|
| 280 |
+
row["cell_total_counts_max"] = float(cell_total_counts_max)
|
| 281 |
+
row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs)
|
| 282 |
+
row["cell_n_genes_detected_min"] = int(cell_n_genes_min)
|
| 283 |
+
row["cell_n_genes_detected_max"] = int(cell_n_genes_max)
|
| 284 |
+
row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs)
|
| 285 |
+
else:
|
| 286 |
+
row["cell_total_counts_min"] = None
|
| 287 |
+
row["cell_total_counts_max"] = None
|
| 288 |
+
row["cell_total_counts_mean"] = None
|
| 289 |
+
row["cell_n_genes_detected_min"] = None
|
| 290 |
+
row["cell_n_genes_detected_max"] = None
|
| 291 |
+
row["cell_n_genes_detected_mean"] = None
|
| 292 |
+
|
| 293 |
+
# Gene-level stats
|
| 294 |
+
genes_detected = int(np.count_nonzero(gene_n_cells))
|
| 295 |
+
row["genes_detected_in_any_cell"] = genes_detected
|
| 296 |
+
row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0
|
| 297 |
+
if genes_detected > 0:
|
| 298 |
+
mask = gene_n_cells > 0
|
| 299 |
+
row["gene_n_cells_min"] = int(gene_n_cells[mask].min())
|
| 300 |
+
row["gene_n_cells_max"] = int(gene_n_cells[mask].max())
|
| 301 |
+
row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean())
|
| 302 |
+
row["gene_total_counts_min"] = float(gene_total_counts[mask].min())
|
| 303 |
+
row["gene_total_counts_max"] = float(gene_total_counts[mask].max())
|
| 304 |
+
row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean())
|
| 305 |
+
else:
|
| 306 |
+
for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean",
|
| 307 |
+
"gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"):
|
| 308 |
+
row[k] = 0
|
| 309 |
+
|
| 310 |
+
# Metadata
|
| 311 |
+
row["obs_columns"] = int(len(adata.obs.columns))
|
| 312 |
+
row["var_columns"] = int(len(adata.var.columns))
|
| 313 |
+
row["metadata_obs_summary"] = summarize_metadata(
|
| 314 |
+
adata.obs, max_cols=max_meta_cols, max_categories=max_categories
|
| 315 |
+
)
|
| 316 |
+
row["metadata_var_summary"] = summarize_metadata(
|
| 317 |
+
adata.var, max_cols=max_meta_cols, max_categories=max_categories
|
| 318 |
+
)
|
| 319 |
+
row["obs_schema"] = extract_schema(adata.obs)
|
| 320 |
+
row["var_schema"] = extract_schema(adata.var)
|
| 321 |
+
|
| 322 |
+
# Clean up
|
| 323 |
+
del gene_n_cells, gene_total_counts
|
| 324 |
+
try:
|
| 325 |
+
if hasattr(adata, "file") and adata.file is not None:
|
| 326 |
+
adata.file.close()
|
| 327 |
+
except Exception:
|
| 328 |
+
pass
|
| 329 |
+
del adata
|
| 330 |
+
|
| 331 |
+
row["status"] = "ok"
|
| 332 |
+
row["n_slices_total"] = 1
|
| 333 |
+
row["n_slices_ok"] = 1
|
| 334 |
+
row["n_slices_failed"] = 0
|
| 335 |
+
|
| 336 |
+
except Exception as exc:
|
| 337 |
+
row["status"] = "failed"
|
| 338 |
+
row["error"] = str(exc)
|
| 339 |
+
|
| 340 |
+
gc.collect()
|
| 341 |
+
row["elapsed_sec"] = round(time.time() - t0, 2)
|
| 342 |
+
return row
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ---------------------------------------------------------------------------
|
| 346 |
+
# Core worker function: process ONE slice of ONE dataset (Dask)
|
| 347 |
# ---------------------------------------------------------------------------
|
| 348 |
def process_slice(
|
| 349 |
path_str: str,
|
|
|
|
| 636 |
datasets: list[dict],
|
| 637 |
config: dict,
|
| 638 |
per_dataset_dir: Path,
|
| 639 |
+
client: Client | None,
|
| 640 |
max_retries: int = 3,
|
| 641 |
) -> tuple[list[dict], list[dict]]:
|
| 642 |
+
"""Process all datasets: small ones with ProcessPoolExecutor, large ones with Dask."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
chunk_size = config["resources"]["chunk_size"]
|
| 644 |
obs_slice_size = config["slicing"].get("obs_slice_size", 75_000)
|
| 645 |
small_threshold = config["dataset_thresholds"]["small"]
|
| 646 |
max_meta_cols = config["metadata"]["max_meta_cols"]
|
| 647 |
max_categories = config["metadata"]["max_categories"]
|
| 648 |
+
max_workers_base = config["resources"]["max_workers"]
|
| 649 |
|
| 650 |
successes = []
|
| 651 |
failures = []
|
| 652 |
|
| 653 |
+
# Categorize datasets
|
| 654 |
small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold]
|
| 655 |
large_datasets = [d for d in datasets if d.get("total_entries", 0) >= small_threshold]
|
| 656 |
|
|
|
|
| 657 |
small_datasets.sort(key=lambda d: d["total_entries"])
|
| 658 |
large_datasets.sort(key=lambda d: d["total_entries"])
|
| 659 |
|
|
|
|
| 660 |
datasets_sorted = small_datasets + large_datasets
|
| 661 |
small_count = len(small_datasets)
|
| 662 |
sliced_count = len(large_datasets)
|
| 663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
print(f"\n{'=' * 80}")
|
| 665 |
print(f"Processing {len(datasets_sorted)} datasets")
|
| 666 |
+
print(f" Small datasets (ProcessPoolExecutor): {small_count}")
|
| 667 |
+
print(f" Medium/Large (Dask + slicing): {sliced_count}")
|
| 668 |
print(f"Slice size: {obs_slice_size:,} rows per task (for medium/large)")
|
| 669 |
print(f"Small threshold: {small_threshold:,} entries")
|
| 670 |
print(f"Chunk size: {chunk_size:,} rows per sub-chunk")
|
| 671 |
print(f"{'=' * 80}\n")
|
| 672 |
|
| 673 |
total_datasets = len(datasets_sorted)
|
| 674 |
+
|
| 675 |
+
# ========================================================================
|
| 676 |
+
# Phase 1: Process small datasets with ProcessPoolExecutor (batched)
|
| 677 |
+
# ========================================================================
|
| 678 |
+
if small_count > 0:
|
| 679 |
+
print(f"{'='*80}")
|
| 680 |
+
print(f"PHASE 1: Small datasets ({small_count}) - ProcessPoolExecutor")
|
| 681 |
+
print(f"{'='*80}\n")
|
| 682 |
+
|
| 683 |
+
# Adaptive worker management
|
| 684 |
+
current_workers = max_workers_base
|
| 685 |
+
min_workers = max(1, max_workers_base // 4)
|
| 686 |
+
batch_size = max(30, min(100, small_count // 4))
|
| 687 |
+
|
| 688 |
+
# Throughput monitoring
|
| 689 |
+
check_interval = 50
|
| 690 |
+
baseline_throughput = None
|
| 691 |
+
slowdown_threshold = 0.5
|
| 692 |
+
last_check_idx = 0
|
| 693 |
+
batch_start_time = time.time()
|
| 694 |
+
|
| 695 |
+
print(f"Workers: {current_workers} (adaptive: {min_workers}-{max_workers_base})")
|
| 696 |
+
print(f"Batch size: {batch_size} (recycled between batches)\n")
|
| 697 |
+
|
| 698 |
+
with tqdm(total=small_count, desc="Small datasets", position=0) as pbar:
|
| 699 |
+
for batch_start in range(0, small_count, batch_size):
|
| 700 |
+
batch_end = min(batch_start + batch_size, small_count)
|
| 701 |
+
batch = small_datasets[batch_start:batch_end]
|
| 702 |
+
|
| 703 |
+
# Check throughput and adjust workers
|
| 704 |
+
processed = len(successes) + len(failures)
|
| 705 |
+
if processed >= last_check_idx + check_interval and processed > check_interval:
|
| 706 |
+
elapsed = time.time() - batch_start_time
|
| 707 |
+
current_throughput = processed / elapsed if elapsed > 0 else 0
|
| 708 |
+
|
| 709 |
+
if baseline_throughput is None and processed >= check_interval * 2:
|
| 710 |
+
baseline_throughput = current_throughput
|
| 711 |
+
tqdm.write(f"Baseline: {baseline_throughput:.2f} ds/sec")
|
| 712 |
+
|
| 713 |
+
if baseline_throughput and current_throughput < baseline_throughput * slowdown_threshold:
|
| 714 |
+
if current_workers > min_workers:
|
| 715 |
+
old_workers = current_workers
|
| 716 |
+
current_workers = max(min_workers, current_workers // 2)
|
| 717 |
+
tqdm.write(f"⚠️ Slowdown detected. Workers: {old_workers} → {current_workers}")
|
| 718 |
+
baseline_throughput = None
|
| 719 |
+
|
| 720 |
+
last_check_idx = processed
|
| 721 |
+
|
| 722 |
+
# Process batch
|
| 723 |
+
executor = concurrent.futures.ProcessPoolExecutor(max_workers=current_workers)
|
| 724 |
+
futures = {}
|
| 725 |
+
|
| 726 |
+
try:
|
| 727 |
+
for dataset in batch:
|
| 728 |
+
future = executor.submit(
|
| 729 |
+
process_dataset_simple,
|
| 730 |
+
dataset["dataset_path"],
|
| 731 |
+
dataset["n_obs"],
|
| 732 |
+
dataset["n_vars"],
|
| 733 |
+
chunk_size,
|
| 734 |
+
max_meta_cols,
|
| 735 |
+
max_categories,
|
| 736 |
+
)
|
| 737 |
+
futures[future] = dataset
|
| 738 |
+
|
| 739 |
+
for future in concurrent.futures.as_completed(futures):
|
| 740 |
+
dataset = futures[future]
|
| 741 |
+
ds_path = dataset["dataset_path"]
|
| 742 |
+
ds_name = Path(ds_path).name
|
| 743 |
+
|
| 744 |
+
try:
|
| 745 |
+
row = future.result(timeout=3600)
|
| 746 |
+
|
| 747 |
+
# File size
|
| 748 |
+
try:
|
| 749 |
+
row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4)
|
| 750 |
+
except Exception:
|
| 751 |
+
pass
|
| 752 |
+
|
| 753 |
+
# Save JSON
|
| 754 |
+
try:
|
| 755 |
+
payload_name = safe_name(Path(ds_path)) + ".json"
|
| 756 |
+
(per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2))
|
| 757 |
+
except Exception as exc:
|
| 758 |
+
row["save_error"] = str(exc)
|
| 759 |
+
|
| 760 |
+
if row.get("status") == "ok":
|
| 761 |
+
successes.append(row)
|
| 762 |
+
elapsed = row.get("elapsed_sec", "?")
|
| 763 |
+
tqdm.write(f" [{len(successes)}/{total_datasets}] ✓ {ds_name[:50]} | {elapsed}s")
|
| 764 |
+
else:
|
| 765 |
+
failures.append(row)
|
| 766 |
+
error = row.get("error", "Unknown")[:60]
|
| 767 |
+
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {error}")
|
| 768 |
+
|
| 769 |
+
except concurrent.futures.TimeoutError:
|
| 770 |
+
failures.append({
|
| 771 |
+
"dataset_path": ds_path,
|
| 772 |
+
"dataset_file": ds_name,
|
| 773 |
+
"status": "failed",
|
| 774 |
+
"error": "Timeout",
|
| 775 |
+
})
|
| 776 |
+
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | Timeout")
|
| 777 |
+
except Exception as exc:
|
| 778 |
+
failures.append({
|
| 779 |
+
"dataset_path": ds_path,
|
| 780 |
+
"dataset_file": ds_name,
|
| 781 |
+
"status": "failed",
|
| 782 |
+
"error": str(exc),
|
| 783 |
+
})
|
| 784 |
+
tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {exc}")
|
| 785 |
+
finally:
|
| 786 |
+
pbar.update(1)
|
| 787 |
+
finally:
|
| 788 |
+
executor.shutdown(wait=True)
|
| 789 |
+
gc.collect()
|
| 790 |
+
time.sleep(1)
|
| 791 |
+
|
| 792 |
+
print(f"\nPhase 1 complete: {len([s for s in successes if s in successes[-small_count:]])} ok, " +
|
| 793 |
+
f"{len([f for f in failures if f in failures[-small_count:]])} failed\n")
|
| 794 |
+
|
| 795 |
+
# ========================================================================
|
| 796 |
+
# Phase 2: Process large datasets with Dask (existing logic)
|
| 797 |
+
# ========================================================================
|
| 798 |
+
if sliced_count > 0 and client:
|
| 799 |
+
print(f"{'='*80}")
|
| 800 |
+
print(f"PHASE 2: Medium/Large datasets ({sliced_count}) - Dask + slicing")
|
| 801 |
+
print(f"{'='*80}\n")
|
| 802 |
+
|
| 803 |
+
with tqdm(
|
| 804 |
+
total=sliced_count,
|
| 805 |
+
desc="Med/Large datasets",
|
| 806 |
+
position=0,
|
| 807 |
+
leave=True,
|
| 808 |
+
ncols=100
|
| 809 |
+
) as dataset_pbar:
|
| 810 |
+
for ds_local_idx, dataset in enumerate(large_datasets):
|
| 811 |
+
dataset_idx = small_count + ds_local_idx
|
| 812 |
ds_path = dataset["dataset_path"]
|
| 813 |
ds_name = Path(ds_path).name
|
| 814 |
n_obs = dataset["n_obs"]
|
| 815 |
n_vars = dataset["n_vars"]
|
| 816 |
total_entries = dataset["total_entries"]
|
| 817 |
|
| 818 |
+
t0 = time.time()
|
| 819 |
+
|
| 820 |
+
# Create slice tasks
|
| 821 |
slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
|
| 822 |
n_slices = len(slice_tasks)
|
| 823 |
|
| 824 |
+
dataset_pbar.set_description(f"Med/Large [{ds_local_idx + 1}/{sliced_count}]")
|
| 825 |
+
|
| 826 |
+
# Submit all slices for this dataset
|
| 827 |
+
slice_results: list[SliceResult] = []
|
| 828 |
+
failed_slices: list[tuple[str, int, int]] = []
|
| 829 |
+
|
| 830 |
# Submit slice tasks to Dask
|
| 831 |
futures = client.map(
|
| 832 |
lambda t: process_slice(t[0], t[1], t[2], chunk_size),
|
| 833 |
slice_tasks,
|
| 834 |
pure=False,
|
| 835 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
total_entries = info['total_entries']
|
| 837 |
slice_tasks = info['slice_tasks']
|
| 838 |
n_slices = info['n_slices']
|
|
|
|
| 842 |
slice_results: list[SliceResult] = []
|
| 843 |
failed_slices: list[tuple[str, int, int]] = []
|
| 844 |
|
| 845 |
+
# Collect results with progress bar (show for sliced datasets)
|
| 846 |
+
show_slice_bar = n_slices > 1
|
| 847 |
slice_pbar = tqdm(
|
| 848 |
total=n_slices,
|
| 849 |
+
desc=f" \u2514\u2500 Slices",
|
| 850 |
position=1,
|
| 851 |
leave=False,
|
| 852 |
ncols=100,
|
|
|
|
| 955 |
# Update dataset progress
|
| 956 |
dataset_pbar.update(1)
|
| 957 |
|
| 958 |
+
print(f"\nPhase 2 complete\n")
|
|
|
|
| 959 |
|
| 960 |
return successes, failures
|
| 961 |
|
|
|
|
| 1004 |
per_dataset_dir = output_dir / "per_dataset"
|
| 1005 |
per_dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 1006 |
|
| 1007 |
+
# Check if we need Dask cluster (for medium/large datasets)
|
| 1008 |
+
small_threshold = config["dataset_thresholds"]["small"]
|
| 1009 |
+
large_count = sum(1 for d in datasets if d.get("total_entries", 0) >= small_threshold)
|
| 1010 |
+
|
| 1011 |
+
client = None
|
| 1012 |
+
cluster = None
|
| 1013 |
+
|
| 1014 |
+
if large_count > 0:
|
| 1015 |
+
# Cluster setup for large datasets
|
| 1016 |
+
max_memory_gib = config["resources"]["max_memory_gib"]
|
| 1017 |
+
max_workers = config["resources"]["max_workers"]
|
| 1018 |
+
min_workers = min(4, max_workers)
|
| 1019 |
+
|
| 1020 |
+
memory_per_worker_gib = max(2.0, max_memory_gib / max_workers)
|
| 1021 |
+
|
| 1022 |
+
total_entries = sum(d["total_entries"] for d in datasets)
|
| 1023 |
+
total_slices = sum(
|
| 1024 |
+
max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 75_000)))
|
| 1025 |
+
for d in datasets if d.get("total_entries", 0) >= small_threshold
|
| 1026 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
|
| 1028 |
+
print(json.dumps({
|
| 1029 |
+
"total_datasets": len(datasets),
|
| 1030 |
+
"small_datasets": len(datasets) - large_count,
|
| 1031 |
+
"large_datasets": large_count,
|
| 1032 |
+
"total_slices": total_slices,
|
| 1033 |
+
"total_entries": total_entries,
|
| 1034 |
+
"shard_index": shard_index,
|
| 1035 |
+
"num_shards": num_shards,
|
| 1036 |
+
"memory_per_worker_gib": round(memory_per_worker_gib, 1),
|
| 1037 |
+
"max_workers": max_workers,
|
| 1038 |
+
}, indent=2))
|
| 1039 |
+
|
| 1040 |
+
print(f"\nStarting Dask LocalCluster (for {large_count} large datasets):")
|
| 1041 |
+
print(f" Workers: {min_workers} -> {max_workers} (adaptive)")
|
| 1042 |
+
print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB")
|
| 1043 |
+
print(f" Total memory budget: {max_memory_gib} GiB\n")
|
| 1044 |
+
|
| 1045 |
+
cluster = LocalCluster(
|
| 1046 |
+
n_workers=min_workers,
|
| 1047 |
+
threads_per_worker=1,
|
| 1048 |
+
processes=True,
|
| 1049 |
+
memory_limit=f"{memory_per_worker_gib}GiB",
|
| 1050 |
+
silence_logs=True,
|
| 1051 |
+
dashboard_address=None,
|
| 1052 |
+
lifetime="120 minutes",
|
| 1053 |
+
lifetime_stagger="15 minutes",
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
cluster.adapt(
|
| 1057 |
+
minimum=min_workers,
|
| 1058 |
+
maximum=max_workers,
|
| 1059 |
+
target_duration="30s",
|
| 1060 |
+
wait_count=3,
|
| 1061 |
+
interval="2s",
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
client = Client(cluster)
|
| 1065 |
print(f"Dask cluster ready: {client}\n")
|
| 1066 |
+
else:
|
| 1067 |
+
print(f"All {len(datasets)} datasets are small - using ProcessPoolExecutor only\n")
|
| 1068 |
|
| 1069 |
+
try:
|
| 1070 |
successes, failures = process_all_datasets(
|
| 1071 |
datasets, config, per_dataset_dir, client,
|
| 1072 |
max_retries=args.max_retries,
|
|
|
|
| 1102 |
}, indent=2))
|
| 1103 |
|
| 1104 |
finally:
|
| 1105 |
+
if client:
|
| 1106 |
+
client.close()
|
| 1107 |
+
if cluster:
|
| 1108 |
+
cluster.close()
|
| 1109 |
|
| 1110 |
|
| 1111 |
if __name__ == "__main__":
|