| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import math |
| | import os |
| | import subprocess |
| | import sys |
| | import tempfile |
| | from collections import defaultdict |
| | from itertools import combinations |
| |
|
| |
|
| | def read_translations(path, n_repeats): |
| | segment_counter = 0 |
| | segment_translations = [] |
| | translations = defaultdict(list) |
| | for line in open(path): |
| | segment_translations.append(" ".join(line.split())) |
| | if len(segment_translations) == n_repeats: |
| | translations[segment_counter] = segment_translations |
| | segment_translations = [] |
| | segment_counter += 1 |
| | return translations |
| |
|
| |
|
| | def generate_input(translations, n_repeats): |
| | _, ref_path = tempfile.mkstemp() |
| | _, mt_path = tempfile.mkstemp() |
| | ref_fh = open(ref_path, "w") |
| | mt_fh = open(mt_path, "w") |
| | for segid in sorted(translations.keys()): |
| | assert len(translations[segid]) == n_repeats |
| | indexes = combinations(range(n_repeats), 2) |
| | for idx1, idx2 in indexes: |
| | mt_fh.write(translations[segid][idx1].strip() + "\n") |
| | ref_fh.write(translations[segid][idx2].strip() + "\n") |
| | sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path)) |
| | return ref_path, mt_path |
| |
|
| |
|
| | def run_meteor(ref_path, mt_path, metric_path, lang="en"): |
| | _, out_path = tempfile.mkstemp() |
| | subprocess.call( |
| | [ |
| | "java", |
| | "-Xmx2G", |
| | "-jar", |
| | metric_path, |
| | mt_path, |
| | ref_path, |
| | "-p", |
| | "0.5 0.2 0.6 0.75", |
| | "-norm", |
| | "-l", |
| | lang, |
| | ], |
| | stdout=open(out_path, "w"), |
| | ) |
| | os.remove(ref_path) |
| | os.remove(mt_path) |
| | sys.stderr.write("\nSaved Meteor output to %s" % out_path) |
| | return out_path |
| |
|
| |
|
| | def read_output(meteor_output_path, n_repeats): |
| | n_combinations = math.factorial(n_repeats) / ( |
| | math.factorial(2) * math.factorial(n_repeats - 2) |
| | ) |
| | raw_scores = [] |
| | average_scores = [] |
| | for line in open(meteor_output_path): |
| | if not line.startswith("Segment "): |
| | continue |
| | score = float(line.strip().split("\t")[1]) |
| | raw_scores.append(score) |
| | if len(raw_scores) == n_combinations: |
| | average_scores.append(sum(raw_scores) / n_combinations) |
| | raw_scores = [] |
| | os.remove(meteor_output_path) |
| | return average_scores |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("-i", "--infile") |
| | parser.add_argument("-n", "--repeat_times", type=int) |
| | parser.add_argument("-m", "--meteor") |
| | parser.add_argument("-o", "--output") |
| | args = parser.parse_args() |
| |
|
| | translations = read_translations(args.infile, args.repeat_times) |
| | sys.stderr.write("\nGenerating input for Meteor...") |
| | ref_path, mt_path = generate_input(translations, args.repeat_times) |
| | sys.stderr.write("\nRunning Meteor...") |
| | out_path = run_meteor(ref_path, mt_path, args.meteor) |
| | sys.stderr.write("\nReading output...") |
| | scores = read_output(out_path, args.repeat_times) |
| | sys.stderr.write("\nWriting results...") |
| | with open(args.output, "w") as o: |
| | for scr in scores: |
| | o.write("{}\n".format(scr)) |
| | o.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|