DylanJHJ's picture
download
raw
2.34 kB
# NOTE: rank_start is not used.
# NOTE: Testing only the relevant anchor
import math
import copy
from tqdm import tqdm
from typing import Optional, Tuple, List, Dict, Union, Any
from ..utils import Result, batch_iterator
from .base import RerankStrategy
class RefRerank(RerankStrategy):
def run(
self,
init_results: List[Result],
rank_start: int = 0,
rank_end: int = None,
batch_size: Optional[int] = 32,
anchor_index: Optional[int] = 0,
**kwargs
) -> List[Result]:
results = [copy.deepcopy(result) for result in init_results]
all_scores = {}
for index, result in enumerate(results):
## Placeholder for scores
result.hits = [hit for hit in result.hits[:rank_end]]
all_scores[result.qid] = [0 for _ in result.hits]
## Create prompts for enumerating pairs
A = anchor_index
idx_pairs = [(A, j) for j in range(len(result.hits)) if j != A]
idx_pairs += [(i, A) for i in range(len(result.hits)) if i != A]
prompts = self._prompt_builder.create_prompt(
result,
rank_start=0, rank_end=rank_end,
idx_pairs=idx_pairs
)
## Iterate over pairs
scores = []
for batch_prompts in tqdm(
batch_iterator(prompts, batch_size),
desc=f"Batch processing with {batch_size} pairs",
):
batch_scores = self._llm.generate(prompts=batch_prompts, prob=self._rerank_mode.use_logits)
scores.extend(batch_scores)
## Pairwise score aggregation
for (i, j), score in zip(idx_pairs, scores):
score = math.log(score) if self.config.score_aggregation == 'symsumlog' else score
all_scores[result.qid][i] += score if j == A else 0
all_scores[result.qid][j] -= score if i == A else 0
# All scores updates
reranked_results = self._result_parser.parse(
[all_scores[result.qid] for result in results],
init_results
)
return reranked_results
def run_pass(self, **kwargs):
raise NotImplementedError("RefRank does not support `run_pass`. Use run instead.")

Xet Storage Details

Size:
2.34 kB
·
Xet hash:
77c569fb5140d6cc78ecca6ce292399d7a266702d1fffb0165462cc436e42c19

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.