| import numpy as np |
| from typing import List, Union |
| from sentence_transformers import SentenceTransformer |
|
|
| from bertopic.backend import BaseEmbedder |
|
|
|
|
| class SentenceTransformerBackend(BaseEmbedder): |
| """ Sentence-transformers embedding model |
| |
| The sentence-transformers embedding model used for generating document and |
| word embeddings. |
| |
| Arguments: |
| embedding_model: A sentence-transformers embedding model |
| |
| Examples: |
| |
| To create a model, you can load in a string pointing to a |
| sentence-transformers model: |
| |
| ```python |
| from bertopic.backend import SentenceTransformerBackend |
| |
| sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2") |
| ``` |
| |
| or you can instantiate a model yourself: |
| ```python |
| from bertopic.backend import SentenceTransformerBackend |
| from sentence_transformers import SentenceTransformer |
| |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
| sentence_model = SentenceTransformerBackend(embedding_model) |
| ``` |
| """ |
| def __init__(self, embedding_model: Union[str, SentenceTransformer]): |
| super().__init__() |
|
|
| self._hf_model = None |
| if isinstance(embedding_model, SentenceTransformer): |
| self.embedding_model = embedding_model |
| elif isinstance(embedding_model, str): |
| self.embedding_model = SentenceTransformer(embedding_model) |
| self._hf_model = embedding_model |
| else: |
| raise ValueError("Please select a correct SentenceTransformers model: \n" |
| "`from sentence_transformers import SentenceTransformer` \n" |
| "`model = SentenceTransformer('all-MiniLM-L6-v2')`") |
|
|
| def embed(self, |
| documents: List[str], |
| verbose: bool = False) -> np.ndarray: |
| """ Embed a list of n documents/words into an n-dimensional |
| matrix of embeddings |
| |
| Arguments: |
| documents: A list of documents or words to be embedded |
| verbose: Controls the verbosity of the process |
| |
| Returns: |
| Document/words embeddings with shape (n, m) with `n` documents/words |
| that each have an embeddings size of `m` |
| """ |
| embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) |
| return embeddings |