File size: 7,622 Bytes
19b102a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import List, Union
from sentence_transformers import SentenceTransformer

from bertopic.backend import BaseEmbedder


class MultiModalBackend(BaseEmbedder):
    """ Multimodal backend using Sentence-transformers

    The sentence-transformers embedding model used for 
    generating word, document, and image embeddings.  

    Arguments:
        embedding_model: A sentence-transformers embedding model that 
                         can either embed both images and text or only text.
                         If it only embeds text, then `image_model` needs 
                         to be used to embed the images.
        image_model: A sentence-transformers embedding model that is used
                     to embed only images.
        batch_size: The sizes of image batches to pass

    Examples:

    To create a model, you can load in a string pointing to a
    sentence-transformers model:

    ```python
    from bertopic.backend import MultiModalBackend

    sentence_model = MultiModalBackend("clip-ViT-B-32")
    ```

    or  you can instantiate a model yourself:
    ```python
    from bertopic.backend import MultiModalBackend
    from sentence_transformers import SentenceTransformer

    embedding_model = SentenceTransformer("clip-ViT-B-32")
    sentence_model = MultiModalBackend(embedding_model)
    ```
    """
    def __init__(self, 
                 embedding_model: Union[str, SentenceTransformer], 
                 image_model: Union[str, SentenceTransformer] = None,
                 batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size
        
        # Text or Text+Image model
        if isinstance(embedding_model, SentenceTransformer):
            self.embedding_model = embedding_model
        elif isinstance(embedding_model, str):
            self.embedding_model = SentenceTransformer(embedding_model)
        else:
            raise ValueError("Please select a correct SentenceTransformers model: \n"
                             "`from sentence_transformers import SentenceTransformer` \n"
                             "`model = SentenceTransformer('clip-ViT-B-32')`")

        # Image Model
        self.image_model = None
        if image_model is not None:
            if isinstance(image_model, SentenceTransformer):
                self.image_model = image_model
            elif isinstance(image_model, str):
                self.image_model = SentenceTransformer(image_model)
            else:
                raise ValueError("Please select a correct SentenceTransformers model: \n"
                                "`from sentence_transformers import SentenceTransformer` \n"
                                "`model = SentenceTransformer('clip-ViT-B-32')`")
            
        try:
            self.tokenizer = self.embedding_model._first_module().processor.tokenizer
        except AttributeError:
            self.tokenizer = self.embedding_model.tokenizer
        except:
            self.tokenizer = None

    def embed(self,
              documents: List[str],
              images: List[str] = None,
              verbose: bool = False) -> np.ndarray:
        """ Embed a list of n documents/words into an n-dimensional
        matrix of embeddings

        Arguments:
            documents: A list of documents or words to be embedded
            verbose: Controls the verbosity of the process

        Returns:
            Document/words embeddings with shape (n, m) with `n` documents/words
            that each have an embeddings size of `m`
        """
        # Embed documents
        doc_embeddings = None
        if documents[0] is not None:
            doc_embeddings = self.embed_documents(documents)

        # Embed images
        image_embeddings = None
        if isinstance(images, list):
            image_embeddings = self.embed_images(images, verbose)

        # Average embeddings
        averaged_embeddings = None
        if doc_embeddings is not None and image_embeddings is not None:
            averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0)

        if averaged_embeddings is not None:
            return averaged_embeddings
        elif doc_embeddings is not None:
            return doc_embeddings
        elif image_embeddings is not None:
            return image_embeddings
    
    def embed_documents(self,
              documents: List[str],
              verbose: bool = False) -> np.ndarray:
        """ Embed a list of n documents/words into an n-dimensional
        matrix of embeddings

        Arguments:
            documents: A list of documents or words to be embedded
            verbose: Controls the verbosity of the process

        Returns:
            Document/words embeddings with shape (n, m) with `n` documents/words
            that each have an embeddings size of `m`
        """
        truncated_docs = [self._truncate_document(doc) for doc in documents]
        embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose)
        return embeddings
    
    def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray:
        """ Embed a list of n words into an n-dimensional
        matrix of embeddings

        Arguments:
            words: A list of words to be embedded
            verbose: Controls the verbosity of the process

        Returns:
            Document/words embeddings with shape (n, m) with `n` documents/words
            that each have an embeddings size of `m`
        """
        embeddings = self.embedding_model.encode(words, show_progress_bar=verbose)    
        return embeddings
    
    def embed_images(self, images, verbose):
        if self.batch_size:
            nr_iterations = int(np.ceil(len(images) / self.batch_size))

            # Embed images per batch
            embeddings = []
            for i in tqdm(range(nr_iterations), disable=not verbose):
                start_index = i * self.batch_size
                end_index = (i * self.batch_size) + self.batch_size

                images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]]
                if self.image_model is not None:
                    img_emb = self.image_model.encode(images_to_embed)
                else:
                    img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
                embeddings.extend(img_emb.tolist())

                # Close images
                if isinstance(images[0], str):
                    for image in images_to_embed:
                        image.close()
            embeddings = np.array(embeddings)
        else:
            images_to_embed = [Image.open(filepath) for filepath in images]
            if self.image_model is not None:
                embeddings = self.image_model.encode(images_to_embed)
            else:
                embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
        return embeddings
    
    def _truncate_document(self, document):
        if self.tokenizer:
            tokens = self.tokenizer.encode(document)

            if len(tokens) > 77:
                # Skip the starting token, only include 75 tokens
                truncated_tokens = tokens[1:76]
                document = self.tokenizer.decode(truncated_tokens)

                # Recursive call here, because the encode(decode()) can have different result
                return self._truncate_document(document)

        return document