| | from huggingface_hub import HfApi |
| | import torch |
| | from tqdm import tqdm |
| | from transformers import AutoTokenizer, AutoModelForMaskedLM |
| | from transformers.tokenization_utils_base import BatchEncoding |
| | from transformers.modeling_outputs import MaskedLMOutput |
| |
|
| | |
| | def get_models() -> list[None|str]: |
| | """Fetch suitable ESM models from HuggingFace Hub.""" |
| | if not any( |
| | out := [ |
| | m.modelId for m in HfApi().list_models( |
| | author="facebook", |
| | model_name="esm", |
| | task="fill-mask", |
| | sort="lastModified", |
| | direction=-1 |
| | ) |
| | ] |
| | ): |
| | raise RuntimeError("Error while retrieving models from HuggingFace Hub") |
| | return out |
| |
|
| | |
| | class Model: |
| | """Wrapper for ESM models.""" |
| | def __init__(self, model_name: str = ""): |
| | """Load selected model and tokenizer.""" |
| | self.model_name = model_name |
| | if model_name: |
| | self.model = AutoModelForMaskedLM.from_pretrained(model_name) |
| | self.batch_converter = AutoTokenizer.from_pretrained(model_name) |
| | self.alphabet = self.batch_converter.get_vocab() |
| | |
| | if torch.cuda.is_available(): |
| | self.model = self.model.cuda() |
| | self.device = torch.device("cuda") |
| | else: |
| | self.device = torch.device("cpu") |
| |
|
| | def tokenise(self, input: str) -> BatchEncoding: |
| | """Convert input string to batch of tokens.""" |
| | return self.batch_converter(input, return_tensors="pt") |
| |
|
| | def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput: |
| | """Run model on batch of tokens.""" |
| | return self.model(batch_tokens.to(self.device), **kwargs) |
| |
|
| | def __getitem__(self, key: str) -> int: |
| | """Get token ID from character.""" |
| | return self.alphabet[key] |
| |
|
| | def run_model(self, data): |
| | """Run model on data.""" |
| | def label_row(row, token_probs): |
| | """Label row with score.""" |
| | |
| | wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] |
| | |
| | score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] |
| | return score.item() |
| |
|
| | |
| | batch_tokens = self.tokenise(data.seq).input_ids |
| |
|
| | |
| | with torch.no_grad(): |
| | token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1) |
| | |
| | data.token_probs = token_probs.cpu().numpy() |
| |
|
| | |
| | if data.scoring_strategy.startswith("masked-marginals"): |
| | all_token_probs = [] |
| | |
| | for i in tqdm(range(batch_tokens.size()[1])): |
| | |
| | if i in data.resi: |
| | |
| | batch_tokens_masked = batch_tokens.clone() |
| | batch_tokens_masked[0, i] = self['<mask>'] |
| | |
| | with torch.no_grad(): |
| | masked_token_probs = torch.log_softmax( |
| | self(batch_tokens_masked).logits, dim=-1 |
| | ) |
| | else: |
| | |
| | masked_token_probs = token_probs |
| | |
| | all_token_probs.append(masked_token_probs[:, i]) |
| | |
| | token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) |
| |
|
| | |
| | data.out[self.model_name] = data.sub.apply( |
| | lambda row: label_row( |
| | row['0'], |
| | token_probs, |
| | ), |
| | axis=1, |
| | ) |