File size: 2,343 Bytes
acd7cf4
 
7566ac3
 
acd7cf4
43d27f2
acd7cf4
 
7566ac3
acd7cf4
 
7566ac3
 
 
 
 
acd7cf4
7566ac3
acd7cf4
7566ac3
acd7cf4
 
 
 
7566ac3
acd7cf4
 
 
 
7566ac3
acd7cf4
 
 
 
 
 
 
 
7566ac3
 
acd7cf4
7566ac3
 
acd7cf4
7566ac3
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7566ac3
acd7cf4
 
 
 
 
43d27f2
acd7cf4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Set

from graphgen.bases import BaseEvaluator, QAPair
from graphgen.utils import NLTKHelper, detect_main_language


class MTLDEvaluator(BaseEvaluator):
    """
    Metrics for measuring the lexical diversity of text.
    """

    def __init__(self, threshold: float = 0.72):
        self.nltk_helper = NLTKHelper()
        self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en"))
        self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
        self.threshold = threshold

    def evaluate(self, pair: QAPair) -> float:
        """
        Calculate the MTLD (Mean Token Length Diversity) score for a given text.

        min is 1.0
        higher is better
        """
        text = pair.answer
        if not text or not text.strip():
            return 0.0

        lang = detect_main_language(text)
        tokens = self.nltk_helper.word_tokenize(text, lang)

        stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
        filtered_tokens = [word for word in tokens if word not in stopwords]
        filtered_tokens = [word for word in filtered_tokens if word.isalnum()]

        if not filtered_tokens:
            return 0

        # Compute forward factors
        forward_factors = self._compute_factors(filtered_tokens, self.threshold)

        # Compute backward factors
        backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)

        # Compute average factors
        return (forward_factors + backward_factors) / 2

    @staticmethod
    def _compute_factors(tokens: list, threshold: float) -> float:
        factors = 0
        current_segment = []
        unique_words = set()

        for token in tokens:
            current_segment.append(token)
            unique_words.add(token)
            ttr = len(unique_words) / len(current_segment)

            if ttr <= threshold:
                factors += 1
                current_segment = []
                unique_words = set()

        # handle last segment
        if current_segment:
            ttr = len(unique_words) / len(current_segment)
            if ttr <= threshold:
                factors += 1
            else:
                factors += 1 - (ttr - threshold) / (1 - threshold)

        return len(tokens) / factors if factors > 0 else len(tokens)