DylanJHJ's picture
download
raw
5.58 kB
"""
Parse the outputs and results, and return the updated results.
Apply different parsing depending with diffrent LLM outputs.
* non-parallel reranking methods, len(output) == len(results),
* parallel reranking methods: the output length equals to the number of queries.
- respones: list of permutation (e.g., RankGPT)
- swap: List[bool] (e.g., Pairwise topk) # TODO: should be fixed. Use the doc-index pair instead.
- scores:
* absoluate scores: List[List[float]] (e.g., Pairwise All, Pointwise)
* partial scores: List[List[float]] (e.g., APRIL, Setwise)
"""
import copy
from typing import List, Optional, Tuple, Callable, Dict, Union
from abc import ABC, abstractmethod
from ..utils import Result
class ResultParser(ABC):
def __init__(self, use_alpha=False):
self._use_alpha = use_alpha
# TODO: parse all, or maybe multithreading
# TODO: make the meaning of `rank_start` and `rank_end` similar across methods?
def parse(
self,
outputs: Union[List[List[Union[float, int]]], List[str]],
results: List[Result],
rank_start: int = 0,
rank_end: int = None,
) -> Result:
assert len(outputs) == len(results), "outputs and results must have the same length."
for index, (output, result) in enumerate(zip(outputs, results)):
if isinstance(output, str): # "[1] > [3] > [5] > ... "
parsed_result = self._parse_responses(output, result, rank_start, rank_end)
elif isinstance(output, bool): # True if swapping (for top-k)
parsed_result = self._parse_swap(output, result, rank_end)
elif isinstance(output, list): # e.g., Pairwise or Pointwise
if len(output) == len(result.hits):
parsed_result = self._parse_absolute_scores(output, result)
else:
parsed_result = self._parse_scores(output, result, rank_start, rank_end)
else:
raise TypeError(f"Unsupported outputs type: {type(output)}, {output}")
results[index] = parsed_result
return results
# NOTE: this is suitable for continuous items sorting. But not for setwise items initially
# NOTE: consider to add a design of APRIL
def _parse_scores(self, scores: List[float], result: Result, rank_start: int, rank_end: int) -> Result:
cut_range = copy.deepcopy(result.hits[rank_start:rank_end])
permutation = [(idx, s) for idx, s in zip(range(len(scores)), scores)]
permutation.sort(key=lambda x: x[1], reverse=True)
permutation = [(idx, s) for idx, s in permutation if idx < len(cut_range)]
for j, (p, s) in enumerate(permutation):
result.hits[j + rank_start] = copy.deepcopy(cut_range[p])
return result
def _parse_responses(self, permutation: str, result, rank_start: int, rank_end: int):
response = self._clean_response(permutation)
response = [int(x) - 1 for x in response.split()]
response = self._remove_duplicate(response)
cut_range = copy.deepcopy(result.hits[rank_start:rank_end])
original_rank = [tt for tt in range(len(cut_range))]
response = [ss for ss in response if ss in original_rank]
response = response + [tt for tt in original_rank if tt not in response]
for j, x in enumerate(response):
result.hits[j + rank_start] = copy.deepcopy(cut_range[x])
return result
def _parse_swap(self, swap: bool, result: Result, target: int) -> Result:
if swap is False: # means passage [1] > [2] (hits[rank_end-1] > hits[rank_end-2])
return result
init_hits = copy.deepcopy(result.hits)
result.hits[target - 1] = init_hits[target - 2]
result.hits[target - 2] = init_hits[target - 1]
return result
def _parse_absolute_scores(self, scores: List[Union[int, float, str]], result: Result):
""" Assign the scores from top to bottom, and fill the rest with decreasing scores. """
init_hits = copy.deepcopy(result.hits)
if isinstance(scores[0], str):
def to_float(s):
try:
return float(s)
except (ValueError, TypeError):
return 0.0
scores = [to_float(s) for s in scores]
min_score = min(scores) - 1
for i in range(len(init_hits)):
if i <= len(scores) - 1:
result.hits[i]["score"] = scores[i]
else:
result.hits[i]["score"] = min_score
min_score -= 1
return result
# TODO: use regular expression?
def _clean_response(self, response: str) -> str:
ALPH_START_IDX = 64 # ASCII 'A' starts at 65, so we use 64 to map 'A' to 1
new_response = ""
if self._use_alpha:
for c in response:
if not c.isalpha():
new_response += " "
else:
new_response += str(ord(c) - ALPH_START_IDX)
new_response = new_response.strip()
else:
for c in response:
if not c.isdigit():
new_response += " "
else:
new_response += c
new_response = new_response.strip()
return new_response
def _remove_duplicate(self, response: List[int]) -> List[int]:
new_response = []
for c in response:
if c not in new_response:
new_response.append(c)
return new_response

Xet Storage Details

Size:
5.58 kB
·
Xet hash:
313d8d1624749dc9f0011cb11f1c2e2c7b5560e85c62b4b92683c438c2d3150e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.