File size: 2,775 Bytes
c961996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from data import LyricsCommentsDatasetPsuedo_fusion
from torch import utils, nn
from model import CommentGenerator
from model_fusion import CommentGenerator_fusion
import transformers
import datasets
from tqdm import tqdm
import statistics
import os
DATASET_PATH = "dataset_test.pkl"
MODEL_PATH = "model/bart_fusion_full.pt"
# MODEL_NAME = "bart"

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

test_dataset = LyricsCommentsDatasetPsuedo_fusion(DATASET_PATH)
dataset_length = len(test_dataset)

test_dataloader = utils.data.DataLoader(test_dataset,
                                         # batch_size=len(valid_dataset),
                                         batch_size=32,
                                         shuffle=False)

if 'baseline' in MODEL_PATH:
    model = CommentGenerator().cuda()
else:
    model = CommentGenerator_fusion().cuda()
model.load_state_dict(torch.load(MODEL_PATH))

model.eval()

samples_list = list()
# generate
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
    if 'baseline' in MODEL_PATH:
        with torch.no_grad():
            output_samples = model.generate(lyrics)
    else:
        with torch.no_grad():
            output_samples = model.generate(lyrics, music_id)
    samples_list.append(output_samples)

# ------ ROUGE ------ #

metrics = datasets.load_metric('rouge')#, 'sacrebleu', 'meteor', 'bertscore')

for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
    output_samples = samples_list[batch_index]
    metrics.add_batch(predictions=output_samples, references=comment)

score = metrics.compute()
print(score)

# ------ BLEU ------ #

metrics = datasets.load_metric('sacrebleu')#, 'sacrebleu', 'meteor', 'bertscore')

for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
    output_samples = samples_list[batch_index]
    metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])

score = metrics.compute()
print(score)

# ------ BERTScore ------ #

metrics = datasets.load_metric('bertscore')#, 'sacrebleu', 'meteor', 'bertscore')

for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
    output_samples = samples_list[batch_index]
    metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])

score = metrics.compute(lang='en')
score = statistics.mean(score['f1'])
print(score)

# ------ METEOR ------ #

metrics = datasets.load_metric('meteor')#, 'sacrebleu', 'meteor', 'bertscore')

for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
    output_samples = samples_list[batch_index]
    metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])

score = metrics.compute()
print(score)