TopicModelingRepo / BERTopic /bertopic /backend /_sentencetransformers.py
kisejin's picture
Upload 261 files
19b102a verified
raw
history blame
2.32 kB
import numpy as np
from typing import List, Union
from sentence_transformers import SentenceTransformer
from bertopic.backend import BaseEmbedder
class SentenceTransformerBackend(BaseEmbedder):
""" Sentence-transformers embedding model
The sentence-transformers embedding model used for generating document and
word embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model
Examples:
To create a model, you can load in a string pointing to a
sentence-transformers model:
```python
from bertopic.backend import SentenceTransformerBackend
sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2")
```
or you can instantiate a model yourself:
```python
from bertopic.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
sentence_model = SentenceTransformerBackend(embedding_model)
```
"""
def __init__(self, embedding_model: Union[str, SentenceTransformer]):
super().__init__()
self._hf_model = None
if isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = SentenceTransformer(embedding_model)
self._hf_model = embedding_model
else:
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('all-MiniLM-L6-v2')`")
def embed(self,
documents: List[str],
verbose: bool = False) -> np.ndarray:
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
return embeddings