| | import json
|
| | from typing import List, Dict, Optional, Tuple, Union
|
| | from pathlib import Path
|
| | import regex as re
|
| |
|
| | class OpenPeerTokenizer:
|
| | """Simple tokenizer implementation for testing"""
|
| |
|
| | def __init__(self, unk_token="<|endoftext|>",
|
| | bos_token="<|endoftext|>",
|
| | eos_token="<|endoftext|>",
|
| | pad_token="<|endoftext|>"):
|
| | self.unk_token = unk_token
|
| | self.bos_token = bos_token
|
| | self.eos_token = eos_token
|
| | self.pad_token = pad_token
|
| | self.eos_token_id = 0
|
| |
|
| |
|
| | self.vocab = self._get_default_vocab()
|
| | self.vocab_size = len(self.vocab)
|
| |
|
| | def _get_default_vocab(self) -> Dict[str, int]:
|
| | """Get a basic default vocabulary"""
|
| | vocab = {}
|
| |
|
| | vocab[self.unk_token] = 0
|
| | vocab[self.pad_token] = 1
|
| | vocab["<|mask|>"] = 2
|
| |
|
| |
|
| | for i in range(32, 127):
|
| | vocab[chr(i)] = len(vocab)
|
| |
|
| |
|
| | common_words = ["the", "be", "to", "of", "and", "a", "in", "that", "have"]
|
| | for word in common_words:
|
| | vocab[word] = len(vocab)
|
| |
|
| | return vocab
|
| |
|
| | def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, List[int]]:
|
| | """Tokenize text"""
|
| | if isinstance(text, str):
|
| |
|
| | tokens = []
|
| | for word in text.split():
|
| |
|
| | if word in self.vocab:
|
| | tokens.append(self.vocab[word])
|
| | else:
|
| | for char in word:
|
| | tokens.append(self.vocab.get(char, self.vocab[self.unk_token]))
|
| | else:
|
| | tokens = []
|
| | for t in text:
|
| | word_tokens = []
|
| | for word in t.split():
|
| | if word in self.vocab:
|
| | word_tokens.append(self.vocab[word])
|
| | else:
|
| | for char in word:
|
| | word_tokens.append(self.vocab.get(char, self.vocab[self.unk_token]))
|
| | tokens.append(word_tokens)
|
| |
|
| | if isinstance(text, str):
|
| | attention_mask = [1] * len(tokens)
|
| | return {"input_ids": tokens, "attention_mask": attention_mask}
|
| | else:
|
| | attention_masks = [[1] * len(t) for t in tokens]
|
| | return {"input_ids": tokens, "attention_mask": attention_masks}
|
| |
|
| | def decode(self, token_ids: Union[List[int], List[List[int]]], skip_special_tokens: bool = True) -> str:
|
| | """Decode token ids to text"""
|
| |
|
| | id_to_token = {v: k for k, v in self.vocab.items()}
|
| |
|
| | if isinstance(token_ids[0], list):
|
| |
|
| | texts = []
|
| | for ids in token_ids:
|
| | text = []
|
| | for id in ids:
|
| | token = id_to_token.get(id, self.unk_token)
|
| | if not skip_special_tokens or token not in [self.unk_token, self.pad_token, "<|mask|>"]:
|
| | text.append(token)
|
| | texts.append(" ".join(text))
|
| | return texts
|
| | else:
|
| |
|
| | text = []
|
| | for id in token_ids:
|
| | token = id_to_token.get(id, self.unk_token)
|
| | if not skip_special_tokens or token not in [self.unk_token, self.pad_token, "<|mask|>"]:
|
| | text.append(token)
|
| | return " ".join(text) |