""" Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints. Accepts conversation context + query and returns a rewritten query. """ import re import torch from typing import Any, Dict, List, Union from transformers import AutoModelForCausalLM, AutoTokenizer PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \ coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \ search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within and \ tags, respectively, i.e., reasoning process here rewrite here . ### Context Begin ### {context} ### Context End ### Query: {query} Rewrite:""" class EndpointHandler: """Handler for ConvSearch-R1 query rewriting.""" def __init__(self, path: str = ""): """Initialize the model and tokenizer.""" self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) self.model.eval() print(f"ConvSearch-R1 loaded on {self.device}") def _format_prompt(self, context: List[str], query: str) -> str: """Format conversation context and query into the model prompt.""" ctx_lines = [] for i in range(0, len(context), 2): turn = i // 2 + 1 ctx_lines.append(f"Q{turn}: {context[i]}") if i + 1 < len(context): ctx_lines.append(f"A{turn}: {context[i + 1]}") return PROMPT_TEMPLATE.format( context="\n".join(ctx_lines), query=query, ) def _extract_rewrite(self, output: str) -> str: """Extract rewrite from model output.""" match = re.search(r"(.*?)", output, re.DOTALL) if match: return match.group(1).strip() return output.strip() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference requests. Input format: { "inputs": [ {"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"}, ... ], "parameters": {"temperature": 0.7, "max_new_tokens": 1024} } Or single input: { "inputs": {"context": [...], "query": "..."}, "parameters": {...} } Returns: [{"rewrite": "rewritten query", "raw_output": "full model output"}, ...] """ inputs = data.get("inputs", data) params = data.get("parameters", {}) temperature = params.get("temperature", 0.7) max_new_tokens = params.get("max_new_tokens", 4096) # Normalize to list if isinstance(inputs, dict): inputs = [inputs] results = [] for inp in inputs: context = inp.get("context", []) query = inp.get("query", "") # Format prompt prompt_text = self._format_prompt(context, query) messages = [{"role": "user", "content": prompt_text}] formatted = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize tokens = self.tokenizer( formatted, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Generate with torch.no_grad(): output_ids = self.model.generate( **tokens, max_new_tokens=max_new_tokens, temperature=temperature, top_p=1.0, do_sample=temperature > 0, ) # Decode only new tokens new_tokens = output_ids[0][tokens["input_ids"].shape[1]:] raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True) # Extract rewrite rewrite = self._extract_rewrite(raw_output) results.append({ "rewrite": rewrite, "raw_output": raw_output, }) return results