|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
from typing import List |
|
|
|
|
|
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor |
|
|
from nemo.collections.nlp.data.text_normalization import constants |
|
|
from nemo.collections.nlp.data.text_normalization.utils import normalize_str, read_data_file, remove_puncts |
|
|
from nemo.utils import logging |
|
|
|
|
|
__all__ = ['TextNormalizationTestDataset'] |
|
|
|
|
|
|
|
|
class TextNormalizationTestDataset: |
|
|
""" |
|
|
Creates dataset to use to do end-to-end inference |
|
|
|
|
|
Args: |
|
|
input_file: path to the raw data file (e.g., train.tsv). For more info about the data format, refer to the `text_normalization doc <https://github.com/NVIDIA/NeMo/blob/main/docs/source/nlp/text_normalization.rst>`. |
|
|
mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time. |
|
|
lang: Language of the dataset |
|
|
""" |
|
|
|
|
|
def __init__(self, input_file: str, mode: str, lang: str): |
|
|
self.lang = lang |
|
|
insts = read_data_file(input_file, lang=lang) |
|
|
processor = MosesProcessor(lang_id=lang) |
|
|
|
|
|
self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = ( |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
) |
|
|
for (classes, w_words, s_words) in insts: |
|
|
|
|
|
for direction in constants.INST_DIRECTIONS: |
|
|
if direction == constants.INST_BACKWARD: |
|
|
if mode == constants.TN_MODE: |
|
|
continue |
|
|
|
|
|
( |
|
|
processed_w_words, |
|
|
processed_s_words, |
|
|
processed_classes, |
|
|
processed_nb_spans, |
|
|
processed_s_span_starts, |
|
|
processed_s_span_ends, |
|
|
) = ([], [], [], 0, [], []) |
|
|
s_word_idx = 0 |
|
|
for cls, w_word, s_word in zip(classes, w_words, s_words): |
|
|
if s_word == constants.SIL_WORD: |
|
|
continue |
|
|
elif s_word == constants.SELF_WORD: |
|
|
processed_s_words.append(w_word) |
|
|
else: |
|
|
processed_s_words.append(s_word) |
|
|
|
|
|
s_word_last = processor.tokenize(processed_s_words.pop()).split() |
|
|
processed_s_words.append(" ".join(s_word_last)) |
|
|
num_tokens = len(s_word_last) |
|
|
processed_nb_spans += 1 |
|
|
processed_classes.append(cls) |
|
|
processed_s_span_starts.append(s_word_idx) |
|
|
s_word_idx += num_tokens |
|
|
processed_s_span_ends.append(s_word_idx) |
|
|
processed_w_words.append(w_word) |
|
|
|
|
|
self.span_starts.append(processed_s_span_starts) |
|
|
self.span_ends.append(processed_s_span_ends) |
|
|
self.classes.append(processed_classes) |
|
|
self.nb_spans.append(processed_nb_spans) |
|
|
input_words = ' '.join(processed_s_words) |
|
|
|
|
|
self.directions.append(direction) |
|
|
self.inputs.append(input_words) |
|
|
self.targets.append( |
|
|
processed_w_words |
|
|
) |
|
|
|
|
|
elif direction == constants.INST_FORWARD: |
|
|
if mode == constants.ITN_MODE: |
|
|
continue |
|
|
( |
|
|
processed_w_words, |
|
|
processed_s_words, |
|
|
processed_classes, |
|
|
processed_nb_spans, |
|
|
w_span_starts, |
|
|
w_span_ends, |
|
|
) = ([], [], [], 0, [], []) |
|
|
w_word_idx = 0 |
|
|
for cls, w_word, s_word in zip(classes, w_words, s_words): |
|
|
|
|
|
|
|
|
w_word = processor.tokenize(w_word).split() |
|
|
num_tokens = len(w_word) |
|
|
if s_word in constants.SPECIAL_WORDS: |
|
|
processed_s_words.append(" ".join(w_word)) |
|
|
else: |
|
|
processed_s_words.append(s_word) |
|
|
w_span_starts.append(w_word_idx) |
|
|
w_word_idx += num_tokens |
|
|
w_span_ends.append(w_word_idx) |
|
|
processed_nb_spans += 1 |
|
|
processed_classes.append(cls) |
|
|
processed_w_words.extend(w_word) |
|
|
|
|
|
self.span_starts.append(w_span_starts) |
|
|
self.span_ends.append(w_span_ends) |
|
|
self.classes.append(processed_classes) |
|
|
self.nb_spans.append(processed_nb_spans) |
|
|
input_words = ' '.join(processed_w_words) |
|
|
|
|
|
self.directions.append(direction) |
|
|
self.inputs.append(input_words) |
|
|
self.targets.append( |
|
|
processed_s_words |
|
|
) |
|
|
|
|
|
self.examples = list( |
|
|
zip( |
|
|
self.directions, |
|
|
self.inputs, |
|
|
self.targets, |
|
|
self.classes, |
|
|
self.nb_spans, |
|
|
self.span_starts, |
|
|
self.span_ends, |
|
|
) |
|
|
) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.examples[idx] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.inputs) |
|
|
|
|
|
@staticmethod |
|
|
def is_same(pred: str, target: str, inst_dir: str): |
|
|
""" |
|
|
Function for checking whether the predicted string can be considered |
|
|
the same as the target string |
|
|
|
|
|
Args: |
|
|
pred: Predicted string |
|
|
target: Target string |
|
|
inst_dir: Direction of the instance (i.e., INST_BACKWARD or INST_FORWARD). |
|
|
Return: an int value (0/1) indicating whether pred and target are the same. |
|
|
""" |
|
|
if inst_dir == constants.INST_BACKWARD: |
|
|
pred = remove_puncts(pred) |
|
|
target = remove_puncts(target) |
|
|
pred = normalize_str(pred) |
|
|
target = normalize_str(target) |
|
|
return int(pred == target) |
|
|
|
|
|
@staticmethod |
|
|
def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str]): |
|
|
""" |
|
|
Compute the sentence accuracy metric. |
|
|
|
|
|
Args: |
|
|
preds: List of predicted strings. |
|
|
targets: List of target strings. |
|
|
inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD or INST_FORWARD). |
|
|
Return: the sentence accuracy score |
|
|
""" |
|
|
assert len(preds) == len(targets) |
|
|
if len(targets) == 0: |
|
|
return 'NA' |
|
|
|
|
|
correct_count = 0 |
|
|
for inst_dir, pred, target in zip(inst_directions, preds, targets): |
|
|
correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir) |
|
|
sent_accuracy = correct_count / len(targets) |
|
|
|
|
|
return sent_accuracy |
|
|
|
|
|
@staticmethod |
|
|
def compute_class_accuracy( |
|
|
inputs: List[List[str]], |
|
|
targets: List[List[str]], |
|
|
tag_preds: List[List[str]], |
|
|
inst_directions: List[str], |
|
|
output_spans: List[List[str]], |
|
|
classes: List[List[str]], |
|
|
nb_spans: List[int], |
|
|
span_ends: List[List[int]], |
|
|
) -> dict: |
|
|
""" |
|
|
Compute the class based accuracy metric. This uses model's predicted tags. |
|
|
|
|
|
Args: |
|
|
inputs: List of lists where inner list contains words of input text |
|
|
targets: List of lists where inner list contains target strings grouped by class boundary |
|
|
tag_preds: List of lists where inner list contains predicted tags for each of the input words |
|
|
inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD or INST_FORWARD). |
|
|
output_spans: A list of lists where each inner list contains the decoded spans for the corresponding input sentence |
|
|
classes: A list of lists where inner list contains the class for each semiotic token in input sentence |
|
|
nb_spans: A list that contains the number of tokens in the input |
|
|
span_ends: A list of lists where inner list contains the end word index of the current token |
|
|
Return: the class accuracy scores as dict |
|
|
""" |
|
|
|
|
|
if len(targets) == 0: |
|
|
return 'NA' |
|
|
class2stats, class2correct = defaultdict(int), defaultdict(int) |
|
|
for ix, (sent, tags) in enumerate(zip(inputs, tag_preds)): |
|
|
try: |
|
|
assert len(sent) == len(tags) |
|
|
except: |
|
|
logging.warning(f"Error: skipping example {ix}") |
|
|
continue |
|
|
cur_words = [[] for _ in range(nb_spans[ix])] |
|
|
jx, span_idx = 0, 0 |
|
|
cur_spans = output_spans[ix] |
|
|
class_idx = 0 |
|
|
if classes[ix]: |
|
|
class2stats[classes[ix][class_idx]] += 1 |
|
|
while jx < len(sent): |
|
|
tag, word = tags[jx], sent[jx] |
|
|
while jx >= span_ends[ix][class_idx]: |
|
|
class_idx += 1 |
|
|
class2stats[classes[ix][class_idx]] += 1 |
|
|
if constants.SAME_TAG in tag: |
|
|
cur_words[class_idx].append(word) |
|
|
jx += 1 |
|
|
else: |
|
|
jx += 1 |
|
|
tmp = cur_spans[span_idx] |
|
|
cur_words[class_idx].append(tmp) |
|
|
span_idx += 1 |
|
|
while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG: |
|
|
while jx >= span_ends[ix][class_idx]: |
|
|
class_idx += 1 |
|
|
class2stats[classes[ix][class_idx]] += 1 |
|
|
cur_words[class_idx].append(tmp) |
|
|
jx += 1 |
|
|
|
|
|
target_token_idx = 0 |
|
|
|
|
|
for class_idx in range(nb_spans[ix]): |
|
|
correct = TextNormalizationTestDataset.is_same( |
|
|
" ".join(cur_words[class_idx]), targets[ix][target_token_idx], inst_directions[ix] |
|
|
) |
|
|
class2correct[classes[ix][class_idx]] += correct |
|
|
target_token_idx += 1 |
|
|
|
|
|
for key in class2stats: |
|
|
class2stats[key] = (class2correct[key] / class2stats[key], class2correct[key], class2stats[key]) |
|
|
|
|
|
return class2stats |
|
|
|