camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import ceil
from time import perf_counter
from typing import List
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset, constants
from nemo.collections.nlp.data.text_normalization.utils import input_preprocessing
from nemo.collections.nlp.models.duplex_text_normalization.utils import get_formatted_string
from nemo.utils import logging
try:
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct
PYNINI_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
PYNINI_AVAILABLE = False
__all__ = ['DuplexTextNormalizationModel']
class DuplexTextNormalizationModel(nn.Module):
"""
DuplexTextNormalizationModel is a wrapper class that can be used to
encapsulate a trained tagger and a trained decoder. The class is intended
to be used for inference only (e.g., for evaluation).
"""
def __init__(self, tagger, decoder, lang):
super(DuplexTextNormalizationModel, self).__init__()
self.tagger = tagger
self.decoder = decoder
self.lang = lang
def evaluate(
self, dataset: TextNormalizationTestDataset, batch_size: int, errors_log_fp: str, verbose: bool = True
):
""" Function for evaluating the performance of the model on a dataset
Args:
dataset: The dataset to be used for evaluation.
batch_size: Batch size to use during inference. You can set it to be 1
(no batching) if you want to measure the running time of the model
per individual example (assuming requests are coming to the model one-by-one).
errors_log_fp: Path to the file for logging the errors
verbose: if true prints and logs various evaluation results
Returns:
results: A Dict containing the evaluation results (e.g., accuracy, running time)
"""
results = {}
error_f = open(errors_log_fp, 'w+')
# Apply the model on the dataset
(
all_run_times,
all_dirs,
all_inputs,
all_targets,
all_classes,
all_nb_spans,
all_span_starts,
all_span_ends,
all_output_spans,
) = ([], [], [], [], [], [], [], [], [])
all_tag_preds, all_final_preds = [], []
nb_iters = int(ceil(len(dataset) / batch_size))
for i in tqdm(range(nb_iters)):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
batch_insts = dataset[start_idx:end_idx]
(
batch_dirs,
batch_inputs,
batch_targets,
batch_classes,
batch_nb_spans,
batch_span_starts,
batch_span_ends,
) = zip(*batch_insts)
# Inference and Running Time Measurement
batch_start_time = perf_counter()
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(
batch_inputs, batch_dirs, processed=True
)
batch_run_time = (perf_counter() - batch_start_time) * 1000 # milliseconds
all_run_times.append(batch_run_time)
# Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
all_dirs.extend(batch_dirs)
all_inputs.extend(batch_inputs)
all_tag_preds.extend(batch_tag_preds)
all_final_preds.extend(batch_final_preds)
all_targets.extend(batch_targets)
all_classes.extend(batch_classes)
all_nb_spans.extend(batch_nb_spans)
all_span_starts.extend(batch_span_starts)
all_span_ends.extend(batch_span_ends)
all_output_spans.extend(batch_output_spans)
# Metrics
tn_error_ctx, itn_error_ctx = 0, 0
for direction in constants.INST_DIRECTIONS:
(
cur_dirs,
cur_inputs,
cur_tag_preds,
cur_final_preds,
cur_targets,
cur_classes,
cur_nb_spans,
cur_span_starts,
cur_span_ends,
cur_output_spans,
) = ([], [], [], [], [], [], [], [], [], [])
for dir, _input, tag_pred, final_pred, target, cls, nb_spans, span_starts, span_ends, output_spans in zip(
all_dirs,
all_inputs,
all_tag_preds,
all_final_preds,
all_targets,
all_classes,
all_nb_spans,
all_span_starts,
all_span_ends,
all_output_spans,
):
if dir == direction:
cur_dirs.append(dir)
cur_inputs.append(_input)
cur_tag_preds.append(tag_pred)
cur_final_preds.append(final_pred)
cur_targets.append(target)
cur_classes.append(cls)
cur_nb_spans.append(nb_spans)
cur_span_starts.append(span_starts)
cur_span_ends.append(span_ends)
cur_output_spans.append(output_spans)
nb_instances = len(cur_final_preds)
cur_targets_sent = [" ".join(x) for x in cur_targets]
sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
cur_final_preds, cur_targets_sent, cur_dirs
)
class_accuracy = TextNormalizationTestDataset.compute_class_accuracy(
[x.split() for x in cur_inputs],
cur_targets,
cur_tag_preds,
cur_dirs,
cur_output_spans,
cur_classes,
cur_nb_spans,
cur_span_ends,
)
if verbose:
logging.info(f'\n============ Direction {direction} ============')
logging.info(f'Sentence Accuracy: {sent_accuracy}')
logging.info(f'nb_instances: {nb_instances}')
if not isinstance(class_accuracy, str):
log_class_accuracies = ""
for key, value in class_accuracy.items():
log_class_accuracies += f"\n\t{key}:\t{value[0]}\t{value[1]}/{value[2]}"
else:
log_class_accuracies = class_accuracy
logging.info(f'class accuracies: {log_class_accuracies}')
# Update results
results[direction] = {
'sent_accuracy': sent_accuracy,
'nb_instances': nb_instances,
"class_accuracy": log_class_accuracies,
}
# Write errors to log file
for _input, tag_pred, final_pred, target, classes in zip(
cur_inputs, cur_tag_preds, cur_final_preds, cur_targets_sent, cur_classes
):
if not TextNormalizationTestDataset.is_same(final_pred, target, direction):
if direction == constants.INST_BACKWARD:
error_f.write('Backward Problem (ITN)\n')
itn_error_ctx += 1
elif direction == constants.INST_FORWARD:
error_f.write('Forward Problem (TN)\n')
tn_error_ctx += 1
formatted_input_str = get_formatted_string(self.decoder.processor.tokenize(_input).split())
formatted_tag_pred_str = get_formatted_string(tag_pred)
class_str = " ".join(classes)
error_f.write(f'Original Input : {_input}\n')
error_f.write(f'Input : {formatted_input_str}\n')
error_f.write(f'Predicted Tags : {formatted_tag_pred_str}\n')
error_f.write(f'Ground Classes : {class_str}\n')
error_f.write(f'Predicted Str : {final_pred}\n')
error_f.write(f'Ground-Truth : {target}\n')
error_f.write('\n')
results['itn_error_ctx'] = itn_error_ctx
results['tn_error_ctx'] = tn_error_ctx
# Running Time
avg_running_time = np.average(all_run_times) / batch_size # in ms
if verbose:
logging.info(f'Average running time (normalized by batch size): {avg_running_time} ms')
results['running_time'] = avg_running_time
# Close log file
error_f.close()
logging.info(f'Errors are saved at {errors_log_fp}.')
return results
# Functions for inference
def _infer(self, sents: List[str], inst_directions: List[str], processed=False):
"""
Main function for Inference
If the 'joint' mode is used, "sents" will include both spoken and written forms on each input sentence,
and "inst_directions" will include both constants.INST_BACKWARD and constants.INST_FORWARD
Args:
sents: A list of input texts.
inst_directions: A list of str where each str indicates the direction of the corresponding instance \
(i.e., constants.INST_BACKWARD for ITN or constants.INST_FORWARD for TN).
processed: Set to True when used with TextNormalizationTestDataset, the data is already tokenized with moses,
repetitive moses tokenization could lead to the number of tokens and class span mismatch
Returns:
tag_preds: A list of lists where the inner list contains the tag predictions from the tagger for each word in the input text.
output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text.
final_outputs: A list of str where each str is the final output text for an input text.
"""
original_sents = [s for s in sents]
# Separate into words
if not processed:
sents = [input_preprocessing(x, lang=self.lang) for x in sents]
sents = [self.decoder.processor.tokenize(x).split() for x in sents]
else:
sents = [x.split() for x in sents]
# Tagging
# span_ends included, returns index wrt to words in input without auxiliary words
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(sents, inst_directions)
output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions)
# Prepare final outputs
final_outputs = []
for ix, (sent, tags) in enumerate(zip(sents, tag_preds)):
try:
cur_words, jx, span_idx = [], 0, 0
cur_spans = output_spans[ix]
while jx < len(sent):
tag, word = tags[jx], sent[jx]
if constants.SAME_TAG in tag:
cur_words.append(word)
jx += 1
else:
jx += 1
cur_words.append(cur_spans[span_idx])
span_idx += 1
while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG:
jx += 1
if processed:
# for Class-based evaluation, don't apply Moses detokenization
cur_output_str = " ".join(cur_words)
else:
# detokenize the output with Moses and fix punctuation marks to match the input
# for interactive inference or inference from a file
cur_output_str = self.decoder.processor.detokenize(cur_words)
if PYNINI_AVAILABLE:
cur_output_str = post_process_punct(input=original_sents[ix], normalized_text=cur_output_str)
else:
logging.warning(
"`pynini` not installed, please install via nemo_text_processing/pynini_install.sh"
)
final_outputs.append(cur_output_str)
except IndexError:
logging.warning(f"Input sent is too long and will be skipped - {original_sents[ix]}")
final_outputs.append(original_sents[ix])
return tag_preds, output_spans, final_outputs