File size: 4,981 Bytes
8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d 8613884 4894b7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
from typing import List, Union, Optional
class OpsColQwen3Embedder:
"""
Embedder for OpsColQwen3-4B model.
"""
def __init__(
self,
model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B",
dims: int = 2560,
device: Optional[str] = None,
**kwargs
):
"""
Initialize the embedder.
Args:
model_name: Model path or hub name
dims: Embedding dimensions
device: Device to use for inference ('mps', 'cuda', or 'cpu')
**kwargs: Additional arguments passed to from_pretrained
"""
device_map = kwargs.pop('device_map', None)
if not device_map:
if device:
device_map = device
elif torch.cuda.is_available():
device_map = "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device_map = "mps" # Use MPS for Apple Silicon
else:
device_map = "cpu"
dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32)
self.model = AutoModel.from_pretrained(
model_name,
dims=dims,
trust_remote_code=True,
dtype=dtype,
device_map=device_map,
**kwargs
)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True,
**kwargs
)
self.device = device_map
self.dims = dims
def encode_queries(
self,
queries: List[str]
) -> List[torch.Tensor]:
"""
Encode a list of text queries.
Args:
queries: List of query texts
Returns:
List of query embeddings
"""
query_inputs = self.processor.process_queries(queries)
query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()}
with torch.no_grad():
query_embeddings = self.model(**query_inputs)
return [q.cpu() for q in query_embeddings]
def encode_images(
self,
images: List[Union[str, Image.Image]]
) -> List[torch.Tensor]:
"""
Encode a list of images.
Args:
images: List of image paths or PIL Images
Returns:
List of image embeddings
"""
image_objects = []
for img in images:
if isinstance(img, str):
image_objects.append(Image.open(img).convert("RGB"))
elif isinstance(img, Image.Image):
image_objects.append(img)
else:
raise ValueError(f"Unsupported image type: {type(img)}")
image_inputs = self.processor.process_images(image_objects)
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
with torch.no_grad():
image_embeddings = self.model(**image_inputs)
return [i.cpu() for i in image_embeddings]
def compute_scores(
self,
query_embeddings: List[torch.Tensor],
image_embeddings: List[torch.Tensor]
) -> torch.Tensor:
"""
Compute similarity scores between queries and images.
Args:
query_embeddings: List of query embeddings
image_embeddings: List of image embeddings
Returns:
Similarity scores matrix
"""
return self.processor.score_multi_vector(query_embeddings, image_embeddings)
def encode_and_score(
self,
queries: List[str],
images: List[Union[str, Image.Image]]
):
"""
Convenience method to encode queries and images and compute scores.
Args:
queries: List of query texts
images: List of images (paths or PIL objects)
Returns:
Similarity scores between queries and images
"""
query_embeddings = self.encode_queries(queries)
image_embeddings = self.encode_images(images)
return self.compute_scores(query_embeddings, image_embeddings)
# Example usage
if __name__ == "__main__":
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
embedder = OpsColQwen3Embedder(
model_name="OpenSearch-AI/Ops-Colqwen3-4B",
dims=2560,
dtype=torch.float16,
attn_implementation="flash_attention_2",
)
query_embeddings = embedder.encode_queries(queries)
image_embeddings = embedder.encode_images(images)
print(query_embeddings[0].shape, image_embeddings[0].shape) # (23, 2560) (18, 2560)
scores = embedder.compute_scores(query_embeddings, image_embeddings)
print(f"Scores:\n{scores}") |