|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from math import ceil |
|
|
from time import perf_counter |
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
from tqdm import tqdm |
|
|
|
|
|
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset, constants |
|
|
from nemo.collections.nlp.data.text_normalization.utils import input_preprocessing |
|
|
from nemo.collections.nlp.models.duplex_text_normalization.utils import get_formatted_string |
|
|
from nemo.utils import logging |
|
|
|
|
|
try: |
|
|
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct |
|
|
|
|
|
PYNINI_AVAILABLE = True |
|
|
except (ImportError, ModuleNotFoundError): |
|
|
PYNINI_AVAILABLE = False |
|
|
|
|
|
|
|
|
__all__ = ['DuplexTextNormalizationModel'] |
|
|
|
|
|
|
|
|
class DuplexTextNormalizationModel(nn.Module): |
|
|
""" |
|
|
DuplexTextNormalizationModel is a wrapper class that can be used to |
|
|
encapsulate a trained tagger and a trained decoder. The class is intended |
|
|
to be used for inference only (e.g., for evaluation). |
|
|
""" |
|
|
|
|
|
def __init__(self, tagger, decoder, lang): |
|
|
super(DuplexTextNormalizationModel, self).__init__() |
|
|
|
|
|
self.tagger = tagger |
|
|
self.decoder = decoder |
|
|
self.lang = lang |
|
|
|
|
|
def evaluate( |
|
|
self, dataset: TextNormalizationTestDataset, batch_size: int, errors_log_fp: str, verbose: bool = True |
|
|
): |
|
|
""" Function for evaluating the performance of the model on a dataset |
|
|
|
|
|
Args: |
|
|
dataset: The dataset to be used for evaluation. |
|
|
batch_size: Batch size to use during inference. You can set it to be 1 |
|
|
(no batching) if you want to measure the running time of the model |
|
|
per individual example (assuming requests are coming to the model one-by-one). |
|
|
errors_log_fp: Path to the file for logging the errors |
|
|
verbose: if true prints and logs various evaluation results |
|
|
|
|
|
Returns: |
|
|
results: A Dict containing the evaluation results (e.g., accuracy, running time) |
|
|
""" |
|
|
results = {} |
|
|
error_f = open(errors_log_fp, 'w+') |
|
|
|
|
|
|
|
|
( |
|
|
all_run_times, |
|
|
all_dirs, |
|
|
all_inputs, |
|
|
all_targets, |
|
|
all_classes, |
|
|
all_nb_spans, |
|
|
all_span_starts, |
|
|
all_span_ends, |
|
|
all_output_spans, |
|
|
) = ([], [], [], [], [], [], [], [], []) |
|
|
all_tag_preds, all_final_preds = [], [] |
|
|
nb_iters = int(ceil(len(dataset) / batch_size)) |
|
|
for i in tqdm(range(nb_iters)): |
|
|
start_idx = i * batch_size |
|
|
end_idx = (i + 1) * batch_size |
|
|
batch_insts = dataset[start_idx:end_idx] |
|
|
( |
|
|
batch_dirs, |
|
|
batch_inputs, |
|
|
batch_targets, |
|
|
batch_classes, |
|
|
batch_nb_spans, |
|
|
batch_span_starts, |
|
|
batch_span_ends, |
|
|
) = zip(*batch_insts) |
|
|
|
|
|
batch_start_time = perf_counter() |
|
|
|
|
|
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer( |
|
|
batch_inputs, batch_dirs, processed=True |
|
|
) |
|
|
|
|
|
batch_run_time = (perf_counter() - batch_start_time) * 1000 |
|
|
all_run_times.append(batch_run_time) |
|
|
|
|
|
all_dirs.extend(batch_dirs) |
|
|
all_inputs.extend(batch_inputs) |
|
|
all_tag_preds.extend(batch_tag_preds) |
|
|
all_final_preds.extend(batch_final_preds) |
|
|
all_targets.extend(batch_targets) |
|
|
all_classes.extend(batch_classes) |
|
|
all_nb_spans.extend(batch_nb_spans) |
|
|
all_span_starts.extend(batch_span_starts) |
|
|
all_span_ends.extend(batch_span_ends) |
|
|
all_output_spans.extend(batch_output_spans) |
|
|
|
|
|
|
|
|
tn_error_ctx, itn_error_ctx = 0, 0 |
|
|
for direction in constants.INST_DIRECTIONS: |
|
|
( |
|
|
cur_dirs, |
|
|
cur_inputs, |
|
|
cur_tag_preds, |
|
|
cur_final_preds, |
|
|
cur_targets, |
|
|
cur_classes, |
|
|
cur_nb_spans, |
|
|
cur_span_starts, |
|
|
cur_span_ends, |
|
|
cur_output_spans, |
|
|
) = ([], [], [], [], [], [], [], [], [], []) |
|
|
for dir, _input, tag_pred, final_pred, target, cls, nb_spans, span_starts, span_ends, output_spans in zip( |
|
|
all_dirs, |
|
|
all_inputs, |
|
|
all_tag_preds, |
|
|
all_final_preds, |
|
|
all_targets, |
|
|
all_classes, |
|
|
all_nb_spans, |
|
|
all_span_starts, |
|
|
all_span_ends, |
|
|
all_output_spans, |
|
|
): |
|
|
if dir == direction: |
|
|
cur_dirs.append(dir) |
|
|
cur_inputs.append(_input) |
|
|
cur_tag_preds.append(tag_pred) |
|
|
cur_final_preds.append(final_pred) |
|
|
cur_targets.append(target) |
|
|
cur_classes.append(cls) |
|
|
cur_nb_spans.append(nb_spans) |
|
|
cur_span_starts.append(span_starts) |
|
|
cur_span_ends.append(span_ends) |
|
|
cur_output_spans.append(output_spans) |
|
|
nb_instances = len(cur_final_preds) |
|
|
cur_targets_sent = [" ".join(x) for x in cur_targets] |
|
|
|
|
|
sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy( |
|
|
cur_final_preds, cur_targets_sent, cur_dirs |
|
|
) |
|
|
|
|
|
class_accuracy = TextNormalizationTestDataset.compute_class_accuracy( |
|
|
[x.split() for x in cur_inputs], |
|
|
cur_targets, |
|
|
cur_tag_preds, |
|
|
cur_dirs, |
|
|
cur_output_spans, |
|
|
cur_classes, |
|
|
cur_nb_spans, |
|
|
cur_span_ends, |
|
|
) |
|
|
if verbose: |
|
|
logging.info(f'\n============ Direction {direction} ============') |
|
|
logging.info(f'Sentence Accuracy: {sent_accuracy}') |
|
|
logging.info(f'nb_instances: {nb_instances}') |
|
|
if not isinstance(class_accuracy, str): |
|
|
log_class_accuracies = "" |
|
|
for key, value in class_accuracy.items(): |
|
|
log_class_accuracies += f"\n\t{key}:\t{value[0]}\t{value[1]}/{value[2]}" |
|
|
else: |
|
|
log_class_accuracies = class_accuracy |
|
|
logging.info(f'class accuracies: {log_class_accuracies}') |
|
|
|
|
|
results[direction] = { |
|
|
'sent_accuracy': sent_accuracy, |
|
|
'nb_instances': nb_instances, |
|
|
"class_accuracy": log_class_accuracies, |
|
|
} |
|
|
|
|
|
for _input, tag_pred, final_pred, target, classes in zip( |
|
|
cur_inputs, cur_tag_preds, cur_final_preds, cur_targets_sent, cur_classes |
|
|
): |
|
|
if not TextNormalizationTestDataset.is_same(final_pred, target, direction): |
|
|
if direction == constants.INST_BACKWARD: |
|
|
error_f.write('Backward Problem (ITN)\n') |
|
|
itn_error_ctx += 1 |
|
|
elif direction == constants.INST_FORWARD: |
|
|
error_f.write('Forward Problem (TN)\n') |
|
|
tn_error_ctx += 1 |
|
|
|
|
|
formatted_input_str = get_formatted_string(self.decoder.processor.tokenize(_input).split()) |
|
|
formatted_tag_pred_str = get_formatted_string(tag_pred) |
|
|
class_str = " ".join(classes) |
|
|
error_f.write(f'Original Input : {_input}\n') |
|
|
error_f.write(f'Input : {formatted_input_str}\n') |
|
|
error_f.write(f'Predicted Tags : {formatted_tag_pred_str}\n') |
|
|
error_f.write(f'Ground Classes : {class_str}\n') |
|
|
error_f.write(f'Predicted Str : {final_pred}\n') |
|
|
error_f.write(f'Ground-Truth : {target}\n') |
|
|
error_f.write('\n') |
|
|
results['itn_error_ctx'] = itn_error_ctx |
|
|
results['tn_error_ctx'] = tn_error_ctx |
|
|
|
|
|
|
|
|
avg_running_time = np.average(all_run_times) / batch_size |
|
|
if verbose: |
|
|
logging.info(f'Average running time (normalized by batch size): {avg_running_time} ms') |
|
|
results['running_time'] = avg_running_time |
|
|
|
|
|
|
|
|
error_f.close() |
|
|
logging.info(f'Errors are saved at {errors_log_fp}.') |
|
|
return results |
|
|
|
|
|
|
|
|
def _infer(self, sents: List[str], inst_directions: List[str], processed=False): |
|
|
""" |
|
|
Main function for Inference |
|
|
|
|
|
If the 'joint' mode is used, "sents" will include both spoken and written forms on each input sentence, |
|
|
and "inst_directions" will include both constants.INST_BACKWARD and constants.INST_FORWARD |
|
|
|
|
|
Args: |
|
|
sents: A list of input texts. |
|
|
inst_directions: A list of str where each str indicates the direction of the corresponding instance \ |
|
|
(i.e., constants.INST_BACKWARD for ITN or constants.INST_FORWARD for TN). |
|
|
processed: Set to True when used with TextNormalizationTestDataset, the data is already tokenized with moses, |
|
|
repetitive moses tokenization could lead to the number of tokens and class span mismatch |
|
|
|
|
|
Returns: |
|
|
tag_preds: A list of lists where the inner list contains the tag predictions from the tagger for each word in the input text. |
|
|
output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text. |
|
|
final_outputs: A list of str where each str is the final output text for an input text. |
|
|
""" |
|
|
original_sents = [s for s in sents] |
|
|
|
|
|
if not processed: |
|
|
sents = [input_preprocessing(x, lang=self.lang) for x in sents] |
|
|
sents = [self.decoder.processor.tokenize(x).split() for x in sents] |
|
|
else: |
|
|
sents = [x.split() for x in sents] |
|
|
|
|
|
|
|
|
|
|
|
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(sents, inst_directions) |
|
|
output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions) |
|
|
|
|
|
|
|
|
final_outputs = [] |
|
|
for ix, (sent, tags) in enumerate(zip(sents, tag_preds)): |
|
|
try: |
|
|
cur_words, jx, span_idx = [], 0, 0 |
|
|
cur_spans = output_spans[ix] |
|
|
while jx < len(sent): |
|
|
tag, word = tags[jx], sent[jx] |
|
|
if constants.SAME_TAG in tag: |
|
|
cur_words.append(word) |
|
|
jx += 1 |
|
|
else: |
|
|
jx += 1 |
|
|
cur_words.append(cur_spans[span_idx]) |
|
|
span_idx += 1 |
|
|
while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG: |
|
|
jx += 1 |
|
|
|
|
|
if processed: |
|
|
|
|
|
cur_output_str = " ".join(cur_words) |
|
|
else: |
|
|
|
|
|
|
|
|
cur_output_str = self.decoder.processor.detokenize(cur_words) |
|
|
if PYNINI_AVAILABLE: |
|
|
cur_output_str = post_process_punct(input=original_sents[ix], normalized_text=cur_output_str) |
|
|
else: |
|
|
logging.warning( |
|
|
"`pynini` not installed, please install via nemo_text_processing/pynini_install.sh" |
|
|
) |
|
|
final_outputs.append(cur_output_str) |
|
|
except IndexError: |
|
|
logging.warning(f"Input sent is too long and will be skipped - {original_sents[ix]}") |
|
|
final_outputs.append(original_sents[ix]) |
|
|
return tag_preds, output_spans, final_outputs |
|
|
|