Text Ranking
sentence-transformers
Safetensors
Transformers
bert
mteb
custom_code
Eval Results (legacy)
Instructions to use ByteDance/ListConRanker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use ByteDance/ListConRanker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("ByteDance/ListConRanker", trust_remote_code=True) query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Transformers
How to use ByteDance/ListConRanker with Transformers:
# Load model directly from transformers import AutoTokenizer, ListConRanker tokenizer = AutoTokenizer.from_pretrained("ByteDance/ListConRanker", trust_remote_code=True) model = ListConRanker.from_pretrained("ByteDance/ListConRanker", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
Roman Solomatin commited on
update after review
Browse files- listconranker.py +101 -4
listconranker.py
CHANGED
|
@@ -30,6 +30,9 @@ from transformers import (
|
|
| 30 |
import os
|
| 31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 32 |
from typing import Union, List, Optional
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class ListConRankerConfig(BertConfig):
|
|
@@ -295,14 +298,15 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 295 |
if sep_idxs.numel() == 0:
|
| 296 |
raise ValueError(f"No SEP in sequence {idx}")
|
| 297 |
first_sep = sep_idxs[0].item()
|
|
|
|
| 298 |
|
| 299 |
# Extract query and passage
|
| 300 |
q_seq = seq[: first_sep + 1]
|
| 301 |
q_mask = mask[: first_sep + 1]
|
| 302 |
q_tt = torch.zeros_like(q_seq)
|
| 303 |
|
| 304 |
-
p_seq = seq[first_sep:]
|
| 305 |
-
p_mask = mask[first_sep:]
|
| 306 |
p_seq = p_seq.clone()
|
| 307 |
p_seq[0] = self.config.cls_token_id
|
| 308 |
p_tt = torch.zeros_like(p_seq)
|
|
@@ -315,6 +319,16 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 315 |
].tolist()
|
| 316 |
)
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
if key not in grouped:
|
| 319 |
grouped[key] = {
|
| 320 |
"query": (q_seq, q_mask, q_tt),
|
|
@@ -396,7 +410,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 396 |
):
|
| 397 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
| 398 |
model.hf_model = BertModel.from_pretrained(
|
| 399 |
-
model_name_or_path, config=model.config.bert_config
|
| 400 |
)
|
| 401 |
|
| 402 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
|
@@ -439,11 +453,94 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 439 |
inputs = tokenizer(
|
| 440 |
batch_pairs,
|
| 441 |
padding=True,
|
| 442 |
-
truncation=
|
| 443 |
return_tensors="pt",
|
| 444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
logits = self(**inputs)[0]
|
| 446 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
| 447 |
logits.squeeze(1)
|
| 448 |
)
|
| 449 |
return total_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
import os
|
| 31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 32 |
from typing import Union, List, Optional
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
import numpy as np
|
| 35 |
+
import math
|
| 36 |
|
| 37 |
|
| 38 |
class ListConRankerConfig(BertConfig):
|
|
|
|
| 298 |
if sep_idxs.numel() == 0:
|
| 299 |
raise ValueError(f"No SEP in sequence {idx}")
|
| 300 |
first_sep = sep_idxs[0].item()
|
| 301 |
+
second_sep = sep_idxs[1].item()
|
| 302 |
|
| 303 |
# Extract query and passage
|
| 304 |
q_seq = seq[: first_sep + 1]
|
| 305 |
q_mask = mask[: first_sep + 1]
|
| 306 |
q_tt = torch.zeros_like(q_seq)
|
| 307 |
|
| 308 |
+
p_seq = seq[first_sep : second_sep + 1]
|
| 309 |
+
p_mask = mask[first_sep : second_sep + 1]
|
| 310 |
p_seq = p_seq.clone()
|
| 311 |
p_seq[0] = self.config.cls_token_id
|
| 312 |
p_tt = torch.zeros_like(p_seq)
|
|
|
|
| 319 |
].tolist()
|
| 320 |
)
|
| 321 |
|
| 322 |
+
# truncation
|
| 323 |
+
q_seq = q_seq[: self.config.max_position_embeddings]
|
| 324 |
+
q_seq[-1] = self.config.sep_token_id
|
| 325 |
+
p_seq = p_seq[: self.config.max_position_embeddings]
|
| 326 |
+
p_seq[-1] = self.config.sep_token_id
|
| 327 |
+
q_mask = q_mask[: self.config.max_position_embeddings]
|
| 328 |
+
p_mask = p_mask[: self.config.max_position_embeddings]
|
| 329 |
+
q_tt = q_tt[: self.config.max_position_embeddings]
|
| 330 |
+
p_tt = p_tt[: self.config.max_position_embeddings]
|
| 331 |
+
|
| 332 |
if key not in grouped:
|
| 333 |
grouped[key] = {
|
| 334 |
"query": (q_seq, q_mask, q_tt),
|
|
|
|
| 410 |
):
|
| 411 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
| 412 |
model.hf_model = BertModel.from_pretrained(
|
| 413 |
+
model_name_or_path, config=model.config.bert_config, **kwargs
|
| 414 |
)
|
| 415 |
|
| 416 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
|
|
|
| 453 |
inputs = tokenizer(
|
| 454 |
batch_pairs,
|
| 455 |
padding=True,
|
| 456 |
+
truncation=False,
|
| 457 |
return_tensors="pt",
|
| 458 |
)
|
| 459 |
+
|
| 460 |
+
for k, v in inputs.items():
|
| 461 |
+
inputs[k] = v.to(self.device)
|
| 462 |
+
|
| 463 |
logits = self(**inputs)[0]
|
| 464 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
| 465 |
logits.squeeze(1)
|
| 466 |
)
|
| 467 |
return total_logits
|
| 468 |
+
|
| 469 |
+
def multi_passage_in_iterative_inference(
|
| 470 |
+
self,
|
| 471 |
+
sentences: List[str],
|
| 472 |
+
stop_num: int = 20,
|
| 473 |
+
decrement_rate: float = 0.2,
|
| 474 |
+
min_filter_num: int = 10,
|
| 475 |
+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
|
| 476 |
+
"ByteDance/ListConRanker"
|
| 477 |
+
),
|
| 478 |
+
):
|
| 479 |
+
"""
|
| 480 |
+
Process multiple passages for one query in iterative inference.
|
| 481 |
+
:param sentences: List contains sentences for a query.
|
| 482 |
+
:return: Tensor of logits for each passage.
|
| 483 |
+
"""
|
| 484 |
+
if stop_num < 1:
|
| 485 |
+
raise ValueError("stop_num must be greater than 0")
|
| 486 |
+
if decrement_rate <= 0 or decrement_rate >= 1:
|
| 487 |
+
raise ValueError("decrement_rate must be in (0, 1)")
|
| 488 |
+
if min_filter_num < 1:
|
| 489 |
+
raise ValueError("min_filter_num must be greater than 0")
|
| 490 |
+
|
| 491 |
+
query = sentences[0]
|
| 492 |
+
passage = sentences[1:]
|
| 493 |
+
|
| 494 |
+
filter_times = 0
|
| 495 |
+
passage2score = defaultdict(list)
|
| 496 |
+
while len(passage) > stop_num:
|
| 497 |
+
batch = [[query] + passage]
|
| 498 |
+
pred_scores = self.multi_passage(
|
| 499 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
| 500 |
+
).tolist()
|
| 501 |
+
pred_scores_argsort = np.argsort(
|
| 502 |
+
pred_scores
|
| 503 |
+
).tolist() # Sort in increasing order
|
| 504 |
+
|
| 505 |
+
passage_len = len(passage)
|
| 506 |
+
to_filter_num = math.ceil(passage_len * decrement_rate)
|
| 507 |
+
if to_filter_num < min_filter_num:
|
| 508 |
+
to_filter_num = min_filter_num
|
| 509 |
+
|
| 510 |
+
have_filter_num = 0
|
| 511 |
+
while have_filter_num < to_filter_num:
|
| 512 |
+
idx = pred_scores_argsort[have_filter_num]
|
| 513 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| 514 |
+
have_filter_num += 1
|
| 515 |
+
while (
|
| 516 |
+
pred_scores[pred_scores_argsort[have_filter_num - 1]]
|
| 517 |
+
== pred_scores[pred_scores_argsort[have_filter_num]]
|
| 518 |
+
):
|
| 519 |
+
idx = pred_scores_argsort[have_filter_num]
|
| 520 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| 521 |
+
have_filter_num += 1
|
| 522 |
+
next_passage = []
|
| 523 |
+
next_passage_idx = have_filter_num
|
| 524 |
+
while next_passage_idx < len(passage):
|
| 525 |
+
idx = pred_scores_argsort[next_passage_idx]
|
| 526 |
+
next_passage.append(passage[idx])
|
| 527 |
+
next_passage_idx += 1
|
| 528 |
+
passage = next_passage
|
| 529 |
+
filter_times += 1
|
| 530 |
+
|
| 531 |
+
batch = [[query] + passage]
|
| 532 |
+
pred_scores = self.multi_passage(
|
| 533 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
| 534 |
+
).tolist()
|
| 535 |
+
|
| 536 |
+
cnt = 0
|
| 537 |
+
while cnt < len(passage):
|
| 538 |
+
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
| 539 |
+
cnt += 1
|
| 540 |
+
|
| 541 |
+
passage = sentences[1:]
|
| 542 |
+
final_score = []
|
| 543 |
+
for i in range(len(passage)):
|
| 544 |
+
p = passage[i]
|
| 545 |
+
final_score.append(passage2score[p][0])
|
| 546 |
+
return final_score
|