| |
| """ |
| 统一评测脚本。 |
| |
| 输入: |
| - 旧版图像 benchmark 的 JSONL 配置、VideoDR 的无表头 CSV 配置,或 |
| `SenseNova-MARS-Data/data/eval` 下按子目录组织的图像 benchmark。 |
| - 模型服务地址、可选的工具配置 YAML、本地 MARS 检索服务,以及评测输出目录。 |
| - 图像路径或视频路径;对视频会在评测过程中按 1fps 抽帧并缓存到输出目录。 |
| - 可选的 Vertex Gemini 账号池 JSON;每个条目绑定一个 GCP project 与一个 service account |
| JSON,用于 `--model-client vertex` 时轮询调用 Gemini。 |
| |
| 处理: |
| - 当传入 `--eval-root` / `--benchmarks` 时,自动扫描每个 benchmark 子目录的 |
| `data.jsonl`,按子目录内的 `images/` 与可选 image search 缩略图构造临时数据集配置; |
| 多选题 benchmark 使用 EM,开放问答 / 搜索型 benchmark 使用 LLM judge。 |
| - 保留原始 image benchmark 的 direct/tool 评测流程,并把工具名兼容到新论文的 |
| `zoom_in / image_search / web_search`;direct answer 模式只做单轮模型调用,不进入工具循环。 |
| - 为 VideoDR 样本增加 `choose_frames / find_frame / zoom_in / image_search / web_search` |
| 的多轮执行状态机,工具实现复用 SFT 构造代码中的定义;当 VideoDR 使用 direct answer |
| 模式时,每条视频只抽取一次 1fps 均匀采样帧并直接回答,不调用任何工具。 |
| - 支持工具消融 profile:在 search 消融中移除 `image_search / web_search`,在 |
| location 消融中移除 `choose_frames / zoom_in`,并同步约束 prompt 与运行时白名单。 |
| - 支持 Vertex 原生 Gemini `generateContent`:把 OpenAI-style messages 转为 |
| `systemInstruction + contents`,每次模型请求按账号池选择 project,遇到 429/RESOURCE_EXHAUSTED |
| 时切换到下一个 project 重试。 |
| - 根据数据集配置动态执行 EM 或 LLM judge,并生成 HTML 结果浏览页。 |
| |
| 输出: |
| - 在输出目录下写入 `results.jsonl`、`summary.json`、`results.html`。 |
| - 对视频样本额外落盘抽帧缓存与 HTML 可视化图片,便于复查轨迹。 |
| """ |
|
|
| import argparse |
| import asyncio |
| import base64 |
| import ipaddress |
| import io |
| import json |
| import math |
| import os |
| import re |
| import string |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Optional |
| from urllib.parse import urlparse |
|
|
| import atexit |
| import hashlib |
| import sqlite3 |
|
|
| import aiohttp |
| import yaml |
| from PIL import Image |
|
|
| |
| from playwright.async_api import async_playwright |
|
|
| from video_dr_bridge import ( |
| DEFAULT_VIDEO_INITIAL_FRAMES, |
| DEFAULT_VIDEO_INTERVAL_SAMPLES, |
| DEFAULT_VIDEO_JPEG_QUALITY, |
| DEFAULT_VIDEO_MAX_RESOLUTION, |
| GATEWAY_TOKEN as VIDEO_DR_GATEWAY_TOKEN, |
| GATEWAY_USERID as VIDEO_DR_GATEWAY_USERID, |
| GATEWAY_USERNAME as VIDEO_DR_GATEWAY_USERNAME, |
| GATEWAY_URL as VIDEO_DR_GATEWAY_URL, |
| IMAGE_SEARCH_MODE as VIDEO_DR_IMAGE_SEARCH_MODE, |
| ImageSearchCache, |
| MARS_RETRIEVAL_ADDRESS, |
| MARS_SUMMARIZER_ADDRESS, |
| MARS_SUMMARIZER_MODEL, |
| VIDEO_DR_SYSTEM_PROMPT, |
| add_search_padding as vdr_add_search_padding, |
| crop_frame as vdr_crop_frame, |
| extract_video_frames_1fps, |
| get_bbox_config as vdr_get_bbox_config, |
| get_video_dr_system_prompt, |
| get_frame as vdr_get_frame, |
| load_videodr_csv_samples, |
| mars_web_search as vdr_mars_web_search, |
| normalize_bbox as vdr_normalize_bbox, |
| real_image_search as vdr_real_image_search, |
| sample_interval as vdr_sample_interval, |
| uniform_sample_indices as vdr_uniform_sample_indices, |
| ) |
|
|
| DEMO_ROOT = Path(__file__).resolve().parent.parent |
| LOCAL_DATA_ROOT = Path( |
| os.environ.get("VIDEO_DEEP_RESEARCH_DATA_ROOT", str(DEMO_ROOT / "local_data")) |
| ).expanduser() |
| LOCAL_SECRETS_DIR = Path( |
| os.environ.get("VIDEO_DEEP_RESEARCH_SECRETS_DIR", str(DEMO_ROOT / "secrets")) |
| ).expanduser() |
|
|
| DEFAULT_EVAL_ROOT = str(LOCAL_DATA_ROOT / "eval") |
|
|
| DEFAULT_COMPANY_OPENAI_BASE_URL = "http://35.220.164.252:3888/v1" |
| DEFAULT_COMPANY_OPENAI_API_KEY_FILE = str(LOCAL_SECRETS_DIR / "company_openai_api_key.txt") |
| DEFAULT_MODEL_GATEWAY_USERNAME = "kaiweichen" |
| DEFAULT_MODEL_GATEWAY_USERID = "483124" |
| DEFAULT_MODEL_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_token.txt") |
| DEFAULT_GPT54_GATEWAY_USERNAME = "jemminyang" |
| DEFAULT_GPT54_GATEWAY_USERID = "491087" |
| DEFAULT_GPT54_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_gpt54_token.txt") |
| DEFAULT_GEMINI_GATEWAY_USERNAME = "suyuanhuang" |
| DEFAULT_GEMINI_GATEWAY_USERID = "485311" |
| DEFAULT_GEMINI_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_gemini_token.txt") |
| DEFAULT_TAVILY_API_KEY_FILE = str(LOCAL_SECRETS_DIR / "tavily_api_keys.txt") |
| GATEWAY_MODEL_NAME_ALIASES = { |
| "gpt5.4": "GPT-5.4", |
| "gpt-5.4": "GPT-5.4", |
| "gpt_5.4": "GPT-5.4", |
| "gpt5.2": "gpt-5.2", |
| "gpt-5.2": "gpt-5.2", |
| "gpt_5.2": "gpt-5.2", |
| "gpt4o": "gpt-4o", |
| "gpt-4o": "gpt-4o", |
| "gpt_4o": "gpt-4o", |
| "gemini3propreview": "gemini-3-pro-preview", |
| "gemini-3-pro-preview": "gemini-3-pro-preview", |
| "gemini_3_pro_preview": "gemini-3-pro-preview", |
| } |
| OPENAI_MODEL_NAME_ALIASES = { |
| "gpt4o": "gpt-4o", |
| "gpt-4o": "gpt-4o", |
| "gpt_4o": "gpt-4o", |
| "gpt5.2": "gpt-5.2", |
| "gpt-5.2": "gpt-5.2", |
| "gpt_5.2": "gpt-5.2", |
| } |
|
|
| DEFAULT_IMAGE_DATASET_SYSTEM_PROMPT = """#Role |
| You are a step-by-step reasoning assistant. |
| Given a question, your task is to solve the problem one substep at a time. |
| |
| ## Guiding Principles |
| At each turn, you must either: |
| 1. Issue one specific tool enclosed in <tool_call></tool_call> tags, |
| 2. Or provide the final answer enclosed in <answer></answer> tags. |
| |
| All outputs must begin with a thought enclosed in <thinking></thinking> tags, |
| explaining your current reasoning and what to do next. |
| |
| ## Output Format (strict) |
| Always start with <thinking>. Do not output the previous reasoning chain. |
| |
| 1. If reasoning continues: |
| <thinking> Your current reasoning and next plan </thinking> |
| <tool_call> One precise tool call to assist your reasoning </tool_call> |
| |
| 2. If ready to conclude: |
| <thinking> Summarize all reasoning and derive the answer </thinking> |
| <answer> Final answer </answer>""" |
|
|
| GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT = """# Role |
| You are an advanced general video understanding assistant. |
| Given a user query about a video, answer directly from the provided video frames and any supplemental images. |
| |
| # Video Context |
| - The input video has already been converted to 1 frame per second (1 fps). |
| - You are provided with uniformly sampled frames from the video. |
| - Some questions may include supplemental images referenced as <image 1>, <image 2>, etc.; these images are provided after the sampled video frames. |
| |
| # Direct Answer Rules |
| 1. Do not call tools. Do not output tool-call tags or tool-like JSON. |
| 2. Use only the provided frames, supplemental images, and your internal knowledge. |
| 3. For multiple-choice questions, put only the option letter in the final answer. |
| 4. For short-answer questions, put only the concise answer in the final answer. |
| 5. If evidence is incomplete, answer with the best supported conclusion. |
| |
| # Output Format (STRICT) |
| You may include brief reasoning in <think></think>, then provide the final answer in <answer></answer>. |
| |
| <think> ... </think> |
| <answer> Final answer to the user's query </answer>""" |
|
|
| DEFAULT_MCQ_BENCHMARKS = { |
| "hr_bench_4k", |
| "hr_bench_8k", |
| "vstar_bench", |
| } |
|
|
|
|
| class FatalAPIError(Exception): |
| """Raised when search or LLM judge has an error that should stop evaluation.""" |
| pass |
|
|
|
|
| class URLFetchError(Exception): |
| """Raised when URL fetch fails with a retriable error (timeout, HTTP error, etc.).""" |
| pass |
|
|
|
|
| class WebFetchStats: |
| """Global tracker for web fetch statistics.""" |
| def __init__(self): |
| self.total = 0 |
| self.successful = 0 |
| self.failed = 0 |
| self.skipped = 0 |
| self.errors_by_code = {} |
|
|
| def record_success(self): |
| self.total += 1 |
| self.successful += 1 |
|
|
| def record_failure(self, error_msg: str = ""): |
| self.total += 1 |
| self.failed += 1 |
| |
| match = re.search(r'HTTP (\d+)', error_msg) |
| if match: |
| code = match.group(1) |
| self.errors_by_code[code] = self.errors_by_code.get(code, 0) + 1 |
|
|
| def record_skip(self): |
| self.total += 1 |
| self.skipped += 1 |
|
|
| def get_stats(self) -> dict: |
| return { |
| "total": self.total, |
| "successful": self.successful, |
| "failed": self.failed, |
| "skipped": self.skipped, |
| "success_rate": (self.successful / self.total * 100) if self.total > 0 else 0, |
| "errors_by_code": self.errors_by_code, |
| } |
|
|
| def format_progress(self) -> str: |
| if self.total == 0: |
| return "" |
| rate = self.successful / self.total * 100 |
| failed_str = f", {self.failed} failed" if self.failed > 0 else "" |
| skipped_str = f", {self.skipped} skipped" if self.skipped > 0 else "" |
| return f"Web: {self.successful}/{self.total} ({rate:.0f}%){failed_str}{skipped_str}" |
|
|
|
|
| |
| _web_fetch_stats = WebFetchStats() |
|
|
|
|
| def _truncate_debug_text(text: str, limit: int = 2000) -> str: |
| """清理并截断调试文本,避免错误日志被 base64 图片淹没。""" |
| cleaned = re.sub(r"data:image[^\"'\\s>]*", "[IMAGE]", str(text or "")) |
| return cleaned[:limit] |
|
|
|
|
| def _summarize_messages_for_debug(messages: list[dict]) -> dict: |
| """统计请求消息规模,帮助定位 400 是否由超长上下文触发。""" |
| summary = { |
| "num_messages": len(messages), |
| "role_counts": {}, |
| "text_chars": 0, |
| "num_image_parts": 0, |
| "num_tool_responses": 0, |
| "last_user_text": "", |
| "last_assistant_text": "", |
| } |
|
|
| def _content_to_text(content) -> str: |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts = [] |
| for item in content: |
| if isinstance(item, str): |
| parts.append(item) |
| elif isinstance(item, dict): |
| if item.get("type") == "text": |
| parts.append(item.get("text", "")) |
| elif item.get("type") == "image_url": |
| summary["num_image_parts"] += 1 |
| return "\n".join(parts) |
| return str(content) |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| summary["role_counts"][role] = summary["role_counts"].get(role, 0) + 1 |
| text = _content_to_text(msg.get("content", "")) |
| summary["text_chars"] += len(text) |
| summary["num_tool_responses"] += text.count("<tool_response>") |
| if role == "user" and text: |
| summary["last_user_text"] = _truncate_debug_text(text, limit=400) |
| elif role == "assistant" and text: |
| summary["last_assistant_text"] = _truncate_debug_text(text, limit=400) |
| return summary |
|
|
|
|
| def _extract_http_error_message(resp_status: int, data: dict, raw_text: str) -> str: |
| """兼容不同 OpenAI 兼容服务的错误格式。""" |
| if isinstance(data, dict): |
| nested = data.get("error") |
| if isinstance(nested, dict) and nested.get("message"): |
| return _truncate_debug_text(nested["message"], limit=500) |
| if isinstance(nested, str) and nested: |
| return _truncate_debug_text(nested, limit=500) |
| for key in ("message", "detail"): |
| value = data.get(key) |
| if value: |
| return _truncate_debug_text(value, limit=500) |
| return _truncate_debug_text(raw_text or f"HTTP {resp_status}", limit=500) |
|
|
|
|
| def _append_no_proxy_entry(raw_entry: str) -> None: |
| """将本地服务地址追加到 no_proxy / NO_PROXY。""" |
| entry = (raw_entry or "").strip() |
| if not entry: |
| return |
| if "://" in entry: |
| parsed = urlparse(entry) |
| entry = parsed.netloc or parsed.path |
| entry = entry.strip().strip("/") |
| if not entry: |
| return |
|
|
| for env_name in ("no_proxy", "NO_PROXY"): |
| current = os.environ.get(env_name, "") |
| items = [item.strip() for item in current.split(",") if item.strip()] |
| if entry not in items: |
| items.append(entry) |
| os.environ[env_name] = ",".join(items) |
|
|
|
|
| def configure_local_service_no_proxy(raw_url: str) -> None: |
| """为本地或内网服务设置 no_proxy,避免访问时误走代理。""" |
| if not raw_url: |
| return |
|
|
| normalized_url = raw_url if "://" in raw_url else f"http://{raw_url}" |
| try: |
| parsed = urlparse(normalized_url) |
| except Exception: |
| return |
|
|
| host = (parsed.hostname or "").strip() |
| port = parsed.port |
| if not host: |
| return |
|
|
| is_local = host == "localhost" |
| if not is_local: |
| try: |
| ip_obj = ipaddress.ip_address(host) |
| in_cgn = ip_obj in ipaddress.ip_network("100.64.0.0/10") |
| is_local = ip_obj.is_private or ip_obj.is_loopback or in_cgn |
| except ValueError: |
| is_local = False |
| if not is_local: |
| return |
|
|
| |
| _append_no_proxy_entry(f"{host}:{port}" if port else host) |
| _append_no_proxy_entry(host) |
|
|
|
|
| def _openai_api_base(base_url: str) -> str: |
| """Return a normalized OpenAI-compatible API base that ends with /v1.""" |
| base = (base_url or "").rstrip("/") |
| return base if base.endswith("/v1") else f"{base}/v1" |
|
|
|
|
| def _openai_chat_completions_url(base_url: str) -> str: |
| return f"{_openai_api_base(base_url)}/chat/completions" |
|
|
|
|
| def _openai_models_url(base_url: str) -> str: |
| return f"{_openai_api_base(base_url)}/models" |
|
|
|
|
| def _read_secret_file(path: str) -> str: |
| if not path: |
| return "" |
| try: |
| with open(os.path.expanduser(path), "r", encoding="utf-8") as f: |
| return f.read().strip() |
| except FileNotFoundError: |
| return "" |
|
|
|
|
| def normalize_model_name_for_client(model: str, client: str) -> str: |
| """Normalize common user-facing model aliases before sending requests.""" |
| raw = (model or "").strip() |
| key = raw.lower() |
| if client == "gateway": |
| return GATEWAY_MODEL_NAME_ALIASES.get(key, raw) |
| if client == "openai": |
| return OPENAI_MODEL_NAME_ALIASES.get(key, raw) |
| if client == "vertex" and raw.startswith("google/"): |
| return raw.split("/", 1)[1] |
| return raw |
|
|
|
|
| def _is_gemini_gateway_model(model: str) -> bool: |
| normalized = normalize_model_name_for_client(model, "gateway").lower() |
| return normalized.startswith("gemini-") |
|
|
|
|
| def _is_gpt54_gateway_model(model: str) -> bool: |
| return normalize_model_name_for_client(model, "gateway").lower() == "gpt-5.4" |
|
|
|
|
| def _gateway_model_type(model: str) -> str: |
| if _is_gemini_gateway_model(model): |
| return os.environ.get("MODEL_GATEWAY_GEMINI_MODEL_TYPE", "openai") |
| return os.environ.get("MODEL_GATEWAY_MODEL_TYPE", "openai") |
|
|
|
|
| def _get_model_gateway_credentials(model: str, api_key: str = "") -> tuple[str, str, str]: |
| if _is_gemini_gateway_model(model): |
| username = ( |
| os.environ.get("MODEL_GATEWAY_GEMINI_USERNAME") |
| or os.environ.get("GEMINI_GATEWAY_USERNAME") |
| or DEFAULT_GEMINI_GATEWAY_USERNAME |
| ) |
| userid = ( |
| os.environ.get("MODEL_GATEWAY_GEMINI_USERID") |
| or os.environ.get("GEMINI_GATEWAY_USERID") |
| or DEFAULT_GEMINI_GATEWAY_USERID |
| ) |
| token = ( |
| os.environ.get("MODEL_GATEWAY_GEMINI_TOKEN") |
| or os.environ.get("GEMINI_GATEWAY_TOKEN", "") |
| or _read_secret_file( |
| os.environ.get("MODEL_GATEWAY_GEMINI_TOKEN_FILE", DEFAULT_GEMINI_GATEWAY_TOKEN_FILE) |
| ) |
| or api_key |
| ) |
| return username, userid, token |
|
|
| if _is_gpt54_gateway_model(model): |
| username = ( |
| os.environ.get("MODEL_GATEWAY_GPT54_USERNAME") |
| or os.environ.get("GPT54_GATEWAY_USERNAME") |
| or DEFAULT_GPT54_GATEWAY_USERNAME |
| ) |
| userid = ( |
| os.environ.get("MODEL_GATEWAY_GPT54_USERID") |
| or os.environ.get("GPT54_GATEWAY_USERID") |
| or DEFAULT_GPT54_GATEWAY_USERID |
| ) |
| token = ( |
| os.environ.get("MODEL_GATEWAY_GPT54_TOKEN") |
| or os.environ.get("GPT54_GATEWAY_TOKEN", "") |
| or _read_secret_file( |
| os.environ.get("MODEL_GATEWAY_GPT54_TOKEN_FILE", DEFAULT_GPT54_GATEWAY_TOKEN_FILE) |
| ) |
| or api_key |
| ) |
| return username, userid, token |
|
|
| username = ( |
| os.environ.get("MODEL_GATEWAY_USERNAME") |
| or os.environ.get("GATEWAY_USERNAME") |
| or DEFAULT_MODEL_GATEWAY_USERNAME |
| ) |
| userid = ( |
| os.environ.get("MODEL_GATEWAY_USERID") |
| or os.environ.get("GATEWAY_USERID") |
| or DEFAULT_MODEL_GATEWAY_USERID |
| ) |
| token = ( |
| api_key |
| or os.environ.get("MODEL_GATEWAY_TOKEN") |
| or os.environ.get("GATEWAY_TOKEN", "") |
| or _read_secret_file(os.environ.get("MODEL_GATEWAY_TOKEN_FILE", DEFAULT_MODEL_GATEWAY_TOKEN_FILE)) |
| ) |
| return username, userid, token |
|
|
|
|
| def create_http_session(timeout: aiohttp.ClientTimeout) -> aiohttp.ClientSession: |
| """创建 HTTP 会话:外部服务走环境代理,本地服务依赖 no_proxy 直连。""" |
| return aiohttp.ClientSession(timeout=timeout, trust_env=True) |
|
|
|
|
| class SearchCache: |
| """Simple SQLite cache for search results.""" |
|
|
| def __init__(self, cache_dir: str): |
| self.db_path = os.path.join(cache_dir, "search_cache.db") |
| os.makedirs(cache_dir, exist_ok=True) |
| self._init_db() |
| self.hits = 0 |
| self.misses = 0 |
| |
| count = self._get_entry_count() |
| print(f"Search cache: {self.db_path} ({count} entries)") |
| |
| atexit.register(self.close) |
|
|
| def _init_db(self): |
| conn = sqlite3.connect(self.db_path) |
| try: |
| conn.execute("PRAGMA journal_mode=WAL") |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS cache ( |
| key TEXT PRIMARY KEY, |
| value TEXT NOT NULL, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| ) |
| """) |
| finally: |
| conn.close() |
|
|
| def _get_entry_count(self) -> int: |
| conn = sqlite3.connect(self.db_path) |
| try: |
| count = conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] |
| finally: |
| conn.close() |
| return count |
|
|
| def _iter_seed_db_paths(self, seed_paths: list[str]) -> list[str]: |
| """Find historical search_cache.db files under files or directories.""" |
| own_db = os.path.realpath(self.db_path) |
| db_paths = [] |
| seen = set() |
| for seed_path in seed_paths: |
| if not seed_path: |
| continue |
| expanded = os.path.realpath(os.path.expanduser(seed_path)) |
| if os.path.isfile(expanded): |
| candidates = [expanded] |
| elif os.path.isdir(expanded): |
| candidates = [] |
| for root, _, files in os.walk(expanded): |
| if "search_cache.db" in files: |
| candidates.append(os.path.join(root, "search_cache.db")) |
| else: |
| print(f"[CACHE SEED] Skip missing path: {seed_path}", flush=True) |
| continue |
|
|
| for candidate in candidates: |
| real_candidate = os.path.realpath(candidate) |
| if real_candidate == own_db or real_candidate in seen: |
| continue |
| seen.add(real_candidate) |
| db_paths.append(real_candidate) |
| return sorted(db_paths) |
|
|
| def seed_from_paths(self, seed_paths: list[str]) -> dict: |
| """Copy historical cache rows into this run cache without overwriting local rows.""" |
| db_paths = self._iter_seed_db_paths(seed_paths) |
| stats = { |
| "sources": len(db_paths), |
| "inserted": 0, |
| "existing": 0, |
| "skipped": 0, |
| } |
| if not db_paths: |
| print("[CACHE SEED] No historical search_cache.db found.", flush=True) |
| return stats |
|
|
| dest = sqlite3.connect(self.db_path, timeout=30.0) |
| try: |
| for db_path in db_paths: |
| try: |
| src = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=10.0) |
| try: |
| columns = { |
| row[1] |
| for row in src.execute("PRAGMA table_info(cache)").fetchall() |
| } |
| if not {"key", "value"}.issubset(columns): |
| print(f"[CACHE SEED] Skip incompatible cache: {db_path}", flush=True) |
| stats["skipped"] += 1 |
| continue |
| if "created_at" in columns: |
| rows = src.execute("SELECT key, value, created_at FROM cache").fetchall() |
| else: |
| rows = [ |
| (key, value, None) |
| for key, value in src.execute("SELECT key, value FROM cache").fetchall() |
| ] |
| finally: |
| src.close() |
|
|
| before = dest.total_changes |
| dest.executemany( |
| "INSERT OR IGNORE INTO cache (key, value, created_at) VALUES (?, ?, COALESCE(?, CURRENT_TIMESTAMP))", |
| rows, |
| ) |
| dest.commit() |
| inserted = dest.total_changes - before |
| stats["inserted"] += inserted |
| stats["existing"] += max(0, len(rows) - inserted) |
| print( |
| f"[CACHE SEED] {db_path}: {inserted} inserted, " |
| f"{max(0, len(rows) - inserted)} already present", |
| flush=True, |
| ) |
| except sqlite3.Error as e: |
| print(f"[CACHE SEED] Skip unreadable cache {db_path}: {e}", flush=True) |
| stats["skipped"] += 1 |
| finally: |
| dest.close() |
|
|
| total = self._get_entry_count() |
| print( |
| f"[CACHE SEED] Done: {stats['inserted']} inserted from {stats['sources']} source db(s); " |
| f"{stats['existing']} duplicate row(s); current cache has {total} entries.", |
| flush=True, |
| ) |
| return stats |
|
|
| def _make_key(self, query: str, top_k: int, model: str) -> str: |
| |
| normalized = re.sub(r'\s+', ' ', query.strip().lower()) |
| parts = [ |
| f"q={normalized}", |
| f"k={top_k}", |
| f"model={model}", |
| f"prompt=mmsearch_r1", |
| f"jina=0", |
| f"think=0" |
| ] |
| return hashlib.md5("|".join(parts).encode("utf-8")).hexdigest() |
|
|
| async def get(self, query: str, top_k: int, model: str) -> Optional[str]: |
| key = self._make_key(query, top_k, model) |
| conn = sqlite3.connect(self.db_path) |
| try: |
| row = conn.execute("SELECT value FROM cache WHERE key = ?", (key,)).fetchone() |
| finally: |
| conn.close() |
| if row: |
| self.hits += 1 |
| print(f"[CACHE HIT] {query[:50]}...") |
| |
| try: |
| data = json.loads(row[0]) |
| if isinstance(data, dict) and "summaries" in data: |
| return data["summaries"] |
| return row[0] |
| except json.JSONDecodeError: |
| return row[0] |
| self.misses += 1 |
| return None |
|
|
| async def set(self, query: str, top_k: int, model: str, value: str): |
| |
| key = self._make_key(query, top_k, model) |
| data = json.dumps({"summaries": value}) |
| for attempt in range(5): |
| conn = sqlite3.connect(self.db_path, timeout=5.0) |
| try: |
| conn.execute("INSERT OR REPLACE INTO cache (key, value) VALUES (?, ?)", (key, data)) |
| conn.commit() |
| print(f"[CACHE STORE] {query[:50]}...") |
| return |
| except sqlite3.OperationalError as e: |
| if "locked" in str(e).lower() and attempt < 4: |
| await asyncio.sleep(0.05 * (2 ** attempt)) |
| else: |
| raise |
| finally: |
| conn.close() |
|
|
| def get_stats(self) -> dict: |
| total = self.hits + self.misses |
| return { |
| "hits": self.hits, |
| "misses": self.misses, |
| "total": total, |
| "hit_rate": (self.hits / total * 100) if total > 0 else 0.0, |
| } |
|
|
| def close(self): |
| """Checkpoint WAL to prevent corruption on shutdown.""" |
| try: |
| conn = sqlite3.connect(self.db_path, timeout=10.0) |
| try: |
| conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") |
| finally: |
| conn.close() |
| except Exception: |
| pass |
|
|
|
|
| def resolve_cache_seed_paths(cli_paths: Optional[list[str]], no_auto_seed: bool) -> list[str]: |
| """Resolve cache seed roots from CLI, env, or the default inference/runs tree.""" |
| if no_auto_seed: |
| return [] |
| if cli_paths is not None: |
| return list(cli_paths) |
|
|
| env_seed_roots = os.environ.get("SEARCH_CACHE_SEED_ROOTS", "").strip() |
| if env_seed_roots: |
| return [ |
| item.strip() |
| for item in re.split(r"[:;,]", env_seed_roots) |
| if item.strip() |
| ] |
| return [os.path.join(os.path.dirname(__file__), "runs")] |
|
|
|
|
| def seed_image_search_cache(image_search_cache: ImageSearchCache, seed_paths: list[str]) -> dict: |
| """Merge historical image_search_cache.json files into this run cache.""" |
| own_file = os.path.realpath(image_search_cache.cache_file) |
| cache_files = [] |
| seen = set() |
| for seed_path in seed_paths: |
| if not seed_path: |
| continue |
| expanded = os.path.realpath(os.path.expanduser(seed_path)) |
| if os.path.isfile(expanded): |
| candidates = [expanded] if os.path.basename(expanded) == "image_search_cache.json" else [] |
| elif os.path.isdir(expanded): |
| candidates = [] |
| for root, _, files in os.walk(expanded): |
| if "image_search_cache.json" in files: |
| candidates.append(os.path.join(root, "image_search_cache.json")) |
| else: |
| continue |
|
|
| for candidate in candidates: |
| real_candidate = os.path.realpath(candidate) |
| if real_candidate == own_file or real_candidate in seen: |
| continue |
| seen.add(real_candidate) |
| cache_files.append(real_candidate) |
|
|
| stats = { |
| "sources": len(cache_files), |
| "inserted": 0, |
| "existing": 0, |
| "skipped": 0, |
| } |
| if not cache_files: |
| print("[IMAGE CACHE SEED] No historical image_search_cache.json found.", flush=True) |
| return stats |
|
|
| for cache_file in sorted(cache_files): |
| try: |
| with open(cache_file, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if not isinstance(data, dict): |
| print(f"[IMAGE CACHE SEED] Skip incompatible cache: {cache_file}", flush=True) |
| stats["skipped"] += 1 |
| continue |
| inserted = 0 |
| existing = 0 |
| for key, value in data.items(): |
| if key in image_search_cache.cache: |
| existing += 1 |
| else: |
| image_search_cache.cache[key] = value |
| inserted += 1 |
| stats["inserted"] += inserted |
| stats["existing"] += existing |
| print( |
| f"[IMAGE CACHE SEED] {cache_file}: {inserted} inserted, {existing} already present", |
| flush=True, |
| ) |
| except Exception as e: |
| print(f"[IMAGE CACHE SEED] Skip unreadable cache {cache_file}: {e}", flush=True) |
| stats["skipped"] += 1 |
|
|
| if stats["inserted"]: |
| image_search_cache.save() |
| print( |
| f"[IMAGE CACHE SEED] Done: {stats['inserted']} inserted from {stats['sources']} file(s); " |
| f"{stats['existing']} duplicate row(s); current cache has {len(image_search_cache.cache)} entries.", |
| flush=True, |
| ) |
| return stats |
|
|
|
|
| TOOL_ABLATION_PROFILES = ("none", "nosearch", "nolocation") |
| IMAGE_BENCHMARK_TOOL_NAMES = {"zoom_in", "image_search", "web_search"} |
| VIDEO_DR_TOOL_NAMES = {"choose_frames", "find_frame", "zoom_in", "image_search", "web_search"} |
| TOOL_NAME_ALIASES = { |
| "image_zoom_in_tool": "zoom_in", |
| "image_search_tool": "image_search", |
| "text_search_tool": "web_search", |
| } |
|
|
|
|
| def canonicalize_tool_name(tool_name: str) -> str: |
| return TOOL_NAME_ALIASES.get(tool_name, tool_name) |
|
|
|
|
| def get_allowed_tool_names(profile: str, task_kind: str) -> set[str]: |
| profile = profile or "none" |
| if task_kind == "video_dr": |
| if profile == "nosearch": |
| return {"choose_frames", "find_frame", "zoom_in"} |
| if profile == "nolocation": |
| return {"find_frame", "image_search", "web_search"} |
| return set(VIDEO_DR_TOOL_NAMES) |
|
|
| if profile == "nosearch": |
| return {"zoom_in"} |
| if profile == "nolocation": |
| return {"image_search", "web_search"} |
| return set(IMAGE_BENCHMARK_TOOL_NAMES) |
|
|
|
|
| def format_tool_names(tool_names: set[str]) -> str: |
| ordered = [ |
| name |
| for name in ["choose_frames", "find_frame", "zoom_in", "image_search", "web_search"] |
| if name in tool_names |
| ] |
| return ", ".join(f"`{name}`" for name in ordered) |
|
|
|
|
| def build_disallowed_tool_response(tool_name: str, allowed_tool_names: set[str]) -> str: |
| return ( |
| "<tool_response>\n" |
| f"Error: tool `{tool_name}` is not available in this evaluation profile. " |
| f"Use only these tools: {format_tool_names(allowed_tool_names)}.\n" |
| "</tool_response>" |
| ) |
|
|
|
|
| def load_tool_config( |
| tool_config_path: str, |
| allowed_tool_names: Optional[set[str]] = None, |
| normalize_image_schema: bool = False, |
| normalize_video_schema: bool = False, |
| ) -> str: |
| with open(tool_config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| tools = config.get("tools", []) |
| tool_definitions = [] |
| loaded_names = [] |
| skipped_names = [] |
|
|
| for tool in tools: |
| schema = tool.get("tool_schema", {}) |
| tool_name = schema.get("function", {}).get("name", "") |
| if allowed_tool_names is not None and tool_name not in allowed_tool_names: |
| skipped_names.append(tool_name or "<unnamed>") |
| continue |
| if normalize_image_schema or normalize_video_schema: |
| schema = json.loads(json.dumps(schema)) |
| if normalize_image_schema: |
| _normalize_image_tool_schema(schema) |
| if normalize_video_schema: |
| _normalize_video_tool_schema(schema) |
| tool_definitions.append(json.dumps(schema, ensure_ascii=False)) |
| loaded_names.append(tool_name or "<unnamed>") |
|
|
| print(f" Loaded {len(tool_definitions)} tools: {', '.join(loaded_names)}") |
| if skipped_names: |
| print(f" Skipped tools for this benchmark: {', '.join(skipped_names)}") |
|
|
| |
| tool_def_str = "\n".join(tool_definitions) |
| tools_section = f"""# Tools |
| |
| You may call one or more functions to assist with the user query. |
| |
| You are provided with function signatures within <tools></tools> XML tags: |
| <tools> |
| {tool_def_str} |
| </tools> |
| |
| For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
| <tool_call> |
| {{"name": <function-name>, "arguments": <args-json-object>}} |
| </tool_call>""" |
|
|
| return tools_section |
|
|
|
|
| def _normalize_image_tool_schema(schema: dict) -> None: |
| """把图像 benchmark 工具描述规范成单图三工具语义。""" |
| function = schema.get("function", {}) |
| tool_name = function.get("name", "") |
| if tool_name == "zoom_in": |
| function["description"] = "Magnify a specific region of the input image for detailed visual analysis." |
| elif tool_name == "image_search": |
| function["description"] = ( |
| "Reverse image search a specific region of the input image to identify unknown entities. " |
| "Use this if you DO NOT understand what the zoomed-in entity is." |
| ) |
|
|
| properties = function.get("parameters", {}).get("properties", {}) |
| bbox_schema = properties.get("bbox") |
| if isinstance(bbox_schema, dict): |
| bbox_schema["description"] = ( |
| "Bounding box [x1, y1, x2, y2] in 0-1000 relative coordinates of the original input image." |
| ) |
|
|
|
|
| def _normalize_video_tool_schema(schema: dict) -> None: |
| """把视频 benchmark 工具描述规范成 VideoDR 评测的 0-1000 bbox 语义。""" |
| function = schema.get("function", {}) |
| tool_name = function.get("name", "") |
| if tool_name == "zoom_in": |
| function["description"] = ( |
| "Magnify a specific region of the frame currently locked by find_frame for detailed visual analysis." |
| ) |
| elif tool_name == "image_search": |
| function["description"] = ( |
| "Reverse image search a specific region of the locked frame to identify unknown entities. " |
| "Use this if you DO NOT understand what the entity in the locked frame is." |
| ) |
|
|
| properties = function.get("parameters", {}).get("properties", {}) |
| bbox_schema = properties.get("bbox") |
| if isinstance(bbox_schema, dict): |
| bbox_schema["items"] = {"type": "integer"} |
| bbox_schema["minItems"] = 4 |
| bbox_schema["maxItems"] = 4 |
| bbox_schema["description"] = ( |
| "Bounding box [x1, y1, x2, y2] in 0-1000 relative coordinates on the frame currently locked by find_frame." |
| ) |
|
|
|
|
| def build_image_tool_system_prompt( |
| task_instruction: str, |
| tools_section: str, |
| max_turns: int = 10, |
| allowed_tool_names: Optional[set[str]] = None, |
| ) -> str: |
| """构造 image benchmark 的 DeepResearch 风格 tool system prompt。""" |
| task_instruction = (task_instruction or "").strip() |
| allowed_tool_names = set(allowed_tool_names or IMAGE_BENCHMARK_TOOL_NAMES) |
| task_block = "" |
| if task_instruction: |
| task_block = f""" |
| # Task-Specific Instruction |
| {task_instruction} |
| """ |
| can_zoom = "zoom_in" in allowed_tool_names |
| can_image_search = "image_search" in allowed_tool_names |
| can_web_search = "web_search" in allowed_tool_names |
|
|
| if can_image_search or can_web_search: |
| role_goal = ( |
| "combining visual clues from a single image with available external search evidence" |
| ) |
| else: |
| role_goal = "using the visual evidence in a single image without external search tools" |
|
|
| if can_zoom: |
| if can_image_search or can_web_search: |
| initial_action = ( |
| "Start from the full image. If the key evidence is already clear, answer directly or use an available search tool. " |
| "Otherwise, use `zoom_in(bbox)` to inspect a relevant region more closely." |
| ) |
| else: |
| initial_action = ( |
| "Start from the full image. If the key evidence is already clear, answer directly. " |
| "Otherwise, use `zoom_in(bbox)` to inspect a relevant region more closely." |
| ) |
| region_rule = ( |
| "`zoom_in` is used for detailed local inspection. After examining a region, continue with another " |
| "allowed tool only when it adds new evidence." |
| ) |
| else: |
| initial_action = ( |
| "Start from the full image. `zoom_in` is not available in this evaluation profile, so choose any bbox " |
| "for `image_search` directly from the original image when identification is needed." |
| ) |
| region_rule = ( |
| "No zoomed crop will be returned. Use the original image and the available search tools to gather evidence." |
| ) |
|
|
| search_rules = [] |
| if can_image_search: |
| search_rules.extend([ |
| "Do NOT rely on a visual guess, memory, or resemblance to identify a real-world person, place, artwork, product, logo, team, landmark, media title, or scientific object.", |
| "Before using a visually inferred entity name in `web_search` or in the final answer, confirm that entity with `image_search(bbox)` on the most diagnostic region.", |
| "If the entity name is directly readable from text in the image, verify the text carefully with the full image or `zoom_in` when available; otherwise use `image_search` for identity confirmation.", |
| "If `image_search` contradicts your initial guess, discard the guess and reason from the search evidence.", |
| ]) |
| elif can_web_search: |
| search_rules.extend([ |
| "`image_search` is not available. Do not invent an entity name from appearance alone; use `web_search(query)` only when the query can be grounded in readable text or unambiguous visual evidence.", |
| "If the visual identity is uncertain and no image search is available, state the uncertainty in `<think>` and answer from the best supported evidence.", |
| ]) |
| else: |
| search_rules.extend([ |
| "`image_search` and `web_search` are not available in this evaluation profile.", |
| "Answer from the image, the question, and any knowledge already contained in the model. Do not call search tools or claim that external verification was performed.", |
| ]) |
| search_rules_text = "\n".join(f" - {line}" for line in search_rules) |
|
|
| loop_rule_parts = [] |
| if can_zoom: |
| loop_rule_parts.append("Do not spend many turns only zooming.") |
| if can_image_search or can_web_search: |
| loop_rule_parts.append("If the current evidence is not sufficient, switch to another allowed evidence-gathering tool or provide the best supported answer.") |
| else: |
| loop_rule_parts.append("If zooming no longer adds visual evidence, stop using tools and answer.") |
| loop_rule = " ".join(loop_rule_parts) |
|
|
| bbox_tools = [name for name in ["zoom_in", "image_search"] if name in allowed_tool_names] |
| if bbox_tools: |
| bbox_rule = ( |
| f"Whenever you use {format_tool_names(set(bbox_tools))}, provide a precise bbox in 0-1000 " |
| "relative coordinates that tightly covers the relevant region in the original input image." |
| ) |
| else: |
| bbox_rule = "No bbox-based inspection tool is available in this evaluation profile." |
|
|
| return f"""# Role |
| You are an advanced Image DeepResearch reasoning assistant. |
| Given a user query that requires {role_goal}, your task is to solve the problem step-by-step while using only the tools exposed below. |
| |
| # Image Context |
| - You are given one input image. |
| - The image may contain multiple entities, regions, text snippets, or visual clues relevant to the question. |
| - You should first reason over the full image, then inspect specific regions only when additional detail is necessary. |
| |
| {tools_section} |
| |
| # Tool Dependency & Workflow Rules (STRICT) |
| 1. **Allowed Tools:** You may only call {format_tool_names(allowed_tool_names)}. Do not call any tool that is not listed in the `<tools>` section. |
| 2. **Initial Action:** {initial_action} |
| 3. **Region-Based Analysis:** {region_rule} |
| 4. **Search Strategy Selection:** |
| {search_rules_text} |
| 5. **Avoid Tool Loops:** {loop_rule} |
| 6. **Bounding Boxes:** {bbox_rule} |
| 7. **Retry Mechanism:** If the current region or search result is not sufficient, inspect another region or try another search. You have a MAXIMUM OF {max_turns} ATTEMPTS (loops) to find the answer. If all attempts are exhausted, provide the best answer you have with an explicit note about remaining uncertainty. |
| {task_block} |
| # Output Format (STRICT) |
| At each turn, you must either issue ONE precise tool call OR provide the final answer. |
| All outputs MUST begin with a thought process enclosed in <think></think> tags. |
| |
| 1. If reasoning continues: |
| <think> ... </think> |
| <tool_call>{{"name": "<function-name>", "arguments": <args-json-object>}}</tool_call> |
| |
| 2. If ready to conclude (after gathering sufficient information): |
| <think> ... </think> |
| <answer> Final answer to the user's query </answer>""" |
|
|
|
|
| def build_video_tool_system_prompt( |
| tools_section: str, |
| allowed_tool_names: set[str], |
| max_turns: int = 10, |
| ) -> str: |
| """构造 VideoDR 工具消融 profile 的 system prompt。""" |
| can_choose = "choose_frames" in allowed_tool_names |
| can_zoom = "zoom_in" in allowed_tool_names |
| can_image_search = "image_search" in allowed_tool_names |
| can_web_search = "web_search" in allowed_tool_names |
|
|
| if can_choose: |
| initial_action = ( |
| "You can directly call `find_frame` if the target is obvious in the initial 64 sparse frames. " |
| "Otherwise, call `choose_frames` to narrow down the search interval." |
| ) |
| retry_action = "loop back to `choose_frames` or `find_frame` to explore another segment/frame" |
| else: |
| initial_action = ( |
| "`choose_frames` is not available. Use the initial 64 sparse frames to select the most relevant " |
| "frame index, then call `find_frame` to lock onto that frame." |
| ) |
| retry_action = "call `find_frame` on another candidate frame from the initial sparse frames" |
|
|
| if can_choose: |
| interval_rule = ( |
| "After calling `choose_frames` and receiving the sub-frames, you MUST call `find_frame` to lock " |
| "onto a specific single frame before performing any detailed actions." |
| ) |
| else: |
| interval_rule = ( |
| "Since `choose_frames` is not available, do not ask for an interval. Move from the initial sparse " |
| "frames directly to `find_frame`." |
| ) |
|
|
| detail_actions = [] |
| if can_zoom: |
| detail_actions.append("use `zoom_in` to inspect local details") |
| if can_image_search: |
| detail_actions.append("use `image_search` to identify an unknown entity") |
| if can_web_search: |
| detail_actions.append("use `web_search` for external facts after the entity or clue is grounded") |
| if detail_actions: |
| detail_rule = ( |
| "`find_frame` locks the working frame. After a successful `find_frame`, you may " |
| + "; ".join(detail_actions) |
| + "." |
| ) |
| else: |
| detail_rule = ( |
| "`find_frame` locks the working frame. No detail or search tool is available after that, so answer " |
| "from the locked frame and the initial sparse frames." |
| ) |
|
|
| search_rules = [] |
| if can_image_search: |
| search_rules.append("If you visually locate an entity but do not know its name or identity, use `image_search(bbox)`.") |
| else: |
| search_rules.append("`image_search` is not available; do not claim reverse-image verification.") |
| if can_web_search: |
| search_rules.append("If you visually locate and recognize the entity, use `web_search(query)` for related external knowledge or facts.") |
| else: |
| search_rules.append("`web_search` is not available; do not call text search or claim web verification.") |
| if not can_image_search and not can_web_search: |
| search_rules.append("For this no-search profile, answer only from video evidence and model-internal knowledge.") |
| search_rules_text = "\n".join(f" - {line}" for line in search_rules) |
|
|
| bbox_tools = [name for name in ["zoom_in", "image_search"] if name in allowed_tool_names] |
| if bbox_tools: |
| bbox_rule = ( |
| f"Whenever you use {format_tool_names(set(bbox_tools))}, provide a precise bbox in 0-1000 " |
| "relative coordinates on the frame currently locked by `find_frame`." |
| ) |
| else: |
| bbox_rule = "No bbox-based tool is available in this evaluation profile." |
|
|
| return f"""# Role |
| You are an advanced Video DeepResearch reasoning assistant. |
| Given a user query that requires combining visual clues from a video with the tools exposed below, your task is to solve the problem step-by-step by deeply navigating the video. |
| |
| # Video Context |
| - The original video has been converted to 1 frame per second (1 fps). |
| - The original frame index (e.g., Frame 1, Frame 10) is watermarked in the top-left corner of every frame. |
| - You are initially provided with 64 uniformly sampled frames from the video. |
| |
| {tools_section} |
| |
| # Tool Dependency & Workflow Rules (STRICT) |
| 1. **Allowed Tools:** You may only call {format_tool_names(allowed_tool_names)}. Do not call any tool that is not listed in the `<tools>` section. |
| 2. **Initial Action:** {initial_action} |
| 3. **Interval to Frame:** {interval_rule} |
| 4. **Detailing & Searching:** {detail_rule} |
| 5. **Search Strategy Selection:** |
| {search_rules_text} |
| 6. **Bounding Boxes:** {bbox_rule} |
| 7. **Retry Mechanism:** If your current evidence fails, yields incorrect info, or does not solve the query, you can {retry_action}. You have a MAXIMUM OF {max_turns} ATTEMPTS (loops) to find the answer. If all attempts are exhausted, provide the best answer you have with an explicit note about remaining uncertainty. |
| |
| # Output Format (STRICT) |
| At each turn, you must either issue ONE precise tool call OR provide the final answer. |
| All outputs MUST begin with a thought process enclosed in <think></think> tags. |
| |
| 1. If reasoning continues: |
| <think> ... </think> |
| <tool_call>{{"name": "<function-name>", "arguments": <args-json-object>}}</tool_call> |
| |
| 2. If ready to conclude: |
| <think> ... </think> |
| <answer> Final answer to the user's query </answer>""" |
|
|
|
|
| |
| |
| |
|
|
| def round_by_factor(number: int, factor: int) -> int: |
| return round(number / factor) * factor |
|
|
|
|
| def ceil_by_factor(number: int, factor: int) -> int: |
| return math.ceil(number / factor) * factor |
|
|
|
|
| def floor_by_factor(number: int, factor: int) -> int: |
| return math.floor(number / factor) * factor |
|
|
|
|
| def smart_resize( |
| height: int, |
| width: int, |
| factor: int = 32, |
| min_pixels: int = 65536, |
| max_pixels: int = 8294400, |
| ) -> tuple[int, int]: |
| """Resize dimensions to be divisible by factor and within pixel limits.""" |
| if max(height, width) / min(height, width) > 200: |
| raise ValueError("Aspect ratio too extreme") |
|
|
| h_bar = max(factor, round_by_factor(height, factor)) |
| w_bar = max(factor, round_by_factor(width, factor)) |
|
|
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = floor_by_factor(int(height / beta), factor) |
| w_bar = floor_by_factor(int(width / beta), factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = ceil_by_factor(int(height * beta), factor) |
| w_bar = ceil_by_factor(int(width * beta), factor) |
|
|
| return h_bar, w_bar |
|
|
|
|
| def process_image( |
| image: Image.Image, |
| min_pixels: int = 65536, |
| max_pixels: int = 8294400, |
| factor: int = 32, |
| qwen_vl_processing: bool = True, |
| ) -> Image.Image: |
| """Process image with smart_resize (Qwen-VL style) or simple max_pixels resize.""" |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| width, height = image.size |
|
|
| if qwen_vl_processing: |
| |
| resized_height, resized_width = smart_resize(height, width, factor, min_pixels, max_pixels) |
| else: |
| |
| if width * height > max_pixels: |
| scale = math.sqrt(max_pixels / (width * height)) |
| resized_width = int(width * scale) |
| resized_height = int(height * scale) |
| else: |
| resized_width, resized_height = width, height |
|
|
| if (resized_width, resized_height) != (width, height): |
| image = image.resize((resized_width, resized_height), Image.BICUBIC) |
|
|
| return image |
|
|
|
|
| def load_and_process_image( |
| path: str, |
| min_pixels: int = 65536, |
| max_pixels: int = 8294400, |
| factor: int = 32, |
| qwen_vl_processing: bool = True, |
| ) -> Image.Image: |
| """Load image from path and process it.""" |
| img = Image.open(path) |
| img.load() |
| return process_image(img, min_pixels, max_pixels, factor, qwen_vl_processing) |
|
|
|
|
| def image_to_base64(img: Image.Image) -> tuple[str, str]: |
| """Convert PIL Image to base64 string.""" |
| if img.mode not in ("RGB", "RGBA"): |
| img = img.convert("RGB") |
|
|
| buffer = io.BytesIO() |
| img.save(buffer, format="PNG") |
| base64_str = base64.standard_b64encode(buffer.getvalue()).decode("utf-8") |
| return base64_str, "image/png" |
|
|
|
|
| def crop_frame_to_rgb_jpeg(frame_path: str, bbox: list[float], output_path: str, quality: int = 90) -> None: |
| """按归一化 bbox 裁剪图片,并强制以 RGB JPEG 保存,兼容 RGBA 输入。""" |
| with Image.open(frame_path) as img: |
| img.load() |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| width, height = img.size |
| x1, y1, x2, y2 = bbox |
| crop_box = ( |
| max(0, min(width, int(round(x1 * width)))), |
| max(0, min(height, int(round(y1 * height)))), |
| max(0, min(width, int(round(x2 * width)))), |
| max(0, min(height, int(round(y2 * height)))), |
| ) |
| if crop_box[2] <= crop_box[0] or crop_box[3] <= crop_box[1]: |
| print( |
| f"[WARN] Invalid crop bbox after scaling: {bbox}; using full image instead", |
| flush=True, |
| ) |
| crop_box = (0, 0, width, height) |
| cropped = img.crop(crop_box) |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
| cropped.save(output_path, "JPEG", quality=quality) |
|
|
|
|
| def crop_image( |
| image: Image.Image, |
| bbox: list, |
| coord_scale: float = 1000.0, |
| min_pixels: int = 65536, |
| max_pixels: int = 8294400, |
| factor: int = 32, |
| qwen_vl_processing: bool = True, |
| padding: tuple = (0.0, 0.0), |
| ) -> Image.Image: |
| """Crop image using bbox [x1, y1, x2, y2] in 0-1000 coords. |
| |
| - Supports configurable padding (default: no padding) |
| - Padding is capped at 600px |
| - Uses smart_resize for output |
| """ |
| img_w, img_h = image.size |
|
|
| |
| x1, y1, x2, y2 = [float(c) / coord_scale for c in bbox] |
|
|
| |
| x1, y1 = max(0, x1), max(0, y1) |
| x2, y2 = min(1, x2), min(1, y2) |
|
|
| |
| if padding[0] > 0 or padding[1] > 0: |
| padding_cap = (600.0 / img_w, 600.0 / img_h) |
| actual_padding = ( |
| min(padding[0], padding_cap[0]), |
| min(padding[1], padding_cap[1]), |
| ) |
| x1 = max(0.0, x1 - actual_padding[0]) |
| y1 = max(0.0, y1 - actual_padding[1]) |
| x2 = min(1.0, x2 + actual_padding[0]) |
| y2 = min(1.0, y2 + actual_padding[1]) |
|
|
| crop_box = (int(x1 * img_w), int(y1 * img_h), int(x2 * img_w), int(y2 * img_h)) |
| cropped = image.crop(crop_box) |
|
|
| if qwen_vl_processing: |
| |
| w, h = cropped.size |
| if w < 28 or h < 28: |
| cropped = cropped.resize((max(w, 28), max(h, 28)), Image.Resampling.LANCZOS) |
| return process_image(cropped, min_pixels, max_pixels, factor, qwen_vl_processing) |
| else: |
| |
| if cropped.mode != "RGB": |
| cropped = cropped.convert("RGB") |
| return cropped |
|
|
|
|
| |
| |
| |
|
|
| def extract_mcq_answer(text: str) -> Optional[str]: |
| """Extract MCQ answer (A, B, C, D) from model output.""" |
| if not text: |
| return None |
|
|
| |
| text = text.rsplit("<|im_start|>assistant", 1)[-1] |
| text = re.split(r'</think(?:ing)?>', text)[-1] |
|
|
| |
| matches = list(re.finditer(r'<answer>(.*?)</answer>', text, re.DOTALL)) |
| if matches: |
| candidate = matches[-1].group(1).strip() |
| |
| if re.match(r'^[A-Da-d]$', candidate): |
| return candidate.upper() |
| |
| punct = re.findall(r'(?:\(([A-D])\)|\[([A-D])\]|(?<![A-Za-z])([A-D])[.\)\]])', candidate, re.IGNORECASE) |
| if punct: |
| last = punct[-1] |
| return (last[0] or last[1] or last[2]).upper() |
| |
| standalone = re.findall(r'(?<![A-Za-z])([A-D])(?![A-Za-z])', candidate) |
| if standalone: |
| return standalone[-1].upper() |
| return None |
|
|
| |
| match = re.search(r'\\boxed\{([^}]+)\}', text, re.IGNORECASE) |
| if match: |
| boxed = match.group(1).strip() |
| if re.search(r'[A-Da-d]', boxed): |
| |
| if re.match(r'^[A-Da-d]$', boxed): |
| return boxed.upper() |
| |
| punct = re.findall(r'(?:\(([A-D])\)|\[([A-D])\]|(?<![A-Za-z])([A-D])[.\)\]])', boxed, re.IGNORECASE) |
| if punct: |
| last = punct[-1] |
| return (last[0] or last[1] or last[2]).upper() |
| |
| standalone = re.findall(r'(?<![A-Za-z])([A-D])(?![A-Za-z])', boxed) |
| if standalone: |
| return standalone[-1].upper() |
| |
| return None |
|
|
| |
| |
| answer_matches = re.findall(r'Answer:\s*\(([A-D])\)', text, re.IGNORECASE) |
| if answer_matches: |
| return answer_matches[-1].upper() |
|
|
| |
| phrase = re.findall(r'(?:correct answer is|answer is)[:\s]*\(([A-D])\)', text, re.IGNORECASE) |
| if phrase: |
| return phrase[-1].upper() |
|
|
| |
| bold_segments = re.findall(r'\*\*[^*]+\*\*', text) |
| for seg in reversed(bold_segments): |
| m = re.search(r'\(([A-D])\)', seg, re.IGNORECASE) |
| if m: |
| return m.group(1).upper() |
|
|
| |
| paren = re.findall(r'\(([A-D])\)', text, re.IGNORECASE) |
| if paren: |
| return paren[-1].upper() |
|
|
| |
| bracket = re.findall(r'(?<![A-Za-z])([A-D])[\)\]]', text, re.IGNORECASE) |
| if bracket: |
| return bracket[-1].upper() |
|
|
| |
| standalone = re.findall(r'(?<![A-Za-z])([A-D])(?![A-Za-z])', text) |
| if standalone: |
| return standalone[-1].upper() |
|
|
| return None |
|
|
|
|
| def check_answer(extracted: Optional[str], ground_truth: list | str) -> bool: |
| """Check MCQ answer against ground truth(s).""" |
| if extracted is None: |
| return False |
| |
| if isinstance(ground_truth, str): |
| ground_truth = [ground_truth] |
| extracted_upper = extracted.upper() |
| for gt in ground_truth: |
| if extracted_upper == gt.upper(): |
| return True |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def is_answer_acceptable(generated: str, gt_answer: str) -> bool: |
| """Check deterministic exact answer equivalence before using LLM judge.""" |
| if not generated or not gt_answer: |
| return False |
| gen_lower = generated.strip().lower() |
| gt_lower = gt_answer.strip().lower() |
| return gen_lower == gt_lower |
|
|
|
|
| def _normalize_answer_for_em(s: str) -> str: |
| """Normalize answer for EM comparison (aligned with reference code).""" |
| def remove_articles(text): |
| if text.strip().lower() in ["a"]: |
| return text |
| return re.sub(r"\b(a|an|the)\b", " ", text) |
|
|
| def white_space_fix(text): |
| return " ".join(text.split()) |
|
|
| def remove_punc(text): |
| exclude = set(string.punctuation) |
| return "".join(ch for ch in text if ch not in exclude) |
|
|
| return white_space_fix(remove_articles(remove_punc(s.lower()))) |
|
|
|
|
| def _em_check(prediction: str, golden_answer: str) -> bool: |
| """Check if prediction matches golden answer after normalization.""" |
| if not prediction or not golden_answer: |
| return False |
| return _normalize_answer_for_em(prediction) == _normalize_answer_for_em(golden_answer) |
|
|
|
|
| def extract_answer_allow_no_tag(text: str) -> str: |
| """Extract answer from <answer> tags, fallback to full response.""" |
| if not text: |
| return "" |
| |
| text = text.rsplit("<|im_start|>assistant", 1)[-1] |
| |
| |
| |
| matches = list(re.finditer(r'<answer>(.*?)</answer>', text, re.DOTALL)) |
| if matches and "<tool_call>" not in text: |
| return matches[-1].group(1).strip() |
| text = re.split(r'</think(?:ing)?>', text)[-1] |
| |
| return text.split('<|im_end|>')[0].strip() |
|
|
|
|
| |
| LLM_JUDGE_SYSTEM_PROMPT = """You are an AI assistant tasked with evaluating the correctness of model responses based on a question and ground truth answer. Your judgment should follow these principles: |
| |
| 1. Consider the question and ground truth answer holistically before evaluating the model's response. |
| 2. Your decision should be strictly Yes or No, based on whether the model's response is factually accurate and aligns with the ground truth answer. |
| 3. If the model response is a more specific form of the ground truth answer, it is correct. |
| 4. If the model response includes all key information but adds minor details, it is correct as long as the extra details are factually correct. |
| 5. If the model response contradicts, modifies, or omits critical parts of the answer, it is incorrect. |
| 6. For numerical values, ensure correctness even when presented in different units. |
| 7. For names, check for first and last name correctness. If the middle name is extra but correct, consider it correct. |
| 8. For yes/no questions, the response must exactly match "Yes" or "No" to be correct. |
| 9. If the model response is a common abbreviation, nickname, or alternative name for the ground truth answer, it is correct. |
| 10. If there are multiple candidate answers, evaluate the model's response against all of them. If the response aligns with at least one candidate according to the rules above, it should be considered correct. |
| 11. For multiple choice questions (A, B, C, D), be more lenient. If the model provides the correct letter choice, even with additional text or formatting, consider it correct. |
| 12. If the model's answer contains the correct choice letter (A, B, C, or D) anywhere in the response, and it's clear this is the intended answer, mark it as correct. |
| 13. Ignore formatting issues like extra parentheses, brackets, or minor text variations as long as the core answer is correct. |
| |
| Your output must be in the following format: |
| <judge>Yes/No</judge> |
| <reason>Explanation of why the answer is correct or incorrect.</reason>""" |
|
|
| LLM_JUDGE_USER_PROMPT = """Question and Model Response Evaluation |
| Question: {question} |
| Ground Truth Answer: {ground_truth} |
| Model Response: {model_response} |
| |
| Evaluation Instructions |
| Evaluate whether the Model Response is correct based on the Question and Ground Truth Answer. Follow the predefined judgment rules and provide a clear Yes/No answer along with a justification. |
| |
| Output Format |
| <judge>Yes/No</judge> |
| <reason>Detailed reasoning following the evaluation principles.</reason>""" |
|
|
| async def llm_judge_acceptable( |
| generated: str, |
| gt_answer: str, |
| question: str, |
| session: aiohttp.ClientSession, |
| max_retries: int = 3, |
| ) -> bool: |
| """用 MARS summarizer 判定生成答案是否正确。""" |
| if not generated or not gt_answer: |
| return False |
|
|
| if _em_check(generated, gt_answer): |
| return True |
|
|
| if is_answer_acceptable(generated, gt_answer): |
| return True |
|
|
| addr = MARS_SUMMARIZER_ADDRESS |
| model = MARS_SUMMARIZER_MODEL |
| if not addr or not model: |
| print("[WARN] MARS_SUMMARIZER not configured, falling back to is_answer_acceptable()") |
| return False |
|
|
| user_prompt = LLM_JUDGE_USER_PROMPT.format( |
| question=question, |
| ground_truth=gt_answer, |
| model_response=generated, |
| ) |
|
|
| payload = { |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": LLM_JUDGE_SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| "max_tokens": 1024, |
| "temperature": 0.0, |
| "chat_template_kwargs": {"enable_thinking": False}, |
| } |
|
|
| for attempt in range(max_retries): |
| try: |
| async with session.post( |
| f"http://{addr}/v1/chat/completions", |
| json=payload, |
| headers={"Content-Type": "application/json"}, |
| timeout=aiohttp.ClientTimeout(total=120), |
| proxy="", |
| ) as resp: |
| if resp.status != 200: |
| text = await resp.text() |
| print(f" [JUDGE] HTTP {resp.status}: {text[:200]}") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(1) |
| continue |
| return False |
|
|
| data = await resp.json() |
| choices = data.get("choices", []) |
| if choices: |
| content = choices[0].get("message", {}).get("content", "") |
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() |
| match = re.search(r"<judge>\s*(Yes|No)\s*</judge>", content, re.IGNORECASE | re.DOTALL) |
| if match: |
| return match.group(1).lower() == "yes" |
| return False |
| except asyncio.TimeoutError: |
| print(f" [JUDGE] Timeout (attempt {attempt + 1}/{max_retries})") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(1) |
| except Exception as e: |
| print(f" [JUDGE] Error: {e} (attempt {attempt + 1}/{max_retries})") |
| if attempt < max_retries - 1: |
| await asyncio.sleep(1) |
|
|
| return False |
|
|
|
|
| async def llm_judge_score( |
| question: str, |
| model_answer: str, |
| ground_truth: list | str, |
| image_path: str, |
| judge_client: str, |
| judge_base_url: str, |
| judge_api_key: str, |
| judge_temperature: float = 0.0, |
| shared_session: aiohttp.ClientSession | None = None, |
| ) -> float: |
| """Use VideoDR reference LLM judge logic. Returns 1.0 or 0.0.""" |
| del image_path, judge_client, judge_base_url, judge_api_key, judge_temperature |
|
|
| extracted_answer = extract_answer_allow_no_tag(model_answer) |
| if isinstance(ground_truth, str): |
| ground_truth = [ground_truth] |
|
|
| if shared_session is not None: |
| for gt in ground_truth: |
| if await llm_judge_acceptable(extracted_answer, gt, question, shared_session): |
| return 1.0 |
| return 0.0 |
|
|
| timeout = aiohttp.ClientTimeout(total=3600) |
| async with create_http_session(timeout) as session: |
| for gt in ground_truth: |
| if await llm_judge_acceptable(extracted_answer, gt, question, session): |
| return 1.0 |
| return 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def _read_first_jsonl_record(path: str) -> dict: |
| """读取 JSONL 的第一条有效样本,用于推断 benchmark 类型。""" |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| return json.loads(line) |
| return {} |
|
|
|
|
| def _count_jsonl_records(path: str) -> int: |
| """统计 JSONL 样本数,写入自动生成的数据集配置。""" |
| total = 0 |
| with open(path) as f: |
| for line in f: |
| if line.strip(): |
| total += 1 |
| return total |
|
|
|
|
| def _prompt_text_from_raw_record(raw: dict) -> str: |
| """从原始样本中取 user prompt 文本。""" |
| prompt = raw.get("prompt", []) |
| if isinstance(prompt, list): |
| for msg in prompt: |
| if isinstance(msg, dict) and msg.get("role") == "user": |
| return str(msg.get("content", "")) |
| return "" |
|
|
|
|
| def _ground_truths_from_raw_record(raw: dict) -> list[str]: |
| """从原始样本中取 ground_truth 列表。""" |
| reward_model = raw.get("reward_model", {}) |
| ground_truth = reward_model.get("ground_truth", []) |
| if isinstance(ground_truth, str): |
| return [ground_truth] |
| if isinstance(ground_truth, list): |
| return [str(item) for item in ground_truth] |
| return [] |
|
|
|
|
| def _looks_like_mcq_record(raw: dict) -> bool: |
| """根据首条样本判断是否为 A-D 多选题。""" |
| ground_truths = _ground_truths_from_raw_record(raw) |
| if not ground_truths: |
| return False |
| if not all(re.fullmatch(r"[A-Da-d]", gt.strip()) for gt in ground_truths): |
| return False |
| prompt_text = _prompt_text_from_raw_record(raw) |
| return bool(re.search(r"\([A-Da-d]\)", prompt_text)) |
|
|
|
|
| def build_eval_root_dataset_config( |
| eval_root: str, |
| benchmarks: Optional[list[str]] = None, |
| ) -> dict: |
| """从 `data/eval` 子目录自动构造 `eval.py` 可用的数据集配置。""" |
| root = os.path.abspath(eval_root or DEFAULT_EVAL_ROOT) |
| if not os.path.isdir(root): |
| raise FileNotFoundError(f"eval root not found: {root}") |
|
|
| if benchmarks: |
| benchmark_names = benchmarks |
| else: |
| benchmark_names = sorted( |
| name for name in os.listdir(root) |
| if os.path.isdir(os.path.join(root, name)) |
| and os.path.exists(os.path.join(root, name, "data.jsonl")) |
| ) |
|
|
| config = {} |
| missing = [] |
| for benchmark_name in benchmark_names: |
| benchmark_dir = os.path.join(root, benchmark_name) |
| annotation = os.path.join(benchmark_dir, "data.jsonl") |
| if not os.path.exists(annotation): |
| missing.append(benchmark_name) |
| continue |
|
|
| first_record = _read_first_jsonl_record(annotation) |
| is_mcq = benchmark_name in DEFAULT_MCQ_BENCHMARKS or _looks_like_mcq_record(first_record) |
| reward_fn = ["em_score_mcq"] if is_mcq else [] |
| unused_reward_fn = [] if is_mcq else ["llm_score"] |
|
|
| config[benchmark_name] = { |
| "root": benchmark_dir, |
| "annotation": annotation, |
| "length": _count_jsonl_records(annotation), |
| "repeat_time": 1, |
| "reward_fn": reward_fn, |
| "unused_reward_fn": unused_reward_fn, |
| "input_template": { |
| "name": "general", |
| "arguments": { |
| "system_prompt": DEFAULT_IMAGE_DATASET_SYSTEM_PROMPT, |
| "format_instruction": "", |
| "add_image_path": False, |
| }, |
| }, |
| "comment": benchmark_name, |
| } |
|
|
| if missing: |
| raise FileNotFoundError( |
| "benchmark data.jsonl not found under eval root: " + ", ".join(missing) |
| ) |
| if not config: |
| raise ValueError(f"no benchmark data.jsonl found under eval root: {root}") |
| return config |
|
|
|
|
| def load_datasets( |
| config_path: str | dict, |
| data_root: str = "", |
| video_dr_system_prompt: str = VIDEO_DR_SYSTEM_PROMPT, |
| ) -> tuple[list[dict], dict]: |
| """Load datasets from JSON config. Returns (samples, dataset_configs).""" |
| if isinstance(config_path, dict): |
| config = config_path |
| else: |
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| all_samples = [] |
| dataset_configs = {} |
|
|
| for dataset_name, ds_config in config.items(): |
| ds_type = ds_config.get("type", "") |
| root = ds_config.get("root", "") |
| annotation = ds_config.get("annotation", "") |
| video_root = ds_config.get("video_root", "") |
|
|
| if data_root and not os.path.isabs(root): |
| root = os.path.join(data_root, root) |
| if data_root and not os.path.isabs(annotation): |
| annotation = os.path.join(data_root, annotation) |
| if data_root and video_root and not os.path.isabs(video_root): |
| video_root = os.path.join(data_root, video_root) |
|
|
| if not os.path.exists(annotation): |
| print(f"Warning: {annotation} not found, skipping {dataset_name}") |
| continue |
|
|
| samples = [] |
| if ds_type == "video_dr_csv" or annotation.lower().endswith(".csv"): |
| if not video_root: |
| video_root = ds_config.get("root", "") |
| if data_root and video_root and not os.path.isabs(video_root): |
| video_root = os.path.join(data_root, video_root) |
| if not video_root: |
| raise ValueError( |
| f"Dataset '{dataset_name}' is video_dr_csv but missing video_root/root" |
| ) |
| samples = load_videodr_csv_samples(annotation, video_root, dataset_name) |
| else: |
| with open(annotation) as f: |
| for idx, line in enumerate(f): |
| raw = json.loads(line) |
| prompt = raw.get("prompt", []) |
| question = "" |
| for msg in prompt: |
| if msg.get("role") == "user": |
| question = msg.get("content", "").replace("<image>", "").strip() |
| break |
|
|
| reward_model = raw.get("reward_model", {}) |
| |
| answer = reward_model.get("ground_truth", [""]) |
| images = raw.get("image", []) |
| image_path = os.path.join(root, images[0]) if images else "" |
|
|
| |
| |
| |
| IMAGE_SEARCH_MAX_RESULTS = 5 |
| image_search_kwargs = {} |
| if "image_search_title_list" in raw: |
| titles = raw["image_search_title_list"] |
| image_search_kwargs["image_search_title_list"] = titles[:IMAGE_SEARCH_MAX_RESULTS] if titles else None |
| if "image_search_thumbnail_list" in raw: |
| thumbs = raw["image_search_thumbnail_list"] |
| image_search_kwargs["image_search_thumbnail_list"] = thumbs[:IMAGE_SEARCH_MAX_RESULTS] if thumbs else None |
|
|
| sample = { |
| "id": f"{dataset_name}-{idx}", |
| "question": question, |
| "answer": answer, |
| "image_path": image_path, |
| "judge_image_path": image_path, |
| "dataset": dataset_name, |
| "data_root": root, |
| "task_kind": "image", |
| } |
| if image_search_kwargs: |
| sample["image_search_data"] = image_search_kwargs |
|
|
| samples.append(sample) |
|
|
| all_samples.extend(samples) |
|
|
| |
| reward_fn = ds_config.get("reward_fn", []) |
| unused_reward_fn = ds_config.get("unused_reward_fn", []) |
| score_methods = list(set(reward_fn + unused_reward_fn)) |
|
|
| if not score_methods: |
| raise ValueError(f"Dataset '{dataset_name}' has no scoring methods (reward_fn or unused_reward_fn)") |
|
|
| valid_methods = {"em_score_mcq", "llm_score"} |
| invalid = set(score_methods) - valid_methods |
| if invalid: |
| raise ValueError(f"Dataset '{dataset_name}' has invalid scoring methods: {invalid}") |
|
|
| |
| input_template = ds_config.get("input_template", {}) |
| template_args = input_template.get("arguments", {}) |
| system_prompt = template_args.get("system_prompt", "") |
| format_instruction = template_args.get("format_instruction", "") |
| if ds_type == "video_dr_csv" and not system_prompt: |
| system_prompt = video_dr_system_prompt |
|
|
| dataset_configs[dataset_name] = { |
| "score_methods": score_methods, |
| "system_prompt": system_prompt, |
| "format_instruction": format_instruction, |
| "task_kind": "video_dr" if ds_type == "video_dr_csv" or annotation.lower().endswith(".csv") else "image", |
| } |
|
|
| print(f" Loaded {len(samples)} samples from {dataset_name}") |
| print(f" reward_fn: {reward_fn}") |
| print(f" unused_reward_fn: {unused_reward_fn}") |
|
|
| return all_samples, dataset_configs |
|
|
|
|
| |
| |
| |
|
|
| def _message_content_to_text(content) -> str: |
| """把 OpenAI 兼容返回里的 content/reasoning_content 统一转成纯文本。""" |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| text_parts = [] |
| for item in content: |
| if isinstance(item, str): |
| text_parts.append(item) |
| elif isinstance(item, dict) and item.get("type") == "text": |
| text_parts.append(item.get("text", "")) |
| return "".join(text_parts) |
| return "" |
|
|
|
|
| def _normalize_openai_tool_calls(raw_tool_calls) -> list[dict]: |
| """把 OpenAI/sglang 风格的 tool_calls 规范化成统一结构。""" |
| normalized = [] |
| for tool_call in raw_tool_calls or []: |
| if isinstance(tool_call, dict): |
| function = tool_call.get("function") or {} |
| name = function.get("name", "") |
| raw_args = function.get("arguments") |
| else: |
| function = getattr(tool_call, "function", None) |
| name = getattr(function, "name", "") if function is not None else "" |
| raw_args = getattr(function, "arguments", None) if function is not None else None |
|
|
| if isinstance(raw_args, str): |
| try: |
| args = json.loads(raw_args) |
| except Exception: |
| args = {} |
| elif isinstance(raw_args, dict): |
| args = raw_args |
| else: |
| args = {} |
|
|
| normalized.append({ |
| "name": name, |
| "arguments": args, |
| }) |
| return normalized |
|
|
|
|
| def reconstruct_assistant_text( |
| content: str = "", |
| reasoning_content: str = "", |
| tool_calls: Optional[list[dict]] = None, |
| ) -> str: |
| """把 reasoning_content/tool_calls/content 拼回完整 assistant 文本。""" |
| content = (content or "").strip() |
| reasoning_content = (reasoning_content or "").strip() |
| tool_calls = tool_calls or [] |
|
|
| if not reasoning_content and not tool_calls: |
| return content |
|
|
| parts = [] |
| if reasoning_content: |
| parts.append(f"<think>\n{reasoning_content}\n</think>") |
| for tool_call in tool_calls: |
| call_json = json.dumps({ |
| "name": tool_call.get("name", ""), |
| "arguments": tool_call.get("arguments", {}), |
| }, ensure_ascii=False) |
| parts.append(f"<tool_call>\n{call_json}\n</tool_call>") |
| if content: |
| parts.append(content) |
| return "\n".join(parts).strip() |
|
|
|
|
| async def call_openai_api( |
| messages: list[dict], |
| model: str, |
| base_url: str, |
| api_key: str = "", |
| **kwargs, |
| ) -> dict: |
| """Call OpenAI-compatible API.""" |
| request_debug = _summarize_messages_for_debug(messages) |
| debug_context = kwargs.get("request_debug_context") or {} |
|
|
| |
| |
| body = { |
| "model": normalize_model_name_for_client(model, "openai"), |
| "messages": messages, |
| "max_tokens": kwargs.get("max_tokens", 4096), |
| } |
| if "temperature" in kwargs: |
| body["temperature"] = kwargs["temperature"] |
| if "top_p" in kwargs: |
| body["top_p"] = kwargs["top_p"] |
| if "top_k" in kwargs: |
| body["top_k"] = kwargs["top_k"] |
| if "presence_penalty" in kwargs: |
| body["presence_penalty"] = kwargs["presence_penalty"] |
| if "repetition_penalty" in kwargs: |
| body["repetition_penalty"] = kwargs["repetition_penalty"] |
| if "seed" in kwargs and kwargs["seed"] is not None: |
| body["seed"] = kwargs["seed"] |
| |
| if kwargs.get("extra_body"): |
| body.update(kwargs["extra_body"]) |
|
|
| headers = {"Content-Type": "application/json"} |
| if api_key: |
| headers["Authorization"] = f"Bearer {api_key}" |
|
|
| url = _openai_chat_completions_url(base_url) |
|
|
| try: |
| timeout = aiohttp.ClientTimeout(total=300) |
| async with create_http_session(timeout) as session: |
| async with session.post(url, headers=headers, json=body) as resp: |
| raw_text = await resp.text() |
| try: |
| data = json.loads(raw_text) |
| except json.JSONDecodeError: |
| data = {} |
| if resp.status != 200: |
| error = _extract_http_error_message(resp.status, data, raw_text) |
| raw_error_text = _truncate_debug_text(raw_text, limit=4000) |
| print( |
| "[MODEL HTTP ERROR]", |
| json.dumps( |
| { |
| "status": resp.status, |
| "error": error, |
| "sample_id": debug_context.get("sample_id"), |
| "turn": debug_context.get("turn"), |
| "task_kind": debug_context.get("task_kind"), |
| "request_debug": request_debug, |
| "raw_response": raw_error_text, |
| }, |
| ensure_ascii=False, |
| ), |
| flush=True, |
| ) |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": error, |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
|
|
| choice = data["choices"][0] |
| message = choice.get("message", {}) or {} |
| content = _message_content_to_text(message.get("content")) |
| reasoning_content = _message_content_to_text(message.get("reasoning_content")) |
| tool_calls = _normalize_openai_tool_calls(message.get("tool_calls")) |
| full_text = reconstruct_assistant_text( |
| content=content, |
| reasoning_content=reasoning_content, |
| tool_calls=tool_calls, |
| ) |
| return { |
| "content": content, |
| "reasoning_content": reasoning_content, |
| "tool_calls": tool_calls, |
| "full_text": full_text, |
| "finish_reason": choice.get("finish_reason", "stop"), |
| "error": None, |
| "request_debug": request_debug, |
| } |
| except Exception as e: |
| error = _truncate_debug_text(str(e), limit=1000) |
| print( |
| "[MODEL REQUEST EXCEPTION]", |
| json.dumps( |
| { |
| "error": error, |
| "sample_id": debug_context.get("sample_id"), |
| "turn": debug_context.get("turn"), |
| "task_kind": debug_context.get("task_kind"), |
| "request_debug": request_debug, |
| }, |
| ensure_ascii=False, |
| ), |
| flush=True, |
| ) |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": error, |
| "error_status_code": None, |
| "error_raw_response": "", |
| "request_debug": request_debug, |
| } |
|
|
|
|
| def _parse_gateway_json(value): |
| """Parse JSON strings returned by the gateway; leave other values unchanged.""" |
| if isinstance(value, str): |
| stripped = value.strip() |
| if not stripped: |
| return value |
| try: |
| return json.loads(stripped) |
| except Exception: |
| return value |
| return value |
|
|
|
|
| def _extract_gateway_response_payload(data): |
| """Find the model payload inside known gateway wrapper shapes.""" |
| data = _parse_gateway_json(data) |
| if not isinstance(data, dict): |
| return data |
|
|
| for key in ("model_output", "response", "result", "data"): |
| if key in data: |
| value = _parse_gateway_json(data.get(key)) |
| if isinstance(value, dict) and value is not data: |
| return _extract_gateway_response_payload(value) |
| return value |
| return data |
|
|
|
|
| def _gateway_text_value(value) -> str: |
| """Extract text from a known gateway/OpenAI content value.""" |
| value = _parse_gateway_json(value) |
| if isinstance(value, str): |
| return value |
| if isinstance(value, list): |
| parts = [_gateway_text_value(item) for item in value] |
| return "\n".join(part for part in parts if part).strip() |
| if not isinstance(value, dict): |
| return "" |
|
|
| for key in ("output_text", "text", "value", "summary"): |
| text = value.get(key) |
| if isinstance(text, str) and text: |
| return text |
| if isinstance(text, (list, dict)): |
| nested = _gateway_text_value(text) |
| if nested: |
| return nested |
|
|
| content = value.get("content") |
| if content is not None: |
| nested = _gateway_text_value(content) |
| if nested: |
| return nested |
|
|
| return "" |
|
|
|
|
| def _extract_gateway_text(payload) -> str: |
| """Extract assistant text from OpenAI Responses, chat, or plain gateway payloads.""" |
| payload = _parse_gateway_json(payload) |
|
|
| if isinstance(payload, str): |
| return payload |
|
|
| if not isinstance(payload, dict): |
| return "" |
|
|
| output_text = _gateway_text_value(payload.get("output_text")) |
| if output_text: |
| return output_text |
|
|
| choices = payload.get("choices") |
| if isinstance(choices, list) and choices: |
| message = choices[0].get("message", {}) or {} |
| content = _message_content_to_text(message.get("content")) |
| reasoning_content = _message_content_to_text(message.get("reasoning_content")) |
| tool_calls = _normalize_openai_tool_calls(message.get("tool_calls")) |
| return reconstruct_assistant_text(content, reasoning_content, tool_calls) |
|
|
| output = payload.get("output") |
| if isinstance(output, list): |
| parts = [] |
| for item in output: |
| if not isinstance(item, dict): |
| continue |
| if item.get("type") == "message": |
| text = _gateway_text_value(item.get("content", [])) |
| if text: |
| parts.append(text) |
| elif isinstance(item.get("content"), list): |
| text = _gateway_text_value(item.get("content", [])) |
| if text: |
| parts.append(text) |
| elif isinstance(item.get("text"), str): |
| parts.append(item["text"]) |
| return "\n".join(parts).strip() |
|
|
| content = payload.get("content") |
| if content is not None: |
| text = _message_content_to_text(content) |
| return text or _gateway_text_value(content) |
|
|
| return "" |
|
|
|
|
| def _gateway_error_summary(payload) -> str: |
| """Build a compact error summary from a gateway/OpenAI response payload.""" |
| payload = _parse_gateway_json(payload) |
| if not isinstance(payload, dict): |
| return str(payload)[:500] |
|
|
| fields = [] |
| for key in ("id", "status", "error", "incomplete_details"): |
| value = payload.get(key) |
| if value: |
| fields.append(f"{key}={str(value)[:200]}") |
| output = payload.get("output") |
| if isinstance(output, list): |
| fields.append(f"output_items={len(output)}") |
| item_types = [str(item.get("type")) for item in output if isinstance(item, dict)] |
| if item_types: |
| fields.append(f"output_types={','.join(item_types[:10])}") |
| if fields: |
| return "; ".join(fields)[:500] |
| return str(payload)[:500] |
|
|
|
|
| def _gateway_uses_completion_token_limit(model: str) -> bool: |
| """Return whether this gateway model expects max_completion_tokens.""" |
| normalized = normalize_model_name_for_client(model, "gateway").lower() |
| return normalized.startswith("gpt-5") |
|
|
|
|
| def _gateway_uses_responses_api(model: str) -> bool: |
| """Return whether this gateway model expects OpenAI Responses API params.""" |
| return _is_gpt54_gateway_model(model) |
|
|
|
|
| def _gateway_responses_content(content): |
| """Convert OpenAI chat-style content parts to Responses API content parts.""" |
| if isinstance(content, str): |
| return content |
| if not isinstance(content, list): |
| return str(content or "") |
|
|
| converted = [] |
| for item in content: |
| if isinstance(item, str): |
| if item: |
| converted.append({"type": "input_text", "text": item}) |
| continue |
| if not isinstance(item, dict): |
| continue |
| item_type = item.get("type") |
| if item_type == "text": |
| text = item.get("text", "") |
| if text: |
| converted.append({"type": "input_text", "text": text}) |
| elif item_type == "image_url": |
| image_url = item.get("image_url", {}) |
| url = image_url.get("url", "") if isinstance(image_url, dict) else str(image_url or "") |
| if url: |
| converted.append({"type": "input_image", "image_url": url}) |
| return converted or "" |
|
|
|
|
| def _gateway_responses_input(messages: list[dict]) -> list[dict]: |
| """Convert OpenAI chat-style messages to Responses API input messages.""" |
| converted = [] |
| for message in messages: |
| role = message.get("role", "user") |
| converted.append( |
| { |
| "role": role, |
| "content": _gateway_responses_content(message.get("content", "")), |
| } |
| ) |
| return converted |
|
|
|
|
| def _is_retryable_gateway_result(result: dict) -> bool: |
| """Return whether a gateway model-call failure is likely transient.""" |
| if _is_fatal_gateway_quota_result(result): |
| return False |
| status = result.get("error_status_code") |
| if status in {408, 409, 425, 429, 500, 502, 503, 504, 520, 521, 522, 523, 524}: |
| return True |
| text = " ".join( |
| str(result.get(key) or "") |
| for key in ("error", "error_raw_response") |
| ).lower() |
| retry_markers = ( |
| "connection reset", |
| "connection aborted", |
| "server disconnected", |
| "connection closed", |
| "read tcp", |
| "broken pipe", |
| "timeout", |
| "timed out", |
| "temporarily unavailable", |
| "empty gateway response", |
| "error_code': '50001'", |
| '"error_code": "50001"', |
| ) |
| return any(marker in text for marker in retry_markers) |
|
|
|
|
| def _is_fatal_gateway_quota_result(result: dict) -> bool: |
| """Return whether a gateway failure means the whole eval should stop.""" |
| text = " ".join( |
| str(result.get(key) or "") |
| for key in ("error", "error_raw_response") |
| ).lower() |
| fatal_markers = ( |
| "application quota has been exhausted", |
| "quota has been exhausted", |
| "quota exhausted", |
| "insufficient quota", |
| ) |
| return any(marker in text for marker in fatal_markers) |
|
|
|
|
| async def _maybe_retry_gateway_result( |
| result: dict, |
| *, |
| attempt: int, |
| max_attempts: int, |
| model: str, |
| request_debug: dict, |
| initial_delay: float, |
| ) -> bool: |
| """Sleep and return True when the caller should retry a gateway request.""" |
| result.setdefault("request_debug", request_debug) |
| result["request_debug"]["gateway_attempt"] = attempt |
| result["request_debug"]["gateway_max_attempts"] = max_attempts |
| if _is_fatal_gateway_quota_result(result): |
| error = _truncate_debug_text(str(result.get("error") or ""), limit=500) |
| raise FatalAPIError(f"[GATEWAY ERROR] {model} quota exhausted: {error}") |
| if attempt >= max_attempts or not _is_retryable_gateway_result(result): |
| return False |
| delay = min(initial_delay * (2 ** (attempt - 1)), 20.0) |
| print( |
| "[GATEWAY RETRY]", |
| json.dumps( |
| { |
| "model": model, |
| "attempt": attempt, |
| "max_attempts": max_attempts, |
| "sleep_seconds": round(delay, 2), |
| "status": result.get("error_status_code"), |
| "error": _truncate_debug_text(str(result.get("error") or ""), limit=500), |
| "sample_id": request_debug.get("sample_id"), |
| "turn": request_debug.get("turn"), |
| "task_kind": request_debug.get("task_kind"), |
| }, |
| ensure_ascii=False, |
| ), |
| flush=True, |
| ) |
| await asyncio.sleep(delay) |
| return True |
|
|
|
|
| async def call_gateway_api( |
| messages: list[dict], |
| model: str, |
| base_url: str, |
| api_key: str = "", |
| **kwargs, |
| ) -> dict: |
| """Call the company gateway for GPT models.""" |
| request_debug = _summarize_messages_for_debug(messages) |
| debug_context = kwargs.get("request_debug_context") or {} |
| request_debug.update(debug_context) |
| gateway_username, gateway_userid, gateway_token = _get_model_gateway_credentials(model, api_key) |
| if not gateway_token: |
| token_hint = ( |
| "MODEL_GATEWAY_GPT54_TOKEN or GPT54_GATEWAY_TOKEN is required" |
| if _is_gpt54_gateway_model(model) |
| else "MODEL_GATEWAY_GEMINI_TOKEN or GEMINI_GATEWAY_TOKEN is required" |
| if _is_gemini_gateway_model(model) |
| else "MODEL_GATEWAY_TOKEN or GATEWAY_TOKEN is required" |
| ) |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": token_hint, |
| "request_debug": request_debug, |
| } |
|
|
| if _gateway_uses_responses_api(model): |
| params = { |
| "input": _gateway_responses_input(messages), |
| "temperature": kwargs.get("temperature", 0.7), |
| "max_output_tokens": kwargs.get("max_tokens", 4096), |
| } |
| else: |
| params = { |
| "messages": messages, |
| "temperature": kwargs.get("temperature", 0.7), |
| } |
| token_limit_key = ( |
| "max_completion_tokens" |
| if _gateway_uses_completion_token_limit(model) |
| else "max_tokens" |
| ) |
| params[token_limit_key] = kwargs.get("max_tokens", 4096) |
|
|
| payload = { |
| "sec_info": { |
| "username": gateway_username, |
| "userid": gateway_userid, |
| "token": gateway_token, |
| }, |
| "model_type": _gateway_model_type(model), |
| "model_name": normalize_model_name_for_client(model, "gateway"), |
| "params": json.dumps(params, ensure_ascii=False), |
| } |
| headers = { |
| "Content-Type": "application/json", |
| "User-Agent": "ifbook-http-client", |
| } |
|
|
| raw_max_retries = kwargs.get( |
| "gateway_max_retries", |
| os.environ.get("MODEL_GATEWAY_MAX_RETRIES", "10"), |
| ) |
| try: |
| max_attempts = max(1, int(raw_max_retries)) |
| except (TypeError, ValueError): |
| max_attempts = 10 |
| try: |
| initial_delay = max( |
| 0.1, |
| float( |
| os.environ.get("MODEL_GATEWAY_RETRY_INITIAL_DELAY") |
| or os.environ.get("MODEL_GATEWAY_INITIAL_RETRY_DELAY", "1.0") |
| ), |
| ) |
| except ValueError: |
| initial_delay = 1.0 |
|
|
| last_result = None |
| for attempt in range(1, max_attempts + 1): |
| try: |
| timeout = aiohttp.ClientTimeout(total=kwargs.get("model_request_timeout", 300)) |
| async with create_http_session(timeout) as session: |
| async with session.post(base_url, headers=headers, json=payload) as resp: |
| text = await resp.text() |
| raw_error_text = _truncate_debug_text(text, limit=4000) |
| if resp.status != 200: |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"HTTP {resp.status}: {raw_error_text[:500]}", |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
| try: |
| data = json.loads(text) |
| except Exception: |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"Invalid JSON response: {raw_error_text[:500]}", |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
|
|
| code = str(data.get("code", "0")) |
| if code not in ("0", ""): |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": str(data.get("message", data))[:500], |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
|
|
| response_payload = _extract_gateway_response_payload(data) |
| if isinstance(response_payload, dict): |
| payload_error = response_payload.get("error") |
| if payload_error: |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"Gateway response error: {_gateway_error_summary(response_payload)}", |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
| status = response_payload.get("status") |
| if status and status not in ("completed", "succeeded", "success"): |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"Gateway response not completed: {_gateway_error_summary(response_payload)}", |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
|
|
| content = _extract_gateway_text(response_payload) |
| if not content.strip(): |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"Empty gateway response: {_gateway_error_summary(response_payload)}", |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
|
|
| request_debug["gateway_attempt"] = attempt |
| request_debug["gateway_max_attempts"] = max_attempts |
| return { |
| "content": content, |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": content, |
| "finish_reason": "stop", |
| "error": None, |
| "request_debug": request_debug, |
| } |
| except Exception as e: |
| error = f"{type(e).__name__}: {str(e)}".strip() |
| result = { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": _truncate_debug_text(error, limit=1000), |
| "error_status_code": None, |
| "error_raw_response": "", |
| "request_debug": request_debug, |
| } |
| last_result = result |
| if await _maybe_retry_gateway_result( |
| result, |
| attempt=attempt, |
| max_attempts=max_attempts, |
| model=model, |
| request_debug=request_debug, |
| initial_delay=initial_delay, |
| ): |
| continue |
| return result |
|
|
| return last_result or { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": "Gateway request failed after retries", |
| "error_status_code": None, |
| "error_raw_response": "", |
| "request_debug": request_debug, |
| } |
|
|
|
|
| async def call_gemini_api( |
| messages: list[dict], |
| model: str, |
| base_url: str, |
| api_key: str, |
| **kwargs, |
| ) -> dict: |
| """Call Gemini API.""" |
| |
| contents = [] |
| for msg in messages: |
| role = "user" if msg["role"] == "user" else "model" |
| parts = [] |
| content = msg.get("content", []) |
| if isinstance(content, str): |
| parts.append({"text": content}) |
| else: |
| for item in content: |
| if isinstance(item, str): |
| parts.append({"text": item}) |
| elif item.get("type") == "text": |
| parts.append({"text": item["text"]}) |
| elif item.get("type") == "image_url": |
| url = item["image_url"]["url"] |
| if url.startswith("data:"): |
| mime, data = url.split(";base64,", 1) |
| mime = mime.replace("data:", "") |
| parts.append({"inlineData": {"mimeType": mime, "data": data}}) |
| contents.append({"role": role, "parts": parts}) |
|
|
| body = { |
| "contents": contents, |
| "generationConfig": { |
| "temperature": kwargs.get("temperature", 0.7), |
| "maxOutputTokens": kwargs.get("max_tokens", 4096), |
| } |
| } |
|
|
| url = f"{base_url.rstrip('/')}/models/{model}:generateContent" |
| headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"} |
|
|
| try: |
| timeout = aiohttp.ClientTimeout(total=300) |
| async with create_http_session(timeout) as session: |
| async with session.post(url, headers=headers, json=body) as resp: |
| data = await resp.json() |
| if resp.status != 200: |
| error = data.get('error', {}).get('message', f'HTTP {resp.status}') |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": error, |
| } |
|
|
| content = "" |
| if data.get("candidates"): |
| parts = data["candidates"][0].get("content", {}).get("parts", []) |
| content = "".join(p.get("text", "") for p in parts) |
|
|
| return { |
| "content": content, |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": content, |
| "finish_reason": "stop", |
| "error": None, |
| } |
| except Exception as e: |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": str(e), |
| } |
|
|
|
|
| class VertexAccountPool: |
| """Vertex Gemini service-account pool with per-request rotation.""" |
|
|
| def __init__( |
| self, |
| pool_file: str, |
| default_location: str = "global", |
| cooldown_seconds: float = 60.0, |
| ): |
| self.pool_file = pool_file |
| self.default_location = default_location or "global" |
| self.cooldown_seconds = max(0.0, float(cooldown_seconds or 0.0)) |
| self.accounts = self._load_accounts(pool_file) |
| if not self.accounts: |
| raise ValueError(f"Vertex account pool is empty or missing: {pool_file}") |
| self._lock = asyncio.Lock() |
| self._next_index = 0 |
| self._cooldown_until = {} |
| self._auth_request = None |
| self._init_credentials() |
|
|
| def _load_accounts(self, pool_file: str) -> list[dict]: |
| if not pool_file: |
| return [] |
| path = Path(pool_file).expanduser() |
| if not path.is_absolute(): |
| path = Path.cwd() / path |
| if not path.exists(): |
| return [] |
|
|
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| if isinstance(data, list): |
| records = data |
| elif isinstance(data, dict) and isinstance(data.get("accounts"), list): |
| records = data["accounts"] |
| else: |
| raise ValueError( |
| f"Vertex account pool must be a JSON list or {{'accounts': [...]}}: {path}" |
| ) |
|
|
| accounts = [] |
| for idx, item in enumerate(records, start=1): |
| if not isinstance(item, dict): |
| raise ValueError(f"Vertex account #{idx} must be an object") |
| if item.get("enabled", True) is False: |
| continue |
| project_id = str(item.get("project_id") or "").strip() |
| if not project_id: |
| raise ValueError(f"Vertex account #{idx} missing project_id") |
| location = str(item.get("location") or self.default_location).strip() or "global" |
| key_path = str(item.get("service_account_key") or item.get("json_path") or "").strip() |
| key_info = item.get("service_account_info") |
| if key_path: |
| resolved = Path(key_path).expanduser() |
| if not resolved.is_absolute(): |
| resolved = path.parent / resolved |
| key_path = str(resolved.resolve()) |
| if not key_path and not isinstance(key_info, dict): |
| raise ValueError( |
| f"Vertex account #{idx} must provide service_account_key/json_path " |
| f"or service_account_info. project_id={project_id}" |
| ) |
| accounts.append({ |
| "account_name": item.get("account_name") or item.get("name") or f"vertex-{idx}", |
| "project_id": project_id, |
| "location": location, |
| "service_account_key": key_path, |
| "service_account_info": key_info if isinstance(key_info, dict) else None, |
| "credentials": None, |
| }) |
| return accounts |
|
|
| def _init_credentials(self) -> None: |
| try: |
| import google.auth.transport.requests |
| from google.oauth2 import service_account |
| except ImportError as exc: |
| raise ImportError( |
| "Vertex mode requires google-auth: pip install google-auth google-auth-httplib2" |
| ) from exc |
|
|
| scopes = ["https://www.googleapis.com/auth/cloud-platform"] |
| self._auth_request = google.auth.transport.requests.Request() |
| valid_accounts = [] |
| skipped_accounts = [] |
| for account in self.accounts: |
| key_path = account.get("service_account_key") or "" |
| key_info = account.get("service_account_info") |
| if key_path: |
| if not os.path.exists(key_path): |
| skipped_accounts.append((account.get("account_name", ""), f"missing key file: {key_path}")) |
| continue |
| credentials = service_account.Credentials.from_service_account_file( |
| key_path, |
| scopes=scopes, |
| ) |
| else: |
| credentials = service_account.Credentials.from_service_account_info( |
| key_info, |
| scopes=scopes, |
| ) |
| try: |
| credentials.refresh(self._auth_request) |
| except Exception as exc: |
| skipped_accounts.append(( |
| account.get("account_name", ""), |
| f"{type(exc).__name__}: {str(exc)[:300]}", |
| )) |
| continue |
| account["credentials"] = credentials |
| valid_accounts.append(account) |
| self.accounts = valid_accounts |
| if not self.accounts: |
| details = "; ".join(f"{name}: {reason}" for name, reason in skipped_accounts[:5]) |
| raise ValueError(f"No valid Vertex account credentials in {self.pool_file}. {details}") |
| if skipped_accounts: |
| skipped_text = "; ".join(f"{name}: {reason}" for name, reason in skipped_accounts[:5]) |
| print( |
| f"[VERTEX] Skipped {len(skipped_accounts)} invalid account(s): {skipped_text}", |
| flush=True, |
| ) |
| print(f"[VERTEX] Loaded {len(self.accounts)} account(s) from {self.pool_file}", flush=True) |
|
|
| async def acquire_account(self) -> tuple[int, dict]: |
| while True: |
| async with self._lock: |
| now = time.monotonic() |
| count = len(self.accounts) |
| for offset in range(count): |
| idx = (self._next_index + offset) % count |
| if self._cooldown_until.get(idx, 0.0) <= now: |
| self._next_index = (idx + 1) % count |
| return idx, self.accounts[idx] |
| wait = min(self._cooldown_until.values()) - now |
| await asyncio.sleep(max(wait, 0.1)) |
|
|
| async def mark_cooldown(self, idx: int) -> None: |
| if self.cooldown_seconds <= 0: |
| return |
| async with self._lock: |
| self._cooldown_until[idx] = time.monotonic() + self.cooldown_seconds |
|
|
| def token_for(self, account: dict) -> str: |
| credentials = account["credentials"] |
| if credentials.expired or not credentials.token: |
| credentials.refresh(self._auth_request) |
| return credentials.token |
|
|
|
|
| def _convert_messages_to_vertex_native(messages: list[dict]) -> tuple[Optional[dict], list[dict]]: |
| system_instruction = None |
| contents = [] |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if role == "system": |
| parts = [] |
| if isinstance(content, str): |
| if content.strip(): |
| parts.append({"text": content}) |
| elif isinstance(content, list): |
| for item in content: |
| if isinstance(item, str) and item.strip(): |
| parts.append({"text": item}) |
| elif isinstance(item, dict) and item.get("type") == "text": |
| text = item.get("text", "") |
| if text.strip(): |
| parts.append({"text": text}) |
| if parts: |
| system_instruction = {"parts": parts} |
| continue |
|
|
| vertex_role = "model" if role == "assistant" else "user" |
| parts = [] |
| if isinstance(content, str): |
| if content.strip(): |
| parts.append({"text": content}) |
| elif isinstance(content, list): |
| for item in content: |
| if isinstance(item, str): |
| if item.strip(): |
| parts.append({"text": item}) |
| elif isinstance(item, dict): |
| item_type = item.get("type", "") |
| if item_type == "text": |
| text = item.get("text", "") |
| if text.strip(): |
| parts.append({"text": text}) |
| elif item_type == "image_url": |
| image_url = item.get("image_url", {}) |
| url = image_url.get("url", "") if isinstance(image_url, dict) else "" |
| if url.startswith("data:"): |
| try: |
| header, b64_data = url.split(",", 1) |
| mime_type = header.split(":", 1)[1].split(";", 1)[0] |
| except (ValueError, IndexError): |
| continue |
| parts.append({ |
| "inlineData": { |
| "mimeType": mime_type, |
| "data": b64_data, |
| } |
| }) |
| elif url.startswith("gs://"): |
| parts.append({ |
| "fileData": { |
| "fileUri": url, |
| "mimeType": "image/jpeg", |
| } |
| }) |
| if parts: |
| contents.append({"role": vertex_role, "parts": parts}) |
|
|
| return system_instruction, contents |
|
|
|
|
| def _parse_vertex_response(data: dict) -> tuple[str, str, str]: |
| prompt_feedback = data.get("promptFeedback", {}) |
| block_reason = prompt_feedback.get("blockReason", "") |
| if block_reason: |
| raise ValueError(f"Vertex prompt blocked: {block_reason}") |
|
|
| candidates = data.get("candidates", []) |
| if not candidates: |
| raise ValueError(f"Empty candidates in Vertex response: {json.dumps(data)[:500]}") |
|
|
| candidate = candidates[0] |
| content_text = "" |
| reasoning_text = "" |
| for part in candidate.get("content", {}).get("parts", []) or []: |
| text = part.get("text", "") |
| if not text: |
| continue |
| if part.get("thought", False): |
| reasoning_text += text |
| else: |
| content_text += text |
| finish_reason = candidate.get("finishReason", "STOP") |
| return content_text, reasoning_text, finish_reason |
|
|
|
|
| def _vertex_stop_sequences() -> list[str]: |
| """返回 Vertex 文本生成的协议边界停止串,防止模型续写用户回合或工具响应。""" |
| return ["<|im_end|>", "<|im_start|>user", "<tool_response>"] |
|
|
|
|
| def _is_vertex_rate_limit(status: int, body: str) -> bool: |
| body_lower = (body or "").lower() |
| return ( |
| status == 429 |
| or "resource_exhausted" in body_lower |
| or "quota" in body_lower |
| or "rate limit" in body_lower |
| ) |
|
|
|
|
| def _is_vertex_transient_http_error(status: int) -> bool: |
| """判断 Vertex HTTP 状态是否适合换账号重试,而不是记为样本错误。""" |
| return status in {408, 409, 425, 499} or status >= 500 |
|
|
|
|
| async def call_vertex_gemini_api( |
| messages: list[dict], |
| model: str, |
| base_url: str, |
| api_key: str, |
| **kwargs, |
| ) -> dict: |
| """Call Gemini through Vertex AI native generateContent with account-pool rotation.""" |
| request_debug = _summarize_messages_for_debug(messages) |
| vertex_model = normalize_model_name_for_client(model, "vertex") |
| vertex_pool = kwargs.get("vertex_account_pool") |
| if not isinstance(vertex_pool, VertexAccountPool): |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": "vertex_account_pool is required for --model-client vertex", |
| "request_debug": request_debug, |
| } |
|
|
| system_instruction, contents = _convert_messages_to_vertex_native(messages) |
| body = { |
| "contents": contents, |
| "generationConfig": { |
| "temperature": kwargs.get("temperature", 0.7), |
| "maxOutputTokens": kwargs.get("max_tokens", 4096), |
| "stopSequences": kwargs.get("vertex_stop_sequences") or _vertex_stop_sequences(), |
| }, |
| } |
| if system_instruction: |
| body["systemInstruction"] = system_instruction |
| if "top_p" in kwargs: |
| body["generationConfig"]["topP"] = kwargs["top_p"] |
| if "top_k" in kwargs: |
| body["generationConfig"]["topK"] = kwargs["top_k"] |
|
|
| max_attempts = len(vertex_pool.accounts) |
| last_error = "" |
| for _ in range(max_attempts): |
| account_idx, account = await vertex_pool.acquire_account() |
| try: |
| token = vertex_pool.token_for(account) |
| except Exception as e: |
| last_error = f"{type(e).__name__}: {str(e)}" |
| await vertex_pool.mark_cooldown(account_idx) |
| continue |
| url = ( |
| "https://aiplatform.googleapis.com/v1/projects/" |
| f"{account['project_id']}/locations/{account['location']}" |
| f"/publishers/google/models/{vertex_model}:generateContent" |
| ) |
| headers = { |
| "Authorization": f"Bearer {token}", |
| "Content-Type": "application/json", |
| } |
| try: |
| timeout = aiohttp.ClientTimeout(total=kwargs.get("model_request_timeout", 300)) |
| async with create_http_session(timeout) as session: |
| async with session.post(url, headers=headers, json=body) as resp: |
| raw_text = await resp.text() |
| raw_error_text = _truncate_debug_text(raw_text, limit=4000) |
| try: |
| data = json.loads(raw_text) |
| except json.JSONDecodeError: |
| data = {} |
| if resp.status != 200: |
| error = _extract_http_error_message(resp.status, data, raw_text) |
| last_error = f"HTTP {resp.status}: {error}" |
| if _is_vertex_rate_limit(resp.status, raw_text) and max_attempts > 1: |
| await vertex_pool.mark_cooldown(account_idx) |
| continue |
| if _is_vertex_transient_http_error(resp.status) and max_attempts > 1: |
| continue |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": last_error, |
| "error_status_code": resp.status, |
| "error_raw_response": raw_error_text, |
| "request_debug": request_debug, |
| } |
|
|
| content, reasoning, finish_reason = _parse_vertex_response(data) |
| full_text = reconstruct_assistant_text( |
| content=content, |
| reasoning_content=reasoning, |
| tool_calls=[], |
| ) |
| return { |
| "content": content, |
| "reasoning_content": reasoning, |
| "tool_calls": [], |
| "full_text": full_text, |
| "finish_reason": finish_reason.lower(), |
| "error": None, |
| "request_debug": request_debug, |
| "vertex_account": account.get("account_name", ""), |
| "vertex_project_id": account.get("project_id", ""), |
| } |
| except Exception as e: |
| last_error = f"{type(e).__name__}: {str(e)}" |
| if "429" in last_error or "RESOURCE_EXHAUSTED" in last_error: |
| await vertex_pool.mark_cooldown(account_idx) |
| continue |
| if isinstance(e, (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError, OSError)): |
| continue |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": _truncate_debug_text(last_error, limit=1000), |
| "error_status_code": None, |
| "error_raw_response": "", |
| "request_debug": request_debug, |
| } |
|
|
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": f"All Vertex accounts failed or are cooling down. Last error: {last_error[:500]}", |
| "request_debug": request_debug, |
| } |
|
|
|
|
| async def call_azure_api( |
| messages: list[dict], |
| model: str, |
| base_url: str, |
| api_key: str, |
| **kwargs, |
| ) -> dict: |
| """Call Azure OpenAI API.""" |
| api_version = kwargs.get("azure_api_version", "2025-04-01-preview") |
|
|
| body = { |
| "model": model, |
| "messages": messages, |
| "temperature": kwargs.get("temperature", 0.7), |
| "max_tokens": kwargs.get("max_tokens", 4096), |
| } |
| if "top_p" in kwargs: |
| body["top_p"] = kwargs["top_p"] |
| if "reasoning_effort" in kwargs: |
| body["reasoning_effort"] = kwargs["reasoning_effort"] |
|
|
| url = f"{base_url.rstrip('/')}/openai/deployments/{model}/chat/completions?api-version={api_version}" |
| headers = {"api-key": api_key, "Content-Type": "application/json"} |
|
|
| try: |
| timeout = aiohttp.ClientTimeout(total=300) |
| async with create_http_session(timeout) as session: |
| async with session.post(url, headers=headers, json=body) as resp: |
| data = await resp.json() |
| if resp.status != 200: |
| error = data.get('error', {}).get('message', f'HTTP {resp.status}') |
| error = re.sub(r'data:image[^"]*', '[IMAGE]', str(error))[:500] |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": error, |
| } |
|
|
| choice = data["choices"][0] |
| return { |
| "content": choice["message"]["content"], |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": choice["message"]["content"], |
| "finish_reason": choice.get("finish_reason", "stop"), |
| "error": None, |
| } |
| except Exception as e: |
| return { |
| "content": "", |
| "reasoning_content": "", |
| "tool_calls": [], |
| "full_text": "", |
| "finish_reason": "error", |
| "error": str(e), |
| } |
|
|
|
|
| |
| |
| |
|
|
| VALID_TOOL_CALL_NAMES = { |
| "choose_frames", |
| "find_frame", |
| "image_search", |
| "image_search_tool", |
| "image_zoom_in_tool", |
| "text_search_tool", |
| "web_search", |
| "zoom_in", |
| } |
|
|
|
|
| def _strip_code_fence(text: str) -> str: |
| content = text.strip() |
| match = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", content, re.DOTALL | re.IGNORECASE) |
| return match.group(1).strip() if match else content |
|
|
|
|
| def _load_tool_call_json(content: str) -> Optional[dict]: |
| """Load a tool call JSON object from a possibly noisy text fragment.""" |
| content = _strip_code_fence(content) |
| candidates = [content] |
|
|
| decoder = json.JSONDecoder() |
| try: |
| obj, _ = decoder.raw_decode(content) |
| if isinstance(obj, dict): |
| candidates.append(json.dumps(obj, ensure_ascii=False)) |
| except json.JSONDecodeError: |
| pass |
|
|
| first_brace = content.find("{") |
| last_brace = content.rfind("}") |
| if 0 <= first_brace < last_brace: |
| candidates.append(content[first_brace:last_brace + 1]) |
|
|
| for candidate in candidates: |
| candidate = candidate.strip() |
| while candidate: |
| try: |
| obj = json.loads(candidate) |
| except json.JSONDecodeError as exc: |
| if exc.msg == "Extra data" and 0 <= exc.pos < len(candidate): |
| candidate = candidate[:exc.pos].strip() |
| continue |
| if candidate.endswith("}"): |
| candidate = candidate[:-1].rstrip() |
| continue |
| break |
| if isinstance(obj, dict) and obj.get("name") in VALID_TOOL_CALL_NAMES: |
| return obj |
| break |
| return None |
|
|
|
|
| def _extract_tool_call_json_from_text(text: str) -> Optional[dict]: |
| """Recover a tool call JSON object when XML tags are missing or malformed.""" |
| if '"name"' not in text and "'name'" not in text: |
| return None |
| for match in re.finditer(r"\{", text): |
| candidate = _load_tool_call_json(text[match.start():]) |
| if candidate: |
| return candidate |
| return None |
|
|
|
|
| def _clean_search_query(text: str) -> str: |
| text = re.sub(r"\[IMAGE\s+\d+\]", " ", text, flags=re.IGNORECASE) |
| text = re.sub(r"<[^>]+>", " ", text) |
| text = re.sub(r"\s+", " ", text).strip(" .:-\n\t") |
| return text[:240].strip() |
|
|
|
|
| def _extract_visual_search_hints(text: str) -> list[str]: |
| hints: list[str] = [] |
| patterns = [ |
| r'visible text\s+"([^"]+)"', |
| r'prominent text\s+"([^"]+)"', |
| r'reads\s+"([^"]+)"', |
| r'\(([^()]{3,80})\)', |
| r'\b(?:as|is|likely|identified as|reads)\s+(?:the\s+)?([A-Z][A-Za-z0-9&.\' -]{2,60})', |
| ] |
| for pattern in patterns: |
| for match in re.finditer(pattern, text): |
| value = _clean_search_query(match.group(1)) |
| if value and value.lower() not in {"the", "a", "an"} and value not in hints: |
| hints.append(value) |
| for match in re.finditer(r"[\u4e00-\u9fff]{2,20}", text): |
| value = match.group(0) |
| if value not in hints: |
| hints.append(value) |
| return hints[:5] |
|
|
|
|
| def _recover_intent_tool_call(text: str, user_query: str = "") -> Optional[dict]: |
| """Infer a tool call only when the model clearly announces a tool intent.""" |
| if "<answer>" in text: |
| return None |
|
|
| intent_prefix = ( |
| r"\b(?:i|we)\s+" |
| r"(?:will|would|need to|should|can|must|am going to|plan to)\s+" |
| r"(?:directly\s+)?(?:use|perform|conduct|call|do)\s+(?:an?\s+)?" |
| ) |
| asks_web_search = bool(re.search( |
| intent_prefix + r"`?(?:web_search|web search|search the web)`?\b", |
| text, |
| flags=re.IGNORECASE, |
| )) |
| asks_image_search = bool(re.search( |
| intent_prefix + r"`?(?:image_search|image search|reverse image search)`?\b", |
| text, |
| flags=re.IGNORECASE, |
| )) |
|
|
| if asks_web_search: |
| query_parts = [] |
| cleaned_user_query = _clean_search_query(user_query) |
| if cleaned_user_query: |
| query_parts.append(cleaned_user_query) |
| query_parts.extend(_extract_visual_search_hints(text)) |
| if not query_parts: |
| match = re.search( |
| r"(?:web_search|web search|search the web)\s+(?:to|for)?\s*" |
| r"(?:find|determine|identify|look up|verify|confirm)?\s*([^.\n<]{8,180})", |
| text, |
| flags=re.IGNORECASE, |
| ) |
| if match: |
| query_parts.append(_clean_search_query(match.group(1))) |
| query = _clean_search_query(" ".join(part for part in query_parts if part)) |
| if query: |
| return {"name": "web_search", "arguments": {"query": query}} |
|
|
| if asks_image_search: |
| return {"name": "image_search", "arguments": {"bbox": [0, 0, 1000, 1000]}} |
|
|
| return None |
|
|
|
|
| def parse_tool_call(text: str, user_query: str = "") -> Optional[dict]: |
| """Parse or conservatively recover a tool call from model output.""" |
| match = re.search(r'<tool_call>\s*(.*?)\s*</tool_call>', text, re.DOTALL) |
| if match: |
| parsed = _load_tool_call_json(match.group(1).strip()) |
| if parsed: |
| return parsed |
|
|
| parsed = _extract_tool_call_json_from_text(text) |
| if parsed: |
| return parsed |
|
|
| return _recover_intent_tool_call(text, user_query=user_query) |
|
|
|
|
| def sanitize_assistant_tool_turn(text: str) -> str: |
| """工具回合只保留 assistant 自己的工具调用,截掉模型续写的用户回合。""" |
| text = text or "" |
| match = re.search(r"<tool_call>\s*.*?</tool_call>", text, re.DOTALL | re.IGNORECASE) |
| if match: |
| return text[:match.end()].strip() |
| cutoff = len(text) |
| for marker in ("<|im_start|>user", "<tool_response>", "<|im_end|>"): |
| idx = text.find(marker) |
| if idx >= 0: |
| cutoff = min(cutoff, idx) |
| return text[:cutoff].strip() |
|
|
|
|
| def extract_tool_name_and_args(tool_call: Optional[dict]) -> tuple[str, dict]: |
| """兼容嵌套 `arguments` 与扁平 JSON 两种 tool_call。""" |
| if not isinstance(tool_call, dict): |
| return "", {} |
| tool_name = tool_call.get("name", "") |
| arguments = tool_call.get("arguments") |
| if isinstance(arguments, dict): |
| return tool_name, arguments |
| return tool_name, {k: v for k, v in tool_call.items() if k != "name"} |
|
|
|
|
| def extract_tool_call_from_result(result: dict, assistant_text: str, user_query: str = "") -> Optional[dict]: |
| """优先读取结构化 tool_calls,回退到文本里的 <tool_call>。""" |
| tool_calls = result.get("tool_calls") or [] |
| if tool_calls: |
| first = tool_calls[0] |
| if isinstance(first, dict): |
| return { |
| "name": first.get("name", ""), |
| "arguments": first.get("arguments", {}), |
| } |
| return parse_tool_call(assistant_text, user_query=user_query) |
|
|
|
|
| FORMAT_REPAIR_PROMPT = """<tool_response> |
| Your previous assistant message was empty or did not follow the required protocol, so no valid tool call or final answer could be parsed. |
| |
| Continue from the current evidence. Return exactly ONE of the following forms: |
| |
| <think>brief reasoning</think> |
| <tool_call>{"name": "<tool-name>", "arguments": {...}}</tool_call> |
| |
| or: |
| |
| <think>brief reasoning</think> |
| <answer>final answer only</answer> |
| |
| Do not output an empty message. Do not output an empty <tool_call>. Do not put the final answer only inside <think>. |
| </tool_response>""" |
|
|
|
|
| FINAL_ANSWER_PROMPT = """<tool_response> |
| No more tool calls are available. Use only the video evidence and tool results already collected. |
| |
| You must now provide the final answer in exactly this form: |
| |
| <think>brief reasoning based on the collected evidence</think> |
| <answer>final answer only</answer> |
| |
| Do not call any tool. Do not search again. Do not output an empty message. |
| </tool_response>""" |
|
|
|
|
| FINAL_ANSWER_REPAIR_PROMPT = """<tool_response> |
| Your previous message still did not provide a final answer. Tool calls are disabled. |
| |
| Return exactly: |
| |
| <think>brief reasoning based on the collected evidence</think> |
| <answer>final answer only</answer> |
| |
| Do not call any tool. |
| </tool_response>""" |
|
|
|
|
| FINAL_ONLY_SYSTEM_PROMPT = """Use the provided evidence to answer the question. |
| |
| Continue the assistant prefix with the final answer text only. |
| """ |
|
|
|
|
| FINAL_ONLY_REPAIR_PROMPT = """Your previous message still was not a valid final answer. |
| Use the provided evidence and continue the assistant prefix with the final answer text only. |
| """ |
|
|
|
|
| def _has_answer_tag(text: str) -> bool: |
| return bool(re.search(r"<answer>.*?</answer>", text or "", re.DOTALL)) |
|
|
|
|
| def _extract_last_answer_tag(text: str) -> str: |
| matches = list(re.finditer(r"<answer>(.*?)</answer>", text or "", re.DOTALL)) |
| if not matches: |
| return "" |
| return matches[-1].group(1).strip() |
|
|
|
|
| def _has_tool_call_tag(text: str) -> bool: |
| return "<tool_call" in (text or "").lower() |
|
|
|
|
| def _looks_like_malformed_tool_shell(text: str) -> bool: |
| text = (text or "").strip() |
| if not text: |
| return True |
| if _has_tool_call_tag(text): |
| return True |
| if "<tool_call>" in text and "</tool_call>" not in text: |
| return True |
| match = re.search(r"<tool_call>\s*(.*?)\s*</tool_call>", text, re.DOTALL) |
| if match and not match.group(1).strip(): |
| return True |
| return False |
|
|
|
|
| def _looks_like_unfinished_tool_intent(text: str) -> bool: |
| """Detect prose that asks to use a tool but failed to emit a valid call.""" |
| text = (text or "").lower() |
| if not text: |
| return True |
| intent_patterns = [ |
| r"\b(?:i|we)\s+(?:will|would|need to|should|must|am going to|plan to)\s+" |
| r"(?:use|call|perform|conduct|do)\s+(?:the\s+)?(?:web_search|image_search|web search|image search|reverse image search|zoom_in|zoom in|find_frame|choose_frames)", |
| r"\b(?:next|now)\s+(?:i|we)\s+(?:will|would|should|need to)\s+" |
| r"(?:search|zoom|inspect|call|use|look up)", |
| ] |
| return any(re.search(pattern, text) for pattern in intent_patterns) |
|
|
|
|
| def _looks_like_unfinished_no_tool_answer(text: str) -> bool: |
| """Detect no-tool prose that is still analysis instead of a final answer.""" |
| raw = (text or "").strip() |
| if not raw: |
| return True |
| if _has_answer_tag(raw): |
| return False |
| lowered = raw.lower() |
|
|
| |
| |
| if re.search(r"</?think(?:ing)?>", raw, flags=re.IGNORECASE): |
| return True |
|
|
| plain = _strip_protocol_text(raw) |
| if not plain: |
| return True |
| if len(plain) > 1200: |
| return True |
| if plain[-1:] in {",", ";", ":"}: |
| return True |
|
|
| unfinished_patterns = [ |
| r"\b(?:i|we)\s+(?:need to|should|must|will|would|can|could|am going to|plan to)\s+" |
| r"(?:find|confirm|verify|check|search|look up|identify|zoom|inspect|use|call)", |
| r"\b(?:let me|let's|now i|next i|first i|i'll|i will)\b", |
| r"\b(?:it could be|could be|might be|may be|likely|probably|possibly)\b", |
| r"\b(?:wait|actually|however|but hold on|not sure|unclear)\b", |
| r"\bcommon .* (?:are|include):", |
| r"\b(?:first|second|third),?\s+(?:i|we)\b", |
| r"\b(?:step|plan|search strategy)\b", |
| r"\n\s*(?:1\.|2\.|3\.|\*)\s+", |
| ] |
| return any(re.search(pattern, lowered, flags=re.DOTALL) for pattern in unfinished_patterns) |
|
|
|
|
| def _strip_protocol_text(text: str) -> str: |
| text = re.sub(r"<\|im_end\|>", " ", text or "") |
| text = re.sub(r"</?think(?:ing)?>", " ", text, flags=re.IGNORECASE) |
| text = re.sub(r"</?answer>", " ", text, flags=re.IGNORECASE) |
| text = re.sub(r"</?tool_call>", " ", text, flags=re.IGNORECASE) |
| return " ".join(text.split()).strip() |
|
|
|
|
| def _recover_answer_from_no_tool_output(result: dict, assistant_text: str) -> Optional[str]: |
| """Return a final-answer text when no tool call exists but the model clearly answered.""" |
| assistant_text = assistant_text or "" |
| if _has_answer_tag(assistant_text): |
| if _has_tool_call_tag(assistant_text): |
| answer = _extract_last_answer_tag(assistant_text) |
| return f"<answer>{answer}</answer>" if answer else None |
| return assistant_text |
|
|
| content = _message_content_to_text(result.get("content")).strip() |
| reasoning = _message_content_to_text(result.get("reasoning_content")).strip() |
|
|
| |
| |
| content_plain = _strip_protocol_text(content) |
| if ( |
| content_plain |
| and not _has_tool_call_tag(content) |
| and not _looks_like_unfinished_tool_intent(content) |
| and not _looks_like_unfinished_no_tool_answer(content) |
| ): |
| return content |
|
|
| |
| |
| |
| reasoning_plain = _strip_protocol_text(reasoning) |
| if ( |
| reasoning_plain |
| and len(reasoning_plain) <= 800 |
| and not _looks_like_malformed_tool_shell(reasoning) |
| and not _looks_like_unfinished_tool_intent(reasoning) |
| and not _looks_like_unfinished_no_tool_answer(reasoning) |
| ): |
| return f"<think>\n{reasoning}\n</think>\n<answer>{reasoning_plain}</answer>" |
|
|
| return None |
|
|
|
|
| def _make_format_repair_turn(assistant_text: str) -> tuple[dict | None, dict, str]: |
| """Build messages/output snippet for one protocol-repair retry.""" |
| assistant_message = {"role": "assistant", "content": assistant_text} if assistant_text else None |
| user_message = {"role": "user", "content": FORMAT_REPAIR_PROMPT} |
| output_snippet = ( |
| f"{assistant_text}<|im_end|><|im_start|>user\n" |
| f"{FORMAT_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| return assistant_message, user_message, output_snippet |
|
|
|
|
| def _build_final_only_user_prompt(question: str, output_parts: list[str], max_chars: int = 8000) -> str: |
| transcript = "".join(output_parts) |
| responses = re.findall(r"<tool_response>\s*(.*?)\s*</tool_response>", transcript, re.DOTALL) |
| cleaned_responses = [] |
| for idx, response in enumerate(responses, 1): |
| text = re.sub(r"<\|im_start\|>.*", "", response, flags=re.DOTALL) |
| text = re.sub(r"\n{3,}", "\n\n", text).strip() |
| if text: |
| cleaned_responses.append(f"[Evidence {idx}]\n{text}") |
|
|
| evidence_text = "\n\n".join(cleaned_responses) |
| if len(evidence_text) > max_chars: |
| evidence_text = evidence_text[-max_chars:] |
| evidence_text = "[Evidence truncated to the most recent collected results]\n" + evidence_text |
|
|
| return ( |
| "Question:\n" |
| f"{question}\n\n" |
| "Evidence:\n" |
| f"{evidence_text or 'No textual evidence was collected.'}" |
| ) |
|
|
|
|
| def _with_final_only_extra_body(kwargs: dict) -> dict: |
| extra_body = dict(kwargs.get("extra_body") or {}) |
| chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {}) |
| chat_template_kwargs["enable_thinking"] = False |
| extra_body["chat_template_kwargs"] = chat_template_kwargs |
| extra_body["continue_final_message"] = True |
| extra_body["stop"] = ["</answer>", "<tool_call"] |
| return extra_body |
|
|
|
|
| def _recover_final_only_json_output(result: dict, assistant_text: str) -> Optional[str]: |
| candidates = [ |
| _message_content_to_text(result.get("content")).strip(), |
| (assistant_text or "").strip(), |
| ] |
| for text in candidates: |
| if not text: |
| continue |
| text = re.sub(r"^```(?:json)?\s*", "", text.strip(), flags=re.IGNORECASE) |
| text = re.sub(r"\s*```$", "", text.strip()) |
| try: |
| data = json.loads(text) |
| except json.JSONDecodeError: |
| continue |
| if not isinstance(data, dict): |
| continue |
| reasoning = str(data.get("reasoning") or "").strip() |
| answer = str(data.get("answer") or "").strip() |
| if not answer: |
| continue |
| if _has_tool_call_tag(reasoning) or _has_tool_call_tag(answer): |
| continue |
| return f"<think>{reasoning}</think>\n<answer>{answer}</answer>" |
| return None |
|
|
|
|
| async def call_final_only_answer( |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| question: str, |
| output_parts: list[str], |
| tool_calls: list[dict], |
| sample_id: str, |
| retry_limit: int, |
| **kwargs, |
| ) -> dict: |
| final_messages = [ |
| {"role": "system", "content": FINAL_ONLY_SYSTEM_PROMPT}, |
| {"role": "user", "content": _build_final_only_user_prompt(question, output_parts)}, |
| {"role": "assistant", "content": "<answer>"}, |
| ] |
| transcript_parts = [ |
| f"<|im_start|>system\n{FINAL_ONLY_SYSTEM_PROMPT}<|im_end|>\n", |
| f"<|im_start|>user\n{final_messages[1]['content']}<|im_end|>\n<|im_start|>assistant\n<answer>", |
| ] |
| last_result = {} |
| last_output = "" |
| retries_used = 0 |
|
|
| for attempt in range(max(0, retry_limit) + 1): |
| api_kwargs = dict(kwargs) |
| api_kwargs["request_debug_context"] = { |
| "sample_id": sample_id, |
| "turn": f"final_only_{attempt}", |
| "task_kind": "video_final_only", |
| } |
| api_kwargs["extra_body"] = _with_final_only_extra_body(api_kwargs) |
| if client == "gemini": |
| result = await call_gemini_api(final_messages, model, base_url, api_key, **api_kwargs) |
| elif client == "azure": |
| result = await call_azure_api(final_messages, model, base_url, api_key, **api_kwargs) |
| elif client == "vertex": |
| result = await call_vertex_gemini_api(final_messages, model, base_url, api_key, **api_kwargs) |
| elif client == "gateway": |
| result = await call_gateway_api(final_messages, model, base_url, api_key, **api_kwargs) |
| else: |
| result = await call_openai_api(final_messages, model, base_url, api_key, **api_kwargs) |
|
|
| last_result = result |
| last_output = result.get("full_text") or result.get("content", "") |
| if result.get("error"): |
| transcript_parts.append(last_output + "<|im_end|>") |
| break |
|
|
| answer_text = _extract_last_answer_tag(last_output) or _strip_protocol_text(last_output) |
| recovered_answer = None |
| if ( |
| answer_text |
| and not _has_tool_call_tag(last_output) |
| and not _looks_like_unfinished_tool_intent(last_output) |
| ): |
| recovered_answer = f"<answer>{answer_text}</answer>" |
| if not recovered_answer: |
| recovered_answer = _recover_final_only_json_output(result, last_output) |
| if not recovered_answer: |
| recovered_answer = _recover_answer_from_no_tool_output(result, last_output) |
| if recovered_answer and not _has_tool_call_tag(recovered_answer): |
| if recovered_answer != last_output: |
| last_output = recovered_answer |
| transcript_output = last_output |
| if transcript_output.startswith("<answer>"): |
| transcript_output = transcript_output[len("<answer>"):] |
| transcript_parts.append(transcript_output + "<|im_end|>") |
| final_messages[-1] = {"role": "assistant", "content": last_output} |
| return { |
| "output": last_output, |
| "transcript": "".join(transcript_parts), |
| "messages": final_messages, |
| "finish_reason": "answer", |
| "result": result, |
| "retries": retries_used, |
| } |
|
|
| transcript_parts.append(last_output + "<|im_end|>") |
| final_messages[-1] = {"role": "assistant", "content": f"<answer>{last_output}"} |
| if attempt < retry_limit: |
| final_messages.append({"role": "user", "content": FINAL_ONLY_REPAIR_PROMPT}) |
| final_messages.append({"role": "assistant", "content": "<answer>"}) |
| transcript_parts.append( |
| f"\n<|im_start|>user\n{FINAL_ONLY_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n<answer>" |
| ) |
| retries_used += 1 |
|
|
| return { |
| "output": last_output, |
| "transcript": "".join(transcript_parts), |
| "messages": final_messages, |
| "finish_reason": "error" if last_result.get("error") else "max_turns", |
| "result": last_result, |
| "retries": retries_used, |
| } |
|
|
|
|
| |
| |
| SUMMARY_SYSTEM_PROMPT = """You are a helpful assistant. Your task is to summarize the main content of the given web page in no more than five sentences. Your summary should cover the overall key points of the page, not just parts related to the user's question. |
| |
| If any part of the content is helpful for answering the user's question, be sure to include it clearly in the summary. Do not ignore relevant information, but also make sure the general structure and main ideas of the page are preserved. Your summary should be concise, factual, and informative.""" |
|
|
| SUMMARY_USER_PROMPT = """Webpage Content (first {content_limit} characters) is: {content} |
| Question: {query}""" |
|
|
|
|
| |
| |
| SKIP_EXTENSIONS = ['.pdf', '.doc', '.docx', '.ppt', '.pptx', '.xls', '.xlsx', '.jpg', '.jpeg', '.png', '.gif'] |
|
|
| |
| _playwright = None |
| _browser = None |
| _browser_context = None |
| _browser_lock = asyncio.Lock() |
| _playwright_fetch_semaphore = asyncio.Semaphore( |
| max(1, int(os.environ.get("WEB_FETCH_PLAYWRIGHT_CONCURRENCY", "4"))) |
| ) |
|
|
|
|
| def _is_browser_closed_error(error: Exception) -> bool: |
| """Return True when Playwright's shared browser/context must be recreated.""" |
| text = str(error) |
| closed_markers = [ |
| "Target page, context or browser has been closed", |
| "BrowserContext.new_page", |
| "BrowserType.launch", |
| "browser has been closed", |
| "context has been closed", |
| ] |
| return any(marker in text for marker in closed_markers) |
|
|
|
|
| async def _get_browser_context(): |
| """Get or create a shared browser context.""" |
| global _playwright, _browser, _browser_context |
| async with _browser_lock: |
| if _browser is not None and not _browser.is_connected(): |
| await _cleanup_browser_unlocked() |
| if _browser_context is None: |
| _playwright = await async_playwright().start() |
| _browser = await _playwright.chromium.launch(headless=True) |
| |
| _browser_context = await _browser.new_context( |
| user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.6261.57 Safari/537.36" |
| ) |
| return _browser_context |
|
|
|
|
| async def _cleanup_browser(): |
| """Cleanup browser resources.""" |
| async with _browser_lock: |
| await _cleanup_browser_unlocked() |
|
|
|
|
| async def _cleanup_browser_unlocked(): |
| """Cleanup browser resources. Caller must hold _browser_lock.""" |
| global _playwright, _browser, _browser_context |
| if _browser_context: |
| try: |
| await asyncio.wait_for(_browser_context.close(), timeout=5.0) |
| except Exception: |
| pass |
| _browser_context = None |
| if _browser: |
| try: |
| await asyncio.wait_for(_browser.close(), timeout=5.0) |
| except Exception: |
| pass |
| _browser = None |
| if _playwright: |
| try: |
| await asyncio.wait_for(_playwright.stop(), timeout=5.0) |
| except Exception: |
| pass |
| _playwright = None |
|
|
|
|
| |
| BOT_DETECTION_CODES = {403, 429, 406, 418, 421, 451} |
|
|
| |
| TRACKING_DOMAINS = [ |
| |
| "*google-analytics.com*", "*googletagmanager.com*", "*doubleclick.net*", |
| "*googleadservices.com*", "*googlesyndication.com*", "*googletagservices.com*", |
| |
| "*facebook.com/tr*", "*connect.facebook.net*", "*facebook.net*", |
| |
| "*hotjar.com*", "*mixpanel.com*", "*segment.com*", "*amplitude.com*", |
| "*fullstory.com*", "*logrocket.com*", "*mouseflow.com*", |
| |
| "*adsystem.com*", "*pubmatic.com*", "*rubiconproject.com*", |
| "*amazon-adsystem.com*", "*adsafeprotected.com*", |
| |
| "*newrelic.com*", "*nr-data.net*", "*pingdom.net*", |
| "*optimizely.com*", "*quantserve.com*", "*scorecardresearch.com*" |
| ] |
|
|
|
|
| async def fetch_url_content(url: str, timeout_sec: int = 20) -> Optional[str]: |
| """Fetch URL content using Playwright with JS rendering. |
| |
| Returns None for skip cases (non-HTML, skip extensions). |
| Raises URLFetchError for retriable errors (timeout, HTTP error, etc.). |
| """ |
| |
| if any(url.lower().endswith(ext) for ext in SKIP_EXTENSIONS): |
| return None |
|
|
| async with _playwright_fetch_semaphore: |
| return await _fetch_url_content_with_playwright(url, timeout_sec) |
|
|
|
|
| async def _fetch_url_content_with_playwright(url: str, timeout_sec: int = 20) -> Optional[str]: |
| """Fetch one URL while holding the Playwright fetch semaphore.""" |
| from bs4 import BeautifulSoup |
|
|
| page = None |
| try: |
| context = await _get_browser_context() |
| page = await context.new_page() |
|
|
| |
| excluded_resource_types = ["image", "stylesheet", "font", "media", "websocket", "eventsource", "manifest"] |
| async def resource_handler(route): |
| if route.request.resource_type in excluded_resource_types: |
| await route.abort() |
| else: |
| await route.continue_() |
| await page.route("**/*", resource_handler) |
|
|
| |
| async def tracking_handler(route): |
| await route.abort() |
| for domain_pattern in TRACKING_DOMAINS: |
| try: |
| await page.route(domain_pattern, tracking_handler) |
| except: |
| pass |
|
|
| |
| try: |
| response = await asyncio.wait_for( |
| page.goto(url, timeout=timeout_sec * 1000, wait_until='domcontentloaded'), |
| timeout=25.0 |
| ) |
| except asyncio.TimeoutError: |
| raise URLFetchError(f"Timeout fetching {url}") |
|
|
| if response is None: |
| raise URLFetchError(f"No response from {url}") |
|
|
| if response.status >= 400: |
| raise URLFetchError(f"HTTP {response.status} from {url}") |
|
|
| |
| content_type = response.headers.get('content-type', '') |
| if 'text/html' not in content_type: |
| return None |
|
|
| try: |
| html = await asyncio.wait_for(page.content(), timeout=15.0) |
| except asyncio.TimeoutError: |
| raise URLFetchError(f"Content timeout for {url}") |
|
|
| |
| soup = BeautifulSoup(html, "lxml") |
| for script in soup(["script", "style"]): |
| script.decompose() |
| raw_content = soup.get_text(separator='\n', strip=True) |
|
|
| if not raw_content: |
| raise URLFetchError(f"Empty content from {url}") |
|
|
| return raw_content |
|
|
| except URLFetchError: |
| raise |
| except Exception as e: |
| if _is_browser_closed_error(e): |
| await _cleanup_browser() |
| raise URLFetchError(f"Error fetching {url}: {e}") |
| finally: |
| if page: |
| try: |
| await asyncio.wait_for(page.close(), timeout=5.0) |
| except: |
| pass |
|
|
|
|
| def _clean_think_blocks(text: str) -> str: |
| """Remove <think>...</think> blocks from response.""" |
| if not text: |
| return text |
| return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE).strip() |
|
|
|
|
| async def summarize_content( |
| query: str, |
| content: str, |
| summarizer_base_url: str, |
| summarizer_model: str, |
| content_limit: int = 30000, |
| max_retries: int = 5, |
| ) -> Optional[str]: |
| """Summarize content with LLM using mmsearch_r1 style prompt.""" |
| |
| if not content: |
| return None |
|
|
| messages = [ |
| {"role": "system", "content": SUMMARY_SYSTEM_PROMPT}, |
| {"role": "user", "content": SUMMARY_USER_PROMPT.format( |
| query=query, |
| content=content[:content_limit], |
| content_limit=content_limit, |
| )}, |
| ] |
|
|
| |
| |
| |
| extra_body = None |
| if "qwen3" in summarizer_model.lower(): |
| extra_body = {"chat_template_kwargs": {"enable_thinking": False}} |
|
|
| for attempt in range(max_retries): |
| try: |
| result = await call_openai_api( |
| messages=messages, |
| model=summarizer_model, |
| base_url=summarizer_base_url, |
| api_key="", |
| max_tokens=8192, |
| extra_body=extra_body, |
| ) |
| if result.get("error"): |
| if attempt < max_retries - 1: |
| await asyncio.sleep(1) |
| continue |
| return None |
| |
| return _clean_think_blocks(result["content"]) |
| except Exception as e: |
| if attempt < max_retries - 1: |
| await asyncio.sleep(1) |
| else: |
| print(f"[SUMMARIZER ERROR] {e}") |
| return None |
| return None |
|
|
|
|
| async def call_text_search( |
| query: str, |
| serper_api_key: str, |
| summarizer_base_url: str, |
| summarizer_model: str, |
| serper_semaphore: asyncio.Semaphore, |
| serper_concurrency: int = 5, |
| top_k: int = 3, |
| content_limit: int = 30000, |
| search_cache: Optional["SearchCache"] = None, |
| max_serper_attempts: int = 3, |
| ) -> str: |
| """Call Google Serper API for text search, fetch each URL, summarize each, then generate final summary. |
| |
| Flow: |
| 1. Check cache (if enabled) |
| 2. Search via Serper to get URLs (with retry) |
| 3. For each URL: fetch content and summarize with LLM |
| 4. Generate final summary from all per-URL summaries |
| 5. Store in cache (if enabled) |
| """ |
| |
| if search_cache: |
| cached = await search_cache.get(query, top_k, summarizer_model) |
| if cached: |
| return f"Found cached summary for query: {query}\n{cached}" |
|
|
| |
| |
| |
| url = "https://google.serper.dev/search" |
| headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"} |
| payload = {"q": query} |
|
|
| all_search_results = None |
| search_results = None |
| last_error = None |
|
|
| for attempt in range(max_serper_attempts): |
| async with serper_semaphore: |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.post(url, headers=headers, json=payload) as resp: |
| if resp.status != 200: |
| error_text = await resp.text() |
| raise Exception(f"HTTP {resp.status}: {error_text}") |
| data = await resp.json() |
| organic_results = data.get("organic", []) |
| if not organic_results: |
| return f"No search results found for query: {query}" |
| search_results = organic_results |
| break |
| except Exception as e: |
| last_error = e |
| if attempt < max_serper_attempts - 1: |
| |
| backoff = 1.0 * (attempt + 1) |
| print(f"[SERPER RETRY] Attempt {attempt + 1}/{max_serper_attempts} failed: {e}, retrying in {backoff}s...", flush=True) |
| await asyncio.sleep(backoff) |
|
|
| |
| if search_results is None: |
| error_msg = str(last_error) if last_error and str(last_error) else "Unknown error" |
| raise FatalAPIError(f"[SERPER ERROR] Failed after {max_serper_attempts} attempts: {error_msg}") from last_error |
|
|
| result = await summarize_serper_organic_results( |
| query=query, |
| organic_results=search_results, |
| summarizer_base_url=summarizer_base_url, |
| summarizer_model=summarizer_model, |
| top_k=top_k, |
| content_limit=content_limit, |
| ) |
|
|
| |
| if search_cache and not result.startswith("Error:"): |
| await search_cache.set(query, top_k, summarizer_model, result) |
|
|
| return f"Final summary generated for query: {query}\n{result}" |
|
|
|
|
| async def summarize_serper_organic_results( |
| query: str, |
| organic_results: list[dict], |
| summarizer_base_url: str, |
| summarizer_model: str, |
| top_k: int = 3, |
| content_limit: int = 30000, |
| ) -> str: |
| """Fetch Serper organic result URLs and summarize them with the text_search_tool pipeline.""" |
| all_search_results = [ |
| {'title': res.get('title', ''), 'snippet': res.get('snippet', ''), 'link': res.get('link', '')} |
| for res in organic_results |
| ] |
| search_results = all_search_results[:top_k] |
|
|
| |
| |
| async def fetch_and_summarize(item, max_retries: int = 3): |
| item_url = item.get("link", "") |
| if not item_url: |
| return None |
|
|
| for attempt in range(max_retries): |
| try: |
| content = await fetch_url_content(item_url) |
| if not content: |
| |
| _web_fetch_stats.record_skip() |
| return None |
| summary = await summarize_content( |
| query=query, |
| content=content, |
| summarizer_base_url=summarizer_base_url, |
| summarizer_model=summarizer_model, |
| content_limit=content_limit, |
| ) |
| _web_fetch_stats.record_success() |
| return summary |
| except Exception as e: |
| if attempt == max_retries - 1: |
| print(f"[WEB] Failed to fetch {item_url} after {max_retries} attempts: {e}", flush=True) |
| _web_fetch_stats.record_failure(str(e)) |
| return None |
| |
| backoff_time = min(2 ** attempt, 5) |
| await asyncio.sleep(backoff_time) |
| return None |
|
|
| tasks = [fetch_and_summarize(item) for item in search_results] |
| summaries = await asyncio.gather(*tasks, return_exceptions=True) |
| |
| summaries = [s if not isinstance(s, Exception) else None for s in summaries] |
|
|
| |
| |
| |
| all_content = f"Search Query: {query}\n\n" |
| all_content += "--- Search Results ---\n" |
| for i, result in enumerate(all_search_results, 1): |
| all_content += f"[{i}] Title: {result.get('title', '')}\n" |
| all_content += f" Snippet: {result.get('snippet', 'No snippet')}\n" |
| all_content += f" Link: {result.get('link', 'No link')}\n\n" |
|
|
| all_content += "--- Web Page Summaries ---\n" |
| for summary in summaries: |
| if summary: |
| all_content += f"{summary}\n\n" |
|
|
| |
| final_summary = await summarize_content( |
| query=query, |
| content=all_content, |
| summarizer_base_url=summarizer_base_url, |
| summarizer_model=summarizer_model, |
| content_limit=content_limit, |
| ) |
|
|
| |
| if not final_summary: |
| return f"Error: Failed to generate final summary for query: {query}" |
|
|
| return final_summary |
|
|
|
|
| def _coerce_tavily_option(value: str): |
| """把 CLI/env 字符串转成 Tavily API 可接受的布尔值或枚举字符串。""" |
| normalized = (value or "").strip().lower() |
| if normalized in {"", "none", "null"}: |
| return False |
| if normalized in {"1", "true", "yes", "y", "on"}: |
| return True |
| if normalized in {"0", "false", "no", "n", "off"}: |
| return False |
| return normalized |
|
|
|
|
| def _split_csv_env(value: str) -> list[str]: |
| """解析逗号分隔的域名过滤配置。""" |
| return [item.strip() for item in (value or "").split(",") if item.strip()] |
|
|
|
|
| def _split_secret_list(value: str) -> list[str]: |
| """解析 API key 列表,支持逗号、分号和空白分隔。""" |
| return [item.strip() for item in re.split(r"[,;\s]+", value or "") if item.strip()] |
|
|
|
|
| def load_tavily_api_keys() -> list[str]: |
| """从环境变量和可选文件加载 Tavily API key 池。""" |
| keys: list[str] = [] |
|
|
| keys.extend(_split_secret_list(os.environ.get("TAVILY_API_KEYS", ""))) |
|
|
| key_file = os.environ.get("TAVILY_API_KEY_FILE", DEFAULT_TAVILY_API_KEY_FILE).strip() |
| if key_file: |
| try: |
| with open(key_file, "r", encoding="utf-8") as f: |
| for line in f: |
| item = line.strip() |
| if not item or item.startswith("#"): |
| continue |
| keys.extend(_split_secret_list(item)) |
| except OSError as e: |
| print(f"[WARN] Failed to read TAVILY_API_KEY_FILE={key_file}: {e}", flush=True) |
|
|
| single_key = os.environ.get("TAVILY_API_KEY", "").strip() |
| if single_key: |
| keys.append(single_key) |
|
|
| deduped = [] |
| seen = set() |
| for key in keys: |
| if key not in seen: |
| deduped.append(key) |
| seen.add(key) |
| return deduped |
|
|
|
|
| class TavilyApiKeyPool: |
| """Tavily API key 轮询池,遇到限流或服务错误时临时冷却单个 key。""" |
|
|
| def __init__(self, keys: list[str], cooldown_seconds: int = 60): |
| self.keys = list(keys) |
| self.cooldown_seconds = max(1, int(cooldown_seconds or 60)) |
| self._cooldown_until = {key: 0.0 for key in self.keys} |
| self._cursor = 0 |
| self._lock = asyncio.Lock() |
|
|
| def __len__(self) -> int: |
| return len(self.keys) |
|
|
| async def next_key(self, exclude: Optional[set[str]] = None) -> str: |
| if not self.keys: |
| return "" |
|
|
| async with self._lock: |
| excluded = exclude or set() |
| available_keys = [key for key in self.keys if key not in excluded] |
| if not available_keys: |
| return "" |
|
|
| now = time.time() |
| n = len(self.keys) |
| for _ in range(n): |
| key = self.keys[self._cursor % n] |
| self._cursor += 1 |
| if key in excluded: |
| continue |
| if self._cooldown_until.get(key, 0.0) <= now: |
| return key |
|
|
| |
| |
| key = min(available_keys, key=lambda k: self._cooldown_until.get(k, 0.0)) |
| self._cursor += 1 |
| return key |
|
|
| async def mark_failure(self, key: str, status: int | None = None) -> None: |
| if not key: |
| return |
| if status in {401, 403, 432}: |
| cooldown = max(self.cooldown_seconds, 3600) |
| elif status == 429: |
| cooldown = self.cooldown_seconds |
| elif status and status >= 500: |
| cooldown = min(self.cooldown_seconds, 30) |
| else: |
| cooldown = min(self.cooldown_seconds, 10) |
|
|
| async with self._lock: |
| self._cooldown_until[key] = time.time() + cooldown |
|
|
|
|
| def _truncate_text(value: str, limit: int = 1400) -> str: |
| """限制 Tavily 返回给模型的单条结果长度,避免工具响应过长。""" |
| text = re.sub(r"\s+", " ", value or "").strip() |
| if len(text) <= limit: |
| return text |
| return text[: limit - 3].rstrip() + "..." |
|
|
|
|
| async def call_tavily_search( |
| query: str, |
| session: aiohttp.ClientSession, |
| api_key: str = "", |
| api_key_pool: Optional[TavilyApiKeyPool] = None, |
| search_depth: str = "advanced", |
| max_results: int = 8, |
| include_answer: str = "advanced", |
| include_raw_content: str = "false", |
| topic: str = "general", |
| auto_parameters: bool = False, |
| include_domains: Optional[list[str]] = None, |
| exclude_domains: Optional[list[str]] = None, |
| timeout_seconds: int = 60, |
| search_cache: Optional["SearchCache"] = None, |
| ) -> str: |
| """调用 Tavily Search API,直接返回 answer 与搜索结果,不再二次 summarizer。""" |
| clean_query = query.strip().replace("\n", " ") |
| if not clean_query: |
| return "Error: Tavily search requires a non-empty query." |
|
|
| include_answer_value = _coerce_tavily_option(include_answer) |
| include_raw_value = _coerce_tavily_option(include_raw_content) |
| max_results = max(1, min(int(max_results or 8), 20)) |
| cache_model = ( |
| f"tavily:depth={search_depth}:answer={include_answer_value}:" |
| f"raw={include_raw_value}:topic={topic}:auto={int(bool(auto_parameters))}" |
| ) |
|
|
| if search_cache: |
| cached = await search_cache.get(clean_query, max_results, cache_model) |
| if cached: |
| return cached |
|
|
| if not api_key and (not api_key_pool or len(api_key_pool) == 0): |
| return "Error: TAVILY_API_KEY/TAVILY_API_KEYS not configured." |
|
|
| payload = { |
| "query": clean_query, |
| "topic": topic, |
| "search_depth": search_depth, |
| "max_results": max_results, |
| "include_answer": include_answer_value, |
| "include_raw_content": include_raw_value, |
| "auto_parameters": bool(auto_parameters), |
| } |
| if include_domains: |
| payload["include_domains"] = include_domains |
| if exclude_domains: |
| payload["exclude_domains"] = exclude_domains |
|
|
| max_attempts = max(1, len(api_key_pool) if api_key_pool else 1) |
| last_error = "" |
| data = None |
| attempted_keys: set[str] = set() |
| for attempt in range(max_attempts): |
| selected_key = await api_key_pool.next_key(exclude=attempted_keys) if api_key_pool else api_key |
| if not selected_key: |
| break |
| attempted_keys.add(selected_key) |
| try: |
| headers = { |
| "Authorization": f"Bearer {selected_key}", |
| "Content-Type": "application/json", |
| } |
| async with session.post( |
| "https://api.tavily.com/search", |
| headers=headers, |
| json=payload, |
| timeout=aiohttp.ClientTimeout(total=timeout_seconds), |
| ) as resp: |
| body = await resp.text() |
| if resp.status == 200: |
| try: |
| data = json.loads(body) |
| except json.JSONDecodeError as e: |
| last_error = f"Error: Tavily search returned invalid JSON: {e}" |
| if api_key_pool: |
| await api_key_pool.mark_failure(selected_key, None) |
| if attempt < max_attempts - 1: |
| continue |
| break |
| break |
|
|
| last_error = f"Error: Tavily search returned HTTP {resp.status}: {body[:500]}" |
| if api_key_pool: |
| await api_key_pool.mark_failure(selected_key, resp.status) |
| if attempt < max_attempts - 1: |
| continue |
| break |
| return last_error |
| except asyncio.TimeoutError: |
| last_error = f"Error: Tavily search timeout for query: {clean_query[:100]}" |
| if api_key_pool: |
| await api_key_pool.mark_failure(selected_key, None) |
| if attempt < max_attempts - 1: |
| continue |
| break |
| return last_error |
| except aiohttp.ClientError as e: |
| last_error = f"Error: Tavily search failed: {e}" |
| if api_key_pool: |
| await api_key_pool.mark_failure(selected_key, None) |
| if attempt < max_attempts - 1: |
| continue |
| break |
| return last_error |
| except Exception as e: |
| last_error = f"Error: Tavily search failed: {e}" |
| if api_key_pool: |
| await api_key_pool.mark_failure(selected_key, None) |
| if attempt < max_attempts - 1: |
| continue |
| break |
| return last_error |
|
|
| if data is None: |
| if api_key_pool: |
| raise FatalAPIError( |
| "[TAVILY ERROR] All Tavily API keys failed for one search request " |
| f"after {len(attempted_keys)}/{max_attempts} attempts. Last error: " |
| f"{last_error or 'no key available'}" |
| ) |
| return last_error or "Error: Tavily search failed before receiving a response." |
|
|
| lines = [ |
| f'Web Search Results for "{clean_query}" ' |
| f"(Tavily, depth={search_depth}, max_results={max_results}):" |
| ] |
| answer = (data.get("answer") or "").strip() |
| if answer: |
| lines.extend(["", f"Answer: {answer}"]) |
|
|
| results = data.get("results") or [] |
| if results: |
| lines.append("") |
| lines.append("Results:") |
| for idx, item in enumerate(results, 1): |
| title = (item.get("title") or "").strip() |
| url = (item.get("url") or "").strip() |
| score = item.get("score") |
| content = _truncate_text(item.get("content") or "") |
| raw_content = item.get("raw_content") |
|
|
| lines.append(f"{idx}. {title or 'Untitled'}") |
| if url: |
| lines.append(f" URL: {url}") |
| if score is not None: |
| lines.append(f" Score: {score}") |
| if content: |
| lines.append(f" Content: {content}") |
| if raw_content: |
| lines.append(f" Raw Content: {_truncate_text(str(raw_content), 1800)}") |
| else: |
| lines.extend(["", "No Tavily results returned."]) |
|
|
| request_id = data.get("request_id") |
| response_time = data.get("response_time") |
| if request_id or response_time: |
| suffix = [] |
| if response_time: |
| suffix.append(f"response_time={response_time}") |
| if request_id: |
| suffix.append(f"request_id={request_id}") |
| lines.extend(["", "Metadata: " + ", ".join(suffix)]) |
|
|
| result = "\n".join(lines) |
| if search_cache: |
| await search_cache.set(clean_query, max_results, cache_model, result) |
| return result |
|
|
|
|
| def _default_web_search_backend() -> str: |
| """读取 web_search 默认后端;未配置时使用企业 Serper 网关。""" |
| return os.environ.get( |
| "VIDEO_DR_WEB_SEARCH_BACKEND", |
| os.environ.get("WEB_SEARCH_BACKEND", "serper_gateway"), |
| ) |
|
|
|
|
| def _normalize_web_search_backend(raw_backend: str) -> str: |
| """归一化 web_search 后端别名。""" |
| backend = (raw_backend or _default_web_search_backend()).strip().lower().replace("-", "_") |
| aliases = { |
| "gateway": "serper_gateway", |
| "gateway_serper": "serper_gateway", |
| "internal_serper": "serper_gateway", |
| "company_serper": "serper_gateway", |
| } |
| return aliases.get(backend, backend) |
|
|
|
|
| def _get_gateway_text_search_config(kwargs: dict) -> tuple[str, str, str, str]: |
| """从 kwargs/env/VideoDR 配置读取企业 Serper 网关参数。""" |
| gateway_url = ( |
| kwargs.get("serper_gateway_url") |
| or os.environ.get("SERPER_GATEWAY_URL") |
| or os.environ.get("GATEWAY_URL") |
| or VIDEO_DR_GATEWAY_URL |
| ) |
| gateway_username = ( |
| kwargs.get("serper_gateway_username") |
| or os.environ.get("SERPER_GATEWAY_USERNAME") |
| or os.environ.get("GATEWAY_USERNAME") |
| or VIDEO_DR_GATEWAY_USERNAME |
| ) |
| gateway_userid = ( |
| kwargs.get("serper_gateway_userid") |
| or os.environ.get("SERPER_GATEWAY_USERID") |
| or os.environ.get("GATEWAY_USERID") |
| or VIDEO_DR_GATEWAY_USERID |
| ) |
| gateway_token = ( |
| kwargs.get("serper_gateway_token") |
| or os.environ.get("SERPER_GATEWAY_TOKEN") |
| or os.environ.get("GATEWAY_TOKEN") |
| or VIDEO_DR_GATEWAY_TOKEN |
| ) |
| return gateway_url, gateway_username, gateway_userid, gateway_token |
|
|
|
|
| def _parse_gateway_model_output(gateway_resp: dict) -> dict: |
| """解析企业网关返回的 model_output 字段。""" |
| raw_output = gateway_resp.get("model_output", {}) |
| if isinstance(raw_output, str): |
| return json.loads(raw_output or "{}") |
| if isinstance(raw_output, dict): |
| return raw_output |
| raise ValueError(f"Unexpected gateway model_output type: {type(raw_output).__name__}") |
|
|
|
|
| def _format_serper_text_results(data: dict, max_results: int = 8) -> str: |
| """把 Serper 文搜结果格式化成工具响应中的可读文本。""" |
| lines = [] |
|
|
| answer_box = data.get("answerBox") or {} |
| if answer_box: |
| answer = answer_box.get("answer") or answer_box.get("snippet") or "" |
| title = answer_box.get("title") or "" |
| link = answer_box.get("link") or "" |
| if answer: |
| lines.append("Quick Answer:") |
| if title: |
| lines.append(f" Title: {title}") |
| lines.append(f" Answer: {answer}") |
| if link: |
| lines.append(f" URL: {link}") |
| lines.append("") |
|
|
| knowledge_graph = data.get("knowledgeGraph") or {} |
| if knowledge_graph: |
| kg_title = knowledge_graph.get("title") or "" |
| kg_type = knowledge_graph.get("type") or "" |
| kg_desc = knowledge_graph.get("description") or "" |
| kg_link = knowledge_graph.get("descriptionLink") or knowledge_graph.get("website") or "" |
| lines.append("Knowledge Graph:") |
| if kg_title: |
| lines.append(f" Title: {kg_title}") |
| if kg_type: |
| lines.append(f" Type: {kg_type}") |
| if kg_desc: |
| lines.append(f" Description: {kg_desc}") |
| if kg_link: |
| lines.append(f" URL: {kg_link}") |
| lines.append("") |
|
|
| organic = data.get("organic") or [] |
| if organic: |
| lines.append("Organic Results:") |
| for idx, item in enumerate(organic[:max_results], 1): |
| title = item.get("title") or "" |
| snippet = item.get("snippet") or "" |
| link = item.get("link") or "" |
| source = item.get("source") or item.get("domain") or "" |
| lines.append(f"Result {idx}:") |
| if title: |
| lines.append(f" Title: {title}") |
| if snippet: |
| lines.append(f" Snippet: {snippet}") |
| if source: |
| lines.append(f" Source: {source}") |
| if link: |
| lines.append(f" URL: {link}") |
| lines.append("") |
|
|
| people_also_ask = data.get("peopleAlsoAsk") or [] |
| if people_also_ask: |
| lines.append("People Also Ask:") |
| for idx, item in enumerate(people_also_ask[: min(3, max_results)], 1): |
| question = item.get("question") or "" |
| snippet = item.get("snippet") or "" |
| link = item.get("link") or "" |
| lines.append(f"{idx}. {question}") |
| if snippet: |
| lines.append(f" {snippet}") |
| if link: |
| lines.append(f" URL: {link}") |
| lines.append("") |
|
|
| formatted = "\n".join(lines).strip() |
| return formatted or "No relevant results found." |
|
|
|
|
| def _build_serper_text_summary_prompt(query: str, formatted_results: str) -> str: |
| """构造企业 Serper 文搜结果 summarizer prompt。""" |
| return ( |
| "You are a helpful assistant. Your task is to summarize the following Google " |
| "Serper text search results in no more than five sentences.\n\n" |
| "Focus on facts that help answer the user's query. Preserve important names, " |
| "dates, entities, and distinctions when they appear in the search results. " |
| "Use only the provided search result titles, snippets, quick answers, and " |
| "metadata. Do not invent facts and do not assume content from linked pages.\n\n" |
| "If the results are ambiguous, conflicting, or insufficient, clearly state " |
| "that uncertainty.\n\n" |
| f"Query: {query}\n\n" |
| f"Search Results:\n{formatted_results}" |
| ) |
|
|
|
|
| async def summarize_serper_text_results( |
| query: str, |
| formatted_results: str, |
| session: aiohttp.ClientSession, |
| summarizer_address: str = "", |
| summarizer_model: str = "", |
| max_tokens: int = 1024, |
| ) -> Optional[str]: |
| """用本地 summarizer 汇总 Serper 文搜元数据,不抓取网页正文。""" |
| summarizer_addr = summarizer_address or MARS_SUMMARIZER_ADDRESS |
| sum_model = summarizer_model or MARS_SUMMARIZER_MODEL |
| if not formatted_results or not summarizer_addr or not sum_model: |
| return None |
|
|
| payload = { |
| "model": sum_model, |
| "messages": [{"role": "user", "content": _build_serper_text_summary_prompt(query, formatted_results)}], |
| "max_tokens": max_tokens, |
| "temperature": 0.3, |
| "chat_template_kwargs": {"enable_thinking": False}, |
| } |
| try: |
| async with session.post( |
| f"http://{summarizer_addr.rstrip('/')}/v1/chat/completions", |
| json=payload, |
| headers={"Content-Type": "application/json"}, |
| timeout=aiohttp.ClientTimeout(total=120), |
| ) as resp: |
| if resp.status != 200: |
| print(f" [SERPER_GATEWAY] Summarizer returned HTTP {resp.status}, falling back to raw results", flush=True) |
| return None |
| data = await resp.json() |
| choices = data.get("choices", []) |
| if choices and isinstance(choices, list): |
| content = choices[0].get("message", {}).get("content", "") |
| summary = _clean_think_blocks(content).strip() |
| return summary or None |
| except asyncio.TimeoutError: |
| print(" [SERPER_GATEWAY] Summarizer timeout, falling back to raw results", flush=True) |
| except Exception as e: |
| print(f" [SERPER_GATEWAY] Summarizer error: {e}, falling back to raw results", flush=True) |
| return None |
|
|
|
|
| async def call_serper_gateway_search( |
| query: str, |
| session: aiohttp.ClientSession, |
| kwargs: dict, |
| max_results: int = 8, |
| timeout_seconds: int = 60, |
| search_cache: Optional["SearchCache"] = None, |
| ) -> str: |
| """通过企业内部网关调用 Serper 文本搜索,后续复用 text_search_tool 的网页抓取与总结流程。""" |
| clean_query = query.strip().replace("\n", " ") |
| if not clean_query: |
| return "Error: Serper gateway search requires a non-empty query." |
|
|
| max_results = max(1, min(int(max_results or 8), 20)) |
| summarizer_base_url = kwargs.get("summarizer_base_url", "") or ( |
| f"http://{MARS_SUMMARIZER_ADDRESS.rstrip('/')}" if MARS_SUMMARIZER_ADDRESS else "" |
| ) |
| summarizer_model = kwargs.get("summarizer_model", "") or MARS_SUMMARIZER_MODEL |
| top_k = 3 |
| content_limit = 30000 |
| if not summarizer_base_url or not summarizer_model: |
| return "Error: Summarizer not configured for Serper gateway search." |
|
|
| cache_model = ( |
| f"serper_gateway_fetch:{summarizer_model}:max_results={max_results}:" |
| f"top_k={top_k}:content_limit={content_limit}" |
| ) |
| if search_cache: |
| cached = await search_cache.get(clean_query, top_k, cache_model) |
| if cached: |
| return f"Found cached summary for query: {clean_query}\n{cached}" |
|
|
| gateway_url, gateway_username, gateway_userid, gateway_token = _get_gateway_text_search_config(kwargs) |
| if not gateway_url: |
| return "Error: GATEWAY_URL/SERPER_GATEWAY_URL not configured." |
| if not gateway_token: |
| return "Error: GATEWAY_TOKEN/SERPER_GATEWAY_TOKEN not configured." |
|
|
| headers = { |
| "Content-Type": "application/json", |
| "User-Agent": "ifbook-http-client", |
| } |
| serper_params = { |
| "q": clean_query, |
| "type": "search", |
| "num": max_results, |
| } |
| gateway_payload = { |
| "sec_info": { |
| "username": gateway_username, |
| "userid": gateway_userid, |
| "token": gateway_token, |
| }, |
| "model_type": "openai", |
| "model_name": "serper", |
| "params": json.dumps(serper_params), |
| } |
|
|
| max_attempts = 2 |
| last_error = "" |
| for attempt in range(max_attempts): |
| try: |
| async with session.post( |
| gateway_url, |
| headers=headers, |
| json=gateway_payload, |
| timeout=aiohttp.ClientTimeout(total=timeout_seconds), |
| ) as resp: |
| body = await resp.text() |
| if resp.status != 200: |
| last_error = f"Error: Serper gateway returned HTTP {resp.status}: {body[:500]}" |
| if attempt < max_attempts - 1: |
| print( |
| f" [SERPER_GATEWAY] HTTP {resp.status}, retrying ({attempt + 1}/{max_attempts})...", |
| flush=True, |
| ) |
| await asyncio.sleep(3 * (attempt + 1)) |
| continue |
| return last_error |
|
|
| data = _parse_gateway_model_output(json.loads(body)) |
| organic_results = data.get("organic", []) |
| if not organic_results: |
| return f"No search results found for query: {clean_query}" |
|
|
| result = await summarize_serper_organic_results( |
| query=clean_query, |
| organic_results=organic_results, |
| summarizer_base_url=summarizer_base_url, |
| summarizer_model=summarizer_model, |
| top_k=top_k, |
| content_limit=content_limit, |
| ) |
| if search_cache and not result.startswith("Error:"): |
| await search_cache.set(clean_query, top_k, cache_model, result) |
| return f"Final summary generated for query: {clean_query}\n{result}" |
| except asyncio.TimeoutError: |
| last_error = f"Error: Serper gateway search timeout for query: {clean_query[:100]}" |
| except Exception as e: |
| last_error = f"Error: Serper gateway search failed: {e}" |
|
|
| if attempt < max_attempts - 1: |
| print(f" [SERPER_GATEWAY] Error, retrying ({attempt + 1}/{max_attempts}): {last_error}", flush=True) |
| await asyncio.sleep(3 * (attempt + 1)) |
|
|
| return last_error or "Error: Serper gateway search failed." |
|
|
|
|
| def resolve_web_search_backend(raw_backend: str, tavily_api_keys: list[str]) -> str: |
| """解析 web_search 后端;auto 在有 Tavily key 时优先 Tavily,否则用企业 Serper 网关。""" |
| backend = _normalize_web_search_backend(raw_backend) |
| if backend == "auto": |
| return "tavily" if tavily_api_keys else "serper_gateway" |
| return backend |
|
|
|
|
| async def call_configured_web_search(query: str, kwargs: dict) -> tuple[str, str]: |
| """按配置调用 VideoDR/web_search 后端,返回 `(结果文本, 后端名)`。""" |
| session = kwargs.get("shared_session") |
| raw_backend = kwargs.get("web_search_backend", _default_web_search_backend()) |
| tavily_api_key = kwargs.get("tavily_api_key", "") |
| tavily_key_pool = kwargs.get("tavily_api_key_pool") |
| backend = resolve_web_search_backend( |
| raw_backend, |
| tavily_key_pool.keys if tavily_key_pool else ([tavily_api_key] if tavily_api_key else []), |
| ) |
|
|
| if backend == "tavily": |
| result = await call_tavily_search( |
| query=query, |
| session=session, |
| api_key=tavily_api_key, |
| api_key_pool=tavily_key_pool, |
| search_depth=kwargs.get("tavily_search_depth", "advanced"), |
| max_results=kwargs.get("tavily_max_results", 8), |
| include_answer=kwargs.get("tavily_include_answer", "advanced"), |
| include_raw_content=kwargs.get("tavily_include_raw_content", "false"), |
| topic=kwargs.get("tavily_topic", "general"), |
| auto_parameters=kwargs.get("tavily_auto_parameters", False), |
| include_domains=kwargs.get("tavily_include_domains") or None, |
| exclude_domains=kwargs.get("tavily_exclude_domains") or None, |
| timeout_seconds=kwargs.get("tavily_timeout", 60), |
| search_cache=kwargs.get("search_cache"), |
| ) |
| return result, backend |
|
|
| if backend == "serper_gateway": |
| result = await call_serper_gateway_search( |
| query=query, |
| session=session, |
| kwargs=kwargs, |
| max_results=kwargs.get("serper_gateway_max_results", kwargs.get("tavily_max_results", 8)), |
| timeout_seconds=kwargs.get("serper_gateway_timeout", 60), |
| search_cache=kwargs.get("search_cache"), |
| ) |
| return result, backend |
|
|
| if backend != "mars": |
| return ( |
| f"Error: Unknown WEB_SEARCH_BACKEND '{raw_backend}'. " |
| "Supported values: serper_gateway, mars, tavily, auto.", |
| backend, |
| ) |
|
|
| return await vdr_mars_web_search(query, session), backend |
|
|
|
|
| |
| |
| |
|
|
| async def evaluate_direct( |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| question: str, |
| image_path: str, |
| system_prompt: str = "", |
| format_instruction: str = "", |
| **kwargs, |
| ) -> dict: |
| """Direct mode - single turn, no tools.""" |
| |
| img = load_and_process_image(image_path, kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True)) |
| img_b64, mime = image_to_base64(img) |
|
|
| |
| prompt = question |
| if format_instruction: |
| prompt = f"{question}\n{format_instruction}" |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({ |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{img_b64}"}}, |
| {"type": "text", "text": prompt}, |
| ] |
| }) |
|
|
| if client == "gemini": |
| result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "azure": |
| result = await call_azure_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "vertex": |
| result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "gateway": |
| result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) |
| else: |
| api_kwargs = dict(kwargs) |
| api_kwargs["request_debug_context"] = { |
| "sample_id": kwargs.get("sample_id", ""), |
| "turn": 0, |
| "task_kind": "image_direct", |
| } |
| result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) |
|
|
| assistant_text = result.get("full_text") or result["content"] |
| conversation_messages = json.loads(json.dumps(messages)) |
| if assistant_text: |
| conversation_messages.append({"role": "assistant", "content": assistant_text}) |
|
|
| return { |
| "output": assistant_text, |
| "input_messages": messages, |
| "conversation_messages": conversation_messages, |
| "finish_reason": result["finish_reason"], |
| "num_round": 1, |
| "tool_calls": result.get("tool_calls", []), |
| "error": result.get("error"), |
| "saved_images": [], |
| } |
|
|
|
|
| async def evaluate_video_direct( |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| question: str, |
| video_path: str, |
| system_prompt: str = GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, |
| sample_id: str = "", |
| images_dir: str = "", |
| extra_image_paths: Optional[list[str]] = None, |
| **kwargs, |
| ) -> dict: |
| """General video direct mode: sampled frames in, one model answer out.""" |
| saved_images = [] |
| image_counter = 0 |
| frame_cache_dir = kwargs.get("frame_cache_dir", "") |
| sample_frame_dir = os.path.join(frame_cache_dir or images_dir or ".", sample_id) |
| os.makedirs(sample_frame_dir, exist_ok=True) |
|
|
| def save_image_for_html(img: Image.Image, img_type: str) -> str: |
| nonlocal image_counter |
| image_counter += 1 |
| marker = f"[IMAGE {image_counter}]" |
| if images_dir and sample_id: |
| filename = f"{sample_id}_{img_type}_{image_counter}.jpg" |
| filepath = os.path.join(images_dir, filename) |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| img.save(filepath, "JPEG", quality=85) |
| saved_images.append({"marker": marker, "path": filename, "type": img_type}) |
| return marker |
|
|
| def build_image_part(img_path: str, img_type: str) -> tuple[dict, str]: |
| img = load_and_process_image( |
| img_path, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| img_b64, img_mime = image_to_base64(img) |
| marker = save_image_for_html(img, img_type) |
| return {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, marker |
|
|
| all_frames = extract_video_frames_1fps( |
| video_path, |
| sample_frame_dir, |
| kwargs.get("video_max_resolution", DEFAULT_VIDEO_MAX_RESOLUTION), |
| kwargs.get("video_jpeg_quality", DEFAULT_VIDEO_JPEG_QUALITY), |
| ) |
| initial_indices = vdr_uniform_sample_indices( |
| len(all_frames), |
| kwargs.get("video_initial_frames", DEFAULT_VIDEO_INITIAL_FRAMES), |
| ) |
|
|
| user_content = [] |
| for idx in initial_indices: |
| image_part, _ = build_image_part(all_frames[idx], "input_frame") |
| user_content.append(image_part) |
| for image_idx, extra_image_path in enumerate(extra_image_paths or [], start=1): |
| image_part, _ = build_image_part(extra_image_path, f"supplemental_image_{image_idx}") |
| user_content.append({ |
| "type": "text", |
| "text": ( |
| f"\nSupplemental image {image_idx} referenced by the question " |
| f"as <image {image_idx}>:\n" |
| ), |
| }) |
| user_content.append(image_part) |
| user_content.append({"type": "text", "text": question}) |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": user_content}) |
|
|
| if client == "gemini": |
| result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "azure": |
| result = await call_azure_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "vertex": |
| result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "gateway": |
| result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) |
| else: |
| api_kwargs = dict(kwargs) |
| api_kwargs["request_debug_context"] = { |
| "sample_id": sample_id, |
| "turn": 0, |
| "task_kind": "video_direct", |
| } |
| result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) |
|
|
| assistant_text = result.get("full_text") or result["content"] |
| conversation_messages = json.loads(json.dumps(messages)) |
| if assistant_text: |
| conversation_messages.append({"role": "assistant", "content": assistant_text}) |
|
|
| return { |
| "output": assistant_text, |
| "input_messages": messages, |
| "conversation_messages": conversation_messages, |
| "finish_reason": "error" if result.get("error") else result["finish_reason"], |
| "num_round": 1, |
| "tool_calls": [], |
| "error": result.get("error"), |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| } |
|
|
|
|
| async def evaluate_tool( |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| question: str, |
| image_path: str, |
| tool_system_prompt: str, |
| max_turns: int = 10, |
| serper_api_key: str = "", |
| image_search_data: Optional[dict] = None, |
| sample_id: str = "", |
| images_dir: str = "", |
| **kwargs, |
| ) -> dict: |
| """图像 benchmark 的工具模式,兼容旧工具名与 VideoDR 新工具名。""" |
| saved_images = [] |
| image_counter = 0 |
| bbox_config = kwargs.get("bbox_config") or vdr_get_bbox_config(model) |
| tool_work_dir = os.path.join(kwargs.get("frame_cache_dir", images_dir or "."), f"{sample_id}_tool") |
| os.makedirs(tool_work_dir, exist_ok=True) |
| image_search_cache = kwargs.get("image_search_cache") |
|
|
| def save_image_for_html(img: Image.Image, img_type: str) -> tuple[str, str]: |
| nonlocal image_counter |
| image_counter += 1 |
| marker = f"[IMAGE {image_counter}]" |
| if images_dir and sample_id: |
| filename = f"{sample_id}_{img_type}_{image_counter}.jpg" |
| filepath = os.path.join(images_dir, filename) |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| img.save(filepath, "JPEG", quality=85) |
| saved_images.append({"marker": marker, "path": filename, "type": img_type}) |
| return marker, filename |
| return marker, "" |
|
|
| def build_multimodal_tool_message(prefix: str, img: Image.Image, img_type: str) -> tuple[str, list[dict]]: |
| marker, _ = save_image_for_html(img, img_type) |
| img_b64, img_mime = image_to_base64(img) |
| tool_response = f"<tool_response>\n{prefix}{marker}\n</tool_response>" |
| content = [ |
| {"type": "text", "text": f"<tool_response>\n{prefix}"}, |
| {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, |
| {"type": "text", "text": "\n</tool_response>"}, |
| ] |
| return tool_response, content |
|
|
| original_image = Image.open(image_path) |
| original_image.load() |
| if original_image.mode != "RGB": |
| original_image = original_image.convert("RGB") |
|
|
| processed = process_image( |
| original_image, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| img_b64, mime = image_to_base64(processed) |
| save_image_for_html(processed, "input") |
|
|
| messages = [ |
| {"role": "system", "content": tool_system_prompt}, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{img_b64}"}}, |
| {"type": "text", "text": question}, |
| ], |
| }, |
| ] |
| initial_messages = json.loads(json.dumps(messages)) |
|
|
| tool_calls = [] |
| output_parts = [] |
| frame_label = kwargs.get("frame_label", 0) |
| format_retry_limit = int(kwargs.get("format_retry_limit", 2)) |
| format_retries = 0 |
|
|
| for turn in range(max_turns + format_retry_limit): |
| if client == "gemini": |
| result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "azure": |
| result = await call_azure_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "vertex": |
| result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "gateway": |
| result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) |
| else: |
| api_kwargs = dict(kwargs) |
| api_kwargs["request_debug_context"] = { |
| "sample_id": sample_id, |
| "turn": turn, |
| "task_kind": "image_tool", |
| } |
| result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) |
|
|
| output = result.get("full_text") or result["content"] |
| if result.get("error"): |
| if output: |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "error", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": result["error"], |
| "saved_images": saved_images, |
| } |
|
|
| tool_call = extract_tool_call_from_result(result, output, user_query=question) |
| if not tool_call: |
| recovered_answer = ( |
| _recover_answer_from_no_tool_output(result, output) |
| if kwargs.get("recover_no_tool_answer", True) |
| else None |
| ) |
| if recovered_answer: |
| if recovered_answer != output: |
| output = recovered_answer |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "answer", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "saved_images": saved_images, |
| "format_retries": format_retries, |
| } |
| if format_retries < format_retry_limit: |
| assistant_message, user_message, output_snippet = _make_format_repair_turn(output) |
| if assistant_message: |
| messages.append(assistant_message) |
| messages.append(user_message) |
| output_parts.append(output_snippet) |
| format_retries += 1 |
| continue |
| if output: |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "no_tool_calls", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "saved_images": saved_images, |
| "format_retries": format_retries, |
| } |
|
|
| tool_name, args = extract_tool_name_and_args(tool_call) |
| assistant_tool_output = sanitize_assistant_tool_turn(output) |
| allowed_tool_names = kwargs.get("allowed_tool_names") |
| canonical_tool_name = canonicalize_tool_name(tool_name) |
| if allowed_tool_names is not None and canonical_tool_name not in allowed_tool_names: |
| format_retries = 0 |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| tool_response = build_disallowed_tool_response(tool_name, allowed_tool_names) |
| tool_calls.append({ |
| "name": canonical_tool_name, |
| "raw_name": tool_name, |
| "blocked": True, |
| "reason": "tool_ablation_profile", |
| }) |
| messages.append({"role": "user", "content": tool_response}) |
| output_parts.append( |
| f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| continue |
| format_retries = 0 |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| tool_response = "" |
|
|
| if tool_name in {"image_zoom_in_tool", "zoom_in"}: |
| raw_bbox = args.get("bbox_2d") or args.get("bbox") |
| if raw_bbox and len(raw_bbox) == 4: |
| bbox = vdr_normalize_bbox(raw_bbox, bbox_config) |
| crop_path = os.path.join(tool_work_dir, f"zoom_turn{turn:03d}.jpg") |
| crop_frame_to_rgb_jpeg(image_path, bbox, crop_path) |
| cropped = load_and_process_image( |
| crop_path, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| tool_calls.append({"name": "zoom_in", "bbox": bbox}) |
| tool_response, content = build_multimodal_tool_message( |
| ( |
| f"Here is the zoomed-in region " |
| f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, " |
| f"x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] of Frame {frame_label}:\n" |
| ), |
| cropped, |
| "zoom", |
| ) |
| messages.append({"role": "user", "content": content}) |
| else: |
| tool_response = "<tool_response>\nError: bbox must have exactly 4 values.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| elif tool_name == "text_search_tool": |
| query = args.get("query", "") |
| summarizer_base_url = kwargs.get("summarizer_base_url", "") |
| summarizer_model = kwargs.get("summarizer_model", "") |
| if query and serper_api_key and summarizer_base_url and summarizer_model: |
| search_result = await call_text_search( |
| query=query, |
| serper_api_key=serper_api_key, |
| summarizer_base_url=summarizer_base_url, |
| summarizer_model=summarizer_model, |
| serper_semaphore=kwargs.get("serper_semaphore"), |
| serper_concurrency=kwargs.get("serper_concurrency", 5), |
| search_cache=kwargs.get("search_cache"), |
| ) |
| tool_calls.append({"name": tool_name, "query": query}) |
| tool_response = f"<tool_response>\n{search_result}\n</tool_response>" |
| else: |
| tool_response = "<tool_response>\nError: Search not available.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| elif tool_name == "web_search": |
| query = args.get("query", "") |
| if query: |
| search_result, search_backend = await call_configured_web_search(query, kwargs) |
| tool_calls.append({"name": "web_search", "query": query, "backend": search_backend}) |
| if search_result.startswith(f'Web Search Results for "{query}"'): |
| tool_response = f"<tool_response>\n{search_result}\n</tool_response>" |
| else: |
| tool_response = f'<tool_response>\nWeb Search Results for "{query}":\n\n{search_result}\n</tool_response>' |
| else: |
| tool_response = "<tool_response>\nError: web_search requires a non-empty query.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| elif tool_name in {"image_search_tool", "image_search"}: |
| raw_bbox = args.get("bbox", [0, 0, 1000, 1000]) |
| tool_calls.append({"name": "image_search", "bbox": raw_bbox}) |
|
|
| if image_search_data: |
| title_list = image_search_data.get("image_search_title_list", []) |
| thumbnail_list = image_search_data.get("image_search_thumbnail_list", []) |
| data_root = kwargs.get("data_root", "") |
|
|
| if title_list and thumbnail_list: |
| content_parts = [{"type": "text", "text": "<tool_response>\nReverse Image Search Results:"}] |
| thumb_markers = [] |
| for i, title in enumerate(title_list): |
| content_parts.append({"type": "text", "text": f"\n\nTitle {i+1}: {title}\nThumbnail {i+1}: "}) |
| if i < len(thumbnail_list): |
| thumb_path = thumbnail_list[i] |
| if data_root and not os.path.isabs(thumb_path): |
| thumb_path = os.path.join(data_root, thumb_path) |
| if os.path.exists(thumb_path): |
| thumb_img = load_and_process_image( |
| thumb_path, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| marker, _ = save_image_for_html(thumb_img, "thumbnail") |
| thumb_markers.append(marker) |
| thumb_b64, thumb_mime = image_to_base64(thumb_img) |
| content_parts.append({"type": "image_url", "image_url": {"url": f"data:{thumb_mime};base64,{thumb_b64}"}}) |
| content_parts.append({"type": "text", "text": "\n</tool_response>"}) |
| messages.append({"role": "user", "content": content_parts}) |
|
|
| tool_response = "<tool_response>\nReverse Image Search Results:" |
| for i, title in enumerate(title_list): |
| tool_response += f"\n\nTitle {i+1}: {title}\nThumbnail {i+1}: " |
| if i < len(thumb_markers): |
| tool_response += thumb_markers[i] |
| tool_response += "\n</tool_response>" |
| else: |
| tool_response = "<tool_response>\nNo matching images were found.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| if not raw_bbox or len(raw_bbox) != 4: |
| tool_response = "<tool_response>\nError: bbox must have exactly 4 values.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| bbox = vdr_normalize_bbox(raw_bbox, bbox_config) |
| padded_bbox = vdr_add_search_padding(bbox, image_path, padding=(0.2, 0.2), padding_cap_px=600) |
| crop_path = os.path.join(tool_work_dir, f"image_search_turn{turn:03d}.jpg") |
| crop_frame_to_rgb_jpeg(image_path, padded_bbox, crop_path) |
| with open(crop_path, "rb") as f: |
| crop_bytes = f.read() |
| cached = image_search_cache.get(crop_bytes) if image_search_cache else None |
| if cached: |
| result_text = cached |
| else: |
| crop_b64 = base64.b64encode(crop_bytes).decode("ascii") |
| result_text = await vdr_real_image_search(crop_b64, kwargs.get("shared_session"), serper_api_key, crop_path=crop_path) |
| if image_search_cache and not any(pat in result_text for pat in ("Image search error:", "Error:", "No results found from reverse image search")): |
| image_search_cache.set(crop_bytes, result_text) |
| tool_response = ( |
| f"<tool_response>\nReverse Image Search Results " |
| f"(region [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}] " |
| f"of Frame {frame_label}):\n\n{result_text}\n</tool_response>" |
| ) |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| else: |
| tool_response = f"<tool_response>\nError: Unknown tool '{tool_name}'.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| output_parts.append(f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n") |
|
|
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "max_turns", |
| "num_round": max_turns, |
| "tool_calls": tool_calls, |
| "error": None, |
| "saved_images": saved_images, |
| "format_retries": format_retries, |
| } |
|
|
|
|
| async def evaluate_video_tool( |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| question: str, |
| video_path: str, |
| tool_system_prompt: str = VIDEO_DR_SYSTEM_PROMPT, |
| max_turns: int = 10, |
| serper_api_key: str = "", |
| sample_id: str = "", |
| images_dir: str = "", |
| **kwargs, |
| ) -> dict: |
| """VideoDR 工具模式评测:64 帧初始化 + 五工具状态机。""" |
| saved_images = [] |
| image_counter = 0 |
| bbox_config = kwargs.get("bbox_config") or vdr_get_bbox_config(model) |
| frame_cache_dir = kwargs.get("frame_cache_dir", "") |
| sample_frame_dir = os.path.join(frame_cache_dir or images_dir or ".", sample_id) |
| tool_work_dir = os.path.join(sample_frame_dir, "_tool") |
| os.makedirs(tool_work_dir, exist_ok=True) |
| image_search_cache = kwargs.get("image_search_cache") |
|
|
| def save_image_for_html(img: Image.Image, img_type: str) -> str: |
| nonlocal image_counter |
| image_counter += 1 |
| marker = f"[IMAGE {image_counter}]" |
| if images_dir and sample_id: |
| filename = f"{sample_id}_{img_type}_{image_counter}.jpg" |
| filepath = os.path.join(images_dir, filename) |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| img.save(filepath, "JPEG", quality=85) |
| saved_images.append({"marker": marker, "path": filename, "type": img_type}) |
| return marker |
|
|
| def build_image_part(img_path: str, img_type: str) -> tuple[dict, str]: |
| img = load_and_process_image( |
| img_path, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| img_b64, img_mime = image_to_base64(img) |
| marker = save_image_for_html(img, img_type) |
| return {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, marker |
|
|
| all_frames = extract_video_frames_1fps( |
| video_path, |
| sample_frame_dir, |
| kwargs.get("video_max_resolution", DEFAULT_VIDEO_MAX_RESOLUTION), |
| kwargs.get("video_jpeg_quality", DEFAULT_VIDEO_JPEG_QUALITY), |
| ) |
| initial_indices = vdr_uniform_sample_indices(len(all_frames), kwargs.get("video_initial_frames", DEFAULT_VIDEO_INITIAL_FRAMES)) |
|
|
| user_content = [] |
| for idx in initial_indices: |
| image_part, _ = build_image_part(all_frames[idx], "input_frame") |
| user_content.append(image_part) |
| user_content.append({"type": "text", "text": question}) |
|
|
| messages = [ |
| {"role": "system", "content": tool_system_prompt}, |
| {"role": "user", "content": user_content}, |
| ] |
| initial_messages = json.loads(json.dumps(messages)) |
|
|
| locked_frame_idx = None |
| locked_frame_path = "" |
| tool_calls = [] |
| output_parts = [] |
| format_retry_limit = int(kwargs.get("format_retry_limit", 2)) |
| format_retries = 0 |
| force_final_answer_turn = bool(kwargs.get("force_final_answer_turn", True)) |
| final_answer_retry_limit = int(kwargs.get("final_answer_retry_limit", 1)) |
| final_answer_retries = 0 |
| forced_final_answer = False |
|
|
| for turn in range(max_turns + format_retry_limit): |
| force_final_this_turn = ( |
| force_final_answer_turn |
| and not forced_final_answer |
| and bool(tool_calls) |
| and turn >= max_turns - 1 |
| ) |
| if force_final_this_turn: |
| forced_final_answer = True |
| final_result = await call_final_only_answer( |
| client=client, |
| model=model, |
| base_url=base_url, |
| api_key=api_key, |
| question=question, |
| output_parts=output_parts, |
| tool_calls=tool_calls, |
| sample_id=sample_id, |
| retry_limit=final_answer_retry_limit, |
| **kwargs, |
| ) |
| output_parts.append(final_result["transcript"]) |
| final_api_result = final_result.get("result") or {} |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages + final_result.get("messages", []))), |
| "finish_reason": final_result["finish_reason"], |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": final_api_result.get("error"), |
| "error_status_code": final_api_result.get("error_status_code"), |
| "error_raw_response": final_api_result.get("error_raw_response", ""), |
| "request_debug": final_api_result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_result.get("retries", 0), |
| "final_only_answer": True, |
| } |
|
|
| if client == "gemini": |
| result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "azure": |
| result = await call_azure_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "vertex": |
| result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) |
| elif client == "gateway": |
| result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) |
| else: |
| api_kwargs = dict(kwargs) |
| api_kwargs["request_debug_context"] = { |
| "sample_id": sample_id, |
| "turn": turn, |
| "task_kind": "video_tool", |
| } |
| result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) |
|
|
| output = result.get("full_text") or result["content"] |
| if result.get("error"): |
| if output: |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "error", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": result["error"], |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| } |
|
|
| tool_call = extract_tool_call_from_result(result, output, user_query=question) |
| if not tool_call: |
| recovered_answer = ( |
| _recover_answer_from_no_tool_output(result, output) |
| if kwargs.get("recover_no_tool_answer", True) |
| else None |
| ) |
| if recovered_answer: |
| if recovered_answer != output: |
| output = recovered_answer |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "answer", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_answer_retries, |
| } |
| if forced_final_answer and final_answer_retries < final_answer_retry_limit: |
| assistant_message = {"role": "assistant", "content": output} if output else None |
| if assistant_message: |
| messages.append(assistant_message) |
| messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) |
| output_parts.append( |
| f"{output}<|im_end|><|im_start|>user\n" |
| f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| final_answer_retries += 1 |
| continue |
| if format_retries < format_retry_limit: |
| assistant_message, user_message, output_snippet = _make_format_repair_turn(output) |
| if assistant_message: |
| messages.append(assistant_message) |
| messages.append(user_message) |
| output_parts.append(output_snippet) |
| format_retries += 1 |
| continue |
| if output: |
| messages.append({"role": "assistant", "content": output}) |
| output_parts.append(output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "no_tool_calls", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_answer_retries, |
| } |
|
|
| tool_name, args = extract_tool_name_and_args(tool_call) |
| assistant_tool_output = sanitize_assistant_tool_turn(output) |
| allowed_tool_names = kwargs.get("allowed_tool_names") |
| canonical_tool_name = canonicalize_tool_name(tool_name) |
| if allowed_tool_names is not None and canonical_tool_name not in allowed_tool_names: |
| if forced_final_answer: |
| if final_answer_retries < final_answer_retry_limit: |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) |
| output_parts.append( |
| f"{assistant_tool_output}<|im_end|><|im_start|>user\n" |
| f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| final_answer_retries += 1 |
| continue |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| output_parts.append(assistant_tool_output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "max_turns", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_answer_retries, |
| } |
| format_retries = 0 |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| tool_response = build_disallowed_tool_response(tool_name, allowed_tool_names) |
| tool_calls.append({ |
| "name": canonical_tool_name, |
| "raw_name": tool_name, |
| "blocked": True, |
| "reason": "tool_ablation_profile", |
| }) |
| messages.append({"role": "user", "content": tool_response}) |
| output_parts.append( |
| f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| continue |
| if forced_final_answer: |
| if final_answer_retries < final_answer_retry_limit: |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) |
| output_parts.append( |
| f"{assistant_tool_output}<|im_end|><|im_start|>user\n" |
| f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" |
| ) |
| final_answer_retries += 1 |
| continue |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| output_parts.append(assistant_tool_output + "<|im_end|>") |
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "max_turns", |
| "num_round": turn + 1, |
| "tool_calls": tool_calls, |
| "error": None, |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_answer_retries, |
| } |
| format_retries = 0 |
| messages.append({"role": "assistant", "content": assistant_tool_output}) |
| tool_response = "" |
|
|
| if tool_name == "choose_frames": |
| start = int(args.get("start_frame_index", 0)) |
| end = int(args.get("end_frame_index", start)) |
| start = max(0, min(start, len(all_frames) - 1)) |
| end = max(0, min(end, len(all_frames) - 1)) |
| if start > end: |
| start, end = end, start |
| interval_frames = vdr_sample_interval( |
| all_frames, |
| start, |
| end, |
| kwargs.get("video_interval_samples", DEFAULT_VIDEO_INTERVAL_SAMPLES), |
| ) |
| if not interval_frames: |
| interval_frames = [vdr_get_frame(all_frames, (start + end) // 2)] |
|
|
| content_parts = [{ |
| "type": "text", |
| "text": f"<tool_response>\nHere are {len(interval_frames)} uniformly sampled frames from the interval [Frame {start} to Frame {end}]:\n\n", |
| }] |
| markers = [] |
| for idx, frame_path in interval_frames: |
| image_part, marker = build_image_part(frame_path, "interval_frame") |
| markers.append(marker) |
| content_parts.append(image_part) |
| content_parts.append({"type": "text", "text": "\n"}) |
| content_parts.append({"type": "text", "text": "</tool_response>"}) |
| messages.append({"role": "user", "content": content_parts}) |
|
|
| tool_calls.append({"name": "choose_frames", "start_frame_index": start, "end_frame_index": end}) |
| tool_response = ( |
| f"<tool_response>\nHere are {len(interval_frames)} uniformly sampled frames from the interval " |
| f"[Frame {start} to Frame {end}]:\n\n" + "\n".join(markers) + "\n</tool_response>" |
| ) |
|
|
| elif tool_name == "find_frame": |
| requested_idx = int(args.get("frame_index", 0)) |
| locked_frame_idx, locked_frame_path = vdr_get_frame(all_frames, requested_idx) |
| image_part, marker = build_image_part(locked_frame_path, "locked_frame") |
| messages.append({ |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": f"<tool_response>\nHere is Frame {locked_frame_idx}:\n"}, |
| image_part, |
| {"type": "text", "text": "\n</tool_response>"}, |
| ], |
| }) |
| tool_calls.append({"name": "find_frame", "frame_index": requested_idx, "actual_frame_index": locked_frame_idx}) |
| tool_response = f"<tool_response>\nHere is Frame {locked_frame_idx}:\n{marker}\n</tool_response>" |
|
|
| elif tool_name == "zoom_in": |
| if not locked_frame_path: |
| tool_response = "<tool_response>\nError: zoom_in can only be used after find_frame.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| raw_bbox = args.get("bbox") |
| if not raw_bbox or len(raw_bbox) != 4: |
| tool_response = "<tool_response>\nError: bbox must have exactly 4 values.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| bbox = vdr_normalize_bbox(raw_bbox, bbox_config) |
| crop_path = os.path.join(tool_work_dir, f"zoom_turn{turn:03d}.jpg") |
| crop_frame_to_rgb_jpeg(locked_frame_path, bbox, crop_path) |
| cropped = load_and_process_image( |
| crop_path, |
| kwargs.get("min_pixels", 65536), |
| kwargs.get("max_pixels", 8294400), |
| kwargs.get("factor", 32), |
| kwargs.get("qwen_vl_processing", True), |
| ) |
| crop_b64, crop_mime = image_to_base64(cropped) |
| marker = save_image_for_html(cropped, "zoom") |
| messages.append({ |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": ( |
| f"<tool_response>\nHere is the zoomed-in region " |
| f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] " |
| f"of Frame {locked_frame_idx}:\n" |
| ), |
| }, |
| {"type": "image_url", "image_url": {"url": f"data:{crop_mime};base64,{crop_b64}"}}, |
| {"type": "text", "text": "\n</tool_response>"}, |
| ], |
| }) |
| tool_calls.append({"name": "zoom_in", "bbox": bbox, "frame_index": locked_frame_idx}) |
| tool_response = ( |
| f"<tool_response>\nHere is the zoomed-in region " |
| f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] " |
| f"of Frame {locked_frame_idx}:\n{marker}\n</tool_response>" |
| ) |
|
|
| elif tool_name == "image_search": |
| if not locked_frame_path: |
| tool_response = "<tool_response>\nError: image_search can only be used after find_frame.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| raw_bbox = args.get("bbox") |
| if not raw_bbox or len(raw_bbox) != 4: |
| tool_response = "<tool_response>\nError: bbox must have exactly 4 values.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
| else: |
| bbox = vdr_normalize_bbox(raw_bbox, bbox_config) |
| padded_bbox = vdr_add_search_padding(bbox, locked_frame_path, padding=(0.2, 0.2), padding_cap_px=600) |
| crop_path = os.path.join(tool_work_dir, f"image_search_turn{turn:03d}.jpg") |
| crop_frame_to_rgb_jpeg(locked_frame_path, padded_bbox, crop_path) |
| with open(crop_path, "rb") as f: |
| crop_bytes = f.read() |
| cached = image_search_cache.get(crop_bytes) if image_search_cache else None |
| if cached: |
| result_text = cached |
| else: |
| crop_b64 = base64.b64encode(crop_bytes).decode("ascii") |
| result_text = await vdr_real_image_search(crop_b64, kwargs.get("shared_session"), serper_api_key, crop_path=crop_path) |
| if image_search_cache and not any(pat in result_text for pat in ("Image search error:", "Error:", "No results found from reverse image search")): |
| image_search_cache.set(crop_bytes, result_text) |
| tool_calls.append({"name": "image_search", "bbox": bbox, "frame_index": locked_frame_idx}) |
| tool_response = ( |
| f"<tool_response>\nReverse Image Search Results " |
| f"(region [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}] " |
| f"of Frame {locked_frame_idx}):\n\n{result_text}\n</tool_response>" |
| ) |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| elif tool_name == "web_search": |
| query = args.get("query", "") |
| if not query: |
| tool_response = "<tool_response>\nError: web_search requires a non-empty query.\n</tool_response>" |
| else: |
| result_text, search_backend = await call_configured_web_search(query, kwargs) |
| if result_text.startswith(f'Web Search Results for "{query}"'): |
| tool_response = f"<tool_response>\n{result_text}\n</tool_response>" |
| else: |
| tool_response = f'<tool_response>\nWeb Search Results for "{query}":\n\n{result_text}\n</tool_response>' |
| tool_calls.append({ |
| "name": "web_search", |
| "query": query, |
| "frame_index": locked_frame_idx, |
| "backend": search_backend, |
| }) |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| else: |
| tool_response = f"<tool_response>\nError: Unknown tool '{tool_name}'.\n</tool_response>" |
| messages.append({"role": "user", "content": tool_response}) |
|
|
| output_parts.append(f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n") |
|
|
| return { |
| "output": "".join(output_parts), |
| "input_messages": initial_messages, |
| "conversation_messages": json.loads(json.dumps(messages)), |
| "finish_reason": "max_turns", |
| "num_round": max_turns, |
| "tool_calls": tool_calls, |
| "error": None, |
| "error_status_code": None, |
| "error_raw_response": "", |
| "request_debug": result.get("request_debug") if "result" in locals() else None, |
| "saved_images": saved_images, |
| "locked_frame_path": locked_frame_path, |
| "format_retries": format_retries, |
| "forced_final_answer": forced_final_answer, |
| "final_answer_retries": final_answer_retries, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def messages_to_input_string( |
| messages: list, |
| saved_images: Optional[list[dict]] = None, |
| append_assistant_prompt: bool = True, |
| ) -> str: |
| """Convert messages to transcript string like rollout data. |
| |
| When saved_images is provided, image markers are assigned in saved order. |
| """ |
| parts = [] |
| image_markers = [img.get("marker", "[IMAGE]") for img in (saved_images or []) if img.get("marker")] |
| image_marker_idx = 0 |
|
|
| def next_marker() -> str: |
| nonlocal image_marker_idx |
| if image_marker_idx < len(image_markers): |
| marker = image_markers[image_marker_idx] |
| image_marker_idx += 1 |
| return marker |
| image_marker_idx += 1 |
| return f"[IMAGE {image_marker_idx}]" |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
|
|
| if isinstance(content, str): |
| parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") |
| else: |
| |
| text_parts = [] |
| has_image = False |
| for item in content: |
| if isinstance(item, str): |
| text_parts.append(item) |
| elif item.get("type") == "text": |
| text_parts.append(item["text"]) |
| elif item.get("type") == "image_url": |
| has_image = True |
| text_parts.append(next_marker()) |
|
|
| content_str = "\n".join(text_parts) |
| parts.append(f"<|im_start|>{role}\n{content_str}<|im_end|>") |
|
|
| transcript = "\n".join(parts) |
| if append_assistant_prompt: |
| return transcript + "\n<|im_start|>assistant\n" |
| return transcript |
|
|
|
|
| async def evaluate_sample( |
| sample: dict, |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| mode: str, |
| dataset_configs: dict, |
| semaphore: asyncio.Semaphore, |
| **kwargs, |
| ) -> dict: |
| """Evaluate a single sample.""" |
| dataset_name = sample.get("dataset", "") |
| ds_config = dataset_configs.get(dataset_name, {}) |
| score_methods = ds_config.get("score_methods", []) |
| task_kind = sample.get("task_kind") or ds_config.get("task_kind", "image") |
|
|
| |
| system_prompt = ds_config.get("system_prompt", "") |
| format_instruction = ds_config.get("format_instruction", "") |
|
|
| async with semaphore: |
| try: |
| if mode == "direct": |
| if task_kind == "video_dr": |
| result = await evaluate_video_direct( |
| client, |
| model, |
| base_url, |
| api_key, |
| sample["question"], |
| sample["video_path"], |
| system_prompt=kwargs.get( |
| "general_video_direct_system_prompt", |
| GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, |
| ), |
| sample_id=sample["id"], |
| extra_image_paths=sample.get("extra_image_paths") or [], |
| **kwargs, |
| ) |
| else: |
| result = await evaluate_direct( |
| client, model, base_url, api_key, |
| sample["question"], sample["image_path"], |
| system_prompt=system_prompt, |
| format_instruction=format_instruction, |
| **kwargs |
| ) |
| else: |
| if task_kind == "video_dr": |
| result = await evaluate_video_tool( |
| client, |
| model, |
| base_url, |
| api_key, |
| sample["question"], |
| sample["video_path"], |
| tool_system_prompt=system_prompt or kwargs.get("video_dr_system_prompt", VIDEO_DR_SYSTEM_PROMPT), |
| sample_id=sample["id"], |
| **kwargs, |
| ) |
| else: |
| tools_section = kwargs.get("tools_section", "") |
| if not tools_section: |
| raise ValueError("tools_section required for image tool mode (use --tool-config)") |
| task_instruction_parts = [] |
| if system_prompt: |
| task_instruction_parts.append(system_prompt.strip()) |
| if format_instruction: |
| task_instruction_parts.append( |
| "# Required Answer Format\n" |
| + format_instruction.strip() |
| ) |
| task_instruction = "\n\n".join(part for part in task_instruction_parts if part) |
| if not task_instruction: |
| raise ValueError( |
| f"system_prompt or format_instruction required in dataset config for tool mode " |
| f"(dataset: {dataset_name})" |
| ) |
| tool_system_prompt = build_image_tool_system_prompt( |
| task_instruction=task_instruction, |
| tools_section=tools_section, |
| max_turns=kwargs.get("max_turns", 10), |
| allowed_tool_names=kwargs.get("allowed_tool_names"), |
| ) |
| result = await evaluate_tool( |
| client, model, base_url, api_key, |
| sample["question"], sample["image_path"], |
| tool_system_prompt=tool_system_prompt, |
| image_search_data=sample.get("image_search_data"), |
| data_root=sample.get("data_root", ""), |
| sample_id=sample["id"], |
| **kwargs, |
| ) |
|
|
| |
| scores = {} |
| extracted = None |
|
|
| for method in score_methods: |
| if method == "em_score_mcq": |
| extracted = extract_mcq_answer(result["output"]) |
| em_correct = check_answer(extracted, sample["answer"]) |
| scores[method] = 1.0 if em_correct else 0.0 |
| elif method == "llm_score": |
| judge_client = kwargs.get("judge_client", "azure") |
| judge_base_url = kwargs.get("judge_base_url", "") |
| judge_api_key = kwargs.get("judge_api_key", "") |
| judge_temperature = kwargs.get("judge_temperature", 0.0) |
| scores[method] = await llm_judge_score( |
| sample["question"], result["output"], sample["answer"], |
| sample.get("judge_image_path") or result.get("locked_frame_path", "") or sample.get("image_path", ""), |
| judge_client, judge_base_url, judge_api_key, judge_temperature, |
| shared_session=kwargs.get("shared_session"), |
| ) |
| else: |
| |
| scores[method] = None |
|
|
| |
| input_str = messages_to_input_string(result["input_messages"], result.get("saved_images", [])) |
| trajectory_str = messages_to_input_string( |
| result.get("conversation_messages", result["input_messages"]), |
| result.get("saved_images", []), |
| append_assistant_prompt=False, |
| ) |
|
|
| result_dict = { |
| "sample_id": sample["id"], |
| "dataset": dataset_name, |
| "input": input_str, |
| "output": result["output"], |
| "trajectory": trajectory_str, |
| "gts": sample["answer"], |
| "finish_reason": result["finish_reason"], |
| "num_round": result["num_round"], |
| "tool_calls": result.get("tool_calls", []), |
| "error": result.get("error"), |
| "error_status_code": result.get("error_status_code"), |
| "error_raw_response": result.get("error_raw_response", ""), |
| "request_debug": result.get("request_debug"), |
| "format_retries": result.get("format_retries", 0), |
| "forced_final_answer": result.get("forced_final_answer", False), |
| "final_answer_retries": result.get("final_answer_retries", 0), |
| "final_only_answer": result.get("final_only_answer", False), |
| "eval_compat_profile": kwargs.get("eval_compat_profile", "current"), |
| "saved_images": result.get("saved_images", []), |
| } |
| for key in ("category", "difficulty", "task_kind"): |
| if key in sample: |
| result_dict[key] = sample[key] |
| if result.get("locked_frame_path"): |
| result_dict["locked_frame_path"] = result["locked_frame_path"] |
| |
| result_dict.update(scores) |
| return result_dict |
|
|
| except FatalAPIError: |
| raise |
| except Exception as e: |
| print(f"[ERROR] {sample['id']}: {e}") |
| error_dict = { |
| "sample_id": sample["id"], |
| "dataset": dataset_name, |
| "input": "", |
| "output": "", |
| "gts": sample["answer"], |
| "finish_reason": "error", |
| "num_round": 0, |
| "error": str(e), |
| } |
| for key in ("category", "difficulty", "task_kind"): |
| if key in sample: |
| error_dict[key] = sample[key] |
| |
| for method in score_methods: |
| error_dict[method] = 0.0 |
| return error_dict |
|
|
|
|
| async def run_evaluation( |
| samples: list[dict], |
| dataset_configs: dict, |
| client: str, |
| model: str, |
| base_url: str, |
| api_key: str, |
| mode: str, |
| max_concurrent: int = 4, |
| output_dir: str = "", |
| **kwargs, |
| ) -> dict: |
| """Run evaluation on all samples.""" |
| datasets = set(s.get("dataset", "") for s in samples) |
|
|
| |
| completed_ids = set() |
| existing_results = [] |
| results_file = os.path.join(output_dir, "results.jsonl") if output_dir else "" |
| if results_file and os.path.exists(results_file): |
| with open(results_file) as f: |
| for line in f: |
| r = json.loads(line) |
| existing_results.append(r) |
| completed_ids.add(r.get("sample_id")) |
| print(f"Resuming: loaded {len(existing_results)} completed results, skipping {len(completed_ids)} samples") |
| samples = [s for s in samples if s["id"] not in completed_ids] |
|
|
| print(f"Evaluating {len(samples)} samples across {len(datasets)} datasets") |
| print(f"Client: {client}, Model: {model}") |
| print(f"Mode: {mode}") |
| print() |
|
|
| |
| dataset_stats = {ds: {"total": 0, "errors": 0, "em_correct": 0, "em_total": 0, "llm_correct": 0, "llm_total": 0} |
| for ds in datasets} |
| |
| for r in existing_results: |
| ds = r.get("dataset", "") |
| if ds in dataset_stats: |
| dataset_stats[ds]["total"] += 1 |
| if r.get("error"): |
| dataset_stats[ds]["errors"] += 1 |
| if r.get("em_score_mcq") is not None: |
| dataset_stats[ds]["em_total"] += 1 |
| if r["em_score_mcq"] > 0.5: |
| dataset_stats[ds]["em_correct"] += 1 |
| if r.get("llm_score") is not None: |
| dataset_stats[ds]["llm_total"] += 1 |
| if r["llm_score"] > 0.5: |
| dataset_stats[ds]["llm_correct"] += 1 |
|
|
| if not samples: |
| print("All samples already completed!") |
| |
| if output_dir and existing_results: |
| html_path = os.path.join(output_dir, "results.html") |
| generate_html(existing_results, html_path, images_subdir="images") |
| print(f"Generated HTML viewer: {html_path}", flush=True) |
| |
| dataset_results = {} |
| for dataset_name in sorted(datasets): |
| st = dataset_stats[dataset_name] |
| ds_stats = {"total": st["total"], "errors": st["errors"]} |
| if st["em_total"] > 0: |
| ds_stats["em_correct"] = st["em_correct"] |
| ds_stats["em_total"] = st["em_total"] |
| ds_stats["em_accuracy"] = st["em_correct"] / st["em_total"] |
| if st["llm_total"] > 0: |
| ds_stats["llm_correct"] = st["llm_correct"] |
| ds_stats["llm_total"] = st["llm_total"] |
| ds_stats["llm_accuracy"] = st["llm_correct"] / st["llm_total"] |
| dataset_results[dataset_name] = ds_stats |
| await _cleanup_browser() |
| return dataset_results |
|
|
| semaphore = asyncio.Semaphore(max_concurrent) |
|
|
| |
| serper_concurrency = kwargs.pop("serper_concurrency", 5) |
| serper_semaphore = asyncio.Semaphore(serper_concurrency) |
| kwargs["serper_semaphore"] = serper_semaphore |
| kwargs["serper_concurrency"] = serper_concurrency |
|
|
| |
| images_dir = "" |
| frame_cache_dir = "" |
| image_search_cache = None |
| if output_dir: |
| os.makedirs(output_dir, exist_ok=True) |
| images_dir = os.path.join(output_dir, "images") |
| frame_cache_dir = os.path.join(output_dir, "frame_cache") |
| os.makedirs(images_dir, exist_ok=True) |
| os.makedirs(frame_cache_dir, exist_ok=True) |
| image_search_cache = ImageSearchCache(os.path.join(output_dir, "image_search_cache.json")) |
| image_seed_paths = kwargs.get("image_search_cache_seed_paths") or [] |
| if image_seed_paths: |
| seed_image_search_cache(image_search_cache, image_seed_paths) |
| kwargs["images_dir"] = images_dir |
| kwargs["frame_cache_dir"] = frame_cache_dir |
| kwargs["image_search_cache"] = image_search_cache |
|
|
| start_time = time.time() |
| num_existing = len(existing_results) |
| last_log_time = time.time() |
|
|
| |
| results_file_handle = None |
| all_results_for_html = [] |
|
|
| try: |
| shared_timeout = aiohttp.ClientTimeout(total=3600) |
| async with create_http_session(shared_timeout) as shared_session: |
| kwargs["shared_session"] = shared_session |
| tasks = [ |
| evaluate_sample(s, client, model, base_url, api_key, mode, dataset_configs, semaphore, **kwargs) |
| for s in samples |
| ] |
|
|
| if output_dir: |
| write_mode = "a" if existing_results else "w" |
| results_file_handle = open(os.path.join(output_dir, "results.jsonl"), write_mode) |
|
|
| for i, task in enumerate(asyncio.as_completed(tasks)): |
| try: |
| result = await task |
| except FatalAPIError as e: |
| print(f"\n{'='*70}", flush=True) |
| print(f"FATAL ERROR: {e}", flush=True) |
| print(f"{'='*70}\n", flush=True) |
| raise |
|
|
| |
| ds = result.get("dataset", "") |
| if ds in dataset_stats: |
| dataset_stats[ds]["total"] += 1 |
| if result.get("error"): |
| dataset_stats[ds]["errors"] += 1 |
| if result.get("em_score_mcq") is not None: |
| dataset_stats[ds]["em_total"] += 1 |
| if result["em_score_mcq"] > 0.5: |
| dataset_stats[ds]["em_correct"] += 1 |
| if result.get("llm_score") is not None: |
| dataset_stats[ds]["llm_total"] += 1 |
| if result["llm_score"] > 0.5: |
| dataset_stats[ds]["llm_correct"] += 1 |
|
|
| |
| if results_file_handle: |
| jsonl_skip_keys = {"saved_images"} |
| jsonl_result = {k: v for k, v in result.items() if k not in jsonl_skip_keys} |
| results_file_handle.write(json.dumps(jsonl_result, ensure_ascii=False) + "\n") |
| results_file_handle.flush() |
|
|
| |
| all_results_for_html.append(result) |
|
|
| now = time.time() |
| total_done = num_existing + i + 1 |
| if (i + 1) % 10 == 0 or (i + 1) == len(tasks) or (now - last_log_time > 30): |
| elapsed = now - start_time |
| rate = (i + 1) / elapsed if elapsed > 0 else 0 |
| eta = (len(tasks) - i - 1) / rate if rate > 0 else 0 |
|
|
| ds_stats_strs = [] |
| for ds_name in sorted(datasets): |
| st = dataset_stats[ds_name] |
| if st["em_total"] > 0: |
| ds_acc = 100 * st["em_correct"] / st["em_total"] |
| ds_stats_strs.append(f"{ds_name}(em): {st['em_correct']}/{st['em_total']} ({ds_acc:.1f}%)") |
| if st["llm_total"] > 0: |
| ds_acc = 100 * st["llm_correct"] / st["llm_total"] |
| ds_stats_strs.append(f"{ds_name}(llm): {st['llm_correct']}/{st['llm_total']} ({ds_acc:.1f}%)") |
|
|
| cache_info = "" |
| search_cache = kwargs.get("search_cache") |
| if search_cache: |
| cs = search_cache.get_stats() |
| cache_info = f" | Cache: {cs['hits']}/{cs['total']} ({cs['hit_rate']:.0f}%)" |
|
|
| web_stats = _web_fetch_stats.format_progress() |
| web_info = f" | {web_stats}" if web_stats else "" |
|
|
| print( |
| f"Progress: {total_done}/{num_existing + len(tasks)} | " |
| f"Elapsed: {elapsed:.0f}s | {rate:.1f} samples/s | ETA: {eta:.0f}s" |
| f"{cache_info}{web_info}", |
| flush=True, |
| ) |
| if ds_stats_strs: |
| print(f" {' | '.join(ds_stats_strs)}", flush=True) |
| last_log_time = now |
| finally: |
| if results_file_handle: |
| results_file_handle.close() |
| if image_search_cache: |
| image_search_cache.save() |
| await _cleanup_browser() |
|
|
| |
| |
| |
| if output_dir and (all_results_for_html or existing_results): |
| html_results = existing_results + all_results_for_html |
| html_path = os.path.join(output_dir, "results.html") |
| generate_html(html_results, html_path, images_subdir="images") |
| print(f"Generated HTML viewer: {html_path}", flush=True) |
|
|
| elapsed = time.time() - start_time |
|
|
| |
| print(flush=True) |
| print("=" * 70, flush=True) |
| print("Results by Dataset:", flush=True) |
| print("=" * 70, flush=True) |
|
|
| dataset_results = {} |
| total_samples = 0 |
| for dataset_name in sorted(datasets): |
| st = dataset_stats[dataset_name] |
| total_samples += st["total"] |
|
|
| ds_stats = {"total": st["total"], "errors": st["errors"]} |
| if st["em_total"] > 0: |
| ds_stats["em_correct"] = st["em_correct"] |
| ds_stats["em_total"] = st["em_total"] |
| ds_stats["em_accuracy"] = st["em_correct"] / st["em_total"] |
| if st["llm_total"] > 0: |
| ds_stats["llm_correct"] = st["llm_correct"] |
| ds_stats["llm_total"] = st["llm_total"] |
| ds_stats["llm_accuracy"] = st["llm_correct"] / st["llm_total"] |
|
|
| dataset_results[dataset_name] = ds_stats |
|
|
| print(f"\n{dataset_name} ({st['total']} samples):", flush=True) |
| if "em_accuracy" in ds_stats: |
| print(f" em_score_mcq: {ds_stats['em_accuracy']:.4f} ({ds_stats['em_correct']}/{ds_stats['em_total']})", flush=True) |
| if "llm_accuracy" in ds_stats: |
| print(f" llm_score_allow_no_answer: {ds_stats['llm_accuracy']:.4f} ({ds_stats['llm_correct']}/{ds_stats['llm_total']})", flush=True) |
| if ds_stats["errors"] > 0: |
| print(f" errors: {ds_stats['errors']}", flush=True) |
|
|
| print(flush=True) |
| if total_samples > 0: |
| print(f"Time: {elapsed:.1f}s ({elapsed/total_samples:.2f}s per sample)", flush=True) |
| else: |
| print(f"Time: {elapsed:.1f}s (no samples processed)", flush=True) |
|
|
| |
| ws = _web_fetch_stats.get_stats() |
| if ws["total"] > 0: |
| print(f"\nWeb Fetch Stats: {ws['successful']}/{ws['total']} ({ws['success_rate']:.1f}%) successful", flush=True) |
| if ws["failed"] > 0: |
| print(f" Failed: {ws['failed']}", flush=True) |
| if ws["errors_by_code"]: |
| code_str = ", ".join(f"HTTP {code}: {cnt}" for code, cnt in sorted(ws["errors_by_code"].items())) |
| print(f" Error codes: {code_str}", flush=True) |
| if ws["skipped"] > 0: |
| print(f" Skipped (non-HTML): {ws['skipped']}", flush=True) |
|
|
| print("=" * 70, flush=True) |
|
|
| |
| return dataset_results |
|
|
|
|
| def generate_html(results: list[dict], output_path: str, images_subdir: str = "images"): |
| """Generate HTML viewer for results with actual images displayed. |
| |
| Args: |
| results: List of result dicts, each may contain 'saved_images' with image paths |
| output_path: Path to write HTML file |
| images_subdir: Subdirectory containing images (relative to HTML file) |
| """ |
| import html as html_lib |
|
|
| def highlight_tags_with_images(text: str, saved_images: list[dict]) -> str: |
| """Add syntax highlighting for special tags and replace [IMAGE N] with actual images.""" |
| text = html_lib.escape(text) |
|
|
| |
| for img_info in saved_images: |
| marker = img_info.get("marker", "") |
| path = img_info.get("path", "") |
| img_type = img_info.get("type", "") |
| if marker and path: |
| escaped_marker = html_lib.escape(marker) |
| |
| img_tag = f'<br><img src="{images_subdir}/{path}" class="inline-image" title="{img_type}"><br>' |
| text = text.replace(escaped_marker, img_tag) |
|
|
| text = text.replace("\n", "<br>") |
| |
| tag_styles = [ |
| (r"<think>", '<span class="tag-think"><think></span>'), |
| (r"</think>", '<span class="tag-think"></think></span>'), |
| (r"<thinking>", '<span class="tag-think"><thinking></span>'), |
| (r"</thinking>", '<span class="tag-think"></thinking></span>'), |
| (r"<tool_call>", '<span class="tag-tool"><tool_call></span>'), |
| (r"</tool_call>", '<span class="tag-tool"></tool_call></span>'), |
| (r"<answer>", '<span class="tag-answer"><answer></span>'), |
| (r"</answer>", '<span class="tag-answer"></answer></span>'), |
| (r"<tool_response>", '<span class="tag-response"><tool_response></span>'), |
| (r"</tool_response>", '<span class="tag-response"></tool_response></span>'), |
| ] |
| for pattern, replacement in tag_styles: |
| text = text.replace(pattern, replacement) |
| return text |
|
|
| html_content = '''<!DOCTYPE html> |
| <html lang="en"> |
| <head> |
| <meta charset="UTF-8"> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| <title>Eval Results Viewer</title> |
| <style> |
| :root { |
| --bg-primary: #1e1e2e; --bg-secondary: #282839; --bg-tertiary: #313244; |
| --text-primary: #cdd6f4; --text-secondary: #a6adc8; |
| --accent: #cba6f7; --green: #a6e3a1; --red: #f38ba8; --yellow: #f9e2af; --blue: #89b4fa; |
| --border: #45475a; |
| } |
| * { box-sizing: border-box; margin: 0; padding: 0; } |
| body { font-family: 'SF Mono', Consolas, monospace; background: var(--bg-primary); color: var(--text-primary); line-height: 1.6; padding: 20px; } |
| .header { background: var(--bg-secondary); padding: 20px; border-radius: 12px; margin-bottom: 20px; border: 1px solid var(--border); } |
| .header h1 { color: var(--accent); margin-bottom: 10px; font-size: 1.5em; } |
| .stats { display: flex; gap: 20px; flex-wrap: wrap; } |
| .stat { background: var(--bg-tertiary); padding: 8px 16px; border-radius: 8px; font-size: 0.9em; } |
| .stat-label { color: var(--text-secondary); } |
| .stat-value { color: var(--accent); font-weight: bold; } |
| .controls { background: var(--bg-secondary); padding: 15px; border-radius: 12px; margin-bottom: 20px; display: flex; gap: 10px; flex-wrap: wrap; border: 1px solid var(--border); } |
| .controls input { flex: 1; min-width: 200px; padding: 10px 15px; border: 1px solid var(--border); border-radius: 8px; background: var(--bg-tertiary); color: var(--text-primary); font-family: inherit; } |
| .controls button { padding: 10px 20px; border: none; border-radius: 8px; background: var(--accent); color: var(--bg-primary); cursor: pointer; font-weight: bold; } |
| .sample { background: var(--bg-secondary); border-radius: 12px; margin-bottom: 15px; overflow: hidden; border: 1px solid var(--border); } |
| .sample-header { background: var(--bg-tertiary); padding: 15px 20px; cursor: pointer; display: flex; align-items: center; gap: 15px; border-bottom: 1px solid var(--border); } |
| .sample-header:hover { background: #3b3b4f; } |
| .toggle { color: var(--accent); transition: transform 0.2s; } |
| .sample.collapsed .toggle { transform: rotate(-90deg); } |
| .sample-title { flex: 1; font-weight: bold; } |
| .sample-badges { display: flex; gap: 10px; } |
| .badge { padding: 4px 12px; border-radius: 20px; font-size: 0.8em; font-weight: bold; } |
| .badge-positive { background: var(--green); color: var(--bg-primary); } |
| .badge-negative { background: var(--red); color: var(--bg-primary); } |
| .badge-neutral { background: var(--text-secondary); color: var(--bg-primary); } |
| .sample-content { padding: 20px; } |
| .sample.collapsed .sample-content { display: none; } |
| .meta-row { display: flex; gap: 15px; margin-bottom: 15px; flex-wrap: wrap; } |
| .meta-item { background: var(--bg-tertiary); padding: 8px 12px; border-radius: 8px; font-size: 0.85em; } |
| .meta-label { color: var(--text-secondary); } |
| .meta-value { color: var(--green); } |
| .conversation { display: flex; flex-direction: column; gap: 15px; } |
| .turn { border-radius: 12px; overflow: hidden; border: 1px solid var(--border); } |
| .turn-header { padding: 10px 15px; font-weight: bold; font-size: 0.85em; text-transform: uppercase; } |
| .turn.user .turn-header { background: #2a4a6a; color: var(--blue); } |
| .turn.assistant .turn-header { background: #2a4a3a; color: var(--green); } |
| .turn-content { padding: 15px; background: var(--bg-tertiary); white-space: pre-wrap; word-wrap: break-word; font-size: 0.9em; max-height: 800px; overflow-y: auto; } |
| .tag-think { color: var(--yellow); } |
| .tag-tool { color: var(--accent); } |
| .tag-answer { color: var(--green); } |
| .tag-response { color: var(--blue); } |
| .inline-image { max-width: 400px; max-height: 300px; border-radius: 8px; margin: 8px 0; border: 2px solid var(--border); cursor: pointer; transition: transform 0.2s; } |
| .inline-image:hover { transform: scale(1.02); border-color: var(--accent); } |
| </style> |
| </head> |
| <body> |
| <div class="header"> |
| <h1>Eval Results Viewer</h1> |
| <div class="stats"> |
| <div class="stat"><span class="stat-label">Samples:</span> <span class="stat-value">''' + str(len(results)) + '''</span></div> |
| <div class="stat"><span class="stat-label">Generated:</span> <span class="stat-value">''' + time.strftime("%Y-%m-%d %H:%M:%S") + '''</span></div> |
| </div> |
| </div> |
| <div class="controls"> |
| <input type="text" id="search" placeholder="Search samples..." onkeyup="filterSamples()"> |
| <button onclick="expandAll()">Expand All</button> |
| <button onclick="collapseAll()">Collapse All</button> |
| </div> |
| <div id="samples"> |
| ''' |
|
|
| for i, r in enumerate(results): |
| sample_id = r.get("sample_id", f"sample-{i}") |
| gts = r.get("gts", "") |
| saved_images = r.get("saved_images", []) |
|
|
| |
| skip_keys = {"sample_id", "dataset", "input", "output", "gts", "saved_images"} |
|
|
| |
| meta_items = [] |
| meta_items.append(f'<div class="meta-item"><span class="meta-label">Ground Truth:</span> <span class="meta-value">{html_lib.escape(str(gts))}</span></div>') |
|
|
| for key, value in r.items(): |
| if key in skip_keys: |
| continue |
| |
| if value is None: |
| formatted = "None" |
| elif isinstance(value, float): |
| formatted = f"{value:.4f}" |
| elif isinstance(value, (int, bool)): |
| formatted = str(value) |
| elif isinstance(value, str) and len(value) < 200: |
| formatted = html_lib.escape(value) |
| else: |
| continue |
| meta_items.append(f'<div class="meta-item"><span class="meta-label">{html_lib.escape(key)}:</span> <span class="meta-value">{formatted}</span></div>') |
|
|
| |
| input_html = highlight_tags_with_images(r.get("input", ""), saved_images) |
| output_html = highlight_tags_with_images(r.get("output", ""), saved_images) |
|
|
| html_content += f''' |
| <div class="sample collapsed" id="sample-{i}"> |
| <div class="sample-header" onclick="toggle({i})"> |
| <span class="toggle">▼</span> |
| <span class="sample-title">{html_lib.escape(sample_id)}</span> |
| </div> |
| <div class="sample-content"> |
| <div class="meta-row">{"".join(meta_items)}</div> |
| <div class="conversation"> |
| <div class="turn user"> |
| <div class="turn-header">Input</div> |
| <div class="turn-content">{input_html}</div> |
| </div> |
| <div class="turn assistant"> |
| <div class="turn-header">Output</div> |
| <div class="turn-content">{output_html}</div> |
| </div> |
| </div> |
| </div> |
| </div> |
| ''' |
|
|
| html_content += ''' |
| </div> |
| <script> |
| function toggle(i) { |
| document.getElementById('sample-' + i).classList.toggle('collapsed'); |
| } |
| function expandAll() { |
| document.querySelectorAll('.sample').forEach(s => s.classList.remove('collapsed')); |
| } |
| function collapseAll() { |
| document.querySelectorAll('.sample').forEach(s => s.classList.add('collapsed')); |
| } |
| function filterSamples() { |
| const query = document.getElementById('search').value.toLowerCase(); |
| document.querySelectorAll('.sample').forEach(s => { |
| s.style.display = s.textContent.toLowerCase().includes(query) ? '' : 'none'; |
| }); |
| } |
| </script> |
| </body> |
| </html>''' |
|
|
| with open(output_path, "w") as f: |
| f.write(html_content) |
|
|
|
|
| def save_results(dataset_results: dict, output_dir: str, run_config: Optional[dict] = None): |
| """Save summary to output directory. Note: results.jsonl is saved incrementally during eval.""" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| summary = {"datasets": dataset_results} |
| if run_config: |
| summary["run_config"] = run_config |
|
|
| |
| ws = _web_fetch_stats.get_stats() |
| if ws["total"] > 0: |
| summary["web_fetch_stats"] = ws |
|
|
| with open(os.path.join(output_dir, "summary.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"\nResults saved to: {output_dir}", flush=True) |
| print(f" Results: {output_dir}/results.jsonl", flush=True) |
| html_path = os.path.join(output_dir, "results.html") |
| if os.path.exists(html_path): |
| print(f" HTML: {html_path}", flush=True) |
| print(f" Summary: {output_dir}/summary.json", flush=True) |
|
|
|
|
| |
| |
| |
|
|
| EVAL_COMPAT_PROFILES = ("current", "qwen235b_repair", "step14_plus_tavily432") |
|
|
|
|
| def _env_flag(name: str, default: bool = False) -> bool: |
| value = os.environ.get(name) |
| if value is None: |
| return default |
| return value.lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def _argv_has_option(option: str) -> bool: |
| prefix = option + "=" |
| return any(arg == option or arg.startswith(prefix) for arg in sys.argv[1:]) |
|
|
|
|
| def apply_eval_compat_profile(args: argparse.Namespace, parser: argparse.ArgumentParser) -> str: |
| """Normalize evaluation behavior for reproducible cross-run comparisons.""" |
| profile = args.eval_compat_profile |
| if profile not in EVAL_COMPAT_PROFILES: |
| parser.error(f"Unknown --eval-compat-profile: {args.eval_compat_profile}") |
|
|
| if profile == "step14_plus_tavily432": |
| if _argv_has_option("--format-retry-limit") and args.format_retry_limit not in (None, 0): |
| parser.error("--eval-compat-profile step14_plus_tavily432 requires --format-retry-limit 0") |
| if _argv_has_option("--final-answer-retry-limit") and args.final_answer_retry_limit not in (None, 0): |
| parser.error("--eval-compat-profile step14_plus_tavily432 requires --final-answer-retry-limit 0") |
| args.format_retry_limit = 0 |
| args.disable_force_final_answer_turn = True |
| args.final_answer_retry_limit = 0 |
| args.disable_no_tool_answer_recovery = True |
| return "step14_plus_tavily432" |
|
|
| if args.format_retry_limit is None: |
| args.format_retry_limit = int(os.environ.get("FORMAT_RETRY_LIMIT", "2")) |
| if args.final_answer_retry_limit is None: |
| args.final_answer_retry_limit = int(os.environ.get("FINAL_ANSWER_RETRY_LIMIT", "1")) |
| return profile |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="VLM Evaluation Script v2") |
|
|
| parser.add_argument("--model", type=str, required=True, help="Model name") |
| parser.add_argument("--mode", type=str, required=True, choices=["direct", "tool"]) |
| parser.add_argument("--datasets", type=str, default="", |
| help="Path to datasets JSON config. If omitted, use --benchmarks with --eval-root.") |
| parser.add_argument("--eval-root", type=str, default="", |
| help=( |
| "Root directory whose subfolders are image benchmarks with data.jsonl. " |
| f"Default used with --benchmarks: {DEFAULT_EVAL_ROOT}" |
| )) |
| parser.add_argument("--benchmarks", type=str, nargs="+", default=None, |
| help="Benchmark subfolder names under --eval-root, e.g. mmsearch_end2end_only_image hr_mmsearch.") |
| parser.add_argument("--save-resolved-datasets-config", type=str, default="", |
| help="Optional path to save the auto-generated datasets JSON config.") |
| parser.add_argument("--model-client", type=str, choices=["gemini", "openai", "azure", "gateway", "vertex"], required=True, |
| help="Model API client: gemini, openai (Qwen/MARS), azure (GPT), gateway (company GPT gateway), or vertex (Vertex Gemini service-account pool)") |
| parser.add_argument("--judge-client", type=str, choices=["openai", "azure"], default=None, |
| help="Legacy judge client flag; llm_score now follows video_dr_gen and ignores this option") |
| parser.add_argument("--judge-temperature", type=float, default=0.0, |
| help="Legacy judge temperature flag; llm_score now follows video_dr_gen and ignores this option") |
| parser.add_argument("--data-root", type=str, default="", help="Root for relative paths") |
| parser.add_argument("--output-dir", type=str, default=None) |
| parser.add_argument("--target-ids", type=str, nargs="+", default=None, |
| help="Only evaluate the specified sample ids or source ids") |
| parser.add_argument("--vertex-account-pool-file", type=str, |
| default=os.environ.get("VERTEX_ACCOUNT_POOL_FILE", ""), |
| help="Vertex Gemini account pool JSON; defaults to env VERTEX_ACCOUNT_POOL_FILE") |
| parser.add_argument("--vertex-location", type=str, |
| default=os.environ.get("VERTEX_LOCATION", "global"), |
| help="Default Vertex location for account-pool entries without location (default: global)") |
| parser.add_argument("--vertex-rate-limit-cooldown", type=float, |
| default=float(os.environ.get("VERTEX_RATE_LIMIT_COOLDOWN_SECONDS", "60")), |
| help="Cooldown seconds for a Vertex account after 429/RESOURCE_EXHAUSTED before rotating back") |
|
|
| parser.add_argument("--max-concurrent", type=int, default=4) |
| parser.add_argument("--max-tokens", type=int, default=4096) |
| parser.add_argument("--max-turns", type=int, default=50, |
| help="Max assistant turns (default: 50)") |
| parser.add_argument("--eval-compat-profile", type=str, choices=EVAL_COMPAT_PROFILES, |
| default=os.environ.get("EVAL_COMPAT_PROFILE", "current"), |
| help=( |
| "Evaluation behavior profile. Use step14_plus_tavily432 to keep " |
| "step14 prompt/control flow while retaining later Tavily key-pool fixes." |
| )) |
| parser.add_argument("--format-retry-limit", type=int, default=None, |
| help="Extra protocol-repair retries when a tool-mode response is empty, malformed, or only in reasoning_content (default: 2; forced to 0 by step14_plus_tavily432)") |
| parser.add_argument("--disable-force-final-answer-turn", action="store_true", |
| default=_env_flag("DISABLE_FORCE_FINAL_ANSWER_TURN"), |
| help="Disable the final no-tool answer turn used to prevent VideoDR tool loops near max_turns") |
| parser.add_argument("--final-answer-retry-limit", type=int, default=None, |
| help="Extra retries after the forced final-answer instruction if the model still fails to answer (default: 1; forced to 0 by step14_plus_tavily432)") |
| parser.add_argument("--disable-no-tool-answer-recovery", action="store_true", |
| default=_env_flag("DISABLE_NO_TOOL_ANSWER_RECOVERY"), |
| help="Disable recovery of plain no-tool outputs as final answers.") |
| parser.add_argument("--temperature", type=float, default=0.7) |
| parser.add_argument("--top-p", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=20) |
| parser.add_argument("--presence-penalty", type=float, default=1.5) |
| parser.add_argument("--repetition-penalty", type=float, default=1.0, |
| help="Repetition penalty for model generation (default: 1.0)") |
| parser.add_argument("--seed", type=int, default=3407, |
| help="Random seed for deterministic generation (default: 3407)") |
|
|
| parser.add_argument("--min-pixels", type=int, default=65536, help="Min pixels for image processing") |
| parser.add_argument("--max-pixels", type=int, default=8294400, help="Max pixels for image processing") |
| parser.add_argument("--factor", type=int, default=32, help="Alignment factor (32 for Qwen3-VL, 28 for Qwen2-VL)") |
| parser.add_argument("--qwen-vl-processing", type=lambda x: x.lower() == 'true', default=True, |
| help="Use Qwen-VL style image processing (default: True, set False for Gemini/GPT)") |
| parser.add_argument("--tool-config", type=str, default="", help="Path to tool config YAML (optional, overrides dataset system_prompt)") |
| parser.add_argument("--tool-ablation-profile", type=str, choices=TOOL_ABLATION_PROFILES, |
| default=os.environ.get("TOOL_ABLATION_PROFILE", "none"), |
| help=( |
| "Tool ablation profile for tool-mode evaluation: none, nosearch " |
| "(remove image_search/web_search), or nolocation (remove choose_frames/zoom_in)." |
| )) |
| parser.add_argument("--serper-concurrency", type=int, default=5, help="Max concurrent Serper API requests (default: 5)") |
| parser.add_argument("--search-cache-dir", type=str, default="", help="Directory for search result cache (optional, no cache if not set)") |
| parser.add_argument("--seed-search-cache-from", type=str, nargs="*", default=None, |
| help="Seed this run's search cache from historical search_cache.db files under these files/directories. Defaults to inference/runs when --search-cache-dir is set.") |
| parser.add_argument("--no-auto-seed-search-cache", action="store_true", |
| default=os.environ.get("NO_AUTO_SEED_SEARCH_CACHE", "").lower() in {"1", "true", "yes", "on"}, |
| help="Disable automatic historical search cache seeding.") |
| parser.add_argument("--seed-image-search-cache-from", type=str, nargs="*", default=None, |
| help="Seed this run's image_search_cache.json from historical image_search_cache.json files under these files/directories. Defaults to inference/runs in tool mode.") |
| parser.add_argument("--no-auto-seed-image-search-cache", action="store_true", |
| default=os.environ.get("NO_AUTO_SEED_IMAGE_SEARCH_CACHE", "").lower() in {"1", "true", "yes", "on"}, |
| help="Disable automatic historical image_search_cache.json seeding.") |
| parser.add_argument("--web-search-backend", type=str, |
| choices=["serper_gateway", "gateway", "gateway_serper", "internal_serper", "company_serper", "mars", "tavily", "auto"], |
| default=_default_web_search_backend(), |
| help="Backend for web_search tool: serper_gateway, mars, tavily, or auto (default: env WEB_SEARCH_BACKEND/VIDEO_DR_WEB_SEARCH_BACKEND or serper_gateway)") |
| parser.add_argument("--serper-gateway-max-results", type=int, |
| default=int(os.environ.get("SERPER_GATEWAY_MAX_RESULTS", "8")), |
| help="Serper gateway max results, clamped to 1-20 (default: 8)") |
| parser.add_argument("--serper-gateway-timeout", type=int, |
| default=int(os.environ.get("SERPER_GATEWAY_TIMEOUT", "60")), |
| help="Serper gateway request timeout in seconds (default: 60)") |
| parser.add_argument("--serper-gateway-summary-max-tokens", type=int, |
| default=int(os.environ.get("SERPER_GATEWAY_SUMMARY_MAX_TOKENS", "1024")), |
| help="Deprecated compatibility flag for the old metadata-only Serper gateway summarizer.") |
| parser.add_argument("--tavily-search-depth", type=str, choices=["basic", "advanced"], |
| default=os.environ.get("TAVILY_SEARCH_DEPTH", "advanced"), |
| help="Tavily search_depth when --web-search-backend=tavily") |
| parser.add_argument("--tavily-max-results", type=int, |
| default=int(os.environ.get("TAVILY_MAX_RESULTS", "8")), |
| help="Tavily max_results, clamped to 1-20 (default: 8)") |
| parser.add_argument("--tavily-include-answer", type=str, |
| default=os.environ.get("TAVILY_INCLUDE_ANSWER", "advanced"), |
| help="Tavily include_answer: false, true/basic, or advanced (default: advanced)") |
| parser.add_argument("--tavily-include-raw-content", type=str, |
| default=os.environ.get("TAVILY_INCLUDE_RAW_CONTENT", "false"), |
| help="Tavily include_raw_content: false, true/markdown, or text (default: false)") |
| parser.add_argument("--tavily-topic", type=str, choices=["general", "news", "finance"], |
| default=os.environ.get("TAVILY_TOPIC", "general"), |
| help="Tavily topic (default: general)") |
| parser.add_argument("--tavily-auto-parameters", action="store_true", |
| default=os.environ.get("TAVILY_AUTO_PARAMETERS", "").lower() in {"1", "true", "yes", "on"}, |
| help="Enable Tavily auto_parameters; uses extra Tavily credits") |
| parser.add_argument("--tavily-timeout", type=int, |
| default=int(os.environ.get("TAVILY_TIMEOUT", "60")), |
| help="Tavily request timeout in seconds (default: 60)") |
| parser.add_argument("--tavily-key-cooldown-seconds", type=int, |
| default=int(os.environ.get("TAVILY_KEY_COOLDOWN_SECONDS", "60")), |
| help="Cooldown for a Tavily key after 429/5xx before rotating back (default: 60)") |
| parser.add_argument("--video-initial-frames", type=int, default=DEFAULT_VIDEO_INITIAL_FRAMES, |
| help="Number of initial 1fps frames shown to the model for VideoDR") |
| parser.add_argument("--video-interval-samples", type=int, default=DEFAULT_VIDEO_INTERVAL_SAMPLES, |
| help="Number of uniformly sampled frames returned by choose_frames") |
| parser.add_argument("--video-max-resolution", type=int, default=DEFAULT_VIDEO_MAX_RESOLUTION, |
| help="Max side length for 1fps extracted video frames") |
| parser.add_argument("--video-jpeg-quality", type=int, default=DEFAULT_VIDEO_JPEG_QUALITY, |
| help="JPEG quality for cached video frames") |
|
|
| args = parser.parse_args() |
| resolved_eval_profile = apply_eval_compat_profile(args, parser) |
| video_dr_system_prompt = get_video_dr_system_prompt(resolved_eval_profile) |
| print( |
| "Evaluation compatibility profile: " |
| f"{resolved_eval_profile} " |
| f"(format_retry_limit={args.format_retry_limit}, " |
| f"force_final_answer_turn={not args.disable_force_final_answer_turn}, " |
| f"final_answer_retry_limit={args.final_answer_retry_limit}, " |
| f"no_tool_answer_recovery={not args.disable_no_tool_answer_recovery})", |
| flush=True, |
| ) |
|
|
| model_client = args.model_client |
|
|
| |
| if model_client == "gemini": |
| if "GEMINI_API_KEY" not in os.environ: |
| parser.error("GEMINI_API_KEY required") |
| if "GEMINI_BASE_URL" not in os.environ: |
| parser.error("GEMINI_BASE_URL required") |
| api_key = os.environ["GEMINI_API_KEY"] |
| base_url = os.environ["GEMINI_BASE_URL"] |
| elif model_client == "azure": |
| if "AZURE_OPENAI_API_KEY" not in os.environ: |
| parser.error("AZURE_OPENAI_API_KEY required") |
| if "AZURE_OPENAI_BASE_URL" not in os.environ: |
| parser.error("AZURE_OPENAI_BASE_URL required") |
| api_key = os.environ["AZURE_OPENAI_API_KEY"] |
| base_url = os.environ["AZURE_OPENAI_BASE_URL"] |
| elif model_client == "gateway": |
| _, _, api_key = _get_model_gateway_credentials(args.model) |
| if not api_key: |
| token_hint = ( |
| "MODEL_GATEWAY_GEMINI_TOKEN or GEMINI_GATEWAY_TOKEN required" |
| if _is_gemini_gateway_model(args.model) |
| else "MODEL_GATEWAY_TOKEN or GATEWAY_TOKEN required" |
| ) |
| parser.error(token_hint) |
| base_url = ( |
| os.environ.get("MODEL_GATEWAY_URL") |
| or os.environ.get("GATEWAY_URL") |
| or "http://112.65.194.90:8000/trpc.youtu.llm_interface_service.Greeter/DescribeLlmResult" |
| ) |
| elif model_client == "vertex": |
| if not args.vertex_account_pool_file: |
| parser.error("--vertex-account-pool-file or VERTEX_ACCOUNT_POOL_FILE required for --model-client vertex") |
| api_key = "" |
| base_url = "vertex://account-pool" |
| else: |
| |
| base_url = os.environ.get("MODEL_BASE_URL", DEFAULT_COMPANY_OPENAI_BASE_URL) |
| api_key = ( |
| os.environ.get("MODEL_API_KEY") |
| or os.environ.get("MODEL_OPENAI_API_KEY") |
| or os.environ.get("OPENAI_API_KEY", "") |
| or _read_secret_file(os.environ.get("MODEL_API_KEY_FILE", DEFAULT_COMPANY_OPENAI_API_KEY_FILE)) |
| ) |
| if model_client not in {"gateway", "vertex"} and not api_key: |
| configure_local_service_no_proxy(base_url) |
|
|
| vertex_account_pool = None |
| if model_client == "vertex": |
| vertex_account_pool = VertexAccountPool( |
| args.vertex_account_pool_file, |
| default_location=args.vertex_location, |
| cooldown_seconds=args.vertex_rate_limit_cooldown, |
| ) |
|
|
| |
| auto_dataset_config = None |
| if args.benchmarks or args.eval_root: |
| if args.datasets: |
| parser.error("Use either --datasets or --eval-root/--benchmarks, not both.") |
| eval_root = args.eval_root or DEFAULT_EVAL_ROOT |
| auto_dataset_config = build_eval_root_dataset_config(eval_root, args.benchmarks) |
| print( |
| "Loading benchmarks from eval root " |
| f"{eval_root}: {', '.join(auto_dataset_config.keys())}", |
| flush=True, |
| ) |
| if args.save_resolved_datasets_config: |
| save_dir = os.path.dirname(os.path.abspath(args.save_resolved_datasets_config)) |
| if save_dir: |
| os.makedirs(save_dir, exist_ok=True) |
| with open(args.save_resolved_datasets_config, "w") as f: |
| json.dump(auto_dataset_config, f, ensure_ascii=False, indent=2) |
| print( |
| f"Saved resolved datasets config to {args.save_resolved_datasets_config}", |
| flush=True, |
| ) |
| samples, dataset_configs = load_datasets( |
| auto_dataset_config, |
| args.data_root, |
| video_dr_system_prompt=video_dr_system_prompt, |
| ) |
| else: |
| if not args.datasets: |
| parser.error("--datasets is required unless --benchmarks or --eval-root is provided.") |
| print(f"Loading datasets from {args.datasets}...") |
| samples, dataset_configs = load_datasets( |
| args.datasets, |
| args.data_root, |
| video_dr_system_prompt=video_dr_system_prompt, |
| ) |
| if args.target_ids: |
| target_ids = set(args.target_ids) |
| total_before_filter = len(samples) |
| samples = [ |
| sample for sample in samples |
| if sample["id"] in target_ids or sample.get("source_id") in target_ids |
| ] |
| print( |
| f"Filtered samples with --target-ids: {len(samples)}/{total_before_filter}", |
| flush=True, |
| ) |
| if not samples: |
| parser.error("--target-ids did not match any sample id/source id") |
| print(f"Loaded {len(samples)} samples total\n") |
|
|
| selected_dataset_names = {sample.get("dataset", "") for sample in samples} |
| selected_dataset_configs = [ |
| dataset_configs[name] for name in selected_dataset_names if name in dataset_configs |
| ] |
| has_video_dr_dataset = any(cfg.get("task_kind") == "video_dr" for cfg in selected_dataset_configs) |
| has_image_dataset = any(cfg.get("task_kind") == "image" for cfg in selected_dataset_configs) |
| if args.tool_ablation_profile != "none" and args.mode != "tool": |
| parser.error("--tool-ablation-profile 仅支持 --mode tool") |
| if args.tool_ablation_profile != "none" and has_video_dr_dataset and has_image_dataset: |
| parser.error("同一次工具消融评测暂不支持混合 VideoDR 与 image benchmark") |
|
|
| |
| needs_llm = any( |
| "llm_score" in cfg.get("score_methods", []) |
| for cfg in selected_dataset_configs |
| ) |
| judge_client = args.judge_client or "" |
| judge_base_url = "" |
| judge_api_key = "" |
| if needs_llm and (not MARS_SUMMARIZER_ADDRESS or not MARS_SUMMARIZER_MODEL): |
| print( |
| "[WARN] llm_score 已切换到 video_dr_gen 的 MARS summarizer judge," |
| "但 MARS_SUMMARIZER_ADDRESS 或 MARS_SUMMARIZER_MODEL 未配置;" |
| "此时 llm_score 只会走快速路径,无法调用 LLM judge。", |
| flush=True, |
| ) |
|
|
| serper_api_key = os.environ.get("SERPER_API_KEY", "") |
| tavily_api_keys = load_tavily_api_keys() |
| tavily_key_pool = TavilyApiKeyPool(tavily_api_keys, cooldown_seconds=args.tavily_key_cooldown_seconds) |
| summarizer_base_url = os.environ.get("SUMMARIZER_BASE_URL", "") |
| summarizer_model = os.environ.get("SUMMARIZER_MODEL", "") |
| resolved_web_search_backend = resolve_web_search_backend(args.web_search_backend, tavily_api_keys) |
| if args.mode == "tool": |
| print( |
| f"web_search backend: {resolved_web_search_backend}" |
| + (" (from auto)" if args.web_search_backend == "auto" else ""), |
| flush=True, |
| ) |
| if resolved_web_search_backend == "tavily" and not tavily_api_keys: |
| parser.error("TAVILY_API_KEY, TAVILY_API_KEYS, or TAVILY_API_KEY_FILE required when --web-search-backend=tavily") |
| if resolved_web_search_backend == "tavily": |
| print(f"Tavily API key pool: {len(tavily_api_keys)} key(s)", flush=True) |
| tools_section = "" |
| allowed_tool_names = None |
| if args.mode == "tool": |
| if has_video_dr_dataset: |
| allowed_tool_names = get_allowed_tool_names(args.tool_ablation_profile, "video_dr") |
| elif has_image_dataset: |
| allowed_tool_names = get_allowed_tool_names(args.tool_ablation_profile, "image") |
| if has_image_dataset and not args.tool_config: |
| parser.error("图像 benchmark 的 tool 模式需要提供 --tool-config") |
| if args.tool_config: |
| print(f"Loading tool config from {args.tool_config}...") |
| tools_section = load_tool_config( |
| args.tool_config, |
| allowed_tool_names=allowed_tool_names, |
| normalize_image_schema=has_image_dataset, |
| normalize_video_schema=has_video_dr_dataset, |
| ) |
| if args.tool_ablation_profile != "none": |
| print( |
| "Tool ablation profile: " |
| f"{args.tool_ablation_profile}; allowed tools: {format_tool_names(allowed_tool_names or set())}", |
| flush=True, |
| ) |
| if has_video_dr_dataset and args.tool_ablation_profile != "none": |
| if not tools_section: |
| parser.error("VideoDR 工具消融需要提供 --tool-config 以生成消融后的工具定义 prompt") |
| video_dr_system_prompt = build_video_tool_system_prompt( |
| tools_section=tools_section, |
| allowed_tool_names=allowed_tool_names or set(), |
| max_turns=args.max_turns, |
| ) |
| for cfg in selected_dataset_configs: |
| if cfg.get("task_kind") == "video_dr": |
| cfg["system_prompt"] = video_dr_system_prompt |
|
|
| configure_local_service_no_proxy(summarizer_base_url) |
| configure_local_service_no_proxy(MARS_RETRIEVAL_ADDRESS) |
| configure_local_service_no_proxy(MARS_SUMMARIZER_ADDRESS) |
| if args.mode == "tool" and VIDEO_DR_IMAGE_SEARCH_MODE == "gateway": |
| configure_local_service_no_proxy(VIDEO_DR_GATEWAY_URL) |
|
|
| |
| if args.output_dir: |
| output_dir = args.output_dir |
| else: |
| timestamp = time.strftime("%y%m%d%H%M%S") |
| output_dir = f"eval_{args.model.replace('/', '_')}_{args.mode}_{timestamp}" |
| os.makedirs(output_dir, exist_ok=True) |
| if args.mode == "tool" and not args.search_cache_dir: |
| args.search_cache_dir = os.path.join(output_dir, "search_cache") |
| print( |
| f"Search cache dir not provided; using {args.search_cache_dir}", |
| flush=True, |
| ) |
|
|
| |
| search_cache = None |
| search_cache_seed_paths = resolve_cache_seed_paths( |
| args.seed_search_cache_from, |
| args.no_auto_seed_search_cache, |
| ) |
| if args.search_cache_dir: |
| search_cache = SearchCache(args.search_cache_dir) |
| if search_cache_seed_paths: |
| search_cache.seed_from_paths(search_cache_seed_paths) |
| image_search_cache_seed_paths = resolve_cache_seed_paths( |
| args.seed_image_search_cache_from, |
| args.no_auto_seed_image_search_cache, |
| ) |
|
|
| |
| kwargs = { |
| "eval_compat_profile": resolved_eval_profile, |
| "video_dr_system_prompt": video_dr_system_prompt, |
| "general_video_direct_system_prompt": GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, |
| "max_tokens": args.max_tokens, |
| "max_turns": args.max_turns, |
| "format_retry_limit": args.format_retry_limit, |
| "force_final_answer_turn": not args.disable_force_final_answer_turn, |
| "final_answer_retry_limit": args.final_answer_retry_limit, |
| "recover_no_tool_answer": not args.disable_no_tool_answer_recovery, |
| "temperature": args.temperature, |
| "top_p": args.top_p, |
| "top_k": args.top_k, |
| "presence_penalty": args.presence_penalty, |
| "repetition_penalty": args.repetition_penalty, |
| "seed": args.seed, |
| "min_pixels": args.min_pixels, |
| "max_pixels": args.max_pixels, |
| "factor": args.factor, |
| "qwen_vl_processing": args.qwen_vl_processing, |
| "serper_api_key": serper_api_key, |
| "web_search_backend": args.web_search_backend, |
| "serper_gateway_max_results": args.serper_gateway_max_results, |
| "serper_gateway_timeout": args.serper_gateway_timeout, |
| "serper_gateway_summary_max_tokens": args.serper_gateway_summary_max_tokens, |
| "tavily_api_key": tavily_api_keys[0] if tavily_api_keys else "", |
| "tavily_api_key_pool": tavily_key_pool, |
| "tavily_search_depth": args.tavily_search_depth, |
| "tavily_max_results": args.tavily_max_results, |
| "tavily_include_answer": args.tavily_include_answer, |
| "tavily_include_raw_content": args.tavily_include_raw_content, |
| "tavily_topic": args.tavily_topic, |
| "tavily_auto_parameters": args.tavily_auto_parameters, |
| "tavily_timeout": args.tavily_timeout, |
| "tavily_include_domains": _split_csv_env(os.environ.get("TAVILY_INCLUDE_DOMAINS", "")), |
| "tavily_exclude_domains": _split_csv_env(os.environ.get("TAVILY_EXCLUDE_DOMAINS", "")), |
| "summarizer_base_url": summarizer_base_url, |
| "summarizer_model": summarizer_model, |
| "serper_concurrency": args.serper_concurrency, |
| "search_cache": search_cache, |
| "image_search_cache_seed_paths": image_search_cache_seed_paths, |
| "tools_section": tools_section, |
| "tool_ablation_profile": args.tool_ablation_profile, |
| "allowed_tool_names": allowed_tool_names, |
| "judge_client": judge_client, |
| "judge_base_url": judge_base_url, |
| "judge_api_key": judge_api_key, |
| "judge_temperature": args.judge_temperature, |
| "video_initial_frames": args.video_initial_frames, |
| "video_interval_samples": args.video_interval_samples, |
| "video_max_resolution": args.video_max_resolution, |
| "video_jpeg_quality": args.video_jpeg_quality, |
| "vertex_account_pool": vertex_account_pool, |
| } |
|
|
| |
| async def health_check(): |
| import aiohttp |
| print("Running health checks...", flush=True) |
|
|
| |
| print(f" Checking model server: {base_url}", flush=True) |
| try: |
| timeout = aiohttp.ClientTimeout(total=10) |
| if model_client == "gateway": |
| result = await call_gateway_api( |
| [{"role": "user", "content": "Reply with exactly: OK"}], |
| args.model, |
| base_url, |
| api_key, |
| max_tokens=128, |
| temperature=0.0, |
| model_request_timeout=10, |
| ) |
| if result.get("error"): |
| raise Exception(result["error"]) |
| print(" Model gateway OK", flush=True) |
| elif model_client == "vertex": |
| result = await call_vertex_gemini_api( |
| [{"role": "user", "content": "Hello, please test the connection and reply with OK."}], |
| args.model, |
| base_url, |
| api_key, |
| vertex_account_pool=vertex_account_pool, |
| max_tokens=32, |
| temperature=0.0, |
| model_request_timeout=10, |
| ) |
| if result.get("error"): |
| raise Exception(result["error"]) |
| project = result.get("vertex_project_id", "") |
| print(f" Vertex Gemini OK project={project}", flush=True) |
| else: |
| async with create_http_session(timeout) as session: |
| |
| url = _openai_models_url(base_url) |
| headers = {"Content-Type": "application/json"} |
| if api_key: |
| headers["Authorization"] = f"Bearer {api_key}" |
| async with session.get(url, headers=headers) as resp: |
| if resp.status != 200: |
| raise Exception(f"Model server returned HTTP {resp.status}") |
| print(f" Model server OK", flush=True) |
| except Exception as e: |
| raise RuntimeError(f"Model server not reachable at {base_url}: {e}") |
|
|
| |
| if args.mode == "tool" and summarizer_base_url: |
| print(f" Checking summarizer server: {summarizer_base_url}", flush=True) |
| try: |
| async with create_http_session(timeout) as session: |
| url = f"{summarizer_base_url.rstrip('/')}/v1/models" |
| async with session.get(url) as resp: |
| if resp.status != 200: |
| raise Exception(f"Summarizer server returned HTTP {resp.status}") |
| print(f" Summarizer server OK", flush=True) |
| except Exception as e: |
| raise RuntimeError(f"Summarizer server not reachable at {summarizer_base_url}: {e}") |
|
|
| if needs_llm and MARS_SUMMARIZER_ADDRESS: |
| print(f" Checking LLM judge summarizer: http://{MARS_SUMMARIZER_ADDRESS}", flush=True) |
| try: |
| async with create_http_session(timeout) as session: |
| url = f"http://{MARS_SUMMARIZER_ADDRESS.rstrip('/')}/v1/models" |
| async with session.get(url, proxy="") as resp: |
| if resp.status != 200: |
| raise Exception(f"Judge summarizer returned HTTP {resp.status}") |
| print(" LLM judge summarizer OK", flush=True) |
| except Exception as e: |
| raise RuntimeError(f"LLM judge summarizer not reachable at http://{MARS_SUMMARIZER_ADDRESS}: {e}") |
|
|
| print("Health checks passed!", flush=True) |
|
|
| asyncio.run(health_check()) |
|
|
| try: |
| dataset_results = asyncio.run(run_evaluation( |
| samples, dataset_configs, model_client, args.model, base_url, api_key, |
| args.mode, args.max_concurrent, output_dir=output_dir, **kwargs |
| )) |
|
|
| save_results( |
| dataset_results, |
| output_dir, |
| run_config={ |
| "eval_compat_profile": resolved_eval_profile, |
| "format_retry_limit": args.format_retry_limit, |
| "force_final_answer_turn": not args.disable_force_final_answer_turn, |
| "final_answer_retry_limit": args.final_answer_retry_limit, |
| "recover_no_tool_answer": not args.disable_no_tool_answer_recovery, |
| "web_search_backend": resolved_web_search_backend, |
| "tavily_search_depth": args.tavily_search_depth, |
| "tavily_max_results": args.tavily_max_results, |
| "tavily_include_answer": args.tavily_include_answer, |
| "tavily_include_raw_content": args.tavily_include_raw_content, |
| "tavily_topic": args.tavily_topic, |
| "eval_root": args.eval_root or (DEFAULT_EVAL_ROOT if auto_dataset_config else ""), |
| "benchmarks": list(auto_dataset_config.keys()) if auto_dataset_config else [], |
| "resolved_datasets_config": args.save_resolved_datasets_config, |
| }, |
| ) |
| finally: |
| |
| |
| if search_cache: |
| stats = search_cache.get_stats() |
| print(f"Search cache: {stats['hits']}/{stats['total']} hits ({stats['hit_rate']:.1f}%), {stats['misses']} new searches cached", flush=True) |
| search_cache.close() |
| print("Done.", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|