| import json
|
| import re
|
| from collections import Counter
|
| import pickle
|
| import argparse
|
|
|
| class Tokenizer:
|
| def __init__(self):
|
| self.special_tokens = ["[PAD]", "[MASK]"]
|
| self.vocab = {}
|
| self.token_to_id = {}
|
| self.id_to_token = {}
|
|
|
| def tokenize(self, text):
|
|
|
| tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
|
|
|
| modified_list = []
|
| for s in tokens:
|
| modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
|
| modified_list.append(modified_s)
|
| return modified_list
|
|
|
| def pad_sequence(self, tokens, length):
|
| """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
|
| if len(tokens) > length:
|
| raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
|
|
|
| pad_token = self.token_to_id["[PAD]"]
|
| return tokens + [pad_token] * (length - len(tokens))
|
|
|
| def build_vocab(self, dataset_path, min_freq=1):
|
| token_counter = Counter()
|
|
|
| with open(dataset_path, 'r') as f:
|
| data = json.load(f)
|
| for entry in data:
|
| caption = entry['caption']
|
| tokens = self.tokenize(caption)
|
| token_counter.update(tokens)
|
|
|
|
|
| tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
|
|
|
|
|
| all_tokens = self.special_tokens + sorted(tokens)
|
|
|
|
|
| self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
|
| self.token_to_id = self.vocab
|
| self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
|
|
|
| print(f"Vocabulary size: {len(self.vocab)}")
|
|
|
| def encode(self, text):
|
| tokens = self.tokenize(text)
|
| encoded = []
|
| for tok in tokens:
|
| if tok not in self.token_to_id:
|
| raise ValueError(f"Unknown token encountered: {tok} in {text}")
|
| encoded.append(self.token_to_id[tok])
|
| return encoded
|
|
|
| def encode_batch(self, texts, pad_to_length=None):
|
| """
|
| Encode a batch of texts into token IDs with padding to ensure uniform length.
|
|
|
| Args:
|
| texts (list): A list of strings to encode
|
| pad_to_length (int, optional): Length to pad all sequences to. If None,
|
| will pad to the length of the longest sequence.
|
|
|
| Returns:
|
| list: A list of lists, where each inner list contains the token IDs for a text
|
| """
|
|
|
| pad_token = self.token_to_id["[PAD]"]
|
|
|
|
|
| encoded_texts = []
|
| for text in texts:
|
| try:
|
| encoded = self.encode(text)
|
| encoded_texts.append(encoded)
|
| except ValueError as e:
|
| raise ValueError(f"Error encoding text: {text}. {str(e)}")
|
|
|
|
|
| if pad_to_length is None:
|
| pad_to_length = max(len(seq) for seq in encoded_texts)
|
|
|
|
|
| padded_texts = []
|
| for seq in encoded_texts:
|
| if len(seq) > pad_to_length:
|
|
|
| padded_texts.append(seq[:pad_to_length])
|
| else:
|
|
|
| padding = [pad_token] * (pad_to_length - len(seq))
|
| padded_texts.append(seq + padding)
|
|
|
| return padded_texts
|
|
|
| def decode(self, token_ids):
|
| return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
|
|
|
| def save(self, path):
|
| with open(path, 'wb') as f:
|
| pickle.dump({'vocab': self.vocab}, f)
|
|
|
| def load(self, path):
|
| with open(path, 'rb') as f:
|
| data = pickle.load(f)
|
| self.vocab = data['vocab']
|
| self.token_to_id = self.vocab
|
| self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
|
|
|
| def get_vocab(self):
|
| return sorted(self.vocab.keys())
|
|
|
| def get_vocab_size(self):
|
| return len(self.vocab)
|
|
|
| if __name__ == "__main__":
|
| tokenizer = Tokenizer()
|
|
|
| parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
|
| parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
|
| parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
|
| parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
|
|
|
| args = parser.parse_args()
|
|
|
| if args.action == "save":
|
| if not args.json_file:
|
| raise ValueError("The --json_file argument is required for the 'save' action.")
|
| tokenizer.build_vocab(args.json_file)
|
| tokenizer.save(args.pkl_file)
|
| elif args.action == "load":
|
| tokenizer.load(args.pkl_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|