diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..fbfb9ec876f2a4eeeb410565c6e9c4279b71a83b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +_upstream_gen_searcher +.git +**/__pycache__ +**/*.pyc +**/.pytest_cache +.env +*.md.bak diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..accd2a49353007a3536eb923508f62e17dbfbe06 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +vendor/rllm/docs/assets/rllm_components.png filter=lfs diff=lfs merge=lfs -text +vendor/rllm/docs/assets/sdk_arch.png filter=lfs diff=lfs merge=lfs -text diff --git a/.pytest_cache/.gitignore b/.pytest_cache/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..08a7f458f1f002823bc794c47ca1996a57e72c86 --- /dev/null +++ b/.pytest_cache/.gitignore @@ -0,0 +1,2 @@ +# Created by pytest automatically. +* diff --git a/.pytest_cache/CACHEDIR.TAG b/.pytest_cache/CACHEDIR.TAG new file mode 100644 index 0000000000000000000000000000000000000000..fce15ad7eaa74e5682b644c84efb75334c112f95 --- /dev/null +++ b/.pytest_cache/CACHEDIR.TAG @@ -0,0 +1,4 @@ +Signature: 8a477f597d28d172789f06886806bc55 +# This file is a cache directory tag created by pytest. +# For information about cache directory tags, see: +# https://bford.info/cachedir/spec.html diff --git a/.pytest_cache/README.md b/.pytest_cache/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c7526af2448672de4537dfed042ed74daadb17bf --- /dev/null +++ b/.pytest_cache/README.md @@ -0,0 +1,8 @@ +# pytest cache directory # + +This directory contains data from the pytest's cache plugin, +which provides the `--lf` and `--ff` options, as well as the `cache` fixture. + +**Do not** commit this to version control. + +See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information. diff --git a/.pytest_cache/v/cache/nodeids b/.pytest_cache/v/cache/nodeids new file mode 100644 index 0000000000000000000000000000000000000000..52cf1bc405b75f5eecd11f9bc8bbe08ec07b8504 --- /dev/null +++ b/.pytest_cache/v/cache/nodeids @@ -0,0 +1,4 @@ +[ + "tests/test_imports.py::test_firered_service_parse", + "tests/test_imports.py::test_space_gen_importable" +] \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..de1f90a5c9be47ff9c70a597725a50550b966703 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,35 @@ +# Hugging Face Space (Docker) — GenSearcher + FireRed +# Requires GPU. For multi-GPU full-local mode, set START_VLLM_*=1 and CUDA device envs in README. + +FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime + +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY vendor/rllm /app/vendor/rllm +COPY requirements.txt /app/requirements.txt +COPY app.py space_gen.py /app/ +COPY services /app/services +COPY scripts /app/scripts + +ENV PYTHONPATH=/app/vendor/rllm +ENV GRADIO_SERVER_PORT=7860 + +RUN pip install --no-cache-dir --upgrade pip setuptools wheel \ + && pip install --no-cache-dir -e /app/vendor/rllm \ + && pip install --no-cache-dir -r /app/requirements.txt + +# Optional: local vLLM inside the image (large). Disable with build-arg if you only use external APIs. +ARG INSTALL_VLLM=1 +RUN if [ "$INSTALL_VLLM" = "1" ]; then pip install --no-cache-dir "vllm>=0.6.3"; fi + +RUN chmod +x /app/scripts/entrypoint.sh + +EXPOSE 7860 + +CMD ["/app/scripts/entrypoint.sh"] diff --git a/README.md b/README.md index 62af97ab5757ccf932a12e9bce7f991ef62ad5ee..90eb7abe1b0c266ce41f37829d47e1c25dd32750 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,70 @@ ---- -title: Gensearcher Firered -emoji: 🐢 -colorFrom: yellow -colorTo: red -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: GenSearcher + FireRed +emoji: 🔍 +colorFrom: blue +colorTo: purple +sdk: docker +pinned: false +suggested_hardware: a100-large +--- + +# GenSearcher + FireRed-Image-Edit-1.1 + +This Space runs the **official** [Gen-Searcher](https://github.com/tulerfeng/Gen-Searcher) image workflow (`GenImageDeepResearchWorkflow` + `create_gen_image_tools`) against an OpenAI-compatible **GenSearcher-8B** server, then calls **FireRed-Image-Edit-1.1** through the same HTTP `/generate` contract as the upstream Qwen image API. + +## Architecture + +1. **Agent** — vendored `vision_deepresearch_async_workflow` from Gen-Searcher (unchanged `create_gen_image_tools`). +2. **LLM** — `OPENAI_BASE_URL` + `GEN_EVAL_MODEL` (default `Gen-Searcher-8B`). +3. **Browse summaries** — `BROWSE_SUMMARY_BASE_URL` + `BROWSE_SUMMARY_MODEL` with `BROWSE_GENERATE_ENGINE=vllm` (see [`.env.gen_image`](https://github.com/tulerfeng/Gen-Searcher/blob/main/Gen-DeepResearch-RL/rllm/.env.gen_image)). +4. **Image generation** — local FastAPI adapter at `QWEN_EDIT_APP_URL` (default `http://127.0.0.1:8765`), compatible with `call_qwen_edit_to_generate_image` in upstream `gen_image_deepresearch_reward.py`. + +## Space secrets / environment + +Configure in the Space **Settings → Variables and secrets** (or a mounted `.env.gen_image`): + +| Variable | Purpose | +|----------|---------| +| `SERPER_KEY_ID` | Serper API key ([serper.dev](https://serper.dev)) | +| `JINA_API_KEYS` | Jina reader key for `r.jina.ai` | +| `OPENAI_BASE_URL` | OpenAI-compatible base URL for GenSearcher-8B (e.g. `https://.../v1`) | +| `OPENAI_API_KEY` | API key for that endpoint (use `EMPTY` if unused) | +| `GEN_EVAL_MODEL` | Served model name (default `Gen-Searcher-8B`) | +| `BROWSE_SUMMARY_BASE_URL` | OpenAI-compatible base for Qwen3-VL browse summarizer | +| `BROWSE_SUMMARY_MODEL` | Model id (e.g. `Qwen3-VL-30B-A3B-Instruct`) | +| `BROWSE_SUMMARY_API_KEY` | Key for browse server (`EMPTY` if none) | +| `BROWSE_GENERATE_ENGINE` | Set to `vllm` for OpenAI-compatible servers | + +If the FireRed adapter runs **inside** this container (default), you usually do **not** need to set `QWEN_EDIT_APP_URL` (the entrypoint sets it to `http://127.0.0.1:8765`). + +See [`dotenv.example`](./dotenv.example) for a full template. + +## Hardware + +- **Minimum practical:** 1× GPU for FireRed + Gradio, with **external** vLLM endpoints for GenSearcher and browse (set `START_VLLM_GENSEARCHER=0`, `START_VLLM_BROWSE=0` — defaults). +- **Full local (as in upstream scripts):** multiple GPUs — enable `START_VLLM_GENSEARCHER=1`, `START_VLLM_BROWSE=1`, and set `GENSEARCHER_CUDA_VISIBLE_DEVICES`, `BROWSE_CUDA_VISIBLE_DEVICES`, `FIRERED_CUDA_VISIBLE_DEVICES` to disjoint GPU indices. + +## Local build + +```bash +cd hf-space +docker build -t gensearcher-firered . +docker run --gpus all -p 7860:7860 --env-file .env.gen_image gensearcher-firered +``` + +## Deploy to your Hugging Face account + +```bash +hf auth login +hf repos create JSCPPProgrammer/gensearcher-firered --type space --sdk docker --private +# from hf-space/ +hf upload JSCPPProgrammer/gensearcher-firered . . +``` + +Then set Space GPU and secrets in the Hub UI. + +## References + +- [Gen-Searcher](https://github.com/tulerfeng/Gen-Searcher) +- [GenSearcher/Gen-Searcher-8B](https://huggingface.co/GenSearcher/Gen-Searcher-8B) +- [FireRed-Image-Edit-1.1](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.1) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..24517a19856b4023fe067aba8e72c3fa470317cf --- /dev/null +++ b/app.py @@ -0,0 +1,108 @@ +""" +Hugging Face Space: official GenSearcher agent + FireRed-Image-Edit-1.1 generation. + +Set PYTHONPATH to include vendor/rllm (see Dockerfile / README). +""" +from __future__ import annotations + +import io +import json +import os +from pathlib import Path + +import gradio as gr +from PIL import Image + +from space_gen import run_sync + + +def _trajectory_to_markdown(trajectory: list) -> str: + if not trajectory: + return "_No messages_" + parts = [] + for i, m in enumerate(trajectory): + role = m.get("role", "") + content = m.get("content", "") + if isinstance(content, list): + content = json.dumps(content, ensure_ascii=False)[:8000] + parts.append(f"### {i + 1}. {role}\n\n```\n{content}\n```\n") + return "\n".join(parts) + + +def run_pipeline( + prompt: str, + temperature: float, + top_p: float, + research_only: bool, +): + prompt = (prompt or "").strip() + if not prompt: + return None, "Enter a non-empty prompt.", "", "" + + try: + result = run_sync( + prompt, + temperature=float(temperature), + top_p=float(top_p), + skip_generation=bool(research_only), + ) + except Exception as e: + import traceback + + return None, f"**Error**\n\n```\n{e}\n{traceback.format_exc()}\n```", "", "" + + traj_md = _trajectory_to_markdown(result.get("trajectory_messages") or []) + meta = { + "termination": result.get("termination"), + "gen_prompt": result.get("gen_prompt"), + "used_prompt": result.get("used_prompt"), + "reference_paths": result.get("reference_paths"), + "image_error": result.get("image_error"), + } + meta_txt = "```json\n" + json.dumps(meta, ensure_ascii=False, indent=2) + "\n```" + + png = result.get("image_png") + if png: + img = Image.open(io.BytesIO(png)).convert("RGB") + return img, meta_txt, traj_md, result.get("gen_prompt") or "" + + return None, meta_txt, traj_md, result.get("gen_prompt") or "" + + +with gr.Blocks(title="GenSearcher + FireRed") as demo: + gr.Markdown( + "## GenSearcher + FireRed-Image-Edit-1.1\n" + "Runs the **official** GenSearcher search/browse/image-search agent (vLLM), " + "then generates with **FireRed** via the same `/generate` API as the Qwen edit server.\n\n" + "**Required secrets:** `SERPER_KEY_ID`, `JINA_API_KEYS`, and vLLM endpoints for " + "`OPENAI_BASE_URL` + `BROWSE_SUMMARY_BASE_URL` (see README)." + ) + with gr.Row(): + prompt = gr.Textbox( + label="Image task / prompt", + lines=4, + placeholder="Describe the image you want, including any real-world facts to verify.", + ) + with gr.Row(): + temperature = gr.Slider(0.0, 1.5, value=0.6, label="Temperature") + top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p") + research_only = gr.Checkbox( + label="Research only (no FireRed generation)", + value=False, + ) + run_btn = gr.Button("Run", variant="primary") + with gr.Row(): + out_image = gr.Image(label="Generated image", type="pil") + out_meta = gr.Markdown(label="Run metadata") + out_traj = gr.Markdown(label="Trajectory (sanitized)") + out_gen_prompt = gr.Textbox(label="gen_prompt (from agent)", lines=6) + + run_btn.click( + fn=run_pipeline, + inputs=[prompt, temperature, top_p, research_only], + outputs=[out_image, out_meta, out_traj, out_gen_prompt], + ) + +if __name__ == "__main__": + demo.queue(default_concurrency_limit=1) + demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860"))) diff --git a/dotenv.example b/dotenv.example new file mode 100644 index 0000000000000000000000000000000000000000..15e8c74219acdce084be2473bedb428ce26c4b0a --- /dev/null +++ b/dotenv.example @@ -0,0 +1,31 @@ +# Copy to .env.gen_image in Space secrets or mount. See README. + +# GenSearcher agent (OpenAI-compatible vLLM) +export OPENAI_API_KEY="EMPTY" +export OPENAI_BASE_URL="http://127.0.0.1:8002/v1" +export GEN_EVAL_MODEL="Gen-Searcher-8B" + +# FireRed adapter (this Space sets automatically if START_FIRERED_API=1) +export QWEN_EDIT_APP_URL="http://127.0.0.1:8765" +export QWEN_EDIT_APP_PATH="/generate" + +# Serper + Jina (required for official tools) +export SERPER_KEY_ID="" +export JINA_API_KEYS="" +export TEXT_SEARCH_API_BASE_URL="https://google.serper.dev/search" +export IMAGE_SEARCH_API_BASE_URL="https://google.serper.dev/images" +export IMAGE_SEARCH_SAVE_DIR="/tmp/cached_images" + +# Browse summarization (vLLM OpenAI-compatible) +export BROWSE_GENERATE_ENGINE="vllm" +export BROWSE_SUMMARY_BASE_URL="http://127.0.0.1:8003/v1" +export BROWSE_SUMMARY_API_KEY="EMPTY" +export BROWSE_SUMMARY_MODEL="Qwen3-VL-30B-A3B-Instruct" + +export MAX_LLM_CALL_PER_RUN=9 +export GEN_MAX_NEW_TOKENS_PER_TURN=4096 +export GEN_IMAGE_TIMEOUT=1800 + +# Optional: launch local vLLM inside the container (needs extra GPUs) +# export START_VLLM_GENSEARCHER=1 +# export START_VLLM_BROWSE=1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a0b767f24188f379721339ca0e441165ee48519 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# Install rllm from vendored tree first (see Dockerfile). +diffusers>=0.31.0 +accelerate>=0.26.0 +gradio>=4.44.0 +tiktoken>=0.7.0 +uvicorn[standard]>=0.30.0 diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..66b71057c3d8d235bbe8d801f25f6aa33a361e6e --- /dev/null +++ b/scripts/entrypoint.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +set -euo pipefail +cd /app + +export PYTHONPATH="/app/vendor/rllm:${PYTHONPATH:-}" + +# Optional: load Space secrets copied to this path +if [[ -f /app/.env.gen_image ]]; then + set -a + # shellcheck source=/dev/null + source /app/.env.gen_image + set +a +fi + +wait_http() { + local url=$1 + local name=$2 + local max_attempts=${3:-90} + local i=0 + echo "[entrypoint] Waiting for ${name} (${url})..." + until curl -sf "$url" >/dev/null 2>&1; do + i=$((i + 1)) + if [[ $i -ge $max_attempts ]]; then + echo "[entrypoint] Timeout waiting for ${name}" + exit 1 + fi + sleep 2 + done + echo "[entrypoint] ${name} is up." +} + +# Defaults: only FireRed + Gradio in-container. Point OPENAI_BASE_URL / BROWSE_SUMMARY_BASE_URL +# to your vLLM (or other OpenAI-compatible) endpoints via Space secrets. + +# --- Optional local vLLM: GenSearcher-8B (OpenAI-compatible) --- +if [[ "${START_VLLM_GENSEARCHER:-0}" == "1" ]]; then + CUDA_VISIBLE_DEVICES="${GENSEARCHER_CUDA_VISIBLE_DEVICES:-0}" \ + vllm serve "${GENSEARCHER_MODEL_ID:-GenSearcher/Gen-Searcher-8B}" \ + --host 0.0.0.0 \ + --port 8002 \ + --tensor-parallel-size "${GENSEARCHER_TP:-1}" \ + --gpu-memory-utilization "${VLLM_GPU_MEMORY_UTIL:-0.85}" \ + --served-model-name "${GEN_EVAL_MODEL:-Gen-Searcher-8B}" \ + --max-model-len "${GENSEARCHER_MAX_MODEL_LEN:-65536}" \ + --no-enable-prefix-caching & + wait_http "http://127.0.0.1:8002/v1/models" "GenSearcher vLLM" + export OPENAI_BASE_URL="${OPENAI_BASE_URL:-http://127.0.0.1:8002/v1}" +fi + +# --- Optional local vLLM: browse summarization (Qwen3-VL) --- +if [[ "${START_VLLM_BROWSE:-0}" == "1" ]]; then + export BROWSE_GENERATE_ENGINE=vllm + CUDA_VISIBLE_DEVICES="${BROWSE_CUDA_VISIBLE_DEVICES:-1}" \ + vllm serve "${BROWSE_MODEL_ID:-Qwen/Qwen3-VL-30B-A3B-Instruct}" \ + --host 0.0.0.0 \ + --port 8003 \ + --tensor-parallel-size "${BROWSE_TP:-1}" \ + --gpu-memory-utilization "${VLLM_GPU_MEMORY_UTIL:-0.85}" \ + --served-model-name "${BROWSE_SUMMARY_MODEL:-Qwen3-VL-30B-A3B-Instruct}" \ + --max-model-len "${BROWSE_MAX_MODEL_LEN:-65536}" \ + --mm-processor-cache-gb 0 \ + --no-enable-prefix-caching & + wait_http "http://127.0.0.1:8003/v1/models" "Browse-summary vLLM" + export BROWSE_SUMMARY_BASE_URL="${BROWSE_SUMMARY_BASE_URL:-http://127.0.0.1:8003/v1}" +fi + +# --- FireRed adapter (GenSearcher /generate contract) --- +if [[ "${START_FIRERED_API:-1}" == "1" ]]; then + CUDA_VISIBLE_DEVICES="${FIRERED_CUDA_VISIBLE_DEVICES:-0}" \ + python -m uvicorn services.firered_generate:app --host 0.0.0.0 --port 8765 & + wait_http "http://127.0.0.1:8765/health" "FireRed API" 120 + export QWEN_EDIT_APP_URL="${QWEN_EDIT_APP_URL:-http://127.0.0.1:8765}" +else + echo "[entrypoint] START_FIRERED_API=0 — use external QWEN_EDIT_APP_URL for generation." +fi + +exec python app.py diff --git a/scripts/verify_env.py b/scripts/verify_env.py new file mode 100644 index 0000000000000000000000000000000000000000..783411be5a984db7fc51cf35931f31955061701a --- /dev/null +++ b/scripts/verify_env.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Print which GenSearcher Space env vars are set (never print secret values).""" +from __future__ import annotations + +import os + +CHECKS = [ + ("SERPER_KEY_ID", True), + ("JINA_API_KEYS", True), + ("OPENAI_BASE_URL", True), + ("GEN_EVAL_MODEL", False), + ("OPENAI_API_KEY", False), + ("BROWSE_SUMMARY_BASE_URL", True), + ("BROWSE_SUMMARY_MODEL", False), + ("BROWSE_SUMMARY_API_KEY", False), + ("BROWSE_GENERATE_ENGINE", False), + ("QWEN_EDIT_APP_URL", False), + ("QWEN_EDIT_APP_PATH", False), +] + + +def main() -> None: + missing_required = [] + for name, required in CHECKS: + val = os.environ.get(name, "").strip() + ok = bool(val) + status = "OK" if ok else ("MISSING" if required else "optional empty") + print(f"{name}: {status}") + if required and not ok: + missing_required.append(name) + if missing_required: + print("\nSet required variables (see README / dotenv.example):", ", ".join(missing_required)) + raise SystemExit(1) + print("\nRequired variables present.") + + +if __name__ == "__main__": + main() diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..500287872c4ce3d1a2bd9ec53e26bee64a8fa650 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1 @@ +# FireRed GenSearcher adapter service package. diff --git a/services/firered_generate.py b/services/firered_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ce3144a07808b682109ced2b058fb3afbc1d94 --- /dev/null +++ b/services/firered_generate.py @@ -0,0 +1,143 @@ +""" +FireRed-Image-Edit HTTP service matching GenSearcher Qwen /generate contract. + +Request/response aligned with qwen_image_api_server and gen_image_deepresearch_reward.call_qwen_edit_to_generate_image. +""" +from __future__ import annotations + +import argparse +import base64 +import io +import os +import re +from typing import List, Optional + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from PIL import Image + +app = FastAPI(title="FireRed-Image-Edit GenSearcher adapter") + +_pipe = None + + +def _load_image_from_url_or_data(url_or_data: str) -> Image.Image: + if url_or_data.startswith("data:image/"): + m = re.match(r"data:image/[^;]+;base64,(.*)", url_or_data, re.DOTALL) + if not m: + raise ValueError("Invalid data URL") + raw = base64.b64decode(m.group(1)) + return Image.open(io.BytesIO(raw)).convert("RGB") + raise ValueError("Only data:image/...;base64,... URLs are supported in Space adapter") + + +class GenerateRequest(BaseModel): + image_urls: Optional[List[str]] = None + prompt: str + seed: int = 0 + true_cfg_scale: float = 4.0 + negative_prompt: str = " " + num_inference_steps: int = 40 + guidance_scale: float = 1.0 + num_images_per_prompt: int = 1 + + +def get_pipeline(): + global _pipe + if _pipe is None: + import torch + from diffusers import QwenImageEditPlusPipeline + + model_path = os.environ.get( + "FIRERED_MODEL_ID", "FireRedTeam/FireRed-Image-Edit-1.1" + ) + dtype = torch.bfloat16 + _pipe = QwenImageEditPlusPipeline.from_pretrained( + model_path, + torch_dtype=dtype, + ) + _pipe.to("cuda") + _pipe.set_progress_bar_config(disable=True) + return _pipe + + +@app.get("/health") +def health(): + return {"status": "ok", "model_loaded": _pipe is not None} + + +@app.post("/generate") +def generate(request: GenerateRequest): + try: + pipe = get_pipeline() + except Exception as e: + raise HTTPException(status_code=503, detail=f"Model not ready: {e}") + + import torch + + images: List[Image.Image] = [] + if request.image_urls: + for u in request.image_urls[:3]: + if u: + try: + images.append(_load_image_from_url_or_data(u)) + except Exception as ex: + raise HTTPException( + status_code=400, detail=f"Bad image_urls entry: {ex}" + ) + + gen = torch.Generator(device="cuda").manual_seed(int(request.seed)) + + try: + with torch.inference_mode(): + if not images: + # Text-only: FireRed is edit-focused; synthesize a neutral canvas for conditioning-free edit + blank = Image.new("RGB", (1024, 1024), (240, 240, 240)) + out = pipe( + image=[blank], + prompt=request.prompt, + generator=gen, + true_cfg_scale=float(request.true_cfg_scale), + negative_prompt=request.negative_prompt or " ", + num_inference_steps=int(request.num_inference_steps), + guidance_scale=float(request.guidance_scale), + num_images_per_prompt=int(request.num_images_per_prompt), + ) + else: + out = pipe( + image=images, + prompt=request.prompt, + generator=gen, + true_cfg_scale=float(request.true_cfg_scale), + negative_prompt=request.negative_prompt or " ", + num_inference_steps=int(request.num_inference_steps), + guidance_scale=float(request.guidance_scale), + num_images_per_prompt=int(request.num_images_per_prompt), + ) + pil = out.images[0] + except Exception as e: + import traceback + + return { + "success": False, + "message": f"{e}\n{traceback.format_exc()}", + } + + buf = io.BytesIO() + pil.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + return {"success": True, "image": b64} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8765) + args = parser.parse_args() + import uvicorn + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/space_gen.py b/space_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..e2edc2cbc0f4782ce5e4baed7166266c7c4e2bbc --- /dev/null +++ b/space_gen.py @@ -0,0 +1,214 @@ +""" +Run one official GenSearcher trajectory (OpenAI-compatible vLLM) then call FireRed /generate adapter. +""" +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import requests + +from rllm.engine.agent_workflow_engine import AgentWorkflowEngine +from rllm.engine.rollout import OpenAIEngine +from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import ( + create_gen_image_tools, +) +from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import ( + GenImageDeepResearchWorkflow, +) + + +def _sanitize_content(msg: dict) -> dict: + out = {"role": msg.get("role", ""), "content": ""} + content = msg.get("content", "") + if isinstance(content, str): + out["content"] = content[:50000] + ("..." if len(content) > 50000 else "") + else: + out["content"] = str(content)[:50000] + if "images" in msg: + out["images"] = [ + (p if isinstance(p, str) else p.get("image", ""))[:200] + for p in (msg["images"] or [])[:10] + ] + return out + + +def get_effective_prompt_and_images( + user_prompt: str, prediction: dict +) -> Tuple[str, List[str]]: + """Match eval/gen_image_from_results.get_effective_prompt_and_images.""" + gen_prompt = (prediction.get("gen_prompt") or "").strip() + paths: List[str] = [] + for r in prediction.get("reference_images") or []: + if not isinstance(r, dict): + continue + p = (r.get("local_path") or "").strip() + if p and os.path.exists(p): + paths.append(p) + if gen_prompt and paths: + return gen_prompt, paths + return user_prompt, [] + + +def _parse_qwen_edit_base_url() -> str: + raw = os.environ.get("QWEN_EDIT_APP_URL", "http://127.0.0.1:8765").strip() + try: + urls = json.loads(raw) + if isinstance(urls, list) and urls: + return str(urls[0]).rstrip("/") + except json.JSONDecodeError: + pass + return raw.rstrip("/").strip('"').strip("'") + + +def call_generate_api( + base_url: str, + path: str, + prompt: str, + image_paths: List[str], + timeout: int = 1800, +) -> bytes: + if not path.startswith("/"): + path = "/" + path + url = base_url.rstrip("/") + path + ref_images_b64: List[str] = [] + for img_path in image_paths[:3]: + with open(img_path, "rb") as f: + ref_images_b64.append(base64.b64encode(f.read()).decode("utf-8")) + image_urls = ( + [f"data:image/jpeg;base64,{b}" for b in ref_images_b64] + if ref_images_b64 + else None + ) + payload = { + "image_urls": image_urls, + "prompt": prompt, + "seed": int(os.environ.get("GEN_SEED", "0")), + "true_cfg_scale": float(os.environ.get("GEN_TRUE_CFG_SCALE", "4.0")), + "negative_prompt": os.environ.get("GEN_NEGATIVE_PROMPT", " "), + "num_inference_steps": int(os.environ.get("GEN_NUM_INFERENCE_STEPS", "40")), + "guidance_scale": float(os.environ.get("GEN_GUIDANCE_SCALE", "1.0")), + "num_images_per_prompt": 1, + } + r = requests.post(url, json=payload, timeout=timeout) + r.raise_for_status() + result = r.json() + if not result.get("success"): + raise RuntimeError(result.get("message", str(result))) + img_b64 = result.get("image") or "" + if img_b64.startswith("data:image"): + img_b64 = img_b64.split(",", 1)[-1] + return base64.b64decode(img_b64) + + +async def run_gensearcher_then_generate( + user_prompt: str, + *, + temperature: float = 0.6, + top_p: float = 0.9, + skip_generation: bool = False, +) -> Dict[str, Any]: + sample_id = str(uuid.uuid4())[:8] + task = { + "id": sample_id, + "question": user_prompt, + "prompt": user_prompt, + "meta": {"source": "hf_space"}, + } + + model = os.environ.get("GEN_EVAL_MODEL", "Gen-Searcher-8B") + base_url = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:8002/v1").rstrip("/") + if not base_url.endswith("/v1"): + base_url = base_url + "/v1" + api_key = os.environ.get("OPENAI_API_KEY", "EMPTY") + + rollout_engine = OpenAIEngine( + model=model, + base_url=base_url, + api_key=api_key, + max_prompt_length=int(os.environ.get("MAX_PROMPT_LENGTH", "64000")), + max_response_length=int(os.environ.get("MAX_RESPONSE_LENGTH", "64000")), + sampling_params={ + "temperature": temperature, + "top_p": top_p, + }, + ) + tools = create_gen_image_tools() + workflow_engine = AgentWorkflowEngine( + workflow_cls=GenImageDeepResearchWorkflow, + workflow_args={"tools": tools, "reward_function": None}, + rollout_engine=rollout_engine, + n_parallel_tasks=1, + retry_limit=2, + ) + await workflow_engine.initialize_pool() + try: + _, _, episode = await workflow_engine.process_task_with_retry( + task, sample_id, 0 + ) + finally: + workflow_engine.shutdown() + + info = episode.info or {} + messages = info.get("messages", []) + prediction = info.get("prediction", {}) if isinstance(info.get("prediction"), dict) else {} + termination = info.get("termination") or ( + episode.termination_reason.value + if getattr(episode, "termination_reason", None) + else "unknown" + ) + trajectory = [_sanitize_content(m) for m in messages] + + out: Dict[str, Any] = { + "termination": termination, + "trajectory_messages": trajectory, + "gen_prompt": prediction.get("gen_prompt", ""), + "prediction": prediction, + } + + if skip_generation: + out["image_png"] = None + out["image_error"] = None + return out + + if termination != "answer": + out["image_png"] = None + out["image_error"] = f"Agent did not finish with answer (termination={termination})" + return out + + eff_prompt, img_paths = get_effective_prompt_and_images(user_prompt, prediction) + gen_base = _parse_qwen_edit_base_url() + gen_path = os.environ.get("QWEN_EDIT_APP_PATH", "/generate") + timeout = int(os.environ.get("GEN_IMAGE_TIMEOUT", "1800")) + + try: + png_bytes = call_generate_api(gen_base, gen_path, eff_prompt, img_paths, timeout=timeout) + out["image_png"] = png_bytes + out["used_prompt"] = eff_prompt + out["reference_paths"] = img_paths + out["image_error"] = None + except Exception as e: + out["image_png"] = None + out["image_error"] = str(e) + return out + + +def run_sync( + user_prompt: str, + *, + temperature: float = 0.6, + top_p: float = 0.9, + skip_generation: bool = False, +) -> Dict[str, Any]: + return asyncio.run( + run_gensearcher_then_generate( + user_prompt, + temperature=temperature, + top_p=top_p, + skip_generation=skip_generation, + ) + ) diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000000000000000000000000000000000000..88926f8d38ded439a02d23cfb8c35cb0f48e6cd9 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,23 @@ +"""Smoke test: package layout and syntax (no GPU required).""" + + +def test_space_gen_importable(): + import importlib.util + from pathlib import Path + + root = Path(__file__).resolve().parents[1] + sg = root / "space_gen.py" + assert sg.exists() + spec = importlib.util.spec_from_file_location("space_gen", sg) + assert spec and spec.loader + # Do not execute full import (pulls torch/rllm); file must parse + src = sg.read_text(encoding="utf-8") + compile(src, str(sg), "exec") + + +def test_firered_service_parse(): + from pathlib import Path + + root = Path(__file__).resolve().parents[1] + fp = root / "services" / "firered_generate.py" + compile(fp.read_text(encoding="utf-8"), str(fp), "exec") diff --git a/vendor/rllm/.env.gen_image b/vendor/rllm/.env.gen_image new file mode 100644 index 0000000000000000000000000000000000000000..ddaa03459d18eb0f263742a3c967002a9aaa0d48 --- /dev/null +++ b/vendor/rllm/.env.gen_image @@ -0,0 +1,68 @@ +# Gen Image environment variables +# Usage: source .env.gen_image + +# ===== Image generation service for train/eval: qwen_image | nano ===== +# qwen_image: Qwen Edit HTTP service (see below); nano: Nano/Gemini image generation API (aligned with gen_image_from_results nano) +export GEN_IMAGE_SERVICE="qwen_image" + +# ===== Qwen Edit image generation service endpoints (used when GEN_IMAGE_SERVICE=qwen_image) ===== +# vLLM deployed service host/IP placeholder (IPs removed for open-source). +# Use multiple urls for acceleration +export QWEN_EDIT_APP_URL='["http://xxxx:8001", "http://xxxx:8001"]' +export QWEN_EDIT_APP_PATH="/generate" + +# ===== Nano image generation API (used when GEN_IMAGE_SERVICE=nano; aligned with eval gen_image_from_results nano) ===== +# export GEN_IMAGE_NANO_API_KEY="" +# export GEN_IMAGE_NANO_MODEL="gemini-3-pro-image-preview" +# # Total timeout (seconds): one Nano call (including retries) must not exceed this duration +# export GEN_IMAGE_NANO_TIMEOUT=1200 +# # Max retries: retry on submit/poll failure up to this count +# export GEN_IMAGE_NANO_MAX_TRY=100 +# # Max poll time per attempt (seconds): after task_id, poll imageGenerateQuery until status=1 or timeout +# export GEN_IMAGE_NANO_MAX_POLL=300 + +# ===== Reward scoring aligned with worldgen eval (GPT-4.1 + same prompt/overall formula) ===== +# OpenAI API key for gpt-4.1 scoring +export GEN_REWARD_API_KEY="" +export GEN_REWARD_API_BASE_URL="https://api.openai.com/v1" +export GEN_REWARD_MODEL="gpt-4.1" +# Text reward coefficient in [0,1]. +# Final reward = (1-GEN_REWARD_TEXT_COEF)*image_reward + GEN_REWARD_TEXT_COEF*text_reward +# 0 means image reward only (no extra text scoring call) +export GEN_REWARD_TEXT_COEF=0.5 + +# ===== Max LLM calls per agent run (shared by train/eval) ===== +export MAX_LLM_CALL_PER_RUN=9 + +# ===== Image generation settings ===== +export GEN_IMAGE_OUTPUT_DIR="./output_images" +export GEN_IMAGE_TIMEOUT=1800 +export GEN_MIN_INPUT_IMAGES=1 +export GEN_MAX_INPUT_IMAGES=4 +export GEN_API_CONCURRENCY=32 +# Per-turn generation token cap (rollout only; training still uses data.max_response_length) +export GEN_MAX_NEW_TOKENS_PER_TURN=4096 +# export QWEN_VL_MAX_PIXELS=262144 +export QWEN_VL_MAX_PIXELS=160000 + +# ===== Web tools (text + image search; default Serper endpoints) ===== +export TEXT_SEARCH_API_BASE_URL="https://google.serper.dev/search" +export IMAGE_SEARCH_API_BASE_URL="https://google.serper.dev/images" +# API key sent as X-API-KEY (Serper: get key from serper.dev) +export SERPER_KEY_ID="" + +# ===== Jina API (for web browsing) ===== +export JINA_API_KEYS="" + +# ===== Local cache directory for image search (IMAGE_SEARCH_SAVE_DIR) ===== +export IMAGE_SEARCH_SAVE_DIR="./cached_images" + +# ===== Browse summary model (Qwen3 via vLLM; shared by train/eval) ===== +# vLLM deployed service host/IP (IPs removed for open-source) +export BROWSE_SUMMARY_BASE_URL="http://xxx:8001/v1" +export BROWSE_SUMMARY_MODEL="Qwen3-VL-30B-A3B-Instruct" +export BROWSE_GENERATE_ENGINE="vllm" + +# ===== Other optional settings ===== +export BROWSE_RANDOM_SLEEP="0" + diff --git a/vendor/rllm/.github/workflows/pre-commit.yml b/vendor/rllm/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000000000000000000000000000000000..db0a804ecc87052cf875ca85b2cc931bd08ecae0 --- /dev/null +++ b/vendor/rllm/.github/workflows/pre-commit.yml @@ -0,0 +1,35 @@ +name: Pre-commit + +on: + pull_request: + push: + branches: + - main + - v0.* + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Cache pre-commit + uses: actions/cache@v3 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + restore-keys: | + pre-commit- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files diff --git a/vendor/rllm/.gitignore b/vendor/rllm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a9c78451eaca180da8b5a675666dc5599f92a9d5 --- /dev/null +++ b/vendor/rllm/.gitignore @@ -0,0 +1,214 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +*.whl +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# PyPI configuration file +.pypirc + +# DS_Store +.DS_Store + +# Ignore parquet files. +data/* + +# Ignore verl script outputs. +**/outputs/ +**/wandb/ +**/checkpoints/ +tmp/ + +# Ignore debug logs and run result logs +logs/ +rllm/*.json + +# Ignore the big datasets +rllm/data/test/ +rllm/data/train/ + +rllm/data/datasets/ +rllm/registry/ + +# Coding assistant local rules ignore +.cursor/rules/* +CLAUDE.md + +# Strands outputs ignore +examples/strands_outputs/* +strands_outputs/* +examples/strands/strands_outputs/* + +# Deepresearch outputs ignore +examples/deepresearch/deepresearch_outputs/* +deepresearch_outputs/* +examples/deepresearch/hle_outputs/* +*/hle_outputs/* +examples/deepresearch/HLE_OUTPUT_EVOLUTION.md + +# Until we have a good way to handle cuda-version specific pkgs, we ignore uv.lock +uv.lock \ No newline at end of file diff --git a/vendor/rllm/.pre-commit-config.yaml b/vendor/rllm/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e56ed52cfeb780b93585a16e8116b5c5e856a524 --- /dev/null +++ b/vendor/rllm/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.11.4" + hooks: + - id: ruff + args: ["--fix", "--show-fixes", "--output-format=full"] + exclude: ^.*\.(ipynb)$|^verl/.*$ + - id: ruff-format + exclude: ^verl/.*$ \ No newline at end of file diff --git a/vendor/rllm/.readthedocs.yaml b/vendor/rllm/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..171e2e26f4cea9c57b9830d7a8f82b5f6dce8b38 --- /dev/null +++ b/vendor/rllm/.readthedocs.yaml @@ -0,0 +1,27 @@ +# Read the Docs configuration file for rLLM +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.10" + jobs: + post_create_environment: + # Install poetry + - pip install --upgrade pip + post_install: + # Install any additional system dependencies if needed + - echo "Build environment ready" + +# Build documentation in the "docs/" directory with MkDocs +mkdocs: + configuration: mkdocs.yml + fail_on_warning: false + +# Python configuration +python: + install: + - requirements: docs/requirements.txt \ No newline at end of file diff --git a/vendor/rllm/Dockerfile b/vendor/rllm/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..c852ea28f700d4689fa11791f2ad47e7156da34d --- /dev/null +++ b/vendor/rllm/Dockerfile @@ -0,0 +1,26 @@ +FROM verlai/verl:vllm011.latest + +WORKDIR /workspace + +RUN git clone https://github.com/volcengine/verl.git +RUN cd verl && \ + git checkout v0.6.1 && \ + pip install -e . + +# 2) Install rllm (editable) +RUN git clone https://github.com/rllm-org/rllm.git +RUN cd rllm && \ + pip install -e . + +# 3) Install playwright +RUN pip install playwright && \ + playwright install chromium && \ + playwright install-deps + +CMD ["/bin/bash"] + +# Docker Usage +# docker build -t rllm . +# docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/rllm -v /tmp:/tmp --name rllm-container rllm sleep infinity +# docker start rllm-container +# docker exec -it rllm-container bash diff --git a/vendor/rllm/LICENSE b/vendor/rllm/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9b5e4019df618fc47d429529c369f4903142669b --- /dev/null +++ b/vendor/rllm/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/vendor/rllm/README.md b/vendor/rllm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..39877f133e50351219fdd289ddedadb77814e4bb --- /dev/null +++ b/vendor/rllm/README.md @@ -0,0 +1,126 @@ +
and tags.
+
+IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
+
+Example of a correct call:
+
+import numpy as np
+# Your code here
+print(f"The result is: {np.mean([1,2,3])}")
+
+ and tags.
+
+IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
+
+Example of a correct call:
+
+import numpy as np
+# Your code here
+print(f"The result is: {np.mean([1,2,3])}")
+
+" in content:
+ pass
+ elif '"name":' in content:
+ try:
+ tool_text = content.split("")[1].split(
+ " "
+ )[0]
+ tool_data = json5.loads(tool_text)
+ tool_name = tool_data.get("name", "Unknown")
+ if "arguments" in tool_data:
+ args_str = str(tool_data["arguments"])
+ pass
+ else:
+ pass
+ except Exception:
+ pass
+ else:
+ pass
+
+ # Clean up content if it contains tool_response
+ if "" in content:
+ pos = content.find("")
+ content = content[:pos]
+
+ # Only XML ReAct tool calls are supported.
+ if "" in content and " " in content:
+ # ReAct text format path
+ assistant_message = {
+ "role": "assistant",
+ "content": content.strip(),
+ "step_error": False,
+ }
+ messages.append(assistant_message)
+ tool_error = False
+
+ tool_call_text = content.split("")[1].split(" ")[
+ 0
+ ]
+ # Special handling for Python code (match original logic)
+ if "python" in tool_call_text.lower():
+ try:
+ # Extract code from the original content (not just tool_call_text)
+ code_raw = (
+ content.split("")[1]
+ .split(" ")[0]
+ .split("")[1]
+ .split("")[0]
+ .strip()
+ )
+ result = await self.execute_python(code_raw)
+ if isinstance(result, str) and result.startswith(
+ (
+ "Python execution error:",
+ "PythonInterpreter tool not available",
+ "PythonInterpreter tool is not callable",
+ )
+ ):
+ tool_error = True
+ except Exception:
+ result = (
+ "[Python Interpreter Error]: Python code formatting error."
+ )
+ tool_error = True
+ else:
+ try:
+ # Parse JSON tool call
+ tool_call = json5.loads(tool_call_text)
+ tool_name = tool_call.get("name", "")
+ tool_args = tool_call.get("arguments", {})
+ if tool_name == "crop_and_search":
+ tool_args["image_id"] = image_path
+ result = await self.custom_call_tool(tool_name, tool_args)
+ except Exception:
+ result = "[Json Parse Error]: Tool call is not a valid JSON."
+ tool_error = True
+
+ if tool_error:
+ assistant_message["step_error"] = True
+
+ # Add tool response in ReAct format
+ tool_response = f"\n{result}\n "
+ messages.append({"role": "user", "content": tool_response})
+
+ # Check for final answer AFTER processing tools
+ # This allows o3 to execute tools even when it includes answer in same message
+ elif "" in content and " " in content:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.strip(),
+ "step_error": False,
+ }
+ )
+ prediction = content.split("")[1].split(" ")[0].strip()
+ termination = "answer"
+ consecutive_bad_steps = 0
+ break
+
+ # Priority 3: No tool call and answer, just reasoning or format error
+ else:
+ is_repetitive = analyze_repetition_ngram(content)
+ is_overlong = count_words(content) > 2500
+ if is_repetitive and is_overlong:
+ repetition_count += 1
+ print(f"Round {round}: Content repetition detected (count: {repetition_count}/{MAX_REPEAT_TURN})")
+
+ if repetition_count >= MAX_REPEAT_TURN:
+ final_instruction = {
+ "role": "user",
+ "content": f"Based on all the information above, please provide your best answer now in the format: your final thinking \nyour answer "
+ }
+
+ messages.append(final_instruction)
+
+ print(f"Round {round}: Content repetition threshold reached, requesting final answer.")
+
+ try:
+ response = await self.call_server(messages)
+ final_content = response.text if hasattr(
+ response, "text") and response.text else ""
+ messages.append(
+ {"role": "assistant", "content": final_content.strip()})
+
+ if "" in final_content and " " in final_content:
+ prediction = final_content.split(
+ "")[1].split(" ")[0].strip()
+ termination = "answer"
+ else:
+ prediction = final_content.strip() if final_content.strip(
+ ) else "No answer found due to content repetition."
+ termination = "answer"
+ except Exception as exc:
+ prediction = "No answer found due to content repetition and model failure."
+ termination = f"answer"
+
+ break
+ else:
+ # Repetition detected but below threshold, continue with format error handling
+ observation = f"Error: Content repetition detected. Invalid content format. Content must contain or tags. Let's try again."
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.strip(),
+ "step_error": True,
+ }
+ )
+ messages.append(
+ {"role": "user", "content": observation})
+ else:
+ observation = "Error: Invalid content format. Content must contain or tags. Let's try again."
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.strip(),
+ "step_error": True,
+ }
+ )
+ messages.append({"role": "user", "content": observation})
+
+ # Determine whether another round is feasible
+ if num_llm_calls_available <= 0 and "" not in content:
+ # Round limit reached, give model one final chance to answer
+ final_instruction = {
+ "role": "user",
+ "content": f"You have reached the maximum number of reasoning rounds ({self.max_llm_calls}). Based on all the information gathered so far, please provide your best final answer now in the format: your final thinking \nyour answer "
+ }
+
+ messages.append(final_instruction)
+
+ print(f"Round {round}: Round limit reached, requesting final answer")
+
+ try:
+ response = await self.call_server(messages)
+ final_content = response.text if hasattr(
+ response, "text") and response.text else ""
+ messages.append(
+ {"role": "assistant", "content": final_content.strip()})
+
+ if "" in final_content and " " in final_content:
+ prediction = final_content.split(
+ "")[1].split(" ")[0].strip()
+ termination = "answer"
+ else:
+ prediction = final_content.strip() if final_content.strip(
+ ) else f"No answer found after {self.max_llm_calls} rounds."
+ termination = "answer"
+ except Exception as exc:
+ prediction = f"No answer found after {self.max_llm_calls} rounds and model failure."
+ termination = f"round limit reached, model failed: {str(exc)}"
+
+ return self._build_result(
+ question=question,
+ answer=answer,
+ messages=messages,
+ prediction=prediction,
+ termination=termination,
+ rounds=round,
+ start_time=start_time,
+ )
+
+ last_message_content = (
+ messages[-1].get("content", "") if isinstance(messages[-1], dict) else ""
+ )
+ if last_message_content and "" in last_message_content:
+ prediction = last_message_content.split("")[1].split(" ")[0]
+ termination = "answer"
+ else:
+ prediction = "No answer found."
+ termination = "answer not found"
+ if num_llm_calls_available == 0:
+ termination = "exceed available llm calls"
+
+ result = self._build_result(
+ question=question,
+ answer=answer,
+ messages=messages,
+ prediction=prediction,
+ termination=termination,
+ rounds=round,
+ start_time=start_time,
+ )
+
+ print("\n🏁 DeepResearch completed:")
+ print(f" Rounds: {round}")
+ print(f" Time: {result['time_taken']:.1f}s")
+ print(f" Termination: {termination}")
+ print(
+ " Token usage: prompt={prompt}, completion={completion}, max_prompt={max_prompt}".format(
+ prompt=self.total_prompt_tokens,
+ completion=self.total_completion_tokens,
+ max_prompt=self.max_prompt_tokens,
+ )
+ )
+ return result
+
+ async def custom_call_tool(self, tool_name: str, tool_args: dict, **kwargs) -> str:
+ """
+ Execute tool calls with the available tools.
+
+ Args:
+ tool_name: Name of the tool to call
+ tool_args: Arguments to pass to the tool
+
+ Returns:
+ Tool execution result as string
+ """
+ if tool_name in self.tools:
+ try:
+ # Call the tool
+ if hasattr(self.tools[tool_name], "call"):
+ # Async tool
+ if asyncio.iscoroutinefunction(self.tools[tool_name].call):
+ result = await self.tools[tool_name].call(**tool_args)
+ else:
+ result = self.tools[tool_name].call(**tool_args)
+ elif callable(self.tools[tool_name]):
+ # Direct callable
+ result = self.tools[tool_name](**tool_args)
+ else:
+ result = f"Tool {tool_name} is not callable"
+
+ return str(result)
+
+ except Exception as e:
+ return f"Error calling tool {tool_name}: {e}"
+ else:
+ available_tools = list(self.tools.keys())
+ return f"Tool {tool_name} not found. Available tools: {available_tools}"
+
+ async def execute_python(self, code: str) -> str:
+ """
+ Execute Python code using the PythonInterpreter tool.
+
+ Args:
+ code: Python code to execute
+
+ Returns:
+ Execution result as string
+ """
+ if "PythonInterpreter" in self.tools:
+ try:
+ # Use the PythonInterpreter tool
+ tool = self.tools["PythonInterpreter"]
+ if hasattr(tool, "call"):
+ if asyncio.iscoroutinefunction(tool.call):
+ result = await tool.call(code=code)
+ else:
+ result = tool.call(code=code)
+ return str(result)
+ else:
+ return "PythonInterpreter tool is not callable"
+ except Exception as e:
+ return f"Python execution error: {e}"
+ else:
+ return "PythonInterpreter tool not available"
+
+ def reset(self):
+ """Reset the agent state (for compatibility with rLLM workflow)."""
+ # Reset token counters for each new task
+ self.total_prompt_tokens = 0
+ self.total_completion_tokens = 0
+
+ async def run(
+ self,
+ question: str,
+ answer: str = None,
+ images: list = None,
+ image_path: str = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Public interface for running the agent.
+
+ Args:
+ question: Research question to answer
+ answer: Ground truth answer (optional, for evaluation)
+
+ Returns:
+ Result dictionary
+ """
+ # Reset token counters for each new run
+ self.reset()
+ return await self._run(question, answer, images, image_path, **kwargs)
+
+
+DeepResearchAgent = MultiTurnReactAgent
diff --git a/vendor/rllm/eval/deepresearch_workflow.py b/vendor/rllm/eval/deepresearch_workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ec572394bf7daacbbf058d5358ef2b3e40d9fc
--- /dev/null
+++ b/vendor/rllm/eval/deepresearch_workflow.py
@@ -0,0 +1,393 @@
+from io import BytesIO
+import asyncio
+import re
+from typing import Any, List, Optional
+
+from PIL import Image
+
+from eval.deepresearch_agent import DeepResearchAgent
+from rllm.agents.agent import Action, Episode, Step, Trajectory
+from rllm.engine.rollout import RolloutEngine
+from rllm.rewards.reward_fn import RewardFunction
+from rllm.workflows.workflow import TerminationReason, Workflow
+
+import base64
+
+
+def as_pil_image(image: Any) -> Image.Image | None:
+ if isinstance(image, Image.Image):
+ return image
+
+ if isinstance(image, str) and image.startswith("data:image/"):
+ try:
+ _, encoded = image.split(",", 1)
+ image_bytes = base64.b64decode(encoded)
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+
+ if isinstance(image, dict):
+ if "bytes" in image and image["bytes"] is not None:
+ try:
+ return Image.open(BytesIO(image["bytes"])).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+ data_str = None
+ if "data" in image and isinstance(image["data"], str):
+ data_str = image["data"]
+ elif "path" in image and isinstance(image["path"], str):
+ data_str = image["path"]
+ elif "url" in image and isinstance(image["url"], str):
+ data_str = image["url"]
+ if data_str:
+ if data_str.startswith("data:image/"):
+ try:
+ _, encoded = data_str.split(",", 1)
+ image_bytes = base64.b64decode(encoded)
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+ try:
+ return Image.open(data_str).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+
+ if isinstance(image, str):
+ try:
+ return Image.open(image).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+
+ return None
+
+
+def _extract_action_from_response(response: str) -> Action:
+ if "" in response and " " in response:
+ tool_call_text = response.split("")[1].split(" ")[0]
+ return Action(action={"type": "tool_call", "tool_call": tool_call_text.strip()})
+ if "" in response and " " in response:
+ answer = response.split("")[1].split(" ")[0].strip()
+ return Action(action={"type": "final_answer", "answer": answer})
+ return Action(action={"type": "reasoning", "content": response})
+
+
+def _is_valid_format(content: str) -> bool:
+ if not isinstance(content, str) or not content:
+ return False
+ pattern = (
+ r"^.*? \s*(.*? |.*? )\s*$"
+ )
+ return re.match(pattern, content, re.DOTALL) is not None
+
+
+def _format_reward_for_step(step: Step) -> float:
+ if step.info.get("step_error"):
+ return 0.0
+ if _has_tool_error_observation(step.observation):
+ return 0.0
+ content = step.model_response if isinstance(step.model_response, str) else ""
+ return 1.0 if _is_valid_format(content) else 0.0
+
+
+def _has_tool_error_observation(observation: Any) -> bool:
+ if not isinstance(observation, str):
+ return False
+ error_markers = (
+ "[Json Parse Error]",
+ "[Python Interpreter Error]",
+ "Python execution error:",
+ "PythonInterpreter tool not available",
+ "PythonInterpreter tool is not callable",
+ )
+ return any(marker in observation for marker in error_markers)
+
+
+def _is_step_error(step: Step) -> bool:
+ if step.info.get("step_error"):
+ return True
+ return _has_tool_error_observation(step.observation)
+
+
+def _get_next_observation(messages: list[dict], current_index: int) -> str:
+ if current_index + 1 < len(messages):
+ next_msg = messages[current_index + 1]
+ if next_msg["role"] == "user" and "" in next_msg["content"]:
+ return next_msg["content"]
+ return ""
+
+
+def _map_termination_reason(termination: str) -> TerminationReason:
+ mapping = {
+ "answer": TerminationReason.ENV_DONE,
+ "timeout": TerminationReason.UNKNOWN,
+ "max_rounds_reached": TerminationReason.UNKNOWN,
+ "token_limit_no_answer": TerminationReason.UNKNOWN,
+ "answer_token_limit": TerminationReason.UNKNOWN,
+ "exceed available llm calls": TerminationReason.UNKNOWN,
+ "prompt_budget_reached": TerminationReason.UNKNOWN,
+ "max_rounds_reached_no_answer": TerminationReason.UNKNOWN,
+ "repetition_detected": TerminationReason.UNKNOWN, # Will be masked
+ "tool_errors_too_many": TerminationReason.UNKNOWN, # Will be masked
+ "consecutive_step_errors": TerminationReason.UNKNOWN, # Will be masked
+ "error": TerminationReason.UNKNOWN, # Will be masked
+ }
+ return mapping.get(termination, TerminationReason.UNKNOWN)
+
+
+def _evaluate_answer(prediction: str, ground_truth: str) -> bool:
+ if not prediction or not ground_truth:
+ return False
+ return prediction.strip().lower() == ground_truth.strip().lower()
+
+
+def _should_mask_episode(result: dict, episode: Episode) -> tuple[bool, str]:
+ """Determine if the entire episode should be masked based on answer/step error conditions."""
+ steps = episode.trajectories[0].steps if episode.trajectories else []
+ termination = result.get("termination", "")
+
+ # Mask directly if no final answer is produced.
+ if termination != "answer":
+ return True, termination or "no_final_answer"
+
+ # Check if there are too many step errors.
+ total_steps = len(steps)
+
+ if total_steps > 0:
+ step_error_steps = sum(1 for step in steps if _is_step_error(step))
+ if step_error_steps / total_steps > 0.5:
+ return True, "tool_errors_too_many"
+
+ return False, ""
+
+
+def _to_pil_image(img: Any) -> Optional[Image.Image]:
+ """Best-effort conversion to PIL.Image for downstream multi-modal pipeline."""
+ if isinstance(img, Image.Image):
+ return img
+ pil = as_pil_image(img)
+ if pil is not None:
+ return pil
+ if isinstance(img, dict) and "bytes" in img:
+ try:
+ return Image.open(BytesIO(img["bytes"])).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+ if isinstance(img, str):
+ try:
+ return Image.open(img).convert("RGB")
+ except Exception: # noqa: BLE001
+ return None
+ return None
+
+
+class DeepResearchWorkflow(Workflow):
+ def __init__(
+ self,
+ rollout_engine: RolloutEngine,
+ executor,
+ tools: dict | None = None,
+ system_prompt: str | None = None,
+ reward_function: RewardFunction | None = None,
+ **kwargs,
+ ):
+ super().__init__(rollout_engine, executor, **kwargs)
+
+ self.tools = tools or {}
+ for tool in self.tools.values():
+ if hasattr(tool, "set_executor"):
+ tool.set_executor(self.executor)
+ self.system_prompt = system_prompt
+ self.reward_function = reward_function
+
+ self.agent = DeepResearchAgent(
+ rollout_engine=rollout_engine,
+ tools=self.tools,
+ system_prompt=self.system_prompt,
+ )
+
+ async def run(self, task: dict, uid: str, **kwargs) -> Episode:
+ self.reset(task=task, uid=uid)
+
+ question = task.get("question", task.get("query", "No question provided"))
+ answer = task.get("answer", "")
+
+ print(f"🚀 Starting DeepResearch workflow for task {uid}")
+ print(f" Question: {question}")
+
+ try:
+ raw_images = None
+ if "images" in task:
+ raw_images = task.get("images")
+
+ pil_images: List[Image.Image] = []
+ if raw_images is not None:
+ if not isinstance(raw_images, list):
+ raw_images = [raw_images]
+ for img in raw_images:
+ pil = _to_pil_image(img)
+ if pil is not None:
+ pil_images.append(pil)
+
+ if pil_images:
+ result = await self.agent.run(
+ question=question,
+ answer=answer,
+ images=pil_images,
+ image_path=raw_images[0],
+ **kwargs,
+ )
+ else:
+ result = await self.agent.run(
+ question=question, answer=answer, **kwargs
+ )
+
+ episode = self._convert_result_to_episode(result, task, uid)
+
+ prediction = result.get("prediction", "")
+ if self.reward_function is not None and prediction:
+ try:
+ if asyncio.iscoroutinefunction(self.reward_function):
+ reward_out = await self.reward_function(task, prediction)
+ else:
+ loop = asyncio.get_running_loop()
+ reward_out = await loop.run_in_executor(
+ self.executor,
+ lambda: self.reward_function(task, prediction),
+ )
+ except Exception as err: # noqa: BLE001
+ print(f"Reward function failed: {err}")
+ else:
+ if reward_out.is_correct is not None:
+ episode.is_correct = bool(reward_out.is_correct)
+ if isinstance(reward_out.metadata, dict):
+ reward_metadata = episode.info.setdefault("reward_metadata", {})
+ for key, value in reward_out.metadata.items():
+ if key not in reward_metadata:
+ reward_metadata[key] = value
+ if getattr(reward_out, "reward", None) is not None:
+ episode.info["reward_function_reward"] = float(
+ reward_out.reward
+ )
+
+ # Check whether to mask the whole episode.
+ should_mask_episode, mask_reason = _should_mask_episode(result, episode)
+
+ if should_mask_episode:
+ episode.termination_reason = TerminationReason.UNKNOWN
+ episode.metrics = {
+ "reward/outcome": 0.0,
+ "masked": 1.0,
+ }
+ episode.info["mask_reason"] = mask_reason or result.get(
+ "termination", "unknown"
+ )
+ else:
+ # No mask: use outcome_reward only, no format reward.
+ outcome_reward = 1.0 if episode.is_correct else 0.0
+ for trajectory in episode.trajectories:
+ if not trajectory.steps:
+ continue
+ trajectory.reward = outcome_reward
+
+ last_step = trajectory.steps[-1]
+ last_step.reward = trajectory.reward
+ trajectory.steps = [last_step]
+
+ episode.metrics = {
+ "reward/outcome": outcome_reward,
+ "masked": 0.0,
+ }
+
+ print(f"✅ DeepResearch workflow completed for task {uid}")
+ print(f" Prediction: {result.get('prediction', 'No prediction')}")
+ print(f" True Answer: {answer}")
+ print(f" Metrics: {episode.metrics}")
+ if episode.info.get("mask_reason"):
+ print(f" Mask Reason: {episode.info['mask_reason']}")
+ return episode
+
+ except Exception as exc: # noqa: BLE001
+ print(f"❌ DeepResearch workflow failed for task {uid}: {exc}")
+ episode = Episode()
+ episode.id = uid
+ episode.task = task
+ episode.termination_reason = TerminationReason.ERROR
+ episode.is_correct = False
+ episode.trajectories = []
+ episode.metrics = {
+ "reward/outcome": 0.0,
+ "masked": 1.0,
+ }
+ episode.info = {"error": str(exc)}
+ return episode
+
+ def _convert_result_to_episode(self, result: dict, task: dict, uid: str) -> Episode:
+ messages = result.get("messages", [])
+ prediction = result.get("prediction", "")
+ termination = result.get("termination", "unknown")
+ rounds = result.get("rounds", 0)
+ time_taken = result.get("time_taken", 0.0)
+
+ trajectories: list[Trajectory] = []
+ steps: list[Step] = []
+
+ i = 0
+ while i < len(messages):
+ msg = messages[i]
+ if msg["role"] == "assistant":
+ context = messages[: i + 1]
+ assistant_content = msg.get("content", "")
+ action = _extract_action_from_response(assistant_content)
+ observation = _get_next_observation(messages, i)
+ step = Step(
+ chat_completions=context.copy(),
+ model_response=assistant_content,
+ action=action,
+ observation=observation,
+ reward=0.0,
+ )
+ step.model_output = None
+ if msg.get("step_error"):
+ step.info["step_error"] = True
+ if _has_tool_error_observation(observation):
+ step.info["step_error"] = True
+ steps.append(step)
+ i += 1
+
+ trajectory = Trajectory(
+ name="deepresearch_agent",
+ task=task,
+ steps=steps,
+ reward=0.0,
+ info={},
+ )
+ trajectories.append(trajectory)
+
+ answer_text = task.get("answer", "")
+ is_correct = _evaluate_answer(prediction, answer_text) if answer_text else False
+
+ episode = Episode()
+ episode.id = uid
+ episode.task = task
+ episode.termination_reason = _map_termination_reason(termination)
+ episode.is_correct = is_correct
+ episode.trajectories = trajectories
+ episode.metrics = {}
+ episode.info = {
+ "rounds": rounds,
+ "time_taken": time_taken,
+ "prediction": prediction,
+ "answer": answer_text,
+ "token_usage": result.get("token_usage", {}),
+ }
+ return episode
+
+ def reset(self, task: dict | None = None, uid: str | None = None):
+ # MultiTurnReactAgent handles per-run state; nothing to reset here.
+ return
+
+ def is_multithread_safe(self) -> bool:
+ return True
+
+
+__all__ = ["DeepResearchWorkflow"]
diff --git a/vendor/rllm/eval/eval_runner.py b/vendor/rllm/eval/eval_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d72476407b19d16a5c9fe41a4d112e7b2179b4
--- /dev/null
+++ b/vendor/rllm/eval/eval_runner.py
@@ -0,0 +1,431 @@
+"""
+Benchmark evaluation runner for Vision DeepResearch (single workflow).
+
+Goals:
+- Use the same workflow/tools/reward as training (`DeepResearchWorkflow`, `deepresearch_reward_fn`).
+- Default rollout: OpenAI-compatible (can point to local vLLM server via base_url).
+- Input: Parquet only with columns question/answer/(images).
+- Save full trajectories (sanitized) and metrics under eval/outputs//.
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import json
+import os
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional
+from urllib.parse import urlparse
+from urllib.request import urlopen
+
+from datasets import load_dataset
+
+import yaml
+from PIL import Image
+
+from vision_deepresearch_async_workflow.deepresearch_tools_async_executor import (
+ get_all_tools,
+)
+from vision_deepresearch_async_workflow.deepresearch_workflow import (
+ DeepResearchWorkflow,
+)
+from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
+from rllm.engine.rollout import OpenAIEngine
+from rllm.rewards.reward_fn import deepresearch_reward_fn
+
+
+# ---------------------- Data loading ---------------------- #
+
+
+def _extract_from_content(content: list) -> tuple[str, List[Any]]:
+ """Convert OpenAI-style content array (text + image_url) to question + images list."""
+ texts: List[str] = []
+ images: List[Any] = []
+ for item in content:
+ if not isinstance(item, dict):
+ continue
+ if item.get("type") == "text" and "text" in item:
+ texts.append(str(item["text"]))
+ elif item.get("type") == "image_url":
+ url = item.get("image_url", {}) or {}
+ if isinstance(url, dict) and "url" in url:
+ images.append({"url": url["url"]})
+ question = "\n".join([t for t in texts if t.strip()])
+ return question, images
+
+
+def _url_to_bytes(url: str) -> dict:
+ try:
+ parsed = urlparse(url)
+ if parsed.scheme in {"http", "https"}:
+ with urlopen(url, timeout=10) as resp:
+ data = resp.read()
+ return {"bytes": data, "origin_url": url}
+ except Exception:
+ pass
+ return {"url": url}
+
+
+def _normalize_images(images: List[Any]) -> List[Any]:
+ normalized: List[Any] = []
+ for img in images:
+ if isinstance(img, str) and img.startswith(("http://", "https://")):
+ normalized.append(_url_to_bytes(img))
+ elif isinstance(img, dict) and "url" in img and isinstance(img["url"], str):
+ normalized.append(_url_to_bytes(img["url"]))
+ else:
+ normalized.append(img)
+ return normalized
+
+
+def _record_to_task(record: dict) -> dict:
+ """Strict mapping to the unified schema question/answer/images."""
+ question = record.get("question", "")
+ answer = record.get("answer", "")
+ images: List[Any] = []
+
+ if "images" in record and record["images"] is not None:
+ imgs = record["images"]
+ images.extend(imgs if isinstance(imgs, list) else [imgs])
+
+ if not isinstance(question, str):
+ question = str(question)
+ if not isinstance(answer, str):
+ answer = str(answer)
+
+ return {
+ "id": record.get("id") or record.get("idx"),
+ "question": question,
+ "answer": answer,
+ "images": _normalize_images(images),
+ }
+
+
+def _record_from_parquet(rec: dict, idx: int) -> dict:
+ # Strict allowed fields
+ if "question" not in rec or "answer" not in rec:
+ raise ValueError(f"Record {idx} missing required fields 'question'/'answer'")
+
+ question_raw = rec.get("question")
+ answer_raw = rec.get("answer")
+
+ question = str(question_raw) if question_raw is not None else ""
+ if not question.strip():
+ raise ValueError(f"Record {idx} has empty question")
+
+ answer = str(answer_raw) if answer_raw is not None else ""
+ if not answer.strip():
+ raise ValueError(f"Record {idx} has empty answer")
+ images_raw = rec.get("images", [])
+ if images_raw is None:
+ images_raw = []
+ images_list: List[Any] = (
+ images_raw if isinstance(images_raw, list) else [images_raw]
+ )
+
+ return {
+ "id": rec.get("id") or rec.get("idx") or rec.get("_id"),
+ "question": question,
+ "answer": answer,
+ "images": _normalize_images(images_list),
+ }
+
+
+def load_tasks(args) -> List[dict]:
+ if not args.parquet:
+ raise ValueError(
+ "Parquet path is required. Please set --parquet or data.parquet."
+ )
+
+ ds_dict = load_dataset("parquet", data_files=str(args.parquet))
+ ds = ds_dict["train"]
+
+ allowed_cols = {"question", "answer", "images", "id", "idx", "_id","image_caption","question_original"}
+ required_cols = {"question", "answer"}
+ cols = set(ds.column_names)
+
+ missing = required_cols - cols
+ if missing:
+ raise ValueError(f"Parquet file missing required columns: {sorted(missing)}")
+
+ extras = cols - allowed_cols
+ if extras:
+ raise ValueError(
+ f"Parquet file contains unsupported columns (allowed: question/answer/images/id): {sorted(extras)}"
+ )
+
+ if args.max_samples is not None:
+ ds = ds.select(range(min(args.max_samples, len(ds))))
+
+ tasks: List[dict] = []
+ for idx, rec in enumerate(ds):
+ rec_dict = dict(rec)
+ task = _record_from_parquet(rec_dict, idx)
+ tasks.append(task)
+
+ if not tasks:
+ raise ValueError("No valid tasks loaded from Parquet.")
+ return tasks
+
+
+# ---------------------- Rollout setup ---------------------- #
+
+
+def build_rollout_engine(args):
+ sampling_params = {
+ "temperature": args.temperature,
+ "top_p": args.top_p,
+ }
+ if args.max_tokens is not None:
+ sampling_params["max_tokens"] = args.max_tokens
+ # OpenAIEngine also works with local OpenAI-compatible servers (e.g., vLLM)
+ return OpenAIEngine(
+ model=args.model,
+ base_url=args.base_url,
+ api_key=args.api_key,
+ sampling_params=sampling_params,
+ )
+
+
+# ---------------------- Serialization helpers ---------------------- #
+
+
+def _sanitize(obj: Any) -> Any:
+ if isinstance(obj, Image.Image):
+ return {"type": "PIL.Image", "size": obj.size}
+ if isinstance(obj, bytes):
+ return f""
+ if isinstance(obj, dict):
+ return {k: _sanitize(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [_sanitize(v) for v in obj]
+ return obj
+
+
+def episode_to_dict(ep) -> Dict[str, Any]:
+ trajectories = []
+ for tr in ep.trajectories or []:
+ steps = []
+ for st in tr.steps or []:
+ action_val = getattr(st.action, "action", st.action)
+ steps.append(
+ {
+ "chat_completions": _sanitize(st.chat_completions),
+ "model_response": _sanitize(st.model_response),
+ "action": _sanitize(action_val),
+ "observation": _sanitize(st.observation),
+ "reward": st.reward,
+ }
+ )
+ trajectories.append(
+ {
+ "name": tr.name,
+ "reward": tr.reward,
+ "info": _sanitize(getattr(tr, "info", {})),
+ "steps": steps,
+ }
+ )
+
+ return {
+ "id": ep.id,
+ "task": _sanitize(ep.task),
+ "termination_reason": (
+ ep.termination_reason.value if ep.termination_reason else None
+ ),
+ "is_correct": ep.is_correct,
+ "metrics": _sanitize(ep.metrics),
+ "info": _sanitize(ep.info),
+ "trajectories": trajectories,
+ }
+
+
+# ---------------------- Metrics & IO ---------------------- #
+
+
+def compute_metrics(episodes: Iterable[Any]) -> dict:
+ episodes = list(episodes)
+ total = len(episodes)
+ correct = sum(1 for ep in episodes if getattr(ep, "is_correct", False))
+ termination = {}
+ rewards: List[float] = []
+ for ep in episodes:
+ reason = ep.termination_reason.value if ep.termination_reason else "unknown"
+ termination[reason] = termination.get(reason, 0) + 1
+ if ep.trajectories:
+ rewards.extend(
+ [tr.reward for tr in ep.trajectories if tr.reward is not None]
+ )
+ avg_reward = (sum(rewards) / len(rewards)) if rewards else None
+ return {
+ "total": total,
+ "correct": correct,
+ "accuracy": correct / total if total else 0.0,
+ "termination_distribution": termination,
+ "average_reward": avg_reward,
+ }
+
+
+def save_outputs(episodes: List[Any], metrics: dict, output_dir: Path, config: dict):
+ output_dir.mkdir(parents=True, exist_ok=True)
+ ep_path = output_dir / "episodes.jsonl"
+ with ep_path.open("w", encoding="utf-8") as f:
+ for ep in episodes:
+ f.write(json.dumps(episode_to_dict(ep), ensure_ascii=False) + "\n")
+
+ with (output_dir / "metrics.json").open("w", encoding="utf-8") as f:
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
+
+ with (output_dir / "config.json").open("w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, ensure_ascii=False)
+
+ print(f"💾 Saved episodes to {ep_path}")
+ print(f"📊 Metrics: {metrics}")
+
+
+# ---------------------- Arg parsing ---------------------- #
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="DeepResearch evaluation (Parquet only)"
+ )
+ parser.add_argument(
+ "--config", default="eval/config/eval_hle.yaml", help="YAML config path"
+ )
+
+ # Data (Parquet only)
+ parser.add_argument("--parquet", default=None, help="Local Parquet path (required)")
+ parser.add_argument("--max-samples", type=int, default=None, help="Max samples")
+
+ # Rollout
+ parser.add_argument(
+ "--provider",
+ default=None,
+ help="openai (default) | vllm (still uses OpenAIEngine base_url)",
+ )
+ parser.add_argument("--model", default=None, help="Model name")
+ parser.add_argument("--base-url", default=None, help="OpenAI-compatible base URL")
+ parser.add_argument("--api-key", default=None, help="API key")
+ parser.add_argument(
+ "--temperature", type=float, default=None, help="Sampling temperature"
+ )
+ parser.add_argument("--top-p", type=float, default=None, help="Top-p")
+ parser.add_argument(
+ "--max-tokens", type=int, default=None, help="Max tokens for completion"
+ )
+
+ # Execution
+ parser.add_argument(
+ "--parallel-tasks", type=int, default=None, help="Parallel tasks"
+ )
+ parser.add_argument("--retry-limit", type=int, default=None, help="Retry limit")
+ parser.add_argument(
+ "--output-dir", default=None, help="Output dir (relative to eval/)"
+ )
+
+ return parser.parse_args()
+
+
+def load_config(path: Path) -> dict:
+ if path.exists():
+ with path.open("r", encoding="utf-8") as f:
+ return yaml.safe_load(f) or {}
+ return {}
+
+
+def merge_args_with_config(cfg: dict, args) -> argparse.Namespace:
+ # Rollout source
+ provider = args.provider or cfg.get("provider", "openai")
+ rollout_cfg = cfg.get(provider, {}) if isinstance(cfg, dict) else {}
+
+ data_cfg = cfg.get("data", {}) if isinstance(cfg, dict) else {}
+ exec_cfg = cfg.get("execution", {}) if isinstance(cfg, dict) else {}
+ out_cfg = cfg.get("output", {}) if isinstance(cfg, dict) else {}
+
+ merged = argparse.Namespace(
+ provider=provider,
+ model=args.model or rollout_cfg.get("model") or "gpt-4o",
+ base_url=args.base_url
+ or rollout_cfg.get("base_url")
+ or os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
+ api_key=args.api_key
+ or rollout_cfg.get("api_key")
+ or os.getenv("OPENAI_API_KEY", ""),
+ temperature=(
+ args.temperature
+ if args.temperature is not None
+ else rollout_cfg.get("sampling_params", {}).get("temperature", 0.6)
+ ),
+ top_p=(
+ args.top_p
+ if args.top_p is not None
+ else rollout_cfg.get("sampling_params", {}).get("top_p", 0.95)
+ ),
+ max_tokens=(
+ args.max_tokens
+ if args.max_tokens is not None
+ else rollout_cfg.get("sampling_params", {}).get("max_tokens")
+ ),
+ parquet=args.parquet or data_cfg.get("parquet"),
+ max_samples=(
+ args.max_samples
+ if args.max_samples is not None
+ else data_cfg.get("max_samples")
+ ),
+ parallel_tasks=(
+ args.parallel_tasks
+ if args.parallel_tasks is not None
+ else exec_cfg.get("parallel_tasks", 4)
+ ),
+ retry_limit=(
+ args.retry_limit
+ if args.retry_limit is not None
+ else exec_cfg.get("retry_limit", 1)
+ ),
+ output_dir=args.output_dir or out_cfg.get("dir") or "./outputs",
+ )
+ return merged
+
+
+# ---------------------- Main ---------------------- #
+
+
+async def main():
+ args_cli = parse_args()
+ cfg = load_config(Path(args_cli.config))
+ args = merge_args_with_config(cfg, args_cli)
+
+ tasks = load_tasks(args)
+ if not tasks:
+ raise ValueError("No tasks loaded. Please check dataset or JSONL path.")
+
+ tools = get_all_tools()
+ rollout_engine = build_rollout_engine(args)
+
+ workflow_args = {
+ "tools": tools,
+ "reward_function": deepresearch_reward_fn,
+ }
+
+ workflow_engine = AgentWorkflowEngine(
+ workflow_cls=DeepResearchWorkflow,
+ workflow_args=workflow_args,
+ rollout_engine=rollout_engine,
+ n_parallel_tasks=args.parallel_tasks,
+ retry_limit=args.retry_limit,
+ )
+
+ print(f"🚀 Running DeepResearch evaluation with {len(tasks)} tasks")
+ episodes = await workflow_engine.execute_tasks(tasks)
+
+ metrics = compute_metrics(episodes)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_dir = Path(__file__).parent / args.output_dir / timestamp
+ save_outputs(list(episodes), metrics, output_dir, cfg)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/vendor/rllm/eval/gen_image_eval_runner.py b/vendor/rllm/eval/gen_image_eval_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..805ca3fb9cd59d538e9a30f9c58accc39e9df95b
--- /dev/null
+++ b/vendor/rllm/eval/gen_image_eval_runner.py
@@ -0,0 +1,329 @@
+"""
+Gen Image eval: run the Gen workflow and only produce trajectory logs, gen_prompt, and reference_images.
+No image generation and no Gemini-based scoring.
+Input: a JSON file, each line is one sample {id, prompt, meta, gen_image}.
+Output: results.json (appended after each sample; includes id, trajectory_messages, gen_prompt, reference_images, termination, etc.).
+Supports --resume: skip ids that are already completed.
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import json
+import os
+import shutil
+from pathlib import Path
+from typing import Any, Dict, List
+
+from tqdm import tqdm
+
+from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
+from rllm.engine.rollout import OpenAIEngine
+from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import (
+ create_gen_image_tools,
+)
+from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import (
+ GenImageDeepResearchWorkflow,
+)
+
+
+# ---------------------- Data loading ---------------------- #
+
+
+def load_tasks_from_json(path: str, max_samples: int | None = None) -> List[dict]:
+ """Load JSON input. Supports a JSON array or JSONL (one JSON object per line)."""
+ path = Path(path).resolve()
+ if not path.exists():
+ raise FileNotFoundError(f"Input file not found: {path}")
+
+ tasks: List[dict] = []
+ with open(path, "r", encoding="utf-8") as f:
+ raw = f.read().strip()
+
+ if raw.startswith("["):
+ data = json.loads(raw)
+ if not isinstance(data, list):
+ raise ValueError("JSON root must be an array")
+ tasks = data
+ else:
+ for line in raw.splitlines():
+ line = line.strip()
+ if not line:
+ continue
+ tasks.append(json.loads(line))
+
+ if not tasks:
+ raise ValueError("No valid tasks in input file")
+
+ for i, t in enumerate(tasks):
+ if "prompt" not in t:
+ raise ValueError(f"Task {i} missing 'prompt'")
+ if "id" not in t:
+ t["id"] = i
+
+ if max_samples is not None:
+ tasks = tasks[:max_samples]
+
+ return tasks
+
+
+def task_to_workflow_input(record: dict) -> dict:
+ """Convert an input record into the workflow task format (question, etc.)."""
+ return {
+ "id": record.get("id"),
+ "question": record.get("prompt", ""),
+ "prompt": record.get("prompt", ""),
+ "meta": record.get("meta", {}),
+ "gen_image": record.get("gen_image"), # Not used for now
+ }
+
+
+# ---------------------- Sanitize & output ---------------------- #
+
+
+def _sanitize_content(msg: dict) -> dict:
+ """Sanitize message content by trimming oversized fields."""
+ out = {"role": msg.get("role", ""), "content": ""}
+ content = msg.get("content", "")
+ if isinstance(content, str):
+ out["content"] = content[:50000] + "..." if len(content) > 50000 else content
+ else:
+ out["content"] = str(content)[:50000]
+ if "images" in msg:
+ out["images"] = [
+ (p if isinstance(p, str) else p.get("image", ""))[:200]
+ for p in (msg["images"] or [])[:10]
+ ]
+ return out
+
+
+def _copy_ref_images_and_build_list(
+ prediction: dict,
+ sample_id: str,
+ ref_images_dir: Path,
+) -> List[dict]:
+ """Copy images in reference_images into ref_images_dir and return the updated list with new paths."""
+ ref_images_dir.mkdir(parents=True, exist_ok=True)
+ out = []
+ raw_refs = prediction.get("reference_images", []) if isinstance(prediction, dict) else []
+ for r in raw_refs:
+ if not isinstance(r, dict):
+ continue
+ local_path = (r.get("local_path") or "").strip()
+ img_id = r.get("img_id", "").strip() or "img"
+ note = r.get("note", "")
+ url = r.get("url", "")
+ title = r.get("title", "")
+ new_path = ""
+ if local_path and os.path.exists(local_path):
+ ext = Path(local_path).suffix or ".jpg"
+ safe_img_id = "".join(c if c.isalnum() or c in "_-" else "_" for c in img_id)
+ dest_name = f"{sample_id}_{safe_img_id}{ext}"
+ dest_path = ref_images_dir / dest_name
+ try:
+ shutil.copy2(local_path, dest_path)
+ new_path = str(dest_path)
+ except Exception as e:
+ print(f"[GenEval] Failed to copy ref image {local_path} -> {dest_path}: {e}")
+ new_path = local_path
+ else:
+ new_path = local_path
+ out.append({
+ "img_id": img_id,
+ "note": note,
+ "local_path": new_path,
+ "url": url,
+ "title": title,
+ })
+ return out
+
+
+def episode_to_output_record(
+ episode: Any,
+ original_task: dict,
+ ref_images_dir: Path | None = None,
+) -> dict:
+ """Extract fields to save from an Episode."""
+ info = episode.info or {}
+ messages = info.get("messages", [])
+ prediction = info.get("prediction", {})
+ termination = info.get("termination") or (
+ episode.termination_reason.value if hasattr(episode, "termination_reason") and episode.termination_reason else "unknown"
+ )
+
+ trajectory_messages = [_sanitize_content(m) for m in messages]
+ gen_prompt = prediction.get("gen_prompt", "") if isinstance(prediction, dict) else ""
+ sample_id = str(original_task.get("id", "unknown"))
+
+ if ref_images_dir:
+ reference_images = _copy_ref_images_and_build_list(prediction, sample_id, ref_images_dir)
+ else:
+ reference_images = []
+ if isinstance(prediction, dict) and "reference_images" in prediction:
+ for r in prediction["reference_images"]:
+ if isinstance(r, dict):
+ reference_images.append({
+ "img_id": r.get("img_id", ""),
+ "note": r.get("note", ""),
+ "local_path": r.get("local_path", ""),
+ "url": r.get("url", ""),
+ "title": r.get("title", ""),
+ })
+
+ # Open-source data convention: GT image path is stored at top-level task field `gt_image`
+ gt_image = original_task.get("gt_image", "")
+
+ return {
+ "id": original_task.get("id"),
+ "prompt": original_task.get("prompt", ""),
+ "meta": original_task.get("meta", {}),
+ "termination": termination,
+ "trajectory_messages": trajectory_messages,
+ "gen_prompt": gen_prompt,
+ "reference_images": reference_images,
+ "gt_image": gt_image,
+ }
+
+
+# ---------------------- Rollout ---------------------- #
+
+
+def build_rollout_engine(args: argparse.Namespace) -> OpenAIEngine:
+ sampling_params = {
+ "temperature": getattr(args, "temperature", 0.7),
+ "top_p": getattr(args, "top_p", 1.0),
+ }
+ max_tokens = getattr(args, "max_tokens", None)
+ if max_tokens is not None:
+ sampling_params["max_tokens"] = max_tokens
+ return OpenAIEngine(
+ model=args.model,
+ base_url=args.base_url,
+ api_key=args.api_key,
+ max_prompt_length=getattr(args, "max_prompt_length", 32768),
+ max_response_length=getattr(args, "max_response_length", 32768),
+ sampling_params=sampling_params,
+ )
+
+
+# ---------------------- Main ---------------------- #
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Gen Image eval: produce trajectory logs, gen_prompt, reference_images (no generation/scoring)"
+ )
+ parser.add_argument("--input", "--json", default=None, dest="input_json", help="Input JSON path (array or JSONL)")
+ parser.add_argument("--max-samples", type=int, default=None, help="Max number of samples to evaluate")
+ parser.add_argument("--model", default=None, help="Model name (inference service)")
+ parser.add_argument("--base-url", default=None, help="OpenAI-compatible inference service base_url")
+ parser.add_argument("--api-key", default=None, help="API Key")
+ parser.add_argument("--temperature", type=float, default=0.7)
+ parser.add_argument("--top-p", type=float, default=1.0)
+ parser.add_argument("--max-tokens", type=int, default=None)
+ parser.add_argument("--max-prompt-length", type=int, default=32768, help="Max prompt length (default 32k)")
+ parser.add_argument("--max-response-length", type=int, default=32768, help="Max response length (default 32k)")
+ parser.add_argument("--parallel-tasks", type=int, default=4, help="Number of parallel tasks")
+ parser.add_argument("--output-dir", default="./gen_eval_outputs", help="Output directory")
+ parser.add_argument("--resume", action="store_true", help="Skip completed ids and resume from checkpoint")
+ return parser.parse_args()
+
+
+def _append_and_save(results: List[dict], rec: dict, results_path: Path) -> None:
+ """Append one result and immediately write back to the JSON file."""
+ results.append(rec)
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+
+
+async def _run_one(
+ workflow_engine: AgentWorkflowEngine,
+ task: dict,
+ task_id: str,
+):
+ """Run a single task and return the episode."""
+ _, _, episode = await workflow_engine.process_task_with_retry(task, task_id, 0)
+ return episode
+
+
+async def main():
+ args = parse_args()
+
+ input_path = args.input_json or os.environ.get("GEN_EVAL_INPUT_JSON")
+ if not input_path:
+ raise ValueError("Please provide --input/--json or set GEN_EVAL_INPUT_JSON")
+
+ model = args.model or os.environ.get("GEN_EVAL_MODEL", "Vision-DeepResearch-8B")
+ base_url = args.base_url or os.environ.get("OPENAI_BASE_URL", "http://localhost:8000/v1")
+ api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
+
+ args.model = model
+ args.base_url = base_url
+ args.api_key = api_key
+
+ tasks_raw = load_tasks_from_json(input_path, args.max_samples)
+ tasks = [task_to_workflow_input(t) for t in tasks_raw]
+ task_ids = [str(t.get("id", i)) for i, t in enumerate(tasks_raw)]
+ id_to_task = {str(t.get("id", i)): t for i, t in enumerate(tasks_raw)}
+
+ out_dir = Path(args.output_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ ref_images_dir = out_dir / "ref_images"
+ results_path = out_dir / "results.json"
+
+ results: List[dict] = []
+ done_ids: set = set()
+ if args.resume and results_path.exists():
+ try:
+ with open(results_path, "r", encoding="utf-8") as f:
+ results = json.load(f)
+ done_ids = {str(r.get("id")) for r in results if r.get("id") is not None}
+ print(f"[GenEval] Resume: loaded {len(results)} records, skipping {len(done_ids)} ids")
+ except Exception as e:
+ print(f"[GenEval] Resume load failed: {e}; starting from scratch")
+
+ pending = [
+ (t, tid)
+ for t, tid in zip(tasks, task_ids)
+ if tid not in done_ids
+ ]
+ if not pending:
+ print("[GenEval] No pending samples; already completed")
+ return
+
+ tools = create_gen_image_tools()
+ rollout_engine = build_rollout_engine(args)
+ workflow_engine = AgentWorkflowEngine(
+ workflow_cls=GenImageDeepResearchWorkflow,
+ workflow_args={
+ "tools": tools,
+ "reward_function": None,
+ },
+ rollout_engine=rollout_engine,
+ n_parallel_tasks=args.parallel_tasks,
+ retry_limit=2,
+ )
+ await workflow_engine.initialize_pool()
+
+ write_lock = asyncio.Lock()
+ total_count = len(results) + len(pending)
+ pbar = tqdm(total=len(pending), desc="GenEval", unit="sample")
+
+ async def run_and_save(task: dict, task_id: str) -> None:
+ episode = await _run_one(workflow_engine, task, task_id)
+ orig = id_to_task.get(task_id, {"id": task_id, "prompt": "", "meta": {}})
+ rec = episode_to_output_record(episode, orig, ref_images_dir=ref_images_dir)
+ async with write_lock:
+ _append_and_save(results, rec, results_path)
+ pbar.update(1)
+ pbar.set_postfix(completed=len(results), total=total_count)
+
+ print(f"[GenEval] Pending {len(pending)} samples; starting")
+ await asyncio.gather(*[run_and_save(t, tid) for t, tid in pending])
+ pbar.close()
+ print(f"[GenEval] Done. Results saved to {results_path}; ref images saved to {ref_images_dir}")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/vendor/rllm/eval/gen_image_from_results.py b/vendor/rllm/eval/gen_image_from_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee6b91e27290947c692fc510a8d17d9ae9608a51
--- /dev/null
+++ b/vendor/rllm/eval/gen_image_from_results.py
@@ -0,0 +1,1652 @@
+"""
+Generate images from a result JSON (with gen_prompt/reference_images) or a prompt-only JSON using an API backend or local diffusers.
+- If gen_prompt exists and reference_images is non-empty: generate with gen_prompt + reference images
+- Otherwise: generate from the original prompt (text-only)
+- If input only contains prompt: generate directly from prompt
+- diffuser_flux: unified FLUX text-to-image backend (text-only)
+Output: save images under OUTPUT_DIR/model_xxx/ and write results.json.
+"""
+from __future__ import annotations
+
+import argparse
+import io
+import json
+import os
+import re
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+from typing import Any, List, Optional
+
+from tqdm import tqdm
+
+# ---------------------------------------------------------------------------
+# Utilities: image base64 (no external api_generator dependency)
+# ---------------------------------------------------------------------------
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _img2base64(img: str | "Image.Image", format: str = "JPEG") -> str:
+ if Image is None:
+ raise RuntimeError("PIL is required for image encoding")
+ if isinstance(img, str):
+ if not os.path.exists(img):
+ raise FileNotFoundError(f"File not found: {img}")
+ Image.MAX_IMAGE_PIXELS = None
+ img = Image.open(img)
+ if getattr(img, "mode", "") == "RGBA":
+ img = img.convert("RGB")
+ buf = io.BytesIO()
+ img.save(buf, format=format)
+ return __import__("base64").b64encode(buf.getvalue()).decode()
+
+
+def _base642img(b64: str) -> "Image.Image":
+ if Image is None:
+ raise RuntimeError("PIL is required")
+ data = __import__("base64").b64decode(b64)
+ return Image.open(io.BytesIO(data))
+
+
+# ---------------------------------------------------------------------------
+# Data loading and "which backend to use" logic
+# ---------------------------------------------------------------------------
+
+
+def load_records(path: str) -> List[dict]:
+ path = Path(path).resolve()
+ if not path.exists():
+ raise FileNotFoundError(f"Input file not found: {path}")
+ records = []
+ with open(path, "r", encoding="utf-8") as f:
+ raw = f.read().strip()
+ if raw.startswith("["):
+ data = json.loads(raw)
+ if not isinstance(data, list):
+ raise ValueError("JSON root must be an array")
+ records = data
+ else:
+ for line in raw.splitlines():
+ line = line.strip()
+ if not line:
+ continue
+ records.append(json.loads(line))
+ if not records:
+ raise ValueError("No valid records in input")
+ return records
+
+
+def get_effective_prompt_and_images(record: dict) -> tuple[str, List[str]]:
+ """
+ Return (prompt, image_paths).
+ - If gen_prompt exists and reference_images is non-empty: use gen_prompt + all valid local_path values
+ - Otherwise: use record["prompt"], with image_paths empty (text-only)
+ """
+ prompt = record.get("prompt") or ""
+ gen_prompt = (record.get("gen_prompt") or "").strip()
+ refs = record.get("reference_images") or []
+ if not isinstance(refs, list):
+ refs = []
+ paths = []
+ for r in refs:
+ if not isinstance(r, dict):
+ continue
+ p = (r.get("local_path") or "").strip()
+ if p and os.path.exists(p):
+ paths.append(p)
+ if gen_prompt and paths:
+ return gen_prompt, paths
+ return prompt, []
+
+
+# ---------------------------------------------------------------------------
+# Generator interface and API retries
+# ---------------------------------------------------------------------------
+
+
+class ImageGeneratorBase:
+ """Generation interface: generate(prompt, image_paths) -> PIL.Image. Empty image_paths means text-only."""
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ raise NotImplementedError
+
+
+def _api_request_with_retry(
+ method: str,
+ url: str,
+ payload: str,
+ headers: dict,
+ timeout: int = 120,
+ max_try: int = 5,
+ print_log: bool = False,
+) -> Any:
+ import requests
+ for i in range(max_try):
+ try:
+ if print_log and max_try > 1:
+ print(f" Request {i + 1}/{max_try}: {url[:80]}...", flush=True)
+ r = requests.post(url, data=payload.encode("utf-8") if isinstance(payload, str) else payload, headers=headers, timeout=timeout)
+ return r
+ except Exception as e:
+ if print_log:
+ print(f" Request exception: {e}", flush=True)
+ time.sleep(min(30, 3 + 3 * (1.1 ** i)))
+ return None
+
+
+# ---------------------------------------------------------------------------
+# API generators: Nano / Seed / GPT (with retries)
+# ---------------------------------------------------------------------------
+
+
+def _extract_inline_image_b64_from_gemini_generate_content(data: Any) -> Optional[str]:
+ """Parse Gemini `generateContent` JSON; return base64 image data from the first inlineData part."""
+ if not isinstance(data, dict):
+ return None
+ for cand in data.get("candidates") or []:
+ if not isinstance(cand, dict):
+ continue
+ content = cand.get("content") or {}
+ if not isinstance(content, dict):
+ continue
+ for part in content.get("parts") or []:
+ if not isinstance(part, dict):
+ continue
+ inline = part.get("inlineData") or part.get("inline_data")
+ if isinstance(inline, dict):
+ b64 = inline.get("data")
+ if b64:
+ return str(b64)
+ return None
+
+
+class NanoAPIGenerator(ImageGeneratorBase):
+ """Nano (e.g. gemini-3-pro-image-preview): Google Generative Language `generateContent` (official REST)."""
+
+ def __init__(
+ self,
+ api_key: str,
+ model_name: str,
+ timeout: int = 60,
+ max_try: int = 5,
+ print_log: bool = False,
+ poll_interval: float = 1.0,
+ max_poll_seconds: int = 180,
+ ):
+ import requests
+ self.requests = requests
+ self.api_key = api_key
+ self.model_name = model_name
+ # Official endpoint (hardcoded): https://ai.google.dev/api/rest/v1beta/models.generateContent
+ self.usage_app_url = (
+ f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
+ )
+ self.timeout = timeout
+ self.max_try = max_try
+ self.print_log = print_log
+ self.poll_interval = poll_interval
+ self.max_poll_seconds = max_poll_seconds
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ images = []
+ if image_paths:
+ for p in image_paths:
+ images.append(Image.open(p) if os.path.isfile(p) else None)
+ images = [x for x in images if x is not None]
+ parts = [{"text": prompt}]
+ for img in images:
+ parts.append({
+ "inlineData": {"mimeType": "image/jpeg", "data": _img2base64(img)}
+ })
+ payload = json.dumps({"contents": [{"parts": parts}]}, ensure_ascii=False)
+ headers = {"x-goog-api-key": self.api_key, "Content-Type": "application/json"}
+ last_err = None
+ total_deadline = time.time() + self.timeout # Total time across retries must not exceed timeout
+ for attempt in range(self.max_try):
+ if time.time() >= total_deadline:
+ raise last_err or TimeoutError(f"Total duration exceeded {self.timeout}s")
+ req_timeout = min(self.timeout, max(60, int(total_deadline - time.time())))
+ try:
+ r = _api_request_with_retry(
+ "POST", self.usage_app_url, payload, headers, req_timeout, 1, self.print_log
+ ) # 1=no inner retry; outer max_try controls retries
+ if r is None or r.status_code != 200:
+ raise RuntimeError(
+ f"Nano generateContent failed: {r.status_code if r else 'NoResponse'} {getattr(r, 'text', '')}"
+ )
+ data = r.json()
+ if isinstance(data, dict) and data.get("error"):
+ raise RuntimeError(f"Nano API error: {data.get('error')}")
+ img_b64 = _extract_inline_image_b64_from_gemini_generate_content(data)
+ if not img_b64:
+ raise RuntimeError(f"Nano: no inline image in response: {str(data)[:500]}")
+ raw = __import__("base64").b64decode(img_b64)
+ return Image.open(io.BytesIO(raw))
+ except Exception as e:
+ last_err = e
+ if (attempt + 1) % 10 == 0:
+ print(f" [Nano] Still failing after {attempt + 1}/{self.max_try} retries: {e}", flush=True)
+ if attempt < self.max_try - 1 and time.time() < total_deadline - 5:
+ time.sleep(min(30, 3 + 3 * (1.1 ** attempt)))
+ raise last_err
+
+
+class SeedAPIGenerator(ImageGeneratorBase):
+ """Seedream (Doubao) via Volcengine Ark: POST .../images/generations (OpenAI-style), returns URL or bytes."""
+
+ def __init__(
+ self,
+ api_key: str,
+ model_name: str = "doubao-seedream-4-0-250828",
+ timeout: int = 60,
+ max_try: int = 5,
+ print_log: bool = False,
+ ):
+ import requests
+ self.requests = requests
+ self.api_key = api_key
+ self.model_name = model_name
+ # Volcengine Ark API v3 (images/generations)
+ self.usage_app_url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
+ self.timeout = timeout
+ self.max_try = max_try
+ self.print_log = print_log
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ images_base64 = []
+ if image_paths:
+ for p in image_paths:
+ if os.path.isfile(p):
+ images_base64.append(_img2base64(p))
+ payload = {
+ "model": self.model_name,
+ "prompt": prompt,
+ "size": "2K",
+ "response_format": "url",
+ "watermark": False,
+ "sequential_image_generation": "disabled",
+ }
+ if images_base64:
+ payload["image"] = [f"data:image/jpg;base64,{b}" for b in images_base64]
+ payload_str = json.dumps(payload)
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
+ last_err = None
+ total_deadline = time.time() + self.timeout # Total time across retries must not exceed timeout
+ for attempt in range(self.max_try):
+ if time.time() >= total_deadline:
+ raise last_err or TimeoutError(f"Total duration exceeded {self.timeout}s")
+ req_timeout = min(self.timeout, max(60, int(total_deadline - time.time())))
+ try:
+ r = _api_request_with_retry(
+ "POST", self.usage_app_url, payload_str, headers, req_timeout, 1, self.print_log
+ ) # 1=no inner retry; outer max_try controls retries
+ if r is None:
+ raise RuntimeError("Seed: no response")
+ data = r.json()
+ if r.status_code != 200 or "data" not in data:
+ raise RuntimeError(f"Seed API error: {data}")
+ image_url = data["data"][0]["url"]
+ img_r = self.requests.get(image_url, timeout=req_timeout)
+ if img_r.status_code != 200:
+ raise RuntimeError(f"Seed: failed to download image: {img_r.status_code}")
+ return Image.open(io.BytesIO(img_r.content))
+ except Exception as e:
+ last_err = e
+ if (attempt + 1) % 10 == 0:
+ print(f" [Seed] Still failing after {attempt + 1}/{self.max_try} retries: {e}", flush=True)
+ if attempt < self.max_try - 1 and time.time() < total_deadline - 5:
+ time.sleep(min(30, 3 + 3 * (1.1 ** attempt)))
+ raise last_err
+
+
+class GPTImageAPIGenerator(ImageGeneratorBase):
+ """OpenAI Images API: edits when reference images are provided; otherwise generations."""
+
+ def __init__(
+ self,
+ api_key: str,
+ model_name: str = "gpt-image-1",
+ timeout: int = 300,
+ max_try: int = 5,
+ print_log: bool = False,
+ ):
+ import requests
+ self.requests = requests
+ self.api_key = api_key
+ self.model_name = model_name
+ # OpenAI official API v1 base; endpoints: /v1/images/generations, /v1/images/edits
+ self.base_url = "https://api.openai.com/v1"
+ self.timeout = timeout
+ self.max_try = max_try
+ self.print_log = print_log
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ if image_paths and len(image_paths) > 0:
+ images_b64 = [_img2base64(p) for p in image_paths if os.path.isfile(p)]
+ if not images_b64:
+ image_paths = None
+ else:
+ images_b64 = []
+ if image_paths and images_b64:
+ url = f"{self.base_url}/images/edits"
+ payload = {
+ "model": "gpt-image-1",
+ "prompt": prompt,
+ "image": images_b64,
+ "n": 1,
+ "size": "auto",
+ "quality": "auto",
+ }
+ else:
+ url = f"{self.base_url}/images/generations"
+ payload = {
+ "model": self.model_name,
+ "prompt": prompt,
+ "n": 1,
+ "size": "1024x1024",
+ "quality": "standard" if self.model_name == "dall-e-3" else "auto",
+ }
+ if self.model_name == "gpt-image-1":
+ payload["background"] = "auto"
+ payload["moderation"] = "low"
+ payload["output_compression"] = 100
+ payload["size"] = "auto"
+ payload_str = json.dumps(payload)
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
+ last_err = None
+ total_deadline = time.time() + self.timeout # Total time across retries must not exceed timeout
+ for attempt in range(self.max_try):
+ if time.time() >= total_deadline:
+ raise last_err or TimeoutError(f"Total duration exceeded {self.timeout}s")
+ req_timeout = min(self.timeout, max(60, int(total_deadline - time.time())))
+ try:
+ r = _api_request_with_retry(
+ "POST", url, payload_str, headers, req_timeout, 1, self.print_log
+ ) # 1=no inner retry; outer max_try controls retries
+ if r is None:
+ raise RuntimeError("GPT Image: no response")
+ data = r.json()
+ if r.status_code != 200 or "data" not in data:
+ raise RuntimeError(f"GPT Image API error: {data}")
+ img_data = data["data"][0]
+ if "b64_json" in img_data:
+ return _base642img(img_data["b64_json"])
+ if "url" in img_data:
+ img_r = self.requests.get(img_data["url"], timeout=req_timeout)
+ if img_r.status_code != 200:
+ raise RuntimeError(f"GPT Image: download failed: {img_r.status_code}")
+ return Image.open(io.BytesIO(img_r.content))
+ raise RuntimeError("GPT Image response contains no image data")
+ except Exception as e:
+ last_err = e
+ if (attempt + 1) % 10 == 0:
+ print(
+ f" [GPT Image] Still failing after {attempt + 1}/{self.max_try} retries: {e}",
+ flush=True,
+ )
+ if attempt < self.max_try - 1 and time.time() < total_deadline - 5:
+ time.sleep(min(30, 3 + 3 * (1.1 ** attempt)))
+ raise last_err
+
+
+# ---------------------------------------------------------------------------
+# Local diffusers: Qwen (separate gen/edit), LongCat, Z-Image, Z-Image-Turbo, FLUX
+# ---------------------------------------------------------------------------
+
+
+class DiffuserQwenGenerator(ImageGeneratorBase):
+ """
+ Text-only uses the gen model (e.g. Qwen-Image) on gen_device.
+ Text+image uses the edit model (e.g. Qwen-Image-Edit-2509) on the edit device pool.
+ gen_device: e.g. "cuda:0"
+ edit_device: e.g. "cuda:2,cuda:3" (comma-separated)
+
+ Scheduling strategy:
+ - Text-only tasks use gen resources (concurrency=1)
+ - Text+image tasks use the edit pool (concurrency = number of edit devices)
+ - Total concurrency cap = min(3, 1 + number of edit devices)
+ """
+
+ def __init__(
+ self,
+ gen_model: str = "Qwen/Qwen-Image",
+ edit_model: str = "Qwen/Qwen-Image-Edit-2509",
+ gen_device: str = "cuda:0",
+ edit_device: str = "cuda:1",
+ torch_dtype: Optional[str] = None,
+ ):
+ import torch
+ self.torch = torch
+ self.gen_device = gen_device
+ self.edit_devices = [d.strip() for d in str(edit_device).split(",") if d.strip()]
+ if not self.edit_devices:
+ self.edit_devices = ["cuda:1"]
+ dtype = getattr(torch, torch_dtype or "bfloat16", torch.bfloat16)
+ self._gen_pipe = None
+ self._edit_pipes = [None for _ in self.edit_devices]
+ self._gen_model_id = (gen_model or "").strip() or "Qwen/Qwen-Image"
+ self._edit_model_id = (edit_model or "").strip() or "Qwen/Qwen-Image-Edit-2509"
+ self._dtype = dtype
+ self._pipe_init_lock = threading.Lock()
+ self._gen_lock = threading.Lock()
+ import queue
+ self._edit_slots = queue.Queue()
+ for i in range(len(self.edit_devices)):
+ self._edit_slots.put(i)
+ self.max_parallel = min(3, 1 + len(self.edit_devices))
+
+ def _get_gen_pipe(self):
+ if self._gen_pipe is None:
+ with self._pipe_init_lock:
+ if self._gen_pipe is None:
+ from diffusers import DiffusionPipeline
+ self._gen_pipe = DiffusionPipeline.from_pretrained(
+ self._gen_model_id, torch_dtype=self._dtype
+ ).to(self.gen_device)
+ return self._gen_pipe
+
+ def _get_edit_pipe(self, slot_idx: int):
+ if self._edit_pipes[slot_idx] is None:
+ with self._pipe_init_lock:
+ if self._edit_pipes[slot_idx] is None:
+ from diffusers import QwenImageEditPlusPipeline
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
+ self._edit_model_id, torch_dtype=self._dtype
+ ).to(self.edit_devices[slot_idx])
+ pipe.set_progress_bar_config(disable=None)
+ self._edit_pipes[slot_idx] = pipe
+ return self._edit_pipes[slot_idx]
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ if not image_paths or len(image_paths) == 0:
+ with self._gen_lock:
+ pipe = self._get_gen_pipe()
+ positive_magic = ", Ultra HD, 4K, cinematic composition."
+ image = pipe(
+ prompt=prompt + positive_magic,
+ negative_prompt=" ",
+ width=1664,
+ height=928,
+ num_inference_steps=50,
+ true_cfg_scale=4.0,
+ generator=self.torch.Generator(device=self.gen_device).manual_seed(0),
+ ).images[0]
+ return image
+
+ image_paths = image_paths[:3] # Qwen supports at most 3 reference images
+ images = [Image.open(p) for p in image_paths if os.path.isfile(p)]
+ if not images:
+ return self.generate(prompt, None)
+ slot_idx = self._edit_slots.get()
+ inputs = {
+ "image": images,
+ "prompt": prompt,
+ "generator": self.torch.Generator(device=self.edit_devices[slot_idx]).manual_seed(0),
+ "true_cfg_scale": 4.0,
+ "negative_prompt": " ",
+ "num_inference_steps": 40,
+ "guidance_scale": 1.0,
+ "num_images_per_prompt": 1,
+ }
+ try:
+ pipe = self._get_edit_pipe(slot_idx)
+ with self.torch.inference_mode():
+ out = pipe(**inputs)
+ return out.images[0]
+ finally:
+ self._edit_slots.put(slot_idx)
+
+
+class DiffuserLongCatGenerator(ImageGeneratorBase):
+ """
+ LongCat: text-only uses LongCatImagePipeline (gen); text+image uses LongCatImageEditPipeline (edit).
+ Similar to Qwen, gen and edit can use different devices.
+ """
+
+ def __init__(
+ self,
+ gen_model: str = "meituan-longcat/LongCat-Image",
+ edit_model: str = "meituan-longcat/LongCat-Image-Edit",
+ gen_device: str = "cuda:0",
+ edit_device: str = "cuda:1",
+ torch_dtype: Optional[str] = None,
+ ):
+ import torch
+ self.torch = torch
+ self.gen_device = gen_device
+ self.edit_device = edit_device
+ dtype = getattr(torch, torch_dtype or "bfloat16", torch.bfloat16)
+ self._gen_pipe = None
+ self._edit_pipe = None
+ self._gen_model_id = (gen_model or "").strip() or "meituan-longcat/LongCat-Image"
+ self._edit_model_id = (edit_model or "").strip() or "meituan-longcat/LongCat-Image-Edit"
+ self._dtype = dtype
+
+ def _get_gen_pipe(self):
+ if self._gen_pipe is None:
+ from diffusers import LongCatImagePipeline
+ self._gen_pipe = LongCatImagePipeline.from_pretrained(
+ self._gen_model_id, torch_dtype=self._dtype
+ ).to(self.gen_device)
+ return self._gen_pipe
+
+ def _get_edit_pipe(self):
+ if self._edit_pipe is None:
+ from diffusers import LongCatImageEditPipeline
+ self._edit_pipe = LongCatImageEditPipeline.from_pretrained(
+ self._edit_model_id, torch_dtype=self._dtype
+ ).to(self.edit_device)
+ return self._edit_pipe
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ if not image_paths or len(image_paths) == 0:
+ pipe = self._get_gen_pipe()
+ image = pipe(
+ prompt,
+ height=768,
+ width=1344,
+ guidance_scale=4.0,
+ num_inference_steps=50,
+ num_images_per_prompt=1,
+ generator=self.torch.Generator(self.gen_device).manual_seed(0),
+ enable_cfg_renorm=True,
+ enable_prompt_rewrite=True,
+ ).images[0]
+ return image
+ pipe = self._get_edit_pipe()
+ images = [Image.open(p).convert("RGB") for p in image_paths if os.path.isfile(p)]
+ if not images:
+ return self.generate(prompt, None)
+ img = images[0]
+ image = pipe(
+ img,
+ prompt,
+ negative_prompt="",
+ guidance_scale=4.5,
+ num_inference_steps=50,
+ num_images_per_prompt=1,
+ generator=self.torch.Generator(self.edit_device).manual_seed(0),
+ ).images[0]
+ return image
+
+
+class DiffuserZImageGenerator(ImageGeneratorBase):
+ """Z-Image: text-only. model_id can be a local path or a HuggingFace id (auto-download)."""
+
+ def __init__(
+ self,
+ model_id: str = "Tongyi-MAI/Z-Image",
+ device: str = "cuda",
+ torch_dtype: str = "bfloat16",
+ ):
+ import torch
+ self.torch = torch
+ self.device = device
+ load_id = (model_id or "").strip() or "Tongyi-MAI/Z-Image"
+ dtype = getattr(torch, torch_dtype, torch.bfloat16)
+ from diffusers import ZImagePipeline
+ self.pipe = ZImagePipeline.from_pretrained(
+ load_id, torch_dtype=dtype, low_cpu_mem_usage=False
+ ).to(device)
+ self._model_id = load_id
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ image = self.pipe(
+ prompt=prompt,
+ negative_prompt="",
+ height=1280,
+ width=720,
+ cfg_normalization=False,
+ num_inference_steps=50,
+ guidance_scale=4,
+ generator=self.torch.Generator(self.device).manual_seed(0),
+ ).images[0]
+ return image
+
+
+class DiffuserZImageTurboGenerator(ImageGeneratorBase):
+ """Z-Image-Turbo: text-only. model_id can be a local path or a HuggingFace id (auto-download)."""
+
+ def __init__(
+ self,
+ model_id: str = "Tongyi-MAI/Z-Image-Turbo",
+ device: str = "cuda",
+ torch_dtype: str = "bfloat16",
+ ):
+ import torch
+ self.torch = torch
+ self.device = device
+ load_id = (model_id or "").strip() or "Tongyi-MAI/Z-Image-Turbo"
+ dtype = getattr(torch, torch_dtype, torch.bfloat16)
+ from diffusers import ZImagePipeline
+ self.pipe = ZImagePipeline.from_pretrained(
+ load_id, torch_dtype=dtype, low_cpu_mem_usage=False
+ ).to(device)
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ image = self.pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ generator=self.torch.Generator(self.device).manual_seed(0),
+ ).images[0]
+ return image
+
+
+def _resolve_local_model_path_case_insensitive_contains(path_hint: str, keyword: str) -> Optional[str]:
+ """
+ Case-insensitive "contains" matching for local model paths.
+ - If path_hint exists and contains keyword, return it.
+ - If path_hint is a directory, recursively match subdirectories.
+ - If path_hint does not exist, fall back to the nearest existing parent directory and recurse.
+ """
+ if not path_hint:
+ return None
+ keyword_l = (keyword or "").strip().lower()
+ if not keyword_l:
+ return None
+
+ hint = os.path.expanduser(path_hint.strip())
+ if os.path.exists(hint) and keyword_l in hint.lower():
+ return hint
+
+ roots = []
+ if os.path.isdir(hint):
+ roots.append(hint)
+ else:
+ cur = hint
+ while cur and cur != os.path.dirname(cur):
+ parent = os.path.dirname(cur)
+ if parent and os.path.isdir(parent):
+ roots.append(parent)
+ break
+ cur = parent
+
+ matches = []
+ for root in roots:
+ for dirpath, dirnames, _ in os.walk(root):
+ for d in dirnames:
+ full = os.path.join(dirpath, d)
+ if keyword_l in full.lower():
+ matches.append(full)
+
+ if not matches:
+ return None
+ matches.sort(key=lambda x: (len(x), x.lower()))
+ return matches[0]
+
+
+class DiffuserFluxGenerator(ImageGeneratorBase):
+ """
+ Unified FLUX text-to-image backend (text-only):
+ - FLUX.1-dev
+ - FLUX.1-Krea-dev
+ - FLUX.2-klein-9B
+ - FLUX.2-klein-4B
+ The specific variant is chosen by case-insensitive substring matching in model_path.
+ """
+
+ force_text_only = True
+
+ _VARIANTS = [
+ ("flux.1-krea-dev", "black-forest-labs/FLUX.1-Krea-dev"),
+ ("flux.1-dev", "black-forest-labs/FLUX.1-dev"),
+ ("flux.2-klein-9b", "black-forest-labs/FLUX.2-klein-9B"),
+ ("flux.2-klein-4b", "black-forest-labs/FLUX.2-klein-4B"),
+ ]
+
+ def __init__(self, model_path: str, device: str = "cuda", torch_dtype: str = "bfloat16"):
+ import torch
+ self.torch = torch
+ self.device = device
+ self.dtype = getattr(torch, torch_dtype, torch.bfloat16)
+
+ hint = (model_path or "").strip()
+ hint_l = hint.lower()
+ variant = None
+ model_id = None
+ resolved_local = None
+ for k, hf_id in self._VARIANTS:
+ if k in hint_l:
+ variant = k
+ model_id = hf_id
+ resolved_local = _resolve_local_model_path_case_insensitive_contains(hint, k)
+ break
+ if variant is None:
+ supported = ", ".join(k for k, _ in self._VARIANTS)
+ raise ValueError(
+ f"diffuser_flux cannot infer model variant from path: {hint}. "
+ f"It must contain one of: {supported}"
+ )
+
+ self.variant = variant
+ self.model_ref = resolved_local or model_id
+ if "flux.2-klein" in self.variant:
+ from diffusers import Flux2KleinPipeline
+ self.pipe = Flux2KleinPipeline.from_pretrained(self.model_ref, torch_dtype=self.dtype)
+ else:
+ from diffusers import FluxPipeline
+ self.pipe = FluxPipeline.from_pretrained(self.model_ref, torch_dtype=self.dtype)
+ self.pipe = self.pipe.to(self.device)
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ # FLUX backend is text-only by design; ignore image_paths.
+ if self.variant == "flux.1-dev":
+ return self.pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3.5,
+ num_inference_steps=50,
+ max_sequence_length=512,
+ generator=self.torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+
+ if self.variant == "flux.1-krea-dev":
+ return self.pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=4.5,
+ ).images[0]
+
+ return self.pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=1.0,
+ num_inference_steps=4,
+ generator=self.torch.Generator(device=self.device).manual_seed(0),
+ ).images[0]
+
+
+class DiffuserLumina2Generator(ImageGeneratorBase):
+ """
+ Lumina-Image-2.0 text-to-image (text-only).
+ - Model: Alpha-VLLM/Lumina-Image-2.0 or a local path
+ - Uses enable_model_cpu_offload to reduce VRAM usage
+ Inference parameters follow the provided reference settings.
+ """
+
+ force_text_only = True
+
+ def __init__(self, model_path: str, device: str = "cuda", torch_dtype: str = "bfloat16"):
+ import torch
+ from diffusers import Lumina2Pipeline
+
+ self.torch = torch
+ self.device = device
+ self.dtype = getattr(torch, torch_dtype, torch.bfloat16)
+
+ load_id = (model_path or "").strip() or "Alpha-VLLM/Lumina-Image-2.0"
+ self.pipe = Lumina2Pipeline.from_pretrained(load_id, torch_dtype=self.dtype)
+ # Offload to CPU as needed (aligned with the reference settings)
+ self.pipe.enable_model_cpu_offload()
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ image = self.pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=4.0,
+ num_inference_steps=50,
+ cfg_trunc_ratio=0.25,
+ cfg_normalization=True,
+ generator=self.torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+ return image
+
+
+class DiffuserSD3Generator(ImageGeneratorBase):
+ """
+ Stable Diffusion 3.5 text-to-image (large / medium, text-only).
+ Distinguish large vs medium by substring matching in model_path:
+ - *3.5-large* -> stabilityai/stable-diffusion-3.5-large
+ - *3.5-medium* -> stabilityai/stable-diffusion-3.5-medium
+ Inference parameters follow the provided reference settings.
+ """
+
+ force_text_only = True
+
+ _VARIANTS = [
+ ("3.5-large", "stabilityai/stable-diffusion-3.5-large"),
+ ("3.5-medium", "stabilityai/stable-diffusion-3.5-medium"),
+ ]
+
+ def __init__(self, model_path: str, device: str = "cuda", torch_dtype: str = "bfloat16"):
+ import torch
+ from diffusers import StableDiffusion3Pipeline
+
+ self.torch = torch
+ self.device = device
+ self.dtype = getattr(torch, torch_dtype, torch.bfloat16)
+
+ hint = (model_path or "").strip()
+ hint_l = hint.lower()
+ variant = None
+ model_id = None
+ resolved_local = None
+ for k, hf_id in self._VARIANTS:
+ if k.lower() in hint_l:
+ variant = k
+ model_id = hf_id
+ resolved_local = _resolve_local_model_path_case_insensitive_contains(hint, k)
+ break
+ if variant is None:
+ # If cannot infer from path, default to large
+ variant = "3.5-large"
+ model_id = "stabilityai/stable-diffusion-3.5-large"
+
+ self.variant = variant
+ self.model_ref = resolved_local or model_id
+ self.pipe = StableDiffusion3Pipeline.from_pretrained(self.model_ref, torch_dtype=self.dtype)
+ self.pipe = self.pipe.to(self.device)
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ if self.variant == "3.5-medium":
+ image = self.pipe(
+ prompt,
+ num_inference_steps=40,
+ guidance_scale=4.5,
+ ).images[0]
+ return image
+ image = self.pipe(
+ prompt,
+ num_inference_steps=28,
+ guidance_scale=3.5,
+ ).images[0]
+ return image
+
+
+class HunyuanImage3Generator(ImageGeneratorBase):
+ """
+ HunyuanImage-3 text-to-image (Transformers AutoModelForCausalLM).
+ - Text prompt only; reference images are ignored.
+ - Model path comes from diffuser_gen_model_path; default is ./HunyuanImage-3.
+ """
+
+ force_text_only = True
+
+ def __init__(self, model_path: str, device_spec: Optional[str] = None):
+ import os
+ from transformers import AutoModelForCausalLM
+
+ self.model_id = (model_path or "").strip() or "./HunyuanImage-3"
+ dev = (device_spec or "").strip()
+ if dev:
+ dev_l = dev.lower()
+ if dev_l != "auto":
+ # Supported forms:
+ # - "cuda:0,1,2"
+ # - "0,1,2"
+ # - "cuda:3"
+ parts = [p.strip() for p in dev.split(",") if p.strip()]
+ indices = []
+ for p in parts:
+ if p.startswith("cuda:"):
+ p = p[len("cuda:") :]
+ if p.isdigit():
+ indices.append(p)
+ if indices:
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(indices)
+
+ kwargs = dict(
+ attn_implementation="sdpa",
+ trust_remote_code=True,
+ torch_dtype="auto",
+ device_map="auto",
+ moe_impl="eager",
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **kwargs)
+ self.model.load_tokenizer(self.model_id)
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ image = self.model.generate_image(prompt=prompt, stream=True)
+ return image
+
+
+# ---------------------------------------------------------------------------
+# Bagel: text-to-image generation
+# ---------------------------------------------------------------------------
+
+
+class BagelTextGenerator(ImageGeneratorBase):
+ """
+ Bagel text-to-image (text-only).
+ - force_text_only=True: ignore reference_images
+ - mode=1: normal mode (aligned with eval/new_test/t2i_infer_batch.py defaults)
+ - Model path is passed via --diffuser-gen-model-path to gen_image_from_results.py
+ """
+
+ force_text_only = True
+
+ def __init__(
+ self,
+ model_path: str,
+ mode: int = 1,
+ device_spec: Optional[str] = None,
+ image_shapes: tuple = (1024, 1024),
+ cfg_text_scale: float = 4.0,
+ cfg_interval: tuple = (0.4, 1.0),
+ timestep_shift: float = 3.0,
+ num_timesteps: int = 50,
+ cfg_renorm_min: float = 0.0,
+ cfg_renorm_type: str = "global",
+ seed: int = 0,
+ ):
+ import os
+ self.model_path = (model_path or "").strip()
+ if not self.model_path:
+ raise ValueError("bagel backend requires a model path (pass via --diffuser-gen-model-path)")
+ # mode=1 means normal mode. For compatibility, only normal mode is implemented.
+ if mode != 1:
+ raise ValueError("bagel backend currently only supports normal mode (mode=1)")
+
+ # Reuse DIFFUSER_GEN_DEVICE format (e.g. cuda:0 / cuda:0,1 / 0,1)
+ dev = (device_spec or "").strip() or "cuda:0"
+ dev_l = dev.lower().strip()
+ requested_gpu_count = None
+ if dev_l and dev_l != "auto":
+ parts = [p.strip() for p in dev.split(",") if p.strip()]
+ indices = []
+ for p in parts:
+ if p.startswith("cuda:"):
+ p = p[len("cuda:") :]
+ if p.isdigit():
+ indices.append(p)
+ if indices:
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(indices)
+ requested_gpu_count = len(indices)
+
+ import random
+ import numpy as np
+ import torch
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
+
+ from data.transforms import ImageTransform
+ from data.data_utils import add_special_tokens
+ from modeling.bagel import (
+ Bagel,
+ BagelConfig,
+ Qwen2Config,
+ Qwen2ForCausalLM,
+ SiglipVisionConfig,
+ SiglipVisionModel,
+ )
+ from modeling.autoencoder import load_ae
+ from modeling.qwen2 import Qwen2Tokenizer
+ from inferencer import InterleaveInferencer
+
+ self.os = os
+ self.random = random
+ self.np = np
+ self.torch = torch
+
+ # ----------------- Initialize following the official Bagel pipeline -----------------
+ llm_config = Qwen2Config.from_json_file(os.path.join(self.model_path, "llm_config.json"))
+ llm_config.qk_norm = True
+ llm_config.tie_word_embeddings = False
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
+
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(self.model_path, "vit_config.json"))
+ vit_config.rope = False
+ vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
+
+ vae_model, vae_config = load_ae(local_path=os.path.join(self.model_path, "ae.safetensors"))
+
+ config = BagelConfig(
+ visual_gen=True,
+ visual_und=True,
+ llm_config=llm_config,
+ vit_config=vit_config,
+ vae_config=vae_config,
+ vit_max_num_patch_per_side=70,
+ connector_act="gelu_pytorch_tanh",
+ latent_patch_size=2,
+ max_latent_size=64,
+ )
+
+ with init_empty_weights():
+ language_model = Qwen2ForCausalLM(llm_config)
+ vit_model = SiglipVisionModel(vit_config)
+ model = Bagel(language_model, vit_model, config)
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
+
+ tokenizer = Qwen2Tokenizer.from_pretrained(self.model_path)
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
+
+ vae_transform = ImageTransform(1024, 512, 16)
+ vit_transform = ImageTransform(980, 224, 14)
+
+ # Dynamically estimate max_memory from free memory on visible GPUs to avoid forcing weights onto saturated GPUs
+ max_mem_per_gpu = os.environ.get("BAGEL_MAX_MEM_PER_GPU", "40GiB")
+ min_free_gib = int(os.environ.get("BAGEL_MIN_FREE_GIB", "8"))
+ if torch.cuda.device_count() > 0:
+ gib = 1024 ** 3
+ max_memory = {}
+ free_report = []
+ for i in range(torch.cuda.device_count()):
+ try:
+ free_bytes, _ = torch.cuda.mem_get_info(i)
+ # Reserve 2GiB buffer; keep at least 1GiB to avoid hitting the exact limit
+ usable_gib = max(1, int(free_bytes / gib) - 2)
+ max_memory[i] = f"{usable_gib}GiB"
+ free_report.append((i, round(free_bytes / gib, 2), usable_gib))
+ except Exception:
+ max_memory[i] = max_mem_per_gpu
+ free_report.append((i, -1.0, -1))
+ # Allow CPU offload to reduce OOM risk
+ max_memory["cpu"] = os.environ.get("BAGEL_MAX_MEM_CPU", "256GiB")
+ print(f"[bagel] visible_cuda={torch.cuda.device_count()}, requested={requested_gpu_count or 'auto'}", flush=True)
+ for idx, free_gib, usable in free_report:
+ print(f"[bagel] cuda:{idx} free={free_gib}GiB usable_for_map={usable}GiB", flush=True)
+
+ # If user explicitly requested 3-GPU parallelism, all 3 GPUs must have sufficient free memory
+ if requested_gpu_count == 3:
+ low_free = [idx for idx, free_gib, _ in free_report if free_gib >= 0 and free_gib < min_free_gib]
+ if low_free:
+ raise RuntimeError(
+ f"Bagel requires 3-GPU parallelism, but these visible GPUs have free memory < {min_free_gib}GiB: {low_free}. "
+ f"Please free the corresponding GPUs and retry."
+ )
+ if torch.cuda.device_count() < 3:
+ raise RuntimeError(
+ f"Bagel requires 3-GPU parallelism, but only {torch.cuda.device_count()} GPUs are visible."
+ )
+ else:
+ max_memory = None
+
+ device_map = infer_auto_device_map(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
+ )
+
+ same_device_modules = [
+ "language_model.model.embed_tokens",
+ "time_embedder",
+ "latent_pos_embed",
+ "vae2llm",
+ "llm2vae",
+ "connector",
+ "vit_pos_embed",
+ ]
+ if torch.cuda.device_count() <= 1:
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
+ for k in same_device_modules:
+ device_map[k] = device_map.get(k, first_device)
+ else:
+ first_device = device_map.get(same_device_modules[0])
+ for k in same_device_modules:
+ if k in device_map:
+ device_map[k] = first_device
+
+ try:
+ model = load_checkpoint_and_dispatch(
+ model,
+ checkpoint=os.path.join(self.model_path, "ema.safetensors"),
+ device_map=device_map,
+ offload_buffers=True,
+ dtype=torch.bfloat16,
+ force_hooks=True,
+ offload_folder="/tmp/offload",
+ )
+ except Exception as e:
+ raise RuntimeError(f"Bagel checkpoint load failed: {e}") from e
+ self.model = model.eval()
+ self.inferencer = InterleaveInferencer(
+ model=self.model,
+ vae_model=vae_model,
+ tokenizer=tokenizer,
+ vae_transform=vae_transform,
+ vit_transform=vit_transform,
+ new_token_ids=new_token_ids,
+ )
+ self.image_shapes = image_shapes
+ self.cfg_text_scale = cfg_text_scale
+ self.cfg_interval = cfg_interval
+ self.timestep_shift = timestep_shift
+ self.num_timesteps = num_timesteps
+ self.cfg_renorm_min = cfg_renorm_min
+ self.cfg_renorm_type = cfg_renorm_type
+ self.seed = seed
+
+ def generate(self, prompt: str, image_paths: Optional[List[str]] = None):
+ # Text-to-image only; do not use reference_images
+ _ = image_paths
+ self.random.seed(self.seed)
+ self.np.random.seed(self.seed)
+ self.torch.manual_seed(self.seed)
+ if self.torch.cuda.is_available():
+ self.torch.cuda.manual_seed(self.seed)
+ self.torch.cuda.manual_seed_all(self.seed)
+ self.torch.backends.cudnn.deterministic = True
+ self.torch.backends.cudnn.benchmark = False
+
+ output_dict = self.inferencer(
+ text=prompt,
+ think=True,
+ max_think_token_n=1024,
+ do_sample=False,
+ text_temperature=0.3,
+ cfg_text_scale=self.cfg_text_scale,
+ cfg_img_scale=1.0,
+ cfg_interval=list(self.cfg_interval),
+ timestep_shift=self.timestep_shift,
+ num_timesteps=self.num_timesteps,
+ cfg_renorm_min=self.cfg_renorm_min,
+ cfg_renorm_type=self.cfg_renorm_type,
+ image_shapes=self.image_shapes,
+ )
+ image = output_dict.get("image") if isinstance(output_dict, dict) else None
+ if image is None:
+ raise RuntimeError("bagel inference failed: no image returned")
+ return image
+
+
+# ---------------------------------------------------------------------------
+# Build generator
+# ---------------------------------------------------------------------------
+
+
+def build_generator(args: argparse.Namespace) -> ImageGeneratorBase:
+ backend = (getattr(args, "backend", None) or os.environ.get("GEN_IMAGE_BACKEND", "api")).strip().lower()
+ api_type = (getattr(args, "api_type", None) or os.environ.get("GEN_IMAGE_API_TYPE", "gpt")).strip().lower()
+
+ if backend == "api":
+ api_key = getattr(args, "api_key", None) or os.environ.get("GEN_IMAGE_API_KEY", "")
+ model_name = getattr(args, "model_name", None) or os.environ.get("GEN_IMAGE_MODEL", "gpt-image-1")
+ timeout = int(getattr(args, "timeout", None) or os.environ.get("GEN_IMAGE_TIMEOUT", "120"))
+ max_try = int(getattr(args, "max_try", None) or os.environ.get("GEN_IMAGE_MAX_TRY", "5"))
+ if api_type == "nano":
+ return NanoAPIGenerator(api_key=api_key, model_name=model_name, timeout=timeout, max_try=max_try, print_log=args.print_log)
+ if api_type == "seed":
+ return SeedAPIGenerator(api_key=api_key, model_name=model_name, timeout=timeout, max_try=max_try, print_log=args.print_log)
+ return GPTImageAPIGenerator(api_key=api_key, model_name=model_name, timeout=timeout, max_try=max_try, print_log=args.print_log)
+
+ gen_model = (getattr(args, "diffuser_gen_model_path", None) or "").strip() or "Qwen/Qwen-Image"
+ edit_model = (getattr(args, "diffuser_edit_model_path", None) or "").strip() or "Qwen/Qwen-Image-Edit-2509"
+ gen_device = getattr(args, "diffuser_gen_device", None) or "cuda:0"
+ edit_device = getattr(args, "diffuser_edit_device", None) or "cuda:1"
+
+ if backend == "diffuser_qwen":
+ return DiffuserQwenGenerator(
+ gen_model=gen_model,
+ edit_model=edit_model,
+ gen_device=gen_device,
+ edit_device=edit_device,
+ )
+
+ if backend == "diffuser_longcat":
+ gen_model_lc = (getattr(args, "diffuser_gen_model_path", None) or "").strip() or "meituan-longcat/LongCat-Image"
+ edit_model_lc = (getattr(args, "diffuser_edit_model_path", None) or "").strip() or "meituan-longcat/LongCat-Image-Edit"
+ return DiffuserLongCatGenerator(
+ gen_model=gen_model_lc,
+ edit_model=edit_model_lc,
+ gen_device=gen_device,
+ edit_device=edit_device,
+ )
+
+ if backend == "diffuser_zimage":
+ return DiffuserZImageGenerator(model_id=gen_model, device=gen_device)
+
+ if backend == "diffuser_zimage_turbo":
+ return DiffuserZImageTurboGenerator(model_id=gen_model, device=gen_device)
+
+ if backend == "diffuser_flux":
+ return DiffuserFluxGenerator(model_path=gen_model, device=gen_device)
+
+ if backend == "diffuser_lumina2":
+ return DiffuserLumina2Generator(model_path=gen_model, device=gen_device)
+
+ if backend == "diffuser_sd3":
+ return DiffuserSD3Generator(model_path=gen_model, device=gen_device)
+
+ if backend == "hunyuan_image3":
+ return HunyuanImage3Generator(model_path=gen_model, device_spec=gen_device)
+
+ if backend == "bagel":
+ # Open-source build: bagel backend is disabled
+ raise ValueError("bagel backend is disabled (open-source build does not support it)")
+
+ raise ValueError(f"Unsupported backend: {backend}")
+
+
+# ---------------------------------------------------------------------------
+# Main: read JSON -> generate per record -> write images and results.json
+# ---------------------------------------------------------------------------
+
+
+def _safe_model_suffix(s: str) -> str:
+ return re.sub(r"[^\w\-]", "_", s).strip("_") or "model"
+
+
+def _pick_gt_image(rec: dict) -> Optional[str]:
+ """Compatibility for GT fields: prediction.gen_image / top-level gt_image / top-level gen_image."""
+ return (
+ (rec.get("prediction") or {}).get("gen_image")
+ or rec.get("gt_image")
+ or rec.get("gen_image")
+ )
+
+
+def _run_one_record(
+ generator: ImageGeneratorBase,
+ rec: dict,
+ index: int,
+ images_dir: Path,
+ print_log: bool = False,
+) -> tuple[int, dict]:
+ """Process one record and return (index, entry)."""
+ rid = rec.get("id", index)
+ sid = str(rid)
+ prompt, image_paths = get_effective_prompt_and_images(rec)
+ if getattr(generator, "force_text_only", False):
+ image_paths = []
+
+ out_name = f"{sid}.png"
+ out_path = (images_dir / out_name).resolve()
+ entry = {
+ "id": rid,
+ "prompt": rec.get("prompt", ""),
+ "gen_prompt": rec.get("gen_prompt", ""),
+ "meta": rec.get("meta", {}),
+ "gt_image": _pick_gt_image(rec),
+ "used_prompt": prompt,
+ "used_images": image_paths,
+ "output_path": str(out_path),
+ "success": False,
+ }
+ try:
+ img = generator.generate(prompt, image_paths if image_paths else None)
+ img.save(out_path)
+ entry["success"] = True
+ except Exception as e:
+ entry["success"] = False
+ entry["error"] = str(e)
+ if print_log:
+ print(f"[{sid}] Generation failed: {e}", flush=True)
+ return (index, entry)
+
+
+def _parse_device_list(device_spec: str) -> List[str]:
+ """Parse 'cuda:1,cuda:2,cuda:3' / '1,2,3' into ['1','2','3']."""
+ raw = (device_spec or "").strip()
+ if not raw:
+ return []
+ out: List[str] = []
+ for p in raw.split(","):
+ p = p.strip()
+ if not p:
+ continue
+ if p.startswith("cuda:"):
+ p = p[len("cuda:") :]
+ if p.isdigit():
+ out.append(p)
+ return out
+
+
+def _run_bagel_worker_mode(args: argparse.Namespace) -> int:
+ """
+ Internal worker:
+ - Read shard input (JSON array)
+ - Run bagel generation serially
+ - Write results to _bagel_worker_output
+ """
+ if not args._bagel_worker_output:
+ raise ValueError("worker mode missing --_bagel-worker-output")
+
+ records = load_records(args.input)
+ out_root = Path(args.output_dir)
+ folder_name = _safe_model_suffix(getattr(args, "suffix", "default"))
+ save_dir = out_root / folder_name
+ save_dir.mkdir(parents=True, exist_ok=True)
+ images_dir = save_dir / "images"
+ images_dir.mkdir(parents=True, exist_ok=True)
+ results_path = save_dir / "results.json"
+
+ generator = build_generator(args)
+ worker_results = []
+ for i, rec in tqdm(list(enumerate(records)), desc=f"BagelWorker-{args._bagel_worker_id}", unit="sample"):
+ _, entry = _run_one_record(generator, rec, i, images_dir, args.print_log)
+ try:
+ entry["_orig_index"] = int(rec.get("__orig_index", i))
+ except Exception:
+ entry["_orig_index"] = i
+ worker_results.append(entry)
+
+ with open(args._bagel_worker_output, "w", encoding="utf-8") as f:
+ json.dump(worker_results, f, ensure_ascii=False, indent=2)
+ return 0
+
+
+def _run_bagel_data_parallel(args: argparse.Namespace, records: list, results_path: Path) -> bool:
+ """
+ Bagel data-parallel entry:
+ - When --diffuser-gen-device includes multiple GPUs, shard samples and launch worker subprocesses
+ - Each worker sees only one GPU and loads its own model copy
+ - Workers write partial outputs; the main process merges them into results.json
+
+ Returns True if handled and caller can return; False if this branch is not used.
+ """
+ if getattr(args, "_bagel_worker_mode", False):
+ return False
+ if getattr(args, "backend", "") != "bagel":
+ return False
+
+ # Main process: prefer CLI device list; fall back to env DIFFUSER_GEN_DEVICE
+ raw_device_spec = getattr(args, "diffuser_gen_device", None) or os.environ.get("DIFFUSER_GEN_DEVICE", "")
+ device_list = _parse_device_list(raw_device_spec)
+ if len(device_list) <= 1:
+ return False
+
+ success_ids = set()
+ id_to_index = {}
+ existing_results = []
+ if args.resume and results_path.exists():
+ try:
+ with open(results_path, "r", encoding="utf-8") as f:
+ existing_results = json.load(f)
+ for idx, r in enumerate(existing_results):
+ iid = r.get("id")
+ if iid is not None:
+ sid = str(iid)
+ id_to_index[sid] = idx
+ if r.get("success") is True:
+ success_ids.add(sid)
+ except Exception:
+ existing_results = []
+ id_to_index = {}
+ success_ids = set()
+
+ pending = []
+ for i, rec in enumerate(records):
+ sid = str(rec.get("id", i))
+ if args.resume and sid in success_ids:
+ continue
+ rec_copy = dict(rec)
+ rec_copy["__orig_index"] = i
+ pending.append(rec_copy)
+
+ if not pending:
+ if existing_results:
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(existing_results, f, ensure_ascii=False, indent=2)
+ else:
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump([], f, ensure_ascii=False, indent=2)
+ return True
+
+ shards = [[] for _ in device_list]
+ for idx, rec in enumerate(pending):
+ shards[idx % len(device_list)].append(rec)
+
+ temp_dir = Path(tempfile.mkdtemp(prefix="bagel_dp_"))
+ procs = []
+ worker_out_files = []
+
+ try:
+ script_path = Path(__file__).resolve()
+ for worker_id, (gpu_id, shard) in enumerate(zip(device_list, shards)):
+ if not shard:
+ continue
+ shard_in = temp_dir / f"shard_{worker_id}.json"
+ shard_out = temp_dir / f"worker_{worker_id}_results.json"
+ with open(shard_in, "w", encoding="utf-8") as f:
+ json.dump(shard, f, ensure_ascii=False, indent=2)
+
+ cmd = [
+ sys.executable,
+ str(script_path),
+ "--input",
+ str(shard_in),
+ "--output-dir",
+ str(args.output_dir),
+ "--suffix",
+ str(args.suffix),
+ "--backend",
+ "bagel",
+ "--diffuser-gen-model-path",
+ str(args.diffuser_gen_model_path or ""),
+ "--diffuser-gen-device",
+ # Note: in the worker process CUDA_VISIBLE_DEVICES is controlled by env;
+ # use auto here to avoid BagelTextGenerator overriding env again
+ "auto",
+ "--_bagel-worker-mode",
+ "--_bagel-worker-id",
+ str(worker_id),
+ "--_bagel-worker-output",
+ str(shard_out),
+ ]
+ if args.print_log:
+ cmd.append("--print-log")
+ if getattr(args, "diffuser_edit_model_path", None):
+ cmd.extend(["--diffuser-edit-model-path", str(args.diffuser_edit_model_path)])
+ if getattr(args, "diffuser_edit_device", None):
+ cmd.extend(["--diffuser-edit-device", str(args.diffuser_edit_device)])
+
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+ if args.print_log:
+ print(
+ f"[bagel-dp] start worker_id={worker_id} gpu_id={gpu_id} shard_size={len(shard)} CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}",
+ flush=True,
+ )
+ p = subprocess.Popen(cmd, env=env)
+ procs.append((worker_id, p))
+ worker_out_files.append(shard_out)
+
+ for worker_id, p in procs:
+ rc = p.wait()
+ if rc != 0:
+ raise RuntimeError(f"bagel worker {worker_id} exited with non-zero code: {rc}")
+
+ merged = [None] * len(records)
+ if args.resume and existing_results:
+ for i, rec in enumerate(records):
+ sid = str(rec.get("id", i))
+ if sid in id_to_index:
+ old = existing_results[id_to_index[sid]]
+ if old.get("success") is True:
+ merged[i] = old
+ for out_file in worker_out_files:
+ if not out_file.exists():
+ continue
+ with open(out_file, "r", encoding="utf-8") as f:
+ part = json.load(f)
+ for entry in part:
+ orig_index = int(entry.pop("_orig_index", -1))
+ if 0 <= orig_index < len(records):
+ merged[orig_index] = entry
+
+ # Keep one-to-one correspondence with input records to avoid holes
+ for i, rec in enumerate(records):
+ if merged[i] is None:
+ merged[i] = {"id": rec.get("id", i), "pending": True}
+
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(merged, f, ensure_ascii=False, indent=2)
+ return True
+ finally:
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate images from result JSON (API or local diffuser)")
+ parser.add_argument("--input", "-i", required=True, help="Input JSON/JSONL (id/prompt/gen_prompt/reference_images or prompt-only)")
+ parser.add_argument("--output-dir", "-o", required=True, help="Output root directory")
+ parser.add_argument("--suffix", default="default", help="Output subdirectory name; results under OUTPUT_DIR/{suffix}/")
+ parser.add_argument("--resume", action="store_true", help="Skip ids that already have output")
+ parser.add_argument("--print-log", action="store_true", help="Print request/retry logs")
+
+ parser.add_argument(
+ "--backend",
+ choices=[
+ "api",
+ "diffuser_qwen",
+ "diffuser_longcat",
+ "diffuser_zimage",
+ "diffuser_zimage_turbo",
+ "diffuser_flux",
+ "diffuser_lumina2",
+ "diffuser_sd3",
+ "hunyuan_image3",
+ ],
+ default="api",
+ )
+ parser.add_argument("--api-type", choices=["nano", "seed", "gpt"], default="gpt")
+ parser.add_argument("--api-key", default=None)
+ parser.add_argument("--model-name", default=None)
+ parser.add_argument("--timeout", type=int, default=None)
+ parser.add_argument("--max-try", type=int, default=None)
+ parser.add_argument("--parallel", type=int, default=12, help="API-mode parallelism (effective only when backend=api)")
+
+ parser.add_argument("--diffuser-gen-model-path", default=None, help="Diffuser gen model: local path or HuggingFace id (e.g. Qwen/Qwen-Image); HF ids will be downloaded")
+ parser.add_argument("--diffuser-edit-model-path", default=None, help="Diffuser edit model: local path or HuggingFace id")
+ parser.add_argument("--diffuser-gen-device", default=None, help="Diffuser gen device, e.g. cuda:0")
+ parser.add_argument("--diffuser-edit-device", default=None, help="Diffuser edit device, e.g. cuda:1")
+ # Open-source build disables bagel, so hidden bagel-worker args are not kept
+ args = parser.parse_args()
+
+ # if args._bagel_worker_mode:
+ # _run_bagel_worker_mode(args)
+ # return
+
+ records = load_records(args.input)
+ out_root = Path(args.output_dir)
+ folder_name = _safe_model_suffix(getattr(args, "suffix", "default"))
+ save_dir = out_root / folder_name
+ save_dir.mkdir(parents=True, exist_ok=True)
+ images_dir = save_dir / "images"
+ images_dir.mkdir(parents=True, exist_ok=True)
+ results_path = save_dir / "results.json"
+
+ # Bagel multi-GPU data-parallel mode: disabled
+
+ generator = build_generator(args)
+ backend = getattr(args, "backend", "api")
+ if backend == "api":
+ worker_count = max(1, int(getattr(args, "parallel", 1)))
+ elif backend == "diffuser_qwen":
+ worker_count = max(1, int(getattr(generator, "max_parallel", 1)))
+ else:
+ worker_count = 1
+ use_parallel = worker_count > 1
+
+ if use_parallel:
+ # Parallel mode (API or qwen): keep results in input order and update by index
+ by_id = {}
+ if args.resume and results_path.exists():
+ try:
+ with open(results_path, "r", encoding="utf-8") as f:
+ loaded = json.load(f)
+ by_id = {str(r.get("id")): r for r in loaded if r.get("id") is not None}
+ except Exception:
+ pass
+ results = [None] * len(records)
+ for i, rec in enumerate(records):
+ sid = str(rec.get("id", i))
+ if sid in by_id and by_id[sid].get("success") is True:
+ results[i] = by_id[sid]
+ pending_indices = [i for i in range(len(records)) if results[i] is None or results[i].get("success") is not True]
+ write_lock = threading.Lock()
+
+ def _write_results():
+ out = []
+ for i, r in enumerate(results):
+ if r is not None:
+ out.append(r)
+ else:
+ out.append({"id": records[i].get("id", i), "pending": True})
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(out, f, ensure_ascii=False, indent=2)
+
+ with ThreadPoolExecutor(max_workers=worker_count) as ex:
+ futures = {
+ ex.submit(
+ _run_one_record,
+ generator,
+ records[i],
+ i,
+ images_dir,
+ args.print_log,
+ ): i
+ for i in pending_indices
+ }
+ for fut in tqdm(as_completed(futures), total=len(pending_indices), desc="GenImage", unit="sample"):
+ try:
+ i, entry = fut.result()
+ with write_lock:
+ results[i] = entry
+ _write_results()
+ except Exception as e:
+ idx = futures[fut]
+ rec = records[idx]
+ with write_lock:
+ results[idx] = {
+ "id": rec.get("id", idx),
+ "prompt": rec.get("prompt", ""),
+ "gen_prompt": rec.get("gen_prompt", ""),
+ "meta": rec.get("meta", {}),
+ "gt_image": _pick_gt_image(rec),
+ "success": False,
+ "error": str(e),
+ }
+ _write_results()
+ # Final write (aligned with input order; unfinished entries keep pending=True)
+ final_results = [
+ results[i] if results[i] is not None else {"id": records[i].get("id", i), "pending": True}
+ for i in range(len(records))
+ ]
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(final_results, f, ensure_ascii=False, indent=2)
+ else:
+ # Serial mode (non-API backends or parallel=1)
+ results = []
+ success_ids = set()
+ id_to_index = {}
+ if args.resume and results_path.exists():
+ try:
+ with open(results_path, "r", encoding="utf-8") as f:
+ results = json.load(f)
+ for idx, r in enumerate(results):
+ iid = r.get("id")
+ if iid is not None:
+ sid = str(iid)
+ id_to_index[sid] = idx
+ if r.get("success") is True:
+ success_ids.add(sid)
+ except Exception:
+ pass
+
+ to_run = [
+ (i, rec)
+ for i, rec in enumerate(records)
+ if not (args.resume and str(rec.get("id", i)) in success_ids)
+ ]
+ for i, rec in tqdm(to_run, desc="GenImage", unit="sample"):
+ _, entry = _run_one_record(generator, rec, i, images_dir, args.print_log)
+ rid = rec.get("id", i)
+ sid = str(rid)
+ if args.resume and sid in id_to_index:
+ results[id_to_index[sid]] = entry
+ else:
+ results.append(entry)
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+ print(f"Done. Output dir: {save_dir}; results: {results_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/eval/run_eval.sh b/vendor/rllm/eval/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6a459a424488feb646bce70d3590fbbebf5c8845
--- /dev/null
+++ b/vendor/rllm/eval/run_eval.sh
@@ -0,0 +1,8 @@
+python3 -m eval.eval_runner \
+ --parquet /path/to/test.parquet \
+ --base-url http://10.144.200.237:8001/v1 \
+ --model Vision-DeepResearch-8B \
+ --parallel-tasks 5 \
+ --api-key EMPTY \
+ 2>&1 | tee eval/log.txt
+
diff --git a/vendor/rllm/eval/run_gen_image_eval.sh b/vendor/rllm/eval/run_gen_image_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..284561e2ddb15f18cd509eca505a2c67c4ed20b8
--- /dev/null
+++ b/vendor/rllm/eval/run_gen_image_eval.sh
@@ -0,0 +1,55 @@
+#!/usr/bin/env bash
+# Gen Image eval: only produce trajectory logs, gen_prompt, and reference_images; no image generation or scoring
+#
+# Input JSON format (array or JSONL, one object per line):
+# {"id": 7012, "prompt": "...", "meta": {"category": "...", ...}, "gen_image": "/path/to/gt.png"}
+# gen_image is the ground-truth path in the legacy format; current logic does not use it.
+#
+# Environment variables (optional):
+# GEN_EVAL_INPUT_JSON input JSON path
+# GEN_EVAL_OUTPUT_DIR output directory
+# OPENAI_BASE_URL inference service base_url
+# GEN_EVAL_MODEL model name
+# OPENAI_API_KEY API Key
+set -euo pipefail
+
+# Ensure we run under the rllm directory
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+RLLM_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
+cd "${RLLM_DIR}"
+
+# Load Gen Image environment variables (required by search/browse/image tools)
+if [ -f .env.gen_image ]; then
+ echo "[GenEval] Loading .env.gen_image"
+ set -a
+ source .env.gen_image
+ set +a
+else
+ echo "[GenEval] WARNING: .env.gen_image not found; some tools may be unavailable"
+fi
+
+INPUT_JSON="KnowGen-Bench.json"
+OUTPUT_DIR="./eval_output"
+SUFFIX="test"
+OUTPUT_DIR="${OUTPUT_DIR}/${SUFFIX}"
+
+# vLLM OpenAI-compatible API base URL for Gen-Searcher-8B (replace host/IP after deployment)
+BASE_URL="http://xxx:8001/v1"
+MODEL="Gen-Searcher-8B"
+
+mkdir -p "${OUTPUT_DIR}"
+LOG_FILE="${OUTPUT_DIR}/gen_eval_$(date +%Y%m%d_%H%M%S).log"
+
+python3 -m eval.gen_image_eval_runner \
+ --input "${INPUT_JSON}" \
+ --output-dir "${OUTPUT_DIR}" \
+ --base-url "${BASE_URL}" \
+ --model "${MODEL}" \
+ --max-prompt-length 64000 \
+ --max-response-length 64000 \
+ --parallel-tasks 5 \
+ --temperature 0.6 \
+ --top-p 0.9 \
+ --resume \
+ "$@" 2>&1 | tee "${LOG_FILE}"
+
diff --git a/vendor/rllm/eval/run_gen_image_from_results.sh b/vendor/rllm/eval/run_gen_image_from_results.sh
new file mode 100644
index 0000000000000000000000000000000000000000..605ea18896d987394967882c5972c34a810bee49
--- /dev/null
+++ b/vendor/rllm/eval/run_gen_image_from_results.sh
@@ -0,0 +1,58 @@
+#!/usr/bin/env bash
+# Generate images from a result JSON (with gen_prompt/reference_images) or a prompt-only JSON.
+# Supports API backends (nano/seed/gpt with retries) and local diffusers (qwen gen+edit, zimage, zimage_turbo, etc.).
+# Output layout: OUTPUT_DIR / {suffix} / images/ and results.json
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "${SCRIPT_DIR}"
+
+# ----------------- I/O (edit here) -----------------
+# Paths are relative to this script directory (rllm/eval)
+INPUT_JSON="eval_output/results.json"
+OUTPUT_DIR="./output"
+SUFFIX="qwen_image_test_gen"
+
+# ----------------- Backend (edit here) -----------------
+BACKEND="diffuser_qwen" # api | diffuser_qwen | diffuser_longcat | diffuser_zimage | diffuser_zimage_turbo | diffuser_flux | diffuser_lumina2 | diffuser_sd3 | hunyuan_image3
+
+# If BACKEND is api, then the following variables are used:
+API_TYPE="nano" # nano | seed | gpt
+API_KEY=""
+# e.g., "gemini-3-pro-image-preview" "doubao-seedream-4-5-251128"
+MODEL_NAME="gemini-3-pro-image-preview"
+TIMEOUT=1600
+MAX_TRY=50
+PARALLEL=10 # API-mode parallelism
+
+# If BACKEND is NOT api, then the following variables are used:
+
+# Shared local-diffuser config (used by all diffuser_* backends; selected by backend)
+# gen: text-to-image model; edit: image-editing model (only for backends that need both, e.g. qwen)
+# If path is empty, Python falls back to the default HuggingFace id
+DIFFUSER_GEN_MODEL_PATH="Qwen/Qwen-Image"
+DIFFUSER_EDIT_MODEL_PATH="Qwen/Qwen-Image-Edit-2509"
+DIFFUSER_GEN_DEVICE="cuda:0"
+DIFFUSER_EDIT_DEVICE="cuda:1,cuda:2,cuda:3" # use more GPUs for Qwen-Image-Edit for acceleration
+
+# ----------------- Run -----------------
+mkdir -p "${OUTPUT_DIR}"
+
+python3 "${SCRIPT_DIR}/gen_image_from_results.py" \
+ --input "${INPUT_JSON}" \
+ --output-dir "${OUTPUT_DIR}" \
+ --suffix "${SUFFIX}" \
+ --backend "${BACKEND}" \
+ --api-type "${API_TYPE}" \
+ --api-key "${API_KEY}" \
+ --model-name "${MODEL_NAME}" \
+ --timeout "${TIMEOUT}" \
+ --max-try "${MAX_TRY}" \
+ --parallel "${PARALLEL}" \
+ --diffuser-gen-model-path "${DIFFUSER_GEN_MODEL_PATH}" \
+ --diffuser-edit-model-path "${DIFFUSER_EDIT_MODEL_PATH}" \
+ --diffuser-gen-device "${DIFFUSER_GEN_DEVICE}" \
+ --diffuser-edit-device "${DIFFUSER_EDIT_DEVICE}" \
+ --resume \
+ --print-log
diff --git a/vendor/rllm/examples/appworld/run_appworld_agent.py b/vendor/rllm/examples/appworld/run_appworld_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..47466a5b978f6a457de0e0caf0d38ca96e8ab4cb
--- /dev/null
+++ b/vendor/rllm/examples/appworld/run_appworld_agent.py
@@ -0,0 +1,136 @@
+import argparse
+import asyncio
+import os
+
+from transformers import AutoTokenizer
+
+from rllm.agents.appworld_react_agents import AppWorldReactAgent
+from rllm.engine.agent_execution_engine import AgentExecutionEngine
+from rllm.environments.appworld.appworld_env import AppWorldEnv
+
+# ============================================================================
+# Fix AppWorld multithreading issues: apply signal patch
+# ============================================================================
+from rllm.environments.appworld.signal_patch import apply_signal_patch
+from rllm.utils import compute_pass_at_k, save_trajectories
+
+# Apply patch before importing AppWorld
+# Signal can only be used in the main thread but in the async engine, the thread is not the main thread.
+apply_signal_patch(verbose=True)
+# ============================================================================
+
+
+async def main(num_tasks=10, max_turns=40, split="dev"):
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ # Check API key
+ if not os.getenv("OPENAI_API_KEY"):
+ print("No OPENAI_API_KEY")
+ return
+
+ n_parallel_agents = 4
+
+ model_name = "gpt-4o-mini"
+ # Use a tokenizer with chat template (only for formatting messages and calculating token counts in the engine)
+ # Qwen2-0.5B is small and fast to download
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
+ agent_args = {}
+ env_args = {"max_turns": max_turns}
+
+ # Create engine
+ engine = AgentExecutionEngine(
+ agent_class=AppWorldReactAgent,
+ agent_args=agent_args,
+ env_class=AppWorldEnv,
+ env_args=env_args,
+ engine_name="openai",
+ tokenizer=tokenizer,
+ sampling_params=sampling_params,
+ rollout_engine_args={"base_url": "https://api.openai.com/v1", "api_key": os.getenv("OPENAI_API_KEY")},
+ n_parallel_agents=n_parallel_agents,
+ max_response_length=16384,
+ max_prompt_length=4096,
+ max_steps=max_turns,
+ )
+
+ tasks = load_appworld_official_tasks(split=split, num_tasks=num_tasks)
+
+ if not tasks:
+ print("No tasks loaded, exiting...")
+ return
+
+ print(f"Running evaluation on {len(tasks)} AppWorld tasks...")
+ results = await engine.execute_tasks(tasks)
+
+ # Save trajectories
+ save_trajectories(results, save_dir="./trajectories/appworld", filename="trajectories.pt")
+ compute_pass_at_k(results)
+ # Compute accuracy and show per-task results
+ print("\n" + "=" * 80)
+ print("Task Completion Results")
+ print("=" * 80)
+ n_passed = 0
+ for i, trajectory in enumerate(results, 1):
+ task_id = trajectory.task.get("task_id", f"task_{i}") if isinstance(trajectory.task, dict) else f"task_{i}"
+ reward = trajectory.reward
+ status = "PASSED" if reward >= 1.0 else "FAILED"
+
+ print(f"{i:2d}. {task_id:20s} | Reward: {reward:.2f} | {status}")
+
+ if reward >= 1.0:
+ n_passed += 1
+
+ accuracy = n_passed / num_tasks if num_tasks > 0 else 0.0
+
+ print("=" * 80)
+ print(f"Summary: {n_passed} out of {num_tasks} tasks passed")
+ print(f"Accuracy: {accuracy:.2%} ({n_passed}/{num_tasks})")
+ print("=" * 80 + "\n")
+
+
+def load_appworld_official_tasks(split="dev", num_tasks=10):
+ """
+ Load tasks from the official AppWorld tasks.
+ """
+ try:
+ # lazy load the appworld package
+ from appworld import AppWorld, load_task_ids
+
+ # Use 'dev' split for development/testing
+ # Available splits: 'train', 'dev', 'test_normal', 'test_challenge'
+ task_ids = load_task_ids(split)[:num_tasks] # Get first 10 task IDs
+
+ # Create task dictionaries with task_id
+ # The AppWorldEnv will load the instruction when it initializes
+ tasks = []
+ for task_id in task_ids:
+ # Temporarily create AppWorld instance to get instruction for display
+ try:
+ world = AppWorld(task_id=task_id)
+ instruction = world.task.instruction
+ except Exception:
+ instruction = f"Task {task_id}"
+
+ tasks.append({"task_id": task_id, "instruction": instruction})
+
+ print(f"Loaded {len(tasks)} official AppWorld tasks from 'dev' split")
+
+ for task in tasks:
+ print(f"Task {task['task_id']}: {task['instruction'][:80]}...")
+ return tasks
+ except Exception as e:
+ print(f"Warning: Cannot load AppWorld - {e}")
+ raise e
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run AppWorld Agent with rLLM", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument("-n", "--num-tasks", type=int, default=10, help="Number of tasks to run (use -1 for all tasks)")
+ parser.add_argument("-t", "--max-turns", type=int, default=40, help="Maximum number of turns per task")
+ parser.add_argument("-s", "--split", type=str, default="dev", choices=["train", "dev", "test_normal", "test_challenge"], help="Which split to use")
+
+ args = parser.parse_args()
+
+ asyncio.run(main(num_tasks=args.num_tasks, max_turns=args.max_turns, split=args.split))
diff --git a/vendor/rllm/examples/countdown/prepare_countdown_data.py b/vendor/rllm/examples/countdown/prepare_countdown_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f0a9d143f500a6701f36da820dd2182eaac4c1
--- /dev/null
+++ b/vendor/rllm/examples/countdown/prepare_countdown_data.py
@@ -0,0 +1,96 @@
+import random
+
+from datasets import load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_countdown_data():
+ """
+ Prepare the countdown task dataset from HuggingFace.
+ Take 1024 examples as test set, remaining as training set.
+ Also create stage 2 and stage 3 training sets with 50k examples each.
+ """
+ # Load the countdown dataset
+ dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
+
+ # Split dataset: 1024 examples for test, rest for training
+ test_size = 1024
+ total_size = len(dataset)
+
+ # Create train/test split
+ test_dataset = dataset.select(range(test_size))
+ train_dataset = dataset.select(range(test_size, total_size))
+
+ def preprocess_fn(example, idx):
+ """
+ Convert countdown task format to math problem format.
+ Example: target=98, nums=[44, 19, 35] becomes a math word problem.
+ """
+ target = example["target"]
+ nums = example["nums"]
+
+ # Format as a math problem
+ nums_str = ", ".join(map(str, nums))
+ question = f"Using the numbers {nums_str}, find a way to reach the target number {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your step-by-step calculation and output the final answer within ... , for example (1 + 2) / 3 ."
+
+ return {
+ "question": question,
+ "ground_truth": str(target),
+ "data_source": "countdown",
+ "target": target,
+ "nums": nums,
+ }
+
+ # Apply preprocessing
+ train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
+ test_dataset = test_dataset.map(preprocess_fn, with_indices=True)
+
+ # Create stage 2 and stage 3 training datasets
+ train_size = len(train_dataset)
+ stage_size = 50000
+
+ # Ensure we have enough data for both stages
+ if train_size < 2 * stage_size:
+ print(f"Warning: Training set has only {train_size} examples, but need {2 * stage_size} for both stages")
+ stage_size = min(stage_size, train_size // 2)
+
+ # Shuffle and select indices for stage 2 and stage 3
+ all_indices = list(range(train_size))
+ random.shuffle(all_indices)
+
+ stage2_indices = all_indices[:stage_size]
+ stage3_indices = all_indices[stage_size : 2 * stage_size]
+
+ # Create stage datasets
+ stage2_dataset = train_dataset.select(stage2_indices)
+ stage3_dataset = train_dataset.select(stage3_indices)
+
+ # Register datasets
+ train_dataset = DatasetRegistry.register_dataset("countdown", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("countdown", test_dataset, "test")
+ stage2_dataset = DatasetRegistry.register_dataset("countdown", stage2_dataset, "stage2_train")
+ stage3_dataset = DatasetRegistry.register_dataset("countdown", stage3_dataset, "stage3_train")
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Test dataset size: {len(test_dataset)}")
+ print(f"Stage 2 train dataset size: {len(stage2_dataset)}")
+ print(f"Stage 3 train dataset size: {len(stage3_dataset)}")
+
+ return train_dataset, test_dataset, stage2_dataset, stage3_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset, stage2_dataset, stage3_dataset = prepare_countdown_data()
+ print("Train dataset path:", train_dataset.get_data_path())
+ print("Test dataset path:", test_dataset.get_data_path())
+ print("Stage 2 train dataset path:", stage2_dataset.get_data_path())
+ print("Stage 3 train dataset path:", stage3_dataset.get_data_path())
+
+ # Print a sample
+ print("\nSample train example:")
+ print(train_dataset[0])
+ print("\nSample stage 2 train example:")
+ print(stage2_dataset[0])
+ print("\nSample stage 3 train example:")
+ print(stage3_dataset[0])
diff --git a/vendor/rllm/examples/countdown/train_countdown.py b/vendor/rllm/examples/countdown/train_countdown.py
new file mode 100644
index 0000000000000000000000000000000000000000..f272bcbda64013e8719501249ee8608b23d353c0
--- /dev/null
+++ b/vendor/rllm/examples/countdown/train_countdown.py
@@ -0,0 +1,27 @@
+import hydra
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.rewards.countdown_reward import countdown_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+from rllm.workflows.simple_workflow import SimpleWorkflow
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("countdown", "train")
+ test_dataset = DatasetRegistry.load_dataset("countdown", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=SimpleWorkflow,
+ workflow_args={
+ "reward_function": countdown_reward_fn,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/countdown/train_countdown.sh b/vendor/rllm/examples/countdown/train_countdown.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d941d9c979f032195f74a9aff71e8fa7f6e1060a
--- /dev/null
+++ b/vendor/rllm/examples/countdown/train_countdown.sh
@@ -0,0 +1,64 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python3 -m examples.countdown.train_countdown \
+ data.train_batch_size=64 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=1024 \
+ actor_rollout_ref.model.path=Qwen/Qwen3-0.6B \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0.0 \
+ actor_rollout_ref.actor.clip_ratio_low=0.2 \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.adv_estimator=grpo \
+ rllm.compact_filtering.enable=False \
+ rllm.compact_filtering.mask_max_prompt_length_exceeded=True \
+ rllm.compact_filtering.mask_max_response_length_exceeded=True \
+ rllm.compact_filtering.mask_max_turns_exceeded=False \
+ rllm.compact_filtering.mask_timeout=True \
+ rllm.rejection_sample.enable=False \
+ rllm.rejection_sample.multiplier=1.0 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.stepwise_advantage.mode=per_step \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='countdown' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=1000 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ trainer.total_epochs=100 \
+ rllm.workflow.use_workflow=True
+
+pkill -9 -f 'ray::WorkerDict'
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepcoder/README.md b/vendor/rllm/examples/deepcoder/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d9f80286bbf495e0b3f8287c462d30d6cad035df
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/README.md
@@ -0,0 +1,88 @@
+# DeepCoder Training Examples
+
+This directory contains examples for training and running DeepCoder, a code reasoning LLM fine-tuned from DeepSeek-R1-Distill-Qwen-14B using distributed reinforcement learning (RL).
+
+Our examples uses the following:
+* DeepSeek-R1-Distill-Qwen-14B as the base model
+* agentica-org/DeepCoder-Preview-Dataset (lcbv5 subset) for training and evaluation
+
+
+
+## Model Hosting
+
+### Option 1: Using vLLM
+
+Start a vLLM server with OpenAI-compatible API:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model agentica-org/DeepCoder-14B-Preview \
+ --host 0.0.0.0 \
+ --port 30000 \
+ --dtype bfloat16 \
+ --max-model-len 65536
+```
+
+### Option 2: Using SGLang
+
+```bash
+python -m sglang_router.launch_server \
+ --model-path agentica-org/DeepCoder-14B-Preview \
+ --dp-size 1 \
+ --dtype bfloat16
+# increase dp_size to enable data-parallel processing on multi-GPU
+```
+
+The server should be accessible at `http://localhost:30000/v1`
+
+## Dataset Preparation
+
+Prepare the DeepCoder Preview Dataset:
+
+```bash
+cd examples/deepcoder
+python prepare_deepcoder_data.py
+```
+
+This will:
+- Download the agentica-org/DeepCoder-Preview-Dataset (lcbv5 subset)
+- Register both train/test splits with the RLLM DatasetRegistry
+
+## Running Inference
+
+Once your model server is running and datasets are prepared, you can run inference:
+
+```bash
+cd examples/deepcoder
+python run_deepcoder.py
+```
+
+### Configuration Options
+
+You can modify the inference script parameters:
+
+- `n_parallel_agents`: Number of parallel agents (default: 64)
+- `model_name`: Model to use (default: "agentica-org/DeepCoder-14B-Preview")
+- `base_url`: API server URL (default: "http://localhost:30000/v1")
+- `max_response_length`: Maximum response length (default: 64000)
+- `max_prompt_length`: Maximum prompt length (default: 2048)
+- `temperature`: Sampling temperature (default: 0.6)
+- `top_p`: Top-p sampling (default: 0.95)
+
+The script will:
+1. Load the DeepCoder Preview test dataset
+2. Run parallel and async trajectory collection using the agent execution engine
+3. Evaluate results and report accuracy metrics
+
+## Training
+
+### Basic Training
+
+To train DeepCoder with iterative context lengthening (16K -> 32K -> 64K):
+
+```bash
+bash examples/deepcoder/train_deepcoder_16k.sh
+
+# modify MODEL_PATH to the 16k checkpoint path before running the script.
+bash examples/deepcoder/train_deepcoder_32k.sh
+```
diff --git a/vendor/rllm/examples/deepcoder/prepare_deepcoder_data.py b/vendor/rllm/examples/deepcoder/prepare_deepcoder_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd634a78af23248ac6aff2b5daeb1d5a3d9883a
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/prepare_deepcoder_data.py
@@ -0,0 +1,61 @@
+import json
+
+from datasets import concatenate_datasets, load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.data.utils import fetch_live_code_bench_system_prompt
+
+
+def prepare_deepcoder_data(train_size: int = None, test_size: int = None):
+ train_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="primeintellect", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="taco", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="train")])
+ test_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="codeforces", split="test"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="test")])
+
+ def preprocess_fn(example, idx):
+ starter_code = example.get("starter_code", "")
+ question = fetch_live_code_bench_system_prompt(example["problem"], starter_code if starter_code else None)
+
+ tests_raw = example["tests"]
+ # Handle different test formats
+ if isinstance(tests_raw, str):
+ tests = json.loads(tests_raw)
+ else:
+ tests = tests_raw
+ metadata = example.get("metadata", {})
+
+ # Convert TACO format to standard format
+ if isinstance(tests, dict) and "inputs" in tests and "outputs" in tests:
+ normalized_tests = []
+ for input_val, output_val in zip(tests["inputs"], tests["outputs"], strict=False):
+ normalized_tests.append({"input": input_val, "output": output_val, "testtype": "stdin_stdout"})
+ tests = normalized_tests
+
+ # Ensure tests is always a list
+ if not isinstance(tests, list):
+ tests = [tests] if tests else []
+
+ for test in tests:
+ if test.get("testtype") == "functional" and metadata.get("func_name") is not None:
+ test["metadata"] = {"func_name": str(metadata["func_name"])}
+ else:
+ test["metadata"] = {"func_name": None}
+
+ return {"question": question, "ground_truth": json.dumps(tests), "data_source": "livecodebench", "uid": f"deepcoder_{idx}", "index": idx, "starter_code": starter_code, "metadata": json.dumps(metadata)}
+
+ if train_size:
+ train_dataset = train_dataset.select(range(min(train_size, len(train_dataset))))
+ if test_size:
+ test_dataset = test_dataset.select(range(min(test_size, len(test_dataset))))
+
+ train_dataset = train_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)
+ test_dataset = test_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)
+ train_dataset = DatasetRegistry.register_dataset("deepcoder", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("deepcoder", test_dataset, "test")
+
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_deepcoder_data()
+ print(f" - Train dataset: {len(train_dataset.get_data())} examples")
+ print(f" - Test dataset: {len(test_dataset.get_data())} examples")
+ print(train_dataset.get_data()[0])
diff --git a/vendor/rllm/examples/deepcoder/run_deepcoder.py b/vendor/rllm/examples/deepcoder/run_deepcoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6c20ca8bce627a57c52939ec364b306be9a231b
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/run_deepcoder.py
@@ -0,0 +1,59 @@
+import asyncio
+import os
+from datetime import datetime
+
+from transformers import AutoTokenizer
+
+from rllm.agents.code_agent import CompetitionCodingAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.agent_execution_engine import AgentExecutionEngine
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import code_reward_fn
+from rllm.utils import save_trajectories
+
+if __name__ == "__main__":
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ n_parallel_agents = 64
+
+ model_name = "agentica-org/DeepCoder-14B-Preview"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ reward_fn = code_reward_fn
+
+ env_args = {
+ "reward_fn": reward_fn,
+ }
+
+ sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
+
+ engine = AgentExecutionEngine(
+ agent_class=CompetitionCodingAgent,
+ env_class=SingleTurnEnvironment,
+ agent_args={},
+ env_args=env_args,
+ engine_name="openai",
+ tokenizer=tokenizer,
+ sampling_params=sampling_params,
+ rollout_engine_args={
+ "base_url": "http://localhost:30000/v1",
+ "api_key": "None",
+ },
+ max_response_length=65536,
+ max_prompt_length=4096,
+ n_parallel_agents=n_parallel_agents,
+ )
+
+ test_dataset = DatasetRegistry.load_dataset("deepcoder", "test")
+ if test_dataset is None:
+ print("Dataset not found, preparing dataset...")
+ from prepare_deepcoder_data import prepare_deepcoder_data
+
+ _, test_dataset = prepare_deepcoder_data()
+
+ tasks = test_dataset.get_data()
+
+ results = asyncio.run(engine.execute_tasks(tasks))
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ save_trajectories(results, filename=f"deepcoder_trajectories_{len(tasks)}_{timestamp}.pt")
diff --git a/vendor/rllm/examples/deepcoder/train_deepcoder.py b/vendor/rllm/examples/deepcoder/train_deepcoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..baa90b42537ae860b9b732c34b9022267673d1c8
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/train_deepcoder.py
@@ -0,0 +1,30 @@
+import hydra
+
+from rllm.agents.code_agent import CompetitionCodingAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import code_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("deepcoder", "train")
+ test_dataset = DatasetRegistry.load_dataset("deepcoder", "test")
+
+ env_args = {"reward_fn": code_reward_fn}
+
+ trainer = AgentTrainer(
+ agent_class=CompetitionCodingAgent,
+ agent_args={},
+ env_args=env_args,
+ env_class=SingleTurnEnvironment,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/deepcoder/train_deepcoder_16k.sh b/vendor/rllm/examples/deepcoder/train_deepcoder_16k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..55baf0f271587abf82d93d34a1b287e0cb02ceed
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/train_deepcoder_16k.sh
@@ -0,0 +1,71 @@
+set -x
+
+ulimit -n 1048576
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+# DeepCoder base model - 14B parameter DeepSeek-R1 distilled model
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
+
+python3 -m examples.deepcoder.train_deepcoder \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=128 \
+ data.val_batch_size=512 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=16384 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.ppo_micro_batch_size=16 \
+ actor_rollout_ref.actor.ppo_epochs=1 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.clip_ratio_low=0.2 \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.top_p=0.95 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=2 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=True \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-deepcoder' \
+ trainer.experiment_name='deepcoder-14b-16k-stage1' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=10 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepcoder/train_deepcoder_32k.sh b/vendor/rllm/examples/deepcoder/train_deepcoder_32k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2656fc20bbb2fbff7dadfdcb56360db514129085
--- /dev/null
+++ b/vendor/rllm/examples/deepcoder/train_deepcoder_32k.sh
@@ -0,0 +1,71 @@
+set -x
+
+ulimit -n 1048576
+export VLLM_ATTENTION_BACKEND=XFORMERS
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+# TODO: Update this path to your 16K checkpoint after stage 1 training
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
+
+python3 -m examples.deepcoder.train_deepcoder \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=128 \
+ data.val_batch_size=512 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=32768 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.ppo_micro_batch_size=16 \
+ actor_rollout_ref.actor.ppo_epochs=1 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.kl_loss_coef=0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.clip_ratio_low=0.2 \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.top_p=0.95 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=2 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-deepcoder' \
+ trainer.experiment_name='deepcoder-14b-32k-stage2' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=10 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepresearch/.env.example b/vendor/rllm/examples/deepresearch/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..ecb7d218587a1f7fe08de5be771c78c6ed419175
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/.env.example
@@ -0,0 +1,28 @@
+# DeepResearch API Configuration
+# Copy this file to .env and fill in your API keys
+
+# OpenAI API (recommended for best performance)
+OPENAI_API_KEY=your_openai_api_key_here
+OPENAI_BASE_URL=https://api.openai.com/v1
+MODEL_NAME=gpt-4
+
+# Alternative: Together AI (cost-effective option)
+# TOGETHER_AI_API_KEY=your_together_ai_key_here
+# TOGETHER_AI_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct-Turbo
+
+# Alternative: Custom OpenAI-compatible endpoint (for vLLM hosting)
+# OPENAI_API_KEY=your_custom_api_key
+# OPENAI_BASE_URL=http://your-vllm-server:8000/v1
+# MODEL_NAME=your-hosted-model-name
+
+# Search API keys for research tools
+# Serper API (required for web search functionality)
+SERPER_KEY_ID=your_serper_api_key_from_serper.dev
+
+# Alternative: Google Custom Search API (if you prefer Google over Serper)
+# GOOGLE_SEARCH_SECRET_KEY=your_google_api_key
+# GOOGLE_SEARCH_ENGINE_ID=your_custom_search_engine_id
+
+# Evaluation settings
+# DEEPRESEARCH_TASK=Custom research question to test
+# GAIA_DATASET_PATH=path/to/gaia.json
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepresearch/README.md b/vendor/rllm/examples/deepresearch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..384305c40d4d5de0d9a10a81d58944c28f78d6ca
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/README.md
@@ -0,0 +1,260 @@
+# DeepResearch Integration for rLLM
+
+## Overview
+
+This module integrates Tongyi's DeepResearch ReAct agent into the rLLM framework, enabling evaluation on academic benchmarks like HLE (Humanity's Last Exam). The integration demonstrates how to port external agent architectures into rLLM's workflow system while maintaining compatibility with the training and evaluation infrastructure.
+
+## Architecture
+
+```
+DeepResearch Agent (ReAct with XML-based tool calling)
+ ↓
+DeepResearchWorkflow (rLLM Workflow wrapper)
+ ↓
+AgentWorkflowEngine (Parallel execution)
+ ↓
+Episode/Trajectory (rLLM data format)
+```
+
+### Key Components
+
+- **`deepresearch_agent.py`**: MultiTurnReactAgent implementing Tongyi's ReAct loop with tool calling
+- **`deepresearch_workflow.py`**: Wrapper that converts agent outputs to rLLM Episodes for trajectory tracking
+- **`deepresearch_tools.py`**: Tool implementations (Search, Scholar, Visit, FileParser, PythonInterpreter)
+- **`evaluate_hle.py`**: Evaluation script for HLE (Humanity's Last Exam) benchmark
+
+## Installation
+
+### Prerequisites
+
+```bash
+# Activate rLLM environment
+conda activate rllm
+
+# Install required dependencies
+pip install datasets # For HLE dataset access
+pip install tiktoken # Optional: for better token counting with OpenAI models
+```
+
+### Environment Setup
+
+Create a `.env` file with your API keys:
+
+```bash
+# For model inference (choose one)
+OPENAI_API_KEY=your_openai_key
+TOGETHER_AI_API_KEY=your_together_key
+
+# Optional: For web search tool
+SERPER_API_KEY=your_serper_key # Get free key from serper.dev
+```
+
+## Usage
+
+### Running HLE Evaluation
+
+```bash
+# Evaluate on HLE dataset with default settings
+python evaluate_hle.py --hf-dataset cais/hle --max-samples 10 --parallel-tasks 4
+
+# Use specific model
+python evaluate_hle.py --model gpt-4o --max-samples 5
+
+# Use Together AI for evaluation
+python evaluate_hle.py --model Qwen/Qwen2.5-7B-Instruct-Turbo \
+ --base-url https://api.together.xyz/v1 \
+ --max-samples 20
+
+# Custom output directory
+python evaluate_hle.py --output-dir ./my_results --max-samples 20
+```
+
+### Using DeepResearch Agent Directly
+
+```python
+from rllm.engine.rollout import OpenAIEngine
+from deepresearch_agent import MultiTurnReactAgent
+from deepresearch_tools import get_all_tools
+
+# Setup rollout engine
+engine = OpenAIEngine(
+ model="gpt-4o",
+ api_key="your_key",
+ base_url="https://api.openai.com/v1"
+)
+
+# Create agent with tools
+agent = MultiTurnReactAgent(
+ rollout_engine=engine,
+ tools=get_all_tools()
+)
+
+# Run a research task
+result = await agent.run(
+ question="What is the reduced 12th dimensional Spin bordism of BG2?",
+ answer="Z/2" # Optional ground truth for evaluation
+)
+
+print(f"Prediction: {result['prediction']}")
+print(f"Rounds: {result['rounds']}")
+print(f"Time taken: {result['time_taken']}s")
+```
+
+### Integrating with rLLM Workflows
+
+```python
+from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
+from deepresearch_workflow import DeepResearchWorkflow
+
+# Create workflow engine for parallel execution
+workflow_engine = AgentWorkflowEngine(
+ workflow_cls=DeepResearchWorkflow,
+ workflow_args={
+ "tools": get_all_tools(),
+ "max_prompt_length": 4096,
+ "max_response_length": 2048
+ },
+ rollout_engine=engine,
+ n_parallel_tasks=4 # Run 4 tasks in parallel
+)
+
+# Run evaluation on multiple tasks
+tasks = [
+ {"question": "Question 1", "answer": "Answer 1"},
+ {"question": "Question 2", "answer": "Answer 2"}
+]
+
+episodes = await workflow_engine.execute_tasks(tasks)
+
+# Episodes contain full trajectories for training
+for episode in episodes:
+ print(f"Task: {episode.task}")
+ print(f"Prediction: {episode.metrics.get('prediction')}")
+ print(f"Is correct: {episode.is_correct}")
+```
+
+## Tools
+
+The agent has access to the following research tools:
+
+| Tool | Description | Implementation Status |
+| --------------------- | --------------------------- | ------------------------------------ |
+| **Search** | Web search via Serper API | ✅ Fully implemented (needs API key) |
+| **PythonInterpreter** | Execute Python code safely | ✅ Fully implemented with security |
+| **Scholar** | Academic paper search | ❌ Placeholder only |
+| **Visit** | Visit and analyze web pages | ❌ Placeholder only |
+| **FileParser** | Parse various file formats | ⚠️ Basic text only (no PDF/DOCX) |
+
+### Tool Implementation Notes
+
+- **Search**: Real web search with Serper API integration. Configure API key in `.env` file
+- **PythonInterpreter**: Enhanced security, 50s timeout, supports numpy/pandas when available
+- **Scholar**: Returns placeholder results. Needs integration with arXiv/Google Scholar APIs
+- **Visit**: Returns placeholder content. Needs requests/BeautifulSoup implementation
+- **FileParser**: Only reads text files up to 5000 chars. Original supports PDF/DOCX/media files
+
+## Key Improvements from Original
+
+### 1. Token Counting Fix
+
+- **Problem**: Original used mismatched tokenizers (GPT-2 for GPT-4o) causing incorrect context limits
+- **Solution**: Now uses OpenAI API's actual token statistics from response.prompt_tokens and response.completion_tokens
+- **Impact**: No more false "context exceeded" errors at 13k tokens when limit is 128k
+
+### 2. Context Management
+
+- **Problem**: System would incorrectly truncate messages based on wrong token counts
+- **Solution**: Track actual cumulative API token consumption for accurate context management
+- **Impact**: Model can use full context window effectively
+
+### 3. System Prompt Optimization
+
+- **Problem**: Over-constrained prompt requiring specific tags caused unnatural responses
+- **Solution**: Simplified prompt matching original Tongyi design, letting model reason naturally
+- **Impact**: Better convergence, fewer infinite loops
+
+### 4. Parallel Execution
+
+- \*\*Leverages AgentWorkflowEngine for concurrent task processing
+- \*\*Configurable parallelism (n_parallel_tasks parameter)
+- \*\*Automatic retry on failures
+
+## Evaluation Results
+
+Evaluation results will be added after running benchmarks. The system is designed to evaluate on HLE and other academic benchmarks.
+
+## Known Issues and Limitations
+
+1. **Tool Placeholders**: Scholar and Visit tools need real implementations for research tasks
+2. **Model-Specific Behavior**:
+ - Some models may not consistently use `` tags
+ - Tool calling format adherence varies by model
+3. **Long Context Tasks**: Very complex research may still hit token limits
+4. **Judge Accuracy**: LLM judge may not perfectly evaluate complex answers
+
+## Future Improvements
+
+- [ ] Implement real Scholar tool using arXiv/Semantic Scholar APIs
+- [ ] Implement real Visit tool using requests/BeautifulSoup
+- [ ] Add PDF/DOCX parsing to FileParser
+- [ ] Create unified evaluation framework for multiple benchmarks
+- [ ] Add more Tongyi agents (QwenCoder, etc.)
+- [ ] Improve judge accuracy with better prompts
+
+## Project Structure
+
+```
+examples/deepresearch/
+├── deepresearch_agent.py # Core ReAct agent implementation
+├── deepresearch_workflow.py # rLLM workflow wrapper
+├── deepresearch_tools.py # Tool implementations
+├── evaluate_hle.py # HLE evaluation script
+├── react_agent_original.py # Original Tongyi reference
+├── tool_*_original.py # Original tool references
+├── hle_outputs/ # Evaluation results (git ignored)
+└── README.md # This file
+```
+
+## Contributing
+
+To add new tools or improve existing ones:
+
+1. Implement tool in `deepresearch_tools.py` following the pattern:
+
+ ```python
+ class YourTool(DeepResearchTool):
+ async def call(self, **kwargs) -> str:
+ # Your implementation
+ return result_string
+ ```
+
+2. Add to `DEEPRESEARCH_TOOLS` registry
+
+3. Test with evaluation script
+
+4. Submit PR with test results
+
+## Related Work
+
+This integration is part of the rLLM evaluation framework initiative. See also:
+
+- `examples/strands/` - Strands agent integration
+- `rllm/agents/` - Native rLLM agents
+- `rllm/workflows/` - Workflow base classes
+
+## Citation
+
+If you use this integration, please cite:
+
+```bibtex
+@misc{deepresearch2024,
+ title={DeepResearch: Multi-turn Research Agent},
+ author={Alibaba NLP Team},
+ year={2024},
+ url={https://github.com/Alibaba-NLP/DeepResearch}
+}
+```
+
+## License
+
+This integration follows rLLM's license. The original DeepResearch implementation is from Alibaba's Tongyi team.
diff --git a/vendor/rllm/examples/deepresearch/deepresearch_agent.py b/vendor/rllm/examples/deepresearch/deepresearch_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ab74238cd1dbb3e53903ef9b4b4f4621cc2893b
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/deepresearch_agent.py
@@ -0,0 +1,733 @@
+"""
+DeepResearch Agent - Adapted from Tongyi DeepResearch for rLLM
+
+This is the core ReAct agent that implements DeepResearch's reasoning and tool-calling logic,
+adapted to work with rLLM's OpenAI engine instead of the original server-based approach.
+
+Original: https://github.com/Alibaba-NLP/DeepResearch/blob/main/inference/react_agent.py
+"""
+
+import asyncio
+import json
+import time
+from datetime import datetime
+
+# rLLM imports
+from rllm.engine.rollout import RolloutEngine
+
+# Constants from original DeepResearch
+OBS_START = ""
+OBS_END = "\n "
+MAX_LLM_CALL_PER_RUN = 100
+
+# System prompt adapted from DeepResearch
+DEEPRESEARCH_SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You MUST use the provided tools to research and verify information before answering. Do NOT answer directly from memory - always use tools to gather current, accurate information.
+
+IMPORTANT: You are REQUIRED to use at least one tool before providing any answer. Even if you think you know the answer, you must verify it using the appropriate tools. Direct answers without tool use are not acceptable.
+
+When you have gathered sufficient information through tool use and are ready to provide the definitive response, you must enclose the entire final answer within tags.
+
+# Tools
+
+You MUST use one or more of the following tools to research the query:
+
+You are provided with the following tools:
+- Search: for web searches to find current information
+- Scholar: for academic research and paper searches
+- Visit: for visiting and analyzing web pages
+- PythonInterpreter: for running Python code and calculations
+- FileParser: for reading and analyzing files
+
+For each function call, return a json object with function name and arguments within XML tags:
+
+{"name": , "arguments": }
+
+
+For Python code execution, use:
+
+python
+
+# Your Python code here
+print("Hello World")
+
+
+
+Current date: """
+
+
+def today_date():
+ """Get today's date in YYYY-MM-DD format."""
+ return datetime.now().date().strftime("%Y-%m-%d")
+
+
+def build_text_completion_prompt(messages: list[dict], allow_special: bool = True) -> str:
+ """
+ Build text completion prompt from messages list.
+ Adapted from qwen_agent.utils.utils.build_text_completion_prompt
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content' keys
+ allow_special: Whether to allow special tokens (for compatibility)
+
+ Returns:
+ Formatted prompt string
+ """
+ im_start = "<|im_start|>"
+ im_end = "<|im_end|>"
+
+ prompt_parts = []
+
+ # Handle system message
+ if messages and messages[0]["role"] == "system":
+ sys_content = messages[0]["content"]
+ prompt_parts.append(f"{im_start}system\n{sys_content}{im_end}")
+ messages = messages[1:]
+
+ # Ensure chat completes with assistant
+ if messages and messages[-1]["role"] != "assistant":
+ messages = messages + [{"role": "assistant", "content": ""}]
+
+ # Format each message
+ for msg in messages:
+ role = msg["role"]
+ content = msg["content"]
+ prompt_parts.append(f"{im_start}{role}\n{content}{im_end}")
+
+ return "\n".join(prompt_parts)
+
+
+class MultiTurnReactAgent:
+ """
+ Multi-turn ReAct Agent adapted from Tongyi DeepResearch.
+
+ This agent implements the core reasoning loop with tool calling capabilities,
+ using rLLM's OpenAI engine for model inference.
+ """
+
+ def __init__(
+ self,
+ rollout_engine: RolloutEngine,
+ tools: dict = None,
+ system_prompt: str | None = None,
+ use_native_function_calling: bool = False,
+ **kwargs,
+ ):
+ """
+ Initialize the ReAct agent.
+
+ Args:
+ rollout_engine: rLLM OpenAI engine for model inference
+ tools: Dictionary of available tools {tool_name: tool_instance}
+ system_prompt: Optional custom system prompt
+ use_native_function_calling: Whether to use OpenAI native function calling (supports o3)
+ """
+ self.rollout_engine = rollout_engine
+ self.tools = tools or {}
+ self.system_prompt = system_prompt
+ self.use_native_function_calling = use_native_function_calling
+
+ # Convert tools to OpenAI format if using native function calling
+ if use_native_function_calling and self.tools:
+ self.openai_tools = [tool.json for tool in self.tools.values()]
+ else:
+ self.openai_tools = None
+
+ # Configuration from original DeepResearch
+ self.max_llm_calls = MAX_LLM_CALL_PER_RUN
+ self.max_time = 150 * 60 # 150 minutes timeout
+
+ # Smart context management using actual API consumption
+ self.total_prompt_tokens = 0
+ self.total_completion_tokens = 0
+
+ # Auto-detect context limit based on model capabilities
+ # This ensures we don't hit limits too early for capable models
+ self.max_context_tokens = self._get_model_context_limit(rollout_engine)
+
+ def _get_model_context_limit(self, rollout_engine) -> int:
+ """
+ Auto-detect context limit based on model capabilities.
+ Uses LiteLLM's model info when available, falls back to conservative estimates.
+ Returns 90% of max to leave safety headroom.
+ """
+ model_name = rollout_engine.model
+
+ # Method 1: Try LiteLLM's get_model_info (most accurate)
+ try:
+ import litellm
+
+ model_info = litellm.get_model_info(model_name)
+ if model_info and "max_input_tokens" in model_info:
+ max_tokens = model_info["max_input_tokens"]
+ conservative_limit = int(max_tokens * 0.90) # Use 90% for safety
+ if not hasattr(MultiTurnReactAgent, "_context_limit_reported"):
+ print(f" 📏 Detected context window: {max_tokens:,} tokens (using 90% = {conservative_limit:,})")
+ MultiTurnReactAgent._context_limit_reported = True
+ return conservative_limit
+ except Exception:
+ # LiteLLM might not have info for all models, that's ok
+ pass
+
+ # Method 2: Try tiktoken to get model family info
+ try:
+ import tiktoken
+
+ # tiktoken.encoding_for_model will throw if model unknown
+ encoding = tiktoken.encoding_for_model(model_name)
+ # Map known encodings to context limits
+ encoding_limits = {
+ "cl100k_base": 128 * 1024, # GPT-4, GPT-3.5-turbo-16k
+ "p50k_base": 4 * 1024, # text-davinci-002/003
+ "r50k_base": 4 * 1024, # GPT-3 base models
+ }
+ if encoding.name in encoding_limits:
+ max_tokens = encoding_limits[encoding.name]
+ conservative_limit = int(max_tokens * 0.90)
+ if not hasattr(MultiTurnReactAgent, "_context_limit_reported"):
+ print(f" 📏 Inferred context from encoding '{encoding.name}': {conservative_limit:,} tokens")
+ MultiTurnReactAgent._context_limit_reported = True
+ return conservative_limit
+ except Exception:
+ pass
+
+ # Method 3: Pattern matching fallback (least accurate but works)
+ model_lower = model_name.lower()
+ fallback_limits = {
+ # OpenAI reasoning models
+ ("o3", "o1"): 128 * 1024,
+ # GPT-4 family
+ ("gpt-4o", "gpt-4-turbo"): 128 * 1024,
+ ("gpt-4-32k",): 32 * 1024,
+ ("gpt-4",): 8 * 1024,
+ # Claude family
+ ("claude-3-5", "claude-3.5"): 200 * 1024,
+ ("claude-3",): 200 * 1024,
+ ("claude-2",): 100 * 1024,
+ # Gemini family
+ ("gemini-1.5", "gemini-2"): 1000 * 1024,
+ ("gemini",): 32 * 1024,
+ # Qwen
+ ("qwen2", "qwen-2"): 128 * 1024,
+ ("qwen",): 32 * 1024,
+ }
+
+ for patterns, max_tokens in fallback_limits.items():
+ if any(pattern in model_lower for pattern in patterns):
+ conservative_limit = int(max_tokens * 0.90)
+ if not hasattr(MultiTurnReactAgent, "_context_limit_reported"):
+ print(f" 📏 Pattern-matched context limit: {conservative_limit:,} tokens (90% of {max_tokens:,})")
+ MultiTurnReactAgent._context_limit_reported = True
+ return conservative_limit
+
+ # Method 4: Ultimate fallback
+ default_limit = 100 * 1024
+ if not hasattr(MultiTurnReactAgent, "_context_limit_reported"):
+ print(f" ⚠️ Unknown model '{model_name}', using conservative default: {default_limit:,} tokens")
+ MultiTurnReactAgent._context_limit_reported = True
+ return default_limit
+
+ def sanity_check_output(self, content: str) -> bool:
+ """Check if the model output contains the expected thinking structure."""
+ return "" in content and " " in content
+
+ async def call_server(self, messages: list[dict], max_tries: int = 10):
+ """
+ Call rLLM OpenAI engine with hybrid mode support.
+
+ Supports both:
+ - Native function calling (for o3, gpt-4-turbo)
+ - ReAct text format (for gpt-4o, Claude)
+
+ Args:
+ messages: List of chat completion messages
+ max_tries: Maximum number of retry attempts
+
+ Returns:
+ ModelOutput with text and tool_calls
+ """
+ for attempt in range(max_tries):
+ try:
+ # Base parameters
+ api_params = {"messages": messages}
+
+ # Model-specific parameter configuration
+ model_name = self.rollout_engine.model.lower()
+
+ if "o3" in model_name or "o1" in model_name:
+ # O3/O1: Very limited parameter support
+ api_params["max_completion_tokens"] = 4096
+ elif "gpt-4" in model_name:
+ # GPT-4: Full parameter support
+ api_params.update(
+ {
+ "stop": ["\n", ""],
+ "temperature": 0.6,
+ "top_p": 0.95,
+ "max_tokens": 4096,
+ "presence_penalty": 1.1,
+ }
+ )
+ elif "qwen" in model_name:
+ # Qwen models
+ api_params.update(
+ {
+ "temperature": 0.6,
+ "top_p": 0.95,
+ "max_tokens": 4096,
+ }
+ )
+ else:
+ # Fallback: Conservative params
+ api_params.update(
+ {
+ "temperature": 0.6,
+ "max_tokens": 4096,
+ }
+ )
+
+ # Add tools parameter for native function calling
+ if self.use_native_function_calling and self.openai_tools:
+ api_params["tools"] = self.openai_tools
+ api_params["tool_choice"] = "auto"
+
+ # Call rLLM OpenAI Engine
+ response = await self.rollout_engine.get_model_response(**api_params)
+
+ # Track actual token consumption from API
+ if hasattr(response, "prompt_length") and hasattr(response, "completion_length"):
+ self.total_prompt_tokens += response.prompt_length
+ self.total_completion_tokens += response.completion_length
+
+ # Return full ModelOutput (contains both text and tool_calls)
+ return response
+
+ except Exception as e:
+ print(f"Error: Attempt {attempt + 1} failed: {e}")
+ if attempt < max_tries - 1:
+ # Exponential backoff
+ sleep_time = 2**attempt
+ print(f"Waiting {sleep_time} seconds before retry...")
+ await asyncio.sleep(sleep_time)
+
+ raise Exception(f"Failed to get response after {max_tries} attempts")
+
+ def get_total_tokens_used(self) -> int:
+ """
+ Get total tokens consumed so far from actual API usage.
+ This is much more accurate than any tokenizer estimation.
+
+ Returns:
+ Total tokens used (prompt + completion)
+ """
+ return self.total_prompt_tokens + self.total_completion_tokens
+
+ async def _run(self, question: str, answer: str = None, images: list = None, **kwargs) -> dict:
+ """
+ Main reasoning loop adapted from original DeepResearch.
+
+ This is the core ReAct implementation that handles:
+ - Multi-turn conversation
+ - Tool calling and execution
+ - Context length management
+ - Termination conditions
+
+ Args:
+ question: The research question to answer
+ answer: Ground truth answer (for evaluation)
+ images: List of image data URLs (base64 encoded)
+
+ Returns:
+ Dictionary with results including messages, prediction, and termination reason
+ """
+ start_time = time.time()
+
+ # Setup system prompt with current date
+ system_prompt = (self.system_prompt or DEEPRESEARCH_SYSTEM_PROMPT) + today_date()
+
+ # Construct initial user message (multimodal if images present)
+ if images:
+ # Build multimodal message with images
+ user_content = [{"type": "text", "text": question}]
+ for image_data in images:
+ user_content.append({"type": "image_url", "image_url": {"url": image_data}})
+ user_message = {"role": "user", "content": user_content}
+ else:
+ # Plain text message
+ user_message = {"role": "user", "content": question}
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ user_message,
+ ]
+
+ num_llm_calls_available = self.max_llm_calls
+ round = 0
+ termination = None
+ prediction = ""
+
+ # Truncate question for display
+ q_display = str(question).replace("\n", " ").strip()
+ if len(q_display) > 200:
+ q_display = q_display[:200] + "..."
+ print(f"🔍 Starting DeepResearch for question: {q_display}")
+
+ while num_llm_calls_available > 0:
+ # Check time limit (150 minutes)
+ if time.time() - start_time > self.max_time:
+ prediction = "No answer found after 2h30mins"
+ termination = "No answer found after 2h30mins"
+ result = {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction,
+ "termination": termination,
+ }
+ return result
+
+ round += 1
+ num_llm_calls_available -= 1
+
+ # Get model response (ModelOutput with text and tool_calls)
+ response = await self.call_server(messages)
+
+ # Extract text content (may be None for pure function calling)
+ content = response.text if hasattr(response, "text") and response.text else ""
+
+ # Debug: Print raw model response to see format
+ if round == 1:
+ print(f"[DEBUG] Raw model response (first 500 chars): {content[:500]}")
+ if hasattr(response, "tool_calls") and response.tool_calls:
+ print(f"[DEBUG] Native tool_calls detected: {len(response.tool_calls)} call(s)")
+
+ # Print concise round info with truncation
+ MAX_PRINT_LENGTH = 200
+
+ # Simple truncation for all prints
+ def truncate(text, max_len=MAX_PRINT_LENGTH):
+ text = str(text).replace("\n", " ").strip()
+ # Special handling for base64 images
+ if "data:image" in text or ";base64," in text:
+ # Find the base64 part and truncate it
+ if "base64," in text:
+ parts = text.split("base64,", 1)
+ return parts[0] + "base64,[truncated]"
+ return "[base64 image data]"
+ if len(text) > max_len:
+ return text[:max_len] + "..."
+ return text
+
+ # Print round info based on content type
+ if "" in content:
+ # Extract tool name for display
+ if "python" in content.lower() and "" in content:
+ print(f"Round {round}: 🐍 Executing Python code")
+ elif '"name":' in content:
+ try:
+ import json5
+
+ tool_text = content.split("")[1].split(" ")[0]
+ tool_text = tool_text[:1000] # Limit for parsing
+ tool_data = json5.loads(tool_text)
+ tool_name = tool_data.get("name", "Unknown")
+ if "arguments" in tool_data:
+ args_str = truncate(str(tool_data["arguments"]), 100)
+ print(f"Round {round}: 🔧 Calling {tool_name} with args: {args_str}")
+ else:
+ print(f"Round {round}: 🔧 Calling {tool_name}")
+ except Exception:
+ print(f"Round {round}: 🔧 Tool call")
+ else:
+ print(f"Round {round}: 🔧 Tool call")
+ elif "" in content:
+ # Final answer
+ answer_preview = content.split("")[1].split(" ")[0]
+ print(f"Round {round}: ✅ Final answer: {truncate(answer_preview, 100)}")
+ else:
+ # Show internal reasoning if available, otherwise show content
+ if hasattr(response, "reasoning") and response.reasoning:
+ reasoning_preview = truncate(response.reasoning, 300)
+ print(f"Round {round}: 💭 [Internal] {reasoning_preview}")
+ elif content:
+ print(f"Round {round}: 💭 Reasoning: {truncate(content)}")
+
+ # Clean up content if it contains tool_response
+ if "" in content:
+ pos = content.find("")
+ content = content[:pos]
+
+ # HYBRID MODE: Handle both native tool_calls and ReAct text format
+
+ # Priority 1: Check for native function calling (o3, gpt-4-turbo)
+ if hasattr(response, "tool_calls") and response.tool_calls:
+ # Native function calling path - build ALL messages first, then append atomically
+ tool_calls_formatted = []
+ tool_responses = []
+
+ for tool_call in response.tool_calls:
+ try:
+ # Follow strands.py tolerant extraction of function/name/arguments
+ try:
+ function = tool_call.get("function", {}) if isinstance(tool_call, dict) else getattr(tool_call, "function", {})
+ except Exception:
+ function = tool_call
+
+ tool_id = tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", "unknown")
+ tool_name = function.get("name") if isinstance(function, dict) else getattr(function, "name", "")
+ arguments_raw = function.get("arguments") if isinstance(function, dict) else getattr(function, "arguments", "{}")
+
+ # Parse arguments if provided as JSON string
+ tool_args = json.loads(arguments_raw) if isinstance(arguments_raw, str) else arguments_raw
+
+ # Print tool call with arguments (for consistency with ReAct format)
+ def truncate(text, max_len=100):
+ text = str(text).replace("\n", " ").strip()
+ if len(text) > max_len:
+ return text[:max_len] + "..."
+ return text
+
+ args_str = truncate(str(tool_args), 100)
+ print(f"Round {round}: 🔧 [Native] Calling {tool_name} with args: {args_str}")
+
+ # Execute tool
+ result = await self.custom_call_tool(tool_name, tool_args)
+
+ # Collect tool call and response (don't append yet)
+ tool_calls_formatted.append(
+ {
+ "id": tool_id,
+ "type": "function",
+ "function": {
+ "name": tool_name,
+ "arguments": arguments_raw,
+ },
+ }
+ )
+ tool_responses.append({"role": "tool", "tool_call_id": tool_id, "content": result})
+
+ except Exception as e:
+ print(f"Error processing native tool call: {e}")
+ # On error, append error message and skip this tool call
+ messages.append({"role": "assistant", "content": content.strip()})
+ messages.append({"role": "user", "content": f"Tool call error: {e}"})
+ continue
+
+ # Only append to messages if we have successful tool calls
+ if tool_calls_formatted:
+ # Add assistant message with ALL tool calls at once
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content or "", # May be empty for pure function calling
+ "tool_calls": tool_calls_formatted,
+ }
+ )
+ # Add all tool responses
+ messages.extend(tool_responses)
+
+ # Priority 2: Check for ReAct text format (gpt-4o, Claude)
+ elif "" in content and " " in content:
+ # ReAct text format path
+ messages.append({"role": "assistant", "content": content.strip()})
+
+ tool_call_text = content.split("")[1].split(" ")[0]
+ try:
+ # Special handling for Python code (match original logic)
+ if "python" in tool_call_text.lower():
+ try:
+ # Extract code from the original content (not just tool_call_text)
+ code_raw = content.split("")[1].split(" ")[0].split("")[1].split("")[0].strip()
+ result = await self.execute_python(code_raw)
+ except Exception:
+ result = "[Python Interpreter Error]: Formatting error."
+ else:
+ # Parse JSON tool call
+ tool_call = json5.loads(tool_call_text)
+ tool_name = tool_call.get("name", "")
+ tool_args = tool_call.get("arguments", {})
+ result = await self.custom_call_tool(tool_name, tool_args)
+
+ except Exception:
+ result = 'Error: Tool call is not a valid JSON. Tool call must contain a valid "name" and "arguments" field.'
+
+ # Add tool response in ReAct format
+ tool_response = f"\n{result}\n "
+ messages.append({"role": "user", "content": tool_response})
+
+ # Priority 3: No tool call, just reasoning or answer
+ else:
+ messages.append({"role": "assistant", "content": content.strip()})
+
+ # Check for final answer AFTER processing tools
+ # This allows o3 to execute tools even when it includes answer in same message
+ if "" in content and " " in content:
+ prediction = content.split("")[1].split(" ")[0].strip()
+ termination = "answer"
+ break
+
+ # Check if we've exceeded call limit
+ if num_llm_calls_available <= 0 and "" not in content:
+ # Handle both message formats
+ if isinstance(messages[-1], dict) and "content" in messages[-1]:
+ messages[-1]["content"] = "Sorry, the number of llm calls exceeds the limit."
+
+ # Handle context length limit using actual API consumption
+ total_tokens_used = self.get_total_tokens_used()
+
+ if total_tokens_used > self.max_context_tokens:
+ # Instead of replacing the last message, add a clear instruction
+ final_instruction = {
+ "role": "user",
+ "content": "You have reached the maximum context length. Based on all the information above, please provide your best answer now in the format: your final thinking \nyour answer ",
+ }
+
+ # Truncate conversation history to make room for final answer
+ # Keep system prompt, original question, and recent context
+ if len(messages) > 4: # system + user + at least 2 exchanges
+ # Keep first 2 messages (system + original question) and last 2 meaningful exchanges
+ truncated_messages = messages[:2] # system + original question
+ recent_messages = messages[-4:] # last 4 messages for context
+ truncated_messages.extend(recent_messages)
+ messages = truncated_messages
+
+ messages.append(final_instruction)
+
+ # Note: After truncation, we'll let the next API call handle any remaining limits
+ print(f"Round {round + 1}: ⚠️ Context limit reached, requesting final answer")
+
+ response = await self.call_server(messages)
+ content = response.text if hasattr(response, "text") and response.text else ""
+ messages.append({"role": "assistant", "content": content.strip()})
+
+ if "" in content and " " in content:
+ prediction = content.split("")[1].split(" ")[0].strip()
+ termination = "answer generated due to token limit"
+ else:
+ prediction = content.strip()
+ termination = "response generated due to token limit (no answer format)"
+
+ result = {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction,
+ "termination": termination,
+ }
+ return result
+
+ # Final validation logic from original Tongyi implementation
+ # Handle both native function calling and ReAct text format
+ last_message_content = messages[-1].get("content", "") if isinstance(messages[-1], dict) else ""
+ if last_message_content and "" in last_message_content:
+ prediction = last_message_content.split("")[1].split(" ")[0]
+ termination = "answer"
+ else:
+ prediction = "No answer found."
+ termination = "answer not found"
+ if num_llm_calls_available == 0:
+ termination = "exceed available llm calls"
+
+ # Final result
+ result = {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction,
+ "termination": termination,
+ "rounds": round,
+ "time_taken": time.time() - start_time,
+ }
+
+ print("\n🏁 DeepResearch completed:")
+ print(f" Rounds: {round}")
+ print(f" Time: {result['time_taken']:.1f}s")
+ print(f" Termination: {termination}")
+ # Truncate prediction for display
+ pred_display = str(prediction).replace("\n", " ").strip()
+ if len(pred_display) > 200:
+ pred_display = pred_display[:200] + "..."
+ print(f" Prediction: {pred_display}")
+
+ return result
+
+ async def custom_call_tool(self, tool_name: str, tool_args: dict, **kwargs) -> str:
+ """
+ Execute tool calls with the available tools.
+
+ Args:
+ tool_name: Name of the tool to call
+ tool_args: Arguments to pass to the tool
+
+ Returns:
+ Tool execution result as string
+ """
+ if tool_name in self.tools:
+ try:
+ # Call the tool
+ if hasattr(self.tools[tool_name], "call"):
+ # Async tool
+ if asyncio.iscoroutinefunction(self.tools[tool_name].call):
+ result = await self.tools[tool_name].call(**tool_args)
+ else:
+ result = self.tools[tool_name].call(**tool_args)
+ elif callable(self.tools[tool_name]):
+ # Direct callable
+ result = self.tools[tool_name](**tool_args)
+ else:
+ result = f"Tool {tool_name} is not callable"
+
+ return str(result)
+
+ except Exception as e:
+ return f"Error calling tool {tool_name}: {e}"
+ else:
+ available_tools = list(self.tools.keys())
+ return f"Tool {tool_name} not found. Available tools: {available_tools}"
+
+ async def execute_python(self, code: str) -> str:
+ """
+ Execute Python code using the PythonInterpreter tool.
+
+ Args:
+ code: Python code to execute
+
+ Returns:
+ Execution result as string
+ """
+ if "PythonInterpreter" in self.tools:
+ try:
+ # Use the PythonInterpreter tool
+ tool = self.tools["PythonInterpreter"]
+ if hasattr(tool, "call"):
+ if asyncio.iscoroutinefunction(tool.call):
+ result = await tool.call(code=code)
+ else:
+ result = tool.call(code=code)
+ return str(result)
+ else:
+ return "PythonInterpreter tool is not callable"
+ except Exception as e:
+ return f"Python execution error: {e}"
+ else:
+ return "PythonInterpreter tool not available"
+
+ def reset(self):
+ """Reset the agent state (for compatibility with rLLM workflow)."""
+ # Reset token counters for each new task
+ self.total_prompt_tokens = 0
+ self.total_completion_tokens = 0
+
+ async def run(self, question: str, answer: str = None, **kwargs) -> dict:
+ """
+ Public interface for running the agent.
+
+ Args:
+ question: Research question to answer
+ answer: Ground truth answer (optional, for evaluation)
+
+ Returns:
+ Result dictionary
+ """
+ # Reset token counters for each new run
+ self.reset()
+ return await self._run(question, answer, **kwargs)
diff --git a/vendor/rllm/examples/deepresearch/deepresearch_tools.py b/vendor/rllm/examples/deepresearch/deepresearch_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b7b61de7eb2d1640a7de5c8c3a3e59d538ab401
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/deepresearch_tools.py
@@ -0,0 +1,750 @@
+"""
+DeepResearch Tools - Production-ready implementations
+
+This module provides tool implementations for the DeepResearch agent, with real
+functionality ported from Tongyi's original implementations where possible.
+
+Now supports both:
+- ReAct text format (for gpt-4o, Claude, etc.)
+- OpenAI native function calling (for o3, o3-mini, etc.)
+"""
+
+import http.client
+import json
+import os
+from abc import ABC, abstractmethod
+
+from rllm.tools.tool_base import Tool as RLLMTool
+
+
+class DeepResearchTool(RLLMTool, ABC):
+ """
+ Base class for all DeepResearch tools.
+
+ Inherits from rLLM's Tool to support OpenAI native function calling,
+ while maintaining compatibility with ReAct text format.
+ """
+
+ def __init__(self, name: str, description: str, parameters: dict | None = None):
+ """
+ Initialize DeepResearch tool with OpenAI function calling support.
+
+ Args:
+ name: Tool name
+ description: Tool description
+ parameters: OpenAI-style parameter schema (optional)
+ """
+ # Set _json BEFORE calling super().__init__
+ # because the parent's __init__ may access self.json
+ self._json = {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": description,
+ "parameters": parameters or {"type": "object", "properties": {}, "required": []},
+ },
+ }
+
+ super().__init__(name=name, description=description)
+
+ @abstractmethod
+ async def call(self, **kwargs) -> str:
+ """Execute the tool with given arguments."""
+ pass
+
+ async def async_forward(self, **kwargs):
+ """rLLM Tool interface - delegates to call()"""
+ from rllm.tools.tool_base import ToolOutput
+
+ try:
+ result = await self.call(**kwargs)
+ return ToolOutput(name=self.name, output=result)
+ except Exception as e:
+ return ToolOutput(name=self.name, error=f"{type(e).__name__} - {str(e)}")
+
+
+class SearchTool(DeepResearchTool):
+ """Web search tool using Serper API (ported from Tongyi)."""
+
+ def __init__(self):
+ super().__init__(
+ name="Search",
+ description="Performs web searches using Google via Serper API",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The search query string",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ def contains_chinese(self, text: str) -> bool:
+ """Check if text contains Chinese characters."""
+ return any("\u4e00" <= char <= "\u9fff" for char in text)
+
+ def _google_search_fallback(self, query: str | list) -> str:
+ """Use Google Custom Search API as fallback."""
+ try:
+ import requests
+
+ google_key = os.getenv("GOOGLE_SEARCH_SECRET_KEY")
+ engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
+
+ queries = [query] if isinstance(query, str) else query
+ all_results = []
+
+ for q in queries:
+ params = {"key": google_key, "cx": engine_id, "q": q, "num": 10}
+
+ response = requests.get(
+ "https://customsearch.googleapis.com/customsearch/v1",
+ params=params,
+ timeout=5,
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+ items = data.get("items", [])
+
+ web_snippets = []
+ for idx, item in enumerate(items[:10], 1):
+ title = item.get("title", "")
+ link = item.get("link", "")
+ snippet = item.get("snippet", "")
+ entry = f"{idx}. [{title}]({link})\n {snippet}"
+ web_snippets.append(entry)
+
+ result = f"Google search for '{q}' found {len(web_snippets)} results:\n\n" + "\n\n".join(web_snippets)
+ all_results.append(result)
+ else:
+ all_results.append(f"Google search error for '{q}': {response.status_code}")
+
+ return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0]
+
+ except Exception as e:
+ return f"Google search fallback error: {e}"
+
+ async def call(self, query: str | list, **kwargs) -> str:
+ """
+ Search the web using Serper API or Google Custom Search.
+
+ Args:
+ query: Search query string or list of queries
+
+ Returns:
+ Formatted search results
+ """
+ api_key = os.getenv("SERPER_API_KEY")
+
+ # Try Google Custom Search as fallback if no Serper key
+ if not api_key:
+ google_key = os.getenv("GOOGLE_SEARCH_SECRET_KEY")
+ google_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
+
+ if google_key and google_engine_id:
+ return self._google_search_fallback(query)
+
+ return f"""[Search - API Key Required]
+
+To enable real web search, use one of these options:
+
+Option 1 - Serper (Recommended, simpler):
+1. Get a free API key from https://serper.dev (2500 searches/month free)
+2. Add to .env: SERPER_API_KEY=your_key_here
+
+Option 2 - Google Custom Search:
+1. Set up at https://developers.google.com/custom-search
+2. Add to .env:
+ GOOGLE_SEARCH_SECRET_KEY=your_key
+ GOOGLE_SEARCH_ENGINE_ID=your_engine_id
+
+Placeholder results for '{query}'..."""
+
+ # Handle single query or list
+ queries = [query] if isinstance(query, str) else query
+ all_results = []
+
+ for q in queries:
+ try:
+ conn = http.client.HTTPSConnection("google.serper.dev")
+
+ # Localize for Chinese queries
+ if self.contains_chinese(q):
+ payload = json.dumps({"q": q, "location": "China", "gl": "cn", "hl": "zh-cn"})
+ else:
+ payload = json.dumps({"q": q, "location": "United States", "gl": "us", "hl": "en"})
+
+ headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
+
+ # Retry logic
+ for i in range(5):
+ try:
+ conn.request("POST", "/search", payload, headers)
+ res = conn.getresponse()
+ break
+ except Exception:
+ if i == 4:
+ all_results.append(f"Google search timeout for '{q}'")
+ continue
+
+ data = res.read()
+ results = json.loads(data.decode("utf-8"))
+
+ if "organic" not in results:
+ all_results.append(f"No results found for '{q}'")
+ continue
+
+ web_snippets = []
+ for idx, page in enumerate(results.get("organic", [])[:10], 1):
+ date_published = f"\nDate: {page['date']}" if "date" in page else ""
+ source = f"\nSource: {page['source']}" if "source" in page else ""
+ snippet = f"\n{page['snippet']}" if "snippet" in page else ""
+
+ entry = f"{idx}. [{page.get('title', 'Untitled')}]({page.get('link', '')}){date_published}{source}{snippet}"
+ web_snippets.append(entry)
+
+ content = f"Google search for '{q}' found {len(web_snippets)} results:\n\n" + "\n\n".join(web_snippets)
+ all_results.append(content)
+
+ except Exception as e:
+ all_results.append(f"Search error for '{q}': {e}")
+
+ return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0]
+
+
+class ScholarTool(DeepResearchTool):
+ """Google Scholar search using Serper API (ported from Tongyi)."""
+
+ def __init__(self):
+ super().__init__(
+ name="Scholar",
+ description="Search Google Scholar for academic papers",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The academic search query",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def call(self, query: str | list, **kwargs) -> str:
+ """
+ Search Google Scholar using Serper API.
+
+ Args:
+ query: Search query string or list of queries
+
+ Returns:
+ Academic search results
+ """
+ api_key = os.getenv("SERPER_API_KEY")
+ if not api_key:
+ return """[Scholar - API Key Required]
+
+To enable Google Scholar search, configure SERPER_API_KEY in your .env file."""
+
+ queries = [query] if isinstance(query, str) else query
+ all_results = []
+
+ for q in queries:
+ try:
+ conn = http.client.HTTPSConnection("google.serper.dev")
+ payload = json.dumps({"q": q, "type": "scholar", "num": 10})
+ headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
+
+ conn.request("POST", "/scholar", payload, headers)
+ res = conn.getresponse()
+ data = res.read()
+ results = json.loads(data.decode("utf-8"))
+
+ if "organic" not in results:
+ all_results.append(f"No scholar results for '{q}'")
+ continue
+
+ papers = []
+ for idx, paper in enumerate(results.get("organic", [])[:10], 1):
+ title = paper.get("title", "Untitled")
+ link = paper.get("link", "")
+ snippet = paper.get("snippet", "")
+ publication = paper.get("publication", "")
+ year = paper.get("year", "")
+ cited_by = paper.get("citedBy", {}).get("value", 0)
+
+ entry = f"{idx}. [{title}]({link})"
+ if publication:
+ entry += f"\n Publication: {publication}"
+ if year:
+ entry += f" ({year})"
+ if cited_by:
+ entry += f"\n Cited by: {cited_by}"
+ if snippet:
+ entry += f"\n {snippet}"
+
+ papers.append(entry)
+
+ result_text = f"Google Scholar search for '{q}':\n\n" + "\n\n".join(papers)
+ all_results.append(result_text)
+
+ except Exception as e:
+ all_results.append(f"Scholar search error for '{q}': {e}")
+
+ return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0]
+
+
+class VisitTool(DeepResearchTool):
+ """Web page visiting with content extraction."""
+
+ def __init__(self):
+ super().__init__(
+ name="Visit",
+ description="Visit and extract content from web pages",
+ parameters={
+ "type": "object",
+ "properties": {
+ "url": {"type": "string", "description": "The URL to visit"},
+ "goal": {
+ "type": "string",
+ "description": "Optional goal for the visit",
+ },
+ },
+ "required": ["url"],
+ },
+ )
+
+ async def call(self, url: str | list, goal: str = "", **kwargs) -> str:
+ """
+ Visit web pages and extract content.
+
+ Args:
+ url: URL string or list of URLs
+ goal: Optional goal for the visit
+
+ Returns:
+ Extracted webpage content
+ """
+ try:
+ import requests
+ from bs4 import BeautifulSoup
+ except ImportError:
+ return """[Visit Tool - Dependencies Required]
+
+To enable webpage visiting:
+pip install requests beautifulsoup4
+
+Then the tool will fetch and parse webpage content."""
+
+ import re
+ from urllib.parse import urlparse
+
+ urls = [url] if isinstance(url, str) else url
+ all_results = []
+
+ for target_url in urls[:5]: # Limit to 5 URLs
+ try:
+ # Validate and normalize URL
+ parsed = urlparse(target_url)
+ if not parsed.scheme:
+ target_url = f"https://{target_url}"
+
+ # Fetch webpage
+ headers = {"User-Agent": "Mozilla/5.0 (compatible; DeepResearch/1.0)"}
+ response = requests.get(target_url, headers=headers, timeout=10)
+ response.raise_for_status()
+
+ # Parse HTML
+ soup = BeautifulSoup(response.text, "html.parser")
+
+ # Remove unwanted elements
+ for element in soup(["script", "style", "nav", "footer", "header", "aside"]):
+ element.decompose()
+
+ # Extract title
+ title = soup.title.string if soup.title else "No Title"
+
+ # Extract main content
+ content = ""
+ for selector in ["main", "article", ".content", "#content", ".post"]:
+ element = soup.select_one(selector)
+ if element:
+ content = element.get_text(separator="\n", strip=True)
+ break
+
+ if not content:
+ body = soup.find("body")
+ if body:
+ content = body.get_text(separator="\n", strip=True)
+
+ # Clean up text
+ content = re.sub(r"\n{3,}", "\n\n", content)
+ content = re.sub(r" {2,}", " ", content)
+
+ # Limit length
+ if len(content) > 5000:
+ content = content[:5000] + "\n[Content truncated...]"
+
+ # Format result
+ result = f"[Webpage: {target_url}]\nTitle: {title}"
+ if goal:
+ result += f"\nGoal: {goal}"
+ result += f"\n\nContent:\n{content}"
+
+ all_results.append(result)
+
+ except Exception as e:
+ all_results.append(f"[Error visiting {target_url}]: {e}")
+
+ return "\n\n=======\n\n".join(all_results)
+
+
+class FileParserTool(DeepResearchTool):
+ """Enhanced file parsing for multiple formats."""
+
+ def __init__(self):
+ super().__init__(
+ name="FileParser",
+ description="Parse files: TXT, JSON, CSV, PDF, DOCX, etc.",
+ parameters={
+ "type": "object",
+ "properties": {
+ "files": {
+ "type": "string",
+ "description": "File path or list of file paths to parse",
+ }
+ },
+ "required": ["files"],
+ },
+ )
+
+ async def call(self, files: str | list, **kwargs) -> str:
+ """
+ Parse files and extract content.
+
+ Args:
+ files: File path string or list of paths
+
+ Returns:
+ Extracted file content
+ """
+ import csv
+ from pathlib import Path
+
+ file_paths = [files] if isinstance(files, str) else files
+ all_results = []
+
+ for file_path in file_paths[:10]: # Limit to 10 files
+ if not os.path.exists(file_path):
+ all_results.append(f"Error: File not found at {file_path}")
+ continue
+
+ try:
+ file_ext = Path(file_path).suffix.lower()
+ file_name = os.path.basename(file_path)
+ file_size = os.path.getsize(file_path)
+
+ content = ""
+
+ # Text files
+ if file_ext in [
+ ".txt",
+ ".md",
+ ".log",
+ ".py",
+ ".js",
+ ".java",
+ ".cpp",
+ ".c",
+ ".h",
+ ]:
+ with open(file_path, encoding="utf-8", errors="ignore") as f:
+ content = f.read()
+
+ # JSON files
+ elif file_ext == ".json":
+ with open(file_path, encoding="utf-8") as f:
+ data = json.load(f)
+ content = json.dumps(data, indent=2, ensure_ascii=False)
+
+ # CSV files
+ elif file_ext == ".csv":
+ rows = []
+ with open(file_path, encoding="utf-8", errors="ignore") as f:
+ reader = csv.reader(f)
+ for i, row in enumerate(reader):
+ if i >= 100:
+ rows.append("[... truncated ...]")
+ break
+ rows.append(", ".join(row))
+ content = "\n".join(rows)
+
+ # PDF files
+ elif file_ext == ".pdf":
+ try:
+ import PyPDF2
+
+ with open(file_path, "rb") as f:
+ pdf_reader = PyPDF2.PdfReader(f)
+ pages = []
+ for i in range(min(len(pdf_reader.pages), 10)):
+ page = pdf_reader.pages[i]
+ pages.append(f"Page {i + 1}:\n{page.extract_text()}")
+ content = "\n\n".join(pages)
+ except ImportError:
+ content = "[PDF parsing requires: pip install PyPDF2]"
+
+ # Word documents
+ elif file_ext in [".docx", ".doc"]:
+ try:
+ from docx import Document
+
+ doc = Document(file_path)
+ paragraphs = []
+ for i, para in enumerate(doc.paragraphs):
+ if i >= 100:
+ paragraphs.append("[... truncated ...]")
+ break
+ if para.text.strip():
+ paragraphs.append(para.text)
+ content = "\n\n".join(paragraphs)
+ except ImportError:
+ content = "[DOCX parsing requires: pip install python-docx]"
+
+ # Default: try as text
+ else:
+ try:
+ with open(file_path, encoding="utf-8", errors="ignore") as f:
+ content = f.read()
+ except Exception:
+ content = f"[Cannot parse file type: {file_ext}]"
+
+ # Limit content
+ if len(content) > 10000:
+ content = content[:10000] + "\n[Content truncated...]"
+
+ result = f"[File: {file_name}]\nType: {file_ext}\nSize: {file_size:,} bytes\n\nContent:\n{content}"
+ all_results.append(result)
+
+ except Exception as e:
+ all_results.append(f"Error parsing {file_path}: {e}")
+
+ return "\n\n=======\n\n".join(all_results)
+
+
+class PythonInterpreterTool(DeepResearchTool):
+ """Safe Python code execution (from existing implementation)."""
+
+ def __init__(self):
+ super().__init__(
+ name="PythonInterpreter",
+ description="Execute Python code for calculations and analysis",
+ parameters={
+ "type": "object",
+ "properties": {"code": {"type": "string", "description": "Python code to execute"}},
+ "required": ["code"],
+ },
+ )
+ self.timeout = 50
+
+ async def call(self, code: str, timeout: int = None, **kwargs) -> str:
+ """Execute Python code safely with timeout."""
+ timeout = timeout or self.timeout
+
+ # Security checks - check for dangerous imports/operations
+ dangerous_patterns = [
+ "import os",
+ "import subprocess",
+ "import sys",
+ "from os import",
+ "from subprocess import",
+ "from sys import",
+ "exec(",
+ "eval(",
+ "compile(",
+ "open(",
+ "file(",
+ ]
+
+ code_lower = code.lower()
+ for pattern in dangerous_patterns:
+ if pattern in code_lower:
+ return f"[Security Error] '{pattern}' not allowed for safety reasons"
+
+ import io
+ import sys
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
+
+ # Setup safe environment
+ allowed_modules = {
+ "math": __import__("math"),
+ "datetime": __import__("datetime"),
+ "json": __import__("json"),
+ "random": __import__("random"),
+ "re": __import__("re"),
+ "collections": __import__("collections"),
+ "itertools": __import__("itertools"),
+ "statistics": __import__("statistics"),
+ }
+
+ # Add numpy/pandas if available
+ try:
+ import numpy as np
+
+ allowed_modules["numpy"] = np
+ allowed_modules["np"] = np
+ except ImportError:
+ pass
+
+ try:
+ import pandas as pd
+
+ allowed_modules["pandas"] = pd
+ allowed_modules["pd"] = pd
+ except ImportError:
+ pass
+
+ # Restricted builtins with safe import capability
+ def safe_import(name, *args, **kwargs):
+ """Allow importing only safe modules."""
+ safe_modules = [
+ "math",
+ "datetime",
+ "json",
+ "random",
+ "re",
+ "collections",
+ "itertools",
+ "statistics",
+ "numpy",
+ "pandas",
+ "scipy",
+ "scipy.linalg", # Add scipy submodules
+ "scipy.optimize",
+ "scipy.signal",
+ "scipy.special",
+ "matplotlib",
+ "matplotlib.pyplot",
+ ]
+ # Check if the module or its parent is allowed
+ if name in safe_modules or any(name.startswith(m + ".") for m in safe_modules):
+ return __import__(name, *args, **kwargs)
+ else:
+ raise ImportError(f"Module '{name}' is not allowed for safety reasons")
+
+ restricted_builtins = {
+ "abs": abs,
+ "all": all,
+ "any": any,
+ "bin": bin,
+ "bool": bool,
+ "chr": chr,
+ "dict": dict,
+ "enumerate": enumerate,
+ "filter": filter,
+ "float": float,
+ "hex": hex,
+ "int": int,
+ "len": len,
+ "list": list,
+ "map": map,
+ "max": max,
+ "min": min,
+ "oct": oct,
+ "ord": ord,
+ "pow": pow,
+ "print": print,
+ "range": range,
+ "reversed": reversed,
+ "round": round,
+ "set": set,
+ "slice": slice,
+ "sorted": sorted,
+ "str": str,
+ "sum": sum,
+ "tuple": tuple,
+ "type": type,
+ "zip": zip,
+ "__import__": safe_import, # Allow safe imports
+ # Add exception classes for proper error handling
+ "Exception": Exception,
+ "ImportError": ImportError,
+ "ValueError": ValueError,
+ "TypeError": TypeError,
+ "KeyError": KeyError,
+ "IndexError": IndexError,
+ "AttributeError": AttributeError,
+ }
+
+ global_vars = {"__builtins__": restricted_builtins}
+ global_vars.update(allowed_modules)
+ local_vars = {}
+
+ # Capture output
+ old_stdout = sys.stdout
+ old_stderr = sys.stderr
+ stdout_buffer = io.StringIO()
+ stderr_buffer = io.StringIO()
+
+ def execute_with_timeout():
+ try:
+ sys.stdout = stdout_buffer
+ sys.stderr = stderr_buffer
+ exec(code, global_vars, local_vars)
+ return True
+ except Exception as e:
+ stderr_buffer.write(f"Execution error: {e}")
+ return False
+ finally:
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
+
+ # Execute with timeout
+ with ThreadPoolExecutor() as executor:
+ try:
+ future = executor.submit(execute_with_timeout)
+ future.result(timeout=timeout)
+
+ stdout_content = stdout_buffer.getvalue()
+ stderr_content = stderr_buffer.getvalue()
+
+ if stderr_content:
+ return f"[Error]\n{stderr_content}"
+ elif stdout_content:
+ return f"[Output]\n{stdout_content.rstrip()}"
+ else:
+ meaningful_vars = {k: v for k, v in local_vars.items() if not k.startswith("_") and k not in allowed_modules}
+ if meaningful_vars:
+ return f"[Variables]\n{meaningful_vars}"
+ else:
+ return "[Success] Code executed (no output)"
+
+ except TimeoutError:
+ return f"[Timeout] Execution exceeded {timeout}s"
+
+ return "[Error] Unexpected execution error"
+
+
+# Tool registry
+DEEPRESEARCH_TOOLS = {
+ "Search": SearchTool(),
+ "Scholar": ScholarTool(),
+ "Visit": VisitTool(),
+ "FileParser": FileParserTool(),
+ "PythonInterpreter": PythonInterpreterTool(),
+}
+
+
+def get_tool(name: str) -> DeepResearchTool:
+ """Get a tool by name."""
+ return DEEPRESEARCH_TOOLS.get(name)
+
+
+def get_all_tools() -> dict[str, DeepResearchTool]:
+ """Get all available tools."""
+ return DEEPRESEARCH_TOOLS.copy()
diff --git a/vendor/rllm/examples/deepresearch/deepresearch_workflow.py b/vendor/rllm/examples/deepresearch/deepresearch_workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d557892ca155f404213ff48d93366ca2de000d
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/deepresearch_workflow.py
@@ -0,0 +1,271 @@
+"""
+DeepResearch Workflow for rLLM
+
+This workflow integrates the DeepResearch agent with rLLM's AgentWorkflowEngine,
+enabling parallel execution and trajectory tracking while maintaining DeepResearch's
+core reasoning capabilities.
+"""
+
+from deepresearch_agent import MultiTurnReactAgent
+
+from rllm.agents.agent import Action, Episode, Step, Trajectory
+from rllm.engine.rollout import RolloutEngine
+from rllm.workflows.workflow import TerminationReason, Workflow
+
+
+class DeepResearchWorkflow(Workflow):
+ """
+ Workflow that wraps the DeepResearch MultiTurnReactAgent for use with AgentWorkflowEngine.
+
+ This workflow:
+ 1. Creates a DeepResearch agent instance
+ 2. Executes the research task using the agent's ReAct loop
+ 3. Converts the results to rLLM Episode format for trajectory tracking
+ """
+
+ def __init__(
+ self,
+ rollout_engine: RolloutEngine,
+ executor,
+ tools: dict = None,
+ system_prompt: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the DeepResearch workflow.
+
+ Args:
+ rollout_engine: rLLM rollout engine for model inference
+ executor: Thread pool executor for async operations
+ tools: Dictionary of available tools for research tasks
+ system_prompt: Custom system prompt (optional, uses default if None)
+ **kwargs: Additional arguments passed to parent Workflow
+ """
+ super().__init__(rollout_engine, executor, **kwargs)
+
+ self.tools = tools or {}
+ self.system_prompt = system_prompt
+
+ # Auto-detect if we should use native function calling
+ # O3 models require native function calling, other models use XML format
+ model_name = rollout_engine.model.lower()
+ use_native_fc = "o3" in model_name or "o1" in model_name
+
+ # Create the DeepResearch agent
+ self.agent = MultiTurnReactAgent(
+ rollout_engine=rollout_engine,
+ tools=self.tools,
+ system_prompt=self.system_prompt,
+ use_native_function_calling=use_native_fc,
+ )
+
+ # Note: We don't register the agent since DeepResearch handles its own trajectory
+
+ async def run(self, task: dict, uid: str, **kwargs) -> Episode:
+ """
+ Execute the DeepResearch workflow on a single task.
+
+ Args:
+ task: Task dictionary containing:
+ - question: The research question to answer
+ - answer: Ground truth answer (optional, for evaluation)
+ - Any other task metadata
+ uid: Unique identifier for this episode
+
+ Returns:
+ Episode object with trajectory and results
+ """
+ # Reset workflow state for this task
+ self.reset(task=task, uid=uid)
+
+ # Extract question and answer from task
+ question = task.get("question", task.get("query", "No question provided"))
+ answer = task.get("answer", "")
+
+ print(f"🚀 Starting DeepResearch workflow for task {uid}")
+ print(f" Question: {question}")
+
+ try:
+ # Run the DeepResearch agent
+ result = await self.agent.run(question=question, answer=answer, **kwargs)
+
+ # Convert the result to rLLM Episode format
+ episode = self._convert_to_episode(result, task, uid)
+
+ print(f"✅ DeepResearch workflow completed for task {uid}")
+ print(f" Prediction: {result.get('prediction', 'No prediction')}")
+
+ return episode
+
+ except Exception as e:
+ print(f"❌ DeepResearch workflow failed for task {uid}: {e}")
+
+ # Create a failed episode
+ episode = Episode()
+ episode.id = uid
+ episode.task = task
+ episode.termination_reason = TerminationReason.UNKNOWN
+ episode.is_correct = False
+ episode.trajectories = []
+ episode.metrics = {"error": str(e)}
+ return episode
+
+ def _convert_to_episode(self, result: dict, task: dict, uid: str) -> Episode:
+ """
+ Convert DeepResearch result to rLLM Episode format.
+
+ Args:
+ result: Result dictionary from DeepResearch agent
+ task: Original task dictionary
+ uid: Episode unique identifier
+
+ Returns:
+ Episode object with trajectory
+ """
+ # Create trajectory from the conversation messages
+ trajectory = Trajectory(task=task.get("question", ""))
+
+ # Convert conversation to steps
+ messages = result.get("messages", [])
+
+ i = 0
+ while i < len(messages):
+ # Look for assistant messages (model responses)
+ if messages[i]["role"] == "assistant":
+ # Build chat completion context up to this point
+ current_context = messages[: i + 1]
+
+ # Create step
+ step = Step(
+ chat_completions=current_context.copy(),
+ model_response=messages[i]["content"],
+ action=self._extract_action_from_response(messages[i]["content"]),
+ observation=self._get_next_observation(messages, i),
+ reward=0.0, # Will be computed later if needed
+ )
+
+ trajectory.steps.append(step)
+
+ i += 1
+
+ # Determine if the answer is correct (if ground truth available)
+ prediction = result.get("prediction", "")
+ ground_truth = task.get("answer", "")
+ is_correct = self._evaluate_answer(prediction, ground_truth) if ground_truth else False
+
+ # Map termination reason
+ termination_reason = self._map_termination_reason(result.get("termination", "unknown"))
+
+ # Create episode
+ episode = Episode()
+ episode.id = uid
+ episode.task = task
+ episode.termination_reason = termination_reason
+ episode.is_correct = is_correct
+ episode.trajectories = [("deepresearch_agent", trajectory)]
+ episode.metrics = {
+ "rounds": result.get("rounds", 0),
+ "time_taken": result.get("time_taken", 0),
+ "prediction": prediction,
+ "ground_truth": ground_truth,
+ }
+
+ return episode
+
+ def _extract_action_from_response(self, response: str) -> Action:
+ """
+ Extract action information from model response.
+
+ Args:
+ response: Model response text
+
+ Returns:
+ Action object
+ """
+ # Check for tool calls
+ if "" in response and " " in response:
+ tool_call_text = response.split("")[1].split(" ")[0]
+ return Action(action={"type": "tool_call", "tool_call": tool_call_text.strip()})
+ # Check for final answer
+ elif "" in response and " " in response:
+ answer = response.split("")[1].split(" ")[0].strip()
+ return Action(action={"type": "final_answer", "answer": answer})
+ else:
+ # Just thinking/reasoning
+ return Action(action={"type": "reasoning", "content": response})
+
+ def _get_next_observation(self, messages: list, current_index: int) -> str:
+ """
+ Get the observation that follows the current assistant message.
+
+ Args:
+ messages: List of all messages
+ current_index: Index of current assistant message
+
+ Returns:
+ Next observation string (tool response or empty)
+ """
+ if current_index + 1 < len(messages):
+ next_msg = messages[current_index + 1]
+ if next_msg["role"] == "user" and "" in next_msg["content"]:
+ return next_msg["content"]
+
+ return ""
+
+ def _evaluate_answer(self, prediction: str, ground_truth: str) -> bool:
+ """
+ Simple answer evaluation (can be enhanced with specific metrics).
+
+ Args:
+ prediction: Model's predicted answer
+ ground_truth: Correct answer
+
+ Returns:
+ True if correct, False otherwise
+ """
+ if not prediction or not ground_truth:
+ return False
+
+ # Simple string matching (can be enhanced with fuzzy matching, etc.)
+ return prediction.strip().lower() == ground_truth.strip().lower()
+
+ def _map_termination_reason(self, termination: str) -> TerminationReason:
+ """
+ Map DeepResearch termination reasons to rLLM TerminationReason enum.
+
+ Args:
+ termination: DeepResearch termination string
+
+ Returns:
+ Mapped TerminationReason
+ """
+ mapping = {
+ "answer": TerminationReason.ENV_DONE,
+ "timeout": TerminationReason.TIMEOUT,
+ "max_rounds_reached": TerminationReason.MAX_TURNS_EXCEEDED,
+ "token_limit_no_answer": TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED,
+ "answer_token_limit": TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED,
+ }
+
+ return mapping.get(termination, TerminationReason.UNKNOWN)
+
+ def reset(self, task: dict = None, uid: str = None):
+ """
+ Reset the workflow for a new task.
+
+ Args:
+ task: New task dictionary
+ uid: New unique identifier
+ """
+ # Skip parent reset since we don't use registered agents
+ # The DeepResearch agent manages its own state per run()
+ pass
+
+ def is_multithread_safe(self) -> bool:
+ """
+ Indicate whether this workflow is safe for multithreaded execution.
+
+ Returns:
+ True, as each workflow instance manages its own state
+ """
+ return True
diff --git a/vendor/rllm/examples/deepresearch/evaluate_hle.py b/vendor/rllm/examples/deepresearch/evaluate_hle.py
new file mode 100644
index 0000000000000000000000000000000000000000..62f3fde5c787a62c6fefb5d7c3643655661e55d9
--- /dev/null
+++ b/vendor/rllm/examples/deepresearch/evaluate_hle.py
@@ -0,0 +1,517 @@
+"""
+Humanity's Last Exam (HLE) Evaluation for DeepResearch + rLLM
+
+Adapted from original DeepResearch HLE evaluation to work with rLLM's
+DeepResearch integration and AgentWorkflowEngine.
+
+Original: https://github.com/Alibaba-NLP/DeepResearch/blob/main/evaluation/evaluate_hle_official.py
+"""
+
+import argparse
+import asyncio
+import json
+import os
+import statistics
+from datetime import datetime
+from typing import Any
+
+from datasets import load_dataset
+from deepresearch_tools import get_all_tools
+from deepresearch_workflow import DeepResearchWorkflow
+from dotenv import find_dotenv, load_dotenv
+
+from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
+from rllm.engine.rollout import OpenAIEngine
+
+
+class HLEJudge:
+ """Judge for evaluating HLE responses using OpenAI API."""
+
+ def __init__(self, judge_engine: OpenAIEngine):
+ self.judge_engine = judge_engine
+ # Binary yes/no judge prompt aligned with Tongyi DeepResearch
+ self.judge_prompt = """You are an impartial judge evaluating the correctness of an AI assistant's answer.
+
+[Question]
+{question}
+
+[Correct Answer]
+{reference_answer}
+
+[Assistant's Answer]
+{assistant_answer}
+
+Task: Determine if the assistant's answer is correct by comparing it to the correct answer.
+
+Instructions:
+1. Extract the final answer from the assistant's response
+2. Compare it with the correct answer
+3. Provide your reasoning
+4. Answer with "yes" if correct, "no" if incorrect
+
+Output format:
+correct: [yes/no]
+reasoning: [your explanation]"""
+
+ async def judge_response(self, question: str, reference_answer: str, assistant_answer: str) -> dict[str, Any]:
+ """
+ Judge a single response.
+
+ Args:
+ question: Original question
+ reference_answer: Ground truth answer
+ assistant_answer: Model's prediction
+
+ Returns:
+ Dictionary with judgment results
+ """
+ try:
+ prompt = self.judge_prompt.format(
+ question=question,
+ reference_answer=reference_answer,
+ assistant_answer=assistant_answer,
+ )
+
+ messages = [{"role": "user", "content": prompt}]
+
+ # Use appropriate token parameter based on model
+ if "o3" in self.judge_engine.model.lower() or "o1" in self.judge_engine.model.lower():
+ response = await self.judge_engine.get_model_response(messages=messages, max_completion_tokens=1000)
+ else:
+ response = await self.judge_engine.get_model_response(messages=messages, temperature=0.1, max_tokens=1000)
+
+ judgment_text = response.text if hasattr(response, "text") else str(response)
+
+ # Parse binary yes/no from judge output
+ is_correct = False
+ if "correct:" in judgment_text.lower():
+ # Extract the yes/no after "correct:"
+ try:
+ correct_line = [line for line in judgment_text.lower().split("\n") if "correct:" in line][0]
+ is_correct = "yes" in correct_line
+ except (IndexError, ValueError):
+ is_correct = False
+
+ return {
+ "judgment": judgment_text,
+ "is_correct": is_correct,
+ }
+
+ except Exception as e:
+ print(f"Judge error: {e}")
+ return {"judgment": f"Judge error: {e}", "is_correct": False}
+
+
+async def evaluate_hle_dataset(dataset_path: str, args) -> dict[str, Any]:
+ """
+ Evaluate DeepResearch on HLE dataset.
+
+ Args:
+ dataset_path: Path to HLE JSONL dataset
+ args: Command line arguments
+
+ Returns:
+ Evaluation results dictionary
+ """
+ print("📊 Starting HLE Evaluation")
+ print(f"Dataset: {dataset_path}")
+ print(f"Max samples: {args.max_samples}")
+ print("=" * 60)
+
+ # Load dataset (HF only to align with examples pattern)
+ questions = []
+ dataset_name = args.hf_dataset or "cais/hle"
+ split_name = args.hf_split or "test"
+
+ print(f"🧰 Loading dataset from Hugging Face: {dataset_name} (split={split_name})")
+ try:
+ if args.hf_config:
+ ds = load_dataset(dataset_name, args.hf_config, split=split_name)
+ else:
+ ds = load_dataset(dataset_name, split=split_name)
+
+ def extract_qa(example: dict[str, Any]) -> dict[str, str]:
+ q = ""
+ a = ""
+ if "question" in example:
+ q = example["question"]
+ elif "prompt" in example:
+ q = example["prompt"]
+ elif "input" in example:
+ q = example["input"]
+
+ if "answer" in example:
+ a = example["answer"]
+ elif "target" in example:
+ a = example["target"]
+ elif "output" in example:
+ a = example["output"]
+ elif "correct_answer" in example:
+ a = example["correct_answer"]
+
+ if "choices" in example and a:
+ try:
+ choices_text = "\n".join([f"{i + 1}. {choice}" for i, choice in enumerate(example["choices"])])
+ q = f"{q}\n\nChoices:\n{choices_text}"
+ except Exception:
+ pass
+
+ # Inject external contexts (urls/files/images/extra text) to help tools
+ try:
+ extras: list[str] = []
+ # Text contexts
+ for key in [
+ "context",
+ "contexts",
+ "extra",
+ "additional_context",
+ "background",
+ "passage",
+ "passages",
+ ]:
+ if key in example and example[key]:
+ val = example[key]
+ if isinstance(val, list | tuple):
+ val_str = "\n".join([str(v) for v in val][:5])
+ else:
+ val_str = str(val)
+ if val_str.strip():
+ extras.append(f"{key.title()}:\n{val_str}")
+
+ # URLs
+ urls = []
+ if "urls" in example and example["urls"]:
+ urls = example["urls"] if isinstance(example["urls"], list | tuple) else [example["urls"]]
+ elif "url" in example and example["url"]:
+ urls = [example["url"]]
+ if urls:
+ url_lines = "\n".join([f"- {u}" for u in urls[:10]])
+ extras.append(f"URLs:\n{url_lines}")
+
+ # File paths
+ file_paths = []
+ for key in ["file_paths", "file_path", "files"]:
+ if key in example and example[key]:
+ vals = example[key] if isinstance(example[key], list | tuple) else [example[key]]
+ file_paths.extend([str(v) for v in vals])
+ if file_paths:
+ file_lines = "\n".join([f"- {p}" for p in file_paths[:10]])
+ extras.append(f"Files:\n{file_lines}")
+
+ # Images
+ images = []
+ for key in ["images", "image"]:
+ if key in example and example[key]:
+ vals = example[key] if isinstance(example[key], list | tuple) else [example[key]]
+ images.extend([str(v) for v in vals])
+ if images:
+ img_lines = "\n".join([f"- {p}" for p in images[:10]])
+ extras.append(f"Images:\n{img_lines}")
+
+ if extras:
+ q = f"{q}\n\nAdditional context for tools:\n" + "\n\n".join(extras)
+ except Exception:
+ pass
+
+ return {
+ "question": str(q) if q is not None else "",
+ "answer": str(a) if a is not None else "",
+ }
+
+ total_len = len(ds)
+ limit = min(args.max_samples, total_len) if args.max_samples else total_len
+ for idx in range(limit):
+ ex = ds[idx]
+ qa = extract_qa(ex)
+ if qa["question"] and qa["answer"]:
+ questions.append(
+ {
+ "id": f"hle_{idx}",
+ "question": qa["question"],
+ "answer": qa["answer"],
+ }
+ )
+ else:
+ print(f"Warning: Could not extract question/answer from example {idx}")
+
+ except Exception as e:
+ print(f"❌ Failed to load dataset from Hugging Face: {e}")
+ raise
+
+ print(f"📋 Loaded {len(questions)} questions from HLE dataset")
+
+ # Setup rollout engine
+ load_dotenv(find_dotenv())
+
+ # Use GPT-4o for model evaluation
+ model_engine = setup_rollout_engine(args, model_role="evaluation")
+
+ # Setup judge (can use same or different model)
+ judge_engine = setup_rollout_engine(args, model_role="judge")
+ judge = HLEJudge(judge_engine)
+
+ # Setup tools
+ tools = get_all_tools()
+
+ # Create AgentWorkflowEngine
+ workflow_engine = AgentWorkflowEngine(
+ workflow_cls=DeepResearchWorkflow,
+ workflow_args={
+ "tools": tools,
+ "max_prompt_length": 4096,
+ "max_response_length": 2048,
+ },
+ rollout_engine=model_engine,
+ n_parallel_tasks=args.parallel_tasks,
+ retry_limit=1,
+ )
+
+ print(f"⚙️ Created evaluation setup with {args.parallel_tasks} parallel tasks")
+
+ # Run DeepResearch evaluation
+ print("\n🔬 Running DeepResearch evaluation...")
+ start_time = asyncio.get_event_loop().time()
+
+ try:
+ episodes = await workflow_engine.execute_tasks(questions)
+ eval_time = asyncio.get_event_loop().time() - start_time
+
+ print(f"\n✅ Evaluation completed in {eval_time:.1f}s")
+
+ # Extract predictions
+ results = []
+ for episode in episodes:
+ prediction = episode.metrics.get("prediction", "No prediction available")
+ results.append(
+ {
+ "question": episode.task.get("question", ""),
+ "reference_answer": episode.task.get("answer", ""),
+ "prediction": prediction,
+ "episode_id": episode.id,
+ "is_correct": episode.is_correct,
+ "rounds": episode.metrics.get("rounds", 0),
+ "termination_reason": episode.termination_reason.value if episode.termination_reason else "unknown",
+ }
+ )
+
+ # Judge responses
+ print(f"\n⚖️ Judging {len(results)} responses...")
+
+ judge_results = []
+ for result in results:
+ judgment = await judge.judge_response(
+ question=result["question"],
+ reference_answer=result["reference_answer"],
+ assistant_answer=result["prediction"],
+ )
+ result.update(judgment)
+ judge_results.append(result)
+
+ # Calculate metrics
+ metrics = calculate_hle_metrics(judge_results)
+ metrics["evaluation_time"] = eval_time
+ metrics["total_questions"] = len(questions)
+
+ # Save results
+ save_hle_results(judge_results, metrics, args)
+
+ return metrics
+
+ except Exception as e:
+ print(f"❌ Evaluation failed: {e}")
+ raise
+
+
+def setup_rollout_engine(args, model_role="evaluation") -> OpenAIEngine:
+ """Setup rollout engine for evaluation or judging."""
+
+ # Load environment variables
+ load_dotenv(find_dotenv())
+
+ # Provider selection
+ together_api_key = os.getenv("TOGETHER_AI_API_KEY")
+ openai_api_key = os.getenv("OPENAI_API_KEY")
+
+ if args.api_key:
+ api_key = args.api_key
+ base_url = args.base_url or "https://api.openai.com/v1"
+ model_name = args.model or "gpt-4"
+ elif together_api_key and model_role == "evaluation":
+ api_key = together_api_key
+ base_url = args.base_url or "https://api.together.xyz/v1"
+ model_name = args.model or os.getenv("TOGETHER_AI_MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct-Turbo")
+ print(f"🔧 Using Together AI for {model_role}")
+ elif openai_api_key:
+ api_key = openai_api_key
+ base_url = args.base_url or "https://api.openai.com/v1"
+ model_name = args.model or "gpt-4o"
+ print(f"🔧 Using OpenAI for {model_role}")
+ else:
+ raise ValueError("❌ API key required. Please set OPENAI_API_KEY or TOGETHER_AI_API_KEY in .env file")
+
+ # For evaluation, DeepResearch handles all sampling params internally
+ # For judge, we need basic params
+ if model_role == "judge":
+ # Check if model is O3/O1 (use model_name which is already determined above)
+ if "o3" in model_name.lower() or "o1" in model_name.lower():
+ sampling_params = {
+ "max_completion_tokens": 1000,
+ }
+ else:
+ sampling_params = {
+ "temperature": 0.1,
+ "top_p": 0.95,
+ "max_tokens": 1000,
+ }
+ else:
+ # Don't set default sampling_params for evaluation
+ # DeepResearch will handle model-specific params
+ sampling_params = {}
+
+ return OpenAIEngine(
+ model=model_name,
+ tokenizer=None,
+ base_url=base_url,
+ api_key=api_key,
+ sampling_params=sampling_params,
+ )
+
+
+def calculate_hle_metrics(results: list[dict[str, Any]]) -> dict[str, Any]:
+ """Calculate HLE evaluation metrics."""
+
+ total = len(results)
+ if total == 0:
+ return {"error": "No results to evaluate"}
+
+ # Basic accuracy (judge-based binary yes/no)
+ judge_correct = sum(1 for r in results if r.get("is_correct", False))
+ judge_accuracy = judge_correct / total
+
+ # Termination analysis
+ termination_counts = {}
+ for result in results:
+ reason = result.get("termination_reason", "unknown")
+ termination_counts[reason] = termination_counts.get(reason, 0) + 1
+
+ # Round analysis
+ rounds = [r.get("rounds", 0) for r in results]
+ avg_rounds = statistics.mean(rounds) if rounds else 0
+
+ return {
+ "total_questions": total,
+ "judge_accuracy": judge_accuracy,
+ "judge_correct": judge_correct,
+ "average_rounds": avg_rounds,
+ "termination_distribution": termination_counts,
+ }
+
+
+def save_hle_results(results: list[dict], metrics: dict, args):
+ """Save HLE evaluation results."""
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+ # Save detailed results
+ results_file = os.path.join(args.output_dir, f"hle_results_{timestamp}.json")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ with open(results_file, "w", encoding="utf-8") as f:
+ json.dump(
+ {
+ "metadata": {
+ "timestamp": timestamp,
+ "dataset": "HLE",
+ "model": args.model,
+ "total_questions": len(results),
+ },
+ "results": results,
+ "metrics": metrics,
+ },
+ f,
+ indent=2,
+ ensure_ascii=False,
+ )
+
+ # Save metrics summary
+ metrics_file = os.path.join(args.output_dir, f"hle_metrics_{timestamp}.json")
+ with open(metrics_file, "w", encoding="utf-8") as f:
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
+
+ print(f"💾 Results saved to: {results_file}")
+ print(f"📊 Metrics saved to: {metrics_file}")
+
+
+def print_hle_summary(metrics: dict[str, Any]):
+ """Print HLE evaluation summary."""
+
+ print("\n" + "=" * 60)
+ print("📊 HLE EVALUATION SUMMARY")
+ print("=" * 60)
+ print(f"Total Questions: {metrics.get('total_questions', 0)}")
+ print(f"Judge Accuracy: {metrics.get('judge_accuracy', 0):.2%}")
+ print(f"Correct Answers: {metrics.get('judge_correct', 0)}/{metrics.get('total_questions', 0)}")
+ print(f"Average Rounds: {metrics.get('average_rounds', 0):.1f}")
+ print(f"Evaluation Time: {metrics.get('evaluation_time', 0):.1f}s")
+
+ print("\nTermination Reasons:")
+ term_dist = metrics.get("termination_distribution", {})
+ for reason, count in term_dist.items():
+ print(f" {reason}: {count}")
+
+ print("=" * 60)
+
+
+async def main():
+ parser = argparse.ArgumentParser(description="Run HLE evaluation with DeepResearch + rLLM")
+
+ # Dataset options (HF only)
+ parser.add_argument(
+ "--hf-dataset",
+ default="cais/hle",
+ help="Hugging Face dataset path (default: cais/hle)",
+ )
+ parser.add_argument(
+ "--hf-config",
+ default=None,
+ help="Optional dataset configuration name for HF datasets that require it.",
+ )
+ parser.add_argument(
+ "--hf-split",
+ default="test",
+ help="Dataset split to load from HF (default: test)",
+ )
+ parser.add_argument(
+ "--max-samples",
+ type=int,
+ default=None,
+ help="Maximum number of samples to evaluate",
+ )
+
+ # Model options
+ parser.add_argument("--model", default=None, help="Model name to use")
+ parser.add_argument("--base-url", default=None, help="API base URL")
+ parser.add_argument("--api-key", default=None, help="API key (uses env vars if not provided)")
+
+ # Execution options
+ parser.add_argument("--parallel-tasks", type=int, default=4, help="Number of parallel tasks")
+ parser.add_argument("--output-dir", default="./hle_outputs", help="Output directory for results")
+
+ args = parser.parse_args()
+
+ try:
+ metrics = await evaluate_hle_dataset(args.hf_dataset, args)
+ print_hle_summary(metrics)
+
+ except Exception as e:
+ print(f"❌ HLE evaluation failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ # Set environment for tokenizers
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ asyncio.run(main())
diff --git a/vendor/rllm/examples/deepscaler/README.md b/vendor/rllm/examples/deepscaler/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f79387bbdfa0385dd8a8c984031e50561a814aea
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/README.md
@@ -0,0 +1,92 @@
+# DeepScaleR Examples
+
+This directory contains examples for training and running math reasoning agents with tool usage capabilities using the RLLM framework. The math tool agent has access to a Python interepreter to solve mathematical problems through step-by-step reasoning and tool-use.
+
+Our examples uses the following:
+* Deepseek-R1-Distill-1.5B as the base model
+* DeepScaleR-Math dataset for training
+* AIME2024 dataset for evaluation
+
+
+## Model Hosting
+
+### Option 1: Using vLLM
+
+Start a vLLM server with OpenAI-compatible API:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model agentica-org/DeepScaleR-1.5B-Preview \
+ --host 0.0.0.0 \
+ --port 30000 \
+ --dtype bfloat16
+```
+
+### Option 2: Using SGLang
+
+```bash
+python -m sglang_router.launch_server \
+ --model-path agentica-org/DeepScaleR-1.5B-Preview \
+ --dp-size 1 \
+ --dtype bfloat16
+# increase dp_size to enable data-parallel processing on multi-GPU
+```
+
+The server should be accessible at `http://localhost:30000/v1`
+
+## Dataset Preparation
+
+Prepare the required datasets (AIME 2024 for testing, DeepScaleR for training):
+
+```bash
+cd examples/deepscaler
+python prepare_math_data.py
+```
+
+This will:
+- Download AIME 2024 dataset from HuggingFace
+- Download DeepScaleR math dataset for training
+- Register both datasets with the RLLM DatasetRegistry
+
+## Running Inference
+
+Once your model server is running and datasets are prepared, you can run inference:
+
+```bash
+cd examples/deepscaler
+python run_deepscaler.py
+```
+
+### Configuration Options
+
+You can modify the inference script parameters:
+
+- `n_parallel_agents`: Number of parallel agents (default: 64)
+- `model_name`: Model to use (default: "agentica-org/DeepScaleR-1.5B-Preview")
+- `base_url`: API server URL (default: "http://localhost:30000/v1")
+- `max_response_length`: Maximum response length (default: 32768)
+- `max_prompt_length`: Maximum prompt length (default: 2048)
+- `temperature`: Sampling temperature (default: 0.6)
+- `top_p`: Top-p sampling (default: 0.95)
+
+The script will:
+1. Load the AIME 2024 test dataset
+2. Repeat each problem 16 times for Pass@K evaluation
+3. Run parallel and async trajectory collection using the agent execution engine
+4. Evaluate results and report Pass@1 and Pass@K accuracy
+
+## Training
+
+### Basic Training
+
+To train Deepscaler with iterative context lengthening (8K -> 16K -> 24K):
+
+```bash
+bash examples/deepscaler/train_deepscaler_8k.sh
+
+# modify MODEL_PATH to the 8k checkpoint path before running the script.
+bash examples/deepscaler/train_deepscaler_16k.sh
+
+# modify MODEL_PATH to the 16k checkpoint path before running the script
+bash examples/deepscaler/train_deepscaler_24k.sh
+```
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepscaler/prepare_math_data.py b/vendor/rllm/examples/deepscaler/prepare_math_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..08981f4fca410dfae77c81ab59f880dbb4488f2e
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/prepare_math_data.py
@@ -0,0 +1,28 @@
+from datasets import load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_math_data():
+ train_dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")
+ test_dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
+
+ def preprocess_fn(example, idx):
+ return {
+ "question": example["problem"],
+ "ground_truth": example["answer"],
+ "data_source": "math",
+ }
+
+ train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
+ test_dataset = test_dataset.map(preprocess_fn, with_indices=True)
+
+ train_dataset = DatasetRegistry.register_dataset("deepscaler_math", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("aime2024", test_dataset, "test")
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_math_data()
+ print(train_dataset.get_data_path())
+ print(test_dataset.get_data_path())
diff --git a/vendor/rllm/examples/deepscaler/run_deepscaler.py b/vendor/rllm/examples/deepscaler/run_deepscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6a738a490f1e480cae2db23a9f1b31c2aea455
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/run_deepscaler.py
@@ -0,0 +1,58 @@
+import asyncio
+
+from transformers import AutoTokenizer
+
+from rllm.agents.math_agent import MathAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.agent_execution_engine import AgentExecutionEngine
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.utils import compute_pass_at_k
+
+if __name__ == "__main__":
+ import os
+
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ n_parallel_agents = 64
+
+ model_name = "agentica-org/DeepScaleR-1.5B-Preview"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ reward_fn = math_reward_fn
+
+ env_args = {
+ "reward_fn": reward_fn,
+ }
+
+ sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
+
+ engine = AgentExecutionEngine(
+ agent_class=MathAgent,
+ env_class=SingleTurnEnvironment,
+ agent_args={},
+ env_args=env_args,
+ engine_name="openai",
+ tokenizer=tokenizer,
+ sampling_params=sampling_params,
+ rollout_engine_args={
+ "base_url": "http://localhost:30000/v1",
+ "api_key": "None",
+ },
+ max_response_length=32768,
+ max_prompt_length=2048,
+ n_parallel_agents=n_parallel_agents,
+ )
+
+ test_dataset = DatasetRegistry.load_dataset("aime2024", "test")
+ if test_dataset is None:
+ print("Dataset not found, preparing dataset...")
+ from prepare_math_data import prepare_math_data
+
+ _, test_dataset = prepare_math_data()
+
+ tasks = test_dataset.repeat(n=16) # repeat to evaluate pass@k
+
+ results = asyncio.run(engine.execute_tasks(tasks))
+ compute_pass_at_k(results)
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler.py b/vendor/rllm/examples/deepscaler/train_deepscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a730358862f603f184d27d20aa6448a3f7f761f8
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler.py
@@ -0,0 +1,30 @@
+import hydra
+
+from rllm.agents.math_agent import MathAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("deepscaler_math", "train")
+ test_dataset = DatasetRegistry.load_dataset("aime2024", "test")
+
+ env_args = {"reward_fn": math_reward_fn}
+
+ trainer = AgentTrainer(
+ agent_class=MathAgent,
+ agent_args={},
+ env_args=env_args,
+ env_class=SingleTurnEnvironment,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler_16k.sh b/vendor/rllm/examples/deepscaler/train_deepscaler_16k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..745f74355a82d41f6558e9e0de8c71e17604da46
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler_16k.sh
@@ -0,0 +1,64 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+
+python3 -m examples.deepscaler.train_deepscaler \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=128 \
+ data.val_batch_size=30 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=16384 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='deepscaler-1.5b-16k' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.test_freq=20 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler_24k.sh b/vendor/rllm/examples/deepscaler/train_deepscaler_24k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9c35ae3c78fdcb24b459e3945d23bf1e3eab43b3
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler_24k.sh
@@ -0,0 +1,64 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+
+python3 -m examples.deepscaler.train_deepscaler \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=128 \
+ data.val_batch_size=30 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=24576 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='deepscaler-1.5b-24k' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.test_freq=20 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler_8k.sh b/vendor/rllm/examples/deepscaler/train_deepscaler_8k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b939ad3486547badb5e0d5390c6d078f028d32e7
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler_8k.sh
@@ -0,0 +1,64 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+
+python3 -m examples.deepscaler.train_deepscaler \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=128 \
+ data.val_batch_size=30 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=8192 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='deepscaler-1.5b-8k' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.test_freq=20 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.py b/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b90eb31492403be4f7a3dc95742c57b01248ff6
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.py
@@ -0,0 +1,30 @@
+import hydra
+
+from rllm.agents.math_agent import MathAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer_megatron", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("deepscaler_math", "train")
+ test_dataset = DatasetRegistry.load_dataset("aime2024", "test")
+
+ env_args = {"reward_fn": math_reward_fn}
+
+ trainer = AgentTrainer(
+ agent_class=MathAgent,
+ agent_args={},
+ env_args=env_args,
+ env_class=SingleTurnEnvironment,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.sh b/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.sh
new file mode 100644
index 0000000000000000000000000000000000000000..344de069ad53440ca74bbe5d57a43ea16f4b7f5b
--- /dev/null
+++ b/vendor/rllm/examples/deepscaler/train_deepscaler_megatron.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+
+MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
+
+gen_tp=2
+train_tp=2
+train_pp=2
+
+# Run DeepScaler training with Megatron
+python -m examples.deepscaler.train_deepscaler_megatron \
+ algorithm.adv_estimator=grpo \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ data.train_batch_size=128 \
+ data.val_batch_size=30 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=24576 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.megatron.param_offload=True \
+ actor_rollout_ref.actor.megatron.grad_offload=True \
+ actor_rollout_ref.actor.megatron.optimizer_offload=True \
+ actor_rollout_ref.ref.megatron.param_offload=True \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
+ critic.ppo_micro_batch_size_per_gpu=1 \
+ critic.ppo_mini_batch_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=$gen_tp \
+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$train_tp \
+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$train_pp \
+ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$train_tp \
+ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$train_pp \
+ critic.megatron.tensor_model_parallel_size=$train_tp \
+ critic.megatron.pipeline_model_parallel_size=$train_pp \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='deepscaler-1.5b-megatron' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=20 \
+ trainer.test_freq=20 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/eval_protocol/README.md b/vendor/rllm/examples/eval_protocol/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fe028f3ef600d4a4324990ba7b062aafcbca898d
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/README.md
@@ -0,0 +1,78 @@
+# Eval Protocol FrozenLake Example
+
+This example shows how to use **Eval Protocol**'s FrozenLake environment from within **rLLM** using the generic `EvalProtocolWorkflow`.
+
+For a conceptual overview of how this integration works and how it generalizes to other benchmarks, see the core-concepts page on [Eval Protocol Integration](../../docs/core-concepts/eval-protocol.md).
+
+---
+
+## Quick Start
+
+### Prepare FrozenLake dataset
+
+From the project root:
+
+```bash
+cd examples/eval_protocol
+python prepare_frozen_lake_data.py
+```
+
+This script builds and registers the `frozen_lake_eval_protocol` train/test splits in the rLLM `DatasetRegistry`.
+
+### Run FrozenLake workflow (inference)
+
+Once your Fireworks API credentials are configured, you can run a small batch of FrozenLake tasks through Eval Protocol and rLLM:
+
+```bash
+python run_frozen_lake_flow.py
+```
+
+This will:
+
+- Load the `frozen_lake_eval_protocol` test split.
+- Use `EvalProtocolWorkflow` (with `env_path="eval_protocol.benchmarks.test_frozen_lake"`) to run rollouts via Eval Protocol.
+- Print per-task rewards/accuracy and save results to `logs/frozen_lake_results.json`.
+
+### Train an RL agent
+
+To train an agent against the same Eval Protocol FrozenLake environment:
+
+```bash
+bash train_frozen_lake_flow.sh
+```
+
+This uses `EvalProtocolWorkflow` inside `AgentTrainer` (via Hydra configs) to:
+
+- Generate rollouts using Eval Protocol’s rollout processor and MCP server.
+- Compute rewards via the Eval Protocol evaluation function.
+- Optimize the underlying model with PPO/GRPO.
+
+You can edit `train_frozen_lake_flow.sh` to customize model path, Fireworks deployment, and training hyperparameters.
+
+---
+
+## Code Reference
+
+### Data preparation
+
+Script that builds and registers the FrozenLake Eval Protocol dataset:
+
+```python title="examples/eval_protocol/prepare_frozen_lake_data.py"
+--8<-- "examples/eval_protocol/prepare_frozen_lake_data.py"
+```
+
+### Workflow runner
+
+Main script for running the FrozenLake Eval Protocol workflow through rLLM:
+
+```python title="examples/eval_protocol/run_frozen_lake_flow.py"
+--8<-- "examples/eval_protocol/run_frozen_lake_flow.py"
+```
+
+### Training script
+
+Agent training implementation using `EvalProtocolWorkflow` and `AgentTrainer`:
+
+```python title="examples/eval_protocol/train_frozen_lake_flow.py"
+--8<-- "examples/eval_protocol/train_frozen_lake_flow.py"
+```
diff --git a/vendor/rllm/examples/eval_protocol/frozen_lake_flow.py b/vendor/rllm/examples/eval_protocol/frozen_lake_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..37640b3ac43356361cc868ac2b712d8657e8c9f5
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/frozen_lake_flow.py
@@ -0,0 +1,227 @@
+"""
+This workflow bridges eval-protocol's MCPGymRolloutProcessor with rllm-fw's Workflow pattern
+for the FrozenLake environment.
+"""
+
+import asyncio
+from pathlib import Path
+
+import eval_protocol
+from eval_protocol.benchmarks.test_frozen_lake import test_frozen_lake_evaluation
+from eval_protocol.models import EvaluationRow, InputMetadata, Message
+from eval_protocol.pytest.default_mcp_gym_rollout_processor import (
+ MCPGymRolloutProcessor,
+)
+from eval_protocol.pytest.types import RolloutProcessorConfig
+
+from rllm.agents.agent import Episode, Step, Trajectory
+from rllm.engine.rollout.openai_engine import OpenAIEngine
+from rllm.workflows.workflow import Workflow
+
+
+class FrozenLakeWorkflow(Workflow):
+ """
+ Workflow that executes frozen lake tasks using MCPGymRolloutProcessor.
+
+ Task format expected:
+ {
+ "id": "frozen_lake_task_0",
+ "system_prompt": "...",
+ "environment_context": {...},
+ "user_prompt_template": "{observation}"
+ }
+ """
+
+ # Class variables (shared across all workflow instances)
+ _shared_server_started = False
+ _server_lock = asyncio.Lock()
+ _shared_rollout_processor = MCPGymRolloutProcessor()
+
+ def __init__(self, rollout_engine: OpenAIEngine, lite_llm_prefix: str = "fireworks_ai/", max_steps: int = 30, temperature: float = 1.0, max_tokens: int = 4096, **kwargs):
+ super().__init__(rollout_engine, **kwargs)
+
+ self._rollout_processor_server_started = False
+ self._rollout_processor_semaphore = asyncio.Semaphore(1)
+ self._lite_llm_prefix = lite_llm_prefix
+ self._temperature = temperature
+ self._max_tokens = max_tokens
+ self._max_steps = max_steps
+
+ eval_protocol_path = Path(eval_protocol.__file__).parent
+ self._server_script_path = eval_protocol_path / "mcp_servers" / "frozen_lake" / "server.py"
+
+ # Use shared rollout processor across all instances
+ self.rollout_processor = FrozenLakeWorkflow._shared_rollout_processor
+
+ def _build_rollout_processor_config(self):
+ model = self._lite_llm_prefix + self.rollout_engine.model
+ print("model in frozen_lake_flow", model)
+ return RolloutProcessorConfig(
+ completion_params={
+ "model": model,
+ "temperature": self._temperature,
+ "max_tokens": self._max_tokens,
+ },
+ mcp_config_path="",
+ server_script_path=str(self._server_script_path),
+ steps=self._max_steps,
+ semaphore=self._rollout_processor_semaphore,
+ kwargs={"start_server": self._rollout_processor_server_started},
+ )
+
+ async def run(self, task: dict, uid: str, **kwargs) -> Episode:
+ """
+ Execute the frozen lake workflow.
+
+ Args:
+ task: Dict containing frozen lake task data
+ uid: Unique identifier for this episode
+ **kwargs: Additional arguments
+
+ Returns:
+ Episode with trajectory and computed rewards
+ """
+ # Thread-safe server startup (double-checked locking pattern)
+ if not FrozenLakeWorkflow._shared_server_started:
+ # Only acquire lock if server not started yet
+ async with FrozenLakeWorkflow._server_lock:
+ # Check again inside lock (another workflow might have started it)
+ if not FrozenLakeWorkflow._shared_server_started:
+ # First workflow to reach here starts the server
+ self._rollout_processor_server_started = True
+ FrozenLakeWorkflow._shared_server_started = True
+ else:
+ self._rollout_processor_server_started = False
+ else:
+ self._rollout_processor_server_started = False
+
+ self.reset(task=task, uid=uid)
+
+ try:
+ eval_row = self._task_to_evaluation_row(task)
+
+ tasks = self.rollout_processor([eval_row], self._build_rollout_processor_config())
+
+ if not tasks:
+ raise ValueError("MCPGymRolloutProcessor returned no tasks")
+
+ result_row: EvaluationRow = await tasks[0]
+
+ episode = await self._evaluate_and_create_episode(result_row, task, uid)
+
+ return episode
+
+ except Exception as e:
+ # Gracefully handle failures - return a failed Episode instead of crashing
+ print(f"⚠️ Task {uid} failed: {e}")
+
+ failed_episode = Episode(
+ id=uid,
+ task=task,
+ is_correct=False,
+ trajectories=[],
+ metrics={"frozen_lake_reward": 0.0, "error": str(e)},
+ )
+ return failed_episode
+
+ def _task_to_evaluation_row(self, task: dict) -> EvaluationRow:
+ """Convert rllm task dict to eval protocol EvaluationRow."""
+ return EvaluationRow(
+ messages=[Message(role="system", content=task["system_prompt"])],
+ input_metadata=InputMetadata(
+ row_id=task["id"],
+ dataset_info={
+ "environment_context": task["environment_context"],
+ "user_prompt_template": task["user_prompt_template"],
+ },
+ ),
+ )
+
+ async def _evaluate_and_create_episode(
+ self,
+ row: EvaluationRow,
+ task: dict,
+ uid: str,
+ ) -> Episode:
+ """
+ Evaluate the rollout and convert to rllm Episode.
+ """
+ # Call the evaluation function
+ evaluated_row: EvaluationRow = await test_frozen_lake_evaluation(row)
+
+ # Extract reward and metrics from evaluation_result
+ if evaluated_row.evaluation_result is None:
+ raise ValueError("Evaluation function did not return a result")
+
+ reward = evaluated_row.evaluation_result.score
+ reward_info = evaluated_row.evaluation_result.metrics or {}
+
+ def msg_to_dict(msg: Message) -> dict:
+ """Convert eval_protocol Message to chat completion dict."""
+ d = {"role": msg.role, "content": msg.content}
+ if msg.tool_calls:
+ d["tool_calls"] = [
+ {
+ "id": tc.id,
+ "type": tc.type,
+ "function": {
+ "name": tc.function.name,
+ "arguments": tc.function.arguments,
+ },
+ }
+ for tc in msg.tool_calls
+ ]
+ if msg.tool_call_id:
+ d["tool_call_id"] = msg.tool_call_id
+ if msg.name:
+ d["name"] = msg.name
+ return d
+
+ trajectory = Trajectory()
+ all_messages = []
+
+ for msg in row.messages:
+ msg_dict = msg_to_dict(msg)
+ all_messages.append(msg_dict)
+
+ # Create Step with only observation and chat_completions for user or tool message
+ if msg.role in ["user", "tool"]:
+ new_step = Step(observation=str(msg.content or ""), chat_completions=all_messages.copy())
+ trajectory.steps.append(new_step)
+
+ # Create new Step with action/response for assistant message
+ elif msg.role == "assistant":
+ # Extract action: tool calls if present, otherwise message content
+ action_data = msg_dict.get("tool_calls") if msg.tool_calls else str(msg.content or "")
+
+ new_step = Step(
+ model_response=str(msg.content) if msg.content else "",
+ action=action_data,
+ chat_completions=all_messages.copy(),
+ )
+ trajectory.steps.append(new_step)
+
+ # Assign final reward to the last step (sparse reward)
+ if trajectory.steps:
+ trajectory.steps[-1].reward = reward
+ trajectory.steps[-1].info = reward_info
+
+ trajectory.reward = reward
+ trajectory.task = task
+
+ # Create episode
+ episode = Episode(
+ id=uid,
+ task=task,
+ is_correct=(reward == 1.0),
+ trajectories=[trajectory],
+ metrics={"frozen_lake_reward": reward, **reward_info},
+ )
+
+ return episode
+
+ def cleanup(self):
+ """Cleanup MCP server resources."""
+ if self.rollout_processor:
+ self.rollout_processor.cleanup()
+ self.rollout_processor = None
diff --git a/vendor/rllm/examples/eval_protocol/prepare_frozen_lake_data.py b/vendor/rllm/examples/eval_protocol/prepare_frozen_lake_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bc46af04ef4180078fa0322584a0e288d39954
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/prepare_frozen_lake_data.py
@@ -0,0 +1,33 @@
+import random
+
+from datasets import Dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_frozen_lake_data(train_size: int, test_size: int):
+ system_prompt = "You are playing FrozenLake, a grid-based navigation game displayed as a 4x4 text grid. The grid contains: S (Start), F (Frozen safe), H (Hole - deadly), G (Goal). You start at position S and must reach G while avoiding H tiles. In this version, the surface is not slippery so your moves are deterministic. IMPORTANT: When you are at the starting position, you appear as 'S'. When you move to other positions, the hightlighted position will change on the grid. If you step on H, the episode ends with failure. Use the lake_move tool with actions LEFT, DOWN, RIGHT, UP to navigate the grid."
+ user_prompt_template = "Current game state grid:\n{observation}\n\nYou are navigating the 4x4 grid above. Navigate safely to reach the goal 'G' while avoiding holes 'H'. Choose your next move from: LEFT, DOWN, RIGHT, or UP."
+
+ def create_row(idx, seed):
+ return {"id": f"run_{idx}", "system_prompt": system_prompt, "user_prompt_template": user_prompt_template, "environment_context": {"game": "FrozenLake", "map_name": "4x4", "seed": seed}}
+
+ seeds = random.sample(range(1, 1_000_001), train_size + test_size)
+ all_rows = []
+ for i in range(train_size + test_size):
+ all_rows.append(create_row(i, seeds[i]))
+ train_rows = all_rows[:train_size]
+ test_rows = all_rows[train_size:]
+
+ train_dataset = Dataset.from_list(train_rows)
+ test_dataset = Dataset.from_list(test_rows)
+
+ DatasetRegistry.register_dataset("frozen_lake_eval_protocol", train_dataset, "train")
+ DatasetRegistry.register_dataset("frozen_lake_eval_protocol", test_dataset, "test")
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Test dataset size: {len(test_dataset)}")
+
+
+if __name__ == "__main__":
+ prepare_frozen_lake_data(train_size=100, test_size=100)
diff --git a/vendor/rllm/examples/eval_protocol/run_frozen_lake_flow.py b/vendor/rllm/examples/eval_protocol/run_frozen_lake_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f3f45560a3c2a975e274c9c474f78ff68e7899
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/run_frozen_lake_flow.py
@@ -0,0 +1,118 @@
+"""
+Run Frozen Lake Workflow with rllm-fw using EvalProtocolWorkflow
+
+This script demonstrates how to execute frozen lake tasks using rllm-fw's
+AgentWorkflowEngine with the generic EvalProtocolWorkflow.
+"""
+
+import asyncio
+import json
+import os
+from pathlib import Path
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
+from rllm.engine.rollout.openai_engine import OpenAIEngine
+from rllm.workflows.eval_protocol_workflow import EvalProtocolWorkflow
+
+
+def evaluate_results(episodes):
+ """
+ Evaluate the results and compute accuracy metrics.
+
+ Args:
+ episodes: List of Episode objects
+ """
+ total = len(episodes)
+ correct = sum(1 for ep in episodes if ep.is_correct)
+ accuracy = correct / total if total > 0 else 0.0
+
+ print("\n" + "=" * 60)
+ print("EVALUATION RESULTS")
+ print("=" * 60)
+ print(f"Total tasks: {total}")
+ print(f"Correct: {correct}")
+ print(f"Accuracy: {accuracy:.2%}")
+ print()
+
+ for episode in episodes:
+ status = "✅" if episode.is_correct else "❌"
+ reward = episode.metrics.get("evaluation_reward", 0.0)
+ print(f"{status} Task {episode.id}: reward={reward:.3f}")
+
+ print("=" * 60)
+
+ return accuracy
+
+
+async def main():
+ """Main execution function."""
+
+ n_parallel_tasks = 4
+ max_tasks = 4
+ model_id = "accounts/fireworks/models/kimi-k2-instruct"
+
+ # Create dummy rollout_engine (required by Workflow base class but not used)
+ rollout_engine = OpenAIEngine(
+ model=model_id,
+ base_url="https://api.fireworks.ai/inference/v1",
+ api_key=os.getenv("FIREWORKS_API_KEY"),
+ )
+
+ engine = AgentWorkflowEngine(
+ workflow_cls=EvalProtocolWorkflow,
+ workflow_args={
+ "env_path": "eval_protocol.benchmarks.test_frozen_lake",
+ "lite_llm_prefix": "fireworks_ai/",
+ "steps": 30,
+ "temperature": 1.0,
+ "max_tokens": 16384,
+ },
+ rollout_engine=rollout_engine,
+ n_parallel_tasks=n_parallel_tasks,
+ retry_limit=1,
+ )
+
+ test_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "test")
+ tasks = []
+ for i in range(max_tasks):
+ tasks.append(test_dataset[i])
+
+ print("Starting frozen lake workflow execution...")
+ print(f"Model: {model_id}")
+ print(f"Parallel tasks: {n_parallel_tasks}")
+ print()
+
+ try:
+ episodes = await engine.execute_tasks(tasks)
+ for episode in episodes:
+ print(episode.trajectories)
+ accuracy = evaluate_results(episodes)
+
+ output_dir = Path("logs")
+ output_dir.mkdir(exist_ok=True)
+ output_file = output_dir / "frozen_lake_results.json"
+
+ with open(output_file, "w") as f:
+ json.dump([episode.to_dict() for episode in episodes], f, indent=2)
+
+ print(f"\n✅ Results saved to {output_file}")
+
+ return accuracy
+
+ except Exception as e:
+ print(f"❌ Error during execution: {e}")
+ import traceback
+
+ traceback.print_exc()
+ raise
+ finally:
+ engine.shutdown()
+
+
+if __name__ == "__main__":
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ accuracy = asyncio.run(main())
+
+ print(f"\n🎯 Final Accuracy: {accuracy:.2%}")
diff --git a/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.py b/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b3079b466bbd8f157c80338121e8ad1f59f54c
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.py
@@ -0,0 +1,31 @@
+import hydra
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.trainer.agent_trainer import AgentTrainer
+from rllm.workflows.eval_protocol_workflow import EvalProtocolWorkflow
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "train")
+ test_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=EvalProtocolWorkflow,
+ workflow_args={
+ "env_path": "eval_protocol.benchmarks.test_frozen_lake",
+ "lite_llm_prefix": "fireworks_ai/",
+ "steps": 30,
+ "temperature": 1.0,
+ "max_tokens": 32768,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ backend="fireworks",
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.sh b/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.sh
new file mode 100644
index 0000000000000000000000000000000000000000..972d8d04ddd7f54101c0bf09735a8104b289852d
--- /dev/null
+++ b/vendor/rllm/examples/eval_protocol/train_frozen_lake_flow.sh
@@ -0,0 +1,81 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+MODEL_PATH=Qwen/Qwen3-8B
+
+python3 -m examples.eval_protocol.train_frozen_lake_flow \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=16 \
+ data.val_batch_size=48 \
+ data.max_prompt_length=16384 \
+ data.max_response_length=4096 \
+ actor_rollout_ref.model.lora_rank=32 \
+ actor_rollout_ref.model.lora_alpha=32 \
+ actor_rollout_ref.rollout.load_format=safetensors \
+ actor_rollout_ref.model.target_modules=all-linear \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-sum-norm \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=8 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.calculate_log_probs=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.9 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.compact_filtering.enable=True \
+ rllm.compact_filtering.mask_max_prompt_length_exceeded=True \
+ rllm.compact_filtering.mask_max_response_length_exceeded=True \
+ rllm.compact_filtering.mask_max_turns_exceeded=False \
+ rllm.compact_filtering.mask_timeout=True \
+ rllm.rejection_sample.enable=False \
+ rllm.rejection_sample.multiplier=1.0 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.stepwise_advantage.mode=per_step \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-fireworks-workflow' \
+ trainer.experiment_name='fireworks-frozen-lake-8b' \
+ trainer.max_actor_ckpt_to_keep=2 \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ +trainer.n_training_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=1 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ trainer.total_epochs=100 \
+ rllm.workflow.use_workflow=True \
+ fireworks.deployment_id=rllm-qwen3-8b-1 \
+ fireworks.model_id_prefix=test-frozen-lake-qwen3-8b-1
\ No newline at end of file
diff --git a/vendor/rllm/examples/fireworks_math/README.md b/vendor/rllm/examples/fireworks_math/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ecaf3e8031fd7b4ce267e15fa52e8830d403a7fb
--- /dev/null
+++ b/vendor/rllm/examples/fireworks_math/README.md
@@ -0,0 +1,75 @@
+## Before Running Your Training Job
+
+First, install Fireworks SDK and export your FIREWORKS_API_KEY
+
+```bash
+pip install fireworks-ai
+```
+
+```bash
+export FIREWORKS_API_KEY=
+```
+
+Before starting your training, create a **Fireworks deployment**.
+
+We recommend installing **firectl** by following the guide here:
+[firectl Documentation](https://docs.fireworks.ai/tools-sdks/firectl/firectl)
+
+Then, create your deployment:
+
+```bash
+firectl create deployment accounts/fireworks/models/accounts/fireworks/models/qwen3-4b --enable-hot-reload-latest-addon --deployment-id --accelerator-type NVIDIA_H100_80GB
+```
+
+---
+
+## How Fireworks Loads LoRA Adapters
+
+### Inference Against `addon1`
+
+```bash
+firectl load-lora addon1 --replace-merged-addon
+```
+
+---
+
+### Swap to a Second LoRA Adapter (`addon2`)
+
+```bash
+firectl load-lora addon2 --replace-merged-addon
+```
+
+---
+
+### Unload LoRA Adapter (Return to Base Model)
+
+```bash
+firectl unload-lora addon2 --deployment
+```
+
+---
+
+
+## 🚀 After Deployment Is Ready
+
+Once your deployment state becomes **`READY`**, append the following arguments to your **training command**:
+
+```bash
+fireworks.deployment_id= \
+fireworks.model_id_prefix=
+```
+
+Also make sure you set
+```bash
+trainer.save_freq=1 \ # So that Fireworks stores every intermediate checkpoints
+trainer.max_actor_ckpt_to_keep=2 \ # To prevent unnecessary storage usage
++trainer.n_training_gpus_per_node= \ # So that all your local GPUs are only used for training
+```
+
+Currently **firectl** only supports lora reload, make sure you also set
+```bash
+actor_rollout_ref.model.lora_rank=32 \
+actor_rollout_ref.model.lora_alpha=32 \
+actor_rollout_ref.rollout.load_format=safetensors \
+actor_rollout_ref.model.target_modules=all-linear \
+```
diff --git a/vendor/rllm/examples/fireworks_math/prepare_hendrycks_math_dataset.py b/vendor/rllm/examples/fireworks_math/prepare_hendrycks_math_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..719f782c9c01fd1a620f31ca0e23dd9f24f3ef80
--- /dev/null
+++ b/vendor/rllm/examples/fireworks_math/prepare_hendrycks_math_dataset.py
@@ -0,0 +1,42 @@
+from datasets import concatenate_datasets, load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_math_data_hendrycks():
+ # List of all configs to aggregate
+ configs = ["algebra", "counting_and_probability", "geometry", "intermediate_algebra", "number_theory", "prealgebra", "precalculus"]
+
+ # Load and collect all splits
+ datasets = []
+ for config in configs:
+ ds = load_dataset("EleutherAI/hendrycks_math", config, split="train")
+ datasets.append(ds)
+
+ # Aggregate all splits into one dataset
+ all_train_dataset = concatenate_datasets(datasets)
+
+ # Optionally, preprocess if needed (example: rename fields, add source, etc.)
+ def preprocess_fn(example):
+ return {
+ "messages": [
+ {
+ "role": "user",
+ "content": example.get("problem", ""),
+ }
+ ],
+ "ground_truth": example.get("solution", ""),
+ "data_source": "hendrycks_math",
+ }
+
+ all_train_dataset = all_train_dataset.map(preprocess_fn)
+
+ math_500 = load_dataset("HuggingFaceH4/MATH-500", split="test")
+ math_500 = math_500.map(preprocess_fn)
+
+ DatasetRegistry.register_dataset("hendrycks_math", all_train_dataset, "train")
+ DatasetRegistry.register_dataset("math500", math_500, "test")
+
+
+if __name__ == "__main__":
+ prepare_math_data_hendrycks()
diff --git a/vendor/rllm/examples/fireworks_math/train_fireworks_math.py b/vendor/rllm/examples/fireworks_math/train_fireworks_math.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6c539746abc9d9824ab22f3b7b8ef726d335e61
--- /dev/null
+++ b/vendor/rllm/examples/fireworks_math/train_fireworks_math.py
@@ -0,0 +1,44 @@
+import hydra
+
+from rllm.agents.agent import Action
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.rollout.rollout_engine import ModelOutput
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.rewards.reward_types import RewardOutput
+from rllm.trainer.agent_trainer import AgentTrainer
+from rllm.workflows.simple_workflow import SimpleWorkflow
+
+# from rllm.agents.math_agent import MathAgent
+# from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+
+
+def math_workflow_reward_fn(task_info: dict, action: str) -> RewardOutput:
+ if isinstance(action, Action):
+ action = action.action
+ if isinstance(action, ModelOutput):
+ action = action.text
+ return math_reward_fn(task_info, action)
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
+ test_dataset = DatasetRegistry.load_dataset("math500", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=SimpleWorkflow,
+ workflow_args={
+ "reward_function": math_workflow_reward_fn,
+ "max_prompt_length": config.data.max_prompt_length,
+ "max_response_length": config.data.max_response_length,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ backend="fireworks",
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/fireworks_math/train_fireworks_math.sh b/vendor/rllm/examples/fireworks_math/train_fireworks_math.sh
new file mode 100644
index 0000000000000000000000000000000000000000..adcedb2b01689b7cbfb142b9a42d09a06bfc7060
--- /dev/null
+++ b/vendor/rllm/examples/fireworks_math/train_fireworks_math.sh
@@ -0,0 +1,81 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+MODEL_PATH=Qwen/Qwen3-4B
+
+python3 -m examples.fireworks_math.train_fireworks_math \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=8 \
+ data.val_batch_size=512 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=2048 \
+ actor_rollout_ref.model.lora_rank=32 \
+ actor_rollout_ref.model.lora_alpha=32 \
+ actor_rollout_ref.rollout.load_format=safetensors \
+ actor_rollout_ref.model.target_modules=all-linear \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-sum-norm \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=8 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.calculate_log_probs=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.9 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.compact_filtering.enable=False \
+ rllm.compact_filtering.mask_max_prompt_length_exceeded=True \
+ rllm.compact_filtering.mask_max_response_length_exceeded=True \
+ rllm.compact_filtering.mask_max_turns_exceeded=False \
+ rllm.compact_filtering.mask_timeout=True \
+ rllm.rejection_sample.enable=False \
+ rllm.rejection_sample.multiplier=1.0 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.stepwise_advantage.mode=per_step \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-fireworks-workflow' \
+ trainer.experiment_name='fireworks-hendrycks-math-4b' \
+ trainer.max_actor_ckpt_to_keep=2 \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=2 \
+ +trainer.n_training_gpus_per_node=2 \
+ trainer.nnodes=1 \
+ trainer.save_freq=1 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ trainer.total_epochs=100 \
+ rllm.workflow.use_workflow=True \
+ fireworks.deployment_id=wtk15cs9 \
+ fireworks.model_id_prefix=qwen3-4b-math
\ No newline at end of file
diff --git a/vendor/rllm/examples/frozenlake/README.md b/vendor/rllm/examples/frozenlake/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6b7276dcb68852bec78af288c27fb6c9a1a95bb5
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/README.md
@@ -0,0 +1,97 @@
+# FrozenLake Agent Examples
+
+This directory contains examples for training and running FrozenLake RL agents using the rLLM framework. The FrozenLake agent learns to navigate a slippery grid world environment to reach a goal while avoiding holes.
+
+Our examples use the following:
+* Qwen3-4B as the base model
+* Randomly generated FrozenLake environments with varying sizes and slip probabilities
+* GRPO for training
+
+## Environment Overview
+
+FrozenLake is a classic reinforcement learning environment where:
+- **Objective**: Navigate from start position to goal position
+- **Dynamics**: Depending on configuration, the surface may be slippery, causing actions to execute stochastically (e.g., intended moves may go sideways)
+- **Termination**: Episode ends when reaching goal (reward +1) or falling into hole (reward 0)
+- **Parameters**:
+ - `size`: Grid size (e.g., 4x4, 8x8)
+ - `p`: Probability that the agent performs the intended action (remainder is split among unintended directions)
+ - `seed`: Random seed for environment generation
+ - `is_slippery`: Boolean flag controlling whether movement is stochastic (True) or deterministic (False)
+
+## Model Hosting
+
+### Option 1: Using vLLM
+
+Start a vLLM server with OpenAI-compatible API:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model Qwen/Qwen3-4B \
+ --host 0.0.0.0 \
+ --port 30000 \
+ --dtype bfloat16
+```
+
+### Option 2: Using SGLang
+
+```bash
+python -m sglang_router.launch_server \
+ --model-path Qwen/Qwen3-4B \
+ --dp-size 1 \
+ --dtype bfloat16
+# increase dp_size to enable data-parallel processing on multiple GPUs
+```
+
+The server should be accessible at `http://localhost:30000/v1`
+
+## Dataset Preparation
+
+Prepare the FrozenLake datasets (randomly generated environments for training and testing):
+
+```bash
+cd examples/frozenlake
+python prepare_frozenlake_data.py
+```
+
+This will:
+- Generate 10,000 random FrozenLake environments for training
+- Generate 100 random FrozenLake environments for testing
+- Register both datasets with the RLLM DatasetRegistry
+- Each environment has random size (2-10), slip probability (0.6-0.85), and seed
+
+## Running Inference
+
+Once your model server is running and datasets are prepared, you can run inference:
+
+```bash
+cd examples/frozenlake
+python run_frozenlake_agent.py
+```
+
+### Configuration Options
+
+You can modify the inference script parameters:
+
+- `n_parallel_agents`: Number of parallel agents (default: 256)
+- `model_name`: Model to use (default: "Qwen/Qwen3-4B")
+- `base_url`: API server URL (default: "http://localhost:30000/v1")
+- `max_response_length`: Maximum response length (default: 16384)
+- `max_prompt_length`: Maximum prompt length (default: 4096)
+- `temperature`: Sampling temperature (default: 0.6)
+- `top_p`: Top-p sampling (default: 0.95)
+
+The script will:
+1. Load the FrozenLake test dataset (or generate if not exists)
+2. Run parallel inference using the async agent execution engine
+3. Evaluate results and compute success rates
+
+## Training
+
+### Basic Training
+
+To train a FrozenLake agent:
+
+```bash
+bash examples/frozenlake/train_frozenlake_agent.sh
+```
diff --git a/vendor/rllm/examples/frozenlake/prepare_frozenlake_data.py b/vendor/rllm/examples/frozenlake/prepare_frozenlake_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2af779dd26047119f59c890cae4ee4c3f655b57
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/prepare_frozenlake_data.py
@@ -0,0 +1,48 @@
+import numpy as np
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_frozenlake_data(train_size=10000, test_size=100):
+ """
+ Prepare and register FrozenLake datasets for training and testing.
+
+ Args:
+ train_size (int): Number of training examples to generate
+ test_size (int): Number of test examples to generate
+
+ Returns:
+ tuple: (train_dataset, test_dataset)
+ """
+ # Set random seed for reproducibility
+ np.random.seed(42)
+
+ # Generate random parameters for train and test sets
+ train_seeds = np.random.randint(0, 100000, size=train_size)
+ test_seeds = np.random.randint(0, 100000, size=test_size)
+ train_sizes = np.random.randint(2, 10, size=train_size)
+ test_sizes = np.random.randint(2, 10, size=test_size)
+ train_ps = np.random.uniform(0.6, 0.85, size=train_size)
+ test_ps = np.random.uniform(0.6, 0.85, size=test_size)
+
+ def frozenlake_process_fn(seed, size, p, idx):
+ """Process function to create FrozenLake task instances."""
+ return {"seed": seed, "size": size, "p": p, "index": idx, "uid": f"{seed}_{size}_{p}"}
+
+ # Create train and test data
+ train_data = [frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx) for idx, seed in enumerate(train_seeds)]
+ test_data = [frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx) for idx, seed in enumerate(test_seeds)]
+
+ # Register the datasets with the DatasetRegistry
+ train_dataset = DatasetRegistry.register_dataset("frozenlake", train_data, "train")
+ test_dataset = DatasetRegistry.register_dataset("frozenlake", test_data, "test")
+
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_frozenlake_data()
+ print(f"Train dataset: {len(train_dataset.get_data())} examples")
+ print(f"Test dataset: {len(test_dataset.get_data())} examples")
+ print("Sample train example:", train_dataset.get_data()[0])
+ print("Sample test example:", test_dataset.get_data()[0])
diff --git a/vendor/rllm/examples/frozenlake/run_frozenlake_agent.py b/vendor/rllm/examples/frozenlake/run_frozenlake_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f396eeb1f71119543389dc81a321f265620efa5
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/run_frozenlake_agent.py
@@ -0,0 +1,68 @@
+import asyncio
+
+from transformers import AutoTokenizer
+
+from rllm.agents.frozenlake_agent import FrozenLakeAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.agent_execution_engine import AgentExecutionEngine
+from rllm.environments.frozenlake.frozenlake import FrozenLakeEnv
+from rllm.utils import compute_pass_at_k
+
+
+def load_frozenlake_data():
+ if DatasetRegistry.dataset_exists("frozenlake", "test"):
+ test_dataset = DatasetRegistry.load_dataset("frozenlake", "test")
+ return test_dataset.get_data()
+
+ print("FrozenLake datasets not found. Preparing datasets...")
+ from prepare_frozenlake_data import prepare_frozenlake_data
+
+ train_dataset, test_dataset = prepare_frozenlake_data()
+
+ return test_dataset.get_data()
+
+
+if __name__ == "__main__":
+ import os
+
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ n_parallel_agents = 256
+
+ model_name = "Qwen/Qwen3-4B"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
+
+ agent_args = {
+ "max_steps": 10,
+ "use_accumulate_history": True,
+ }
+
+ env_args = {
+ "max_steps": 8,
+ "is_slippery": False,
+ }
+
+ engine = AgentExecutionEngine(
+ agent_class=FrozenLakeAgent,
+ env_class=FrozenLakeEnv,
+ agent_args=agent_args,
+ env_args=env_args,
+ engine_name="openai",
+ tokenizer=tokenizer,
+ sampling_params=sampling_params,
+ rollout_engine_args={
+ "base_url": "http://localhost:30000/v1",
+ "api_key": "None",
+ },
+ max_response_length=16384,
+ max_prompt_length=4096,
+ n_parallel_agents=n_parallel_agents,
+ )
+
+ tasks = load_frozenlake_data()
+
+ results = asyncio.run(engine.execute_tasks(tasks))
+ compute_pass_at_k(results)
diff --git a/vendor/rllm/examples/frozenlake/train_frozenlake_agent.py b/vendor/rllm/examples/frozenlake/train_frozenlake_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb4b47eb51cacbe4d3e2d068b20b3d807a7cca1
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/train_frozenlake_agent.py
@@ -0,0 +1,25 @@
+import hydra
+
+from rllm.agents.frozenlake_agent import FrozenLakeAgent
+from rllm.data import DatasetRegistry
+from rllm.environments.frozenlake.frozenlake import FrozenLakeEnv
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("frozenlake", "train")
+ val_dataset = DatasetRegistry.load_dataset("frozenlake", "test")
+
+ trainer = AgentTrainer(
+ agent_class=FrozenLakeAgent,
+ env_class=FrozenLakeEnv,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/frozenlake/train_frozenlake_agent.sh b/vendor/rllm/examples/frozenlake/train_frozenlake_agent.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f98d654f467855b9047646b26b48332e3043cc03
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/train_frozenlake_agent.sh
@@ -0,0 +1,70 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+python3 -m examples.frozenlake.train_frozenlake_agent \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=64 \
+ data.val_batch_size=128 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=10240 \
+ actor_rollout_ref.model.path=Qwen/Qwen3-0.6B \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-sum \
+ actor_rollout_ref.actor.ppo_mini_batch_size=32 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.7 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=4 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.8 \
+ actor_rollout_ref.rollout.val_kwargs.top_k=20 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='frozenlake-agent-0.6B' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=40 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ rllm.rejection_sample.enable=True \
+ rllm.rejection_sample.multiplier=2 \
+ +rllm.env.env_args.max_steps=8 \
+ +rllm.env.env_args.is_slippery=False \
+ rllm.agent.max_steps=10 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.disable_thinking=False \
+ +rllm.agent.agent_args.max_steps=10 \
+ +rllm.agent.agent_args.use_accumulate_history=True \
+ trainer.total_epochs=1
\ No newline at end of file
diff --git a/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.py b/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..d23b37c124f4a24ea2a5059499e6da47c696453c
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.py
@@ -0,0 +1,32 @@
+import hydra
+
+from rllm.agents.frozenlake_agent import FrozenLakeAgent
+from rllm.data import DatasetRegistry
+from rllm.environments.frozenlake.frozenlake import FrozenLakeEnv
+from rllm.trainer.agent_trainer import AgentTrainer
+from rllm.workflows.cumulative_workflow import CumulativeWorkflow
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("frozenlake", "train")
+ val_dataset = DatasetRegistry.load_dataset("frozenlake", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=CumulativeWorkflow,
+ workflow_args={
+ "agent_cls": FrozenLakeAgent,
+ "agent_args": {"max_steps": 10, "use_accumulate_history": True},
+ "env_cls": FrozenLakeEnv,
+ "env_args": {"max_steps": 8, "is_slippery": False},
+ "max_steps": 10,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.sh b/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f7261fd900c015bfb9f08bcbd292fee593577a10
--- /dev/null
+++ b/vendor/rllm/examples/frozenlake/workflow/train_frozenlake_agent.sh
@@ -0,0 +1,68 @@
+set -x
+
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+python3 -m examples.frozenlake.workflow.train_frozenlake_agent \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=64 \
+ data.val_batch_size=128 \
+ data.max_prompt_length=4096 \
+ data.max_response_length=10240 \
+ actor_rollout_ref.model.path=Qwen/Qwen3-0.6B \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-sum \
+ actor_rollout_ref.actor.ppo_mini_batch_size=32 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.temperature=0.7 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.val_kwargs.n=4 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.8 \
+ actor_rollout_ref.rollout.val_kwargs.top_k=20 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-agent' \
+ trainer.experiment_name='frozenlake-agent-0.6B-workflow' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=40 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ trainer.total_epochs=1 \
+ rllm.rejection_sample.enable=True \
+ rllm.rejection_sample.multiplier=4 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.disable_thinking=False \
+ rllm.accumulate_reasoning=True \
+ rllm.workflow.use_workflow=True \
+ rllm.workflow.n_parallel_tasks=1024
\ No newline at end of file
diff --git a/vendor/rllm/examples/geo3k/geo3k_workflow.py b/vendor/rllm/examples/geo3k/geo3k_workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ab1a5f748f7835c3719ee8c6caca6871e93bd38
--- /dev/null
+++ b/vendor/rllm/examples/geo3k/geo3k_workflow.py
@@ -0,0 +1,71 @@
+import base64
+from io import BytesIO
+
+from PIL import Image
+
+from rllm.agents.agent import Action, Episode, Step, Trajectory
+from rllm.engine import ModelOutput, RolloutEngine
+from rllm.rewards.reward_fn import RewardFunction, math_reward_fn
+from rllm.workflows.simple_workflow import SimpleAgent
+from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow
+
+
+class Geo3KWorkflow(Workflow):
+ def __init__(self, rollout_engine: RolloutEngine, reward_function: RewardFunction = None, encode_as_base64: bool = False, **kwargs):
+ super().__init__(rollout_engine, **kwargs)
+ self.agent = SimpleAgent()
+ self.reward_fn: RewardFunction = reward_function or math_reward_fn
+ self.encode_as_base64 = encode_as_base64
+
+ async def run(self, task: dict, uid: str, **kwargs) -> Episode:
+ self.reset(task, uid)
+
+ question = task.get("question")
+ image = task.get("image", task.get("images", None))
+ if isinstance(image, list) and len(image) > 0:
+ image = image[0]
+ if isinstance(image, dict) and "bytes" in image:
+ image = Image.open(BytesIO(image["bytes"]))
+ assert isinstance(image, Image.Image) or image is None, f"Image must be a PIL.Image.Image, but got {type(image)}"
+
+ if self.encode_as_base64 and image is not None:
+ # format as openai compatible base64 encoded image
+ image = image.convert("RGB")
+ buffer = BytesIO()
+ image.save(buffer, format="JPEG")
+ image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": question},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}},
+ ],
+ }
+ ]
+ elif image is not None:
+ messages = [{"role": "user", "content": question, "images": [image]}]
+ else:
+ messages = [{"role": "user", "content": question}]
+
+ output: ModelOutput = await self.rollout_engine.get_model_response(messages, application_id=uid, **kwargs)
+ action = Action(output.content)
+ reward_result = self.reward_fn(task, action)
+
+ trajectory: Trajectory = self.agent.trajectory
+ trajectory.steps.append(
+ Step(
+ chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
+ thought=output.reasoning,
+ action=action,
+ reward=reward_result.reward,
+ model_output=output,
+ )
+ )
+
+ self.commit(agent=self.agent, reset=True)
+
+ if output.finish_reason == "length":
+ raise TerminationEvent(TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED)
+
+ raise TerminationEvent(TerminationReason.ENV_DONE)
diff --git a/vendor/rllm/examples/geo3k/preprocess_geo3k.py b/vendor/rllm/examples/geo3k/preprocess_geo3k.py
new file mode 100644
index 0000000000000000000000000000000000000000..d650c9b7968b3146214fca688e14b6e28f79a13e
--- /dev/null
+++ b/vendor/rllm/examples/geo3k/preprocess_geo3k.py
@@ -0,0 +1,49 @@
+from datasets import load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+def prepare_geo3k_data():
+ # Load dataset
+ dataset = load_dataset("hiyouga/geometry3k")
+ train_dataset = dataset["train"]
+ test_dataset = dataset["test"]
+
+ # instruction_following = (
+ # r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. "
+ # r"The reasoning process MUST BE enclosed within tags. "
+ # r"The final answer MUST BE put in \boxed{}."
+ # )
+
+ instruction_following = "Let's think step by step and output your final answer in \\boxed{}."
+
+ def process_fn(example, idx):
+ problem = example.pop("problem")
+ prompt = problem + instruction_following
+ answer = example.pop("answer")
+ image = example.pop("images")
+
+ data = {
+ "idx": idx,
+ "data_source": "geo3k",
+ "image": image,
+ "question": prompt,
+ "ground_truth": answer,
+ }
+ return data
+
+ # Preprocess datasets
+ train_dataset = train_dataset.map(function=process_fn, with_indices=True, num_proc=8)
+ test_dataset = test_dataset.map(function=process_fn, with_indices=True, num_proc=8)
+
+ # Register datasets
+ train_dataset = DatasetRegistry.register_dataset("geo3k", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("geo3k", test_dataset, "test")
+
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_geo3k_data()
+ print(train_dataset.get_data_path())
+ print(test_dataset.get_data_path())
diff --git a/vendor/rllm/examples/geo3k/run_geo3k.py b/vendor/rllm/examples/geo3k/run_geo3k.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8031ad7af87b7298aa68b0c5f312422e3656ed
--- /dev/null
+++ b/vendor/rllm/examples/geo3k/run_geo3k.py
@@ -0,0 +1,106 @@
+import asyncio
+import json
+import os
+from copy import deepcopy
+
+from geo3k_workflow import Geo3KWorkflow
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine import AgentWorkflowEngine, OpenAIEngine
+from rllm.rewards.reward_fn import math_reward_fn
+
+
+def load_data(n=1):
+ """Load geo3k data using the Dataset interface."""
+ dataset = DatasetRegistry.load_dataset("geo3k", "test")
+ if dataset is None:
+ print("Dataset not found, preparing dataset...")
+ from prepare_geo3k_data import preprocess_geo3k_data
+
+ _, dataset = preprocess_geo3k_data()
+
+ data = []
+ for idx, example in enumerate(dataset):
+ for i in range(n):
+ data.append(deepcopy(example))
+ return data
+
+
+def evaluate_results(results):
+ """Evaluate the results and compute pass@k metrics."""
+ from collections import defaultdict
+
+ # Create a map to store correct answers per problem
+ problem_correct_map = defaultdict(int)
+ problem_total_map = defaultdict(int)
+
+ # Count correct answers for each problem
+ for episode in results:
+ idx = episode.task["idx"]
+
+ # Use the episode-level is_correct flag set by the workflow
+ is_correct = episode.is_correct
+
+ problem_correct_map[idx] += int(is_correct)
+ problem_total_map[idx] += 1
+
+ # Calculate pass@1 and pass@k
+ k = max(problem_total_map.values()) if problem_total_map else 1
+ total_problems = len(problem_correct_map)
+
+ if total_problems > 0:
+ pass_at_1 = sum(problem_correct_map.values()) / sum(problem_total_map.values())
+ pass_at_k = sum(1 for idx, correct in problem_correct_map.items() if correct > 0) / total_problems
+ else:
+ pass_at_1 = 0.0
+ pass_at_k = 0.0
+
+ print("Total unique problems:", total_problems)
+ print("Average Pass@1 Accuracy:", pass_at_1)
+ print(f"Average Pass@{k} Accuracy:", pass_at_k)
+
+
+if __name__ == "__main__":
+ import os
+
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ n_parallel_tasks = 128
+ model_name = "Qwen/Qwen3-VL-2B-Instruct"
+
+ rollout_engine = OpenAIEngine(
+ model=model_name,
+ max_prompt_length=1024,
+ max_response_length=2048,
+ base_url="http://localhost:30000/v1",
+ api_key="None",
+ sampling_params={"temperature": 0.6, "top_p": 0.95},
+ )
+
+ engine = AgentWorkflowEngine(
+ workflow_cls=Geo3KWorkflow,
+ workflow_args={
+ "reward_function": math_reward_fn,
+ "encode_as_base64": True,
+ },
+ rollout_engine=rollout_engine,
+ config=None,
+ n_parallel_tasks=n_parallel_tasks,
+ retry_limit=1,
+ )
+
+ tasks = load_data(n=4)
+ print(f"Loaded {len(tasks)} geo3k tasks")
+
+ results = asyncio.run(engine.execute_tasks(tasks))
+
+ # Evaluate results (rewards are already assigned in the workflow)
+ print("Evaluating results...")
+ evaluate_results(results)
+
+ # Save results
+ os.makedirs("logs", exist_ok=True)
+ with open("logs/geo3k.json", "w") as f:
+ json.dump([episode.to_dict() for episode in results], f, indent=4)
+
+ print("\nResults saved to logs/geo3k.json")
diff --git a/vendor/rllm/examples/geo3k/train_geo3k.py b/vendor/rllm/examples/geo3k/train_geo3k.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c832c7a4431c9f169e1b4e2e0917ddbe821d75
--- /dev/null
+++ b/vendor/rllm/examples/geo3k/train_geo3k.py
@@ -0,0 +1,27 @@
+import hydra
+
+from examples.geo3k.geo3k_workflow import Geo3KWorkflow
+from rllm.data.dataset import DatasetRegistry
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("geo3k", "train")
+ test_dataset = DatasetRegistry.load_dataset("geo3k", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=Geo3KWorkflow,
+ workflow_args={
+ "reward_function": math_reward_fn,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/geo3k/train_geo3k.sh b/vendor/rllm/examples/geo3k/train_geo3k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f37f2ab1bdf04337301f943e40b116bdd208af96
--- /dev/null
+++ b/vendor/rllm/examples/geo3k/train_geo3k.sh
@@ -0,0 +1,57 @@
+set -x
+
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+
+python3 -m examples.geo3k.train_geo3k \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=32 \
+ data.val_batch_size=512 \
+ data.max_prompt_length=1024 \
+ data.max_response_length=2048 \
+ actor_rollout_ref.model.path=Qwen/Qwen3-VL-2B-Instruct \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.strategy=fsdp \
+ actor_rollout_ref.actor.loss_agg_mode=token-mean \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=16 \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.28 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.9 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='geo3k' \
+ trainer.experiment_name='qwen3-vl-2b-instruct' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=1000 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ rllm.workflow.use_workflow=True \
+ trainer.total_epochs=3
\ No newline at end of file
diff --git a/vendor/rllm/examples/gsm8k_lora/prepare_gsm8k_data.py b/vendor/rllm/examples/gsm8k_lora/prepare_gsm8k_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..f309ab1044c235ebd38cdcf67695db6a4bb78c17
--- /dev/null
+++ b/vendor/rllm/examples/gsm8k_lora/prepare_gsm8k_data.py
@@ -0,0 +1,40 @@
+import re
+
+from datasets import load_dataset
+
+from rllm.data.dataset import DatasetRegistry
+
+
+# Adapted from verl/examples/data_preprocess/gsm8k.py
+def extract_solution(solution_str):
+ solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
+ assert solution is not None
+ final_solution = solution.group(0)
+ final_solution = final_solution.split("#### ")[1].replace(",", "")
+ return final_solution
+
+
+def prepare_gsm8k_data():
+ gsm8k_dataset = load_dataset("openai/gsm8k", "main")
+ train_dataset = gsm8k_dataset["train"]
+ test_dataset = gsm8k_dataset["test"]
+
+ def preprocess_fn(example, idx):
+ return {
+ "question": example["question"],
+ "ground_truth": extract_solution(example["answer"]),
+ "data_source": "gsm8k",
+ }
+
+ train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
+ test_dataset = test_dataset.map(preprocess_fn, with_indices=True)
+
+ train_dataset = DatasetRegistry.register_dataset("gsm8k", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("gsm8k", test_dataset, "test")
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_gsm8k_data()
+ print(train_dataset)
+ print(test_dataset)
diff --git a/vendor/rllm/examples/gsm8k_lora/train_gsm8k_lora.sh b/vendor/rllm/examples/gsm8k_lora/train_gsm8k_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..166a979161087c992744a54e5e7be4dbb6e08dc0
--- /dev/null
+++ b/vendor/rllm/examples/gsm8k_lora/train_gsm8k_lora.sh
@@ -0,0 +1,64 @@
+set -x
+
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export VLLM_USE_V1=1
+export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+MODEL_PATH=Qwen/Qwen2.5-3B-Instruct
+
+python3 -m examples.gsm8k_lora.train_gsm8k_lora \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=8 \
+ data.val_batch_size=512 \
+ data.max_prompt_length=512 \
+ data.max_response_length=1024 \
+ actor_rollout_ref.model.path=$MODEL_PATH \
+ actor_rollout_ref.model.lora_rank=32 \
+ actor_rollout_ref.model.lora_alpha=32 \
+ actor_rollout_ref.model.target_modules=all-linear \
+ actor_rollout_ref.actor.optim.lr=5e-6 \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.loss_agg_mode=token-mean \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=8 \
+ actor_rollout_ref.actor.use_dynamic_bsz=False \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20000 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio_high=0.2 \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode="async" \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.enforce_eager=True \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.rollout.temperature=0.7 \
+ actor_rollout_ref.rollout.top_p=0.95 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=False \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ rllm.mask_truncated_samples=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-experiment' \
+ trainer.experiment_name='gsm8k-lora' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=4 \
+ trainer.nnodes=1 \
+ trainer.save_freq=1000 \
+ trainer.test_freq=10 \
+ trainer.default_hdfs_dir=null \
+ rllm.agent.max_steps=1 \
+ rllm.stepwise_advantage.enable=False \
+ trainer.total_epochs=100
\ No newline at end of file
diff --git a/vendor/rllm/examples/gsm8k_lora/train_gsm8k_with_lora.py b/vendor/rllm/examples/gsm8k_lora/train_gsm8k_with_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af8017ea7d805033c2f26fa062924e7cd0cf6aa
--- /dev/null
+++ b/vendor/rllm/examples/gsm8k_lora/train_gsm8k_with_lora.py
@@ -0,0 +1,30 @@
+import hydra
+
+from rllm.agents.math_agent import MathAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("gsm8k", "train")
+ test_dataset = DatasetRegistry.load_dataset("gsm8k", "test")
+
+ env_args = {"reward_fn": math_reward_fn}
+
+ trainer = AgentTrainer(
+ agent_class=MathAgent,
+ agent_args={},
+ env_args=env_args,
+ env_class=SingleTurnEnvironment,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/examples/math_tinker/math_agent_with_fewshot.py b/vendor/rllm/examples/math_tinker/math_agent_with_fewshot.py
new file mode 100644
index 0000000000000000000000000000000000000000..6133afbfa03492481d2a81009fc8c863f1dcb076
--- /dev/null
+++ b/vendor/rllm/examples/math_tinker/math_agent_with_fewshot.py
@@ -0,0 +1,143 @@
+"""
+MathAgent with few-shot prompting support to match tinker-cookbook math_rl.
+
+This agent variant includes:
+1. Few-shot prefix with a standard example (strawberry)
+2. Instruction text matching math_rl: " Write your answer in \\boxed{} format."
+"""
+
+import copy
+from typing import Any
+
+from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
+
+
+class MathAgentWithFewshot(BaseAgent):
+ """
+ A math agent with few-shot prompting that matches tinker-cookbook math_rl behavior.
+ """
+
+ # Standard few-shot example from tinker-cookbook math_rl
+ STANDARD_FEWSHOT_PREFIX = [
+ {
+ "role": "user",
+ "content": "How many r's are in strawberry? Provide a numerical answer without units, written inside \\boxed{}.",
+ },
+ {
+ "role": "assistant",
+ "content": "Let's spell the word out and number all the letters: 1) s 2) t 3) r 4) a 5) w 6) b 7) e 8) r 9) r 10) y. We have r's at positions 3, 8, and 9. \\boxed{3}",
+ },
+ ]
+
+ def __init__(self, accumulate_thinking=True, use_fewshot=True):
+ """
+ Initialize the MathAgent with few-shot support.
+
+ Args:
+ accumulate_thinking: Whether to accumulate thinking in conversation history
+ use_fewshot: Whether to use few-shot prompting
+ """
+ self._trajectory = Trajectory()
+ self.messages = []
+ self.accumulate_thinking = accumulate_thinking
+ self.use_fewshot = use_fewshot
+
+ # Add few-shot prefix if enabled
+ if self.use_fewshot:
+ self.messages.extend(copy.deepcopy(self.STANDARD_FEWSHOT_PREFIX))
+
+ def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
+ """Process environment feedback and update internal state."""
+
+ # Reward update for existing step (None OR empty dict)
+ if observation is None or (isinstance(observation, dict) and observation == {}):
+ if self.trajectory.steps:
+ cur_step = self.get_current_state()
+ cur_step.reward = reward
+ cur_step.done = done
+ cur_step.info = info
+ return
+
+ # Update reward/done/info on existing step if we have steps already
+ if self.trajectory.steps:
+ cur_step = self.get_current_state()
+ cur_step.reward = reward
+ cur_step.done = done
+ cur_step.info.update(info)
+
+ if done:
+ return
+
+ # This is a new observation, create a new step
+ if isinstance(observation, dict):
+ if "question" not in observation:
+ raise ValueError(f"Observation dict missing required 'question' field: {observation}")
+ # Match math_rl instruction text exactly
+ formatted_observation = observation["question"] + " Provide a numerical answer without units, written inside \\boxed{}."
+ elif isinstance(observation, str):
+ formatted_observation = observation + " Provide a numerical answer without units, written inside \\boxed{}."
+ else:
+ raise ValueError(f"Invalid observation type: {type(observation)}")
+
+ self.messages.append({"role": "user", "content": formatted_observation})
+
+ new_step = Step(observation=formatted_observation)
+ self._trajectory.steps.append(new_step)
+
+ def update_from_model(self, response: str, **kwargs) -> Action:
+ """
+ Updates the agent's internal state based on the model's response.
+ """
+
+ # Update the latest step
+ self.messages.append({"role": "assistant", "content": response})
+
+ cur_step = self.get_current_state()
+ cur_step.chat_completions = self.chat_completions
+ cur_step.model_response = response
+
+ if response.count(" +📃 Blog Post +• + 🤗 HF Dataset (R2E-Gym) +• + +🔥 WandB Logs +• + 🤗 DeepSWE-Preview +• + 📈 Evaluation Logs +• + 🌐 Project Page +• + 🧑💻 Code +
+ +Maud is best pony
\r\n\t\tMaud is
\r\n\t\r\n\r\n", "target": "html! {\r\n\t(DOCTYPE)\r\n\thtml {\r\n\t\thead {\r\n\t\t\tmeta charset=\"UTF-8\"\r\n\t\t\ttitle \"This is title\"\r\n\t\t}\r\n\t\tbody {\r\n\t\t\th1.hello \"Hello maud!\"\r\n\t\t\tp#best.truth \"Maud is best pony\"\r\n\t\t\tdiv#beautiful.truth \"Maud is the most beautiful pony\"\r\n\t\t\tbr;\r\n\t\t\tspan#adorable.truth \"Maud is adorable\"\r\n\t\t\tbr;\r\n\t\t\tp.anything {\r\n\t\t\t\t\"Maud is\"\r\n\t\t\t\tinput#something type=\"text\" value=\"anything\"\r\n\t\t\t}\r\n\t\t}\r\n\t}\r\n}\r\n", "metadata": {"href": "/challenges/5fe3ddcb8a967b00099d7e59", "title": "RUST HTML to maud.", "detail": "RUST HTML to maud. maud is one of the template engines for server-side rendering based on rust. rusty practice on vimgolf", "challenge_hash": "5fe3ddcb8a967b00099d7e59"}, "id": "5fe3ddcb8a967b00099d7e59"} +{"input": "[package]\nname = \"rust-web\"\nversion = \"0.1.0\"\nauthors = [\"The Rust Developers\"]\nedition = \"2018\"\n\n[dependencies]\nlazy_static = \"1.2.0\"\nfluent = \"0.13\"\nfluent-bundle = \"0.6.0\"\nfluent-syntax = \"0.10.0\"\nfluent-locale = \"0.10.1\"\nhandlebars-fluent = \"0.2.0\"\nrand = \"0.8\"\nregex = \"1\"\nrocket = \"0.4.6\"\nserde = { version = \"1.0\", features = [\"derive\"] }\nserde_yaml = \"0.8.14\"\nsass-rs = \"0.2.1\"\nreqwest = { version = \"0.10.10\", features = [\"blocking\", \"json\"] }\ntoml = \"0.5\"\nserde_json = \"1.0\"\nrust_team_data = { git = \"https://github.com/rust-lang/team\" }\nhandlebars = \"1.1.0\"\nsiphasher = \"0.3.3\"\npercent-encoding = \"2.1.0\"\n\n[dependencies.rocket_contrib]\nversion = \"0.4\"\ndefault-features = false\nfeatures = [\"handlebars_templates\"]\n", "target": "[package]\nname = \"rust-web\"\nversion = \"0.1.0\"\nauthors = [\"The Rust Developers\"]\nedition = \"2018\"\n\n[dependencies]\nlazy_static = \"*\"\nfluent = \"*\"\nfluent-bundle = \"*\"\nfluent-syntax = \"*\"\nfluent-locale = \"*\"\nhandlebars-fluent = \"*\"\nrand = \"*\"\nregex = \"*\"\nrocket = \"*\"\nserde = { version = \"*\", features = [\"derive\"] }\nserde_yaml = \"*\"\nsass-rs = \"*\"\nreqwest = { version = \"*\", features = [\"blocking\", \"json\"] }\ntoml = \"*\"\nserde_json = \"*\"\nrust_team_data = { git = \"https://github.com/rust-lang/team\" }\nhandlebars = \"*\"\nsiphasher = \"*\"\npercent-encoding = \"*\"\n\n[dependencies.rocket_contrib]\nversion = \"*\"\ndefault-features = false\nfeatures = [\"handlebars_templates\"]\n", "metadata": {"href": "/challenges/5fe326eb11ba250006cbf2cd", "title": "RUST Cargo.toml version to last", "detail": "RUST Cargo.toml version to last. rusty practice on vimgolf", "challenge_hash": "5fe326eb11ba250006cbf2cd"}, "id": "5fe326eb11ba250006cbf2cd"} +{"input": "\uac00\r\n\t\r\n\u5bb6 \u4f73 \u8857 \u53ef \u6b4c \u52a0 \u50f9 \u5047 \u67b6 \u6687\r\n\t\r\n\u5609 \u5ac1 \u7a3c \u8cc8 \u99d5 \u4f3d \u8fe6 \u67ef \u5475 \u54e5 \u67b7 \u73c2 \u75c2 \u82db \u8304 \u8888 \u8a36 \u8dcf \u8efb \u54ff \u560f \u8238 \u73c8 \u5777 \u659d \u698e \u6a9f \u7b33 \u801e \u846d \u8b0c \u6cc7\r\n\t\r\n\u2013\r\n\uac01\r\n\t\r\n\u5404 \u89d2 \u811a \u95a3 \u5374 \u89ba \u523b\r\n\t\r\n\u73cf \u606a \u6bbc \u6128 \u537b \u54af \u57c6 \u6409 \u64f1 \u6877\r\n\t\r\n\u6128(\u6164)\r\n\uac04\r\n\t\r\n\u5e72 \u9593 \u770b \u520a \u809d \u5e79 \u7c21 \u59e6 \u61c7\r\n\t\r\n\u826e \u4f83 \u6746 \u7395 \u7aff \u63c0 \u8aeb \u58be \u681e \u5978 \u67ec \u6f97 \u78f5 \u7a08 \u8271 \u7647 \u5fd3 \u77f8 \u5058 \u6173 \u69a6 \u79c6 \u831b \u884e \u8d76 \u8fc0 \u9f66\r\n\t\r\n\u6746(\u687f) \u7647(\u764e)\r\n\uac08\r\n\t\r\n\u6e34\r\n\t\r\n\u845b \u4e6b \u559d \u66f7 \u78a3 \u7aed \u8910 \u874e \u97a8 \u5676 \u696c \u79f8 \u7faf \u880d\r\n\t\r\n\u2013\r\n\uac10\r\n\t\r\n\u7518 \u6e1b \u611f \u6562 \u76e3 \u9451\r\n\t\r\n\u52d8 \u582a \u77b0 \u574e \u5d4c \u61be \u6221 \u67d1 \u6a44 \u75b3 \u7d3a \u90af \u9f95 \u73aa \u5769 \u57f3 \u5d41 \u5f07 \u61a8 \u64bc \u6b3f \u6b5b \u6cd4 \u6de6 \u6f89 \u77d9 \u8f57 \u9163 \u9e7b\r\n\t\r\n\u9451(\u9452)\r\n\uac11\r\n\t\r\n\u7532\r\n\t\r\n\u9240 \u5323 \u5cac \u80db \u9598\r\n\t\r\n\u2013\r\n\uac1c\r\n\t\r\n\u6539 \u7686 \u500b \u958b \u4ecb \u6168 \u69ea \u84cb\r\n\t\r\n\u4ef7 \u51f1 \u6137 \u6f11 \u584f \u613e \u75a5 \u82a5 \u8c48 \u93a7 \u73a0 \u5274 \u5303 \u63e9 \u69e9 \u78d5 \u95d3\r\n\t\r\n\u500b(\u7b87) \u84cb(\u76d6)\r\n\uac1d\r\n\t\r\n\u5ba2\r\n\t\r\n\u5580\r\n\t\r\n\u2013\r\n\uac31\r\n\t\r\n\u66f4\r\n\t\r\n\u5751 \u7cb3 \u7fb9 \u785c \u8ce1 \u93d7\r\n\t\r\n\u2013\r\n\uac39\r\n\t\r\n\u2013\r\n\t\r\n\u91b5\r\n\t\r\n\u2013\r\n\uac70\r\n\t\r\n\u53bb \u5de8 \u5c45 \u8eca \u64e7 \u8ddd \u62d2 \u64da\r\n\t\r\n\u6e20 \u907d \u9245 \u70ac \u5028 \u636e \u795b \u8e1e \u92f8 \u99cf \u547f \u661b \u79ec \u7b65 \u7c67 \u80e0 \u8152 \u82e3 \u8392 \u8556 \u8627 \u88aa \u88fe\r\n\t\r\n\u2013\r\n\uac74\r\n\t\r\n\u5efa \u4e7e \u4ef6 \u5065\r\n\t\r\n\u5dfe \u8654 \u6957 \u9375 \u6106 \u8171 \u8e47 \u9a2b \u6434 \u6e55 \u8e3a \u63f5 \u728d \u7777 \u8930 \u8b07 \u97ac\r\n\t\r\n\u5efa(\u4896) \u4e7e(\u6f27 \u4e79)\r\n\uac78\r\n\t\r\n\u5091 \u4e5e\r\n\t\r\n\u6840 \u4e6c \u6705 \u69a4\r\n\t\r\n\u5091(\u6770)\r\n\uac80\r\n\t\r\n\u5109 \u528d \u6aa2\r\n\t\r\n\u77bc \u9210 \u9ed4 \u64bf \u82a1\r\n\t\r\n\u528d(\u5292)\r\n\uac81\r\n\t\r\n\u2013\r\n\t\r\n\u52ab \u602f \u8ff2 \u5226 \u5227\r\n\t\r\n\u2013\r\n\uac8c\r\n\t\r\n\u2013\r\n\t\r\n\u63ed \u5048 \u61a9\r\n\t\r\n\u2013\r\n\uaca9\r\n\t\r\n\u683c \u64ca \u6fc0 \u9694\r\n\t\r\n\u6a84 \u8188 \u89a1 \u630c \u6bc4 \u95c3 \u9abc \u9b32 \u9d03\r\n\t\r\n\u2013\r\n\uacac\r\n\t\r\n\u72ac \u898b \u5805 \u80a9 \u7d79 \u9063 \u727d\r\n\t\r\n\u9d51 \u7504 \u7e6d \u8b74 \u72f7 \u754e \u7b67 \u7e33 \u7e7e \u7f82 \u8832 \u9c39\r\n\t\r\n\u2013\r\n\uacb0\r\n\t\r\n\u6c7a \u7d50 \u6f54 \u7f3a\r\n\t\r\n\u8a23 \u6289 \u36c3 \u7106 \u8ffc \u73a6 \u9365 \u89d6 \u95cb\r\n\t\r\n\u6f54(\u34d7 \u6d2f) \u9365(\u493f)\r\n\uacb8\r\n\t\r\n\u517c \u8b19\r\n\t\r\n\u938c \u614a \u7b9d \u9257 \u55db \u69cf \u5094 \u5c92 \u62d1 \u6b49 \u7e11 \u84b9 \u9eda \u9f38 \u5d70\r\n\t\r\n\u2013\r\n\uad18\r\n\t\r\n\u639b\r\n\t\r\n\u5366 \u7f6b \u54bc \u6302 \u7f63 \u8a7f\r\n\t\r\n\u2013\r\n\uad34\r\n\t\r\n\u584a \u6127 \u602a \u58de\r\n\t\r\n\u4e56 \u5080 \u62d0 \u69d0 \u9b41 \u5abf \u5ee5 \u7470 \u749d \u84af \u8958\r\n\t\r\n\u2013\r\n\uad35\r\n\t\r\n\u2013\r\n\t\r\n\u9998\r\n\t\r\n\u2013\r\n\uad49\r\n\t\r\n\u2013\r\n\t\r\n\u5b8f \u7d18 \u80b1 \u8f5f \u6d64 \u89e5 \u8a07 \u958e\r\n\t\r\n\u2013\r\n\uad6d\r\n\t\r\n\u570b \u83ca \u5c40\r\n\t\r\n\u97a0 \u97ab \u9eb4 \ud856\udf97 \u530a \u63ac \u8dfc \u9eaf \u8d9c\r\n\t\r\n\u570b(\u56fd)\r\n\uad70\r\n\t\r\n\u541b \u90e1 \u8ecd \u7fa4\r\n\t\r\n\u7a98 \u88d9 \u6343 \u687e \u76b8\r\n\t\r\n\u2013\r\n\uad74\r\n\t\r\n\u5c48\r\n\t\r\n\u7a9f \u5800 \u6398 \u5014 \u5d1b \u6dc8 \u8a58\r\n\t\r\n\u2013\r\n\uad81\r\n\t\r\n\u5f13 \u5bae \u7aae\r\n\t\r\n\u8eac \u7a79 \u828e \u8eb3\r\n\t\r\n\u2013\r\n\uad8c\r\n\t\r\n\u5238 \u6b0a \u52f8 \u5377 \u62f3\r\n\t\r\n\u5708 \u7737 \u5026 \u6372 \u6dc3 \u52cc \u60d3 \u68ec \u7760 \u7da3 \u8737\r\n\t\r\n\u6b0a(\u6a29)\r\n\uad90\r\n\t\r\n\u53a5\r\n\t\r\n\u95d5 \u7357 \u8568 \u8e76\r\n\t\r\n\u2013\r\n\uada4\r\n\t\r\n\u8ecc\r\n\t\r\n\u673a \u6ac3 \u6f70 \u8a6d \u994b \u4f79 \u51e0 \u5282 \u5331 \u6192 \u6485 \u6a3b \u6c3f \u7c0b \u7e62 \u8dea \u95e0 \u993d \u9e82\r\n\t\r\n\u2013\r\n\uadc0\r\n\t\r\n\u8cb4 \u6b78 \u9b3c\r\n\t\r\n\u53e5 \u6677 \u4925 \u9f9c\r\n\t\r\n\u9f9c(\u9f9c)\r\n\uaddc\r\n\t\r\n\u53eb \u898f \u7cfe\r\n\t\r\n\u572d \u594e \u73ea \u63c6 \u9035 \u7aba \u8475 \u69fb \u7845 \u7ac5 \u8d73 \u95a8 \u90bd \u5ae2 \u6e40 \u8325 \u7143 \u5232 \u5b00 \u5dcb \u668c \u694f \u6a1b \u6f59 \u777d \u866f \u8dec \u95da \u980d \u9997 \u9a24\r\n\t\r\n\u7cfe(\u7cfa)\r\n\uade0\r\n\t\r\n\u5747 \u83cc\r\n\t\r\n\u7547 \u921e \u7b60 \u52fb \u9f9c \u89a0 \u56f7 \u9e8f\r\n\t\r\n\u52fb(\u5300) \u9f9c(\u9f9c)\r\n\uade4\r\n\t\r\n\u2013\r\n\t\r\n\u6a58\r\n\t\r\n\u2013\r\n\uadf9\r\n\t\r\n\u6975 \u514b \u5287\r\n\t\r\n\u524b \u9699 \u621f \u68d8 \u4e9f \u5c05 \u5c50 \u90c4\r\n\t\r\n\u2013\r\n\uadfc\r\n\t\r\n\u8fd1 \u52e4 \u6839 \u65a4 \u50c5 \u8b39\r\n\t\r\n\u5890 \u6f0c \u69ff \u747e \u5ae4 \u7b4b \u52a4 \u61c3 \u82b9 \u83eb \u89b2 \u9949 \u5df9 \u5ed1 \u89d4 \u8ddf \u91ff \u9773 \u5807\r\n\t\r\n\u2013\r\n\uae00\r\n\t\r\n\u2013\r\n\t\r\n\u5951 \u3515\r\n\t\r\n\u2013\r\n\uae08\r\n\t\r\n\u91d1 \u4eca \u7981 \u9326 \u79bd \u7434\r\n\t\r\n\u887e \u895f \u6611 \u5997 \u64d2 \u6a8e \u82a9 \u887f \u552b \u5664 \u5d94 \u7b12 \u9ec5\r\n\t\r\n\u2013\r\n\uae09\r\n\t\r\n\u53ca \u7d66 \u6025 \u7d1a\r\n\t\r\n\u6c72 \u4f0b \u6271 \u573e \u5c8c \u7680 \u790f \u7b08 \u82a8\r\n\t\r\n\u2013\r\n\uae0d\r\n\t\r\n\u80af\r\n\t\r\n\u4e98 \u5162 \u77dc \u6b91\r\n\t\r\n\u4e98(\u4e99)\r\n\uae34\r\n\t\r\n\u7dca\r\n\t\r\n\u2013\r\n\t\r\n\u2013\r\n\uae38\r\n\t\r\n\u5409\r\n\t\r\n\u4f76 \u6854 \u59de \u62ee \u86e3\r\n\t\r\n\u2013\r\n\uae40\r\n\t\r\n\u2013\r\n\t\r\n\u91d1\r\n\t\r\n\u2013\r\n\ub07d\r\n\t\r\n\u2013\r\n\t\r\n\u55ab\r\n\t\r\n\u2013", "target": "\"\uac00\":[\"\u5bb6\", \"\u4f73\", \"\u8857\", \"\u53ef\", \"\u6b4c\", \"\u52a0\", \"\u50f9\", \"\u5047\", \"\u67b6\", \"\u6687\", \"\u5609\", \"\u5ac1\", \"\u7a3c\", \"\u8cc8\", \"\u99d5\", \"\u4f3d\", \"\u8fe6\", \"\u67ef\", \"\u5475\", \"\u54e5\", \"\u67b7\", \"\u73c2\", \"\u75c2\", \"\u82db\", \"\u8304\", \"\u8888\", \"\u8a36\", \"\u8dcf\", \"\u8efb\", \"\u54ff\", \"\u560f\", \"\u8238\", \"\u73c8\", \"\u5777\", \"\u659d\", \"\u698e\", \"\u6a9f\", \"\u7b33\", \"\u801e\", \"\u846d\", \"\u8b0c\", \"\u6cc7\"],\r\n\"\uac01\":[\"\u5404\", \"\u89d2\", \"\u811a\", \"\u95a3\", \"\u5374\", \"\u89ba\", \"\u523b\", \"\u73cf\", \"\u606a\", \"\u6bbc\", \"\u6128\", \"\u537b\", \"\u54af\", \"\u57c6\", \"\u6409\", \"\u64f1\", \"\u6877\", \"\u6164\"],\r\n\"\uac04\":[\"\u5e72\", \"\u9593\", \"\u770b\", \"\u520a\", \"\u809d\", \"\u5e79\", \"\u7c21\", \"\u59e6\", \"\u61c7\", \"\u826e\", \"\u4f83\", \"\u6746\", \"\u7395\", \"\u7aff\", \"\u63c0\", \"\u8aeb\", \"\u58be\", \"\u681e\", \"\u5978\", \"\u67ec\", \"\u6f97\", \"\u78f5\", \"\u7a08\", \"\u8271\", \"\u7647\", \"\u5fd3\", \"\u77f8\", \"\u5058\", \"\u6173\", \"\u69a6\", \"\u79c6\", \"\u831b\", \"\u884e\", \"\u8d76\", \"\u8fc0\", \"\u9f66\", \"\u687f\", \"\u764e\"],\r\n\"\uac08\":[\"\u6e34\", \"\u845b\", \"\u4e6b\", \"\u559d\", \"\u66f7\", \"\u78a3\", \"\u7aed\", \"\u8910\", \"\u874e\", \"\u97a8\", \"\u5676\", \"\u696c\", \"\u79f8\", \"\u7faf\", \"\u880d\"],\r\n\"\uac10\":[\"\u7518\", \"\u6e1b\", \"\u611f\", \"\u6562\", \"\u76e3\", \"\u9451\", \"\u52d8\", \"\u582a\", \"\u77b0\", \"\u574e\", \"\u5d4c\", \"\u61be\", \"\u6221\", \"\u67d1\", \"\u6a44\", \"\u75b3\", \"\u7d3a\", \"\u90af\", \"\u9f95\", \"\u73aa\", \"\u5769\", \"\u57f3\", \"\u5d41\", \"\u5f07\", \"\u61a8\", \"\u64bc\", \"\u6b3f\", \"\u6b5b\", \"\u6cd4\", \"\u6de6\", \"\u6f89\", \"\u77d9\", \"\u8f57\", \"\u9163\", \"\u9e7b\", \"\u9452\"],\r\n\"\uac11\":[\"\u7532\", \"\u9240\", \"\u5323\", \"\u5cac\", \"\u80db\", \"\u9598\"],\r\n\"\uac1c\":[\"\u6539\", \"\u7686\", \"\u500b\", \"\u958b\", \"\u4ecb\", \"\u6168\", \"\u69ea\", \"\u84cb\", \"\u4ef7\", \"\u51f1\", \"\u6137\", \"\u6f11\", \"\u584f\", \"\u613e\", \"\u75a5\", \"\u82a5\", \"\u8c48\", \"\u93a7\", \"\u73a0\", \"\u5274\", \"\u5303\", \"\u63e9\", \"\u69e9\", \"\u78d5\", \"\u95d3\", \"\u7b87\", \"\u76d6\"],\r\n\"\uac1d\":[\"\u5ba2\", \"\u5580\"],\r\n\"\uac31\":[\"\u66f4\", \"\u5751\", \"\u7cb3\", \"\u7fb9\", \"\u785c\", \"\u8ce1\", \"\u93d7\"],\r\n\"\uac39\":[\"\u91b5\"],\r\n\"\uac70\":[\"\u53bb\", \"\u5de8\", \"\u5c45\", \"\u8eca\", \"\u64e7\", \"\u8ddd\", \"\u62d2\", \"\u64da\", \"\u6e20\", \"\u907d\", \"\u9245\", \"\u70ac\", \"\u5028\", \"\u636e\", \"\u795b\", \"\u8e1e\", \"\u92f8\", \"\u99cf\", \"\u547f\", \"\u661b\", \"\u79ec\", \"\u7b65\", \"\u7c67\", \"\u80e0\", \"\u8152\", \"\u82e3\", \"\u8392\", \"\u8556\", \"\u8627\", \"\u88aa\", \"\u88fe\"],\r\n\"\uac74\":[\"\u5efa\", \"\u4e7e\", \"\u4ef6\", \"\u5065\", \"\u5dfe\", \"\u8654\", \"\u6957\", \"\u9375\", \"\u6106\", \"\u8171\", \"\u8e47\", \"\u9a2b\", \"\u6434\", \"\u6e55\", \"\u8e3a\", \"\u63f5\", \"\u728d\", \"\u7777\", \"\u8930\", \"\u8b07\", \"\u97ac\", \"\u4896\", \"\u6f27\", \"\u4e79\"],\r\n\"\uac78\":[\"\u5091\", \"\u4e5e\", \"\u6840\", \"\u4e6c\", \"\u6705\", \"\u69a4\", \"\u6770\"],\r\n\"\uac80\":[\"\u5109\", \"\u528d\", \"\u6aa2\", \"\u77bc\", \"\u9210\", \"\u9ed4\", \"\u64bf\", \"\u82a1\", \"\u5292\"],\r\n\"\uac81\":[\"\u52ab\", \"\u602f\", \"\u8ff2\", \"\u5226\", \"\u5227\"],\r\n\"\uac8c\":[\"\u63ed\", \"\u5048\", \"\u61a9\"],\r\n\"\uaca9\":[\"\u683c\", \"\u64ca\", \"\u6fc0\", \"\u9694\", \"\u6a84\", \"\u8188\", \"\u89a1\", \"\u630c\", \"\u6bc4\", \"\u95c3\", \"\u9abc\", \"\u9b32\", \"\u9d03\"],\r\n\"\uacac\":[\"\u72ac\", \"\u898b\", \"\u5805\", \"\u80a9\", \"\u7d79\", \"\u9063\", \"\u727d\", \"\u9d51\", \"\u7504\", \"\u7e6d\", \"\u8b74\", \"\u72f7\", \"\u754e\", \"\u7b67\", \"\u7e33\", \"\u7e7e\", \"\u7f82\", \"\u8832\", \"\u9c39\"],\r\n\"\uacb0\":[\"\u6c7a\", \"\u7d50\", \"\u6f54\", \"\u7f3a\", \"\u8a23\", \"\u6289\", \"\u36c3\", \"\u7106\", \"\u8ffc\", \"\u73a6\", \"\u9365\", \"\u89d6\", \"\u95cb\", \"\u34d7\", \"\u6d2f\", \"\u493f\"],\r\n\"\uacb8\":[\"\u517c\", \"\u8b19\", \"\u938c\", \"\u614a\", \"\u7b9d\", \"\u9257\", \"\u55db\", \"\u69cf\", \"\u5094\", \"\u5c92\", \"\u62d1\", \"\u6b49\", \"\u7e11\", \"\u84b9\", \"\u9eda\", \"\u9f38\", \"\u5d70\"],\r\n\"\uad18\":[\"\u639b\", \"\u5366\", \"\u7f6b\", \"\u54bc\", \"\u6302\", \"\u7f63\", \"\u8a7f\"],\r\n\"\uad34\":[\"\u584a\", \"\u6127\", \"\u602a\", \"\u58de\", \"\u4e56\", \"\u5080\", \"\u62d0\", \"\u69d0\", \"\u9b41\", \"\u5abf\", \"\u5ee5\", \"\u7470\", \"\u749d\", \"\u84af\", \"\u8958\"],\r\n\"\uad35\":[\"\u9998\"],\r\n\"\uad49\":[\"\u5b8f\", \"\u7d18\", \"\u80b1\", \"\u8f5f\", \"\u6d64\", \"\u89e5\", \"\u8a07\", \"\u958e\"],\r\n\"\uad6d\":[\"\u570b\", \"\u83ca\", \"\u5c40\", \"\u97a0\", \"\u97ab\", \"\u9eb4\", \"\ud856\udf97\", \"\u530a\", \"\u63ac\", \"\u8dfc\", \"\u9eaf\", \"\u8d9c\", \"\u56fd\"],\r\n\"\uad70\":[\"\u541b\", \"\u90e1\", \"\u8ecd\", \"\u7fa4\", \"\u7a98\", \"\u88d9\", \"\u6343\", \"\u687e\", \"\u76b8\"],\r\n\"\uad74\":[\"\u5c48\", \"\u7a9f\", \"\u5800\", \"\u6398\", \"\u5014\", \"\u5d1b\", \"\u6dc8\", \"\u8a58\"],\r\n\"\uad81\":[\"\u5f13\", \"\u5bae\", \"\u7aae\", \"\u8eac\", \"\u7a79\", \"\u828e\", \"\u8eb3\"],\r\n\"\uad8c\":[\"\u5238\", \"\u6b0a\", \"\u52f8\", \"\u5377\", \"\u62f3\", \"\u5708\", \"\u7737\", \"\u5026\", \"\u6372\", \"\u6dc3\", \"\u52cc\", \"\u60d3\", \"\u68ec\", \"\u7760\", \"\u7da3\", \"\u8737\", \"\u6a29\"],\r\n\"\uad90\":[\"\u53a5\", \"\u95d5\", \"\u7357\", \"\u8568\", \"\u8e76\"],\r\n\"\uada4\":[\"\u8ecc\", \"\u673a\", \"\u6ac3\", \"\u6f70\", \"\u8a6d\", \"\u994b\", \"\u4f79\", \"\u51e0\", \"\u5282\", \"\u5331\", \"\u6192\", \"\u6485\", \"\u6a3b\", \"\u6c3f\", \"\u7c0b\", \"\u7e62\", \"\u8dea\", \"\u95e0\", \"\u993d\", \"\u9e82\"],\r\n\"\uadc0\":[\"\u8cb4\", \"\u6b78\", \"\u9b3c\", \"\u53e5\", \"\u6677\", \"\u4925\", \"\u9f9c\", \"\u9f9c\"],\r\n\"\uaddc\":[\"\u53eb\", \"\u898f\", \"\u7cfe\", \"\u572d\", \"\u594e\", \"\u73ea\", \"\u63c6\", \"\u9035\", \"\u7aba\", \"\u8475\", \"\u69fb\", \"\u7845\", \"\u7ac5\", \"\u8d73\", \"\u95a8\", \"\u90bd\", \"\u5ae2\", \"\u6e40\", \"\u8325\", \"\u7143\", \"\u5232\", \"\u5b00\", \"\u5dcb\", \"\u668c\", \"\u694f\", \"\u6a1b\", \"\u6f59\", \"\u777d\", \"\u866f\", \"\u8dec\", \"\u95da\", \"\u980d\", \"\u9997\", \"\u9a24\", \"\u7cfa\"],\r\n\"\uade0\":[\"\u5747\", \"\u83cc\", \"\u7547\", \"\u921e\", \"\u7b60\", \"\u52fb\", \"\u9f9c\", \"\u89a0\", \"\u56f7\", \"\u9e8f\", \"\u5300\", \"\u9f9c\"],\r\n\"\uade4\":[\"\u6a58\"],\r\n\"\uadf9\":[\"\u6975\", \"\u514b\", \"\u5287\", \"\u524b\", \"\u9699\", \"\u621f\", \"\u68d8\", \"\u4e9f\", \"\u5c05\", \"\u5c50\", \"\u90c4\"],\r\n\"\uadfc\":[\"\u8fd1\", \"\u52e4\", \"\u6839\", \"\u65a4\", \"\u50c5\", \"\u8b39\", \"\u5890\", \"\u6f0c\", \"\u69ff\", \"\u747e\", \"\u5ae4\", \"\u7b4b\", \"\u52a4\", \"\u61c3\", \"\u82b9\", \"\u83eb\", \"\u89b2\", \"\u9949\", \"\u5df9\", \"\u5ed1\", \"\u89d4\", \"\u8ddf\", \"\u91ff\", \"\u9773\", \"\u5807\"],\r\n\"\uae00\":[\"\u5951\", \"\u3515\"],\r\n\"\uae08\":[\"\u91d1\", \"\u4eca\", \"\u7981\", \"\u9326\", \"\u79bd\", \"\u7434\", \"\u887e\", \"\u895f\", \"\u6611\", \"\u5997\", \"\u64d2\", \"\u6a8e\", \"\u82a9\", \"\u887f\", \"\u552b\", \"\u5664\", \"\u5d94\", \"\u7b12\", \"\u9ec5\"],\r\n\"\uae09\":[\"\u53ca\", \"\u7d66\", \"\u6025\", \"\u7d1a\", \"\u6c72\", \"\u4f0b\", \"\u6271\", \"\u573e\", \"\u5c8c\", \"\u7680\", \"\u790f\", \"\u7b08\", \"\u82a8\"],\r\n\"\uae0d\":[\"\u80af\", \"\u4e98\", \"\u5162\", \"\u77dc\", \"\u6b91\", \"\u4e99\"],\r\n\"\uae34\":[\"\u7dca\"],\r\n\"\uae38\":[\"\u5409\", \"\u4f76\", \"\u6854\", \"\u59de\", \"\u62ee\", \"\u86e3\"],\r\n\"\uae40\":[\"\u91d1\"],\r\n\"\ub07d\":[\"\u55ab\"]", "metadata": {"href": "/challenges/5fe14618f5abb00009be3ace", "title": "\ub300\ubc95\uc6d0 \uc778\uba85\uc6a9 \ud55c\uc790 \ubc14\uc778\ub529(Combine Hangul and Chinese characters)", "detail": "The character system of Chinese characters has caused difficulties in character encoding for a long time.\r\n\r\nThe content is to link the Korean name and the corresponding Chinese character.\r\n\r\ninput is simply pasted Chinese table, and you need to make it in json format.\r\n\r\nMost of them are omitted because there are too many Chinese characters, but I think it's better to work from \"\uac00\" to \"\ud790\".\r\n\r\nThere are a few rules here.\r\n\r\n1. One Hangul must correspond to multiple Chinese characters.\r\n2. In the case of \"a(b c)\", \"a\" is a redundant Chinese character, so you must remove \"a\" and use only \"b c\".\r\n3. There are some characters here that don't appear to be an encoding issue, but they should all be used.\r\n4. Should follow json format, Chinese characters are arrays.\r\n\r\n\r\n\ubb38\uc790\uc778\ucf54\ub529\uc758 \uc5ed\uc0ac\uc18d\uc5d0\uc11c \ud55c\ubb38\uc740 \uc5ec\ub7ec\uac00\uc9c0 \ub09c\uc81c\ub97c \ub9cc\ub4e4\uc5b4\uc654\uc2b5\ub2c8\ub2e4.\r\n\uc774 \ubb38\uc81c\ub294 \ub300\ubc95\uc6d0 \uc778\uba85\uc6a9 \ud55c\uc790\ub97c \ubc14\uc778\ub529\ud558\ub294 \uac83\uc785\ub2c8\ub2e4.\r\n\r\ninput\uc740 \ub2e8\uc21c\ud788 \uc778\uba85\uc6a9\ud55c\uc790 \ud45c\ub97c \ubd99\uc5ec\ub123\uae30 \ud55c \uac83\uc774\uba70, \uc774\uac83\uc744 json\ud615\uc2dd\uc73c\ub85c \ub9cc\ub4e4\uc5b4\uc57c\ud569\ub2c8\ub2e4.\r\n\r\n\ud55c\uc790\uac00 \ub108\ubb34 \ub9ce\uc740 \uad00\uacc4\ub85c \ub300\ubd80\ubd84\uc744 \uc0dd\ub7b5\ud558\uc600\uc73c\ub098, \uae30\ubcf8\uc801\uc73c\ub85c \"\uac00\"\ubd80\ud130 \"\ud790\"\uae4c\uc9c0 \ubaa8\ub450 \ub3d9\uc791\ud558\ub294 \uac83\uc774 \uc88b\ub2e4\uace0 \uc0dd\uac01\ud569\ub2c8\ub2e4.\r\n\r\n\uc5ec\uae30\uc5d0\ub294 \uba87\uac00\uc9c0 \uaddc\uce59\uc774 \uc788\uc2b5\ub2c8\ub2e4.\r\n\r\n1. \ud55c\uae00 \ud558\ub098\uc5d0\ub294 \uc5ec\ub7ec\uac1c\uc758 \ud55c\uc790\uac00 \ub300\uc751\ud569\ub2c8\ub2e4.\r\n2. \"a(b c)\"\uac19\uc740 \uacbd\uc6b0 \"a\"\ub294 \uc911\ubcf5\ub41c \ud55c\uc790\uc774\ubbc0\ub85c, \"a\"\ub97c \uc81c\uac70\ud558\uace0 \"b c\"\ub9cc\uc744 \uc0ac\uc6a9\ud574\uc57c\ud569\ub2c8\ub2e4.\r\n3. \uba87\uba87 \ud55c\uc790\ub294 \uc778\ucf54\ub529\uc758 \ubb38\uc81c\ub85c \ubcf4\uc774\uc9c0 \uc54a\uc744 \uc218 \uc788\uc2b5\ub2c8\ub2e4. \uadf8\ub7ec\ub098 \ubaa8\ub450 \uc0ac\uc6a9\ud574\uc57c\ud569\ub2c8\ub2e4.\r\n4. json\ud615\uc2dd\uc744 \ub530\ub974\uba70, \ud55c\uc790\ub294 \ubc30\uc5f4\uc785\ub2c8\ub2e4.", "challenge_hash": "5fe14618f5abb00009be3ace"}, "id": "5fe14618f5abb00009be3ace"} +{"input": "enum Coin {\r\n\tPenny,\r\n\tNickel,\r\n\tDime,\r\n\tQuarter,\r\n}\r\n", "target": "enum Coin {\r\n\tPenny,\r\n\tNickel,\r\n\tDime,\r\n\tQuarter,\r\n}\r\n\r\nfn value_in_cents(coin: Coin) -> u32 {\r\n\tmatch coin {\r\n\t\tCoin::Penny => 1,\r\n\t\tCoin::Nickel => 5,\r\n\t\tCoin::Dime => 10,\r\n\t\tCoin::Quarter => 25,\r\n\t}\r\n}\r\n", "metadata": {"href": "/challenges/5fe050e596c8f7000cda4ddc", "title": "RUST match with enum for Coin", "detail": "RUST match with enum for Coin\r\n\r\nrusty practice on vimgolf", "challenge_hash": "5fe050e596c8f7000cda4ddc"}, "id": "5fe050e596c8f7000cda4ddc"} +{"input": "Pixel *operator[](Position p) {\n return &pixels[p.getc()][p.getr()];\n}\n", "target": "Pixel *operator[](Position p) { return &pixels[p.getc()][p.getr()]; }\n", "metadata": {"href": "/challenges/5fd623f3cf017601b2b63cbc", "title": "Making 3 line function a one liner.", "detail": "Simple challenge for frequent action while programming.", "challenge_hash": "5fd623f3cf017601b2b63cbc"}, "id": "5fd623f3cf017601b2b63cbc"} +{"input": "std::vectortask 1
\ndescription here ....
\ntask 2
\nsome other description here ....
\n| task 1 | \ndescription here .... | \n
| task 2 | \nsome other description here .... | \n
| 0 | 1 | 2 |
| 3 | 4 | 5 |
| 6 | 7 | 8 |
| 9 | a | b |
| c | d | e |
| f |
| \n | First Name | \nLast Name | \nAddress | \nCity | \nState | \nZip | \nCountry | \nPhone | \n\n\n
|---|
| \n\t\t | First Name | \n\t\tLast Name | \n\t\tAddress | \n\t\tCity | \n\t\tState | \n\t\tZip | \n\t\tCountry | \n\t\tPhone | \n\t\t\n\t\n
|---|
112, Hex 70, Octal 160\n
112, Hex 70, Octal 160\n \nVim is an advanced text editor that seeks to provide the power of\nthe de-facto Unix editor 'Vi', with a more complete feature set.\nIt's useful whether you're already using vi or using a different editor. Users of Vim 5 and 6 should\nconsider upgrading to Vim 7. The main advantages of Vim 6 compared to Vim 5 \ncan be found on this page.\n \nVim is a highly configurable text editor built to enable efficient text\nediting. It is an improved version of the vi editor distributed with\nmost UNIX systems. \n \nVim is often called a \"programmer's editor,\" and so useful for\nprogramming that many consider it an entire IDE. It's not just for programmers,\nthough. Vim is perfect for all kinds of text editing, from composing\nemail to editing configuration files.\n \nDespite what the above comic suggests, Vim can be configured to work in a very\nsimple (Notepad-like) way, called evim or Easy Vim.\n \nVim isn't an editor designed to hold its users' hands. It is a tool,\nthe use of which must be learned.\n \nVim isn't a word processor. Although it can display text with various\nforms of highlighting and formatting, it isn't there to provide WYSIWYG\nediting of typeset documents. (It is great for editing TeX, though.)\n \nVim is charityware. Its license is GPL-compatible, so it's\ndistributed freely, but we ask that if you find it useful you make a\ndonation to help children in Uganda through the\nICCF. The full license text can be\nfound in the documentation.\nMuch more information about charityware on\nCharityware.info.\nWhat Vim Can Do
\nA General Overview
\n \n
\nCopyright (c) 2007 Laurent Gregoire\n \nWhat Is Vim?
\nWhat Vim Is Not?
\nVim's License
\n and tags.
+
+IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
+
+Example of a correct call:
+
+import numpy as np
+# Your code here
+print(f"The result is: {np.mean([1,2,3])}")
+
+ and tags.
+
+IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
+
+Example of a correct call:
+
+import numpy as np
+# Your code here
+print(f"The result is: {np.mean([1,2,3])}")
+
+" in content:
+ pass
+ elif '"name":' in content:
+ try:
+ tool_text = content.split("")[1]
+ .split("")[0]
+ .strip()
+ )
+ result = await self.execute_python(code_raw)
+ if isinstance(result, str) and result.startswith(
+ (
+ "Python execution error:",
+ "PythonInterpreter tool not available",
+ "PythonInterpreter tool is not callable",
+ )
+ ):
+ tool_error = True
+ except Exception:
+ result = (
+ "[Python Interpreter Error]: Python code formatting error."
+ )
+ tool_error = True
+ else:
+ try:
+ # Parse JSON tool call
+ tool_call = json5.loads(tool_call_text)
+ tool_name = tool_call.get("name", "")
+ tool_args = tool_call.get("arguments", {})
+ if tool_name == "crop_and_search":
+ tool_args["image_id"] = image_path
+ result = await self.custom_call_tool(tool_name, tool_args)
+ except Exception:
+ result = "[Json Parse Error]: Tool call is not a valid JSON."
+ tool_error = True
+
+ if tool_error:
+ assistant_message["step_error"] = True
+
+ # Add tool response in ReAct format
+ tool_response = f" XML tags, such as:\n\nHere is the code.\n\n XML tags. Remember to use print() statements for any output you want to see.",
+ }
+ },
+ "required": ["code"],
+ },
+ )
+ self.timeout = 30
+
+ async def call(self, code: str, timeout: int = None, **kwargs) -> str:
+ """Execute Python code safely with timeout."""
+ timeout = timeout or self.timeout
+
+ code_len = len(code or "")
+
+ def log_result(
+ status: str,
+ message: str,
+ extra: str | None = None,
+ *,
+ level: str | None = None,
+ ) -> None:
+ preview = shorten_for_log(message)
+ details = f"code_len={code_len} result_len={len(message)} preview={json.dumps(preview, ensure_ascii=False)}"
+ if extra:
+ details += f" {extra}"
+ log_tool_event(
+ source="PythonInterpreter",
+ status=status,
+ message=details,
+ level=level or "INFO",
+ )
+
+ # Security checks - check for dangerous imports/operations
+ dangerous_patterns = [
+ "import os",
+ "import subprocess",
+ "import sys",
+ "from os import",
+ "from subprocess import",
+ "from sys import",
+ "exec(",
+ "eval(",
+ "compile(",
+ "open(",
+ "file(",
+ ]
+
+ code_lower = code.lower()
+ for pattern in dangerous_patterns:
+ if pattern in code_lower:
+ result = f"[Security Error] '{pattern}' not allowed for safety reasons"
+ log_result(
+ "SecurityBlocked",
+ result,
+ extra=f"pattern={json.dumps(pattern, ensure_ascii=False)}",
+ level="WARNING",
+ )
+ return result
+
+ import io
+ import sys
+
+ # Setup safe environment
+ allowed_modules = {
+ "math": __import__("math"),
+ "datetime": __import__("datetime"),
+ "json": __import__("json"),
+ "random": __import__("random"),
+ "re": __import__("re"),
+ "collections": __import__("collections"),
+ "itertools": __import__("itertools"),
+ "statistics": __import__("statistics"),
+ }
+
+ # Add numpy/pandas if available
+ try:
+ import numpy as np
+
+ allowed_modules["numpy"] = np
+ allowed_modules["np"] = np
+ except ImportError:
+ pass
+
+ try:
+ import pandas as pd
+
+ allowed_modules["pandas"] = pd
+ allowed_modules["pd"] = pd
+ except ImportError:
+ pass
+
+ # Restricted builtins with safe import capability
+ def safe_import(name, *args, **kwargs):
+ """Allow importing only safe modules."""
+ safe_modules = [
+ "math",
+ "datetime",
+ "json",
+ "random",
+ "re",
+ "collections",
+ "itertools",
+ "statistics",
+ "numpy",
+ "pandas",
+ "scipy",
+ "scipy.linalg", # Add scipy submodules
+ "scipy.optimize",
+ "scipy.signal",
+ "scipy.special",
+ "matplotlib",
+ "matplotlib.pyplot",
+ "urllib.request",
+ "requests",
+ "sys",
+ ]
+ # Check if the module or its parent is allowed
+ if name in safe_modules or any(
+ name.startswith(m + ".") for m in safe_modules
+ ):
+ return __import__(name, *args, **kwargs)
+ else:
+ raise ImportError(f"Module '{name}' is not allowed for safety reasons")
+
+ restricted_builtins = {
+ "abs": abs,
+ "all": all,
+ "any": any,
+ "bin": bin,
+ "bool": bool,
+ "chr": chr,
+ "dict": dict,
+ "enumerate": enumerate,
+ "filter": filter,
+ "float": float,
+ "hex": hex,
+ "int": int,
+ "len": len,
+ "list": list,
+ "map": map,
+ "max": max,
+ "min": min,
+ "oct": oct,
+ "ord": ord,
+ "pow": pow,
+ "print": print,
+ "range": range,
+ "reversed": reversed,
+ "round": round,
+ "set": set,
+ "slice": slice,
+ "sorted": sorted,
+ "str": str,
+ "sum": sum,
+ "tuple": tuple,
+ "type": type,
+ "zip": zip,
+ "__import__": safe_import, # Allow safe imports
+ # Add exception classes for proper error handling
+ "Exception": Exception,
+ "ImportError": ImportError,
+ "ValueError": ValueError,
+ "TypeError": TypeError,
+ "KeyError": KeyError,
+ "IndexError": IndexError,
+ "AttributeError": AttributeError,
+ }
+
+ global_vars = {"__builtins__": restricted_builtins}
+ global_vars.update(allowed_modules)
+ local_vars = {}
+
+ # Capture output
+ old_stdout = sys.stdout
+ old_stderr = sys.stderr
+ stdout_buffer = io.StringIO()
+ stderr_buffer = io.StringIO()
+
+ def execute_with_timeout():
+ try:
+ sys.stdout = stdout_buffer
+ sys.stderr = stderr_buffer
+ exec(code, global_vars, local_vars)
+ return True
+ except Exception as e:
+ stderr_buffer.write(f"Execution error: {e}")
+ return False
+ finally:
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
+
+ # Execute with timeout
+ loop = asyncio.get_running_loop()
+ future = loop.run_in_executor(self.executor, execute_with_timeout)
+ try:
+ await asyncio.wait_for(future, timeout=timeout)
+ except asyncio.TimeoutError:
+ result = f"[Timeout] Execution exceeded {timeout}s"
+ log_result("Timeout", result, level="WARNING")
+ return result
+ except Exception as exc: # noqa: BLE001
+ result = f"[Error] Unexpected execution error: {exc}"
+ log_result("UnexpectedError", result, level="ERROR")
+ return result
+
+ stdout_content = stdout_buffer.getvalue()
+ stderr_content = stderr_buffer.getvalue()
+
+ if stderr_content:
+ result = f"[Error]\n{stderr_content}"
+ log_result("Error", result, level="ERROR")
+ return result
+ elif stdout_content:
+ cleaned_output = stdout_content.rstrip()
+ result = f"[Output]\n{cleaned_output}"
+ return result
+ else:
+ meaningful_vars = {
+ k: v
+ for k, v in local_vars.items()
+ if not k.startswith("_") and k not in allowed_modules
+ }
+ if meaningful_vars:
+ result = f"[Variables]\n{meaningful_vars}"
+ return result
+ else:
+ result = "[Success] Code executed (no output)"
+ return result
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/tools/search_tool.py b/vendor/rllm/vision_deepresearch_async_workflow/tools/search_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..c43b1d491431d2693f16335f37bf602c17a9eb7b
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/tools/search_tool.py
@@ -0,0 +1,304 @@
+import asyncio
+import json
+import os
+
+from vision_deepresearch_async_workflow.tools.shared import (
+ DeepResearchTool,
+ get_cache_async,
+ get_cache_key,
+ log_search,
+ log_tool_event,
+ run_with_retries_async,
+ set_cache_async,
+)
+
+
+class SearchTool(DeepResearchTool):
+ """Web search tool using Zhipu or Serp API."""
+
+ MAX_URLS = 10
+
+ def __init__(self):
+ super().__init__(
+ name="search",
+ description="Performs batched web searches: supply an array 'query'; the tool retrieves the top 10 results for each query in one call.",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Array of query strings. Include multiple complementary search queries in a single call.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+ self.zhipu_api_key = os.getenv("ZHIPU_API_KEY")
+ self.serp_api_key = os.getenv("SERP_API_KEY")
+ self.zhipu_search_url = os.getenv(
+ "TEXT_SEARCH_URL", "https://search-svip.bigmodel.cn/api/paas/v4/search"
+ )
+ self.serp_search_url = os.getenv(
+ "TEXT_SEARCH_URL", "https://google.serper.dev/search"
+ )
+
+ def contains_chinese(self, text: str) -> bool:
+ """Check if text contains Chinese characters."""
+ return any("\u4e00" <= char <= "\u9fff" for char in text)
+
+ async def _zhipu_search(self, query: str | list) -> str:
+ """Use Zhipu web_search API when key is available."""
+ try:
+ import requests
+ except ImportError:
+ return """[Search - Dependencies Required]
+
+Please install requests: pip install requests"""
+
+ queries = [query] if isinstance(query, str) else query
+ proxies = self._get_requests_proxies()
+
+ async def search_single_query(q: str) -> str:
+ # Check cache for individual query
+ cache_key = get_cache_key(q)
+ cached_result = await get_cache_async(
+ "text_search", cache_key, executor=self.executor
+ )
+ if cached_result:
+ return cached_result
+
+ # Build request
+ headers = {
+ # Zhipu PaaS expects raw token in Authorization; keep value as-is
+ "Authorization": self.zhipu_api_key,
+ "Content-Type": "application/json",
+ }
+ location = "us"
+ body = {
+ "q": q,
+ "search_engine": "search_prime",
+ "location": location,
+ "query_rewrite": False,
+ "content_size": "high",
+ }
+
+ def send_request():
+ return requests.post(
+ self.zhipu_search_url,
+ headers=headers,
+ data=json.dumps(body, ensure_ascii=False),
+ timeout=300,
+ proxies=proxies,
+ )
+
+ try:
+ resp = await run_with_retries_async(
+ send_request, executor=self.executor
+ )
+ except Exception as exc: # noqa: BLE001
+ error_message = f"Search request failed for '{q}': {exc}"
+ log_search("Zhipu", "Exception", q, error=error_message)
+ return error_message
+
+ text = resp.text
+ try:
+ data_obj = resp.json()
+ except Exception:
+ data_obj = None
+
+ if resp.status_code != 200:
+ error_message = f"HTTP {resp.status_code}: {text}"
+ log_search("Zhipu", "HTTPError", q, error=error_message)
+ return f"Search returned HTTP {resp.status_code} for '{q}'\n{text}"
+
+ items = []
+ if isinstance(data_obj, dict):
+ items = data_obj.get("search_result") or data_obj.get("data") or []
+
+ web_snippets: list[str] = []
+ for idx, item in enumerate(items[: self.MAX_URLS], 1):
+ title = (
+ item.get("title", "Untitled")
+ if isinstance(item, dict)
+ else "Untitled"
+ )
+ url = item.get("url", "") if isinstance(item, dict) else ""
+ snippet = item.get("description", "") if isinstance(item, dict) else ""
+ date = item.get("date") if isinstance(item, dict) else None
+
+ snippet = (snippet or "").strip()
+
+ entry = f"{idx}. [{title}]({url})"
+ if date:
+ entry += f"\n Date published: {date}"
+ if snippet:
+ entry += f"\n {snippet}"
+ web_snippets.append(entry)
+
+ content = (
+ f"Search for '{q}' returned {len(web_snippets)} results:\n\n"
+ + "\n\n".join(web_snippets)
+ if web_snippets
+ else f"No search results found for '{q}'"
+ )
+ # Store individual query result in cache (we've already passed error checks above)
+ if not web_snippets:
+ await set_cache_async(
+ "text_search", cache_key, q, content, executor=self.executor
+ )
+
+ return content
+
+ tasks = [search_single_query(q) for q in queries]
+ all_results: list[str] = await asyncio.gather(*tasks) if tasks else []
+
+ final_result = (
+ "\n=======\n".join(all_results)
+ if len(all_results) > 1
+ else (all_results[0] if all_results else "")
+ )
+
+ return final_result
+
+ async def _serp_search(self, query: str | list) -> str:
+ """Use Serp web search API when key is available."""
+ try:
+ import requests
+ except ImportError:
+ return """[Search - Dependencies Required]
+
+Please install requests: pip install requests"""
+
+ queries = [query] if isinstance(query, str) else query
+ proxies = self._get_requests_proxies()
+
+ async def search_single_query(q: str) -> str:
+ cache_key = get_cache_key(q)
+ cached_result = await get_cache_async(
+ "text_search", cache_key, executor=self.executor
+ )
+ if cached_result:
+ return cached_result
+
+ payload = {
+ "q": q,
+ "hl": "en",
+ "gl": "us",
+ }
+
+ headers = {
+ "X-API-KEY": self.serp_api_key,
+ "Content-Type": "application/json",
+ }
+
+ def send_request():
+ return requests.post(
+ self.serp_search_url,
+ headers=headers,
+ data=json.dumps(payload, ensure_ascii=False),
+ timeout=300,
+ proxies=proxies,
+ )
+
+ try:
+ resp = await run_with_retries_async(
+ send_request, executor=self.executor
+ )
+ except Exception as exc: # noqa: BLE001
+ error_message = f"Search request failed for '{q}': {exc}"
+ log_search("Serp", "Exception", q, error=error_message)
+ return error_message
+
+ text = resp.text
+ try:
+ data_obj = resp.json()
+ except Exception:
+ data_obj = None
+
+ if resp.status_code != 200:
+ error_message = f"HTTP {resp.status_code}: {text}"
+ log_search("Serp", "HTTPError", q, error=error_message)
+ return f"Search returned HTTP {resp.status_code} for '{q}'\n{text}"
+
+ items = []
+ if isinstance(data_obj, dict):
+ items = data_obj.get("organic") or []
+
+ web_snippets: list[str] = []
+ for idx, item in enumerate(items[: self.MAX_URLS], 1):
+ title = (
+ item.get("title", "Untitled")
+ if isinstance(item, dict)
+ else "Untitled"
+ )
+ url = item.get("link", "") if isinstance(item, dict) else ""
+ snippet = item.get("snippet", "") if isinstance(item, dict) else ""
+ date = item.get("date") if isinstance(item, dict) else None
+
+ snippet = (snippet or "").strip()
+
+ entry = f"{idx}. [{title}]({url})"
+ if date:
+ entry += f"\n Date published: {date}"
+ if snippet:
+ entry += f"\n {snippet}"
+ web_snippets.append(entry)
+
+ content = (
+ f"Search for '{q}' returned {len(web_snippets)} results:\n\n"
+ + "\n\n".join(web_snippets)
+ if web_snippets
+ else f"No search results found for '{q}'"
+ )
+ if not web_snippets:
+ await set_cache_async(
+ "text_search", cache_key, q, content, executor=self.executor
+ )
+
+ return content
+
+ tasks = [search_single_query(q) for q in queries]
+ all_results: list[str] = await asyncio.gather(*tasks) if tasks else []
+
+ final_result = (
+ "\n=======\n".join(all_results)
+ if len(all_results) > 1
+ else (all_results[0] if all_results else "")
+ )
+
+ return final_result
+
+ async def call(self, query: str | list, **kwargs) -> str:
+ """
+ Search the web using Zhipu or Serp API.
+
+ Args:
+ query: Search query string or list of queries
+
+ Returns:
+ Formatted search results
+ """
+ # Prefer Zhipu if key available
+ if self.zhipu_api_key:
+ return await self._zhipu_search(query)
+
+ if not self.serp_api_key:
+ message = f"""[Search - API Key Required]
+
+To enable real web search, configure one of these options:
+
+Option 1 - Zhipu:
+1. Add to .env: ZHIPU_API_KEY=your_key_here
+
+Option 2 - Serp:
+1. Get an API key from https://serper.dev
+2. Add to .env: SERP_API_KEY=your_key_here
+
+Placeholder results for '{query}'..."""
+
+ log_tool_event("Search", "APIKeyMissing", f"query={query}", level="ERROR")
+ log_search("Serp", "Config", str(query), error=message)
+ return message
+
+ return await self._serp_search(query)
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/tools/shared.py b/vendor/rllm/vision_deepresearch_async_workflow/tools/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..84be811432dfb79a9afc6c05ccb5ff6304ee6179
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/tools/shared.py
@@ -0,0 +1,686 @@
+"""
+DeepResearch Tools - Shared utilities
+"""
+
+import asyncio
+import hashlib
+import json
+import os
+import random
+import re
+import time
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+from concurrent.futures import ThreadPoolExecutor
+from typing import Optional, TypeVar
+
+# Async SQLite support (required dependency)
+import aiosqlite
+
+from rllm.tools.tool_base import Tool as RLLMTool
+
+T = TypeVar("T")
+
+
+def _normalize_level(level: str | None) -> str:
+ if not level:
+ return "INFO"
+ return str(level).upper()
+
+
+def run_with_retries(func: Callable[[], T], attempts: int = 5, delay: float = 0.5) -> T:
+ """Execute a callable with retry support."""
+
+ last_error: Exception | None = None
+ for attempt in range(1, max(attempts, 1) + 1):
+ try:
+ return func()
+ except Exception as exc: # noqa: BLE001
+ last_error = exc
+ if attempt >= attempts:
+ break
+ # Log retry attempt
+ log_tool_event(
+ "Retry",
+ "AttemptFailed",
+ f"Attempt {attempt}/{attempts} failed, will retry",
+ level="WARNING",
+ )
+ if delay > 0:
+ time.sleep(delay)
+
+ if last_error is not None:
+ raise last_error
+
+ raise RuntimeError("run_with_retries executed without performing any attempts")
+
+
+async def run_blocking(
+ func: Callable[[], T], executor: ThreadPoolExecutor | None = None
+) -> T:
+ """Run a blocking call in the given executor."""
+ loop = asyncio.get_running_loop()
+ return await loop.run_in_executor(executor, func)
+
+
+async def run_with_retries_async(
+ func: Callable[[], T],
+ attempts: int = 20,
+ delay: float = 1,
+ executor: ThreadPoolExecutor | None = None,
+) -> T:
+ """Execute a callable with retry support without blocking the event loop."""
+
+ last_error: Exception | None = None
+ for attempt in range(1, max(attempts, 1) + 1):
+ try:
+ return await run_blocking(func, executor=executor)
+ except Exception as exc: # noqa: BLE001
+ last_error = exc
+ if attempt >= attempts:
+ break
+ # Log retry attempt
+ log_tool_event(
+ "Retry",
+ "AttemptFailed",
+ f"Attempt {attempt}/{attempts} failed, will retry",
+ level="WARNING",
+ )
+ if delay > 0:
+ await asyncio.sleep(delay)
+
+ if last_error is not None:
+ raise last_error
+
+ raise RuntimeError(
+ "run_with_retries_async executed without performing any attempts"
+ )
+
+
+def shorten_for_log(text: str, limit: int = 200) -> str:
+ """Create a concise preview string for debug logging."""
+
+ if text is None:
+ return ""
+
+ if not isinstance(text, str):
+ text = str(text)
+
+ if not text:
+ return ""
+
+ normalized = text.replace("\n", "\\n")
+ if len(normalized) <= limit * 2:
+ return normalized
+ return f"{normalized[:limit]} ... {normalized[-limit:]}"
+
+
+def _select_extract_url(env_key: str = "EXTRACT_URL") -> str | None:
+ raw_value = os.getenv(env_key, "")
+ if not raw_value:
+ return None
+ candidates = [item.strip() for item in raw_value.split(",") if item.strip()]
+ if not candidates:
+ return None
+ selected = random.choice(candidates)
+ if not re.search(r"/v1/chat/completions/?$", selected):
+ selected = f"{selected.rstrip('/')}/v1/chat/completions"
+ return selected
+
+
+# Cache database configuration
+CACHE_CONFIG = {
+ "db_path": os.getenv("CACHE_DB_PATH", "deepresearch_cache.db"),
+ "max_age_days": int(os.getenv("CACHE_MAX_AGE_DAYS", "30")), # Cache validity period
+ "max_retries": int(os.getenv("CACHE_MAX_RETRIES", "3")), # Max retry attempts
+ "base_retry_delay": float(
+ os.getenv("CACHE_RETRY_DELAY", "0.1")
+ ), # Base delay in seconds
+ "busy_timeout": int(
+ os.getenv("CACHE_BUSY_TIMEOUT", "1000")
+ ), # SQLite busy timeout in ms
+}
+
+# Global async database connection manager
+_async_db_pool: Optional["AsyncCacheDB"] = None
+_async_db_lock = asyncio.Lock()
+
+
+class AsyncCacheDB:
+ """
+ Async SQLite database connection manager.
+
+ Optimized for high-concurrency scenarios:
+ - Use WAL mode to enable concurrent reads
+ - Protect writes with a lock to avoid conflicts
+ - Single-connection design to avoid SQLite lock contention
+ """
+
+ def __init__(self, db_path: str):
+ self.db_path = db_path
+ self._connection: Optional[aiosqlite.Connection] = None
+ self._init_lock = asyncio.Lock() # Initialization lock
+ self._write_lock = asyncio.Lock() # Write lock to protect concurrent writes
+ self._initialized = False
+
+ async def get_connection(self) -> aiosqlite.Connection:
+ """Get a database connection (lazy initialization)."""
+ if self._connection is None:
+ async with self._init_lock:
+ if self._connection is None:
+ await self._create_connection()
+ return self._connection
+
+ async def _create_connection(self):
+ """Create and configure the database connection."""
+ # Ensure the database directory exists.
+ db_dir = os.path.dirname(self.db_path)
+ if db_dir and not os.path.exists(db_dir):
+ os.makedirs(db_dir, exist_ok=True)
+
+ self._connection = await aiosqlite.connect(self.db_path)
+
+ # Configure database options for high concurrency.
+ await self._connection.execute(
+ "PRAGMA journal_mode=WAL"
+ ) # WAL enables concurrent reads
+ await self._connection.execute(
+ "PRAGMA synchronous=NORMAL"
+ ) # Balance performance and durability
+ await self._connection.execute("PRAGMA cache_size=-64000") # 64MB cache
+ await self._connection.execute(
+ f"PRAGMA busy_timeout={CACHE_CONFIG['busy_timeout']}"
+ )
+ await self._connection.execute("PRAGMA wal_autocheckpoint=1000")
+ await self._connection.execute("PRAGMA temp_store=MEMORY")
+ await self._connection.execute(
+ "PRAGMA read_uncommitted=1"
+ ) # Allow dirty reads to improve concurrency
+
+ # Initialize tables.
+ if not self._initialized:
+ await self._initialize_tables()
+ self._initialized = True
+
+ async def _initialize_tables(self):
+ """Initialize cache tables."""
+ for table_name, schema in CACHE_TABLES.items():
+ await self._connection.executescript(schema)
+ await self._connection.commit()
+
+ async def execute_read(self, sql: str, params: tuple = ()) -> Optional[tuple]:
+ """Execute a read operation (no lock needed; WAL supports concurrent reads)."""
+ conn = await self.get_connection()
+ async with conn.execute(sql, params) as cursor:
+ return await cursor.fetchone()
+
+ async def execute_write(self, sql: str, params: tuple = ()):
+ """Execute a write operation (protected by a lock)."""
+ async with self._write_lock:
+ conn = await self.get_connection()
+ await conn.execute(sql, params)
+ await conn.commit()
+
+ async def execute_write_batch(self, operations: list[tuple[str, tuple]]):
+ """Execute batched writes (single lock acquisition for efficiency)."""
+ async with self._write_lock:
+ conn = await self.get_connection()
+ for sql, params in operations:
+ await conn.execute(sql, params)
+ await conn.commit()
+
+ async def close(self):
+ """Close the database connection."""
+ if self._connection is not None:
+ await self._connection.close()
+ self._connection = None
+
+
+async def get_async_cache_db() -> AsyncCacheDB:
+ """Get the global async database instance (thread-safe singleton)."""
+ global _async_db_pool
+ if _async_db_pool is None:
+ async with _async_db_lock:
+ if _async_db_pool is None:
+ _async_db_pool = AsyncCacheDB(CACHE_CONFIG["db_path"])
+ return _async_db_pool
+
+
+# Cache table schemas
+CACHE_TABLES = {
+ "text_search": """
+ CREATE TABLE IF NOT EXISTS text_search (
+ query_hash TEXT PRIMARY KEY,
+ query TEXT NOT NULL,
+ result TEXT NOT NULL,
+ created_at REAL NOT NULL,
+ last_accessed REAL NOT NULL,
+ access_count INTEGER DEFAULT 1
+ );
+ CREATE INDEX IF NOT EXISTS idx_text_search_query_hash ON text_search(query_hash);
+ CREATE INDEX IF NOT EXISTS idx_text_search_last_accessed ON text_search(last_accessed);
+ """,
+ "text_visit": """
+ CREATE TABLE IF NOT EXISTS text_visit (
+ url_hash TEXT PRIMARY KEY,
+ url TEXT NOT NULL,
+ result TEXT NOT NULL,
+ created_at REAL NOT NULL,
+ last_accessed REAL NOT NULL,
+ access_count INTEGER DEFAULT 1
+ );
+ CREATE INDEX IF NOT EXISTS idx_text_visit_url_hash ON text_visit(url_hash);
+ CREATE INDEX IF NOT EXISTS idx_text_visit_last_accessed ON text_visit(last_accessed);
+ """,
+ "image_search": """
+ CREATE TABLE IF NOT EXISTS image_search (
+ image_url_hash TEXT PRIMARY KEY,
+ image_url TEXT NOT NULL,
+ result TEXT NOT NULL,
+ created_at REAL NOT NULL,
+ last_accessed REAL NOT NULL,
+ access_count INTEGER DEFAULT 1
+ );
+ CREATE INDEX IF NOT EXISTS idx_image_search_url_hash ON image_search(image_url_hash);
+ CREATE INDEX IF NOT EXISTS idx_image_search_last_accessed ON image_search(last_accessed);
+ """,
+ "image_visit": """
+ CREATE TABLE IF NOT EXISTS image_visit (
+ url_hash TEXT PRIMARY KEY,
+ url TEXT NOT NULL,
+ result TEXT NOT NULL,
+ created_at REAL NOT NULL,
+ last_accessed REAL NOT NULL,
+ access_count INTEGER DEFAULT 1
+ );
+ CREATE INDEX IF NOT EXISTS idx_image_visit_url_hash ON image_visit(url_hash);
+ CREATE INDEX IF NOT EXISTS idx_image_visit_last_accessed ON image_visit(last_accessed);
+ """,
+}
+
+
+def get_cache_key(text: str) -> str:
+ """Generate a cache key from text using SHA256."""
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
+
+
+# ============================================================================
+# Async cache operations (high-concurrency optimized)
+# ============================================================================
+
+# Column name mapping (for table-specific queries)
+_HASH_COL_MAPPING = {
+ "text_search": "query_hash",
+ "text_visit": "url_hash",
+ "image_search": "image_url_hash",
+ "image_visit": "url_hash",
+}
+
+_INPUT_COL_MAPPING = {
+ "text_search": ("query_hash", "query"),
+ "text_visit": ("url_hash", "url"),
+ "image_search": ("image_url_hash", "image_url"),
+ "image_visit": ("url_hash", "url"),
+}
+
+
+async def get_cache_async(
+ table: str, key: str, executor: ThreadPoolExecutor | None = None
+) -> Optional[str]:
+ """
+ Fetch a cache entry asynchronously.
+
+ High-concurrency optimizations:
+ - Reads do not need a lock (WAL supports concurrent reads)
+ - Use asyncio.sleep for non-blocking retries
+ - The executor argument is kept for backward compatibility but unused
+ """
+ max_retries = CACHE_CONFIG["max_retries"]
+ base_delay = CACHE_CONFIG["base_retry_delay"]
+ hash_col = _HASH_COL_MAPPING.get(table, "hash")
+
+ for attempt in range(max_retries):
+ try:
+ cache_db = await get_async_cache_db()
+ current_time = time.time()
+
+ # Read first (no lock needed).
+ row = await cache_db.execute_read(
+ f"SELECT result FROM {table} WHERE {hash_col} = ?", (key,)
+ )
+
+ if row:
+ # Update access time asynchronously (write lock).
+ # Fire-and-forget; do not wait for completion.
+ asyncio.create_task(
+ _update_access_time_async(table, hash_col, key, current_time)
+ )
+ return row[0]
+ else:
+ return None
+
+ except Exception as e:
+ error_msg = str(e).lower()
+ if (
+ "database is locked" in error_msg or "database is busy" in error_msg
+ ) and attempt < max_retries - 1:
+ # Async wait; do not block the event loop.
+ wait_time = (2**attempt) * base_delay
+ await asyncio.sleep(wait_time)
+ continue
+ else:
+ log_tool_event(
+ "Cache",
+ "AsyncGetError",
+ f"table={table} error={str(e)}",
+ level="ERROR",
+ )
+ return None
+
+ return None
+
+
+async def _update_access_time_async(
+ table: str, hash_col: str, key: str, current_time: float
+):
+ """Update cache access time in the background (does not block main flow)."""
+ try:
+ cache_db = await get_async_cache_db()
+ await cache_db.execute_write(
+ f"UPDATE {table} SET last_accessed = ?, access_count = access_count + 1 WHERE {hash_col} = ?",
+ (current_time, key),
+ )
+ except Exception:
+ pass # Ignore update failures; does not affect main flow.
+
+
+async def set_cache_async(
+ table: str,
+ key: str,
+ original_input: str,
+ result: str,
+ executor: ThreadPoolExecutor | None = None,
+):
+ """
+ Store a cache entry asynchronously.
+
+ High-concurrency optimizations:
+ - Protect writes with a lock to avoid conflicts
+ - Use asyncio.sleep for non-blocking retries
+ - The executor argument is kept for backward compatibility but unused
+ """
+ max_retries = CACHE_CONFIG["max_retries"]
+ base_delay = CACHE_CONFIG["base_retry_delay"]
+ hash_col, input_col = _INPUT_COL_MAPPING.get(table, ("hash", "input"))
+
+ # Validate data size.
+ if len(result) > 100 * 1024 * 1024: # 100MB limit
+ log_tool_event(
+ "Cache",
+ "SizeError",
+ f"table={table} result too large: {len(result)} bytes",
+ level="WARNING",
+ )
+ return
+
+ for attempt in range(max_retries):
+ try:
+ cache_db = await get_async_cache_db()
+ current_time = time.time()
+
+ # Write operation (protected by a lock).
+ await cache_db.execute_write(
+ f"""
+ INSERT OR REPLACE INTO {table}
+ ({hash_col}, {input_col}, result, created_at, last_accessed, access_count)
+ VALUES (?, ?, ?, ?, ?, 1)
+ """,
+ (key, original_input, result, current_time, current_time),
+ )
+ return # Success
+
+ except Exception as e:
+ error_msg = str(e).lower()
+ if (
+ "database is locked" in error_msg or "database is busy" in error_msg
+ ) and attempt < max_retries - 1:
+ # Async wait; do not block the event loop.
+ wait_time = (2**attempt) * base_delay
+ await asyncio.sleep(wait_time)
+ continue
+ else:
+ log_tool_event(
+ "Cache",
+ "AsyncSetError",
+ f"table={table} error={str(e)}",
+ level="ERROR",
+ )
+ return
+
+
+def log_tool_event(
+ source: str,
+ status: str,
+ message: str | None,
+ *,
+ error: str | None = None,
+ level: str | None = "INFO",
+) -> None:
+ """Unified logging helper for DeepResearch tools (stdout based)."""
+
+ safe_message = message or ""
+ message_preview = shorten_for_log(safe_message)
+ level_name = _normalize_level(level)
+
+ log_parts = [
+ f"[Tool][{source}][{status}][{level_name}]",
+ f"message_len={len(safe_message)}",
+ f"preview={json.dumps(message_preview, ensure_ascii=False)}",
+ ]
+
+ if error is not None:
+ error_preview = shorten_for_log(error)
+ log_parts.append(f"error_len={len(error)}")
+ log_parts.append(f"error={json.dumps(error_preview, ensure_ascii=False)}")
+
+ print(" ".join(log_parts))
+
+
+def log_search(
+ source: str,
+ status: str,
+ query: str,
+ result: str | None = None,
+ error: str | None = None,
+) -> None:
+ """Standardized debug logs for search tools."""
+
+ parts = [f"query={json.dumps(query, ensure_ascii=False)}"]
+
+ if result is not None:
+ preview = shorten_for_log(result)
+ parts.append(f"result_len={len(result)}")
+ parts.append(f"preview={json.dumps(preview, ensure_ascii=False)}")
+
+ message = " ".join(parts)
+ level = "ERROR" if error else "INFO"
+
+ log_tool_event(
+ source=f"Search/{source}",
+ status=status,
+ message=message,
+ error=error,
+ level=level,
+ )
+
+
+class DeepResearchTool(RLLMTool, ABC):
+ """
+ Base class for all DeepResearch tools.
+
+ Inherits from rLLM's Tool to support OpenAI native function calling,
+ while maintaining compatibility with ReAct text format.
+ """
+
+ def __init__(self, name: str, description: str, parameters: dict | None = None):
+ """
+ Initialize DeepResearch tool with OpenAI function calling support.
+
+ Args:
+ name: Tool name
+ description: Tool description
+ parameters: OpenAI-style parameter schema (optional)
+ """
+ # Set _json BEFORE calling super().__init__
+ # because the parent's __init__ may access self.json
+ self._json = {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": description,
+ "parameters": parameters
+ or {"type": "object", "properties": {}, "required": []},
+ },
+ }
+
+ super().__init__(name=name, description=description)
+ self.executor: ThreadPoolExecutor | None = None
+
+ @abstractmethod
+ async def call(self, **kwargs) -> str:
+ """Execute the tool with given arguments."""
+ pass
+
+ def set_executor(self, executor: ThreadPoolExecutor | None) -> None:
+ """Bind a tool executor for blocking calls."""
+ self.executor = executor
+
+ def _get_requests_proxies(self) -> dict | None:
+ """Build requests-compatible proxy mapping from TOOL_HTTPS_PROXY."""
+ proxy_value = os.getenv("TOOL_HTTPS_PROXY")
+ if proxy_value is None:
+ return None
+
+ proxy_value = proxy_value.strip()
+ if not proxy_value or proxy_value.lower() == "none":
+ return {"http": None, "https": None}
+
+ return {"http": proxy_value, "https": proxy_value}
+
+ async def _run_blocking(self, func: Callable[[], T]) -> T:
+ """Run a blocking function in the bound executor."""
+ return await run_blocking(func, executor=self.executor)
+
+ async def async_forward(self, **kwargs):
+ """rLLM Tool interface - delegates to call()"""
+ try:
+ from rllm.tools.tool_base import ToolOutput
+ except ImportError:
+ from rllm_mllm.rllm.tools.tool_base import ToolOutput
+
+ try:
+ result = await self.call(**kwargs)
+ return ToolOutput(name=self.name, output=result)
+ except Exception as e:
+ return ToolOutput(name=self.name, error=f"{type(e).__name__} - {str(e)}")
+
+
+async def check_cache_health_async() -> bool:
+ """Check cache database health asynchronously."""
+ try:
+ cache_db = await get_async_cache_db()
+ conn = await cache_db.get_connection()
+
+ # Test basic connectivity
+ async with conn.execute(
+ "SELECT COUNT(*) FROM sqlite_master WHERE type='table'"
+ ) as cursor:
+ row = await cursor.fetchone()
+ table_count = row[0] if row else 0
+
+ # Check if our tables exist
+ expected_tables = {"text_search", "text_visit", "image_search", "image_visit"}
+ async with conn.execute(
+ "SELECT name FROM sqlite_master WHERE type='table'"
+ ) as cursor:
+ rows = await cursor.fetchall()
+ existing_tables = {row[0] for row in rows}
+
+ if not expected_tables.issubset(existing_tables):
+ # Tables will be auto-created by AsyncCacheDB
+ pass
+
+ # Test WAL file size (rough check)
+ try:
+ wal_path = CACHE_CONFIG["db_path"] + "-wal"
+ if os.path.exists(wal_path):
+ wal_size = os.path.getsize(wal_path)
+ if wal_size > 100 * 1024 * 1024: # 100MB
+ log_tool_event(
+ "Cache",
+ "WALSize",
+ f"WAL file too large: {wal_size} bytes",
+ level="WARNING",
+ )
+ except:
+ pass
+
+ return True
+
+ except Exception as e:
+ log_tool_event(
+ "Cache", "HealthCheck", f"Health check failed: {str(e)}", level="ERROR"
+ )
+ return False
+
+
+async def cleanup_expired_cache_async():
+ """Clean expired cache entries asynchronously."""
+ try:
+ cache_db = await get_async_cache_db()
+ max_age_seconds = CACHE_CONFIG["max_age_days"] * 24 * 60 * 60
+ cutoff_time = time.time() - max_age_seconds
+
+ # Clean up expired entries
+ tables = ["text_search", "text_visit", "image_search", "image_visit"]
+
+ for table in tables:
+ await cache_db.execute_write(
+ f"DELETE FROM {table} WHERE last_accessed < ?", (cutoff_time,)
+ )
+
+ except Exception as e:
+ log_tool_event(
+ "Cache", "CleanupError", f"Failed to cleanup cache: {str(e)}", level="ERROR"
+ )
+
+
+async def initialize_cache_async():
+ """Initialize the cache database asynchronously (auto on first use)."""
+ try:
+ cache_db = await get_async_cache_db()
+ await cache_db.get_connection() # Trigger connection and table initialization.
+ await check_cache_health_async()
+ await cleanup_expired_cache_async()
+ except Exception as e:
+ log_tool_event(
+ "Cache",
+ "InitError",
+ f"Failed to initialize cache: {str(e)}",
+ level="WARNING",
+ )
+
+
+# Cache initialization flag
+_cache_initialized = False
+_cache_init_lock = asyncio.Lock()
+
+
+async def ensure_cache_initialized():
+ """Ensure the cache is initialized (thread-safe)."""
+ global _cache_initialized
+ if not _cache_initialized:
+ async with _cache_init_lock:
+ if not _cache_initialized:
+ await initialize_cache_async()
+ _cache_initialized = True
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/tools/visit_tool.py b/vendor/rllm/vision_deepresearch_async_workflow/tools/visit_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..7111214c81dccc2107f3e4fbb60cf97f9bfcb05e
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/tools/visit_tool.py
@@ -0,0 +1,477 @@
+import asyncio
+import json
+import os
+import random
+import re
+from typing import Any
+
+from vision_deepresearch_async_workflow.tools.shared import (
+ DeepResearchTool,
+ get_cache_async,
+ get_cache_key,
+ log_tool_event,
+ run_with_retries_async,
+ set_cache_async,
+ shorten_for_log,
+)
+
+
+class VisitTool(DeepResearchTool):
+ """Web page visiting with content extraction."""
+
+ MAX_URLS = 5
+ MAX_CONTENT_CHARS = 120000
+
+ EXTRACTOR_PROMPT = """Please process the following webpage content and user goal to extract relevant information:
+
+## **Webpage Content**
+{webpage_content}
+
+## **User Goal**
+{goal}
+
+## **Task Guidelines**
+1. **Content Scanning for Rational**: Locate the **specific sections/data** directly related to the user's goal within the webpage content
+2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.
+3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal.
+
+**Final Output Requirements**
+- Return a valid JSON object only (no code fences, Markdown, comments, or additional text).
+- The JSON must contain exactly the keys "rational", "evidence", and "summary".
+- Each key must map to a string value. Use an empty string if no content is available.
+- Do not include any extra keys or explanatory sentences outside the JSON object.
+
+Example:
+{{"rational": "Explain why the information is relevant to the goal.", "evidence": "Quote or paraphrase the key supporting content from the webpage.", "summary": "Provide a concise summary that connects the evidence back to the goal."}}
+"""
+
+ def __init__(self):
+ super().__init__(
+ name="visit",
+ description="Visit webpage(s) and return the summary of the content.",
+ parameters={
+ "type": "object",
+ "properties": {
+ "url": {
+ "type": ["string", "array"],
+ "items": {"type": "string"},
+ "minItems": 1,
+ "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.",
+ },
+ "goal": {
+ "type": "string",
+ "description": "The goal of the visit for webpage(s).",
+ },
+ },
+ "required": ["url", "goal"],
+ },
+ )
+ self.zhipu_api_key = os.getenv("ZHIPU_API_KEY")
+ self.jina_api_key = os.getenv("JINA_API_KEY")
+ self.zhipu_reader_url = os.getenv(
+ "READER_URL", "https://search-svip.bigmodel.cn/api/paas/v4/reader"
+ )
+ self.jina_reader_url = os.getenv("READER_URL", "https://r.jina.ai")
+ self.extract_model = os.getenv("EXTRACT_MODEL", "Qwen3-VL-30B-A3B-Instruct")
+ self.extract_max_tokens = 16384
+ raw_extract_urls = os.getenv("EXTRACT_URL", "")
+ self.extract_urls = [
+ item.strip() for item in raw_extract_urls.split(",") if item.strip()
+ ]
+
+ async def call(self, url: str | list, goal: str = "", **kwargs) -> str:
+ """Visit webpages via Reader API and optionally summarize with a local model."""
+
+ urls = [url] if isinstance(url, str) else url
+ if not urls:
+ return "[Visit] No valid URL provided"
+
+ tasks = [
+ self._handle_single_url(target_url, goal)
+ for target_url in urls[: self.MAX_URLS]
+ ]
+ results = await asyncio.gather(*tasks) if tasks else []
+
+ return "\n\n=======\n\n".join(results)
+
+ async def _handle_single_url(self, url: str, goal: str) -> str:
+ normalized_url = self._normalize_url(url)
+
+ try:
+ reader_payload = await self._fetch_reader_content(normalized_url)
+ except Exception as exc: # noqa: BLE001
+ log_tool_event(
+ source="Visit/Reader",
+ status="Exception",
+ message=f"url={normalized_url}",
+ error=str(exc),
+ level="ERROR",
+ )
+ return self._build_failure_message(
+ normalized_url, goal, f"Unable to fetch webpage content: {exc}"
+ )
+
+ if reader_payload is None:
+ return self._build_failure_message(
+ normalized_url, goal, "Reader API returned empty payload"
+ )
+
+ content = reader_payload.get("content") or ""
+ description = reader_payload.get("description") or ""
+
+ if not content:
+ fallback = description or "Webpage content is empty"
+ return self._build_failure_message(normalized_url, goal, fallback)
+
+ content = self._truncate_content(content)
+
+ summary_result = await self._summarize_with_extract(
+ content, goal, reader_payload
+ )
+
+ if summary_result is None:
+ log_tool_event(
+ "Visit", "ExtractSummaryFailed", f"url={normalized_url}", level="ERROR"
+ )
+ evidence_text = content
+ summary_text = (
+ description or "Summary service unavailable. Returning raw content."
+ )
+ else:
+ evidence_text = summary_result.get("evidence") or content
+ summary_text = summary_result.get("summary") or description or ""
+
+ return self._format_success(normalized_url, goal, evidence_text, summary_text)
+
+ def _normalize_url(self, url: str) -> str:
+ from urllib.parse import urlparse
+
+ parsed = urlparse(url)
+ if not parsed.scheme:
+ return f"https://{url}"
+ return url
+
+ def _select_extract_url(self) -> str | None:
+ if not self.extract_urls:
+ return None
+ selected = random.choice(self.extract_urls)
+ if not re.search(r"/v1/chat/completions/?$", selected):
+ selected = f"{selected.rstrip('/')}/v1/chat/completions"
+ return selected
+
+ async def _fetch_reader_content(self, url: str) -> dict[str, Any] | None:
+ # Check cache first
+ cache_key = get_cache_key(url)
+ cached_result = await get_cache_async(
+ "text_visit", cache_key, executor=self.executor
+ )
+ if cached_result:
+ try:
+ return json.loads(cached_result)
+ except json.JSONDecodeError:
+ pass # Continue with API call if cache is corrupted
+
+ try:
+ import requests
+ except ImportError as exc: # noqa: PERF203
+ raise RuntimeError("Visit tool requires 'requests' package") from exc
+
+ proxies = self._get_requests_proxies()
+
+ if self.zhipu_api_key:
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if self.zhipu_api_key:
+ headers["Authorization"] = self.zhipu_api_key
+
+ # Support optional headers consistent with the demo scripts
+ optional_headers = {
+ "X-Return-Format": "markdown",
+ "X-No-Cache": "false",
+ "X-Timeout": "60",
+ "X-Retain-Images": "false",
+ "X-With-Images-Summary": "false",
+ "X-With-Links-Summary": "false",
+ }
+ headers.update({k: v for k, v in optional_headers.items() if v is not None})
+
+ body = {
+ "url": url,
+ }
+
+ def send_request():
+ return requests.post(
+ self.zhipu_reader_url,
+ headers=headers,
+ data=json.dumps(body, ensure_ascii=False),
+ timeout=60,
+ proxies=proxies,
+ )
+
+ response = await run_with_retries_async(
+ send_request, executor=self.executor
+ )
+
+ if response.status_code != 200:
+ raise RuntimeError(f"Reader API returned HTTP {response.status_code}")
+
+ try:
+ payload = response.json()
+ except json.JSONDecodeError as exc: # noqa: PERF203
+ raise RuntimeError("Reader API returned non-JSON payload") from exc
+
+ if not isinstance(payload, dict):
+ raise RuntimeError("Reader API payload structure is invalid")
+
+ if payload.get("code") != 200:
+ raise RuntimeError(
+ f"Reader API returned error code: {payload.get('code')}"
+ )
+
+ data = payload.get("data")
+ if not isinstance(data, dict):
+ raise RuntimeError("Reader API data field missing or malformed")
+
+ result = {
+ "content": data.get("content") or "",
+ "description": data.get("description") or "",
+ "meta": data,
+ }
+ else:
+ headers = {
+ "Authorization": self.jina_api_key,
+ }
+ body = {
+ "url": url,
+ }
+
+ def send_request():
+ return requests.post(
+ self.jina_reader_url,
+ headers=headers,
+ data=body,
+ timeout=60,
+ proxies=proxies,
+ )
+
+ response = await run_with_retries_async(
+ send_request, executor=self.executor
+ )
+
+ if response.status_code != 200:
+ raise RuntimeError(f"Reader API returned HTTP {response.status_code}")
+
+ result = {
+ "content": response.text or "",
+ "description": "",
+ "meta": {
+ "provider": "jina",
+ "url": url,
+ "reader_url": self.jina_reader_url,
+ },
+ }
+
+ # Store result in cache only if we have valid content
+ if result["content"].strip():
+ await set_cache_async(
+ "text_visit",
+ cache_key,
+ url,
+ json.dumps(result, ensure_ascii=False),
+ executor=self.executor,
+ )
+
+ return result
+
+ def _truncate_content(self, content: str) -> str:
+ if len(content) <= self.MAX_CONTENT_CHARS:
+ return content
+ return content[: self.MAX_CONTENT_CHARS] + "\n[Content truncated...]"
+
+ async def _summarize_with_extract(
+ self, content: str, goal: str, reader_payload: dict[str, Any]
+ ) -> dict[str, Any] | None:
+ extract_url = self._select_extract_url()
+ if not extract_url:
+ log_tool_event(
+ source="Visit/Extract",
+ status="Config",
+ message="EXTRACT_URL is not set, skip extract service",
+ )
+ return None
+
+ try:
+ import requests
+ except ImportError:
+ log_tool_event(
+ source="Visit/Extract",
+ status="DependencyMissing",
+ message="'requests' package not installed, cannot call extract service",
+ level="WARNING",
+ )
+ return None
+
+ prompt = self.EXTRACTOR_PROMPT.format(
+ webpage_content=content, goal=goal or "N/A"
+ )
+
+ extract_messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ]
+
+ if self.extract_model:
+ payload = {
+ "model": self.extract_model,
+ "messages": extract_messages,
+ "max_tokens": self.extract_max_tokens,
+ }
+
+ headers = {"Content-Type": "application/json"}
+ proxies = self._get_requests_proxies()
+
+ try:
+ response = await run_with_retries_async(
+ lambda: requests.post(
+ url=extract_url,
+ headers=headers,
+ json=payload,
+ timeout=60,
+ proxies=proxies,
+ ),
+ executor=self.executor,
+ )
+ except Exception as exc: # noqa: BLE001
+ log_tool_event(
+ source="Visit/Extract",
+ status="RequestError",
+ message=f"url={extract_url}",
+ error=str(exc),
+ level="ERROR",
+ )
+ return None
+
+ if response.status_code != 200:
+ log_tool_event(
+ source="Visit/Extract",
+ status="HTTPError",
+ message=f"url={extract_url} status={response.status_code}",
+ level="WARNING",
+ )
+ return None
+
+ try:
+ result = response.json()
+ except json.JSONDecodeError:
+ log_tool_event(
+ source="Visit/Extract",
+ status="ParseError",
+ message="Extract service returned non-JSON response, unable to parse",
+ level="WARNING",
+ )
+ return None
+
+ raw_payload: str | dict | None = None
+ content_source: str | None = None
+
+ if isinstance(result, dict):
+ choices = result.get("choices")
+ if isinstance(choices, list) and choices:
+ first_choice = choices[0] or {}
+ if isinstance(first_choice, dict):
+ message_dict = first_choice.get("message")
+ if isinstance(message_dict, dict):
+ message_content = message_dict.get("content")
+ if isinstance(message_content, str) and message_content.strip():
+ raw_payload = message_content
+ content_source = "choices[0].message.content"
+ if raw_payload is None:
+ text_candidate = first_choice.get("text")
+ if isinstance(text_candidate, str) and text_candidate.strip():
+ raw_payload = text_candidate
+ content_source = "choices[0].text"
+ if raw_payload is None:
+ fallback_payload = result.get("content") or result.get("data")
+ if isinstance(fallback_payload, (str, dict)):
+ raw_payload = fallback_payload
+ content_source = "response.content/data"
+
+ if raw_payload is None:
+ log_tool_event(
+ source="Visit/Extract",
+ status="InvalidContent",
+ message="Extract service response missing usable content",
+ level="WARNING",
+ )
+ return None
+
+ fallback_used = False
+ content_dict: dict | None = None
+
+ if isinstance(raw_payload, dict):
+ content_dict = raw_payload
+ elif isinstance(raw_payload, str):
+ candidate = raw_payload.strip()
+ if candidate.startswith("`"):
+ candidate = candidate.strip("`")
+ try:
+ content_dict = json.loads(candidate)
+ content_source = f"{content_source or 'string_payload'} -> json.loads"
+ except json.JSONDecodeError:
+ fallback_used = True
+ summary_text = candidate
+ content_dict = {
+ "rational": "",
+ "evidence": summary_text,
+ "summary": summary_text,
+ }
+ if not isinstance(content_dict, dict):
+ log_tool_event(
+ source="Visit/Extract",
+ status="InvalidContent",
+ message="Extract service response does not contain JSON summary content",
+ level="WARNING",
+ )
+ return None
+
+ return content_dict
+
+ def _build_failure_message(self, url: str, goal: str, reason: str) -> str:
+ useful_information = f"The useful information in {url} for user goal {goal or 'N/A'} as follows: \n\n"
+ useful_information += "Evidence in page: \n" + reason + "\n\n"
+ useful_information += (
+ "Summary: \n"
+ + "Unable to retrieve webpage content. Please check the link or try again later."
+ + "\n\n"
+ )
+
+ reason_preview = shorten_for_log(reason)
+ result_preview = shorten_for_log(useful_information)
+ log_tool_event(
+ source="Visit",
+ status="Failure",
+ message=(
+ f"url={url} "
+ f"reason_len={len(reason)} "
+ f"result_len={len(useful_information)} "
+ f"reason_preview={json.dumps(reason_preview, ensure_ascii=False)} "
+ f"result_preview={json.dumps(result_preview, ensure_ascii=False)}"
+ ),
+ level="WARNING",
+ )
+
+ return useful_information
+
+ def _format_success(self, url: str, goal: str, evidence: str, summary: str) -> str:
+ useful_information = f"The useful information in {url} for user goal {goal or 'N/A'} as follows: \n\n"
+ useful_information += "Evidence in page: \n" + evidence + "\n\n"
+ useful_information += (
+ "Summary: \n" + (summary or "No summary generated") + "\n\n"
+ )
+
+ evidence_text = evidence or ""
+ summary_text = summary or ""
+ evidence_preview = shorten_for_log(evidence_text)
+ summary_preview = shorten_for_log(summary_text)
+ return useful_information
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp.py b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..88065fc6f2911f8fb499475b8b529b772cd772ea
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp.py
@@ -0,0 +1,38 @@
+import hydra
+
+from vision_deepresearch_async_workflow.deepresearch_tools_async_executor import (
+ get_all_tools,
+)
+from vision_deepresearch_async_workflow.deepresearch_workflow import (
+ DeepResearchWorkflow,
+)
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.rewards.reward_fn import deepresearch_reward_fn_async
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(
+ config_path="pkg://rllm.trainer.config",
+ config_name="agent_ppo_trainer",
+ version_base=None,
+)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "train")
+ test_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=DeepResearchWorkflow,
+ workflow_args={
+ "reward_function": deepresearch_reward_fn_async,
+ "tools": get_all_tools(),
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp_gen.py b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..344905bf0cfbbd1d8c053ea683e6a5e97c8c52a6
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_fsdp_gen.py
@@ -0,0 +1,47 @@
+"""
+Gen 版训练入口:3 个公司工具 + Gen system prompt + Gen reward(公司 DeepSeek Judge)。
+使用方式:python -m vision_deepresearch_async_workflow.train_deepresearch_workflow_fsdp_gen
+运行前请 source .env.gen 或设置 DEEPSEEK_API_KEY / DEEPSEEK_API_BASE / JUDGE_MODEL 等。
+"""
+import hydra
+
+from vision_deepresearch_async_workflow.gen_deepresearch_reward import (
+ gen_deepresearch_reward_fn_async,
+)
+from vision_deepresearch_async_workflow.gen_deepresearch_tools_async_executor import (
+ GEN_DEEPRESEARCH_SYSTEM_PROMPT,
+ get_all_tools,
+)
+from vision_deepresearch_async_workflow.deepresearch_workflow import (
+ DeepResearchWorkflow,
+)
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(
+ config_path="pkg://rllm.trainer.config",
+ config_name="agent_ppo_trainer",
+ version_base=None,
+)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "train")
+ test_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=DeepResearchWorkflow,
+ workflow_args={
+ "reward_function": gen_deepresearch_reward_fn_async,
+ "tools": get_all_tools(),
+ "system_prompt": GEN_DEEPRESEARCH_SYSTEM_PROMPT,
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_megatron.py b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_megatron.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ed44aa9f04de38b88ea3823bee4be83cb7a1345
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/train_deepresearch_workflow_megatron.py
@@ -0,0 +1,38 @@
+import hydra
+
+from vision_deepresearch_async_workflow.deepresearch_tools_async_executor import (
+ get_all_tools,
+)
+from vision_deepresearch_async_workflow.deepresearch_workflow import (
+ DeepResearchWorkflow,
+)
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.rewards.reward_fn import deepresearch_reward_fn_async
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(
+ config_path="pkg://rllm.trainer.config",
+ config_name="agent_ppo_trainer_megatron",
+ version_base=None,
+)
+def main(config):
+ train_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "train")
+ test_dataset = DatasetRegistry.load_dataset("Vision-DeepResearch-QA", "test")
+
+ trainer = AgentTrainer(
+ workflow_class=DeepResearchWorkflow,
+ workflow_args={
+ "reward_function": deepresearch_reward_fn_async,
+ "tools": get_all_tools(),
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vendor/rllm/vision_deepresearch_async_workflow/train_image_deepresearch_workflow_fsdp_gen.py b/vendor/rllm/vision_deepresearch_async_workflow/train_image_deepresearch_workflow_fsdp_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6019de7667fd155c7327d9ac21b77ba33fd96ffb
--- /dev/null
+++ b/vendor/rllm/vision_deepresearch_async_workflow/train_image_deepresearch_workflow_fsdp_gen.py
@@ -0,0 +1,85 @@
+"""
+Gen Image training entry: image generation task.
+Usage: python -m vision_deepresearch_async_workflow.train_image_deepresearch_workflow_fsdp_gen
+"""
+import os
+import hydra
+
+from vision_deepresearch_async_workflow.gen_image_deepresearch_reward import (
+ gen_image_deepresearch_reward_fn_async,
+)
+from vision_deepresearch_async_workflow.gen_image_deepresearch_tools_executor import (
+ create_gen_image_tools,
+)
+from vision_deepresearch_async_workflow.gen_image_deepresearch_workflow import (
+ GenImageDeepResearchWorkflow,
+)
+
+from rllm.data.dataset import DatasetRegistry
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(
+ config_path="pkg://rllm.trainer.config",
+ config_name="agent_ppo_trainer",
+ version_base=None,
+)
+def main(config):
+ """
+ Main training function for the image generation task.
+
+ Environment variables:
+ - DEEPSEEK_API_KEY: DeepSeek API key
+ - DEEPSEEK_API_BASE: DeepSeek API base URL
+ - QWEN_EDIT_APP_URL / QWEN_EDIT_APP_PATH: Qwen Edit image generation service URLs (JSON array or comma-separated) and HTTP path (default /generate)
+ - GEN_REWARD_API_KEY / GEN_REWARD_API_BASE_URL / GEN_REWARD_MODEL: GPT-4.1 scoring aligned with KnowGen eval (gpt_eval_knowgen)
+ (same prompt, same overall formula; scoring uses the original question, not gen_prompt)
+ - GEN_IMAGE_OUTPUT_DIR: output directory for generated images
+ - GEN_IMAGE_TIMEOUT: image generation timeout (seconds)
+ - IMAGE_SEARCH_PROXY_IPS: proxy IP list for image search
+ - BROWSE_JINA_PROXY: proxy for read-proxy browsing
+ - JINA_API_KEYS: read-proxy API keys
+ - SERPER_KEY_ID: API key (X-API-KEY) for text + image search when using Serper
+ - TEXT_SEARCH_API_BASE_URL: full POST URL for web search (e.g. https://google.serper.dev/search)
+ - IMAGE_SEARCH_API_BASE_URL: full POST URL for image search (e.g. https://google.serper.dev/images)
+ """
+ print("\n[TrainGenImage] ===== Initialize Training =====")
+ print(f"[TrainGenImage] Config: {config}")
+
+ # Load dataset (name from env; default Vision-DeepResearch-Gen)
+ dataset_name = os.environ.get("DATASET_NAME", "Vision-DeepResearch-Gen").strip()
+ print(f"[TrainGenImage] Dataset name: {dataset_name}")
+ print("[TrainGenImage] Loading dataset...")
+ train_dataset = DatasetRegistry.load_dataset(dataset_name, "train")
+ test_dataset = DatasetRegistry.load_dataset(dataset_name, "test")
+
+ print(f"[TrainGenImage] Train dataset size: {len(train_dataset) if hasattr(train_dataset, '__len__') else 'unknown'}")
+ print(f"[TrainGenImage] Test dataset size: {len(test_dataset) if hasattr(test_dataset, '__len__') else 'unknown'}")
+
+ # Create tools
+ print("[TrainGenImage] Creating tools...")
+ tools = create_gen_image_tools()
+
+ # Create trainer
+ print("[TrainGenImage] Creating Trainer...")
+ trainer = AgentTrainer(
+ workflow_class=GenImageDeepResearchWorkflow,
+ workflow_args={
+ "reward_function": gen_image_deepresearch_reward_fn_async,
+ "tools": tools,
+ "system_prompt": None, # Defined in GenImageDeepResearchAgent
+ },
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+
+ print("[TrainGenImage] ===== Start Training =====\n")
+
+ trainer.train()
+
+ print("\n[TrainGenImage] ===== Training Completed =====")
+
+
+if __name__ == "__main__":
+ main()