File size: 2,491 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from bertopic.backend import BaseEmbedder
from sklearn.utils.validation import check_is_fitted, NotFittedError
class SklearnEmbedder(BaseEmbedder):
""" Scikit-Learn based embedding model
This component allows the usage of scikit-learn pipelines for generating document and
word embeddings.
Arguments:
pipe: A scikit-learn pipeline that can `.transform()` text.
Examples:
Scikit-Learn is very flexible and it allows for many representations.
A relatively simple pipeline is shown below.
```python
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from bertopic.backend import SklearnEmbedder
pipe = make_pipeline(
TfidfVectorizer(),
TruncatedSVD(100)
)
sklearn_embedder = SklearnEmbedder(pipe)
topic_model = BERTopic(embedding_model=sklearn_embedder)
```
This pipeline first constructs a sparse representation based on TF/idf and then
makes it dense by applying SVD. Alternatively, you might also construct something
more elaborate. As long as you construct a scikit-learn compatible pipeline, you
should be able to pass it to Bertopic.
!!! Warning
One caveat to be aware of is that scikit-learns base `Pipeline` class does not
support the `.partial_fit()`-API. If you have a pipeline that theoretically should
be able to support online learning then you might want to explore
the [scikit-partial](https://github.com/koaning/scikit-partial) project.
"""
def __init__(self, pipe):
super().__init__()
self.pipe = pipe
def embed(self, documents, verbose=False):
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
try:
check_is_fitted(self.pipe)
embeddings = self.pipe.transform(documents)
except NotFittedError:
embeddings = self.pipe.fit_transform(documents)
return embeddings
|