pashaa's picture
Upload handler.py with huggingface_hub
083dc97 verified
"""
Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints.
Accepts conversation context + query and returns a rewritten query.
"""
import re
import torch
from typing import Any, Dict, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \
coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \
search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within <think> </think> and <rewrite> </rewrite> \
tags, respectively, i.e., <think> reasoning process here </think>
<rewrite> rewrite here </rewrite>.
### Context Begin ###
{context}
### Context End ###
Query: {query}
Rewrite:"""
class EndpointHandler:
"""Handler for ConvSearch-R1 query rewriting."""
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
print(f"ConvSearch-R1 loaded on {self.device}")
def _format_prompt(self, context: List[str], query: str) -> str:
"""Format conversation context and query into the model prompt."""
ctx_lines = []
for i in range(0, len(context), 2):
turn = i // 2 + 1
ctx_lines.append(f"Q{turn}: {context[i]}")
if i + 1 < len(context):
ctx_lines.append(f"A{turn}: {context[i + 1]}")
return PROMPT_TEMPLATE.format(
context="\n".join(ctx_lines),
query=query,
)
def _extract_rewrite(self, output: str) -> str:
"""Extract rewrite from model output."""
match = re.search(r"<rewrite>(.*?)</rewrite>", output, re.DOTALL)
if match:
return match.group(1).strip()
return output.strip()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Input format:
{
"inputs": [
{"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"},
...
],
"parameters": {"temperature": 0.7, "max_new_tokens": 1024}
}
Or single input:
{
"inputs": {"context": [...], "query": "..."},
"parameters": {...}
}
Returns:
[{"rewrite": "rewritten query", "raw_output": "full model output"}, ...]
"""
inputs = data.get("inputs", data)
params = data.get("parameters", {})
temperature = params.get("temperature", 0.7)
max_new_tokens = params.get("max_new_tokens", 4096)
# Normalize to list
if isinstance(inputs, dict):
inputs = [inputs]
results = []
for inp in inputs:
context = inp.get("context", [])
query = inp.get("query", "")
# Format prompt
prompt_text = self._format_prompt(context, query)
messages = [{"role": "user", "content": prompt_text}]
formatted = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize
tokens = self.tokenizer(
formatted, return_tensors="pt", truncation=True, max_length=2048
).to(self.device)
# Generate
with torch.no_grad():
output_ids = self.model.generate(
**tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=1.0,
do_sample=temperature > 0,
)
# Decode only new tokens
new_tokens = output_ids[0][tokens["input_ids"].shape[1]:]
raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Extract rewrite
rewrite = self._extract_rewrite(raw_output)
results.append({
"rewrite": rewrite,
"raw_output": raw_output,
})
return results