Argus-Colqwen3.5-2b-v0 / processing_argus.py
abdoelsayed's picture
Initial fp32 release
a1638ea verified
"""Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval.
Self-contained processor for Argus-Colqwen3.5-9B. Wraps the Qwen3-VL processor
(image processor + Qwen2 tokenizer + optional video processor) and adds ColPali-
style ``process_images`` / ``process_texts`` / ``score_multi_vector`` helpers.
"""
from __future__ import annotations
from pathlib import Path
from typing import ClassVar, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen3_vl import Qwen3VLProcessor
class ArgusProcessor(Qwen3VLProcessor):
"""Processor for Argus-Colqwen3.5-9B.
Subclasses ``Qwen3VLProcessor`` (the Qwen3.5-9B hub repo ships that
processor class even though the LLM is Qwen3.5). Adds:
- ``process_images``: batch-encode PIL images into the exact dict the
retriever forward expects (``pixel_values``, ``image_grid_thw``,
``input_ids``, ``attention_mask``).
- ``process_texts``: batch-encode query strings.
- ``score`` / ``score_multi_vector``: MaxSim scoring helper.
- ``max_num_visual_tokens`` knob: caps the longest-edge pixel budget per
image so long documents don't blow up the vision encoder.
"""
visual_prompt_prefix: ClassVar[str] = (
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
)
query_augmentation_token: ClassVar[str] = "<|endoftext|>"
query_prefix: ClassVar[str] = ""
image_token: ClassVar[str] = "<|image_pad|>"
# Number of <|endoftext|> tokens appended to every query — matches the
# training-time collator (``colpali_novel/data/layout_collator.py``).
# Removing or changing this number measurably hurts retrieval scores.
n_query_augmentation_tokens: ClassVar[int] = 10
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs,
):
# Explicit signature matters for ``ProcessorMixin``: it inspects
# __init__.__code__ to decide which modality attributes to set. A
# *args,**kwargs signature silently drops tokenizer/image_processor.
super().__init__(
image_processor=image_processor,
tokenizer=tokenizer,
video_processor=video_processor,
chat_template=chat_template,
**kwargs,
)
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "left"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*args,
device_map: Optional[str] = None,
max_num_visual_tokens: Optional[int] = None,
**kwargs,
):
"""Load the processor from a local folder or HF repo id.
The Qwen3.5-9B hub repo declares ``processor_class=Qwen3VLProcessor``
but ``tokenizer_class=Qwen2Tokenizer``. The stock ``Qwen3VLProcessor
.from_pretrained`` returns ``tokenizer=None`` in that case and then
crashes on ``tokenizer.convert_tokens_to_ids(self.image_token)``.
We load tokenizer + image processor via the Auto* registry
explicitly so both are real objects before ``__init__`` runs.
"""
from transformers import AutoImageProcessor, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
video_processor = None
try:
from transformers import AutoVideoProcessor
video_processor = AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
except Exception: # noqa: BLE001 — video processing is optional
video_processor = None
chat_template = None
try:
candidate = Path(str(pretrained_model_name_or_path)) / "chat_template.jinja"
if candidate.is_file():
chat_template = candidate.read_text()
except Exception: # noqa: BLE001
chat_template = None
instance = cls(
image_processor=image_processor,
tokenizer=tokenizer,
video_processor=video_processor,
chat_template=chat_template,
)
if max_num_visual_tokens is not None:
patch_size = getattr(instance.image_processor, "patch_size", None)
merge_size = getattr(instance.image_processor, "merge_size", None)
if patch_size is None or merge_size is None:
raise ValueError("Argus image processor missing patch_size or merge_size.")
tile = patch_size * merge_size
instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile
instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
return instance
# ------------------------------------------------------------------ #
# Encoding
# ------------------------------------------------------------------ #
def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]:
"""Encode PIL images into the backbone's expected input dict."""
images = [img.convert("RGB") for img in images]
batch_doc = self(
text=[self.visual_prompt_prefix] * len(images),
images=images,
padding="longest",
return_tensors="pt",
)
# Pack pixel_values so the forward can scatter them per image via
# image_grid_thw offsets. This mirrors the training-time collator.
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist()))
batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)
return batch_doc
def process_texts(
self,
texts: List[str],
max_length: Optional[int] = None,
) -> Union[BatchFeature, BatchEncoding]:
"""Encode query strings into the backbone's expected input dict."""
kwargs = {"text": texts, "return_tensors": "pt", "padding": "longest"}
if max_length is not None:
kwargs["max_length"] = max_length
kwargs["truncation"] = True
return self(**kwargs)
def process_queries(
self,
queries: Optional[List[str]] = None,
texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
suffix: Optional[str] = None,
) -> Union[BatchFeature, BatchEncoding]:
"""Encode queries with the training-time augmentation:
``query_prefix + query + query_augmentation_token * n_query_augmentation_tokens``.
Mirrors ``colpali_engine.utils.processing_utils.BaseVisualRetrieverProcessor
.process_queries`` and the Argus training collator. The default 10 trailing
``<|endoftext|>`` tokens are not optional — without them, MaxSim scoring
drops several nDCG points because the query has fewer active multi-vectors.
"""
if texts is not None and queries is not None:
raise ValueError("Only one of 'texts' or 'queries' should be provided.")
if queries is None:
queries = texts
if queries is None:
raise ValueError("No queries provided.")
if suffix is None:
suffix = self.query_augmentation_token * self.n_query_augmentation_tokens
wrapped = [self.query_prefix + q + suffix for q in queries]
return self.process_texts(wrapped, max_length=max_length)
# ------------------------------------------------------------------ #
# Scoring
# ------------------------------------------------------------------ #
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""Alias for ``score_multi_vector`` (MaxSim over multi-vectors)."""
return self.score_multi_vector(qs, ps, device=device, **kwargs)
def score_multi_vector(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
batch_size: int = 128,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""Compute an [N_q, N_p] score matrix via MaxSim (ColBERT scoring).
For each (q, p) pair: ``sum_t max_p <q_t, p_p>``. Inputs are the raw
(potentially ragged) per-sample multi-vector tensors returned by
:meth:`encode_queries` / :meth:`encode_images`.
"""
dev = torch.device(device) if device is not None else torch.device("cpu")
n_q, n_p = len(qs), len(ps)
scores = torch.zeros(n_q, n_p, device=dev)
for qi in range(0, n_q, batch_size):
q_slice = qs[qi : qi + batch_size]
q_len = max(x.size(0) for x in q_slice)
q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev)
q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool)
for i, t in enumerate(q_slice):
q_pad[i, : t.size(0)] = t.to(dev)
q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0
for pi in range(0, n_p, batch_size):
p_slice = ps[pi : pi + batch_size]
p_len = max(x.size(0) for x in p_slice)
p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev)
for j, t in enumerate(p_slice):
p_pad[j, : t.size(0)] = t.to(dev)
sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad)
maxsim = sim.max(dim=-1).values
maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1)
scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim
return scores
# ------------------------------------------------------------------ #
# Misc helpers (match colpali-engine BaseVisualRetrieverProcessor API)
# ------------------------------------------------------------------ #
def get_n_patches(
self,
image_size: Tuple[int, int],
spatial_merge_size: int,
) -> Tuple[int, int]:
patch_size = self.image_processor.patch_size
height_new, width_new = smart_resize(
width=image_size[0],
height=image_size[1],
factor=patch_size * self.image_processor.merge_size,
min_pixels=self.image_processor.size["shortest_edge"],
max_pixels=self.image_processor.size["longest_edge"],
)
n_patches_x = width_new // patch_size // spatial_merge_size
n_patches_y = height_new // patch_size // spatial_merge_size
return n_patches_x, n_patches_y
def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
return batch_images.input_ids == self.image_token_id
__all__ = ["ArgusProcessor"]