File size: 5,619 Bytes
a09cfc1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import json
import re
from collections import Counter
import pickle
import argparse
class Tokenizer:
def __init__(self):
self.special_tokens = ["[PAD]", "[MASK]"]
self.vocab = {}
self.token_to_id = {}
self.id_to_token = {}
def tokenize(self, text):
# Match words, numbers, periods, and commas as separate tokens
tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
# Restore MASK and PAD to all caps
modified_list = []
for s in tokens:
modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
modified_list.append(modified_s)
return modified_list
def pad_sequence(self, tokens, length):
"""Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
if len(tokens) > length:
raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
pad_token = self.token_to_id["[PAD]"]
return tokens + [pad_token] * (length - len(tokens))
def build_vocab(self, dataset_path, min_freq=1):
token_counter = Counter()
with open(dataset_path, 'r') as f:
data = json.load(f)
for entry in data:
caption = entry['caption']
tokens = self.tokenize(caption)
token_counter.update(tokens)
# Keep tokens that meet the min frequency
tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
# Ensure special tokens are always included
all_tokens = self.special_tokens + sorted(tokens)
# Build vocab dictionaries
self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
self.token_to_id = self.vocab
self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
print(f"Vocabulary size: {len(self.vocab)}")
def encode(self, text):
tokens = self.tokenize(text)
encoded = []
for tok in tokens:
if tok not in self.token_to_id:
raise ValueError(f"Unknown token encountered: {tok} in {text}")
encoded.append(self.token_to_id[tok])
return encoded
def encode_batch(self, texts, pad_to_length=None):
"""
Encode a batch of texts into token IDs with padding to ensure uniform length.
Args:
texts (list): A list of strings to encode
pad_to_length (int, optional): Length to pad all sequences to. If None,
will pad to the length of the longest sequence.
Returns:
list: A list of lists, where each inner list contains the token IDs for a text
"""
# Get the padding token ID
pad_token = self.token_to_id["[PAD]"]
# First encode all texts
encoded_texts = []
for text in texts:
try:
encoded = self.encode(text)
encoded_texts.append(encoded)
except ValueError as e:
raise ValueError(f"Error encoding text: {text}. {str(e)}")
# Determine padding length
if pad_to_length is None:
pad_to_length = max(len(seq) for seq in encoded_texts)
# Pad sequences to uniform length
padded_texts = []
for seq in encoded_texts:
if len(seq) > pad_to_length:
# Truncate if too long
padded_texts.append(seq[:pad_to_length])
else:
# Pad if too short
padding = [pad_token] * (pad_to_length - len(seq))
padded_texts.append(seq + padding)
return padded_texts
def decode(self, token_ids):
return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
def save(self, path):
with open(path, 'wb') as f:
pickle.dump({'vocab': self.vocab}, f)
def load(self, path):
with open(path, 'rb') as f:
data = pickle.load(f)
self.vocab = data['vocab']
self.token_to_id = self.vocab
self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
def get_vocab(self):
return sorted(self.vocab.keys())
def get_vocab_size(self):
return len(self.vocab)
if __name__ == "__main__":
tokenizer = Tokenizer()
parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
args = parser.parse_args()
if args.action == "save":
if not args.json_file:
raise ValueError("The --json_file argument is required for the 'save' action.")
tokenizer.build_vocab(args.json_file)
tokenizer.save(args.pkl_file)
elif args.action == "load":
tokenizer.load(args.pkl_file)
# Example usage
#print(tokenizer.encode("floor with one gap. one enemy."))
#print(tokenizer.get_vocab())
#for id, token in tokenizer.id_to_token.items():
# print(id,":",token)
|