hemantn commited on
Commit
9d902f5
·
1 Parent(s): b3cd799

Implement proper alignment functionality using ANARCI

Browse files
Files changed (1) hide show
  1. adapter.py +67 -29
adapter.py CHANGED
@@ -568,46 +568,84 @@ class HFAbRestore(AbRestore):
568
  n_seqs = len(seqs)
569
 
570
  if align:
571
- # For now, skip alignment as it requires ANARCI
572
- print("WARNING: Alignment not implemented, skipping...")
573
- pass
574
-
575
- print(f"DEBUG: Processing sequences: {seqs}")
576
-
577
- # Tokenize sequences using original interface
578
- tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
579
- print(f"DEBUG: Tokenized shape: {tokens.shape}")
580
- print(f"DEBUG: First sequence tokens: {tokens[0]}")
581
-
582
- # Get predictions for amino acids (indices 1-20)
583
- predictions = self.AbLang(tokens)[:,:,1:21]
584
- print(f"DEBUG: Predictions shape: {predictions.shape}")
585
-
586
- # Find predicted tokens and replace mask tokens
 
 
 
 
 
 
 
 
 
587
  predicted_tokens = torch.max(predictions, -1).indices + 1
588
- print(f"DEBUG: Predicted tokens: {predicted_tokens[0]}")
589
-
590
- # Find mask token positions
591
- mask_positions = (tokens == 23).nonzero(as_tuple=True)
592
- print(f"DEBUG: Mask token positions: {mask_positions}")
593
-
594
  restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
595
- print(f"DEBUG: Restored tokens: {restored_tokens[0]}")
596
-
597
- # Decode back to sequences using original tokenizer
598
  restored_seqs = self.tokenizer(restored_tokens, mode="decode")
599
- print(f"DEBUG: Decoded sequences: {restored_seqs}")
600
-
601
- # Handle paired sequences format
602
  if n_seqs < len(restored_seqs):
603
  restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
604
  seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
605
 
606
- # Apply final formatting
607
  from extra_utils import res_to_seq
608
  result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
609
  print(f"DEBUG: Final result: {result}")
610
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
 
612
  def add_angle_brackets(seq):
613
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
568
  n_seqs = len(seqs)
569
 
570
  if align:
571
+ # Implement alignment using ANARCI to create spread sequences
572
+ print("DEBUG: Using alignment to create spread sequences...")
573
+ seqs = self._sequence_aligning(seqs)
574
+ nr_seqs = len(seqs)//self.spread
575
+
576
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
577
+ predictions = self.AbLang(tokens)[:,:,1:21]
578
+
579
+ # Reshape
580
+ tokens = tokens.reshape(nr_seqs, self.spread, -1)
581
+ predictions = predictions.reshape(nr_seqs, self.spread, -1, 20)
582
+ seqs = seqs.reshape(nr_seqs, -1)
583
+
584
+ # Find index of best predictions
585
+ best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1)
586
+
587
+ # Select best predictions
588
+ tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1)
589
+ predictions = predictions[range(predictions.shape[0]), best_seq_idx]
590
+ seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1)
591
+ else:
592
+ print(f"DEBUG: Processing sequences without alignment: {seqs}")
593
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
594
+ predictions = self.AbLang(tokens)[:,:,1:21]
595
+
596
  predicted_tokens = torch.max(predictions, -1).indices + 1
 
 
 
 
 
 
597
  restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
598
+
 
 
599
  restored_seqs = self.tokenizer(restored_tokens, mode="decode")
600
+
 
 
601
  if n_seqs < len(restored_seqs):
602
  restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
603
  seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
604
 
 
605
  from extra_utils import res_to_seq
606
  result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
607
  print(f"DEBUG: Final result: {result}")
608
  return result
609
+
610
+ def _sequence_aligning(self, seqs):
611
+ """Create spread sequences using ANARCI alignment."""
612
+ tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs]
613
+
614
+ spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')]
615
+ spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')]
616
+
617
+ return np.concatenate([np.array(spread_heavy),np.array(spread_light)])
618
+
619
+ def _create_spread_of_sequences(self, seqs, chain = 'H'):
620
+ """Create spread sequences using ANARCI."""
621
+ import pandas as pd
622
+ import anarci
623
+
624
+ chain_idx = 0 if chain == 'H' else 1
625
+ numbered_seqs = anarci.run_anarci(
626
+ pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(),
627
+ ncpu=self.ncpu,
628
+ scheme='imgt',
629
+ allowed_species=['human', 'mouse'],
630
+ )
631
+
632
+ anarci_data = pd.DataFrame(
633
+ [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]],
634
+ columns=['anarci']
635
+ ).astype('<U90')
636
+
637
+ max_position = 128 if chain == 'H' else 127
638
+
639
+ from extra_utils import get_sequences_from_anarci
640
+ seqs = anarci_data.apply(
641
+ lambda x: get_sequences_from_anarci(
642
+ x.anarci,
643
+ max_position,
644
+ self.spread
645
+ ), axis=1, result_type='expand'
646
+ ).to_numpy().reshape(-1)
647
+
648
+ return seqs
649
 
650
  def add_angle_brackets(seq):
651
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'