Qwen3-VL-Reranker-2B / scripts /qwen3_vl_reranker.py
littlebird13's picture
Upload folder using huggingface_hub
f9ca7f1 verified
raw
history blame
10.6 kB
import torch
import numpy as np
import logging
from PIL import Image
from scipy import special
from typing import List
from qwen_vl_utils import process_vision_info
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
logger = logging.getLogger(__name__)
IMAGE_BASE_FACTOR = 16
IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR # 1280 tokens
MAX_RATIO = 200
FRAME_FACTOR = 2
FPS = 1
MIN_FRAMES = 2
MAX_FRAMES = 64
MIN_TOTAL_PIXELS = 1 * FRAME_FACTOR * MIN_PIXELS # 1 frames
MAX_TOTAL_PIXELS = 4 * FRAME_FACTOR * MAX_PIXELS # 4 frames
def sample_frames(frames, num_segments, max_segments):
duration = len(frames)
frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
frame_id_list = frame_id_array.tolist()
last_frame_id = frame_id_list[-1]
sampled_frames = []
for frame_idx in frame_id_list:
try:
single_frame_path = frames[frame_idx]
except:
break
sampled_frames.append(single_frame_path)
# If total frame numbers is less than num_segments, append the last images to achieve
while len(sampled_frames) < num_segments:
sampled_frames.append(frames[last_frame_id])
return sampled_frames[:max_segments]
class Qwen3VLReranker():
def __init__(
self,
model_name_or_path: str,
**kwargs,
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.max_length = kwargs.pop('max_length', 8192)
self.default_instruction = "Given a search query, retrieve relevant candidates that answer the query."
self.min_pixels = kwargs.pop('min_pixels', MIN_PIXELS)
self.max_pixels = kwargs.pop('max_pixels', MAX_PIXELS)
self.total_pixels = kwargs.pop('total_pixels', MAX_TOTAL_PIXELS)
self.fps = kwargs.pop('fps', FPS)
self.num_frames = kwargs.pop('num_frames', None)
self.max_frames = kwargs.pop('max_frames', None)
lm = Qwen3VLForConditionalGeneration.from_pretrained(
model_name_or_path,
trust_remote_code=True, **kwargs
).to(self.device)
self.model = lm.model
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, trust_remote_code=True,
padding_side='left'
)
token_true_id = self.processor.tokenizer.get_vocab()["yes"]
token_false_id = self.processor.tokenizer.get_vocab()["no"]
self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
self.model.eval()
self.score_linear.eval()
self.score_linear.to(self.device).to(self.model.dtype)
def get_binary_linear(self, model, token_yes, token_no):
lm_head_weights = model.lm_head.weight.data
weight_yes = lm_head_weights[token_yes]
weight_no = lm_head_weights[token_no]
D = weight_yes.size()[0]
linear_layer = torch.nn.Linear(D, 1, bias=False)
with torch.no_grad():
linear_layer.weight[0] = weight_yes - weight_no
return linear_layer
@torch.no_grad()
def compute_scores(self, inputs):
batch_scores = self.model(**inputs).last_hidden_state[:, -1]
scores = self.score_linear(batch_scores)
scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
return scores
def truncate_tokens_optimized(
self,
tokens: List[str],
max_length: int,
special_tokens: List[str]
) -> List[str]:
if len(tokens) <= max_length:
return tokens
special_tokens_set = set(special_tokens)
# 1. 确定预算:计算我们能保留多少个非特殊token
num_special = sum(1 for token in tokens if token in special_tokens_set)
# 根据保证(特殊token总数 < max_length),这个值总是非负的
num_non_special_to_keep = max_length - num_special
# 2. 按预算构建最终列表
final_tokens = []
non_special_kept_count = 0
for token in tokens:
# 如果是特殊token,直接保留
if token in special_tokens_set:
final_tokens.append(token)
# 如果是非特殊token,并且我们还有预算
elif non_special_kept_count < num_non_special_to_keep:
final_tokens.append(token)
non_special_kept_count += 1
# 如果是非特殊token但预算已用完,则丢弃(即什么都不做)
return final_tokens
def tokenize(self, pairs: list, **kwargs):
max_length = self.max_length
text = self.processor.apply_chat_template(pairs, tokenize=False, add_generation_prompt=True)
try:
images, videos, video_kwargs = process_vision_info(
pairs, image_patch_size=16,
return_video_kwargs=True, return_video_metadata=True
)
except Exception as e:
logger.warning(f"Error in processing vision info: {e}")
images = None
videos = None
video_kwargs = {'do_sample_frames': False}
text = self.processor.apply_chat_template(
[{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}],
add_generation_prompt=True, tokenize=False
)
if videos is not None:
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
else:
video_metadatas = None
inputs = self.processor(text=text,
images=images,
videos=videos,
video_metadata=video_metadatas,
truncation=False,
padding=False,
max_length=max_length,
do_resize=False,
**video_kwargs)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = self.truncate_tokens_optimized(inputs['input_ids'][i][:-5], max_length,
self.processor.tokenizer.all_special_ids) + \
inputs['input_ids'][i][-5:]
temp_inputs = self.processor.tokenizer.pad({'input_ids': inputs['input_ids']}, padding=True,
return_tensors="pt", max_length=self.max_length)
for key in temp_inputs:
inputs[key] = temp_inputs[key]
return inputs
def format_mm_content(self, text, image, video, prefix='Query:', fps=None):
content = []
content.append({'type': 'text', 'text': prefix})
if not text and not image and not video:
content.append({'type': 'text', 'text': ""})
return content
if video:
video_content = None
if isinstance(video, list):
video_content = video
if self.num_frames is not None or self.max_frames is not None:
video_content = sample_frames(video_content, self.num_frames, self.max_frames)
video_content = ['file://' + ele for ele in video_content]
if video.startswith('http') or video.startswith('oss'):
video_content = video
elif isinstance(video, str):
video_content = 'file://' + video
if video_content:
content.append({'type': 'video', 'video': video_content, 'total_pixels': self.total_pixels, 'fps': fps})
if image:
image_content = None
if isinstance(image, Image.Image):
image_content = image
elif image.startswith('http') or image.startswith('oss'):
image_content = image
elif isinstance(image, str):
image_content = 'file://' + image
else:
image_content = image
if image_content:
content.append({'type': 'image', 'image': image_content, "min_pixels": self.min_pixels,
"max_pixels": self.max_pixels})
if text:
content.append({'type': 'text', 'text': text})
return content
def format_mm_instruction(
self,
query_text, query_image, query_video,
doc_text, doc_image, doc_video,
instruction=None, fps=None
):
inputs = []
inputs.append({
"role": "system",
"content": [{
"type": "text",
"text": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."
}
]
})
if isinstance(query_text, tuple):
instruct, query_text = query_text
else:
instruct = instruction
contents = []
contents.append({
"type": "text",
"text": '<Instruct>: ' + instruct
})
query_content = self.format_mm_content(query_text, query_image, query_video, prefix='<Query>:', fps=fps)
contents.extend(query_content)
doc_content = self.format_mm_content(doc_text, doc_image, doc_video, prefix='\n<Document>:', fps=fps)
contents.extend(doc_content)
inputs.append({
"role": "user",
"content": contents
})
return inputs
def process(
self,
inputs,
) -> list[torch.Tensor]:
instruction = inputs.get('instruction', self.default_instruction)
query = inputs.get("query", {})
documents = inputs.get("documents", [])
if not query or not documents:
return []
pairs = [self.format_mm_instruction(
query.get('text', None),
query.get('image', None),
query.get('video', None),
document.get('text', None),
document.get('image', None),
document.get('video', None),
instruction=instruction,
fps=inputs.get('fps', self.fps))
for document in documents]
final_scores = []
for pair in pairs:
inputs = self.tokenize([pair])
inputs = inputs.to(self.model.device)
scores = self.compute_scores(inputs)
final_scores.extend(scores)
return final_scores