File size: 1,792 Bytes
7566ac3
acd7cf4
7566ac3
 
acd7cf4
 
 
7566ac3
 
 
 
 
acd7cf4
 
7566ac3
 
 
 
 
 
 
 
 
 
 
acd7cf4
 
7566ac3
 
 
 
 
 
 
acd7cf4
 
7566ac3
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
7566ac3
 
 
 
 
 
 
acd7cf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from functools import lru_cache
import os
from typing import Dict, List, Final, Optional
import warnings
import nltk
import jieba

warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    module=r"jieba\._compat"
)

class NLTKHelper:
    """
    NLTK helper class
    """

    SUPPORTED_LANGUAGES: Final[Dict[str, str]] = {
        "en": "english",
        "zh": "chinese"
    }
    _NLTK_PACKAGES: Final[Dict[str, str]] = {
        "stopwords": "corpora",
        "punkt_tab": "tokenizers"
    }

    def __init__(self, nltk_data_path: Optional[str] = None):
        self._nltk_path = nltk_data_path or os.path.join(
            os.path.dirname(os.path.dirname(__file__)),
            "resources", 
            "nltk_data"
        )
        nltk.data.path.append(self._nltk_path)
        jieba.initialize()

        self._ensure_nltk_data("stopwords")
        self._ensure_nltk_data("punkt_tab")

    def _ensure_nltk_data(self, package_name: str) -> None:
        """
        ensure nltk data is downloaded
        """
        try:
            nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}")
        except LookupError:
            nltk.download(package_name, download_dir=self._nltk_path, quiet=True)

    @lru_cache(maxsize=2)
    def get_stopwords(self, lang: str) -> List[str]:
        if lang not in self.SUPPORTED_LANGUAGES:
            raise ValueError(f"Language {lang} is not supported.")
        return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang])

    def word_tokenize(self, text: str, lang: str) -> List[str]:
        if lang not in self.SUPPORTED_LANGUAGES:
            raise ValueError(f"Language {lang} is not supported.")
        if lang == "zh":
            return jieba.lcut(text)

        return nltk.word_tokenize(text)