Spaces:
Running
Running
| from typing import Set | |
| from graphgen.bases import BaseEvaluator, QAPair | |
| from graphgen.utils import NLTKHelper, detect_main_language | |
| class MTLDEvaluator(BaseEvaluator): | |
| """ | |
| Metrics for measuring the lexical diversity of text. | |
| """ | |
| def __init__(self, threshold: float = 0.72): | |
| self.nltk_helper = NLTKHelper() | |
| self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en")) | |
| self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh")) | |
| self.threshold = threshold | |
| def evaluate(self, pair: QAPair) -> float: | |
| """ | |
| Calculate the MTLD (Mean Token Length Diversity) score for a given text. | |
| min is 1.0 | |
| higher is better | |
| """ | |
| text = pair.answer | |
| if not text or not text.strip(): | |
| return 0.0 | |
| lang = detect_main_language(text) | |
| tokens = self.nltk_helper.word_tokenize(text, lang) | |
| stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en | |
| filtered_tokens = [word for word in tokens if word not in stopwords] | |
| filtered_tokens = [word for word in filtered_tokens if word.isalnum()] | |
| if not filtered_tokens: | |
| return 0 | |
| # Compute forward factors | |
| forward_factors = self._compute_factors(filtered_tokens, self.threshold) | |
| # Compute backward factors | |
| backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold) | |
| # Compute average factors | |
| return (forward_factors + backward_factors) / 2 | |
| def _compute_factors(tokens: list, threshold: float) -> float: | |
| factors = 0 | |
| current_segment = [] | |
| unique_words = set() | |
| for token in tokens: | |
| current_segment.append(token) | |
| unique_words.add(token) | |
| ttr = len(unique_words) / len(current_segment) | |
| if ttr <= threshold: | |
| factors += 1 | |
| current_segment = [] | |
| unique_words = set() | |
| # handle last segment | |
| if current_segment: | |
| ttr = len(unique_words) / len(current_segment) | |
| if ttr <= threshold: | |
| factors += 1 | |
| else: | |
| factors += 1 - (ttr - threshold) / (1 - threshold) | |
| return len(tokens) / factors if factors > 0 else len(tokens) | |