camenduru's picture
thanks to NVIDIA ❤
7934b29
# ! /usr/bin/python
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import Dict, List, Tuple, Union
from .evaluation.metrics.metrics import ErrorMetric
def parse_semantics_str2dict(semantics_str: Union[List[str], str, Dict]) -> Tuple[Dict, bool]:
"""
This function parse the input string to a valid python dictionary for later evaluation.
Part of this function is adapted from
https://github.com/speechbrain/speechbrain/blob/develop/recipes/SLURP/direct/train_with_wav2vec2.py#L110-L127
"""
invalid = False
if isinstance(semantics_str, dict):
return semantics_str, invalid
if isinstance(semantics_str, list):
semantics_str = " ".join(semantics_str)
try:
if "|" in semantics_str:
semantics_str = semantics_str.replace("|", ",")
_dict = ast.literal_eval(semantics_str)
if not isinstance(_dict, dict):
_dict = {
"scenario": "none",
"action": "none",
"entities": [],
}
invalid = True
except SyntaxError: # need this if the output is not a valid dict
_dict = {
"scenario": "none",
"action": "none",
"entities": [],
}
invalid = True
if "scenario" not in _dict or not isinstance(_dict["scenario"], str):
_dict["scenario"] = "none"
invalid = True
if "action" not in _dict or not isinstance(_dict["action"], str):
_dict["action"] = "none"
invalid = True
if "entities" not in _dict:
_dict["entities"] = []
invalid = True
else:
def _parse_entity(item: Dict):
error = False
for key in ["type", "filler"]:
if key not in item or not isinstance(item[key], str):
item[key] = "none"
error = True
return item, error
for i, x in enumerate(_dict["entities"]):
item, entity_error = _parse_entity(x)
invalid = invalid or entity_error
_dict["entities"][i] = item
return _dict, invalid
class SLURPEvaluator:
"""
Evaluator class for calculating SLURP metrics
"""
def __init__(self, average_mode: str = 'micro') -> None:
if average_mode not in ['micro', 'macro']:
raise ValueError(f"Only supports 'micro' or 'macro' average, but got {average_mode} instead.")
self.average_mode = average_mode
self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode)
self.action_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode)
self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode)
self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=average_mode)
self.distance_metrics = {}
for distance in ['word', 'char']:
self.distance_metrics[distance] = ErrorMetric.get_instance(
metric="span_distance_f1", average=average_mode, distance=distance
)
self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=average_mode)
self.invalid = 0
self.total = 0
def reset(self):
self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode)
self.action_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode)
self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode)
self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=self.average_mode)
self.distance_metrics = {}
for distance in ['word', 'char']:
self.distance_metrics[distance] = ErrorMetric.get_instance(
metric="span_distance_f1", average=self.average_mode, distance=distance
)
self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=self.average_mode)
self.invalid = 0
self.total = 0
def update(self, predictions: Union[List[str], str], groundtruth: Union[List[str], str]) -> None:
if isinstance(predictions, str):
predictions = [predictions]
if isinstance(groundtruth, str):
groundtruth = [groundtruth]
for pred, truth in zip(predictions, groundtruth):
pred, syntax_error = parse_semantics_str2dict(pred)
truth, _ = parse_semantics_str2dict(truth)
self.scenario_f1(truth["scenario"], pred["scenario"])
self.action_f1(truth["action"], pred["action"])
self.intent_f1(f"{truth['scenario']}_{truth['action']}", f"{pred['scenario']}_{pred['action']}")
self.span_f1(truth["entities"], pred["entities"])
for distance, metric in self.distance_metrics.items():
metric(truth["entities"], pred["entities"])
self.total += 1
self.invalid += int(syntax_error)
def compute(self, aggregate=True) -> Dict:
scenario_results = self.scenario_f1.get_metric()
action_results = self.action_f1.get_metric()
intent_results = self.intent_f1.get_metric()
entity_results = self.span_f1.get_metric()
word_dist_results = self.distance_metrics['word'].get_metric()
char_dist_results = self.distance_metrics['char'].get_metric()
self.slu_f1(word_dist_results)
self.slu_f1(char_dist_results)
slurp_results = self.slu_f1.get_metric()
if not aggregate:
return {
"scenario": scenario_results,
"action": action_results,
"intent": intent_results,
"entity": entity_results,
"word_dist": word_dist_results,
"char_dist": char_dist_results,
"slurp": slurp_results,
"invalid": self.invalid,
"total": self.total,
}
scores = dict()
scores["invalid"] = self.invalid
scores["total"] = self.total
self.update_scores_dict(scenario_results, scores, "scenario")
self.update_scores_dict(action_results, scores, "action")
self.update_scores_dict(intent_results, scores, "intent")
self.update_scores_dict(entity_results, scores, "entity")
self.update_scores_dict(word_dist_results, scores, "word_dist")
self.update_scores_dict(char_dist_results, scores, "char_dist")
self.update_scores_dict(slurp_results, scores, "slurp")
return scores
def update_scores_dict(self, source: Dict, target: Dict, tag: str = '') -> Dict:
scores = source['overall']
p, r, f1 = scores[:3]
target[f"{tag}_p"] = p
target[f"{tag}_r"] = r
target[f"{tag}_f1"] = f1
return target