| """Semi-synthetic pseudo-bulk generation from scRNA-seq reference.""" |
| import numpy as np |
| import logging |
| from dataclasses import dataclass |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class PseudoBulkDataset: |
| """Output of pseudo-bulk generation. |
| |
| bulk: pseudo-bulk expression matrix (n_samples, n_genes) |
| proportions: true cell type proportions (n_samples, K) |
| signature: cell-type signature matrix (n_genes, K) |
| cell_type_names: list of cell type names |
| gene_names: list of gene names |
| """ |
| bulk: np.ndarray |
| proportions: np.ndarray |
| signature: np.ndarray |
| cell_type_names: list[str] |
| gene_names: list[str] |
|
|
|
|
| def load_scrna_reference(h5ad_path: str, celltype_key: str = "cell_type", |
| min_cells_per_type: int = 50, |
| n_top_genes: int = 2000): |
| """Load scRNA-seq reference from h5ad, return expression matrix and labels. |
| |
| Args: |
| h5ad_path: path to .h5ad file |
| celltype_key: obs column with cell type labels |
| min_cells_per_type: drop types with fewer cells |
| n_top_genes: number of highly variable genes to keep |
| |
| Returns: |
| expr: expression matrix (n_cells, n_genes), dense, counts or normalized |
| labels: cell type labels (n_cells,) |
| gene_names: list of gene names |
| cell_type_names: list of retained cell types |
| """ |
| import scanpy as sc |
|
|
| adata = sc.read_h5ad(h5ad_path) |
| log.info(f"Loaded {adata.n_obs} cells, {adata.n_vars} genes") |
|
|
| |
| type_counts = adata.obs[celltype_key].value_counts() |
| keep_types = type_counts[type_counts >= min_cells_per_type].index.tolist() |
| adata = adata[adata.obs[celltype_key].isin(keep_types)].copy() |
| log.info(f"Kept {len(keep_types)} cell types, {adata.n_obs} cells") |
|
|
| |
| sc.pp.normalize_total(adata, target_sum=1e4) |
| sc.pp.log1p(adata) |
| sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes) |
| adata = adata[:, adata.var["highly_variable"]].copy() |
|
|
| |
| expr = adata.X |
| if hasattr(expr, "toarray"): |
| expr = expr.toarray() |
| expr = np.asarray(expr, dtype=np.float64) |
|
|
| labels = adata.obs[celltype_key].values |
| gene_names = adata.var_names.tolist() |
| cell_type_names = sorted(keep_types) |
|
|
| return expr, labels, gene_names, cell_type_names |
|
|
|
|
| def build_signature(expr: np.ndarray, labels: np.ndarray, |
| cell_type_names: list[str]) -> np.ndarray: |
| """Build signature matrix: mean expression per cell type. |
| |
| Args: |
| expr: (n_cells, n_genes) |
| labels: (n_cells,) |
| cell_type_names: ordered cell type names |
| |
| Returns: |
| signature: (n_genes, K) |
| """ |
| n_genes = expr.shape[1] |
| K = len(cell_type_names) |
| sig = np.zeros((n_genes, K)) |
| for k, ct in enumerate(cell_type_names): |
| mask = labels == ct |
| sig[:, k] = expr[mask].mean(axis=0) |
| return sig |
|
|
|
|
| def sample_proportions(n: int, K: int, rng: np.random.Generator, |
| concentration: float = 1.0) -> np.ndarray: |
| """Sample random proportions from Dirichlet. |
| |
| Args: |
| n: number of samples |
| K: number of cell types |
| rng: numpy random generator |
| concentration: Dirichlet concentration. |
| <1: sparse/peaky, =1: uniform, >1: balanced |
| |
| Returns: |
| proportions: (n, K), rows sum to 1 |
| """ |
| alpha = np.full(K, concentration) |
| return rng.dirichlet(alpha, size=n) |
|
|
|
|
| def generate_pseudobulk( |
| expr: np.ndarray, |
| labels: np.ndarray, |
| cell_type_names: list[str], |
| gene_names: list[str], |
| n_samples: int = 5000, |
| cells_per_sample: int = 200, |
| concentration: float = 1.0, |
| noise_sd: float = 0.1, |
| seed: int = 2026, |
| ) -> PseudoBulkDataset: |
| """Generate pseudo-bulk dataset with known ground-truth proportions. |
| |
| For each sample: |
| 1. Draw proportions from Dirichlet(concentration) |
| 2. Sample cells_per_sample cells according to proportions |
| 3. Sum their expression to get pseudo-bulk |
| 4. Add optional Gaussian noise |
| |
| Args: |
| expr: scRNA-seq expression (n_cells, n_genes) |
| labels: cell type labels (n_cells,) |
| cell_type_names: ordered cell type names |
| gene_names: gene names |
| n_samples: number of pseudo-bulk samples |
| cells_per_sample: cells mixed per sample |
| concentration: Dirichlet parameter |
| noise_sd: Gaussian noise on log-expression (0 = no noise) |
| seed: random seed |
| |
| Returns: |
| PseudoBulkDataset |
| """ |
| rng = np.random.default_rng(seed) |
| K = len(cell_type_names) |
| n_genes = expr.shape[1] |
|
|
| |
| type_indices = {} |
| for k, ct in enumerate(cell_type_names): |
| type_indices[ct] = np.where(labels == ct)[0] |
|
|
| |
| props = sample_proportions(n_samples, K, rng, concentration) |
|
|
| |
| bulk = np.zeros((n_samples, n_genes)) |
| actual_props = np.zeros((n_samples, K)) |
|
|
| for i in range(n_samples): |
| |
| counts = rng.multinomial(cells_per_sample, props[i]) |
| actual_props[i] = counts / counts.sum() |
|
|
| |
| sample_expr = np.zeros(n_genes) |
| for k, ct in enumerate(cell_type_names): |
| if counts[k] > 0: |
| idx = rng.choice(type_indices[ct], size=counts[k], replace=True) |
| sample_expr += expr[idx].sum(axis=0) |
|
|
| bulk[i] = sample_expr / cells_per_sample |
|
|
| |
| if noise_sd > 0: |
| bulk = bulk + rng.normal(0, noise_sd, bulk.shape) |
| bulk = np.maximum(bulk, 0) |
|
|
| |
| sig = build_signature(expr, labels, cell_type_names) |
|
|
| log.info(f"Generated {n_samples} pseudo-bulk samples, {K} types, {n_genes} genes") |
|
|
| return PseudoBulkDataset( |
| bulk=bulk, |
| proportions=actual_props, |
| signature=sig, |
| cell_type_names=cell_type_names, |
| gene_names=gene_names, |
| ) |
|
|