schrum2 commited on
Commit
7bd88dc
·
verified ·
1 Parent(s): f9fce19

moving this

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -147
tokenizer.py DELETED
@@ -1,147 +0,0 @@
1
- import json
2
- import re
3
- from collections import Counter
4
- import pickle
5
- import argparse
6
-
7
- class Tokenizer:
8
- def __init__(self):
9
- self.special_tokens = ["[PAD]", "[MASK]"]
10
- self.vocab = {}
11
- self.token_to_id = {}
12
- self.id_to_token = {}
13
-
14
- def tokenize(self, text):
15
- # Match words, numbers, periods, and commas as separate tokens
16
- tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
- # Restore MASK and PAD to all caps
18
- modified_list = []
19
- for s in tokens:
20
- modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
- modified_list.append(modified_s)
22
- return modified_list
23
-
24
- def pad_sequence(self, tokens, length):
25
- """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
- if len(tokens) > length:
27
- raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
-
29
- pad_token = self.token_to_id["[PAD]"]
30
- return tokens + [pad_token] * (length - len(tokens))
31
-
32
- def build_vocab(self, dataset_path, min_freq=1):
33
- token_counter = Counter()
34
-
35
- with open(dataset_path, 'r') as f:
36
- data = json.load(f)
37
- for entry in data:
38
- caption = entry['caption']
39
- tokens = self.tokenize(caption)
40
- token_counter.update(tokens)
41
-
42
- # Keep tokens that meet the min frequency
43
- tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
-
45
- # Ensure special tokens are always included
46
- all_tokens = self.special_tokens + sorted(tokens)
47
-
48
- # Build vocab dictionaries
49
- self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
- self.token_to_id = self.vocab
51
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
-
53
- print(f"Vocabulary size: {len(self.vocab)}")
54
-
55
- def encode(self, text):
56
- tokens = self.tokenize(text)
57
- encoded = []
58
- for tok in tokens:
59
- if tok not in self.token_to_id:
60
- raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
- encoded.append(self.token_to_id[tok])
62
- return encoded
63
-
64
- def encode_batch(self, texts, pad_to_length=None):
65
- """
66
- Encode a batch of texts into token IDs with padding to ensure uniform length.
67
-
68
- Args:
69
- texts (list): A list of strings to encode
70
- pad_to_length (int, optional): Length to pad all sequences to. If None,
71
- will pad to the length of the longest sequence.
72
-
73
- Returns:
74
- list: A list of lists, where each inner list contains the token IDs for a text
75
- """
76
- # Get the padding token ID
77
- pad_token = self.token_to_id["[PAD]"]
78
-
79
- # First encode all texts
80
- encoded_texts = []
81
- for text in texts:
82
- try:
83
- encoded = self.encode(text)
84
- encoded_texts.append(encoded)
85
- except ValueError as e:
86
- raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
-
88
- # Determine padding length
89
- if pad_to_length is None:
90
- pad_to_length = max(len(seq) for seq in encoded_texts)
91
-
92
- # Pad sequences to uniform length
93
- padded_texts = []
94
- for seq in encoded_texts:
95
- if len(seq) > pad_to_length:
96
- # Truncate if too long
97
- padded_texts.append(seq[:pad_to_length])
98
- else:
99
- # Pad if too short
100
- padding = [pad_token] * (pad_to_length - len(seq))
101
- padded_texts.append(seq + padding)
102
-
103
- return padded_texts
104
-
105
- def decode(self, token_ids):
106
- return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
-
108
- def save(self, path):
109
- with open(path, 'wb') as f:
110
- pickle.dump({'vocab': self.vocab}, f)
111
-
112
- def load(self, path):
113
- with open(path, 'rb') as f:
114
- data = pickle.load(f)
115
- self.vocab = data['vocab']
116
- self.token_to_id = self.vocab
117
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
-
119
- def get_vocab(self):
120
- return sorted(self.vocab.keys())
121
-
122
- def get_vocab_size(self):
123
- return len(self.vocab)
124
-
125
- if __name__ == "__main__":
126
- tokenizer = Tokenizer()
127
-
128
- parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
- parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
- parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
- parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
-
133
- args = parser.parse_args()
134
-
135
- if args.action == "save":
136
- if not args.json_file:
137
- raise ValueError("The --json_file argument is required for the 'save' action.")
138
- tokenizer.build_vocab(args.json_file)
139
- tokenizer.save(args.pkl_file)
140
- elif args.action == "load":
141
- tokenizer.load(args.pkl_file)
142
-
143
- # Example usage
144
- #print(tokenizer.encode("floor with one gap. one enemy."))
145
- #print(tokenizer.get_vocab())
146
- #for id, token in tokenizer.id_to_token.items():
147
- # print(id,":",token)