schrum2 commited on
Commit
f93d408
·
verified ·
1 Parent(s): 0ff087c

load from here

Browse files
Files changed (1) hide show
  1. tokenizer/tokenizer.py +147 -0
tokenizer/tokenizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections import Counter
4
+ import pickle
5
+ import argparse
6
+
7
+ class Tokenizer:
8
+ def __init__(self):
9
+ self.special_tokens = ["[PAD]", "[MASK]"]
10
+ self.vocab = {}
11
+ self.token_to_id = {}
12
+ self.id_to_token = {}
13
+
14
+ def tokenize(self, text):
15
+ # Match words, numbers, periods, and commas as separate tokens
16
+ tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
+ # Restore MASK and PAD to all caps
18
+ modified_list = []
19
+ for s in tokens:
20
+ modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
+ modified_list.append(modified_s)
22
+ return modified_list
23
+
24
+ def pad_sequence(self, tokens, length):
25
+ """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
+ if len(tokens) > length:
27
+ raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
+
29
+ pad_token = self.token_to_id["[PAD]"]
30
+ return tokens + [pad_token] * (length - len(tokens))
31
+
32
+ def build_vocab(self, dataset_path, min_freq=1):
33
+ token_counter = Counter()
34
+
35
+ with open(dataset_path, 'r') as f:
36
+ data = json.load(f)
37
+ for entry in data:
38
+ caption = entry['caption']
39
+ tokens = self.tokenize(caption)
40
+ token_counter.update(tokens)
41
+
42
+ # Keep tokens that meet the min frequency
43
+ tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
+
45
+ # Ensure special tokens are always included
46
+ all_tokens = self.special_tokens + sorted(tokens)
47
+
48
+ # Build vocab dictionaries
49
+ self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
+ self.token_to_id = self.vocab
51
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
+
53
+ print(f"Vocabulary size: {len(self.vocab)}")
54
+
55
+ def encode(self, text):
56
+ tokens = self.tokenize(text)
57
+ encoded = []
58
+ for tok in tokens:
59
+ if tok not in self.token_to_id:
60
+ raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
+ encoded.append(self.token_to_id[tok])
62
+ return encoded
63
+
64
+ def encode_batch(self, texts, pad_to_length=None):
65
+ """
66
+ Encode a batch of texts into token IDs with padding to ensure uniform length.
67
+
68
+ Args:
69
+ texts (list): A list of strings to encode
70
+ pad_to_length (int, optional): Length to pad all sequences to. If None,
71
+ will pad to the length of the longest sequence.
72
+
73
+ Returns:
74
+ list: A list of lists, where each inner list contains the token IDs for a text
75
+ """
76
+ # Get the padding token ID
77
+ pad_token = self.token_to_id["[PAD]"]
78
+
79
+ # First encode all texts
80
+ encoded_texts = []
81
+ for text in texts:
82
+ try:
83
+ encoded = self.encode(text)
84
+ encoded_texts.append(encoded)
85
+ except ValueError as e:
86
+ raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
+
88
+ # Determine padding length
89
+ if pad_to_length is None:
90
+ pad_to_length = max(len(seq) for seq in encoded_texts)
91
+
92
+ # Pad sequences to uniform length
93
+ padded_texts = []
94
+ for seq in encoded_texts:
95
+ if len(seq) > pad_to_length:
96
+ # Truncate if too long
97
+ padded_texts.append(seq[:pad_to_length])
98
+ else:
99
+ # Pad if too short
100
+ padding = [pad_token] * (pad_to_length - len(seq))
101
+ padded_texts.append(seq + padding)
102
+
103
+ return padded_texts
104
+
105
+ def decode(self, token_ids):
106
+ return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
+
108
+ def save(self, path):
109
+ with open(path, 'wb') as f:
110
+ pickle.dump({'vocab': self.vocab}, f)
111
+
112
+ def load(self, path):
113
+ with open(path, 'rb') as f:
114
+ data = pickle.load(f)
115
+ self.vocab = data['vocab']
116
+ self.token_to_id = self.vocab
117
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
+
119
+ def get_vocab(self):
120
+ return sorted(self.vocab.keys())
121
+
122
+ def get_vocab_size(self):
123
+ return len(self.vocab)
124
+
125
+ if __name__ == "__main__":
126
+ tokenizer = Tokenizer()
127
+
128
+ parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
+ parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
+ parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
+ parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
+
133
+ args = parser.parse_args()
134
+
135
+ if args.action == "save":
136
+ if not args.json_file:
137
+ raise ValueError("The --json_file argument is required for the 'save' action.")
138
+ tokenizer.build_vocab(args.json_file)
139
+ tokenizer.save(args.pkl_file)
140
+ elif args.action == "load":
141
+ tokenizer.load(args.pkl_file)
142
+
143
+ # Example usage
144
+ #print(tokenizer.encode("floor with one gap. one enemy."))
145
+ #print(tokenizer.get_vocab())
146
+ #for id, token in tokenizer.id_to_token.items():
147
+ # print(id,":",token)