Spaces:
Running
Running
| from coref_utils.utils import get_mention_to_cluster_idx | |
| from collections import defaultdict | |
| def get_gt_actions(pred_mentions, document, mem_type_config, mapped_mentions=[]): | |
| if "clusters" in document: | |
| # Ground truth is avaliable | |
| gt_clusters = document["clusters"] | |
| return get_actions_unbounded_fast(pred_mentions, gt_clusters, mapped_mentions) | |
| else: | |
| # Don't have ground truth clusters i.e. running it in the wild | |
| # Generate dummy actions | |
| return [(-1, "i")] * len(pred_mentions) | |
| def action_sequences_to_clusters(actions, mentions, num_major_entities): | |
| cell_to_clusters = defaultdict(list) | |
| for mention, (cell_idx, action_type) in zip(mentions, actions): | |
| if action_type == "i": | |
| continue | |
| elif action_type == "o": | |
| cell_to_clusters[num_major_entities].append(mention) | |
| else: | |
| cell_to_clusters[cell_idx].append(mention) | |
| clusters = [[] for _ in range(num_major_entities + 1)] | |
| for cell_idx, cluster in cell_to_clusters.items(): | |
| clusters[cell_idx] = cluster | |
| return clusters | |
| def get_cluster_to_cell(mapped_mentions, mention_to_cluster): | |
| cluster_to_cell = {} | |
| cell_counter = 0 | |
| for mention in mapped_mentions: | |
| if tuple(mention) not in mention_to_cluster: | |
| print("Error: Mention not in mentions", tuple(mention)) | |
| else: | |
| mention_cluster = mention_to_cluster[tuple(mention)] | |
| if mention_cluster not in cluster_to_cell: | |
| cluster_to_cell[mention_cluster] = cell_counter | |
| cell_counter += 1 | |
| return cluster_to_cell | |
| def get_actions_unbounded_fast(pred_mentions, gt_clusters, mapped_mentions=[]): | |
| actions = [] | |
| num_clusters = len(gt_clusters) | |
| mention_to_cluster = get_mention_to_cluster_idx(gt_clusters) | |
| for idx, mention in enumerate(pred_mentions): | |
| if tuple(mention) not in mention_to_cluster: | |
| actions.append((num_clusters - 1, "o")) | |
| else: | |
| mention_cluster = mention_to_cluster[tuple(mention)] | |
| if mention_cluster == num_clusters - 1: | |
| actions.append((mention_cluster, "o")) | |
| else: | |
| actions.append((mention_cluster, "c")) | |
| return actions | |