soiz1 commited on
Commit
256f6fb
·
verified ·
1 Parent(s): 9f9eb74

Create collation.py

Browse files
Files changed (1) hide show
  1. data/collation.py +118 -0
data/collation.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ class TextTokenCollater:
9
+ """Collate list of text tokens
10
+
11
+ Map sentences to integers. Sentences are padded to equal length.
12
+ Beginning and end-of-sequence symbols can be added.
13
+
14
+ Example:
15
+ >>> token_collater = TextTokenCollater(text_tokens)
16
+ >>> tokens_batch, tokens_lens = token_collater(text)
17
+
18
+ Returns:
19
+ tokens_batch: IntTensor of shape (B, L)
20
+ B: batch dimension, number of input sentences
21
+ L: length of the longest sentence
22
+ tokens_lens: IntTensor of shape (B,)
23
+ Length of each sentence after adding <eos> and <bos>
24
+ but before padding.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ text_tokens: List[str],
30
+ add_eos: bool = True,
31
+ add_bos: bool = True,
32
+ pad_symbol: str = "<pad>",
33
+ bos_symbol: str = "<bos>",
34
+ eos_symbol: str = "<eos>",
35
+ ):
36
+ self.pad_symbol = pad_symbol
37
+
38
+ self.add_eos = add_eos
39
+ self.add_bos = add_bos
40
+
41
+ self.bos_symbol = bos_symbol
42
+ self.eos_symbol = eos_symbol
43
+
44
+ unique_tokens = (
45
+ [pad_symbol]
46
+ + ([bos_symbol] if add_bos else [])
47
+ + ([eos_symbol] if add_eos else [])
48
+ + sorted(text_tokens)
49
+ )
50
+
51
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
52
+ self.idx2token = [token for token in unique_tokens]
53
+
54
+ def index(
55
+ self, tokens_list: List[str]
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ seqs, seq_lens = [], []
58
+ for tokens in tokens_list:
59
+ assert (
60
+ all([True if s in self.token2idx else False for s in tokens])
61
+ is True
62
+ )
63
+ seq = (
64
+ ([self.bos_symbol] if self.add_bos else [])
65
+ + list(tokens)
66
+ + ([self.eos_symbol] if self.add_eos else [])
67
+ )
68
+ seqs.append(seq)
69
+ seq_lens.append(len(seq))
70
+
71
+ max_len = max(seq_lens)
72
+ for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
73
+ seq.extend([self.pad_symbol] * (max_len - seq_len))
74
+
75
+ tokens = torch.from_numpy(
76
+ np.array(
77
+ [[self.token2idx[token] for token in seq] for seq in seqs],
78
+ dtype=np.int64,
79
+ )
80
+ )
81
+ tokens_lens = torch.IntTensor(seq_lens)
82
+
83
+ return tokens, tokens_lens
84
+
85
+ def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
86
+ tokens_seqs = [[p for p in text] for text in texts]
87
+ max_len = len(max(tokens_seqs, key=len))
88
+
89
+ seqs = [
90
+ ([self.bos_symbol] if self.add_bos else [])
91
+ + list(seq)
92
+ + ([self.eos_symbol] if self.add_eos else [])
93
+ + [self.pad_symbol] * (max_len - len(seq))
94
+ for seq in tokens_seqs
95
+ ]
96
+
97
+ tokens_batch = torch.from_numpy(
98
+ np.array(
99
+ [seq for seq in seqs],
100
+ dtype=np.int64,
101
+ )
102
+ )
103
+
104
+ tokens_lens = torch.IntTensor(
105
+ [
106
+ len(seq) + int(self.add_eos) + int(self.add_bos)
107
+ for seq in tokens_seqs
108
+ ]
109
+ )
110
+
111
+ return tokens_batch, tokens_lens
112
+
113
+
114
+ def get_text_token_collater() -> TextTokenCollater:
115
+ collater = TextTokenCollater(
116
+ ['0'], add_bos=False, add_eos=False
117
+ )
118
+ return collater