| |
| """End‑to‑end inference loop that emulates the ZeroSearch prompting style. |
| |
| The policy model ("thinker") must: |
| • reason inside <think> … </think> |
| • place a query inside <search> … </search> whenever it needs external knowledge |
| • return the final short answer inside <answer> … </answer> |
| |
| The wrapper intercepts each <search> request, fulfils it with either: |
| |
| (a) a **simulated search engine** (another small LLM fine‑tuned as ZeroSearch |
| retriever) ‑‑ default; or |
| (b) a real search backend (e.g. Serper.dev, Bing) if `engine="real"`. |
| |
| It then injects results between <information> … </information> and hands control |
| back to the policy model. The loop repeats until </answer> is produced or a |
| maximum number of retrieval rounds is reached. |
| |
| The goal is to mirror the ergonomics of the user’s existing `ReCall` class so |
| that outer orchestration code can drop this in with minimal friction. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import time |
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| import requests |
| from openai import OpenAI |
|
|
| __all__ = ["ZeroSearchInference", "ZeroSearchConfig"] |
|
|
| TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B" |
|
|
| |
| try: |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True) |
| except Exception as e: |
| import sys |
| sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}") |
|
|
|
|
| |
| |
| |
|
|
| def retry(max_attempts: int = 4, sleep: int = 1, fallback=None): |
| def decorator(func): |
| def wrapper(*args, **kwargs): |
| for i in range(max_attempts): |
| try: |
| return func(*args, **kwargs) |
| except Exception as exc: |
| |
| if i == max_attempts - 1: |
| return fallback |
| time.sleep(sleep) |
| return wrapper |
| return decorator |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ZeroSearchConfig: |
| |
| thinker_url: str = "http://0.0.0.0:1214" |
| thinker_temperature: float = 0.7 |
| thinker_max_tokens: int = 40960 |
|
|
| |
| engine: str = "real" |
|
|
| |
| retriever_model: str = "gpt-4o-mini" |
| retriever_top_k: int = 5 |
|
|
| |
| serper_api_key: Optional[str] = "7bfe51ead1a1766b656c1355b292d1d29c15c114" |
| serper_url: str = "https://google.serper.dev/search" |
| serper_top_k: int = 5 |
|
|
| |
| max_rounds: int = 16 |
|
|
|
|
| |
| |
| |
|
|
| class ZeroSearchInference: |
| SEARCH_OPEN = "<search>" |
| SEARCH_CLOSE = "</search>" |
| INFO_OPEN = "<information>" |
| INFO_CLOSE = "</information>" |
|
|
| ANSWER_CLOSE = "</answer>" |
| THINK_OPEN = "<think>" |
| THINK_CLOSE = "</think>" |
|
|
| STOP_TOKENS = ["<|im_end|>", "<|endoftext|>", "</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"] |
|
|
| def __init__(self, cfg: ZeroSearchConfig): |
| self.cfg = cfg |
| |
| |
| |
|
|
| def run(self, user_question: str) -> str: |
| tool_calls = [] |
| prompt = self._build_initial_prompt(user_question) |
| for round_idx in range(self.cfg.max_rounds): |
| generated = self._call_thinker(prompt) |
| |
| |
| |
| prompt += generated |
|
|
| if self.ANSWER_CLOSE in generated: |
| |
| break |
|
|
| query = self._extract_query(generated) |
| |
| if not query: |
| |
| break |
| tool_calls.append(query) |
| info_block = self._retrieve_and_format(query) |
| |
| |
| prompt += info_block + self.THINK_OPEN |
|
|
| else: |
| prompt += "<answer>I don't know.</answer><|im_end|>" |
| return prompt, tool_calls |
|
|
| |
| |
| |
|
|
| def _build_initial_prompt(self, question: str) -> str: |
| user_msg = f"""Answer the given question. \ |
| You must conduct reasoning inside <think> and </think> first every time you get new information. \ |
| After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \ |
| You can search as many times as your want. \ |
| If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n""" |
| return f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{self.THINK_OPEN}" |
|
|
| |
| |
| |
|
|
| @retry(fallback="") |
| def _call_thinker(self, prompt: str) -> str: |
| prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"] |
| max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100 |
| resp = requests.post( |
| f"{self.cfg.thinker_url}/generate", |
| json={ |
| "text": prompt, |
| "sampling_params": { |
| "temperature": self.cfg.thinker_temperature, |
| "max_new_tokens": max_tokens_left, |
| "stop": self.STOP_TOKENS, |
| }, |
| |
| }, |
| timeout=60, |
| ).json() |
| generated = resp["text"] |
| matched = resp["meta_info"]["finish_reason"].get("matched") |
| reason = resp["meta_info"]["finish_reason"].get("type") |
| |
| if reason == "stop" and matched in self.STOP_TOKENS: |
| if not generated.endswith(matched): |
| generated += matched |
| if reason == "stop" and matched == 151645: |
| if not generated.endswith("<|im_end|>"): |
| generated += "<|im_end|>" |
| return generated |
|
|
| |
| |
| |
|
|
| def _extract_query(self, gen_text: str) -> Optional[str]: |
| if self.SEARCH_OPEN not in gen_text or self.SEARCH_CLOSE not in gen_text: |
| return None |
| query = gen_text.split(self.SEARCH_OPEN)[-1].split(self.SEARCH_CLOSE)[0].strip() |
| return query or None |
|
|
| |
| |
| |
|
|
| def _retrieve_and_format(self, query: str) -> str: |
| if self.cfg.engine == "real": |
| docs = self._real_search(query) |
| |
| |
| else: |
| docs = self._simulated_search(query) |
| return f"{self.INFO_OPEN}\n{docs}\n{self.INFO_CLOSE}\n\n" |
|
|
| |
|
|
| @retry(fallback="No information available") |
| def _simulated_search(self, query: str) -> str: |
| messages = [ |
| { |
| "role": "user", |
| "content": ( |
| "You are a search engine. Return up to " |
| f"{self.cfg.retriever_top_k} short documents (titles + snippets) " |
| "most relevant to the query, each on a new line.\n\n" |
| f"Query: {query}" |
| ), |
| } |
| ] |
| resp = self.openai.chat.completions.create( |
| model=self.cfg.retriever_model, |
| messages=messages, |
| max_tokens=256, |
| ) |
| return resp.choices[0].message.content.strip() |
|
|
| |
|
|
| @retry(fallback="No information available") |
| def _real_search(self, query: str) -> str: |
| if not self.cfg.serper_api_key: |
| raise ValueError("serper_api_key must be set for real search mode") |
| headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"} |
| payload = {"q": query, "num": self.cfg.serper_top_k} |
| resp = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20) |
| resp.raise_for_status() |
| data = resp.json().get("organic", [])[: self.cfg.serper_top_k] |
| lines = [] |
| for i, item in enumerate(data, 1): |
| snippet = f"Title: {item['title']}, \nSnippet{item['snippet']}" |
| lines.append(f"Doc {i}: {snippet}") |
| return "\n".join(lines) or "No information available" |
|
|