File size: 6,009 Bytes
fc329a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """Semi-synthetic pseudo-bulk generation from scRNA-seq reference."""
import numpy as np
import logging
from dataclasses import dataclass
log = logging.getLogger(__name__)
@dataclass
class PseudoBulkDataset:
"""Output of pseudo-bulk generation.
bulk: pseudo-bulk expression matrix (n_samples, n_genes)
proportions: true cell type proportions (n_samples, K)
signature: cell-type signature matrix (n_genes, K)
cell_type_names: list of cell type names
gene_names: list of gene names
"""
bulk: np.ndarray
proportions: np.ndarray
signature: np.ndarray
cell_type_names: list[str]
gene_names: list[str]
def load_scrna_reference(h5ad_path: str, celltype_key: str = "cell_type",
min_cells_per_type: int = 50,
n_top_genes: int = 2000):
"""Load scRNA-seq reference from h5ad, return expression matrix and labels.
Args:
h5ad_path: path to .h5ad file
celltype_key: obs column with cell type labels
min_cells_per_type: drop types with fewer cells
n_top_genes: number of highly variable genes to keep
Returns:
expr: expression matrix (n_cells, n_genes), dense, counts or normalized
labels: cell type labels (n_cells,)
gene_names: list of gene names
cell_type_names: list of retained cell types
"""
import scanpy as sc
adata = sc.read_h5ad(h5ad_path)
log.info(f"Loaded {adata.n_obs} cells, {adata.n_vars} genes")
# Filter cell types with too few cells
type_counts = adata.obs[celltype_key].value_counts()
keep_types = type_counts[type_counts >= min_cells_per_type].index.tolist()
adata = adata[adata.obs[celltype_key].isin(keep_types)].copy()
log.info(f"Kept {len(keep_types)} cell types, {adata.n_obs} cells")
# Normalize + select HVGs
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes)
adata = adata[:, adata.var["highly_variable"]].copy()
# Dense matrix
expr = adata.X
if hasattr(expr, "toarray"):
expr = expr.toarray()
expr = np.asarray(expr, dtype=np.float64)
labels = adata.obs[celltype_key].values
gene_names = adata.var_names.tolist()
cell_type_names = sorted(keep_types)
return expr, labels, gene_names, cell_type_names
def build_signature(expr: np.ndarray, labels: np.ndarray,
cell_type_names: list[str]) -> np.ndarray:
"""Build signature matrix: mean expression per cell type.
Args:
expr: (n_cells, n_genes)
labels: (n_cells,)
cell_type_names: ordered cell type names
Returns:
signature: (n_genes, K)
"""
n_genes = expr.shape[1]
K = len(cell_type_names)
sig = np.zeros((n_genes, K))
for k, ct in enumerate(cell_type_names):
mask = labels == ct
sig[:, k] = expr[mask].mean(axis=0)
return sig
def sample_proportions(n: int, K: int, rng: np.random.Generator,
concentration: float = 1.0) -> np.ndarray:
"""Sample random proportions from Dirichlet.
Args:
n: number of samples
K: number of cell types
rng: numpy random generator
concentration: Dirichlet concentration.
<1: sparse/peaky, =1: uniform, >1: balanced
Returns:
proportions: (n, K), rows sum to 1
"""
alpha = np.full(K, concentration)
return rng.dirichlet(alpha, size=n)
def generate_pseudobulk(
expr: np.ndarray,
labels: np.ndarray,
cell_type_names: list[str],
gene_names: list[str],
n_samples: int = 5000,
cells_per_sample: int = 200,
concentration: float = 1.0,
noise_sd: float = 0.1,
seed: int = 2026,
) -> PseudoBulkDataset:
"""Generate pseudo-bulk dataset with known ground-truth proportions.
For each sample:
1. Draw proportions from Dirichlet(concentration)
2. Sample cells_per_sample cells according to proportions
3. Sum their expression to get pseudo-bulk
4. Add optional Gaussian noise
Args:
expr: scRNA-seq expression (n_cells, n_genes)
labels: cell type labels (n_cells,)
cell_type_names: ordered cell type names
gene_names: gene names
n_samples: number of pseudo-bulk samples
cells_per_sample: cells mixed per sample
concentration: Dirichlet parameter
noise_sd: Gaussian noise on log-expression (0 = no noise)
seed: random seed
Returns:
PseudoBulkDataset
"""
rng = np.random.default_rng(seed)
K = len(cell_type_names)
n_genes = expr.shape[1]
# Index cells by type
type_indices = {}
for k, ct in enumerate(cell_type_names):
type_indices[ct] = np.where(labels == ct)[0]
# Sample proportions
props = sample_proportions(n_samples, K, rng, concentration)
# Generate pseudo-bulk
bulk = np.zeros((n_samples, n_genes))
actual_props = np.zeros((n_samples, K))
for i in range(n_samples):
# Number of cells per type (multinomial)
counts = rng.multinomial(cells_per_sample, props[i])
actual_props[i] = counts / counts.sum()
# Sample and sum cells
sample_expr = np.zeros(n_genes)
for k, ct in enumerate(cell_type_names):
if counts[k] > 0:
idx = rng.choice(type_indices[ct], size=counts[k], replace=True)
sample_expr += expr[idx].sum(axis=0)
bulk[i] = sample_expr / cells_per_sample
# Optional noise
if noise_sd > 0:
bulk = bulk + rng.normal(0, noise_sd, bulk.shape)
bulk = np.maximum(bulk, 0)
# Build signature
sig = build_signature(expr, labels, cell_type_names)
log.info(f"Generated {n_samples} pseudo-bulk samples, {K} types, {n_genes} genes")
return PseudoBulkDataset(
bulk=bulk,
proportions=actual_props,
signature=sig,
cell_type_names=cell_type_names,
gene_names=gene_names,
)
|