| | import argparse |
| | import pandas as pd |
| | from pathlib import Path |
| | import json |
| | import numpy as np |
| | from model.help_funcs import caption_evaluate |
| | from transformers import BertTokenizer, AutoTokenizer |
| | from nltk.translate.bleu_score import corpus_bleu |
| |
|
| | pd.options.display.max_rows = 1000 |
| | pd.options.display.max_columns = 1000 |
| |
|
| |
|
| | def print_std(accs, stds, categories, append_mean=False): |
| | category_line = ' '.join(categories) |
| | if append_mean: |
| | category_line += ' Mean' |
| | |
| | line = '' |
| | if stds is None: |
| | for acc in accs: |
| | line += '{:0.1f} '.format(acc) |
| | else: |
| | for acc, std in zip(accs, stds): |
| | line += '{:0.1f}±{:0.1f} '.format(acc, std) |
| | |
| | if append_mean: |
| | line += '{:0.1f}'.format(sum(accs) / len(accs)) |
| | print(category_line) |
| | print(line) |
| |
|
| |
|
| | def get_mode(df): |
| | if 'bleu2' in df.columns: |
| | return 'caption' |
| | elif 'dataset0/bleu2' in df.columns: |
| | return 'mix_caption' |
| | elif 'test_inbatch_p2t_acc' in df.columns: |
| | return 'retrieval' |
| | elif 'onto_test_rerank_inbatch_p2t_rec20' in df.columns: |
| | return 'mix_retrieval' |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def read_retrieval(df, args): |
| | df = df.round(2) |
| | if args.disable_rerank: |
| | retrieval_cols = ['test_inbatch_p2t_acc', 'test_inbatch_p2t_rec20', 'test_inbatch_t2p_acc', 'test_inbatch_t2p_rec20', 'test_fullset_p2t_acc', 'test_fullset_p2t_rec20', 'test_fullset_t2p_acc', 'test_fullset_t2p_rec20'] |
| | else: |
| | retrieval_cols = ['rerank_test_inbatch_p2t_acc', 'rerank_test_inbatch_p2t_rec20', 'rerank_test_inbatch_t2p_acc', 'rerank_test_inbatch_t2p_rec20', 'rerank_test_fullset_p2t_acc', 'rerank_test_fullset_p2t_rec20', 'rerank_test_fullset_t2p_acc', 'rerank_test_fullset_t2p_rec20'] |
| | retrieval_log = df[~df['test_inbatch_t2p_acc'].isnull()][retrieval_cols] |
| | print(retrieval_cols) |
| | print(retrieval_log.to_string(header=False)) |
| |
|
| | def read_mix_retrieval(df, args): |
| | df = df.round(2) |
| |
|
| | |
| | if args.disable_rerank: |
| | retrieval_cols = ['swiss_test_inbatch_p2t_acc', 'swiss_test_inbatch_p2t_rec20', 'swiss_test_inbatch_t2p_acc', 'swiss_test_inbatch_t2p_rec20','swiss_test_fullset_p2t_acc', 'swiss_test_fullset_p2t_rec20', 'swiss_test_fullset_t2p_acc', 'swiss_test_fullset_t2p_rec20'] |
| | else: |
| | retrieval_cols = ['swiss_test_rerank_inbatch_p2t_acc', 'swiss_test_rerank_inbatch_p2t_rec20', 'swiss_test_rerank_inbatch_t2p_acc', 'swiss_test_rerank_inbatch_t2p_rec20', 'swiss_test_rerank_fullset_p2t_acc', 'swiss_test_rerank_fullset_p2t_rec20', 'swiss_test_rerank_fullset_t2p_acc', 'swiss_test_rerank_fullset_t2p_rec20'] |
| | retrieval_log = df[~df['swiss_test_rerank_inbatch_p2t_acc'].isnull()][retrieval_cols] |
| | print(retrieval_cols) |
| | print(retrieval_log.to_string(header=False)) |
| | print('--------------------') |
| |
|
| | if args.disable_rerank: |
| | retrieval_cols = ['onto_test_inbatch_p2t_acc', 'onto_test_inbatch_p2t_rec20', 'onto_test_inbatch_t2p_acc', 'onto_test_inbatch_t2p_rec20', |
| | 'onto_test_fullset_p2t_acc', 'onto_test_fullset_p2t_rec20', 'onto_test_fullset_t2p_acc', 'onto_test_fullset_t2p_rec20'] |
| | else: |
| | retrieval_cols = ['onto_test_rerank_inbatch_p2t_acc','onto_test_rerank_inbatch_p2t_rec20','onto_test_rerank_inbatch_t2p_acc','onto_test_rerank_inbatch_t2p_rec20', 'onto_test_rerank_fullset_p2t_acc','onto_test_rerank_fullset_p2t_rec20','onto_test_rerank_fullset_t2p_acc','onto_test_rerank_fullset_t2p_rec20'] |
| | |
| | retrieval_log = df[~df['onto_test_rerank_inbatch_p2t_acc'].isnull()][retrieval_cols] |
| | print(retrieval_cols) |
| | print(retrieval_log.to_string(header=False)) |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | def read_caption(df, args): |
| | df = df.round(2) |
| | df = df[~df['bleu2'].isnull()] |
| | if 'acc' in df.columns: |
| | cols = ['epoch', 'acc', 'bleu2','bleu4','rouge_1','rouge_2','rouge_l','meteor_score'] |
| | else: |
| | cols = ['epoch', 'bleu2','bleu4','rouge_1','rouge_2','rouge_l','meteor_score'] |
| | caption_log = df[cols] |
| | |
| | print(cols) |
| | print(caption_log) |
| |
|
| |
|
| | def read_mix_caption(df, args): |
| | df = df.round(2) |
| | df = df[~df['dataset0/bleu2'].isnull()] |
| | cols = ['epoch', 'dataset0/acc', 'dataset0/bleu2','dataset0/bleu4','dataset0/rouge_1','dataset0/rouge_2','dataset0/rouge_l','dataset0/meteor_score'] |
| | caption_log = df[cols] |
| | print('dataset 0') |
| | print([col.split('/')[-1] for col in cols]) |
| | print(caption_log.to_string(header=False)) |
| | if 'dataset1/acc' in df.columns: |
| | print('------------------------------') |
| | cols = ['epoch', 'dataset1/acc', 'dataset1/bleu2','dataset1/bleu4','dataset1/rouge_1','dataset1/rouge_2','dataset1/rouge_l','dataset1/meteor_score'] |
| |
|
| | caption_log = df[cols] |
| | print('dataset 1') |
| | print([col.split('/')[-1] for col in cols]) |
| | print(caption_log.to_string(header=False)) |
| | |
| |
|
| | def exact_match(prediction_list, target_list): |
| | match = 0 |
| | for prediction, target in zip(prediction_list, target_list): |
| | prediction = prediction.strip() |
| | target = target.strip() |
| | if prediction == target: |
| | match += 1 |
| | acc = round(match / len(prediction_list) * 100, 2) |
| | return acc |
| |
|
| | def read_caption_prediction(args): |
| | path = args.path |
| | with open(path, 'r') as f: |
| | lines = f.readlines() |
| | lines = [json.loads(line) for line in lines] |
| |
|
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b', use_fast=False, padding_side='right') |
| | tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
| | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
| |
|
| | prediction_list = [] |
| | target_list = [] |
| | for line in lines: |
| | prediction = line['prediction'].strip() |
| | target = line['target'].strip() |
| | prediction_list.append(prediction) |
| | target_list.append(target) |
| | bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = caption_evaluate(prediction_list, target_list, tokenizer, 128) |
| | bleu2 = round(bleu2, 2) |
| | bleu4 = round(bleu4, 2) |
| | rouge_1 = round(rouge_1, 2) |
| | rouge_2 = round(rouge_2, 2) |
| | rouge_l = round(rouge_l, 2) |
| | meteor_score = round(meteor_score, 2) |
| | |
| | acc = exact_match(prediction_list, target_list) |
| | cols = ['Exact match', 'bleu2','bleu4','rouge_1','rouge_2','rouge_l','meteor_score'] |
| | print(cols) |
| | print(acc, bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score) |
| |
|
| |
|
| | def read_mpp_results(args): |
| | ds_list = ['bace', 'bbbp', 'clintox', 'toxcast', 'sider', 'tox21'] |
| | from pathlib import Path |
| | results = [] |
| | stds = [] |
| | used_ds = [] |
| | for ds in ds_list: |
| | ds_path = Path(args.path) / ds |
| | if not ds_path.exists(): |
| | continue |
| | ds_path = ds_path / 'lightning_logs' |
| | test_roc_list = [] |
| | for f in ds_path.glob("version_*"): |
| | f = f / 'metrics.csv' |
| | df = pd.read_csv(f) |
| | df = df[['val roc', 'test roc']] |
| | df = df[~df['val roc'].isnull()] |
| | array = df.to_numpy() |
| | test_roc = array[array[:, 0].argmax(), 1] |
| | test_roc_list.append(test_roc) |
| | test_roc_list = np.asarray(test_roc_list) |
| | test_roc = round(test_roc_list.mean() * 100, 2) |
| | results.append(test_roc) |
| | test_std = round(test_roc_list.std() * 100, 2) |
| | stds.append(test_std) |
| | used_ds.append(ds) |
| |
|
| | print_std(results, stds, used_ds, True) |
| |
|
| | def read_regression_results(args): |
| | path = Path(args.path) |
| | test_rmse_list = [] |
| | for file in path.glob('version_*'): |
| | file = file / 'metrics.csv' |
| | df = pd.read_csv(file) |
| | df = df[['val rmse', 'test rmse']] |
| | df = df[~df['val rmse'].isnull()] |
| | array = df.to_numpy() |
| | test_rmse = array[array[:, 0].argmin(), 1] |
| | test_rmse_list.append(test_rmse) |
| | test_rmse_list = np.asarray(test_rmse_list) |
| | mean = round(test_rmse_list.mean(), 3) |
| | std = round(test_rmse_list.std(), 3) |
| | print(f'{mean}±{std}') |
| |
|
| |
|
| | def read_qa_results(path, text_trunc_length): |
| | tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b', use_fast=False, padding_side='right') |
| | tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
| | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
| | |
| | with open(path, 'r') as f: |
| | lines = f.readlines() |
| | lines = [line.strip() for line in lines] |
| | lines = [json.loads(line) for line in lines] |
| | for line in lines: |
| | line['target'] = line['target'].strip() |
| | line['prediction'] = line['prediction'].strip() |
| | |
| | |
| | total = len(lines) |
| | correct = 0 |
| | for line in lines: |
| | if line['target'] == line['prediction']: |
| | correct += 1 |
| | overall_acc = round(correct / total * 100, 2) |
| | |
| | |
| | q_types = ['Number structure/property', 'Number side information', 'String structure/property', 'String side information', 'Number', 'String'] |
| | q_type2acc = {q_type: 0 for q_type in q_types} |
| | q_type2bleu2 = {q_type: 0 for q_type in q_types if q_type.find('String') >= 0} |
| | for q_type in q_types: |
| | prediction_list = [] |
| | target_list = [] |
| | for line in lines: |
| | if line['q_type'].find(q_type) >= 0: |
| | prediction_list.append(line['prediction']) |
| | target_list.append(line['target']) |
| | if len(prediction_list) == 0: |
| | continue |
| | correct = 0 |
| | for prediction, target in zip(prediction_list, target_list): |
| | if prediction == target: |
| | correct += 1 |
| | acc = round(correct / len(prediction_list) * 100, 2) |
| | q_type2acc[q_type] = acc |
| | |
| | |
| | if q_type.find('String') >= 0: |
| | |
| | prediction_list = tokenizer(prediction_list, truncation=True, max_length=text_trunc_length, padding=False)['input_ids'] |
| | prediction_list = [tokenizer.convert_ids_to_tokens(i) for i in prediction_list] |
| | target_list = tokenizer(target_list, truncation=True, max_length=text_trunc_length, padding=False)['input_ids'] |
| | target_list = [tokenizer.convert_ids_to_tokens(i) for i in target_list] |
| | |
| | target_list = list(filter(('<pad>').__ne__, target_list)) |
| | target_list = list(filter(('[PAD]').__ne__, target_list)) |
| | target_list = list(filter(('[CLS]').__ne__, target_list)) |
| | target_list = list(filter(('[SEP]').__ne__, target_list)) |
| | |
| | prediction_list = list(filter(('<pad>').__ne__, prediction_list)) |
| | prediction_list = list(filter(('[PAD]').__ne__, prediction_list)) |
| | prediction_list = list(filter(('[CLS]').__ne__, prediction_list)) |
| | prediction_list = list(filter(('[SEP]').__ne__, prediction_list)) |
| |
|
| | hypothesis = prediction_list |
| | references = [[t] for t in target_list] |
| | bleu2 = corpus_bleu(references, hypothesis, weights=(.5,.5)) |
| | bleu2 = round(bleu2 * 100, 2) |
| | q_type2bleu2[q_type] = bleu2 |
| | print('overall accuracy') |
| | print(overall_acc, ) |
| | print('accuracy') |
| | print(q_type2acc) |
| | print('bleu-2') |
| | print(q_type2bleu2) |
| | return overall_acc, q_type2acc, q_type2bleu2 |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--path', type=str) |
| | parser.add_argument('--tag', type=str, default='train_loss_gtm') |
| | parser.add_argument('--max_step', type=int, default=1000) |
| | parser.add_argument('--disable_rerank', action='store_true', default=False) |
| | parser.add_argument('--qa_question', action='store_true', default=False) |
| | args = parser.parse_args() |
| | args.path = Path(args.path) |
| | |
| | if args.qa_question: |
| | read_qa_results(args.path, 128) |
| | exit() |
| | |
| | if args.path.name.find('predictions') >= 0: |
| | read_caption_prediction(args) |
| | exit() |
| | elif str(args.path).find('mpp') >= 0: |
| | read_mpp_results(args) |
| | exit() |
| | elif str(args.path).find('regression') >= 0: |
| | read_regression_results(args) |
| | exit() |
| |
|
| | log_hparas = args.path / 'hparams.yaml' |
| | with open(log_hparas, 'r') as f: |
| | line = f.readline() |
| | file_name = line.strip().split(' ')[1] |
| | |
| | log_path = args.path / 'metrics.csv' |
| | log = pd.read_csv(log_path) |
| | |
| | print(f'File name: {file_name}') |
| | mode = get_mode(log) |
| | |
| | if mode == 'retrieval': |
| | read_retrieval(log, args) |
| | elif mode == 'caption': |
| | read_caption(log, args) |
| | elif mode == 'mix_caption': |
| | read_mix_caption(log, args) |
| | elif mode == 'mix_retrieval': |
| | read_mix_retrieval(log, args) |