martymukherjee commited on
Commit
a9fdaa8
·
verified ·
1 Parent(s): ac70ffc

Create LyapunovTokenizer.py

Browse files
Files changed (1) hide show
  1. LyapunovTokenizer.py +92 -0
LyapunovTokenizer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import sympy as sp
3
+ from transformers import PreTrainedTokenizer
4
+ import json
5
+ import os
6
+ from huggingface_hub import upload_folder
7
+
8
+ SPECIAL_WORDS = ["<s>", "</s>", "<pad>", "(", ")"]
9
+ SPECIAL_WORDS = SPECIAL_WORDS + [f"<SPECIAL_{i}>" for i in range(10)]
10
+
11
+ class LyapunovTokenizer(PreTrainedTokenizer):
12
+ def __init__(self):
13
+ self.SYMPY_OPERATORS = {
14
+ sp.Add: "+",
15
+ sp.Mul: "*",
16
+ sp.Pow: "^",
17
+ sp.exp: "exp",
18
+ sp.log: "ln",
19
+ sp.Abs: "Abs",
20
+ sp.sin: "sin",
21
+ sp.cos: "cos",
22
+ sp.tan: "tan",
23
+ sp.asin: "asin",
24
+ sp.acos: "acos",
25
+ sp.atan: "atan",
26
+ sp.DiracDelta: "delta0",
27
+ }
28
+
29
+ self.trig_ops = ["sin", "cos", "tan"]
30
+ self.arctrig_ops = ["asin", "acos", "atan"]
31
+ self.exp_ops = ["exp", "ln"]
32
+ self.other_ops = ["sqrt"]
33
+
34
+ op_set = {
35
+ "+": 2,
36
+ "-": 2,
37
+ "*": 2,
38
+ "/": 2,
39
+ "^": 2,
40
+ "sqrt": 1,
41
+ "exp": 1,
42
+ "ln": 1,
43
+ "sin": 1,
44
+ "cos": 1,
45
+ "tan": 1,
46
+ "asin": 1,
47
+ "acos": 1,
48
+ "atan": 1,
49
+ "Abs": 1,
50
+ }
51
+
52
+ self.int_base = 1000
53
+ self.max_degree = 6
54
+
55
+ self.operators_lyap = op_set
56
+ self.operators = self.operators_lyap
57
+
58
+ self.variables = OrderedDict({f"x{i}": sp.Symbol(f"x{i}") for i in range(2 * self.max_degree)})
59
+ self.constants = ["pi", "E"]
60
+ self.symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
61
+ self.elements = [str(i) for i in range(max(10, self.int_base))]
62
+ self.mask_symbol = ["<mask>"]
63
+
64
+ self.words = SPECIAL_WORDS + self.constants + list(self.variables.keys()) + list(self.operators.keys()) + self.symbols + self.elements + self.mask_symbol
65
+
66
+ self.vocab = {s: i for i, s in enumerate(self.words)}
67
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
68
+ super().__init__(
69
+ model_max_length=2048, bos_token="<s>", eos_token="</s>", unk_token="<unk>", mask_token="<mask>"
70
+ )
71
+
72
+ def _tokenize(self, text):
73
+ return text.split()
74
+
75
+ def _convert_token_to_id(self, token):
76
+ return self.vocab.get(token, self.unk_token_id)
77
+
78
+ def _convert_id_to_token(self, index):
79
+ return self.inv_vocab.get(index, self.unk_token)
80
+
81
+ def get_vocab(self):
82
+ return self.vocab
83
+
84
+ @property
85
+ def vocab_size(self):
86
+ return len(self.vocab)
87
+
88
+ def save_vocabulary(self, save_directory, filename_prefix=None):
89
+ vocab_file = os.path.join(save_directory, "vocab.json")
90
+ with open(vocab_file, "w") as f:
91
+ json.dump(self.vocab, f)
92
+ return (vocab_file,)