| """ | |
| Run one official GenSearcher trajectory (OpenAI-compatible vLLM) then call FireRed /generate adapter. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import uuid | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import requests | |
| from space_health import check_v1_models, is_localhost_url | |
| from rllm.engine.agent_workflow_engine import AgentWorkflowEngine | |
| from rllm.engine.rollout import OpenAIEngine | |
| from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import ( | |
| create_gen_image_tools, | |
| ) | |
| from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import ( | |
| GenImageDeepResearchWorkflow, | |
| ) | |
| def _sanitize_content(msg: dict) -> dict: | |
| out = {"role": msg.get("role", ""), "content": ""} | |
| content = msg.get("content", "") | |
| if isinstance(content, str): | |
| out["content"] = content[:50000] + ("..." if len(content) > 50000 else "") | |
| else: | |
| out["content"] = str(content)[:50000] | |
| if "images" in msg: | |
| out["images"] = [ | |
| (p if isinstance(p, str) else p.get("image", ""))[:200] | |
| for p in (msg["images"] or [])[:10] | |
| ] | |
| return out | |
| def get_effective_prompt_and_images( | |
| user_prompt: str, prediction: dict | |
| ) -> Tuple[str, List[str]]: | |
| """Match eval/gen_image_from_results.get_effective_prompt_and_images.""" | |
| gen_prompt = (prediction.get("gen_prompt") or "").strip() | |
| paths: List[str] = [] | |
| for r in prediction.get("reference_images") or []: | |
| if not isinstance(r, dict): | |
| continue | |
| p = (r.get("local_path") or "").strip() | |
| if p and os.path.exists(p): | |
| paths.append(p) | |
| if gen_prompt and paths: | |
| return gen_prompt, paths | |
| return user_prompt, [] | |
| def _parse_qwen_edit_base_url() -> str: | |
| raw = os.environ.get("QWEN_EDIT_APP_URL", "http://127.0.0.1:8765").strip() | |
| try: | |
| urls = json.loads(raw) | |
| if isinstance(urls, list) and urls: | |
| return str(urls[0]).rstrip("/") | |
| except json.JSONDecodeError: | |
| pass | |
| return raw.rstrip("/").strip('"').strip("'") | |
| def call_generate_api( | |
| base_url: str, | |
| path: str, | |
| prompt: str, | |
| image_paths: List[str], | |
| timeout: int = 1800, | |
| ) -> bytes: | |
| if not path.startswith("/"): | |
| path = "/" + path | |
| url = base_url.rstrip("/") + path | |
| ref_images_b64: List[str] = [] | |
| for img_path in image_paths[:3]: | |
| with open(img_path, "rb") as f: | |
| ref_images_b64.append(base64.b64encode(f.read()).decode("utf-8")) | |
| image_urls = ( | |
| [f"data:image/jpeg;base64,{b}" for b in ref_images_b64] | |
| if ref_images_b64 | |
| else None | |
| ) | |
| payload = { | |
| "image_urls": image_urls, | |
| "prompt": prompt, | |
| "seed": int(os.environ.get("GEN_SEED", "0")), | |
| "true_cfg_scale": float(os.environ.get("GEN_TRUE_CFG_SCALE", "4.0")), | |
| "negative_prompt": os.environ.get("GEN_NEGATIVE_PROMPT", " "), | |
| "num_inference_steps": int(os.environ.get("GEN_NUM_INFERENCE_STEPS", "40")), | |
| "guidance_scale": float(os.environ.get("GEN_GUIDANCE_SCALE", "1.0")), | |
| "num_images_per_prompt": 1, | |
| } | |
| r = requests.post(url, json=payload, timeout=timeout) | |
| r.raise_for_status() | |
| result = r.json() | |
| if not result.get("success"): | |
| raise RuntimeError(result.get("message", str(result))) | |
| img_b64 = result.get("image") or "" | |
| if img_b64.startswith("data:image"): | |
| img_b64 = img_b64.split(",", 1)[-1] | |
| return base64.b64decode(img_b64) | |
| async def run_gensearcher_then_generate( | |
| user_prompt: str, | |
| *, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| skip_generation: bool = False, | |
| ) -> Dict[str, Any]: | |
| sample_id = str(uuid.uuid4())[:8] | |
| task = { | |
| "id": sample_id, | |
| "question": user_prompt, | |
| "prompt": user_prompt, | |
| "meta": {"source": "hf_space"}, | |
| } | |
| model = os.environ.get("GEN_EVAL_MODEL", "Gen-Searcher-8B") | |
| base_url = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8002/v1").strip().rstrip("/") | |
| if not base_url.endswith("/v1"): | |
| base_url = base_url + "/v1" | |
| api_key = os.environ.get("OPENAI_API_KEY", "EMPTY") | |
| ok_llm, llm_msg = check_v1_models(base_url, api_key) | |
| if not ok_llm: | |
| hint = "" | |
| if is_localhost_url(base_url): | |
| hint = ( | |
| " You are targeting localhost inside the Space container. Nothing is listening unless you set " | |
| "Space variable START_VLLM_GENSEARCHER=1 (and GPU) or change OPENAI_BASE_URL to a reachable " | |
| "OpenAI-compatible server (your vLLM / TGI URL ending in /v1)." | |
| ) | |
| raise RuntimeError( | |
| f"GenSearcher LLM is not reachable at {base_url}/models — {llm_msg}.{hint}" | |
| ) | |
| rollout_engine = OpenAIEngine( | |
| model=model, | |
| base_url=base_url, | |
| api_key=api_key, | |
| max_prompt_length=int(os.environ.get("MAX_PROMPT_LENGTH", "64000")), | |
| max_response_length=int(os.environ.get("MAX_RESPONSE_LENGTH", "64000")), | |
| sampling_params={ | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| }, | |
| ) | |
| tools = create_gen_image_tools() | |
| workflow_engine = AgentWorkflowEngine( | |
| workflow_cls=GenImageDeepResearchWorkflow, | |
| workflow_args={"tools": tools, "reward_function": None}, | |
| rollout_engine=rollout_engine, | |
| n_parallel_tasks=1, | |
| retry_limit=2, | |
| ) | |
| await workflow_engine.initialize_pool() | |
| try: | |
| _, _, episode = await workflow_engine.process_task_with_retry( | |
| task, sample_id, 0 | |
| ) | |
| finally: | |
| workflow_engine.shutdown() | |
| info = episode.info or {} | |
| messages = info.get("messages", []) | |
| prediction = info.get("prediction", {}) if isinstance(info.get("prediction"), dict) else {} | |
| termination = info.get("termination") or ( | |
| episode.termination_reason.value | |
| if getattr(episode, "termination_reason", None) | |
| else "unknown" | |
| ) | |
| trajectory = [_sanitize_content(m) for m in messages] | |
| out: Dict[str, Any] = { | |
| "termination": termination, | |
| "trajectory_messages": trajectory, | |
| "gen_prompt": prediction.get("gen_prompt", ""), | |
| "prediction": prediction, | |
| } | |
| if skip_generation: | |
| out["image_png"] = None | |
| out["image_error"] = None | |
| return out | |
| if termination != "answer": | |
| out["image_png"] = None | |
| out["image_error"] = f"Agent did not finish with answer (termination={termination})" | |
| return out | |
| eff_prompt, img_paths = get_effective_prompt_and_images(user_prompt, prediction) | |
| gen_base = _parse_qwen_edit_base_url() | |
| gen_path = os.environ.get("QWEN_EDIT_APP_PATH", "/generate") | |
| timeout = int(os.environ.get("GEN_IMAGE_TIMEOUT", "1800")) | |
| try: | |
| png_bytes = call_generate_api(gen_base, gen_path, eff_prompt, img_paths, timeout=timeout) | |
| out["image_png"] = png_bytes | |
| out["used_prompt"] = eff_prompt | |
| out["reference_paths"] = img_paths | |
| out["image_error"] = None | |
| except Exception as e: | |
| out["image_png"] = None | |
| out["image_error"] = str(e) | |
| return out | |
| def run_sync( | |
| user_prompt: str, | |
| *, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| skip_generation: bool = False, | |
| ) -> Dict[str, Any]: | |
| return asyncio.run( | |
| run_gensearcher_then_generate( | |
| user_prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| skip_generation=skip_generation, | |
| ) | |
| ) | |