STAR / fairseq /scoring /meteor.py
Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
@dataclass
class MeteorScorerConfig(FairseqDataclass):
pass
@register_scorer("meteor", dataclass=MeteorScorerConfig)
class MeteorScorer(BaseScorer):
def __init__(self, args):
super(MeteorScorer, self).__init__(args)
try:
import nltk
except ImportError:
raise ImportError("Please install nltk to use METEOR scorer")
self.nltk = nltk
self.scores = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
def score(self, order=4):
self.scores = [
self.nltk.translate.meteor_score.single_meteor_score(r, p)
for r, p in zip(self.ref, self.pred)
]
return np.mean(self.scores)
def result_string(self, order=4):
return f"METEOR: {self.score():.4f}"