whats2000 commited on
Commit
ac07329
·
1 Parent(s): 4e03c42

fix(eda): optimize gene statistics calculation in distributed EDA

Browse files

refactor(eda): adjust resource settings for improved performance and reduced I/O contention

configs/eda_optimized.yaml CHANGED
@@ -4,8 +4,8 @@
4
 
5
  resources:
6
  max_memory_gib: 5500 # Leave ~500 GB buffer for system
7
- max_workers: 100 # Based on actual RAM availability
8
- chunk_size: 50000 # Increased for larger memory
9
 
10
  paths:
11
  input_dirs:
@@ -24,30 +24,30 @@ paths:
24
  enhanced_metadata_cache: output/cache/enhanced_metadata.parquet
25
 
26
  dataset_thresholds:
27
- small: 2_000_000_000 # < 2B entries: full speed
28
- medium: 15_000_000_000 # < 15B entries: moderate
29
- large: 75_000_000_000 # < 75B entries: slice required
30
  max_entries: 1_000_000_000_000 # Max 1T entries (accommodates largest dataset: 520B)
31
 
32
  slicing:
33
  enabled: true
34
- obs_slice_size: 150000 # Increased for larger memory
35
  overlap: 0
36
  merge_strategy: "combine"
37
 
38
  strategy:
39
  small:
40
- workers_fraction: 1.0 # Use all 42 workers
41
  chunk_size_multiplier: 1.0
42
  priority: 1
43
 
44
  medium:
45
- workers_fraction: 0.7 # ~30 workers
46
  chunk_size_multiplier: 0.85
47
  priority: 2
48
 
49
  large:
50
- workers_fraction: 0.4 # ~17 workers with slicing
51
  chunk_size_multiplier: 0.6
52
  priority: 3
53
  require_slicing: true
 
4
 
5
  resources:
6
  max_memory_gib: 5500 # Leave ~500 GB buffer for system
7
+ max_workers: 48 # Reduced to minimize I/O contention and oversubscription
8
+ chunk_size: 50000 # Optimized for current workload
9
 
10
  paths:
11
  input_dirs:
 
24
  enhanced_metadata_cache: output/cache/enhanced_metadata.parquet
25
 
26
  dataset_thresholds:
27
+ small: 5_000_000_000 # < 5B entries: full speed (reduced overhead)
28
+ medium: 30_000_000_000 # < 30B entries: moderate (sparse matrices safe)
29
+ large: 150_000_000_000 # < 150B entries: slice required
30
  max_entries: 1_000_000_000_000 # Max 1T entries (accommodates largest dataset: 520B)
31
 
32
  slicing:
33
  enabled: true
34
+ obs_slice_size: 300000 # Increased to reduce scheduling overhead and HDF5 opens
35
  overlap: 0
36
  merge_strategy: "combine"
37
 
38
  strategy:
39
  small:
40
+ workers_fraction: 1.0 # Use all 48 workers
41
  chunk_size_multiplier: 1.0
42
  priority: 1
43
 
44
  medium:
45
+ workers_fraction: 0.75 # ~36 workers (reduced I/O pressure)
46
  chunk_size_multiplier: 0.85
47
  priority: 2
48
 
49
  large:
50
+ workers_fraction: 0.5 # ~24 workers (I/O-bound, fewer is faster)
51
  chunk_size_multiplier: 0.6
52
  priority: 3
53
  require_slicing: true
run_eda_slurm.sh CHANGED
@@ -29,6 +29,13 @@ echo "========================================="
29
 
30
  cd /project/GOV108018/whats2000_work/cell_x_gene_visualization
31
 
 
 
 
 
 
 
 
32
  # Create logs directory if it doesn't exist
33
  mkdir -p logs
34
 
 
29
 
30
  cd /project/GOV108018/whats2000_work/cell_x_gene_visualization
31
 
32
+ # Limit BLAS/NumPy threading to prevent oversubscription
33
+ export OMP_NUM_THREADS=1
34
+ export MKL_NUM_THREADS=1
35
+ export OPENBLAS_NUM_THREADS=1
36
+ export NUMEXPR_NUM_THREADS=1
37
+ export VECLIB_MAXIMUM_THREADS=1
38
+
39
  # Create logs directory if it doesn't exist
40
  mkdir -p logs
41
 
scripts/distributed_eda.py CHANGED
@@ -386,12 +386,19 @@ def process_slice(
386
  cell_counts = np.asarray(csr.sum(axis=1)).ravel()
387
  cell_genes = np.diff(csr.indptr).astype(np.int64)
388
 
389
- # Gene stats
390
- csc = csr.tocsc()
391
- gene_n_cells += np.diff(csc.indptr).astype(np.int64)
392
- gene_total_counts += np.asarray(csc.sum(axis=0)).ravel()
 
 
 
 
 
 
 
393
 
394
- del csr, csc, data
395
  else:
396
  arr = np.asarray(chunk, dtype=np.float64)
397
  nz = arr != 0
 
386
  cell_counts = np.asarray(csr.sum(axis=1)).ravel()
387
  cell_genes = np.diff(csr.indptr).astype(np.int64)
388
 
389
+ # Gene stats (optimized: use bincount instead of CSC conversion)
390
+ # Accumulate counts directly from CSR indices/data
391
+ gene_total_counts += np.bincount(
392
+ csr.indices,
393
+ weights=data,
394
+ minlength=n_vars
395
+ )
396
+ gene_n_cells += np.bincount(
397
+ csr.indices,
398
+ minlength=n_vars
399
+ )
400
 
401
+ del csr, data
402
  else:
403
  arr = np.asarray(chunk, dtype=np.float64)
404
  nz = arr != 0