File size: 4,305 Bytes
19b102a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from ._base import BaseEmbedder

# Imports for light-weight variant of BERTopic
from bertopic.backend._sklearn import SklearnEmbedder
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline as ScikitPipeline


languages = [
    "arabic",
    "bulgarian",
    "catalan",
    "czech",
    "danish",
    "german",
    "greek",
    "english",
    "spanish",
    "estonian",
    "persian",
    "finnish",
    "french",
    "canadian french",
    "galician",
    "gujarati",
    "hebrew",
    "hindi",
    "croatian",
    "hungarian",
    "armenian",
    "indonesian",
    "italian",
    "japanese",
    "georgian",
    "korean",
    "kurdish",
    "lithuanian",
    "latvian",
    "macedonian",
    "mongolian",
    "marathi",
    "malay",
    "burmese",
    "norwegian bokmal",
    "dutch",
    "polish",
    "portuguese",
    "brazilian portuguese",
    "romanian",
    "russian",
    "slovak",
    "slovenian",
    "albanian",
    "serbian",
    "swedish",
    "thai",
    "turkish",
    "ukrainian",
    "urdu",
    "vietnamese",
    "chinese (simplified)",
    "chinese (traditional)",
]


def select_backend(embedding_model,
                   language: str = None) -> BaseEmbedder:
    """ Select an embedding model based on language or a specific sentence transformer models.
    When selecting a language, we choose all-MiniLM-L6-v2 for English and
    paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages.

    Returns:
        model: Either a Sentence-Transformer or Flair model
    """
    # BERTopic language backend
    if isinstance(embedding_model, BaseEmbedder):
        return embedding_model

    # Scikit-learn backend
    if isinstance(embedding_model, ScikitPipeline):
        return SklearnEmbedder(embedding_model)

    # Flair word embeddings
    if "flair" in str(type(embedding_model)):
        from bertopic.backend._flair import FlairBackend
        return FlairBackend(embedding_model)

    # Spacy embeddings
    if "spacy" in str(type(embedding_model)):
        from bertopic.backend._spacy import SpacyBackend
        return SpacyBackend(embedding_model)

    # Gensim embeddings
    if "gensim" in str(type(embedding_model)):
        from bertopic.backend._gensim import GensimBackend
        return GensimBackend(embedding_model)

    # USE embeddings
    if "tensorflow" and "saved_model" in str(type(embedding_model)):
        from bertopic.backend._use import USEBackend
        return USEBackend(embedding_model)

    # Sentence Transformer embeddings
    if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str):
        from ._sentencetransformers import SentenceTransformerBackend
        return SentenceTransformerBackend(embedding_model)

    # Hugging Face embeddings
    if "transformers" and "pipeline" in str(type(embedding_model)):
        from ._hftransformers import HFTransformerBackend
        return HFTransformerBackend(embedding_model)

    # Select embedding model based on language
    if language:
        try:
            from ._sentencetransformers import SentenceTransformerBackend
            if language.lower() in ["English", "english", "en"]:
                return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
            elif language.lower() in languages or language == "multilingual":
                return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
            else:
                raise ValueError(f"{language} is currently not supported. However, you can "
                                f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n"
                                "Else, please select a language from the following list:\n"
                                f"{languages}")

        # Only for light-weight installation
        except ModuleNotFoundError:
            pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100))
            return SklearnEmbedder(pipe)

    from ._sentencetransformers import SentenceTransformerBackend
    return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")