|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
|
from fairseq.scoring import BaseScorer, register_scorer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MeteorScorerConfig(FairseqDataclass): |
|
|
pass |
|
|
|
|
|
|
|
|
@register_scorer("meteor", dataclass=MeteorScorerConfig) |
|
|
class MeteorScorer(BaseScorer): |
|
|
def __init__(self, args): |
|
|
super(MeteorScorer, self).__init__(args) |
|
|
try: |
|
|
import nltk |
|
|
except ImportError: |
|
|
raise ImportError("Please install nltk to use METEOR scorer") |
|
|
|
|
|
self.nltk = nltk |
|
|
self.scores = [] |
|
|
|
|
|
def add_string(self, ref, pred): |
|
|
self.ref.append(ref) |
|
|
self.pred.append(pred) |
|
|
|
|
|
def score(self, order=4): |
|
|
self.scores = [ |
|
|
self.nltk.translate.meteor_score.single_meteor_score(r, p) |
|
|
for r, p in zip(self.ref, self.pred) |
|
|
] |
|
|
return np.mean(self.scores) |
|
|
|
|
|
def result_string(self, order=4): |
|
|
return f"METEOR: {self.score():.4f}" |
|
|
|