""" Multi-Omics Data Loading and Preprocessing Handles loading MLOmics, Tabula Muris Senis, ComputAge, and BALM datasets, converting them into compatible formats for MuLGIT training. Data sources: - AIBIC/MLOmics: TCGA multi-omics (mRNA, miRNA, methylation, CNV) + survival - longevity-db/Tabula_Muris_Senis_10x: Mouse aging scRNA-seq - computage/computage_bench: DNA methylation aging clocks - BALM/BALM-benchmark: Drug-target binding affinity """ import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from typing import Optional, Dict, List, Tuple, Any from pathlib import Path import os from datasets import load_dataset # ─── MLOmics Dataset Loader ───────────────────────────────────────────────── class MLOmicsDataset(Dataset): """ Loads TCGA multi-omics data from AIBIC/MLOmics. Data format: CSV files with genes as rows, samples as columns (transposed). We transpose to get samples × genes format. Output per sample: - methylation: tensor of methylation features - cnv: tensor of copy number variation features - mrna: tensor of gene expression features - mirna: tensor of microRNA expression features - survival_times: event/censoring time in days - event_observed: 1 if death, 0 if censored """ def __init__( self, cache_dir: str, cancer_type: str = "pan-cancer", feature_scale: str = "Original", # "Original", "Aligned", "Top" normalize: bool = True, common_genes_only: bool = True, ): """ Args: cache_dir: path to downloaded MLOmics dataset cancer_type: TCGA cancer code or "pan-cancer" feature_scale: which feature set to use normalize: whether to standardize features common_genes_only: use only genes present across all modalities """ self.cache_dir = Path(cache_dir) self.normalize = normalize # Determine data path if cancer_type == "pan-cancer": base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / "Pan-cancer" / feature_scale # Pan-cancer classification doesn't come with survival — use clustering data instead # For training, we'll use individual cancer types raise ValueError( "Pan-cancer classification data lacks survival labels. " "Use individual cancer types from Clustering_datasets which have survival data. " "Example: cancer_type='ACC', 'KIRC', 'LIHC', etc." ) else: # Check if clustering data exists (has survival) cluster_path = self.cache_dir / "Main_Dataset" / "Clustering_datasets" / cancer_type / feature_scale if cluster_path.exists(): base_path = cluster_path has_survival = True else: # Fall back to classification data (no survival) base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / f"GS-{cancer_type}" / feature_scale has_survival = False self.has_survival = has_survival self._load_data(base_path, cancer_type) if common_genes_only and len(self.gene_sets) > 1: self._align_genes() def _load_data(self, base_path: Path, cancer_type: str): """Load and transpose CSV files.""" self.modalities = {} self.gene_sets = {} # File naming convention file_map = { "methylation": f"{cancer_type}_Methy", "cnv": f"{cancer_type}_CNV", "mrna": f"{cancer_type}_mRNA", "mirna": f"{cancer_type}_miRNA", } for mod_name, file_prefix in file_map.items(): # Try different extensions for ext in ["_top.csv", "_aligned.csv", ".csv"]: filepath = base_path / f"{file_prefix}{ext}" if filepath.exists(): break else: print(f"Warning: {mod_name} file not found for {cancer_type}") continue df = pd.read_csv(filepath, index_col=0) # Transpose: genes × samples → samples × genes df_t = df.T self.modalities[mod_name] = df_t.astype(np.float32) self.gene_sets[mod_name] = set(df.index) # Load survival data if available surv_path = base_path / f"survival_{cancer_type}.csv" if surv_path.exists(): surv_df = pd.read_csv(surv_path, index_col=0) self.survival_times = surv_df["survival_times"].values.astype(np.float32) self.event_observed = surv_df["event_observed"].values.astype(np.float32) self.has_survival = True else: self.survival_times = None self.event_observed = None self.has_survival = False # Align sample IDs across modalities self._align_samples() def _align_samples(self): """Ensure all modalities have the same samples in the same order.""" # Find intersection of sample IDs common_samples = None for mod_name, df in self.modalities.items(): samples = set(df.index) if common_samples is None: common_samples = samples else: common_samples &= samples if common_samples is None: raise ValueError("No common samples across modalities") common_samples = sorted(common_samples) # Reindex all modalities for mod_name in self.modalities: self.modalities[mod_name] = self.modalities[mod_name].loc[common_samples] # Reindex survival data if self.has_survival and common_samples: surv_idx = list(self.survival_times) if hasattr(self, 'survival_df') else common_samples # Match survival data to common samples # (simplified — in practice need to handle sample ID matching) self.sample_ids = common_samples else: self.sample_ids = common_samples self.n_samples = len(common_samples) def _align_genes(self): """Use only genes present in all modalities (where applicable).""" # For cross-modality gene alignment, find common genes common_genes = None for mod_name, genes in self.gene_sets.items(): if common_genes is None: common_genes = genes else: common_genes &= genes # If reasonable overlap, filter if common_genes and len(common_genes) > 100: for mod_name in self.modalities: df = self.modalities[mod_name] available = [g for g in common_genes if g in df.columns] self.modalities[mod_name] = df[available] # Otherwise keep all genes per modality def __len__(self) -> int: return self.n_samples def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: item = {} for mod_name, df in self.modalities.items(): values = df.iloc[idx].values if self.normalize: values = (values - values.mean()) / (values.std() + 1e-8) item[mod_name] = torch.tensor(values, dtype=torch.float32) if self.has_survival: item["survival_times"] = torch.tensor(self.survival_times[idx], dtype=torch.float32) item["event_observed"] = torch.tensor(self.event_observed[idx], dtype=torch.float32) return item @property def feature_dims(self) -> Dict[str, int]: return {name: df.shape[1] for name, df in self.modalities.items()} # ─── Simplified: Synthetic Multi-Omics Generator for Rapid Prototyping ────── def generate_synthetic_multi_omics( n_samples: int = 1000, n_methylation: int = 1000, n_cnv: int = 1000, n_mrna: int = 1000, n_mirna: int = 300, seed: int = 42, ) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray]: """ Generate synthetic multi-omics data with known survival structure. Useful for testing the pipeline without downloading large datasets. Returns: modalities: dict of feature matrices survival_times: simulated event times event_observed: simulated event indicators """ rng = np.random.RandomState(seed) # Generate latent risk factor latent_risk = rng.randn(n_samples) # Generate correlated omics features modalities = {} # Methylation: some features correlated with risk meth_noise = rng.randn(n_samples, n_methylation) * 0.5 meth_signal = latent_risk[:, None] * rng.randn(1, n_methylation) * 0.5 modalities["methylation"] = (meth_signal + meth_noise).astype(np.float32) # CNV: sparse structural variants cnv = rng.randn(n_samples, n_cnv) * 0.3 cnv[:, :10] += latent_risk[:, None] * rng.randn(1, 10) * 0.3 modalities["cnv"] = cnv.astype(np.float32) # mRNA: strongly correlated with risk mrna_noise = rng.randn(n_samples, n_mrna) * 0.3 mrna_signal = latent_risk[:, None] * rng.randn(1, n_mrna) * 0.7 modalities["mrna"] = (mrna_signal + mrna_noise).astype(np.float32) # miRNA: regulatory mirna_noise = rng.randn(n_samples, n_mirna) * 0.4 mirna_signal = latent_risk[:, None] * rng.randn(1, n_mirna) * 0.6 modalities["mirna"] = (mirna_signal + mirna_noise).astype(np.float32) # Generate survival times using Cox model: h(t) = h0(t) * exp(risk) baseline_hazard = rng.exponential(scale=365.0, size=n_samples) # ~1 year baseline risk_factor = np.exp(latent_risk * 0.5) # hazard ratio event_times = baseline_hazard / risk_factor # Censoring (~30%) censor_time = rng.exponential(scale=1000.0, size=n_samples) observed_times = np.minimum(event_times, censor_time) event_observed = (event_times <= censor_time).astype(np.float32) return modalities, observed_times.astype(np.float32), event_observed class SyntheticMultiOmicsDataset(Dataset): """PyTorch Dataset wrapper for synthetic multi-omics data.""" def __init__( self, modalities: Dict[str, np.ndarray], survival_times: np.ndarray, event_observed: np.ndarray, ): self.modalities = {k: torch.tensor(v) for k, v in modalities.items()} self.survival_times = torch.tensor(survival_times) self.event_observed = torch.tensor(event_observed) self.n_samples = len(survival_times) def __len__(self) -> int: return self.n_samples def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: return { **{k: v[idx] for k, v in self.modalities.items()}, "survival_times": self.survival_times[idx], "event_observed": self.event_observed[idx], } # ─── HF Dataset Loaders ───────────────────────────────────────────────────── def load_tabula_muris_senis(split: str = "train") -> Dict: """ Load Tabula Muris Senis aging mouse scRNA-seq data from HF. Returns dict with expression matrix and metadata. """ ds = load_dataset("longevity-db/Tabula_Muris_Senis_10x", split=split) return ds def load_computage_bench(split: str = "train") -> Dict: """ Load ComputAge benchmark for epigenetic aging clocks. Returns dict with methylation samples and age labels. """ ds = load_dataset("computage/computage_bench", split=split) return ds def load_balm_benchmark(config: str = "BindingDB_filtered") -> Dict: """ Load BALM drug-target binding affinity benchmark. Returns dict with Drug (SMILES), Target (protein sequence), Y (affinity). """ ds = load_dataset("BALM/BALM-benchmark", config, split="train") return ds # ─── Data Collation ────────────────────────────────────────────────────────── def collate_multi_omics(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """ Collate function for multi-omics batches. Handles variable-length tensors within a batch (if any). """ keys = batch[0].keys() collated = {} for key in keys: tensors = [item[key] for item in batch] if tensors[0].dim() == 0: # scalar collated[key] = torch.stack(tensors) else: collated[key] = torch.stack(tensors) return collated def create_data_loader( dataset: Dataset, batch_size: int = 256, shuffle: bool = True, num_workers: int = 4, ) -> DataLoader: """Create a DataLoader for multi-omics data.""" return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_multi_omics, pin_memory=True, )