| import torch |
| from torch import nn |
| import numpy as np |
| from typing import Optional, Tuple, List, Union |
| from transformers import Qwen2VLForConditionalGeneration |
| import logging |
| import warnings |
| from PIL import Image |
| from transformers.image_utils import load_image |
|
|
| logger = logging.getLogger(__name__) |
|
|
| LOGIT_BIAS = 2.65 |
|
|
| def load_images(images, lazy_load: bool = True): |
| |
| pil_max_px = Image.MAX_IMAGE_PIXELS |
| Image.MAX_IMAGE_PIXELS = None |
|
|
| images_batch = [] |
| for image in images: |
| if isinstance(image, Image.Image): |
| images_batch.append(image) |
| else: |
| pil_image = load_image(image) |
| if lazy_load: |
| images_batch.append(pil_image) |
| else: |
| |
| images_batch.append(pil_image.copy()) |
| pil_image.close() |
| Image.MAX_IMAGE_PIXELS = pil_max_px |
|
|
| return images_batch |
|
|
|
|
| def formatting_prompts_func( |
| query: str, |
| doc: str, |
| query_type: str = 'text', |
| doc_type: str = 'text', |
| prefix_str: str = '', |
| ) -> str: |
| """ |
| Format prompts for different combinations of query and content types. |
| |
| Args: |
| query: Query text or image path |
| doc: Content text or image path |
| query_type: Whether query is an image |
| doc_type: Whether content is an image |
| prefix_str: Optional prefix string to add |
| """ |
| |
| if query_type == 'image': |
| query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>" |
| else: |
| query_part = f"**Query**:\n{query}" |
|
|
| |
| if doc_type == 'image': |
| doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>" |
| else: |
| doc_part = f"**Document**:\n{doc}" |
|
|
| |
| prompt = doc_part + '\n' + query_part |
|
|
| |
| if prefix_str: |
| prompt = prefix_str + '\n' + prompt |
|
|
| return prompt |
|
|
|
|
| class JinaVLForRanking(Qwen2VLForConditionalGeneration): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.padding_side = "left" |
| self.num_labels = 1 |
|
|
| |
| self.lm_head = nn.Identity() |
|
|
| |
| self.score = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.ReLU(), |
| nn.Linear(config.hidden_size, self.num_labels), |
| ) |
|
|
| |
| self.post_init() |
|
|
| self.score_token_id = 100 |
|
|
| def forward(self, *args, **kwargs) -> torch.Tensor: |
| |
| kwargs.pop("output_hidden_states", None) |
| kwargs.pop("use_cache", None) |
| output_attentions = kwargs.pop("output_attentions", False) |
|
|
| assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()" |
|
|
| outputs = super().forward( |
| *args, |
| use_cache=False, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
|
|
| |
| hidden_states = outputs.hidden_states[-1] |
|
|
| |
| |
| pooled_logits = self.score(hidden_states[:, -1]) |
|
|
| if output_attentions: |
| return pooled_logits.squeeze(-1), outputs.attentions |
|
|
| return pooled_logits.squeeze(-1) |
|
|
| @torch.no_grad() |
| def compute_score( |
| self, |
| pairs: Union[List[Tuple[str, str]], Tuple[str, str]], |
| batch_size: int = 8, |
| max_length: int = 10240, |
| max_query_length: int = 512, |
| max_doc_length: Optional[int] = None, |
| query_type: str = 'text', |
| doc_type: str = 'text', |
| normalize_scores: bool = True, |
| show_progress: bool = False, |
| ) -> List[float]: |
|
|
| if not hasattr(self, "_processor"): |
| from transformers import AutoProcessor |
|
|
| self._processor = AutoProcessor.from_pretrained( |
| self.name_or_path, max_pixels=602112, min_pixels=3136, trust_remote_code=True |
| ) |
|
|
| assert isinstance(pairs, list) |
|
|
| if isinstance(pairs[0], str): |
| pairs = [pairs] |
|
|
| max_length = max_length or self.config.max_length |
|
|
| if max_doc_length is None: |
| max_doc_length = max(max_length - max_query_length, max_query_length) |
|
|
| if max_doc_length < max_query_length: |
| warnings.warn( |
| f"max_doc_length={max_doc_length} should be greater than max_query_length={max_query_length}" |
| ) |
|
|
| assert ( |
| max_doc_length + max_query_length <= max_length |
| ), f"max_doc_length ({max_doc_length}) + max_query_length ({max_query_length}) should be less than max_length ({max_length})" |
|
|
| max_length = max_length - 1 |
|
|
| all_scores = [] |
|
|
| device = next(self.parameters()).device |
|
|
| batch_iter = range(0, len(pairs), batch_size) |
| if show_progress: |
| from tqdm import trange |
|
|
| batch_iter = trange(0, len(pairs), batch_size, desc="Computing scores") |
|
|
| for start_index in batch_iter: |
| mini_batch = pairs[start_index : start_index + batch_size] |
|
|
| batch_inputs = [] |
| for q, d in mini_batch: |
| |
| if doc_type == 'text': |
| tokens = self._processor.tokenizer(d, truncation=True, max_length=max_doc_length) |
| if len(tokens['input_ids']) >= max_doc_length: |
| d = self._processor.tokenizer.decode(tokens['input_ids']) |
|
|
| batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type)) |
|
|
| batch_images = None |
| |
| |
| |
| |
|
|
| doc_images = [] |
| query_images = [] |
| if doc_type == 'image': |
| doc_images = load_images([d for (q, d) in mini_batch]) |
| if query_type == 'image': |
| query_images = load_images([q for (q, d) in mini_batch]) |
|
|
| if len(doc_images) == len(query_images) and len(doc_images) > 0: |
| batch_images = [[d, q] for q, d in zip(query_images, doc_images)] |
| elif len(doc_images) > 0: |
| batch_images = doc_images |
| elif len(query_images) > 0: |
| batch_images = query_images |
|
|
| batch = self._processor( |
| text=batch_inputs, |
| images=batch_images, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
|
|
| |
| batch_size = batch["input_ids"].size(0) |
| batch["input_ids"] = torch.cat( |
| [ |
| batch["input_ids"], |
| torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device), |
| ], |
| dim=1, |
| ) |
| batch["attention_mask"] = torch.cat( |
| [ |
| batch["attention_mask"], |
| torch.ones((batch_size, 1), device=batch["attention_mask"].device), |
| ], |
| dim=1, |
| ) |
| |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
|
|
| scores = self.forward(**batch).view(-1).cpu().float().numpy() |
|
|
| |
| scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS))) |
|
|
| all_scores.extend(scores.tolist()) |
|
|
| if len(all_scores) == 1: |
| return all_scores[0] |
| return all_scores |
|
|