|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
|
from typing import Dict, List, Tuple, Union |
|
|
|
|
|
from .evaluation.metrics.metrics import ErrorMetric |
|
|
|
|
|
|
|
|
def parse_semantics_str2dict(semantics_str: Union[List[str], str, Dict]) -> Tuple[Dict, bool]: |
|
|
""" |
|
|
This function parse the input string to a valid python dictionary for later evaluation. |
|
|
Part of this function is adapted from |
|
|
https://github.com/speechbrain/speechbrain/blob/develop/recipes/SLURP/direct/train_with_wav2vec2.py#L110-L127 |
|
|
""" |
|
|
invalid = False |
|
|
if isinstance(semantics_str, dict): |
|
|
return semantics_str, invalid |
|
|
if isinstance(semantics_str, list): |
|
|
semantics_str = " ".join(semantics_str) |
|
|
|
|
|
try: |
|
|
if "|" in semantics_str: |
|
|
semantics_str = semantics_str.replace("|", ",") |
|
|
_dict = ast.literal_eval(semantics_str) |
|
|
if not isinstance(_dict, dict): |
|
|
_dict = { |
|
|
"scenario": "none", |
|
|
"action": "none", |
|
|
"entities": [], |
|
|
} |
|
|
invalid = True |
|
|
except SyntaxError: |
|
|
_dict = { |
|
|
"scenario": "none", |
|
|
"action": "none", |
|
|
"entities": [], |
|
|
} |
|
|
invalid = True |
|
|
|
|
|
if "scenario" not in _dict or not isinstance(_dict["scenario"], str): |
|
|
_dict["scenario"] = "none" |
|
|
invalid = True |
|
|
if "action" not in _dict or not isinstance(_dict["action"], str): |
|
|
_dict["action"] = "none" |
|
|
invalid = True |
|
|
if "entities" not in _dict: |
|
|
_dict["entities"] = [] |
|
|
invalid = True |
|
|
else: |
|
|
|
|
|
def _parse_entity(item: Dict): |
|
|
error = False |
|
|
for key in ["type", "filler"]: |
|
|
if key not in item or not isinstance(item[key], str): |
|
|
item[key] = "none" |
|
|
error = True |
|
|
return item, error |
|
|
|
|
|
for i, x in enumerate(_dict["entities"]): |
|
|
item, entity_error = _parse_entity(x) |
|
|
invalid = invalid or entity_error |
|
|
_dict["entities"][i] = item |
|
|
|
|
|
return _dict, invalid |
|
|
|
|
|
|
|
|
class SLURPEvaluator: |
|
|
""" |
|
|
Evaluator class for calculating SLURP metrics |
|
|
""" |
|
|
|
|
|
def __init__(self, average_mode: str = 'micro') -> None: |
|
|
if average_mode not in ['micro', 'macro']: |
|
|
raise ValueError(f"Only supports 'micro' or 'macro' average, but got {average_mode} instead.") |
|
|
self.average_mode = average_mode |
|
|
self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
|
|
self.action_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
|
|
self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
|
|
self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=average_mode) |
|
|
self.distance_metrics = {} |
|
|
for distance in ['word', 'char']: |
|
|
self.distance_metrics[distance] = ErrorMetric.get_instance( |
|
|
metric="span_distance_f1", average=average_mode, distance=distance |
|
|
) |
|
|
self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=average_mode) |
|
|
self.invalid = 0 |
|
|
self.total = 0 |
|
|
|
|
|
def reset(self): |
|
|
self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
|
|
self.action_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
|
|
self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
|
|
self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=self.average_mode) |
|
|
self.distance_metrics = {} |
|
|
for distance in ['word', 'char']: |
|
|
self.distance_metrics[distance] = ErrorMetric.get_instance( |
|
|
metric="span_distance_f1", average=self.average_mode, distance=distance |
|
|
) |
|
|
self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=self.average_mode) |
|
|
self.invalid = 0 |
|
|
self.total = 0 |
|
|
|
|
|
def update(self, predictions: Union[List[str], str], groundtruth: Union[List[str], str]) -> None: |
|
|
if isinstance(predictions, str): |
|
|
predictions = [predictions] |
|
|
if isinstance(groundtruth, str): |
|
|
groundtruth = [groundtruth] |
|
|
|
|
|
for pred, truth in zip(predictions, groundtruth): |
|
|
pred, syntax_error = parse_semantics_str2dict(pred) |
|
|
truth, _ = parse_semantics_str2dict(truth) |
|
|
self.scenario_f1(truth["scenario"], pred["scenario"]) |
|
|
self.action_f1(truth["action"], pred["action"]) |
|
|
self.intent_f1(f"{truth['scenario']}_{truth['action']}", f"{pred['scenario']}_{pred['action']}") |
|
|
self.span_f1(truth["entities"], pred["entities"]) |
|
|
for distance, metric in self.distance_metrics.items(): |
|
|
metric(truth["entities"], pred["entities"]) |
|
|
|
|
|
self.total += 1 |
|
|
self.invalid += int(syntax_error) |
|
|
|
|
|
def compute(self, aggregate=True) -> Dict: |
|
|
scenario_results = self.scenario_f1.get_metric() |
|
|
action_results = self.action_f1.get_metric() |
|
|
intent_results = self.intent_f1.get_metric() |
|
|
entity_results = self.span_f1.get_metric() |
|
|
word_dist_results = self.distance_metrics['word'].get_metric() |
|
|
char_dist_results = self.distance_metrics['char'].get_metric() |
|
|
self.slu_f1(word_dist_results) |
|
|
self.slu_f1(char_dist_results) |
|
|
slurp_results = self.slu_f1.get_metric() |
|
|
|
|
|
if not aggregate: |
|
|
return { |
|
|
"scenario": scenario_results, |
|
|
"action": action_results, |
|
|
"intent": intent_results, |
|
|
"entity": entity_results, |
|
|
"word_dist": word_dist_results, |
|
|
"char_dist": char_dist_results, |
|
|
"slurp": slurp_results, |
|
|
"invalid": self.invalid, |
|
|
"total": self.total, |
|
|
} |
|
|
|
|
|
scores = dict() |
|
|
scores["invalid"] = self.invalid |
|
|
scores["total"] = self.total |
|
|
self.update_scores_dict(scenario_results, scores, "scenario") |
|
|
self.update_scores_dict(action_results, scores, "action") |
|
|
self.update_scores_dict(intent_results, scores, "intent") |
|
|
self.update_scores_dict(entity_results, scores, "entity") |
|
|
self.update_scores_dict(word_dist_results, scores, "word_dist") |
|
|
self.update_scores_dict(char_dist_results, scores, "char_dist") |
|
|
self.update_scores_dict(slurp_results, scores, "slurp") |
|
|
|
|
|
return scores |
|
|
|
|
|
def update_scores_dict(self, source: Dict, target: Dict, tag: str = '') -> Dict: |
|
|
scores = source['overall'] |
|
|
p, r, f1 = scores[:3] |
|
|
target[f"{tag}_p"] = p |
|
|
target[f"{tag}_r"] = r |
|
|
target[f"{tag}_f1"] = f1 |
|
|
return target |
|
|
|