infoseeker-4b / inference /zerosearch.py
shreyess's picture
Upload folder using huggingface_hub
85d096e verified
# zero_search_inference.py
"""End‑to‑end inference loop that emulates the ZeroSearch prompting style.
The policy model ("thinker") must:
• reason inside <think> … </think>
• place a query inside <search> … </search> whenever it needs external knowledge
• return the final short answer inside <answer> … </answer>
The wrapper intercepts each <search> request, fulfils it with either:
(a) a **simulated search engine** (another small LLM fine‑tuned as ZeroSearch
retriever) ‑‑ default; or
(b) a real search backend (e.g. Serper.dev, Bing) if `engine="real"`.
It then injects results between <information> … </information> and hands control
back to the policy model. The loop repeats until </answer> is produced or a
maximum number of retrieval rounds is reached.
The goal is to mirror the ergonomics of the user’s existing `ReCall` class so
that outer orchestration code can drop this in with minimal friction.
"""
from __future__ import annotations
import json
import os
import re
import time
from dataclasses import dataclass
from typing import List, Optional
import requests
from openai import OpenAI
__all__ = ["ZeroSearchInference", "ZeroSearchConfig"]
TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
# ───────────────────────── tokenizer ────────────────────────────────────────
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
except Exception as e:
import sys
sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
# ---------------------------------------------------------------------------
# Utility: retry decorator ---------------------------------------------------
# ---------------------------------------------------------------------------
def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as exc:
#print(f"[retry] {func.__name__}: attempt {i + 1}/{max_attempts} failed – {exc}")
if i == max_attempts - 1:
return fallback
time.sleep(sleep)
return wrapper
return decorator
# ---------------------------------------------------------------------------
# Configuration dataclass ----------------------------------------------------
# ---------------------------------------------------------------------------
@dataclass
class ZeroSearchConfig:
# thinker LLM endpoint
thinker_url: str = "http://0.0.0.0:1214"
thinker_temperature: float = 0.7
thinker_max_tokens: int = 40960
# retrieval engine mode: "sim" or "real"
engine: str = "real" # simulated search (LLM) by default
# simulated search model (only used if engine == "sim")
retriever_model: str = "gpt-4o-mini"
retriever_top_k: int = 5
# real search backend (engine == "real")
serper_api_key: Optional[str] = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
serper_url: str = "https://google.serper.dev/search"
serper_top_k: int = 5
# Loop limits
max_rounds: int = 16
# ---------------------------------------------------------------------------
# Main wrapper ---------------------------------------------------------------
# ---------------------------------------------------------------------------
class ZeroSearchInference:
SEARCH_OPEN = "<search>"
SEARCH_CLOSE = "</search>"
INFO_OPEN = "<information>"
INFO_CLOSE = "</information>"
ANSWER_CLOSE = "</answer>"
THINK_OPEN = "<think>"
THINK_CLOSE = "</think>"
STOP_TOKENS = ["<|im_end|>", "<|endoftext|>", "</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]#, "</think>", "</think>\n", " </think>\n", "</think>\n\n", " </think>\n\n"]
def __init__(self, cfg: ZeroSearchConfig):
self.cfg = cfg
# ------------------------------------------------------------------
# Public driver -----------------------------------------------------
# ------------------------------------------------------------------
def run(self, user_question: str) -> str:
tool_calls = []
prompt = self._build_initial_prompt(user_question)
for round_idx in range(self.cfg.max_rounds):
generated = self._call_thinker(prompt)
#print("-"*100)
#print(f"Round: {round_idx}")
#print(generated)
prompt += generated
if self.ANSWER_CLOSE in generated:
#print(f"[ZeroSearch] Done in {round_idx + 1} rounds")
break
query = self._extract_query(generated)
if not query:
#print("[ZeroSearch] Model failed to emit <search>; aborting")
break
tool_calls.append(query)
info_block = self._retrieve_and_format(query)
#print(f"retrived docs: \n{info_block}")
#print("-"*100)
prompt += info_block + self.THINK_OPEN # next turn
else: # exceeded rounds
prompt += "<answer>I don't know.</answer><|im_end|>"
return prompt, tool_calls
# ------------------------------------------------------------------
# Prompt construction helpers --------------------------------------
# ------------------------------------------------------------------
def _build_initial_prompt(self, question: str) -> str:
user_msg = f"""Answer the given question. \
You must conduct reasoning inside <think> and </think> first every time you get new information. \
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \
You can search as many times as your want. \
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n"""
return f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{self.THINK_OPEN}"
# ------------------------------------------------------------------
# Thinker model call ------------------------------------------------
# ------------------------------------------------------------------
@retry(fallback="")
def _call_thinker(self, prompt: str) -> str:
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
resp = requests.post(
f"{self.cfg.thinker_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": self.cfg.thinker_temperature,
"max_new_tokens": max_tokens_left,
"stop": self.STOP_TOKENS,
},
},
timeout=60,
).json()
generated = resp["text"] # what you have now
matched = resp["meta_info"]["finish_reason"].get("matched")
reason = resp["meta_info"]["finish_reason"].get("type")
# ⇢ append the tag back only if it was removed
if reason == "stop" and matched in self.STOP_TOKENS:
if not generated.endswith(matched):
generated += matched
if reason == "stop" and matched == 151645:
if not generated.endswith("<|im_end|>"):
generated += "<|im_end|>"
return generated
# ------------------------------------------------------------------
# Query extraction --------------------------------------------------
# ------------------------------------------------------------------
def _extract_query(self, gen_text: str) -> Optional[str]:
if self.SEARCH_OPEN not in gen_text or self.SEARCH_CLOSE not in gen_text:
return None
query = gen_text.split(self.SEARCH_OPEN)[-1].split(self.SEARCH_CLOSE)[0].strip()
return query or None
# ------------------------------------------------------------------
# Retrieval path ----------------------------------------------------
# ------------------------------------------------------------------
def _retrieve_and_format(self, query: str) -> str:
if self.cfg.engine == "real":
docs = self._real_search(query)
#print("DOCS")
#print(docs)
else:
docs = self._simulated_search(query)
return f"{self.INFO_OPEN}\n{docs}\n{self.INFO_CLOSE}\n\n"
# --- simulated search with LLM ------------------------------------
@retry(fallback="No information available")
def _simulated_search(self, query: str) -> str:
messages = [
{
"role": "user",
"content": (
"You are a search engine. Return up to "
f"{self.cfg.retriever_top_k} short documents (titles + snippets) "
"most relevant to the query, each on a new line.\n\n"
f"Query: {query}"
),
}
]
resp = self.openai.chat.completions.create(
model=self.cfg.retriever_model,
messages=messages,
max_tokens=256,
)
return resp.choices[0].message.content.strip()
# --- real web search via Serper ----------------------------------
@retry(fallback="No information available")
def _real_search(self, query: str) -> str:
if not self.cfg.serper_api_key:
raise ValueError("serper_api_key must be set for real search mode")
headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"}
payload = {"q": query, "num": self.cfg.serper_top_k}
resp = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20)
resp.raise_for_status()
data = resp.json().get("organic", [])[: self.cfg.serper_top_k]
lines = []
for i, item in enumerate(data, 1):
snippet = f"Title: {item['title']}, \nSnippet{item['snippet']}"
lines.append(f"Doc {i}: {snippet}")
return "\n".join(lines) or "No information available"