Arastun
/

File size: 2,810 Bytes
82bc2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ceda7
 
 
 
82bc2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model.eval()

        # Qwen3-Reranker uses a specific token to extract the relevance score
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")

    def format_input(self, query, document, instruction=None):
        if instruction is None:
            instruction = "Given a web search query, retrieve relevant passages that answer the query"
        prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"
        return prefix

    def __call__(self, data: dict) -> dict:
        """
        Expected input format:
        {
            "query": "what is the capital of France",
            "documents": ["Paris is the capital...", "London is the capital..."],
            "instruction": "optional custom instruction"
        }
        """
        inputs = data.get("inputs", data)  # unwrap HF gateway nesting
        query = inputs.get("query", "")
        documents = inputs.get("documents", [])
        instruction = inputs.get("instruction", None)

        if not query or not documents:
            return {"error": "Must provide 'query' and 'documents'"}

        prompts = [self.format_input(query, doc, instruction) for doc in documents]

        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4096
        ).to(self.model.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            # Get logits for the final token position
            logits = outputs.logits[:, -1, :]
            # Score is the softmax probability of "yes" vs "no"
            true_logits = logits[:, self.token_true_id]
            false_logits = logits[:, self.token_false_id]
            scores = torch.softmax(
                torch.stack([false_logits, true_logits], dim=1), dim=1
            )[:, 1].tolist()

        return {
            "scores": scores,
            "ranking": sorted(
                range(len(documents)),
                key=lambda i: scores[i],
                reverse=True
            )
        }