|
|
import os |
|
|
from bert_score import score |
|
|
import json |
|
|
import argparse |
|
|
import csv |
|
|
import torch |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
os.environ["HF_TOKEN"] = "Your_HF_TOKEN" |
|
|
cache_dir = 'Your_cache_directory' |
|
|
|
|
|
def main(args): |
|
|
start = 0 |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
with open(args.data_can, 'r') as f: |
|
|
data_1 = json.load(f)[start:args.N] |
|
|
cands = [item["Watermarked_summary"] for item in data_1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(args.data_ref, 'r') as f: |
|
|
data_2 = json.load(f)[start:args.N] |
|
|
refs = [item["summary"] for item in data_2] |
|
|
|
|
|
|
|
|
saving_freq = 10 |
|
|
|
|
|
input_counter = 0 |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, item in enumerate(cands): |
|
|
num_tokens = len(item.split()) |
|
|
print(f"Item number: {i}") |
|
|
|
|
|
if num_tokens >= 16: |
|
|
P, R, F1 = score([cands[i]], [refs[i]], lang="en", verbose=True) |
|
|
scores = F1.mean().item() |
|
|
|
|
|
results.append([i, scores]) |
|
|
|
|
|
else: |
|
|
print(f"Skipping item number {i} due to insufficient tokens.") |
|
|
|
|
|
|
|
|
input_counter += 1 |
|
|
|
|
|
|
|
|
if input_counter % saving_freq == 0: |
|
|
|
|
|
if os.path.isfile(f"{args.Output_name}{start}_{input_counter-saving_freq}.csv"): |
|
|
os.remove(f"{args.Output_name}{start}_{input_counter-saving_freq}.csv") |
|
|
|
|
|
with open(f'{args.Output_name}{start}_{input_counter}.csv', 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["data_item", 'BERTScore']) |
|
|
writer.writerows(results) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='Calculate BERTScore') |
|
|
parser.add_argument('--data_can',default= 'DeepSeek_TW_Summarization_test__1000.json',type=str, help='a file containing the candidate document to test') |
|
|
parser.add_argument('--data_ref',default= 'DeepSeek_No_WM_Summarization_test_0_1000_1000.json',type=str, help='a file containing the reference document to test') |
|
|
parser.add_argument('--N', default= 1000, type=int, help='Number of data items to process') |
|
|
parser.add_argument('--Output_name', default= "BERTScore_DeepSeek_Summarization_TW_ref_No_WM_", type=str, help='Name of the output file') |
|
|
main(parser.parse_args()) |