lmt-arr / src /mt_scoring.py
sleepyhead111's picture
Upload folder using huggingface_hub
9f73d88 verified
raw
history blame
7.87 kB
# coding=utf8
import os
import pandas as pd
from tqdm import tqdm
import subprocess
import json
import shutil
from collections import defaultdict
import argparse
import datetime
from openpyxl import load_workbook,Workbook
from openpyxl.utils import get_column_letter
from sacrebleu.metrics import BLEU, CHRF, TER
from comet import load_from_checkpoint
def bleu_scoring(ref_file, hypo_file, lp):
src, tgt = lp.split("2")
langpair = f"{src}-{tgt}"
command = f"sacrebleu -w 2 -b {ref_file} -i {hypo_file} -l {langpair}"
print(command)
score = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)
print(score.stdout)
return float(score.stdout.strip())
def comet22_scoring(src_file, ref_file, hypo_file, model):
srcs = [x.strip() for x in open(src_file, encoding='utf-8')]
refs = [x.strip() for x in open(ref_file, encoding='utf-8')]
hypos = [x.strip() for x in open(hypo_file, encoding='utf-8')]
assert len(srcs) == len(refs) == len(hypos), print(src_file, ref_file, hypo_file)
data = [{"src":x, "mt":y, "ref":z} for x,y,z in zip(srcs, hypos, refs)]
print(f"comet22\nsrc_file: {src_file}\nref_file: {ref_file}\nhypo_file: {hypo_file}")
model_output = model.predict(data, batch_size=128, gpus=1) ###256
score = round(model_output[1]*100, 2)
return score
def xcomet_scoring(src_file, hypo_file, model):
srcs = [x.strip() for x in open(src_file, encoding='utf-8') if x.strip()]
hypos = [x.strip() for x in open(hypo_file, encoding='utf-8') if x.strip()]
assert len(srcs) == len(hypos)
data = [{"src":x, "mt":y} for x,y in zip(srcs, hypos)]
print(f"xcomet\nsrc_file: {src_file}\nhypo_file: {hypo_file}")
model_output = model.predict(data, batch_size=16, gpus=1)
score = round(model_output[1]*100, 2)
return score
def write_xlsl(file, data, flag=""):
if os.path.exists(file):
wb = load_workbook(file)
else:
wb = Workbook()
ws = wb.active
# 找到第一个空白行的位置
row_index = 1
while ws[f'A{row_index}'].value is not None:
row_index += 1
current_time = datetime.datetime.now()
ws[f'A{row_index}'] = f"{current_time.strftime('%Y-%m-%d %H:%M:%S')}\n{flag}"
# ws[f'B{row_index}'] = flag
headers = list(data.keys())
for col_index, header in enumerate(headers, start=1):
ws[f'{get_column_letter(col_index)}{row_index + 1}'] = header
max_length = max(len(value) for value in data.values())
for i in range(max_length):
row_index += 1
for col_index, (key, values) in enumerate(data.items(), start=1):
try:
ws[f'{get_column_letter(col_index)}{row_index + 1}'] = values[i]
except:
print(data)
print(flag)
print(values, max_length)
wb.save(file)
def sort_data(src_files, hypo_files, ref_files, lang_pairs):
# sort_order = {'de2en': 1, 'cs2en': 2, 'ru2en': 3, 'zh2en': 4, 'en2de': 5,'en2cs': 6,'en2ru': 7,'en2zh': 8}
# sort_order = {'zh2en': 1, 'zh2ja': 2, 'zh2ko': 3, 'zh2ru': 4, 'zh2de': 5,'zh2fr': 6,'zh2it': 7,'zh2pt': 8,'zh2es': 9,'zh2ar': 10,
# 'en2zh': 11, 'ja2zh': 12, 'ko2zh': 13, 'ru2zh': 14, 'de2zh': 15,'fr2zh': 16,'it2zh': 17,'pt2zh': 18,'es2zh': 19,'ar2zh': 20,
# 'en2ja': 21, 'en2ko': 22, 'en2ru': 23, 'en2de': 24,'en2fr': 25,'en2it': 26,'en2pt': 27,'en2es': 28, 'en2ar': 29,
# 'ja2en': 30, 'ko2en': 31, 'ru2en': 32, 'de2en': 33,'fr2en': 34,'it2en': 35,'pt2en': 36,'es2en': 37, 'ar2en': 38,
# 'zh2ug':39, 'zh2bo':40, 'zh2mn':41, 'ug2zh':42, 'bo2zh':43, 'mn2zh':44,
# 'en2ug':45, 'en2bo':46, 'en2mn':47, 'ug2en':48, 'bo2en':49, 'mn2en':50,
# }
sort_order = {"zh2en":1, "zh2ru":2, "zh2de":3, "zh2bn":4, 'zh2hi': 5, 'zh2th': 6, 'zh2jv': 7, 'zh2sw': 8, 'zh2si':9, 'zh2km':10,
"en2zh":11, "ru2zh":12, 'de2zh':13, 'bn2zh':14, 'hi2zh':15, 'th2zh':16, 'jv2zh':17, 'sw2zh':18, 'si2zh':19, 'km2zh':20
}
combined = list(zip(src_files, hypo_files, ref_files, lang_pairs))
combined_sorted = sorted(combined, key=lambda x: sort_order.get(x[-1], 100))
src_files, hypo_files, ref_files, lang_pairs = zip(*combined_sorted)
return list(src_files), list(hypo_files), list(ref_files), list(lang_pairs)
def main():
parser = argparse.ArgumentParser(description="Script with conditional parameters")
parser.add_argument('--metric', type=str, help='The evaluate metric', default="bleu,comet_22,xcomet_xxl")
parser.add_argument('--comet_22_path', default="/mnt/luoyingfeng/model_card/wmt22-comet-da/checkpoints/model.ckpt", type=str, help='The comet22 path model')
parser.add_argument('--xcomet_xl_path', default="/mnt/luoyingfeng/model_card/XCOMET-XL/checkpoints/model.ckpt", type=str, help='The xcomet xl path model')
parser.add_argument('--xcomet_xxl_path', default="/mnt/luoyingfeng/model_card/XCOMET-XXL/checkpoints/model.ckpt", type=str, help='The xcomet xxl path model')
parser.add_argument('--lang_pair', type=str, help='plain text')
parser.add_argument('--write_key', type=str, default="language", help='plain text')
parser.add_argument('--src_file', type=str, help='plain text')
parser.add_argument('--ref_file', type=str, help='plain text')
parser.add_argument('--hypo_file', type=str, help='plain text')
parser.add_argument('--record_file', default="result.xlsx", type=str, help='plain text')
parser.add_argument('--gpu', type=str, default="0,1,2", help='plain text')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
src_files = args.src_file.split(",")
hypo_files = args.hypo_file.split(",")
ref_files = args.ref_file.split(",")
lang_pairs = args.lang_pair.split(",")
assert len(src_files) == len(hypo_files) == len(lang_pairs) == len(ref_files)
src_files, hypo_files, ref_files, lang_pairs = sort_data(src_files, hypo_files, ref_files, lang_pairs)
metrics = args.metric.split(",")
if "comet_22" in metrics:
comet_22_model = load_from_checkpoint(args.comet_22_path, reload_hparams=True)
if "xcomet_xl" in metrics:
comet_xl_model = load_from_checkpoint(args.xcomet_xl_path, reload_hparams=True)
if "xcomet_xxl" in metrics:
comet_xxl_model = load_from_checkpoint(args.xcomet_xxl_path, reload_hparams=True)
result = defaultdict(list)
result["metric"] = metrics
for metric in metrics:
for lp,src_file,ref_file, hypo_file in zip(lang_pairs, src_files, ref_files, hypo_files):
if not os.path.isfile(src_file):
print(f"file {src_file} not exist!")
exit()
if not os.path.isfile(ref_file):
print(f"file {ref_file} not exist!")
exit()
print(f"evaluate {lp}")
if args.write_key == "language":
wk = lp
else:
# hypo suffix
wk = os.path.basename(hypo_file)
if metric == "bleu":
score = bleu_scoring(ref_file, hypo_file, lp)
result[wk].append(score)
if metric == "comet_22":
score = comet22_scoring(src_file, ref_file, hypo_file, comet_22_model)
result[wk].append(score)
if metric == "xcomet_xl":
score = xcomet_scoring(src_file, hypo_file, comet_xl_model)
result[wk].append(score)
if metric == "xcomet_xxl":
score = xcomet_scoring(src_file, hypo_file, comet_xxl_model)
result[wk].append(score)
write_xlsl(args.record_file, result, flag=hypo_files[-1])
if __name__ == '__main__':
main()