File size: 1,791 Bytes
c4d24a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
import evaluate

class TranslationEvaluator:
    def __init__(self):
        self.bleu = evaluate.load("bleu")
        self.bertscore = evaluate.load("bertscore")
        # COMET MQM model
        self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
        logging.info("Loaded BLEU, BERTScore, COMET metrics")

    def evaluate(self, sources, references, predictions):
        """
        - sources: List[str]
        - references: List[str]
        - predictions: List[str]
        Returns a dict: { "BLEU": float, "BERTScore": float, "BERTurk": float, "COMET": float }
        """
        results = {}

        # BLEU
        results["BLEU"] = self.bleu.compute(
            predictions=predictions,
            references=[[r] for r in references]
        )["bleu"]

        # BERTScore (general, lang="xx")
        bs = self.bertscore.compute(
            predictions=predictions,
            references=references,
            lang="xx"
        )
        results["BERTScore"] = float(sum(bs["f1"]) / len(bs["f1"])) if bs["f1"] else 0.0

        # BERTurk (lang="tr")
        bs_tr = self.bertscore.compute(
            predictions=predictions,
            references=references,
            lang="tr"
        )
        results["BERTurk"] = float(sum(bs_tr["f1"]) / len(bs_tr["f1"])) if bs_tr["f1"] else 0.0

        # COMET (expects srcs, hyps, refs)
        comet_out = self.comet.compute(
            srcs=sources,
            hyps=predictions,
            refs=references
        )
        scores = comet_out.get("scores", None)
        if isinstance(scores, list):
            results["COMET"] = float(scores[0]) if scores else 0.0
        else:
            results["COMET"] = float(scores) if scores is not None else 0.0

        return results