zyznull commited on
Commit
79d4fd0
·
verified ·
1 Parent(s): de4ea3c

Update scripts/qwen3_vl_reranker.py

Browse files
Files changed (1) hide show
  1. scripts/qwen3_vl_reranker.py +88 -59
scripts/qwen3_vl_reranker.py CHANGED
@@ -10,6 +10,7 @@ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
13
  IMAGE_BASE_FACTOR = 16
14
  IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
15
  MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
@@ -37,33 +38,37 @@ def sample_frames(frames, num_segments, max_segments):
37
  except:
38
  break
39
  sampled_frames.append(single_frame_path)
40
- # If total frame numbers is less than num_segments, append the last images to achieve
41
  while len(sampled_frames) < num_segments:
42
  sampled_frames.append(frames[last_frame_id])
43
  return sampled_frames[:max_segments]
44
 
45
-
46
-
47
  class Qwen3VLReranker():
48
  def __init__(
49
  self,
50
  model_name_or_path: str,
 
 
 
 
 
 
 
 
51
  **kwargs,
52
  ):
53
 
54
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
 
56
- self.max_length = kwargs.pop('max_length', 8192)
57
-
58
- self.default_instruction = "Given a search query, retrieve relevant candidates that answer the query."
59
-
60
- self.min_pixels = kwargs.pop('min_pixels', MIN_PIXELS)
61
- self.max_pixels = kwargs.pop('max_pixels', MAX_PIXELS)
62
- self.total_pixels = kwargs.pop('total_pixels', MAX_TOTAL_PIXELS)
63
- self.fps = kwargs.pop('fps', FPS)
64
- self.num_frames = kwargs.pop('num_frames', None)
65
- self.max_frames = kwargs.pop('max_frames', None)
66
 
 
67
 
68
  lm = Qwen3VLForConditionalGeneration.from_pretrained(
69
  model_name_or_path,
@@ -71,16 +76,15 @@ class Qwen3VLReranker():
71
  ).to(self.device)
72
 
73
  self.model = lm.model
74
-
75
  self.processor = AutoProcessor.from_pretrained(
76
  model_name_or_path, trust_remote_code=True,
77
  padding_side='left'
78
  )
 
79
 
80
  token_true_id = self.processor.tokenizer.get_vocab()["yes"]
81
  token_false_id = self.processor.tokenizer.get_vocab()["no"]
82
  self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
83
- self.model.eval()
84
  self.score_linear.eval()
85
  self.score_linear.to(self.device).to(self.model.dtype)
86
 
@@ -115,24 +119,19 @@ class Qwen3VLReranker():
115
 
116
  special_tokens_set = set(special_tokens)
117
 
118
- # 1. 确定预算:计算我们能保留多少个非特殊token
119
  num_special = sum(1 for token in tokens if token in special_tokens_set)
120
-
121
- # 根据保证(特殊token总数 < max_length),这个值总是非负的
122
  num_non_special_to_keep = max_length - num_special
123
 
124
- # 2. 按预算构建最终列表
125
  final_tokens = []
126
  non_special_kept_count = 0
127
  for token in tokens:
128
- # 如果是特殊token,直接保留
129
  if token in special_tokens_set:
130
  final_tokens.append(token)
131
- # 如果是非特殊token,并且我们还有预算
132
  elif non_special_kept_count < num_non_special_to_keep:
133
  final_tokens.append(token)
134
  non_special_kept_count += 1
135
- # 如果是非特殊token但预算已用完,则丢弃(即什么都不做)
136
 
137
  return final_tokens
138
 
@@ -142,10 +141,11 @@ class Qwen3VLReranker():
142
  try:
143
  images, videos, video_kwargs = process_vision_info(
144
  pairs, image_patch_size=16,
145
- return_video_kwargs=True, return_video_metadata=True
 
146
  )
147
  except Exception as e:
148
- logger.warning(f"Error in processing vision info: {e}")
149
  images = None
150
  videos = None
151
  video_kwargs = {'do_sample_frames': False}
@@ -159,60 +159,80 @@ class Qwen3VLReranker():
159
  videos, video_metadatas = list(videos), list(video_metadatas)
160
  else:
161
  video_metadatas = None
162
- inputs = self.processor(text=text,
163
- images=images,
164
- videos=videos,
165
- video_metadata=video_metadatas,
166
- truncation=False,
167
- padding=False,
168
- max_length=max_length,
169
- do_resize=False,
170
- **video_kwargs)
 
171
  for i, ele in enumerate(inputs['input_ids']):
172
- inputs['input_ids'][i] = self.truncate_tokens_optimized(inputs['input_ids'][i][:-5], max_length,
173
- self.processor.tokenizer.all_special_ids) + \
174
- inputs['input_ids'][i][-5:]
175
- temp_inputs = self.processor.tokenizer.pad({'input_ids': inputs['input_ids']}, padding=True,
176
- return_tensors="pt", max_length=self.max_length)
 
 
 
177
  for key in temp_inputs:
178
  inputs[key] = temp_inputs[key]
179
  return inputs
180
 
181
- def format_mm_content(self, text, image, video, prefix='Query:', fps=None):
 
 
 
 
 
182
  content = []
183
 
184
  content.append({'type': 'text', 'text': prefix})
185
  if not text and not image and not video:
186
- content.append({'type': 'text', 'text': ""})
187
  return content
 
188
  if video:
189
  video_content = None
 
190
  if isinstance(video, list):
191
  video_content = video
192
  if self.num_frames is not None or self.max_frames is not None:
193
- video_content = sample_frames(video_content, self.num_frames, self.max_frames)
194
- video_content = ['file://' + ele for ele in video_content]
195
- if video.startswith('http') or video.startswith('oss'):
196
- video_content = video
 
197
  elif isinstance(video, str):
198
- video_content = 'file://' + video
 
 
 
 
199
  if video_content:
200
- content.append({'type': 'video', 'video': video_content, 'total_pixels': self.total_pixels, 'fps': fps})
 
 
 
201
 
202
  if image:
203
  image_content = None
204
  if isinstance(image, Image.Image):
205
  image_content = image
206
-
207
- elif image.startswith('http') or image.startswith('oss'):
208
- image_content = image
209
  elif isinstance(image, str):
210
- image_content = 'file://' + image
211
  else:
212
- image_content = image
 
213
  if image_content:
214
- content.append({'type': 'image', 'image': image_content, "min_pixels": self.min_pixels,
215
- "max_pixels": self.max_pixels})
 
 
 
216
 
217
  if text:
218
  content.append({'type': 'text', 'text': text})
@@ -222,7 +242,8 @@ class Qwen3VLReranker():
222
  self,
223
  query_text, query_image, query_video,
224
  doc_text, doc_image, doc_video,
225
- instruction=None, fps=None
 
226
  ):
227
  inputs = []
228
  inputs.append({
@@ -242,9 +263,15 @@ class Qwen3VLReranker():
242
  "type": "text",
243
  "text": '<Instruct>: ' + instruct
244
  })
245
- query_content = self.format_mm_content(query_text, query_image, query_video, prefix='<Query>:', fps=fps)
 
 
 
246
  contents.extend(query_content)
247
- doc_content = self.format_mm_content(doc_text, doc_image, doc_video, prefix='\n<Document>:', fps=fps)
 
 
 
248
  contents.extend(doc_content)
249
  inputs.append({
250
  "role": "user",
@@ -271,12 +298,14 @@ class Qwen3VLReranker():
271
  document.get('image', None),
272
  document.get('video', None),
273
  instruction=instruction,
274
- fps=inputs.get('fps', self.fps))
275
- for document in documents]
 
 
276
  final_scores = []
277
  for pair in pairs:
278
  inputs = self.tokenize([pair])
279
  inputs = inputs.to(self.model.device)
280
  scores = self.compute_scores(inputs)
281
  final_scores.extend(scores)
282
- return final_scores
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ MAX_LENGTH = 8192
14
  IMAGE_BASE_FACTOR = 16
15
  IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
16
  MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
 
38
  except:
39
  break
40
  sampled_frames.append(single_frame_path)
41
+ # Pad with last frame if total frames less than num_segments
42
  while len(sampled_frames) < num_segments:
43
  sampled_frames.append(frames[last_frame_id])
44
  return sampled_frames[:max_segments]
45
 
 
 
46
  class Qwen3VLReranker():
47
  def __init__(
48
  self,
49
  model_name_or_path: str,
50
+ max_length: int = MAX_LENGTH,
51
+ min_pixels: int = MIN_PIXELS,
52
+ max_pixels: int = MAX_PIXELS,
53
+ total_pixels: int = MAX_TOTAL_PIXELS,
54
+ fps: float = FPS,
55
+ num_frames: int = MAX_FRAMES,
56
+ max_frames: int = MAX_FRAMES,
57
+ default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.",
58
  **kwargs,
59
  ):
60
 
61
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
+ self.max_length = max_length
64
+ self.min_pixels = min_pixels
65
+ self.max_pixels = max_pixels
66
+ self.total_pixels = total_pixels
67
+ self.fps = fps
68
+ self.num_frames = num_frames
69
+ self.max_frames = max_frames
 
 
 
70
 
71
+ self.default_instruction = default_instruction
72
 
73
  lm = Qwen3VLForConditionalGeneration.from_pretrained(
74
  model_name_or_path,
 
76
  ).to(self.device)
77
 
78
  self.model = lm.model
 
79
  self.processor = AutoProcessor.from_pretrained(
80
  model_name_or_path, trust_remote_code=True,
81
  padding_side='left'
82
  )
83
+ self.model.eval()
84
 
85
  token_true_id = self.processor.tokenizer.get_vocab()["yes"]
86
  token_false_id = self.processor.tokenizer.get_vocab()["no"]
87
  self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
 
88
  self.score_linear.eval()
89
  self.score_linear.to(self.device).to(self.model.dtype)
90
 
 
119
 
120
  special_tokens_set = set(special_tokens)
121
 
122
+ # Calculate budget: how many non-special tokens we can keep
123
  num_special = sum(1 for token in tokens if token in special_tokens_set)
 
 
124
  num_non_special_to_keep = max_length - num_special
125
 
126
+ # Build final list according to budget
127
  final_tokens = []
128
  non_special_kept_count = 0
129
  for token in tokens:
 
130
  if token in special_tokens_set:
131
  final_tokens.append(token)
 
132
  elif non_special_kept_count < num_non_special_to_keep:
133
  final_tokens.append(token)
134
  non_special_kept_count += 1
 
135
 
136
  return final_tokens
137
 
 
141
  try:
142
  images, videos, video_kwargs = process_vision_info(
143
  pairs, image_patch_size=16,
144
+ return_video_kwargs=True,
145
+ return_video_metadata=True
146
  )
147
  except Exception as e:
148
+ logger.error(f"Error in processing vision info: {e}")
149
  images = None
150
  videos = None
151
  video_kwargs = {'do_sample_frames': False}
 
159
  videos, video_metadatas = list(videos), list(video_metadatas)
160
  else:
161
  video_metadatas = None
162
+ inputs = self.processor(
163
+ text=text,
164
+ images=images,
165
+ videos=videos,
166
+ video_metadata=video_metadatas,
167
+ truncation=False,
168
+ padding=False,
169
+ do_resize=False,
170
+ **video_kwargs
171
+ )
172
  for i, ele in enumerate(inputs['input_ids']):
173
+ inputs['input_ids'][i] = self.truncate_tokens_optimized(
174
+ inputs['input_ids'][i][:-5], max_length,
175
+ self.processor.tokenizer.all_special_ids
176
+ ) + inputs['input_ids'][i][-5:]
177
+ temp_inputs = self.processor.tokenizer.pad(
178
+ {'input_ids': inputs['input_ids']}, padding=True,
179
+ return_tensors="pt", max_length=self.max_length
180
+ )
181
  for key in temp_inputs:
182
  inputs[key] = temp_inputs[key]
183
  return inputs
184
 
185
+ def format_mm_content(
186
+ self,
187
+ text, image, video,
188
+ prefix='Query:',
189
+ fps=None, max_frames=None,
190
+ ):
191
  content = []
192
 
193
  content.append({'type': 'text', 'text': prefix})
194
  if not text and not image and not video:
195
+ content.append({'type': 'text', 'text': "NULL"})
196
  return content
197
+
198
  if video:
199
  video_content = None
200
+ video_kwargs = { 'total_pixels': self.total_pixels }
201
  if isinstance(video, list):
202
  video_content = video
203
  if self.num_frames is not None or self.max_frames is not None:
204
+ video_content = self._sample_frames(video_content, self.num_frames, self.max_frames)
205
+ video_content = [
206
+ ('file://' + ele if isinstance(ele, str) else ele)
207
+ for ele in video_content
208
+ ]
209
  elif isinstance(video, str):
210
+ video_content = video if video.startswith(('http://', 'https://')) else 'file://' + video
211
+ video_kwargs = {'fps': fps or self.fps, 'max_frames': max_frames or self.max_frames,}
212
+ else:
213
+ raise TypeError(f"Unrecognized video type: {type(video)}")
214
+
215
  if video_content:
216
+ content.append({
217
+ 'type': 'video', 'video': video_content,
218
+ **video_kwargs
219
+ })
220
 
221
  if image:
222
  image_content = None
223
  if isinstance(image, Image.Image):
224
  image_content = image
 
 
 
225
  elif isinstance(image, str):
226
+ image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
227
  else:
228
+ raise TypeError(f"Unrecognized image type: {type(image)}")
229
+
230
  if image_content:
231
+ content.append({
232
+ 'type': 'image', 'image': image_content,
233
+ "min_pixels": self.min_pixels,
234
+ "max_pixels": self.max_pixels
235
+ })
236
 
237
  if text:
238
  content.append({'type': 'text', 'text': text})
 
242
  self,
243
  query_text, query_image, query_video,
244
  doc_text, doc_image, doc_video,
245
+ instruction=None,
246
+ fps=None, max_frames=None
247
  ):
248
  inputs = []
249
  inputs.append({
 
263
  "type": "text",
264
  "text": '<Instruct>: ' + instruct
265
  })
266
+ query_content = self.format_mm_content(
267
+ query_text, query_image, query_video, prefix='<Query>:',
268
+ fps=fps, max_frames=max_frames
269
+ )
270
  contents.extend(query_content)
271
+ doc_content = self.format_mm_content(
272
+ doc_text, doc_image, doc_video, prefix='\n<Document>:',
273
+ fps=fps, max_frames=max_frames
274
+ )
275
  contents.extend(doc_content)
276
  inputs.append({
277
  "role": "user",
 
298
  document.get('image', None),
299
  document.get('video', None),
300
  instruction=instruction,
301
+ fps=inputs.get('fps', self.fps),
302
+ max_frames=inputs.get('max_frames', self.max_frames)
303
+ ) for document in documents]
304
+
305
  final_scores = []
306
  for pair in pairs:
307
  inputs = self.tokenize([pair])
308
  inputs = inputs.to(self.model.device)
309
  scores = self.compute_scores(inputs)
310
  final_scores.extend(scores)
311
+ return final_scores