File size: 2,775 Bytes
c961996 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from data import LyricsCommentsDatasetPsuedo_fusion
from torch import utils, nn
from model import CommentGenerator
from model_fusion import CommentGenerator_fusion
import transformers
import datasets
from tqdm import tqdm
import statistics
import os
DATASET_PATH = "dataset_test.pkl"
MODEL_PATH = "model/bart_fusion_full.pt"
# MODEL_NAME = "bart"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
test_dataset = LyricsCommentsDatasetPsuedo_fusion(DATASET_PATH)
dataset_length = len(test_dataset)
test_dataloader = utils.data.DataLoader(test_dataset,
# batch_size=len(valid_dataset),
batch_size=32,
shuffle=False)
if 'baseline' in MODEL_PATH:
model = CommentGenerator().cuda()
else:
model = CommentGenerator_fusion().cuda()
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
samples_list = list()
# generate
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
if 'baseline' in MODEL_PATH:
with torch.no_grad():
output_samples = model.generate(lyrics)
else:
with torch.no_grad():
output_samples = model.generate(lyrics, music_id)
samples_list.append(output_samples)
# ------ ROUGE ------ #
metrics = datasets.load_metric('rouge')#, 'sacrebleu', 'meteor', 'bertscore')
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
output_samples = samples_list[batch_index]
metrics.add_batch(predictions=output_samples, references=comment)
score = metrics.compute()
print(score)
# ------ BLEU ------ #
metrics = datasets.load_metric('sacrebleu')#, 'sacrebleu', 'meteor', 'bertscore')
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
output_samples = samples_list[batch_index]
metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
score = metrics.compute()
print(score)
# ------ BERTScore ------ #
metrics = datasets.load_metric('bertscore')#, 'sacrebleu', 'meteor', 'bertscore')
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
output_samples = samples_list[batch_index]
metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
score = metrics.compute(lang='en')
score = statistics.mean(score['f1'])
print(score)
# ------ METEOR ------ #
metrics = datasets.load_metric('meteor')#, 'sacrebleu', 'meteor', 'bertscore')
for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
output_samples = samples_list[batch_index]
metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
score = metrics.compute()
print(score) |