|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from typing import Union, List |
|
|
from flair.data import Sentence |
|
|
from flair.embeddings import DocumentEmbeddings, TokenEmbeddings, DocumentPoolEmbeddings |
|
|
|
|
|
from bertopic.backend import BaseEmbedder |
|
|
|
|
|
|
|
|
class FlairBackend(BaseEmbedder): |
|
|
""" Flair Embedding Model |
|
|
|
|
|
The Flair embedding model used for generating document and |
|
|
word embeddings. |
|
|
|
|
|
Arguments: |
|
|
embedding_model: A Flair embedding model |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
from bertopic.backend import FlairBackend |
|
|
from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings |
|
|
|
|
|
# Create a Flair Embedding model |
|
|
glove_embedding = WordEmbeddings('crawl') |
|
|
document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding]) |
|
|
|
|
|
# Pass the Flair model to create a new backend |
|
|
flair_embedder = FlairBackend(document_glove_embeddings) |
|
|
``` |
|
|
""" |
|
|
def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if isinstance(embedding_model, TokenEmbeddings): |
|
|
self.embedding_model = DocumentPoolEmbeddings([embedding_model]) |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(embedding_model, DocumentEmbeddings): |
|
|
if "fine_tune" in embedding_model.__dict__: |
|
|
embedding_model.fine_tune = False |
|
|
self.embedding_model = embedding_model |
|
|
|
|
|
else: |
|
|
raise ValueError("Please select a correct Flair model by either using preparing a token or document " |
|
|
"embedding model: \n" |
|
|
"`from flair.embeddings import TransformerDocumentEmbeddings` \n" |
|
|
"`roberta = TransformerDocumentEmbeddings('roberta-base')`") |
|
|
|
|
|
def embed(self, |
|
|
documents: List[str], |
|
|
verbose: bool = False) -> np.ndarray: |
|
|
""" Embed a list of n documents/words into an n-dimensional |
|
|
matrix of embeddings |
|
|
|
|
|
Arguments: |
|
|
documents: A list of documents or words to be embedded |
|
|
verbose: Controls the verbosity of the process |
|
|
|
|
|
Returns: |
|
|
Document/words embeddings with shape (n, m) with `n` documents/words |
|
|
that each have an embeddings size of `m` |
|
|
""" |
|
|
embeddings = [] |
|
|
for document in tqdm(documents, disable=not verbose): |
|
|
try: |
|
|
sentence = Sentence(document) if document else Sentence("an empty document") |
|
|
self.embedding_model.embed(sentence) |
|
|
except RuntimeError: |
|
|
sentence = Sentence("an empty document") |
|
|
self.embedding_model.embed(sentence) |
|
|
embedding = sentence.embedding.detach().cpu().numpy() |
|
|
embeddings.append(embedding) |
|
|
embeddings = np.asarray(embeddings) |
|
|
return embeddings |
|
|
|