| | """ |
| | Custom handler for ConvSearch-R1 query rewriting on HuggingFace Inference Endpoints. |
| | |
| | Accepts conversation context + query and returns a rewritten query. |
| | """ |
| |
|
| | import re |
| | import torch |
| | from typing import Any, Dict, List, Union |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | PROMPT_TEMPLATE = """Given a query and its context, you must first think about the reasoning process in the mind to decontextualize the query by resolving \ |
| | coreference and omission issues. Then, provide the user with a rewrite that retains its original meaning and is as informative as possible to help \ |
| | search engines retrieve relevant documents effectively. The reasoning process and rewrite should be enclosed within <think> </think> and <rewrite> </rewrite> \ |
| | tags, respectively, i.e., <think> reasoning process here </think> |
| | <rewrite> rewrite here </rewrite>. |
| | |
| | ### Context Begin ### |
| | {context} |
| | ### Context End ### |
| | |
| | Query: {query} |
| | Rewrite:""" |
| |
|
| |
|
| | class EndpointHandler: |
| | """Handler for ConvSearch-R1 query rewriting.""" |
| |
|
| | def __init__(self, path: str = ""): |
| | """Initialize the model and tokenizer.""" |
| | self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | ) |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = self.model.to(self.device) |
| | self.model.eval() |
| | print(f"ConvSearch-R1 loaded on {self.device}") |
| |
|
| | def _format_prompt(self, context: List[str], query: str) -> str: |
| | """Format conversation context and query into the model prompt.""" |
| | ctx_lines = [] |
| | for i in range(0, len(context), 2): |
| | turn = i // 2 + 1 |
| | ctx_lines.append(f"Q{turn}: {context[i]}") |
| | if i + 1 < len(context): |
| | ctx_lines.append(f"A{turn}: {context[i + 1]}") |
| | return PROMPT_TEMPLATE.format( |
| | context="\n".join(ctx_lines), |
| | query=query, |
| | ) |
| |
|
| | def _extract_rewrite(self, output: str) -> str: |
| | """Extract rewrite from model output.""" |
| | match = re.search(r"<rewrite>(.*?)</rewrite>", output, re.DOTALL) |
| | if match: |
| | return match.group(1).strip() |
| | return output.strip() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process inference requests. |
| | |
| | Input format: |
| | { |
| | "inputs": [ |
| | {"context": ["Q1", "A1", "Q2", "A2"], "query": "current question"}, |
| | ... |
| | ], |
| | "parameters": {"temperature": 0.7, "max_new_tokens": 1024} |
| | } |
| | |
| | Or single input: |
| | { |
| | "inputs": {"context": [...], "query": "..."}, |
| | "parameters": {...} |
| | } |
| | |
| | Returns: |
| | [{"rewrite": "rewritten query", "raw_output": "full model output"}, ...] |
| | """ |
| | inputs = data.get("inputs", data) |
| | params = data.get("parameters", {}) |
| | temperature = params.get("temperature", 0.7) |
| | max_new_tokens = params.get("max_new_tokens", 4096) |
| |
|
| | |
| | if isinstance(inputs, dict): |
| | inputs = [inputs] |
| |
|
| | results = [] |
| | for inp in inputs: |
| | context = inp.get("context", []) |
| | query = inp.get("query", "") |
| |
|
| | |
| | prompt_text = self._format_prompt(context, query) |
| | messages = [{"role": "user", "content": prompt_text}] |
| | formatted = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| |
|
| | |
| | tokens = self.tokenizer( |
| | formatted, return_tensors="pt", truncation=True, max_length=2048 |
| | ).to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | output_ids = self.model.generate( |
| | **tokens, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=1.0, |
| | do_sample=temperature > 0, |
| | ) |
| |
|
| | |
| | new_tokens = output_ids[0][tokens["input_ids"].shape[1]:] |
| | raw_output = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| |
|
| | |
| | rewrite = self._extract_rewrite(raw_output) |
| | results.append({ |
| | "rewrite": rewrite, |
| | "raw_output": raw_output, |
| | }) |
| |
|
| | return results |
| |
|