File size: 2,297 Bytes
19b102a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import numpy as np
from tqdm import tqdm
from typing import List
from bertopic.backend import BaseEmbedder
from gensim.models.keyedvectors import Word2VecKeyedVectors
class GensimBackend(BaseEmbedder):
""" Gensim Embedding Model
The Gensim embedding model is typically used for word embeddings with
GloVe, Word2Vec or FastText.
Arguments:
embedding_model: A Gensim embedding model
Examples:
```python
from bertopic.backend import GensimBackend
import gensim.downloader as api
ft = api.load('fasttext-wiki-news-subwords-300')
ft_embedder = GensimBackend(ft)
```
"""
def __init__(self, embedding_model: Word2VecKeyedVectors):
super().__init__()
if isinstance(embedding_model, Word2VecKeyedVectors):
self.embedding_model = embedding_model
else:
raise ValueError("Please select a correct Gensim model: \n"
"`import gensim.downloader as api` \n"
"`ft = api.load('fasttext-wiki-news-subwords-300')`")
def embed(self,
documents: List[str],
verbose: bool = False) -> np.ndarray:
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
vector_shape = self.embedding_model.get_vector(list(self.embedding_model.index_to_key)[0]).shape[0]
empty_vector = np.zeros(vector_shape)
# Extract word embeddings and pool to document-level
embeddings = []
for doc in tqdm(documents, disable=not verbose, position=0, leave=True):
embedding = [self.embedding_model.get_vector(word) for word in doc.split()
if word in self.embedding_model.key_to_index]
if len(embedding) > 0:
embeddings.append(np.mean(embedding, axis=0))
else:
embeddings.append(empty_vector)
embeddings = np.array(embeddings)
return embeddings
|