| """ |
| Multi-Omics Data Loading and Preprocessing |
| |
| Handles loading MLOmics, Tabula Muris Senis, ComputAge, and BALM datasets, |
| converting them into compatible formats for MuLGIT training. |
| |
| Data sources: |
| - AIBIC/MLOmics: TCGA multi-omics (mRNA, miRNA, methylation, CNV) + survival |
| - longevity-db/Tabula_Muris_Senis_10x: Mouse aging scRNA-seq |
| - computage/computage_bench: DNA methylation aging clocks |
| - BALM/BALM-benchmark: Drug-target binding affinity |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import pandas as pd |
| import numpy as np |
| from typing import Optional, Dict, List, Tuple, Any |
| from pathlib import Path |
| import os |
| from datasets import load_dataset |
|
|
|
|
| |
|
|
| class MLOmicsDataset(Dataset): |
| """ |
| Loads TCGA multi-omics data from AIBIC/MLOmics. |
| |
| Data format: CSV files with genes as rows, samples as columns (transposed). |
| We transpose to get samples Γ genes format. |
| |
| Output per sample: |
| - methylation: tensor of methylation features |
| - cnv: tensor of copy number variation features |
| - mrna: tensor of gene expression features |
| - mirna: tensor of microRNA expression features |
| - survival_times: event/censoring time in days |
| - event_observed: 1 if death, 0 if censored |
| """ |
| |
| def __init__( |
| self, |
| cache_dir: str, |
| cancer_type: str = "pan-cancer", |
| feature_scale: str = "Original", |
| normalize: bool = True, |
| common_genes_only: bool = True, |
| ): |
| """ |
| Args: |
| cache_dir: path to downloaded MLOmics dataset |
| cancer_type: TCGA cancer code or "pan-cancer" |
| feature_scale: which feature set to use |
| normalize: whether to standardize features |
| common_genes_only: use only genes present across all modalities |
| """ |
| self.cache_dir = Path(cache_dir) |
| self.normalize = normalize |
| |
| |
| if cancer_type == "pan-cancer": |
| base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / "Pan-cancer" / feature_scale |
| |
| |
| raise ValueError( |
| "Pan-cancer classification data lacks survival labels. " |
| "Use individual cancer types from Clustering_datasets which have survival data. " |
| "Example: cancer_type='ACC', 'KIRC', 'LIHC', etc." |
| ) |
| else: |
| |
| cluster_path = self.cache_dir / "Main_Dataset" / "Clustering_datasets" / cancer_type / feature_scale |
| if cluster_path.exists(): |
| base_path = cluster_path |
| has_survival = True |
| else: |
| |
| base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / f"GS-{cancer_type}" / feature_scale |
| has_survival = False |
| |
| self.has_survival = has_survival |
| self._load_data(base_path, cancer_type) |
| |
| if common_genes_only and len(self.gene_sets) > 1: |
| self._align_genes() |
| |
| def _load_data(self, base_path: Path, cancer_type: str): |
| """Load and transpose CSV files.""" |
| self.modalities = {} |
| self.gene_sets = {} |
| |
| |
| file_map = { |
| "methylation": f"{cancer_type}_Methy", |
| "cnv": f"{cancer_type}_CNV", |
| "mrna": f"{cancer_type}_mRNA", |
| "mirna": f"{cancer_type}_miRNA", |
| } |
| |
| for mod_name, file_prefix in file_map.items(): |
| |
| for ext in ["_top.csv", "_aligned.csv", ".csv"]: |
| filepath = base_path / f"{file_prefix}{ext}" |
| if filepath.exists(): |
| break |
| else: |
| print(f"Warning: {mod_name} file not found for {cancer_type}") |
| continue |
| |
| df = pd.read_csv(filepath, index_col=0) |
| |
| df_t = df.T |
| self.modalities[mod_name] = df_t.astype(np.float32) |
| self.gene_sets[mod_name] = set(df.index) |
| |
| |
| surv_path = base_path / f"survival_{cancer_type}.csv" |
| if surv_path.exists(): |
| surv_df = pd.read_csv(surv_path, index_col=0) |
| self.survival_times = surv_df["survival_times"].values.astype(np.float32) |
| self.event_observed = surv_df["event_observed"].values.astype(np.float32) |
| self.has_survival = True |
| else: |
| self.survival_times = None |
| self.event_observed = None |
| self.has_survival = False |
| |
| |
| self._align_samples() |
| |
| def _align_samples(self): |
| """Ensure all modalities have the same samples in the same order.""" |
| |
| common_samples = None |
| for mod_name, df in self.modalities.items(): |
| samples = set(df.index) |
| if common_samples is None: |
| common_samples = samples |
| else: |
| common_samples &= samples |
| |
| if common_samples is None: |
| raise ValueError("No common samples across modalities") |
| |
| common_samples = sorted(common_samples) |
| |
| |
| for mod_name in self.modalities: |
| self.modalities[mod_name] = self.modalities[mod_name].loc[common_samples] |
| |
| |
| if self.has_survival and common_samples: |
| surv_idx = list(self.survival_times) if hasattr(self, 'survival_df') else common_samples |
| |
| |
| self.sample_ids = common_samples |
| else: |
| self.sample_ids = common_samples |
| |
| self.n_samples = len(common_samples) |
| |
| def _align_genes(self): |
| """Use only genes present in all modalities (where applicable).""" |
| |
| common_genes = None |
| for mod_name, genes in self.gene_sets.items(): |
| if common_genes is None: |
| common_genes = genes |
| else: |
| common_genes &= genes |
| |
| |
| if common_genes and len(common_genes) > 100: |
| for mod_name in self.modalities: |
| df = self.modalities[mod_name] |
| available = [g for g in common_genes if g in df.columns] |
| self.modalities[mod_name] = df[available] |
| |
| |
| def __len__(self) -> int: |
| return self.n_samples |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| item = {} |
| |
| for mod_name, df in self.modalities.items(): |
| values = df.iloc[idx].values |
| if self.normalize: |
| values = (values - values.mean()) / (values.std() + 1e-8) |
| item[mod_name] = torch.tensor(values, dtype=torch.float32) |
| |
| if self.has_survival: |
| item["survival_times"] = torch.tensor(self.survival_times[idx], dtype=torch.float32) |
| item["event_observed"] = torch.tensor(self.event_observed[idx], dtype=torch.float32) |
| |
| return item |
| |
| @property |
| def feature_dims(self) -> Dict[str, int]: |
| return {name: df.shape[1] for name, df in self.modalities.items()} |
|
|
|
|
| |
|
|
| def generate_synthetic_multi_omics( |
| n_samples: int = 1000, |
| n_methylation: int = 1000, |
| n_cnv: int = 1000, |
| n_mrna: int = 1000, |
| n_mirna: int = 300, |
| seed: int = 42, |
| ) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray]: |
| """ |
| Generate synthetic multi-omics data with known survival structure. |
| Useful for testing the pipeline without downloading large datasets. |
| |
| Returns: |
| modalities: dict of feature matrices |
| survival_times: simulated event times |
| event_observed: simulated event indicators |
| """ |
| rng = np.random.RandomState(seed) |
| |
| |
| latent_risk = rng.randn(n_samples) |
| |
| |
| modalities = {} |
| |
| |
| meth_noise = rng.randn(n_samples, n_methylation) * 0.5 |
| meth_signal = latent_risk[:, None] * rng.randn(1, n_methylation) * 0.5 |
| modalities["methylation"] = (meth_signal + meth_noise).astype(np.float32) |
| |
| |
| cnv = rng.randn(n_samples, n_cnv) * 0.3 |
| cnv[:, :10] += latent_risk[:, None] * rng.randn(1, 10) * 0.3 |
| modalities["cnv"] = cnv.astype(np.float32) |
| |
| |
| mrna_noise = rng.randn(n_samples, n_mrna) * 0.3 |
| mrna_signal = latent_risk[:, None] * rng.randn(1, n_mrna) * 0.7 |
| modalities["mrna"] = (mrna_signal + mrna_noise).astype(np.float32) |
| |
| |
| mirna_noise = rng.randn(n_samples, n_mirna) * 0.4 |
| mirna_signal = latent_risk[:, None] * rng.randn(1, n_mirna) * 0.6 |
| modalities["mirna"] = (mirna_signal + mirna_noise).astype(np.float32) |
| |
| |
| baseline_hazard = rng.exponential(scale=365.0, size=n_samples) |
| risk_factor = np.exp(latent_risk * 0.5) |
| event_times = baseline_hazard / risk_factor |
| |
| |
| censor_time = rng.exponential(scale=1000.0, size=n_samples) |
| observed_times = np.minimum(event_times, censor_time) |
| event_observed = (event_times <= censor_time).astype(np.float32) |
| |
| return modalities, observed_times.astype(np.float32), event_observed |
|
|
|
|
| class SyntheticMultiOmicsDataset(Dataset): |
| """PyTorch Dataset wrapper for synthetic multi-omics data.""" |
| |
| def __init__( |
| self, |
| modalities: Dict[str, np.ndarray], |
| survival_times: np.ndarray, |
| event_observed: np.ndarray, |
| ): |
| self.modalities = {k: torch.tensor(v) for k, v in modalities.items()} |
| self.survival_times = torch.tensor(survival_times) |
| self.event_observed = torch.tensor(event_observed) |
| self.n_samples = len(survival_times) |
| |
| def __len__(self) -> int: |
| return self.n_samples |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| return { |
| **{k: v[idx] for k, v in self.modalities.items()}, |
| "survival_times": self.survival_times[idx], |
| "event_observed": self.event_observed[idx], |
| } |
|
|
|
|
| |
|
|
| def load_tabula_muris_senis(split: str = "train") -> Dict: |
| """ |
| Load Tabula Muris Senis aging mouse scRNA-seq data from HF. |
| |
| Returns dict with expression matrix and metadata. |
| """ |
| ds = load_dataset("longevity-db/Tabula_Muris_Senis_10x", split=split) |
| return ds |
|
|
|
|
| def load_computage_bench(split: str = "train") -> Dict: |
| """ |
| Load ComputAge benchmark for epigenetic aging clocks. |
| |
| Returns dict with methylation samples and age labels. |
| """ |
| ds = load_dataset("computage/computage_bench", split=split) |
| return ds |
|
|
|
|
| def load_balm_benchmark(config: str = "BindingDB_filtered") -> Dict: |
| """ |
| Load BALM drug-target binding affinity benchmark. |
| |
| Returns dict with Drug (SMILES), Target (protein sequence), Y (affinity). |
| """ |
| ds = load_dataset("BALM/BALM-benchmark", config, split="train") |
| return ds |
|
|
|
|
| |
|
|
| def collate_multi_omics(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| """ |
| Collate function for multi-omics batches. |
| Handles variable-length tensors within a batch (if any). |
| """ |
| keys = batch[0].keys() |
| collated = {} |
| |
| for key in keys: |
| tensors = [item[key] for item in batch] |
| if tensors[0].dim() == 0: |
| collated[key] = torch.stack(tensors) |
| else: |
| collated[key] = torch.stack(tensors) |
| |
| return collated |
|
|
|
|
| def create_data_loader( |
| dataset: Dataset, |
| batch_size: int = 256, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| ) -> DataLoader: |
| """Create a DataLoader for multi-omics data.""" |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| collate_fn=collate_multi_omics, |
| pin_memory=True, |
| ) |
|
|