MuLGIT / mulgit /perturb /encoder.py
vedatonuryilmaz's picture
Upload mulgit/perturb/encoder.py
a05da2c verified
"""
Perturbation Encoder for MuLGIT-Perturb.
Encodes drug (SMILES) and genetic perturbations into a unified embedding.
Drug encoding uses ChemBERTa/MolFormer when available and falls back to
RDKit Morgan fingerprints. The encoder always returns the configured
`embed_dim`, so fallback fingerprints cannot break downstream projection
layers.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
import warnings
class DrugEncoder(nn.Module):
"""SMILES -> fixed-size molecular embedding."""
def __init__(
self,
transformer_model: str = "seyonec/ChemBERTa-zinc-base-v1",
embed_dim: int = 768,
use_morgan_fallback: bool = True,
max_smiles_len: int = 512,
):
super().__init__()
self.embed_dim = embed_dim
self.use_morgan_fallback = use_morgan_fallback
self.max_smiles_len = max_smiles_len
self.transformer_model = transformer_model
self._transformer = None
self._tokenizer = None
self._tried_loading = False
def _match_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad/truncate any encoder output to self.embed_dim."""
if x.shape[-1] == self.embed_dim:
return x
if x.shape[-1] < self.embed_dim:
return F.pad(x, (0, self.embed_dim - x.shape[-1]))
return x[:, : self.embed_dim]
def _load_transformer(self):
if self._tried_loading:
return
self._tried_loading = True
try:
from transformers import AutoModel, AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(self.transformer_model)
self._transformer = AutoModel.from_pretrained(self.transformer_model)
for p in self._transformer.parameters():
p.requires_grad = False
self._transformer.eval()
except Exception as e:
warnings.warn(
f"Failed to load {self.transformer_model}: {e}. Falling back to Morgan fingerprints."
)
self._transformer = None
self._tokenizer = None
def encode_morgan(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
try:
from rdkit import Chem
from rdkit.Chem import AllChem
except ImportError as e:
raise ImportError("RDKit required for Morgan fingerprint fallback. Install with: pip install rdkit") from e
fps = []
for smi in smiles_list:
try:
mol = Chem.MolFromSmiles(smi or "")
if mol is not None:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
fp_array = torch.tensor(list(fp), dtype=torch.float32)
else:
fp_array = torch.zeros(2048, dtype=torch.float32)
except Exception:
fp_array = torch.zeros(2048, dtype=torch.float32)
fps.append(fp_array)
return self._match_dim(torch.stack(fps).to(device))
def encode_transformer(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
self._load_transformer()
if self._transformer is None or self._tokenizer is None:
if self.use_morgan_fallback:
return self.encode_morgan(smiles_list, device)
raise RuntimeError("No drug encoder available")
tokens = self._tokenizer(
smiles_list,
padding=True,
truncation=True,
max_length=self.max_smiles_len,
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = self._transformer(**tokens)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
embeddings = outputs.pooler_output
else:
embeddings = outputs.last_hidden_state.mean(dim=1)
return self._match_dim(embeddings.float())
def forward(self, smiles_list: List[str], device: torch.device = None) -> torch.Tensor:
if device is None:
# DrugEncoder has no parameters; default to CPU unless caller passes device.
device = torch.device("cpu")
try:
return self.encode_transformer(smiles_list, device)
except Exception as e:
if self.use_morgan_fallback:
warnings.warn(f"Transformer encoding failed ({e}); using Morgan fingerprints")
return self.encode_morgan(smiles_list, device)
raise
class GeneticPerturbationEncoder(nn.Module):
"""Gene perturbation encoder: (gene_id, perturbation_type) -> embedding."""
def __init__(
self,
gene_embed_dim: int = 256,
pert_type_dim: int = 5,
output_dim: int = 768,
n_genes: int = 20000,
):
super().__init__()
self.gene_embedding = nn.Embedding(n_genes, gene_embed_dim, padding_idx=0)
self.pert_type_embedding = nn.Embedding(pert_type_dim, 32)
self.projector = nn.Sequential(
nn.Linear(gene_embed_dim + 32, 512),
nn.SELU(),
nn.Linear(512, output_dim),
)
self.pert_type_map = {
"crispr_ko": 0,
"ko": 0,
"crispri": 1,
"kd": 1,
"knockdown": 1,
"shrna": 2,
"oe": 3,
"overexpression": 3,
"crispra": 4,
"activation": 4,
"unknown": 0,
}
def load_geneformer_embeddings(self, geneformer_model: str = "ctheodoris/Geneformer"):
warnings.warn(
"Geneformer embedding extraction is not automated in this implementation; "
"using learned gene embeddings. Pre-extracted Geneformer embeddings can be loaded manually."
)
def forward(self, gene_ids: torch.LongTensor, pert_types: List[str]) -> torch.Tensor:
device = gene_ids.device
gene_emb = self.gene_embedding(gene_ids)
pert_type_idxs = torch.tensor(
[self.pert_type_map.get(str(pt).lower(), 0) for pt in pert_types],
dtype=torch.long,
device=device,
)
pert_type_emb = self.pert_type_embedding(pert_type_idxs)
return self.projector(torch.cat([gene_emb, pert_type_emb], dim=-1))
class PerturbationEncoder(nn.Module):
"""Unified perturbation encoder for drugs and genetic perturbations."""
def __init__(
self,
drug_encoder_model: str = "seyonec/ChemBERTa-zinc-base-v1",
drug_embed_dim: int = 768,
gene_embed_dim: int = 256,
output_dim: int = 768,
use_morgan_fallback: bool = True,
n_genes: int = 20000,
):
super().__init__()
self.output_dim = output_dim
self.drug_encoder = DrugEncoder(
transformer_model=drug_encoder_model,
embed_dim=drug_embed_dim,
use_morgan_fallback=use_morgan_fallback,
)
self.drug_proj = nn.Sequential(nn.Linear(drug_embed_dim, output_dim), nn.SELU()) if drug_embed_dim != output_dim else nn.Identity()
self.gene_encoder = GeneticPerturbationEncoder(gene_embed_dim=gene_embed_dim, output_dim=output_dim, n_genes=n_genes)
self.alpha_logit = nn.Parameter(torch.tensor(0.0))
self.output_proj = nn.Sequential(nn.Linear(output_dim, output_dim), nn.SELU(), nn.Linear(output_dim, output_dim))
self.dropout = nn.AlphaDropout(0.1)
def forward(
self,
smiles_list: Optional[List[str]] = None,
gene_ids: Optional[torch.LongTensor] = None,
pert_types: Optional[List[str]] = None,
) -> torch.Tensor:
if gene_ids is not None:
batch_size = gene_ids.shape[0]
device = gene_ids.device
elif smiles_list is not None:
batch_size = len(smiles_list)
device = self.alpha_logit.device
else:
raise ValueError("Either smiles_list or gene_ids must be provided")
if smiles_list is not None and len(smiles_list) > 0:
z_drug = self.drug_proj(self.drug_encoder(smiles_list, device))
else:
z_drug = torch.zeros(batch_size, self.output_dim, device=device)
if gene_ids is not None and pert_types is not None:
z_gene = self.gene_encoder(gene_ids, pert_types)
else:
z_gene = torch.zeros(batch_size, self.output_dim, device=device)
alpha = torch.sigmoid(self.alpha_logit)
z_pert = alpha * z_drug + (1.0 - alpha) * z_gene
return self.output_proj(self.dropout(F.selu(z_pert)))
def create_perturbation_encoder(config) -> PerturbationEncoder:
return PerturbationEncoder(
drug_encoder_model=config.drug_encoder_model,
drug_embed_dim=config.drug_embed_dim,
gene_embed_dim=config.gene_embed_dim,
output_dim=config.pert_output_dim,
use_morgan_fallback=config.use_morgan_fallback,
)