| """ |
| PATH-AE: Cross-Species Omics Transfer Learning via Autoencoders |
| |
| Enables knowledge transfer between species (e.g., C. elegans β mouse β human) |
| for elucidating evolutionarily conserved and species-specific molecular |
| determinants of exceptional longevity. |
| |
| Architecture: |
| 1. Per-species autoencoders learn compact latent representations |
| 2. Cross-species alignment loss maps orthologous features to shared space |
| 3. Phylogenetic distance weighting accounts for evolutionary relationships |
| 4. Species-specific decoders identify conserved vs species-specific features |
| |
| References: |
| - EVA (arxiv:2602.10168): Cross-species multimodal foundation model |
| - GPN/PhyloGPN (arxiv:2503.03773): Phylogenetic genomic language model |
| - Domain adaptation via autoencoders for cross-species transfer |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, List, Tuple, Set |
|
|
|
|
| |
|
|
| class OrthologMapper: |
| """ |
| Maps features between species using ortholog relationships. |
| |
| In practice, this uses Ensembl BioMart data to map human genes to |
| their orthologs in model organisms (mouse, C. elegans). |
| |
| For the MVP, we provide a programmatic interface that can be populated |
| from Ensembl BioMart exports. |
| """ |
| |
| def __init__(self): |
| self._ortholog_map: Dict[str, Dict[str, List[str]]] = {} |
| |
| |
| def load_from_ensembl(self, biomart_csv_path: str): |
| """Load ortholog mappings from Ensembl BioMart export.""" |
| import csv |
| self._ortholog_map = {} |
| with open(biomart_csv_path) as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| source_species = row.get("source_species", "human") |
| target_species = row.get("target_species", "mouse") |
| source_gene = row.get("source_gene", "") |
| target_gene = row.get("target_gene", "") |
| |
| if source_species not in self._ortholog_map: |
| self._ortholog_map[source_species] = {} |
| if target_species not in self._ortholog_map[source_species]: |
| self._ortholog_map[source_species][target_species] = [] |
| self._ortholog_map[source_species][target_species].append( |
| [source_gene, target_gene] |
| ) |
|
|
| def get_shared_genes( |
| self, species_a: str, species_b: str |
| ) -> Tuple[List[str], List[str]]: |
| """Get aligned ortholog pairs between two species.""" |
| if species_a in self._ortholog_map and species_b in self._ortholog_map[species_a]: |
| pairs = self._ortholog_map[species_a][species_b] |
| genes_a, genes_b = zip(*pairs) if pairs else ([], []) |
| return list(genes_a), list(genes_b) |
| return [], [] |
|
|
|
|
| |
|
|
| class SpeciesAutoencoder(nn.Module): |
| """ |
| Autoencoder for a single species' omics data. |
| |
| Encoder: omics features β compact latent representation |
| Decoder: latent β reconstructed omics features |
| |
| The latent space serves as a "molecular embedding" that can be |
| aligned across species. |
| """ |
| |
| def __init__( |
| self, |
| input_dim: int, |
| latent_dim: int = 128, |
| hidden_dims: List[int] = None, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| if hidden_dims is None: |
| hidden_dims = [512, 256] |
| |
| |
| encoder_layers = [] |
| prev_dim = input_dim |
| for h_dim in hidden_dims: |
| encoder_layers.extend([ |
| nn.Linear(prev_dim, h_dim), |
| nn.SELU(), |
| nn.AlphaDropout(dropout), |
| ]) |
| prev_dim = h_dim |
| encoder_layers.append(nn.Linear(prev_dim, latent_dim)) |
| self.encoder = nn.Sequential(*encoder_layers) |
| |
| |
| decoder_layers = [] |
| prev_dim = latent_dim |
| for h_dim in reversed(hidden_dims): |
| decoder_layers.extend([ |
| nn.Linear(prev_dim, h_dim), |
| nn.SELU(), |
| nn.AlphaDropout(dropout), |
| ]) |
| prev_dim = h_dim |
| decoder_layers.append(nn.Linear(prev_dim, input_dim)) |
| self.decoder = nn.Sequential(*decoder_layers) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| latent: encoded representation (B, latent_dim) |
| reconstructed: decoded reconstruction (B, input_dim) |
| """ |
| latent = self.encoder(x) |
| reconstructed = self.decoder(latent) |
| return latent, reconstructed |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| return self.encoder(x) |
|
|
|
|
| |
|
|
| class CrossSpeciesAligner(nn.Module): |
| """ |
| Aligns latent representations across species using ortholog-guided |
| contrastive learning. |
| |
| For each ortholog pair, the model learns to map the species-specific |
| representations close together in the shared latent space, while |
| separating non-orthologous gene pairs. |
| """ |
| |
| def __init__( |
| self, |
| latent_dim: int = 128, |
| temperature: float = 0.07, |
| ): |
| super().__init__() |
| |
| self.projectors = nn.ModuleDict() |
| self.latent_dim = latent_dim |
| self.temperature = temperature |
|
|
| def add_species(self, species_name: str): |
| """Register a new species with its projection head.""" |
| if species_name not in self.projectors: |
| self.projectors[species_name] = nn.Sequential( |
| nn.Linear(self.latent_dim, self.latent_dim), |
| nn.SELU(), |
| nn.Linear(self.latent_dim, self.latent_dim), |
| ) |
|
|
| def project( |
| self, species_name: str, latent: torch.Tensor |
| ) -> torch.Tensor: |
| """Project species-specific latent to shared alignment space.""" |
| return self.projectors[species_name](latent) |
|
|
| def alignment_loss( |
| self, |
| latent_a: torch.Tensor, |
| species_a: str, |
| latent_b: torch.Tensor, |
| species_b: str, |
| ortholog_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Contrastive alignment loss for orthologous gene pairs. |
| |
| Args: |
| latent_a: (N_a, D) representations from species A |
| species_a: name of species A |
| latent_b: (N_b, D) representations from species B |
| species_b: name of species B |
| ortholog_mask: (N_a, N_b) 1 if (i,j) are orthologs, 0 otherwise |
| """ |
| |
| z_a = self.project(species_a, latent_a) |
| z_b = self.project(species_b, latent_b) |
| |
| |
| z_a = F.normalize(z_a, dim=-1) |
| z_b = F.normalize(z_b, dim=-1) |
| |
| |
| sim = torch.matmul(z_a, z_b.T) / self.temperature |
| |
| |
| pos_sim = sim[ortholog_mask.bool()] |
| |
| |
| |
| loss_a_to_b = -pos_sim + torch.logsumexp(sim, dim=-1)[ortholog_mask.bool().any(dim=-1)] |
| |
| |
| sim_t = sim.T |
| pos_sim_t = sim_t[ortholog_mask.T.bool()] |
| loss_b_to_a = -pos_sim_t + torch.logsumexp(sim_t, dim=-1)[ortholog_mask.T.bool().any(dim=-1)] |
| |
| return (loss_a_to_b.mean() + loss_b_to_a.mean()) / 2 |
|
|
|
|
| |
|
|
| class PATH_AE(nn.Module): |
| """ |
| Phylogenetic Autoencoder for Transfer Heterogeneity. |
| |
| Complete cross-species omics transfer learning framework: |
| 1. Per-species autoencoders for feature compression |
| 2. Cross-species alignment via ortholog contrastive learning |
| 3. Phylogenetic distance weighting |
| 4. Species-specific decoders for conserved vs species-specific features |
| """ |
| |
| def __init__( |
| self, |
| species_dims: Dict[str, int], |
| latent_dim: int = 128, |
| alignment_strength: float = 0.1, |
| reconstruction_strength: float = 1.0, |
| ): |
| """ |
| Args: |
| species_dims: {"human": 20000, "mouse": 20000, "celegans": 18000} |
| latent_dim: dimension of shared latent space |
| alignment_strength: weight for cross-species alignment loss |
| reconstruction_strength: weight for autoencoder reconstruction loss |
| """ |
| super().__init__() |
| |
| |
| self.autoencoders = nn.ModuleDict({ |
| species: SpeciesAutoencoder(dim, latent_dim) |
| for species, dim in species_dims.items() |
| }) |
| |
| |
| self.aligner = CrossSpeciesAligner(latent_dim) |
| for species in species_dims: |
| self.aligner.add_species(species) |
| |
| |
| self.conservation_classifier = nn.Linear(latent_dim, 1) |
| |
| self.alignment_strength = alignment_strength |
| self.reconstruction_strength = reconstruction_strength |
| self.species_dims = species_dims |
|
|
| def encode( |
| self, species: str, x: torch.Tensor |
| ) -> torch.Tensor: |
| """Encode species-specific features to latent space.""" |
| return self.autoencoders[species].encode(x) |
|
|
| def encode_all( |
| self, species_data: Dict[str, torch.Tensor] |
| ) -> Dict[str, torch.Tensor]: |
| """Encode all species' data to latent space.""" |
| return { |
| species: self.encode(species, x) |
| for species, x in species_data.items() |
| } |
|
|
| def compute_conservation_score( |
| self, species: str, x: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Predict how evolutionarily conserved each gene's expression pattern is. |
| High score = conserved across species, low = species-specific. |
| """ |
| latent = self.encode(species, x) |
| return torch.sigmoid(self.conservation_classifier(latent)) |
|
|
| def reconstruction_loss( |
| self, |
| species: str, |
| x: torch.Tensor, |
| ) -> torch.Tensor: |
| """MSE reconstruction loss for autoencoder.""" |
| latent, reconstructed = self.autoencoders[species](x) |
| return F.mse_loss(reconstructed, x) |
|
|
| def transfer_features( |
| self, |
| source_species: str, |
| source_data: torch.Tensor, |
| target_species: str, |
| ) -> torch.Tensor: |
| """ |
| Transfer features from source species to target species latent space. |
| |
| Encodes source data, then projects to shared alignment space. |
| The aligned representation can be used for zero-shot prediction |
| in the target species. |
| """ |
| source_latent = self.encode(source_species, source_data) |
| aligned = self.aligner.project(source_species, source_latent) |
| |
| |
| return aligned |
|
|
| def forward( |
| self, |
| species_data: Dict[str, torch.Tensor], |
| ortholog_masks: Dict[Tuple[str, str], torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Full forward pass. |
| |
| Args: |
| species_data: {"human": tensor, "mouse": tensor, ...} |
| ortholog_masks: {(species_a, species_b): mask_tensor, ...} |
| |
| Returns dict with: |
| latents: per-species latent representations |
| aligned: per-species aligned representations |
| reconstructed: per-species reconstructions |
| conservation_scores: per-species conservation scores |
| """ |
| latents = {} |
| aligned = {} |
| reconstructed = {} |
| conservation_scores = {} |
| |
| for species, x in species_data.items(): |
| latent, recon = self.autoencoders[species](x) |
| latents[species] = latent |
| reconstructed[species] = recon |
| aligned[species] = self.aligner.project(species, latent) |
| conservation_scores[species] = torch.sigmoid( |
| self.conservation_classifier(latent) |
| ) |
| |
| return { |
| "latents": latents, |
| "aligned": aligned, |
| "reconstructed": reconstructed, |
| "conservation_scores": conservation_scores, |
| } |
|
|
|
|
| |
| |
|
|
| def create_human_mouse_ortholog_map() -> OrthologMapper: |
| """ |
| Create ortholog map between human and mouse. |
| In production: query Ensembl BioMart REST API. |
| """ |
| mapper = OrthologMapper() |
| |
| |
| return mapper |
|
|
|
|
| def create_default_path_ae() -> PATH_AE: |
| """ |
| Create a default PATH-AE model for human, mouse, and C. elegans. |
| Approximate gene counts: |
| - Human: ~20,000 protein-coding genes |
| - Mouse: ~20,000 protein-coding genes |
| - C. elegans: ~20,000 protein-coding genes |
| """ |
| return PATH_AE( |
| species_dims={ |
| "human": 20000, |
| "mouse": 20000, |
| "celegans": 20000, |
| }, |
| latent_dim=128, |
| ) |
|
|