| """ |
| Perturbation Encoder for MuLGIT-Perturb. |
| |
| Encodes drug (SMILES) and genetic perturbations into a unified embedding. |
| Drug encoding uses ChemBERTa/MolFormer when available and falls back to |
| RDKit Morgan fingerprints. The encoder always returns the configured |
| `embed_dim`, so fallback fingerprints cannot break downstream projection |
| layers. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, List |
| import warnings |
|
|
|
|
| class DrugEncoder(nn.Module): |
| """SMILES -> fixed-size molecular embedding.""" |
|
|
| def __init__( |
| self, |
| transformer_model: str = "seyonec/ChemBERTa-zinc-base-v1", |
| embed_dim: int = 768, |
| use_morgan_fallback: bool = True, |
| max_smiles_len: int = 512, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.use_morgan_fallback = use_morgan_fallback |
| self.max_smiles_len = max_smiles_len |
| self.transformer_model = transformer_model |
| self._transformer = None |
| self._tokenizer = None |
| self._tried_loading = False |
|
|
| def _match_dim(self, x: torch.Tensor) -> torch.Tensor: |
| """Pad/truncate any encoder output to self.embed_dim.""" |
| if x.shape[-1] == self.embed_dim: |
| return x |
| if x.shape[-1] < self.embed_dim: |
| return F.pad(x, (0, self.embed_dim - x.shape[-1])) |
| return x[:, : self.embed_dim] |
|
|
| def _load_transformer(self): |
| if self._tried_loading: |
| return |
| self._tried_loading = True |
| try: |
| from transformers import AutoModel, AutoTokenizer |
| self._tokenizer = AutoTokenizer.from_pretrained(self.transformer_model) |
| self._transformer = AutoModel.from_pretrained(self.transformer_model) |
| for p in self._transformer.parameters(): |
| p.requires_grad = False |
| self._transformer.eval() |
| except Exception as e: |
| warnings.warn( |
| f"Failed to load {self.transformer_model}: {e}. Falling back to Morgan fingerprints." |
| ) |
| self._transformer = None |
| self._tokenizer = None |
|
|
| def encode_morgan(self, smiles_list: List[str], device: torch.device) -> torch.Tensor: |
| try: |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| except ImportError as e: |
| raise ImportError("RDKit required for Morgan fingerprint fallback. Install with: pip install rdkit") from e |
|
|
| fps = [] |
| for smi in smiles_list: |
| try: |
| mol = Chem.MolFromSmiles(smi or "") |
| if mol is not None: |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) |
| fp_array = torch.tensor(list(fp), dtype=torch.float32) |
| else: |
| fp_array = torch.zeros(2048, dtype=torch.float32) |
| except Exception: |
| fp_array = torch.zeros(2048, dtype=torch.float32) |
| fps.append(fp_array) |
| return self._match_dim(torch.stack(fps).to(device)) |
|
|
| def encode_transformer(self, smiles_list: List[str], device: torch.device) -> torch.Tensor: |
| self._load_transformer() |
| if self._transformer is None or self._tokenizer is None: |
| if self.use_morgan_fallback: |
| return self.encode_morgan(smiles_list, device) |
| raise RuntimeError("No drug encoder available") |
|
|
| tokens = self._tokenizer( |
| smiles_list, |
| padding=True, |
| truncation=True, |
| max_length=self.max_smiles_len, |
| return_tensors="pt", |
| ).to(device) |
| with torch.no_grad(): |
| outputs = self._transformer(**tokens) |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
| embeddings = outputs.pooler_output |
| else: |
| embeddings = outputs.last_hidden_state.mean(dim=1) |
| return self._match_dim(embeddings.float()) |
|
|
| def forward(self, smiles_list: List[str], device: torch.device = None) -> torch.Tensor: |
| if device is None: |
| |
| device = torch.device("cpu") |
| try: |
| return self.encode_transformer(smiles_list, device) |
| except Exception as e: |
| if self.use_morgan_fallback: |
| warnings.warn(f"Transformer encoding failed ({e}); using Morgan fingerprints") |
| return self.encode_morgan(smiles_list, device) |
| raise |
|
|
|
|
| class GeneticPerturbationEncoder(nn.Module): |
| """Gene perturbation encoder: (gene_id, perturbation_type) -> embedding.""" |
|
|
| def __init__( |
| self, |
| gene_embed_dim: int = 256, |
| pert_type_dim: int = 5, |
| output_dim: int = 768, |
| n_genes: int = 20000, |
| ): |
| super().__init__() |
| self.gene_embedding = nn.Embedding(n_genes, gene_embed_dim, padding_idx=0) |
| self.pert_type_embedding = nn.Embedding(pert_type_dim, 32) |
| self.projector = nn.Sequential( |
| nn.Linear(gene_embed_dim + 32, 512), |
| nn.SELU(), |
| nn.Linear(512, output_dim), |
| ) |
| self.pert_type_map = { |
| "crispr_ko": 0, |
| "ko": 0, |
| "crispri": 1, |
| "kd": 1, |
| "knockdown": 1, |
| "shrna": 2, |
| "oe": 3, |
| "overexpression": 3, |
| "crispra": 4, |
| "activation": 4, |
| "unknown": 0, |
| } |
|
|
| def load_geneformer_embeddings(self, geneformer_model: str = "ctheodoris/Geneformer"): |
| warnings.warn( |
| "Geneformer embedding extraction is not automated in this implementation; " |
| "using learned gene embeddings. Pre-extracted Geneformer embeddings can be loaded manually." |
| ) |
|
|
| def forward(self, gene_ids: torch.LongTensor, pert_types: List[str]) -> torch.Tensor: |
| device = gene_ids.device |
| gene_emb = self.gene_embedding(gene_ids) |
| pert_type_idxs = torch.tensor( |
| [self.pert_type_map.get(str(pt).lower(), 0) for pt in pert_types], |
| dtype=torch.long, |
| device=device, |
| ) |
| pert_type_emb = self.pert_type_embedding(pert_type_idxs) |
| return self.projector(torch.cat([gene_emb, pert_type_emb], dim=-1)) |
|
|
|
|
| class PerturbationEncoder(nn.Module): |
| """Unified perturbation encoder for drugs and genetic perturbations.""" |
|
|
| def __init__( |
| self, |
| drug_encoder_model: str = "seyonec/ChemBERTa-zinc-base-v1", |
| drug_embed_dim: int = 768, |
| gene_embed_dim: int = 256, |
| output_dim: int = 768, |
| use_morgan_fallback: bool = True, |
| n_genes: int = 20000, |
| ): |
| super().__init__() |
| self.output_dim = output_dim |
| self.drug_encoder = DrugEncoder( |
| transformer_model=drug_encoder_model, |
| embed_dim=drug_embed_dim, |
| use_morgan_fallback=use_morgan_fallback, |
| ) |
| self.drug_proj = nn.Sequential(nn.Linear(drug_embed_dim, output_dim), nn.SELU()) if drug_embed_dim != output_dim else nn.Identity() |
| self.gene_encoder = GeneticPerturbationEncoder(gene_embed_dim=gene_embed_dim, output_dim=output_dim, n_genes=n_genes) |
| self.alpha_logit = nn.Parameter(torch.tensor(0.0)) |
| self.output_proj = nn.Sequential(nn.Linear(output_dim, output_dim), nn.SELU(), nn.Linear(output_dim, output_dim)) |
| self.dropout = nn.AlphaDropout(0.1) |
|
|
| def forward( |
| self, |
| smiles_list: Optional[List[str]] = None, |
| gene_ids: Optional[torch.LongTensor] = None, |
| pert_types: Optional[List[str]] = None, |
| ) -> torch.Tensor: |
| if gene_ids is not None: |
| batch_size = gene_ids.shape[0] |
| device = gene_ids.device |
| elif smiles_list is not None: |
| batch_size = len(smiles_list) |
| device = self.alpha_logit.device |
| else: |
| raise ValueError("Either smiles_list or gene_ids must be provided") |
|
|
| if smiles_list is not None and len(smiles_list) > 0: |
| z_drug = self.drug_proj(self.drug_encoder(smiles_list, device)) |
| else: |
| z_drug = torch.zeros(batch_size, self.output_dim, device=device) |
|
|
| if gene_ids is not None and pert_types is not None: |
| z_gene = self.gene_encoder(gene_ids, pert_types) |
| else: |
| z_gene = torch.zeros(batch_size, self.output_dim, device=device) |
|
|
| alpha = torch.sigmoid(self.alpha_logit) |
| z_pert = alpha * z_drug + (1.0 - alpha) * z_gene |
| return self.output_proj(self.dropout(F.selu(z_pert))) |
|
|
|
|
| def create_perturbation_encoder(config) -> PerturbationEncoder: |
| return PerturbationEncoder( |
| drug_encoder_model=config.drug_encoder_model, |
| drug_embed_dim=config.drug_embed_dim, |
| gene_embed_dim=config.gene_embed_dim, |
| output_dim=config.pert_output_dim, |
| use_morgan_fallback=config.use_morgan_fallback, |
| ) |
|
|