| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import ast |
| from typing import Dict, List, Tuple, Union |
|
|
| from .evaluation.metrics.metrics import ErrorMetric |
|
|
|
|
| def parse_semantics_str2dict(semantics_str: Union[List[str], str, Dict]) -> Tuple[Dict, bool]: |
| """ |
| This function parse the input string to a valid python dictionary for later evaluation. |
| Part of this function is adapted from |
| https://github.com/speechbrain/speechbrain/blob/develop/recipes/SLURP/direct/train_with_wav2vec2.py#L110-L127 |
| """ |
| invalid = False |
| if isinstance(semantics_str, dict): |
| return semantics_str, invalid |
| if isinstance(semantics_str, list): |
| semantics_str = " ".join(semantics_str) |
|
|
| try: |
| if "|" in semantics_str: |
| semantics_str = semantics_str.replace("|", ",") |
| _dict = ast.literal_eval(semantics_str) |
| if not isinstance(_dict, dict): |
| _dict = { |
| "scenario": "none", |
| "action": "none", |
| "entities": [], |
| } |
| invalid = True |
| except Exception: |
| _dict = { |
| "scenario": "none", |
| "action": "none", |
| "entities": [], |
| } |
| invalid = True |
|
|
| if "scenario" not in _dict or not isinstance(_dict["scenario"], str): |
| _dict["scenario"] = "none" |
| invalid = True |
| if "action" not in _dict or not isinstance(_dict["action"], str): |
| _dict["action"] = "none" |
| invalid = True |
| if "entities" not in _dict: |
| _dict["entities"] = [] |
| invalid = True |
| else: |
|
|
| def _parse_entity(item: Dict): |
| error = False |
| for key in ["type", "filler"]: |
| if key not in item or not isinstance(item[key], str): |
| item[key] = "none" |
| error = True |
| return item, error |
|
|
| for i, x in enumerate(_dict["entities"]): |
| item, entity_error = _parse_entity(x) |
| invalid = invalid or entity_error |
| _dict["entities"][i] = item |
|
|
| return _dict, invalid |
|
|
|
|
| class SLURPEvaluator: |
| """ |
| Evaluator class for calculating SLURP metrics |
| """ |
|
|
| def __init__(self, average_mode: str = 'micro') -> None: |
| if average_mode not in ['micro', 'macro']: |
| raise ValueError(f"Only supports 'micro' or 'macro' average, but got {average_mode} instead.") |
| self.average_mode = average_mode |
| self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
| self.action_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
| self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=average_mode) |
| self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=average_mode) |
| self.distance_metrics = {} |
| for distance in ['word', 'char']: |
| self.distance_metrics[distance] = ErrorMetric.get_instance( |
| metric="span_distance_f1", average=average_mode, distance=distance |
| ) |
| self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=average_mode) |
| self.invalid = 0 |
| self.total = 0 |
|
|
| def reset(self): |
| self.scenario_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
| self.action_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
| self.intent_f1 = ErrorMetric.get_instance(metric="f1", average=self.average_mode) |
| self.span_f1 = ErrorMetric.get_instance(metric="span_f1", average=self.average_mode) |
| self.distance_metrics = {} |
| for distance in ['word', 'char']: |
| self.distance_metrics[distance] = ErrorMetric.get_instance( |
| metric="span_distance_f1", average=self.average_mode, distance=distance |
| ) |
| self.slu_f1 = ErrorMetric.get_instance(metric="slu_f1", average=self.average_mode) |
| self.invalid = 0 |
| self.total = 0 |
|
|
| def update(self, predictions: Union[List[str], str], groundtruth: Union[List[str], str]) -> None: |
| if isinstance(predictions, str): |
| predictions = [predictions] |
| if isinstance(groundtruth, str): |
| groundtruth = [groundtruth] |
|
|
| for pred, truth in zip(predictions, groundtruth): |
| pred, syntax_error = parse_semantics_str2dict(pred) |
| truth, _ = parse_semantics_str2dict(truth) |
| self.scenario_f1(truth["scenario"], pred["scenario"]) |
| self.action_f1(truth["action"], pred["action"]) |
| self.intent_f1(f"{truth['scenario']}_{truth['action']}", f"{pred['scenario']}_{pred['action']}") |
| self.span_f1(truth["entities"], pred["entities"]) |
| for distance, metric in self.distance_metrics.items(): |
| metric(truth["entities"], pred["entities"]) |
|
|
| self.total += 1 |
| self.invalid += int(syntax_error) |
|
|
| def compute(self, aggregate=True) -> Dict: |
| scenario_results = self.scenario_f1.get_metric() |
| action_results = self.action_f1.get_metric() |
| intent_results = self.intent_f1.get_metric() |
| entity_results = self.span_f1.get_metric() |
| word_dist_results = self.distance_metrics['word'].get_metric() |
| char_dist_results = self.distance_metrics['char'].get_metric() |
| self.slu_f1(word_dist_results) |
| self.slu_f1(char_dist_results) |
| slurp_results = self.slu_f1.get_metric() |
|
|
| if not aggregate: |
| return { |
| "scenario": scenario_results, |
| "action": action_results, |
| "intent": intent_results, |
| "entity": entity_results, |
| "word_dist": word_dist_results, |
| "char_dist": char_dist_results, |
| "slurp": slurp_results, |
| "invalid": self.invalid, |
| "total": self.total, |
| } |
|
|
| scores = dict() |
| scores["invalid"] = self.invalid |
| scores["total"] = self.total |
| self.update_scores_dict(scenario_results, scores, "scenario") |
| self.update_scores_dict(action_results, scores, "action") |
| self.update_scores_dict(intent_results, scores, "intent") |
| self.update_scores_dict(entity_results, scores, "entity") |
| self.update_scores_dict(word_dist_results, scores, "word_dist") |
| self.update_scores_dict(char_dist_results, scores, "char_dist") |
| self.update_scores_dict(slurp_results, scores, "slurp") |
|
|
| return scores |
|
|
| def update_scores_dict(self, source: Dict, target: Dict, tag: str = '') -> Dict: |
| scores = source['overall'] |
| p, r, f1 = scores[:3] |
| target[f"{tag}_p"] = p |
| target[f"{tag}_r"] = r |
| target[f"{tag}_f1"] = f1 |
| return target |
|
|