Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.scoring import BaseScorer, register_scorer | |
| class BertScoreScorerConfig(FairseqDataclass): | |
| bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"}) | |
| class BertScoreScorer(BaseScorer): | |
| def __init__(self, cfg): | |
| super(BertScoreScorer, self).__init__(cfg) | |
| try: | |
| import bert_score as _bert_score | |
| except ImportError: | |
| raise ImportError("Please install BERTScore: pip install bert-score") | |
| self.cfg = cfg | |
| self._bert_score = _bert_score | |
| self.scores = None | |
| def add_string(self, ref, pred): | |
| self.ref.append(ref) | |
| self.pred.append(pred) | |
| def score(self, order=4): | |
| _, _, self.scores = self._bert_score.score( | |
| self.pred, self.ref, lang=self.cfg.bert_score_lang | |
| ) | |
| self.scores = self.scores.numpy() | |
| return np.mean(self.scores) | |
| def result_string(self, order=4): | |
| return f"BERTScore: {self.score():.4f}" | |