SMIRK_Tokenizer / tok.py
haydn-jones's picture
Update tok.py
798bcfc verified
import regex as re
from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers, processors
SPECIAL = [
"[PAD]",
"[BOS]",
"[EOS]",
"[MASK]",
"[UNK]",
]
# fmt: off
ELEMENTS = [
'H', 'He',
'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
'Cs', 'Ba', 'La', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn',
'Fr', 'Ra', 'Ac', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og',
'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu',
'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'
]
ELEMENTS += ["te", "si"] # These are not 'correct', but RDKit is allowing them, and they show up in PubChem.
# fmt: on
AROMATIC = ["b", "c", "n", "o", "p", "s", "se", "as"]
ORGANIC = ["B", "C", "N", "O", "P", "S", "F", "I", "Cl", "Br"] + AROMATIC
BONDS = ["-", "=", "#", "$", ":", "/", "\\", "."]
CHARGE = ["+", "-"]
CHIRAL = ["@", "@@", "@TH", "@AL", "@SP", "@TB", "@OH"]
BRANCH = ["(", ")", "*"]
RINGS = [str(i) for i in range(10)] + ["%"]
TOKENS = (
SPECIAL
+ ELEMENTS
+ AROMATIC # will be deduped
+ ORGANIC # will be deduped
+ BONDS
+ CHARGE
+ CHIRAL
+ BRANCH
+ RINGS
+ ["[", "]"]
)
# Make tokens unique while preserving order
TOKENS = list(dict.fromkeys(TOKENS))
VOCAB = {tok: i for i, tok in enumerate(TOKENS)}
AROMATIC_SINGLE = {"b", "c", "n", "o", "p", "s"}
AROMATIC_MULTI = {"se", "as"} # two-letter aromatic tokens
AROMATIC_ALL = AROMATIC_SINGLE | AROMATIC_MULTI
def is_ambiguous(elem: str) -> bool:
for i in range(1, len(elem)):
head, tail = elem[:i], elem[i:]
if head in TOKENS and tail in TOKENS:
return True
return False
UNAMBIGUOUS_ELEMENTS = [e for e in ELEMENTS if not is_ambiguous(e)]
OUTER_TOKENS = ORGANIC + BONDS + CHARGE + CHIRAL + BRANCH + RINGS + UNAMBIGUOUS_ELEMENTS
OUTER_REGEX = Regex(
"|".join(
sorted(
[re.escape(tok) for tok in OUTER_TOKENS] + [r"\[[^\[\]]+\]"],
key=len,
reverse=True,
)
)
)
INNER_TOKENS = ELEMENTS + AROMATIC + BONDS + CHARGE + CHIRAL + RINGS + ["%"] + ["[", "]"]
INNER_REGEX = Regex("|".join(sorted(map(re.escape, INNER_TOKENS), key=len, reverse=True)))
tokenizer = Tokenizer(models.WordLevel(VOCAB, unk_token="[UNK]"))
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
tokenizer.add_special_tokens(SPECIAL)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(pattern=OUTER_REGEX, behavior="isolated"),
pre_tokenizers.Split(pattern=INNER_REGEX, behavior="isolated"),
]
) # type: ignore
tokenizer.post_processor = processors.TemplateProcessing(
single="[BOS] $A [EOS]", special_tokens=[("[BOS]", VOCAB["[BOS]"]), ("[EOS]", VOCAB["[EOS]"])]
) # type: ignore
tokenizer.decoder = decoders.WordPiece(prefix="") # type: ignore