DylanJHJ's picture
download
raw
3.07 kB
import math
import copy
from tqdm import tqdm
from typing import Optional, Tuple, List, Dict, Union, Any
from .base import RerankStrategy
from ..utils import Result
class April(RerankStrategy):
""" To make sure the prompts with same body can be done in the same batch. iterate over the query """
def run(
self,
init_results: List[Result],
rank_start: int,
rank_end: int,
num_runs: Optional[int] = 1,
**kwargs
) -> List[Result]:
# w_end = [rank_end, rank_end - step_size, ...]
rerank_results = [copy.deepcopy(result) for result in init_results]
for i_run in range(num_runs):
for curr_end in tqdm(
range(rank_end, rank_start, -self._step_size),
desc=f"Setwise APRIL (the {i_run+1} run)"
):
results = self.run_pass(results, rank_start, rank_end, curr_end)
# for index, result in enumerate(rerank_results):
#
# for w_end in tqdm(
# range(rank_end, rank_start, -self._step_size),
# desc=f"APRIL for query:{index}",
# ):
# # for w_end in range(rank_end, rank_start, -self._step_size):
# w_start = max(rank_start, w_end - self._window_size)
# w_size = w_end - w_start
# idx_pairs = [(i, j) for i in range(w_size) for j in range(w_size) if i != j]
#
# ## prefix caching ## [NOTE] this seems not necessary. it takes 4.5s per window
# prompts = self._prompt_builder.create_prompt(result=result, rank_start=w_start, rank_end=w_end)
# outputs = self._llm.generate(prompts=prompts, prob=self._rerank_mode.use_logits)
#
# # the last window
# scores = [0 for _ in range(w_size)]
# for (i, j), output in zip(idx_pairs, outputs):
# scores[i] += output
# scores[j] += 1 - output
#
# # outputs
# result = self._result_parser.parse(
# outputs=[scores],
# results=[result],
# rank_start=w_start,
# rank_end=w_end,
# )[0]
#
# # ignore the last pass as it was done and also not a full window
# if w_start == rank_start:
# break
#
# # update the rerank result
# rerank_results[index] = result
#
# # update the rank_start for the next run
# rank_start = rank_start + self._step_size
#
# # Assign reciprocal rank
# for result in rerank_results:
# for rank, hit in enumerate(result.hits, start=1):
# hit['score'] = float(1 / rank)
# hit['rank'] = rank
#
# return rerank_results
def run_pass(self, **kwargs: Any):
raise NotImplementedError("APRIL does not support `run_pass`. Use run instead.")

Xet Storage Details

Size:
3.07 kB
·
Xet hash:
b0d277e621d5dc43015ff660ceae11425c90d428849cc6c19c54887303d725f2

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.