| """Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval. |
| |
| Self-contained processor for Argus-Colqwen3.5-9B. Wraps the Qwen3-VL processor |
| (image processor + Qwen2 tokenizer + optional video processor) and adds ColPali- |
| style ``process_images`` / ``process_texts`` / ``score_multi_vector`` helpers. |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import ClassVar, List, Optional, Tuple, Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import BatchEncoding, BatchFeature |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| from transformers.models.qwen3_vl import Qwen3VLProcessor |
|
|
|
|
| class ArgusProcessor(Qwen3VLProcessor): |
| """Processor for Argus-Colqwen3.5-9B. |
| |
| Subclasses ``Qwen3VLProcessor`` (the Qwen3.5-9B hub repo ships that |
| processor class even though the LLM is Qwen3.5). Adds: |
| |
| - ``process_images``: batch-encode PIL images into the exact dict the |
| retriever forward expects (``pixel_values``, ``image_grid_thw``, |
| ``input_ids``, ``attention_mask``). |
| - ``process_texts``: batch-encode query strings. |
| - ``score`` / ``score_multi_vector``: MaxSim scoring helper. |
| - ``max_num_visual_tokens`` knob: caps the longest-edge pixel budget per |
| image so long documents don't blow up the vision encoder. |
| """ |
|
|
| visual_prompt_prefix: ClassVar[str] = ( |
| "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" |
| ) |
| query_augmentation_token: ClassVar[str] = "<|endoftext|>" |
| query_prefix: ClassVar[str] = "" |
| image_token: ClassVar[str] = "<|image_pad|>" |
| |
| |
| |
| n_query_augmentation_tokens: ClassVar[int] = 10 |
|
|
| def __init__( |
| self, |
| image_processor=None, |
| tokenizer=None, |
| video_processor=None, |
| chat_template=None, |
| **kwargs, |
| ): |
| |
| |
| |
| super().__init__( |
| image_processor=image_processor, |
| tokenizer=tokenizer, |
| video_processor=video_processor, |
| chat_template=chat_template, |
| **kwargs, |
| ) |
| if getattr(self, "tokenizer", None) is not None: |
| self.tokenizer.padding_side = "left" |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path, |
| *args, |
| device_map: Optional[str] = None, |
| max_num_visual_tokens: Optional[int] = None, |
| **kwargs, |
| ): |
| """Load the processor from a local folder or HF repo id. |
| |
| The Qwen3.5-9B hub repo declares ``processor_class=Qwen3VLProcessor`` |
| but ``tokenizer_class=Qwen2Tokenizer``. The stock ``Qwen3VLProcessor |
| .from_pretrained`` returns ``tokenizer=None`` in that case and then |
| crashes on ``tokenizer.convert_tokens_to_ids(self.image_token)``. |
| We load tokenizer + image processor via the Auto* registry |
| explicitly so both are real objects before ``__init__`` runs. |
| """ |
| from transformers import AutoImageProcessor, AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|
| video_processor = None |
| try: |
| from transformers import AutoVideoProcessor |
| video_processor = AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| except Exception: |
| video_processor = None |
|
|
| chat_template = None |
| try: |
| candidate = Path(str(pretrained_model_name_or_path)) / "chat_template.jinja" |
| if candidate.is_file(): |
| chat_template = candidate.read_text() |
| except Exception: |
| chat_template = None |
|
|
| instance = cls( |
| image_processor=image_processor, |
| tokenizer=tokenizer, |
| video_processor=video_processor, |
| chat_template=chat_template, |
| ) |
|
|
| if max_num_visual_tokens is not None: |
| patch_size = getattr(instance.image_processor, "patch_size", None) |
| merge_size = getattr(instance.image_processor, "merge_size", None) |
| if patch_size is None or merge_size is None: |
| raise ValueError("Argus image processor missing patch_size or merge_size.") |
| tile = patch_size * merge_size |
| instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile |
| instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels |
| return instance |
|
|
| |
| |
| |
|
|
| def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]: |
| """Encode PIL images into the backbone's expected input dict.""" |
| images = [img.convert("RGB") for img in images] |
| batch_doc = self( |
| text=[self.visual_prompt_prefix] * len(images), |
| images=images, |
| padding="longest", |
| return_tensors="pt", |
| ) |
| |
| |
| offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] |
| pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist())) |
| batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True) |
| return batch_doc |
|
|
| def process_texts( |
| self, |
| texts: List[str], |
| max_length: Optional[int] = None, |
| ) -> Union[BatchFeature, BatchEncoding]: |
| """Encode query strings into the backbone's expected input dict.""" |
| kwargs = {"text": texts, "return_tensors": "pt", "padding": "longest"} |
| if max_length is not None: |
| kwargs["max_length"] = max_length |
| kwargs["truncation"] = True |
| return self(**kwargs) |
|
|
| def process_queries( |
| self, |
| queries: Optional[List[str]] = None, |
| texts: Optional[List[str]] = None, |
| max_length: Optional[int] = None, |
| suffix: Optional[str] = None, |
| ) -> Union[BatchFeature, BatchEncoding]: |
| """Encode queries with the training-time augmentation: |
| ``query_prefix + query + query_augmentation_token * n_query_augmentation_tokens``. |
| |
| Mirrors ``colpali_engine.utils.processing_utils.BaseVisualRetrieverProcessor |
| .process_queries`` and the Argus training collator. The default 10 trailing |
| ``<|endoftext|>`` tokens are not optional — without them, MaxSim scoring |
| drops several nDCG points because the query has fewer active multi-vectors. |
| """ |
| if texts is not None and queries is not None: |
| raise ValueError("Only one of 'texts' or 'queries' should be provided.") |
| if queries is None: |
| queries = texts |
| if queries is None: |
| raise ValueError("No queries provided.") |
|
|
| if suffix is None: |
| suffix = self.query_augmentation_token * self.n_query_augmentation_tokens |
|
|
| wrapped = [self.query_prefix + q + suffix for q in queries] |
| return self.process_texts(wrapped, max_length=max_length) |
|
|
| |
| |
| |
|
|
| def score( |
| self, |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Alias for ``score_multi_vector`` (MaxSim over multi-vectors).""" |
| return self.score_multi_vector(qs, ps, device=device, **kwargs) |
|
|
| def score_multi_vector( |
| self, |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| batch_size: int = 128, |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> torch.Tensor: |
| """Compute an [N_q, N_p] score matrix via MaxSim (ColBERT scoring). |
| |
| For each (q, p) pair: ``sum_t max_p <q_t, p_p>``. Inputs are the raw |
| (potentially ragged) per-sample multi-vector tensors returned by |
| :meth:`encode_queries` / :meth:`encode_images`. |
| """ |
| dev = torch.device(device) if device is not None else torch.device("cpu") |
| n_q, n_p = len(qs), len(ps) |
| scores = torch.zeros(n_q, n_p, device=dev) |
|
|
| for qi in range(0, n_q, batch_size): |
| q_slice = qs[qi : qi + batch_size] |
| q_len = max(x.size(0) for x in q_slice) |
| q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev) |
| q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool) |
| for i, t in enumerate(q_slice): |
| q_pad[i, : t.size(0)] = t.to(dev) |
| q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0 |
|
|
| for pi in range(0, n_p, batch_size): |
| p_slice = ps[pi : pi + batch_size] |
| p_len = max(x.size(0) for x in p_slice) |
| p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev) |
| for j, t in enumerate(p_slice): |
| p_pad[j, : t.size(0)] = t.to(dev) |
|
|
| sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad) |
| maxsim = sim.max(dim=-1).values |
| maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1) |
| scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim |
|
|
| return scores |
|
|
| |
| |
| |
|
|
| def get_n_patches( |
| self, |
| image_size: Tuple[int, int], |
| spatial_merge_size: int, |
| ) -> Tuple[int, int]: |
| patch_size = self.image_processor.patch_size |
| height_new, width_new = smart_resize( |
| width=image_size[0], |
| height=image_size[1], |
| factor=patch_size * self.image_processor.merge_size, |
| min_pixels=self.image_processor.size["shortest_edge"], |
| max_pixels=self.image_processor.size["longest_edge"], |
| ) |
| n_patches_x = width_new // patch_size // spatial_merge_size |
| n_patches_y = height_new // patch_size // spatial_merge_size |
| return n_patches_x, n_patches_y |
|
|
| def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: |
| return batch_images.input_ids == self.image_token_id |
|
|
|
|
| __all__ = ["ArgusProcessor"] |
|
|