| from pathlib import Path |
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class TextTokenCollater: |
| """Collate list of text tokens |
| |
| Map sentences to integers. Sentences are padded to equal length. |
| Beginning and end-of-sequence symbols can be added. |
| |
| Example: |
| >>> token_collater = TextTokenCollater(text_tokens) |
| >>> tokens_batch, tokens_lens = token_collater(text) |
| |
| Returns: |
| tokens_batch: IntTensor of shape (B, L) |
| B: batch dimension, number of input sentences |
| L: length of the longest sentence |
| tokens_lens: IntTensor of shape (B,) |
| Length of each sentence after adding <eos> and <bos> |
| but before padding. |
| """ |
|
|
| def __init__( |
| self, |
| text_tokens: List[str], |
| add_eos: bool = True, |
| add_bos: bool = True, |
| pad_symbol: str = "<pad>", |
| bos_symbol: str = "<bos>", |
| eos_symbol: str = "<eos>", |
| ): |
| self.pad_symbol = pad_symbol |
|
|
| self.add_eos = add_eos |
| self.add_bos = add_bos |
|
|
| self.bos_symbol = bos_symbol |
| self.eos_symbol = eos_symbol |
|
|
| unique_tokens = ( |
| [pad_symbol] |
| + ([bos_symbol] if add_bos else []) |
| + ([eos_symbol] if add_eos else []) |
| + sorted(text_tokens) |
| ) |
|
|
| self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
| self.idx2token = [token for token in unique_tokens] |
|
|
| def index( |
| self, tokens_list: List[str] |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| seqs, seq_lens = [], [] |
| for tokens in tokens_list: |
| assert ( |
| all([True if s in self.token2idx else False for s in tokens]) |
| is True |
| ) |
| seq = ( |
| ([self.bos_symbol] if self.add_bos else []) |
| + list(tokens) |
| + ([self.eos_symbol] if self.add_eos else []) |
| ) |
| seqs.append(seq) |
| seq_lens.append(len(seq)) |
|
|
| max_len = max(seq_lens) |
| for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): |
| seq.extend([self.pad_symbol] * (max_len - seq_len)) |
|
|
| tokens = torch.from_numpy( |
| np.array( |
| [[self.token2idx[token] for token in seq] for seq in seqs], |
| dtype=np.int64, |
| ) |
| ) |
| tokens_lens = torch.IntTensor(seq_lens) |
|
|
| return tokens, tokens_lens |
|
|
| def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
| tokens_seqs = [[p for p in text] for text in texts] |
| max_len = len(max(tokens_seqs, key=len)) |
|
|
| seqs = [ |
| ([self.bos_symbol] if self.add_bos else []) |
| + list(seq) |
| + ([self.eos_symbol] if self.add_eos else []) |
| + [self.pad_symbol] * (max_len - len(seq)) |
| for seq in tokens_seqs |
| ] |
|
|
| tokens_batch = torch.from_numpy( |
| np.array( |
| [seq for seq in seqs], |
| dtype=np.int64, |
| ) |
| ) |
|
|
| tokens_lens = torch.IntTensor( |
| [ |
| len(seq) + int(self.add_eos) + int(self.add_bos) |
| for seq in tokens_seqs |
| ] |
| ) |
|
|
| return tokens_batch, tokens_lens |
|
|
|
|
| def get_text_token_collater() -> TextTokenCollater: |
| collater = TextTokenCollater( |
| ['0'], add_bos=False, add_eos=False |
| ) |
| return collater |