from bertopic.backend import BaseEmbedder from sklearn.utils.validation import check_is_fitted, NotFittedError class SklearnEmbedder(BaseEmbedder): """ Scikit-Learn based embedding model This component allows the usage of scikit-learn pipelines for generating document and word embeddings. Arguments: pipe: A scikit-learn pipeline that can `.transform()` text. Examples: Scikit-Learn is very flexible and it allows for many representations. A relatively simple pipeline is shown below. ```python from sklearn.pipeline import make_pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from bertopic.backend import SklearnEmbedder pipe = make_pipeline( TfidfVectorizer(), TruncatedSVD(100) ) sklearn_embedder = SklearnEmbedder(pipe) topic_model = BERTopic(embedding_model=sklearn_embedder) ``` This pipeline first constructs a sparse representation based on TF/idf and then makes it dense by applying SVD. Alternatively, you might also construct something more elaborate. As long as you construct a scikit-learn compatible pipeline, you should be able to pass it to Bertopic. !!! Warning One caveat to be aware of is that scikit-learns base `Pipeline` class does not support the `.partial_fit()`-API. If you have a pipeline that theoretically should be able to support online learning then you might want to explore the [scikit-partial](https://github.com/koaning/scikit-partial) project. """ def __init__(self, pipe): super().__init__() self.pipe = pipe def embed(self, documents, verbose=False): """ Embed a list of n documents/words into an n-dimensional matrix of embeddings Arguments: documents: A list of documents or words to be embedded verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API. Returns: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ try: check_is_fitted(self.pipe) embeddings = self.pipe.transform(documents) except NotFittedError: embeddings = self.pipe.fit_transform(documents) return embeddings