kisejin's picture
Upload 261 files
19b102a verified
raw
history blame
4.31 kB
from ._base import BaseEmbedder
# Imports for light-weight variant of BERTopic
from bertopic.backend._sklearn import SklearnEmbedder
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline as ScikitPipeline
languages = [
"arabic",
"bulgarian",
"catalan",
"czech",
"danish",
"german",
"greek",
"english",
"spanish",
"estonian",
"persian",
"finnish",
"french",
"canadian french",
"galician",
"gujarati",
"hebrew",
"hindi",
"croatian",
"hungarian",
"armenian",
"indonesian",
"italian",
"japanese",
"georgian",
"korean",
"kurdish",
"lithuanian",
"latvian",
"macedonian",
"mongolian",
"marathi",
"malay",
"burmese",
"norwegian bokmal",
"dutch",
"polish",
"portuguese",
"brazilian portuguese",
"romanian",
"russian",
"slovak",
"slovenian",
"albanian",
"serbian",
"swedish",
"thai",
"turkish",
"ukrainian",
"urdu",
"vietnamese",
"chinese (simplified)",
"chinese (traditional)",
]
def select_backend(embedding_model,
language: str = None) -> BaseEmbedder:
""" Select an embedding model based on language or a specific sentence transformer models.
When selecting a language, we choose all-MiniLM-L6-v2 for English and
paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages.
Returns:
model: Either a Sentence-Transformer or Flair model
"""
# BERTopic language backend
if isinstance(embedding_model, BaseEmbedder):
return embedding_model
# Scikit-learn backend
if isinstance(embedding_model, ScikitPipeline):
return SklearnEmbedder(embedding_model)
# Flair word embeddings
if "flair" in str(type(embedding_model)):
from bertopic.backend._flair import FlairBackend
return FlairBackend(embedding_model)
# Spacy embeddings
if "spacy" in str(type(embedding_model)):
from bertopic.backend._spacy import SpacyBackend
return SpacyBackend(embedding_model)
# Gensim embeddings
if "gensim" in str(type(embedding_model)):
from bertopic.backend._gensim import GensimBackend
return GensimBackend(embedding_model)
# USE embeddings
if "tensorflow" and "saved_model" in str(type(embedding_model)):
from bertopic.backend._use import USEBackend
return USEBackend(embedding_model)
# Sentence Transformer embeddings
if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str):
from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend(embedding_model)
# Hugging Face embeddings
if "transformers" and "pipeline" in str(type(embedding_model)):
from ._hftransformers import HFTransformerBackend
return HFTransformerBackend(embedding_model)
# Select embedding model based on language
if language:
try:
from ._sentencetransformers import SentenceTransformerBackend
if language.lower() in ["English", "english", "en"]:
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
elif language.lower() in languages or language == "multilingual":
return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
else:
raise ValueError(f"{language} is currently not supported. However, you can "
f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n"
"Else, please select a language from the following list:\n"
f"{languages}")
# Only for light-weight installation
except ModuleNotFoundError:
pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100))
return SklearnEmbedder(pipe)
from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")