vedatonuryilmaz commited on
Commit
a05da2c
·
verified ·
1 Parent(s): c61a578

Upload mulgit/perturb/encoder.py

Browse files
Files changed (1) hide show
  1. mulgit/perturb/encoder.py +84 -208
mulgit/perturb/encoder.py CHANGED
@@ -1,35 +1,22 @@
1
  """
2
  Perturbation Encoder for MuLGIT-Perturb.
3
 
4
- Encodes both drug (SMILES) and genetic (gene ID) perturbations into a
5
- unified embedding space. Uses pretrained transformers as frozen feature
6
- extractors, with a learned fusion gate that determines the relative
7
- contribution of drug vs. genetic perturbation.
8
-
9
- Architecture:
10
- Drug: SMILES → ChemBERTa/MolFormer (frozen) → drug_embed (768-dim)
11
- Genetic: Gene symbol → Geneformer (frozen) → gene_embed (256-dim)
12
- + Perturbation type (KO/KD/OE) → 3-dim one-hot
13
- Fusion: Learned convex combination z_pert = α·z_drug + (1-α)·z_gene
14
  """
15
 
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
- from typing import Optional, List, Dict, Tuple
20
  import warnings
21
 
22
 
23
  class DrugEncoder(nn.Module):
24
- """
25
- Encodes drug SMILES strings into molecular embeddings.
26
-
27
- Two modes:
28
- 1. ChemBERTa/MolFormer (primary): pretrained transformer, frozen.
29
- Provides 768-dim continuous embedding capturing molecular properties.
30
- 2. Morgan fingerprints (fallback): 2048-bit ECFP4 fingerprint via RDKit.
31
- Used when SMILES parsing fails or transformer is unavailable.
32
- """
33
 
34
  def __init__(
35
  self,
@@ -43,39 +30,47 @@ class DrugEncoder(nn.Module):
43
  self.use_morgan_fallback = use_morgan_fallback
44
  self.max_smiles_len = max_smiles_len
45
  self.transformer_model = transformer_model
46
- self._transformer = None # lazy load
47
  self._tokenizer = None
48
-
49
- @property
50
- def transformer(self):
51
- """Lazy-load the transformer to avoid loading if not needed."""
52
- if self._transformer is None:
53
- try:
54
- from transformers import AutoModel, AutoTokenizer
55
- self._tokenizer = AutoTokenizer.from_pretrained(self.transformer_model)
56
- self._transformer = AutoModel.from_pretrained(self.transformer_model)
57
- for p in self._transformer.parameters():
58
- p.requires_grad = False
59
- self._transformer.eval()
60
- except Exception as e:
61
- warnings.warn(f"Failed to load {self.transformer_model}: {e}. "
62
- f"Falling back to Morgan fingerprints.")
63
- self._transformer = None
64
- return self._transformer
 
 
 
 
 
 
 
 
 
 
65
 
66
  def encode_morgan(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
67
- """Encode SMILES using Morgan (ECFP4) fingerprints via RDKit."""
68
  try:
69
  from rdkit import Chem
70
  from rdkit.Chem import AllChem
71
- except ImportError:
72
- raise ImportError("RDKit required for Morgan fingerprint fallback. "
73
- "Install with: pip install rdkit")
74
 
75
  fps = []
76
  for smi in smiles_list:
77
  try:
78
- mol = Chem.MolFromSmiles(smi)
79
  if mol is not None:
80
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
81
  fp_array = torch.tensor(list(fp), dtype=torch.float32)
@@ -84,18 +79,15 @@ class DrugEncoder(nn.Module):
84
  except Exception:
85
  fp_array = torch.zeros(2048, dtype=torch.float32)
86
  fps.append(fp_array)
87
-
88
- fps = torch.stack(fps).to(device) # (B, 2048)
89
- return fps
90
 
91
  def encode_transformer(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
92
- """Encode SMILES using ChemBERTa/MolFormer."""
93
- if self.transformer is None:
94
  if self.use_morgan_fallback:
95
  return self.encode_morgan(smiles_list, device)
96
  raise RuntimeError("No drug encoder available")
97
 
98
- # Tokenize
99
  tokens = self._tokenizer(
100
  smiles_list,
101
  padding=True,
@@ -103,142 +95,79 @@ class DrugEncoder(nn.Module):
103
  max_length=self.max_smiles_len,
104
  return_tensors="pt",
105
  ).to(device)
106
-
107
  with torch.no_grad():
108
- outputs = self.transformer(**tokens)
109
- # Use CLS token or mean pooling
110
  if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
111
- embeddings = outputs.pooler_output # (B, 768)
112
  else:
113
- embeddings = outputs.last_hidden_state.mean(dim=1) # (B, 768)
114
-
115
- return embeddings
116
 
117
  def forward(self, smiles_list: List[str], device: torch.device = None) -> torch.Tensor:
118
- """
119
- Encode a list of SMILES strings.
120
-
121
- Args:
122
- smiles_list: list of SMILES strings
123
- device: torch device
124
-
125
- Returns:
126
- drug_embeddings: (B, embed_dim)
127
- """
128
  if device is None:
129
- device = next(self.parameters()).device if list(self.parameters()) else torch.device("cpu")
130
-
131
- # Try transformer first
132
  try:
133
  return self.encode_transformer(smiles_list, device)
134
  except Exception as e:
135
  if self.use_morgan_fallback:
136
- warnings.warn(f"ChemBERTa encoding failed ({e}), using Morgan fingerprints")
137
  return self.encode_morgan(smiles_list, device)
138
  raise
139
 
140
 
141
  class GeneticPerturbationEncoder(nn.Module):
142
- """
143
- Encodes genetic perturbations (gene KO/KD/OE) into embeddings.
144
-
145
- Gene identity: gene name → pretrained gene embedding (Geneformer, gene2vec, or learned)
146
- Perturbation type: one-hot encoding (CRISPR_KO=0, CRISPRi=1, shRNA=2, OE=3, CRISPRa=4)
147
- """
148
 
149
  def __init__(
150
  self,
151
  gene_embed_dim: int = 256,
152
- pert_type_dim: int = 5, # KO, CRISPRi, shRNA, OE, CRISPRa
153
  output_dim: int = 768,
154
  n_genes: int = 20000,
155
  ):
156
  super().__init__()
157
- # Gene embedding: learned lookup table (can be initialized from Geneformer)
158
  self.gene_embedding = nn.Embedding(n_genes, gene_embed_dim, padding_idx=0)
159
-
160
- # Perturbation type embedding
161
  self.pert_type_embedding = nn.Embedding(pert_type_dim, 32)
162
-
163
- # Project combined gene + perturbation type to output dimension
164
- combined_dim = gene_embed_dim + 32
165
  self.projector = nn.Sequential(
166
- nn.Linear(combined_dim, 512),
167
  nn.SELU(),
168
  nn.Linear(512, output_dim),
169
  )
170
-
171
- # Map perturbation type string to index
172
  self.pert_type_map = {
173
- "CRISPR_KO": 0, "CRISPR_KO": 0,
174
- "CRISPRi": 1, "KD": 1, "knockdown": 1,
175
- "shRNA": 2,
176
- "OE": 3, "overexpression": 3, "CRISPRa": 3,
177
- "CRISPRa": 4, "activation": 4,
 
 
 
 
 
178
  "unknown": 0,
179
  }
180
 
181
  def load_geneformer_embeddings(self, geneformer_model: str = "ctheodoris/Geneformer"):
182
- """
183
- Initialize gene embeddings from Geneformer's pretrained gene embeddings.
184
- Falls back to random initialization if Geneformer is unavailable.
185
- """
186
- try:
187
- from transformers import AutoModel
188
- model = AutoModel.from_pretrained(geneformer_model)
189
- # Geneformer stores gene embeddings in the embedding layer
190
- # For now, keep learned initialization
191
- warnings.warn(
192
- "Geneformer embedding extraction not yet automated. "
193
- "Using learned embeddings. For best results, pre-extract "
194
- "Geneformer gene embeddings and load them manually."
195
- )
196
- except Exception:
197
- pass
198
-
199
- def forward(
200
- self,
201
- gene_ids: torch.LongTensor,
202
- pert_types: List[str],
203
- ) -> torch.Tensor:
204
- """
205
- Args:
206
- gene_ids: (B,) integer gene indices
207
- pert_types: list of perturbation type strings
208
 
209
- Returns:
210
- gene_pert_embedding: (B, output_dim)
211
- """
212
  device = gene_ids.device
213
-
214
- # Gene embedding
215
- gene_emb = self.gene_embedding(gene_ids) # (B, gene_embed_dim)
216
-
217
- # Perturbation type
218
  pert_type_idxs = torch.tensor(
219
- [self.pert_type_map.get(pt.lower(), 0) for pt in pert_types],
220
- dtype=torch.long, device=device,
 
221
  )
222
- pert_type_emb = self.pert_type_embedding(pert_type_idxs) # (B, 32)
223
-
224
- # Combine and project
225
- combined = torch.cat([gene_emb, pert_type_emb], dim=-1) # (B, gene_embed_dim + 32)
226
- return self.projector(combined) # (B, output_dim)
227
 
228
 
229
  class PerturbationEncoder(nn.Module):
230
- """
231
- Unified perturbation encoder for both drug and genetic perturbations.
232
-
233
- Architecture:
234
- Drug: SMILES → DrugEncoder (ChemBERTa or Morgan) → drug_embed (768-dim)
235
- Genetic: (gene_id, pert_type) → GeneticPerturbationEncoder → gene_embed (768-dim)
236
- Fusion: z_pert = α·z_drug + (1-α)·z_gene
237
- where α = sigmoid(learned_alpha) ∈ [0, 1]
238
-
239
- When only one type of perturbation is provided, the other branch
240
- outputs zeros and the fusion weight adapts accordingly.
241
- """
242
 
243
  def __init__(
244
  self,
@@ -251,41 +180,15 @@ class PerturbationEncoder(nn.Module):
251
  ):
252
  super().__init__()
253
  self.output_dim = output_dim
254
-
255
- # Drug encoder
256
  self.drug_encoder = DrugEncoder(
257
  transformer_model=drug_encoder_model,
258
  embed_dim=drug_embed_dim,
259
  use_morgan_fallback=use_morgan_fallback,
260
  )
261
-
262
- # Project drug embedding to output dim if needed
263
- if drug_embed_dim != output_dim:
264
- self.drug_proj = nn.Sequential(
265
- nn.Linear(drug_embed_dim, output_dim),
266
- nn.SELU(),
267
- )
268
- else:
269
- self.drug_proj = nn.Identity()
270
-
271
- # Genetic perturbation encoder
272
- self.gene_encoder = GeneticPerturbationEncoder(
273
- gene_embed_dim=gene_embed_dim,
274
- output_dim=output_dim,
275
- n_genes=n_genes,
276
- )
277
-
278
- # Fusion gate: learned α ∈ [0, 1] determines drug vs. gene contribution
279
- self.alpha_logit = nn.Parameter(torch.tensor(0.0)) # sigmoid(0) = 0.5
280
-
281
- # Output projection (after fusion)
282
- self.output_proj = nn.Sequential(
283
- nn.Linear(output_dim, output_dim),
284
- nn.SELU(),
285
- nn.Linear(output_dim, output_dim),
286
- )
287
-
288
- # Dropout for training
289
  self.dropout = nn.AlphaDropout(0.1)
290
 
291
  def forward(
@@ -294,58 +197,31 @@ class PerturbationEncoder(nn.Module):
294
  gene_ids: Optional[torch.LongTensor] = None,
295
  pert_types: Optional[List[str]] = None,
296
  ) -> torch.Tensor:
297
- """
298
- Encode perturbation into unified embedding.
299
-
300
- At least one of (smiles_list) or (gene_ids + pert_types) must be provided.
301
-
302
- Args:
303
- smiles_list: list of SMILES strings, or None
304
- gene_ids: (B,) integer gene indices, or None
305
- pert_types: list of perturbation type strings, or None
306
-
307
- Returns:
308
- z_pert: (B, output_dim) unified perturbation embedding
309
- """
310
- batch_size = None
311
- device = None
312
-
313
- # Determine batch size and device
314
  if gene_ids is not None:
315
  batch_size = gene_ids.shape[0]
316
  device = gene_ids.device
317
  elif smiles_list is not None:
318
  batch_size = len(smiles_list)
319
- device = next(self.parameters()).device if list(self.parameters()) else torch.device("cpu")
320
-
321
- if batch_size is None:
322
  raise ValueError("Either smiles_list or gene_ids must be provided")
323
 
324
- # Encode drug
325
  if smiles_list is not None and len(smiles_list) > 0:
326
- z_drug = self.drug_encoder(smiles_list, device) # (B, drug_embed_dim)
327
- z_drug = self.drug_proj(z_drug) # (B, output_dim)
328
  else:
329
  z_drug = torch.zeros(batch_size, self.output_dim, device=device)
330
 
331
- # Encode genetic perturbation
332
  if gene_ids is not None and pert_types is not None:
333
- z_gene = self.gene_encoder(gene_ids, pert_types) # (B, output_dim)
334
  else:
335
  z_gene = torch.zeros(batch_size, self.output_dim, device=device)
336
 
337
- # Learned fusion
338
- alpha = torch.sigmoid(self.alpha_logit) # [0, 1]
339
- z_pert = alpha * z_drug + (1 - alpha) * z_gene
340
-
341
- # Output projection
342
- z_pert = self.output_proj(self.dropout(F.selu(z_pert)))
343
-
344
- return z_pert
345
 
346
 
347
  def create_perturbation_encoder(config) -> PerturbationEncoder:
348
- """Factory function to create PerturbationEncoder from config."""
349
  return PerturbationEncoder(
350
  drug_encoder_model=config.drug_encoder_model,
351
  drug_embed_dim=config.drug_embed_dim,
 
1
  """
2
  Perturbation Encoder for MuLGIT-Perturb.
3
 
4
+ Encodes drug (SMILES) and genetic perturbations into a unified embedding.
5
+ Drug encoding uses ChemBERTa/MolFormer when available and falls back to
6
+ RDKit Morgan fingerprints. The encoder always returns the configured
7
+ `embed_dim`, so fallback fingerprints cannot break downstream projection
8
+ layers.
 
 
 
 
 
9
  """
10
 
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
+ from typing import Optional, List
15
  import warnings
16
 
17
 
18
  class DrugEncoder(nn.Module):
19
+ """SMILES -> fixed-size molecular embedding."""
 
 
 
 
 
 
 
 
20
 
21
  def __init__(
22
  self,
 
30
  self.use_morgan_fallback = use_morgan_fallback
31
  self.max_smiles_len = max_smiles_len
32
  self.transformer_model = transformer_model
33
+ self._transformer = None
34
  self._tokenizer = None
35
+ self._tried_loading = False
36
+
37
+ def _match_dim(self, x: torch.Tensor) -> torch.Tensor:
38
+ """Pad/truncate any encoder output to self.embed_dim."""
39
+ if x.shape[-1] == self.embed_dim:
40
+ return x
41
+ if x.shape[-1] < self.embed_dim:
42
+ return F.pad(x, (0, self.embed_dim - x.shape[-1]))
43
+ return x[:, : self.embed_dim]
44
+
45
+ def _load_transformer(self):
46
+ if self._tried_loading:
47
+ return
48
+ self._tried_loading = True
49
+ try:
50
+ from transformers import AutoModel, AutoTokenizer
51
+ self._tokenizer = AutoTokenizer.from_pretrained(self.transformer_model)
52
+ self._transformer = AutoModel.from_pretrained(self.transformer_model)
53
+ for p in self._transformer.parameters():
54
+ p.requires_grad = False
55
+ self._transformer.eval()
56
+ except Exception as e:
57
+ warnings.warn(
58
+ f"Failed to load {self.transformer_model}: {e}. Falling back to Morgan fingerprints."
59
+ )
60
+ self._transformer = None
61
+ self._tokenizer = None
62
 
63
  def encode_morgan(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
 
64
  try:
65
  from rdkit import Chem
66
  from rdkit.Chem import AllChem
67
+ except ImportError as e:
68
+ raise ImportError("RDKit required for Morgan fingerprint fallback. Install with: pip install rdkit") from e
 
69
 
70
  fps = []
71
  for smi in smiles_list:
72
  try:
73
+ mol = Chem.MolFromSmiles(smi or "")
74
  if mol is not None:
75
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
76
  fp_array = torch.tensor(list(fp), dtype=torch.float32)
 
79
  except Exception:
80
  fp_array = torch.zeros(2048, dtype=torch.float32)
81
  fps.append(fp_array)
82
+ return self._match_dim(torch.stack(fps).to(device))
 
 
83
 
84
  def encode_transformer(self, smiles_list: List[str], device: torch.device) -> torch.Tensor:
85
+ self._load_transformer()
86
+ if self._transformer is None or self._tokenizer is None:
87
  if self.use_morgan_fallback:
88
  return self.encode_morgan(smiles_list, device)
89
  raise RuntimeError("No drug encoder available")
90
 
 
91
  tokens = self._tokenizer(
92
  smiles_list,
93
  padding=True,
 
95
  max_length=self.max_smiles_len,
96
  return_tensors="pt",
97
  ).to(device)
 
98
  with torch.no_grad():
99
+ outputs = self._transformer(**tokens)
 
100
  if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
101
+ embeddings = outputs.pooler_output
102
  else:
103
+ embeddings = outputs.last_hidden_state.mean(dim=1)
104
+ return self._match_dim(embeddings.float())
 
105
 
106
  def forward(self, smiles_list: List[str], device: torch.device = None) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
107
  if device is None:
108
+ # DrugEncoder has no parameters; default to CPU unless caller passes device.
109
+ device = torch.device("cpu")
 
110
  try:
111
  return self.encode_transformer(smiles_list, device)
112
  except Exception as e:
113
  if self.use_morgan_fallback:
114
+ warnings.warn(f"Transformer encoding failed ({e}); using Morgan fingerprints")
115
  return self.encode_morgan(smiles_list, device)
116
  raise
117
 
118
 
119
  class GeneticPerturbationEncoder(nn.Module):
120
+ """Gene perturbation encoder: (gene_id, perturbation_type) -> embedding."""
 
 
 
 
 
121
 
122
  def __init__(
123
  self,
124
  gene_embed_dim: int = 256,
125
+ pert_type_dim: int = 5,
126
  output_dim: int = 768,
127
  n_genes: int = 20000,
128
  ):
129
  super().__init__()
 
130
  self.gene_embedding = nn.Embedding(n_genes, gene_embed_dim, padding_idx=0)
 
 
131
  self.pert_type_embedding = nn.Embedding(pert_type_dim, 32)
 
 
 
132
  self.projector = nn.Sequential(
133
+ nn.Linear(gene_embed_dim + 32, 512),
134
  nn.SELU(),
135
  nn.Linear(512, output_dim),
136
  )
 
 
137
  self.pert_type_map = {
138
+ "crispr_ko": 0,
139
+ "ko": 0,
140
+ "crispri": 1,
141
+ "kd": 1,
142
+ "knockdown": 1,
143
+ "shrna": 2,
144
+ "oe": 3,
145
+ "overexpression": 3,
146
+ "crispra": 4,
147
+ "activation": 4,
148
  "unknown": 0,
149
  }
150
 
151
  def load_geneformer_embeddings(self, geneformer_model: str = "ctheodoris/Geneformer"):
152
+ warnings.warn(
153
+ "Geneformer embedding extraction is not automated in this implementation; "
154
+ "using learned gene embeddings. Pre-extracted Geneformer embeddings can be loaded manually."
155
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ def forward(self, gene_ids: torch.LongTensor, pert_types: List[str]) -> torch.Tensor:
 
 
158
  device = gene_ids.device
159
+ gene_emb = self.gene_embedding(gene_ids)
 
 
 
 
160
  pert_type_idxs = torch.tensor(
161
+ [self.pert_type_map.get(str(pt).lower(), 0) for pt in pert_types],
162
+ dtype=torch.long,
163
+ device=device,
164
  )
165
+ pert_type_emb = self.pert_type_embedding(pert_type_idxs)
166
+ return self.projector(torch.cat([gene_emb, pert_type_emb], dim=-1))
 
 
 
167
 
168
 
169
  class PerturbationEncoder(nn.Module):
170
+ """Unified perturbation encoder for drugs and genetic perturbations."""
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def __init__(
173
  self,
 
180
  ):
181
  super().__init__()
182
  self.output_dim = output_dim
 
 
183
  self.drug_encoder = DrugEncoder(
184
  transformer_model=drug_encoder_model,
185
  embed_dim=drug_embed_dim,
186
  use_morgan_fallback=use_morgan_fallback,
187
  )
188
+ self.drug_proj = nn.Sequential(nn.Linear(drug_embed_dim, output_dim), nn.SELU()) if drug_embed_dim != output_dim else nn.Identity()
189
+ self.gene_encoder = GeneticPerturbationEncoder(gene_embed_dim=gene_embed_dim, output_dim=output_dim, n_genes=n_genes)
190
+ self.alpha_logit = nn.Parameter(torch.tensor(0.0))
191
+ self.output_proj = nn.Sequential(nn.Linear(output_dim, output_dim), nn.SELU(), nn.Linear(output_dim, output_dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  self.dropout = nn.AlphaDropout(0.1)
193
 
194
  def forward(
 
197
  gene_ids: Optional[torch.LongTensor] = None,
198
  pert_types: Optional[List[str]] = None,
199
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if gene_ids is not None:
201
  batch_size = gene_ids.shape[0]
202
  device = gene_ids.device
203
  elif smiles_list is not None:
204
  batch_size = len(smiles_list)
205
+ device = self.alpha_logit.device
206
+ else:
 
207
  raise ValueError("Either smiles_list or gene_ids must be provided")
208
 
 
209
  if smiles_list is not None and len(smiles_list) > 0:
210
+ z_drug = self.drug_proj(self.drug_encoder(smiles_list, device))
 
211
  else:
212
  z_drug = torch.zeros(batch_size, self.output_dim, device=device)
213
 
 
214
  if gene_ids is not None and pert_types is not None:
215
+ z_gene = self.gene_encoder(gene_ids, pert_types)
216
  else:
217
  z_gene = torch.zeros(batch_size, self.output_dim, device=device)
218
 
219
+ alpha = torch.sigmoid(self.alpha_logit)
220
+ z_pert = alpha * z_drug + (1.0 - alpha) * z_gene
221
+ return self.output_proj(self.dropout(F.selu(z_pert)))
 
 
 
 
 
222
 
223
 
224
  def create_perturbation_encoder(config) -> PerturbationEncoder:
 
225
  return PerturbationEncoder(
226
  drug_encoder_model=config.drug_encoder_model,
227
  drug_embed_dim=config.drug_embed_dim,