"""Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval. Self-contained processor for Argus-Colqwen3.5-9B. Wraps the Qwen3-VL processor (image processor + Qwen2 tokenizer + optional video processor) and adds ColPali- style ``process_images`` / ``process_texts`` / ``score_multi_vector`` helpers. """ from __future__ import annotations from pathlib import Path from typing import ClassVar, List, Optional, Tuple, Union import torch from PIL import Image from transformers import BatchEncoding, BatchFeature from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen3_vl import Qwen3VLProcessor class ArgusProcessor(Qwen3VLProcessor): """Processor for Argus-Colqwen3.5-9B. Subclasses ``Qwen3VLProcessor`` (the Qwen3.5-9B hub repo ships that processor class even though the LLM is Qwen3.5). Adds: - ``process_images``: batch-encode PIL images into the exact dict the retriever forward expects (``pixel_values``, ``image_grid_thw``, ``input_ids``, ``attention_mask``). - ``process_texts``: batch-encode query strings. - ``score`` / ``score_multi_vector``: MaxSim scoring helper. - ``max_num_visual_tokens`` knob: caps the longest-edge pixel budget per image so long documents don't blow up the vision encoder. """ visual_prompt_prefix: ClassVar[str] = ( "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" ) query_augmentation_token: ClassVar[str] = "<|endoftext|>" query_prefix: ClassVar[str] = "" image_token: ClassVar[str] = "<|image_pad|>" # Number of <|endoftext|> tokens appended to every query — matches the # training-time collator (``colpali_novel/data/layout_collator.py``). # Removing or changing this number measurably hurts retrieval scores. n_query_augmentation_tokens: ClassVar[int] = 10 def __init__( self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs, ): # Explicit signature matters for ``ProcessorMixin``: it inspects # __init__.__code__ to decide which modality attributes to set. A # *args,**kwargs signature silently drops tokenizer/image_processor. super().__init__( image_processor=image_processor, tokenizer=tokenizer, video_processor=video_processor, chat_template=chat_template, **kwargs, ) if getattr(self, "tokenizer", None) is not None: self.tokenizer.padding_side = "left" @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *args, device_map: Optional[str] = None, max_num_visual_tokens: Optional[int] = None, **kwargs, ): """Load the processor from a local folder or HF repo id. The Qwen3.5-9B hub repo declares ``processor_class=Qwen3VLProcessor`` but ``tokenizer_class=Qwen2Tokenizer``. The stock ``Qwen3VLProcessor .from_pretrained`` returns ``tokenizer=None`` in that case and then crashes on ``tokenizer.convert_tokens_to_ids(self.image_token)``. We load tokenizer + image processor via the Auto* registry explicitly so both are real objects before ``__init__`` runs. """ from transformers import AutoImageProcessor, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) video_processor = None try: from transformers import AutoVideoProcessor video_processor = AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) except Exception: # noqa: BLE001 — video processing is optional video_processor = None chat_template = None try: candidate = Path(str(pretrained_model_name_or_path)) / "chat_template.jinja" if candidate.is_file(): chat_template = candidate.read_text() except Exception: # noqa: BLE001 chat_template = None instance = cls( image_processor=image_processor, tokenizer=tokenizer, video_processor=video_processor, chat_template=chat_template, ) if max_num_visual_tokens is not None: patch_size = getattr(instance.image_processor, "patch_size", None) merge_size = getattr(instance.image_processor, "merge_size", None) if patch_size is None or merge_size is None: raise ValueError("Argus image processor missing patch_size or merge_size.") tile = patch_size * merge_size instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels return instance # ------------------------------------------------------------------ # # Encoding # ------------------------------------------------------------------ # def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]: """Encode PIL images into the backbone's expected input dict.""" images = [img.convert("RGB") for img in images] batch_doc = self( text=[self.visual_prompt_prefix] * len(images), images=images, padding="longest", return_tensors="pt", ) # Pack pixel_values so the forward can scatter them per image via # image_grid_thw offsets. This mirrors the training-time collator. offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist())) batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True) return batch_doc def process_texts( self, texts: List[str], max_length: Optional[int] = None, ) -> Union[BatchFeature, BatchEncoding]: """Encode query strings into the backbone's expected input dict.""" kwargs = {"text": texts, "return_tensors": "pt", "padding": "longest"} if max_length is not None: kwargs["max_length"] = max_length kwargs["truncation"] = True return self(**kwargs) def process_queries( self, queries: Optional[List[str]] = None, texts: Optional[List[str]] = None, max_length: Optional[int] = None, suffix: Optional[str] = None, ) -> Union[BatchFeature, BatchEncoding]: """Encode queries with the training-time augmentation: ``query_prefix + query + query_augmentation_token * n_query_augmentation_tokens``. Mirrors ``colpali_engine.utils.processing_utils.BaseVisualRetrieverProcessor .process_queries`` and the Argus training collator. The default 10 trailing ``<|endoftext|>`` tokens are not optional — without them, MaxSim scoring drops several nDCG points because the query has fewer active multi-vectors. """ if texts is not None and queries is not None: raise ValueError("Only one of 'texts' or 'queries' should be provided.") if queries is None: queries = texts if queries is None: raise ValueError("No queries provided.") if suffix is None: suffix = self.query_augmentation_token * self.n_query_augmentation_tokens wrapped = [self.query_prefix + q + suffix for q in queries] return self.process_texts(wrapped, max_length=max_length) # ------------------------------------------------------------------ # # Scoring # ------------------------------------------------------------------ # def score( self, qs: List[torch.Tensor], ps: List[torch.Tensor], device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> torch.Tensor: """Alias for ``score_multi_vector`` (MaxSim over multi-vectors).""" return self.score_multi_vector(qs, ps, device=device, **kwargs) def score_multi_vector( self, qs: List[torch.Tensor], ps: List[torch.Tensor], batch_size: int = 128, device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: """Compute an [N_q, N_p] score matrix via MaxSim (ColBERT scoring). For each (q, p) pair: ``sum_t max_p ``. Inputs are the raw (potentially ragged) per-sample multi-vector tensors returned by :meth:`encode_queries` / :meth:`encode_images`. """ dev = torch.device(device) if device is not None else torch.device("cpu") n_q, n_p = len(qs), len(ps) scores = torch.zeros(n_q, n_p, device=dev) for qi in range(0, n_q, batch_size): q_slice = qs[qi : qi + batch_size] q_len = max(x.size(0) for x in q_slice) q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev) q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool) for i, t in enumerate(q_slice): q_pad[i, : t.size(0)] = t.to(dev) q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0 for pi in range(0, n_p, batch_size): p_slice = ps[pi : pi + batch_size] p_len = max(x.size(0) for x in p_slice) p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev) for j, t in enumerate(p_slice): p_pad[j, : t.size(0)] = t.to(dev) sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad) maxsim = sim.max(dim=-1).values maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1) scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim return scores # ------------------------------------------------------------------ # # Misc helpers (match colpali-engine BaseVisualRetrieverProcessor API) # ------------------------------------------------------------------ # def get_n_patches( self, image_size: Tuple[int, int], spatial_merge_size: int, ) -> Tuple[int, int]: patch_size = self.image_processor.patch_size height_new, width_new = smart_resize( width=image_size[0], height=image_size[1], factor=patch_size * self.image_processor.merge_size, min_pixels=self.image_processor.size["shortest_edge"], max_pixels=self.image_processor.size["longest_edge"], ) n_patches_x = width_new // patch_size // spatial_merge_size n_patches_y = height_new // patch_size // spatial_merge_size return n_patches_x, n_patches_y def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: return batch_images.input_ids == self.image_token_id __all__ = ["ArgusProcessor"]