|
|
from ._base import BaseEmbedder |
|
|
|
|
|
|
|
|
from bertopic.backend._sklearn import SklearnEmbedder |
|
|
from sklearn.pipeline import make_pipeline |
|
|
from sklearn.decomposition import TruncatedSVD |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sklearn.pipeline import Pipeline as ScikitPipeline |
|
|
|
|
|
|
|
|
languages = [ |
|
|
"arabic", |
|
|
"bulgarian", |
|
|
"catalan", |
|
|
"czech", |
|
|
"danish", |
|
|
"german", |
|
|
"greek", |
|
|
"english", |
|
|
"spanish", |
|
|
"estonian", |
|
|
"persian", |
|
|
"finnish", |
|
|
"french", |
|
|
"canadian french", |
|
|
"galician", |
|
|
"gujarati", |
|
|
"hebrew", |
|
|
"hindi", |
|
|
"croatian", |
|
|
"hungarian", |
|
|
"armenian", |
|
|
"indonesian", |
|
|
"italian", |
|
|
"japanese", |
|
|
"georgian", |
|
|
"korean", |
|
|
"kurdish", |
|
|
"lithuanian", |
|
|
"latvian", |
|
|
"macedonian", |
|
|
"mongolian", |
|
|
"marathi", |
|
|
"malay", |
|
|
"burmese", |
|
|
"norwegian bokmal", |
|
|
"dutch", |
|
|
"polish", |
|
|
"portuguese", |
|
|
"brazilian portuguese", |
|
|
"romanian", |
|
|
"russian", |
|
|
"slovak", |
|
|
"slovenian", |
|
|
"albanian", |
|
|
"serbian", |
|
|
"swedish", |
|
|
"thai", |
|
|
"turkish", |
|
|
"ukrainian", |
|
|
"urdu", |
|
|
"vietnamese", |
|
|
"chinese (simplified)", |
|
|
"chinese (traditional)", |
|
|
] |
|
|
|
|
|
|
|
|
def select_backend(embedding_model, |
|
|
language: str = None) -> BaseEmbedder: |
|
|
""" Select an embedding model based on language or a specific sentence transformer models. |
|
|
When selecting a language, we choose all-MiniLM-L6-v2 for English and |
|
|
paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages. |
|
|
|
|
|
Returns: |
|
|
model: Either a Sentence-Transformer or Flair model |
|
|
""" |
|
|
|
|
|
if isinstance(embedding_model, BaseEmbedder): |
|
|
return embedding_model |
|
|
|
|
|
|
|
|
if isinstance(embedding_model, ScikitPipeline): |
|
|
return SklearnEmbedder(embedding_model) |
|
|
|
|
|
|
|
|
if "flair" in str(type(embedding_model)): |
|
|
from bertopic.backend._flair import FlairBackend |
|
|
return FlairBackend(embedding_model) |
|
|
|
|
|
|
|
|
if "spacy" in str(type(embedding_model)): |
|
|
from bertopic.backend._spacy import SpacyBackend |
|
|
return SpacyBackend(embedding_model) |
|
|
|
|
|
|
|
|
if "gensim" in str(type(embedding_model)): |
|
|
from bertopic.backend._gensim import GensimBackend |
|
|
return GensimBackend(embedding_model) |
|
|
|
|
|
|
|
|
if "tensorflow" and "saved_model" in str(type(embedding_model)): |
|
|
from bertopic.backend._use import USEBackend |
|
|
return USEBackend(embedding_model) |
|
|
|
|
|
|
|
|
if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str): |
|
|
from ._sentencetransformers import SentenceTransformerBackend |
|
|
return SentenceTransformerBackend(embedding_model) |
|
|
|
|
|
|
|
|
if "transformers" and "pipeline" in str(type(embedding_model)): |
|
|
from ._hftransformers import HFTransformerBackend |
|
|
return HFTransformerBackend(embedding_model) |
|
|
|
|
|
|
|
|
if language: |
|
|
try: |
|
|
from ._sentencetransformers import SentenceTransformerBackend |
|
|
if language.lower() in ["English", "english", "en"]: |
|
|
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2") |
|
|
elif language.lower() in languages or language == "multilingual": |
|
|
return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
else: |
|
|
raise ValueError(f"{language} is currently not supported. However, you can " |
|
|
f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n" |
|
|
"Else, please select a language from the following list:\n" |
|
|
f"{languages}") |
|
|
|
|
|
|
|
|
except ModuleNotFoundError: |
|
|
pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100)) |
|
|
return SklearnEmbedder(pipe) |
|
|
|
|
|
from ._sentencetransformers import SentenceTransformerBackend |
|
|
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|