File size: 3,129 Bytes
40b3335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from bert_score import score
import json
import argparse
import csv
import torch
import warnings
warnings.filterwarnings("ignore")

# Ensure the HF_HOME environment variable points to your desired cache location
os.environ["HF_TOKEN"] = "Your_HF_TOKEN"
cache_dir = 'Your_cache_directory'

def main(args):
    start = 0
    # Clear the cache
    torch.cuda.empty_cache()

    # Load Candidate and Reference Files if they are from the same file.
    with open(args.data_can, 'r') as f:
        data_1 = json.load(f)[start:args.N]
        cands = [item["Watermarked_summary"] for item in data_1]
        # randomized_words = [item["Total_randomized_words"] for item in data_1]
        # total_words = [item["Total_words"] for item in data_1]


    with open(args.data_ref, 'r') as f:
        data_2 = json.load(f)[start:args.N]
        refs = [item["summary"] for item in data_2]

    # Set saving frequency
    saving_freq = 10
    # Initialize input counter
    input_counter = 0
    # Loop through the output text and detect the watermark
    results = []
    # Loop through the data and calculate the BERTScore
    for i, item in enumerate(cands):
            num_tokens = len(item.split())
            print(f"Item number: {i}")
            
            if num_tokens >= 16: # Only consider items with at least 16 tokens for valid assessment
                P, R, F1 = score([cands[i]], [refs[i]], lang="en", verbose=True)
                scores = F1.mean().item()
                #results.append([i, scores, randomized_words[i], total_words[i]])
                results.append([i, scores])

            else:
               print(f"Skipping item number {i} due to insufficient tokens.")
            # Write the results to a CSV file
            # Increment input counter
            input_counter += 1

            # Save the results after processing every saving_freq inputs
            if input_counter % saving_freq == 0:
                # Check if the file exits
                if os.path.isfile(f"{args.Output_name}{start}_{input_counter-saving_freq}.csv"):
                    os.remove(f"{args.Output_name}{start}_{input_counter-saving_freq}.csv")
                
                with open(f'{args.Output_name}{start}_{input_counter}.csv', 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(["data_item", 'BERTScore'])
                    writer.writerows(results) 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Calculate BERTScore')
    parser.add_argument('--data_can',default= 'DeepSeek_TW_Summarization_test__1000.json',type=str, help='a file containing the candidate document to test')
    parser.add_argument('--data_ref',default= 'DeepSeek_No_WM_Summarization_test_0_1000_1000.json',type=str, help='a file containing the reference document to test')
    parser.add_argument('--N', default= 1000, type=int, help='Number of data items to process')
    parser.add_argument('--Output_name', default= "BERTScore_DeepSeek_Summarization_TW_ref_No_WM_", type=str, help='Name of the output file')
    main(parser.parse_args())