Spaces:
Paused
Paused
Commit
·
455c786
1
Parent(s):
f3ebc82
Fix embedding/reranker loading with official Qwen3-VL classes
Browse filesRoot cause: AutoModel.from_pretrained() loads base transformer
instead of specialized embedding/reranking variants.
Changes:
- Vendor official scripts from QwenLM/Qwen3-VL-Embedding repo
- Replace AutoModel with Qwen3VLEmbedder for embedding model
- Replace AutoModel with Qwen3VLReranker for reranker model
- Update embed()/rerank() methods to use official process() API
The official loaders handle:
- Proper last-token pooling and L2 normalization (embedding)
- Yes/no binary scoring from LM head weights (reranker)
This eliminates the fallback L2 norm heuristic scoring that was
producing "less accurate" results.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- models/real.py +86 -167
- scripts/qwen3_vl/__init__.py +14 -0
- scripts/qwen3_vl/qwen3_vl_embedding.py +393 -0
- scripts/qwen3_vl/qwen3_vl_reranker.py +371 -0
models/real.py
CHANGED
|
@@ -2,6 +2,11 @@
|
|
| 2 |
|
| 3 |
This module loads the actual Qwen3-VL models for production use.
|
| 4 |
Requires ~90GB VRAM (4xL4 with 96GB total).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import json
|
|
@@ -28,7 +33,7 @@ class RealModelStack:
|
|
| 28 |
|
| 29 |
def load_all(self) -> "RealModelStack":
|
| 30 |
"""Load all models with device_map='auto' for multi-GPU distribution."""
|
| 31 |
-
from transformers import
|
| 32 |
|
| 33 |
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 34 |
logger.info(f"Loading models on {device_type}")
|
|
@@ -71,34 +76,30 @@ class RealModelStack:
|
|
| 71 |
)
|
| 72 |
logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
|
| 73 |
|
| 74 |
-
# Embedding model (~16GB in BF16)
|
| 75 |
logger.info(f"Loading embedding model: {settings.embedding_model}")
|
| 76 |
embed_start = time.time()
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
torch_dtype=torch.bfloat16,
|
| 80 |
-
device_map="auto",
|
| 81 |
-
trust_remote_code=True,
|
| 82 |
-
)
|
| 83 |
-
self.processors["embedding"] = AutoProcessor.from_pretrained(
|
| 84 |
-
settings.embedding_model,
|
| 85 |
-
trust_remote_code=True,
|
| 86 |
)
|
|
|
|
|
|
|
| 87 |
logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
|
| 88 |
|
| 89 |
-
# Reranker model (~16GB in BF16)
|
| 90 |
logger.info(f"Loading reranker model: {settings.reranker_model}")
|
| 91 |
reranker_start = time.time()
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
torch_dtype=torch.bfloat16,
|
| 95 |
-
device_map="auto",
|
| 96 |
-
trust_remote_code=True,
|
| 97 |
-
)
|
| 98 |
-
self.processors["reranker"] = AutoProcessor.from_pretrained(
|
| 99 |
-
settings.reranker_model,
|
| 100 |
-
trust_remote_code=True,
|
| 101 |
)
|
|
|
|
|
|
|
| 102 |
logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
|
| 103 |
|
| 104 |
self.loaded = True
|
|
@@ -370,80 +371,68 @@ IMPORTANT: Return ONLY valid JSON, no additional text."""
|
|
| 370 |
class RealEmbeddingModel:
|
| 371 |
"""Wrapper for real embedding model inference.
|
| 372 |
|
| 373 |
-
Uses
|
| 374 |
-
|
| 375 |
"""
|
| 376 |
|
| 377 |
def __init__(self, model, processor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
self.model = model
|
| 379 |
self.processor = processor
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
"""Extract the last valid token's hidden state based on attention mask.
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
flipped_tensor = attention_mask.flip(dims=[1])
|
| 390 |
-
last_one_positions = flipped_tensor.argmax(dim=1)
|
| 391 |
-
col = attention_mask.shape[1] - last_one_positions - 1
|
| 392 |
-
row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
|
| 393 |
-
return hidden_state[row, col]
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
"""
|
| 401 |
try:
|
| 402 |
-
#
|
| 403 |
-
inputs =
|
| 404 |
-
|
| 405 |
-
return_tensors="pt",
|
| 406 |
-
padding=True,
|
| 407 |
-
truncation=True,
|
| 408 |
-
max_length=512,
|
| 409 |
-
)
|
| 410 |
-
|
| 411 |
-
# Move to model device
|
| 412 |
-
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 413 |
-
|
| 414 |
-
# Generate embeddings
|
| 415 |
-
with torch.no_grad():
|
| 416 |
-
outputs = self.model(**inputs)
|
| 417 |
-
|
| 418 |
-
# Use last-token pooling (official Qwen3-VL-Embedding method)
|
| 419 |
-
# outputs.last_hidden_state shape: (batch, seq_len, hidden_dim)
|
| 420 |
-
attention_mask = inputs.get("attention_mask")
|
| 421 |
-
if attention_mask is not None:
|
| 422 |
-
embeddings = self._pooling_last(outputs.last_hidden_state, attention_mask)
|
| 423 |
-
else:
|
| 424 |
-
# Fallback: use last token if no attention mask
|
| 425 |
-
embeddings = outputs.last_hidden_state[:, -1, :]
|
| 426 |
-
|
| 427 |
-
# L2 normalize (per official implementation)
|
| 428 |
-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
|
| 429 |
|
|
|
|
| 430 |
return embeddings[0].cpu().tolist()
|
| 431 |
|
| 432 |
except Exception as e:
|
| 433 |
logger.error(f"Embedding generation failed: {e}")
|
| 434 |
# Return zero vector as fallback (4096-dim per Qwen3-VL-Embedding-8B)
|
| 435 |
-
hidden_size = getattr(self.model.config, "hidden_size", 4096)
|
| 436 |
return [0.0] * hidden_size
|
| 437 |
|
| 438 |
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
| 439 |
-
"""Generate embeddings for a batch of texts.
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
|
| 443 |
class RealRerankerModel:
|
| 444 |
"""Wrapper for real reranker model inference.
|
| 445 |
|
| 446 |
-
Uses the official Qwen3-VL-
|
|
|
|
| 447 |
- Extracts "yes" and "no" token weights from the LM head
|
| 448 |
- Creates a binary linear layer: weight = yes_weight - no_weight
|
| 449 |
- Scores = sigmoid(linear(last_token_hidden_state))
|
|
@@ -452,118 +441,48 @@ class RealRerankerModel:
|
|
| 452 |
"""
|
| 453 |
|
| 454 |
def __init__(self, model, processor):
|
| 455 |
-
|
| 456 |
-
self.processor = processor
|
| 457 |
-
self.score_linear = None
|
| 458 |
-
self._initialize_score_linear()
|
| 459 |
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
Per Qwen3-VL-Reranker: the scoring layer uses the difference between
|
| 464 |
-
"yes" and "no" token embeddings from the language model head.
|
| 465 |
"""
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
tokenizer = self.processor.tokenizer if hasattr(self.processor, 'tokenizer') else self.processor
|
| 469 |
-
vocab = tokenizer.get_vocab()
|
| 470 |
-
|
| 471 |
-
# Find yes/no token IDs
|
| 472 |
-
token_yes_id = vocab.get("yes")
|
| 473 |
-
token_no_id = vocab.get("no")
|
| 474 |
-
|
| 475 |
-
if token_yes_id is None or token_no_id is None:
|
| 476 |
-
logger.warning("Could not find 'yes'/'no' tokens in vocab, using fallback scoring")
|
| 477 |
-
return
|
| 478 |
-
|
| 479 |
-
# Get LM head weights
|
| 480 |
-
if not hasattr(self.model, 'lm_head'):
|
| 481 |
-
logger.warning("Model does not have lm_head, using fallback scoring")
|
| 482 |
-
return
|
| 483 |
-
|
| 484 |
-
lm_head_weights = self.model.lm_head.weight.data
|
| 485 |
-
|
| 486 |
-
# Extract yes/no weights
|
| 487 |
-
weight_yes = lm_head_weights[token_yes_id]
|
| 488 |
-
weight_no = lm_head_weights[token_no_id]
|
| 489 |
-
|
| 490 |
-
# Create binary linear layer: weight = yes - no
|
| 491 |
-
hidden_size = weight_yes.shape[0]
|
| 492 |
-
self.score_linear = torch.nn.Linear(hidden_size, 1, bias=False)
|
| 493 |
-
self.score_linear.weight.data[0] = weight_yes - weight_no
|
| 494 |
-
self.score_linear = self.score_linear.to(self.model.device)
|
| 495 |
-
self.score_linear.eval()
|
| 496 |
|
| 497 |
-
|
|
|
|
| 498 |
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
|
|
|
| 505 |
|
| 506 |
-
Returns
|
| 507 |
-
|
|
|
|
| 508 |
"""
|
| 509 |
if not documents:
|
| 510 |
return []
|
| 511 |
|
| 512 |
-
scores = []
|
| 513 |
-
for doc in documents:
|
| 514 |
-
try:
|
| 515 |
-
score = self._score_pair(query, doc)
|
| 516 |
-
scores.append(score)
|
| 517 |
-
except Exception as e:
|
| 518 |
-
logger.warning(f"Reranking failed for document: {e}")
|
| 519 |
-
scores.append(0.0)
|
| 520 |
-
|
| 521 |
-
return scores
|
| 522 |
-
|
| 523 |
-
def _score_pair(self, query: str, document: str) -> float:
|
| 524 |
-
"""Score a single query-document pair using official Qwen3-VL-Reranker method."""
|
| 525 |
-
# Truncate document if too long
|
| 526 |
-
max_doc_len = 400
|
| 527 |
-
if len(document) > max_doc_len:
|
| 528 |
-
document = document[:max_doc_len] + "..."
|
| 529 |
-
|
| 530 |
-
# Format as query-document pair
|
| 531 |
-
pair_text = f"Query: {query}\n\nDocument: {document}"
|
| 532 |
-
|
| 533 |
try:
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
# Move to model device
|
| 543 |
-
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 544 |
-
|
| 545 |
-
with torch.no_grad():
|
| 546 |
-
outputs = self.model(**inputs)
|
| 547 |
-
|
| 548 |
-
# Use LAST token hidden state (not CLS/first token)
|
| 549 |
-
# Per official implementation: last_hidden_state[:, -1]
|
| 550 |
-
last_token_hidden = outputs.last_hidden_state[:, -1, :]
|
| 551 |
-
|
| 552 |
-
if self.score_linear is not None:
|
| 553 |
-
# Official scoring: linear(last_token) -> sigmoid
|
| 554 |
-
raw_score = self.score_linear(last_token_hidden)
|
| 555 |
-
score = torch.sigmoid(raw_score).squeeze(-1).item()
|
| 556 |
-
else:
|
| 557 |
-
# Fallback: use L2 norm with better scaling
|
| 558 |
-
# This is less accurate but provides reasonable ordering
|
| 559 |
-
norm = last_token_hidden.norm(dim=-1).item()
|
| 560 |
-
score = min(1.0, max(0.0, norm / 50.0)) # Heuristic scaling
|
| 561 |
-
|
| 562 |
-
return score
|
| 563 |
|
| 564 |
except Exception as e:
|
| 565 |
-
logger.error(f"
|
| 566 |
-
return 0.0
|
| 567 |
|
| 568 |
def rerank_with_indices(
|
| 569 |
self, query: str, documents: list[str], top_k: int = None
|
|
|
|
| 2 |
|
| 3 |
This module loads the actual Qwen3-VL models for production use.
|
| 4 |
Requires ~90GB VRAM (4xL4 with 96GB total).
|
| 5 |
+
|
| 6 |
+
Model Loading:
|
| 7 |
+
- Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
|
| 8 |
+
- Embedding: Qwen3VLEmbedder (official scripts from QwenLM/Qwen3-VL-Embedding)
|
| 9 |
+
- Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
|
| 10 |
"""
|
| 11 |
|
| 12 |
import json
|
|
|
|
| 33 |
|
| 34 |
def load_all(self) -> "RealModelStack":
|
| 35 |
"""Load all models with device_map='auto' for multi-GPU distribution."""
|
| 36 |
+
from transformers import AutoProcessor
|
| 37 |
|
| 38 |
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 39 |
logger.info(f"Loading models on {device_type}")
|
|
|
|
| 76 |
)
|
| 77 |
logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
|
| 78 |
|
| 79 |
+
# Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
|
| 80 |
logger.info(f"Loading embedding model: {settings.embedding_model}")
|
| 81 |
embed_start = time.time()
|
| 82 |
+
from scripts.qwen3_vl import Qwen3VLEmbedder
|
| 83 |
+
|
| 84 |
+
self.models["embedding"] = Qwen3VLEmbedder(
|
| 85 |
+
model_name_or_path=settings.embedding_model,
|
| 86 |
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
+
# Processor is internal to Qwen3VLEmbedder, but store reference for compatibility
|
| 89 |
+
self.processors["embedding"] = self.models["embedding"].processor
|
| 90 |
logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
|
| 91 |
|
| 92 |
+
# Reranker model (~16GB in BF16) - Using official Qwen3VLReranker
|
| 93 |
logger.info(f"Loading reranker model: {settings.reranker_model}")
|
| 94 |
reranker_start = time.time()
|
| 95 |
+
from scripts.qwen3_vl import Qwen3VLReranker
|
| 96 |
+
|
| 97 |
+
self.models["reranker"] = Qwen3VLReranker(
|
| 98 |
+
model_name_or_path=settings.reranker_model,
|
| 99 |
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
+
# Processor is internal to Qwen3VLReranker, but store reference for compatibility
|
| 102 |
+
self.processors["reranker"] = self.models["reranker"].processor
|
| 103 |
logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
|
| 104 |
|
| 105 |
self.loaded = True
|
|
|
|
| 371 |
class RealEmbeddingModel:
|
| 372 |
"""Wrapper for real embedding model inference.
|
| 373 |
|
| 374 |
+
Uses the official Qwen3VLEmbedder from QwenLM/Qwen3-VL-Embedding.
|
| 375 |
+
The model handles last-token pooling and L2 normalization internally.
|
| 376 |
"""
|
| 377 |
|
| 378 |
def __init__(self, model, processor):
|
| 379 |
+
"""Initialize with Qwen3VLEmbedder instance.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
model: Qwen3VLEmbedder instance (official loader)
|
| 383 |
+
processor: Processor (stored for compatibility, but model has its own)
|
| 384 |
+
"""
|
| 385 |
self.model = model
|
| 386 |
self.processor = processor
|
| 387 |
|
| 388 |
+
def embed(self, text: str) -> list[float]:
|
| 389 |
+
"""Generate embedding for text using official Qwen3VLEmbedder.
|
|
|
|
| 390 |
|
| 391 |
+
The official model.process() handles:
|
| 392 |
+
- Tokenization and preprocessing
|
| 393 |
+
- Last-token pooling
|
| 394 |
+
- L2 normalization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
+
Args:
|
| 397 |
+
text: Input text to embed
|
| 398 |
|
| 399 |
+
Returns:
|
| 400 |
+
List of floats representing the embedding (4096-dim for 8B model)
|
| 401 |
"""
|
| 402 |
try:
|
| 403 |
+
# Use official process() API - expects list of dicts
|
| 404 |
+
inputs = [{"text": text}]
|
| 405 |
+
embeddings = self.model.process(inputs, normalize=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
# embeddings is a tensor of shape (1, hidden_dim)
|
| 408 |
return embeddings[0].cpu().tolist()
|
| 409 |
|
| 410 |
except Exception as e:
|
| 411 |
logger.error(f"Embedding generation failed: {e}")
|
| 412 |
# Return zero vector as fallback (4096-dim per Qwen3-VL-Embedding-8B)
|
| 413 |
+
hidden_size = getattr(self.model.model.config, "hidden_size", 4096)
|
| 414 |
return [0.0] * hidden_size
|
| 415 |
|
| 416 |
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
| 417 |
+
"""Generate embeddings for a batch of texts.
|
| 418 |
+
|
| 419 |
+
Uses official batch processing for efficiency.
|
| 420 |
+
"""
|
| 421 |
+
try:
|
| 422 |
+
inputs = [{"text": text} for text in texts]
|
| 423 |
+
embeddings = self.model.process(inputs, normalize=True)
|
| 424 |
+
return [emb.cpu().tolist() for emb in embeddings]
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.error(f"Batch embedding generation failed: {e}")
|
| 427 |
+
hidden_size = getattr(self.model.model.config, "hidden_size", 4096)
|
| 428 |
+
return [[0.0] * hidden_size for _ in texts]
|
| 429 |
|
| 430 |
|
| 431 |
class RealRerankerModel:
|
| 432 |
"""Wrapper for real reranker model inference.
|
| 433 |
|
| 434 |
+
Uses the official Qwen3VLReranker from QwenLM/Qwen3-VL-Embedding.
|
| 435 |
+
The model handles yes/no scoring internally via:
|
| 436 |
- Extracts "yes" and "no" token weights from the LM head
|
| 437 |
- Creates a binary linear layer: weight = yes_weight - no_weight
|
| 438 |
- Scores = sigmoid(linear(last_token_hidden_state))
|
|
|
|
| 441 |
"""
|
| 442 |
|
| 443 |
def __init__(self, model, processor):
|
| 444 |
+
"""Initialize with Qwen3VLReranker instance.
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
+
Args:
|
| 447 |
+
model: Qwen3VLReranker instance (official loader)
|
| 448 |
+
processor: Processor (stored for compatibility, but model has its own)
|
|
|
|
|
|
|
| 449 |
"""
|
| 450 |
+
self.model = model
|
| 451 |
+
self.processor = processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
+
def rerank(self, query: str, documents: list[str]) -> list[float]:
|
| 454 |
+
"""Rerank documents by relevance to query using official Qwen3VLReranker.
|
| 455 |
|
| 456 |
+
The official model.process() handles:
|
| 457 |
+
- Proper message formatting
|
| 458 |
+
- Tokenization
|
| 459 |
+
- Yes/no scoring with LM head weights
|
| 460 |
+
- Sigmoid normalization
|
| 461 |
|
| 462 |
+
Args:
|
| 463 |
+
query: The search query
|
| 464 |
+
documents: List of documents to rerank
|
| 465 |
|
| 466 |
+
Returns:
|
| 467 |
+
List of relevance scores (0-1) for each document.
|
| 468 |
+
Higher scores indicate more relevant documents.
|
| 469 |
"""
|
| 470 |
if not documents:
|
| 471 |
return []
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
try:
|
| 474 |
+
# Use official process() API - expects dict with query and documents
|
| 475 |
+
inputs = {
|
| 476 |
+
"instruction": "Retrieve relevant documents for the query.",
|
| 477 |
+
"query": {"text": query},
|
| 478 |
+
"documents": [{"text": doc} for doc in documents],
|
| 479 |
+
}
|
| 480 |
+
scores = self.model.process(inputs)
|
| 481 |
+
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
except Exception as e:
|
| 484 |
+
logger.error(f"Reranking failed: {e}")
|
| 485 |
+
return [0.0] * len(documents)
|
| 486 |
|
| 487 |
def rerank_with_indices(
|
| 488 |
self, query: str, documents: list[str], top_k: int = None
|
scripts/qwen3_vl/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vendored Qwen3-VL embedding and reranker implementations.
|
| 2 |
+
|
| 3 |
+
Source: https://github.com/QwenLM/Qwen3-VL-Embedding
|
| 4 |
+
License: Apache 2.0
|
| 5 |
+
|
| 6 |
+
These are the official loading classes for:
|
| 7 |
+
- Qwen/Qwen3-VL-Embedding-8B
|
| 8 |
+
- Qwen/Qwen3-VL-Reranker-8B
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from scripts.qwen3_vl.qwen3_vl_embedding import Qwen3VLEmbedder, Qwen3VLForEmbedding
|
| 12 |
+
from scripts.qwen3_vl.qwen3_vl_reranker import Qwen3VLReranker
|
| 13 |
+
|
| 14 |
+
__all__ = ["Qwen3VLEmbedder", "Qwen3VLForEmbedding", "Qwen3VLReranker"]
|
scripts/qwen3_vl/qwen3_vl_embedding.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Official Qwen3-VL Embedding implementation.
|
| 2 |
+
|
| 3 |
+
Source: https://github.com/QwenLM/Qwen3-VL-Embedding/blob/main/src/models/qwen3_vl_embedding.py
|
| 4 |
+
License: Apache 2.0
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import unicodedata
|
| 10 |
+
import numpy as np
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional, List, Union, Dict, Any
|
| 16 |
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
|
| 17 |
+
Qwen3VLPreTrainedModel,
|
| 18 |
+
Qwen3VLModel,
|
| 19 |
+
Qwen3VLConfig,
|
| 20 |
+
)
|
| 21 |
+
from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor
|
| 22 |
+
from transformers.modeling_outputs import ModelOutput
|
| 23 |
+
from transformers.processing_utils import Unpack
|
| 24 |
+
from transformers.utils import TransformersKwargs
|
| 25 |
+
from transformers.cache_utils import Cache
|
| 26 |
+
from qwen_vl_utils.vision_process import process_vision_info
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# Constants for configuration
|
| 31 |
+
MAX_LENGTH = 8192
|
| 32 |
+
IMAGE_BASE_FACTOR = 16
|
| 33 |
+
IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
|
| 34 |
+
MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR
|
| 35 |
+
MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR
|
| 36 |
+
FPS = 1
|
| 37 |
+
MAX_FRAMES = 64
|
| 38 |
+
FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR
|
| 39 |
+
MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS
|
| 40 |
+
PAD_TOKEN = "<|endoftext|>"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class Qwen3VLForEmbeddingOutput(ModelOutput):
|
| 45 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 46 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel):
|
| 50 |
+
_checkpoint_conversion_mapping = {}
|
| 51 |
+
accepts_loss_kwargs = False
|
| 52 |
+
config: Qwen3VLConfig
|
| 53 |
+
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
super().__init__(config)
|
| 56 |
+
self.model = Qwen3VLModel(config)
|
| 57 |
+
self.post_init()
|
| 58 |
+
|
| 59 |
+
def get_input_embeddings(self):
|
| 60 |
+
return self.model.get_input_embeddings()
|
| 61 |
+
|
| 62 |
+
def set_input_embeddings(self, value):
|
| 63 |
+
self.model.set_input_embeddings(value)
|
| 64 |
+
|
| 65 |
+
def set_decoder(self, decoder):
|
| 66 |
+
self.model.set_decoder(decoder)
|
| 67 |
+
|
| 68 |
+
def get_decoder(self):
|
| 69 |
+
return self.model.get_decoder()
|
| 70 |
+
|
| 71 |
+
def get_video_features(
|
| 72 |
+
self,
|
| 73 |
+
pixel_values_videos: torch.FloatTensor,
|
| 74 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 75 |
+
):
|
| 76 |
+
return self.model.get_video_features(pixel_values_videos, video_grid_thw)
|
| 77 |
+
|
| 78 |
+
def get_image_features(
|
| 79 |
+
self,
|
| 80 |
+
pixel_values: torch.FloatTensor,
|
| 81 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 82 |
+
):
|
| 83 |
+
return self.model.get_image_features(pixel_values, image_grid_thw)
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def language_model(self):
|
| 87 |
+
return self.model.language_model
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def visual(self):
|
| 91 |
+
return self.model.visual
|
| 92 |
+
|
| 93 |
+
def forward(
|
| 94 |
+
self,
|
| 95 |
+
input_ids: torch.LongTensor = None,
|
| 96 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 97 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 98 |
+
past_key_values: Optional[Cache] = None,
|
| 99 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 100 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 101 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 102 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 103 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 104 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 105 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 106 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 107 |
+
) -> Union[tuple, Qwen3VLForEmbeddingOutput]:
|
| 108 |
+
outputs = self.model(
|
| 109 |
+
input_ids=input_ids,
|
| 110 |
+
pixel_values=pixel_values,
|
| 111 |
+
pixel_values_videos=pixel_values_videos,
|
| 112 |
+
image_grid_thw=image_grid_thw,
|
| 113 |
+
video_grid_thw=video_grid_thw,
|
| 114 |
+
position_ids=position_ids,
|
| 115 |
+
attention_mask=attention_mask,
|
| 116 |
+
past_key_values=past_key_values,
|
| 117 |
+
inputs_embeds=inputs_embeds,
|
| 118 |
+
cache_position=cache_position,
|
| 119 |
+
**kwargs,
|
| 120 |
+
)
|
| 121 |
+
return Qwen3VLForEmbeddingOutput(
|
| 122 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 123 |
+
attention_mask=attention_mask,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def sample_frames(
|
| 128 |
+
frames: List[Union[str, Image.Image]], num_segments: int, max_segments: int
|
| 129 |
+
) -> List[str]:
|
| 130 |
+
duration = len(frames)
|
| 131 |
+
frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
|
| 132 |
+
frame_id_list = frame_id_array.tolist()
|
| 133 |
+
last_frame_id = frame_id_list[-1]
|
| 134 |
+
|
| 135 |
+
sampled_frames = []
|
| 136 |
+
for frame_idx in frame_id_list:
|
| 137 |
+
try:
|
| 138 |
+
sampled_frames.append(frames[frame_idx])
|
| 139 |
+
except:
|
| 140 |
+
break
|
| 141 |
+
while len(sampled_frames) < num_segments:
|
| 142 |
+
sampled_frames.append(frames[last_frame_id])
|
| 143 |
+
return sampled_frames[:max_segments]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class Qwen3VLEmbedder:
|
| 147 |
+
"""Official Qwen3-VL embedding model wrapper.
|
| 148 |
+
|
| 149 |
+
Usage:
|
| 150 |
+
model = Qwen3VLEmbedder(model_name_or_path="Qwen/Qwen3-VL-Embedding-8B")
|
| 151 |
+
embeddings = model.process([{"text": "Hello world"}])
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
model_name_or_path: str,
|
| 157 |
+
max_length: int = MAX_LENGTH,
|
| 158 |
+
min_pixels: int = MIN_PIXELS,
|
| 159 |
+
max_pixels: int = MAX_PIXELS,
|
| 160 |
+
total_pixels: int = MAX_TOTAL_PIXELS,
|
| 161 |
+
fps: float = FPS,
|
| 162 |
+
num_frames: int = MAX_FRAMES,
|
| 163 |
+
max_frames: int = MAX_FRAMES,
|
| 164 |
+
default_instruction: str = "Represent the user's input.",
|
| 165 |
+
**kwargs,
|
| 166 |
+
):
|
| 167 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 168 |
+
|
| 169 |
+
self.max_length = max_length
|
| 170 |
+
self.min_pixels = min_pixels
|
| 171 |
+
self.max_pixels = max_pixels
|
| 172 |
+
self.total_pixels = total_pixels
|
| 173 |
+
self.fps = fps
|
| 174 |
+
self.num_frames = num_frames
|
| 175 |
+
self.max_frames = max_frames
|
| 176 |
+
self.default_instruction = default_instruction
|
| 177 |
+
|
| 178 |
+
self.model = Qwen3VLForEmbedding.from_pretrained(
|
| 179 |
+
model_name_or_path, trust_remote_code=True, **kwargs
|
| 180 |
+
).to(device)
|
| 181 |
+
self.processor = Qwen3VLProcessor.from_pretrained(
|
| 182 |
+
model_name_or_path, padding_side="right"
|
| 183 |
+
)
|
| 184 |
+
self.model.eval()
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def device(self):
|
| 188 |
+
return self.model.device
|
| 189 |
+
|
| 190 |
+
@torch.no_grad()
|
| 191 |
+
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
| 192 |
+
outputs = self.model(**inputs)
|
| 193 |
+
return {
|
| 194 |
+
"last_hidden_state": outputs.last_hidden_state,
|
| 195 |
+
"attention_mask": inputs.get("attention_mask"),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]:
|
| 199 |
+
if len(token_ids) <= max_length:
|
| 200 |
+
return token_ids
|
| 201 |
+
|
| 202 |
+
special_token_ids = set(self.processor.tokenizer.all_special_ids)
|
| 203 |
+
num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids)
|
| 204 |
+
num_non_special_to_keep = max_length - num_special
|
| 205 |
+
|
| 206 |
+
final_token_ids = []
|
| 207 |
+
non_special_kept_count = 0
|
| 208 |
+
for token_idx in token_ids:
|
| 209 |
+
if token_idx in special_token_ids:
|
| 210 |
+
final_token_ids.append(token_idx)
|
| 211 |
+
elif non_special_kept_count < num_non_special_to_keep:
|
| 212 |
+
final_token_ids.append(token_idx)
|
| 213 |
+
non_special_kept_count += 1
|
| 214 |
+
return final_token_ids
|
| 215 |
+
|
| 216 |
+
def format_model_input(
|
| 217 |
+
self,
|
| 218 |
+
text: Optional[str] = None,
|
| 219 |
+
image: Optional[Union[str, Image.Image]] = None,
|
| 220 |
+
video: Optional[Union[str, List[Union[str, Image.Image]]]] = None,
|
| 221 |
+
instruction: Optional[str] = None,
|
| 222 |
+
fps: Optional[float] = None,
|
| 223 |
+
max_frames: Optional[int] = None,
|
| 224 |
+
) -> List[Dict]:
|
| 225 |
+
|
| 226 |
+
if instruction:
|
| 227 |
+
instruction = instruction.strip()
|
| 228 |
+
if instruction and not unicodedata.category(instruction[-1]).startswith("P"):
|
| 229 |
+
instruction = instruction + "."
|
| 230 |
+
|
| 231 |
+
content = []
|
| 232 |
+
conversation = [
|
| 233 |
+
{
|
| 234 |
+
"role": "system",
|
| 235 |
+
"content": [
|
| 236 |
+
{"type": "text", "text": instruction or self.default_instruction}
|
| 237 |
+
],
|
| 238 |
+
},
|
| 239 |
+
{"role": "user", "content": content},
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
if not text and not image and not video:
|
| 243 |
+
content.append({"type": "text", "text": "NULL"})
|
| 244 |
+
return conversation
|
| 245 |
+
|
| 246 |
+
if video:
|
| 247 |
+
video_content = None
|
| 248 |
+
video_kwargs = {"total_pixels": self.total_pixels}
|
| 249 |
+
if isinstance(video, list):
|
| 250 |
+
video_content = video
|
| 251 |
+
if self.num_frames is not None or self.max_frames is not None:
|
| 252 |
+
video_content = sample_frames(
|
| 253 |
+
video_content, self.num_frames, self.max_frames
|
| 254 |
+
)
|
| 255 |
+
video_content = [
|
| 256 |
+
("file://" + ele if isinstance(ele, str) else ele)
|
| 257 |
+
for ele in video_content
|
| 258 |
+
]
|
| 259 |
+
elif isinstance(video, str):
|
| 260 |
+
video_content = (
|
| 261 |
+
video
|
| 262 |
+
if video.startswith(("http://", "https://"))
|
| 263 |
+
else "file://" + video
|
| 264 |
+
)
|
| 265 |
+
video_kwargs = {
|
| 266 |
+
"fps": fps or self.fps,
|
| 267 |
+
"max_frames": max_frames or self.max_frames,
|
| 268 |
+
}
|
| 269 |
+
else:
|
| 270 |
+
raise TypeError(f"Unrecognized video type: {type(video)}")
|
| 271 |
+
|
| 272 |
+
if video_content:
|
| 273 |
+
content.append({"type": "video", "video": video_content, **video_kwargs})
|
| 274 |
+
|
| 275 |
+
if image:
|
| 276 |
+
image_content = None
|
| 277 |
+
if isinstance(image, Image.Image):
|
| 278 |
+
image_content = image
|
| 279 |
+
elif isinstance(image, str):
|
| 280 |
+
image_content = (
|
| 281 |
+
image if image.startswith(("http", "oss")) else "file://" + image
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
raise TypeError(f"Unrecognized image type: {type(image)}")
|
| 285 |
+
|
| 286 |
+
if image_content:
|
| 287 |
+
content.append(
|
| 288 |
+
{
|
| 289 |
+
"type": "image",
|
| 290 |
+
"image": image_content,
|
| 291 |
+
"min_pixels": self.min_pixels,
|
| 292 |
+
"max_pixels": self.max_pixels,
|
| 293 |
+
}
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if text:
|
| 297 |
+
content.append({"type": "text", "text": text})
|
| 298 |
+
|
| 299 |
+
return conversation
|
| 300 |
+
|
| 301 |
+
def _preprocess_inputs(
|
| 302 |
+
self, conversations: List[List[Dict]]
|
| 303 |
+
) -> Dict[str, torch.Tensor]:
|
| 304 |
+
text = self.processor.apply_chat_template(
|
| 305 |
+
conversations, add_generation_prompt=True, tokenize=False
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
images, video_inputs, video_kwargs = process_vision_info(
|
| 310 |
+
conversations,
|
| 311 |
+
image_patch_size=16,
|
| 312 |
+
return_video_metadata=True,
|
| 313 |
+
return_video_kwargs=True,
|
| 314 |
+
)
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.error(f"Error in processing vision info: {e}")
|
| 317 |
+
images = None
|
| 318 |
+
video_inputs = None
|
| 319 |
+
video_kwargs = {"do_sample_frames": False}
|
| 320 |
+
text = self.processor.apply_chat_template(
|
| 321 |
+
[{"role": "user", "content": [{"type": "text", "text": "NULL"}]}],
|
| 322 |
+
add_generation_prompt=True,
|
| 323 |
+
tokenize=False,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if video_inputs is not None:
|
| 327 |
+
videos, video_metadata = zip(*video_inputs)
|
| 328 |
+
videos = list(videos)
|
| 329 |
+
video_metadata = list(video_metadata)
|
| 330 |
+
else:
|
| 331 |
+
videos, video_metadata = None, None
|
| 332 |
+
|
| 333 |
+
inputs = self.processor(
|
| 334 |
+
text=text,
|
| 335 |
+
images=images,
|
| 336 |
+
videos=videos,
|
| 337 |
+
video_metadata=video_metadata,
|
| 338 |
+
truncation=True,
|
| 339 |
+
max_length=self.max_length,
|
| 340 |
+
padding=True,
|
| 341 |
+
do_resize=False,
|
| 342 |
+
return_tensors="pt",
|
| 343 |
+
**video_kwargs,
|
| 344 |
+
)
|
| 345 |
+
return inputs
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
def _pooling_last(
|
| 349 |
+
hidden_state: torch.Tensor, attention_mask: torch.Tensor
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""Extract the last valid token's hidden state based on attention mask."""
|
| 352 |
+
flipped_tensor = attention_mask.flip(dims=[1])
|
| 353 |
+
last_one_positions = flipped_tensor.argmax(dim=1)
|
| 354 |
+
col = attention_mask.shape[1] - last_one_positions - 1
|
| 355 |
+
row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
|
| 356 |
+
return hidden_state[row, col]
|
| 357 |
+
|
| 358 |
+
def process(
|
| 359 |
+
self, inputs: List[Dict[str, Any]], normalize: bool = True
|
| 360 |
+
) -> torch.Tensor:
|
| 361 |
+
"""Generate embeddings for a list of inputs.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
inputs: List of dicts with 'text', 'image', and/or 'video' keys
|
| 365 |
+
normalize: Whether to L2 normalize embeddings (default True)
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Tensor of shape (batch_size, hidden_dim) with embeddings
|
| 369 |
+
"""
|
| 370 |
+
conversations = [
|
| 371 |
+
self.format_model_input(
|
| 372 |
+
text=ele.get("text"),
|
| 373 |
+
image=ele.get("image"),
|
| 374 |
+
video=ele.get("video"),
|
| 375 |
+
instruction=ele.get("instruction"),
|
| 376 |
+
fps=ele.get("fps"),
|
| 377 |
+
max_frames=ele.get("max_frames"),
|
| 378 |
+
)
|
| 379 |
+
for ele in inputs
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
processed_inputs = self._preprocess_inputs(conversations)
|
| 383 |
+
processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()}
|
| 384 |
+
|
| 385 |
+
outputs = self.forward(processed_inputs)
|
| 386 |
+
embeddings = self._pooling_last(
|
| 387 |
+
outputs["last_hidden_state"], outputs["attention_mask"]
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if normalize:
|
| 391 |
+
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
| 392 |
+
|
| 393 |
+
return embeddings
|
scripts/qwen3_vl/qwen3_vl_reranker.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Official Qwen3-VL Reranker implementation.
|
| 2 |
+
|
| 3 |
+
Source: https://github.com/QwenLM/Qwen3-VL-Embedding/blob/main/src/models/qwen3_vl_reranker.py
|
| 4 |
+
License: Apache 2.0
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from typing import List, Optional, Union, Dict, Any
|
| 13 |
+
from qwen_vl_utils import process_vision_info
|
| 14 |
+
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
MAX_LENGTH = 8192
|
| 19 |
+
IMAGE_BASE_FACTOR = 16
|
| 20 |
+
IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
|
| 21 |
+
MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4 tokens
|
| 22 |
+
MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR # 1280 tokens
|
| 23 |
+
MAX_RATIO = 200
|
| 24 |
+
|
| 25 |
+
FRAME_FACTOR = 2
|
| 26 |
+
FPS = 1
|
| 27 |
+
MIN_FRAMES = 2
|
| 28 |
+
MAX_FRAMES = 64
|
| 29 |
+
MIN_TOTAL_PIXELS = 1 * FRAME_FACTOR * MIN_PIXELS # 1 frames
|
| 30 |
+
MAX_TOTAL_PIXELS = 4 * FRAME_FACTOR * MAX_PIXELS # 4 frames
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sample_frames(frames, num_segments, max_segments):
|
| 34 |
+
duration = len(frames)
|
| 35 |
+
frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
|
| 36 |
+
frame_id_list = frame_id_array.tolist()
|
| 37 |
+
last_frame_id = frame_id_list[-1]
|
| 38 |
+
|
| 39 |
+
sampled_frames = []
|
| 40 |
+
for frame_idx in frame_id_list:
|
| 41 |
+
try:
|
| 42 |
+
single_frame_path = frames[frame_idx]
|
| 43 |
+
except:
|
| 44 |
+
break
|
| 45 |
+
sampled_frames.append(single_frame_path)
|
| 46 |
+
# Pad with last frame if total frames less than num_segments
|
| 47 |
+
while len(sampled_frames) < num_segments:
|
| 48 |
+
sampled_frames.append(frames[last_frame_id])
|
| 49 |
+
return sampled_frames[:max_segments]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Qwen3VLReranker:
|
| 53 |
+
"""Official Qwen3-VL reranker model wrapper.
|
| 54 |
+
|
| 55 |
+
Usage:
|
| 56 |
+
model = Qwen3VLReranker(model_name_or_path="Qwen/Qwen3-VL-Reranker-8B")
|
| 57 |
+
scores = model.process({
|
| 58 |
+
"instruction": "Retrieve relevant documents.",
|
| 59 |
+
"query": {"text": "search query"},
|
| 60 |
+
"documents": [{"text": "doc1"}, {"text": "doc2"}]
|
| 61 |
+
})
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
model_name_or_path: str,
|
| 67 |
+
max_length: int = MAX_LENGTH,
|
| 68 |
+
min_pixels: int = MIN_PIXELS,
|
| 69 |
+
max_pixels: int = MAX_PIXELS,
|
| 70 |
+
total_pixels: int = MAX_TOTAL_PIXELS,
|
| 71 |
+
fps: float = FPS,
|
| 72 |
+
num_frames: int = MAX_FRAMES,
|
| 73 |
+
max_frames: int = MAX_FRAMES,
|
| 74 |
+
default_instruction: str = "Given a search query, retrieve relevant candidates that answer the query.",
|
| 75 |
+
**kwargs,
|
| 76 |
+
):
|
| 77 |
+
|
| 78 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 79 |
+
|
| 80 |
+
self.max_length = max_length
|
| 81 |
+
self.min_pixels = min_pixels
|
| 82 |
+
self.max_pixels = max_pixels
|
| 83 |
+
self.total_pixels = total_pixels
|
| 84 |
+
self.fps = fps
|
| 85 |
+
self.num_frames = num_frames
|
| 86 |
+
self.max_frames = max_frames
|
| 87 |
+
|
| 88 |
+
self.default_instruction = default_instruction
|
| 89 |
+
|
| 90 |
+
lm = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 91 |
+
model_name_or_path, trust_remote_code=True, **kwargs
|
| 92 |
+
).to(self.device)
|
| 93 |
+
|
| 94 |
+
self.model = lm.model
|
| 95 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 96 |
+
model_name_or_path, trust_remote_code=True, padding_side="left"
|
| 97 |
+
)
|
| 98 |
+
self.model.eval()
|
| 99 |
+
|
| 100 |
+
token_true_id = self.processor.tokenizer.get_vocab()["yes"]
|
| 101 |
+
token_false_id = self.processor.tokenizer.get_vocab()["no"]
|
| 102 |
+
self.score_linear = self.get_binary_linear(lm, token_true_id, token_false_id)
|
| 103 |
+
self.score_linear.eval()
|
| 104 |
+
self.score_linear.to(self.device).to(self.model.dtype)
|
| 105 |
+
|
| 106 |
+
logger.info(
|
| 107 |
+
f"Initialized Qwen3VLReranker with yes/no scoring layer (device={self.device})"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def get_binary_linear(self, model, token_yes, token_no):
|
| 111 |
+
"""Extract yes/no token weights from LM head and create scoring layer."""
|
| 112 |
+
lm_head_weights = model.lm_head.weight.data
|
| 113 |
+
|
| 114 |
+
weight_yes = lm_head_weights[token_yes]
|
| 115 |
+
weight_no = lm_head_weights[token_no]
|
| 116 |
+
|
| 117 |
+
D = weight_yes.size()[0]
|
| 118 |
+
linear_layer = torch.nn.Linear(D, 1, bias=False)
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
linear_layer.weight[0] = weight_yes - weight_no
|
| 121 |
+
return linear_layer
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def compute_scores(self, inputs):
|
| 125 |
+
"""Compute relevance scores using the binary linear layer."""
|
| 126 |
+
batch_scores = self.model(**inputs).last_hidden_state[:, -1]
|
| 127 |
+
scores = self.score_linear(batch_scores)
|
| 128 |
+
scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
|
| 129 |
+
return scores
|
| 130 |
+
|
| 131 |
+
def truncate_tokens_optimized(
|
| 132 |
+
self, tokens: List[str], max_length: int, special_tokens: List[str]
|
| 133 |
+
) -> List[str]:
|
| 134 |
+
if len(tokens) <= max_length:
|
| 135 |
+
return tokens
|
| 136 |
+
|
| 137 |
+
special_tokens_set = set(special_tokens)
|
| 138 |
+
|
| 139 |
+
# Calculate budget: how many non-special tokens we can keep
|
| 140 |
+
num_special = sum(1 for token in tokens if token in special_tokens_set)
|
| 141 |
+
num_non_special_to_keep = max_length - num_special
|
| 142 |
+
|
| 143 |
+
# Build final list according to budget
|
| 144 |
+
final_tokens = []
|
| 145 |
+
non_special_kept_count = 0
|
| 146 |
+
for token in tokens:
|
| 147 |
+
if token in special_tokens_set:
|
| 148 |
+
final_tokens.append(token)
|
| 149 |
+
elif non_special_kept_count < num_non_special_to_keep:
|
| 150 |
+
final_tokens.append(token)
|
| 151 |
+
non_special_kept_count += 1
|
| 152 |
+
|
| 153 |
+
return final_tokens
|
| 154 |
+
|
| 155 |
+
def tokenize(self, pairs: list, **kwargs):
|
| 156 |
+
max_length = self.max_length
|
| 157 |
+
text = self.processor.apply_chat_template(
|
| 158 |
+
pairs, tokenize=False, add_generation_prompt=True
|
| 159 |
+
)
|
| 160 |
+
try:
|
| 161 |
+
images, videos, video_kwargs = process_vision_info(
|
| 162 |
+
pairs,
|
| 163 |
+
image_patch_size=16,
|
| 164 |
+
return_video_kwargs=True,
|
| 165 |
+
return_video_metadata=True,
|
| 166 |
+
)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Error in processing vision info: {e}")
|
| 169 |
+
images = None
|
| 170 |
+
videos = None
|
| 171 |
+
video_kwargs = {"do_sample_frames": False}
|
| 172 |
+
text = self.processor.apply_chat_template(
|
| 173 |
+
[{"role": "user", "content": [{"type": "text", "text": "NULL"}]}],
|
| 174 |
+
add_generation_prompt=True,
|
| 175 |
+
tokenize=False,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if videos is not None:
|
| 179 |
+
videos, video_metadatas = zip(*videos)
|
| 180 |
+
videos, video_metadatas = list(videos), list(video_metadatas)
|
| 181 |
+
else:
|
| 182 |
+
video_metadatas = None
|
| 183 |
+
inputs = self.processor(
|
| 184 |
+
text=text,
|
| 185 |
+
images=images,
|
| 186 |
+
videos=videos,
|
| 187 |
+
video_metadata=video_metadatas,
|
| 188 |
+
truncation=False,
|
| 189 |
+
padding=False,
|
| 190 |
+
do_resize=False,
|
| 191 |
+
**video_kwargs,
|
| 192 |
+
)
|
| 193 |
+
for i, ele in enumerate(inputs["input_ids"]):
|
| 194 |
+
inputs["input_ids"][i] = (
|
| 195 |
+
self.truncate_tokens_optimized(
|
| 196 |
+
inputs["input_ids"][i][:-5],
|
| 197 |
+
max_length,
|
| 198 |
+
self.processor.tokenizer.all_special_ids,
|
| 199 |
+
)
|
| 200 |
+
+ inputs["input_ids"][i][-5:]
|
| 201 |
+
)
|
| 202 |
+
temp_inputs = self.processor.tokenizer.pad(
|
| 203 |
+
{"input_ids": inputs["input_ids"]},
|
| 204 |
+
padding=True,
|
| 205 |
+
return_tensors="pt",
|
| 206 |
+
max_length=self.max_length,
|
| 207 |
+
)
|
| 208 |
+
for key in temp_inputs:
|
| 209 |
+
inputs[key] = temp_inputs[key]
|
| 210 |
+
return inputs
|
| 211 |
+
|
| 212 |
+
def format_mm_content(
|
| 213 |
+
self,
|
| 214 |
+
text,
|
| 215 |
+
image,
|
| 216 |
+
video,
|
| 217 |
+
prefix="Query:",
|
| 218 |
+
fps=None,
|
| 219 |
+
max_frames=None,
|
| 220 |
+
):
|
| 221 |
+
content = []
|
| 222 |
+
|
| 223 |
+
content.append({"type": "text", "text": prefix})
|
| 224 |
+
if not text and not image and not video:
|
| 225 |
+
content.append({"type": "text", "text": "NULL"})
|
| 226 |
+
return content
|
| 227 |
+
|
| 228 |
+
if video:
|
| 229 |
+
video_content = None
|
| 230 |
+
video_kwargs = {"total_pixels": self.total_pixels}
|
| 231 |
+
if isinstance(video, list):
|
| 232 |
+
video_content = video
|
| 233 |
+
if self.num_frames is not None or self.max_frames is not None:
|
| 234 |
+
video_content = sample_frames(
|
| 235 |
+
video_content, self.num_frames, self.max_frames
|
| 236 |
+
)
|
| 237 |
+
video_content = [
|
| 238 |
+
("file://" + ele if isinstance(ele, str) else ele)
|
| 239 |
+
for ele in video_content
|
| 240 |
+
]
|
| 241 |
+
elif isinstance(video, str):
|
| 242 |
+
video_content = (
|
| 243 |
+
video
|
| 244 |
+
if video.startswith(("http://", "https://"))
|
| 245 |
+
else "file://" + video
|
| 246 |
+
)
|
| 247 |
+
video_kwargs = {
|
| 248 |
+
"fps": fps or self.fps,
|
| 249 |
+
"max_frames": max_frames or self.max_frames,
|
| 250 |
+
}
|
| 251 |
+
else:
|
| 252 |
+
raise TypeError(f"Unrecognized video type: {type(video)}")
|
| 253 |
+
|
| 254 |
+
if video_content:
|
| 255 |
+
content.append({"type": "video", "video": video_content, **video_kwargs})
|
| 256 |
+
|
| 257 |
+
if image:
|
| 258 |
+
image_content = None
|
| 259 |
+
if isinstance(image, Image.Image):
|
| 260 |
+
image_content = image
|
| 261 |
+
elif isinstance(image, str):
|
| 262 |
+
image_content = (
|
| 263 |
+
image if image.startswith(("http", "oss")) else "file://" + image
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
raise TypeError(f"Unrecognized image type: {type(image)}")
|
| 267 |
+
|
| 268 |
+
if image_content:
|
| 269 |
+
content.append(
|
| 270 |
+
{
|
| 271 |
+
"type": "image",
|
| 272 |
+
"image": image_content,
|
| 273 |
+
"min_pixels": self.min_pixels,
|
| 274 |
+
"max_pixels": self.max_pixels,
|
| 275 |
+
}
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if text:
|
| 279 |
+
content.append({"type": "text", "text": text})
|
| 280 |
+
return content
|
| 281 |
+
|
| 282 |
+
def format_mm_instruction(
|
| 283 |
+
self,
|
| 284 |
+
query_text,
|
| 285 |
+
query_image,
|
| 286 |
+
query_video,
|
| 287 |
+
doc_text,
|
| 288 |
+
doc_image,
|
| 289 |
+
doc_video,
|
| 290 |
+
instruction=None,
|
| 291 |
+
fps=None,
|
| 292 |
+
max_frames=None,
|
| 293 |
+
):
|
| 294 |
+
inputs = []
|
| 295 |
+
inputs.append(
|
| 296 |
+
{
|
| 297 |
+
"role": "system",
|
| 298 |
+
"content": [
|
| 299 |
+
{
|
| 300 |
+
"type": "text",
|
| 301 |
+
"text": 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".',
|
| 302 |
+
}
|
| 303 |
+
],
|
| 304 |
+
}
|
| 305 |
+
)
|
| 306 |
+
if isinstance(query_text, tuple):
|
| 307 |
+
instruct, query_text = query_text
|
| 308 |
+
else:
|
| 309 |
+
instruct = instruction
|
| 310 |
+
contents = []
|
| 311 |
+
contents.append({"type": "text", "text": "<Instruct>: " + instruct})
|
| 312 |
+
query_content = self.format_mm_content(
|
| 313 |
+
query_text,
|
| 314 |
+
query_image,
|
| 315 |
+
query_video,
|
| 316 |
+
prefix="<Query>:",
|
| 317 |
+
fps=fps,
|
| 318 |
+
max_frames=max_frames,
|
| 319 |
+
)
|
| 320 |
+
contents.extend(query_content)
|
| 321 |
+
doc_content = self.format_mm_content(
|
| 322 |
+
doc_text,
|
| 323 |
+
doc_image,
|
| 324 |
+
doc_video,
|
| 325 |
+
prefix="\n<Document>:",
|
| 326 |
+
fps=fps,
|
| 327 |
+
max_frames=max_frames,
|
| 328 |
+
)
|
| 329 |
+
contents.extend(doc_content)
|
| 330 |
+
inputs.append({"role": "user", "content": contents})
|
| 331 |
+
return inputs
|
| 332 |
+
|
| 333 |
+
def process(self, inputs: Dict[str, Any]) -> List[float]:
|
| 334 |
+
"""Score documents by relevance to query.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
inputs: Dict with 'instruction', 'query', and 'documents' keys.
|
| 338 |
+
query and documents can have 'text', 'image', 'video' fields.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
List of relevance scores (0-1) for each document.
|
| 342 |
+
"""
|
| 343 |
+
instruction = inputs.get("instruction", self.default_instruction)
|
| 344 |
+
|
| 345 |
+
query = inputs.get("query", {})
|
| 346 |
+
documents = inputs.get("documents", [])
|
| 347 |
+
if not query or not documents:
|
| 348 |
+
return []
|
| 349 |
+
|
| 350 |
+
pairs = [
|
| 351 |
+
self.format_mm_instruction(
|
| 352 |
+
query.get("text", None),
|
| 353 |
+
query.get("image", None),
|
| 354 |
+
query.get("video", None),
|
| 355 |
+
document.get("text", None),
|
| 356 |
+
document.get("image", None),
|
| 357 |
+
document.get("video", None),
|
| 358 |
+
instruction=instruction,
|
| 359 |
+
fps=inputs.get("fps", self.fps),
|
| 360 |
+
max_frames=inputs.get("max_frames", self.max_frames),
|
| 361 |
+
)
|
| 362 |
+
for document in documents
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
final_scores = []
|
| 366 |
+
for pair in pairs:
|
| 367 |
+
tokenized_inputs = self.tokenize([pair])
|
| 368 |
+
tokenized_inputs = tokenized_inputs.to(self.model.device)
|
| 369 |
+
scores = self.compute_scores(tokenized_inputs)
|
| 370 |
+
final_scores.extend(scores)
|
| 371 |
+
return final_scores
|