smiles-tokenizer / tokenizer_class.py
suku9's picture
Upload SMILES tokenizer
6425080 verified
"""Hugging Face compatible SMILES tokenizer wrapper."""
import os
import json
import torch
from transformers import PreTrainedTokenizer
from .smiles_tokenizer import SmilesTokenizer, SmilesVocabulary
class HFSmilesTokenizer(PreTrainedTokenizer):
"""
Wrapper class for the SmilesTokenizer to make it compatible with the Hugging Face tokenizer interface.
This allows the tokenizer to be used with any Hugging Face model, especially GPT-2.
"""
# Required for Hugging Face tokenizers
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab=None,
pad_token="<pad>",
eos_token="</s>",
unk_token="<unk>",
bos_token="<go>",
**kwargs
):
# Initialize the base tokenizer
if vocab is None:
self.smiles_tokenizer = SmilesTokenizer()
else:
vocabulary = SmilesVocabulary(
pad=pad_token,
eos=eos_token,
unk=unk_token,
go=bos_token
)
# Add custom vocab symbols if provided
if isinstance(vocab, list):
for token in vocab:
vocabulary.add_symbol(token)
self.smiles_tokenizer = SmilesTokenizer(vocabulary=vocabulary)
# Set up the vocabulary BEFORE calling super().__init__
self._vocab = {
token: idx for idx, token in enumerate(self.smiles_tokenizer.vocabulary.symbols)
}
self._ids_to_tokens = {
idx: token for token, idx in self._vocab.items()
}
# Initialize the PreTrainedTokenizer with our special tokens
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
eos_token=eos_token,
bos_token=bos_token,
**kwargs
)
@property
def vocab_size(self):
"""Return the size of vocabulary."""
return len(self._vocab)
def get_vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
def _tokenize(self, text):
"""
Tokenize a string into a list of tokens.
"""
if isinstance(text, list):
return self.smiles_tokenizer.tokenize(text, enclose=False)[0]
return self.smiles_tokenizer.tokenize([text], enclose=False)[0]
def _convert_token_to_id(self, token):
"""
Convert a token to its ID.
"""
return self.smiles_tokenizer.vocabulary.index(token)
def _convert_id_to_token(self, index):
"""
Convert an ID to its token.
"""
return self.smiles_tokenizer.vocabulary[index]
def convert_tokens_to_string(self, tokens):
"""
Convert a list of tokens to a string.
"""
return "".join(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence by adding special tokens.
Returns:
List[int]: List of input IDs with special tokens added.
"""
bos_token_id = self.bos_token_id
eos_token_id = self.eos_token_id
if token_ids_1 is None:
return [bos_token_id] + token_ids_0 + [eos_token_id]
# For sequence pairs, we follow GPT-2 format: <bos> seq1 <eos> seq2 <eos>
return [bos_token_id] + token_ids_0 + [eos_token_id] + token_ids_1 + [eos_token_id]
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieve sequence of special tokens mask.
Returns:
List[int]: A list of integers where 1 indicates a special token and 0 indicates a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Create token type IDs for sequence pairs.
Returns:
List[int]: List of token type IDs.
"""
if token_ids_1 is None:
return [0] * len(token_ids_0 + 2) # +2 for <bos> and <eos>
# For GPT-2, we use all 0s for token type IDs
return [0] * (len(token_ids_0) + len(token_ids_1) + 3) # +3 for <bos> and two <eos>
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the tokenizer vocabulary to a directory.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def encode_smiles(self, smiles, enclose=True, return_tensors=None):
"""
Encode a list of SMILES strings using the original SmilesTokenizer functionality.
Args:
smiles: A list of SMILES strings or a single SMILES string.
enclose: Whether to add special tokens.
return_tensors: The type of tensors to return ('pt' for PyTorch, None for lists).
Returns:
List of token IDs or PyTorch tensors.
"""
ids_list = self.smiles_tokenizer.encode(smiles, enclose=enclose, aslist=True)
if return_tensors == "pt":
return [torch.tensor(ids, dtype=torch.long) for ids in ids_list]
return ids_list
def decode_smiles(self, ids_list):
"""
Decode a list of token IDs back to SMILES strings using the original SmilesTokenizer functionality.
Args:
ids_list: A list of lists or tensors containing token IDs.
Returns:
List of SMILES strings.
"""
return self.smiles_tokenizer.decode(ids_list)
def tokens_to_smiles(self, tokens):
"""
Convert generated tokens to SMILES strings.
Args:
tokens: List of token IDs.
Returns:
List of SMILES strings.
"""
return self.smiles_tokenizer.tokens_to_smiles(tokens)