AlienChen commited on
Commit
13fd594
·
verified ·
1 Parent(s): a33cd10

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +94 -429
models/peptide_classifiers.py CHANGED
@@ -147,366 +147,6 @@ class MotifModel(nn.Module):
147
  def forward(self, x):
148
  return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
149
 
150
- class UnpooledBindingPredictor(nn.Module):
151
- def __init__(self,
152
- esm_model_name="facebook/esm2_t33_650M_UR50D",
153
- hidden_dim=512,
154
- kernel_sizes=[3, 5, 7],
155
- n_heads=8,
156
- n_layers=3,
157
- dropout=0.1,
158
- freeze_esm=True):
159
- super().__init__()
160
-
161
- # Define binding thresholds
162
- self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
163
- self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
164
-
165
- # Load ESM model for computing embeddings on the fly
166
- self.esm_model = AutoModel.from_pretrained(esm_model_name)
167
- self.config = AutoConfig.from_pretrained(esm_model_name)
168
-
169
- # Freeze ESM parameters if needed
170
- if freeze_esm:
171
- for param in self.esm_model.parameters():
172
- param.requires_grad = False
173
-
174
- # Get ESM hidden size
175
- esm_dim = self.config.hidden_size
176
-
177
- # Output channels for CNN layers
178
- output_channels_per_kernel = 64
179
-
180
- # CNN layers for handling variable length sequences
181
- self.protein_conv_layers = nn.ModuleList([
182
- nn.Conv1d(
183
- in_channels=esm_dim,
184
- out_channels=output_channels_per_kernel,
185
- kernel_size=k,
186
- padding='same'
187
- ) for k in kernel_sizes
188
- ])
189
-
190
- self.binder_conv_layers = nn.ModuleList([
191
- nn.Conv1d(
192
- in_channels=esm_dim,
193
- out_channels=output_channels_per_kernel,
194
- kernel_size=k,
195
- padding='same'
196
- ) for k in kernel_sizes
197
- ])
198
-
199
- # Calculate total features after convolution and pooling
200
- total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
201
-
202
- # Project to same dimension after CNN processing
203
- self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
204
- self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
205
-
206
- self.protein_norm = nn.LayerNorm(hidden_dim)
207
- self.binder_norm = nn.LayerNorm(hidden_dim)
208
-
209
- # Cross attention blocks with layer norm
210
- self.cross_attention_layers = nn.ModuleList([
211
- nn.ModuleDict({
212
- 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
213
- 'norm1': nn.LayerNorm(hidden_dim),
214
- 'ffn': nn.Sequential(
215
- nn.Linear(hidden_dim, hidden_dim * 4),
216
- nn.ReLU(),
217
- nn.Dropout(dropout),
218
- nn.Linear(hidden_dim * 4, hidden_dim)
219
- ),
220
- 'norm2': nn.LayerNorm(hidden_dim)
221
- }) for _ in range(n_layers)
222
- ])
223
-
224
- # Prediction heads
225
- self.shared_head = nn.Sequential(
226
- nn.Linear(hidden_dim * 2, hidden_dim),
227
- nn.ReLU(),
228
- nn.Dropout(dropout),
229
- )
230
-
231
- # Regression head
232
- self.regression_head = nn.Linear(hidden_dim, 1)
233
-
234
- # Classification head (3 classes: tight, medium, loose binding)
235
- self.classification_head = nn.Linear(hidden_dim, 3)
236
-
237
- def get_binding_class(self, affinity):
238
- """Convert affinity values to class indices
239
- 0: tight binding (>= 7.5)
240
- 1: medium binding (6.0-7.5)
241
- 2: weak binding (< 6.0)
242
- """
243
- if isinstance(affinity, torch.Tensor):
244
- tight_mask = affinity >= self.tight_threshold
245
- weak_mask = affinity < self.weak_threshold
246
- medium_mask = ~(tight_mask | weak_mask)
247
-
248
- classes = torch.zeros_like(affinity, dtype=torch.long)
249
- classes[medium_mask] = 1
250
- classes[weak_mask] = 2
251
- return classes
252
- else:
253
- if affinity >= self.tight_threshold:
254
- return 0 # tight binding
255
- elif affinity < self.weak_threshold:
256
- return 2 # weak binding
257
- else:
258
- return 1 # medium binding
259
-
260
- def compute_embeddings(self, input_ids, attention_mask=None):
261
- """Compute ESM embeddings on the fly"""
262
- esm_outputs = self.esm_model(
263
- input_ids=input_ids,
264
- attention_mask=attention_mask,
265
- return_dict=True
266
- )
267
-
268
- # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
269
- return esm_outputs.last_hidden_state
270
-
271
- def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
272
- """Process a sequence through CNN layers and pooling"""
273
- # Transpose for CNN: [batch_size, hidden_size, seq_length]
274
- x = unpooled_emb.transpose(1, 2)
275
-
276
- # Apply CNN layers and collect outputs
277
- conv_outputs = []
278
- for conv in conv_layers:
279
- conv_out = F.relu(conv(x))
280
- conv_outputs.append(conv_out)
281
-
282
- # Concatenate along channel dimension
283
- conv_output = torch.cat(conv_outputs, dim=1)
284
-
285
- # Global pooling (both max and average)
286
- # If attention mask is provided, use it to create a proper mask for pooling
287
- if attention_mask is not None:
288
- # Create a mask for pooling (1 for valid positions, 0 for padding)
289
- # Expand mask to match conv_output channels
290
- expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
291
-
292
- # Apply mask (set padding to large negative value for max pooling)
293
- masked_output = conv_output.clone()
294
- masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
295
-
296
- # Max pooling along sequence dimension
297
- max_pooled = torch.max(masked_output, dim=2)[0]
298
-
299
- # Average pooling (sum divided by number of valid positions)
300
- sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
301
- valid_positions = torch.sum(expanded_mask, dim=2)
302
- valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
303
- avg_pooled = sum_pooled / valid_positions
304
- else:
305
- # If no mask, use standard pooling
306
- max_pooled = torch.max(conv_output, dim=2)[0]
307
- avg_pooled = torch.mean(conv_output, dim=2)
308
-
309
- # Concatenate the pooled features
310
- pooled = torch.cat([max_pooled, avg_pooled], dim=1)
311
-
312
- return pooled
313
-
314
- def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
315
- # Compute embeddings on the fly using the ESM model
316
- protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
317
- binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
318
-
319
- # Process protein and binder sequences through CNN layers
320
- protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
321
- binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
322
-
323
- # Project to same dimension
324
- protein = self.protein_norm(self.protein_projection(protein_features))
325
- binder = self.binder_norm(self.binder_projection(binder_features))
326
-
327
- # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
328
- protein = protein.unsqueeze(0)
329
- binder = binder.unsqueeze(0)
330
-
331
- # Cross attention layers
332
- for layer in self.cross_attention_layers:
333
- # Protein attending to binder
334
- attended_protein = layer['attention'](
335
- protein, binder, binder
336
- )[0]
337
- protein = layer['norm1'](protein + attended_protein)
338
- protein = layer['norm2'](protein + layer['ffn'](protein))
339
-
340
- # Binder attending to protein
341
- attended_binder = layer['attention'](
342
- binder, protein, protein
343
- )[0]
344
- binder = layer['norm1'](binder + attended_binder)
345
- binder = layer['norm2'](binder + layer['ffn'](binder))
346
-
347
- # Remove sequence dimension
348
- protein_pool = protein.squeeze(0)
349
- binder_pool = binder.squeeze(0)
350
-
351
- # Concatenate both representations
352
- combined = torch.cat([protein_pool, binder_pool], dim=-1)
353
-
354
- # Shared features
355
- shared_features = self.shared_head(combined)
356
-
357
- regression_output = self.regression_head(shared_features)
358
- # classification_logits = self.classification_head(shared_features)
359
-
360
- # return regression_output, classification_logits
361
- return regression_output
362
-
363
- class ImprovedBindingPredictor(nn.Module):
364
- def __init__(self,
365
- esm_dim=1280,
366
- smiles_dim=1280,
367
- hidden_dim=512,
368
- n_heads=8,
369
- n_layers=5,
370
- dropout=0.1):
371
- super().__init__()
372
-
373
- # Define binding thresholds
374
- self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
375
- self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
376
-
377
- # Project to same dimension
378
- self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
379
- self.protein_projection = nn.Linear(esm_dim, hidden_dim)
380
- self.protein_norm = nn.LayerNorm(hidden_dim)
381
- self.smiles_norm = nn.LayerNorm(hidden_dim)
382
-
383
- # Cross attention blocks with layer norm
384
- self.cross_attention_layers = nn.ModuleList([
385
- nn.ModuleDict({
386
- 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
387
- 'norm1': nn.LayerNorm(hidden_dim),
388
- 'ffn': nn.Sequential(
389
- nn.Linear(hidden_dim, hidden_dim * 4),
390
- nn.ReLU(),
391
- nn.Dropout(dropout),
392
- nn.Linear(hidden_dim * 4, hidden_dim)
393
- ),
394
- 'norm2': nn.LayerNorm(hidden_dim)
395
- }) for _ in range(n_layers)
396
- ])
397
-
398
- # Prediction heads
399
- self.shared_head = nn.Sequential(
400
- nn.Linear(hidden_dim * 2, hidden_dim),
401
- nn.ReLU(),
402
- nn.Dropout(dropout),
403
- )
404
-
405
- # Regression head
406
- self.regression_head = nn.Linear(hidden_dim, 1)
407
-
408
- # Classification head (3 classes: tight, medium, loose binding)
409
- self.classification_head = nn.Linear(hidden_dim, 3)
410
-
411
- def get_binding_class(self, affinity):
412
- """Convert affinity values to class indices
413
- 0: tight binding (>= 7.5)
414
- 1: medium binding (6.0-7.5)
415
- 2: weak binding (< 6.0)
416
- """
417
- if isinstance(affinity, torch.Tensor):
418
- tight_mask = affinity >= self.tight_threshold
419
- weak_mask = affinity < self.weak_threshold
420
- medium_mask = ~(tight_mask | weak_mask)
421
-
422
- classes = torch.zeros_like(affinity, dtype=torch.long)
423
- classes[medium_mask] = 1
424
- classes[weak_mask] = 2
425
- return classes
426
- else:
427
- if affinity >= self.tight_threshold:
428
- return 0 # tight binding
429
- elif affinity < self.weak_threshold:
430
- return 2 # weak binding
431
- else:
432
- return 1 # medium binding
433
-
434
- def forward(self, protein_emb, binder_emb):
435
-
436
- protein = self.protein_norm(self.protein_projection(protein_emb))
437
- smiles = self.smiles_norm(self.smiles_projection(binder_emb))
438
-
439
- protein = protein.transpose(0, 1)
440
- smiles = smiles.transpose(0, 1)
441
-
442
- # Cross attention layers
443
- for layer in self.cross_attention_layers:
444
- # Protein attending to SMILES
445
- attended_protein = layer['attention'](
446
- protein, smiles, smiles
447
- )[0]
448
- protein = layer['norm1'](protein + attended_protein)
449
- protein = layer['norm2'](protein + layer['ffn'](protein))
450
-
451
- # SMILES attending to protein
452
- attended_smiles = layer['attention'](
453
- smiles, protein, protein
454
- )[0]
455
- smiles = layer['norm1'](smiles + attended_smiles)
456
- smiles = layer['norm2'](smiles + layer['ffn'](smiles))
457
-
458
- # Get sequence-level representations
459
- protein_pool = torch.mean(protein, dim=0)
460
- smiles_pool = torch.mean(smiles, dim=0)
461
-
462
- # Concatenate both representations
463
- combined = torch.cat([protein_pool, smiles_pool], dim=-1)
464
-
465
- # Shared features
466
- shared_features = self.shared_head(combined)
467
-
468
- regression_output = self.regression_head(shared_features)
469
-
470
- return regression_output
471
-
472
- class PooledAffinityModel(nn.Module):
473
- def __init__(self, affinity_predictor, target_sequence):
474
- super(PooledAffinityModel, self).__init__()
475
- self.affinity_predictor = affinity_predictor
476
- self.target_sequence = target_sequence
477
- self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
478
- for param in self.esm_model.parameters():
479
- param.requires_grad = False
480
-
481
- def compute_embeddings(self, input_ids, attention_mask=None):
482
- """Compute ESM embeddings on the fly"""
483
- esm_outputs = self.esm_model(
484
- input_ids=input_ids,
485
- attention_mask=attention_mask,
486
- return_dict=True
487
- )
488
-
489
- # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
490
- return esm_outputs.last_hidden_state
491
-
492
- def forward(self, x):
493
- target_sequence = self.target_sequence.repeat(x.shape[0], 1)
494
-
495
- protein_emb = self.compute_embeddings(input_ids=target_sequence)
496
- binder_emb = self.compute_embeddings(input_ids=x)
497
- return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
498
-
499
- class AffinityModel(nn.Module):
500
- def __init__(self, affinity_predictor, target_sequence):
501
- super(AffinityModel, self).__init__()
502
- self.affinity_predictor = affinity_predictor
503
- self.target_sequence = target_sequence
504
-
505
- def forward(self, x):
506
- target_sequence = self.target_sequence.repeat(x.shape[0], 1)
507
- affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
508
- return affinity / 10
509
-
510
  class HemolysisModel:
511
  def __init__(self, device):
512
  self.predictor = xgb.Booster(model_file='../classifier_ckpt/wt_hemolysis.json')
@@ -516,17 +156,14 @@ class HemolysisModel:
516
 
517
  self.device = device
518
 
519
- def generate_embeddings(self, sequences):
520
- """Generate ESM embeddings for protein sequences"""
521
- with torch.no_grad():
522
- embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
523
- embeddings = embeddings.cpu().numpy()
524
-
525
- return embeddings
526
-
527
  def get_scores(self, input_seqs):
528
  scores = np.ones(len(input_seqs))
529
- features = self.generate_embeddings(input_seqs)
 
 
 
 
 
530
 
531
  if len(features) == 0:
532
  return scores
@@ -584,6 +221,9 @@ class NonfoulingModel:
584
  def get_scores(self, input_ids, attention_mask):
585
  with torch.no_grad():
586
  features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
 
 
 
587
  scores = self.predictor(features, attention_mask)
588
  return scores
589
 
@@ -591,45 +231,8 @@ class NonfoulingModel:
591
  attention_mask = torch.ones_like(input_ids).to(self.device)
592
  scores = self.get_scores(input_ids, attention_mask)
593
  return 1.0 / (1.0 + torch.exp(-scores))
594
-
595
- class SolubilityModel:
596
- def __init__(self, device):
597
- # change model path
598
- self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_solubility.json')
599
-
600
- self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
601
- self.model.eval()
602
 
603
- self.device = device
604
-
605
- def generate_embeddings(self, sequences):
606
- """Generate ESM embeddings for protein sequences"""
607
- with torch.no_grad():
608
- embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
609
- embeddings = embeddings.cpu().numpy()
610
-
611
- return embeddings
612
-
613
- def get_scores(self, input_seqs: list):
614
- scores = np.zeros(len(input_seqs))
615
- features = self.generate_embeddings(input_seqs)
616
-
617
- if len(features) == 0:
618
- return scores
619
-
620
- features = np.nan_to_num(features, nan=0.)
621
- features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
622
-
623
- features = xgb.DMatrix(features)
624
-
625
- scores = self.predictor.predict(features)
626
- return torch.from_numpy(scores).to(self.device)
627
-
628
- def __call__(self, input_seqs: list):
629
- scores = self.get_scores(input_seqs)
630
- return scores
631
-
632
- class SolubilityModelNew:
633
  def __init__(self, device):
634
  self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device)
635
  self.device = device
@@ -748,7 +351,7 @@ class HalfLifeModel:
748
  self.device = device
749
 
750
  # --- load NN checkpoint (saved by your finetune script) ---
751
- ckpt = torch.load(ckpt_path, map_location="cpu")
752
  if not isinstance(ckpt, dict) or "state_dict" not in ckpt:
753
  raise ValueError(f"Checkpoint at {ckpt_path} is not the expected dict with a 'state_dict' key.")
754
 
@@ -871,33 +474,95 @@ def load_solver(checkpoint_path, vocab_size, device):
871
  return solver
872
 
873
 
874
- def load_pooled_affinity_predictor(checkpoint_path, device):
875
- """Load trained model from checkpoint."""
876
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
 
878
- model = ImprovedBindingPredictor().to(device)
879
-
880
- # Load the trained weights
881
- model.load_state_dict(checkpoint['model_state_dict'])
882
- model.eval() # Set to evaluation mode
883
-
884
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
 
886
  def load_affinity_predictor(checkpoint_path, device):
887
  """Load trained model from checkpoint."""
888
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
889
 
890
- model = UnpooledBindingPredictor(
891
- esm_model_name="facebook/esm2_t33_650M_UR50D",
892
- hidden_dim=384,
893
- kernel_sizes=[3, 5, 7],
894
- n_heads=8,
895
- n_layers=4,
896
- dropout=0.14561457009902096,
897
- freeze_esm=True
898
- ).to(device)
899
-
900
- model.load_state_dict(checkpoint['model_state_dict'])
901
  model.eval()
 
902
 
903
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def forward(self, x):
148
  return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  class HemolysisModel:
151
  def __init__(self, device):
152
  self.predictor = xgb.Booster(model_file='../classifier_ckpt/wt_hemolysis.json')
 
156
 
157
  self.device = device
158
 
 
 
 
 
 
 
 
 
159
  def get_scores(self, input_seqs):
160
  scores = np.ones(len(input_seqs))
161
+ with torch.no_grad():
162
+ embeddings = self.model(input_ids=input_seqs, attention_mask=torch.ones_like(input_seqs).to(self.device)).last_hidden_state
163
+ keep = (input_seqs != 0) & (input_seqs != 1) & (input_seqs != 2)
164
+ embeddings[keep==False] = 0
165
+ features = torch.sum(embeddings, dim=1)/torch.sum(keep==True, dim=1).unsqueeze(-1)
166
+ features = features.cpu().numpy()
167
 
168
  if len(features) == 0:
169
  return scores
 
221
  def get_scores(self, input_ids, attention_mask):
222
  with torch.no_grad():
223
  features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
224
+
225
+ keep = (input_ids != 0) & (input_ids != 1) & (input_ids != 2)
226
+ attention_mask[keep==False] = 0
227
  scores = self.predictor(features, attention_mask)
228
  return scores
229
 
 
231
  attention_mask = torch.ones_like(input_ids).to(self.device)
232
  scores = self.get_scores(input_ids, attention_mask)
233
  return 1.0 / (1.0 + torch.exp(-scores))
 
 
 
 
 
 
 
 
234
 
235
+ class SolubilityModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def __init__(self, device):
237
  self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device)
238
  self.device = device
 
351
  self.device = device
352
 
353
  # --- load NN checkpoint (saved by your finetune script) ---
354
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
355
  if not isinstance(ckpt, dict) or "state_dict" not in ckpt:
356
  raise ValueError(f"Checkpoint at {ckpt_path} is not the expected dict with a 'state_dict' key.")
357
 
 
474
  return solver
475
 
476
 
477
+ class CrossAttnUnpooled(nn.Module):
478
+ """
479
+ token sequences with masks; alternating cross attention.
480
+ """
481
+ def __init__(self, Ht=1280, Hb=1280, hidden=768, n_heads=8, n_layers=1, dropout=0.16430662769055482):
482
+ super().__init__()
483
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
484
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
485
+
486
+ self.layers = nn.ModuleList([])
487
+ for _ in range(n_layers):
488
+ self.layers.append(nn.ModuleDict({
489
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
490
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
491
+ "n1t": nn.LayerNorm(hidden),
492
+ "n2t": nn.LayerNorm(hidden),
493
+ "n1b": nn.LayerNorm(hidden),
494
+ "n2b": nn.LayerNorm(hidden),
495
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
496
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
497
+ }))
498
+
499
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
500
+ self.reg = nn.Linear(hidden, 1)
501
+ self.cls = nn.Linear(hidden, 3)
502
+
503
+ def masked_mean(self, X, M):
504
+ Mf = M.unsqueeze(-1).float()
505
+ denom = Mf.sum(dim=1).clamp(min=1.0)
506
+ return (X * Mf).sum(dim=1) / denom
507
 
508
+ def forward(self, T, Mt, B, Mb):
509
+ # T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
510
+ T = self.t_proj(T)
511
+ Bx = self.b_proj(B)
512
+
513
+ kp_t = ~Mt # key_padding_mask True = pad
514
+ kp_b = ~Mb
515
+
516
+ for L in self.layers:
517
+ # T attends to B
518
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
519
+ T = L["n1t"](T + T_attn)
520
+ T = L["n2t"](T + L["fft"](T))
521
+
522
+ # B attends to T
523
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
524
+ Bx = L["n1b"](Bx + B_attn)
525
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
526
+
527
+ t_pool = self.masked_mean(T, Mt)
528
+ b_pool = self.masked_mean(Bx, Mb)
529
+ z = torch.cat([t_pool, b_pool], dim=-1)
530
+ h = self.shared(z)
531
+ return self.reg(h).squeeze(-1), self.cls(h)
532
 
533
  def load_affinity_predictor(checkpoint_path, device):
534
  """Load trained model from checkpoint."""
535
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
536
 
537
+ model = CrossAttnUnpooled()
538
+
539
+ model.load_state_dict(checkpoint['state_dict'])
 
 
 
 
 
 
 
 
540
  model.eval()
541
+ model = model.to(device)
542
 
543
  return model
544
+
545
+ class AffinityModel(nn.Module):
546
+ def __init__(self, affinity_predictor, target_sequence, device):
547
+ super(AffinityModel, self).__init__()
548
+ self.affinity_predictor = affinity_predictor
549
+ self.target_sequence = target_sequence
550
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
551
+ self.esm_model.eval()
552
+ self.device=device
553
+
554
+ def forward(self, x):
555
+ batch = x.shape[0]
556
+ Mt = self.target_sequence['attention_mask'][:, 1:-1].repeat(batch, 1)
557
+ with torch.no_grad():
558
+ T = self.esm_model(**self.target_sequence).last_hidden_state[:, 1:-1, :].repeat(batch, 1, 1)
559
+
560
+ Mb = torch.ones(batch, x.shape[1] - 2, dtype=torch.bool).to(self.device)
561
+ with torch.no_grad():
562
+ for i in range(batch):
563
+ attention_mask = torch.ones_like(x).to(self.device)
564
+ B = self.esm_model(input_ids=x, attention_mask=torch.ones_like(x).to(self.device)).last_hidden_state[:, 1:-1]
565
+
566
+ affinity, _ = self.affinity_predictor(T, Mt.bool(), B, Mb)
567
+ return affinity / 10
568
+