MuLGIT / mulgit /path_ae.py
vedatonuryilmaz's picture
Upload mulgit/path_ae.py with huggingface_hub
c7ad40e verified
"""
PATH-AE: Cross-Species Omics Transfer Learning via Autoencoders
Enables knowledge transfer between species (e.g., C. elegans β†’ mouse β†’ human)
for elucidating evolutionarily conserved and species-specific molecular
determinants of exceptional longevity.
Architecture:
1. Per-species autoencoders learn compact latent representations
2. Cross-species alignment loss maps orthologous features to shared space
3. Phylogenetic distance weighting accounts for evolutionary relationships
4. Species-specific decoders identify conserved vs species-specific features
References:
- EVA (arxiv:2602.10168): Cross-species multimodal foundation model
- GPN/PhyloGPN (arxiv:2503.03773): Phylogenetic genomic language model
- Domain adaptation via autoencoders for cross-species transfer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple, Set
# ─── Ortholog Mapping ────────────────────────────────────────────────────────
class OrthologMapper:
"""
Maps features between species using ortholog relationships.
In practice, this uses Ensembl BioMart data to map human genes to
their orthologs in model organisms (mouse, C. elegans).
For the MVP, we provide a programmatic interface that can be populated
from Ensembl BioMart exports.
"""
def __init__(self):
self._ortholog_map: Dict[str, Dict[str, List[str]]] = {}
# species_a β†’ species_b β†’ [[gene_a1, gene_b1], [gene_a2, gene_b2], ...]
def load_from_ensembl(self, biomart_csv_path: str):
"""Load ortholog mappings from Ensembl BioMart export."""
import csv
self._ortholog_map = {}
with open(biomart_csv_path) as f:
reader = csv.DictReader(f)
for row in reader:
source_species = row.get("source_species", "human")
target_species = row.get("target_species", "mouse")
source_gene = row.get("source_gene", "")
target_gene = row.get("target_gene", "")
if source_species not in self._ortholog_map:
self._ortholog_map[source_species] = {}
if target_species not in self._ortholog_map[source_species]:
self._ortholog_map[source_species][target_species] = []
self._ortholog_map[source_species][target_species].append(
[source_gene, target_gene]
)
def get_shared_genes(
self, species_a: str, species_b: str
) -> Tuple[List[str], List[str]]:
"""Get aligned ortholog pairs between two species."""
if species_a in self._ortholog_map and species_b in self._ortholog_map[species_a]:
pairs = self._ortholog_map[species_a][species_b]
genes_a, genes_b = zip(*pairs) if pairs else ([], [])
return list(genes_a), list(genes_b)
return [], []
# ─── Per-Species Autoencoder ─────────────────────────────────────────────────
class SpeciesAutoencoder(nn.Module):
"""
Autoencoder for a single species' omics data.
Encoder: omics features β†’ compact latent representation
Decoder: latent β†’ reconstructed omics features
The latent space serves as a "molecular embedding" that can be
aligned across species.
"""
def __init__(
self,
input_dim: int,
latent_dim: int = 128,
hidden_dims: List[int] = None,
dropout: float = 0.1,
):
super().__init__()
if hidden_dims is None:
hidden_dims = [512, 256]
# Encoder
encoder_layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
encoder_layers.extend([
nn.Linear(prev_dim, h_dim),
nn.SELU(),
nn.AlphaDropout(dropout),
])
prev_dim = h_dim
encoder_layers.append(nn.Linear(prev_dim, latent_dim))
self.encoder = nn.Sequential(*encoder_layers)
# Decoder (symmetric)
decoder_layers = []
prev_dim = latent_dim
for h_dim in reversed(hidden_dims):
decoder_layers.extend([
nn.Linear(prev_dim, h_dim),
nn.SELU(),
nn.AlphaDropout(dropout),
])
prev_dim = h_dim
decoder_layers.append(nn.Linear(prev_dim, input_dim))
self.decoder = nn.Sequential(*decoder_layers)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
latent: encoded representation (B, latent_dim)
reconstructed: decoded reconstruction (B, input_dim)
"""
latent = self.encoder(x)
reconstructed = self.decoder(latent)
return latent, reconstructed
def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
# ─── Cross-Species Alignment ─────────────────────────────────────────────────
class CrossSpeciesAligner(nn.Module):
"""
Aligns latent representations across species using ortholog-guided
contrastive learning.
For each ortholog pair, the model learns to map the species-specific
representations close together in the shared latent space, while
separating non-orthologous gene pairs.
"""
def __init__(
self,
latent_dim: int = 128,
temperature: float = 0.07,
):
super().__init__()
# Projection heads to map species-specific latents to shared space
self.projectors = nn.ModuleDict()
self.latent_dim = latent_dim
self.temperature = temperature
def add_species(self, species_name: str):
"""Register a new species with its projection head."""
if species_name not in self.projectors:
self.projectors[species_name] = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.SELU(),
nn.Linear(self.latent_dim, self.latent_dim),
)
def project(
self, species_name: str, latent: torch.Tensor
) -> torch.Tensor:
"""Project species-specific latent to shared alignment space."""
return self.projectors[species_name](latent)
def alignment_loss(
self,
latent_a: torch.Tensor,
species_a: str,
latent_b: torch.Tensor,
species_b: str,
ortholog_mask: torch.Tensor, # (N_a, N_b) boolean
) -> torch.Tensor:
"""
Contrastive alignment loss for orthologous gene pairs.
Args:
latent_a: (N_a, D) representations from species A
species_a: name of species A
latent_b: (N_b, D) representations from species B
species_b: name of species B
ortholog_mask: (N_a, N_b) 1 if (i,j) are orthologs, 0 otherwise
"""
# Project to shared space
z_a = self.project(species_a, latent_a)
z_b = self.project(species_b, latent_b)
# Normalize
z_a = F.normalize(z_a, dim=-1)
z_b = F.normalize(z_b, dim=-1)
# Similarity matrix
sim = torch.matmul(z_a, z_b.T) / self.temperature # (N_a, N_b)
# Positive pairs: orthologs
pos_sim = sim[ortholog_mask.bool()]
# InfoNCE loss (both directions)
# A→B: for each gene in A, find its ortholog in B
loss_a_to_b = -pos_sim + torch.logsumexp(sim, dim=-1)[ortholog_mask.bool().any(dim=-1)]
# B→A
sim_t = sim.T
pos_sim_t = sim_t[ortholog_mask.T.bool()]
loss_b_to_a = -pos_sim_t + torch.logsumexp(sim_t, dim=-1)[ortholog_mask.T.bool().any(dim=-1)]
return (loss_a_to_b.mean() + loss_b_to_a.mean()) / 2
# ─── PATH-AE: Full Cross-Species Transfer Framework ──────────────────────────
class PATH_AE(nn.Module):
"""
Phylogenetic Autoencoder for Transfer Heterogeneity.
Complete cross-species omics transfer learning framework:
1. Per-species autoencoders for feature compression
2. Cross-species alignment via ortholog contrastive learning
3. Phylogenetic distance weighting
4. Species-specific decoders for conserved vs species-specific features
"""
def __init__(
self,
species_dims: Dict[str, int],
latent_dim: int = 128,
alignment_strength: float = 0.1,
reconstruction_strength: float = 1.0,
):
"""
Args:
species_dims: {"human": 20000, "mouse": 20000, "celegans": 18000}
latent_dim: dimension of shared latent space
alignment_strength: weight for cross-species alignment loss
reconstruction_strength: weight for autoencoder reconstruction loss
"""
super().__init__()
# Per-species autoencoders
self.autoencoders = nn.ModuleDict({
species: SpeciesAutoencoder(dim, latent_dim)
for species, dim in species_dims.items()
})
# Cross-species aligner
self.aligner = CrossSpeciesAligner(latent_dim)
for species in species_dims:
self.aligner.add_species(species)
# Classifier for conserved vs species-specific (single binary head)
self.conservation_classifier = nn.Linear(latent_dim, 1)
self.alignment_strength = alignment_strength
self.reconstruction_strength = reconstruction_strength
self.species_dims = species_dims
def encode(
self, species: str, x: torch.Tensor
) -> torch.Tensor:
"""Encode species-specific features to latent space."""
return self.autoencoders[species].encode(x)
def encode_all(
self, species_data: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Encode all species' data to latent space."""
return {
species: self.encode(species, x)
for species, x in species_data.items()
}
def compute_conservation_score(
self, species: str, x: torch.Tensor
) -> torch.Tensor:
"""
Predict how evolutionarily conserved each gene's expression pattern is.
High score = conserved across species, low = species-specific.
"""
latent = self.encode(species, x)
return torch.sigmoid(self.conservation_classifier(latent))
def reconstruction_loss(
self,
species: str,
x: torch.Tensor,
) -> torch.Tensor:
"""MSE reconstruction loss for autoencoder."""
latent, reconstructed = self.autoencoders[species](x)
return F.mse_loss(reconstructed, x)
def transfer_features(
self,
source_species: str,
source_data: torch.Tensor,
target_species: str,
) -> torch.Tensor:
"""
Transfer features from source species to target species latent space.
Encodes source data, then projects to shared alignment space.
The aligned representation can be used for zero-shot prediction
in the target species.
"""
source_latent = self.encode(source_species, source_data)
aligned = self.aligner.project(source_species, source_latent)
# Map back through target species projector (reverse direction)
# Note: this is approximate β€” full cycle would need a reverse mapping
return aligned
def forward(
self,
species_data: Dict[str, torch.Tensor],
ortholog_masks: Dict[Tuple[str, str], torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Full forward pass.
Args:
species_data: {"human": tensor, "mouse": tensor, ...}
ortholog_masks: {(species_a, species_b): mask_tensor, ...}
Returns dict with:
latents: per-species latent representations
aligned: per-species aligned representations
reconstructed: per-species reconstructions
conservation_scores: per-species conservation scores
"""
latents = {}
aligned = {}
reconstructed = {}
conservation_scores = {}
for species, x in species_data.items():
latent, recon = self.autoencoders[species](x)
latents[species] = latent
reconstructed[species] = recon
aligned[species] = self.aligner.project(species, latent)
conservation_scores[species] = torch.sigmoid(
self.conservation_classifier(latent)
)
return {
"latents": latents,
"aligned": aligned,
"reconstructed": reconstructed,
"conservation_scores": conservation_scores,
}
# ─── Pre-built Ortholog Maps (from Ensembl BioMart) ─────────────────────────
# These are programmatic stubs β€” in production, load from BioMart exports.
def create_human_mouse_ortholog_map() -> OrthologMapper:
"""
Create ortholog map between human and mouse.
In production: query Ensembl BioMart REST API.
"""
mapper = OrthologMapper()
# In practice: populated from Ensembl BioMart
# Human ↔ Mouse one-to-one orthologs (~16,000 genes)
return mapper
def create_default_path_ae() -> PATH_AE:
"""
Create a default PATH-AE model for human, mouse, and C. elegans.
Approximate gene counts:
- Human: ~20,000 protein-coding genes
- Mouse: ~20,000 protein-coding genes
- C. elegans: ~20,000 protein-coding genes
"""
return PATH_AE(
species_dims={
"human": 20000,
"mouse": 20000,
"celegans": 20000,
},
latent_dim=128,
)