File size: 4,629 Bytes
25e23ab 083dc97 25e23ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints.
Accepts conversation context + query and returns a rewritten query.
"""
import re
import torch
from typing import Any, Dict, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \
coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \
search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within <think> </think> and <rewrite> </rewrite> \
tags, respectively, i.e., <think> reasoning process here </think>
<rewrite> rewrite here </rewrite>.
### Context Begin ###
{context}
### Context End ###
Query: {query}
Rewrite:"""
class EndpointHandler:
"""Handler for ConvSearch-R1 query rewriting."""
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
print(f"ConvSearch-R1 loaded on {self.device}")
def _format_prompt(self, context: List[str], query: str) -> str:
"""Format conversation context and query into the model prompt."""
ctx_lines = []
for i in range(0, len(context), 2):
turn = i // 2 + 1
ctx_lines.append(f"Q{turn}: {context[i]}")
if i + 1 < len(context):
ctx_lines.append(f"A{turn}: {context[i + 1]}")
return PROMPT_TEMPLATE.format(
context="\n".join(ctx_lines),
query=query,
)
def _extract_rewrite(self, output: str) -> str:
"""Extract rewrite from model output."""
match = re.search(r"<rewrite>(.*?)</rewrite>", output, re.DOTALL)
if match:
return match.group(1).strip()
return output.strip()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Input format:
{
"inputs": [
{"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"},
...
],
"parameters": {"temperature": 0.7, "max_new_tokens": 1024}
}
Or single input:
{
"inputs": {"context": [...], "query": "..."},
"parameters": {...}
}
Returns:
[{"rewrite": "rewritten query", "raw_output": "full model output"}, ...]
"""
inputs = data.get("inputs", data)
params = data.get("parameters", {})
temperature = params.get("temperature", 0.7)
max_new_tokens = params.get("max_new_tokens", 4096)
# Normalize to list
if isinstance(inputs, dict):
inputs = [inputs]
results = []
for inp in inputs:
context = inp.get("context", [])
query = inp.get("query", "")
# Format prompt
prompt_text = self._format_prompt(context, query)
messages = [{"role": "user", "content": prompt_text}]
formatted = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize
tokens = self.tokenizer(
formatted, return_tensors="pt", truncation=True, max_length=2048
).to(self.device)
# Generate
with torch.no_grad():
output_ids = self.model.generate(
**tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=1.0,
do_sample=temperature > 0,
)
# Decode only new tokens
new_tokens = output_ids[0][tokens["input_ids"].shape[1]:]
raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Extract rewrite
rewrite = self._extract_rewrite(raw_output)
results.append({
"rewrite": rewrite,
"raw_output": raw_output,
})
return results
|