gensearcher-firered / space_gen.py
JSCPPProgrammer's picture
Keyless search: DuckDuckGo + direct HTTP browse; optional Serper/Jina
138b29f verified
"""
Run one official GenSearcher trajectory (OpenAI-compatible vLLM) then call FireRed /generate adapter.
"""
from __future__ import annotations
import asyncio
import base64
import json
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple
import requests
from space_health import check_v1_models, is_localhost_url
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import (
create_gen_image_tools,
)
from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import (
GenImageDeepResearchWorkflow,
)
def _sanitize_content(msg: dict) -> dict:
out = {"role": msg.get("role", ""), "content": ""}
content = msg.get("content", "")
if isinstance(content, str):
out["content"] = content[:50000] + ("..." if len(content) > 50000 else "")
else:
out["content"] = str(content)[:50000]
if "images" in msg:
out["images"] = [
(p if isinstance(p, str) else p.get("image", ""))[:200]
for p in (msg["images"] or [])[:10]
]
return out
def get_effective_prompt_and_images(
user_prompt: str, prediction: dict
) -> Tuple[str, List[str]]:
"""Match eval/gen_image_from_results.get_effective_prompt_and_images."""
gen_prompt = (prediction.get("gen_prompt") or "").strip()
paths: List[str] = []
for r in prediction.get("reference_images") or []:
if not isinstance(r, dict):
continue
p = (r.get("local_path") or "").strip()
if p and os.path.exists(p):
paths.append(p)
if gen_prompt and paths:
return gen_prompt, paths
return user_prompt, []
def _parse_qwen_edit_base_url() -> str:
raw = os.environ.get("QWEN_EDIT_APP_URL", "http://127.0.0.1:8765").strip()
try:
urls = json.loads(raw)
if isinstance(urls, list) and urls:
return str(urls[0]).rstrip("/")
except json.JSONDecodeError:
pass
return raw.rstrip("/").strip('"').strip("'")
def call_generate_api(
base_url: str,
path: str,
prompt: str,
image_paths: List[str],
timeout: int = 1800,
) -> bytes:
if not path.startswith("/"):
path = "/" + path
url = base_url.rstrip("/") + path
ref_images_b64: List[str] = []
for img_path in image_paths[:3]:
with open(img_path, "rb") as f:
ref_images_b64.append(base64.b64encode(f.read()).decode("utf-8"))
image_urls = (
[f"data:image/jpeg;base64,{b}" for b in ref_images_b64]
if ref_images_b64
else None
)
payload = {
"image_urls": image_urls,
"prompt": prompt,
"seed": int(os.environ.get("GEN_SEED", "0")),
"true_cfg_scale": float(os.environ.get("GEN_TRUE_CFG_SCALE", "4.0")),
"negative_prompt": os.environ.get("GEN_NEGATIVE_PROMPT", " "),
"num_inference_steps": int(os.environ.get("GEN_NUM_INFERENCE_STEPS", "40")),
"guidance_scale": float(os.environ.get("GEN_GUIDANCE_SCALE", "1.0")),
"num_images_per_prompt": 1,
}
r = requests.post(url, json=payload, timeout=timeout)
r.raise_for_status()
result = r.json()
if not result.get("success"):
raise RuntimeError(result.get("message", str(result)))
img_b64 = result.get("image") or ""
if img_b64.startswith("data:image"):
img_b64 = img_b64.split(",", 1)[-1]
return base64.b64decode(img_b64)
async def run_gensearcher_then_generate(
user_prompt: str,
*,
temperature: float = 0.6,
top_p: float = 0.9,
skip_generation: bool = False,
) -> Dict[str, Any]:
sample_id = str(uuid.uuid4())[:8]
task = {
"id": sample_id,
"question": user_prompt,
"prompt": user_prompt,
"meta": {"source": "hf_space"},
}
model = os.environ.get("GEN_EVAL_MODEL", "Gen-Searcher-8B")
base_url = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8002/v1").strip().rstrip("/")
if not base_url.endswith("/v1"):
base_url = base_url + "/v1"
api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
ok_llm, llm_msg = check_v1_models(base_url, api_key)
if not ok_llm:
hint = ""
if is_localhost_url(base_url):
hint = (
" You are targeting localhost inside the Space container. Nothing is listening unless you set "
"Space variable START_VLLM_GENSEARCHER=1 (and GPU) or change OPENAI_BASE_URL to a reachable "
"OpenAI-compatible server (your vLLM / TGI URL ending in /v1)."
)
raise RuntimeError(
f"GenSearcher LLM is not reachable at {base_url}/models — {llm_msg}.{hint}"
)
rollout_engine = OpenAIEngine(
model=model,
base_url=base_url,
api_key=api_key,
max_prompt_length=int(os.environ.get("MAX_PROMPT_LENGTH", "64000")),
max_response_length=int(os.environ.get("MAX_RESPONSE_LENGTH", "64000")),
sampling_params={
"temperature": temperature,
"top_p": top_p,
},
)
tools = create_gen_image_tools()
workflow_engine = AgentWorkflowEngine(
workflow_cls=GenImageDeepResearchWorkflow,
workflow_args={"tools": tools, "reward_function": None},
rollout_engine=rollout_engine,
n_parallel_tasks=1,
retry_limit=2,
)
await workflow_engine.initialize_pool()
try:
_, _, episode = await workflow_engine.process_task_with_retry(
task, sample_id, 0
)
finally:
workflow_engine.shutdown()
info = episode.info or {}
messages = info.get("messages", [])
prediction = info.get("prediction", {}) if isinstance(info.get("prediction"), dict) else {}
termination = info.get("termination") or (
episode.termination_reason.value
if getattr(episode, "termination_reason", None)
else "unknown"
)
trajectory = [_sanitize_content(m) for m in messages]
out: Dict[str, Any] = {
"termination": termination,
"trajectory_messages": trajectory,
"gen_prompt": prediction.get("gen_prompt", ""),
"prediction": prediction,
}
if skip_generation:
out["image_png"] = None
out["image_error"] = None
return out
if termination != "answer":
out["image_png"] = None
out["image_error"] = f"Agent did not finish with answer (termination={termination})"
return out
eff_prompt, img_paths = get_effective_prompt_and_images(user_prompt, prediction)
gen_base = _parse_qwen_edit_base_url()
gen_path = os.environ.get("QWEN_EDIT_APP_PATH", "/generate")
timeout = int(os.environ.get("GEN_IMAGE_TIMEOUT", "1800"))
try:
png_bytes = call_generate_api(gen_base, gen_path, eff_prompt, img_paths, timeout=timeout)
out["image_png"] = png_bytes
out["used_prompt"] = eff_prompt
out["reference_paths"] = img_paths
out["image_error"] = None
except Exception as e:
out["image_png"] = None
out["image_error"] = str(e)
return out
def run_sync(
user_prompt: str,
*,
temperature: float = 0.6,
top_p: float = 0.9,
skip_generation: bool = False,
) -> Dict[str, Any]:
return asyncio.run(
run_gensearcher_then_generate(
user_prompt,
temperature=temperature,
top_p=top_p,
skip_generation=skip_generation,
)
)