whats2000 commited on
Commit
b8d98f3
·
1 Parent(s): 74e20c3

feat(eda): implement hybrid processing strategy for small and large datasets

Browse files
Files changed (1) hide show
  1. scripts/distributed_eda.py +400 -131
scripts/distributed_eda.py CHANGED
@@ -12,6 +12,7 @@ and metadata summaries. This handles datasets from 2 GB to 500 GB.
12
  from __future__ import annotations
13
 
14
  import argparse
 
15
  import gc
16
  import hashlib
17
  import json
@@ -165,7 +166,184 @@ def merge_slice_results(slices: list[SliceResult], n_obs: int, n_vars: int) -> d
165
 
166
 
167
  # ---------------------------------------------------------------------------
168
- # Core worker function: process ONE slice of ONE dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # ---------------------------------------------------------------------------
170
  def process_slice(
171
  path_str: str,
@@ -458,122 +636,203 @@ def process_all_datasets(
458
  datasets: list[dict],
459
  config: dict,
460
  per_dataset_dir: Path,
461
- client: Client,
462
  max_retries: int = 3,
463
  ) -> tuple[list[dict], list[dict]]:
464
- """Process all datasets by slicing into bounded tasks.
465
-
466
- Each task processes at most obs_slice_size rows. Results are merged
467
- per-dataset with O(n_vars) memory on the scheduler side.
468
- """
469
  chunk_size = config["resources"]["chunk_size"]
470
  obs_slice_size = config["slicing"].get("obs_slice_size", 75_000)
471
  small_threshold = config["dataset_thresholds"]["small"]
472
  max_meta_cols = config["metadata"]["max_meta_cols"]
473
  max_categories = config["metadata"]["max_categories"]
 
474
 
475
  successes = []
476
  failures = []
477
 
478
- # Categorize and sort datasets
479
  small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold]
480
  large_datasets = [d for d in datasets if d.get("total_entries", 0) >= small_threshold]
481
 
482
- # Sort each category
483
  small_datasets.sort(key=lambda d: d["total_entries"])
484
  large_datasets.sort(key=lambda d: d["total_entries"])
485
 
486
- # Process small first in batches for speed, then large one-by-one
487
  datasets_sorted = small_datasets + large_datasets
488
  small_count = len(small_datasets)
489
  sliced_count = len(large_datasets)
490
 
491
- # Calculate batch size for small datasets (maximize worker utilization)
492
- max_workers = client.cluster.maximum if hasattr(client.cluster, 'maximum') else 48
493
- small_batch_size = max(1, min(max_workers, small_count)) # Process up to max_workers at once
494
-
495
  print(f"\n{'=' * 80}")
496
  print(f"Processing {len(datasets_sorted)} datasets")
497
- print(f" Small datasets (no slicing): {small_count} (batch size: {small_batch_size})")
498
- print(f" Medium/Large (sliced): {sliced_count}")
499
  print(f"Slice size: {obs_slice_size:,} rows per task (for medium/large)")
500
  print(f"Small threshold: {small_threshold:,} entries")
501
  print(f"Chunk size: {chunk_size:,} rows per sub-chunk")
502
  print(f"{'=' * 80}\n")
503
 
504
  total_datasets = len(datasets_sorted)
505
- is_small_batch = True # Track if we're in small dataset phase
506
-
507
- # Overall progress bar for all datasets
508
- with tqdm(
509
- total=total_datasets,
510
- desc="Datasets",
511
- position=0,
512
- leave=True,
513
- ncols=100
514
- ) as dataset_pbar:
515
- # Process datasets in batches for small, one-by-one for large
516
- ds_idx = 0
517
- while ds_idx < total_datasets:
518
- # Determine batch size
519
- if ds_idx < small_count:
520
- # Small datasets: batch processing
521
- batch_end = min(ds_idx + small_batch_size, small_count)
522
- batch = datasets_sorted[ds_idx:batch_end]
523
- is_small_batch = True
524
- else:
525
- # Large datasets: process one at a time
526
- batch = [datasets_sorted[ds_idx]]
527
- batch_end = ds_idx + 1
528
- is_small_batch = False
529
-
530
- # Submit all tasks for the batch
531
- batch_futures = []
532
- batch_info = []
533
-
534
- for dataset in batch:
535
- dataset_idx = ds_idx + batch.index(dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  ds_path = dataset["dataset_path"]
537
  ds_name = Path(ds_path).name
538
  n_obs = dataset["n_obs"]
539
  n_vars = dataset["n_vars"]
540
  total_entries = dataset["total_entries"]
541
 
542
- # Create slice tasks (small datasets = 1 task, large = sliced)
 
 
543
  slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
544
  n_slices = len(slice_tasks)
545
 
 
 
 
 
 
 
546
  # Submit slice tasks to Dask
547
  futures = client.map(
548
  lambda t: process_slice(t[0], t[1], t[2], chunk_size),
549
  slice_tasks,
550
  pure=False,
551
  )
552
-
553
- batch_futures.append(futures)
554
- batch_info.append({
555
- 'dataset': dataset,
556
- 'ds_idx': dataset_idx,
557
- 'ds_path': ds_path,
558
- 'ds_name': ds_name,
559
- 'n_obs': n_obs,
560
- 'n_vars': n_vars,
561
- 'total_entries': total_entries,
562
- 'slice_tasks': slice_tasks,
563
- 'n_slices': n_slices,
564
- 't0': time.time()
565
- })
566
-
567
- # Process results for the batch
568
- dataset_pbar.set_description(f"Datasets [{ds_idx + 1}-{batch_end}/{total_datasets}]" if len(batch) > 1 else f"Datasets [{ds_idx + 1}/{total_datasets}]")
569
-
570
- for info_idx, (futures, info) in enumerate(zip(batch_futures, batch_info)):
571
- dataset = info['dataset']
572
- dataset_idx = info['ds_idx']
573
- ds_path = info['ds_path']
574
- ds_name = info['ds_name']
575
- n_obs = info['n_obs']
576
- n_vars = info['n_vars']
577
  total_entries = info['total_entries']
578
  slice_tasks = info['slice_tasks']
579
  n_slices = info['n_slices']
@@ -583,11 +842,11 @@ def process_all_datasets(
583
  slice_results: list[SliceResult] = []
584
  failed_slices: list[tuple[str, int, int]] = []
585
 
586
- # Collect results with progress bar (show only for sliced datasets)
587
- show_slice_bar = n_slices > 1 and not is_small_batch
588
  slice_pbar = tqdm(
589
  total=n_slices,
590
- desc=f" └─ Slices",
591
  position=1,
592
  leave=False,
593
  ncols=100,
@@ -696,8 +955,7 @@ def process_all_datasets(
696
  # Update dataset progress
697
  dataset_pbar.update(1)
698
 
699
- # Move to next batch
700
- ds_idx = batch_end
701
 
702
  return successes, failures
703
 
@@ -746,60 +1004,69 @@ def main() -> None:
746
  per_dataset_dir = output_dir / "per_dataset"
747
  per_dataset_dir.mkdir(parents=True, exist_ok=True)
748
 
749
- # Cluster setup
750
- max_memory_gib = config["resources"]["max_memory_gib"]
751
- max_workers = config["resources"]["max_workers"]
752
- min_workers = min(4, max_workers)
753
-
754
- # Each worker needs enough memory for: chunk_size * n_vars * 12 bytes * 3x overhead
755
- # With slice architecture, workers are lightweight - give them decent memory
756
- memory_per_worker_gib = max(2.0, max_memory_gib / max_workers)
757
-
758
- total_entries = sum(d["total_entries"] for d in datasets)
759
- total_slices = sum(
760
- max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 50_000)))
761
- for d in datasets
762
- )
763
-
764
- print(json.dumps({
765
- "total_datasets": len(datasets),
766
- "total_slices": total_slices,
767
- "total_entries": total_entries,
768
- "shard_index": shard_index,
769
- "num_shards": num_shards,
770
- "memory_per_worker_gib": round(memory_per_worker_gib, 1),
771
- "max_workers": max_workers,
772
- }, indent=2))
773
-
774
- print(f"\nStarting Dask LocalCluster:")
775
- print(f" Workers: {min_workers} -> {max_workers} (adaptive)")
776
- print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB")
777
- print(f" Total memory budget: {max_memory_gib} GiB\n")
778
-
779
- cluster = LocalCluster(
780
- n_workers=min_workers,
781
- threads_per_worker=1,
782
- processes=True,
783
- memory_limit=f"{memory_per_worker_gib}GiB",
784
- silence_logs=True,
785
- dashboard_address=None,
786
- lifetime="120 minutes",
787
- lifetime_stagger="15 minutes",
788
- )
789
-
790
- cluster.adapt(
791
- minimum=min_workers,
792
- maximum=max_workers,
793
- target_duration="30s",
794
- wait_count=3,
795
- interval="2s",
796
- )
797
-
798
- client = Client(cluster)
799
 
800
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  print(f"Dask cluster ready: {client}\n")
 
 
802
 
 
803
  successes, failures = process_all_datasets(
804
  datasets, config, per_dataset_dir, client,
805
  max_retries=args.max_retries,
@@ -835,8 +1102,10 @@ def main() -> None:
835
  }, indent=2))
836
 
837
  finally:
838
- client.close()
839
- cluster.close()
 
 
840
 
841
 
842
  if __name__ == "__main__":
 
12
  from __future__ import annotations
13
 
14
  import argparse
15
+ import concurrent.futures
16
  import gc
17
  import hashlib
18
  import json
 
166
 
167
 
168
  # ---------------------------------------------------------------------------
169
+ # Simple worker function for small datasets (no Dask overhead)
170
+ # ---------------------------------------------------------------------------
171
+ def process_dataset_simple(
172
+ path_str: str,
173
+ n_obs: int,
174
+ n_vars: int,
175
+ chunk_size: int,
176
+ max_meta_cols: int,
177
+ max_categories: int,
178
+ ) -> dict:
179
+ """Process entire small dataset in one worker (no slicing, no Dask)."""
180
+ t0 = time.time()
181
+ path = Path(path_str)
182
+ row: dict[str, Any] = {
183
+ "dataset_path": path_str,
184
+ "dataset_file": path.name,
185
+ "n_obs": n_obs,
186
+ "n_vars": n_vars,
187
+ }
188
+
189
+ try:
190
+ adata = ad.read_h5ad(path, backed="r")
191
+ total_entries = n_obs * n_vars
192
+
193
+ nnz_total = 0
194
+ x_sum = 0.0
195
+ x_sum_sq = 0.0
196
+
197
+ # Cell-level accumulators
198
+ cell_total_counts_sum = 0.0
199
+ cell_total_counts_min = math.inf
200
+ cell_total_counts_max = -math.inf
201
+ cell_n_genes_sum = 0
202
+ cell_n_genes_min = 2**63 - 1
203
+ cell_n_genes_max = 0
204
+
205
+ # Gene-level accumulators
206
+ gene_n_cells = np.zeros(n_vars, dtype=np.int64)
207
+ gene_total_counts = np.zeros(n_vars, dtype=np.float64)
208
+
209
+ # Process in chunks
210
+ for start in range(0, n_obs, chunk_size):
211
+ end = min(start + chunk_size, n_obs)
212
+ chunk = adata.X[start:end, :]
213
+
214
+ if sparse.issparse(chunk):
215
+ csr = chunk.tocsr() if not sparse.isspmatrix_csr(chunk) else chunk
216
+ data = csr.data.astype(np.float64, copy=False)
217
+
218
+ nnz_total += int(csr.nnz)
219
+ x_sum += float(data.sum())
220
+ x_sum_sq += float(np.square(data).sum())
221
+
222
+ # Cell stats
223
+ cell_counts = np.asarray(csr.sum(axis=1)).ravel()
224
+ cell_genes = np.diff(csr.indptr).astype(np.int64)
225
+
226
+ cell_total_counts_sum += float(cell_counts.sum())
227
+ cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
228
+ cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
229
+ cell_n_genes_sum += int(cell_genes.sum())
230
+ cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
231
+ cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
232
+
233
+ # Gene stats
234
+ csc = csr.tocsc()
235
+ gene_n_cells += np.diff(csc.indptr).astype(np.int64)
236
+ gene_total_counts += np.asarray(csc.sum(axis=0)).ravel()
237
+
238
+ del csr, csc, data
239
+ else:
240
+ arr = np.asarray(chunk, dtype=np.float64)
241
+ nz = arr != 0
242
+
243
+ nnz_total += int(nz.sum())
244
+ x_sum += float(arr.sum())
245
+ x_sum_sq += float(np.square(arr).sum())
246
+
247
+ # Cell stats
248
+ cell_counts = arr.sum(axis=1)
249
+ cell_genes = nz.sum(axis=1).astype(np.int64)
250
+
251
+ cell_total_counts_sum += float(cell_counts.sum())
252
+ cell_total_counts_min = min(cell_total_counts_min, float(cell_counts.min()))
253
+ cell_total_counts_max = max(cell_total_counts_max, float(cell_counts.max()))
254
+ cell_n_genes_sum += int(cell_genes.sum())
255
+ cell_n_genes_min = min(cell_n_genes_min, int(cell_genes.min()))
256
+ cell_n_genes_max = max(cell_n_genes_max, int(cell_genes.max()))
257
+
258
+ # Gene stats
259
+ gene_n_cells += nz.sum(axis=0).astype(np.int64)
260
+ gene_total_counts += arr.sum(axis=0)
261
+
262
+ del arr, nz
263
+
264
+ del chunk
265
+ gc.collect()
266
+
267
+ # Matrix-level stats
268
+ row["nnz"] = int(nnz_total)
269
+ row["sparsity"] = float(1.0 - nnz_total / total_entries) if total_entries else None
270
+ row["x_mean"] = float(x_sum / total_entries) if total_entries else None
271
+ if total_entries:
272
+ var = max(0.0, x_sum_sq / total_entries - (x_sum / total_entries) ** 2)
273
+ row["x_std"] = float(math.sqrt(var))
274
+ else:
275
+ row["x_std"] = None
276
+
277
+ # Cell-level stats
278
+ if n_obs > 0:
279
+ row["cell_total_counts_min"] = float(cell_total_counts_min)
280
+ row["cell_total_counts_max"] = float(cell_total_counts_max)
281
+ row["cell_total_counts_mean"] = float(cell_total_counts_sum / n_obs)
282
+ row["cell_n_genes_detected_min"] = int(cell_n_genes_min)
283
+ row["cell_n_genes_detected_max"] = int(cell_n_genes_max)
284
+ row["cell_n_genes_detected_mean"] = float(cell_n_genes_sum / n_obs)
285
+ else:
286
+ row["cell_total_counts_min"] = None
287
+ row["cell_total_counts_max"] = None
288
+ row["cell_total_counts_mean"] = None
289
+ row["cell_n_genes_detected_min"] = None
290
+ row["cell_n_genes_detected_max"] = None
291
+ row["cell_n_genes_detected_mean"] = None
292
+
293
+ # Gene-level stats
294
+ genes_detected = int(np.count_nonzero(gene_n_cells))
295
+ row["genes_detected_in_any_cell"] = genes_detected
296
+ row["genes_detected_in_any_cell_pct"] = float(genes_detected / n_vars * 100) if n_vars else 0.0
297
+ if genes_detected > 0:
298
+ mask = gene_n_cells > 0
299
+ row["gene_n_cells_min"] = int(gene_n_cells[mask].min())
300
+ row["gene_n_cells_max"] = int(gene_n_cells[mask].max())
301
+ row["gene_n_cells_mean"] = float(gene_n_cells[mask].mean())
302
+ row["gene_total_counts_min"] = float(gene_total_counts[mask].min())
303
+ row["gene_total_counts_max"] = float(gene_total_counts[mask].max())
304
+ row["gene_total_counts_mean"] = float(gene_total_counts[mask].mean())
305
+ else:
306
+ for k in ("gene_n_cells_min", "gene_n_cells_max", "gene_n_cells_mean",
307
+ "gene_total_counts_min", "gene_total_counts_max", "gene_total_counts_mean"):
308
+ row[k] = 0
309
+
310
+ # Metadata
311
+ row["obs_columns"] = int(len(adata.obs.columns))
312
+ row["var_columns"] = int(len(adata.var.columns))
313
+ row["metadata_obs_summary"] = summarize_metadata(
314
+ adata.obs, max_cols=max_meta_cols, max_categories=max_categories
315
+ )
316
+ row["metadata_var_summary"] = summarize_metadata(
317
+ adata.var, max_cols=max_meta_cols, max_categories=max_categories
318
+ )
319
+ row["obs_schema"] = extract_schema(adata.obs)
320
+ row["var_schema"] = extract_schema(adata.var)
321
+
322
+ # Clean up
323
+ del gene_n_cells, gene_total_counts
324
+ try:
325
+ if hasattr(adata, "file") and adata.file is not None:
326
+ adata.file.close()
327
+ except Exception:
328
+ pass
329
+ del adata
330
+
331
+ row["status"] = "ok"
332
+ row["n_slices_total"] = 1
333
+ row["n_slices_ok"] = 1
334
+ row["n_slices_failed"] = 0
335
+
336
+ except Exception as exc:
337
+ row["status"] = "failed"
338
+ row["error"] = str(exc)
339
+
340
+ gc.collect()
341
+ row["elapsed_sec"] = round(time.time() - t0, 2)
342
+ return row
343
+
344
+
345
+ # ---------------------------------------------------------------------------
346
+ # Core worker function: process ONE slice of ONE dataset (Dask)
347
  # ---------------------------------------------------------------------------
348
  def process_slice(
349
  path_str: str,
 
636
  datasets: list[dict],
637
  config: dict,
638
  per_dataset_dir: Path,
639
+ client: Client | None,
640
  max_retries: int = 3,
641
  ) -> tuple[list[dict], list[dict]]:
642
+ """Process all datasets: small ones with ProcessPoolExecutor, large ones with Dask."""
 
 
 
 
643
  chunk_size = config["resources"]["chunk_size"]
644
  obs_slice_size = config["slicing"].get("obs_slice_size", 75_000)
645
  small_threshold = config["dataset_thresholds"]["small"]
646
  max_meta_cols = config["metadata"]["max_meta_cols"]
647
  max_categories = config["metadata"]["max_categories"]
648
+ max_workers_base = config["resources"]["max_workers"]
649
 
650
  successes = []
651
  failures = []
652
 
653
+ # Categorize datasets
654
  small_datasets = [d for d in datasets if d.get("total_entries", 0) < small_threshold]
655
  large_datasets = [d for d in datasets if d.get("total_entries", 0) >= small_threshold]
656
 
 
657
  small_datasets.sort(key=lambda d: d["total_entries"])
658
  large_datasets.sort(key=lambda d: d["total_entries"])
659
 
 
660
  datasets_sorted = small_datasets + large_datasets
661
  small_count = len(small_datasets)
662
  sliced_count = len(large_datasets)
663
 
 
 
 
 
664
  print(f"\n{'=' * 80}")
665
  print(f"Processing {len(datasets_sorted)} datasets")
666
+ print(f" Small datasets (ProcessPoolExecutor): {small_count}")
667
+ print(f" Medium/Large (Dask + slicing): {sliced_count}")
668
  print(f"Slice size: {obs_slice_size:,} rows per task (for medium/large)")
669
  print(f"Small threshold: {small_threshold:,} entries")
670
  print(f"Chunk size: {chunk_size:,} rows per sub-chunk")
671
  print(f"{'=' * 80}\n")
672
 
673
  total_datasets = len(datasets_sorted)
674
+
675
+ # ========================================================================
676
+ # Phase 1: Process small datasets with ProcessPoolExecutor (batched)
677
+ # ========================================================================
678
+ if small_count > 0:
679
+ print(f"{'='*80}")
680
+ print(f"PHASE 1: Small datasets ({small_count}) - ProcessPoolExecutor")
681
+ print(f"{'='*80}\n")
682
+
683
+ # Adaptive worker management
684
+ current_workers = max_workers_base
685
+ min_workers = max(1, max_workers_base // 4)
686
+ batch_size = max(30, min(100, small_count // 4))
687
+
688
+ # Throughput monitoring
689
+ check_interval = 50
690
+ baseline_throughput = None
691
+ slowdown_threshold = 0.5
692
+ last_check_idx = 0
693
+ batch_start_time = time.time()
694
+
695
+ print(f"Workers: {current_workers} (adaptive: {min_workers}-{max_workers_base})")
696
+ print(f"Batch size: {batch_size} (recycled between batches)\n")
697
+
698
+ with tqdm(total=small_count, desc="Small datasets", position=0) as pbar:
699
+ for batch_start in range(0, small_count, batch_size):
700
+ batch_end = min(batch_start + batch_size, small_count)
701
+ batch = small_datasets[batch_start:batch_end]
702
+
703
+ # Check throughput and adjust workers
704
+ processed = len(successes) + len(failures)
705
+ if processed >= last_check_idx + check_interval and processed > check_interval:
706
+ elapsed = time.time() - batch_start_time
707
+ current_throughput = processed / elapsed if elapsed > 0 else 0
708
+
709
+ if baseline_throughput is None and processed >= check_interval * 2:
710
+ baseline_throughput = current_throughput
711
+ tqdm.write(f"Baseline: {baseline_throughput:.2f} ds/sec")
712
+
713
+ if baseline_throughput and current_throughput < baseline_throughput * slowdown_threshold:
714
+ if current_workers > min_workers:
715
+ old_workers = current_workers
716
+ current_workers = max(min_workers, current_workers // 2)
717
+ tqdm.write(f"⚠️ Slowdown detected. Workers: {old_workers} → {current_workers}")
718
+ baseline_throughput = None
719
+
720
+ last_check_idx = processed
721
+
722
+ # Process batch
723
+ executor = concurrent.futures.ProcessPoolExecutor(max_workers=current_workers)
724
+ futures = {}
725
+
726
+ try:
727
+ for dataset in batch:
728
+ future = executor.submit(
729
+ process_dataset_simple,
730
+ dataset["dataset_path"],
731
+ dataset["n_obs"],
732
+ dataset["n_vars"],
733
+ chunk_size,
734
+ max_meta_cols,
735
+ max_categories,
736
+ )
737
+ futures[future] = dataset
738
+
739
+ for future in concurrent.futures.as_completed(futures):
740
+ dataset = futures[future]
741
+ ds_path = dataset["dataset_path"]
742
+ ds_name = Path(ds_path).name
743
+
744
+ try:
745
+ row = future.result(timeout=3600)
746
+
747
+ # File size
748
+ try:
749
+ row["file_size_gib"] = round(Path(ds_path).stat().st_size / (1024 ** 3), 4)
750
+ except Exception:
751
+ pass
752
+
753
+ # Save JSON
754
+ try:
755
+ payload_name = safe_name(Path(ds_path)) + ".json"
756
+ (per_dataset_dir / payload_name).write_text(json.dumps(row, indent=2))
757
+ except Exception as exc:
758
+ row["save_error"] = str(exc)
759
+
760
+ if row.get("status") == "ok":
761
+ successes.append(row)
762
+ elapsed = row.get("elapsed_sec", "?")
763
+ tqdm.write(f" [{len(successes)}/{total_datasets}] ✓ {ds_name[:50]} | {elapsed}s")
764
+ else:
765
+ failures.append(row)
766
+ error = row.get("error", "Unknown")[:60]
767
+ tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {error}")
768
+
769
+ except concurrent.futures.TimeoutError:
770
+ failures.append({
771
+ "dataset_path": ds_path,
772
+ "dataset_file": ds_name,
773
+ "status": "failed",
774
+ "error": "Timeout",
775
+ })
776
+ tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | Timeout")
777
+ except Exception as exc:
778
+ failures.append({
779
+ "dataset_path": ds_path,
780
+ "dataset_file": ds_name,
781
+ "status": "failed",
782
+ "error": str(exc),
783
+ })
784
+ tqdm.write(f" [{len(successes) + len(failures)}/{total_datasets}] ✗ {ds_name[:50]} | {exc}")
785
+ finally:
786
+ pbar.update(1)
787
+ finally:
788
+ executor.shutdown(wait=True)
789
+ gc.collect()
790
+ time.sleep(1)
791
+
792
+ print(f"\nPhase 1 complete: {len([s for s in successes if s in successes[-small_count:]])} ok, " +
793
+ f"{len([f for f in failures if f in failures[-small_count:]])} failed\n")
794
+
795
+ # ========================================================================
796
+ # Phase 2: Process large datasets with Dask (existing logic)
797
+ # ========================================================================
798
+ if sliced_count > 0 and client:
799
+ print(f"{'='*80}")
800
+ print(f"PHASE 2: Medium/Large datasets ({sliced_count}) - Dask + slicing")
801
+ print(f"{'='*80}\n")
802
+
803
+ with tqdm(
804
+ total=sliced_count,
805
+ desc="Med/Large datasets",
806
+ position=0,
807
+ leave=True,
808
+ ncols=100
809
+ ) as dataset_pbar:
810
+ for ds_local_idx, dataset in enumerate(large_datasets):
811
+ dataset_idx = small_count + ds_local_idx
812
  ds_path = dataset["dataset_path"]
813
  ds_name = Path(ds_path).name
814
  n_obs = dataset["n_obs"]
815
  n_vars = dataset["n_vars"]
816
  total_entries = dataset["total_entries"]
817
 
818
+ t0 = time.time()
819
+
820
+ # Create slice tasks
821
  slice_tasks = create_slice_tasks(dataset, obs_slice_size, small_threshold)
822
  n_slices = len(slice_tasks)
823
 
824
+ dataset_pbar.set_description(f"Med/Large [{ds_local_idx + 1}/{sliced_count}]")
825
+
826
+ # Submit all slices for this dataset
827
+ slice_results: list[SliceResult] = []
828
+ failed_slices: list[tuple[str, int, int]] = []
829
+
830
  # Submit slice tasks to Dask
831
  futures = client.map(
832
  lambda t: process_slice(t[0], t[1], t[2], chunk_size),
833
  slice_tasks,
834
  pure=False,
835
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  total_entries = info['total_entries']
837
  slice_tasks = info['slice_tasks']
838
  n_slices = info['n_slices']
 
842
  slice_results: list[SliceResult] = []
843
  failed_slices: list[tuple[str, int, int]] = []
844
 
845
+ # Collect results with progress bar (show for sliced datasets)
846
+ show_slice_bar = n_slices > 1
847
  slice_pbar = tqdm(
848
  total=n_slices,
849
+ desc=f" \u2514\u2500 Slices",
850
  position=1,
851
  leave=False,
852
  ncols=100,
 
955
  # Update dataset progress
956
  dataset_pbar.update(1)
957
 
958
+ print(f"\nPhase 2 complete\n")
 
959
 
960
  return successes, failures
961
 
 
1004
  per_dataset_dir = output_dir / "per_dataset"
1005
  per_dataset_dir.mkdir(parents=True, exist_ok=True)
1006
 
1007
+ # Check if we need Dask cluster (for medium/large datasets)
1008
+ small_threshold = config["dataset_thresholds"]["small"]
1009
+ large_count = sum(1 for d in datasets if d.get("total_entries", 0) >= small_threshold)
1010
+
1011
+ client = None
1012
+ cluster = None
1013
+
1014
+ if large_count > 0:
1015
+ # Cluster setup for large datasets
1016
+ max_memory_gib = config["resources"]["max_memory_gib"]
1017
+ max_workers = config["resources"]["max_workers"]
1018
+ min_workers = min(4, max_workers)
1019
+
1020
+ memory_per_worker_gib = max(2.0, max_memory_gib / max_workers)
1021
+
1022
+ total_entries = sum(d["total_entries"] for d in datasets)
1023
+ total_slices = sum(
1024
+ max(1, math.ceil(d["n_obs"] / config["slicing"].get("obs_slice_size", 75_000)))
1025
+ for d in datasets if d.get("total_entries", 0) >= small_threshold
1026
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1027
 
1028
+ print(json.dumps({
1029
+ "total_datasets": len(datasets),
1030
+ "small_datasets": len(datasets) - large_count,
1031
+ "large_datasets": large_count,
1032
+ "total_slices": total_slices,
1033
+ "total_entries": total_entries,
1034
+ "shard_index": shard_index,
1035
+ "num_shards": num_shards,
1036
+ "memory_per_worker_gib": round(memory_per_worker_gib, 1),
1037
+ "max_workers": max_workers,
1038
+ }, indent=2))
1039
+
1040
+ print(f"\nStarting Dask LocalCluster (for {large_count} large datasets):")
1041
+ print(f" Workers: {min_workers} -> {max_workers} (adaptive)")
1042
+ print(f" Memory per worker: {memory_per_worker_gib:.1f} GiB")
1043
+ print(f" Total memory budget: {max_memory_gib} GiB\n")
1044
+
1045
+ cluster = LocalCluster(
1046
+ n_workers=min_workers,
1047
+ threads_per_worker=1,
1048
+ processes=True,
1049
+ memory_limit=f"{memory_per_worker_gib}GiB",
1050
+ silence_logs=True,
1051
+ dashboard_address=None,
1052
+ lifetime="120 minutes",
1053
+ lifetime_stagger="15 minutes",
1054
+ )
1055
+
1056
+ cluster.adapt(
1057
+ minimum=min_workers,
1058
+ maximum=max_workers,
1059
+ target_duration="30s",
1060
+ wait_count=3,
1061
+ interval="2s",
1062
+ )
1063
+
1064
+ client = Client(cluster)
1065
  print(f"Dask cluster ready: {client}\n")
1066
+ else:
1067
+ print(f"All {len(datasets)} datasets are small - using ProcessPoolExecutor only\n")
1068
 
1069
+ try:
1070
  successes, failures = process_all_datasets(
1071
  datasets, config, per_dataset_dir, client,
1072
  max_retries=args.max_retries,
 
1102
  }, indent=2))
1103
 
1104
  finally:
1105
+ if client:
1106
+ client.close()
1107
+ if cluster:
1108
+ cluster.close()
1109
 
1110
 
1111
  if __name__ == "__main__":