# coding=utf-8 """ Vocabulary helper class """ import re import numpy as np class Vocabulary: """Stores the tokens and their conversion to one-hot vectors.""" def __init__(self, tokens=None, starting_id=0): self._tokens = {} self._current_id = starting_id if tokens: for token, idx in tokens.items(): self._add(token, idx) self._current_id = max(self._current_id, idx + 1) def __getitem__(self, token_or_id): return self._tokens[token_or_id] def add(self, token): """Adds a token.""" if not isinstance(token, str): raise TypeError("Token is not a string") if token in self: # raise ValueError("Token already present in the vocabulary") print(f'=== Token "{token}"already present in the vocabulary') return self._add(token, self._current_id) self._current_id += 1 return self._current_id - 1 def update(self, tokens): """Adds many tokens.""" return [self.add(token) for token in tokens] def __delitem__(self, token_or_id): other_val = self._tokens[token_or_id] del self._tokens[other_val] del self._tokens[token_or_id] def __contains__(self, token_or_id): return token_or_id in self._tokens def __eq__(self, other_vocabulary): return self._tokens == other_vocabulary._tokens def __len__(self): return len(self._tokens) // 2 def encode(self, tokens): """Encodes a list of tokens, encoding them in 1-hot encoded vectors.""" ohe_vect = np.zeros(len(tokens), dtype=np.float32) for i, token in enumerate(tokens): try: ohe_vect[i] = self._tokens[token] except KeyError: ohe_vect[i] = self._tokens["default_key"] return ohe_vect def decode(self, ohe_vect): """Decodes a one-hot encoded vector matrix to a list of tokens.""" tokens = [] for ohv in ohe_vect: try: tokens.append(self[ohv]) except KeyError: tokens.append("default_key") return tokens def _add(self, token, idx): if idx not in self._tokens: self._tokens[token] = idx self._tokens[idx] = token else: raise ValueError("IDX already present in vocabulary") def tokens(self): """Returns the tokens from the vocabulary""" return [t for t in self._tokens if isinstance(t, str)] def word2idx(self): return {k: self._tokens[k] for k in self._tokens if isinstance(k, str)} class SMILESTokenizer: """Deals with the tokenization and untokenization of SMILES.""" REGEXPS = { "brackets": re.compile(r"(\[[^\]]*\])"), "2_ring_nums": re.compile(r"(%\d{2})"), "brcl": re.compile(r"(Br|Cl)") } REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"] def tokenize(self, data, with_begin_and_end=True): """Tokenizes a SMILES string.""" def split_by(data, regexps): if not regexps: return list(data) regexp = self.REGEXPS[regexps[0]] splitted = regexp.split(data) tokens = [] for i, split in enumerate(splitted): if i % 2 == 0: tokens += split_by(split, regexps[1:]) else: tokens.append(split) return tokens tokens = split_by(data, self.REGEXP_ORDER) if with_begin_and_end: tokens = ["^"] + tokens + ["$"] return tokens def untokenize(self, tokens): """Untokenizes a SMILES string.""" smi = "" for token in tokens: if token == "$": break if token != "^": smi += token return smi def create_vocabulary(smiles_list, tokenizer, property_condition=None): """Creates a vocabulary for the SMILES syntax.""" tokens = set() for smi in smiles_list: tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False)) vocabulary = Vocabulary() vocabulary.update(["*", "^", "$"] + sorted(tokens)) # pad=0, start=1, end=2 if property_condition is not None: vocabulary.update(property_condition) # for random smiles if "8" not in vocabulary.tokens(): vocabulary.update(["8"]) return vocabulary