|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
from typing import List, Union, Optional |
|
|
|
|
|
|
|
|
class OpsColQwen3Embedder: |
|
|
""" |
|
|
Embedder for OpsColQwen3-4B model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B", |
|
|
dims: int = 2560, |
|
|
device: Optional[str] = None, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Initialize the embedder. |
|
|
|
|
|
Args: |
|
|
model_name: Model path or hub name |
|
|
dims: Embedding dimensions |
|
|
device: Device to use for inference ('mps', 'cuda', or 'cpu') |
|
|
**kwargs: Additional arguments passed to from_pretrained |
|
|
""" |
|
|
|
|
|
device_map = kwargs.pop('device_map', None) |
|
|
if not device_map: |
|
|
if device: |
|
|
device_map = device |
|
|
elif torch.cuda.is_available(): |
|
|
device_map = "cuda" |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
device_map = "mps" |
|
|
else: |
|
|
device_map = "cpu" |
|
|
|
|
|
dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32) |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_name, |
|
|
dims=dims, |
|
|
trust_remote_code=True, |
|
|
dtype=dtype, |
|
|
device_map=device_map, |
|
|
**kwargs |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self.device = device_map |
|
|
self.dims = dims |
|
|
|
|
|
def encode_queries( |
|
|
self, |
|
|
queries: List[str] |
|
|
) -> List[torch.Tensor]: |
|
|
""" |
|
|
Encode a list of text queries. |
|
|
|
|
|
Args: |
|
|
queries: List of query texts |
|
|
|
|
|
Returns: |
|
|
List of query embeddings |
|
|
""" |
|
|
query_inputs = self.processor.process_queries(queries) |
|
|
query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
query_embeddings = self.model(**query_inputs) |
|
|
|
|
|
return [q.cpu() for q in query_embeddings] |
|
|
|
|
|
def encode_images( |
|
|
self, |
|
|
images: List[Union[str, Image.Image]] |
|
|
) -> List[torch.Tensor]: |
|
|
""" |
|
|
Encode a list of images. |
|
|
|
|
|
Args: |
|
|
images: List of image paths or PIL Images |
|
|
|
|
|
Returns: |
|
|
List of image embeddings |
|
|
""" |
|
|
image_objects = [] |
|
|
for img in images: |
|
|
if isinstance(img, str): |
|
|
image_objects.append(Image.open(img).convert("RGB")) |
|
|
elif isinstance(img, Image.Image): |
|
|
image_objects.append(img) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image type: {type(img)}") |
|
|
|
|
|
image_inputs = self.processor.process_images(image_objects) |
|
|
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_embeddings = self.model(**image_inputs) |
|
|
|
|
|
return [i.cpu() for i in image_embeddings] |
|
|
|
|
|
def compute_scores( |
|
|
self, |
|
|
query_embeddings: List[torch.Tensor], |
|
|
image_embeddings: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute similarity scores between queries and images. |
|
|
|
|
|
Args: |
|
|
query_embeddings: List of query embeddings |
|
|
image_embeddings: List of image embeddings |
|
|
|
|
|
Returns: |
|
|
Similarity scores matrix |
|
|
""" |
|
|
return self.processor.score_multi_vector(query_embeddings, image_embeddings) |
|
|
|
|
|
def encode_and_score( |
|
|
self, |
|
|
queries: List[str], |
|
|
images: List[Union[str, Image.Image]] |
|
|
): |
|
|
""" |
|
|
Convenience method to encode queries and images and compute scores. |
|
|
|
|
|
Args: |
|
|
queries: List of query texts |
|
|
images: List of images (paths or PIL objects) |
|
|
|
|
|
Returns: |
|
|
Similarity scores between queries and images |
|
|
""" |
|
|
query_embeddings = self.encode_queries(queries) |
|
|
image_embeddings = self.encode_images(images) |
|
|
return self.compute_scores(query_embeddings, image_embeddings) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")] |
|
|
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"] |
|
|
|
|
|
embedder = OpsColQwen3Embedder( |
|
|
model_name="OpenSearch-AI/Ops-Colqwen3-4B", |
|
|
dims=2560, |
|
|
dtype=torch.float16, |
|
|
attn_implementation="flash_attention_2", |
|
|
) |
|
|
|
|
|
query_embeddings = embedder.encode_queries(queries) |
|
|
image_embeddings = embedder.encode_images(images) |
|
|
print(query_embeddings[0].shape, image_embeddings[0].shape) |
|
|
|
|
|
scores = embedder.compute_scores(query_embeddings, image_embeddings) |
|
|
|
|
|
print(f"Scores:\n{scores}") |