"""Custom SMILES tokenizer implementation.""" import re import json import warnings from re import Pattern from typing import Dict, List, Optional, Union, Any import torch # Vocabulary class class SmilesVocabulary(Vocabulary): def __init__(self, pad="", eos="", unk="", go=""): self.unk_word, self.pad_word, self.eos_word, self.go_word = ( unk, pad, eos, go, ) self.symbols = [] self.count = [] self.indices = {} self.pad_index = self.add_symbol(pad) self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) self.go_index = self.add_symbol(go) self.nspecial = len(self.symbols) for token in self.__get_smile_tokens(): self.add_symbol(token) def __get_smile_tokens(self): SMILE_TOKENS = [ "S", "O", "2", "n", "l", "F", "H", "C", "o", "5", "r", "s", "=", "6", "[", "N", "4", "c", "-", "3", ")", "#", "]", "B", "(", "1", ] return SMILE_TOKENS def finalize(self, threshold=-1, nwords=-1, padding_factor=1): super(SmilesVocabulary, self).finalize( threshold=threshold, nwords=nwords, padding_factor=padding_factor ) def go(self): """GO index.""" return self.go_index @classmethod def load(cls, f=None, ignore_utf_errors=False): """Load function for SMILE data. Ignore the file and just initialize the vocab. """ return cls() # Tokenizer class class SmilesTokenizer: """ Smiles Tokenizer """ def __init__( self, vocabulary: Vocabulary = None, ) -> None: if vocabulary is None: self.vocabulary = SmilesVocabulary() else: self.vocabulary = vocabulary self._re: Optional[Pattern] = None @property def re(self) -> Pattern: """Tokens Regex Object. :return: Tokens Regex Object """ if not self._re: self._re = self._get_compiled_regex(self.vocabulary.symbols) return self._re def tokenize(self, smiles: List[str], enclose: bool = True) -> List[List[str]]: """ convert list of smiles strings to list of lists containing tokens for each """ if isinstance(smiles, str): # Convert string to a list with one string smiles = [smiles] tokenized_data = [] for smi in smiles: tokens = self.re.findall(smi) if enclose: tokenized_data.append( [self.vocabulary.go_word] + tokens + [self.vocabulary.eos_word] ) else: tokenized_data.append(tokens) return tokenized_data def encode(self, smiles: List[str], enclose: bool = True, aslist=False): """ convert a list of smiles strings to list of tensors containing token indices """ if isinstance(smiles, str): # Convert string to a list with one string smiles = [smiles] tokenized_smiles = self.tokenize(smiles, enclose=enclose) tokens_lengths = list(map(len, tokenized_smiles)) ids_list = [] for tokens, length in zip(tokenized_smiles, tokens_lengths): ids_tensor = [] # torch.zeros(length, dtype=torch.long) for tdx, token in enumerate(tokens): ids_tensor.append(self.vocabulary.index(token)) if not aslist: ids_tensor = torch.tensor(ids_tensor, dtype=torch.long) ids_list.append(ids_tensor) return ids_list def detokenize( self, token_data: List[List[str]], include_control_tokens: bool = False, include_end_of_line_token: bool = False, truncate_at_end_token: bool = False, ) -> List[str]: """ Detokenizes lists of tokens into SMILES by concatenating the token strings. """ character_lists = [tokens.copy() for tokens in token_data] character_lists = [ self._strip_list( tokens, strip_control_tokens=not include_control_tokens, truncate_at_end_token=truncate_at_end_token, ) for tokens in character_lists ] if include_end_of_line_token: for s in character_lists: s.append("\n") strings = ["".join(s) for s in character_lists] return strings def decode(self, ids_list: List[torch.Tensor]): """ decodes lists of encodings (ids as tensors) back into smiles strings """ tokenized_smiles = [] for ids in ids_list: if not isinstance(ids, list): ids = ids.tolist() tokens = [self.vocabulary[i] for i in ids] tokenized_smiles.append(tokens) smiles = self.detokenize(tokenized_smiles, truncate_at_end_token=True) return smiles def tokens_to_smiles(self, tokens): """ Convert generated tokens to smiles. Arguments: tokens: list of tokens Returns: list of smiles strings """ # convert tokens to smiles smiles = self.decode(tokens) smiles = [smi.replace("", "") for smi in smiles] return smiles def _strip_list( self, tokens: List[str], strip_control_tokens: bool = False, truncate_at_end_token: bool = False, ) -> List[str]: """Cleanup tokens list from control tokens. :param tokens: List of tokens :param strip_control_tokens: Flag to remove control tokens, defaults to False :param truncate_at_end_token: If True truncate tokens after end-token """ if truncate_at_end_token and self.vocabulary.eos_word in tokens: end_token_idx = tokens.index(self.vocabulary.eos_word) tokens = tokens[: end_token_idx + 1] strip_characters: List[str] = [self.vocabulary.pad_word] if strip_control_tokens: strip_characters.extend([self.vocabulary.go_word, self.vocabulary.eos_word]) while len(tokens) > 0 and tokens[0] in strip_characters: tokens.pop(0) while len(tokens) > 0 and tokens[-1] in strip_characters: tokens.pop() return tokens def _get_compiled_regex(self, tokens: List[str]) -> Pattern: """Create a Regular Expression Object from a list of tokens and regular expression tokens. :param tokens: List of tokens :return: Regular Expression Object """ regex_string = r"(" # r"(" for ix, token in enumerate(tokens): processed_token = token for special_character in "()[]+*": processed_token = processed_token.replace( special_character, f"\{special_character}" ) if ix < len(tokens) - 1: regex_string += processed_token + r"|" else: regex_string += processed_token regex_string += r")" pattern = re.compile(regex_string) return pattern