| import torch |
| import math |
| from vllm import LLM, SamplingParams |
| from utils import prompt_template, truncate |
|
|
|
|
| class ERank_vLLM: |
| |
| def __init__(self, model_name_or_path: str): |
| """ |
| Initializes the ERank_vLLM reranker. |
| |
| Args: |
| model_name_or_path (str): The name or path of the model to be loaded. |
| This can be a Hugging Face model ID or a local path. |
| """ |
| num_gpu = torch.cuda.device_count() |
| self.ranker = LLM( |
| model=model_name_or_path, |
| tensor_parallel_size=num_gpu, |
| gpu_memory_utilization=0.95, |
| enable_prefix_caching=True |
| ) |
| self.tokenizer = self.ranker.get_tokenizer() |
| self.sampling_params = SamplingParams( |
| temperature=0, |
| max_tokens=4096, |
| logprobs=20 |
| ) |
| |
| def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: |
| """ |
| Reranks a list of documents based on a query and a specific instruction. |
| |
| Args: |
| query (str): The search query provided by the user. |
| docs (list): A list of dictionaries, where each dictionary represents a document |
| and must contain a "content" key. |
| instruction (str): The instruction for the model, guiding it on how to evaluate the documents. |
| truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. |
| |
| Returns: |
| list: A new list of document dictionaries, sorted by their "rank_score" in descending order. |
| """ |
|
|
| |
| messages = [ |
| [{ |
| "role": "user", |
| "content": prompt_template.format( |
| query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, |
| doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], |
| instruction=instruction |
| ) |
| }] for doc in docs |
| ] |
|
|
| |
| outputs = self.ranker.chat(messages, self.sampling_params) |
| |
| |
| results = [] |
| for doc, output in zip(docs, outputs): |
| |
| |
| cur = "" |
| answer = "" |
| is_ans = False |
| prob = 1.0 |
| for each in output.outputs[0].logprobs[-10:]: |
| _, detail = next(iter(each.items())) |
| token = detail.decoded_token |
| logprob = detail.logprob |
| if is_ans and token.isdigit(): |
| answer += token |
| prob *= math.exp(logprob) |
| else: |
| cur += token |
| if cur.endswith("<answer>"): |
| is_ans = True |
| |
| |
| try: |
| answer = int(answer) |
| assert answer <= 10 |
| except: |
| answer = -1 |
| |
| |
| results.append({ |
| **doc, |
| "rank_score": answer * prob |
| }) |
| |
| |
| results.sort(key=lambda x:x["rank_score"], reverse=True) |
| return results |