| | import torch |
| | import math |
| | from vllm import LLM, SamplingParams |
| | from utils import prompt_template, truncate |
| |
|
| |
|
| | class ERank_vLLM: |
| | |
| | def __init__(self, model_name_or_path: str): |
| | """ |
| | Initializes the ERank_vLLM reranker. |
| | |
| | Args: |
| | model_name_or_path (str): The name or path of the model to be loaded. |
| | This can be a Hugging Face model ID or a local path. |
| | """ |
| | num_gpu = torch.cuda.device_count() |
| | self.ranker = LLM( |
| | model=model_name_or_path, |
| | tensor_parallel_size=num_gpu, |
| | gpu_memory_utilization=0.95, |
| | enable_prefix_caching=True |
| | ) |
| | self.tokenizer = self.ranker.get_tokenizer() |
| | self.sampling_params = SamplingParams( |
| | temperature=0, |
| | max_tokens=4096, |
| | logprobs=20 |
| | ) |
| | |
| | def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: |
| | """ |
| | Reranks a list of documents based on a query and a specific instruction. |
| | |
| | Args: |
| | query (str): The search query provided by the user. |
| | docs (list): A list of dictionaries, where each dictionary represents a document |
| | and must contain a "content" key. |
| | instruction (str): The instruction for the model, guiding it on how to evaluate the documents. |
| | truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. |
| | |
| | Returns: |
| | list: A new list of document dictionaries, sorted by their "rank_score" in descending order. |
| | """ |
| |
|
| | |
| | messages = [ |
| | [{ |
| | "role": "user", |
| | "content": prompt_template.format( |
| | query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, |
| | doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], |
| | instruction=instruction |
| | ) |
| | }] for doc in docs |
| | ] |
| |
|
| | |
| | outputs = self.ranker.chat(messages, self.sampling_params) |
| | |
| | |
| | results = [] |
| | for doc, output in zip(docs, outputs): |
| | |
| | |
| | cur = "" |
| | answer = "" |
| | is_ans = False |
| | prob = 1.0 |
| | for each in output.outputs[0].logprobs[-10:]: |
| | _, detail = next(iter(each.items())) |
| | token = detail.decoded_token |
| | logprob = detail.logprob |
| | if is_ans and token.isdigit(): |
| | answer += token |
| | prob *= math.exp(logprob) |
| | else: |
| | cur += token |
| | if cur.endswith("<answer>"): |
| | is_ans = True |
| | |
| | |
| | try: |
| | answer = int(answer) |
| | assert answer <= 10 |
| | except: |
| | answer = -1 |
| | |
| | |
| | results.append({ |
| | **doc, |
| | "rank_score": answer * prob |
| | }) |
| | |
| | |
| | results.sort(key=lambda x:x["rank_score"], reverse=True) |
| | return results |