| """ |
| Perturbation Data Loader for MuLGIT-Perturb. |
| |
| Loads perturbation data from Tahoe-100M and converts it to the format |
| required by MuLGITPerturb: (baseline omics, perturbation descriptor, delta expression). |
| |
| Tahoe-100M schema: |
| expression_data split: |
| - genes: int64[] — gene indices |
| - expressions: float32[] — post-perturbation expression values |
| - drug: str — drug name |
| - cell_line_id: str |
| - canonical_smiles: str — drug SMILES |
| - pubchem_cid: int64 |
| - moa-fine: str — mechanism of action (fine-grained) |
| - sample: str — sample/cell-line identifier |
| |
| For paired baseline data, use the vehicle control (DMSO) samples per cell line. |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader, IterableDataset |
| from typing import Optional, Dict, List, Tuple, Iterator |
| import numpy as np |
| from collections import defaultdict |
|
|
|
|
| class TahoePerturbationDataset(IterableDataset): |
| """ |
| Streaming IterableDataset for Tahoe-100M perturbation data. |
| |
| Since Tahoe-100M is 320GB (95.6M rows), we stream it and apply |
| in-memory grouping to create (baseline, post-perturbation) pairs. |
| |
| Design: |
| 1. Stream through all rows |
| 2. Group by (cell_line_id, sample) to collect baseline expressions |
| 3. Match drug-perturbed samples to their corresponding baseline |
| 4. Yield (baseline_expression, drug_smiles, delta_expression) tuples |
| |
| Args: |
| dataset_path: HF dataset path (e.g., "tahoebio/Tahoe-100M") |
| config: dataset config name |
| split: dataset split |
| max_samples: cap for testing (None = unlimited) |
| n_genes: number of genes to return (None = all) |
| gene_list: specific gene indices to filter (None = all) |
| """ |
|
|
| def __init__( |
| self, |
| dataset_path: str = "tahoebio/Tahoe-100M", |
| config: str = "expression_data", |
| split: str = "train", |
| max_samples: Optional[int] = None, |
| n_genes: Optional[int] = 978, |
| gene_list: Optional[List[int]] = None, |
| cache_baselines: bool = True, |
| ): |
| self.dataset_path = dataset_path |
| self.config = config |
| self.split = split |
| self.max_samples = max_samples |
| self.n_genes = n_genes |
| self.gene_list = gene_list |
| self.cache_baselines = cache_baselines |
| self._baselines = {} |
|
|
| def _load_dataset(self): |
| """Lazy-load the HF dataset.""" |
| if not hasattr(self, "_ds"): |
| from datasets import load_dataset |
|
|
| self._ds = load_dataset( |
| self.dataset_path, |
| self.config, |
| split=self.split, |
| streaming=True, |
| ) |
|
|
| |
| try: |
| self._drug_meta = load_dataset( |
| self.dataset_path, |
| "drug_metadata", |
| split="train", |
| ) |
| |
| self._smiles_lookup = {} |
| for row in self._drug_meta: |
| if "drug" in row and "canonical_smiles" in row: |
| self._smiles_lookup[row["drug"]] = row["canonical_smiles"] |
| except Exception: |
| self._smiles_lookup = {} |
|
|
| def _build_baseline_cache(self, n_prefetch: int = 10000): |
| """ |
| Build baseline expression cache by sampling vehicle controls. |
| |
| In Tahoe-100M, vehicle control samples (DMSO) serve as baseline. |
| We identify them by drug name = "DMSO" or "Vehicle". |
| """ |
| self._load_dataset() |
|
|
| baselines = defaultdict(list) |
| count = 0 |
|
|
| for row in self._ds: |
| drug = row.get("drug", "").lower() |
| is_vehicle = drug in ("dmso", "vehicle", "control", "untreated") |
|
|
| if is_vehicle and "expressions" in row: |
| expr = np.array(row["expressions"], dtype=np.float32) |
| key = (row.get("cell_line_id", "unknown"), row.get("sample", "default")) |
| baselines[key].append(expr) |
| count += 1 |
|
|
| if count >= n_prefetch: |
| break |
|
|
| |
| self._baselines = {} |
| for key, expr_list in baselines.items(): |
| self._baselines[key] = np.mean(expr_list, axis=0) |
|
|
| def __iter__(self) -> Iterator[Dict]: |
| """Yield perturbation samples with matched baseline.""" |
| self._load_dataset() |
|
|
| |
| if self.cache_baselines and not self._baselines: |
| self._build_baseline_cache() |
|
|
| count = 0 |
| for row in self._ds: |
| if self.max_samples and count >= self.max_samples: |
| break |
|
|
| |
| drug = row.get("drug", "").lower() |
| is_vehicle = drug in ("dmso", "vehicle", "control", "untreated") |
| if is_vehicle: |
| continue |
|
|
| |
| if "expressions" not in row: |
| continue |
|
|
| expr_post = np.array(row["expressions"], dtype=np.float32) |
| cell_line = row.get("cell_line_id", "unknown") |
| sample = row.get("sample", "default") |
| smiles = row.get("canonical_smiles", self._smiles_lookup.get(row.get("drug", ""), "")) |
|
|
| if not smiles: |
| continue |
|
|
| |
| key = (cell_line, sample) |
| if key in self._baselines: |
| expr_baseline = self._baselines[key] |
| else: |
| |
| cl_key = (cell_line, "default") |
| if cl_key in self._baselines: |
| expr_baseline = self._baselines[cl_key] |
| else: |
| |
| continue |
|
|
| |
| delta = expr_post - expr_baseline |
|
|
| |
| if self.n_genes is not None and self.n_genes < len(expr_post): |
| if self.gene_list is not None: |
| indices = self.gene_list[:self.n_genes] |
| else: |
| indices = np.arange(self.n_genes) |
| expr_post = expr_post[indices] |
| expr_baseline = expr_baseline[indices] |
| delta = delta[indices] |
|
|
| yield { |
| "baseline_expression": torch.tensor(expr_baseline), |
| "post_expression": torch.tensor(expr_post), |
| "delta_expression": torch.tensor(delta), |
| "drug": row.get("drug", ""), |
| "smiles": smiles, |
| "cell_line_id": cell_line, |
| "moa": row.get("moa-fine", ""), |
| "pubchem_cid": row.get("pubchem_cid", None), |
| } |
|
|
| count += 1 |
|
|
|
|
| class SyntheticPerturbationDataset(Dataset): |
| """ |
| Synthetic perturbation dataset for testing and development. |
| |
| Generates random baseline expression + drug-induced expression changes |
| with known ground truth for validation. |
| """ |
|
|
| def __init__( |
| self, |
| n_samples: int = 1000, |
| n_genes: int = 978, |
| n_drugs: int = 50, |
| noise_std: float = 0.1, |
| seed: int = 42, |
| ): |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
| self.n_samples = n_samples |
| self.n_genes = n_genes |
|
|
| |
| baseline = np.random.lognormal(mean=0.0, sigma=0.5, size=(n_samples, n_genes)) |
|
|
| |
| drug_effects = np.random.randn(n_drugs, n_genes) * 0.5 |
| sparsity_mask = np.random.random((n_drugs, n_genes)) < 0.1 |
| drug_effects = drug_effects * sparsity_mask |
|
|
| |
| drug_ids = np.random.randint(0, n_drugs, size=n_samples) |
|
|
| |
| deltas = np.array([drug_effects[d] for d in drug_ids]) + \ |
| np.random.randn(n_samples, n_genes) * noise_std |
|
|
| post_expression = baseline + deltas |
|
|
| self.baseline = torch.tensor(baseline, dtype=torch.float32) |
| self.post = torch.tensor(post_expression, dtype=torch.float32) |
| self.delta = torch.tensor(deltas, dtype=torch.float32) |
| self.drug_ids = torch.tensor(drug_ids, dtype=torch.long) |
| self.smiles = [f"CN1C(=O)CN=C(C2CCCCC2)c3ccccc3{str(i)}" for i in range(n_drugs)] |
|
|
| def __len__(self): |
| return self.n_samples |
|
|
| def __getitem__(self, idx): |
| return { |
| "baseline_expression": self.baseline[idx], |
| "post_expression": self.post[idx], |
| "delta_expression": self.delta[idx], |
| "drug": f"Drug_{self.drug_ids[idx].item()}", |
| "smiles": self.smiles[self.drug_ids[idx].item() % len(self.smiles)], |
| "cell_line_id": "synthetic", |
| "moa": "synthetic", |
| } |
|
|
|
|
| class PerturbationCollator: |
| """ |
| Collates perturbation data into model-ready batches. |
| |
| Pads expression vectors to uniform length and handles missing fields. |
| """ |
|
|
| def __init__(self, n_genes: int = 978): |
| self.n_genes = n_genes |
|
|
| def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: |
| |
| baseline = torch.stack([item["baseline_expression"] for item in batch]) |
| post = torch.stack([item["post_expression"] for item in batch]) |
| delta_true = torch.stack([item["delta_expression"] for item in batch]) |
|
|
| |
| smiles_list = [item.get("smiles", "") for item in batch] |
| drugs = [item.get("drug", "") for item in batch] |
| cell_lines = [item.get("cell_line_id", "") for item in batch] |
|
|
| return { |
| "baseline_expression": baseline, |
| "post_expression": post, |
| "delta_expression": delta_true, |
| "smiles": smiles_list, |
| "drugs": drugs, |
| "cell_lines": cell_lines, |
| } |
|
|
|
|
| def create_perturbation_dataloader( |
| dataset: Dataset, |
| batch_size: int = 256, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| n_genes: int = 978, |
| ) -> DataLoader: |
| """ |
| Create a DataLoader for perturbation data. |
| |
| Uses PerturbationCollator for correct batching. |
| """ |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle and not isinstance(dataset, IterableDataset), |
| num_workers=num_workers, |
| collate_fn=PerturbationCollator(n_genes=n_genes), |
| pin_memory=True, |
| drop_last=True, |
| ) |
|
|