import numpy as np from PIL import Image from tqdm import tqdm from typing import List, Union from sentence_transformers import SentenceTransformer from bertopic.backend import BaseEmbedder class MultiModalBackend(BaseEmbedder): """ Multimodal backend using Sentence-transformers The sentence-transformers embedding model used for generating word, document, and image embeddings. Arguments: embedding_model: A sentence-transformers embedding model that can either embed both images and text or only text. If it only embeds text, then `image_model` needs to be used to embed the images. image_model: A sentence-transformers embedding model that is used to embed only images. batch_size: The sizes of image batches to pass Examples: To create a model, you can load in a string pointing to a sentence-transformers model: ```python from bertopic.backend import MultiModalBackend sentence_model = MultiModalBackend("clip-ViT-B-32") ``` or you can instantiate a model yourself: ```python from bertopic.backend import MultiModalBackend from sentence_transformers import SentenceTransformer embedding_model = SentenceTransformer("clip-ViT-B-32") sentence_model = MultiModalBackend(embedding_model) ``` """ def __init__(self, embedding_model: Union[str, SentenceTransformer], image_model: Union[str, SentenceTransformer] = None, batch_size: int = 32): super().__init__() self.batch_size = batch_size # Text or Text+Image model if isinstance(embedding_model, SentenceTransformer): self.embedding_model = embedding_model elif isinstance(embedding_model, str): self.embedding_model = SentenceTransformer(embedding_model) else: raise ValueError("Please select a correct SentenceTransformers model: \n" "`from sentence_transformers import SentenceTransformer` \n" "`model = SentenceTransformer('clip-ViT-B-32')`") # Image Model self.image_model = None if image_model is not None: if isinstance(image_model, SentenceTransformer): self.image_model = image_model elif isinstance(image_model, str): self.image_model = SentenceTransformer(image_model) else: raise ValueError("Please select a correct SentenceTransformers model: \n" "`from sentence_transformers import SentenceTransformer` \n" "`model = SentenceTransformer('clip-ViT-B-32')`") try: self.tokenizer = self.embedding_model._first_module().processor.tokenizer except AttributeError: self.tokenizer = self.embedding_model.tokenizer except: self.tokenizer = None def embed(self, documents: List[str], images: List[str] = None, verbose: bool = False) -> np.ndarray: """ Embed a list of n documents/words into an n-dimensional matrix of embeddings Arguments: documents: A list of documents or words to be embedded verbose: Controls the verbosity of the process Returns: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ # Embed documents doc_embeddings = None if documents[0] is not None: doc_embeddings = self.embed_documents(documents) # Embed images image_embeddings = None if isinstance(images, list): image_embeddings = self.embed_images(images, verbose) # Average embeddings averaged_embeddings = None if doc_embeddings is not None and image_embeddings is not None: averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0) if averaged_embeddings is not None: return averaged_embeddings elif doc_embeddings is not None: return doc_embeddings elif image_embeddings is not None: return image_embeddings def embed_documents(self, documents: List[str], verbose: bool = False) -> np.ndarray: """ Embed a list of n documents/words into an n-dimensional matrix of embeddings Arguments: documents: A list of documents or words to be embedded verbose: Controls the verbosity of the process Returns: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ truncated_docs = [self._truncate_document(doc) for doc in documents] embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose) return embeddings def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray: """ Embed a list of n words into an n-dimensional matrix of embeddings Arguments: words: A list of words to be embedded verbose: Controls the verbosity of the process Returns: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ embeddings = self.embedding_model.encode(words, show_progress_bar=verbose) return embeddings def embed_images(self, images, verbose): if self.batch_size: nr_iterations = int(np.ceil(len(images) / self.batch_size)) # Embed images per batch embeddings = [] for i in tqdm(range(nr_iterations), disable=not verbose): start_index = i * self.batch_size end_index = (i * self.batch_size) + self.batch_size images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]] if self.image_model is not None: img_emb = self.image_model.encode(images_to_embed) else: img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False) embeddings.extend(img_emb.tolist()) # Close images if isinstance(images[0], str): for image in images_to_embed: image.close() embeddings = np.array(embeddings) else: images_to_embed = [Image.open(filepath) for filepath in images] if self.image_model is not None: embeddings = self.image_model.encode(images_to_embed) else: embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False) return embeddings def _truncate_document(self, document): if self.tokenizer: tokens = self.tokenizer.encode(document) if len(tokens) > 77: # Skip the starting token, only include 75 tokens truncated_tokens = tokens[1:76] document = self.tokenizer.decode(truncated_tokens) # Recursive call here, because the encode(decode()) can have different result return self._truncate_document(document) return document