File size: 4,981 Bytes
8613884
 
4894b7d
 
8613884
 
 
 
4894b7d
8613884
 
 
 
4894b7d
8613884
 
4894b7d
8613884
4894b7d
 
8613884
4894b7d
 
 
 
 
 
8613884
4894b7d
 
 
 
 
 
 
 
 
 
8613884
4894b7d
8613884
4894b7d
 
 
 
 
 
 
 
8613884
 
4894b7d
 
 
 
 
 
 
 
8613884
4894b7d
8613884
4894b7d
8613884
 
 
 
 
4894b7d
8613884
 
4894b7d
8613884
4894b7d
 
8613884
 
4894b7d
8613884
4894b7d
8613884
 
 
4894b7d
8613884
 
 
 
 
 
 
 
4894b7d
8613884
4894b7d
8613884
 
4894b7d
8613884
4894b7d
8613884
 
 
4894b7d
 
8613884
 
4894b7d
8613884
4894b7d
8613884
 
 
 
4894b7d
8613884
 
4894b7d
8613884
 
4894b7d
 
8613884
 
4894b7d
8613884
4894b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8613884
 
 
 
 
 
 
4894b7d
8613884
4894b7d
8613884
 
 
 
4894b7d
 
 
 
 
8613884
4894b7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
from typing import List, Union, Optional


class OpsColQwen3Embedder:
    """
    Embedder for OpsColQwen3-4B model.
    """

    def __init__(
        self,
        model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B",
        dims: int = 2560,
        device: Optional[str] = None,
        **kwargs
    ):
        """
        Initialize the embedder.

        Args:
            model_name: Model path or hub name
            dims: Embedding dimensions
            device: Device to use for inference ('mps', 'cuda', or 'cpu')
            **kwargs: Additional arguments passed to from_pretrained
        """

        device_map = kwargs.pop('device_map', None)
        if not device_map:
            if device:
                device_map = device
            elif torch.cuda.is_available():
                device_map = "cuda"
            elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                device_map = "mps" # Use MPS for Apple Silicon
            else:
                device_map = "cpu"

        dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32)

        self.model = AutoModel.from_pretrained(
            model_name,
            dims=dims,
            trust_remote_code=True,
            dtype=dtype,
            device_map=device_map,
            **kwargs
        )
        self.model.eval()

        self.processor = AutoProcessor.from_pretrained(
            model_name,
            trust_remote_code=True,
            **kwargs
        )

        self.device = device_map
        self.dims = dims

    def encode_queries(
        self,
        queries: List[str]
    ) -> List[torch.Tensor]:
        """
        Encode a list of text queries.

        Args:
            queries: List of query texts

        Returns:
            List of query embeddings
        """
        query_inputs = self.processor.process_queries(queries)
        query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()}

        with torch.no_grad():
            query_embeddings = self.model(**query_inputs)

        return [q.cpu() for q in query_embeddings]

    def encode_images(
        self,
        images: List[Union[str, Image.Image]]
    ) -> List[torch.Tensor]:
        """
        Encode a list of images.

        Args:
            images: List of image paths or PIL Images

        Returns:
            List of image embeddings
        """
        image_objects = []
        for img in images:
            if isinstance(img, str):
                image_objects.append(Image.open(img).convert("RGB"))
            elif isinstance(img, Image.Image):
                image_objects.append(img)
            else:
                raise ValueError(f"Unsupported image type: {type(img)}")

        image_inputs = self.processor.process_images(image_objects)
        image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}

        with torch.no_grad():
            image_embeddings = self.model(**image_inputs)

        return [i.cpu() for i in image_embeddings]

    def compute_scores(
        self,
        query_embeddings: List[torch.Tensor],
        image_embeddings: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute similarity scores between queries and images.

        Args:
            query_embeddings: List of query embeddings
            image_embeddings: List of image embeddings

        Returns:
            Similarity scores matrix
        """
        return self.processor.score_multi_vector(query_embeddings, image_embeddings)

    def encode_and_score(
        self,
        queries: List[str],
        images: List[Union[str, Image.Image]]
    ):
        """
        Convenience method to encode queries and images and compute scores.

        Args:
            queries: List of query texts
            images: List of images (paths or PIL objects)

        Returns:
            Similarity scores between queries and images
        """
        query_embeddings = self.encode_queries(queries)
        image_embeddings = self.encode_images(images)
        return self.compute_scores(query_embeddings, image_embeddings)


# Example usage
if __name__ == "__main__":
    images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
    queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]

    embedder = OpsColQwen3Embedder(
        model_name="OpenSearch-AI/Ops-Colqwen3-4B",
        dims=2560,
        dtype=torch.float16,
        attn_implementation="flash_attention_2",
    )

    query_embeddings = embedder.encode_queries(queries)
    image_embeddings = embedder.encode_images(images)
    print(query_embeddings[0].shape, image_embeddings[0].shape) # (23, 2560) (18, 2560)

    scores = embedder.compute_scores(query_embeddings, image_embeddings)

    print(f"Scores:\n{scores}")