diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..e92cf9b594e8e4608b60733d93926bdec544b5c1 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.pt filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..befbe78a0faf21576a9183576d050010940b6a33 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 hemantn + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4c4344667fa8d45b41331635c07267d4e1a9ee51 --- /dev/null +++ b/README.md @@ -0,0 +1,111 @@ +--- +#language: +#- en +license: mit +tags: +- biology +- protein +- antibody +- ablang +- transformers +- pytorch +- chemistry +- oas +- cdr +- ablang2 hf implementation +- roberta +- ESM +- ablang2 +- antibody-design + +# datasets: +# - oas +metrics: +- sequence modeling +- protein language model +library_name: transformers +pipeline_tag: fill-mask +--- + +# 🧬 AbLang2: Transformer-based Antibody Language Model + +This repository provides HuggingFace-compatible 🤗 implementation of the AbLang2 language model for antibodies. The original AbLang2 model was developed by the [Oxford Protein Informatics Group (OPIG)](https://opig.stats.ox.ac.uk/) and is available at: +- **AbLang2**: [https://github.com/TobiasHeOl/AbLang2](https://github.com/TobiasHeOl/AbLang2) + +## 🎯 Model Available + +- **ablang2**: AbLang2 model for paired antibody sequences + +## 📦 Installation + +Install the required dependencies: + +```bash +pip install transformers torch numpy pandas anarci +``` + +## 🚀 Loading Models + +```python +from transformers import AutoModel, AutoTokenizer +from adapter import AbLang2PairedHuggingFaceAdapter + +# AbLang2 +model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True) +ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer) +``` + +**Note**: Models automatically use GPU when available, otherwise fall back to CPU. + +## ⚙️ Available Utilities + +- **seqcoding**: Sequence-level representations (averaged across residues) +- **rescoding**: Residue-level representations (per-residue embeddings) +- **likelihood**: Raw logits for amino acid prediction at each position +- **probability**: Normalized probabilities for amino acid prediction +- **pseudo_log_likelihood**: Uncertainty scoring with stepwise masking (masks each residue) +- **confidence**: Fast uncertainty scoring (single forward pass, no masking) +- **restore**: Restore masked residues (*) with predicted amino acids + +## 💡 Examples + +### 🔗 AbLang2 (Paired Sequences) +```python +from transformers import AutoModel, AutoTokenizer +from adapter import AbLang2PairedHuggingFaceAdapter + +# Load model +model = AutoModel.from_pretrained("your-username/ablang2", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("your-username/ablang2", trust_remote_code=True) +ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer) + +# Restore masked paired sequences +masked_seqs = [ + ['EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', + 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'] +] +restored = ablang(masked_seqs, mode='restore') +``` + +## 📚 Detailed Usage + +For comprehensive examples and detailed usage instructions, see: +- [`test_ablang2_HF_implementation.ipynb`](test_ablang2_HF_implementation.ipynb) + +This notebook demonstrates all utilities with real examples, including alignment features and advanced usage patterns. + +## 📖 Citation + +If you use these models in your research, please cite the original AbLang2 paper: + +**AbLang2:** +``` +@article{Olsen2024, + title={Addressing the antibody germline bias and its effect on language models for improved antibody design}, + author={Tobias H. Olsen, Iain H. Moal and Charlotte M. Deane}, + journal={bioRxiv}, + doi={https://doi.org/10.1101/2024.02.02.578678}, + year={2024} +} +``` \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7906a84d4709df91562c98c7a19e5fc3445169 --- /dev/null +++ b/__init__.py @@ -0,0 +1,6 @@ +from .configuration_ablang2paired import AbLang2PairedConfig +from .modeling_ablang2paired import AbLang2PairedHFModel +from .tokenizer_ablang2paired import AbLang2PairedTokenizer +from ablang2 import pretrained + +__all__ = ['AbLang2PairedConfig', 'AbLang2PairedHFModel', 'AbLang2PairedTokenizer'] diff --git a/ablang2/__init__.py b/ablang2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7bb62b7fbcd07536d0f8bf33c2ae22333304a51 --- /dev/null +++ b/ablang2/__init__.py @@ -0,0 +1 @@ +from .pretrained import pretrained \ No newline at end of file diff --git a/ablang2/__pycache__/__init__.cpython-310.pyc b/ablang2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32835a47dedbf5a9b24f32425bf2884fbb63d41d Binary files /dev/null and b/ablang2/__pycache__/__init__.cpython-310.pyc differ diff --git a/ablang2/__pycache__/adapter.cpython-310.pyc b/ablang2/__pycache__/adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c8459af24406c7856cefbcebceeeb50640bc2e Binary files /dev/null and b/ablang2/__pycache__/adapter.cpython-310.pyc differ diff --git a/ablang2/__pycache__/configuration_ablang2paired.cpython-310.pyc b/ablang2/__pycache__/configuration_ablang2paired.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb4f11fa8e055b4c1cec8c0325296382531ebec Binary files /dev/null and b/ablang2/__pycache__/configuration_ablang2paired.cpython-310.pyc differ diff --git a/ablang2/__pycache__/load_model.cpython-310.pyc b/ablang2/__pycache__/load_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2334a6bf2f07e81b60e12a86ed9ec8a223392893 Binary files /dev/null and b/ablang2/__pycache__/load_model.cpython-310.pyc differ diff --git a/ablang2/__pycache__/pretrained.cpython-310.pyc b/ablang2/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab558a3a472130dd3557c10922457569c956d40 Binary files /dev/null and b/ablang2/__pycache__/pretrained.cpython-310.pyc differ diff --git a/ablang2/adapter.py b/ablang2/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8fe459cdd6a90c00880b9b84dfc0064b4aa466 --- /dev/null +++ b/ablang2/adapter.py @@ -0,0 +1,306 @@ +from ablang2.pretrained_utils.restoration import AbRestore +from ablang2.pretrained_utils.encodings import AbEncoding +from ablang2.pretrained_utils.alignment import AbAlignment +from ablang2.pretrained_utils.scores import AbScores +import torch +import numpy as np +from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list + +class HuggingFaceTokenizerAdapter: + def __init__(self, tokenizer, device): + self.tokenizer = tokenizer + self.device = device + self.pad_token_id = tokenizer.pad_token_id + self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.all_special_tokens = tokenizer.all_special_tokens + + def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + input_ids = tokens['input_ids'].to(self.device if device is None else device) + if mode == 'decode': + # seqs is a tensor of token ids + if isinstance(seqs, torch.Tensor): + seqs = seqs.cpu().numpy() + decoded = [] + for i, seq in enumerate(seqs): + chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] + # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code + # The length is not always available, so use len(chars) as fallback + formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') + decoded.append(formatted) + return decoded + return input_ids + +class HFAbRestore(AbRestore): + def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): + super().__init__(spread=spread, device=device, ncpu=ncpu) + self.used_device = device + self._hf_model = hf_model + self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) + + @property + def AbLang(self): + def model_call(x): + output = self._hf_model(x) + if hasattr(output, 'last_hidden_state'): + return output.last_hidden_state + return output + return model_call + +def add_angle_brackets(seq): + # Assumes input is 'VH|VL' or 'VH|' or '|VL' + if '|' in seq: + vh, vl = seq.split('|', 1) + else: + vh, vl = seq, '' + return f"<{vh}>|<{vl}>" + +class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): + """ + Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. + Automatically uses CUDA if available, otherwise CPU. + """ + def __init__(self, model, tokenizer, device=None, ncpu=1): + super().__init__() + if device is None: + self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.used_device = torch.device(device) + self.AbLang = model # HuggingFace model instance + self.tokenizer = tokenizer + self.AbLang.to(self.used_device) + self.AbLang.eval() + # Always get AbRep from the underlying model + if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): + self.AbRep = self.AbLang.model.AbRep + else: + raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") + self.ncpu = ncpu + self.spread = 11 # For compatibility with original utilities + # The following is no longer needed since all_special_tokens now returns IDs directly + # self.tokenizer.all_special_token_ids = [ + # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens + # ] + # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens + # self.tokenizer.all_special_tokens = [ + # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str + # ] + + def freeze(self): + self.AbLang.eval() + + def unfreeze(self): + self.AbLang.train() + + def _encode_sequences(self, seqs): + # Use HuggingFace-style padding and return PyTorch tensors + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + tokens = extract_input_ids(tokens, self.used_device) + return self.AbRep(tokens).last_hidden_states.detach() + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + tokens = extract_input_ids(tokens, self.used_device) + output = self.AbLang(tokens) + if hasattr(output, 'last_hidden_state'): + return output.last_hidden_state.detach() + return output.detach() + + def _preprocess_labels(self, labels): + labels = extract_input_ids(labels, self.used_device) + return labels + + def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): + """ + Use different modes for different usecases, mimicking the original pretrained class. + """ + from ablang2.pretrained import format_seq_input + + valid_modes = [ + 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', + 'pseudo_log_likelihood', 'confidence' + ] + if mode not in valid_modes: + raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") + + seqs, chain = format_seq_input(seqs, fragmented=fragmented) + + if align: + numbered_seqs, seqs, number_alignment = self.number_sequences( + seqs, chain=chain, fragmented=fragmented + ) + else: + numbered_seqs = None + number_alignment = None + + subset_list = [] + for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: + subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) + + return self.reformat_subsets( + subset_list, + mode=mode, + align=align, + numbered_seqs=numbered_seqs, + seqs=seqs, + number_alignment=number_alignment, + ) + + def pseudo_log_likelihood(self, seqs, **kwargs): + """ + Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. + """ + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + # Tokenize all sequences in batch + labels = self.tokenizer( + formatted_seqs, padding=True, return_tensors='pt' + ) + labels = extract_input_ids(labels, self.used_device) + + # Convert special tokens to IDs + if isinstance(self.tokenizer.all_special_tokens[0], int): + special_token_ids = set(self.tokenizer.all_special_tokens) + else: + special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) + pad_token_id = self.tokenizer.pad_token_id + + mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) + if mask_token_id is None: + mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + plls = [] + with torch.no_grad(): + for i, seq_label in enumerate(labels): + seq_pll = [] + for j, token_id in enumerate(seq_label): + if token_id.item() in special_token_ids or token_id.item() == pad_token_id: + continue + masked = seq_label.clone() + masked[j] = mask_token_id + logits = self.AbLang(masked.unsqueeze(0)) + if hasattr(logits, 'last_hidden_state'): + logits = logits.last_hidden_state + logits = logits[0, j] + nll = torch.nn.functional.cross_entropy( + logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" + ) + seq_pll.append(-nll.item()) + if seq_pll: + plls.append(np.mean(seq_pll)) + else: + plls.append(float('nan')) + return np.array(plls) + + def confidence(self, seqs, **kwargs): + """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + plls = [] + for seq in formatted_seqs: + tokens = self.tokenizer([seq], padding=True, return_tensors='pt') + input_ids = extract_input_ids(tokens, self.used_device) + + with torch.no_grad(): + output = self.AbLang(input_ids) + if hasattr(output, 'last_hidden_state'): + logits = output.last_hidden_state + else: + logits = output + + # Get the sequence (remove batch dimension) + logits = logits[0] # [seq_len, vocab_size] + input_ids = input_ids[0] # [seq_len] + + # Exclude all special tokens (pad, mask, etc.) + if isinstance(self.tokenizer.all_special_tokens[0], int): + special_token_ids = set(self.tokenizer.all_special_tokens) + else: + special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) + valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) + + if valid_mask.sum() > 0: + valid_logits = logits[valid_mask] + valid_labels = input_ids[valid_mask] + + # Calculate cross-entropy loss + nll = torch.nn.functional.cross_entropy( + valid_logits, + valid_labels, + reduction="mean" + ) + pll = -nll.item() + else: + pll = 0.0 + + plls.append(pll) + + return np.array(plls, dtype=np.float32) + + def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Probability of mutations - applies softmax to logits to get probabilities + """ + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + # Get logits + if stepwise_masking: + # For stepwise masking, we need to implement it similar to likelihood + # This is a simplified version - you might want to implement full stepwise masking + logits = self._predict_logits(formatted_seqs) + else: + logits = self._predict_logits(formatted_seqs) + + # Apply softmax to get probabilities + probs = logits.softmax(-1).cpu().numpy() + + if align: + return probs + else: + # Return residue-level probabilities (excluding special tokens) + return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] + + def restore(self, seqs, align=False, **kwargs): + hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) + restored = hf_abrestore.restore(seqs, align=align) + # Apply angle brackets formatting + if isinstance(restored, np.ndarray): + restored = np.array([add_angle_brackets(seq) for seq in restored]) + else: + restored = [add_angle_brackets(seq) for seq in restored] + return restored + +def extract_input_ids(tokens, device): + if hasattr(tokens, 'input_ids'): + return tokens.input_ids.to(device) + elif isinstance(tokens, dict): + if 'input_ids' in tokens: + return tokens['input_ids'].to(device) + else: + for v in tokens.values(): + if hasattr(v, 'ndim') or torch.is_tensor(v): + return v.to(device) + elif torch.is_tensor(tokens): + return tokens.to(device) + else: + raise ValueError("Could not extract input_ids from tokenizer output") \ No newline at end of file diff --git a/ablang2/alignment.py b/ablang2/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..0d14b9d750509596eea8f2849f383f4434375044 --- /dev/null +++ b/ablang2/alignment.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +import numpy as np +import torch + +from .extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment + + +class AbAlignment: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def number_sequences(self, seqs, chain = 'H', fragmented = False): + if chain == 'HL': + numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu) + else: + assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.' + numbered_seqs, seqs, number_alignment = unpaired_msa_numbering( + seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu + ) + + return numbered_seqs, seqs, number_alignment + + def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): + + aligned_encodings = np.concatenate( + [[ + create_alignment( + res_embed, numbered_seq, seq, number_alignment + ) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs) + ]], axis=0 + ) + return aligned_encodings + + + def reformat_subsets( + self, + subset_list, + mode = 'seqcoding', + align = False, + numbered_seqs = None, + seqs = None, + number_alignment = None, + ): + + if mode in [ + 'seqcoding', + 'restore', + 'pseudo_log_likelihood', + 'confidence' + ]: + return np.concatenate(subset_list) + elif align: + subset_list = [ + self.align_encodings( + subset, + numbered_seqs[num*len(subset):(num+1)*len(subset)], + seqs[num*len(subset):(num+1)*len(subset)], + number_alignment + ) for num, subset in enumerate(subset_list) + ] + + subset = np.concatenate(subset_list) + + return aligned_results( + aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]], + aligned_embeds = subset[:,:,:-1].astype(float), + number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values + ) + + elif not align: + return sum(subset_list, []) + else: + return np.concatenate(subset_list) # this needs to be changed + + +@dataclass +class aligned_results(): + """ + Dataclass used to store output. + """ + + aligned_seqs: None + aligned_embeds: None + number_alignment: None \ No newline at end of file diff --git a/ablang2/config.json b/ablang2/config.json new file mode 100644 index 0000000000000000000000000000000000000000..53db5301093e416c4c3781509e63272d840b9615 --- /dev/null +++ b/ablang2/config.json @@ -0,0 +1,18 @@ +{ + "model_type": "ablang2-paired", + "vocab_size": 26, + "hidden_embed_size": 480, + "n_attn_heads": 20, + "n_encoder_blocks": 12, + "padding_tkn": 21, + "mask_tkn": 23, + "layer_norm_eps": 1e-12, + "a_fn": "swiglu", + "dropout": 0.0, + "tokenizer_class": "AbLang2PairedTokenizer", + "auto_map": { + "AutoConfig": "configuration_ablang2paired.AbLang2PairedConfig", + "AutoModel": "modeling_ablang2paired.AbLang2PairedHFModel", + "AutoTokenizer": ["tokenizer_ablang2paired.AbLang2PairedTokenizer", "tokenizer_ablang2paired.AbLang2PairedTokenizer"] + } +} diff --git a/ablang2/configuration_ablang2paired.py b/ablang2/configuration_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..844e53b7a2c748fcb423e2d0fbc3fc15ec7faeb4 --- /dev/null +++ b/ablang2/configuration_ablang2paired.py @@ -0,0 +1,31 @@ +from transformers import PretrainedConfig + +class AbLang2PairedConfig(PretrainedConfig): + model_type = "ablang2-paired" + + def __init__( + self, + vocab_size=26, + hidden_embed_size=480, + n_attn_heads=20, + n_encoder_blocks=12, + padding_tkn=21, + mask_tkn=23, + layer_norm_eps=1e-12, + a_fn="swiglu", + dropout=0.0, + **kwargs + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_embed_size = hidden_embed_size + self.hidden_size = hidden_embed_size # Add this for Hugging Face compatibility + self.n_attn_heads = n_attn_heads + self.num_attention_heads = n_attn_heads # Add this for Hugging Face compatibility + self.num_hidden_layers = n_encoder_blocks # Add this for Hugging Face compatibility + self.n_encoder_blocks = n_encoder_blocks + self.padding_tkn = padding_tkn + self.mask_tkn = mask_tkn + self.layer_norm_eps = layer_norm_eps + self.a_fn = a_fn + self.dropout = dropout \ No newline at end of file diff --git a/ablang2/encodings.py b/ablang2/encodings.py new file mode 100644 index 0000000000000000000000000000000000000000..1946c5df116bb26915c4b552d9e9eb5d38033e04 --- /dev/null +++ b/ablang2/encodings.py @@ -0,0 +1,97 @@ +import numpy as np +import torch + +from .extra_utils import res_to_list, res_to_seq + + +class AbEncoding: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def _initiate_abencoding(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def _encode_sequences(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang.AbRep(tokens).last_hidden_states + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang(tokens) + + def _predict_logits_with_step_masking(self, seqs): + + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + + logits = [] + for single_seq_tokens in tokens: + + tkn_len = len(single_seq_tokens) + masked_tokens = single_seq_tokens.repeat(tkn_len, 1) + for num in range(tkn_len): + masked_tokens[num, num] = self.tokenizer.mask_token + + with torch.no_grad(): + logits_tmp = self.AbLang(masked_tokens) + + logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)]) + + logits.append(logits_tmp) + + return torch.stack(logits, dim=0) + + def seqcoding(self, seqs, **kwargs): + """ + Sequence specific representations + """ + + encodings = self._encode_sequences(seqs).cpu().numpy() + + lens = np.vectorize(len)(seqs) + lens = np.tile(lens.reshape(-1,1,1), (encodings.shape[2], 1)) + + return np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(encodings,1,2), lens]) + + def rescoding(self, seqs, align=False, **kwargs): + """ + Residue specific representations. + """ + encodings = self._encode_sequences(seqs).cpu().numpy() + + if align: return encodings + + else: return [res_to_list(state, seq) for state, seq in zip(encodings, seqs)] + + def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Likelihood of mutations + """ + if stepwise_masking: + logits = self._predict_logits_with_step_masking(seqs).cpu().numpy() + else: + logits = self._predict_logits(seqs).cpu().numpy() + + if align: return logits + + else: return [res_to_list(state, seq) for state, seq in zip(logits, seqs)] + + def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Probability of mutations + """ + if stepwise_masking: + logits = self._predict_logits_with_step_masking(seqs) + else: + logits = self._predict_logits(seqs) + probs = logits.softmax(-1).cpu().numpy() + + if align: return probs + + else: return [res_to_list(state, seq) for state, seq in zip(probs, seqs)] + \ No newline at end of file diff --git a/ablang2/environment.yaml b/ablang2/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..47b3456756d77a882f16e17885d97510929f36cf --- /dev/null +++ b/ablang2/environment.yaml @@ -0,0 +1,44 @@ +name: AbLang +channels: + - conda-forge + - pytorch + - bioconda + - defaults +dependencies: + - python=3.10.18 + - pip + - pytorch=2.5.1 + - pytorch-cuda=12.4 + - numpy=2.2.6 + - pandas=2.3.1 + - transformers=4.53.3 + - anarci=2024.05.21 + - jupyter=7.4.4 + - notebook=7.4.4 + - ipython=8.37.0 + - ipykernel=6.29.5 + - matplotlib-inline=0.1.7 + - scikit-learn + - matplotlib + - seaborn + - biopython=1.85 + - huggingface_hub=0.33.4 + - tokenizers=0.21.3 + - safetensors=0.5.3 + - einops=0.8.1 + - tqdm=4.67.1 + - requests=2.32.4 + - urllib3=2.5.0 + - certifi=2025.7.14 + - filelock=3.18.0 + - fsspec=2025.3.0 + - packaging=25.0 + - regex=2024.11.6 + - sympy=1.13.3 + - networkx=3.4.2 + - jinja2=3.1.6 + - pyyaml=6.0.2 + - typing_extensions=4.14.1 + - pip: + - numba=0.61.2 + - llvmlite=0.44.0 \ No newline at end of file diff --git a/ablang2/extra_utils.py b/ablang2/extra_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fca5c5caa944968d673164a1c29c28172b1aa888 --- /dev/null +++ b/ablang2/extra_utils.py @@ -0,0 +1,165 @@ +import string, re +import numpy as np + + +def res_to_list(logits, seq): + return logits[:len(seq)] + +def res_to_seq(a, mode='mean'): + """ + Function for how we go from n_values for each amino acid to n_values for each sequence. + + We leave out padding tokens. + """ + + if mode=='sum': + return a[0:(int(a[-1]))].sum() + + elif mode=='mean': + return a[0:(int(a[-1]))].mean() + + elif mode=='restore': + return a[0][0:(int(a[-1]))] + +def get_number_alignment(numbered_seqs): + """ + Creates a number alignment from the anarci results. + """ + import pandas as pd + + alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs] + unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0) + max_alignment = get_max_alignment() + + return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]] + +def get_max_alignment(): + """ + Create maximum possible alignment for sorting + """ + import pandas as pd + + sortlist = [[("<", "")]] + for num in range(1, 128+1): + if num in [33,61,112]: + for char in string.ascii_uppercase[::-1]: + sortlist.append([(num, char)]) + + sortlist.append([(num,' ')]) + else: + sortlist.append([(num,' ')]) + for char in string.ascii_uppercase: + sortlist.append([(num, char)]) + + return pd.DataFrame(sortlist + [[(">", "")]]) + + +def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10): + + import pandas as pd + + tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs] + + numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering( + [i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs + ) + numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering( + [i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs + ) + + number_alignment = pd.concat([ + number_alignment_heavy, + pd.DataFrame([[("|",""), "|"]]), + number_alignment_light] + ).reset_index(drop=True) + + seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)] + numbered_seqs = [ + heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light) + ] + + return numbered_seqs, seqs, number_alignment + + +def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10): + + numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs) + number_alignment = get_number_alignment(numbered_seqs) + number_alignment[1] = chain + + seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs] + return numbered_seqs, seqs, number_alignment + + +def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1): + + import anarci + import pandas as pd + + anarci_out = anarci.run_anarci( + pd.DataFrame(seqs).reset_index().values.tolist(), + ncpu=n_jobs, + scheme='imgt', + allowed_species=['human', 'mouse'], + ) + + numbered_seqs = [] + for onarci in anarci_out[1]: + numbered_seq = [] + for i in onarci[0][0]: + if i[1] != '-': + numbered_seq.append((i[0], chain, i[1])) + + if fragmented: + numbered_seqs.append(numbered_seq) + else: + numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")]) + + return numbered_seqs + + +def create_alignment(res_embeds, numbered_seqs, seq, number_alignment): + + import pandas as pd + + datadf = pd.DataFrame(numbered_seqs) + sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2] + + idxs = np.where(sequence_alignment.values == '-')[0] + idxs = [idx-num for num, idx in enumerate(idxs)] + + aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0)) + + return pd.concat([aligned_embeds, sequence_alignment], axis=1).values + + +def get_spread_sequences(seq, spread, start_position): + """ + Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions). + """ + spread_sequences = [] + + for diff in range(start_position-8, start_position+2+1): + spread_sequences.append('*'*diff+seq) + + return np.array(spread_sequences) + +def get_sequences_from_anarci(out_anarci, max_position, spread): + """ + Ensures correct masking on each side of sequence + """ + + if out_anarci == 'ANARCI_error': + return np.array(['ANARCI-ERR']*spread) + + end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1]) + # Fixes ANARCI error of poor numbering of the CDR1 region + start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+', + out_anarci).group().split(',')[0]) - 1 + + sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci)))) + + sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position)) + + return get_spread_sequences(sequence_j, spread, start_position) + diff --git a/ablang2/hparams.json b/ablang2/hparams.json new file mode 100755 index 0000000000000000000000000000000000000000..65a58d738f1183520c5a26f1472fcc556524948b --- /dev/null +++ b/ablang2/hparams.json @@ -0,0 +1 @@ +{"name": "AbLang-2", "n_encoder_blocks": 12, "hidden_embed_size": 480, "n_attn_heads": 20, "a_fn": "swiglu", "layer_norm_eps": 1e-12, "pad_tkn": 21, "start_tkn": 0, "end_tkn": 22, "sep_tkn": 25, "mask_tkn": 23, "vocab_size": 26} \ No newline at end of file diff --git a/ablang2/load_model.py b/ablang2/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2604c5e3f1d69693f8cce9f044306e728f4f7460 --- /dev/null +++ b/ablang2/load_model.py @@ -0,0 +1,119 @@ +import os, subprocess, json, argparse,requests +import torch + +list_of_models = { + "ablang1-heavy":["https://opig.stats.ox.ac.uk/data/downloads/ablang-heavy.tar.gz", "amodel.pt"], + "ablang1-light":["https://opig.stats.ox.ac.uk/data/downloads/ablang-light.tar.gz", "amodel.pt"], + "ablang2-paired":["https://zenodo.org/records/10185169/files/ablang2-weights.tar.gz", "model.pt"], + "tcrlang-paired":["https://zenodo.org/records/11208211/files/tcrlang-weights.tar.gz", "model.pt"], +} +ablang1_models = ["ablang1-heavy", "ablang1-light"] +ablang2_models = ["ablang2-paired", "tcrlang-paired"] + + +def load_model(model_to_use = "ablang2-paired", random_init = False, device = 'cpu'): + + if model_to_use in ablang1_models: + AbLang, tokenizer, hparams = fetch_ablang1( + model_to_use, + random_init=random_init, + device=device + ) + elif model_to_use in ablang2_models: + AbLang, tokenizer, hparams = fetch_ablang2( + model_to_use, + random_init=random_init, + device=device + ) + elif "ABLANG-" in model_to_use: + AbLang, tokenizer, hparams = fetch_ablang2( + model_to_use, + random_init=random_init, + device=device + ) + else: + assert False, f"The selected model to use ({model_to_use}) does not exist.\ + Please select a valid model." + + return AbLang, tokenizer, hparams + + +def download_model(model_to_use = "ablang2-paired"): + """ + If not already downloaded, download model inside environment. + """ + + local_model_folder = os.path.join(os.path.dirname(__file__), "model-weights-{}".format(model_to_use)) + os.makedirs(local_model_folder, exist_ok = True) + + file_w_weights, file_model = list_of_models[model_to_use] # modify list of models + + if not os.path.isfile(os.path.join(local_model_folder, file_model)): + print("Downloading model ...") + tmp_file = os.path.join(local_model_folder, "tmp.tar.gz") + + with open(tmp_file,'wb') as f: f.write(requests.get(file_w_weights).content) + + subprocess.run(["tar", "-zxvf", tmp_file, "-C", local_model_folder], check = True) + os.remove(tmp_file) + + return local_model_folder + + +def fetch_ablang1(model_to_use, random_init=False, device='cpu'): + + from .models.ablang1 import model as ablang_1_model + from .models.ablang1 import tokenizers as ablang_1_tokenizer + + local_model_folder = download_model(model_to_use) + + with open(os.path.join(local_model_folder, 'hparams.json'), 'r', encoding='utf-8') as f: + hparams = argparse.Namespace(**json.load(f)) + + AbLang = ablang_1_model.AbLang(hparams) + if not random_init: + AbLang.load_state_dict( + torch.load( + os.path.join(local_model_folder, 'amodel.pt'), + map_location=torch.device(device) + ) + ) + tokenizer = ablang_1_tokenizer.ABtokenizer(os.path.join(local_model_folder, 'vocab.json')) + + return AbLang, tokenizer, hparams + + +def fetch_ablang2(model_to_use, random_init=False, device='cpu'): + + from .models.ablang2 import ablang + from .models.ablang2 import tokenizers + + if model_to_use in ablang2_models: + local_model_folder = download_model(model_to_use) + else: + local_model_folder = model_to_use + + with open(os.path.join(local_model_folder, 'hparams.json'), 'r', encoding='utf-8') as f: + hparams = argparse.Namespace(**json.load(f)) + + AbLang = ablang.AbLang( + vocab_size = hparams.vocab_size, + hidden_embed_size = hparams.hidden_embed_size, + n_attn_heads = hparams.n_attn_heads, + n_encoder_blocks = hparams.n_encoder_blocks, + padding_tkn = hparams.pad_tkn, + mask_tkn = hparams.mask_tkn, + layer_norm_eps = hparams.layer_norm_eps, + a_fn = hparams.a_fn, + ) + + if not random_init: + AbLang.load_state_dict( + torch.load( + os.path.join(local_model_folder, 'model.pt'), + map_location=torch.device(device) + ) + ) + tokenizer = tokenizers.ABtokenizer() + + return AbLang, tokenizer, hparams diff --git a/ablang2/model.pt b/ablang2/model.pt new file mode 100755 index 0000000000000000000000000000000000000000..945049bdf05f02b513c1fca72406ca08b01d207e --- /dev/null +++ b/ablang2/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56d6f07862a6f824f88c8707bbc03e4026c9db762be2d3041e9767e2e6f86386 +size 179314477 diff --git a/ablang2/modeling_ablang2paired.py b/ablang2/modeling_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..142d9376b1ca32545164068e2a146e3a9c82dafd --- /dev/null +++ b/ablang2/modeling_ablang2paired.py @@ -0,0 +1,81 @@ +import torch +import os +from torch import nn +from transformers import PreTrainedModel +from ablang2.models.ablang2.ablang import AbLang as AbLang2 +from ablang2_paired.configuration_ablang2paired import AbLang2PairedConfig + +class AbLang2PairedHFModel(PreTrainedModel): + config_class = AbLang2PairedConfig + model_type = "ablang2-paired" + + def __init__(self, config: AbLang2PairedConfig): + super().__init__(config) + self.model = AbLang2( + vocab_size=config.vocab_size, + hidden_embed_size=config.hidden_embed_size, + n_attn_heads=config.n_attn_heads, + n_encoder_blocks=config.n_encoder_blocks, + padding_tkn=config.padding_tkn, + mask_tkn=config.mask_tkn, + layer_norm_eps=config.layer_norm_eps, + a_fn=config.a_fn, + dropout=config.dropout, + ) + + def forward(self, input_ids=None, x=None, attention_mask=None, **kwargs): + # Handle both Hugging Face format (input_ids) and original format (x) + if input_ids is not None: + x = input_ids + elif x is None: + raise ValueError("Either input_ids or x must be provided") + + # Get the output from the underlying model + output = self.model(x, attention_mask) + + # Return as a simple object with last_hidden_state attribute + class ModelOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + return ModelOutput(output) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Check if we have custom weights + model_path = pretrained_model_name_or_path + custom_weights_path = os.path.join(model_path, "model.pt") + + if os.path.exists(custom_weights_path): + # Load config + config = kwargs.get("config") + if config is None: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Create model with only the config argument + model = cls(config) + + # Load custom weights + state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True) + model.model.load_state_dict(state_dict) + + # Move model to appropriate device (GPU if available, otherwise CPU) + device = kwargs.get("device", None) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + return model + else: + # Fall back to standard Hugging Face loading + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def save_pretrained(self, save_directory, **kwargs): + os.makedirs(save_directory, exist_ok=True) + # Save custom weights + torch.save(self.model.state_dict(), f"{save_directory}/model.pt") + # Save config + self.config.save_pretrained(save_directory) + # Call parent method for any additional saving + super().save_pretrained(save_directory, **kwargs) \ No newline at end of file diff --git a/ablang2/models/__init__.py b/ablang2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ablang2/models/__pycache__/__init__.cpython-310.pyc b/ablang2/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e71159591e9a568180f6bf3ffc9c7be7c790b5f1 Binary files /dev/null and b/ablang2/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/ablang2/models/__pycache__/__init__.cpython-312.pyc b/ablang2/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d17f992753679cc14cb6d72a313b768a61817686 Binary files /dev/null and b/ablang2/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__init__.py b/ablang2/models/ablang1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d999beef911268f95f604377c424fbcf63650f04 --- /dev/null +++ b/ablang2/models/ablang1/__init__.py @@ -0,0 +1,3 @@ +from .tokenizers import ABtokenizer +from .model import AbLang, AbRep, AbHead +from .pretrained import pretrained \ No newline at end of file diff --git a/ablang2/models/ablang1/__pycache__/__init__.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f46470ccdc117e800e830095e4dd4c59507330f6 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/__init__.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/__init__.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fef56d832d03975084cd5cec02520ba9e1ad155 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/__init__.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/embedding.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..337b2db205a7c2ce33381eeb8aa8cc5adde3b8cd Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/embedding.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/embedding.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/embedding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df16307edeafe5b159ad7fe344ed87f0fe90e2ca Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/embedding.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ee712a5694065f8c2865ab4024d7f9329d62f04 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e320feb9085d0837ce1455711363ae2ec424910d Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/encoderblocks.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/extra_fns.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/extra_fns.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffe9de43dbd45a99d65b955439deb278c38fcb9 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/extra_fns.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/extra_fns.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/extra_fns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac2084247b3de36aebf345cb24a2333a22f1fca Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/extra_fns.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c25c5e7c7ac2991a0f238941b49227a50be2a3f Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3262cfc2a71d953f758f600a7dad998150a102f Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/fairseq_mha.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/model.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14f930ae60c2b33b48366fc4cc023fc6e9f00a2d Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/model.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/model.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54b5be87735d4bfb4088fb6643bd94911807f77 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/model.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/pretrained.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d16ab3486158a481a968ec8a22eb578cf35d5d Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/pretrained.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/pretrained.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/pretrained.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94faacfc9ccf7ef8ba29b02d2e69c8e92f4f466c Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/pretrained.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/tokenizers.cpython-310.pyc b/ablang2/models/ablang1/__pycache__/tokenizers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db67acfa414440e417ebae22ff2758c6893c5dd0 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/tokenizers.cpython-310.pyc differ diff --git a/ablang2/models/ablang1/__pycache__/tokenizers.cpython-312.pyc b/ablang2/models/ablang1/__pycache__/tokenizers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e55f2487a17f204b837c2233b0ed5f13cce726a9 Binary files /dev/null and b/ablang2/models/ablang1/__pycache__/tokenizers.cpython-312.pyc differ diff --git a/ablang2/models/ablang1/embedding.py b/ablang2/models/ablang1/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..bac6969ac31985107c15d4b0dc985bd755742982 --- /dev/null +++ b/ablang2/models/ablang1/embedding.py @@ -0,0 +1,36 @@ +import torch + + +class AbEmbeddings(torch.nn.Module): + """ + Residue embedding and Positional embedding + """ + + def __init__(self, hparams): + super().__init__() + self.pad_token_id = hparams.pad_token_id + + self.AAEmbeddings = torch.nn.Embedding(hparams.vocab_size, hparams.hidden_size, padding_idx=self.pad_token_id) + self.PositionEmbeddings = torch.nn.Embedding(hparams.max_position_embeddings, hparams.hidden_size, padding_idx=0) # here padding_idx is always 0 + + self.LayerNorm = torch.nn.LayerNorm(hparams.hidden_size, eps=hparams.layer_norm_eps) + self.Dropout = torch.nn.Dropout(hparams.hidden_dropout_prob) + + def forward(self, src): + + inputs_embeds = self.AAEmbeddings(src) + + position_ids = self.create_position_ids_from_input_ids(src, self.pad_token_id) + position_embeddings = self.PositionEmbeddings(position_ids) + + embeddings = inputs_embeds + position_embeddings + + return self.Dropout(self.LayerNorm(embeddings)) + + def create_position_ids_from_input_ids(self, input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Padding idx will get position 0, which will be ignored later on. + """ + mask = input_ids.ne(padding_idx).int() + + return torch.cumsum(mask, dim=1).long() * mask \ No newline at end of file diff --git a/ablang2/models/ablang1/encoderblocks.py b/ablang2/models/ablang1/encoderblocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4e50c81252a2f00af9eb2b08919926a28f93bd08 --- /dev/null +++ b/ablang2/models/ablang1/encoderblocks.py @@ -0,0 +1,141 @@ +import math +from typing import List, Optional, Tuple +from dataclasses import dataclass + +import torch +import torch.nn as nn +#from fairseq.modules.multihead_attention import MultiheadAttention +from .fairseq_mha import MultiheadAttention + +from .extra_fns import ACT2FN + + +@dataclass +class AbRepOutput(): + """ + Dataclass used to store AbRep output. + """ + + last_hidden_states: torch.FloatTensor + all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class EncoderBlocks(torch.nn.Module): + """ + Wrapper for multiple EncoderBlocks (or a single). + """ + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + self.Layers = nn.ModuleList([EncoderBlock(hparams) for _ in range(hparams.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False): + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for num_block, a_EncoderBlock in enumerate(self.Layers): + + hidden_states, attentions = a_EncoderBlock(hidden_states, attention_mask, output_attentions) + #print(attentions) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis + + return AbRepOutput(last_hidden_states=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions) + + +class EncoderBlock(torch.nn.Module): + """ + Single EncoderBlock. + + An EncoderBlock consists of a MultiHeadAttention and a IntermediateLayer. + """ + def __init__(self, hparams): + super().__init__() + + self.MultiHeadAttention = ThirdMultiHeadAttention(hparams) + self.MHADropout = nn.Dropout(hparams.hidden_dropout_prob) + self.MHALayerNorm = nn.LayerNorm(hparams.hidden_size, eps=hparams.layer_norm_eps) + + self.IntermediateLayer = IntermediateLayer(hparams) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + + MHAoutput, attentions = self.MultiHeadAttention(hidden_states, attention_mask, output_attentions=output_attentions) + + output = self.MHADropout(MHAoutput) + output = self.MHALayerNorm(output + hidden_states) # HIDDEN_STATES ARE ADDED FOR RESIDUAL BLOCK EFFECT + + output = self.IntermediateLayer(output) # INTERMEDIATELAYER HAS RESIDUAL BLOCK EFFECT INTERNALLY + + #outputs = (layer_output,) + self_attention_outputs[1:] # if output_attentions=False then 1: is empty + + return output, attentions + + +class ThirdMultiHeadAttention(torch.nn.Module): + """ + New MultiHeadAttention which can return the weights of the individual heads. + """ + + def __init__(self, hparams): + super().__init__() + + self.Attention = MultiheadAttention(hparams.hidden_size, hparams.num_attention_heads, dropout=hparams.attention_probs_dropout_prob, self_attention=True) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + + hidden_states = torch.transpose(hidden_states, 0, 1) + + # static_kv is only True because there is currently a bug which doesn't return the head weights unaveraged unless its true + attn_output, attn_weights = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, static_kv=True, + need_weights=output_attentions, need_head_weights=output_attentions) + + return torch.transpose(attn_output, 0, 1), attn_weights + + +class OldMultiHeadAttention(torch.nn.Module): + """ + MultiHeadAttention contains a Scaled Dot Product Attention and a Linear Layer. + """ + def __init__(self, config): + super().__init__() + self.Attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + + hidden_states = torch.transpose(hidden_states, 0, 1) + output, attentions = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, need_weights=output_attentions) + + attention_output = torch.transpose(output, 0, 1) + + return attention_output, attentions + + +class IntermediateLayer(nn.Module): + """ + Contains an expanding layer, while also functioning as a residual block ending with a drop-norm layer + """ + def __init__(self, config): + super().__init__() + self.expand_dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] + + self.dense_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + output = self.expand_dense(hidden_states) + output = self.intermediate_act_fn(output) + + output = self.dense_dense(output) + output = self.dropout(output) + output = self.LayerNorm(output + hidden_states) + + return output diff --git a/ablang2/models/ablang1/extra_fns.py b/ablang2/models/ablang1/extra_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..fed9c50484e9908e7694b78bc009a3f0868fea9b --- /dev/null +++ b/ablang2/models/ablang1/extra_fns.py @@ -0,0 +1,26 @@ +import torch +import math + + +def gelu_new(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + +def gelu_fast(x): + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) + +def mish(x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + +ACT2FN = { + "relu": torch.nn.functional.relu, + "gelu": torch.nn.functional.gelu, + "tanh": torch.tanh, + "gelu_new": gelu_new, + "gelu_fast": gelu_fast, + "mish": mish, + "sigmoid": torch.sigmoid, +} diff --git a/ablang2/models/ablang1/fairseq_mha.py b/ablang2/models/ablang1/fairseq_mha.py new file mode 100644 index 0000000000000000000000000000000000000000..fd29f8dbe7ceda1fb2b19e32de563fa0e9f44200 --- /dev/null +++ b/ablang2/models/ablang1/fairseq_mha.py @@ -0,0 +1,1306 @@ +import math +from typing import Dict, List, Optional, Tuple +import uuid + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + +_xformers_available = False + +# TODO: move this into xformers? +# TODO: uint8 input type should just output a bool +def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): + """ + call to pytorch multihead accepts three mask types: + - ByteTensor where non-zero means to mask + - FloatTensor which is an additive mask + - BoolTensor where True means to mask + xFormers currently accepts boolean and additive maks. For boolean masks + the values have opposite meaning. For a BoolTensor True mean to keep the value. + """ + float_types = [torch.float, torch.float16] + # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. + additive = mask.dtype in float_types + # If to_dype is not specified, keep same dtype as mask. + to_dtype = mask.dtype if to_dtype is None else to_dtype + to_additive = to_dtype in float_types + + if additive: + if to_additive: + return mask.to(to_dtype) + mask = mask < 0 + + if to_additive: + # return additive mask + new_mask = torch.zeros_like(mask, dtype=to_dtype) + new_mask = new_mask.masked_fill_(mask, -float("inf")) + return new_mask + + # In xFormers True is value to keep rather than value to mask + mask = ~mask.to(torch.bool) + mask = mask.to(to_dtype) + return mask + +class FairseqDecoder(nn.Module): + """Base class for decoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + self.onnx_trace = False + self.adaptive_softmax = None + + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + x = self.output_layer(x) + return x, extra + + def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def output_layer(self, features, **kwargs): + """ + Project features to the default output size, e.g., vocabulary size. + + Args: + features (Tensor): features returned by *extract_features*. + """ + raise NotImplementedError + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + else: + return softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + + def max_positions(self): + """Maximum input length supported by the decoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + + +@with_incremental_state +class FairseqIncrementalDecoder(FairseqDecoder): + """Base class for incremental decoders. + + Incremental decoding is a special mode at inference time where the Model + only receives a single timestep of input corresponding to the previous + output token (for teacher forcing) and must produce the next output + *incrementally*. Thus the model must cache any long-term state that is + needed about the sequence, e.g., hidden states, convolutional states, etc. + + Compared to the standard :class:`FairseqDecoder` interface, the incremental + decoder interface allows :func:`forward` functions to take an extra keyword + argument (*incremental_state*) that can be used to cache state across + time-steps. + + The :class:`FairseqIncrementalDecoder` interface also defines the + :func:`reorder_incremental_state` method, which is used during beam search + to select and reorder the incremental state based on the selection of beams. + + To learn more about how incremental decoding works, refer to `this blog + `_. + """ + + def __init__(self, dictionary): + super().__init__(dictionary) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict, optional): dictionary used for storing + state during :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result + + def set_beam_size(self, beam_size): + """Sets the beam size in the decoder and all children.""" + if getattr(self, "_beam_size", -1) != beam_size: + seen = set() + + def apply_set_beam_size(module): + if ( + module != self + and hasattr(module, "set_beam_size") + and module not in seen + ): + seen.add(module) + module.set_beam_size(beam_size) + + self.apply(apply_set_beam_size) + self._beam_size = beam_size + + + + + + +class MultiheadAttention(FairseqIncrementalDecoder): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + dictionary=None, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__(dictionary) + + #xformers_att_config = utils.eval_str_dict(xformers_att_config) + self.use_xformers = False #xformers_att_config is not None + if self.use_xformers and not _xformers_available: + raise ImportError("\n\n Please install xFormers.") + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + + if self.use_xformers: + xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout) + xformers_att_config["num_heads"] = xformers_att_config.get( + "num_heads", num_heads + ) + + if xformers_blocksparse_layout is not None: + # Could be part of a single config passed only once + xformers_att_config["block_size"] = xformers_blocksparse_blocksize + xformers_att_config["layout"] = xformers_blocksparse_layout + xformers_att_config["name"] = "blocksparse" + + self.attention = build_attention(xformers_att_config) + + self.onnx_trace = False + self.skip_embed_dim_check = False + self.init_incremental_state() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum( + torch.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + torch.sum( + torch.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + torch.sum( + torch.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + + self.q_proj.weight = torch.nn.Parameter(new_q_weight) + self.q_proj.bias = torch.nn.Parameter(new_q_bias) + + self.k_proj.weight = torch.nn.Parameter(new_k_weight) + self.k_proj.bias = torch.nn.Parameter(new_k_bias) + + self.v_proj.weight = torch.nn.Parameter(new_v_weight) + self.v_proj.bias = torch.nn.Parameter(new_v_bias) + + self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _xformers_attn_forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + + tgt_len, bsz, embed_dim = query.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == tgt_len + + if self.self_attention: + key = query + value = query + elif self.encoder_decoder_attention: + value = key + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + def fold_heads(x): + return ( + x.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + def split_heads(x): + return ( + x.contiguous() + .view(-1, bsz, self.num_heads, self.head_dim) + .transpose(0, 1) + .transpose(1, 2) + ) + + massage = split_heads if self.attention.requires_head_dimension else fold_heads + q = massage(q) + if k is not None: + k = massage(k) + if v is not None: + v = massage(v) + + if self.add_zero_attn: + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + kwargs = {} + + if attn_mask is not None and self.attention.supports_attention_mask: + attn_mask = _mask_for_xformers(attn_mask, to_dtype=q.dtype) + kwargs["att_mask"] = attn_mask + + if key_padding_mask is not None: + key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=torch.bool) + if not self.attention.requires_separate_masks: + attn_mask = maybe_merge_masks( + attn_mask, + key_padding_mask, + batch_size=bsz, + src_len=k.size(-2), + tgt_len=q.size(-2), + num_heads=self.num_heads, + ) + key_padding_mask = None + kwargs["att_mask"] = attn_mask + if self.attention.supports_key_padding_mask: + kwargs["key_padding_mask"] = key_padding_mask + + y = self.attention(q, k, v, **kwargs) + + y = ( + y.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + .transpose(0, 1) + ) + assert list(y.size()) == [tgt_len, bsz, embed_dim] + + # Dropout not needed because already applied in attention. + # It is applied to the attention weights before matmul with v. + y = self.out_proj(y) + + # TODO: support returning attention weights if needed. + return y, None + + def forward( + self, + query: Tensor, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask.bool() if key_padding_mask is not None else None, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn: Optional[Tensor] = None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select( + 0, + new_order.reshape(-1, self.beam_size)[:, 0] + // self.beam_size, + ) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + + + + + + + + + + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + logger.warning( + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + logger.info( + "Enabling dropout during inference for module: {}".format(name) + ) + self.apply_during_inference = True + else: + logger.info("Disabling dropout for module: {}".format(name)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + + + + + + + + + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) \ No newline at end of file diff --git a/ablang2/models/ablang1/model.py b/ablang2/models/ablang1/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0f241ae558a39dd07d014d9385758298abbf03ab --- /dev/null +++ b/ablang2/models/ablang1/model.py @@ -0,0 +1,102 @@ +import torch + +from .extra_fns import ACT2FN +from .encoderblocks import EncoderBlocks +from .embedding import AbEmbeddings + + +class AbLang(torch.nn.Module): + """ + Pretraining model includes Abrep and the head model used for training. + """ + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + + self.AbRep = AbRep(self.hparams) + self.AbHead = AbHead(self.hparams) + + def forward(self, x, attention_mask=None): + + representations = self.AbRep(x, attention_mask) + + output = self.AbHead(representations.last_hidden_states) + + return output + + def get_aa_embeddings(self): + "This function is used to extract the trained aa_embeddings." + return self.AbRep.AbEmbeddings.aa_embeddings#().weight.detach() + + +class AbRep(torch.nn.Module): + """ + This is the AbRep model. + """ + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + + self.AbEmbeddings = AbEmbeddings(self.hparams) + self.EncoderBlocks = EncoderBlocks(self.hparams) + + self.init_weights() + + def forward(self, src, attention_mask=None, output_attentions=False): + + attention_mask = torch.zeros(*src.shape, device=src.device).masked_fill(src == self.hparams.pad_token_id, 1) + + src = self.AbEmbeddings(src) + + output = self.EncoderBlocks(src, attention_mask=attention_mask, output_attentions=output_attentions) + + return output + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.hparams.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, torch.nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def init_weights(self): + """ + Initializes and prunes weights if needed. + """ + # Initialize weights + self.apply(self._init_weights) + + +class AbHead(torch.nn.Module): + """ + Head for masked sequence prediction. + """ + + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + self.dense = torch.nn.Linear(self.hparams.hidden_size, self.hparams.hidden_size) + self.layer_norm = torch.nn.LayerNorm(self.hparams.hidden_size, eps=self.hparams.layer_norm_eps) + + self.decoder = torch.nn.Linear(self.hparams.hidden_size, self.hparams.vocab_size, bias=False) + self.bias = torch.nn.Parameter(torch.zeros(self.hparams.vocab_size)) + + self.activation = ACT2FN[self.hparams.hidden_act] + + ## self.init_weights() - need to have a function doing this + + self.decoder.bias = self.bias # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + + def forward(self, features, **kwargs): + x = self.dense(features) + + x = self.activation(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x diff --git a/ablang2/models/ablang1/pretrained.py b/ablang2/models/ablang1/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8c231e0f28e7de27af0fe3529f6909059d6428 --- /dev/null +++ b/ablang2/models/ablang1/pretrained.py @@ -0,0 +1,358 @@ +import os, json, argparse, string, subprocess, re +from dataclasses import dataclass + +from numba import jit +from numba.typed import Dict, List +from numba.types import unicode_type, DictType + +import numpy as np +import torch +import requests + +from . import tokenizers, model + + +class pretrained: + """ + Initializes AbLang for heavy or light chains. + """ + + def __init__(self, chain="heavy", model_folder="download", random_init=False, ncpu=7, device='cpu'): + super().__init__() + + self.used_device = torch.device(device) + + if model_folder == "download": + # Download model and save to specific place - if already downloaded do not download again + model_folder = os.path.join(os.path.dirname(__file__), "model-weights-{}".format(chain)) + os.makedirs(model_folder, exist_ok = True) + + if not os.path.isfile(os.path.join(model_folder, "amodel.pt")): + print("Downloading model ...") + + url = "https://opig.stats.ox.ac.uk/data/downloads/ablang-{}.tar.gz".format(chain) + tmp_file = os.path.join(model_folder, "tmp.tar.gz") + + with open(tmp_file,'wb') as f: f.write(requests.get(url).content) + + subprocess.run(["tar", "-zxvf", tmp_file, "-C", model_folder], check = True) + + os.remove(tmp_file) + + self.hparams_file = os.path.join(model_folder, 'hparams.json') + self.model_file = os.path.join(model_folder, 'amodel.pt') + + with open(self.hparams_file, 'r', encoding='utf-8') as f: + self.hparams = argparse.Namespace(**json.load(f)) + + self.AbLang = model.AbLang(self.hparams) + self.AbLang.to(self.used_device) + + if not random_init: + self.AbLang.load_state_dict(torch.load(self.model_file, map_location=self.used_device)) + + self.tokenizer = tokenizers.ABtokenizer(os.path.join(model_folder, 'vocab.json')) + self.AbRep = self.AbLang.AbRep + + self.ncpu = ncpu + self.spread = 11 # Based on get_spread_sequences function + if chain == 'heavy': + self.max_position = 128 + else: + self.max_position = 127 + + + def freeze(self): + self.AbLang.eval() + + def unfreeze(self): + self.AbLang.train() + + def __call__(self, sequence, mode='seqcoding', align=False, splitSize=50): + """ + Mode: sequence, residue, restore or likelihood. + """ + if not mode in ['rescoding', 'seqcoding', 'restore', 'likelihood']: + raise SyntaxError("Given mode doesn't exist.") + + if isinstance(sequence, str): sequence = [sequence] + + + if align and mode=='restore': + sequence = self.sequence_aligning(sequence) + splitSize = ((splitSize//self.spread)+1)*self.spread + + aList = [] + for sequence_part in [sequence[x:x+splitSize] for x in range(0, len(sequence), splitSize)]: + aList.append(getattr(self, mode)(sequence_part, align)) + + if mode == 'rescoding': + if align==True: + return aList + + return sum(aList, []) + + return np.concatenate(aList) + + def seqcoding(self, seqs, align=False): + """ + Sequence specific representations + """ + + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + + residue_states = self.AbRep(tokens).last_hidden_states + + if torch.is_tensor(residue_states): residue_states = residue_states.cpu().detach().numpy() + + lens = np.vectorize(len)(seqs) + + lens = np.tile(lens.reshape(-1,1,1), (residue_states.shape[2], 1)) + + seq_codings = np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(residue_states,1,2), lens]) + + del lens + del residue_states + + return seq_codings + + def restore(self, seqs, align=False): + """ + Restore sequences + """ + + if align: + nr_seqs = len(seqs)//self.spread + + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + # Reshape + tokens = tokens.reshape(nr_seqs, self.spread, -1) + predictions = predictions.reshape(nr_seqs, self.spread, -1, 20) + seqs = seqs.reshape(nr_seqs, -1) + + # Find index of best predictions + best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1) + + # Select best predictions + tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1) + predictions = predictions[range(predictions.shape[0]), best_seq_idx] + seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1) + + + else: + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + predicted_tokens = torch.max(predictions, -1).indices + 1 + restored_tokens = torch.where(tokens==23, predicted_tokens, tokens) + + restored_seqs = self.tokenizer(restored_tokens, encode=False) + + return np.array([res_to_seq(seq, 'reconstruct') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]]) + + def likelihood(self, seqs, align=False): + """ + Possible Mutations + """ + + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + + predictions = self.AbLang(tokens)[:,:,1:21] + + if torch.is_tensor(predictions): predictions = predictions.cpu().detach().numpy() + + return predictions + + def rescoding(self, seqs, align=False): + """ + Residue specific representations. + """ + + if align: + + import pandas as pd + import anarci + + anarci_out = anarci.run_anarci(pd.DataFrame(seqs).reset_index().values.tolist(), ncpu=7, scheme='imgt') + number_alignment = get_number_alignment(anarci_out) + + seqs = np.array([''.join([i[1] for i in onarci[0][0]]).replace('-','') for onarci in anarci_out[1]]) + + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + residue_states = self.AbRep(tokens).last_hidden_states + + if torch.is_tensor(residue_states): residue_states = residue_states.cpu().detach().numpy() + + residue_output = np.array([create_alignment(res_embed, oanarci, seq, number_alignment) for res_embed, oanarci, seq in zip(residue_states, anarci_out[1], seqs)]) + del residue_states + del tokens + + return output(aligned_embeds=residue_output, number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values) + + else: + + tokens = self.tokenizer(seqs, pad=True, device=self.used_device) + residue_states = self.AbRep(tokens).last_hidden_states + + if torch.is_tensor(residue_states): residue_states = residue_states.cpu().detach().numpy() + + residue_output = [res_to_list(state, seq) for state, seq in zip(residue_states, seqs)] + + return residue_output + + def sequence_aligning(self, seqs): + + import pandas as pd + import anarci + + anarci_out = anarci.run_anarci( + pd.DataFrame([seq.replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), + ncpu=self.ncpu, + scheme='imgt' + ) #, allowed_species=['human', 'mouse'] + anarci_data = pd.DataFrame([str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in anarci_out[1]], columns=['anarci']).astype('"]] + except KeyError as e: + + wrong_aa = e.args + + e.args = (f"Following character(s) not accepted in sequences: {wrong_aa}. \ +Please only use amino acids (MRHKDESTNQCGPAVIFYWL) or the mask token (*).",) + raise + + return torch.tensor(encoded, dtype=torch.long, device=device) + # Start and Stop token should probably not be added here, but instead earlier + + def decode(self, seqtokens): + + if torch.is_tensor(seqtokens): seqtokens = seqtokens.cpu().numpy() + + return ''.join([self.vocab_to_aa[token] for token in seqtokens]) + + + diff --git a/ablang2/models/ablang2/__init__.py b/ablang2/models/ablang2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ablang2/models/ablang2/__pycache__/__init__.cpython-310.pyc b/ablang2/models/ablang2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cbc3bb338b48d26f39a712e1275ef7adbd0b142 Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/__init__.cpython-310.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/__init__.cpython-312.pyc b/ablang2/models/ablang2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a9d13ec09cb8720654821c60602827715042ca Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/__init__.cpython-312.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/ablang.cpython-310.pyc b/ablang2/models/ablang2/__pycache__/ablang.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a672e9b35cc8f7884194955398d55b6c190b86c Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/ablang.cpython-310.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/ablang.cpython-312.pyc b/ablang2/models/ablang2/__pycache__/ablang.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98b60a8969532bf62f8a1cf1f8700ec695b8cc69 Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/ablang.cpython-312.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/encoderblock.cpython-310.pyc b/ablang2/models/ablang2/__pycache__/encoderblock.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..814a6815ef9d4cd685316e3b345c69256fd64d90 Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/encoderblock.cpython-310.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/encoderblock.cpython-312.pyc b/ablang2/models/ablang2/__pycache__/encoderblock.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b0c5b26179c2f6dbafbb40f750d52b7cdb35c2 Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/encoderblock.cpython-312.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/tokenizers.cpython-312.pyc b/ablang2/models/ablang2/__pycache__/tokenizers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb1d9bca664f5d789ed5c1fb6e128faee95f1fec Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/tokenizers.cpython-312.pyc differ diff --git a/ablang2/models/ablang2/__pycache__/vocab.cpython-312.pyc b/ablang2/models/ablang2/__pycache__/vocab.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e73a12d82f1a1ad6b7123353f1c281cf793870d1 Binary files /dev/null and b/ablang2/models/ablang2/__pycache__/vocab.cpython-312.pyc differ diff --git a/ablang2/models/ablang2/ablang.py b/ablang2/models/ablang2/ablang.py new file mode 100644 index 0000000000000000000000000000000000000000..60fb086da2703a3a9c0c3799b87d381ce1adc559 --- /dev/null +++ b/ablang2/models/ablang2/ablang.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F + +from .encoderblock import TransformerEncoder, get_activation_fn + + +class AbLang(torch.nn.Module): + """ + AbLang inspired by ESM-2's architecture. + """ + + def __init__( + self, + vocab_size, + hidden_embed_size, + n_attn_heads, + n_encoder_blocks, + padding_tkn, + mask_tkn, + layer_norm_eps: float = 1e-12, + a_fn: str = "gelu", + dropout: float = 0.0, + ): + super().__init__() + + self.AbRep = AbRep( + vocab_size, + hidden_embed_size, + n_attn_heads, + n_encoder_blocks, + padding_tkn, + mask_tkn, + layer_norm_eps, + a_fn, + dropout, + ) + self.AbHead = AbHead( + vocab_size, + hidden_embed_size, + self.AbRep.aa_embed_layer.weight, + layer_norm_eps, + a_fn, + ) + + def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): + + representations = self.AbRep(tokens, return_attn_weights, return_rep_layers) + + if return_attn_weights: + return representations.attention_weights + + elif return_rep_layers != []: + return representations.many_hidden_states + else: + likelihoods = self.AbHead(representations.last_hidden_states) + return likelihoods + + def get_aa_embeddings(self): + "Extracts the trained aa_embeddings." + return self.AbRep.aa_embed_layer + + +class AbRep(torch.nn.Module): + """ + AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations). + """ + + def __init__( + self, + vocab_size, + hidden_embed_size, + n_attn_heads, + n_encoder_blocks, + padding_tkn, + mask_tkn, + layer_norm_eps: float = 1e-12, + a_fn: str = "gelu", + dropout: float = 0.1, + ): + super().__init__() + self.padding_tkn = padding_tkn + self.mask_tkn = mask_tkn + + self.aa_embed_layer = nn.Embedding( + vocab_size, + hidden_embed_size, + padding_idx=padding_tkn, + ) + self.encoder_blocks = nn.ModuleList( + [TransformerEncoder( + hidden_embed_size, + n_attn_heads, + attn_dropout = dropout, + layer_norm_eps = layer_norm_eps, + a_fn = a_fn, + ) for _ in range(n_encoder_blocks)] + ) + self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) + + def forward(self, + tokens, + return_attn_weights=False, + return_rep_layers=[], + ): + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_tkn) + + hidden_embed = self.aa_embed_layer(tokens) + + return_rep_layers = set(return_rep_layers) + rep_layers = {} + if 0 in return_rep_layers: rep_layers[0] = hidden_embed + + all_attn_weights = [] + + for n_layer, encoder_block in enumerate(self.encoder_blocks): + hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights) + + if (n_layer + 1) in return_rep_layers: + rep_layers[n_layer + 1] = hidden_embed + + if return_attn_weights: + all_attn_weights.append(attn_weights) + + hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed) + + return DataAbRep( + last_hidden_states=hidden_embed, + many_hidden_states=rep_layers, + attention_weights=all_attn_weights + ) + + +class AbHead(torch.nn.Module): + """ + AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations). + """ + + def __init__( + self, + vocab_size, + hidden_embed_size, + weights, + layer_norm_eps: float = 1e-12, + a_fn: str = "gelu", + ): + super().__init__() + + activation_fn, scale = get_activation_fn(a_fn) + + self.ff = torch.nn.Sequential( + nn.Linear(hidden_embed_size, hidden_embed_size * scale), + activation_fn(), + nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps), + ) + + self.weights = weights + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + def forward(self, hidden_embed): + + hidden_embed = self.ff(hidden_embed) + logits = F.linear(hidden_embed, self.weights) + self.bias + + return logits + + +@dataclass +class DataAbRep(): + """ + Dataclass used to store AbRep output. + """ + + last_hidden_states: torch.FloatTensor + many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attention_weights: Optional[Tuple[torch.FloatTensor]] = None \ No newline at end of file diff --git a/ablang2/models/ablang2/encoderblock.py b/ablang2/models/ablang2/encoderblock.py new file mode 100644 index 0000000000000000000000000000000000000000..8310566c311bf45c12c46bc86b20039914bfe7fd --- /dev/null +++ b/ablang2/models/ablang2/encoderblock.py @@ -0,0 +1,173 @@ +import torch +import math +from torch import nn +import torch.nn.functional as F +import einops +from rotary_embedding_torch import RotaryEmbedding + +class TransformerEncoder(torch.nn.Module): + """ + Single Transformer Encoder. + + """ + def __init__( + self, + hidden_embed_size, + n_attn_heads, + attn_dropout: float = 0.0, + layer_norm_eps: float = 1e-05, + a_fn: str = "gelu", + ): + super().__init__() + + assert hidden_embed_size % n_attn_heads == 0, \ + "Embedding dimension must be devisible with the number of heads." + + self.multihead_attention = MultiHeadAttention( + embed_dim = hidden_embed_size, + num_heads = n_attn_heads, + attention_dropout_prob = attn_dropout + ) + + activation_fn, scale = get_activation_fn(a_fn) + + self.intermediate_layer = torch.nn.Sequential( + torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale), + activation_fn(), + torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size), + ) + + self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) + self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) + + def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False): + + residual = hidden_embed + hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone()) + hidden_embed, attn_weights = self.multihead_attention( + hidden_embed, + attn_mask=attn_mask, + return_attn_weights=return_attn_weights + ) + hidden_embed = residual + hidden_embed + + residual = hidden_embed + hidden_embed = self.final_layer_norm(hidden_embed) + hidden_embed = self.intermediate_layer(hidden_embed) + hidden_embed = residual + hidden_embed + return hidden_embed, attn_weights + +class MultiHeadAttention(torch.nn.Module): + + def __init__( + self, + embed_dim, + num_heads, + attention_dropout_prob: float = 0.0, + bias: bool = True, + ): + super().__init__() + + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.reset_parameters() + + self.rotary_emb = RotaryEmbedding(dim = self.head_dim) + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + + def attention(self, q, k, v, attn_mask=None): + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) + attn_weights = attn_weights / math.sqrt(self.head_dim) + + if attn_mask is not None: + attn_mask = einops.rearrange( + attn_mask, + 'b_size (h1 h2 seq_len) -> b_size h1 h2 seq_len', + h1=1, h2=1 + ) + attn_weights = attn_weights.masked_fill(attn_mask, float("-inf")) + + attn_weights = F.softmax(attn_weights, dim=-1) + + attn = self.attention_dropout(attn_weights) + attn = torch.matmul(attn, v) + return attn, attn_weights + + def forward(self, x, attn_mask=None, return_attn_weights: bool = False): + + batch_size, seq_len, embed_dim = x.size() + + q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) + q *= self.scaling + + q = q.contiguous().view( + batch_size, + seq_len, + self.num_heads, + self.head_dim + ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim] + k = k.contiguous().view( + batch_size, + seq_len, + self.num_heads, + self.head_dim + ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim] + v = v.contiguous().view( + batch_size, + seq_len, + self.num_heads, + self.head_dim + ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim] + + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # Determine value outputs + attn, attn_weights = self.attention( + q, k, v, + attn_mask=attn_mask + ) # attn_weights [n_batch, n_heads, seq_len (target), seq_len (source)] + + attn = attn.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + attn = self.out_proj(attn) + + if return_attn_weights: + return attn, attn_weights + else: + return attn, None + +class SwiGLU(torch.nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + +def get_activation_fn(a_fn): + + if a_fn == "gelu": + return torch.nn.GELU, 1 + + elif a_fn == "swiglu": + return SwiGLU, 2 + \ No newline at end of file diff --git a/ablang2/models/ablang2/tokenizers.py b/ablang2/models/ablang2/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..882ec1fb4c3c431c5a05d864320e76044b4ede38 --- /dev/null +++ b/ablang2/models/ablang2/tokenizers.py @@ -0,0 +1,70 @@ +import json +import torch + +from .vocab import ablang_vocab + +class ABtokenizer: + """ + Tokenizer for the heavy/light chain of antibodies. + """ + + def __init__(self, vocab_dir=None): + self.set_vocab(vocab_dir) + + def __call__(self, sequence_list, mode='encode', pad=False, w_extra_tkns=True, device='cpu'): + + if w_extra_tkns: + sequence_list = [sequence_list] if isinstance(sequence_list[0], str) else sequence_list + else: + sequence_list = [sequence_list] if isinstance(sequence_list, str) else sequence_list + + if mode == 'encode': + data = [self.encode(seq, w_extra_tkns = w_extra_tkns, device = device) for seq in sequence_list] + if pad: return torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.pad_token) + else: return data + elif mode == 'decode': + return [self.decode(tokenized_seq) for tokenized_seq in sequence_list] + else: + raise SyntaxError("Given mode doesn't exist. Use either encode or decode.") + + def set_vocab(self, vocab_dir): + + if vocab_dir: + with open(vocab_dir, encoding="utf-8") as vocab_handle: + self.vocab_to_token=json.load(vocab_handle) + else: + self.aa_to_token = ablang_vocab + + self.token_to_aa = {v: k for k, v in self.aa_to_token.items()} + self.pad_token = self.aa_to_token['-'] + self.start_token = self.aa_to_token['<'] + self.end_token = self.aa_to_token['>'] + self.sep_token = self.aa_to_token['|'] + self.mask_token = self.aa_to_token['*'] + self.unknown_token = self.aa_to_token['X'] + self.all_special_tokens = [ + self.pad_token, + self.start_token, + self.end_token, + self.sep_token, + self.mask_token, + self.unknown_token + ] + + def encode(self, sequence, w_extra_tkns=True, device='cpu'): + + if w_extra_tkns: + heavy, light = sequence + sequence = f"<{heavy}>|<{light}>".replace("<>","") + + tokenized_seq = [self.aa_to_token[resn] for resn in sequence] + return torch.tensor(tokenized_seq, dtype=torch.long, device=device) + + def decode(self, tokenized_seq): + + if torch.is_tensor(tokenized_seq): tokenized_seq = tokenized_seq.cpu().numpy() + + return ''.join([self.token_to_aa[token] for token in tokenized_seq]) + + + diff --git a/ablang2/models/ablang2/vocab.py b/ablang2/models/ablang2/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..76915cf3df11ad07ebfee4b59a0a810396ae9597 --- /dev/null +++ b/ablang2/models/ablang2/vocab.py @@ -0,0 +1,29 @@ + +ablang_vocab = { + "<": 0, # Start token + "-": 21, # Padding token + ">": 22, # End token + "*": 23, # Mask token + "X": 24, # Unknown (residue) token + "|": 25, # Separation (of heavy and light chain) token + "M": 1, + "R": 2, + "H": 3, + "K": 4, + "D": 5, + "E": 6, + "S": 7, + "T": 8, + "N": 9, + "Q": 10, + "C": 11, + "G": 12, + "P": 13, + "A": 14, + "V": 15, + "I": 16, + "F": 17, + "Y": 18, + "W": 19, + "L": 20 +} \ No newline at end of file diff --git a/ablang2/pretrained.py b/ablang2/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..bb52be4688c75de30c067cb7f5ced138b098d957 --- /dev/null +++ b/ablang2/pretrained.py @@ -0,0 +1,99 @@ +import numpy as np +import torch + +from .load_model import load_model +from .pretrained_utils.restoration import AbRestore +from .pretrained_utils.encodings import AbEncoding +from .pretrained_utils.alignment import AbAlignment +from .pretrained_utils.scores import AbScores + +valid_modes = [ + 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', + 'pseudo_log_likelihood', 'confidence' +] + + +class pretrained(AbEncoding, AbRestore, AbAlignment, AbScores): + """ + Initializes AbLang for heavy or light chains. + """ + + def __init__(self, model_to_use = "ablang2-paired", random_init = False, ncpu = 1, device = 'cpu'): + super().__init__() + + self.used_device = torch.device(device) + + self.AbLang, self.tokenizer, self.hparams = load_model(model_to_use) + self.AbLang.to(self.used_device) + self.AbLang.eval() # Default + self.AbRep = self.AbLang.AbRep + + self.ncpu = ncpu + self.spread = 11 # Based on get_spread_sequences function + + def freeze(self): + self.AbLang.eval() + + def unfreeze(self): + self.AbLang.train() + + def __call__(self, seqs, mode = 'seqcoding', align = False, stepwise_masking=False, fragmented = False, batch_size = 50): + """ + Use different modes for different usecases + """ + if not mode in valid_modes: raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") + + seqs, chain = format_seq_input(seqs, fragmented = fragmented) + + if align: + numbered_seqs, seqs, number_alignment = self.number_sequences( + seqs, chain = chain, fragmented = fragmented + ) + else: + numbered_seqs = None + number_alignment = None + + subset_list = [] + for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: + subset_list.append(getattr(self, mode)(subset, align = align, stepwise_masking=stepwise_masking)) + + return self.reformat_subsets( + subset_list, + mode = mode, + align = align, + numbered_seqs = numbered_seqs, + seqs = seqs, + number_alignment = number_alignment, + ) + + +def format_seq_input(seqs, fragmented = False): + """ + Formats input sequences into the correct format for the tokenizer. + """ + if isinstance(seqs[0], str): + seqs = [seqs] + + seqs = [add_extra_tokens(seq) for seq in seqs] + + return seqs, determine_chain(seqs[0]) + + +def add_extra_tokens(seq, fragmented = False): + + heavy, light = seq + + if fragmented: + return f"{heavy}|{light}" + else: + return f"<{heavy}>|<{light}>".replace("<>","") + + +def determine_chain(seq): + h, l = seq.split('|') + + chain = '' + if len(h)>2: chain+='H' + if len(l)>2: chain+='L' + + return chain diff --git a/ablang2/pretrained_utils/__init__.py b/ablang2/pretrained_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ablang2/pretrained_utils/__pycache__/__init__.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d8e0b41853c0af7a770a0091d76c8f447cf6a0f Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/__init__.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd519583cec84e2d3942c4c52ac855d0c781ab4 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/alignment.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bde5ddb87600557b8a12cafc968fa513894fcb9 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/alignment.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/alignment.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/alignment.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd099a5bf600a016c0588ab8473606aefef7099 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/alignment.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/encodings.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/encodings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b934cc329abe510292002976680a0a078665e894 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/encodings.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/encodings.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/encodings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40848c101570b8f516c81665687944b1389af1af Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/encodings.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57f8fac7b977806135404614d9c22435b5a867bf Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09bff30f90ea8966296fa60db42c19ba6a17482 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/extra_utils.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/restoration.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/restoration.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ef1ced41f9cd5d7bca6b6c944b020d43e2033f0 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/restoration.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/restoration.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/restoration.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48dc48286108598f60076113ac778515c736eb31 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/restoration.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/scores.cpython-310.pyc b/ablang2/pretrained_utils/__pycache__/scores.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c58f9f1d6863438e21d54520beede96690c618c8 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/scores.cpython-310.pyc differ diff --git a/ablang2/pretrained_utils/__pycache__/scores.cpython-312.pyc b/ablang2/pretrained_utils/__pycache__/scores.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3193be329972e2068714b9fe3742f86ffff64238 Binary files /dev/null and b/ablang2/pretrained_utils/__pycache__/scores.cpython-312.pyc differ diff --git a/ablang2/pretrained_utils/alignment.py b/ablang2/pretrained_utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..0d14b9d750509596eea8f2849f383f4434375044 --- /dev/null +++ b/ablang2/pretrained_utils/alignment.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +import numpy as np +import torch + +from .extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment + + +class AbAlignment: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def number_sequences(self, seqs, chain = 'H', fragmented = False): + if chain == 'HL': + numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu) + else: + assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.' + numbered_seqs, seqs, number_alignment = unpaired_msa_numbering( + seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu + ) + + return numbered_seqs, seqs, number_alignment + + def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): + + aligned_encodings = np.concatenate( + [[ + create_alignment( + res_embed, numbered_seq, seq, number_alignment + ) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs) + ]], axis=0 + ) + return aligned_encodings + + + def reformat_subsets( + self, + subset_list, + mode = 'seqcoding', + align = False, + numbered_seqs = None, + seqs = None, + number_alignment = None, + ): + + if mode in [ + 'seqcoding', + 'restore', + 'pseudo_log_likelihood', + 'confidence' + ]: + return np.concatenate(subset_list) + elif align: + subset_list = [ + self.align_encodings( + subset, + numbered_seqs[num*len(subset):(num+1)*len(subset)], + seqs[num*len(subset):(num+1)*len(subset)], + number_alignment + ) for num, subset in enumerate(subset_list) + ] + + subset = np.concatenate(subset_list) + + return aligned_results( + aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]], + aligned_embeds = subset[:,:,:-1].astype(float), + number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values + ) + + elif not align: + return sum(subset_list, []) + else: + return np.concatenate(subset_list) # this needs to be changed + + +@dataclass +class aligned_results(): + """ + Dataclass used to store output. + """ + + aligned_seqs: None + aligned_embeds: None + number_alignment: None \ No newline at end of file diff --git a/ablang2/pretrained_utils/encodings.py b/ablang2/pretrained_utils/encodings.py new file mode 100644 index 0000000000000000000000000000000000000000..1946c5df116bb26915c4b552d9e9eb5d38033e04 --- /dev/null +++ b/ablang2/pretrained_utils/encodings.py @@ -0,0 +1,97 @@ +import numpy as np +import torch + +from .extra_utils import res_to_list, res_to_seq + + +class AbEncoding: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def _initiate_abencoding(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def _encode_sequences(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang.AbRep(tokens).last_hidden_states + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang(tokens) + + def _predict_logits_with_step_masking(self, seqs): + + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + + logits = [] + for single_seq_tokens in tokens: + + tkn_len = len(single_seq_tokens) + masked_tokens = single_seq_tokens.repeat(tkn_len, 1) + for num in range(tkn_len): + masked_tokens[num, num] = self.tokenizer.mask_token + + with torch.no_grad(): + logits_tmp = self.AbLang(masked_tokens) + + logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)]) + + logits.append(logits_tmp) + + return torch.stack(logits, dim=0) + + def seqcoding(self, seqs, **kwargs): + """ + Sequence specific representations + """ + + encodings = self._encode_sequences(seqs).cpu().numpy() + + lens = np.vectorize(len)(seqs) + lens = np.tile(lens.reshape(-1,1,1), (encodings.shape[2], 1)) + + return np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(encodings,1,2), lens]) + + def rescoding(self, seqs, align=False, **kwargs): + """ + Residue specific representations. + """ + encodings = self._encode_sequences(seqs).cpu().numpy() + + if align: return encodings + + else: return [res_to_list(state, seq) for state, seq in zip(encodings, seqs)] + + def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Likelihood of mutations + """ + if stepwise_masking: + logits = self._predict_logits_with_step_masking(seqs).cpu().numpy() + else: + logits = self._predict_logits(seqs).cpu().numpy() + + if align: return logits + + else: return [res_to_list(state, seq) for state, seq in zip(logits, seqs)] + + def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Probability of mutations + """ + if stepwise_masking: + logits = self._predict_logits_with_step_masking(seqs) + else: + logits = self._predict_logits(seqs) + probs = logits.softmax(-1).cpu().numpy() + + if align: return probs + + else: return [res_to_list(state, seq) for state, seq in zip(probs, seqs)] + \ No newline at end of file diff --git a/ablang2/pretrained_utils/extra_utils.py b/ablang2/pretrained_utils/extra_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fca5c5caa944968d673164a1c29c28172b1aa888 --- /dev/null +++ b/ablang2/pretrained_utils/extra_utils.py @@ -0,0 +1,165 @@ +import string, re +import numpy as np + + +def res_to_list(logits, seq): + return logits[:len(seq)] + +def res_to_seq(a, mode='mean'): + """ + Function for how we go from n_values for each amino acid to n_values for each sequence. + + We leave out padding tokens. + """ + + if mode=='sum': + return a[0:(int(a[-1]))].sum() + + elif mode=='mean': + return a[0:(int(a[-1]))].mean() + + elif mode=='restore': + return a[0][0:(int(a[-1]))] + +def get_number_alignment(numbered_seqs): + """ + Creates a number alignment from the anarci results. + """ + import pandas as pd + + alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs] + unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0) + max_alignment = get_max_alignment() + + return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]] + +def get_max_alignment(): + """ + Create maximum possible alignment for sorting + """ + import pandas as pd + + sortlist = [[("<", "")]] + for num in range(1, 128+1): + if num in [33,61,112]: + for char in string.ascii_uppercase[::-1]: + sortlist.append([(num, char)]) + + sortlist.append([(num,' ')]) + else: + sortlist.append([(num,' ')]) + for char in string.ascii_uppercase: + sortlist.append([(num, char)]) + + return pd.DataFrame(sortlist + [[(">", "")]]) + + +def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10): + + import pandas as pd + + tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs] + + numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering( + [i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs + ) + numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering( + [i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs + ) + + number_alignment = pd.concat([ + number_alignment_heavy, + pd.DataFrame([[("|",""), "|"]]), + number_alignment_light] + ).reset_index(drop=True) + + seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)] + numbered_seqs = [ + heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light) + ] + + return numbered_seqs, seqs, number_alignment + + +def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10): + + numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs) + number_alignment = get_number_alignment(numbered_seqs) + number_alignment[1] = chain + + seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs] + return numbered_seqs, seqs, number_alignment + + +def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1): + + import anarci + import pandas as pd + + anarci_out = anarci.run_anarci( + pd.DataFrame(seqs).reset_index().values.tolist(), + ncpu=n_jobs, + scheme='imgt', + allowed_species=['human', 'mouse'], + ) + + numbered_seqs = [] + for onarci in anarci_out[1]: + numbered_seq = [] + for i in onarci[0][0]: + if i[1] != '-': + numbered_seq.append((i[0], chain, i[1])) + + if fragmented: + numbered_seqs.append(numbered_seq) + else: + numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")]) + + return numbered_seqs + + +def create_alignment(res_embeds, numbered_seqs, seq, number_alignment): + + import pandas as pd + + datadf = pd.DataFrame(numbered_seqs) + sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2] + + idxs = np.where(sequence_alignment.values == '-')[0] + idxs = [idx-num for num, idx in enumerate(idxs)] + + aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0)) + + return pd.concat([aligned_embeds, sequence_alignment], axis=1).values + + +def get_spread_sequences(seq, spread, start_position): + """ + Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions). + """ + spread_sequences = [] + + for diff in range(start_position-8, start_position+2+1): + spread_sequences.append('*'*diff+seq) + + return np.array(spread_sequences) + +def get_sequences_from_anarci(out_anarci, max_position, spread): + """ + Ensures correct masking on each side of sequence + """ + + if out_anarci == 'ANARCI_error': + return np.array(['ANARCI-ERR']*spread) + + end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1]) + # Fixes ANARCI error of poor numbering of the CDR1 region + start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+', + out_anarci).group().split(',')[0]) - 1 + + sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci)))) + + sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position)) + + return get_spread_sequences(sequence_j, spread, start_position) + diff --git a/ablang2/pretrained_utils/restoration.py b/ablang2/pretrained_utils/restoration.py new file mode 100644 index 0000000000000000000000000000000000000000..99745fd330cf51a43c8c17ad788d5831173d212c --- /dev/null +++ b/ablang2/pretrained_utils/restoration.py @@ -0,0 +1,96 @@ +import numpy as np +import torch + +from .extra_utils import res_to_seq, get_sequences_from_anarci + + +class AbRestore: + def __init__(self, spread = 11, device = 'cpu', ncpu = 1): + self.spread = spread + self.device = device + self.ncpu = ncpu + + def _initiate_abrestore(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def restore(self, seqs, align = False, **kwargs): + """ + Restore sequences + """ + n_seqs = len(seqs) + + if align: + + seqs = self._sequence_aligning(seqs) + nr_seqs = len(seqs)//self.spread + + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + # Reshape + tokens = tokens.reshape(nr_seqs, self.spread, -1) + predictions = predictions.reshape(nr_seqs, self.spread, -1, 20) + seqs = seqs.reshape(nr_seqs, -1) + + # Find index of best predictions + best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1) + + # Select best predictions + tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1) + predictions = predictions[range(predictions.shape[0]), best_seq_idx] + seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1) + + else: + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + predicted_tokens = torch.max(predictions, -1).indices + 1 + restored_tokens = torch.where(tokens==23, predicted_tokens, tokens) + + restored_seqs = self.tokenizer(restored_tokens, mode="decode") + + if n_seqs < len(restored_seqs): + restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])] + seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])] + + return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]]) + + def _create_spread_of_sequences(self, seqs, chain = 'H'): + import pandas as pd + import anarci + + chain_idx = 0 if chain == 'H' else 1 + numbered_seqs = anarci.run_anarci( + pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), + ncpu=self.ncpu, + scheme='imgt', + allowed_species=['human', 'mouse'], + ) + + anarci_data = pd.DataFrame( + [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]], + columns=['anarci'] + ).astype('", "").replace("<", "").split("|") for pairs in seqs] + + spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')] + spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')] + + return np.concatenate([np.array(spread_heavy),np.array(spread_light)]) \ No newline at end of file diff --git a/ablang2/pretrained_utils/scores.py b/ablang2/pretrained_utils/scores.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5269ed0e3c8981e8babdec93229cba8f71ef83 --- /dev/null +++ b/ablang2/pretrained_utils/scores.py @@ -0,0 +1,98 @@ +import numpy as np +import torch + +from .extra_utils import res_to_list, res_to_seq + + +class AbScores: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def _initiate_abencoding(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def _encode_sequences(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang.AbRep(tokens).last_hidden_states.numpy() + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang(tokens), tokens + + def pseudo_log_likelihood(self, seqs, **kwargs): + """ + Pseudo log likelihood of sequences. + """ + + plls = [] + for seq in seqs: + + labels = self.tokenizer( + seq, pad=True, w_extra_tkns=False, device=self.used_device + ) + + idxs = ( + ~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) + ).nonzero() + + masked_tokens = labels.repeat(len(idxs), 1) + for num, idx in enumerate(idxs): + masked_tokens[num, idx[1]] = self.tokenizer.mask_token + + with torch.no_grad(): + logits = self.AbLang(masked_tokens) + + logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") + logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)]) + + labels = labels[:,idxs[:,1:]].squeeze(2)[0] + + nll = torch.nn.functional.cross_entropy( + logits, + labels, + reduction="mean", + ) + + pll = -nll + + plls.append(pll) + + plls = torch.stack(plls, dim=0).cpu().numpy() + + return plls + + def confidence(self, seqs, **kwargs): + """ + Log likelihood of sequences without masking. + """ + + labels = self.tokenizer( + seqs, pad=True, w_extra_tkns=False, device=self.used_device + ) + with torch.no_grad(): + logits = self.AbLang(labels) + logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") + + plls = [] + for label, logit in zip(labels, logits): + + idxs = ( + ~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) + ).nonzero().squeeze(1) + + nll = torch.nn.functional.cross_entropy( + logit[idxs], + label[idxs], + reduction="mean", + ) + + pll = -nll + plls.append(pll) + + return torch.stack(plls, dim=0).cpu().numpy() \ No newline at end of file diff --git a/ablang2/restoration.py b/ablang2/restoration.py new file mode 100644 index 0000000000000000000000000000000000000000..99745fd330cf51a43c8c17ad788d5831173d212c --- /dev/null +++ b/ablang2/restoration.py @@ -0,0 +1,96 @@ +import numpy as np +import torch + +from .extra_utils import res_to_seq, get_sequences_from_anarci + + +class AbRestore: + def __init__(self, spread = 11, device = 'cpu', ncpu = 1): + self.spread = spread + self.device = device + self.ncpu = ncpu + + def _initiate_abrestore(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def restore(self, seqs, align = False, **kwargs): + """ + Restore sequences + """ + n_seqs = len(seqs) + + if align: + + seqs = self._sequence_aligning(seqs) + nr_seqs = len(seqs)//self.spread + + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + # Reshape + tokens = tokens.reshape(nr_seqs, self.spread, -1) + predictions = predictions.reshape(nr_seqs, self.spread, -1, 20) + seqs = seqs.reshape(nr_seqs, -1) + + # Find index of best predictions + best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1) + + # Select best predictions + tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1) + predictions = predictions[range(predictions.shape[0]), best_seq_idx] + seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1) + + else: + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + predictions = self.AbLang(tokens)[:,:,1:21] + + predicted_tokens = torch.max(predictions, -1).indices + 1 + restored_tokens = torch.where(tokens==23, predicted_tokens, tokens) + + restored_seqs = self.tokenizer(restored_tokens, mode="decode") + + if n_seqs < len(restored_seqs): + restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])] + seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])] + + return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]]) + + def _create_spread_of_sequences(self, seqs, chain = 'H'): + import pandas as pd + import anarci + + chain_idx = 0 if chain == 'H' else 1 + numbered_seqs = anarci.run_anarci( + pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), + ncpu=self.ncpu, + scheme='imgt', + allowed_species=['human', 'mouse'], + ) + + anarci_data = pd.DataFrame( + [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]], + columns=['anarci'] + ).astype('", "").replace("<", "").split("|") for pairs in seqs] + + spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')] + spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')] + + return np.concatenate([np.array(spread_heavy),np.array(spread_light)]) \ No newline at end of file diff --git a/ablang2/scores.py b/ablang2/scores.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5269ed0e3c8981e8babdec93229cba8f71ef83 --- /dev/null +++ b/ablang2/scores.py @@ -0,0 +1,98 @@ +import numpy as np +import torch + +from .extra_utils import res_to_list, res_to_seq + + +class AbScores: + + def __init__(self, device = 'cpu', ncpu = 1): + + self.device = device + self.ncpu = ncpu + + def _initiate_abencoding(self, model, tokenizer): + self.AbLang = model + self.tokenizer = tokenizer + + def _encode_sequences(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang.AbRep(tokens).last_hidden_states.numpy() + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) + with torch.no_grad(): + return self.AbLang(tokens), tokens + + def pseudo_log_likelihood(self, seqs, **kwargs): + """ + Pseudo log likelihood of sequences. + """ + + plls = [] + for seq in seqs: + + labels = self.tokenizer( + seq, pad=True, w_extra_tkns=False, device=self.used_device + ) + + idxs = ( + ~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) + ).nonzero() + + masked_tokens = labels.repeat(len(idxs), 1) + for num, idx in enumerate(idxs): + masked_tokens[num, idx[1]] = self.tokenizer.mask_token + + with torch.no_grad(): + logits = self.AbLang(masked_tokens) + + logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") + logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)]) + + labels = labels[:,idxs[:,1:]].squeeze(2)[0] + + nll = torch.nn.functional.cross_entropy( + logits, + labels, + reduction="mean", + ) + + pll = -nll + + plls.append(pll) + + plls = torch.stack(plls, dim=0).cpu().numpy() + + return plls + + def confidence(self, seqs, **kwargs): + """ + Log likelihood of sequences without masking. + """ + + labels = self.tokenizer( + seqs, pad=True, w_extra_tkns=False, device=self.used_device + ) + with torch.no_grad(): + logits = self.AbLang(labels) + logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") + + plls = [] + for label, logit in zip(labels, logits): + + idxs = ( + ~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) + ).nonzero().squeeze(1) + + nll = torch.nn.functional.cross_entropy( + logit[idxs], + label[idxs], + reduction="mean", + ) + + pll = -nll + plls.append(pll) + + return torch.stack(plls, dim=0).cpu().numpy() \ No newline at end of file diff --git a/ablang2/test_ablang2_HF_implementation.ipynb b/ablang2/test_ablang2_HF_implementation.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..e918bb988145d8d51d161dc46997ad633b940e0b --- /dev/null +++ b/ablang2/test_ablang2_HF_implementation.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "458aed0f", + "metadata": {}, + "source": [ + "Note: This notebook is adapted from the [AbLang2](https://github.com/TobiasHeOl/AbLang2) model's GitHub repository. It is used to verify that the Hugging Face implementation functions correctly and produces the same output as the original model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np \n", + "from transformers import AutoTokenizer, AutoModel\n", + "from ablang2_paired.adapter import AbLang2PairedHuggingFaceAdapter" + ] + }, + { + "cell_type": "markdown", + "id": "10801511-770d-46ac-a15d-a02d4ef9ec87", + "metadata": {}, + "source": [ + "# **0. Sequence input and its format**\n", + "\n", + "AbLang2 takes as input either the individual heavy variable domain (VH), light variable domain (VL), or the full variable domain (Fv).\n", + "\n", + "Each record (antibody) needs to be a list with the VH as the first element and the VL as the second. If either the VH or VL is not known, leave an empty string.\n", + "\n", + "An asterisk (\\*) is used for masking. It is recommended to mask residues which you are interested in mutating.\n", + "\n", + "**NB:** It is important that the VH and VL sequence is ordered correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "99192978-a008-4a32-a80e-bba238e0ec7c", + "metadata": {}, + "outputs": [], + "source": [ + "seq1 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS', # VH sequence\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK' # VL sequence\n", + "]\n", + "seq2 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT',\n", + " 'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "seq3 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',\n", + " '' # The VL sequence is not known, so an empty string is left instead. \n", + "]\n", + "seq4 = [\n", + " '',\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "seq5 = [\n", + " 'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', # (*) is used to mask certain residues\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "\n", + "all_seqs = [seq1, seq2, seq3, seq4, seq5]\n", + "only_both_chains_seqs = [seq1, seq2, seq5]" + ] + }, + { + "cell_type": "markdown", + "id": "dffbacfa-8642-4d94-9572-2205a05c18f9", + "metadata": {}, + "source": [ + "# **1. How to use AbLang2**\n", + "\n", + "AbLang2 can be downloaded and used in its raw form as seen below. For convenience, we have also developed different \"modes\" which can be used for specific use cases (see Section 2) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d66ad84", + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModel.from_pretrained(\"hemantn/AbLang2/ablang2_paired\", trust_remote_code=True)\n", + "tokenizer = AutoTokenizer.from_pretrained(\"hemantn/AbLang2/ablang2_paired\", trust_remote_code=True)\n", + "ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "48562761-6ebe-4025-be97-918c9f9eff7e", + "metadata": {}, + "source": [ + "# **2. Different modes for specific usecases**\n", + "\n", + "AbLang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.\n", + "\n", + "1. seqcoding: Generates sequence representations for each sequence\n", + "2. rescoding: Generates residue representations for each residue in each sequence\n", + "3. likelihood: Generates likelihoods for each amino acid at each position in each sequence\n", + "4. probability: Generates probabilities for each amino acid at each position in each sequence\n", + "5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)\n", + "6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)\n", + "7. restore: Restores masked residues\n", + "\n", + "### **AbLang2 can also align the resulting representations using ANARCI**\n", + "\n", + "This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument \"align=True\".\n", + "\n", + "**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.\n", + "\n", + "### **The align argument can also be used to restore variable missing lengths**\n", + "\n", + "For this, use \"align=True\" with the 'restore' mode." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ceae4a88-0679-4704-8bad-c06a4569c497", + "metadata": {}, + "outputs": [], + "source": [ + "valid_modes = [\n", + " 'seqcoding', 'rescoding', 'likelihood', 'probability',\n", + " 'pseudo_log_likelihood', 'confidence', 'restore' \n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "aa333732-7508-4826-92ec-3acdd54bc1bb", + "metadata": {}, + "source": [ + "## **seqcoding** \n", + "\n", + "The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, including extra tokens. \n", + "\n", + "**NB:** Seqcodings can also be derived in other ways like using the sum or averaging across only parts of the input such as the CDRs. For such cases please use and adapt the below rescoding." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.25206311, 0.18189634, 0.00887137, ..., 0.15365517,\n", + " -0.14508603, -0.13381317],\n", + " [-0.25149415, 0.2086455 , 0.07518203, ..., 0.19478269,\n", + " -0.15227772, -0.08241647],\n", + " [-0.27468949, 0.16507216, 0.08667156, ..., 0.18776284,\n", + " -0.14165082, -0.16389885],\n", + " [-0.1982213 , 0.16841085, -0.04925933, ..., 0.11400164,\n", + " -0.14723683, -0.09713171],\n", + " [-0.29553188, 0.17239201, 0.05676926, ..., 0.15943622,\n", + " -0.16615383, -0.15569784]], shape=(5, 480))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(all_seqs, mode='seqcoding')\n" + ] + }, + { + "cell_type": "markdown", + "id": "4b5d9d60", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[-0.40741208, -0.5118987 , 0.06096708, ..., 0.3268144 ,\n", + " 0.03920235, -0.36715826],\n", + " [-0.5768883 , 0.38245413, -0.21791998, ..., 0.01250262,\n", + " -0.08844463, -0.32367525],\n", + " [-0.1475935 , 0.39639047, -0.38226923, ..., -0.10119921,\n", + " -0.41469565, -0.00319315],\n", + " ...,\n", + " [-0.14358369, 0.3124389 , -0.30157998, ..., -0.13289244,\n", + " -0.45353398, -0.07878865],\n", + " [ 0.17538925, 0.24394299, 0.20141171, ..., 0.14587352,\n", + " -0.38479003, 0.07409196],\n", + " [-0.23031706, -0.35487285, 0.1960684 , ..., -0.1283362 ,\n", + " 0.31107333, -0.3265108 ]], shape=(238, 480), dtype=float32),\n", + " array([[-0.41981837, -0.3666375 , 0.10595217, ..., 0.3903574 ,\n", + " 0.0382378 , -0.36337993],\n", + " [-0.5054137 , 0.38347068, -0.10992069, ..., -0.05231472,\n", + " -0.13636623, -0.34830108],\n", + " [-0.06784609, 0.69349885, -0.4212398 , ..., -0.24805346,\n", + " -0.39583805, -0.10972726],\n", + " ...,\n", + " [-0.2090099 , 0.29489496, -0.11039071, ..., -0.24245434,\n", + " -0.60625184, -0.02307999],\n", + " [ 0.19134358, 0.21744648, 0.2575827 , ..., 0.15845427,\n", + " -0.34743664, 0.10218249],\n", + " [-0.2551157 , -0.21778448, 0.21906358, ..., -0.09656111,\n", + " 0.22394855, -0.20267345]], shape=(222, 480), dtype=float32),\n", + " array([[-0.40043733, -0.48596814, 0.0886725 , ..., 0.38941646,\n", + " 0.06195956, -0.40999672],\n", + " [-0.54576075, 0.4312959 , -0.3451486 , ..., -0.09285564,\n", + " 0.03116508, -0.45269737],\n", + " [ 0.0221165 , 0.53196615, -0.30137214, ..., -0.1889072 ,\n", + " -0.32587305, 0.05078396],\n", + " ...,\n", + " [ 0.2630385 , -0.22976042, 0.5510368 , ..., 0.47436473,\n", + " -0.42733562, -0.83135855],\n", + " [-0.13752195, 0.28678602, -0.18887053, ..., 0.28262627,\n", + " 0.1254679 , -0.6496486 ],\n", + " [-0.4541417 , 0.24564984, 0.2132735 , ..., 0.03287445,\n", + " 0.03825552, -0.34259132]], shape=(124, 480), dtype=float32),\n", + " array([[-0.26863217, 0.32259187, 0.10813517, ..., 0.03953876,\n", + " 0.18312076, -0.00498045],\n", + " [-0.2165424 , -0.38562432, -0.02696264, ..., 0.20541488,\n", + " 0.18698391, -0.22639504],\n", + " [-0.41950518, 0.04743317, 0.0048816 , ..., 0.11408642,\n", + " -0.05384652, 0.1025871 ],\n", + " ...,\n", + " [-0.10960457, 0.35151365, -0.21752454, ..., -0.21448943,\n", + " -0.6396219 , -0.00839792],\n", + " [ 0.20491892, 0.36294487, 0.19217414, ..., 0.07750722,\n", + " -0.5039212 , 0.03793833],\n", + " [-0.11638474, -0.35350856, 0.13215722, ..., -0.1606055 ,\n", + " 0.23913842, -0.2565337 ]], shape=(115, 480), dtype=float32),\n", + " array([[-0.42062947, -0.44009134, 0.00152371, ..., 0.27141467,\n", + " 0.03798106, -0.397461 ],\n", + " [-0.57318133, 0.5258899 , -0.17001636, ..., -0.23864633,\n", + " 0.2088059 , -0.57877594],\n", + " [-0.38988614, 0.46168196, -0.3429413 , ..., -0.14872643,\n", + " -0.46576905, -0.21224979],\n", + " ...,\n", + " [-0.21528634, 0.30046722, -0.25216463, ..., -0.11576828,\n", + " -0.4704907 , -0.0740136 ],\n", + " [ 0.0633081 , 0.22700705, 0.28184187, ..., 0.15967266,\n", + " -0.377182 , 0.06188517],\n", + " [-0.27826303, -0.37297496, 0.21229912, ..., -0.14886017,\n", + " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(all_seqs, mode='rescoding', stepwise_masking = False)" + ] + }, + { + "cell_type": "markdown", + "id": "6da2183b-4306-49bd-a7fc-23e78a23f305", + "metadata": {}, + "source": [ + "## **Align rescoding/likelihood/probability output**\n", + "\n", + "For the 'rescoding', 'likelihood', and 'probability' modes, the output can also be aligned using the argument \"align=True\".\n", + "\n", + "This is done using the antibody numbering tool ANARCI, and requires manually installing **Pandas** and **[ANARCI](https://github.com/oxpig/ANARCI)**.\n", + "\n", + "**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n", + " '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n", + " '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n", + " '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n", + " '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n", + " '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n", + " '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n", + " '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n", + " '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n", + " '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n", + " '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n", + " '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n", + " '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n", + " '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n", + " '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n", + " '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n", + " '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n", + " '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n", + " '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n", + " '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n", + " '125 ' '126 ' '127 ' '>']\n", + "['|', '|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|']\n", + "[[[ 9.31622028 -3.42184472 -3.59397936 ... -14.7370739 -6.89358234\n", + " -0.23662642]\n", + " [ -3.54718733 -5.84866858 -4.02423763 ... -12.93967152 -9.56145287\n", + " -4.48474121]\n", + " [-11.94997501 -2.24554634 -5.6948204 ... -15.19639015 -17.97454453\n", + " -12.56952095]\n", + " ...\n", + " [ -8.94504929 -0.42261285 -4.95588112 ... -16.66817665 -15.22247219\n", + " -10.37267971]\n", + " [-11.65150166 -5.4447751 -2.95585871 ... -16.25555801 -9.75158882\n", + " -11.75897121]\n", + " [ 1.79469919 -1.95846868 -3.59784389 ... -14.95585823 -7.47080564\n", + " -0.95226467]]\n", + "\n", + " [[ 8.55518436 -3.83663297 -2.33595777 ... -13.87456703 -8.14840508\n", + " -0.42472187]\n", + " [ -4.40701675 -5.53201342 -3.69396996 ... -12.97878265 -9.86258698\n", + " -4.9541502 ]\n", + " [-11.95642853 -3.86211038 -5.80935287 ... -14.89213085 -16.94556046\n", + " -11.36959743]\n", + " ...\n", + " [ -7.75924206 -0.66524202 -4.08643198 ... -16.16580582 -14.76507473\n", + " -8.35070801]\n", + " [-11.91039753 -4.86995268 -2.74777412 ... -16.07695007 -8.44975281\n", + " -10.45223904]\n", + " [ 0.86007357 -2.37964249 -3.58130407 ... -15.3542347 -7.73035717\n", + " -1.11989975]]\n", + "\n", + " [[ -4.37902927 -7.55587339 1.21958411 ... -15.48622513 -6.02184296\n", + " -3.79647899]\n", + " [ 0. 0. 0. ... 0. 0.\n", + " 0. ]\n", + " [ 0. 0. 0. ... 0. 0.\n", + " 0. ]\n", + " ...\n", + " [ -8.94207764 -0.51090211 -5.09760666 ... -16.69521904 -15.45450783\n", + " -10.50823402]\n", + " [-11.92355156 -5.55152702 -2.87667084 ... -16.40607834 -10.19432163\n", + " -12.13287926]\n", + " [ 2.42200565 -2.01573086 -3.6170125 ... -14.9590435 -7.19029188\n", + " -0.89829886]]]\n" + ] + } + ], + "source": [ + "results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n", + "\n", + "print(results.number_alignment)\n", + "print(results.aligned_seqs)\n", + "print(results.aligned_embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "56be8cad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[9.9955505e-01, 2.9358694e-06, 2.4716087e-06, ..., 3.5776201e-11,\n", + " 9.1196831e-08, 7.0967326e-05],\n", + " [4.1573694e-06, 4.1619489e-07, 2.5800944e-06, ..., 3.4650952e-10,\n", + " 1.0159109e-08, 1.6279575e-06],\n", + " [7.8059600e-08, 1.2794037e-03, 4.0645118e-05, ..., 3.0375720e-09,\n", + " 1.8879491e-10, 4.2010839e-08],\n", + " ...,\n", + " [3.4210879e-07, 1.7195340e-03, 1.8477240e-05, ..., 1.5137445e-10,\n", + " 6.4255873e-10, 8.2064140e-08],\n", + " [9.1038084e-09, 4.5161755e-06, 5.4411950e-05, ..., 9.1139631e-11,\n", + " 6.0862085e-08, 8.1761966e-09],\n", + " [8.5759175e-04, 2.0104915e-05, 3.9023766e-06, ..., 4.5562460e-11,\n", + " 8.1156479e-08, 5.4990651e-05]], shape=(238, 26), dtype=float32),\n", + " array([[9.9939799e-01, 4.1499175e-06, 1.8611167e-05, ..., 1.8139243e-10,\n", + " 5.5649299e-08, 1.2583815e-04],\n", + " [1.6735513e-06, 5.4332406e-07, 3.4143472e-06, ..., 3.1693398e-10,\n", + " 7.1501400e-09, 9.6832969e-07],\n", + " [3.7784993e-08, 1.2377645e-04, 1.7658784e-05, ..., 2.0061326e-09,\n", + " 2.5737484e-10, 6.7947965e-08],\n", + " ...,\n", + " [1.1050455e-06, 1.3312638e-03, 4.3497097e-05, ..., 2.4686178e-10,\n", + " 1.0018089e-09, 6.1165900e-07],\n", + " [5.7270397e-09, 6.5396339e-06, 5.4601755e-05, ..., 8.8801404e-11,\n", + " 1.8233513e-07, 2.4615032e-08],\n", + " [7.3952030e-04, 2.8970928e-05, 8.7113440e-06, ..., 6.7168833e-11,\n", + " 1.3746008e-07, 1.0210846e-04]], shape=(222, 26), dtype=float32),\n", + " array([[9.99685407e-01, 3.35662639e-06, 1.14241482e-06, ...,\n", + " 2.32460891e-11, 6.88188067e-08, 5.69467156e-05],\n", + " [6.38133372e-07, 1.01300586e-07, 5.64459742e-06, ...,\n", + " 4.09234556e-11, 2.53804799e-09, 4.31722100e-07],\n", + " [1.49096788e-08, 2.04515047e-04, 9.23794141e-06, ...,\n", + " 7.46306961e-10, 2.92107380e-11, 2.21786500e-08],\n", + " ...,\n", + " [2.15093763e-07, 1.06453872e-03, 1.62486140e-05, ...,\n", + " 1.12102910e-10, 1.47300866e-10, 4.73037538e-08],\n", + " [4.30136682e-09, 3.09317988e-06, 3.96632568e-05, ...,\n", + " 5.24226877e-11, 2.39579450e-08, 3.86403221e-09],\n", + " [9.77773685e-04, 1.29533228e-05, 2.78623725e-06, ...,\n", + " 2.73364300e-11, 3.96418649e-08, 4.04014427e-05]],\n", + " shape=(238, 26), dtype=float32)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(only_both_chains_seqs, mode='probability')" + ] + }, + { + "cell_type": "markdown", + "id": "8f0a71ec-e916-4330-90d0-13a4b1121a89", + "metadata": {}, + "source": [ + "## **Pseudo log likelihood and Confidence scores**\n", + "\n", + "The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.\n", + "\n", + "- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).\n", + "\n", + "- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. \n", + "\n", + "**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.\n", + "\n", + "**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "83f3064b-48a7-42fb-ba82-ec153ea946da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = ablang(all_seqs, mode='pseudo_log_likelihood')\n", + "np.exp(-results) # convert to pseudo perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.2636039, 1.126463 , 1.3123759, 1.2140925, 1.1805097],\n", + " dtype=float32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = ablang(all_seqs, mode='confidence')\n", + "np.exp(-results)" + ] + }, + { + "cell_type": "markdown", + "id": "e0b63e48-b2a1-4a8e-8ecb-449748a2cb25", + "metadata": {}, + "source": [ + "## **restore**\n", + "\n", + "This mode can be used to restore masked residues, and fragmented regions with \"align=True\". " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['|',\n", + " '|',\n", + " '|'],\n", + " dtype='|',\n", + " '|',\n", + " '|'],\n", + " dtype='|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "b046ae57", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "78ccf7d8", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lib_transformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ablang2/tokenizer_ablang2paired.py b/ablang2/tokenizer_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..5522bbaa0d134f3a1e490d6e58114f1a08a05a9c --- /dev/null +++ b/ablang2/tokenizer_ablang2paired.py @@ -0,0 +1,97 @@ +import json +import os +from transformers import PreTrainedTokenizer + + +class AbLang2PairedTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "vocab.json"} + model_input_names = ["input_ids"] + + def __init__(self, vocab_file=None, **kwargs): + if vocab_file is None: + # Try to find vocab file in the current directory + vocab_file = "vocab.json" + + self.vocab_file = vocab_file + with open(vocab_file, "r", encoding="utf-8") as f: + self.vocab = json.load(f) + + # Set required token attributes (all as strings, standard for HF) + kwargs.setdefault("pad_token", "-") + kwargs.setdefault("mask_token", "*") + kwargs.setdefault("unk_token", "X") + + super().__init__(**kwargs) + + @property + def pad_token_id(self): + return self.vocab[self.pad_token] + + @property + def mask_token_id(self): + return self.vocab[self.mask_token] + + def _tokenize(self, text): + return list(text) + + def tokenize(self, text, text_pair=None, **kwargs): + """Tokenize text or text pair.""" + if text_pair is not None: + # For paired sequences, combine them with a separator + combined_text = text + "|" + text_pair + return self._tokenize(combined_text) + else: + return self._tokenize(text) + + def _convert_token_to_id(self, token): + return self.vocab.get(token, self.vocab[self.unk_token]) + + def _convert_id_to_token(self, index): + inv_vocab = {v: k for k, v in self.vocab.items()} + return inv_vocab.get(index, self.unk_token) + + def get_vocab(self): + return self.vocab + + def save_vocabulary(self, save_directory, filename_prefix=None): + os.makedirs(save_directory, exist_ok=True) + path = os.path.join(save_directory, (filename_prefix or "") + "vocab.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(self.vocab, f) + return (path,) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") + if not os.path.exists(vocab_file): + raise ValueError(f"Vocabulary file {vocab_file} not found") + return cls(vocab_file=vocab_file, **kwargs) + + def save_pretrained(self, save_directory, filename_prefix=None): + os.makedirs(save_directory, exist_ok=True) + vocab_files = self.save_vocabulary(save_directory, filename_prefix) + + tokenizer_config = { + "tokenizer_class": f"{self.__class__.__module__}.{self.__class__.__name__}" + } + with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: + json.dump(tokenizer_config, f, indent=2) + + return vocab_files + + def __call__(self, sequences, padding=False, return_tensors=None, **kwargs): + # Accepts a string or a list of strings + if isinstance(sequences, str): + sequences = [sequences] + # Tokenize each sequence + input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences] + # Padding + if padding: + maxlen = max(len(ids) for ids in input_ids) + input_ids = [ids + [self.pad_token_id] * (maxlen - len(ids)) for ids in input_ids] + # Return tensors if requested + if return_tensors == 'pt': + import torch + input_ids = torch.tensor(input_ids) + return {'input_ids': input_ids} + diff --git a/ablang2/vocab.json b/ablang2/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..647358ba2a1b3199ed1c483af2bfaf4ccf844baa --- /dev/null +++ b/ablang2/vocab.json @@ -0,0 +1,28 @@ +{ + "<": 0, + "M": 1, + "R": 2, + "H": 3, + "K": 4, + "D": 5, + "E": 6, + "S": 7, + "T": 8, + "N": 9, + "Q": 10, + "C": 11, + "G": 12, + "P": 13, + "A": 14, + "V": 15, + "I": 16, + "F": 17, + "Y": 18, + "W": 19, + "L": 20, + "-": 21, + ">": 22, + "*": 23, + "X": 24, + "|": 25 +} \ No newline at end of file diff --git a/ablang2/vocab.py b/ablang2/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..25ad702a2ee8c04a6c8cde038a27bb7ce18b2d14 --- /dev/null +++ b/ablang2/vocab.py @@ -0,0 +1,28 @@ +ablang_vocab = { + "<": 0, + "M": 1, + "R": 2, + "H": 3, + "K": 4, + "D": 5, + "E": 6, + "S": 7, + "T": 8, + "N": 9, + "Q": 10, + "C": 11, + "G": 12, + "P": 13, + "A": 14, + "V": 15, + "I": 16, + "F": 17, + "Y": 18, + "W": 19, + "L": 20, + "-": 21, + ">": 22, + "*": 23, + "X": 24, + "|": 25 +} \ No newline at end of file diff --git a/adapter.py b/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8fe459cdd6a90c00880b9b84dfc0064b4aa466 --- /dev/null +++ b/adapter.py @@ -0,0 +1,306 @@ +from ablang2.pretrained_utils.restoration import AbRestore +from ablang2.pretrained_utils.encodings import AbEncoding +from ablang2.pretrained_utils.alignment import AbAlignment +from ablang2.pretrained_utils.scores import AbScores +import torch +import numpy as np +from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list + +class HuggingFaceTokenizerAdapter: + def __init__(self, tokenizer, device): + self.tokenizer = tokenizer + self.device = device + self.pad_token_id = tokenizer.pad_token_id + self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.all_special_tokens = tokenizer.all_special_tokens + + def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + input_ids = tokens['input_ids'].to(self.device if device is None else device) + if mode == 'decode': + # seqs is a tensor of token ids + if isinstance(seqs, torch.Tensor): + seqs = seqs.cpu().numpy() + decoded = [] + for i, seq in enumerate(seqs): + chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] + # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code + # The length is not always available, so use len(chars) as fallback + formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') + decoded.append(formatted) + return decoded + return input_ids + +class HFAbRestore(AbRestore): + def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): + super().__init__(spread=spread, device=device, ncpu=ncpu) + self.used_device = device + self._hf_model = hf_model + self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) + + @property + def AbLang(self): + def model_call(x): + output = self._hf_model(x) + if hasattr(output, 'last_hidden_state'): + return output.last_hidden_state + return output + return model_call + +def add_angle_brackets(seq): + # Assumes input is 'VH|VL' or 'VH|' or '|VL' + if '|' in seq: + vh, vl = seq.split('|', 1) + else: + vh, vl = seq, '' + return f"<{vh}>|<{vl}>" + +class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): + """ + Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. + Automatically uses CUDA if available, otherwise CPU. + """ + def __init__(self, model, tokenizer, device=None, ncpu=1): + super().__init__() + if device is None: + self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.used_device = torch.device(device) + self.AbLang = model # HuggingFace model instance + self.tokenizer = tokenizer + self.AbLang.to(self.used_device) + self.AbLang.eval() + # Always get AbRep from the underlying model + if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): + self.AbRep = self.AbLang.model.AbRep + else: + raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") + self.ncpu = ncpu + self.spread = 11 # For compatibility with original utilities + # The following is no longer needed since all_special_tokens now returns IDs directly + # self.tokenizer.all_special_token_ids = [ + # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens + # ] + # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens + # self.tokenizer.all_special_tokens = [ + # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str + # ] + + def freeze(self): + self.AbLang.eval() + + def unfreeze(self): + self.AbLang.train() + + def _encode_sequences(self, seqs): + # Use HuggingFace-style padding and return PyTorch tensors + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + tokens = extract_input_ids(tokens, self.used_device) + return self.AbRep(tokens).last_hidden_states.detach() + + def _predict_logits(self, seqs): + tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') + tokens = extract_input_ids(tokens, self.used_device) + output = self.AbLang(tokens) + if hasattr(output, 'last_hidden_state'): + return output.last_hidden_state.detach() + return output.detach() + + def _preprocess_labels(self, labels): + labels = extract_input_ids(labels, self.used_device) + return labels + + def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): + """ + Use different modes for different usecases, mimicking the original pretrained class. + """ + from ablang2.pretrained import format_seq_input + + valid_modes = [ + 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', + 'pseudo_log_likelihood', 'confidence' + ] + if mode not in valid_modes: + raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") + + seqs, chain = format_seq_input(seqs, fragmented=fragmented) + + if align: + numbered_seqs, seqs, number_alignment = self.number_sequences( + seqs, chain=chain, fragmented=fragmented + ) + else: + numbered_seqs = None + number_alignment = None + + subset_list = [] + for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: + subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) + + return self.reformat_subsets( + subset_list, + mode=mode, + align=align, + numbered_seqs=numbered_seqs, + seqs=seqs, + number_alignment=number_alignment, + ) + + def pseudo_log_likelihood(self, seqs, **kwargs): + """ + Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. + """ + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + # Tokenize all sequences in batch + labels = self.tokenizer( + formatted_seqs, padding=True, return_tensors='pt' + ) + labels = extract_input_ids(labels, self.used_device) + + # Convert special tokens to IDs + if isinstance(self.tokenizer.all_special_tokens[0], int): + special_token_ids = set(self.tokenizer.all_special_tokens) + else: + special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) + pad_token_id = self.tokenizer.pad_token_id + + mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) + if mask_token_id is None: + mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + plls = [] + with torch.no_grad(): + for i, seq_label in enumerate(labels): + seq_pll = [] + for j, token_id in enumerate(seq_label): + if token_id.item() in special_token_ids or token_id.item() == pad_token_id: + continue + masked = seq_label.clone() + masked[j] = mask_token_id + logits = self.AbLang(masked.unsqueeze(0)) + if hasattr(logits, 'last_hidden_state'): + logits = logits.last_hidden_state + logits = logits[0, j] + nll = torch.nn.functional.cross_entropy( + logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" + ) + seq_pll.append(-nll.item()) + if seq_pll: + plls.append(np.mean(seq_pll)) + else: + plls.append(float('nan')) + return np.array(plls) + + def confidence(self, seqs, **kwargs): + """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + plls = [] + for seq in formatted_seqs: + tokens = self.tokenizer([seq], padding=True, return_tensors='pt') + input_ids = extract_input_ids(tokens, self.used_device) + + with torch.no_grad(): + output = self.AbLang(input_ids) + if hasattr(output, 'last_hidden_state'): + logits = output.last_hidden_state + else: + logits = output + + # Get the sequence (remove batch dimension) + logits = logits[0] # [seq_len, vocab_size] + input_ids = input_ids[0] # [seq_len] + + # Exclude all special tokens (pad, mask, etc.) + if isinstance(self.tokenizer.all_special_tokens[0], int): + special_token_ids = set(self.tokenizer.all_special_tokens) + else: + special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) + valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) + + if valid_mask.sum() > 0: + valid_logits = logits[valid_mask] + valid_labels = input_ids[valid_mask] + + # Calculate cross-entropy loss + nll = torch.nn.functional.cross_entropy( + valid_logits, + valid_labels, + reduction="mean" + ) + pll = -nll.item() + else: + pll = 0.0 + + plls.append(pll) + + return np.array(plls, dtype=np.float32) + + def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): + """ + Probability of mutations - applies softmax to logits to get probabilities + """ + # Format input: join VH and VL with '|' + formatted_seqs = [] + for s in seqs: + if isinstance(s, (list, tuple)): + formatted_seqs.append('|'.join(s)) + else: + formatted_seqs.append(s) + + # Get logits + if stepwise_masking: + # For stepwise masking, we need to implement it similar to likelihood + # This is a simplified version - you might want to implement full stepwise masking + logits = self._predict_logits(formatted_seqs) + else: + logits = self._predict_logits(formatted_seqs) + + # Apply softmax to get probabilities + probs = logits.softmax(-1).cpu().numpy() + + if align: + return probs + else: + # Return residue-level probabilities (excluding special tokens) + return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] + + def restore(self, seqs, align=False, **kwargs): + hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) + restored = hf_abrestore.restore(seqs, align=align) + # Apply angle brackets formatting + if isinstance(restored, np.ndarray): + restored = np.array([add_angle_brackets(seq) for seq in restored]) + else: + restored = [add_angle_brackets(seq) for seq in restored] + return restored + +def extract_input_ids(tokens, device): + if hasattr(tokens, 'input_ids'): + return tokens.input_ids.to(device) + elif isinstance(tokens, dict): + if 'input_ids' in tokens: + return tokens['input_ids'].to(device) + else: + for v in tokens.values(): + if hasattr(v, 'ndim') or torch.is_tensor(v): + return v.to(device) + elif torch.is_tensor(tokens): + return tokens.to(device) + else: + raise ValueError("Could not extract input_ids from tokenizer output") \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..53db5301093e416c4c3781509e63272d840b9615 --- /dev/null +++ b/config.json @@ -0,0 +1,18 @@ +{ + "model_type": "ablang2-paired", + "vocab_size": 26, + "hidden_embed_size": 480, + "n_attn_heads": 20, + "n_encoder_blocks": 12, + "padding_tkn": 21, + "mask_tkn": 23, + "layer_norm_eps": 1e-12, + "a_fn": "swiglu", + "dropout": 0.0, + "tokenizer_class": "AbLang2PairedTokenizer", + "auto_map": { + "AutoConfig": "configuration_ablang2paired.AbLang2PairedConfig", + "AutoModel": "modeling_ablang2paired.AbLang2PairedHFModel", + "AutoTokenizer": ["tokenizer_ablang2paired.AbLang2PairedTokenizer", "tokenizer_ablang2paired.AbLang2PairedTokenizer"] + } +} diff --git a/configuration_ablang2paired.py b/configuration_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..844e53b7a2c748fcb423e2d0fbc3fc15ec7faeb4 --- /dev/null +++ b/configuration_ablang2paired.py @@ -0,0 +1,31 @@ +from transformers import PretrainedConfig + +class AbLang2PairedConfig(PretrainedConfig): + model_type = "ablang2-paired" + + def __init__( + self, + vocab_size=26, + hidden_embed_size=480, + n_attn_heads=20, + n_encoder_blocks=12, + padding_tkn=21, + mask_tkn=23, + layer_norm_eps=1e-12, + a_fn="swiglu", + dropout=0.0, + **kwargs + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_embed_size = hidden_embed_size + self.hidden_size = hidden_embed_size # Add this for Hugging Face compatibility + self.n_attn_heads = n_attn_heads + self.num_attention_heads = n_attn_heads # Add this for Hugging Face compatibility + self.num_hidden_layers = n_encoder_blocks # Add this for Hugging Face compatibility + self.n_encoder_blocks = n_encoder_blocks + self.padding_tkn = padding_tkn + self.mask_tkn = mask_tkn + self.layer_norm_eps = layer_norm_eps + self.a_fn = a_fn + self.dropout = dropout \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..47b3456756d77a882f16e17885d97510929f36cf --- /dev/null +++ b/environment.yaml @@ -0,0 +1,44 @@ +name: AbLang +channels: + - conda-forge + - pytorch + - bioconda + - defaults +dependencies: + - python=3.10.18 + - pip + - pytorch=2.5.1 + - pytorch-cuda=12.4 + - numpy=2.2.6 + - pandas=2.3.1 + - transformers=4.53.3 + - anarci=2024.05.21 + - jupyter=7.4.4 + - notebook=7.4.4 + - ipython=8.37.0 + - ipykernel=6.29.5 + - matplotlib-inline=0.1.7 + - scikit-learn + - matplotlib + - seaborn + - biopython=1.85 + - huggingface_hub=0.33.4 + - tokenizers=0.21.3 + - safetensors=0.5.3 + - einops=0.8.1 + - tqdm=4.67.1 + - requests=2.32.4 + - urllib3=2.5.0 + - certifi=2025.7.14 + - filelock=3.18.0 + - fsspec=2025.3.0 + - packaging=25.0 + - regex=2024.11.6 + - sympy=1.13.3 + - networkx=3.4.2 + - jinja2=3.1.6 + - pyyaml=6.0.2 + - typing_extensions=4.14.1 + - pip: + - numba=0.61.2 + - llvmlite=0.44.0 \ No newline at end of file diff --git a/hparams.json b/hparams.json new file mode 100755 index 0000000000000000000000000000000000000000..65a58d738f1183520c5a26f1472fcc556524948b --- /dev/null +++ b/hparams.json @@ -0,0 +1 @@ +{"name": "AbLang-2", "n_encoder_blocks": 12, "hidden_embed_size": 480, "n_attn_heads": 20, "a_fn": "swiglu", "layer_norm_eps": 1e-12, "pad_tkn": 21, "start_tkn": 0, "end_tkn": 22, "sep_tkn": 25, "mask_tkn": 23, "vocab_size": 26} \ No newline at end of file diff --git a/model.pt b/model.pt new file mode 100755 index 0000000000000000000000000000000000000000..945049bdf05f02b513c1fca72406ca08b01d207e --- /dev/null +++ b/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56d6f07862a6f824f88c8707bbc03e4026c9db762be2d3041e9767e2e6f86386 +size 179314477 diff --git a/modeling_ablang2paired.py b/modeling_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc14acb5bd97fc88177e10ceec908baafdf3089 --- /dev/null +++ b/modeling_ablang2paired.py @@ -0,0 +1,81 @@ +import torch +import os +from torch import nn +from transformers import PreTrainedModel +from ablang2.models.ablang2.ablang import AbLang as AbLang2 +from ablang2.configuration_ablang2paired import AbLang2PairedConfig + +class AbLang2PairedHFModel(PreTrainedModel): + config_class = AbLang2PairedConfig + model_type = "ablang2-paired" + + def __init__(self, config: AbLang2PairedConfig): + super().__init__(config) + self.model = AbLang2( + vocab_size=config.vocab_size, + hidden_embed_size=config.hidden_embed_size, + n_attn_heads=config.n_attn_heads, + n_encoder_blocks=config.n_encoder_blocks, + padding_tkn=config.padding_tkn, + mask_tkn=config.mask_tkn, + layer_norm_eps=config.layer_norm_eps, + a_fn=config.a_fn, + dropout=config.dropout, + ) + + def forward(self, input_ids=None, x=None, attention_mask=None, **kwargs): + # Handle both Hugging Face format (input_ids) and original format (x) + if input_ids is not None: + x = input_ids + elif x is None: + raise ValueError("Either input_ids or x must be provided") + + # Get the output from the underlying model + output = self.model(x, attention_mask) + + # Return as a simple object with last_hidden_state attribute + class ModelOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + return ModelOutput(output) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Check if we have custom weights + model_path = pretrained_model_name_or_path + custom_weights_path = os.path.join(model_path, "model.pt") + + if os.path.exists(custom_weights_path): + # Load config + config = kwargs.get("config") + if config is None: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Create model with only the config argument + model = cls(config) + + # Load custom weights + state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True) + model.model.load_state_dict(state_dict) + + # Move model to appropriate device (GPU if available, otherwise CPU) + device = kwargs.get("device", None) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + return model + else: + # Fall back to standard Hugging Face loading + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def save_pretrained(self, save_directory, **kwargs): + os.makedirs(save_directory, exist_ok=True) + # Save custom weights + torch.save(self.model.state_dict(), f"{save_directory}/model.pt") + # Save config + self.config.save_pretrained(save_directory) + # Call parent method for any additional saving + super().save_pretrained(save_directory, **kwargs) \ No newline at end of file diff --git a/test_ablang2_HF_implementation.ipynb b/test_ablang2_HF_implementation.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..530b62ba6c2f5933ef224c8d4b07023a24ebbea0 --- /dev/null +++ b/test_ablang2_HF_implementation.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "458aed0f", + "metadata": {}, + "source": [ + "Note: This notebook is adapted from the [AbLang2](https://github.com/TobiasHeOl/AbLang2) model's GitHub repository. It is used to verify that the Hugging Face implementation functions correctly and produces the same output as the original model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np \n", + "from transformers import AutoTokenizer, AutoModel\n", + "from ablang2.adapter import AbLang2PairedHuggingFaceAdapter" + ] + }, + { + "cell_type": "markdown", + "id": "10801511-770d-46ac-a15d-a02d4ef9ec87", + "metadata": {}, + "source": [ + "# **0. Sequence input and its format**\n", + "\n", + "AbLang2 takes as input either the individual heavy variable domain (VH), light variable domain (VL), or the full variable domain (Fv).\n", + "\n", + "Each record (antibody) needs to be a list with the VH as the first element and the VL as the second. If either the VH or VL is not known, leave an empty string.\n", + "\n", + "An asterisk (\\*) is used for masking. It is recommended to mask residues which you are interested in mutating.\n", + "\n", + "**NB:** It is important that the VH and VL sequence is ordered correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "99192978-a008-4a32-a80e-bba238e0ec7c", + "metadata": {}, + "outputs": [], + "source": [ + "seq1 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS', # VH sequence\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK' # VL sequence\n", + "]\n", + "seq2 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT',\n", + " 'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "seq3 = [\n", + " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',\n", + " '' # The VL sequence is not known, so an empty string is left instead. \n", + "]\n", + "seq4 = [\n", + " '',\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "seq5 = [\n", + " 'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', # (*) is used to mask certain residues\n", + " 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "\n", + "all_seqs = [seq1, seq2, seq3, seq4, seq5]\n", + "only_both_chains_seqs = [seq1, seq2, seq5]" + ] + }, + { + "cell_type": "markdown", + "id": "dffbacfa-8642-4d94-9572-2205a05c18f9", + "metadata": {}, + "source": [ + "# **1. How to use AbLang2**\n", + "\n", + "AbLang2 can be downloaded and used in its raw form as seen below. For convenience, we have also developed different \"modes\" which can be used for specific use cases (see Section 2) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d66ad84", + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModel.from_pretrained(\"/hemantn/ablang2/\", trust_remote_code=True)\n", + "tokenizer = AutoTokenizer.from_pretrained(\"/hemantn/ablang2/\", trust_remote_code=True)\n", + "ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "48562761-6ebe-4025-be97-918c9f9eff7e", + "metadata": {}, + "source": [ + "# **2. Different modes for specific usecases**\n", + "\n", + "AbLang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.\n", + "\n", + "1. seqcoding: Generates sequence representations for each sequence\n", + "2. rescoding: Generates residue representations for each residue in each sequence\n", + "3. likelihood: Generates likelihoods for each amino acid at each position in each sequence\n", + "4. probability: Generates probabilities for each amino acid at each position in each sequence\n", + "5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)\n", + "6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)\n", + "7. restore: Restores masked residues\n", + "\n", + "### **AbLang2 can also align the resulting representations using ANARCI**\n", + "\n", + "This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument \"align=True\".\n", + "\n", + "**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.\n", + "\n", + "### **The align argument can also be used to restore variable missing lengths**\n", + "\n", + "For this, use \"align=True\" with the 'restore' mode." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ceae4a88-0679-4704-8bad-c06a4569c497", + "metadata": {}, + "outputs": [], + "source": [ + "valid_modes = [\n", + " 'seqcoding', 'rescoding', 'likelihood', 'probability',\n", + " 'pseudo_log_likelihood', 'confidence', 'restore' \n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "aa333732-7508-4826-92ec-3acdd54bc1bb", + "metadata": {}, + "source": [ + "## **seqcoding** \n", + "\n", + "The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, including extra tokens. \n", + "\n", + "**NB:** Seqcodings can also be derived in other ways like using the sum or averaging across only parts of the input such as the CDRs. For such cases please use and adapt the below rescoding." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.25206311, 0.18189634, 0.00887137, ..., 0.15365517,\n", + " -0.14508603, -0.13381317],\n", + " [-0.25149415, 0.2086455 , 0.07518203, ..., 0.19478269,\n", + " -0.15227772, -0.08241647],\n", + " [-0.27468949, 0.16507216, 0.08667156, ..., 0.18776284,\n", + " -0.14165082, -0.16389885],\n", + " [-0.1982213 , 0.16841085, -0.04925933, ..., 0.11400164,\n", + " -0.14723683, -0.09713171],\n", + " [-0.29553188, 0.17239201, 0.05676926, ..., 0.15943622,\n", + " -0.16615383, -0.15569784]], shape=(5, 480))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(all_seqs, mode='seqcoding')\n" + ] + }, + { + "cell_type": "markdown", + "id": "4b5d9d60", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[-0.40741208, -0.5118987 , 0.06096708, ..., 0.3268144 ,\n", + " 0.03920235, -0.36715826],\n", + " [-0.5768883 , 0.38245413, -0.21791998, ..., 0.01250262,\n", + " -0.08844463, -0.32367525],\n", + " [-0.1475935 , 0.39639047, -0.38226923, ..., -0.10119921,\n", + " -0.41469565, -0.00319315],\n", + " ...,\n", + " [-0.14358369, 0.3124389 , -0.30157998, ..., -0.13289244,\n", + " -0.45353398, -0.07878865],\n", + " [ 0.17538925, 0.24394299, 0.20141171, ..., 0.14587352,\n", + " -0.38479003, 0.07409196],\n", + " [-0.23031706, -0.35487285, 0.1960684 , ..., -0.1283362 ,\n", + " 0.31107333, -0.3265108 ]], shape=(238, 480), dtype=float32),\n", + " array([[-0.41981837, -0.3666375 , 0.10595217, ..., 0.3903574 ,\n", + " 0.0382378 , -0.36337993],\n", + " [-0.5054137 , 0.38347068, -0.10992069, ..., -0.05231472,\n", + " -0.13636623, -0.34830108],\n", + " [-0.06784609, 0.69349885, -0.4212398 , ..., -0.24805346,\n", + " -0.39583805, -0.10972726],\n", + " ...,\n", + " [-0.2090099 , 0.29489496, -0.11039071, ..., -0.24245434,\n", + " -0.60625184, -0.02307999],\n", + " [ 0.19134358, 0.21744648, 0.2575827 , ..., 0.15845427,\n", + " -0.34743664, 0.10218249],\n", + " [-0.2551157 , -0.21778448, 0.21906358, ..., -0.09656111,\n", + " 0.22394855, -0.20267345]], shape=(222, 480), dtype=float32),\n", + " array([[-0.40043733, -0.48596814, 0.0886725 , ..., 0.38941646,\n", + " 0.06195956, -0.40999672],\n", + " [-0.54576075, 0.4312959 , -0.3451486 , ..., -0.09285564,\n", + " 0.03116508, -0.45269737],\n", + " [ 0.0221165 , 0.53196615, -0.30137214, ..., -0.1889072 ,\n", + " -0.32587305, 0.05078396],\n", + " ...,\n", + " [ 0.2630385 , -0.22976042, 0.5510368 , ..., 0.47436473,\n", + " -0.42733562, -0.83135855],\n", + " [-0.13752195, 0.28678602, -0.18887053, ..., 0.28262627,\n", + " 0.1254679 , -0.6496486 ],\n", + " [-0.4541417 , 0.24564984, 0.2132735 , ..., 0.03287445,\n", + " 0.03825552, -0.34259132]], shape=(124, 480), dtype=float32),\n", + " array([[-0.26863217, 0.32259187, 0.10813517, ..., 0.03953876,\n", + " 0.18312076, -0.00498045],\n", + " [-0.2165424 , -0.38562432, -0.02696264, ..., 0.20541488,\n", + " 0.18698391, -0.22639504],\n", + " [-0.41950518, 0.04743317, 0.0048816 , ..., 0.11408642,\n", + " -0.05384652, 0.1025871 ],\n", + " ...,\n", + " [-0.10960457, 0.35151365, -0.21752454, ..., -0.21448943,\n", + " -0.6396219 , -0.00839792],\n", + " [ 0.20491892, 0.36294487, 0.19217414, ..., 0.07750722,\n", + " -0.5039212 , 0.03793833],\n", + " [-0.11638474, -0.35350856, 0.13215722, ..., -0.1606055 ,\n", + " 0.23913842, -0.2565337 ]], shape=(115, 480), dtype=float32),\n", + " array([[-0.42062947, -0.44009134, 0.00152371, ..., 0.27141467,\n", + " 0.03798106, -0.397461 ],\n", + " [-0.57318133, 0.5258899 , -0.17001636, ..., -0.23864633,\n", + " 0.2088059 , -0.57877594],\n", + " [-0.38988614, 0.46168196, -0.3429413 , ..., -0.14872643,\n", + " -0.46576905, -0.21224979],\n", + " ...,\n", + " [-0.21528634, 0.30046722, -0.25216463, ..., -0.11576828,\n", + " -0.4704907 , -0.0740136 ],\n", + " [ 0.0633081 , 0.22700705, 0.28184187, ..., 0.15967266,\n", + " -0.377182 , 0.06188517],\n", + " [-0.27826303, -0.37297496, 0.21229912, ..., -0.14886017,\n", + " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(all_seqs, mode='rescoding', stepwise_masking = False)" + ] + }, + { + "cell_type": "markdown", + "id": "6da2183b-4306-49bd-a7fc-23e78a23f305", + "metadata": {}, + "source": [ + "## **Align rescoding/likelihood/probability output**\n", + "\n", + "For the 'rescoding', 'likelihood', and 'probability' modes, the output can also be aligned using the argument \"align=True\".\n", + "\n", + "This is done using the antibody numbering tool ANARCI, and requires manually installing **Pandas** and **[ANARCI](https://github.com/oxpig/ANARCI)**.\n", + "\n", + "**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n", + " '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n", + " '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n", + " '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n", + " '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n", + " '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n", + " '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n", + " '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n", + " '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n", + " '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n", + " '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n", + " '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n", + " '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n", + " '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n", + " '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n", + " '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n", + " '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n", + " '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n", + " '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n", + " '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n", + " '125 ' '126 ' '127 ' '>']\n", + "['|', '|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|']\n", + "[[[ 9.31621838 -3.42184329 -3.59397745 ... -14.73707485 -6.8935833\n", + " -0.23662776]\n", + " [ -3.54718232 -5.84866619 -4.02423859 ... -12.93966579 -9.5614481\n", + " -4.48473835]\n", + " [-11.94997597 -2.245543 -5.69481373 ... -15.19639015 -17.97454071\n", + " -12.56952095]\n", + " ...\n", + " [ -8.94504833 -0.42261261 -4.95588207 ... -16.66817474 -15.2224741\n", + " -10.37267494]\n", + " [-11.65150356 -5.44477606 -2.95585775 ... -16.25555801 -9.75158596\n", + " -11.75897026]\n", + " [ 1.79469728 -1.95846701 -3.59784532 ... -14.95585823 -7.47080708\n", + " -0.95226753]]\n", + "\n", + " [[ 8.55518723 -3.83663297 -2.33595967 ... -13.87456799 -8.14840603\n", + " -0.42472434]\n", + " [ -4.40701294 -5.53201008 -3.69397402 ... -12.97877789 -9.86258411\n", + " -4.95414352]\n", + " [-11.95642853 -3.86210871 -5.80935192 ... -14.89213085 -16.94556236\n", + " -11.36959839]\n", + " ...\n", + " [ -7.75924015 -0.66524202 -4.08643246 ... -16.16580772 -14.76507473\n", + " -8.3507061 ]\n", + " [-11.91039753 -4.86995983 -2.74777436 ... -16.07694817 -8.44974899\n", + " -10.45223904]\n", + " [ 0.86006832 -2.37964034 -3.58130741 ... -15.35423565 -7.73035526\n", + " -1.11989737]]\n", + "\n", + " [[ -4.37902737 -7.55587149 1.21958363 ... -15.48622513 -6.021842\n", + " -3.79647374]\n", + " [ 0. 0. 0. ... 0. 0.\n", + " 0. ]\n", + " [ 0. 0. 0. ... 0. 0.\n", + " 0. ]\n", + " ...\n", + " [ -8.94207573 -0.51090252 -5.09760332 ... -16.69521713 -15.45450687\n", + " -10.50823212]\n", + " [-11.92354965 -5.55152607 -2.87666893 ... -16.40607834 -10.19431686\n", + " -12.1328764 ]\n", + " [ 2.42200375 -2.01573253 -3.61701298 ... -14.9590435 -7.19029331\n", + " -0.89830256]]]\n" + ] + } + ], + "source": [ + "results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n", + "\n", + "print(results.number_alignment)\n", + "print(results.aligned_seqs)\n", + "print(results.aligned_embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "56be8cad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[9.9955505e-01, 2.9358694e-06, 2.4716087e-06, ..., 3.5776201e-11,\n", + " 9.1196831e-08, 7.0967326e-05],\n", + " [4.1573694e-06, 4.1619489e-07, 2.5800944e-06, ..., 3.4650952e-10,\n", + " 1.0159109e-08, 1.6279575e-06],\n", + " [7.8059600e-08, 1.2794037e-03, 4.0645118e-05, ..., 3.0375720e-09,\n", + " 1.8879491e-10, 4.2010839e-08],\n", + " ...,\n", + " [3.4210879e-07, 1.7195340e-03, 1.8477240e-05, ..., 1.5137445e-10,\n", + " 6.4255873e-10, 8.2064140e-08],\n", + " [9.1038084e-09, 4.5161755e-06, 5.4411950e-05, ..., 9.1139631e-11,\n", + " 6.0862085e-08, 8.1761966e-09],\n", + " [8.5759175e-04, 2.0104915e-05, 3.9023766e-06, ..., 4.5562460e-11,\n", + " 8.1156479e-08, 5.4990651e-05]], shape=(238, 26), dtype=float32),\n", + " array([[9.9939799e-01, 4.1499175e-06, 1.8611167e-05, ..., 1.8139243e-10,\n", + " 5.5649299e-08, 1.2583815e-04],\n", + " [1.6735513e-06, 5.4332406e-07, 3.4143472e-06, ..., 3.1693398e-10,\n", + " 7.1501400e-09, 9.6832969e-07],\n", + " [3.7784993e-08, 1.2377645e-04, 1.7658784e-05, ..., 2.0061326e-09,\n", + " 2.5737484e-10, 6.7947965e-08],\n", + " ...,\n", + " [1.1050455e-06, 1.3312638e-03, 4.3497097e-05, ..., 2.4686178e-10,\n", + " 1.0018089e-09, 6.1165900e-07],\n", + " [5.7270397e-09, 6.5396339e-06, 5.4601755e-05, ..., 8.8801404e-11,\n", + " 1.8233513e-07, 2.4615032e-08],\n", + " [7.3952030e-04, 2.8970928e-05, 8.7113440e-06, ..., 6.7168833e-11,\n", + " 1.3746008e-07, 1.0210846e-04]], shape=(222, 26), dtype=float32),\n", + " array([[9.99685407e-01, 3.35662639e-06, 1.14241482e-06, ...,\n", + " 2.32460891e-11, 6.88188067e-08, 5.69467156e-05],\n", + " [6.38133372e-07, 1.01300586e-07, 5.64459742e-06, ...,\n", + " 4.09234556e-11, 2.53804799e-09, 4.31722100e-07],\n", + " [1.49096788e-08, 2.04515047e-04, 9.23794141e-06, ...,\n", + " 7.46306961e-10, 2.92107380e-11, 2.21786500e-08],\n", + " ...,\n", + " [2.15093763e-07, 1.06453872e-03, 1.62486140e-05, ...,\n", + " 1.12102910e-10, 1.47300866e-10, 4.73037538e-08],\n", + " [4.30136682e-09, 3.09317988e-06, 3.96632568e-05, ...,\n", + " 5.24226877e-11, 2.39579450e-08, 3.86403221e-09],\n", + " [9.77773685e-04, 1.29533228e-05, 2.78623725e-06, ...,\n", + " 2.73364300e-11, 3.96418649e-08, 4.04014427e-05]],\n", + " shape=(238, 26), dtype=float32)]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ablang(only_both_chains_seqs, mode='probability')" + ] + }, + { + "cell_type": "markdown", + "id": "8f0a71ec-e916-4330-90d0-13a4b1121a89", + "metadata": {}, + "source": [ + "## **Pseudo log likelihood and Confidence scores**\n", + "\n", + "The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.\n", + "\n", + "- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).\n", + "\n", + "- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. \n", + "\n", + "**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.\n", + "\n", + "**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "83f3064b-48a7-42fb-ba82-ec153ea946da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = ablang(all_seqs, mode='pseudo_log_likelihood')\n", + "np.exp(-results) # convert to pseudo perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.2636038, 1.126463 , 1.3123759, 1.2140924, 1.1805094],\n", + " dtype=float32)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = ablang(all_seqs, mode='confidence')\n", + "np.exp(-results)" + ] + }, + { + "cell_type": "markdown", + "id": "e0b63e48-b2a1-4a8e-8ecb-449748a2cb25", + "metadata": {}, + "source": [ + "## **restore**\n", + "\n", + "This mode can be used to restore masked residues, and fragmented regions with \"align=True\". " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['|',\n", + " '|',\n", + " '|'],\n", + " dtype='|',\n", + " '|',\n", + " '|'],\n", + " dtype='|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "b046ae57", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "78ccf7d8", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the VH and VL.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lib_transformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tokenizer_ablang2paired.py b/tokenizer_ablang2paired.py new file mode 100644 index 0000000000000000000000000000000000000000..5522bbaa0d134f3a1e490d6e58114f1a08a05a9c --- /dev/null +++ b/tokenizer_ablang2paired.py @@ -0,0 +1,97 @@ +import json +import os +from transformers import PreTrainedTokenizer + + +class AbLang2PairedTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "vocab.json"} + model_input_names = ["input_ids"] + + def __init__(self, vocab_file=None, **kwargs): + if vocab_file is None: + # Try to find vocab file in the current directory + vocab_file = "vocab.json" + + self.vocab_file = vocab_file + with open(vocab_file, "r", encoding="utf-8") as f: + self.vocab = json.load(f) + + # Set required token attributes (all as strings, standard for HF) + kwargs.setdefault("pad_token", "-") + kwargs.setdefault("mask_token", "*") + kwargs.setdefault("unk_token", "X") + + super().__init__(**kwargs) + + @property + def pad_token_id(self): + return self.vocab[self.pad_token] + + @property + def mask_token_id(self): + return self.vocab[self.mask_token] + + def _tokenize(self, text): + return list(text) + + def tokenize(self, text, text_pair=None, **kwargs): + """Tokenize text or text pair.""" + if text_pair is not None: + # For paired sequences, combine them with a separator + combined_text = text + "|" + text_pair + return self._tokenize(combined_text) + else: + return self._tokenize(text) + + def _convert_token_to_id(self, token): + return self.vocab.get(token, self.vocab[self.unk_token]) + + def _convert_id_to_token(self, index): + inv_vocab = {v: k for k, v in self.vocab.items()} + return inv_vocab.get(index, self.unk_token) + + def get_vocab(self): + return self.vocab + + def save_vocabulary(self, save_directory, filename_prefix=None): + os.makedirs(save_directory, exist_ok=True) + path = os.path.join(save_directory, (filename_prefix or "") + "vocab.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(self.vocab, f) + return (path,) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") + if not os.path.exists(vocab_file): + raise ValueError(f"Vocabulary file {vocab_file} not found") + return cls(vocab_file=vocab_file, **kwargs) + + def save_pretrained(self, save_directory, filename_prefix=None): + os.makedirs(save_directory, exist_ok=True) + vocab_files = self.save_vocabulary(save_directory, filename_prefix) + + tokenizer_config = { + "tokenizer_class": f"{self.__class__.__module__}.{self.__class__.__name__}" + } + with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: + json.dump(tokenizer_config, f, indent=2) + + return vocab_files + + def __call__(self, sequences, padding=False, return_tensors=None, **kwargs): + # Accepts a string or a list of strings + if isinstance(sequences, str): + sequences = [sequences] + # Tokenize each sequence + input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences] + # Padding + if padding: + maxlen = max(len(ids) for ids in input_ids) + input_ids = [ids + [self.pad_token_id] * (maxlen - len(ids)) for ids in input_ids] + # Return tensors if requested + if return_tensors == 'pt': + import torch + input_ids = torch.tensor(input_ids) + return {'input_ids': input_ids} + diff --git a/vocab.json b/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..647358ba2a1b3199ed1c483af2bfaf4ccf844baa --- /dev/null +++ b/vocab.json @@ -0,0 +1,28 @@ +{ + "<": 0, + "M": 1, + "R": 2, + "H": 3, + "K": 4, + "D": 5, + "E": 6, + "S": 7, + "T": 8, + "N": 9, + "Q": 10, + "C": 11, + "G": 12, + "P": 13, + "A": 14, + "V": 15, + "I": 16, + "F": 17, + "Y": 18, + "W": 19, + "L": 20, + "-": 21, + ">": 22, + "*": 23, + "X": 24, + "|": 25 +} \ No newline at end of file diff --git a/vocab.py b/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..25ad702a2ee8c04a6c8cde038a27bb7ce18b2d14 --- /dev/null +++ b/vocab.py @@ -0,0 +1,28 @@ +ablang_vocab = { + "<": 0, + "M": 1, + "R": 2, + "H": 3, + "K": 4, + "D": 5, + "E": 6, + "S": 7, + "T": 8, + "N": 9, + "Q": 10, + "C": 11, + "G": 12, + "P": 13, + "A": 14, + "V": 15, + "I": 16, + "F": 17, + "Y": 18, + "W": 19, + "L": 20, + "-": 21, + ">": 22, + "*": 23, + "X": 24, + "|": 25 +} \ No newline at end of file