JSCPPProgrammer's picture
Fix HF exec entrypoint: generate entrypoint.sh in Docker; body in entrypoint_body.sh
c994fd2 verified
"""
Hugging Face Space: official GenSearcher agent + FireRed-Image-Edit-1.1 generation.
Set PYTHONPATH to include vendor/rllm (see Dockerfile / README).
"""
from __future__ import annotations
import io
import json
import os
from pathlib import Path
import gradio as gr
from PIL import Image
from space_gen import run_sync
from space_health import llm_endpoint_status
def _trajectory_to_markdown(trajectory: list) -> str:
if not trajectory:
return "_No messages_"
parts = []
for i, m in enumerate(trajectory):
role = m.get("role", "")
content = m.get("content", "")
if isinstance(content, list):
content = json.dumps(content, ensure_ascii=False)[:8000]
parts.append(f"### {i + 1}. {role}\n\n```\n{content}\n```\n")
return "\n".join(parts)
def run_pipeline(
prompt: str,
temperature: float,
top_p: float,
research_only: bool,
):
prompt = (prompt or "").strip()
if not prompt:
return None, "Enter a non-empty prompt.", "", ""
try:
result = run_sync(
prompt,
temperature=float(temperature),
top_p=float(top_p),
skip_generation=bool(research_only),
)
except Exception as e:
import traceback
return None, f"**Error**\n\n```\n{e}\n{traceback.format_exc()}\n```", "", ""
traj_md = _trajectory_to_markdown(result.get("trajectory_messages") or [])
meta = {
"termination": result.get("termination"),
"gen_prompt": result.get("gen_prompt"),
"used_prompt": result.get("used_prompt"),
"reference_paths": result.get("reference_paths"),
"image_error": result.get("image_error"),
}
meta_txt = "```json\n" + json.dumps(meta, ensure_ascii=False, indent=2) + "\n```"
png = result.get("image_png")
if png:
img = Image.open(io.BytesIO(png)).convert("RGB")
return img, meta_txt, traj_md, result.get("gen_prompt") or ""
return None, meta_txt, traj_md, result.get("gen_prompt") or ""
with gr.Blocks(title="GenSearcher + FireRed") as demo:
gr.Markdown(
"## GenSearcher + FireRed-Image-Edit-1.1\n"
"Runs the **official** GenSearcher search/browse/image-search agent (vLLM), "
"then generates with **FireRed** via the same `/generate` API as the Qwen edit server.\n\n"
"**LLM:** Either run Gen-Searcher **in this same Space** (`START_VLLM_GENSEARCHER=1` → vLLM on localhost; "
"no second Space), **or** set `OPENAI_BASE_URL` to an OpenAI-compatible **`…/v1`** endpoint. "
"Browse summarization needs `BROWSE_SUMMARY_BASE_URL` when `BROWSE_GENERATE_ENGINE=vllm` (see README).\n\n"
"**Search / browse (optional keys):** without `SERPER_KEY_ID` and `JINA_API_KEYS`, the agent uses **DuckDuckGo** "
"for web and image search and **direct HTTP** page fetch for visits. Set those secrets if you prefer Serper + Jina.\n\n"
"**Connection errors:** On Hugging Face Spaces, `http://127.0.0.1:8002/v1` only works if you run vLLM "
"in the same container (`START_VLLM_GENSEARCHER=1` + GPU). Otherwise set `OPENAI_BASE_URL` to your **public** inference server."
)
status_md = gr.Markdown(llm_endpoint_status())
refresh_status = gr.Button("Re-check endpoints", size="sm")
def _refresh():
return llm_endpoint_status()
refresh_status.click(fn=_refresh, outputs=status_md)
demo.load(fn=_refresh, outputs=status_md)
with gr.Row():
prompt = gr.Textbox(
label="Image task / prompt",
lines=4,
placeholder="Describe the image you want, including any real-world facts to verify.",
)
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.6, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
research_only = gr.Checkbox(
label="Research only (no FireRed generation)",
value=False,
)
run_btn = gr.Button("Run", variant="primary")
with gr.Row():
out_image = gr.Image(label="Generated image", type="pil")
out_meta = gr.Markdown(label="Run metadata")
out_traj = gr.Markdown(label="Trajectory (sanitized)")
out_gen_prompt = gr.Textbox(label="gen_prompt (from agent)", lines=6)
run_btn.click(
fn=run_pipeline,
inputs=[prompt, temperature, top_p, research_only],
outputs=[out_image, out_meta, out_traj, out_gen_prompt],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1)
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")))