| | |
| | |
| | |
| | |
| |
|
| |
|
| | import importlib |
| | import os |
| | from abc import ABC, abstractmethod |
| |
|
| | from fairseq import registry |
| | from omegaconf import DictConfig |
| |
|
| |
|
| | class BaseScorer(ABC): |
| | def __init__(self, cfg): |
| | self.cfg = cfg |
| | self.ref = [] |
| | self.pred = [] |
| |
|
| | def add_string(self, ref, pred): |
| | self.ref.append(ref) |
| | self.pred.append(pred) |
| |
|
| | @abstractmethod |
| | def score(self) -> float: |
| | pass |
| |
|
| | @abstractmethod |
| | def result_string(self) -> str: |
| | pass |
| |
|
| |
|
| | _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( |
| | "--scoring", default="bleu" |
| | ) |
| |
|
| |
|
| | def build_scorer(choice, tgt_dict): |
| | _choice = choice._name if isinstance(choice, DictConfig) else choice |
| |
|
| | if _choice == "bleu": |
| | from fairseq.scoring import bleu |
| |
|
| | return bleu.Scorer( |
| | bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) |
| | ) |
| | return _build_scorer(choice) |
| |
|
| |
|
| | |
| | for file in sorted(os.listdir(os.path.dirname(__file__))): |
| | if file.endswith(".py") and not file.startswith("_"): |
| | module = file[: file.find(".py")] |
| | importlib.import_module("fairseq.scoring." + module) |
| |
|