""" Run one official GenSearcher trajectory (OpenAI-compatible vLLM) then call FireRed /generate adapter. """ from __future__ import annotations import asyncio import base64 import json import os import uuid from typing import Any, Dict, List, Optional, Tuple import requests from space_health import check_v1_models, is_localhost_url from rllm.engine.agent_workflow_engine import AgentWorkflowEngine from rllm.engine.rollout import OpenAIEngine from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import ( create_gen_image_tools, ) from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import ( GenImageDeepResearchWorkflow, ) def _sanitize_content(msg: dict) -> dict: out = {"role": msg.get("role", ""), "content": ""} content = msg.get("content", "") if isinstance(content, str): out["content"] = content[:50000] + ("..." if len(content) > 50000 else "") else: out["content"] = str(content)[:50000] if "images" in msg: out["images"] = [ (p if isinstance(p, str) else p.get("image", ""))[:200] for p in (msg["images"] or [])[:10] ] return out def get_effective_prompt_and_images( user_prompt: str, prediction: dict ) -> Tuple[str, List[str]]: """Match eval/gen_image_from_results.get_effective_prompt_and_images.""" gen_prompt = (prediction.get("gen_prompt") or "").strip() paths: List[str] = [] for r in prediction.get("reference_images") or []: if not isinstance(r, dict): continue p = (r.get("local_path") or "").strip() if p and os.path.exists(p): paths.append(p) if gen_prompt and paths: return gen_prompt, paths return user_prompt, [] def _parse_qwen_edit_base_url() -> str: raw = os.environ.get("QWEN_EDIT_APP_URL", "http://127.0.0.1:8765").strip() try: urls = json.loads(raw) if isinstance(urls, list) and urls: return str(urls[0]).rstrip("/") except json.JSONDecodeError: pass return raw.rstrip("/").strip('"').strip("'") def call_generate_api( base_url: str, path: str, prompt: str, image_paths: List[str], timeout: int = 1800, ) -> bytes: if not path.startswith("/"): path = "/" + path url = base_url.rstrip("/") + path ref_images_b64: List[str] = [] for img_path in image_paths[:3]: with open(img_path, "rb") as f: ref_images_b64.append(base64.b64encode(f.read()).decode("utf-8")) image_urls = ( [f"data:image/jpeg;base64,{b}" for b in ref_images_b64] if ref_images_b64 else None ) payload = { "image_urls": image_urls, "prompt": prompt, "seed": int(os.environ.get("GEN_SEED", "0")), "true_cfg_scale": float(os.environ.get("GEN_TRUE_CFG_SCALE", "4.0")), "negative_prompt": os.environ.get("GEN_NEGATIVE_PROMPT", " "), "num_inference_steps": int(os.environ.get("GEN_NUM_INFERENCE_STEPS", "40")), "guidance_scale": float(os.environ.get("GEN_GUIDANCE_SCALE", "1.0")), "num_images_per_prompt": 1, } r = requests.post(url, json=payload, timeout=timeout) r.raise_for_status() result = r.json() if not result.get("success"): raise RuntimeError(result.get("message", str(result))) img_b64 = result.get("image") or "" if img_b64.startswith("data:image"): img_b64 = img_b64.split(",", 1)[-1] return base64.b64decode(img_b64) async def run_gensearcher_then_generate( user_prompt: str, *, temperature: float = 0.6, top_p: float = 0.9, skip_generation: bool = False, ) -> Dict[str, Any]: sample_id = str(uuid.uuid4())[:8] task = { "id": sample_id, "question": user_prompt, "prompt": user_prompt, "meta": {"source": "hf_space"}, } model = os.environ.get("GEN_EVAL_MODEL", "Gen-Searcher-8B") base_url = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8002/v1").strip().rstrip("/") if not base_url.endswith("/v1"): base_url = base_url + "/v1" api_key = os.environ.get("OPENAI_API_KEY", "EMPTY") ok_llm, llm_msg = check_v1_models(base_url, api_key) if not ok_llm: hint = "" if is_localhost_url(base_url): hint = ( " You are targeting localhost inside the Space container. Nothing is listening unless you set " "Space variable START_VLLM_GENSEARCHER=1 (and GPU) or change OPENAI_BASE_URL to a reachable " "OpenAI-compatible server (your vLLM / TGI URL ending in /v1)." ) raise RuntimeError( f"GenSearcher LLM is not reachable at {base_url}/models — {llm_msg}.{hint}" ) rollout_engine = OpenAIEngine( model=model, base_url=base_url, api_key=api_key, max_prompt_length=int(os.environ.get("MAX_PROMPT_LENGTH", "64000")), max_response_length=int(os.environ.get("MAX_RESPONSE_LENGTH", "64000")), sampling_params={ "temperature": temperature, "top_p": top_p, }, ) tools = create_gen_image_tools() workflow_engine = AgentWorkflowEngine( workflow_cls=GenImageDeepResearchWorkflow, workflow_args={"tools": tools, "reward_function": None}, rollout_engine=rollout_engine, n_parallel_tasks=1, retry_limit=2, ) await workflow_engine.initialize_pool() try: _, _, episode = await workflow_engine.process_task_with_retry( task, sample_id, 0 ) finally: workflow_engine.shutdown() info = episode.info or {} messages = info.get("messages", []) prediction = info.get("prediction", {}) if isinstance(info.get("prediction"), dict) else {} termination = info.get("termination") or ( episode.termination_reason.value if getattr(episode, "termination_reason", None) else "unknown" ) trajectory = [_sanitize_content(m) for m in messages] out: Dict[str, Any] = { "termination": termination, "trajectory_messages": trajectory, "gen_prompt": prediction.get("gen_prompt", ""), "prediction": prediction, } if skip_generation: out["image_png"] = None out["image_error"] = None return out if termination != "answer": out["image_png"] = None out["image_error"] = f"Agent did not finish with answer (termination={termination})" return out eff_prompt, img_paths = get_effective_prompt_and_images(user_prompt, prediction) gen_base = _parse_qwen_edit_base_url() gen_path = os.environ.get("QWEN_EDIT_APP_PATH", "/generate") timeout = int(os.environ.get("GEN_IMAGE_TIMEOUT", "1800")) try: png_bytes = call_generate_api(gen_base, gen_path, eff_prompt, img_paths, timeout=timeout) out["image_png"] = png_bytes out["used_prompt"] = eff_prompt out["reference_paths"] = img_paths out["image_error"] = None except Exception as e: out["image_png"] = None out["image_error"] = str(e) return out def run_sync( user_prompt: str, *, temperature: float = 0.6, top_p: float = 0.9, skip_generation: bool = False, ) -> Dict[str, Any]: return asyncio.run( run_gensearcher_then_generate( user_prompt, temperature=temperature, top_p=top_p, skip_generation=skip_generation, ) )