| |
| |
| import json |
| import numpy as np |
| import sys |
|
|
| sys.path.append("metric") |
| from metric.smooth_bleu import codenn_smooth_bleu |
| from metric.meteor.meteor import Meteor |
| from metric.rouge.rouge import Rouge |
| from metric.cider.cider import Cider |
| import warnings |
| import argparse |
| import logging |
| import prettytable as pt |
|
|
| warnings.filterwarnings('ignore') |
| logging.basicConfig(format='[%(asctime)s - %(levelname)s - %(name)s] %(message)s', |
| datefmt='%m/%d/%Y %H:%M:%S', |
| level=logging.INFO) |
|
|
|
|
| def Commitbleus(refs, preds): |
|
|
| r_str_list = [] |
| p_str_list = [] |
| for r, p in zip(refs, preds): |
| if len(r[0]) == 0 or len(p) == 0: |
| continue |
| r_str_list.append([" ".join([str(token_id) for token_id in r[0]])]) |
| p_str_list.append(" ".join([str(token_id) for token_id in p])) |
| try: |
| bleu_list = codenn_smooth_bleu(r_str_list, p_str_list) |
| except: |
| bleu_list = [0, 0, 0, 0] |
| codenn_bleu = bleu_list[0] |
|
|
| B_Norm = round(codenn_bleu, 4) |
|
|
| return B_Norm |
|
|
|
|
| def read_to_list(filename): |
| f = open(filename, 'r',encoding="utf-8") |
| res = [] |
| for row in f: |
| |
| res.append(row.lower().split()) |
| return res |
|
|
| def metetor_rouge_cider(refs, preds): |
|
|
| refs_dict = {} |
| preds_dict = {} |
| for i in range(len(preds)): |
| preds_dict[i] = [" ".join(preds[i])] |
| refs_dict[i] = [" ".join(refs[i][0])] |
| |
| score_Meteor, scores_Meteor = Meteor().compute_score(refs_dict, preds_dict) |
| print("Meteor: ", round(score_Meteor*100,2)) |
|
|
| score_Rouge, scores_Rouge = Rouge().compute_score(refs_dict, preds_dict) |
| print("Rouge-L: ", round(score_Rouge*100,2)) |
|
|
| score_Cider, scores_Cider = Cider().compute_score(refs_dict, preds_dict) |
| print("Cider: ",round(score_Cider,2) ) |
|
|
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--refs_filename', type=str, default="../saved_model/tlcodesum/UNLC/ref.txt", required=False) |
| parser.add_argument('--preds_filename', type=str, default="../saved_model/tlcodesum/UNLC/dlen500-clen30-dvoc30000-cvoc30000-bs-ddim64-cdim-rhs64-lr0_Medit_pred.txt", required=False) |
| args = parser.parse_args() |
| refs = read_to_list(args.refs_filename) |
| refs = [[t] for t in refs] |
| preds = read_to_list(args.preds_filename) |
| bleus_score = Commitbleus(refs, preds) |
| print("BLEU: %.2f"%bleus_score) |
| metetor_rouge_cider(refs, preds) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|