File size: 7,622 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import List, Union
from sentence_transformers import SentenceTransformer
from bertopic.backend import BaseEmbedder
class MultiModalBackend(BaseEmbedder):
""" Multimodal backend using Sentence-transformers
The sentence-transformers embedding model used for
generating word, document, and image embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model that
can either embed both images and text or only text.
If it only embeds text, then `image_model` needs
to be used to embed the images.
image_model: A sentence-transformers embedding model that is used
to embed only images.
batch_size: The sizes of image batches to pass
Examples:
To create a model, you can load in a string pointing to a
sentence-transformers model:
```python
from bertopic.backend import MultiModalBackend
sentence_model = MultiModalBackend("clip-ViT-B-32")
```
or you can instantiate a model yourself:
```python
from bertopic.backend import MultiModalBackend
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("clip-ViT-B-32")
sentence_model = MultiModalBackend(embedding_model)
```
"""
def __init__(self,
embedding_model: Union[str, SentenceTransformer],
image_model: Union[str, SentenceTransformer] = None,
batch_size: int = 32):
super().__init__()
self.batch_size = batch_size
# Text or Text+Image model
if isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = SentenceTransformer(embedding_model)
else:
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`")
# Image Model
self.image_model = None
if image_model is not None:
if isinstance(image_model, SentenceTransformer):
self.image_model = image_model
elif isinstance(image_model, str):
self.image_model = SentenceTransformer(image_model)
else:
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`")
try:
self.tokenizer = self.embedding_model._first_module().processor.tokenizer
except AttributeError:
self.tokenizer = self.embedding_model.tokenizer
except:
self.tokenizer = None
def embed(self,
documents: List[str],
images: List[str] = None,
verbose: bool = False) -> np.ndarray:
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
# Embed documents
doc_embeddings = None
if documents[0] is not None:
doc_embeddings = self.embed_documents(documents)
# Embed images
image_embeddings = None
if isinstance(images, list):
image_embeddings = self.embed_images(images, verbose)
# Average embeddings
averaged_embeddings = None
if doc_embeddings is not None and image_embeddings is not None:
averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0)
if averaged_embeddings is not None:
return averaged_embeddings
elif doc_embeddings is not None:
return doc_embeddings
elif image_embeddings is not None:
return image_embeddings
def embed_documents(self,
documents: List[str],
verbose: bool = False) -> np.ndarray:
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
truncated_docs = [self._truncate_document(doc) for doc in documents]
embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose)
return embeddings
def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray:
""" Embed a list of n words into an n-dimensional
matrix of embeddings
Arguments:
words: A list of words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(words, show_progress_bar=verbose)
return embeddings
def embed_images(self, images, verbose):
if self.batch_size:
nr_iterations = int(np.ceil(len(images) / self.batch_size))
# Embed images per batch
embeddings = []
for i in tqdm(range(nr_iterations), disable=not verbose):
start_index = i * self.batch_size
end_index = (i * self.batch_size) + self.batch_size
images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]]
if self.image_model is not None:
img_emb = self.image_model.encode(images_to_embed)
else:
img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
embeddings.extend(img_emb.tolist())
# Close images
if isinstance(images[0], str):
for image in images_to_embed:
image.close()
embeddings = np.array(embeddings)
else:
images_to_embed = [Image.open(filepath) for filepath in images]
if self.image_model is not None:
embeddings = self.image_model.encode(images_to_embed)
else:
embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
return embeddings
def _truncate_document(self, document):
if self.tokenizer:
tokens = self.tokenizer.encode(document)
if len(tokens) > 77:
# Skip the starting token, only include 75 tokens
truncated_tokens = tokens[1:76]
document = self.tokenizer.decode(truncated_tokens)
# Recursive call here, because the encode(decode()) can have different result
return self._truncate_document(document)
return document
|