KinetoLabs Claude Opus 4.5 commited on
Commit
455c786
·
1 Parent(s): f3ebc82

Fix embedding/reranker loading with official Qwen3-VL classes

Browse files

Root cause: AutoModel.from_pretrained() loads base transformer
instead of specialized embedding/reranking variants.

Changes:
- Vendor official scripts from QwenLM/Qwen3-VL-Embedding repo
- Replace AutoModel with Qwen3VLEmbedder for embedding model
- Replace AutoModel with Qwen3VLReranker for reranker model
- Update embed()/rerank() methods to use official process() API

The official loaders handle:
- Proper last-token pooling and L2 normalization (embedding)
- Yes/no binary scoring from LM head weights (reranker)

This eliminates the fallback L2 norm heuristic scoring that was
producing "less accurate" results.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

models/real.py CHANGED
@@ -2,6 +2,11 @@
2
 
3
  This module loads the actual Qwen3-VL models for production use.
4
  Requires ~90GB VRAM (4xL4 with 96GB total).
 
 
 
 
 
5
  """
6
 
7
  import json
@@ -28,7 +33,7 @@ class RealModelStack:
28
 
29
  def load_all(self) -> "RealModelStack":
30
  """Load all models with device_map='auto' for multi-GPU distribution."""
31
- from transformers import AutoModel, AutoProcessor
32
 
33
  device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
34
  logger.info(f"Loading models on {device_type}")
@@ -71,34 +76,30 @@ class RealModelStack:
71
  )
72
  logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
73
 
74
- # Embedding model (~16GB in BF16)
75
  logger.info(f"Loading embedding model: {settings.embedding_model}")
76
  embed_start = time.time()
77
- self.models["embedding"] = AutoModel.from_pretrained(
78
- settings.embedding_model,
 
 
79
  torch_dtype=torch.bfloat16,
80
- device_map="auto",
81
- trust_remote_code=True,
82
- )
83
- self.processors["embedding"] = AutoProcessor.from_pretrained(
84
- settings.embedding_model,
85
- trust_remote_code=True,
86
  )
 
 
87
  logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
88
 
89
- # Reranker model (~16GB in BF16)
90
  logger.info(f"Loading reranker model: {settings.reranker_model}")
91
  reranker_start = time.time()
92
- self.models["reranker"] = AutoModel.from_pretrained(
93
- settings.reranker_model,
 
 
94
  torch_dtype=torch.bfloat16,
95
- device_map="auto",
96
- trust_remote_code=True,
97
- )
98
- self.processors["reranker"] = AutoProcessor.from_pretrained(
99
- settings.reranker_model,
100
- trust_remote_code=True,
101
  )
 
 
102
  logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
103
 
104
  self.loaded = True
@@ -370,80 +371,68 @@ IMPORTANT: Return ONLY valid JSON, no additional text."""
370
  class RealEmbeddingModel:
371
  """Wrapper for real embedding model inference.
372
 
373
- Uses last-token pooling per official Qwen3-VL-Embedding implementation:
374
- https://github.com/QwenLM/Qwen3-VL-Embedding
375
  """
376
 
377
  def __init__(self, model, processor):
 
 
 
 
 
 
378
  self.model = model
379
  self.processor = processor
380
 
381
- @staticmethod
382
- def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
383
- """Extract the last valid token's hidden state based on attention mask.
384
 
385
- This is the official pooling method from Qwen3-VL-Embedding.
386
- It finds the last position where attention_mask == 1 and extracts that token.
387
- """
388
- # Flip attention mask to find last 1 position
389
- flipped_tensor = attention_mask.flip(dims=[1])
390
- last_one_positions = flipped_tensor.argmax(dim=1)
391
- col = attention_mask.shape[1] - last_one_positions - 1
392
- row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
393
- return hidden_state[row, col]
394
 
395
- def embed(self, text: str) -> list[float]:
396
- """Generate embedding for text using last-token pooling.
397
 
398
- Per Qwen3-VL-Embedding: extracts the hidden state of the last valid token,
399
- then applies L2 normalization.
400
  """
401
  try:
402
- # Tokenize input
403
- inputs = self.processor(
404
- text,
405
- return_tensors="pt",
406
- padding=True,
407
- truncation=True,
408
- max_length=512,
409
- )
410
-
411
- # Move to model device
412
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
413
-
414
- # Generate embeddings
415
- with torch.no_grad():
416
- outputs = self.model(**inputs)
417
-
418
- # Use last-token pooling (official Qwen3-VL-Embedding method)
419
- # outputs.last_hidden_state shape: (batch, seq_len, hidden_dim)
420
- attention_mask = inputs.get("attention_mask")
421
- if attention_mask is not None:
422
- embeddings = self._pooling_last(outputs.last_hidden_state, attention_mask)
423
- else:
424
- # Fallback: use last token if no attention mask
425
- embeddings = outputs.last_hidden_state[:, -1, :]
426
-
427
- # L2 normalize (per official implementation)
428
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
429
 
 
430
  return embeddings[0].cpu().tolist()
431
 
432
  except Exception as e:
433
  logger.error(f"Embedding generation failed: {e}")
434
  # Return zero vector as fallback (4096-dim per Qwen3-VL-Embedding-8B)
435
- hidden_size = getattr(self.model.config, "hidden_size", 4096)
436
  return [0.0] * hidden_size
437
 
438
  def embed_batch(self, texts: list[str]) -> list[list[float]]:
439
- """Generate embeddings for a batch of texts."""
440
- return [self.embed(text) for text in texts]
 
 
 
 
 
 
 
 
 
 
441
 
442
 
443
  class RealRerankerModel:
444
  """Wrapper for real reranker model inference.
445
 
446
- Uses the official Qwen3-VL-Reranker scoring method:
 
447
  - Extracts "yes" and "no" token weights from the LM head
448
  - Creates a binary linear layer: weight = yes_weight - no_weight
449
  - Scores = sigmoid(linear(last_token_hidden_state))
@@ -452,118 +441,48 @@ class RealRerankerModel:
452
  """
453
 
454
  def __init__(self, model, processor):
455
- self.model = model
456
- self.processor = processor
457
- self.score_linear = None
458
- self._initialize_score_linear()
459
 
460
- def _initialize_score_linear(self):
461
- """Initialize the binary scoring linear layer from LM head weights.
462
-
463
- Per Qwen3-VL-Reranker: the scoring layer uses the difference between
464
- "yes" and "no" token embeddings from the language model head.
465
  """
466
- try:
467
- # Get tokenizer vocab to find yes/no token IDs
468
- tokenizer = self.processor.tokenizer if hasattr(self.processor, 'tokenizer') else self.processor
469
- vocab = tokenizer.get_vocab()
470
-
471
- # Find yes/no token IDs
472
- token_yes_id = vocab.get("yes")
473
- token_no_id = vocab.get("no")
474
-
475
- if token_yes_id is None or token_no_id is None:
476
- logger.warning("Could not find 'yes'/'no' tokens in vocab, using fallback scoring")
477
- return
478
-
479
- # Get LM head weights
480
- if not hasattr(self.model, 'lm_head'):
481
- logger.warning("Model does not have lm_head, using fallback scoring")
482
- return
483
-
484
- lm_head_weights = self.model.lm_head.weight.data
485
-
486
- # Extract yes/no weights
487
- weight_yes = lm_head_weights[token_yes_id]
488
- weight_no = lm_head_weights[token_no_id]
489
-
490
- # Create binary linear layer: weight = yes - no
491
- hidden_size = weight_yes.shape[0]
492
- self.score_linear = torch.nn.Linear(hidden_size, 1, bias=False)
493
- self.score_linear.weight.data[0] = weight_yes - weight_no
494
- self.score_linear = self.score_linear.to(self.model.device)
495
- self.score_linear.eval()
496
 
497
- logger.info(f"Initialized reranker score linear from yes/no LM head weights (hidden_size={hidden_size})")
 
498
 
499
- except Exception as e:
500
- logger.warning(f"Failed to initialize score linear from LM head: {e}, using fallback scoring")
501
- self.score_linear = None
 
 
502
 
503
- def rerank(self, query: str, documents: list[str]) -> list[float]:
504
- """Rerank documents by relevance to query.
 
505
 
506
- Returns a list of relevance scores (0-1) for each document.
507
- Higher scores indicate more relevant documents.
 
508
  """
509
  if not documents:
510
  return []
511
 
512
- scores = []
513
- for doc in documents:
514
- try:
515
- score = self._score_pair(query, doc)
516
- scores.append(score)
517
- except Exception as e:
518
- logger.warning(f"Reranking failed for document: {e}")
519
- scores.append(0.0)
520
-
521
- return scores
522
-
523
- def _score_pair(self, query: str, document: str) -> float:
524
- """Score a single query-document pair using official Qwen3-VL-Reranker method."""
525
- # Truncate document if too long
526
- max_doc_len = 400
527
- if len(document) > max_doc_len:
528
- document = document[:max_doc_len] + "..."
529
-
530
- # Format as query-document pair
531
- pair_text = f"Query: {query}\n\nDocument: {document}"
532
-
533
  try:
534
- inputs = self.processor(
535
- pair_text,
536
- return_tensors="pt",
537
- padding=True,
538
- truncation=True,
539
- max_length=512,
540
- )
541
-
542
- # Move to model device
543
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
544
-
545
- with torch.no_grad():
546
- outputs = self.model(**inputs)
547
-
548
- # Use LAST token hidden state (not CLS/first token)
549
- # Per official implementation: last_hidden_state[:, -1]
550
- last_token_hidden = outputs.last_hidden_state[:, -1, :]
551
-
552
- if self.score_linear is not None:
553
- # Official scoring: linear(last_token) -> sigmoid
554
- raw_score = self.score_linear(last_token_hidden)
555
- score = torch.sigmoid(raw_score).squeeze(-1).item()
556
- else:
557
- # Fallback: use L2 norm with better scaling
558
- # This is less accurate but provides reasonable ordering
559
- norm = last_token_hidden.norm(dim=-1).item()
560
- score = min(1.0, max(0.0, norm / 50.0)) # Heuristic scaling
561
-
562
- return score
563
 
564
  except Exception as e:
565
- logger.error(f"Reranker scoring failed: {e}")
566
- return 0.0
567
 
568
  def rerank_with_indices(
569
  self, query: str, documents: list[str], top_k: int = None
 
2
 
3
  This module loads the actual Qwen3-VL models for production use.
4
  Requires ~90GB VRAM (4xL4 with 96GB total).
5
+
6
+ Model Loading:
7
+ - Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
8
+ - Embedding: Qwen3VLEmbedder (official scripts from QwenLM/Qwen3-VL-Embedding)
9
+ - Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
10
  """
11
 
12
  import json
 
33
 
34
  def load_all(self) -> "RealModelStack":
35
  """Load all models with device_map='auto' for multi-GPU distribution."""
36
+ from transformers import AutoProcessor
37
 
38
  device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
39
  logger.info(f"Loading models on {device_type}")
 
76
  )
77
  logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
78
 
79
+ # Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
80
  logger.info(f"Loading embedding model: {settings.embedding_model}")
81
  embed_start = time.time()
82
+ from scripts.qwen3_vl import Qwen3VLEmbedder
83
+
84
+ self.models["embedding"] = Qwen3VLEmbedder(
85
+ model_name_or_path=settings.embedding_model,
86
  torch_dtype=torch.bfloat16,
 
 
 
 
 
 
87
  )
88
+ # Processor is internal to Qwen3VLEmbedder, but store reference for compatibility
89
+ self.processors["embedding"] = self.models["embedding"].processor
90
  logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
91
 
92
+ # Reranker model (~16GB in BF16) - Using official Qwen3VLReranker
93
  logger.info(f"Loading reranker model: {settings.reranker_model}")
94
  reranker_start = time.time()
95
+ from scripts.qwen3_vl import Qwen3VLReranker
96
+
97
+ self.models["reranker"] = Qwen3VLReranker(
98
+ model_name_or_path=settings.reranker_model,
99
  torch_dtype=torch.bfloat16,
 
 
 
 
 
 
100
  )
101
+ # Processor is internal to Qwen3VLReranker, but store reference for compatibility
102
+ self.processors["reranker"] = self.models["reranker"].processor
103
  logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
104
 
105
  self.loaded = True
 
371
  class RealEmbeddingModel:
372
  """Wrapper for real embedding model inference.
373
 
374
+ Uses the official Qwen3VLEmbedder from QwenLM/Qwen3-VL-Embedding.
375
+ The model handles last-token pooling and L2 normalization internally.
376
  """
377
 
378
  def __init__(self, model, processor):
379
+ """Initialize with Qwen3VLEmbedder instance.
380
+
381
+ Args:
382
+ model: Qwen3VLEmbedder instance (official loader)
383
+ processor: Processor (stored for compatibility, but model has its own)
384
+ """
385
  self.model = model
386
  self.processor = processor
387
 
388
+ def embed(self, text: str) -> list[float]:
389
+ """Generate embedding for text using official Qwen3VLEmbedder.
 
390
 
391
+ The official model.process() handles:
392
+ - Tokenization and preprocessing
393
+ - Last-token pooling
394
+ - L2 normalization
 
 
 
 
 
395
 
396
+ Args:
397
+ text: Input text to embed
398
 
399
+ Returns:
400
+ List of floats representing the embedding (4096-dim for 8B model)
401
  """
402
  try:
403
+ # Use official process() API - expects list of dicts
404
+ inputs = [{"text": text}]
405
+ embeddings = self.model.process(inputs, normalize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ # embeddings is a tensor of shape (1, hidden_dim)
408
  return embeddings[0].cpu().tolist()
409
 
410
  except Exception as e:
411
  logger.error(f"Embedding generation failed: {e}")
412
  # Return zero vector as fallback (4096-dim per Qwen3-VL-Embedding-8B)
413
+ hidden_size = getattr(self.model.model.config, "hidden_size", 4096)
414
  return [0.0] * hidden_size
415
 
416
  def embed_batch(self, texts: list[str]) -> list[list[float]]:
417
+ """Generate embeddings for a batch of texts.
418
+
419
+ Uses official batch processing for efficiency.
420
+ """
421
+ try:
422
+ inputs = [{"text": text} for text in texts]
423
+ embeddings = self.model.process(inputs, normalize=True)
424
+ return [emb.cpu().tolist() for emb in embeddings]
425
+ except Exception as e:
426
+ logger.error(f"Batch embedding generation failed: {e}")
427
+ hidden_size = getattr(self.model.model.config, "hidden_size", 4096)
428
+ return [[0.0] * hidden_size for _ in texts]
429
 
430
 
431
  class RealRerankerModel:
432
  """Wrapper for real reranker model inference.
433
 
434
+ Uses the official Qwen3VLReranker from QwenLM/Qwen3-VL-Embedding.
435
+ The model handles yes/no scoring internally via:
436
  - Extracts "yes" and "no" token weights from the LM head
437
  - Creates a binary linear layer: weight = yes_weight - no_weight
438
  - Scores = sigmoid(linear(last_token_hidden_state))
 
441
  """
442
 
443
  def __init__(self, model, processor):
444
+ """Initialize with Qwen3VLReranker instance.
 
 
 
445
 
446
+ Args:
447
+ model: Qwen3VLReranker instance (official loader)
448
+ processor: Processor (stored for compatibility, but model has its own)
 
 
449
  """
450
+ self.model = model
451
+ self.processor = processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
+ def rerank(self, query: str, documents: list[str]) -> list[float]:
454
+ """Rerank documents by relevance to query using official Qwen3VLReranker.
455
 
456
+ The official model.process() handles:
457
+ - Proper message formatting
458
+ - Tokenization
459
+ - Yes/no scoring with LM head weights
460
+ - Sigmoid normalization
461
 
462
+ Args:
463
+ query: The search query
464
+ documents: List of documents to rerank
465
 
466
+ Returns:
467
+ List of relevance scores (0-1) for each document.
468
+ Higher scores indicate more relevant documents.
469
  """
470
  if not documents:
471
  return []
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  try:
474
+ # Use official process() API - expects dict with query and documents
475
+ inputs = {
476
+ "instruction": "Retrieve relevant documents for the query.",
477
+ "query": {"text": query},
478
+ "documents": [{"text": doc} for doc in documents],
479
+ }
480
+ scores = self.model.process(inputs)
481
+ return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  except Exception as e:
484
+ logger.error(f"Reranking failed: {e}")
485
+ return [0.0] * len(documents)
486
 
487
  def rerank_with_indices(
488
  self, query: str, documents: list[str], top_k: int = None
scripts/qwen3_vl/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vendored Qwen3-VL embedding and reranker implementations.
2
+
3
+ Source: https://github.com/QwenLM/Qwen3-VL-Embedding
4
+ License: Apache 2.0
5
+
6
+ These are the official loading classes for:
7
+ - Qwen/Qwen3-VL-Embedding-8B
8
+ - Qwen/Qwen3-VL-Reranker-8B
9
+ """
10
+
11
+ from scripts.qwen3_vl.qwen3_vl_embedding import Qwen3VLEmbedder, Qwen3VLForEmbedding
12
+ from scripts.qwen3_vl.qwen3_vl_reranker import Qwen3VLReranker
13
+
14
+ __all__ = ["Qwen3VLEmbedder", "Qwen3VLForEmbedding", "Qwen3VLReranker"]
scripts/qwen3_vl/qwen3_vl_embedding.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Official Qwen3-VL Embedding implementation.
2
+
3
+ Source: https://github.com/QwenLM/Qwen3-VL-Embedding/blob/main/src/models/qwen3_vl_embedding.py
4
+ License: Apache 2.0
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import unicodedata
10
+ import numpy as np
11
+ import logging
12
+
13
+ from PIL import Image
14
+ from dataclasses import dataclass
15
+ from typing import Optional, List, Union, Dict, Any
16
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
17
+ Qwen3VLPreTrainedModel,
18
+ Qwen3VLModel,
19
+ Qwen3VLConfig,
20
+ )
21
+ from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor
22
+ from transformers.modeling_outputs import ModelOutput
23
+ from transformers.processing_utils import Unpack
24
+ from transformers.utils import TransformersKwargs
25
+ from transformers.cache_utils import Cache
26
+ from qwen_vl_utils.vision_process import process_vision_info
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Constants for configuration
31
+ MAX_LENGTH = 8192
32
+ IMAGE_BASE_FACTOR = 16
33
+ IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
34
+ MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR
35
+ MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR
36
+ FPS = 1
37
+ MAX_FRAMES = 64
38
+ FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR
39
+ MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS
40
+ PAD_TOKEN = "<|endoftext|>"
41
+
42
+
43
+ @dataclass
44
+ class Qwen3VLForEmbeddingOutput(ModelOutput):
45
+ last_hidden_state: Optional[torch.FloatTensor] = None
46
+ attention_mask: Optional[torch.Tensor] = None
47
+
48
+
49
+ class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel):
50
+ _checkpoint_conversion_mapping = {}
51
+ accepts_loss_kwargs = False
52
+ config: Qwen3VLConfig
53
+
54
+ def __init__(self, config):
55
+ super().__init__(config)
56
+ self.model = Qwen3VLModel(config)
57
+ self.post_init()
58
+
59
+ def get_input_embeddings(self):
60
+ return self.model.get_input_embeddings()
61
+
62
+ def set_input_embeddings(self, value):
63
+ self.model.set_input_embeddings(value)
64
+
65
+ def set_decoder(self, decoder):
66
+ self.model.set_decoder(decoder)
67
+
68
+ def get_decoder(self):
69
+ return self.model.get_decoder()
70
+
71
+ def get_video_features(
72
+ self,
73
+ pixel_values_videos: torch.FloatTensor,
74
+ video_grid_thw: Optional[torch.LongTensor] = None,
75
+ ):
76
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
77
+
78
+ def get_image_features(
79
+ self,
80
+ pixel_values: torch.FloatTensor,
81
+ image_grid_thw: Optional[torch.LongTensor] = None,
82
+ ):
83
+ return self.model.get_image_features(pixel_values, image_grid_thw)
84
+
85
+ @property
86
+ def language_model(self):
87
+ return self.model.language_model
88
+
89
+ @property
90
+ def visual(self):
91
+ return self.model.visual
92
+
93
+ def forward(
94
+ self,
95
+ input_ids: torch.LongTensor = None,
96
+ attention_mask: Optional[torch.Tensor] = None,
97
+ position_ids: Optional[torch.LongTensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ pixel_values: Optional[torch.Tensor] = None,
101
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
102
+ image_grid_thw: Optional[torch.LongTensor] = None,
103
+ video_grid_thw: Optional[torch.LongTensor] = None,
104
+ cache_position: Optional[torch.LongTensor] = None,
105
+ logits_to_keep: Union[int, torch.Tensor] = 0,
106
+ **kwargs: Unpack[TransformersKwargs],
107
+ ) -> Union[tuple, Qwen3VLForEmbeddingOutput]:
108
+ outputs = self.model(
109
+ input_ids=input_ids,
110
+ pixel_values=pixel_values,
111
+ pixel_values_videos=pixel_values_videos,
112
+ image_grid_thw=image_grid_thw,
113
+ video_grid_thw=video_grid_thw,
114
+ position_ids=position_ids,
115
+ attention_mask=attention_mask,
116
+ past_key_values=past_key_values,
117
+ inputs_embeds=inputs_embeds,
118
+ cache_position=cache_position,
119
+ **kwargs,
120
+ )
121
+ return Qwen3VLForEmbeddingOutput(
122
+ last_hidden_state=outputs.last_hidden_state,
123
+ attention_mask=attention_mask,
124
+ )
125
+
126
+
127
+ def sample_frames(
128
+ frames: List[Union[str, Image.Image]], num_segments: int, max_segments: int
129
+ ) -> List[str]:
130
+ duration = len(frames)
131
+ frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
132
+ frame_id_list = frame_id_array.tolist()
133
+ last_frame_id = frame_id_list[-1]
134
+
135
+ sampled_frames = []
136
+ for frame_idx in frame_id_list:
137
+ try:
138
+ sampled_frames.append(frames[frame_idx])
139
+ except:
140
+ break
141
+ while len(sampled_frames) < num_segments:
142
+ sampled_frames.append(frames[last_frame_id])
143
+ return sampled_frames[:max_segments]
144
+
145
+
146
+ class Qwen3VLEmbedder:
147
+ """Official Qwen3-VL embedding model wrapper.
148
+
149
+ Usage:
150
+ model = Qwen3VLEmbedder(model_name_or_path="Qwen/Qwen3-VL-Embedding-8B")
151
+ embeddings = model.process([{"text": "Hello world"}])
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ model_name_or_path: str,
157
+ max_length: int = MAX_LENGTH,
158
+ min_pixels: int = MIN_PIXELS,
159
+ max_pixels: int = MAX_PIXELS,
160
+ total_pixels: int = MAX_TOTAL_PIXELS,
161
+ fps: float = FPS,
162
+ num_frames: int = MAX_FRAMES,
163
+ max_frames: int = MAX_FRAMES,
164
+ default_instruction: str = "Represent the user's input.",
165
+ **kwargs,
166
+ ):
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+
169
+ self.max_length = max_length
170
+ self.min_pixels = min_pixels
171
+ self.max_pixels = max_pixels
172
+ self.total_pixels = total_pixels
173
+ self.fps = fps
174
+ self.num_frames = num_frames
175
+ self.max_frames = max_frames
176
+ self.default_instruction = default_instruction
177
+
178
+ self.model = Qwen3VLForEmbedding.from_pretrained(
179
+ model_name_or_path, trust_remote_code=True, **kwargs
180
+ ).to(device)
181
+ self.processor = Qwen3VLProcessor.from_pretrained(
182
+ model_name_or_path, padding_side="right"
183
+ )
184
+ self.model.eval()
185
+
186
+ @property
187
+ def device(self):
188
+ return self.model.device
189
+
190
+ @torch.no_grad()
191
+ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
192
+ outputs = self.model(**inputs)
193
+ return {
194
+ "last_hidden_state": outputs.last_hidden_state,
195
+ "attention_mask": inputs.get("attention_mask"),
196
+ }
197
+
198
+ def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]:
199
+ if len(token_ids) <= max_length:
200
+ return token_ids
201
+
202
+ special_token_ids = set(self.processor.tokenizer.all_special_ids)
203
+ num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids)
204
+ num_non_special_to_keep = max_length - num_special
205
+
206
+ final_token_ids = []
207
+ non_special_kept_count = 0
208
+ for token_idx in token_ids:
209
+ if token_idx in special_token_ids:
210
+ final_token_ids.append(token_idx)
211
+ elif non_special_kept_count < num_non_special_to_keep:
212
+ final_token_ids.append(token_idx)
213
+ non_special_kept_count += 1
214
+ return final_token_ids
215
+
216
+ def format_model_input(
217
+ self,
218
+ text: Optional[str] = None,
219
+ image: Optional[Union[str, Image.Image]] = None,
220
+ video: Optional[Union[str, List[Union[str, Image.Image]]]] = None,
221
+ instruction: Optional[str] = None,
222
+ fps: Optional[float] = None,
223
+ max_frames: Optional[int] = None,
224
+ ) -> List[Dict]:
225
+
226
+ if instruction:
227
+ instruction = instruction.strip()
228
+ if instruction and not unicodedata.category(instruction[-1]).startswith("P"):
229
+ instruction = instruction + "."
230
+
231
+ content = []
232
+ conversation = [
233
+ {
234
+ "role": "system",
235
+ "content": [
236
+ {"type": "text", "text": instruction or self.default_instruction}
237
+ ],
238
+ },
239
+ {"role": "user", "content": content},
240
+ ]
241
+
242
+ if not text and not image and not video:
243
+ content.append({"type": "text", "text": "NULL"})
244
+ return conversation
245
+
246
+ if video:
247
+ video_content = None
248
+ video_kwargs = {"total_pixels": self.total_pixels}
249
+ if isinstance(video, list):
250
+ video_content = video
251
+ if self.num_frames is not None or self.max_frames is not None:
252
+ video_content = sample_frames(
253
+ video_content, self.num_frames, self.max_frames
254
+ )
255
+ video_content = [
256
+ ("file://" + ele if isinstance(ele, str) else ele)
257
+ for ele in video_content
258
+ ]
259
+ elif isinstance(video, str):
260
+ video_content = (
261
+ video
262
+ if video.startswith(("http://", "https://"))
263
+ else "file://" + video
264
+ )
265
+ video_kwargs = {
266
+ "fps": fps or self.fps,
267
+ "max_frames": max_frames or self.max_frames,
268
+ }
269
+ else:
270
+ raise TypeError(f"Unrecognized video type: {type(video)}")
271
+
272
+ if video_content:
273
+ content.append({"type": "video", "video": video_content, **video_kwargs})
274
+
275
+ if image:
276
+ image_content = None
277
+ if isinstance(image, Image.Image):
278
+ image_content = image
279
+ elif isinstance(image, str):
280
+ image_content = (
281
+ image if image.startswith(("http", "oss")) else "file://" + image
282
+ )
283
+ else:
284
+ raise TypeError(f"Unrecognized image type: {type(image)}")
285
+
286
+ if image_content:
287
+ content.append(
288
+ {
289
+ "type": "image",
290
+ "image": image_content,
291
+ "min_pixels": self.min_pixels,
292
+ "max_pixels": self.max_pixels,
293
+ }
294
+ )
295
+
296
+ if text:
297
+ content.append({"type": "text", "text": text})
298
+
299
+ return conversation
300
+
301
+ def _preprocess_inputs(
302
+ self, conversations: List[List[Dict]]
303
+ ) -> Dict[str, torch.Tensor]:
304
+ text = self.processor.apply_chat_template(
305
+ conversations, add_generation_prompt=True, tokenize=False
306
+ )
307
+
308
+ try:
309
+ images, video_inputs, video_kwargs = process_vision_info(
310
+ conversations,
311
+ image_patch_size=16,
312
+ return_video_metadata=True,
313
+ return_video_kwargs=True,
314
+ )
315
+ except Exception as e:
316
+ logger.error(f"Error in processing vision info: {e}")
317
+ images = None
318
+ video_inputs = None
319
+ video_kwargs = {"do_sample_frames": False}
320
+ text = self.processor.apply_chat_template(
321
+ [{"role": "user", "content": [{"type": "text", "text": "NULL"}]}],
322
+ add_generation_prompt=True,
323
+ tokenize=False,
324
+ )
325
+
326
+ if video_inputs is not None:
327
+ videos, video_metadata = zip(*video_inputs)
328
+ videos = list(videos)
329
+ video_metadata = list(video_metadata)
330
+ else:
331
+ videos, video_metadata = None, None
332
+
333
+ inputs = self.processor(
334
+ text=text,
335
+ images=images,
336
+ videos=videos,
337
+ video_metadata=video_metadata,
338
+ truncation=True,
339
+ max_length=self.max_length,
340
+ padding=True,
341
+ do_resize=False,
342
+ return_tensors="pt",
343
+ **video_kwargs,
344
+ )
345
+ return inputs
346
+
347
+ @staticmethod
348
+ def _pooling_last(
349
+ hidden_state: torch.Tensor, attention_mask: torch.Tensor
350
+ ) -> torch.Tensor:
351
+ """Extract the last valid token's hidden state based on attention mask."""
352
+ flipped_tensor = attention_mask.flip(dims=[1])
353
+ last_one_positions = flipped_tensor.argmax(dim=1)
354
+ col = attention_mask.shape[1] - last_one_positions - 1
355
+ row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
356
+ return hidden_state[row, col]
357
+
358
+ def process(
359
+ self, inputs: List[Dict[str, Any]], normalize: bool = True
360
+ ) -> torch.Tensor:
361
+ """Generate embeddings for a list of inputs.
362
+
363
+ Args:
364
+ inputs: List of dicts with 'text', 'image', and/or 'video' keys
365
+ normalize: Whether to L2 normalize embeddings (default True)
366
+
367
+ Returns:
368
+ Tensor of shape (batch_size, hidden_dim) with embeddings
369
+ """
370
+ conversations = [
371
+ self.format_model_input(
372
+ text=ele.get("text"),
373
+ image=ele.get("image"),
374
+ video=ele.get("video"),
375
+ instruction=ele.get("instruction"),
376
+ fps=ele.get("fps"),
377
+ max_frames=ele.get("max_frames"),
378
+ )
379
+ for ele in inputs
380
+ ]
381
+
382
+ processed_inputs = self._preprocess_inputs(conversations)
383
+ processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()}
384
+
385
+ outputs = self.forward(processed_inputs)
386
+ embeddings = self._pooling_last(
387
+ outputs["last_hidden_state"], outputs["attention_mask"]
388
+ )
389
+
390
+ if normalize:
391
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
392
+
393
+ return embeddings
scripts/qwen3_vl/qwen3_vl_reranker.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Official Qwen3-VL Reranker implementation.
2
+
3
+ Source: https://github.com/QwenLM/Qwen3-VL-Embedding/blob/main/src/models/qwen3_vl_reranker.py
4
+ License: Apache 2.0
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ import logging
10
+
11
+ from PIL import Image
12
+ from typing import List, Optional, Union, Dict, Any
13
+ from qwen_vl_utils import process_vision_info
14
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ MAX_LENGTH = 8192
19
+ IMAGE_BASE_FACTOR = 16
20
+ IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
21
+ MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
22
+ MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR # 1280 tokens
23
+ MAX_RATIO = 200
24
+
25
+ FRAME_FACTOR = 2
26
+ FPS = 1
27
+ MIN_FRAMES = 2
28
+ MAX_FRAMES = 64
29
+ MIN_TOTAL_PIXELS = 1 * FRAME_FACTOR * MIN_PIXELS # 1 frames
30
+ MAX_TOTAL_PIXELS = 4 * FRAME_FACTOR * MAX_PIXELS # 4 frames
31
+
32
+
33
+ def sample_frames(frames, num_segments, max_segments):
34
+ duration = len(frames)
35
+ frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
36
+ frame_id_list = frame_id_array.tolist()
37
+ last_frame_id = frame_id_list[-1]
38
+
39
+ sampled_frames = []
40
+ for frame_idx in frame_id_list:
41
+ try:
42
+ single_frame_path = frames[frame_idx]
43
+ except:
44
+ break
45
+ sampled_frames.append(single_frame_path)
46
+ # Pad with last frame if total frames less than num_segments
47
+ while len(sampled_frames) < num_segments:
48
+ sampled_frames.append(frames[last_frame_id])
49
+ return sampled_frames[:max_segments]
50
+
51
+
52
+ class Qwen3VLReranker:
53
+ """Official Qwen3-VL reranker model wrapper.
54
+
55
+ Usage:
56
+ model = Qwen3VLReranker(model_name_or_path="Qwen/Qwen3-VL-Reranker-8B")
57
+ scores = model.process({
58
+ "instruction": "Retrieve relevant documents.",
59
+ "query": {"text": "search query"},
60
+ "documents": [{"text": "doc1"}, {"text": "doc2"}]
61
+ })
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ model_name_or_path: str,
67
+ max_length: int = MAX_LENGTH,
68
+ min_pixels: int = MIN_PIXELS,
69
+ max_pixels: int = MAX_PIXELS,
70
+ total_pixels: int = MAX_TOTAL_PIXELS,
71
+ fps: float = FPS,
72
+ num_frames: int = MAX_FRAMES,
73
+ max_frames: int = MAX_FRAMES,
74
+ default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.",
75
+ **kwargs,
76
+ ):
77
+
78
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+
80
+ self.max_length = max_length
81
+ self.min_pixels = min_pixels
82
+ self.max_pixels = max_pixels
83
+ self.total_pixels = total_pixels
84
+ self.fps = fps
85
+ self.num_frames = num_frames
86
+ self.max_frames = max_frames
87
+
88
+ self.default_instruction = default_instruction
89
+
90
+ lm = Qwen3VLForConditionalGeneration.from_pretrained(
91
+ model_name_or_path, trust_remote_code=True, **kwargs
92
+ ).to(self.device)
93
+
94
+ self.model = lm.model
95
+ self.processor = AutoProcessor.from_pretrained(
96
+ model_name_or_path, trust_remote_code=True, padding_side="left"
97
+ )
98
+ self.model.eval()
99
+
100
+ token_true_id = self.processor.tokenizer.get_vocab()["yes"]
101
+ token_false_id = self.processor.tokenizer.get_vocab()["no"]
102
+ self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
103
+ self.score_linear.eval()
104
+ self.score_linear.to(self.device).to(self.model.dtype)
105
+
106
+ logger.info(
107
+ f"Initialized Qwen3VLReranker with yes/no scoring layer (device={self.device})"
108
+ )
109
+
110
+ def get_binary_linear(self, model, token_yes, token_no):
111
+ """Extract yes/no token weights from LM head and create scoring layer."""
112
+ lm_head_weights = model.lm_head.weight.data
113
+
114
+ weight_yes = lm_head_weights[token_yes]
115
+ weight_no = lm_head_weights[token_no]
116
+
117
+ D = weight_yes.size()[0]
118
+ linear_layer = torch.nn.Linear(D, 1, bias=False)
119
+ with torch.no_grad():
120
+ linear_layer.weight[0] = weight_yes - weight_no
121
+ return linear_layer
122
+
123
+ @torch.no_grad()
124
+ def compute_scores(self, inputs):
125
+ """Compute relevance scores using the binary linear layer."""
126
+ batch_scores = self.model(**inputs).last_hidden_state[:, -1]
127
+ scores = self.score_linear(batch_scores)
128
+ scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
129
+ return scores
130
+
131
+ def truncate_tokens_optimized(
132
+ self, tokens: List[str], max_length: int, special_tokens: List[str]
133
+ ) -> List[str]:
134
+ if len(tokens) <= max_length:
135
+ return tokens
136
+
137
+ special_tokens_set = set(special_tokens)
138
+
139
+ # Calculate budget: how many non-special tokens we can keep
140
+ num_special = sum(1 for token in tokens if token in special_tokens_set)
141
+ num_non_special_to_keep = max_length - num_special
142
+
143
+ # Build final list according to budget
144
+ final_tokens = []
145
+ non_special_kept_count = 0
146
+ for token in tokens:
147
+ if token in special_tokens_set:
148
+ final_tokens.append(token)
149
+ elif non_special_kept_count < num_non_special_to_keep:
150
+ final_tokens.append(token)
151
+ non_special_kept_count += 1
152
+
153
+ return final_tokens
154
+
155
+ def tokenize(self, pairs: list, **kwargs):
156
+ max_length = self.max_length
157
+ text = self.processor.apply_chat_template(
158
+ pairs, tokenize=False, add_generation_prompt=True
159
+ )
160
+ try:
161
+ images, videos, video_kwargs = process_vision_info(
162
+ pairs,
163
+ image_patch_size=16,
164
+ return_video_kwargs=True,
165
+ return_video_metadata=True,
166
+ )
167
+ except Exception as e:
168
+ logger.error(f"Error in processing vision info: {e}")
169
+ images = None
170
+ videos = None
171
+ video_kwargs = {"do_sample_frames": False}
172
+ text = self.processor.apply_chat_template(
173
+ [{"role": "user", "content": [{"type": "text", "text": "NULL"}]}],
174
+ add_generation_prompt=True,
175
+ tokenize=False,
176
+ )
177
+
178
+ if videos is not None:
179
+ videos, video_metadatas = zip(*videos)
180
+ videos, video_metadatas = list(videos), list(video_metadatas)
181
+ else:
182
+ video_metadatas = None
183
+ inputs = self.processor(
184
+ text=text,
185
+ images=images,
186
+ videos=videos,
187
+ video_metadata=video_metadatas,
188
+ truncation=False,
189
+ padding=False,
190
+ do_resize=False,
191
+ **video_kwargs,
192
+ )
193
+ for i, ele in enumerate(inputs["input_ids"]):
194
+ inputs["input_ids"][i] = (
195
+ self.truncate_tokens_optimized(
196
+ inputs["input_ids"][i][:-5],
197
+ max_length,
198
+ self.processor.tokenizer.all_special_ids,
199
+ )
200
+ + inputs["input_ids"][i][-5:]
201
+ )
202
+ temp_inputs = self.processor.tokenizer.pad(
203
+ {"input_ids": inputs["input_ids"]},
204
+ padding=True,
205
+ return_tensors="pt",
206
+ max_length=self.max_length,
207
+ )
208
+ for key in temp_inputs:
209
+ inputs[key] = temp_inputs[key]
210
+ return inputs
211
+
212
+ def format_mm_content(
213
+ self,
214
+ text,
215
+ image,
216
+ video,
217
+ prefix="Query:",
218
+ fps=None,
219
+ max_frames=None,
220
+ ):
221
+ content = []
222
+
223
+ content.append({"type": "text", "text": prefix})
224
+ if not text and not image and not video:
225
+ content.append({"type": "text", "text": "NULL"})
226
+ return content
227
+
228
+ if video:
229
+ video_content = None
230
+ video_kwargs = {"total_pixels": self.total_pixels}
231
+ if isinstance(video, list):
232
+ video_content = video
233
+ if self.num_frames is not None or self.max_frames is not None:
234
+ video_content = sample_frames(
235
+ video_content, self.num_frames, self.max_frames
236
+ )
237
+ video_content = [
238
+ ("file://" + ele if isinstance(ele, str) else ele)
239
+ for ele in video_content
240
+ ]
241
+ elif isinstance(video, str):
242
+ video_content = (
243
+ video
244
+ if video.startswith(("http://", "https://"))
245
+ else "file://" + video
246
+ )
247
+ video_kwargs = {
248
+ "fps": fps or self.fps,
249
+ "max_frames": max_frames or self.max_frames,
250
+ }
251
+ else:
252
+ raise TypeError(f"Unrecognized video type: {type(video)}")
253
+
254
+ if video_content:
255
+ content.append({"type": "video", "video": video_content, **video_kwargs})
256
+
257
+ if image:
258
+ image_content = None
259
+ if isinstance(image, Image.Image):
260
+ image_content = image
261
+ elif isinstance(image, str):
262
+ image_content = (
263
+ image if image.startswith(("http", "oss")) else "file://" + image
264
+ )
265
+ else:
266
+ raise TypeError(f"Unrecognized image type: {type(image)}")
267
+
268
+ if image_content:
269
+ content.append(
270
+ {
271
+ "type": "image",
272
+ "image": image_content,
273
+ "min_pixels": self.min_pixels,
274
+ "max_pixels": self.max_pixels,
275
+ }
276
+ )
277
+
278
+ if text:
279
+ content.append({"type": "text", "text": text})
280
+ return content
281
+
282
+ def format_mm_instruction(
283
+ self,
284
+ query_text,
285
+ query_image,
286
+ query_video,
287
+ doc_text,
288
+ doc_image,
289
+ doc_video,
290
+ instruction=None,
291
+ fps=None,
292
+ max_frames=None,
293
+ ):
294
+ inputs = []
295
+ inputs.append(
296
+ {
297
+ "role": "system",
298
+ "content": [
299
+ {
300
+ "type": "text",
301
+ "text": 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".',
302
+ }
303
+ ],
304
+ }
305
+ )
306
+ if isinstance(query_text, tuple):
307
+ instruct, query_text = query_text
308
+ else:
309
+ instruct = instruction
310
+ contents = []
311
+ contents.append({"type": "text", "text": "<Instruct>: " + instruct})
312
+ query_content = self.format_mm_content(
313
+ query_text,
314
+ query_image,
315
+ query_video,
316
+ prefix="<Query>:",
317
+ fps=fps,
318
+ max_frames=max_frames,
319
+ )
320
+ contents.extend(query_content)
321
+ doc_content = self.format_mm_content(
322
+ doc_text,
323
+ doc_image,
324
+ doc_video,
325
+ prefix="\n<Document>:",
326
+ fps=fps,
327
+ max_frames=max_frames,
328
+ )
329
+ contents.extend(doc_content)
330
+ inputs.append({"role": "user", "content": contents})
331
+ return inputs
332
+
333
+ def process(self, inputs: Dict[str, Any]) -> List[float]:
334
+ """Score documents by relevance to query.
335
+
336
+ Args:
337
+ inputs: Dict with 'instruction', 'query', and 'documents' keys.
338
+ query and documents can have 'text', 'image', 'video' fields.
339
+
340
+ Returns:
341
+ List of relevance scores (0-1) for each document.
342
+ """
343
+ instruction = inputs.get("instruction", self.default_instruction)
344
+
345
+ query = inputs.get("query", {})
346
+ documents = inputs.get("documents", [])
347
+ if not query or not documents:
348
+ return []
349
+
350
+ pairs = [
351
+ self.format_mm_instruction(
352
+ query.get("text", None),
353
+ query.get("image", None),
354
+ query.get("video", None),
355
+ document.get("text", None),
356
+ document.get("image", None),
357
+ document.get("video", None),
358
+ instruction=instruction,
359
+ fps=inputs.get("fps", self.fps),
360
+ max_frames=inputs.get("max_frames", self.max_frames),
361
+ )
362
+ for document in documents
363
+ ]
364
+
365
+ final_scores = []
366
+ for pair in pairs:
367
+ tokenized_inputs = self.tokenize([pair])
368
+ tokenized_inputs = tokenized_inputs.to(self.model.device)
369
+ scores = self.compute_scores(tokenized_inputs)
370
+ final_scores.extend(scores)
371
+ return final_scores