File size: 1,164 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
@dataclass
class MeteorScorerConfig(FairseqDataclass):
pass
@register_scorer("meteor", dataclass=MeteorScorerConfig)
class MeteorScorer(BaseScorer):
def __init__(self, args):
super(MeteorScorer, self).__init__(args)
try:
import nltk
except ImportError:
raise ImportError("Please install nltk to use METEOR scorer")
self.nltk = nltk
self.scores = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
def score(self, order=4):
self.scores = [
self.nltk.translate.meteor_score.single_meteor_score(r, p)
for r, p in zip(self.ref, self.pred)
]
return np.mean(self.scores)
def result_string(self, order=4):
return f"METEOR: {self.score():.4f}"
|