| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
| self.model.eval() |
|
|
| |
| self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") |
| self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") |
|
|
| def format_input(self, query, document, instruction=None): |
| if instruction is None: |
| instruction = "Given a web search query, retrieve relevant passages that answer the query" |
| prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n" |
| return prefix |
|
|
| def __call__(self, data: dict) -> dict: |
| """ |
| Expected input format: |
| { |
| "query": "what is the capital of France", |
| "documents": ["Paris is the capital...", "London is the capital..."], |
| "instruction": "optional custom instruction" |
| } |
| """ |
| inputs = data.get("inputs", data) |
| query = inputs.get("query", "") |
| documents = inputs.get("documents", []) |
| instruction = inputs.get("instruction", None) |
|
|
| if not query or not documents: |
| return {"error": "Must provide 'query' and 'documents'"} |
|
|
| prompts = [self.format_input(query, doc, instruction) for doc in documents] |
|
|
| inputs = self.tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=4096 |
| ).to(self.model.device) |
|
|
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| logits = outputs.logits[:, -1, :] |
| |
| true_logits = logits[:, self.token_true_id] |
| false_logits = logits[:, self.token_false_id] |
| scores = torch.softmax( |
| torch.stack([false_logits, true_logits], dim=1), dim=1 |
| )[:, 1].tolist() |
|
|
| return { |
| "scores": scores, |
| "ranking": sorted( |
| range(len(documents)), |
| key=lambda i: scores[i], |
| reverse=True |
| ) |
| } |