File size: 4,240 Bytes
0a535de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Character-level RNA tokenizer for UTR-LM."""

import json
import os
from typing import Dict, List, Optional, Tuple

from transformers import PreTrainedTokenizer

# Canonical vocab - fixed; never changes across checkpoints.
_VOCAB: Dict[str, int] = {
    "<pad>": 0,
    "<eos>": 1,
    "<unk>": 2,
    "A": 3,
    "G": 4,
    "C": 5,
    "T": 6,
    "<cls>": 7,
    "<mask>": 8,
    "<sep>": 9,
}
_IDS_TO_TOKENS: Dict[int, str] = {v: k for k, v in _VOCAB.items()}


class UtrLmTokenizer(PreTrainedTokenizer):
    """
    Character-level tokenizer for UTR-LM RNA sequences.

    Each nucleotide (A / G / C / T) maps to a single token.
    Sequences are automatically wrapped with [CLS] ... [EOS] on encoding.

    Example::

        tok = UtrLmTokenizer()
        enc = tok("ATGCATG", return_tensors="pt")
        # enc.input_ids: [[7, 3, 6, 4, 5, 3, 6, 1]]
        #                  CLS A T G C A T  EOS
    """

    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file: Optional[str] = None,
        cls_token: str = "<cls>",
        pad_token: str = "<pad>",
        mask_token: str = "<mask>",
        eos_token: str = "<eos>",
        unk_token: str = "<unk>",
        sep_token: str = "<sep>",
        **kwargs,
    ):
        # Build vocab from file if provided (allows future extension), else use default
        if vocab_file is not None and os.path.isfile(vocab_file):
            with open(vocab_file) as f:
                self._vocab = json.load(f)
        else:
            self._vocab = dict(_VOCAB)
        self._ids_to_tokens = {v: k for k, v in self._vocab.items()}

        super().__init__(
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            **kwargs,
        )

    # ------------------------------------------------------------------
    # Required overrides
    # ------------------------------------------------------------------

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

    def get_vocab(self) -> Dict[str, int]:
        return dict(self._vocab)

    def _tokenize(self, text: str) -> List[str]:
        """Split sequence into individual characters."""
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab["<unk>"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, "<unk>")

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ) -> Tuple[str]:
        os.makedirs(save_directory, exist_ok=True)
        fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        path = os.path.join(save_directory, fname)
        with open(path, "w") as f:
            json.dump(self._vocab, f, indent=2)
        return (path,)

    # ------------------------------------------------------------------
    # Special-token wrapping: prepend [CLS], append [EOS]
    # ------------------------------------------------------------------

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        cls = [self.cls_token_id]
        eos = [self.eos_token_id]
        if token_ids_1 is None:
            return cls + token_ids_0 + eos
        return cls + token_ids_0 + eos + cls + token_ids_1 + eos

    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None,
                                already_has_special_tokens=False):
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0, token_ids_1, already_has_special_tokens=True
            )
        mask = [1] + [0] * len(token_ids_0) + [1]
        if token_ids_1 is not None:
            mask += [1] + [0] * len(token_ids_1) + [1]
        return mask

    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        if token_ids_1 is None:
            return [0] + token_ids_0 + [0]
        return [0] + token_ids_0 + [0, 0] + token_ids_1 + [0]