qxyf commited on
Commit
bbe53f9
·
1 Parent(s): 22d5b8d
Files changed (2) hide show
  1. app.py +50 -31
  2. scripts/qwen3_vl_reranker.py +311 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import numpy as np
4
  from transformers import AutoProcessor, AutoModelForSequenceClassification
5
  from scripts.qwen3_vl_embedding import Qwen3VLEmbedder
 
6
  import cv2
7
  import os
8
  from typing import List
@@ -22,61 +23,79 @@ embedder = Qwen3VLEmbedder(
22
  # Reranker 模型
23
  try:
24
  reranker_id = "Qwen/Qwen3-VL-Reranker-2B"
25
- reranker_processor = AutoProcessor.from_pretrained(reranker_id)
26
- reranker_model = AutoModelForSequenceClassification.from_pretrained(
27
- reranker_id,
28
  torch_dtype=torch.float16,
29
- device_map="cpu"
30
  )
31
- reranker_model.eval()
32
- print("Qwen3-VL-Reranker-2B 加载完成")
33
  except Exception as e:
34
  print(f"Reranker 加载失败: {str(e)}")
35
  reranker_model = None
36
- reranker_processor = None
37
 
38
  # ================================
39
  # Reranker 函数(核心修复版)
40
  # ================================
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def qwen_vl_rerank(query_content: list, candidates: list):
43
  """
44
  query_content: list of dict,例如 [{"type": "text", "text": "..."}, {"type": "image", "image": pil_img}]
45
  candidates: list of list,每个元素是 candidate 的 content list,例如 [{"type": "image", "image": frame}]
46
  返回: list of (original_index, score) 按分数降序
47
  """
48
- if not candidates or reranker_model is None or reranker_processor is None:
49
  return []
50
 
51
- # 构造符合 Qwen3-VL-Reranker 格式的输入
52
- formatted_messages = []
 
 
53
  for cand in candidates:
54
- # 每个 pair messages 格式:query 和 document 拼接成一个 content list
55
- combined_content = query_content + cand
56
- formatted_messages.append([
57
- {"role": "user", "content": combined_content}
58
- ])
 
 
 
 
 
 
 
 
 
 
59
 
60
  try:
61
- # 使processor 处理批量
62
- inputs = reranker_processor.apply_chat_template(
63
- formatted_messages,
64
- tokenize=True,
65
- add_generation_prompt=False,
66
- return_tensors="pt",
67
- padding=True
68
- )
69
- inputs = {k: v.to(reranker_model.device) for k, v in inputs.items()}
70
-
71
- with torch.no_grad():
72
- outputs = reranker_model(**inputs)
73
- logits = outputs.logits.squeeze(-1) # [batch_size]
74
- scores = torch.sigmoid(logits).cpu().numpy() # 转换为 0~1 分数
75
-
76
  # 返回 (原索引, 分数) 并排序
77
  indexed_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
78
  return indexed_scores
79
-
80
  except Exception as e:
81
  print(f"Reranker 执行失败: {str(e)}")
82
  return []
 
3
  import numpy as np
4
  from transformers import AutoProcessor, AutoModelForSequenceClassification
5
  from scripts.qwen3_vl_embedding import Qwen3VLEmbedder
6
+ from scripts.qwen3_vl_reranker import Qwen3VLReranker
7
  import cv2
8
  import os
9
  from typing import List
 
23
  # Reranker 模型
24
  try:
25
  reranker_id = "Qwen/Qwen3-VL-Reranker-2B"
26
+ # 使用官方脚本加载模型
27
+ reranker_model = Qwen3VLReranker(
28
+ model_name_or_path=reranker_id,
29
  torch_dtype=torch.float16,
30
+ device_map="cuda" if torch.cuda.is_available() else "cpu"
31
  )
32
+ print("Qwen3-VL-Reranker-2B 加载完成 (Using official script)")
 
33
  except Exception as e:
34
  print(f"Reranker 加载失败: {str(e)}")
35
  reranker_model = None
 
36
 
37
  # ================================
38
  # Reranker 函数(核心修复版)
39
  # ================================
40
 
41
+ def extract_content_from_list(content_list):
42
+ """
43
+ 辅助函数:从 content list 中提取 text, image, video
44
+ """
45
+ text_parts = []
46
+ images = []
47
+ videos = []
48
+ for item in content_list:
49
+ if item['type'] == 'text':
50
+ text_parts.append(item['text'])
51
+ elif item['type'] == 'image':
52
+ images.append(item['image'])
53
+ elif item['type'] == 'video':
54
+ videos.append(item['video'])
55
+
56
+ text = "\n".join(text_parts) if text_parts else None
57
+ image = images[0] if images else None
58
+ video = videos[0] if videos else None
59
+ return text, image, video
60
+
61
  def qwen_vl_rerank(query_content: list, candidates: list):
62
  """
63
  query_content: list of dict,例如 [{"type": "text", "text": "..."}, {"type": "image", "image": pil_img}]
64
  candidates: list of list,每个元素是 candidate 的 content list,例如 [{"type": "image", "image": frame}]
65
  返回: list of (original_index, score) 按分数降序
66
  """
67
+ if not candidates or reranker_model is None:
68
  return []
69
 
70
+ # 构造符合 Qwen3-VL-Reranker process 方法的输入
71
+ q_text, q_image, q_video = extract_content_from_list(query_content)
72
+
73
+ documents = []
74
  for cand in candidates:
75
+ d_text, d_image, d_video = extract_content_from_list(cand)
76
+ documents.append({
77
+ "text": d_text,
78
+ "image": d_image,
79
+ "video": d_video
80
+ })
81
+
82
+ inputs = {
83
+ "query": {
84
+ "text": q_text,
85
+ "image": q_image,
86
+ "video": q_video
87
+ },
88
+ "documents": documents
89
+ }
90
 
91
  try:
92
+ # process 方法
93
+ scores = reranker_model.process(inputs)
94
+
 
 
 
 
 
 
 
 
 
 
 
 
95
  # 返回 (原索引, 分数) 并排序
96
  indexed_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
97
  return indexed_scores
98
+
99
  except Exception as e:
100
  print(f"Reranker 执行失败: {str(e)}")
101
  return []
scripts/qwen3_vl_reranker.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging
4
+
5
+ from PIL import Image
6
+ from scipy import special
7
+ from typing import List
8
+ from qwen_vl_utils import process_vision_info
9
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ MAX_LENGTH = 8192
14
+ IMAGE_BASE_FACTOR = 16
15
+ IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
16
+ MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
17
+ MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR # 1280 tokens
18
+ MAX_RATIO = 200
19
+
20
+ FRAME_FACTOR = 2
21
+ FPS = 1
22
+ MIN_FRAMES = 2
23
+ MAX_FRAMES = 64
24
+ MIN_TOTAL_PIXELS = 1 * FRAME_FACTOR * MIN_PIXELS # 1 frames
25
+ MAX_TOTAL_PIXELS = 4 * FRAME_FACTOR * MAX_PIXELS # 4 frames
26
+
27
+
28
+ def sample_frames(frames, num_segments, max_segments):
29
+ duration = len(frames)
30
+ frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
31
+ frame_id_list = frame_id_array.tolist()
32
+ last_frame_id = frame_id_list[-1]
33
+
34
+ sampled_frames = []
35
+ for frame_idx in frame_id_list:
36
+ try:
37
+ single_frame_path = frames[frame_idx]
38
+ except:
39
+ break
40
+ sampled_frames.append(single_frame_path)
41
+ # Pad with last frame if total frames less than num_segments
42
+ while len(sampled_frames) < num_segments:
43
+ sampled_frames.append(frames[last_frame_id])
44
+ return sampled_frames[:max_segments]
45
+
46
+ class Qwen3VLReranker():
47
+ def __init__(
48
+ self,
49
+ model_name_or_path: str,
50
+ max_length: int = MAX_LENGTH,
51
+ min_pixels: int = MIN_PIXELS,
52
+ max_pixels: int = MAX_PIXELS,
53
+ total_pixels: int = MAX_TOTAL_PIXELS,
54
+ fps: float = FPS,
55
+ num_frames: int = MAX_FRAMES,
56
+ max_frames: int = MAX_FRAMES,
57
+ default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.",
58
+ **kwargs,
59
+ ):
60
+
61
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+
63
+ self.max_length = max_length
64
+ self.min_pixels = min_pixels
65
+ self.max_pixels = max_pixels
66
+ self.total_pixels = total_pixels
67
+ self.fps = fps
68
+ self.num_frames = num_frames
69
+ self.max_frames = max_frames
70
+
71
+ self.default_instruction = default_instruction
72
+
73
+ lm = Qwen3VLForConditionalGeneration.from_pretrained(
74
+ model_name_or_path,
75
+ trust_remote_code=True, **kwargs
76
+ ).to(self.device)
77
+
78
+ self.model = lm.model
79
+ self.processor = AutoProcessor.from_pretrained(
80
+ model_name_or_path, trust_remote_code=True,
81
+ padding_side='left'
82
+ )
83
+ self.model.eval()
84
+
85
+ token_true_id = self.processor.tokenizer.get_vocab()["yes"]
86
+ token_false_id = self.processor.tokenizer.get_vocab()["no"]
87
+ self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
88
+ self.score_linear.eval()
89
+ self.score_linear.to(self.device).to(self.model.dtype)
90
+
91
+ def get_binary_linear(self, model, token_yes, token_no):
92
+
93
+ lm_head_weights = model.lm_head.weight.data
94
+
95
+ weight_yes = lm_head_weights[token_yes]
96
+ weight_no = lm_head_weights[token_no]
97
+
98
+ D = weight_yes.size()[0]
99
+ linear_layer = torch.nn.Linear(D, 1, bias=False)
100
+ with torch.no_grad():
101
+ linear_layer.weight[0] = weight_yes - weight_no
102
+ return linear_layer
103
+
104
+ @torch.no_grad()
105
+ def compute_scores(self, inputs):
106
+ batch_scores = self.model(**inputs).last_hidden_state[:, -1]
107
+ scores = self.score_linear(batch_scores)
108
+ scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
109
+ return scores
110
+
111
+ def truncate_tokens_optimized(
112
+ self,
113
+ tokens: List[str],
114
+ max_length: int,
115
+ special_tokens: List[str]
116
+ ) -> List[str]:
117
+ if len(tokens) <= max_length:
118
+ return tokens
119
+
120
+ special_tokens_set = set(special_tokens)
121
+
122
+ # Calculate budget: how many non-special tokens we can keep
123
+ num_special = sum(1 for token in tokens if token in special_tokens_set)
124
+ num_non_special_to_keep = max_length - num_special
125
+
126
+ # Build final list according to budget
127
+ final_tokens = []
128
+ non_special_kept_count = 0
129
+ for token in tokens:
130
+ if token in special_tokens_set:
131
+ final_tokens.append(token)
132
+ elif non_special_kept_count < num_non_special_to_keep:
133
+ final_tokens.append(token)
134
+ non_special_kept_count += 1
135
+
136
+ return final_tokens
137
+
138
+ def tokenize(self, pairs: list, **kwargs):
139
+ max_length = self.max_length
140
+ text = self.processor.apply_chat_template(pairs, tokenize=False, add_generation_prompt=True)
141
+ try:
142
+ images, videos, video_kwargs = process_vision_info(
143
+ pairs, image_patch_size=16,
144
+ return_video_kwargs=True,
145
+ return_video_metadata=True
146
+ )
147
+ except Exception as e:
148
+ logger.error(f"Error in processing vision info: {e}")
149
+ images = None
150
+ videos = None
151
+ video_kwargs = {'do_sample_frames': False}
152
+ text = self.processor.apply_chat_template(
153
+ [{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}],
154
+ add_generation_prompt=True, tokenize=False
155
+ )
156
+
157
+ if videos is not None:
158
+ videos, video_metadatas = zip(*videos)
159
+ videos, video_metadatas = list(videos), list(video_metadatas)
160
+ else:
161
+ video_metadatas = None
162
+ inputs = self.processor(
163
+ text=text,
164
+ images=images,
165
+ videos=videos,
166
+ video_metadata=video_metadatas,
167
+ truncation=False,
168
+ padding=False,
169
+ do_resize=False,
170
+ **video_kwargs
171
+ )
172
+ for i, ele in enumerate(inputs['input_ids']):
173
+ inputs['input_ids'][i] = self.truncate_tokens_optimized(
174
+ inputs['input_ids'][i][:-5], max_length,
175
+ self.processor.tokenizer.all_special_ids
176
+ ) + inputs['input_ids'][i][-5:]
177
+ temp_inputs = self.processor.tokenizer.pad(
178
+ {'input_ids': inputs['input_ids']}, padding=True,
179
+ return_tensors="pt", max_length=self.max_length
180
+ )
181
+ for key in temp_inputs:
182
+ inputs[key] = temp_inputs[key]
183
+ return inputs
184
+
185
+ def format_mm_content(
186
+ self,
187
+ text, image, video,
188
+ prefix='Query:',
189
+ fps=None, max_frames=None,
190
+ ):
191
+ content = []
192
+
193
+ content.append({'type': 'text', 'text': prefix})
194
+ if not text and not image and not video:
195
+ content.append({'type': 'text', 'text': "NULL"})
196
+ return content
197
+
198
+ if video:
199
+ video_content = None
200
+ video_kwargs = { 'total_pixels': self.total_pixels }
201
+ if isinstance(video, list):
202
+ video_content = video
203
+ if self.num_frames is not None or self.max_frames is not None:
204
+ video_content = self._sample_frames(video_content, self.num_frames, self.max_frames)
205
+ video_content = [
206
+ ('file://' + ele if isinstance(ele, str) else ele)
207
+ for ele in video_content
208
+ ]
209
+ elif isinstance(video, str):
210
+ video_content = video if video.startswith(('http://', 'https://')) else 'file://' + video
211
+ video_kwargs = {'fps': fps or self.fps, 'max_frames': max_frames or self.max_frames,}
212
+ else:
213
+ raise TypeError(f"Unrecognized video type: {type(video)}")
214
+
215
+ if video_content:
216
+ content.append({
217
+ 'type': 'video', 'video': video_content,
218
+ **video_kwargs
219
+ })
220
+
221
+ if image:
222
+ image_content = None
223
+ if isinstance(image, Image.Image):
224
+ image_content = image
225
+ elif isinstance(image, str):
226
+ image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
227
+ else:
228
+ raise TypeError(f"Unrecognized image type: {type(image)}")
229
+
230
+ if image_content:
231
+ content.append({
232
+ 'type': 'image', 'image': image_content,
233
+ "min_pixels": self.min_pixels,
234
+ "max_pixels": self.max_pixels
235
+ })
236
+
237
+ if text:
238
+ content.append({'type': 'text', 'text': text})
239
+ return content
240
+
241
+ def format_mm_instruction(
242
+ self,
243
+ query_text, query_image, query_video,
244
+ doc_text, doc_image, doc_video,
245
+ instruction=None,
246
+ fps=None, max_frames=None
247
+ ):
248
+ inputs = []
249
+ inputs.append({
250
+ "role": "system",
251
+ "content": [{
252
+ "type": "text",
253
+ "text": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."
254
+ }
255
+ ]
256
+ })
257
+ if isinstance(query_text, tuple):
258
+ instruct, query_text = query_text
259
+ else:
260
+ instruct = instruction
261
+ contents = []
262
+ contents.append({
263
+ "type": "text",
264
+ "text": '<Instruct>: ' + instruct
265
+ })
266
+ query_content = self.format_mm_content(
267
+ query_text, query_image, query_video, prefix='<Query>:',
268
+ fps=fps, max_frames=max_frames
269
+ )
270
+ contents.extend(query_content)
271
+ doc_content = self.format_mm_content(
272
+ doc_text, doc_image, doc_video, prefix='\n<Document>:',
273
+ fps=fps, max_frames=max_frames
274
+ )
275
+ contents.extend(doc_content)
276
+ inputs.append({
277
+ "role": "user",
278
+ "content": contents
279
+ })
280
+ return inputs
281
+
282
+ def process(
283
+ self,
284
+ inputs,
285
+ ) -> list[torch.Tensor]:
286
+ instruction = inputs.get('instruction', self.default_instruction)
287
+
288
+ query = inputs.get("query", {})
289
+ documents = inputs.get("documents", [])
290
+ if not query or not documents:
291
+ return []
292
+
293
+ pairs = [self.format_mm_instruction(
294
+ query.get('text', None),
295
+ query.get('image', None),
296
+ query.get('video', None),
297
+ document.get('text', None),
298
+ document.get('image', None),
299
+ document.get('video', None),
300
+ instruction=instruction,
301
+ fps=inputs.get('fps', self.fps),
302
+ max_frames=inputs.get('max_frames', self.max_frames)
303
+ ) for document in documents]
304
+
305
+ final_scores = []
306
+ for pair in pairs:
307
+ inputs = self.tokenize([pair])
308
+ inputs = inputs.to(self.model.device)
309
+ scores = self.compute_scores(inputs)
310
+ final_scores.extend(scores)
311
+ return final_scores