| from torch.nn import functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from utils import prompt_template, truncate, hybrid_scores |
|
|
| class ERank_Transformer: |
| |
| def __init__(self, model_name_or_path: str): |
| """ |
| Initializes the ERank_Transformer reranker. |
| |
| Args: |
| model_name_or_path (str): The name or path of the model to be loaded. |
| This can be a Hugging Face model ID or a local path. |
| """ |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval() |
| self.reranker.to("cuda") |
| |
| def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: |
| """ |
| Reranks a list of documents based on a query and a specific instruction. |
| |
| Args: |
| query (str): The search query provided by the user. |
| docs (list): A list of dictionaries, where each dictionary represents a document |
| and must contain a "content" key. |
| instruction (str): The instruction for the model, guiding it on how to evaluate the documents. |
| truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. |
| |
| Returns: |
| list: A new list of document dictionaries, sorted by their "rank_score" in descending order. |
| """ |
| |
| |
| messages = [ |
| [{ |
| "role": "user", |
| "content": prompt_template.format( |
| query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, |
| doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], |
| instruction=instruction |
| ) |
| }] for doc in docs |
| ] |
|
|
| |
| texts = [ |
| self.tokenizer.apply_chat_template( |
| each, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) for each in messages |
| ] |
| inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device) |
|
|
| |
| outputs = self.reranker.generate( |
| **inputs, |
| max_new_tokens=8192, |
| output_scores=True, |
| return_dict_in_generate=True |
| ) |
| |
| |
| results = [] |
| scores = outputs.scores |
| generated_ids = outputs.sequences |
| answer_token_ids = self.tokenizer.encode("<answer>", add_special_tokens=False) |
| for idx in range(len(texts)): |
| |
| |
| output_ids = generated_ids[idx].tolist() |
| start_index = -1 |
| for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1): |
| if output_ids[i:i + len(answer_token_ids)] == answer_token_ids: |
| start_index = i + len(answer_token_ids) |
| break |
| |
| |
| answer = "" |
| prob = 1.0 |
| if start_index != -1: |
| for t in range(start_index - inputs.input_ids.size(1), len(scores)): |
| generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t] |
| token = self.tokenizer.decode(generated_token_id) |
| if token.isdigit(): |
| logits = scores[t][idx] |
| probs = F.softmax(logits, dim=-1) |
| prob *= probs[generated_token_id].item() |
| answer += token |
| else: |
| break |
|
|
| |
| try: |
| answer = int(answer) |
| assert answer <= 10 |
| except: |
| answer = -1 |
| |
| |
| results.append({ |
| **docs[idx], |
| "rank_score": answer * prob |
| }) |
| |
| |
| results.sort(key=lambda x:x["rank_score"], reverse=True) |
| return results |
| |
| |
| if __name__ == "__main__": |
| |
| |
| model_name_or_path = "Ucreate/ERank-4B" |
| |
| |
| reranker = ERank_Transformer(model_name_or_path) |
| |
| |
| instruction = "Retrieve relevant documents for the query." |
| query = "I am happy" |
| docs = [ |
| {"content": "excited", "first_stage_score": 46.7}, |
| {"content": "sad", "first_stage_score": 1.5}, |
| {"content": "peaceful", "first_stage_score": 2.3}, |
| ] |
|
|
| |
| results = reranker.rerank(query, docs, instruction, truncate_length=2048) |
| print(results) |
| |
| |
| |
| |
| |
| |
| |
| alpha = 0.2 |
| hybrid_results = hybrid_scores(results, alpha) |
| print(hybrid_results) |
| |
| |
| |
| |
| |