| |
| from collections import Counter |
|
|
|
|
| def gt_label2entity(gt_infos): |
| """Get all entities from ground truth infos. |
| Args: |
| gt_infos (list[dict]): Ground-truth information contains text and |
| label. |
| Returns: |
| gt_entities (list[list]): Original labeled entities in groundtruth. |
| [[category,start_position,end_position]] |
| """ |
| gt_entities = [] |
| for gt_info in gt_infos: |
| line_entities = [] |
| label = gt_info['label'] |
| for key, value in label.items(): |
| for _, places in value.items(): |
| for place in places: |
| line_entities.append([key, place[0], place[1]]) |
| gt_entities.append(line_entities) |
| return gt_entities |
|
|
|
|
| def _compute_f1(origin, found, right): |
| """Calculate recall, precision, f1-score. |
| |
| Args: |
| origin (int): Original entities in groundtruth. |
| found (int): Predicted entities from model. |
| right (int): Predicted entities that |
| can match to the original annotation. |
| Returns: |
| recall (float): Metric of recall. |
| precision (float): Metric of precision. |
| f1 (float): Metric of f1-score. |
| """ |
| recall = 0 if origin == 0 else (right / origin) |
| precision = 0 if found == 0 else (right / found) |
| f1 = 0. if recall + precision == 0 else (2 * precision * recall) / ( |
| precision + recall) |
| return recall, precision, f1 |
|
|
|
|
| def compute_f1_all(pred_entities, gt_entities): |
| """Calculate precision, recall and F1-score for all categories. |
| |
| Args: |
| pred_entities: The predicted entities from model. |
| gt_entities: The entities of ground truth file. |
| Returns: |
| class_info (dict): precision,recall, f1-score in total |
| and each categories. |
| """ |
| origins = [] |
| founds = [] |
| rights = [] |
| for i, _ in enumerate(pred_entities): |
| origins.extend(gt_entities[i]) |
| founds.extend(pred_entities[i]) |
| rights.extend([ |
| pre_entity for pre_entity in pred_entities[i] |
| if pre_entity in gt_entities[i] |
| ]) |
|
|
| class_info = {} |
| origin_counter = Counter([x[0] for x in origins]) |
| found_counter = Counter([x[0] for x in founds]) |
| right_counter = Counter([x[0] for x in rights]) |
| for type_, count in origin_counter.items(): |
| origin = count |
| found = found_counter.get(type_, 0) |
| right = right_counter.get(type_, 0) |
| recall, precision, f1 = _compute_f1(origin, found, right) |
| class_info[type_] = { |
| 'precision': precision, |
| 'recall': recall, |
| 'f1-score': f1 |
| } |
| origin = len(origins) |
| found = len(founds) |
| right = len(rights) |
| recall, precision, f1 = _compute_f1(origin, found, right) |
| class_info['all'] = { |
| 'precision': precision, |
| 'recall': recall, |
| 'f1-score': f1 |
| } |
| return class_info |
|
|
|
|
| def eval_ner_f1(results, gt_infos): |
| """Evaluate for ner task. |
| |
| Args: |
| results (list): Predict results of entities. |
| gt_infos (list[dict]): Ground-truth information which contains |
| text and label. |
| Returns: |
| class_info (dict): precision,recall, f1-score of total |
| and each catogory. |
| """ |
| assert len(results) == len(gt_infos) |
| gt_entities = gt_label2entity(gt_infos) |
| pred_entities = [] |
| for i, gt_info in enumerate(gt_infos): |
| line_entities = [] |
| for result in results[i]: |
| line_entities.append(result) |
| pred_entities.append(line_entities) |
| assert len(pred_entities) == len(gt_entities) |
| class_info = compute_f1_all(pred_entities, gt_entities) |
|
|
| return class_info |
|
|