File size: 2,170 Bytes
843111c
 
 
 
 
bcdd774
 
843111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from utils import encode_image
from utils import bt_embeddings
from tqdm import tqdm
from typing import List
from langchain_core.embeddings import Embeddings
#from langchain_core.pydantic_v1 import BaseModel
from pydantic.v1 import BaseModel

class BridgeTowerEmbeddings(BaseModel,Embeddings):
    """ BridgeTower embedding model """

    def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
        """Embed a list of image-text pairs using BridgeTower.

        Parameters:
        -----------
        texts: str
            The list of texts to embed.
        images: List
            The list of path-to-images to embed
        batch_size: int
            The batch size to process, default to 2

        Returns:
        --------
            List of embeddings, one for each image-text pairs.
        """

        # the length of texts must be equal to the length of images
        assert len(texts)==len(images), "the len of captions should be equal to the len of images"

        print(f"Embedding {len(texts)} image-text pairs...")

        embeddings = []
        for path_to_img, text in tqdm(zip(images, texts), total=len(images), desc="Processing pairs"):
            embedding = bt_embeddings(text, encode_image(path_to_img))
            embeddings.append(embedding)
        return embeddings
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents using BridgeTower.

        Parameters:
        -----------
        texts: str
            The list of texts to embed.

        Returns:
        --------
            List of embeddings, one for each text.
        """
        embeddings = []
        for text in texts:
            
            embedding = bt_embeddings(text, "")
            embeddings.append(embedding)
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        """Embed a query using BridgeTower.

        Parameters:
        -----------
        texts: str 
            The text to embed.

        Returns:
            Embeddings for the text.
        """
        return self.embed_documents([text])[0]