| | import json |
| | import re |
| | from collections import Counter |
| | import pickle |
| | import argparse |
| |
|
| | class Tokenizer: |
| | def __init__(self): |
| | self.special_tokens = ["[PAD]", "[MASK]"] |
| | self.vocab = {} |
| | self.token_to_id = {} |
| | self.id_to_token = {} |
| |
|
| | def tokenize(self, text): |
| | |
| | tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower()) |
| | |
| | modified_list = [] |
| | for s in tokens: |
| | modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]") |
| | modified_list.append(modified_s) |
| | return modified_list |
| |
|
| | def pad_sequence(self, tokens, length): |
| | """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]').""" |
| | if len(tokens) > length: |
| | raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.") |
| | |
| | pad_token = self.token_to_id["[PAD]"] |
| | return tokens + [pad_token] * (length - len(tokens)) |
| |
|
| | def build_vocab(self, dataset_path, min_freq=1): |
| | token_counter = Counter() |
| |
|
| | with open(dataset_path, 'r') as f: |
| | data = json.load(f) |
| | for entry in data: |
| | caption = entry['caption'] |
| | tokens = self.tokenize(caption) |
| | token_counter.update(tokens) |
| |
|
| | |
| | tokens = [tok for tok, count in token_counter.items() if count >= min_freq] |
| |
|
| | |
| | all_tokens = self.special_tokens + sorted(tokens) |
| | |
| | |
| | self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)} |
| | self.token_to_id = self.vocab |
| | self.id_to_token = {idx: tok for tok, idx in self.vocab.items()} |
| |
|
| | print(f"Vocabulary size: {len(self.vocab)}") |
| |
|
| | def encode(self, text): |
| | tokens = self.tokenize(text) |
| | encoded = [] |
| | for tok in tokens: |
| | if tok not in self.token_to_id: |
| | raise ValueError(f"Unknown token encountered: {tok} in {text}") |
| | encoded.append(self.token_to_id[tok]) |
| | return encoded |
| |
|
| | def encode_batch(self, texts, pad_to_length=None): |
| | """ |
| | Encode a batch of texts into token IDs with padding to ensure uniform length. |
| | |
| | Args: |
| | texts (list): A list of strings to encode |
| | pad_to_length (int, optional): Length to pad all sequences to. If None, |
| | will pad to the length of the longest sequence. |
| | |
| | Returns: |
| | list: A list of lists, where each inner list contains the token IDs for a text |
| | """ |
| | |
| | pad_token = self.token_to_id["[PAD]"] |
| | |
| | |
| | encoded_texts = [] |
| | for text in texts: |
| | try: |
| | encoded = self.encode(text) |
| | encoded_texts.append(encoded) |
| | except ValueError as e: |
| | raise ValueError(f"Error encoding text: {text}. {str(e)}") |
| | |
| | |
| | if pad_to_length is None: |
| | pad_to_length = max(len(seq) for seq in encoded_texts) |
| | |
| | |
| | padded_texts = [] |
| | for seq in encoded_texts: |
| | if len(seq) > pad_to_length: |
| | |
| | padded_texts.append(seq[:pad_to_length]) |
| | else: |
| | |
| | padding = [pad_token] * (pad_to_length - len(seq)) |
| | padded_texts.append(seq + padding) |
| | |
| | return padded_texts |
| |
|
| | def decode(self, token_ids): |
| | return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids) |
| |
|
| | def save(self, path): |
| | with open(path, 'wb') as f: |
| | pickle.dump({'vocab': self.vocab}, f) |
| |
|
| | def load(self, path): |
| | with open(path, 'rb') as f: |
| | data = pickle.load(f) |
| | self.vocab = data['vocab'] |
| | self.token_to_id = self.vocab |
| | self.id_to_token = {idx: tok for tok, idx in self.vocab.items()} |
| |
|
| | def get_vocab(self): |
| | return sorted(self.vocab.keys()) |
| |
|
| | def get_vocab_size(self): |
| | return len(self.vocab) |
| |
|
| | if __name__ == "__main__": |
| | tokenizer = Tokenizer() |
| |
|
| | parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.") |
| | parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.") |
| | parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').") |
| | parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.") |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.action == "save": |
| | if not args.json_file: |
| | raise ValueError("The --json_file argument is required for the 'save' action.") |
| | tokenizer.build_vocab(args.json_file) |
| | tokenizer.save(args.pkl_file) |
| | elif args.action == "load": |
| | tokenizer.load(args.pkl_file) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|