| import json |
| import re |
| from collections import Counter |
| import pickle |
| import argparse |
|
|
| class Tokenizer: |
| def __init__(self): |
| self.special_tokens = ["[PAD]", "[MASK]"] |
| self.vocab = {} |
| self.token_to_id = {} |
| self.id_to_token = {} |
|
|
| def tokenize(self, text): |
| |
| tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower()) |
| |
| modified_list = [] |
| for s in tokens: |
| modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]") |
| modified_list.append(modified_s) |
| return modified_list |
|
|
| def pad_sequence(self, tokens, length): |
| """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]').""" |
| if len(tokens) > length: |
| raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.") |
| |
| pad_token = self.token_to_id["[PAD]"] |
| return tokens + [pad_token] * (length - len(tokens)) |
|
|
| def build_vocab(self, dataset_path, min_freq=1): |
| token_counter = Counter() |
|
|
| with open(dataset_path, 'r') as f: |
| data = json.load(f) |
| for entry in data: |
| caption = entry['caption'] |
| tokens = self.tokenize(caption) |
| token_counter.update(tokens) |
|
|
| |
| tokens = [tok for tok, count in token_counter.items() if count >= min_freq] |
|
|
| |
| all_tokens = self.special_tokens + sorted(tokens) |
| |
| |
| self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)} |
| self.token_to_id = self.vocab |
| self.id_to_token = {idx: tok for tok, idx in self.vocab.items()} |
|
|
| print(f"Vocabulary size: {len(self.vocab)}") |
|
|
| def encode(self, text): |
| tokens = self.tokenize(text) |
| encoded = [] |
| for tok in tokens: |
| if tok not in self.token_to_id: |
| raise ValueError(f"Unknown token encountered: {tok} in {text}") |
| encoded.append(self.token_to_id[tok]) |
| return encoded |
|
|
| def encode_batch(self, texts, pad_to_length=None): |
| """ |
| Encode a batch of texts into token IDs with padding to ensure uniform length. |
| |
| Args: |
| texts (list): A list of strings to encode |
| pad_to_length (int, optional): Length to pad all sequences to. If None, |
| will pad to the length of the longest sequence. |
| |
| Returns: |
| list: A list of lists, where each inner list contains the token IDs for a text |
| """ |
| |
| pad_token = self.token_to_id["[PAD]"] |
| |
| |
| encoded_texts = [] |
| for text in texts: |
| try: |
| encoded = self.encode(text) |
| encoded_texts.append(encoded) |
| except ValueError as e: |
| raise ValueError(f"Error encoding text: {text}. {str(e)}") |
| |
| |
| if pad_to_length is None: |
| pad_to_length = max(len(seq) for seq in encoded_texts) |
| |
| |
| padded_texts = [] |
| for seq in encoded_texts: |
| if len(seq) > pad_to_length: |
| |
| padded_texts.append(seq[:pad_to_length]) |
| else: |
| |
| padding = [pad_token] * (pad_to_length - len(seq)) |
| padded_texts.append(seq + padding) |
| |
| return padded_texts |
|
|
| def decode(self, token_ids): |
| return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids) |
|
|
| def save(self, path): |
| with open(path, 'wb') as f: |
| pickle.dump({'vocab': self.vocab}, f) |
|
|
| def load(self, path): |
| with open(path, 'rb') as f: |
| data = pickle.load(f) |
| self.vocab = data['vocab'] |
| self.token_to_id = self.vocab |
| self.id_to_token = {idx: tok for tok, idx in self.vocab.items()} |
|
|
| def get_vocab(self): |
| return sorted(self.vocab.keys()) |
|
|
| def get_vocab_size(self): |
| return len(self.vocab) |
|
|
| if __name__ == "__main__": |
| tokenizer = Tokenizer() |
|
|
| parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.") |
| parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.") |
| parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').") |
| parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.") |
|
|
| args = parser.parse_args() |
|
|
| if args.action == "save": |
| if not args.json_file: |
| raise ValueError("The --json_file argument is required for the 'save' action.") |
| tokenizer.build_vocab(args.json_file) |
| tokenizer.save(args.pkl_file) |
| elif args.action == "load": |
| tokenizer.load(args.pkl_file) |
|
|
| |
| |
| |
| |
| |
|
|