Ops-Colqwen3-4B / scripts /ops_colqwen3_embedder.py
frozenc's picture
update usage
4894b7d verified
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
from typing import List, Union, Optional
class OpsColQwen3Embedder:
"""
Embedder for OpsColQwen3-4B model.
"""
def __init__(
self,
model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B",
dims: int = 2560,
device: Optional[str] = None,
**kwargs
):
"""
Initialize the embedder.
Args:
model_name: Model path or hub name
dims: Embedding dimensions
device: Device to use for inference ('mps', 'cuda', or 'cpu')
**kwargs: Additional arguments passed to from_pretrained
"""
device_map = kwargs.pop('device_map', None)
if not device_map:
if device:
device_map = device
elif torch.cuda.is_available():
device_map = "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device_map = "mps" # Use MPS for Apple Silicon
else:
device_map = "cpu"
dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32)
self.model = AutoModel.from_pretrained(
model_name,
dims=dims,
trust_remote_code=True,
dtype=dtype,
device_map=device_map,
**kwargs
)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True,
**kwargs
)
self.device = device_map
self.dims = dims
def encode_queries(
self,
queries: List[str]
) -> List[torch.Tensor]:
"""
Encode a list of text queries.
Args:
queries: List of query texts
Returns:
List of query embeddings
"""
query_inputs = self.processor.process_queries(queries)
query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()}
with torch.no_grad():
query_embeddings = self.model(**query_inputs)
return [q.cpu() for q in query_embeddings]
def encode_images(
self,
images: List[Union[str, Image.Image]]
) -> List[torch.Tensor]:
"""
Encode a list of images.
Args:
images: List of image paths or PIL Images
Returns:
List of image embeddings
"""
image_objects = []
for img in images:
if isinstance(img, str):
image_objects.append(Image.open(img).convert("RGB"))
elif isinstance(img, Image.Image):
image_objects.append(img)
else:
raise ValueError(f"Unsupported image type: {type(img)}")
image_inputs = self.processor.process_images(image_objects)
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
with torch.no_grad():
image_embeddings = self.model(**image_inputs)
return [i.cpu() for i in image_embeddings]
def compute_scores(
self,
query_embeddings: List[torch.Tensor],
image_embeddings: List[torch.Tensor]
) -> torch.Tensor:
"""
Compute similarity scores between queries and images.
Args:
query_embeddings: List of query embeddings
image_embeddings: List of image embeddings
Returns:
Similarity scores matrix
"""
return self.processor.score_multi_vector(query_embeddings, image_embeddings)
def encode_and_score(
self,
queries: List[str],
images: List[Union[str, Image.Image]]
):
"""
Convenience method to encode queries and images and compute scores.
Args:
queries: List of query texts
images: List of images (paths or PIL objects)
Returns:
Similarity scores between queries and images
"""
query_embeddings = self.encode_queries(queries)
image_embeddings = self.encode_images(images)
return self.compute_scores(query_embeddings, image_embeddings)
# Example usage
if __name__ == "__main__":
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
embedder = OpsColQwen3Embedder(
model_name="OpenSearch-AI/Ops-Colqwen3-4B",
dims=2560,
dtype=torch.float16,
attn_implementation="flash_attention_2",
)
query_embeddings = embedder.encode_queries(queries)
image_embeddings = embedder.encode_images(images)
print(query_embeddings[0].shape, image_embeddings[0].shape) # (23, 2560) (18, 2560)
scores = embedder.compute_scores(query_embeddings, image_embeddings)
print(f"Scores:\n{scores}")