|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
|
from typing import List |
|
|
|
|
|
import numpy |
|
|
|
|
|
from nemo.collections.common.tokenizers.column_coder import ColumnCodes |
|
|
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
|
|
|
__all__ = ['TabularTokenizer'] |
|
|
|
|
|
END_OF_TEXT = '<|endoftext|>' |
|
|
NEW_LINE = '\n' |
|
|
|
|
|
|
|
|
def find_index_of(list_input, item): |
|
|
output = -1 |
|
|
try: |
|
|
output = list_input.index(item) |
|
|
except ValueError: |
|
|
pass |
|
|
return output |
|
|
|
|
|
|
|
|
class TabularTokenizer(TokenizerSpec): |
|
|
def __init__(self, coder, special_tokens=[END_OF_TEXT, NEW_LINE], delimiter=','): |
|
|
if isinstance(coder, ColumnCodes): |
|
|
self.code_column: ColumnCodes = coder |
|
|
else: |
|
|
with open(coder, 'rb') as handle: |
|
|
self.code_column: ColumnCodes = pickle.load(handle) |
|
|
self.num_columns = len(self.code_column.columns) |
|
|
self.special_tokens = {} |
|
|
self.special_tokens_decoder = {} |
|
|
self.add_special_tokens(special_tokens) |
|
|
self.delimiter = delimiter |
|
|
self.eod_id = self.special_tokens[END_OF_TEXT] |
|
|
self.eos_id = self.eod_id |
|
|
self.bos_id = self.eos_id |
|
|
|
|
|
def __len__(self): |
|
|
return self.vocab_size |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return max(self.special_tokens_decoder.keys()) + 1 |
|
|
|
|
|
def text_to_ids(self, text): |
|
|
return self.encode(text) |
|
|
|
|
|
def ids_to_text(self, token_ids): |
|
|
return self.decode(token_ids) |
|
|
|
|
|
@property |
|
|
def eod(self): |
|
|
return self.eod_id |
|
|
|
|
|
@property |
|
|
def eor(self): |
|
|
return self.special_tokens[NEW_LINE] |
|
|
|
|
|
def add_special_tokens(self, special_tokens): |
|
|
""" Add a list of additional tokens to the encoder. |
|
|
The additional tokens are indexed starting from the last |
|
|
index of the |
|
|
current vocabulary in the order of the `special_tokens` list. |
|
|
""" |
|
|
if not special_tokens: |
|
|
self.special_tokens = {} |
|
|
self.special_tokens_decoder = {} |
|
|
return |
|
|
new = dict( |
|
|
(tok, self.code_column.vocab_size + i) |
|
|
for i, tok in enumerate(special_tokens) |
|
|
if tok not in self.special_tokens |
|
|
) |
|
|
self.special_tokens.update(new) |
|
|
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} |
|
|
|
|
|
def text_to_tokens(self, text): |
|
|
""" Tokenize a string. """ |
|
|
tokens = [] |
|
|
rows = text.split(NEW_LINE) |
|
|
num_rows = len(rows) |
|
|
for row_id in range(num_rows): |
|
|
row = rows[row_id] |
|
|
if row == '': |
|
|
continue |
|
|
fields = row.split(self.delimiter) |
|
|
for f in fields: |
|
|
splits = f.split(END_OF_TEXT) |
|
|
if len(splits) == 1: |
|
|
tokens.append(f.strip()) |
|
|
elif len(splits) == 2: |
|
|
if splits[0] != '': |
|
|
tokens.append(splits[0].strip()) |
|
|
tokens.append(END_OF_TEXT) |
|
|
if splits[1] != '': |
|
|
tokens.append(splits[1].strip()) |
|
|
else: |
|
|
raise ValueError("delimiter error") |
|
|
if row_id != num_rows - 1: |
|
|
tokens.append(NEW_LINE) |
|
|
return tokens |
|
|
|
|
|
def tokens_to_ids(self, tokens: List[str]): |
|
|
""" Converts a sequence of tokens into ids using the vocab. """ |
|
|
ids = [] |
|
|
cindex = 0 |
|
|
if NEW_LINE in tokens: |
|
|
idd = tokens.index(NEW_LINE) |
|
|
cindex = (self.num_columns - idd) % self.num_columns |
|
|
for token in tokens: |
|
|
|
|
|
if token in self.special_tokens: |
|
|
ids.append(self.special_tokens[token]) |
|
|
else: |
|
|
index = cindex % self.num_columns |
|
|
column = self.code_column.columns[index] |
|
|
ids.extend(self.code_column.encode(column, token)) |
|
|
cindex += 1 |
|
|
return ids |
|
|
|
|
|
def ids_to_tokens(self, ids, skip_special_tokens=False): |
|
|
"""Converts a sequence of ids in Tabular tokens using the vocab.""" |
|
|
tokens = [] |
|
|
sizes = self.code_column.sizes |
|
|
ids_size = sum(sizes) |
|
|
cindex = 0 |
|
|
eor_pos = find_index_of(ids, self.eor) |
|
|
eod_pos = find_index_of(ids, self.eod) |
|
|
if eor_pos >= 0 and eod_pos >= 0: |
|
|
idd = min(eor_pos, eod_pos) |
|
|
cindex = (ids_size - idd) % ids_size |
|
|
elif eor_pos >= 0 and eod_pos < 0: |
|
|
idd = eor_pos |
|
|
cindex = (ids_size - idd) % ids_size |
|
|
elif eod_pos >= 0 and eor_pos < 0: |
|
|
idd = eod_pos |
|
|
cindex = (ids_size - idd) % ids_size |
|
|
cum_sizes = numpy.cumsum(sizes) |
|
|
old_column_index = -1 |
|
|
token_ids = [] |
|
|
for i in ids: |
|
|
if i in self.special_tokens_decoder: |
|
|
if not skip_special_tokens: |
|
|
tokens.append(self.special_tokens_decoder[i]) |
|
|
else: |
|
|
index = cindex % ids_size |
|
|
column_index = numpy.where(index < cum_sizes)[0][0] |
|
|
column = self.code_column.columns[column_index] |
|
|
if old_column_index != column_index: |
|
|
token_ids = [i] |
|
|
old_column_index = column_index |
|
|
else: |
|
|
token_ids.append(i) |
|
|
if len(token_ids) == sizes[column_index]: |
|
|
tokens.append(self.code_column.decode(column, token_ids)) |
|
|
cindex += 1 |
|
|
return tokens |
|
|
|
|
|
def encode(self, text): |
|
|
return self.tokens_to_ids(self.text_to_tokens(text)) |
|
|
|
|
|
def decode(self, token_ids): |
|
|
tokens = self.ids_to_tokens(token_ids, skip_special_tokens=False) |
|
|
return self.tokens_to_text(tokens) |
|
|
|
|
|
def tokens_to_text(self, tokens): |
|
|
all_lines = [] |
|
|
line = [] |
|
|
for token in tokens: |
|
|
if token == END_OF_TEXT or token == NEW_LINE: |
|
|
if len(line) != 0: |
|
|
line_text = self.delimiter.join(line) |
|
|
all_lines.append(line_text) |
|
|
all_lines.append(token) |
|
|
line = [] |
|
|
else: |
|
|
line.append(token) |
|
|
if len(line) != 0: |
|
|
|
|
|
line_text = self.delimiter.join(line) |
|
|
all_lines.append(line_text) |
|
|
text = "".join(all_lines) |
|
|
return text |
|
|
|