Text Ranking
Transformers
Safetensors
sentence-transformers
qwen3_vl
image-text-to-text
multimodal rerank
text rerank
Instructions to use Qwen/Qwen3-VL-Reranker-2B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Qwen/Qwen3-VL-Reranker-2B with Transformers:
# Load model directly from transformers import AutoProcessor, AutoModelForImageTextToText processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-Reranker-2B") model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen3-VL-Reranker-2B") - sentence-transformers
How to use Qwen/Qwen3-VL-Reranker-2B with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("Qwen/Qwen3-VL-Reranker-2B") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Update README.md
#3
by zyznull - opened
- README.md +3 -3
- scripts/qwen3_vl_reranker.py +88 -59
README.md
CHANGED
|
@@ -57,7 +57,7 @@ We utilize retrieval task datasets from various subtasks of [MMEB-v2](https://hu
|
|
| 57 |
|
| 58 |
| Model | Size | MMEB-v2(Retrieval) - Avg | MMEB-v2(Retrieval) - Image | MMEB-v2(Retrieval) - Video | MMEB-v2(Retrieval) - VisDoc | MMTEB(Retrieval) | JinaVDR | ViDoRe(v3) |
|
| 59 |
|-------|------|--------------------------|----------------------------|----------------------------|------------------------------|------------------|---------|------------|
|
| 60 |
-
| Qwen3-VL-Embedding-2B | 2B | 73.
|
| 61 |
| jina-reranker-m0 | 2B | - | 68.2 | - | 85.2 | - | 82.2 | 57.8 |
|
| 62 |
| Qwen3-VL-Reranker-2B | 2B | 75.1 | 73.8 | 52.1 | 83.4 | 70.0 | 80.9 | 60.8 |
|
| 63 |
| Qwen3-VL-Reranker-8B | 8B | 79.2 | 80.7 | 55.8 | 86.3 | 74.9 | 83.6 | 66.7 |
|
|
@@ -87,7 +87,7 @@ model = Qwen3VLReranker(model_name_or_path=model_name_or_path)
|
|
| 87 |
# Combine queries and documents into a single input list
|
| 88 |
|
| 89 |
inputs = {
|
| 90 |
-
"instruction": "
|
| 91 |
"query": {"text": "A woman playing with her dog on a beach at sunset."},
|
| 92 |
"documents": [
|
| 93 |
{"text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, as the dog offers its paw in a heartwarming display of companionship and trust."},
|
|
@@ -99,7 +99,7 @@ inputs = {
|
|
| 99 |
|
| 100 |
scores = model.process(inputs)
|
| 101 |
print(scores)
|
| 102 |
-
# [0.
|
| 103 |
```
|
| 104 |
For more usage examples, please visit our [GitHub repository](https://github.com/QwenLM/Qwen3-VL-Embedding).
|
| 105 |
|
|
|
|
| 57 |
|
| 58 |
| Model | Size | MMEB-v2(Retrieval) - Avg | MMEB-v2(Retrieval) - Image | MMEB-v2(Retrieval) - Video | MMEB-v2(Retrieval) - VisDoc | MMTEB(Retrieval) | JinaVDR | ViDoRe(v3) |
|
| 59 |
|-------|------|--------------------------|----------------------------|----------------------------|------------------------------|------------------|---------|------------|
|
| 60 |
+
| Qwen3-VL-Embedding-2B | 2B | 73.4 | 74.8 | 53.6 | 79.2 | 68.1 | 71.0 | 52.9 |
|
| 61 |
| jina-reranker-m0 | 2B | - | 68.2 | - | 85.2 | - | 82.2 | 57.8 |
|
| 62 |
| Qwen3-VL-Reranker-2B | 2B | 75.1 | 73.8 | 52.1 | 83.4 | 70.0 | 80.9 | 60.8 |
|
| 63 |
| Qwen3-VL-Reranker-8B | 8B | 79.2 | 80.7 | 55.8 | 86.3 | 74.9 | 83.6 | 66.7 |
|
|
|
|
| 87 |
# Combine queries and documents into a single input list
|
| 88 |
|
| 89 |
inputs = {
|
| 90 |
+
"instruction": "Retrieve images or text relevant to the user's query.",
|
| 91 |
"query": {"text": "A woman playing with her dog on a beach at sunset."},
|
| 92 |
"documents": [
|
| 93 |
{"text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, as the dog offers its paw in a heartwarming display of companionship and trust."},
|
|
|
|
| 99 |
|
| 100 |
scores = model.process(inputs)
|
| 101 |
print(scores)
|
| 102 |
+
# [0.8613124489784241, 0.6757137179374695, 0.8125371336936951]
|
| 103 |
```
|
| 104 |
For more usage examples, please visit our [GitHub repository](https://github.com/QwenLM/Qwen3-VL-Embedding).
|
| 105 |
|
scripts/qwen3_vl_reranker.py
CHANGED
|
@@ -10,6 +10,7 @@ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
| 13 |
IMAGE_BASE_FACTOR = 16
|
| 14 |
IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
|
| 15 |
MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
|
|
@@ -37,33 +38,37 @@ def sample_frames(frames, num_segments, max_segments):
|
|
| 37 |
except:
|
| 38 |
break
|
| 39 |
sampled_frames.append(single_frame_path)
|
| 40 |
-
#
|
| 41 |
while len(sampled_frames) < num_segments:
|
| 42 |
sampled_frames.append(frames[last_frame_id])
|
| 43 |
return sampled_frames[:max_segments]
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
class Qwen3VLReranker():
|
| 48 |
def __init__(
|
| 49 |
self,
|
| 50 |
model_name_or_path: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
**kwargs,
|
| 52 |
):
|
| 53 |
|
| 54 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
|
| 56 |
-
self.max_length =
|
| 57 |
-
|
| 58 |
-
self.
|
| 59 |
-
|
| 60 |
-
self.
|
| 61 |
-
self.
|
| 62 |
-
self.
|
| 63 |
-
self.fps = kwargs.pop('fps', FPS)
|
| 64 |
-
self.num_frames = kwargs.pop('num_frames', None)
|
| 65 |
-
self.max_frames = kwargs.pop('max_frames', None)
|
| 66 |
|
|
|
|
| 67 |
|
| 68 |
lm = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 69 |
model_name_or_path,
|
|
@@ -71,16 +76,15 @@ class Qwen3VLReranker():
|
|
| 71 |
).to(self.device)
|
| 72 |
|
| 73 |
self.model = lm.model
|
| 74 |
-
|
| 75 |
self.processor = AutoProcessor.from_pretrained(
|
| 76 |
model_name_or_path, trust_remote_code=True,
|
| 77 |
padding_side='left'
|
| 78 |
)
|
|
|
|
| 79 |
|
| 80 |
token_true_id = self.processor.tokenizer.get_vocab()["yes"]
|
| 81 |
token_false_id = self.processor.tokenizer.get_vocab()["no"]
|
| 82 |
self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
|
| 83 |
-
self.model.eval()
|
| 84 |
self.score_linear.eval()
|
| 85 |
self.score_linear.to(self.device).to(self.model.dtype)
|
| 86 |
|
|
@@ -115,24 +119,19 @@ class Qwen3VLReranker():
|
|
| 115 |
|
| 116 |
special_tokens_set = set(special_tokens)
|
| 117 |
|
| 118 |
-
#
|
| 119 |
num_special = sum(1 for token in tokens if token in special_tokens_set)
|
| 120 |
-
|
| 121 |
-
# 根据保证(特殊token总数 < max_length),这个值总是非负的
|
| 122 |
num_non_special_to_keep = max_length - num_special
|
| 123 |
|
| 124 |
-
#
|
| 125 |
final_tokens = []
|
| 126 |
non_special_kept_count = 0
|
| 127 |
for token in tokens:
|
| 128 |
-
# 如果是特殊token,直接保留
|
| 129 |
if token in special_tokens_set:
|
| 130 |
final_tokens.append(token)
|
| 131 |
-
# 如果是非特殊token,并且我们还有预算
|
| 132 |
elif non_special_kept_count < num_non_special_to_keep:
|
| 133 |
final_tokens.append(token)
|
| 134 |
non_special_kept_count += 1
|
| 135 |
-
# 如果是非特殊token但预算已用完,则丢弃(即什么都不做)
|
| 136 |
|
| 137 |
return final_tokens
|
| 138 |
|
|
@@ -142,10 +141,11 @@ class Qwen3VLReranker():
|
|
| 142 |
try:
|
| 143 |
images, videos, video_kwargs = process_vision_info(
|
| 144 |
pairs, image_patch_size=16,
|
| 145 |
-
return_video_kwargs=True,
|
|
|
|
| 146 |
)
|
| 147 |
except Exception as e:
|
| 148 |
-
logger.
|
| 149 |
images = None
|
| 150 |
videos = None
|
| 151 |
video_kwargs = {'do_sample_frames': False}
|
|
@@ -159,60 +159,80 @@ class Qwen3VLReranker():
|
|
| 159 |
videos, video_metadatas = list(videos), list(video_metadatas)
|
| 160 |
else:
|
| 161 |
video_metadatas = None
|
| 162 |
-
inputs = self.processor(
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
for i, ele in enumerate(inputs['input_ids']):
|
| 172 |
-
inputs['input_ids'][i] = self.truncate_tokens_optimized(
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
for key in temp_inputs:
|
| 178 |
inputs[key] = temp_inputs[key]
|
| 179 |
return inputs
|
| 180 |
|
| 181 |
-
def format_mm_content(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
content = []
|
| 183 |
|
| 184 |
content.append({'type': 'text', 'text': prefix})
|
| 185 |
if not text and not image and not video:
|
| 186 |
-
content.append({'type': 'text', 'text': ""})
|
| 187 |
return content
|
|
|
|
| 188 |
if video:
|
| 189 |
video_content = None
|
|
|
|
| 190 |
if isinstance(video, list):
|
| 191 |
video_content = video
|
| 192 |
if self.num_frames is not None or self.max_frames is not None:
|
| 193 |
-
video_content =
|
| 194 |
-
video_content = [
|
| 195 |
-
|
| 196 |
-
|
|
|
|
| 197 |
elif isinstance(video, str):
|
| 198 |
-
video_content = 'file://' + video
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
if video_content:
|
| 200 |
-
content.append({
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
if image:
|
| 203 |
image_content = None
|
| 204 |
if isinstance(image, Image.Image):
|
| 205 |
image_content = image
|
| 206 |
-
|
| 207 |
-
elif image.startswith('http') or image.startswith('oss'):
|
| 208 |
-
image_content = image
|
| 209 |
elif isinstance(image, str):
|
| 210 |
-
image_content = 'file://' + image
|
| 211 |
else:
|
| 212 |
-
|
|
|
|
| 213 |
if image_content:
|
| 214 |
-
content.append({
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
if text:
|
| 218 |
content.append({'type': 'text', 'text': text})
|
|
@@ -222,7 +242,8 @@ class Qwen3VLReranker():
|
|
| 222 |
self,
|
| 223 |
query_text, query_image, query_video,
|
| 224 |
doc_text, doc_image, doc_video,
|
| 225 |
-
instruction=None,
|
|
|
|
| 226 |
):
|
| 227 |
inputs = []
|
| 228 |
inputs.append({
|
|
@@ -242,9 +263,15 @@ class Qwen3VLReranker():
|
|
| 242 |
"type": "text",
|
| 243 |
"text": '<Instruct>: ' + instruct
|
| 244 |
})
|
| 245 |
-
query_content = self.format_mm_content(
|
|
|
|
|
|
|
|
|
|
| 246 |
contents.extend(query_content)
|
| 247 |
-
doc_content = self.format_mm_content(
|
|
|
|
|
|
|
|
|
|
| 248 |
contents.extend(doc_content)
|
| 249 |
inputs.append({
|
| 250 |
"role": "user",
|
|
@@ -271,12 +298,14 @@ class Qwen3VLReranker():
|
|
| 271 |
document.get('image', None),
|
| 272 |
document.get('video', None),
|
| 273 |
instruction=instruction,
|
| 274 |
-
fps=inputs.get('fps', self.fps)
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
final_scores = []
|
| 277 |
for pair in pairs:
|
| 278 |
inputs = self.tokenize([pair])
|
| 279 |
inputs = inputs.to(self.model.device)
|
| 280 |
scores = self.compute_scores(inputs)
|
| 281 |
final_scores.extend(scores)
|
| 282 |
-
return final_scores
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
MAX_LENGTH = 8192
|
| 14 |
IMAGE_BASE_FACTOR = 16
|
| 15 |
IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
|
| 16 |
MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
|
|
|
|
| 38 |
except:
|
| 39 |
break
|
| 40 |
sampled_frames.append(single_frame_path)
|
| 41 |
+
# Pad with last frame if total frames less than num_segments
|
| 42 |
while len(sampled_frames) < num_segments:
|
| 43 |
sampled_frames.append(frames[last_frame_id])
|
| 44 |
return sampled_frames[:max_segments]
|
| 45 |
|
|
|
|
|
|
|
| 46 |
class Qwen3VLReranker():
|
| 47 |
def __init__(
|
| 48 |
self,
|
| 49 |
model_name_or_path: str,
|
| 50 |
+
max_length: int = MAX_LENGTH,
|
| 51 |
+
min_pixels: int = MIN_PIXELS,
|
| 52 |
+
max_pixels: int = MAX_PIXELS,
|
| 53 |
+
total_pixels: int = MAX_TOTAL_PIXELS,
|
| 54 |
+
fps: float = FPS,
|
| 55 |
+
num_frames: int = MAX_FRAMES,
|
| 56 |
+
max_frames: int = MAX_FRAMES,
|
| 57 |
+
default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.",
|
| 58 |
**kwargs,
|
| 59 |
):
|
| 60 |
|
| 61 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
|
| 63 |
+
self.max_length = max_length
|
| 64 |
+
self.min_pixels = min_pixels
|
| 65 |
+
self.max_pixels = max_pixels
|
| 66 |
+
self.total_pixels = total_pixels
|
| 67 |
+
self.fps = fps
|
| 68 |
+
self.num_frames = num_frames
|
| 69 |
+
self.max_frames = max_frames
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
self.default_instruction = default_instruction
|
| 72 |
|
| 73 |
lm = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 74 |
model_name_or_path,
|
|
|
|
| 76 |
).to(self.device)
|
| 77 |
|
| 78 |
self.model = lm.model
|
|
|
|
| 79 |
self.processor = AutoProcessor.from_pretrained(
|
| 80 |
model_name_or_path, trust_remote_code=True,
|
| 81 |
padding_side='left'
|
| 82 |
)
|
| 83 |
+
self.model.eval()
|
| 84 |
|
| 85 |
token_true_id = self.processor.tokenizer.get_vocab()["yes"]
|
| 86 |
token_false_id = self.processor.tokenizer.get_vocab()["no"]
|
| 87 |
self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
|
|
|
|
| 88 |
self.score_linear.eval()
|
| 89 |
self.score_linear.to(self.device).to(self.model.dtype)
|
| 90 |
|
|
|
|
| 119 |
|
| 120 |
special_tokens_set = set(special_tokens)
|
| 121 |
|
| 122 |
+
# Calculate budget: how many non-special tokens we can keep
|
| 123 |
num_special = sum(1 for token in tokens if token in special_tokens_set)
|
|
|
|
|
|
|
| 124 |
num_non_special_to_keep = max_length - num_special
|
| 125 |
|
| 126 |
+
# Build final list according to budget
|
| 127 |
final_tokens = []
|
| 128 |
non_special_kept_count = 0
|
| 129 |
for token in tokens:
|
|
|
|
| 130 |
if token in special_tokens_set:
|
| 131 |
final_tokens.append(token)
|
|
|
|
| 132 |
elif non_special_kept_count < num_non_special_to_keep:
|
| 133 |
final_tokens.append(token)
|
| 134 |
non_special_kept_count += 1
|
|
|
|
| 135 |
|
| 136 |
return final_tokens
|
| 137 |
|
|
|
|
| 141 |
try:
|
| 142 |
images, videos, video_kwargs = process_vision_info(
|
| 143 |
pairs, image_patch_size=16,
|
| 144 |
+
return_video_kwargs=True,
|
| 145 |
+
return_video_metadata=True
|
| 146 |
)
|
| 147 |
except Exception as e:
|
| 148 |
+
logger.error(f"Error in processing vision info: {e}")
|
| 149 |
images = None
|
| 150 |
videos = None
|
| 151 |
video_kwargs = {'do_sample_frames': False}
|
|
|
|
| 159 |
videos, video_metadatas = list(videos), list(video_metadatas)
|
| 160 |
else:
|
| 161 |
video_metadatas = None
|
| 162 |
+
inputs = self.processor(
|
| 163 |
+
text=text,
|
| 164 |
+
images=images,
|
| 165 |
+
videos=videos,
|
| 166 |
+
video_metadata=video_metadatas,
|
| 167 |
+
truncation=False,
|
| 168 |
+
padding=False,
|
| 169 |
+
do_resize=False,
|
| 170 |
+
**video_kwargs
|
| 171 |
+
)
|
| 172 |
for i, ele in enumerate(inputs['input_ids']):
|
| 173 |
+
inputs['input_ids'][i] = self.truncate_tokens_optimized(
|
| 174 |
+
inputs['input_ids'][i][:-5], max_length,
|
| 175 |
+
self.processor.tokenizer.all_special_ids
|
| 176 |
+
) + inputs['input_ids'][i][-5:]
|
| 177 |
+
temp_inputs = self.processor.tokenizer.pad(
|
| 178 |
+
{'input_ids': inputs['input_ids']}, padding=True,
|
| 179 |
+
return_tensors="pt", max_length=self.max_length
|
| 180 |
+
)
|
| 181 |
for key in temp_inputs:
|
| 182 |
inputs[key] = temp_inputs[key]
|
| 183 |
return inputs
|
| 184 |
|
| 185 |
+
def format_mm_content(
|
| 186 |
+
self,
|
| 187 |
+
text, image, video,
|
| 188 |
+
prefix='Query:',
|
| 189 |
+
fps=None, max_frames=None,
|
| 190 |
+
):
|
| 191 |
content = []
|
| 192 |
|
| 193 |
content.append({'type': 'text', 'text': prefix})
|
| 194 |
if not text and not image and not video:
|
| 195 |
+
content.append({'type': 'text', 'text': "NULL"})
|
| 196 |
return content
|
| 197 |
+
|
| 198 |
if video:
|
| 199 |
video_content = None
|
| 200 |
+
video_kwargs = { 'total_pixels': self.total_pixels }
|
| 201 |
if isinstance(video, list):
|
| 202 |
video_content = video
|
| 203 |
if self.num_frames is not None or self.max_frames is not None:
|
| 204 |
+
video_content = self._sample_frames(video_content, self.num_frames, self.max_frames)
|
| 205 |
+
video_content = [
|
| 206 |
+
('file://' + ele if isinstance(ele, str) else ele)
|
| 207 |
+
for ele in video_content
|
| 208 |
+
]
|
| 209 |
elif isinstance(video, str):
|
| 210 |
+
video_content = video if video.startswith(('http://', 'https://')) else 'file://' + video
|
| 211 |
+
video_kwargs = {'fps': fps or self.fps, 'max_frames': max_frames or self.max_frames,}
|
| 212 |
+
else:
|
| 213 |
+
raise TypeError(f"Unrecognized video type: {type(video)}")
|
| 214 |
+
|
| 215 |
if video_content:
|
| 216 |
+
content.append({
|
| 217 |
+
'type': 'video', 'video': video_content,
|
| 218 |
+
**video_kwargs
|
| 219 |
+
})
|
| 220 |
|
| 221 |
if image:
|
| 222 |
image_content = None
|
| 223 |
if isinstance(image, Image.Image):
|
| 224 |
image_content = image
|
|
|
|
|
|
|
|
|
|
| 225 |
elif isinstance(image, str):
|
| 226 |
+
image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
|
| 227 |
else:
|
| 228 |
+
raise TypeError(f"Unrecognized image type: {type(image)}")
|
| 229 |
+
|
| 230 |
if image_content:
|
| 231 |
+
content.append({
|
| 232 |
+
'type': 'image', 'image': image_content,
|
| 233 |
+
"min_pixels": self.min_pixels,
|
| 234 |
+
"max_pixels": self.max_pixels
|
| 235 |
+
})
|
| 236 |
|
| 237 |
if text:
|
| 238 |
content.append({'type': 'text', 'text': text})
|
|
|
|
| 242 |
self,
|
| 243 |
query_text, query_image, query_video,
|
| 244 |
doc_text, doc_image, doc_video,
|
| 245 |
+
instruction=None,
|
| 246 |
+
fps=None, max_frames=None
|
| 247 |
):
|
| 248 |
inputs = []
|
| 249 |
inputs.append({
|
|
|
|
| 263 |
"type": "text",
|
| 264 |
"text": '<Instruct>: ' + instruct
|
| 265 |
})
|
| 266 |
+
query_content = self.format_mm_content(
|
| 267 |
+
query_text, query_image, query_video, prefix='<Query>:',
|
| 268 |
+
fps=fps, max_frames=max_frames
|
| 269 |
+
)
|
| 270 |
contents.extend(query_content)
|
| 271 |
+
doc_content = self.format_mm_content(
|
| 272 |
+
doc_text, doc_image, doc_video, prefix='\n<Document>:',
|
| 273 |
+
fps=fps, max_frames=max_frames
|
| 274 |
+
)
|
| 275 |
contents.extend(doc_content)
|
| 276 |
inputs.append({
|
| 277 |
"role": "user",
|
|
|
|
| 298 |
document.get('image', None),
|
| 299 |
document.get('video', None),
|
| 300 |
instruction=instruction,
|
| 301 |
+
fps=inputs.get('fps', self.fps),
|
| 302 |
+
max_frames=inputs.get('max_frames', self.max_frames)
|
| 303 |
+
) for document in documents]
|
| 304 |
+
|
| 305 |
final_scores = []
|
| 306 |
for pair in pairs:
|
| 307 |
inputs = self.tokenize([pair])
|
| 308 |
inputs = inputs.to(self.model.device)
|
| 309 |
scores = self.compute_scores(inputs)
|
| 310 |
final_scores.extend(scores)
|
| 311 |
+
return final_scores
|