DylanJHJ's picture
download
raw
4.45 kB
import math
import copy
from tqdm import tqdm
import numpy as np
from typing import Optional, Tuple, List, Dict, Union, Any
from ..utils import Result
from .base import RerankStrategy
import pdb
class SetMaxHeapTopK(RerankStrategy):
def run(
self,
init_results: List[Result],
rank_start: int = 0,
rank_end: int = None,
batch_size: Optional[int] = 32,
num_runs: int = 10,
**kwargs
) -> List[Result]:
results = [copy.deepcopy(result) for result in init_results]
for index, result in tqdm(enumerate(init_results), desc="Setwise HeapSort"):
sorted_hits = []
# 0. Get the last parent (index - 1) // d
i_parent = (len(result.hits) - 2) // self._window_size
# 1. build maxheap (traverse each paraents)
for i_visit in range(i_parent, -1, -1):
result = self.run_pass(result, target=i_visit)
# 2 swap the top1 with the last element
result.hits[0], result.hits[-1] = result.hits[-1], result.hits[0]
# 3 pop the largest (already at the end of the list)
sorted_hits.append(result.hits.pop(-1))
# Iteration until we have enough sorted hits
while len(sorted_hits) < num_runs: # TODO: maybe we should use variable top_k
if len(result.hits) == 0:
break
# iter-1: build maxheap for the remaining hits (only from the root)
result = self.run_pass(result, target=0)
# iter-2: swap the top1 with the last element
result.hits[0], result.hits[-1] = result.hits[-1], result.hits[0]
# iter-3: pop the largest (already at the end of the list)
sorted_hits.append(result.hits.pop(-1))
print(f"Sorted hits: {len(sorted_hits)}, Remaining hits: {len(result.hits)}")
# 4. Append the sorted hits to the result
results[index].hits = sorted_hits + result.hits
# Assign reciprocal rank
for result in results:
for rank, hit in enumerate(result.hits, start=1):
hit['score'] = float(1 / rank)
hit['rank'] = rank
return results
def run_pass(self, result: Result, target: int) -> List[Result]:
if target * self._window_size + 1 >= len(result.hits):
return result
# Get the comparing set-subtree
## target (current root): i
## target's child nodes: i * n_childs + {1/2/.../n_childs}
idx_pair = [target] + [target * self._window_size + i for i in range(1, self._window_size + 1)]
idx_pair = tuple(idx for idx in idx_pair if idx < len(result.hits))
prompt = self._prompt_builder.create_prompt(
result=result,
rank_start=0,
rank_end=len(result.hits),
idx_pairs=[idx_pair]
)
prompt += '['
outputs = self._llm.generate(prompt, dist_logp=True)[0]
outputs = outputs[:len(idx_pair)]
### version 1 (only swap)
i_max = np.argmax(outputs)
if i_max != 0:
max_idx = idx_pair[i_max]
result.hits[target], result.hits[max_idx] = result.hits[max_idx], result.hits[target]
result = self.run_pass(result, target=max_idx)
return result
# # NOTE: Do we need try all the combintations of the set? now i did it
# # TODO: looking for better implementation
# # NOTE: the first item is target, the last n_child items are the child nodes.
# # NOTE: the items in the middile are remaining the same
# # NOTE: Do we need to also swap the order of the entire set-subtree? I did it now
# max_2, max_1 = np.sort(outputs)[-2:]
# dummy = (max_1 + max_2) / 2
# final_outputs = [dummy for _ in range(curr_start, curr_end)]
# final_outputs[0] = outputs[0]
# final_outputs[1-len(idx_pair):] = outputs[1:len(idx_pair)]
# print(f"Final outputs: {outputs}, max_1: {max_1}, max_2: {max_2}")
#
# result = self._result_parser.parse(
# outputs=[final_outputs],
# results=[result],
# rank_start=curr_start,
# rank_end=curr_end,
# )[0]
#
# # After reranking, we need to sift down if largest element is not at the root
# if outputs[0] != max_1:
# idx_swap = np.argsort(outputs[:len(idx_pair)])[-1]
# i_target = idx_pair[idx_swap]
# result = self.run_pass(result, target=i_target)
#
# return result

Xet Storage Details

Size:
4.45 kB
·
Xet hash:
e91d28e9fc1740d8234079554e3be63c87dced326b8a5d70bcce6982ce2abc3d

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.