Arastun
/

Arastun commited on
Commit
82bc2ed
·
verified ·
1 Parent(s): c87e81a

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -0
handler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
7
+ self.model = AutoModelForCausalLM.from_pretrained(
8
+ path,
9
+ torch_dtype=torch.float16,
10
+ device_map="auto"
11
+ )
12
+ self.model.eval()
13
+
14
+ # Qwen3-Reranker uses a specific token to extract the relevance score
15
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
16
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
17
+
18
+ def format_input(self, query, document, instruction=None):
19
+ if instruction is None:
20
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
21
+ prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"
22
+ return prefix
23
+
24
+ def __call__(self, data: dict) -> dict:
25
+ """
26
+ Expected input format:
27
+ {
28
+ "query": "what is the capital of France",
29
+ "documents": ["Paris is the capital...", "London is the capital..."],
30
+ "instruction": "optional custom instruction"
31
+ }
32
+ """
33
+ query = data.get("query", "")
34
+ documents = data.get("documents", [])
35
+ instruction = data.get("instruction", None)
36
+
37
+ if not query or not documents:
38
+ return {"error": "Must provide 'query' and 'documents'"}
39
+
40
+ prompts = [self.format_input(query, doc, instruction) for doc in documents]
41
+
42
+ inputs = self.tokenizer(
43
+ prompts,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=4096
48
+ ).to(self.model.device)
49
+
50
+ with torch.no_grad():
51
+ outputs = self.model(**inputs)
52
+ # Get logits for the final token position
53
+ logits = outputs.logits[:, -1, :]
54
+ # Score is the softmax probability of "yes" vs "no"
55
+ true_logits = logits[:, self.token_true_id]
56
+ false_logits = logits[:, self.token_false_id]
57
+ scores = torch.softmax(
58
+ torch.stack([false_logits, true_logits], dim=1), dim=1
59
+ )[:, 1].tolist()
60
+
61
+ return {
62
+ "scores": scores,
63
+ "ranking": sorted(
64
+ range(len(documents)),
65
+ key=lambda i: scores[i],
66
+ reverse=True
67
+ )
68
+ }