DylanJHJ's picture
download
raw
2.02 kB
# NOTE: consider change the name to swap topk?
import math
import copy
from tqdm import tqdm
import numpy as np
from typing import Optional, Tuple, List, Dict, Union, Any
from ..utils import Result
from .base import RerankStrategy
class SetBubbleTopK(RerankStrategy):
def run(
self,
init_results: List[Result],
rank_start: int = 0,
rank_end: int = None,
num_runs: int = 10,
**kwargs
) -> List[Result]:
results = [copy.deepcopy(result) for result in init_results]
for i_run in range(num_runs):
for curr_end in tqdm(
range(rank_end, rank_start, -self._step_size),
desc=f"Setwise Bubble (the {i_run+1} run)"
):
if curr_end - self._window_size < rank_start:
break
results = self.run_pass(results, rank_start, rank_end, curr_end)
# Assign reciprocal rank
for result in results:
for rank, hit in enumerate(result.hits, start=1):
hit['score'] = float(1 / rank)
hit['rank'] = rank
return results
def run_pass(
self,
results: List[Result],
rank_start: int,
rank_end: int,
curr_end: int,
) -> List[Result]:
curr_start = max(0, curr_end - self._window_size)
prompts = self._prompt_builder.create_prompt_batched(
results=results,
rank_start=0,
rank_end=rank_end,
idx_pairs=[tuple(range(curr_start, curr_end))],
)
prompts = [p + '[' for p in prompts] # NOTE: consider move this to prompt builder
outputs = self._llm.generate(prompts, dist_logp=True)
outputs = [o[:(curr_end - curr_start)] for o in outputs]
reranked_results = self._result_parser.parse(
outputs=outputs,
results=results,
rank_start=curr_start,
rank_end=curr_end,
)
return reranked_results

Xet Storage Details

Size:
2.02 kB
·
Xet hash:
0368d63ec8df11301810a22e4bce749a081b9841860dcf747f814e1d27e10265

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.