""" Perturbation Data Loader for MuLGIT-Perturb. Loads perturbation data from Tahoe-100M and converts it to the format required by MuLGITPerturb: (baseline omics, perturbation descriptor, delta expression). Tahoe-100M schema: expression_data split: - genes: int64[] — gene indices - expressions: float32[] — post-perturbation expression values - drug: str — drug name - cell_line_id: str - canonical_smiles: str — drug SMILES - pubchem_cid: int64 - moa-fine: str — mechanism of action (fine-grained) - sample: str — sample/cell-line identifier For paired baseline data, use the vehicle control (DMSO) samples per cell line. """ import torch from torch.utils.data import Dataset, DataLoader, IterableDataset from typing import Optional, Dict, List, Tuple, Iterator import numpy as np from collections import defaultdict class TahoePerturbationDataset(IterableDataset): """ Streaming IterableDataset for Tahoe-100M perturbation data. Since Tahoe-100M is 320GB (95.6M rows), we stream it and apply in-memory grouping to create (baseline, post-perturbation) pairs. Design: 1. Stream through all rows 2. Group by (cell_line_id, sample) to collect baseline expressions 3. Match drug-perturbed samples to their corresponding baseline 4. Yield (baseline_expression, drug_smiles, delta_expression) tuples Args: dataset_path: HF dataset path (e.g., "tahoebio/Tahoe-100M") config: dataset config name split: dataset split max_samples: cap for testing (None = unlimited) n_genes: number of genes to return (None = all) gene_list: specific gene indices to filter (None = all) """ def __init__( self, dataset_path: str = "tahoebio/Tahoe-100M", config: str = "expression_data", split: str = "train", max_samples: Optional[int] = None, n_genes: Optional[int] = 978, # L1000 landmark size gene_list: Optional[List[int]] = None, cache_baselines: bool = True, ): self.dataset_path = dataset_path self.config = config self.split = split self.max_samples = max_samples self.n_genes = n_genes self.gene_list = gene_list self.cache_baselines = cache_baselines self._baselines = {} # (cell_line_id, sample) -> baseline expression vector def _load_dataset(self): """Lazy-load the HF dataset.""" if not hasattr(self, "_ds"): from datasets import load_dataset self._ds = load_dataset( self.dataset_path, self.config, split=self.split, streaming=True, ) # Pre-load drug metadata for SMILES lookup try: self._drug_meta = load_dataset( self.dataset_path, "drug_metadata", split="train", ) # Build SMILES lookup: drug_name → canonical_smiles self._smiles_lookup = {} for row in self._drug_meta: if "drug" in row and "canonical_smiles" in row: self._smiles_lookup[row["drug"]] = row["canonical_smiles"] except Exception: self._smiles_lookup = {} def _build_baseline_cache(self, n_prefetch: int = 10000): """ Build baseline expression cache by sampling vehicle controls. In Tahoe-100M, vehicle control samples (DMSO) serve as baseline. We identify them by drug name = "DMSO" or "Vehicle". """ self._load_dataset() baselines = defaultdict(list) count = 0 for row in self._ds: drug = row.get("drug", "").lower() is_vehicle = drug in ("dmso", "vehicle", "control", "untreated") if is_vehicle and "expressions" in row: expr = np.array(row["expressions"], dtype=np.float32) key = (row.get("cell_line_id", "unknown"), row.get("sample", "default")) baselines[key].append(expr) count += 1 if count >= n_prefetch: break # Average multiple baseline measurements per (cell_line, sample) self._baselines = {} for key, expr_list in baselines.items(): self._baselines[key] = np.mean(expr_list, axis=0) def __iter__(self) -> Iterator[Dict]: """Yield perturbation samples with matched baseline.""" self._load_dataset() # Build baseline cache if needed if self.cache_baselines and not self._baselines: self._build_baseline_cache() count = 0 for row in self._ds: if self.max_samples and count >= self.max_samples: break # Skip vehicle controls drug = row.get("drug", "").lower() is_vehicle = drug in ("dmso", "vehicle", "control", "untreated") if is_vehicle: continue # Get expression values if "expressions" not in row: continue expr_post = np.array(row["expressions"], dtype=np.float32) cell_line = row.get("cell_line_id", "unknown") sample = row.get("sample", "default") smiles = row.get("canonical_smiles", self._smiles_lookup.get(row.get("drug", ""), "")) if not smiles: continue # Get baseline expression key = (cell_line, sample) if key in self._baselines: expr_baseline = self._baselines[key] else: # Fallback: use the first baseline for this cell line cl_key = (cell_line, "default") if cl_key in self._baselines: expr_baseline = self._baselines[cl_key] else: # No baseline available, skip continue # Compute delta delta = expr_post - expr_baseline # Subset genes if requested if self.n_genes is not None and self.n_genes < len(expr_post): if self.gene_list is not None: indices = self.gene_list[:self.n_genes] else: indices = np.arange(self.n_genes) expr_post = expr_post[indices] expr_baseline = expr_baseline[indices] delta = delta[indices] yield { "baseline_expression": torch.tensor(expr_baseline), "post_expression": torch.tensor(expr_post), "delta_expression": torch.tensor(delta), "drug": row.get("drug", ""), "smiles": smiles, "cell_line_id": cell_line, "moa": row.get("moa-fine", ""), "pubchem_cid": row.get("pubchem_cid", None), } count += 1 class SyntheticPerturbationDataset(Dataset): """ Synthetic perturbation dataset for testing and development. Generates random baseline expression + drug-induced expression changes with known ground truth for validation. """ def __init__( self, n_samples: int = 1000, n_genes: int = 978, n_drugs: int = 50, noise_std: float = 0.1, seed: int = 42, ): np.random.seed(seed) torch.manual_seed(seed) self.n_samples = n_samples self.n_genes = n_genes # Generate baseline expression (log-normal) baseline = np.random.lognormal(mean=0.0, sigma=0.5, size=(n_samples, n_genes)) # Generate drug effects: each drug affects a sparse set of genes drug_effects = np.random.randn(n_drugs, n_genes) * 0.5 sparsity_mask = np.random.random((n_drugs, n_genes)) < 0.1 # 10% of genes affected drug_effects = drug_effects * sparsity_mask # Generate drug assignments drug_ids = np.random.randint(0, n_drugs, size=n_samples) # Generate deltas deltas = np.array([drug_effects[d] for d in drug_ids]) + \ np.random.randn(n_samples, n_genes) * noise_std post_expression = baseline + deltas self.baseline = torch.tensor(baseline, dtype=torch.float32) self.post = torch.tensor(post_expression, dtype=torch.float32) self.delta = torch.tensor(deltas, dtype=torch.float32) self.drug_ids = torch.tensor(drug_ids, dtype=torch.long) self.smiles = [f"CN1C(=O)CN=C(C2CCCCC2)c3ccccc3{str(i)}" for i in range(n_drugs)] def __len__(self): return self.n_samples def __getitem__(self, idx): return { "baseline_expression": self.baseline[idx], "post_expression": self.post[idx], "delta_expression": self.delta[idx], "drug": f"Drug_{self.drug_ids[idx].item()}", "smiles": self.smiles[self.drug_ids[idx].item() % len(self.smiles)], "cell_line_id": "synthetic", "moa": "synthetic", } class PerturbationCollator: """ Collates perturbation data into model-ready batches. Pads expression vectors to uniform length and handles missing fields. """ def __init__(self, n_genes: int = 978): self.n_genes = n_genes def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: # Stack expression tensors baseline = torch.stack([item["baseline_expression"] for item in batch]) post = torch.stack([item["post_expression"] for item in batch]) delta_true = torch.stack([item["delta_expression"] for item in batch]) # Collect non-tensor fields smiles_list = [item.get("smiles", "") for item in batch] drugs = [item.get("drug", "") for item in batch] cell_lines = [item.get("cell_line_id", "") for item in batch] return { "baseline_expression": baseline, "post_expression": post, "delta_expression": delta_true, "smiles": smiles_list, "drugs": drugs, "cell_lines": cell_lines, } def create_perturbation_dataloader( dataset: Dataset, batch_size: int = 256, shuffle: bool = True, num_workers: int = 4, n_genes: int = 978, ) -> DataLoader: """ Create a DataLoader for perturbation data. Uses PerturbationCollator for correct batching. """ return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle and not isinstance(dataset, IterableDataset), num_workers=num_workers, collate_fn=PerturbationCollator(n_genes=n_genes), pin_memory=True, drop_last=True, )