Spaces:
Running
Running
| import re | |
| import subprocess | |
| import operator | |
| import collections | |
| BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") | |
| COREF_RESULTS_REGEX = re.compile( | |
| r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) " | |
| r"([0-9.]+)%\tF1: ([0-9.]+)%.*", | |
| re.DOTALL, | |
| ) | |
| def get_doc_key(doc_id, part): | |
| return "{}_{}".format(doc_id, int(part)) | |
| def output_conll(input_file, output_file, predictions, subtoken_map): | |
| prediction_map = {} | |
| for doc_key, clusters in predictions.items(): | |
| start_map = collections.defaultdict(list) | |
| end_map = collections.defaultdict(list) | |
| word_map = collections.defaultdict(list) | |
| for cluster_id, mentions in enumerate(clusters): | |
| for start, end in mentions: | |
| start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end] | |
| if start == end: | |
| word_map[start].append(cluster_id) | |
| else: | |
| start_map[start].append((cluster_id, end)) | |
| end_map[end].append((cluster_id, start)) | |
| for k, v in start_map.items(): | |
| start_map[k] = [ | |
| cluster_id | |
| for cluster_id, end in sorted( | |
| v, key=operator.itemgetter(1), reverse=True | |
| ) | |
| ] | |
| for k, v in end_map.items(): | |
| end_map[k] = [ | |
| cluster_id | |
| for cluster_id, start in sorted( | |
| v, key=operator.itemgetter(1), reverse=True | |
| ) | |
| ] | |
| prediction_map[doc_key] = (start_map, end_map, word_map) | |
| word_index = 0 | |
| for line in input_file.readlines(): | |
| row = line.split() | |
| if len(row) == 0: | |
| output_file.write("\n") | |
| elif row[0].startswith("#"): | |
| begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) | |
| if begin_match: | |
| doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) | |
| start_map, end_map, word_map = prediction_map[doc_key] | |
| word_index = 0 | |
| output_file.write(line) | |
| # output_file.write("\n") | |
| else: | |
| assert get_doc_key(row[0], row[1]) == doc_key | |
| coref_list = [] | |
| if word_index in end_map: | |
| for cluster_id in end_map[word_index]: | |
| coref_list.append("{})".format(cluster_id)) | |
| if word_index in word_map: | |
| for cluster_id in word_map[word_index]: | |
| coref_list.append("({})".format(cluster_id)) | |
| if word_index in start_map: | |
| for cluster_id in start_map[word_index]: | |
| coref_list.append("({}".format(cluster_id)) | |
| if len(coref_list) == 0: | |
| row[-1] = "-" | |
| else: | |
| row[-1] = "|".join(coref_list) | |
| output_file.write(" ".join(row)) | |
| output_file.write("\n") | |
| word_index += 1 | |
| def official_conll_eval( | |
| conll_scorer, gold_path, predicted_path, metric, official_stdout=False | |
| ): | |
| cmd = [conll_scorer, metric, gold_path, predicted_path, "none"] | |
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) | |
| stdout, stderr = process.communicate() | |
| process.wait() | |
| stdout = stdout.decode("utf-8") | |
| if stderr is not None: | |
| print(stderr) | |
| if official_stdout: | |
| print("Official result for {}".format(metric)) | |
| print(stdout) | |
| coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) | |
| recall = float(coref_results_match.group(1)) | |
| precision = float(coref_results_match.group(2)) | |
| f1 = float(coref_results_match.group(3)) | |
| return {"r": recall, "p": precision, "f": f1} | |
| def evaluate_conll( | |
| conll_scorer, | |
| gold_path, | |
| predictions, | |
| subtoken_maps, | |
| prediction_path, | |
| all_metrics=False, | |
| official_stdout=False, | |
| ): | |
| with open(prediction_path, "w") as prediction_file: | |
| with open(gold_path, "r") as gold_file: | |
| output_conll(gold_file, prediction_file, predictions, subtoken_maps) | |
| result = { | |
| metric: official_conll_eval( | |
| conll_scorer, gold_file.name, prediction_file.name, metric, official_stdout | |
| ) | |
| for metric in ("muc", "bcub", "ceafe") | |
| } | |
| return result | |