frozenc commited on
Commit
4894b7d
·
verified ·
1 Parent(s): 785e185

update usage

Browse files
README.md CHANGED
@@ -32,33 +32,35 @@ The model is trained using a multi-stage strategy that combines large-scale text
32
 
33
  **Requirements**
34
  ```
 
35
  transformers>=4.57.0
36
  qwen-vl-utils>=0.0.14
37
  torch==2.8.0
38
- colpali_engine==0.3.12
39
  ```
40
 
41
  **Basic Usage**
42
 
43
  ```python
 
44
  from PIL import Image
45
  from scripts.ops_colqwen3_embedder import OpsColQwen3Embedder
46
 
47
  images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
48
-
49
  queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
50
 
51
- encoder = OpsColQwen3Embedder(
52
  model_name="OpenSearch-AI/Ops-Colqwen3-4B",
53
- dims=320,
54
  dtype=torch.float16,
55
  attn_implementation="flash_attention_2",
56
  )
57
 
58
- query_embeddings = encoder.encode_texts(queries, batch_size=2)
59
- image_embeddings = encoder.encode_images(images, batch_size=2)
 
 
 
60
 
61
- scores = encoder.compute_scores(query_embeddings, image_embeddings)
62
  print(f"Scores:\n{scores}")
63
  ```
64
 
 
32
 
33
  **Requirements**
34
  ```
35
+ pillow
36
  transformers>=4.57.0
37
  qwen-vl-utils>=0.0.14
38
  torch==2.8.0
 
39
  ```
40
 
41
  **Basic Usage**
42
 
43
  ```python
44
+ import torch
45
  from PIL import Image
46
  from scripts.ops_colqwen3_embedder import OpsColQwen3Embedder
47
 
48
  images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
 
49
  queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
50
 
51
+ embedder = OpsColQwen3Embedder(
52
  model_name="OpenSearch-AI/Ops-Colqwen3-4B",
53
+ dims=2560,
54
  dtype=torch.float16,
55
  attn_implementation="flash_attention_2",
56
  )
57
 
58
+ query_embeddings = embedder.encode_queries(queries)
59
+ image_embeddings = embedder.encode_images(images)
60
+ print(query_embeddings[0].shape, image_embeddings[0].shape) # (23, 2560) (18, 2560)
61
+
62
+ scores = embedder.compute_scores(query_embeddings, image_embeddings)
63
 
 
64
  print(f"Scores:\n{scores}")
65
  ```
66
 
__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_ops_colqwen3 import OpsColQwen3Config
2
+ from .modeling_ops_colqwen3 import OpsColQwen3Model, OpsColQwen3PreTrainedModel
3
+ from .processing_ops_colqwen3 import OpsColQwen3Processor
4
+
5
+ __all__ = [
6
+ "OpsColQwen3Config",
7
+ "OpsColQwen3Model",
8
+ "OpsColQwen3PreTrainedModel",
9
+ "OpsColQwen3Processor",
10
+ ]
config.json CHANGED
@@ -1,10 +1,17 @@
1
  {
2
  "architectures": [
3
- "ColQwen3VLModel"
4
  ],
 
 
 
 
 
 
 
5
  "dtype": "float32",
6
  "image_token_id": 151655,
7
- "model_type": "qwen3_vl",
8
  "text_config": {
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
 
1
  {
2
  "architectures": [
3
+ "OpsColQwen3Model"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_ops_colqwen3.OpsColQwen3Config",
7
+ "AutoModel": "modeling_ops_colqwen3.OpsColQwen3Model",
8
+ "AutoModelForVision2Seq": "modeling_ops_colqwen3.OpsColQwen3Model",
9
+ "AutoProcessor": "processing_ops_colqwen3.OpsColQwen3Processor"
10
+ },
11
+ "dims": 2560,
12
  "dtype": "float32",
13
  "image_token_id": 151655,
14
+ "model_type": "ops_colqwen3",
15
  "text_config": {
16
  "attention_bias": false,
17
  "attention_dropout": 0.0,
configuration_ops_colqwen3.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen3VLConfig
2
+
3
+
4
+ class OpsColQwen3Config(Qwen3VLConfig):
5
+ """
6
+ Configuration class for OpsColQwen3 model.
7
+ """
8
+ model_type = "ops_colqwen3"
9
+
10
+ def __init__(
11
+ self,
12
+ dims: int = 2560,
13
+ mask_non_image_embeddings: bool = False,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.dims = dims
18
+ self.mask_non_image_embeddings = mask_non_image_embeddings
modeling_ops_colqwen3.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch import nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.models.qwen3_vl import Qwen3VLModel
6
+ from transformers.utils import logging
7
+
8
+ from .configuration_ops_colqwen3 import OpsColQwen3Config
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class OpsColQwen3PreTrainedModel(PreTrainedModel):
14
+ config_class = OpsColQwen3Config
15
+ base_model_prefix = "ops_colqwen3"
16
+ supports_gradient_checkpointing = True
17
+ _no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"]
18
+ _skip_keys_device_placement = "past_key_values"
19
+ _supports_flash_attn_2 = True
20
+ _supports_sdpa = True
21
+ _supports_cache_class = True
22
+
23
+
24
+ class OpsColQwen3Model(OpsColQwen3PreTrainedModel):
25
+ _checkpoint_conversion_mapping = {
26
+ r"^language_model": r"qwen3vl.language_model",
27
+ r"^visual": "qwen3vl.visual",
28
+ }
29
+
30
+ def __init__(self, config: OpsColQwen3Config):
31
+ super().__init__(config)
32
+ self.config = config
33
+
34
+ self.qwen3vl = Qwen3VLModel(config)
35
+ self.dims = config.text_config.hidden_size
36
+ self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims)
37
+
38
+ self.mask_non_image_embeddings = config.mask_non_image_embeddings
39
+ self.post_init()
40
+
41
+ @classmethod
42
+ def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs):
43
+ key_mapping = kwargs.pop("key_mapping", None)
44
+ if key_mapping is None:
45
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
46
+ dims = None
47
+ if 'dims' in kwargs:
48
+ dims = kwargs.pop('dims')
49
+ elif config is not None:
50
+ dims = config.dims
51
+
52
+ model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping)
53
+ if dims is not None:
54
+ model.dims = dims
55
+ return model
56
+
57
+ def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
58
+ has_pixel_values = pixel_values is not None
59
+
60
+ if has_pixel_values:
61
+ if image_grid_thw is None:
62
+ raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")
63
+ if not torch.is_tensor(image_grid_thw):
64
+ image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device)
65
+
66
+ offsets = image_grid_thw.prod(dim=1)
67
+ unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)]
68
+ pixel_values = torch.cat(unpadded, dim=0) if unpadded else None
69
+
70
+ outputs = self.qwen3vl(
71
+ input_ids=input_ids,
72
+ attention_mask=attention_mask,
73
+ pixel_values=pixel_values,
74
+ image_grid_thw=image_grid_thw,
75
+ use_cache=False,
76
+ output_hidden_states=True,
77
+ return_dict=True,
78
+ )
79
+
80
+ last_hidden_states = outputs.last_hidden_state
81
+ proj = self.custom_text_proj(last_hidden_states)
82
+
83
+ if self.dims < self.config.text_config.hidden_size:
84
+ proj = proj[..., : self.dims]
85
+
86
+ proj = proj / proj.norm(dim=-1, keepdim=True)
87
+
88
+ if attention_mask is not None:
89
+ proj = proj * attention_mask.unsqueeze(-1)
90
+
91
+ if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None:
92
+ image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
93
+ proj = proj * image_mask
94
+
95
+ return proj
96
+
97
+ @property
98
+ def patch_size(self) -> int:
99
+ return self.qwen3vl.visual.config.patch_size
100
+
101
+ @property
102
+ def spatial_merge_size(self) -> int:
103
+ return self.qwen3vl.visual.config.spatial_merge_size
preprocessor_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
 
 
 
2
  "crop_size": null,
3
  "data_format": "channels_first",
4
  "default_to_square": true,
 
1
  {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_ops_colqwen3.OpsColQwen3Processor"
4
+ },
5
  "crop_size": null,
6
  "data_format": "channels_first",
7
  "default_to_square": true,
processing_ops_colqwen3.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import BatchEncoding, BatchFeature
8
+ from transformers.models.qwen3_vl import Qwen3VLProcessor
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def get_torch_device(device: str = "auto") -> str:
14
+ """
15
+ Returns the device (string) to be used by PyTorch.
16
+
17
+ `device` arg defaults to "auto" which will use:
18
+ - "cuda:0" if available
19
+ - else "mps" if available
20
+ - else "cpu".
21
+ """
22
+
23
+ if device == "auto":
24
+ if torch.cuda.is_available():
25
+ device = "cuda:0"
26
+ elif torch.backends.mps.is_available(): # for Apple Silicon
27
+ device = "mps"
28
+ else:
29
+ device = "cpu"
30
+ logger.info(f"Using device: {device}")
31
+
32
+ return device
33
+
34
+
35
+ class OpsColQwen3Processor(Qwen3VLProcessor):
36
+ """
37
+ Processor for OpsColQwen3 model.
38
+ """
39
+
40
+ attributes = ["image_processor", "tokenizer"]
41
+ image_processor_class = "AutoImageProcessor"
42
+ tokenizer_class = "AutoTokenizer"
43
+
44
+ query_prefix: str = "Query: "
45
+ visual_prompt_prefix: str = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|im_start|>assistant\n<|endoftext|>"
46
+ query_augmentation_token: str = "<|endoftext|>"
47
+ image_token: str = "<|image_pad|>"
48
+
49
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
50
+ """
51
+ Initialize the processor.
52
+
53
+ Args:
54
+ image_processor: Image processor instance
55
+ tokenizer: Tokenizer instance
56
+ chat_template: Optional chat template
57
+ **kwargs: Additional arguments
58
+ """
59
+ super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template, **kwargs)
60
+
61
+ if self.tokenizer is not None:
62
+ self.tokenizer.padding_side = "left"
63
+
64
+ def process_images(self, images: List[Image.Image], return_tensors: str = "pt", **kwargs) -> Union[BatchFeature, BatchEncoding]:
65
+ """
66
+ Process a batch of PIL images for the model.
67
+ """
68
+ images = [image.convert("RGB") for image in images]
69
+
70
+ batch_doc = self(text=[self.visual_prompt_prefix] * len(images), images=images, padding="longest", return_tensors=return_tensors, **kwargs)
71
+
72
+ if batch_doc["pixel_values"].numel() == 0:
73
+ return batch_doc
74
+
75
+ offsets = batch_doc["image_grid_thw"].prod(dim=1)
76
+ pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist()))
77
+ batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)
78
+
79
+ return batch_doc
80
+
81
+ def process_queries(self, queries: List[str], return_tensors: str = "pt", **kwargs) -> Union[BatchFeature, BatchEncoding]:
82
+ """
83
+ Process a list of text queries.
84
+ """
85
+ processed_queries = [self.query_prefix + q + self.query_augmentation_token * 10 for q in queries]
86
+ return self(text=processed_queries, return_tensors=return_tensors, padding="longest", **kwargs)
87
+
88
+ @staticmethod
89
+ def score_multi_vector(
90
+ qs: Union[torch.Tensor, List[torch.Tensor]],
91
+ ps: Union[torch.Tensor, List[torch.Tensor]],
92
+ batch_size: int = 128,
93
+ device: Optional[Union[str, torch.device]] = None,
94
+ ) -> torch.Tensor:
95
+ """
96
+ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
97
+ query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
98
+ image of a document page.
99
+
100
+ Because the embedding tensors are multi-vector and can thus have different shapes, they
101
+ should be fed as:
102
+ (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
103
+ (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
104
+ obtained by padding the list of tensors.
105
+
106
+ Args:
107
+ qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
108
+ ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
109
+ batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
110
+ device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
111
+ provided, uses `get_torch_device("auto")`.
112
+
113
+ Returns:
114
+ `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
115
+ tensor is saved on the "cpu" device.
116
+ """
117
+ device = device or get_torch_device("auto")
118
+
119
+ if len(qs) == 0:
120
+ raise ValueError("No queries provided")
121
+ if len(ps) == 0:
122
+ raise ValueError("No passages provided")
123
+
124
+ scores_list: List[torch.Tensor] = []
125
+
126
+ for i in range(0, len(qs), batch_size):
127
+ scores_batch = []
128
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(device)
129
+ for j in range(0, len(ps), batch_size):
130
+ ps_batch = torch.nn.utils.rnn.pad_sequence(ps[j : j + batch_size], batch_first=True, padding_value=0).to(device)
131
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
132
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
133
+ scores_list.append(scores_batch)
134
+
135
+ scores = torch.cat(scores_list, dim=0)
136
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
137
+
138
+ scores = scores.to(torch.float32)
139
+ return scores
scripts/ops_colqwen3_embedder.py CHANGED
@@ -1,338 +1,167 @@
1
- from typing import List, Union, Optional, Tuple
2
  import torch
3
- from torch import nn
4
  from PIL import Image
5
- from tqdm.auto import tqdm
6
- from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor
7
- from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
8
- from transformers import BatchEncoding, BatchFeature
9
- from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
10
-
11
-
12
- class OpsColQwen3(Qwen3VLModel):
13
- """
14
- OpsColQwen3 model implementation for multi-vector document retrieval.
15
- """
16
-
17
- def __init__(self, config: Qwen3VLConfig, dims: int = 320, mask_non_image_embeddings: bool = False):
18
- super().__init__(config=config)
19
- self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.config.text_config.hidden_size)
20
- self.dims = dims
21
- self.padding_side = "left"
22
- self.mask_non_image_embeddings = mask_non_image_embeddings
23
- self.post_init()
24
-
25
- @classmethod
26
- def from_pretrained(cls, *args, **kwargs):
27
- key_mapping = kwargs.pop("key_mapping", None)
28
- if key_mapping is None:
29
- key_mapping = {
30
- r"^base_model\.model\.(.*)": r"\1",
31
- r"^model\.(.*)": r"\1",
32
- }
33
-
34
- return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
35
-
36
- def forward(self, *args, **kwargs) -> torch.Tensor:
37
- attention_mask = kwargs.get("attention_mask")
38
- has_pixel_values = "pixel_values" in kwargs and kwargs["pixel_values"] is not None
39
-
40
- if has_pixel_values:
41
- image_grid_thw = kwargs.get("image_grid_thw")
42
- if image_grid_thw is None:
43
- raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")
44
-
45
- if not torch.is_tensor(image_grid_thw):
46
- image_grid_thw = torch.as_tensor(image_grid_thw, device=kwargs["pixel_values"].device)
47
-
48
- offsets = image_grid_thw.prod(dim=1)
49
- unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)]
50
-
51
- if unpadded:
52
- kwargs["pixel_values"] = torch.cat(unpadded, dim=0)
53
- else:
54
- kwargs["pixel_values"] = None
55
-
56
- kwargs.pop("return_dict", True)
57
- kwargs.pop("output_hidden_states", None)
58
- kwargs.pop("use_cache", None)
59
-
60
- last_hidden_states = super().forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True).last_hidden_state
61
-
62
- proj = self.custom_text_proj(last_hidden_states)
63
- if self.dims < self.config.text_config.hidden_size:
64
- proj = proj[..., : self.dims]
65
- proj = proj / proj.norm(dim=-1, keepdim=True)
66
-
67
- if attention_mask is not None:
68
- proj = proj * attention_mask.unsqueeze(-1)
69
-
70
- if has_pixel_values and self.mask_non_image_embeddings and kwargs.get("input_ids") is not None:
71
- image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
72
- proj = proj * image_mask
73
-
74
- return proj
75
-
76
- @property
77
- def patch_size(self) -> int:
78
- return self.visual.config.patch_size
79
-
80
- @property
81
- def spatial_merge_size(self) -> int:
82
- return self.visual.config.spatial_merge_size
83
-
84
- @property
85
- def temporal_patch_size(self) -> int:
86
- return getattr(self.visual.config, "temporal_patch_size", 1)
87
-
88
-
89
- class OpsColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor):
90
- """
91
- Processor for OpsColQwen3.
92
- """
93
-
94
- query_prefix: str = "Query: "
95
- visual_prompt_prefix: str = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|im_start|>assistant\n<|endoftext|>"
96
- query_augmentation_token: str = "<|endoftext|>"
97
- image_token: str = "<|image_pad|>"
98
-
99
- def __init__(self, *args, **kwargs) -> None:
100
- super().__init__(*args, **kwargs)
101
- self.tokenizer.padding_side = "left"
102
-
103
- @classmethod
104
- def from_pretrained(cls, *args, device_map: Optional[str] = None, **kwargs):
105
- instance = super().from_pretrained(*args, device_map=device_map, **kwargs)
106
-
107
- if "max_num_visual_tokens" in kwargs:
108
- instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 32 * 32
109
- instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
110
-
111
- return instance
112
-
113
- def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]:
114
- """Process a batch of PIL images."""
115
- images = [image.convert("RGB") for image in images]
116
-
117
- batch_doc = self.__call__(
118
- text=[self.visual_prompt_prefix] * len(images),
119
- images=images,
120
- padding="longest",
121
- return_tensors="pt",
122
- )
123
-
124
- if batch_doc["pixel_values"].numel() == 0:
125
- return batch_doc
126
-
127
- offsets = batch_doc["image_grid_thw"].prod(dim=1)
128
- pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist()))
129
- batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)
130
-
131
- return batch_doc
132
-
133
- def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
134
- """Process a list of texts."""
135
- return self(text=texts, return_tensors="pt", padding="longest")
136
-
137
- def score(
138
- self,
139
- qs: Union[torch.Tensor, List[torch.Tensor]],
140
- ps: Union[torch.Tensor, List[torch.Tensor]],
141
- device: Optional[Union[str, torch.device]] = None,
142
- **kwargs,
143
- ) -> torch.Tensor:
144
- """Compute the MaxSim score (ColBERT-like) for query and passage embeddings."""
145
- return self.score_multi_vector(qs, ps, device=device, **kwargs)
146
-
147
- def get_n_patches(
148
- self,
149
- image_size: Tuple[int, int],
150
- spatial_merge_size: int,
151
- ) -> Tuple[int, int]:
152
- """
153
- Compute the number of patches (n_patches_x, n_patches_y) for an image.
154
- """
155
- patch_size = self.image_processor.patch_size
156
- merge_size = getattr(self.image_processor, "merge_size", 1)
157
-
158
- height_new, width_new = smart_resize(
159
- width=image_size[0],
160
- height=image_size[1],
161
- factor=patch_size * merge_size,
162
- min_pixels=self.image_processor.size["shortest_edge"],
163
- max_pixels=self.image_processor.size["longest_edge"],
164
- )
165
-
166
- n_patches_x = width_new // patch_size // spatial_merge_size
167
- n_patches_y = height_new // patch_size // spatial_merge_size
168
-
169
- return n_patches_x, n_patches_y
170
-
171
- def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
172
- """Return a boolean tensor identifying image tokens."""
173
- return batch_images.input_ids == self.image_token_id
174
 
175
 
176
  class OpsColQwen3Embedder:
177
  """
178
- Simple embedder wrapper for OpsColQwen3 model.
179
-
180
- Args:
181
- model_name: HuggingFace model name or local path
182
- dims: Embedding dimension after projection
183
- device: Device to run the model on
184
- attn_implementation: Attention implementation
185
  """
186
 
187
  def __init__(
188
  self,
189
- model_name: str = "OpenSearch-AI/Ops-ColQwen3-4B",
190
  dims: int = 2560,
191
  device: Optional[str] = None,
192
- attn_implementation: Optional[str] = None,
193
- **kwargs,
194
  ):
195
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
196
- self.dims = dims
197
 
198
- if attn_implementation is None:
199
- try:
200
- from transformers.utils.import_utils import is_flash_attn_2_available
 
 
 
201
 
202
- attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None
203
- except ImportError:
204
- attn_implementation = None
 
 
 
 
 
 
 
205
 
206
- load_kwargs = {"dims": dims, "device_map": self.device, **kwargs}
207
- if attn_implementation:
208
- load_kwargs["attn_implementation"] = attn_implementation
209
 
210
- self.model = OpsColQwen3.from_pretrained(model_name, **load_kwargs)
 
 
 
 
 
 
 
211
  self.model.eval()
212
 
213
- self.processor = OpsColQwen3Processor.from_pretrained(model_name)
 
 
 
 
 
 
 
214
 
215
- def encode_texts(
216
  self,
217
- texts: List[str],
218
- batch_size: int = 32,
219
- show_progress: bool = False,
220
  ) -> List[torch.Tensor]:
221
  """
222
  Encode a list of text queries.
223
 
224
  Args:
225
- texts: List of text strings to encode
226
- batch_size: Batch size for processing
227
- show_progress: Whether to show progress bar
228
 
229
  Returns:
230
- List of embedding tensors
231
  """
232
- all_embeddings = []
233
-
234
- iterator = range(0, len(texts), batch_size)
235
- if show_progress:
236
- iterator = tqdm(iterator, desc="Encoding texts")
237
 
238
  with torch.no_grad():
239
- for i in iterator:
240
- batch_texts = texts[i : i + batch_size]
241
 
242
- batch_texts = [self.processor.query_prefix + t + self.processor.query_augmentation_token * 10 for t in batch_texts]
243
-
244
- inputs = self.processor.process_texts(batch_texts)
245
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
246
-
247
- embeddings = self.model(**inputs)
248
- all_embeddings.extend(embeddings.cpu().to(torch.float32))
249
-
250
- return all_embeddings
251
 
252
  def encode_images(
253
  self,
254
- images: List[Union[str, Image.Image]],
255
- batch_size: int = 32,
256
- show_progress: bool = False,
257
  ) -> List[torch.Tensor]:
258
  """
259
  Encode a list of images.
260
 
261
  Args:
262
  images: List of image paths or PIL Images
263
- batch_size: Batch size for processing
264
- show_progress: Whether to show progress bar
265
 
266
  Returns:
267
- List of embedding tensors
268
  """
269
- image_list = []
270
  for img in images:
271
  if isinstance(img, str):
272
- image_list.append(Image.open(img).convert("RGB"))
273
  elif isinstance(img, Image.Image):
274
- image_list.append(img.convert("RGB"))
275
  else:
276
  raise ValueError(f"Unsupported image type: {type(img)}")
277
 
278
- all_embeddings = []
279
-
280
- iterator = range(0, len(image_list), batch_size)
281
- if show_progress:
282
- iterator = tqdm(iterator, desc="Encoding images")
283
 
284
  with torch.no_grad():
285
- for i in iterator:
286
- batch_images = image_list[i : i + batch_size]
287
-
288
- inputs = self.processor.process_images(batch_images)
289
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
290
 
291
- embeddings = self.model(**inputs)
292
- all_embeddings.extend(embeddings.cpu().to(torch.float32))
293
-
294
- return all_embeddings
295
 
296
  def compute_scores(
297
  self,
298
  query_embeddings: List[torch.Tensor],
299
- image_embeddings: List[torch.Tensor],
300
- batch_size: int = 128,
301
  ) -> torch.Tensor:
302
  """
303
- Compute relevance scores between queries and images using MaxSim.
304
 
305
  Args:
306
- query_embeddings: List of query embedding tensors
307
- image_embeddings: List of image embedding tensors
308
- batch_size: Batch size for score computation
309
 
310
  Returns:
311
- Score matrix of shape (num_queries, num_images)
312
  """
313
- return self.processor.score_multi_vector(
314
- query_embeddings,
315
- image_embeddings,
316
- batch_size=batch_size,
317
- device=self.device,
318
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  # Example usage
322
  if __name__ == "__main__":
323
  images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
324
-
325
  queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
326
 
327
- encoder = OpsColQwen3Embedder(
328
  model_name="OpenSearch-AI/Ops-Colqwen3-4B",
329
- dims=320,
330
  dtype=torch.float16,
331
  attn_implementation="flash_attention_2",
332
  )
333
 
334
- query_embeddings = encoder.encode_texts(queries, batch_size=2)
335
- image_embeddings = encoder.encode_images(images, batch_size=2)
 
 
 
336
 
337
- scores = encoder.compute_scores(query_embeddings, image_embeddings)
338
- print(f"Scores:\n{scores}")
 
 
1
  import torch
 
2
  from PIL import Image
3
+ from transformers import AutoModel, AutoProcessor
4
+ from typing import List, Union, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class OpsColQwen3Embedder:
8
  """
9
+ Embedder for OpsColQwen3-4B model.
 
 
 
 
 
 
10
  """
11
 
12
  def __init__(
13
  self,
14
+ model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B",
15
  dims: int = 2560,
16
  device: Optional[str] = None,
17
+ **kwargs
 
18
  ):
19
+ """
20
+ Initialize the embedder.
21
 
22
+ Args:
23
+ model_name: Model path or hub name
24
+ dims: Embedding dimensions
25
+ device: Device to use for inference ('mps', 'cuda', or 'cpu')
26
+ **kwargs: Additional arguments passed to from_pretrained
27
+ """
28
 
29
+ device_map = kwargs.pop('device_map', None)
30
+ if not device_map:
31
+ if device:
32
+ device_map = device
33
+ elif torch.cuda.is_available():
34
+ device_map = "cuda"
35
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
36
+ device_map = "mps" # Use MPS for Apple Silicon
37
+ else:
38
+ device_map = "cpu"
39
 
40
+ dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32)
 
 
41
 
42
+ self.model = AutoModel.from_pretrained(
43
+ model_name,
44
+ dims=dims,
45
+ trust_remote_code=True,
46
+ dtype=dtype,
47
+ device_map=device_map,
48
+ **kwargs
49
+ )
50
  self.model.eval()
51
 
52
+ self.processor = AutoProcessor.from_pretrained(
53
+ model_name,
54
+ trust_remote_code=True,
55
+ **kwargs
56
+ )
57
+
58
+ self.device = device_map
59
+ self.dims = dims
60
 
61
+ def encode_queries(
62
  self,
63
+ queries: List[str]
 
 
64
  ) -> List[torch.Tensor]:
65
  """
66
  Encode a list of text queries.
67
 
68
  Args:
69
+ queries: List of query texts
 
 
70
 
71
  Returns:
72
+ List of query embeddings
73
  """
74
+ query_inputs = self.processor.process_queries(queries)
75
+ query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()}
 
 
 
76
 
77
  with torch.no_grad():
78
+ query_embeddings = self.model(**query_inputs)
 
79
 
80
+ return [q.cpu() for q in query_embeddings]
 
 
 
 
 
 
 
 
81
 
82
  def encode_images(
83
  self,
84
+ images: List[Union[str, Image.Image]]
 
 
85
  ) -> List[torch.Tensor]:
86
  """
87
  Encode a list of images.
88
 
89
  Args:
90
  images: List of image paths or PIL Images
 
 
91
 
92
  Returns:
93
+ List of image embeddings
94
  """
95
+ image_objects = []
96
  for img in images:
97
  if isinstance(img, str):
98
+ image_objects.append(Image.open(img).convert("RGB"))
99
  elif isinstance(img, Image.Image):
100
+ image_objects.append(img)
101
  else:
102
  raise ValueError(f"Unsupported image type: {type(img)}")
103
 
104
+ image_inputs = self.processor.process_images(image_objects)
105
+ image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
 
 
 
106
 
107
  with torch.no_grad():
108
+ image_embeddings = self.model(**image_inputs)
 
 
 
 
109
 
110
+ return [i.cpu() for i in image_embeddings]
 
 
 
111
 
112
  def compute_scores(
113
  self,
114
  query_embeddings: List[torch.Tensor],
115
+ image_embeddings: List[torch.Tensor]
 
116
  ) -> torch.Tensor:
117
  """
118
+ Compute similarity scores between queries and images.
119
 
120
  Args:
121
+ query_embeddings: List of query embeddings
122
+ image_embeddings: List of image embeddings
 
123
 
124
  Returns:
125
+ Similarity scores matrix
126
  """
127
+ return self.processor.score_multi_vector(query_embeddings, image_embeddings)
128
+
129
+ def encode_and_score(
130
+ self,
131
+ queries: List[str],
132
+ images: List[Union[str, Image.Image]]
133
+ ):
134
+ """
135
+ Convenience method to encode queries and images and compute scores.
136
+
137
+ Args:
138
+ queries: List of query texts
139
+ images: List of images (paths or PIL objects)
140
+
141
+ Returns:
142
+ Similarity scores between queries and images
143
+ """
144
+ query_embeddings = self.encode_queries(queries)
145
+ image_embeddings = self.encode_images(images)
146
+ return self.compute_scores(query_embeddings, image_embeddings)
147
 
148
 
149
  # Example usage
150
  if __name__ == "__main__":
151
  images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
 
152
  queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
153
 
154
+ embedder = OpsColQwen3Embedder(
155
  model_name="OpenSearch-AI/Ops-Colqwen3-4B",
156
+ dims=2560,
157
  dtype=torch.float16,
158
  attn_implementation="flash_attention_2",
159
  )
160
 
161
+ query_embeddings = embedder.encode_queries(queries)
162
+ image_embeddings = embedder.encode_images(images)
163
+ print(query_embeddings[0].shape, image_embeddings[0].shape) # (23, 2560) (18, 2560)
164
+
165
+ scores = embedder.compute_scores(query_embeddings, image_embeddings)
166
 
167
+ print(f"Scores:\n{scores}")