HybriDNA-7B / hybridna_tokenizer.py
Mishamq's picture
Upload hybridna_tokenizer.py with huggingface_hub
3fdaf71 verified
from transformers import PreTrainedTokenizer, AddedToken
from typing import List, Optional, Union, Dict, Sequence, Tuple
from pathlib import Path
import numpy as np
import json
import os
class HybriDNATokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self,
model_max_length: int,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following are list of all of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
self.characters = ('A', 'C', 'G', 'T', 'N')
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
self._bos_id = self._vocab_str_to_int["[BOS]"]
self._eos_id = self._vocab_str_to_int["[SEP]"]
self._pad_id = self._vocab_str_to_int["[PAD]"]
self._unk_id = self._vocab_str_to_int["[UNK]"]
self._bos_np = np.array([self._bos_id], dtype=np.uint16)
self._eos_np = np.array([self._eos_id], dtype=np.uint16)
self._numpy_lookup = np.full(256, self._unk_id, dtype=np.uint16)
for ch in self.characters:
self._numpy_lookup[ord(ch)] = self._vocab_str_to_int[ch]
for special in ("[CLS]", "[SEP]", "[BOS]", "[MASK]", "[PAD]", "[RESERVED]", "[UNK]"):
token_id = self._vocab_str_to_int[special]
if special.startswith("[") and len(special) == 5:
# Skip bracketed four-letter tokens from attempting ascii mapping.
continue
# Explicitly map special token string representations if they are single characters.
if len(special) == 1:
self._numpy_lookup[ord(special)] = token_id
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
bos = [self.bos_token_id]
eos = [self.eos_token_id]
result = bos + token_ids_0 + eos
if token_ids_1 is not None:
result += token_ids_1 + eos
return result
def create_attention_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""Creates an attention mask to differentiate between padding and non-padding tokens.
Args:
token_ids_0 (List[int]): List of token IDs for the first sequence.
token_ids_1 (Optional[List[int]]): List of token IDs for the second sequence if available.
Returns:
List[int]: A list where 1 represents non-padding tokens and 0 represents padding tokens.
"""
mask = [1] * len(token_ids_0)
if token_ids_1 is not None:
mask += [1] * len(token_ids_1)
return mask
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
vocab_file = os.path.join(save_directory, (filename_prefix or '') + 'vocab.json')
with open(vocab_file, 'w') as f:
json.dump(self._vocab_str_to_int, f)
return (vocab_file,)
def __call__(
self,
text: Union[str, List[str]],
*,
padding: bool = True,
truncation: bool = True,
max_length: Optional[int] = None,
add_special_tokens: bool = True,
return_tensors: Optional[str] = None,
**kwargs,
):
# ---------- detect batch vs single ----------
is_batch = not isinstance(text, str)
seqs = text if is_batch else [text] # always work on a list internally
max_len = max_length or self.model_max_length
# ---------- encode every sequence ----------
batch_input_ids = []
for seq in seqs:
seq_bytes = np.frombuffer(seq.encode("ascii", "ignore"), dtype=np.uint8)
ids = self._numpy_lookup[seq_bytes]
if add_special_tokens:
ids = np.concatenate((self._bos_np, ids, self._eos_np))
if truncation and ids.size > max_len:
ids = ids[:max_len]
batch_input_ids.append(ids.astype(np.uint16, copy=False))
# ---------- pad ----------
if padding and batch_input_ids:
if padding == "max_length":
pad_len = max_len
elif padding == "longest":
pad_len = max(ids.size for ids in batch_input_ids)
elif padding is True:
pad_len = max(ids.size for ids in batch_input_ids)
else:
pad_len = None
if pad_len is not None:
pad_len = min(pad_len, max_len)
padded_ids = []
for ids in batch_input_ids:
if ids.size < pad_len:
pad_width = pad_len - ids.size
ids = np.pad(ids, (0, pad_width), constant_values=self._pad_id)
elif ids.size > pad_len:
ids = ids[:pad_len]
ids = np.asarray(ids, dtype=np.uint16, order="C")
padded_ids.append(ids[:pad_len])
batch_input_ids = padded_ids
for ids in batch_input_ids:
if not isinstance(ids, np.ndarray):
# Fallback for any non-numpy path
continue
# ---------- masks ----------
batch_attention = []
for ids in batch_input_ids:
if isinstance(ids, np.ndarray):
mask = (ids != self._pad_id).astype(np.uint8, copy=False)
else:
mask = [0 if tok == self._pad_id else 1 for tok in ids]
batch_attention.append(mask)
# ---------- collapse back if it was a single example ----------
if not is_batch:
batch_input_ids = batch_input_ids[0]
batch_attention = batch_attention[0]
result = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention,
}
# ---------- convert to tensors if requested ----------
if return_tensors == "pt":
import torch
if is_batch:
result["input_ids"] = torch.tensor(np.stack(result["input_ids"]), dtype=torch.long)
result["attention_mask"] = torch.tensor(np.stack(result["attention_mask"]), dtype=torch.long)
else:
result["input_ids"] = torch.tensor(result["input_ids"], dtype=torch.long).unsqueeze(0)
result["attention_mask"] = torch.tensor(result["attention_mask"], dtype=torch.long).unsqueeze(0)
elif return_tensors == "np":
if is_batch:
result["input_ids"] = np.stack(result["input_ids"])
result["attention_mask"] = np.stack(result["attention_mask"])
else:
result["input_ids"] = np.expand_dims(result["input_ids"], 0)
result["attention_mask"] = np.expand_dims(result["attention_mask"], 0)
return result