File size: 7,774 Bytes
80b7188
 
 
 
 
 
 
 
 
 
 
 
 
 
138b29f
 
80b7188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138b29f
80b7188
 
 
 
138b29f
 
 
 
 
 
 
 
 
 
 
 
 
80b7188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""

Run one official GenSearcher trajectory (OpenAI-compatible vLLM) then call FireRed /generate adapter.

"""
from __future__ import annotations

import asyncio
import base64
import json
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple

import requests

from space_health import check_v1_models, is_localhost_url

from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import (
    create_gen_image_tools,
)
from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import (
    GenImageDeepResearchWorkflow,
)


def _sanitize_content(msg: dict) -> dict:
    out = {"role": msg.get("role", ""), "content": ""}
    content = msg.get("content", "")
    if isinstance(content, str):
        out["content"] = content[:50000] + ("..." if len(content) > 50000 else "")
    else:
        out["content"] = str(content)[:50000]
    if "images" in msg:
        out["images"] = [
            (p if isinstance(p, str) else p.get("image", ""))[:200]
            for p in (msg["images"] or [])[:10]
        ]
    return out


def get_effective_prompt_and_images(

    user_prompt: str, prediction: dict

) -> Tuple[str, List[str]]:
    """Match eval/gen_image_from_results.get_effective_prompt_and_images."""
    gen_prompt = (prediction.get("gen_prompt") or "").strip()
    paths: List[str] = []
    for r in prediction.get("reference_images") or []:
        if not isinstance(r, dict):
            continue
        p = (r.get("local_path") or "").strip()
        if p and os.path.exists(p):
            paths.append(p)
    if gen_prompt and paths:
        return gen_prompt, paths
    return user_prompt, []


def _parse_qwen_edit_base_url() -> str:
    raw = os.environ.get("QWEN_EDIT_APP_URL", "http://127.0.0.1:8765").strip()
    try:
        urls = json.loads(raw)
        if isinstance(urls, list) and urls:
            return str(urls[0]).rstrip("/")
    except json.JSONDecodeError:
        pass
    return raw.rstrip("/").strip('"').strip("'")


def call_generate_api(

    base_url: str,

    path: str,

    prompt: str,

    image_paths: List[str],

    timeout: int = 1800,

) -> bytes:
    if not path.startswith("/"):
        path = "/" + path
    url = base_url.rstrip("/") + path
    ref_images_b64: List[str] = []
    for img_path in image_paths[:3]:
        with open(img_path, "rb") as f:
            ref_images_b64.append(base64.b64encode(f.read()).decode("utf-8"))
    image_urls = (
        [f"data:image/jpeg;base64,{b}" for b in ref_images_b64]
        if ref_images_b64
        else None
    )
    payload = {
        "image_urls": image_urls,
        "prompt": prompt,
        "seed": int(os.environ.get("GEN_SEED", "0")),
        "true_cfg_scale": float(os.environ.get("GEN_TRUE_CFG_SCALE", "4.0")),
        "negative_prompt": os.environ.get("GEN_NEGATIVE_PROMPT", " "),
        "num_inference_steps": int(os.environ.get("GEN_NUM_INFERENCE_STEPS", "40")),
        "guidance_scale": float(os.environ.get("GEN_GUIDANCE_SCALE", "1.0")),
        "num_images_per_prompt": 1,
    }
    r = requests.post(url, json=payload, timeout=timeout)
    r.raise_for_status()
    result = r.json()
    if not result.get("success"):
        raise RuntimeError(result.get("message", str(result)))
    img_b64 = result.get("image") or ""
    if img_b64.startswith("data:image"):
        img_b64 = img_b64.split(",", 1)[-1]
    return base64.b64decode(img_b64)


async def run_gensearcher_then_generate(

    user_prompt: str,

    *,

    temperature: float = 0.6,

    top_p: float = 0.9,

    skip_generation: bool = False,

) -> Dict[str, Any]:
    sample_id = str(uuid.uuid4())[:8]
    task = {
        "id": sample_id,
        "question": user_prompt,
        "prompt": user_prompt,
        "meta": {"source": "hf_space"},
    }

    model = os.environ.get("GEN_EVAL_MODEL", "Gen-Searcher-8B")
    base_url = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8002/v1").strip().rstrip("/")
    if not base_url.endswith("/v1"):
        base_url = base_url + "/v1"
    api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")

    ok_llm, llm_msg = check_v1_models(base_url, api_key)
    if not ok_llm:
        hint = ""
        if is_localhost_url(base_url):
            hint = (
                " You are targeting localhost inside the Space container. Nothing is listening unless you set "
                "Space variable START_VLLM_GENSEARCHER=1 (and GPU) or change OPENAI_BASE_URL to a reachable "
                "OpenAI-compatible server (your vLLM / TGI URL ending in /v1)."
            )
        raise RuntimeError(
            f"GenSearcher LLM is not reachable at {base_url}/models — {llm_msg}.{hint}"
        )

    rollout_engine = OpenAIEngine(
        model=model,
        base_url=base_url,
        api_key=api_key,
        max_prompt_length=int(os.environ.get("MAX_PROMPT_LENGTH", "64000")),
        max_response_length=int(os.environ.get("MAX_RESPONSE_LENGTH", "64000")),
        sampling_params={
            "temperature": temperature,
            "top_p": top_p,
        },
    )
    tools = create_gen_image_tools()
    workflow_engine = AgentWorkflowEngine(
        workflow_cls=GenImageDeepResearchWorkflow,
        workflow_args={"tools": tools, "reward_function": None},
        rollout_engine=rollout_engine,
        n_parallel_tasks=1,
        retry_limit=2,
    )
    await workflow_engine.initialize_pool()
    try:
        _, _, episode = await workflow_engine.process_task_with_retry(
            task, sample_id, 0
        )
    finally:
        workflow_engine.shutdown()

    info = episode.info or {}
    messages = info.get("messages", [])
    prediction = info.get("prediction", {}) if isinstance(info.get("prediction"), dict) else {}
    termination = info.get("termination") or (
        episode.termination_reason.value
        if getattr(episode, "termination_reason", None)
        else "unknown"
    )
    trajectory = [_sanitize_content(m) for m in messages]

    out: Dict[str, Any] = {
        "termination": termination,
        "trajectory_messages": trajectory,
        "gen_prompt": prediction.get("gen_prompt", ""),
        "prediction": prediction,
    }

    if skip_generation:
        out["image_png"] = None
        out["image_error"] = None
        return out

    if termination != "answer":
        out["image_png"] = None
        out["image_error"] = f"Agent did not finish with answer (termination={termination})"
        return out

    eff_prompt, img_paths = get_effective_prompt_and_images(user_prompt, prediction)
    gen_base = _parse_qwen_edit_base_url()
    gen_path = os.environ.get("QWEN_EDIT_APP_PATH", "/generate")
    timeout = int(os.environ.get("GEN_IMAGE_TIMEOUT", "1800"))

    try:
        png_bytes = call_generate_api(gen_base, gen_path, eff_prompt, img_paths, timeout=timeout)
        out["image_png"] = png_bytes
        out["used_prompt"] = eff_prompt
        out["reference_paths"] = img_paths
        out["image_error"] = None
    except Exception as e:
        out["image_png"] = None
        out["image_error"] = str(e)
    return out


def run_sync(

    user_prompt: str,

    *,

    temperature: float = 0.6,

    top_p: float = 0.9,

    skip_generation: bool = False,

) -> Dict[str, Any]:
    return asyncio.run(
        run_gensearcher_then_generate(
            user_prompt,
            temperature=temperature,
            top_p=top_p,
            skip_generation=skip_generation,
        )
    )