MuLGIT / mulgit /perturb /data.py
vedatonuryilmaz's picture
Upload mulgit/perturb/data.py
5bc61e0 verified
"""
Perturbation Data Loader for MuLGIT-Perturb.
Loads perturbation data from Tahoe-100M and converts it to the format
required by MuLGITPerturb: (baseline omics, perturbation descriptor, delta expression).
Tahoe-100M schema:
expression_data split:
- genes: int64[] — gene indices
- expressions: float32[] — post-perturbation expression values
- drug: str — drug name
- cell_line_id: str
- canonical_smiles: str — drug SMILES
- pubchem_cid: int64
- moa-fine: str — mechanism of action (fine-grained)
- sample: str — sample/cell-line identifier
For paired baseline data, use the vehicle control (DMSO) samples per cell line.
"""
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
from typing import Optional, Dict, List, Tuple, Iterator
import numpy as np
from collections import defaultdict
class TahoePerturbationDataset(IterableDataset):
"""
Streaming IterableDataset for Tahoe-100M perturbation data.
Since Tahoe-100M is 320GB (95.6M rows), we stream it and apply
in-memory grouping to create (baseline, post-perturbation) pairs.
Design:
1. Stream through all rows
2. Group by (cell_line_id, sample) to collect baseline expressions
3. Match drug-perturbed samples to their corresponding baseline
4. Yield (baseline_expression, drug_smiles, delta_expression) tuples
Args:
dataset_path: HF dataset path (e.g., "tahoebio/Tahoe-100M")
config: dataset config name
split: dataset split
max_samples: cap for testing (None = unlimited)
n_genes: number of genes to return (None = all)
gene_list: specific gene indices to filter (None = all)
"""
def __init__(
self,
dataset_path: str = "tahoebio/Tahoe-100M",
config: str = "expression_data",
split: str = "train",
max_samples: Optional[int] = None,
n_genes: Optional[int] = 978, # L1000 landmark size
gene_list: Optional[List[int]] = None,
cache_baselines: bool = True,
):
self.dataset_path = dataset_path
self.config = config
self.split = split
self.max_samples = max_samples
self.n_genes = n_genes
self.gene_list = gene_list
self.cache_baselines = cache_baselines
self._baselines = {} # (cell_line_id, sample) -> baseline expression vector
def _load_dataset(self):
"""Lazy-load the HF dataset."""
if not hasattr(self, "_ds"):
from datasets import load_dataset
self._ds = load_dataset(
self.dataset_path,
self.config,
split=self.split,
streaming=True,
)
# Pre-load drug metadata for SMILES lookup
try:
self._drug_meta = load_dataset(
self.dataset_path,
"drug_metadata",
split="train",
)
# Build SMILES lookup: drug_name → canonical_smiles
self._smiles_lookup = {}
for row in self._drug_meta:
if "drug" in row and "canonical_smiles" in row:
self._smiles_lookup[row["drug"]] = row["canonical_smiles"]
except Exception:
self._smiles_lookup = {}
def _build_baseline_cache(self, n_prefetch: int = 10000):
"""
Build baseline expression cache by sampling vehicle controls.
In Tahoe-100M, vehicle control samples (DMSO) serve as baseline.
We identify them by drug name = "DMSO" or "Vehicle".
"""
self._load_dataset()
baselines = defaultdict(list)
count = 0
for row in self._ds:
drug = row.get("drug", "").lower()
is_vehicle = drug in ("dmso", "vehicle", "control", "untreated")
if is_vehicle and "expressions" in row:
expr = np.array(row["expressions"], dtype=np.float32)
key = (row.get("cell_line_id", "unknown"), row.get("sample", "default"))
baselines[key].append(expr)
count += 1
if count >= n_prefetch:
break
# Average multiple baseline measurements per (cell_line, sample)
self._baselines = {}
for key, expr_list in baselines.items():
self._baselines[key] = np.mean(expr_list, axis=0)
def __iter__(self) -> Iterator[Dict]:
"""Yield perturbation samples with matched baseline."""
self._load_dataset()
# Build baseline cache if needed
if self.cache_baselines and not self._baselines:
self._build_baseline_cache()
count = 0
for row in self._ds:
if self.max_samples and count >= self.max_samples:
break
# Skip vehicle controls
drug = row.get("drug", "").lower()
is_vehicle = drug in ("dmso", "vehicle", "control", "untreated")
if is_vehicle:
continue
# Get expression values
if "expressions" not in row:
continue
expr_post = np.array(row["expressions"], dtype=np.float32)
cell_line = row.get("cell_line_id", "unknown")
sample = row.get("sample", "default")
smiles = row.get("canonical_smiles", self._smiles_lookup.get(row.get("drug", ""), ""))
if not smiles:
continue
# Get baseline expression
key = (cell_line, sample)
if key in self._baselines:
expr_baseline = self._baselines[key]
else:
# Fallback: use the first baseline for this cell line
cl_key = (cell_line, "default")
if cl_key in self._baselines:
expr_baseline = self._baselines[cl_key]
else:
# No baseline available, skip
continue
# Compute delta
delta = expr_post - expr_baseline
# Subset genes if requested
if self.n_genes is not None and self.n_genes < len(expr_post):
if self.gene_list is not None:
indices = self.gene_list[:self.n_genes]
else:
indices = np.arange(self.n_genes)
expr_post = expr_post[indices]
expr_baseline = expr_baseline[indices]
delta = delta[indices]
yield {
"baseline_expression": torch.tensor(expr_baseline),
"post_expression": torch.tensor(expr_post),
"delta_expression": torch.tensor(delta),
"drug": row.get("drug", ""),
"smiles": smiles,
"cell_line_id": cell_line,
"moa": row.get("moa-fine", ""),
"pubchem_cid": row.get("pubchem_cid", None),
}
count += 1
class SyntheticPerturbationDataset(Dataset):
"""
Synthetic perturbation dataset for testing and development.
Generates random baseline expression + drug-induced expression changes
with known ground truth for validation.
"""
def __init__(
self,
n_samples: int = 1000,
n_genes: int = 978,
n_drugs: int = 50,
noise_std: float = 0.1,
seed: int = 42,
):
np.random.seed(seed)
torch.manual_seed(seed)
self.n_samples = n_samples
self.n_genes = n_genes
# Generate baseline expression (log-normal)
baseline = np.random.lognormal(mean=0.0, sigma=0.5, size=(n_samples, n_genes))
# Generate drug effects: each drug affects a sparse set of genes
drug_effects = np.random.randn(n_drugs, n_genes) * 0.5
sparsity_mask = np.random.random((n_drugs, n_genes)) < 0.1 # 10% of genes affected
drug_effects = drug_effects * sparsity_mask
# Generate drug assignments
drug_ids = np.random.randint(0, n_drugs, size=n_samples)
# Generate deltas
deltas = np.array([drug_effects[d] for d in drug_ids]) + \
np.random.randn(n_samples, n_genes) * noise_std
post_expression = baseline + deltas
self.baseline = torch.tensor(baseline, dtype=torch.float32)
self.post = torch.tensor(post_expression, dtype=torch.float32)
self.delta = torch.tensor(deltas, dtype=torch.float32)
self.drug_ids = torch.tensor(drug_ids, dtype=torch.long)
self.smiles = [f"CN1C(=O)CN=C(C2CCCCC2)c3ccccc3{str(i)}" for i in range(n_drugs)]
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
return {
"baseline_expression": self.baseline[idx],
"post_expression": self.post[idx],
"delta_expression": self.delta[idx],
"drug": f"Drug_{self.drug_ids[idx].item()}",
"smiles": self.smiles[self.drug_ids[idx].item() % len(self.smiles)],
"cell_line_id": "synthetic",
"moa": "synthetic",
}
class PerturbationCollator:
"""
Collates perturbation data into model-ready batches.
Pads expression vectors to uniform length and handles missing fields.
"""
def __init__(self, n_genes: int = 978):
self.n_genes = n_genes
def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
# Stack expression tensors
baseline = torch.stack([item["baseline_expression"] for item in batch])
post = torch.stack([item["post_expression"] for item in batch])
delta_true = torch.stack([item["delta_expression"] for item in batch])
# Collect non-tensor fields
smiles_list = [item.get("smiles", "") for item in batch]
drugs = [item.get("drug", "") for item in batch]
cell_lines = [item.get("cell_line_id", "") for item in batch]
return {
"baseline_expression": baseline,
"post_expression": post,
"delta_expression": delta_true,
"smiles": smiles_list,
"drugs": drugs,
"cell_lines": cell_lines,
}
def create_perturbation_dataloader(
dataset: Dataset,
batch_size: int = 256,
shuffle: bool = True,
num_workers: int = 4,
n_genes: int = 978,
) -> DataLoader:
"""
Create a DataLoader for perturbation data.
Uses PerturbationCollator for correct batching.
"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle and not isinstance(dataset, IterableDataset),
num_workers=num_workers,
collate_fn=PerturbationCollator(n_genes=n_genes),
pin_memory=True,
drop_last=True,
)