File size: 5,599 Bytes
4894b7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import logging
from typing import List, Optional, Union
import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from transformers.models.qwen3_vl import Qwen3VLProcessor
logger = logging.getLogger(__name__)
def get_torch_device(device: str = "auto") -> str:
"""
Returns the device (string) to be used by PyTorch.
`device` arg defaults to "auto" which will use:
- "cuda:0" if available
- else "mps" if available
- else "cpu".
"""
if device == "auto":
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available(): # for Apple Silicon
device = "mps"
else:
device = "cpu"
logger.info(f"Using device: {device}")
return device
class OpsColQwen3Processor(Qwen3VLProcessor):
"""
Processor for OpsColQwen3 model.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
query_prefix: str = "Query: "
visual_prompt_prefix: str = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|im_start|>assistant\n<|endoftext|>"
query_augmentation_token: str = "<|endoftext|>"
image_token: str = "<|image_pad|>"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
"""
Initialize the processor.
Args:
image_processor: Image processor instance
tokenizer: Tokenizer instance
chat_template: Optional chat template
**kwargs: Additional arguments
"""
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template, **kwargs)
if self.tokenizer is not None:
self.tokenizer.padding_side = "left"
def process_images(self, images: List[Image.Image], return_tensors: str = "pt", **kwargs) -> Union[BatchFeature, BatchEncoding]:
"""
Process a batch of PIL images for the model.
"""
images = [image.convert("RGB") for image in images]
batch_doc = self(text=[self.visual_prompt_prefix] * len(images), images=images, padding="longest", return_tensors=return_tensors, **kwargs)
if batch_doc["pixel_values"].numel() == 0:
return batch_doc
offsets = batch_doc["image_grid_thw"].prod(dim=1)
pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist()))
batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)
return batch_doc
def process_queries(self, queries: List[str], return_tensors: str = "pt", **kwargs) -> Union[BatchFeature, BatchEncoding]:
"""
Process a list of text queries.
"""
processed_queries = [self.query_prefix + q + self.query_augmentation_token * 10 for q in queries]
return self(text=processed_queries, return_tensors=return_tensors, padding="longest", **kwargs)
@staticmethod
def score_multi_vector(
qs: Union[torch.Tensor, List[torch.Tensor]],
ps: Union[torch.Tensor, List[torch.Tensor]],
batch_size: int = 128,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
image of a document page.
Because the embedding tensors are multi-vector and can thus have different shapes, they
should be fed as:
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
obtained by padding the list of tensors.
Args:
qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
provided, uses `get_torch_device("auto")`.
Returns:
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
tensor is saved on the "cpu" device.
"""
device = device or get_torch_device("auto")
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
scores_list: List[torch.Tensor] = []
for i in range(0, len(qs), batch_size):
scores_batch = []
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(device)
for j in range(0, len(ps), batch_size):
ps_batch = torch.nn.utils.rnn.pad_sequence(ps[j : j + batch_size], batch_first=True, padding_value=0).to(device)
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
scores_batch = torch.cat(scores_batch, dim=1).cpu()
scores_list.append(scores_batch)
scores = torch.cat(scores_list, dim=0)
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
scores = scores.to(torch.float32)
return scores
|