MuLGIT / mulgit /data.py
vedatonuryilmaz's picture
Upload mulgit/data.py with huggingface_hub
3305fae verified
"""
Multi-Omics Data Loading and Preprocessing
Handles loading MLOmics, Tabula Muris Senis, ComputAge, and BALM datasets,
converting them into compatible formats for MuLGIT training.
Data sources:
- AIBIC/MLOmics: TCGA multi-omics (mRNA, miRNA, methylation, CNV) + survival
- longevity-db/Tabula_Muris_Senis_10x: Mouse aging scRNA-seq
- computage/computage_bench: DNA methylation aging clocks
- BALM/BALM-benchmark: Drug-target binding affinity
"""
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from typing import Optional, Dict, List, Tuple, Any
from pathlib import Path
import os
from datasets import load_dataset
# ─── MLOmics Dataset Loader ─────────────────────────────────────────────────
class MLOmicsDataset(Dataset):
"""
Loads TCGA multi-omics data from AIBIC/MLOmics.
Data format: CSV files with genes as rows, samples as columns (transposed).
We transpose to get samples Γ— genes format.
Output per sample:
- methylation: tensor of methylation features
- cnv: tensor of copy number variation features
- mrna: tensor of gene expression features
- mirna: tensor of microRNA expression features
- survival_times: event/censoring time in days
- event_observed: 1 if death, 0 if censored
"""
def __init__(
self,
cache_dir: str,
cancer_type: str = "pan-cancer",
feature_scale: str = "Original", # "Original", "Aligned", "Top"
normalize: bool = True,
common_genes_only: bool = True,
):
"""
Args:
cache_dir: path to downloaded MLOmics dataset
cancer_type: TCGA cancer code or "pan-cancer"
feature_scale: which feature set to use
normalize: whether to standardize features
common_genes_only: use only genes present across all modalities
"""
self.cache_dir = Path(cache_dir)
self.normalize = normalize
# Determine data path
if cancer_type == "pan-cancer":
base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / "Pan-cancer" / feature_scale
# Pan-cancer classification doesn't come with survival β€” use clustering data instead
# For training, we'll use individual cancer types
raise ValueError(
"Pan-cancer classification data lacks survival labels. "
"Use individual cancer types from Clustering_datasets which have survival data. "
"Example: cancer_type='ACC', 'KIRC', 'LIHC', etc."
)
else:
# Check if clustering data exists (has survival)
cluster_path = self.cache_dir / "Main_Dataset" / "Clustering_datasets" / cancer_type / feature_scale
if cluster_path.exists():
base_path = cluster_path
has_survival = True
else:
# Fall back to classification data (no survival)
base_path = self.cache_dir / "Main_Dataset" / "Classification_datasets" / f"GS-{cancer_type}" / feature_scale
has_survival = False
self.has_survival = has_survival
self._load_data(base_path, cancer_type)
if common_genes_only and len(self.gene_sets) > 1:
self._align_genes()
def _load_data(self, base_path: Path, cancer_type: str):
"""Load and transpose CSV files."""
self.modalities = {}
self.gene_sets = {}
# File naming convention
file_map = {
"methylation": f"{cancer_type}_Methy",
"cnv": f"{cancer_type}_CNV",
"mrna": f"{cancer_type}_mRNA",
"mirna": f"{cancer_type}_miRNA",
}
for mod_name, file_prefix in file_map.items():
# Try different extensions
for ext in ["_top.csv", "_aligned.csv", ".csv"]:
filepath = base_path / f"{file_prefix}{ext}"
if filepath.exists():
break
else:
print(f"Warning: {mod_name} file not found for {cancer_type}")
continue
df = pd.read_csv(filepath, index_col=0)
# Transpose: genes Γ— samples β†’ samples Γ— genes
df_t = df.T
self.modalities[mod_name] = df_t.astype(np.float32)
self.gene_sets[mod_name] = set(df.index)
# Load survival data if available
surv_path = base_path / f"survival_{cancer_type}.csv"
if surv_path.exists():
surv_df = pd.read_csv(surv_path, index_col=0)
self.survival_times = surv_df["survival_times"].values.astype(np.float32)
self.event_observed = surv_df["event_observed"].values.astype(np.float32)
self.has_survival = True
else:
self.survival_times = None
self.event_observed = None
self.has_survival = False
# Align sample IDs across modalities
self._align_samples()
def _align_samples(self):
"""Ensure all modalities have the same samples in the same order."""
# Find intersection of sample IDs
common_samples = None
for mod_name, df in self.modalities.items():
samples = set(df.index)
if common_samples is None:
common_samples = samples
else:
common_samples &= samples
if common_samples is None:
raise ValueError("No common samples across modalities")
common_samples = sorted(common_samples)
# Reindex all modalities
for mod_name in self.modalities:
self.modalities[mod_name] = self.modalities[mod_name].loc[common_samples]
# Reindex survival data
if self.has_survival and common_samples:
surv_idx = list(self.survival_times) if hasattr(self, 'survival_df') else common_samples
# Match survival data to common samples
# (simplified β€” in practice need to handle sample ID matching)
self.sample_ids = common_samples
else:
self.sample_ids = common_samples
self.n_samples = len(common_samples)
def _align_genes(self):
"""Use only genes present in all modalities (where applicable)."""
# For cross-modality gene alignment, find common genes
common_genes = None
for mod_name, genes in self.gene_sets.items():
if common_genes is None:
common_genes = genes
else:
common_genes &= genes
# If reasonable overlap, filter
if common_genes and len(common_genes) > 100:
for mod_name in self.modalities:
df = self.modalities[mod_name]
available = [g for g in common_genes if g in df.columns]
self.modalities[mod_name] = df[available]
# Otherwise keep all genes per modality
def __len__(self) -> int:
return self.n_samples
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
item = {}
for mod_name, df in self.modalities.items():
values = df.iloc[idx].values
if self.normalize:
values = (values - values.mean()) / (values.std() + 1e-8)
item[mod_name] = torch.tensor(values, dtype=torch.float32)
if self.has_survival:
item["survival_times"] = torch.tensor(self.survival_times[idx], dtype=torch.float32)
item["event_observed"] = torch.tensor(self.event_observed[idx], dtype=torch.float32)
return item
@property
def feature_dims(self) -> Dict[str, int]:
return {name: df.shape[1] for name, df in self.modalities.items()}
# ─── Simplified: Synthetic Multi-Omics Generator for Rapid Prototyping ──────
def generate_synthetic_multi_omics(
n_samples: int = 1000,
n_methylation: int = 1000,
n_cnv: int = 1000,
n_mrna: int = 1000,
n_mirna: int = 300,
seed: int = 42,
) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray]:
"""
Generate synthetic multi-omics data with known survival structure.
Useful for testing the pipeline without downloading large datasets.
Returns:
modalities: dict of feature matrices
survival_times: simulated event times
event_observed: simulated event indicators
"""
rng = np.random.RandomState(seed)
# Generate latent risk factor
latent_risk = rng.randn(n_samples)
# Generate correlated omics features
modalities = {}
# Methylation: some features correlated with risk
meth_noise = rng.randn(n_samples, n_methylation) * 0.5
meth_signal = latent_risk[:, None] * rng.randn(1, n_methylation) * 0.5
modalities["methylation"] = (meth_signal + meth_noise).astype(np.float32)
# CNV: sparse structural variants
cnv = rng.randn(n_samples, n_cnv) * 0.3
cnv[:, :10] += latent_risk[:, None] * rng.randn(1, 10) * 0.3
modalities["cnv"] = cnv.astype(np.float32)
# mRNA: strongly correlated with risk
mrna_noise = rng.randn(n_samples, n_mrna) * 0.3
mrna_signal = latent_risk[:, None] * rng.randn(1, n_mrna) * 0.7
modalities["mrna"] = (mrna_signal + mrna_noise).astype(np.float32)
# miRNA: regulatory
mirna_noise = rng.randn(n_samples, n_mirna) * 0.4
mirna_signal = latent_risk[:, None] * rng.randn(1, n_mirna) * 0.6
modalities["mirna"] = (mirna_signal + mirna_noise).astype(np.float32)
# Generate survival times using Cox model: h(t) = h0(t) * exp(risk)
baseline_hazard = rng.exponential(scale=365.0, size=n_samples) # ~1 year baseline
risk_factor = np.exp(latent_risk * 0.5) # hazard ratio
event_times = baseline_hazard / risk_factor
# Censoring (~30%)
censor_time = rng.exponential(scale=1000.0, size=n_samples)
observed_times = np.minimum(event_times, censor_time)
event_observed = (event_times <= censor_time).astype(np.float32)
return modalities, observed_times.astype(np.float32), event_observed
class SyntheticMultiOmicsDataset(Dataset):
"""PyTorch Dataset wrapper for synthetic multi-omics data."""
def __init__(
self,
modalities: Dict[str, np.ndarray],
survival_times: np.ndarray,
event_observed: np.ndarray,
):
self.modalities = {k: torch.tensor(v) for k, v in modalities.items()}
self.survival_times = torch.tensor(survival_times)
self.event_observed = torch.tensor(event_observed)
self.n_samples = len(survival_times)
def __len__(self) -> int:
return self.n_samples
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
return {
**{k: v[idx] for k, v in self.modalities.items()},
"survival_times": self.survival_times[idx],
"event_observed": self.event_observed[idx],
}
# ─── HF Dataset Loaders ─────────────────────────────────────────────────────
def load_tabula_muris_senis(split: str = "train") -> Dict:
"""
Load Tabula Muris Senis aging mouse scRNA-seq data from HF.
Returns dict with expression matrix and metadata.
"""
ds = load_dataset("longevity-db/Tabula_Muris_Senis_10x", split=split)
return ds
def load_computage_bench(split: str = "train") -> Dict:
"""
Load ComputAge benchmark for epigenetic aging clocks.
Returns dict with methylation samples and age labels.
"""
ds = load_dataset("computage/computage_bench", split=split)
return ds
def load_balm_benchmark(config: str = "BindingDB_filtered") -> Dict:
"""
Load BALM drug-target binding affinity benchmark.
Returns dict with Drug (SMILES), Target (protein sequence), Y (affinity).
"""
ds = load_dataset("BALM/BALM-benchmark", config, split="train")
return ds
# ─── Data Collation ──────────────────────────────────────────────────────────
def collate_multi_omics(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
Collate function for multi-omics batches.
Handles variable-length tensors within a batch (if any).
"""
keys = batch[0].keys()
collated = {}
for key in keys:
tensors = [item[key] for item in batch]
if tensors[0].dim() == 0: # scalar
collated[key] = torch.stack(tensors)
else:
collated[key] = torch.stack(tensors)
return collated
def create_data_loader(
dataset: Dataset,
batch_size: int = 256,
shuffle: bool = True,
num_workers: int = 4,
) -> DataLoader:
"""Create a DataLoader for multi-omics data."""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_multi_omics,
pin_memory=True,
)