""" Perturbation Encoder for MuLGIT-Perturb. Encodes drug (SMILES) and genetic perturbations into a unified embedding. Drug encoding uses ChemBERTa/MolFormer when available and falls back to RDKit Morgan fingerprints. The encoder always returns the configured `embed_dim`, so fallback fingerprints cannot break downstream projection layers. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List import warnings class DrugEncoder(nn.Module): """SMILES -> fixed-size molecular embedding.""" def __init__( self, transformer_model: str = "seyonec/ChemBERTa-zinc-base-v1", embed_dim: int = 768, use_morgan_fallback: bool = True, max_smiles_len: int = 512, ): super().__init__() self.embed_dim = embed_dim self.use_morgan_fallback = use_morgan_fallback self.max_smiles_len = max_smiles_len self.transformer_model = transformer_model self._transformer = None self._tokenizer = None self._tried_loading = False def _match_dim(self, x: torch.Tensor) -> torch.Tensor: """Pad/truncate any encoder output to self.embed_dim.""" if x.shape[-1] == self.embed_dim: return x if x.shape[-1] < self.embed_dim: return F.pad(x, (0, self.embed_dim - x.shape[-1])) return x[:, : self.embed_dim] def _load_transformer(self): if self._tried_loading: return self._tried_loading = True try: from transformers import AutoModel, AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained(self.transformer_model) self._transformer = AutoModel.from_pretrained(self.transformer_model) for p in self._transformer.parameters(): p.requires_grad = False self._transformer.eval() except Exception as e: warnings.warn( f"Failed to load {self.transformer_model}: {e}. Falling back to Morgan fingerprints." ) self._transformer = None self._tokenizer = None def encode_morgan(self, smiles_list: List[str], device: torch.device) -> torch.Tensor: try: from rdkit import Chem from rdkit.Chem import AllChem except ImportError as e: raise ImportError("RDKit required for Morgan fingerprint fallback. Install with: pip install rdkit") from e fps = [] for smi in smiles_list: try: mol = Chem.MolFromSmiles(smi or "") if mol is not None: fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) fp_array = torch.tensor(list(fp), dtype=torch.float32) else: fp_array = torch.zeros(2048, dtype=torch.float32) except Exception: fp_array = torch.zeros(2048, dtype=torch.float32) fps.append(fp_array) return self._match_dim(torch.stack(fps).to(device)) def encode_transformer(self, smiles_list: List[str], device: torch.device) -> torch.Tensor: self._load_transformer() if self._transformer is None or self._tokenizer is None: if self.use_morgan_fallback: return self.encode_morgan(smiles_list, device) raise RuntimeError("No drug encoder available") tokens = self._tokenizer( smiles_list, padding=True, truncation=True, max_length=self.max_smiles_len, return_tensors="pt", ).to(device) with torch.no_grad(): outputs = self._transformer(**tokens) if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: embeddings = outputs.pooler_output else: embeddings = outputs.last_hidden_state.mean(dim=1) return self._match_dim(embeddings.float()) def forward(self, smiles_list: List[str], device: torch.device = None) -> torch.Tensor: if device is None: # DrugEncoder has no parameters; default to CPU unless caller passes device. device = torch.device("cpu") try: return self.encode_transformer(smiles_list, device) except Exception as e: if self.use_morgan_fallback: warnings.warn(f"Transformer encoding failed ({e}); using Morgan fingerprints") return self.encode_morgan(smiles_list, device) raise class GeneticPerturbationEncoder(nn.Module): """Gene perturbation encoder: (gene_id, perturbation_type) -> embedding.""" def __init__( self, gene_embed_dim: int = 256, pert_type_dim: int = 5, output_dim: int = 768, n_genes: int = 20000, ): super().__init__() self.gene_embedding = nn.Embedding(n_genes, gene_embed_dim, padding_idx=0) self.pert_type_embedding = nn.Embedding(pert_type_dim, 32) self.projector = nn.Sequential( nn.Linear(gene_embed_dim + 32, 512), nn.SELU(), nn.Linear(512, output_dim), ) self.pert_type_map = { "crispr_ko": 0, "ko": 0, "crispri": 1, "kd": 1, "knockdown": 1, "shrna": 2, "oe": 3, "overexpression": 3, "crispra": 4, "activation": 4, "unknown": 0, } def load_geneformer_embeddings(self, geneformer_model: str = "ctheodoris/Geneformer"): warnings.warn( "Geneformer embedding extraction is not automated in this implementation; " "using learned gene embeddings. Pre-extracted Geneformer embeddings can be loaded manually." ) def forward(self, gene_ids: torch.LongTensor, pert_types: List[str]) -> torch.Tensor: device = gene_ids.device gene_emb = self.gene_embedding(gene_ids) pert_type_idxs = torch.tensor( [self.pert_type_map.get(str(pt).lower(), 0) for pt in pert_types], dtype=torch.long, device=device, ) pert_type_emb = self.pert_type_embedding(pert_type_idxs) return self.projector(torch.cat([gene_emb, pert_type_emb], dim=-1)) class PerturbationEncoder(nn.Module): """Unified perturbation encoder for drugs and genetic perturbations.""" def __init__( self, drug_encoder_model: str = "seyonec/ChemBERTa-zinc-base-v1", drug_embed_dim: int = 768, gene_embed_dim: int = 256, output_dim: int = 768, use_morgan_fallback: bool = True, n_genes: int = 20000, ): super().__init__() self.output_dim = output_dim self.drug_encoder = DrugEncoder( transformer_model=drug_encoder_model, embed_dim=drug_embed_dim, use_morgan_fallback=use_morgan_fallback, ) self.drug_proj = nn.Sequential(nn.Linear(drug_embed_dim, output_dim), nn.SELU()) if drug_embed_dim != output_dim else nn.Identity() self.gene_encoder = GeneticPerturbationEncoder(gene_embed_dim=gene_embed_dim, output_dim=output_dim, n_genes=n_genes) self.alpha_logit = nn.Parameter(torch.tensor(0.0)) self.output_proj = nn.Sequential(nn.Linear(output_dim, output_dim), nn.SELU(), nn.Linear(output_dim, output_dim)) self.dropout = nn.AlphaDropout(0.1) def forward( self, smiles_list: Optional[List[str]] = None, gene_ids: Optional[torch.LongTensor] = None, pert_types: Optional[List[str]] = None, ) -> torch.Tensor: if gene_ids is not None: batch_size = gene_ids.shape[0] device = gene_ids.device elif smiles_list is not None: batch_size = len(smiles_list) device = self.alpha_logit.device else: raise ValueError("Either smiles_list or gene_ids must be provided") if smiles_list is not None and len(smiles_list) > 0: z_drug = self.drug_proj(self.drug_encoder(smiles_list, device)) else: z_drug = torch.zeros(batch_size, self.output_dim, device=device) if gene_ids is not None and pert_types is not None: z_gene = self.gene_encoder(gene_ids, pert_types) else: z_gene = torch.zeros(batch_size, self.output_dim, device=device) alpha = torch.sigmoid(self.alpha_logit) z_pert = alpha * z_drug + (1.0 - alpha) * z_gene return self.output_proj(self.dropout(F.selu(z_pert))) def create_perturbation_encoder(config) -> PerturbationEncoder: return PerturbationEncoder( drug_encoder_model=config.drug_encoder_model, drug_embed_dim=config.drug_embed_dim, gene_embed_dim=config.gene_embed_dim, output_dim=config.pert_output_dim, use_morgan_fallback=config.use_morgan_fallback, )