File size: 2,968 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import numpy as np
from tqdm import tqdm
from typing import Union, List
from flair.data import Sentence
from flair.embeddings import DocumentEmbeddings, TokenEmbeddings, DocumentPoolEmbeddings
from bertopic.backend import BaseEmbedder
class FlairBackend(BaseEmbedder):
""" Flair Embedding Model
The Flair embedding model used for generating document and
word embeddings.
Arguments:
embedding_model: A Flair embedding model
Examples:
```python
from bertopic.backend import FlairBackend
from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings
# Create a Flair Embedding model
glove_embedding = WordEmbeddings('crawl')
document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding])
# Pass the Flair model to create a new backend
flair_embedder = FlairBackend(document_glove_embeddings)
```
"""
def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]):
super().__init__()
# Flair word embeddings
if isinstance(embedding_model, TokenEmbeddings):
self.embedding_model = DocumentPoolEmbeddings([embedding_model])
# Flair document embeddings + disable fine tune to prevent CUDA OOM
# https://github.com/flairNLP/flair/issues/1719
elif isinstance(embedding_model, DocumentEmbeddings):
if "fine_tune" in embedding_model.__dict__:
embedding_model.fine_tune = False
self.embedding_model = embedding_model
else:
raise ValueError("Please select a correct Flair model by either using preparing a token or document "
"embedding model: \n"
"`from flair.embeddings import TransformerDocumentEmbeddings` \n"
"`roberta = TransformerDocumentEmbeddings('roberta-base')`")
def embed(self,
documents: List[str],
verbose: bool = False) -> np.ndarray:
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = []
for document in tqdm(documents, disable=not verbose):
try:
sentence = Sentence(document) if document else Sentence("an empty document")
self.embedding_model.embed(sentence)
except RuntimeError:
sentence = Sentence("an empty document")
self.embedding_model.embed(sentence)
embedding = sentence.embedding.detach().cpu().numpy()
embeddings.append(embedding)
embeddings = np.asarray(embeddings)
return embeddings
|