File size: 2,138 Bytes
d541e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from typing import List, Tuple

import numpy as np
import torch
from torchtext.vocab import Vocab
from torch import nn, Tensor

from src.util import device


class Tokenizer(nn.Module):
    def __init__(self, vocab: str | Vocab):
        super().__init__()

        # check vocab file exists
        if isinstance(vocab, str):
            assert os.path.exists(vocab)
            self.vocab = torch.load(vocab, map_location=device)
        else:
            self.vocab = vocab

        self.edge_index = vocab['<EDGE>']
        self.pad_index = vocab['<PAD>']
        self.unk_index = vocab['<UNK>']

    def get_tensors(self, data):
        """
        Builds torch.Tensor from a variable length 2D python list. The return value is a tuple of two tensors, one for input and the other for output.

        Parameters
        ----------
        data: Nested list of token indices
            [[1,2,3],
             [4,2,3,4,2],
             [223,4,2]]
            This example has three sentences.

        """
        max_len = max([len(datum) for datum in data]) + 1
        N = len(data)
        X = np.full((N, max_len), self.pad_index, np.int64)
        Y = np.full((N, max_len), self.pad_index, np.int64)

        for i in range(N):
            # prepend the inputs with edge token
            X[i, 0] = self.edge_index
            for j in range(len(data[i])):
                X[i, j + 1] = data[i][j]
                Y[i, j] = data[i][j]

            # finish the outputs with edge token
            Y[i, j] = self.edge_index

        return torch.tensor(X, device=device), torch.tensor(Y, device=device)

    def forward(self, text: List[str]) -> Tuple[Tensor, Tensor]:
        """
        Tokenizes a list of natural text. The return value is a tensor of token ids.

        Parameters
        ----------
        text: List[str]. A list of natural language strings.

        Returns
        -------
        torch.Tensor. A tensor of token ids.
        """

        text = [sentence.split() for sentence in text]
        tokenized = [self.vocab(sentence) for sentence in text]
        return self.get_tensors(tokenized)