Instructions to use Taykhoom/RNA-MSM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/RNA-MSM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/RNA-MSM", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import json | |
| import os | |
| from typing import Dict, List, Optional, Union | |
| import torch | |
| from transformers import PreTrainedTokenizer | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| _VOCAB = { | |
| "<cls>": 0, | |
| "<pad>": 1, | |
| "<eos>": 2, | |
| "<unk>": 3, | |
| "A": 4, | |
| "G": 5, | |
| "C": 6, | |
| "U": 7, | |
| "X": 8, | |
| "N": 9, | |
| "-": 10, | |
| "<mask>": 11, | |
| } | |
| class RNAMSMTokenizer(PreTrainedTokenizer): | |
| """ | |
| Tokenizer for RNA-MSM. | |
| Vocabulary: <cls>(0) <pad>(1) <eos>(2) <unk>(3) A(4) G(5) C(6) U(7) X(8) N(9) -(10) <mask>(11) | |
| RNA-MSM is an MSA Transformer: it always expects 3D input | |
| (batch, num_alignments, seqlen). This tokenizer treats each input string | |
| as a single-sequence MSA (1 alignment row), so the standard __call__ API: | |
| enc = tokenizer(["AGCU", "GAUC"], return_tensors="pt", padding=True) | |
| # enc.input_ids: (2, 1, T) -- batch of 2 single-sequence MSAs | |
| For real MSAs (multiple aligned sequences), use encode_msa(): | |
| enc = tokenizer.encode_msa([["AGCU--", "AGCUUU"]], return_tensors="pt") | |
| # enc["input_ids"]: (1, 2, T) -- 1 MSA with 2 alignment rows | |
| """ | |
| vocab_files_names = {"vocab_file": "vocab.json"} | |
| model_input_names = ["input_ids", "attention_mask"] | |
| def __init__( | |
| self, | |
| vocab_file=None, | |
| cls_token="<cls>", | |
| pad_token="<pad>", | |
| eos_token="<eos>", | |
| unk_token="<unk>", | |
| mask_token="<mask>", | |
| **kwargs, | |
| ): | |
| if vocab_file and os.path.isfile(vocab_file): | |
| with open(vocab_file) as f: | |
| self._vocab = json.load(f) | |
| else: | |
| self._vocab = dict(_VOCAB) | |
| self._ids_to_tokens = {v: k for k, v in self._vocab.items()} | |
| super().__init__( | |
| cls_token=cls_token, | |
| pad_token=pad_token, | |
| eos_token=eos_token, | |
| unk_token=unk_token, | |
| mask_token=mask_token, | |
| **kwargs, | |
| ) | |
| def vocab_size(self): | |
| return len(self._vocab) | |
| def get_vocab(self): | |
| return dict(self._vocab) | |
| def _tokenize(self, text): | |
| return list(text) | |
| def _convert_token_to_id(self, token): | |
| return self._vocab.get(token, self._vocab["<unk>"]) | |
| def _convert_id_to_token(self, index): | |
| return self._ids_to_tokens.get(index, "<unk>") | |
| def save_vocabulary(self, save_directory, filename_prefix=None): | |
| os.makedirs(save_directory, exist_ok=True) | |
| fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" | |
| path = os.path.join(save_directory, fname) | |
| with open(path, "w") as f: | |
| json.dump(self._vocab, f, indent=2) | |
| return (path,) | |
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |
| cls = [self.cls_token_id] | |
| if token_ids_1 is None: | |
| return cls + token_ids_0 | |
| return cls + token_ids_0 + cls + token_ids_1 | |
| def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, | |
| already_has_special_tokens=False): | |
| if already_has_special_tokens: | |
| return super().get_special_tokens_mask( | |
| token_ids_0, token_ids_1, already_has_special_tokens=True) | |
| mask = [1] + [0] * len(token_ids_0) | |
| if token_ids_1 is not None: | |
| mask += [1] + [0] * len(token_ids_1) | |
| return mask | |
| def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): | |
| if token_ids_1 is None: | |
| return [0] + token_ids_0 | |
| return [0] + token_ids_0 + [0] + token_ids_1 | |
| def __call__( | |
| self, | |
| text, | |
| text_pair=None, | |
| add_special_tokens=True, | |
| padding=False, | |
| truncation=False, | |
| max_length=None, | |
| return_tensors=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Tokenize one or more sequences, each treated as a 1-row MSA. | |
| text: str or List[str] | |
| Returns dict with input_ids of shape (batch, 1, seqlen) and | |
| attention_mask of shape (batch, 1, seqlen). | |
| """ | |
| if isinstance(text, str): | |
| sequences = [text] | |
| else: | |
| sequences = list(text) | |
| encoded = [] | |
| for seq in sequences: | |
| ids = self._tokenize_single(seq, add_special_tokens) | |
| encoded.append(ids) | |
| if padding and len(encoded) > 1: | |
| max_len = max(len(ids) for ids in encoded) | |
| pad_id = self.pad_token_id | |
| encoded = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded] | |
| input_ids = [[ids] for ids in encoded] | |
| attention_mask = [[[1 if t != self.pad_token_id else 0 for t in ids]] | |
| for ids in encoded] | |
| if return_tensors == "pt": | |
| input_ids = torch.tensor(input_ids, dtype=torch.long) | |
| attention_mask = torch.tensor(attention_mask, dtype=torch.long) | |
| return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) | |
| return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) | |
| def _tokenize_single(self, sequence, add_special_tokens=True): | |
| tokens = list(sequence) | |
| ids = [self._convert_token_to_id(t) for t in tokens] | |
| if add_special_tokens: | |
| ids = [self.cls_token_id] + ids | |
| return ids | |
| def encode_msa( | |
| self, | |
| msas, | |
| add_special_tokens=True, | |
| padding=False, | |
| return_tensors=None, | |
| ): | |
| """ | |
| Tokenize a batch of MSAs. | |
| msas: List[List[str]] | |
| Each inner list is one MSA (multiple aligned sequences of equal length). | |
| All sequences within an MSA must have the same length. | |
| Returns dict with: | |
| input_ids: (batch, max_alignments, max_seqlen) | |
| attention_mask: (batch, max_alignments, max_seqlen) | |
| """ | |
| if isinstance(msas[0], str): | |
| msas = [msas] | |
| max_rows = max(len(msa) for msa in msas) | |
| max_seqlen = max( | |
| len(self._tokenize_single(seq, add_special_tokens)) | |
| for msa in msas for seq in msa | |
| ) | |
| pad_id = self.pad_token_id | |
| batch_ids = [] | |
| batch_mask = [] | |
| for msa in msas: | |
| msa_ids = [] | |
| msa_mask = [] | |
| for seq in msa: | |
| ids = self._tokenize_single(seq, add_special_tokens) | |
| if padding: | |
| pad_len = max_seqlen - len(ids) | |
| mask = [1] * len(ids) + [0] * pad_len | |
| ids = ids + [pad_id] * pad_len | |
| else: | |
| mask = [1] * len(ids) | |
| msa_ids.append(ids) | |
| msa_mask.append(mask) | |
| if padding: | |
| pad_row = [pad_id] * max_seqlen | |
| pad_mask_row = [0] * max_seqlen | |
| while len(msa_ids) < max_rows: | |
| msa_ids.append(pad_row) | |
| msa_mask.append(pad_mask_row) | |
| batch_ids.append(msa_ids) | |
| batch_mask.append(msa_mask) | |
| if return_tensors == "pt": | |
| batch_ids = torch.tensor(batch_ids, dtype=torch.long) | |
| batch_mask = torch.tensor(batch_mask, dtype=torch.long) | |
| return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask}) | |
| return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask}) | |
| def decode(self, token_ids, skip_special_tokens=False, **kwargs): | |
| if isinstance(token_ids, torch.Tensor): | |
| token_ids = token_ids.tolist() | |
| tokens = [self._convert_id_to_token(i) for i in token_ids] | |
| if skip_special_tokens: | |
| special = {self.cls_token, self.pad_token, self.eos_token, | |
| self.unk_token, self.mask_token} | |
| tokens = [t for t in tokens if t not in special] | |
| return "".join(tokens) | |
| def num_special_tokens_to_add(self, pair=False): | |
| return 1 | |