File size: 4,629 Bytes
25e23ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083dc97
25e23ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints.

Accepts conversation context + query and returns a rewritten query.
"""

import re
import torch
from typing import Any, Dict, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer


PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \
coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \
search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within <think> </think> and <rewrite> </rewrite> \
tags, respectively, i.e., <think> reasoning process here </think>
<rewrite> rewrite here </rewrite>.

### Context Begin ###
{context}
### Context End ###

Query: {query}
Rewrite:"""


class EndpointHandler:
    """Handler for ConvSearch-R1 query rewriting."""

    def __init__(self, path: str = ""):
        """Initialize the model and tokenizer."""
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()
        print(f"ConvSearch-R1 loaded on {self.device}")

    def _format_prompt(self, context: List[str], query: str) -> str:
        """Format conversation context and query into the model prompt."""
        ctx_lines = []
        for i in range(0, len(context), 2):
            turn = i // 2 + 1
            ctx_lines.append(f"Q{turn}: {context[i]}")
            if i + 1 < len(context):
                ctx_lines.append(f"A{turn}: {context[i + 1]}")
        return PROMPT_TEMPLATE.format(
            context="\n".join(ctx_lines),
            query=query,
        )

    def _extract_rewrite(self, output: str) -> str:
        """Extract rewrite from model output."""
        match = re.search(r"<rewrite>(.*?)</rewrite>", output, re.DOTALL)
        if match:
            return match.group(1).strip()
        return output.strip()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process inference requests.

        Input format:
        {
            "inputs": [
                {"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"},
                ...
            ],
            "parameters": {"temperature": 0.7, "max_new_tokens": 1024}
        }

        Or single input:
        {
            "inputs": {"context": [...], "query": "..."},
            "parameters": {...}
        }

        Returns:
        [{"rewrite": "rewritten query", "raw_output": "full model output"}, ...]
        """
        inputs = data.get("inputs", data)
        params = data.get("parameters", {})
        temperature = params.get("temperature", 0.7)
        max_new_tokens = params.get("max_new_tokens", 4096)

        # Normalize to list
        if isinstance(inputs, dict):
            inputs = [inputs]

        results = []
        for inp in inputs:
            context = inp.get("context", [])
            query = inp.get("query", "")

            # Format prompt
            prompt_text = self._format_prompt(context, query)
            messages = [{"role": "user", "content": prompt_text}]
            formatted = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

            # Tokenize
            tokens = self.tokenizer(
                formatted, return_tensors="pt", truncation=True, max_length=2048
            ).to(self.device)

            # Generate
            with torch.no_grad():
                output_ids = self.model.generate(
                    **tokens,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=1.0,
                    do_sample=temperature > 0,
                )

            # Decode only new tokens
            new_tokens = output_ids[0][tokens["input_ids"].shape[1]:]
            raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True)

            # Extract rewrite
            rewrite = self._extract_rewrite(raw_output)
            results.append({
                "rewrite": rewrite,
                "raw_output": raw_output,
            })

        return results