#!/usr/bin/env python3 """ 统一评测脚本。 输入: - 旧版图像 benchmark 的 JSONL 配置、VideoDR 的无表头 CSV 配置,或 `SenseNova-MARS-Data/data/eval` 下按子目录组织的图像 benchmark。 - 模型服务地址、可选的工具配置 YAML、本地 MARS 检索服务,以及评测输出目录。 - 图像路径或视频路径;对视频会在评测过程中按 1fps 抽帧并缓存到输出目录。 - 可选的 Vertex Gemini 账号池 JSON;每个条目绑定一个 GCP project 与一个 service account JSON,用于 `--model-client vertex` 时轮询调用 Gemini。 处理: - 当传入 `--eval-root` / `--benchmarks` 时,自动扫描每个 benchmark 子目录的 `data.jsonl`,按子目录内的 `images/` 与可选 image search 缩略图构造临时数据集配置; 多选题 benchmark 使用 EM,开放问答 / 搜索型 benchmark 使用 LLM judge。 - 保留原始 image benchmark 的 direct/tool 评测流程,并把工具名兼容到新论文的 `zoom_in / image_search / web_search`;direct answer 模式只做单轮模型调用,不进入工具循环。 - 为 VideoDR 样本增加 `choose_frames / find_frame / zoom_in / image_search / web_search` 的多轮执行状态机,工具实现复用 SFT 构造代码中的定义;当 VideoDR 使用 direct answer 模式时,每条视频只抽取一次 1fps 均匀采样帧并直接回答,不调用任何工具。 - 支持工具消融 profile:在 search 消融中移除 `image_search / web_search`,在 location 消融中移除 `choose_frames / zoom_in`,并同步约束 prompt 与运行时白名单。 - 支持 Vertex 原生 Gemini `generateContent`:把 OpenAI-style messages 转为 `systemInstruction + contents`,每次模型请求按账号池选择 project,遇到 429/RESOURCE_EXHAUSTED 时切换到下一个 project 重试。 - 根据数据集配置动态执行 EM 或 LLM judge,并生成 HTML 结果浏览页。 输出: - 在输出目录下写入 `results.jsonl`、`summary.json`、`results.html`。 - 对视频样本额外落盘抽帧缓存与 HTML 可视化图片,便于复查轨迹。 """ import argparse import asyncio import base64 import ipaddress import io import json import math import os import re import string import sys import time from pathlib import Path from typing import Optional from urllib.parse import urlparse import atexit import hashlib import sqlite3 import aiohttp import yaml from PIL import Image # Required for URL fetching with JS rendering from playwright.async_api import async_playwright from video_dr_bridge import ( DEFAULT_VIDEO_INITIAL_FRAMES, DEFAULT_VIDEO_INTERVAL_SAMPLES, DEFAULT_VIDEO_JPEG_QUALITY, DEFAULT_VIDEO_MAX_RESOLUTION, GATEWAY_TOKEN as VIDEO_DR_GATEWAY_TOKEN, GATEWAY_USERID as VIDEO_DR_GATEWAY_USERID, GATEWAY_USERNAME as VIDEO_DR_GATEWAY_USERNAME, GATEWAY_URL as VIDEO_DR_GATEWAY_URL, IMAGE_SEARCH_MODE as VIDEO_DR_IMAGE_SEARCH_MODE, ImageSearchCache, MARS_RETRIEVAL_ADDRESS, MARS_SUMMARIZER_ADDRESS, MARS_SUMMARIZER_MODEL, VIDEO_DR_SYSTEM_PROMPT, add_search_padding as vdr_add_search_padding, crop_frame as vdr_crop_frame, extract_video_frames_1fps, get_bbox_config as vdr_get_bbox_config, get_video_dr_system_prompt, get_frame as vdr_get_frame, load_videodr_csv_samples, mars_web_search as vdr_mars_web_search, normalize_bbox as vdr_normalize_bbox, real_image_search as vdr_real_image_search, sample_interval as vdr_sample_interval, uniform_sample_indices as vdr_uniform_sample_indices, ) DEMO_ROOT = Path(__file__).resolve().parent.parent LOCAL_DATA_ROOT = Path( os.environ.get("VIDEO_DEEP_RESEARCH_DATA_ROOT", str(DEMO_ROOT / "local_data")) ).expanduser() LOCAL_SECRETS_DIR = Path( os.environ.get("VIDEO_DEEP_RESEARCH_SECRETS_DIR", str(DEMO_ROOT / "secrets")) ).expanduser() DEFAULT_EVAL_ROOT = str(LOCAL_DATA_ROOT / "eval") DEFAULT_COMPANY_OPENAI_BASE_URL = "http://35.220.164.252:3888/v1" DEFAULT_COMPANY_OPENAI_API_KEY_FILE = str(LOCAL_SECRETS_DIR / "company_openai_api_key.txt") DEFAULT_MODEL_GATEWAY_USERNAME = "kaiweichen" DEFAULT_MODEL_GATEWAY_USERID = "483124" DEFAULT_MODEL_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_token.txt") DEFAULT_GPT54_GATEWAY_USERNAME = "jemminyang" DEFAULT_GPT54_GATEWAY_USERID = "491087" DEFAULT_GPT54_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_gpt54_token.txt") DEFAULT_GEMINI_GATEWAY_USERNAME = "suyuanhuang" DEFAULT_GEMINI_GATEWAY_USERID = "485311" DEFAULT_GEMINI_GATEWAY_TOKEN_FILE = str(LOCAL_SECRETS_DIR / "model_gateway_gemini_token.txt") DEFAULT_TAVILY_API_KEY_FILE = str(LOCAL_SECRETS_DIR / "tavily_api_keys.txt") GATEWAY_MODEL_NAME_ALIASES = { "gpt5.4": "GPT-5.4", "gpt-5.4": "GPT-5.4", "gpt_5.4": "GPT-5.4", "gpt5.2": "gpt-5.2", "gpt-5.2": "gpt-5.2", "gpt_5.2": "gpt-5.2", "gpt4o": "gpt-4o", "gpt-4o": "gpt-4o", "gpt_4o": "gpt-4o", "gemini3propreview": "gemini-3-pro-preview", "gemini-3-pro-preview": "gemini-3-pro-preview", "gemini_3_pro_preview": "gemini-3-pro-preview", } OPENAI_MODEL_NAME_ALIASES = { "gpt4o": "gpt-4o", "gpt-4o": "gpt-4o", "gpt_4o": "gpt-4o", "gpt5.2": "gpt-5.2", "gpt-5.2": "gpt-5.2", "gpt_5.2": "gpt-5.2", } DEFAULT_IMAGE_DATASET_SYSTEM_PROMPT = """#Role You are a step-by-step reasoning assistant. Given a question, your task is to solve the problem one substep at a time. ## Guiding Principles At each turn, you must either: 1. Issue one specific tool enclosed in tags, 2. Or provide the final answer enclosed in tags. All outputs must begin with a thought enclosed in tags, explaining your current reasoning and what to do next. ## Output Format (strict) Always start with . Do not output the previous reasoning chain. 1. If reasoning continues: Your current reasoning and next plan One precise tool call to assist your reasoning 2. If ready to conclude: Summarize all reasoning and derive the answer Final answer """ GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT = """# Role You are an advanced general video understanding assistant. Given a user query about a video, answer directly from the provided video frames and any supplemental images. # Video Context - The input video has already been converted to 1 frame per second (1 fps). - You are provided with uniformly sampled frames from the video. - Some questions may include supplemental images referenced as , , etc.; these images are provided after the sampled video frames. # Direct Answer Rules 1. Do not call tools. Do not output tool-call tags or tool-like JSON. 2. Use only the provided frames, supplemental images, and your internal knowledge. 3. For multiple-choice questions, put only the option letter in the final answer. 4. For short-answer questions, put only the concise answer in the final answer. 5. If evidence is incomplete, answer with the best supported conclusion. # Output Format (STRICT) You may include brief reasoning in , then provide the final answer in . ... Final answer to the user's query """ DEFAULT_MCQ_BENCHMARKS = { "hr_bench_4k", "hr_bench_8k", "vstar_bench", } class FatalAPIError(Exception): """Raised when search or LLM judge has an error that should stop evaluation.""" pass class URLFetchError(Exception): """Raised when URL fetch fails with a retriable error (timeout, HTTP error, etc.).""" pass class WebFetchStats: """Global tracker for web fetch statistics.""" def __init__(self): self.total = 0 self.successful = 0 self.failed = 0 self.skipped = 0 # Non-HTML, skip extensions self.errors_by_code = {} # HTTP status code -> count def record_success(self): self.total += 1 self.successful += 1 def record_failure(self, error_msg: str = ""): self.total += 1 self.failed += 1 # Extract HTTP status code if present match = re.search(r'HTTP (\d+)', error_msg) if match: code = match.group(1) self.errors_by_code[code] = self.errors_by_code.get(code, 0) + 1 def record_skip(self): self.total += 1 self.skipped += 1 def get_stats(self) -> dict: return { "total": self.total, "successful": self.successful, "failed": self.failed, "skipped": self.skipped, "success_rate": (self.successful / self.total * 100) if self.total > 0 else 0, "errors_by_code": self.errors_by_code, } def format_progress(self) -> str: if self.total == 0: return "" rate = self.successful / self.total * 100 failed_str = f", {self.failed} failed" if self.failed > 0 else "" skipped_str = f", {self.skipped} skipped" if self.skipped > 0 else "" return f"Web: {self.successful}/{self.total} ({rate:.0f}%){failed_str}{skipped_str}" # Global web fetch stats tracker _web_fetch_stats = WebFetchStats() def _truncate_debug_text(text: str, limit: int = 2000) -> str: """清理并截断调试文本,避免错误日志被 base64 图片淹没。""" cleaned = re.sub(r"data:image[^\"'\\s>]*", "[IMAGE]", str(text or "")) return cleaned[:limit] def _summarize_messages_for_debug(messages: list[dict]) -> dict: """统计请求消息规模,帮助定位 400 是否由超长上下文触发。""" summary = { "num_messages": len(messages), "role_counts": {}, "text_chars": 0, "num_image_parts": 0, "num_tool_responses": 0, "last_user_text": "", "last_assistant_text": "", } def _content_to_text(content) -> str: if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict): if item.get("type") == "text": parts.append(item.get("text", "")) elif item.get("type") == "image_url": summary["num_image_parts"] += 1 return "\n".join(parts) return str(content) for msg in messages: role = msg.get("role", "") summary["role_counts"][role] = summary["role_counts"].get(role, 0) + 1 text = _content_to_text(msg.get("content", "")) summary["text_chars"] += len(text) summary["num_tool_responses"] += text.count("") if role == "user" and text: summary["last_user_text"] = _truncate_debug_text(text, limit=400) elif role == "assistant" and text: summary["last_assistant_text"] = _truncate_debug_text(text, limit=400) return summary def _extract_http_error_message(resp_status: int, data: dict, raw_text: str) -> str: """兼容不同 OpenAI 兼容服务的错误格式。""" if isinstance(data, dict): nested = data.get("error") if isinstance(nested, dict) and nested.get("message"): return _truncate_debug_text(nested["message"], limit=500) if isinstance(nested, str) and nested: return _truncate_debug_text(nested, limit=500) for key in ("message", "detail"): value = data.get(key) if value: return _truncate_debug_text(value, limit=500) return _truncate_debug_text(raw_text or f"HTTP {resp_status}", limit=500) def _append_no_proxy_entry(raw_entry: str) -> None: """将本地服务地址追加到 no_proxy / NO_PROXY。""" entry = (raw_entry or "").strip() if not entry: return if "://" in entry: parsed = urlparse(entry) entry = parsed.netloc or parsed.path entry = entry.strip().strip("/") if not entry: return for env_name in ("no_proxy", "NO_PROXY"): current = os.environ.get(env_name, "") items = [item.strip() for item in current.split(",") if item.strip()] if entry not in items: items.append(entry) os.environ[env_name] = ",".join(items) def configure_local_service_no_proxy(raw_url: str) -> None: """为本地或内网服务设置 no_proxy,避免访问时误走代理。""" if not raw_url: return normalized_url = raw_url if "://" in raw_url else f"http://{raw_url}" try: parsed = urlparse(normalized_url) except Exception: return host = (parsed.hostname or "").strip() port = parsed.port if not host: return is_local = host == "localhost" if not is_local: try: ip_obj = ipaddress.ip_address(host) in_cgn = ip_obj in ipaddress.ip_network("100.64.0.0/10") is_local = ip_obj.is_private or ip_obj.is_loopback or in_cgn except ValueError: is_local = False if not is_local: return # 强制包含 `host:port`,对应 AGENTS 中的 `export no_proxy="$no_proxy,:"` 约束。 _append_no_proxy_entry(f"{host}:{port}" if port else host) _append_no_proxy_entry(host) def _openai_api_base(base_url: str) -> str: """Return a normalized OpenAI-compatible API base that ends with /v1.""" base = (base_url or "").rstrip("/") return base if base.endswith("/v1") else f"{base}/v1" def _openai_chat_completions_url(base_url: str) -> str: return f"{_openai_api_base(base_url)}/chat/completions" def _openai_models_url(base_url: str) -> str: return f"{_openai_api_base(base_url)}/models" def _read_secret_file(path: str) -> str: if not path: return "" try: with open(os.path.expanduser(path), "r", encoding="utf-8") as f: return f.read().strip() except FileNotFoundError: return "" def normalize_model_name_for_client(model: str, client: str) -> str: """Normalize common user-facing model aliases before sending requests.""" raw = (model or "").strip() key = raw.lower() if client == "gateway": return GATEWAY_MODEL_NAME_ALIASES.get(key, raw) if client == "openai": return OPENAI_MODEL_NAME_ALIASES.get(key, raw) if client == "vertex" and raw.startswith("google/"): return raw.split("/", 1)[1] return raw def _is_gemini_gateway_model(model: str) -> bool: normalized = normalize_model_name_for_client(model, "gateway").lower() return normalized.startswith("gemini-") def _is_gpt54_gateway_model(model: str) -> bool: return normalize_model_name_for_client(model, "gateway").lower() == "gpt-5.4" def _gateway_model_type(model: str) -> str: if _is_gemini_gateway_model(model): return os.environ.get("MODEL_GATEWAY_GEMINI_MODEL_TYPE", "openai") return os.environ.get("MODEL_GATEWAY_MODEL_TYPE", "openai") def _get_model_gateway_credentials(model: str, api_key: str = "") -> tuple[str, str, str]: if _is_gemini_gateway_model(model): username = ( os.environ.get("MODEL_GATEWAY_GEMINI_USERNAME") or os.environ.get("GEMINI_GATEWAY_USERNAME") or DEFAULT_GEMINI_GATEWAY_USERNAME ) userid = ( os.environ.get("MODEL_GATEWAY_GEMINI_USERID") or os.environ.get("GEMINI_GATEWAY_USERID") or DEFAULT_GEMINI_GATEWAY_USERID ) token = ( os.environ.get("MODEL_GATEWAY_GEMINI_TOKEN") or os.environ.get("GEMINI_GATEWAY_TOKEN", "") or _read_secret_file( os.environ.get("MODEL_GATEWAY_GEMINI_TOKEN_FILE", DEFAULT_GEMINI_GATEWAY_TOKEN_FILE) ) or api_key ) return username, userid, token if _is_gpt54_gateway_model(model): username = ( os.environ.get("MODEL_GATEWAY_GPT54_USERNAME") or os.environ.get("GPT54_GATEWAY_USERNAME") or DEFAULT_GPT54_GATEWAY_USERNAME ) userid = ( os.environ.get("MODEL_GATEWAY_GPT54_USERID") or os.environ.get("GPT54_GATEWAY_USERID") or DEFAULT_GPT54_GATEWAY_USERID ) token = ( os.environ.get("MODEL_GATEWAY_GPT54_TOKEN") or os.environ.get("GPT54_GATEWAY_TOKEN", "") or _read_secret_file( os.environ.get("MODEL_GATEWAY_GPT54_TOKEN_FILE", DEFAULT_GPT54_GATEWAY_TOKEN_FILE) ) or api_key ) return username, userid, token username = ( os.environ.get("MODEL_GATEWAY_USERNAME") or os.environ.get("GATEWAY_USERNAME") or DEFAULT_MODEL_GATEWAY_USERNAME ) userid = ( os.environ.get("MODEL_GATEWAY_USERID") or os.environ.get("GATEWAY_USERID") or DEFAULT_MODEL_GATEWAY_USERID ) token = ( api_key or os.environ.get("MODEL_GATEWAY_TOKEN") or os.environ.get("GATEWAY_TOKEN", "") or _read_secret_file(os.environ.get("MODEL_GATEWAY_TOKEN_FILE", DEFAULT_MODEL_GATEWAY_TOKEN_FILE)) ) return username, userid, token def create_http_session(timeout: aiohttp.ClientTimeout) -> aiohttp.ClientSession: """创建 HTTP 会话:外部服务走环境代理,本地服务依赖 no_proxy 直连。""" return aiohttp.ClientSession(timeout=timeout, trust_env=True) class SearchCache: """Simple SQLite cache for search results.""" def __init__(self, cache_dir: str): self.db_path = os.path.join(cache_dir, "search_cache.db") os.makedirs(cache_dir, exist_ok=True) self._init_db() self.hits = 0 self.misses = 0 # Print initial stats count = self._get_entry_count() print(f"Search cache: {self.db_path} ({count} entries)") # Register cleanup on exit atexit.register(self.close) def _init_db(self): conn = sqlite3.connect(self.db_path) try: conn.execute("PRAGMA journal_mode=WAL") conn.execute(""" CREATE TABLE IF NOT EXISTS cache ( key TEXT PRIMARY KEY, value TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) finally: conn.close() def _get_entry_count(self) -> int: conn = sqlite3.connect(self.db_path) try: count = conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] finally: conn.close() return count def _iter_seed_db_paths(self, seed_paths: list[str]) -> list[str]: """Find historical search_cache.db files under files or directories.""" own_db = os.path.realpath(self.db_path) db_paths = [] seen = set() for seed_path in seed_paths: if not seed_path: continue expanded = os.path.realpath(os.path.expanduser(seed_path)) if os.path.isfile(expanded): candidates = [expanded] elif os.path.isdir(expanded): candidates = [] for root, _, files in os.walk(expanded): if "search_cache.db" in files: candidates.append(os.path.join(root, "search_cache.db")) else: print(f"[CACHE SEED] Skip missing path: {seed_path}", flush=True) continue for candidate in candidates: real_candidate = os.path.realpath(candidate) if real_candidate == own_db or real_candidate in seen: continue seen.add(real_candidate) db_paths.append(real_candidate) return sorted(db_paths) def seed_from_paths(self, seed_paths: list[str]) -> dict: """Copy historical cache rows into this run cache without overwriting local rows.""" db_paths = self._iter_seed_db_paths(seed_paths) stats = { "sources": len(db_paths), "inserted": 0, "existing": 0, "skipped": 0, } if not db_paths: print("[CACHE SEED] No historical search_cache.db found.", flush=True) return stats dest = sqlite3.connect(self.db_path, timeout=30.0) try: for db_path in db_paths: try: src = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=10.0) try: columns = { row[1] for row in src.execute("PRAGMA table_info(cache)").fetchall() } if not {"key", "value"}.issubset(columns): print(f"[CACHE SEED] Skip incompatible cache: {db_path}", flush=True) stats["skipped"] += 1 continue if "created_at" in columns: rows = src.execute("SELECT key, value, created_at FROM cache").fetchall() else: rows = [ (key, value, None) for key, value in src.execute("SELECT key, value FROM cache").fetchall() ] finally: src.close() before = dest.total_changes dest.executemany( "INSERT OR IGNORE INTO cache (key, value, created_at) VALUES (?, ?, COALESCE(?, CURRENT_TIMESTAMP))", rows, ) dest.commit() inserted = dest.total_changes - before stats["inserted"] += inserted stats["existing"] += max(0, len(rows) - inserted) print( f"[CACHE SEED] {db_path}: {inserted} inserted, " f"{max(0, len(rows) - inserted)} already present", flush=True, ) except sqlite3.Error as e: print(f"[CACHE SEED] Skip unreadable cache {db_path}: {e}", flush=True) stats["skipped"] += 1 finally: dest.close() total = self._get_entry_count() print( f"[CACHE SEED] Done: {stats['inserted']} inserted from {stats['sources']} source db(s); " f"{stats['existing']} duplicate row(s); current cache has {total} entries.", flush=True, ) return stats def _make_key(self, query: str, top_k: int, model: str) -> str: # Create normalized cache key normalized = re.sub(r'\s+', ' ', query.strip().lower()) parts = [ f"q={normalized}", f"k={top_k}", f"model={model}", f"prompt=mmsearch_r1", f"jina=0", f"think=0" ] return hashlib.md5("|".join(parts).encode("utf-8")).hexdigest() async def get(self, query: str, top_k: int, model: str) -> Optional[str]: key = self._make_key(query, top_k, model) conn = sqlite3.connect(self.db_path) try: row = conn.execute("SELECT value FROM cache WHERE key = ?", (key,)).fetchone() finally: conn.close() if row: self.hits += 1 print(f"[CACHE HIT] {query[:50]}...") # Cache stores JSON {"summaries": "..."}, extract summaries field try: data = json.loads(row[0]) if isinstance(data, dict) and "summaries" in data: return data["summaries"] return row[0] # Fallback: return as-is if not JSON format except json.JSONDecodeError: return row[0] # Plain string, return as-is self.misses += 1 return None async def set(self, query: str, top_k: int, model: str, value: str): # Store in JSON format key = self._make_key(query, top_k, model) data = json.dumps({"summaries": value}) for attempt in range(5): conn = sqlite3.connect(self.db_path, timeout=5.0) try: conn.execute("INSERT OR REPLACE INTO cache (key, value) VALUES (?, ?)", (key, data)) conn.commit() print(f"[CACHE STORE] {query[:50]}...") return except sqlite3.OperationalError as e: if "locked" in str(e).lower() and attempt < 4: await asyncio.sleep(0.05 * (2 ** attempt)) else: raise finally: conn.close() def get_stats(self) -> dict: total = self.hits + self.misses return { "hits": self.hits, "misses": self.misses, "total": total, "hit_rate": (self.hits / total * 100) if total > 0 else 0.0, } def close(self): """Checkpoint WAL to prevent corruption on shutdown.""" try: conn = sqlite3.connect(self.db_path, timeout=10.0) try: conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") finally: conn.close() except Exception: pass # Best effort on shutdown def resolve_cache_seed_paths(cli_paths: Optional[list[str]], no_auto_seed: bool) -> list[str]: """Resolve cache seed roots from CLI, env, or the default inference/runs tree.""" if no_auto_seed: return [] if cli_paths is not None: return list(cli_paths) env_seed_roots = os.environ.get("SEARCH_CACHE_SEED_ROOTS", "").strip() if env_seed_roots: return [ item.strip() for item in re.split(r"[:;,]", env_seed_roots) if item.strip() ] return [os.path.join(os.path.dirname(__file__), "runs")] def seed_image_search_cache(image_search_cache: ImageSearchCache, seed_paths: list[str]) -> dict: """Merge historical image_search_cache.json files into this run cache.""" own_file = os.path.realpath(image_search_cache.cache_file) cache_files = [] seen = set() for seed_path in seed_paths: if not seed_path: continue expanded = os.path.realpath(os.path.expanduser(seed_path)) if os.path.isfile(expanded): candidates = [expanded] if os.path.basename(expanded) == "image_search_cache.json" else [] elif os.path.isdir(expanded): candidates = [] for root, _, files in os.walk(expanded): if "image_search_cache.json" in files: candidates.append(os.path.join(root, "image_search_cache.json")) else: continue for candidate in candidates: real_candidate = os.path.realpath(candidate) if real_candidate == own_file or real_candidate in seen: continue seen.add(real_candidate) cache_files.append(real_candidate) stats = { "sources": len(cache_files), "inserted": 0, "existing": 0, "skipped": 0, } if not cache_files: print("[IMAGE CACHE SEED] No historical image_search_cache.json found.", flush=True) return stats for cache_file in sorted(cache_files): try: with open(cache_file, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): print(f"[IMAGE CACHE SEED] Skip incompatible cache: {cache_file}", flush=True) stats["skipped"] += 1 continue inserted = 0 existing = 0 for key, value in data.items(): if key in image_search_cache.cache: existing += 1 else: image_search_cache.cache[key] = value inserted += 1 stats["inserted"] += inserted stats["existing"] += existing print( f"[IMAGE CACHE SEED] {cache_file}: {inserted} inserted, {existing} already present", flush=True, ) except Exception as e: print(f"[IMAGE CACHE SEED] Skip unreadable cache {cache_file}: {e}", flush=True) stats["skipped"] += 1 if stats["inserted"]: image_search_cache.save() print( f"[IMAGE CACHE SEED] Done: {stats['inserted']} inserted from {stats['sources']} file(s); " f"{stats['existing']} duplicate row(s); current cache has {len(image_search_cache.cache)} entries.", flush=True, ) return stats TOOL_ABLATION_PROFILES = ("none", "nosearch", "nolocation") IMAGE_BENCHMARK_TOOL_NAMES = {"zoom_in", "image_search", "web_search"} VIDEO_DR_TOOL_NAMES = {"choose_frames", "find_frame", "zoom_in", "image_search", "web_search"} TOOL_NAME_ALIASES = { "image_zoom_in_tool": "zoom_in", "image_search_tool": "image_search", "text_search_tool": "web_search", } def canonicalize_tool_name(tool_name: str) -> str: return TOOL_NAME_ALIASES.get(tool_name, tool_name) def get_allowed_tool_names(profile: str, task_kind: str) -> set[str]: profile = profile or "none" if task_kind == "video_dr": if profile == "nosearch": return {"choose_frames", "find_frame", "zoom_in"} if profile == "nolocation": return {"find_frame", "image_search", "web_search"} return set(VIDEO_DR_TOOL_NAMES) if profile == "nosearch": return {"zoom_in"} if profile == "nolocation": return {"image_search", "web_search"} return set(IMAGE_BENCHMARK_TOOL_NAMES) def format_tool_names(tool_names: set[str]) -> str: ordered = [ name for name in ["choose_frames", "find_frame", "zoom_in", "image_search", "web_search"] if name in tool_names ] return ", ".join(f"`{name}`" for name in ordered) def build_disallowed_tool_response(tool_name: str, allowed_tool_names: set[str]) -> str: return ( "\n" f"Error: tool `{tool_name}` is not available in this evaluation profile. " f"Use only these tools: {format_tool_names(allowed_tool_names)}.\n" "" ) def load_tool_config( tool_config_path: str, allowed_tool_names: Optional[set[str]] = None, normalize_image_schema: bool = False, normalize_video_schema: bool = False, ) -> str: with open(tool_config_path) as f: config = yaml.safe_load(f) tools = config.get("tools", []) tool_definitions = [] loaded_names = [] skipped_names = [] for tool in tools: schema = tool.get("tool_schema", {}) tool_name = schema.get("function", {}).get("name", "") if allowed_tool_names is not None and tool_name not in allowed_tool_names: skipped_names.append(tool_name or "") continue if normalize_image_schema or normalize_video_schema: schema = json.loads(json.dumps(schema)) if normalize_image_schema: _normalize_image_tool_schema(schema) if normalize_video_schema: _normalize_video_tool_schema(schema) tool_definitions.append(json.dumps(schema, ensure_ascii=False)) loaded_names.append(tool_name or "") print(f" Loaded {len(tool_definitions)} tools: {', '.join(loaded_names)}") if skipped_names: print(f" Skipped tools for this benchmark: {', '.join(skipped_names)}") # Build # Tools section tool_def_str = "\n".join(tool_definitions) tools_section = f"""# Tools You may call one or more functions to assist with the user query. You are provided with function signatures within XML tags: {tool_def_str} For each function call, return a json object with function name and arguments within XML tags: {{"name": , "arguments": }} """ return tools_section def _normalize_image_tool_schema(schema: dict) -> None: """把图像 benchmark 工具描述规范成单图三工具语义。""" function = schema.get("function", {}) tool_name = function.get("name", "") if tool_name == "zoom_in": function["description"] = "Magnify a specific region of the input image for detailed visual analysis." elif tool_name == "image_search": function["description"] = ( "Reverse image search a specific region of the input image to identify unknown entities. " "Use this if you DO NOT understand what the zoomed-in entity is." ) properties = function.get("parameters", {}).get("properties", {}) bbox_schema = properties.get("bbox") if isinstance(bbox_schema, dict): bbox_schema["description"] = ( "Bounding box [x1, y1, x2, y2] in 0-1000 relative coordinates of the original input image." ) def _normalize_video_tool_schema(schema: dict) -> None: """把视频 benchmark 工具描述规范成 VideoDR 评测的 0-1000 bbox 语义。""" function = schema.get("function", {}) tool_name = function.get("name", "") if tool_name == "zoom_in": function["description"] = ( "Magnify a specific region of the frame currently locked by find_frame for detailed visual analysis." ) elif tool_name == "image_search": function["description"] = ( "Reverse image search a specific region of the locked frame to identify unknown entities. " "Use this if you DO NOT understand what the entity in the locked frame is." ) properties = function.get("parameters", {}).get("properties", {}) bbox_schema = properties.get("bbox") if isinstance(bbox_schema, dict): bbox_schema["items"] = {"type": "integer"} bbox_schema["minItems"] = 4 bbox_schema["maxItems"] = 4 bbox_schema["description"] = ( "Bounding box [x1, y1, x2, y2] in 0-1000 relative coordinates on the frame currently locked by find_frame." ) def build_image_tool_system_prompt( task_instruction: str, tools_section: str, max_turns: int = 10, allowed_tool_names: Optional[set[str]] = None, ) -> str: """构造 image benchmark 的 DeepResearch 风格 tool system prompt。""" task_instruction = (task_instruction or "").strip() allowed_tool_names = set(allowed_tool_names or IMAGE_BENCHMARK_TOOL_NAMES) task_block = "" if task_instruction: task_block = f""" # Task-Specific Instruction {task_instruction} """ can_zoom = "zoom_in" in allowed_tool_names can_image_search = "image_search" in allowed_tool_names can_web_search = "web_search" in allowed_tool_names if can_image_search or can_web_search: role_goal = ( "combining visual clues from a single image with available external search evidence" ) else: role_goal = "using the visual evidence in a single image without external search tools" if can_zoom: if can_image_search or can_web_search: initial_action = ( "Start from the full image. If the key evidence is already clear, answer directly or use an available search tool. " "Otherwise, use `zoom_in(bbox)` to inspect a relevant region more closely." ) else: initial_action = ( "Start from the full image. If the key evidence is already clear, answer directly. " "Otherwise, use `zoom_in(bbox)` to inspect a relevant region more closely." ) region_rule = ( "`zoom_in` is used for detailed local inspection. After examining a region, continue with another " "allowed tool only when it adds new evidence." ) else: initial_action = ( "Start from the full image. `zoom_in` is not available in this evaluation profile, so choose any bbox " "for `image_search` directly from the original image when identification is needed." ) region_rule = ( "No zoomed crop will be returned. Use the original image and the available search tools to gather evidence." ) search_rules = [] if can_image_search: search_rules.extend([ "Do NOT rely on a visual guess, memory, or resemblance to identify a real-world person, place, artwork, product, logo, team, landmark, media title, or scientific object.", "Before using a visually inferred entity name in `web_search` or in the final answer, confirm that entity with `image_search(bbox)` on the most diagnostic region.", "If the entity name is directly readable from text in the image, verify the text carefully with the full image or `zoom_in` when available; otherwise use `image_search` for identity confirmation.", "If `image_search` contradicts your initial guess, discard the guess and reason from the search evidence.", ]) elif can_web_search: search_rules.extend([ "`image_search` is not available. Do not invent an entity name from appearance alone; use `web_search(query)` only when the query can be grounded in readable text or unambiguous visual evidence.", "If the visual identity is uncertain and no image search is available, state the uncertainty in `` and answer from the best supported evidence.", ]) else: search_rules.extend([ "`image_search` and `web_search` are not available in this evaluation profile.", "Answer from the image, the question, and any knowledge already contained in the model. Do not call search tools or claim that external verification was performed.", ]) search_rules_text = "\n".join(f" - {line}" for line in search_rules) loop_rule_parts = [] if can_zoom: loop_rule_parts.append("Do not spend many turns only zooming.") if can_image_search or can_web_search: loop_rule_parts.append("If the current evidence is not sufficient, switch to another allowed evidence-gathering tool or provide the best supported answer.") else: loop_rule_parts.append("If zooming no longer adds visual evidence, stop using tools and answer.") loop_rule = " ".join(loop_rule_parts) bbox_tools = [name for name in ["zoom_in", "image_search"] if name in allowed_tool_names] if bbox_tools: bbox_rule = ( f"Whenever you use {format_tool_names(set(bbox_tools))}, provide a precise bbox in 0-1000 " "relative coordinates that tightly covers the relevant region in the original input image." ) else: bbox_rule = "No bbox-based inspection tool is available in this evaluation profile." return f"""# Role You are an advanced Image DeepResearch reasoning assistant. Given a user query that requires {role_goal}, your task is to solve the problem step-by-step while using only the tools exposed below. # Image Context - You are given one input image. - The image may contain multiple entities, regions, text snippets, or visual clues relevant to the question. - You should first reason over the full image, then inspect specific regions only when additional detail is necessary. {tools_section} # Tool Dependency & Workflow Rules (STRICT) 1. **Allowed Tools:** You may only call {format_tool_names(allowed_tool_names)}. Do not call any tool that is not listed in the `` section. 2. **Initial Action:** {initial_action} 3. **Region-Based Analysis:** {region_rule} 4. **Search Strategy Selection:** {search_rules_text} 5. **Avoid Tool Loops:** {loop_rule} 6. **Bounding Boxes:** {bbox_rule} 7. **Retry Mechanism:** If the current region or search result is not sufficient, inspect another region or try another search. You have a MAXIMUM OF {max_turns} ATTEMPTS (loops) to find the answer. If all attempts are exhausted, provide the best answer you have with an explicit note about remaining uncertainty. {task_block} # Output Format (STRICT) At each turn, you must either issue ONE precise tool call OR provide the final answer. All outputs MUST begin with a thought process enclosed in tags. 1. If reasoning continues: ... {{"name": "", "arguments": }} 2. If ready to conclude (after gathering sufficient information): ... Final answer to the user's query """ def build_video_tool_system_prompt( tools_section: str, allowed_tool_names: set[str], max_turns: int = 10, ) -> str: """构造 VideoDR 工具消融 profile 的 system prompt。""" can_choose = "choose_frames" in allowed_tool_names can_zoom = "zoom_in" in allowed_tool_names can_image_search = "image_search" in allowed_tool_names can_web_search = "web_search" in allowed_tool_names if can_choose: initial_action = ( "You can directly call `find_frame` if the target is obvious in the initial 64 sparse frames. " "Otherwise, call `choose_frames` to narrow down the search interval." ) retry_action = "loop back to `choose_frames` or `find_frame` to explore another segment/frame" else: initial_action = ( "`choose_frames` is not available. Use the initial 64 sparse frames to select the most relevant " "frame index, then call `find_frame` to lock onto that frame." ) retry_action = "call `find_frame` on another candidate frame from the initial sparse frames" if can_choose: interval_rule = ( "After calling `choose_frames` and receiving the sub-frames, you MUST call `find_frame` to lock " "onto a specific single frame before performing any detailed actions." ) else: interval_rule = ( "Since `choose_frames` is not available, do not ask for an interval. Move from the initial sparse " "frames directly to `find_frame`." ) detail_actions = [] if can_zoom: detail_actions.append("use `zoom_in` to inspect local details") if can_image_search: detail_actions.append("use `image_search` to identify an unknown entity") if can_web_search: detail_actions.append("use `web_search` for external facts after the entity or clue is grounded") if detail_actions: detail_rule = ( "`find_frame` locks the working frame. After a successful `find_frame`, you may " + "; ".join(detail_actions) + "." ) else: detail_rule = ( "`find_frame` locks the working frame. No detail or search tool is available after that, so answer " "from the locked frame and the initial sparse frames." ) search_rules = [] if can_image_search: search_rules.append("If you visually locate an entity but do not know its name or identity, use `image_search(bbox)`.") else: search_rules.append("`image_search` is not available; do not claim reverse-image verification.") if can_web_search: search_rules.append("If you visually locate and recognize the entity, use `web_search(query)` for related external knowledge or facts.") else: search_rules.append("`web_search` is not available; do not call text search or claim web verification.") if not can_image_search and not can_web_search: search_rules.append("For this no-search profile, answer only from video evidence and model-internal knowledge.") search_rules_text = "\n".join(f" - {line}" for line in search_rules) bbox_tools = [name for name in ["zoom_in", "image_search"] if name in allowed_tool_names] if bbox_tools: bbox_rule = ( f"Whenever you use {format_tool_names(set(bbox_tools))}, provide a precise bbox in 0-1000 " "relative coordinates on the frame currently locked by `find_frame`." ) else: bbox_rule = "No bbox-based tool is available in this evaluation profile." return f"""# Role You are an advanced Video DeepResearch reasoning assistant. Given a user query that requires combining visual clues from a video with the tools exposed below, your task is to solve the problem step-by-step by deeply navigating the video. # Video Context - The original video has been converted to 1 frame per second (1 fps). - The original frame index (e.g., Frame 1, Frame 10) is watermarked in the top-left corner of every frame. - You are initially provided with 64 uniformly sampled frames from the video. {tools_section} # Tool Dependency & Workflow Rules (STRICT) 1. **Allowed Tools:** You may only call {format_tool_names(allowed_tool_names)}. Do not call any tool that is not listed in the `` section. 2. **Initial Action:** {initial_action} 3. **Interval to Frame:** {interval_rule} 4. **Detailing & Searching:** {detail_rule} 5. **Search Strategy Selection:** {search_rules_text} 6. **Bounding Boxes:** {bbox_rule} 7. **Retry Mechanism:** If your current evidence fails, yields incorrect info, or does not solve the query, you can {retry_action}. You have a MAXIMUM OF {max_turns} ATTEMPTS (loops) to find the answer. If all attempts are exhausted, provide the best answer you have with an explicit note about remaining uncertainty. # Output Format (STRICT) At each turn, you must either issue ONE precise tool call OR provide the final answer. All outputs MUST begin with a thought process enclosed in tags. 1. If reasoning continues: ... {{"name": "", "arguments": }} 2. If ready to conclude: ... Final answer to the user's query """ # ============================================================================= # Image Processing # ============================================================================= def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def ceil_by_factor(number: int, factor: int) -> int: return math.ceil(number / factor) * factor def floor_by_factor(number: int, factor: int) -> int: return math.floor(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 8294400, ) -> tuple[int, int]: """Resize dimensions to be divisible by factor and within pixel limits.""" if max(height, width) / min(height, width) > 200: raise ValueError("Aspect ratio too extreme") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = floor_by_factor(int(height / beta), factor) w_bar = floor_by_factor(int(width / beta), factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = ceil_by_factor(int(height * beta), factor) w_bar = ceil_by_factor(int(width * beta), factor) return h_bar, w_bar def process_image( image: Image.Image, min_pixels: int = 65536, max_pixels: int = 8294400, factor: int = 32, qwen_vl_processing: bool = True, ) -> Image.Image: """Process image with smart_resize (Qwen-VL style) or simple max_pixels resize.""" if image.mode != "RGB": image = image.convert("RGB") width, height = image.size if qwen_vl_processing: # Qwen-VL style: align to factor and respect min/max pixels resized_height, resized_width = smart_resize(height, width, factor, min_pixels, max_pixels) else: # Simple resize: just ensure within max_pixels if width * height > max_pixels: scale = math.sqrt(max_pixels / (width * height)) resized_width = int(width * scale) resized_height = int(height * scale) else: resized_width, resized_height = width, height if (resized_width, resized_height) != (width, height): image = image.resize((resized_width, resized_height), Image.BICUBIC) return image def load_and_process_image( path: str, min_pixels: int = 65536, max_pixels: int = 8294400, factor: int = 32, qwen_vl_processing: bool = True, ) -> Image.Image: """Load image from path and process it.""" img = Image.open(path) img.load() return process_image(img, min_pixels, max_pixels, factor, qwen_vl_processing) def image_to_base64(img: Image.Image) -> tuple[str, str]: """Convert PIL Image to base64 string.""" if img.mode not in ("RGB", "RGBA"): img = img.convert("RGB") buffer = io.BytesIO() img.save(buffer, format="PNG") base64_str = base64.standard_b64encode(buffer.getvalue()).decode("utf-8") return base64_str, "image/png" def crop_frame_to_rgb_jpeg(frame_path: str, bbox: list[float], output_path: str, quality: int = 90) -> None: """按归一化 bbox 裁剪图片,并强制以 RGB JPEG 保存,兼容 RGBA 输入。""" with Image.open(frame_path) as img: img.load() if img.mode != "RGB": img = img.convert("RGB") width, height = img.size x1, y1, x2, y2 = bbox crop_box = ( max(0, min(width, int(round(x1 * width)))), max(0, min(height, int(round(y1 * height)))), max(0, min(width, int(round(x2 * width)))), max(0, min(height, int(round(y2 * height)))), ) if crop_box[2] <= crop_box[0] or crop_box[3] <= crop_box[1]: print( f"[WARN] Invalid crop bbox after scaling: {bbox}; using full image instead", flush=True, ) crop_box = (0, 0, width, height) cropped = img.crop(crop_box) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) cropped.save(output_path, "JPEG", quality=quality) def crop_image( image: Image.Image, bbox: list, coord_scale: float = 1000.0, min_pixels: int = 65536, max_pixels: int = 8294400, factor: int = 32, qwen_vl_processing: bool = True, padding: tuple = (0.0, 0.0), # Aligned with training config (no padding) ) -> Image.Image: """Crop image using bbox [x1, y1, x2, y2] in 0-1000 coords. - Supports configurable padding (default: no padding) - Padding is capped at 600px - Uses smart_resize for output """ img_w, img_h = image.size # Normalize to 0-1 range x1, y1, x2, y2 = [float(c) / coord_scale for c in bbox] # Clamp to valid range before padding x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(1, x2), min(1, y2) # Apply padding if specified (capped at 600px) if padding[0] > 0 or padding[1] > 0: padding_cap = (600.0 / img_w, 600.0 / img_h) actual_padding = ( min(padding[0], padding_cap[0]), min(padding[1], padding_cap[1]), ) x1 = max(0.0, x1 - actual_padding[0]) y1 = max(0.0, y1 - actual_padding[1]) x2 = min(1.0, x2 + actual_padding[0]) y2 = min(1.0, y2 + actual_padding[1]) crop_box = (int(x1 * img_w), int(y1 * img_h), int(x2 * img_w), int(y2 * img_h)) cropped = image.crop(crop_box) if qwen_vl_processing: # Qwen-VL: ensure minimum size and apply smart_resize w, h = cropped.size if w < 28 or h < 28: cropped = cropped.resize((max(w, 28), max(h, 28)), Image.Resampling.LANCZOS) return process_image(cropped, min_pixels, max_pixels, factor, qwen_vl_processing) else: # No processing: just convert to RGB if cropped.mode != "RGB": cropped = cropped.convert("RGB") return cropped # ============================================================================= # Answer Extraction (em_score_mcq) # ============================================================================= def extract_mcq_answer(text: str) -> Optional[str]: """Extract MCQ answer (A, B, C, D) from model output.""" if not text: return None # Preprocess text = text.rsplit("<|im_start|>assistant", 1)[-1] text = re.split(r'', text)[-1] # Try tags first matches = list(re.finditer(r'(.*?)', text, re.DOTALL)) if matches: candidate = matches[-1].group(1).strip() # Single letter if re.match(r'^[A-Da-d]$', candidate): return candidate.upper() # Letter with punctuation punct = re.findall(r'(?:\(([A-D])\)|\[([A-D])\]|(? tags match = re.search(r'\\boxed\{([^}]+)\}', text, re.IGNORECASE) if match: boxed = match.group(1).strip() if re.search(r'[A-Da-d]', boxed): # Single letter if re.match(r'^[A-Da-d]$', boxed): return boxed.upper() # Letter with punctuation punct = re.findall(r'(?:\(([A-D])\)|\[([A-D])\]|(? return None (skip fallback) return None # Fallback patterns on full text (only when no tag and no \boxed{} with a-d) # "Answer: (A)" answer_matches = re.findall(r'Answer:\s*\(([A-D])\)', text, re.IGNORECASE) if answer_matches: return answer_matches[-1].upper() # "answer is (A)" phrase = re.findall(r'(?:correct answer is|answer is)[:\s]*\(([A-D])\)', text, re.IGNORECASE) if phrase: return phrase[-1].upper() # Bold **(A)** bold_segments = re.findall(r'\*\*[^*]+\*\*', text) for seg in reversed(bold_segments): m = re.search(r'\(([A-D])\)', seg, re.IGNORECASE) if m: return m.group(1).upper() # (A) in parentheses paren = re.findall(r'\(([A-D])\)', text, re.IGNORECASE) if paren: return paren[-1].upper() # A) or A] format bracket = re.findall(r'(? bool: """Check MCQ answer against ground truth(s).""" if extracted is None: return False # Handle list of ground truths if isinstance(ground_truth, str): ground_truth = [ground_truth] extracted_upper = extracted.upper() for gt in ground_truth: if extracted_upper == gt.upper(): return True return False # ============================================================================= # LLM Judge (llm_score) # ============================================================================= def is_answer_acceptable(generated: str, gt_answer: str) -> bool: """Check deterministic exact answer equivalence before using LLM judge.""" if not generated or not gt_answer: return False gen_lower = generated.strip().lower() gt_lower = gt_answer.strip().lower() return gen_lower == gt_lower def _normalize_answer_for_em(s: str) -> str: """Normalize answer for EM comparison (aligned with reference code).""" def remove_articles(text): if text.strip().lower() in ["a"]: return text return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) return white_space_fix(remove_articles(remove_punc(s.lower()))) def _em_check(prediction: str, golden_answer: str) -> bool: """Check if prediction matches golden answer after normalization.""" if not prediction or not golden_answer: return False return _normalize_answer_for_em(prediction) == _normalize_answer_for_em(golden_answer) def extract_answer_allow_no_tag(text: str) -> str: """Extract answer from tags, fallback to full response.""" if not text: return "" # Preprocess: remove assistant prefix and thinking tags text = text.rsplit("<|im_start|>assistant", 1)[-1] # Try tags first on the final assistant turn. Accept malformed # final answers that place inside ..., but do not # score answer-like text from a tool-call turn. matches = list(re.finditer(r'(.*?)', text, re.DOTALL)) if matches and "" not in text: return matches[-1].group(1).strip() text = re.split(r'', text)[-1] # Fallback: return entire response (stripped at <|im_end|> if present) return text.split('<|im_end|>')[0].strip() # System prompt for LLM judge LLM_JUDGE_SYSTEM_PROMPT = """You are an AI assistant tasked with evaluating the correctness of model responses based on a question and ground truth answer. Your judgment should follow these principles: 1. Consider the question and ground truth answer holistically before evaluating the model's response. 2. Your decision should be strictly Yes or No, based on whether the model's response is factually accurate and aligns with the ground truth answer. 3. If the model response is a more specific form of the ground truth answer, it is correct. 4. If the model response includes all key information but adds minor details, it is correct as long as the extra details are factually correct. 5. If the model response contradicts, modifies, or omits critical parts of the answer, it is incorrect. 6. For numerical values, ensure correctness even when presented in different units. 7. For names, check for first and last name correctness. If the middle name is extra but correct, consider it correct. 8. For yes/no questions, the response must exactly match "Yes" or "No" to be correct. 9. If the model response is a common abbreviation, nickname, or alternative name for the ground truth answer, it is correct. 10. If there are multiple candidate answers, evaluate the model's response against all of them. If the response aligns with at least one candidate according to the rules above, it should be considered correct. 11. For multiple choice questions (A, B, C, D), be more lenient. If the model provides the correct letter choice, even with additional text or formatting, consider it correct. 12. If the model's answer contains the correct choice letter (A, B, C, or D) anywhere in the response, and it's clear this is the intended answer, mark it as correct. 13. Ignore formatting issues like extra parentheses, brackets, or minor text variations as long as the core answer is correct. Your output must be in the following format: Yes/No Explanation of why the answer is correct or incorrect.""" LLM_JUDGE_USER_PROMPT = """Question and Model Response Evaluation Question: {question} Ground Truth Answer: {ground_truth} Model Response: {model_response} Evaluation Instructions Evaluate whether the Model Response is correct based on the Question and Ground Truth Answer. Follow the predefined judgment rules and provide a clear Yes/No answer along with a justification. Output Format Yes/No Detailed reasoning following the evaluation principles.""" async def llm_judge_acceptable( generated: str, gt_answer: str, question: str, session: aiohttp.ClientSession, max_retries: int = 3, ) -> bool: """用 MARS summarizer 判定生成答案是否正确。""" if not generated or not gt_answer: return False if _em_check(generated, gt_answer): return True if is_answer_acceptable(generated, gt_answer): return True addr = MARS_SUMMARIZER_ADDRESS model = MARS_SUMMARIZER_MODEL if not addr or not model: print("[WARN] MARS_SUMMARIZER not configured, falling back to is_answer_acceptable()") return False user_prompt = LLM_JUDGE_USER_PROMPT.format( question=question, ground_truth=gt_answer, model_response=generated, ) payload = { "model": model, "messages": [ {"role": "system", "content": LLM_JUDGE_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], "max_tokens": 1024, "temperature": 0.0, "chat_template_kwargs": {"enable_thinking": False}, } for attempt in range(max_retries): try: async with session.post( f"http://{addr}/v1/chat/completions", json=payload, headers={"Content-Type": "application/json"}, timeout=aiohttp.ClientTimeout(total=120), proxy="", ) as resp: if resp.status != 200: text = await resp.text() print(f" [JUDGE] HTTP {resp.status}: {text[:200]}") if attempt < max_retries - 1: await asyncio.sleep(1) continue return False data = await resp.json() choices = data.get("choices", []) if choices: content = choices[0].get("message", {}).get("content", "") content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() match = re.search(r"\s*(Yes|No)\s*", content, re.IGNORECASE | re.DOTALL) if match: return match.group(1).lower() == "yes" return False except asyncio.TimeoutError: print(f" [JUDGE] Timeout (attempt {attempt + 1}/{max_retries})") if attempt < max_retries - 1: await asyncio.sleep(1) except Exception as e: print(f" [JUDGE] Error: {e} (attempt {attempt + 1}/{max_retries})") if attempt < max_retries - 1: await asyncio.sleep(1) return False async def llm_judge_score( question: str, model_answer: str, ground_truth: list | str, image_path: str, judge_client: str, judge_base_url: str, judge_api_key: str, judge_temperature: float = 0.0, shared_session: aiohttp.ClientSession | None = None, ) -> float: """Use VideoDR reference LLM judge logic. Returns 1.0 or 0.0.""" del image_path, judge_client, judge_base_url, judge_api_key, judge_temperature extracted_answer = extract_answer_allow_no_tag(model_answer) if isinstance(ground_truth, str): ground_truth = [ground_truth] if shared_session is not None: for gt in ground_truth: if await llm_judge_acceptable(extracted_answer, gt, question, shared_session): return 1.0 return 0.0 timeout = aiohttp.ClientTimeout(total=3600) async with create_http_session(timeout) as session: for gt in ground_truth: if await llm_judge_acceptable(extracted_answer, gt, question, session): return 1.0 return 0.0 # ============================================================================= # Data Loading # ============================================================================= def _read_first_jsonl_record(path: str) -> dict: """读取 JSONL 的第一条有效样本,用于推断 benchmark 类型。""" with open(path) as f: for line in f: line = line.strip() if line: return json.loads(line) return {} def _count_jsonl_records(path: str) -> int: """统计 JSONL 样本数,写入自动生成的数据集配置。""" total = 0 with open(path) as f: for line in f: if line.strip(): total += 1 return total def _prompt_text_from_raw_record(raw: dict) -> str: """从原始样本中取 user prompt 文本。""" prompt = raw.get("prompt", []) if isinstance(prompt, list): for msg in prompt: if isinstance(msg, dict) and msg.get("role") == "user": return str(msg.get("content", "")) return "" def _ground_truths_from_raw_record(raw: dict) -> list[str]: """从原始样本中取 ground_truth 列表。""" reward_model = raw.get("reward_model", {}) ground_truth = reward_model.get("ground_truth", []) if isinstance(ground_truth, str): return [ground_truth] if isinstance(ground_truth, list): return [str(item) for item in ground_truth] return [] def _looks_like_mcq_record(raw: dict) -> bool: """根据首条样本判断是否为 A-D 多选题。""" ground_truths = _ground_truths_from_raw_record(raw) if not ground_truths: return False if not all(re.fullmatch(r"[A-Da-d]", gt.strip()) for gt in ground_truths): return False prompt_text = _prompt_text_from_raw_record(raw) return bool(re.search(r"\([A-Da-d]\)", prompt_text)) def build_eval_root_dataset_config( eval_root: str, benchmarks: Optional[list[str]] = None, ) -> dict: """从 `data/eval` 子目录自动构造 `eval.py` 可用的数据集配置。""" root = os.path.abspath(eval_root or DEFAULT_EVAL_ROOT) if not os.path.isdir(root): raise FileNotFoundError(f"eval root not found: {root}") if benchmarks: benchmark_names = benchmarks else: benchmark_names = sorted( name for name in os.listdir(root) if os.path.isdir(os.path.join(root, name)) and os.path.exists(os.path.join(root, name, "data.jsonl")) ) config = {} missing = [] for benchmark_name in benchmark_names: benchmark_dir = os.path.join(root, benchmark_name) annotation = os.path.join(benchmark_dir, "data.jsonl") if not os.path.exists(annotation): missing.append(benchmark_name) continue first_record = _read_first_jsonl_record(annotation) is_mcq = benchmark_name in DEFAULT_MCQ_BENCHMARKS or _looks_like_mcq_record(first_record) reward_fn = ["em_score_mcq"] if is_mcq else [] unused_reward_fn = [] if is_mcq else ["llm_score"] config[benchmark_name] = { "root": benchmark_dir, "annotation": annotation, "length": _count_jsonl_records(annotation), "repeat_time": 1, "reward_fn": reward_fn, "unused_reward_fn": unused_reward_fn, "input_template": { "name": "general", "arguments": { "system_prompt": DEFAULT_IMAGE_DATASET_SYSTEM_PROMPT, "format_instruction": "", "add_image_path": False, }, }, "comment": benchmark_name, } if missing: raise FileNotFoundError( "benchmark data.jsonl not found under eval root: " + ", ".join(missing) ) if not config: raise ValueError(f"no benchmark data.jsonl found under eval root: {root}") return config def load_datasets( config_path: str | dict, data_root: str = "", video_dr_system_prompt: str = VIDEO_DR_SYSTEM_PROMPT, ) -> tuple[list[dict], dict]: """Load datasets from JSON config. Returns (samples, dataset_configs).""" if isinstance(config_path, dict): config = config_path else: with open(config_path) as f: config = json.load(f) all_samples = [] dataset_configs = {} for dataset_name, ds_config in config.items(): ds_type = ds_config.get("type", "") root = ds_config.get("root", "") annotation = ds_config.get("annotation", "") video_root = ds_config.get("video_root", "") if data_root and not os.path.isabs(root): root = os.path.join(data_root, root) if data_root and not os.path.isabs(annotation): annotation = os.path.join(data_root, annotation) if data_root and video_root and not os.path.isabs(video_root): video_root = os.path.join(data_root, video_root) if not os.path.exists(annotation): print(f"Warning: {annotation} not found, skipping {dataset_name}") continue samples = [] if ds_type == "video_dr_csv" or annotation.lower().endswith(".csv"): if not video_root: video_root = ds_config.get("root", "") if data_root and video_root and not os.path.isabs(video_root): video_root = os.path.join(data_root, video_root) if not video_root: raise ValueError( f"Dataset '{dataset_name}' is video_dr_csv but missing video_root/root" ) samples = load_videodr_csv_samples(annotation, video_root, dataset_name) else: with open(annotation) as f: for idx, line in enumerate(f): raw = json.loads(line) prompt = raw.get("prompt", []) question = "" for msg in prompt: if msg.get("role") == "user": question = msg.get("content", "").replace("", "").strip() break reward_model = raw.get("reward_model", {}) # Keep full ground_truth list - loops through all for LLM judge answer = reward_model.get("ground_truth", [""]) images = raw.get("image", []) image_path = os.path.join(root, images[0]) if images else "" # Extract image search data if present (for data-driven image_search_tool) # Data is at top level in data.jsonl: image_search_title_list, image_search_thumbnail_list # Limit to 5 results (default) IMAGE_SEARCH_MAX_RESULTS = 5 image_search_kwargs = {} if "image_search_title_list" in raw: titles = raw["image_search_title_list"] image_search_kwargs["image_search_title_list"] = titles[:IMAGE_SEARCH_MAX_RESULTS] if titles else None if "image_search_thumbnail_list" in raw: thumbs = raw["image_search_thumbnail_list"] image_search_kwargs["image_search_thumbnail_list"] = thumbs[:IMAGE_SEARCH_MAX_RESULTS] if thumbs else None sample = { "id": f"{dataset_name}-{idx}", "question": question, "answer": answer, "image_path": image_path, "judge_image_path": image_path, "dataset": dataset_name, "data_root": root, "task_kind": "image", } if image_search_kwargs: sample["image_search_data"] = image_search_kwargs samples.append(sample) all_samples.extend(samples) # Scoring methods: combine reward_fn and unused_reward_fn into single list reward_fn = ds_config.get("reward_fn", []) unused_reward_fn = ds_config.get("unused_reward_fn", []) score_methods = list(set(reward_fn + unused_reward_fn)) if not score_methods: raise ValueError(f"Dataset '{dataset_name}' has no scoring methods (reward_fn or unused_reward_fn)") valid_methods = {"em_score_mcq", "llm_score"} invalid = set(score_methods) - valid_methods if invalid: raise ValueError(f"Dataset '{dataset_name}' has invalid scoring methods: {invalid}") # Get input_template config input_template = ds_config.get("input_template", {}) template_args = input_template.get("arguments", {}) system_prompt = template_args.get("system_prompt", "") format_instruction = template_args.get("format_instruction", "") if ds_type == "video_dr_csv" and not system_prompt: system_prompt = video_dr_system_prompt dataset_configs[dataset_name] = { "score_methods": score_methods, "system_prompt": system_prompt, "format_instruction": format_instruction, "task_kind": "video_dr" if ds_type == "video_dr_csv" or annotation.lower().endswith(".csv") else "image", } print(f" Loaded {len(samples)} samples from {dataset_name}") print(f" reward_fn: {reward_fn}") print(f" unused_reward_fn: {unused_reward_fn}") return all_samples, dataset_configs # ============================================================================= # API Calls # ============================================================================= def _message_content_to_text(content) -> str: """把 OpenAI 兼容返回里的 content/reasoning_content 统一转成纯文本。""" if isinstance(content, str): return content if isinstance(content, list): text_parts = [] for item in content: if isinstance(item, str): text_parts.append(item) elif isinstance(item, dict) and item.get("type") == "text": text_parts.append(item.get("text", "")) return "".join(text_parts) return "" def _normalize_openai_tool_calls(raw_tool_calls) -> list[dict]: """把 OpenAI/sglang 风格的 tool_calls 规范化成统一结构。""" normalized = [] for tool_call in raw_tool_calls or []: if isinstance(tool_call, dict): function = tool_call.get("function") or {} name = function.get("name", "") raw_args = function.get("arguments") else: function = getattr(tool_call, "function", None) name = getattr(function, "name", "") if function is not None else "" raw_args = getattr(function, "arguments", None) if function is not None else None if isinstance(raw_args, str): try: args = json.loads(raw_args) except Exception: args = {} elif isinstance(raw_args, dict): args = raw_args else: args = {} normalized.append({ "name": name, "arguments": args, }) return normalized def reconstruct_assistant_text( content: str = "", reasoning_content: str = "", tool_calls: Optional[list[dict]] = None, ) -> str: """把 reasoning_content/tool_calls/content 拼回完整 assistant 文本。""" content = (content or "").strip() reasoning_content = (reasoning_content or "").strip() tool_calls = tool_calls or [] if not reasoning_content and not tool_calls: return content parts = [] if reasoning_content: parts.append(f"\n{reasoning_content}\n") for tool_call in tool_calls: call_json = json.dumps({ "name": tool_call.get("name", ""), "arguments": tool_call.get("arguments", {}), }, ensure_ascii=False) parts.append(f"\n{call_json}\n") if content: parts.append(content) return "\n".join(parts).strip() async def call_openai_api( messages: list[dict], model: str, base_url: str, api_key: str = "", **kwargs, ) -> dict: """Call OpenAI-compatible API.""" request_debug = _summarize_messages_for_debug(messages) debug_context = kwargs.get("request_debug_context") or {} # Build request # Note: Don't set temperature by default - let API use its default (1.0) body = { "model": normalize_model_name_for_client(model, "openai"), "messages": messages, "max_tokens": kwargs.get("max_tokens", 4096), } if "temperature" in kwargs: body["temperature"] = kwargs["temperature"] if "top_p" in kwargs: body["top_p"] = kwargs["top_p"] if "top_k" in kwargs: body["top_k"] = kwargs["top_k"] if "presence_penalty" in kwargs: body["presence_penalty"] = kwargs["presence_penalty"] if "repetition_penalty" in kwargs: body["repetition_penalty"] = kwargs["repetition_penalty"] if "seed" in kwargs and kwargs["seed"] is not None: body["seed"] = kwargs["seed"] # For Qwen3 models via vLLM/SGLang - control thinking mode if kwargs.get("extra_body"): body.update(kwargs["extra_body"]) headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" url = _openai_chat_completions_url(base_url) try: timeout = aiohttp.ClientTimeout(total=300) async with create_http_session(timeout) as session: async with session.post(url, headers=headers, json=body) as resp: raw_text = await resp.text() try: data = json.loads(raw_text) except json.JSONDecodeError: data = {} if resp.status != 200: error = _extract_http_error_message(resp.status, data, raw_text) raw_error_text = _truncate_debug_text(raw_text, limit=4000) print( "[MODEL HTTP ERROR]", json.dumps( { "status": resp.status, "error": error, "sample_id": debug_context.get("sample_id"), "turn": debug_context.get("turn"), "task_kind": debug_context.get("task_kind"), "request_debug": request_debug, "raw_response": raw_error_text, }, ensure_ascii=False, ), flush=True, ) return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": error, "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } choice = data["choices"][0] message = choice.get("message", {}) or {} content = _message_content_to_text(message.get("content")) reasoning_content = _message_content_to_text(message.get("reasoning_content")) tool_calls = _normalize_openai_tool_calls(message.get("tool_calls")) full_text = reconstruct_assistant_text( content=content, reasoning_content=reasoning_content, tool_calls=tool_calls, ) return { "content": content, "reasoning_content": reasoning_content, "tool_calls": tool_calls, "full_text": full_text, "finish_reason": choice.get("finish_reason", "stop"), "error": None, "request_debug": request_debug, } except Exception as e: error = _truncate_debug_text(str(e), limit=1000) print( "[MODEL REQUEST EXCEPTION]", json.dumps( { "error": error, "sample_id": debug_context.get("sample_id"), "turn": debug_context.get("turn"), "task_kind": debug_context.get("task_kind"), "request_debug": request_debug, }, ensure_ascii=False, ), flush=True, ) return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": error, "error_status_code": None, "error_raw_response": "", "request_debug": request_debug, } def _parse_gateway_json(value): """Parse JSON strings returned by the gateway; leave other values unchanged.""" if isinstance(value, str): stripped = value.strip() if not stripped: return value try: return json.loads(stripped) except Exception: return value return value def _extract_gateway_response_payload(data): """Find the model payload inside known gateway wrapper shapes.""" data = _parse_gateway_json(data) if not isinstance(data, dict): return data for key in ("model_output", "response", "result", "data"): if key in data: value = _parse_gateway_json(data.get(key)) if isinstance(value, dict) and value is not data: return _extract_gateway_response_payload(value) return value return data def _gateway_text_value(value) -> str: """Extract text from a known gateway/OpenAI content value.""" value = _parse_gateway_json(value) if isinstance(value, str): return value if isinstance(value, list): parts = [_gateway_text_value(item) for item in value] return "\n".join(part for part in parts if part).strip() if not isinstance(value, dict): return "" for key in ("output_text", "text", "value", "summary"): text = value.get(key) if isinstance(text, str) and text: return text if isinstance(text, (list, dict)): nested = _gateway_text_value(text) if nested: return nested content = value.get("content") if content is not None: nested = _gateway_text_value(content) if nested: return nested return "" def _extract_gateway_text(payload) -> str: """Extract assistant text from OpenAI Responses, chat, or plain gateway payloads.""" payload = _parse_gateway_json(payload) if isinstance(payload, str): return payload if not isinstance(payload, dict): return "" output_text = _gateway_text_value(payload.get("output_text")) if output_text: return output_text choices = payload.get("choices") if isinstance(choices, list) and choices: message = choices[0].get("message", {}) or {} content = _message_content_to_text(message.get("content")) reasoning_content = _message_content_to_text(message.get("reasoning_content")) tool_calls = _normalize_openai_tool_calls(message.get("tool_calls")) return reconstruct_assistant_text(content, reasoning_content, tool_calls) output = payload.get("output") if isinstance(output, list): parts = [] for item in output: if not isinstance(item, dict): continue if item.get("type") == "message": text = _gateway_text_value(item.get("content", [])) if text: parts.append(text) elif isinstance(item.get("content"), list): text = _gateway_text_value(item.get("content", [])) if text: parts.append(text) elif isinstance(item.get("text"), str): parts.append(item["text"]) return "\n".join(parts).strip() content = payload.get("content") if content is not None: text = _message_content_to_text(content) return text or _gateway_text_value(content) return "" def _gateway_error_summary(payload) -> str: """Build a compact error summary from a gateway/OpenAI response payload.""" payload = _parse_gateway_json(payload) if not isinstance(payload, dict): return str(payload)[:500] fields = [] for key in ("id", "status", "error", "incomplete_details"): value = payload.get(key) if value: fields.append(f"{key}={str(value)[:200]}") output = payload.get("output") if isinstance(output, list): fields.append(f"output_items={len(output)}") item_types = [str(item.get("type")) for item in output if isinstance(item, dict)] if item_types: fields.append(f"output_types={','.join(item_types[:10])}") if fields: return "; ".join(fields)[:500] return str(payload)[:500] def _gateway_uses_completion_token_limit(model: str) -> bool: """Return whether this gateway model expects max_completion_tokens.""" normalized = normalize_model_name_for_client(model, "gateway").lower() return normalized.startswith("gpt-5") def _gateway_uses_responses_api(model: str) -> bool: """Return whether this gateway model expects OpenAI Responses API params.""" return _is_gpt54_gateway_model(model) def _gateway_responses_content(content): """Convert OpenAI chat-style content parts to Responses API content parts.""" if isinstance(content, str): return content if not isinstance(content, list): return str(content or "") converted = [] for item in content: if isinstance(item, str): if item: converted.append({"type": "input_text", "text": item}) continue if not isinstance(item, dict): continue item_type = item.get("type") if item_type == "text": text = item.get("text", "") if text: converted.append({"type": "input_text", "text": text}) elif item_type == "image_url": image_url = item.get("image_url", {}) url = image_url.get("url", "") if isinstance(image_url, dict) else str(image_url or "") if url: converted.append({"type": "input_image", "image_url": url}) return converted or "" def _gateway_responses_input(messages: list[dict]) -> list[dict]: """Convert OpenAI chat-style messages to Responses API input messages.""" converted = [] for message in messages: role = message.get("role", "user") converted.append( { "role": role, "content": _gateway_responses_content(message.get("content", "")), } ) return converted def _is_retryable_gateway_result(result: dict) -> bool: """Return whether a gateway model-call failure is likely transient.""" if _is_fatal_gateway_quota_result(result): return False status = result.get("error_status_code") if status in {408, 409, 425, 429, 500, 502, 503, 504, 520, 521, 522, 523, 524}: return True text = " ".join( str(result.get(key) or "") for key in ("error", "error_raw_response") ).lower() retry_markers = ( "connection reset", "connection aborted", "server disconnected", "connection closed", "read tcp", "broken pipe", "timeout", "timed out", "temporarily unavailable", "empty gateway response", "error_code': '50001'", '"error_code": "50001"', ) return any(marker in text for marker in retry_markers) def _is_fatal_gateway_quota_result(result: dict) -> bool: """Return whether a gateway failure means the whole eval should stop.""" text = " ".join( str(result.get(key) or "") for key in ("error", "error_raw_response") ).lower() fatal_markers = ( "application quota has been exhausted", "quota has been exhausted", "quota exhausted", "insufficient quota", ) return any(marker in text for marker in fatal_markers) async def _maybe_retry_gateway_result( result: dict, *, attempt: int, max_attempts: int, model: str, request_debug: dict, initial_delay: float, ) -> bool: """Sleep and return True when the caller should retry a gateway request.""" result.setdefault("request_debug", request_debug) result["request_debug"]["gateway_attempt"] = attempt result["request_debug"]["gateway_max_attempts"] = max_attempts if _is_fatal_gateway_quota_result(result): error = _truncate_debug_text(str(result.get("error") or ""), limit=500) raise FatalAPIError(f"[GATEWAY ERROR] {model} quota exhausted: {error}") if attempt >= max_attempts or not _is_retryable_gateway_result(result): return False delay = min(initial_delay * (2 ** (attempt - 1)), 20.0) print( "[GATEWAY RETRY]", json.dumps( { "model": model, "attempt": attempt, "max_attempts": max_attempts, "sleep_seconds": round(delay, 2), "status": result.get("error_status_code"), "error": _truncate_debug_text(str(result.get("error") or ""), limit=500), "sample_id": request_debug.get("sample_id"), "turn": request_debug.get("turn"), "task_kind": request_debug.get("task_kind"), }, ensure_ascii=False, ), flush=True, ) await asyncio.sleep(delay) return True async def call_gateway_api( messages: list[dict], model: str, base_url: str, api_key: str = "", **kwargs, ) -> dict: """Call the company gateway for GPT models.""" request_debug = _summarize_messages_for_debug(messages) debug_context = kwargs.get("request_debug_context") or {} request_debug.update(debug_context) gateway_username, gateway_userid, gateway_token = _get_model_gateway_credentials(model, api_key) if not gateway_token: token_hint = ( "MODEL_GATEWAY_GPT54_TOKEN or GPT54_GATEWAY_TOKEN is required" if _is_gpt54_gateway_model(model) else "MODEL_GATEWAY_GEMINI_TOKEN or GEMINI_GATEWAY_TOKEN is required" if _is_gemini_gateway_model(model) else "MODEL_GATEWAY_TOKEN or GATEWAY_TOKEN is required" ) return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": token_hint, "request_debug": request_debug, } if _gateway_uses_responses_api(model): params = { "input": _gateway_responses_input(messages), "temperature": kwargs.get("temperature", 0.7), "max_output_tokens": kwargs.get("max_tokens", 4096), } else: params = { "messages": messages, "temperature": kwargs.get("temperature", 0.7), } token_limit_key = ( "max_completion_tokens" if _gateway_uses_completion_token_limit(model) else "max_tokens" ) params[token_limit_key] = kwargs.get("max_tokens", 4096) payload = { "sec_info": { "username": gateway_username, "userid": gateway_userid, "token": gateway_token, }, "model_type": _gateway_model_type(model), "model_name": normalize_model_name_for_client(model, "gateway"), "params": json.dumps(params, ensure_ascii=False), } headers = { "Content-Type": "application/json", "User-Agent": "ifbook-http-client", } raw_max_retries = kwargs.get( "gateway_max_retries", os.environ.get("MODEL_GATEWAY_MAX_RETRIES", "10"), ) try: max_attempts = max(1, int(raw_max_retries)) except (TypeError, ValueError): max_attempts = 10 try: initial_delay = max( 0.1, float( os.environ.get("MODEL_GATEWAY_RETRY_INITIAL_DELAY") or os.environ.get("MODEL_GATEWAY_INITIAL_RETRY_DELAY", "1.0") ), ) except ValueError: initial_delay = 1.0 last_result = None for attempt in range(1, max_attempts + 1): try: timeout = aiohttp.ClientTimeout(total=kwargs.get("model_request_timeout", 300)) async with create_http_session(timeout) as session: async with session.post(base_url, headers=headers, json=payload) as resp: text = await resp.text() raw_error_text = _truncate_debug_text(text, limit=4000) if resp.status != 200: result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"HTTP {resp.status}: {raw_error_text[:500]}", "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result try: data = json.loads(text) except Exception: result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"Invalid JSON response: {raw_error_text[:500]}", "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result code = str(data.get("code", "0")) if code not in ("0", ""): result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": str(data.get("message", data))[:500], "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result response_payload = _extract_gateway_response_payload(data) if isinstance(response_payload, dict): payload_error = response_payload.get("error") if payload_error: result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"Gateway response error: {_gateway_error_summary(response_payload)}", "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result status = response_payload.get("status") if status and status not in ("completed", "succeeded", "success"): result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"Gateway response not completed: {_gateway_error_summary(response_payload)}", "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result content = _extract_gateway_text(response_payload) if not content.strip(): result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"Empty gateway response: {_gateway_error_summary(response_payload)}", "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result request_debug["gateway_attempt"] = attempt request_debug["gateway_max_attempts"] = max_attempts return { "content": content, "reasoning_content": "", "tool_calls": [], "full_text": content, "finish_reason": "stop", "error": None, "request_debug": request_debug, } except Exception as e: error = f"{type(e).__name__}: {str(e)}".strip() result = { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": _truncate_debug_text(error, limit=1000), "error_status_code": None, "error_raw_response": "", "request_debug": request_debug, } last_result = result if await _maybe_retry_gateway_result( result, attempt=attempt, max_attempts=max_attempts, model=model, request_debug=request_debug, initial_delay=initial_delay, ): continue return result return last_result or { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": "Gateway request failed after retries", "error_status_code": None, "error_raw_response": "", "request_debug": request_debug, } async def call_gemini_api( messages: list[dict], model: str, base_url: str, api_key: str, **kwargs, ) -> dict: """Call Gemini API.""" # Convert messages to Gemini format contents = [] for msg in messages: role = "user" if msg["role"] == "user" else "model" parts = [] content = msg.get("content", []) if isinstance(content, str): parts.append({"text": content}) else: for item in content: if isinstance(item, str): parts.append({"text": item}) elif item.get("type") == "text": parts.append({"text": item["text"]}) elif item.get("type") == "image_url": url = item["image_url"]["url"] if url.startswith("data:"): mime, data = url.split(";base64,", 1) mime = mime.replace("data:", "") parts.append({"inlineData": {"mimeType": mime, "data": data}}) contents.append({"role": role, "parts": parts}) body = { "contents": contents, "generationConfig": { "temperature": kwargs.get("temperature", 0.7), "maxOutputTokens": kwargs.get("max_tokens", 4096), } } url = f"{base_url.rstrip('/')}/models/{model}:generateContent" headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"} try: timeout = aiohttp.ClientTimeout(total=300) async with create_http_session(timeout) as session: async with session.post(url, headers=headers, json=body) as resp: data = await resp.json() if resp.status != 200: error = data.get('error', {}).get('message', f'HTTP {resp.status}') return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": error, } content = "" if data.get("candidates"): parts = data["candidates"][0].get("content", {}).get("parts", []) content = "".join(p.get("text", "") for p in parts) return { "content": content, "reasoning_content": "", "tool_calls": [], "full_text": content, "finish_reason": "stop", "error": None, } except Exception as e: return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": str(e), } class VertexAccountPool: """Vertex Gemini service-account pool with per-request rotation.""" def __init__( self, pool_file: str, default_location: str = "global", cooldown_seconds: float = 60.0, ): self.pool_file = pool_file self.default_location = default_location or "global" self.cooldown_seconds = max(0.0, float(cooldown_seconds or 0.0)) self.accounts = self._load_accounts(pool_file) if not self.accounts: raise ValueError(f"Vertex account pool is empty or missing: {pool_file}") self._lock = asyncio.Lock() self._next_index = 0 self._cooldown_until = {} self._auth_request = None self._init_credentials() def _load_accounts(self, pool_file: str) -> list[dict]: if not pool_file: return [] path = Path(pool_file).expanduser() if not path.is_absolute(): path = Path.cwd() / path if not path.exists(): return [] with path.open("r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): records = data elif isinstance(data, dict) and isinstance(data.get("accounts"), list): records = data["accounts"] else: raise ValueError( f"Vertex account pool must be a JSON list or {{'accounts': [...]}}: {path}" ) accounts = [] for idx, item in enumerate(records, start=1): if not isinstance(item, dict): raise ValueError(f"Vertex account #{idx} must be an object") if item.get("enabled", True) is False: continue project_id = str(item.get("project_id") or "").strip() if not project_id: raise ValueError(f"Vertex account #{idx} missing project_id") location = str(item.get("location") or self.default_location).strip() or "global" key_path = str(item.get("service_account_key") or item.get("json_path") or "").strip() key_info = item.get("service_account_info") if key_path: resolved = Path(key_path).expanduser() if not resolved.is_absolute(): resolved = path.parent / resolved key_path = str(resolved.resolve()) if not key_path and not isinstance(key_info, dict): raise ValueError( f"Vertex account #{idx} must provide service_account_key/json_path " f"or service_account_info. project_id={project_id}" ) accounts.append({ "account_name": item.get("account_name") or item.get("name") or f"vertex-{idx}", "project_id": project_id, "location": location, "service_account_key": key_path, "service_account_info": key_info if isinstance(key_info, dict) else None, "credentials": None, }) return accounts def _init_credentials(self) -> None: try: import google.auth.transport.requests from google.oauth2 import service_account except ImportError as exc: raise ImportError( "Vertex mode requires google-auth: pip install google-auth google-auth-httplib2" ) from exc scopes = ["https://www.googleapis.com/auth/cloud-platform"] self._auth_request = google.auth.transport.requests.Request() valid_accounts = [] skipped_accounts = [] for account in self.accounts: key_path = account.get("service_account_key") or "" key_info = account.get("service_account_info") if key_path: if not os.path.exists(key_path): skipped_accounts.append((account.get("account_name", ""), f"missing key file: {key_path}")) continue credentials = service_account.Credentials.from_service_account_file( key_path, scopes=scopes, ) else: credentials = service_account.Credentials.from_service_account_info( key_info, scopes=scopes, ) try: credentials.refresh(self._auth_request) except Exception as exc: skipped_accounts.append(( account.get("account_name", ""), f"{type(exc).__name__}: {str(exc)[:300]}", )) continue account["credentials"] = credentials valid_accounts.append(account) self.accounts = valid_accounts if not self.accounts: details = "; ".join(f"{name}: {reason}" for name, reason in skipped_accounts[:5]) raise ValueError(f"No valid Vertex account credentials in {self.pool_file}. {details}") if skipped_accounts: skipped_text = "; ".join(f"{name}: {reason}" for name, reason in skipped_accounts[:5]) print( f"[VERTEX] Skipped {len(skipped_accounts)} invalid account(s): {skipped_text}", flush=True, ) print(f"[VERTEX] Loaded {len(self.accounts)} account(s) from {self.pool_file}", flush=True) async def acquire_account(self) -> tuple[int, dict]: while True: async with self._lock: now = time.monotonic() count = len(self.accounts) for offset in range(count): idx = (self._next_index + offset) % count if self._cooldown_until.get(idx, 0.0) <= now: self._next_index = (idx + 1) % count return idx, self.accounts[idx] wait = min(self._cooldown_until.values()) - now await asyncio.sleep(max(wait, 0.1)) async def mark_cooldown(self, idx: int) -> None: if self.cooldown_seconds <= 0: return async with self._lock: self._cooldown_until[idx] = time.monotonic() + self.cooldown_seconds def token_for(self, account: dict) -> str: credentials = account["credentials"] if credentials.expired or not credentials.token: credentials.refresh(self._auth_request) return credentials.token def _convert_messages_to_vertex_native(messages: list[dict]) -> tuple[Optional[dict], list[dict]]: system_instruction = None contents = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": parts = [] if isinstance(content, str): if content.strip(): parts.append({"text": content}) elif isinstance(content, list): for item in content: if isinstance(item, str) and item.strip(): parts.append({"text": item}) elif isinstance(item, dict) and item.get("type") == "text": text = item.get("text", "") if text.strip(): parts.append({"text": text}) if parts: system_instruction = {"parts": parts} continue vertex_role = "model" if role == "assistant" else "user" parts = [] if isinstance(content, str): if content.strip(): parts.append({"text": content}) elif isinstance(content, list): for item in content: if isinstance(item, str): if item.strip(): parts.append({"text": item}) elif isinstance(item, dict): item_type = item.get("type", "") if item_type == "text": text = item.get("text", "") if text.strip(): parts.append({"text": text}) elif item_type == "image_url": image_url = item.get("image_url", {}) url = image_url.get("url", "") if isinstance(image_url, dict) else "" if url.startswith("data:"): try: header, b64_data = url.split(",", 1) mime_type = header.split(":", 1)[1].split(";", 1)[0] except (ValueError, IndexError): continue parts.append({ "inlineData": { "mimeType": mime_type, "data": b64_data, } }) elif url.startswith("gs://"): parts.append({ "fileData": { "fileUri": url, "mimeType": "image/jpeg", } }) if parts: contents.append({"role": vertex_role, "parts": parts}) return system_instruction, contents def _parse_vertex_response(data: dict) -> tuple[str, str, str]: prompt_feedback = data.get("promptFeedback", {}) block_reason = prompt_feedback.get("blockReason", "") if block_reason: raise ValueError(f"Vertex prompt blocked: {block_reason}") candidates = data.get("candidates", []) if not candidates: raise ValueError(f"Empty candidates in Vertex response: {json.dumps(data)[:500]}") candidate = candidates[0] content_text = "" reasoning_text = "" for part in candidate.get("content", {}).get("parts", []) or []: text = part.get("text", "") if not text: continue if part.get("thought", False): reasoning_text += text else: content_text += text finish_reason = candidate.get("finishReason", "STOP") return content_text, reasoning_text, finish_reason def _vertex_stop_sequences() -> list[str]: """返回 Vertex 文本生成的协议边界停止串,防止模型续写用户回合或工具响应。""" return ["<|im_end|>", "<|im_start|>user", ""] def _is_vertex_rate_limit(status: int, body: str) -> bool: body_lower = (body or "").lower() return ( status == 429 or "resource_exhausted" in body_lower or "quota" in body_lower or "rate limit" in body_lower ) def _is_vertex_transient_http_error(status: int) -> bool: """判断 Vertex HTTP 状态是否适合换账号重试,而不是记为样本错误。""" return status in {408, 409, 425, 499} or status >= 500 async def call_vertex_gemini_api( messages: list[dict], model: str, base_url: str, api_key: str, **kwargs, ) -> dict: """Call Gemini through Vertex AI native generateContent with account-pool rotation.""" request_debug = _summarize_messages_for_debug(messages) vertex_model = normalize_model_name_for_client(model, "vertex") vertex_pool = kwargs.get("vertex_account_pool") if not isinstance(vertex_pool, VertexAccountPool): return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": "vertex_account_pool is required for --model-client vertex", "request_debug": request_debug, } system_instruction, contents = _convert_messages_to_vertex_native(messages) body = { "contents": contents, "generationConfig": { "temperature": kwargs.get("temperature", 0.7), "maxOutputTokens": kwargs.get("max_tokens", 4096), "stopSequences": kwargs.get("vertex_stop_sequences") or _vertex_stop_sequences(), }, } if system_instruction: body["systemInstruction"] = system_instruction if "top_p" in kwargs: body["generationConfig"]["topP"] = kwargs["top_p"] if "top_k" in kwargs: body["generationConfig"]["topK"] = kwargs["top_k"] max_attempts = len(vertex_pool.accounts) last_error = "" for _ in range(max_attempts): account_idx, account = await vertex_pool.acquire_account() try: token = vertex_pool.token_for(account) except Exception as e: last_error = f"{type(e).__name__}: {str(e)}" await vertex_pool.mark_cooldown(account_idx) continue url = ( "https://aiplatform.googleapis.com/v1/projects/" f"{account['project_id']}/locations/{account['location']}" f"/publishers/google/models/{vertex_model}:generateContent" ) headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } try: timeout = aiohttp.ClientTimeout(total=kwargs.get("model_request_timeout", 300)) async with create_http_session(timeout) as session: async with session.post(url, headers=headers, json=body) as resp: raw_text = await resp.text() raw_error_text = _truncate_debug_text(raw_text, limit=4000) try: data = json.loads(raw_text) except json.JSONDecodeError: data = {} if resp.status != 200: error = _extract_http_error_message(resp.status, data, raw_text) last_error = f"HTTP {resp.status}: {error}" if _is_vertex_rate_limit(resp.status, raw_text) and max_attempts > 1: await vertex_pool.mark_cooldown(account_idx) continue if _is_vertex_transient_http_error(resp.status) and max_attempts > 1: continue return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": last_error, "error_status_code": resp.status, "error_raw_response": raw_error_text, "request_debug": request_debug, } content, reasoning, finish_reason = _parse_vertex_response(data) full_text = reconstruct_assistant_text( content=content, reasoning_content=reasoning, tool_calls=[], ) return { "content": content, "reasoning_content": reasoning, "tool_calls": [], "full_text": full_text, "finish_reason": finish_reason.lower(), "error": None, "request_debug": request_debug, "vertex_account": account.get("account_name", ""), "vertex_project_id": account.get("project_id", ""), } except Exception as e: last_error = f"{type(e).__name__}: {str(e)}" if "429" in last_error or "RESOURCE_EXHAUSTED" in last_error: await vertex_pool.mark_cooldown(account_idx) continue if isinstance(e, (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError, OSError)): continue return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": _truncate_debug_text(last_error, limit=1000), "error_status_code": None, "error_raw_response": "", "request_debug": request_debug, } return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": f"All Vertex accounts failed or are cooling down. Last error: {last_error[:500]}", "request_debug": request_debug, } async def call_azure_api( messages: list[dict], model: str, base_url: str, api_key: str, **kwargs, ) -> dict: """Call Azure OpenAI API.""" api_version = kwargs.get("azure_api_version", "2025-04-01-preview") body = { "model": model, "messages": messages, "temperature": kwargs.get("temperature", 0.7), "max_tokens": kwargs.get("max_tokens", 4096), } if "top_p" in kwargs: body["top_p"] = kwargs["top_p"] if "reasoning_effort" in kwargs: body["reasoning_effort"] = kwargs["reasoning_effort"] url = f"{base_url.rstrip('/')}/openai/deployments/{model}/chat/completions?api-version={api_version}" headers = {"api-key": api_key, "Content-Type": "application/json"} try: timeout = aiohttp.ClientTimeout(total=300) async with create_http_session(timeout) as session: async with session.post(url, headers=headers, json=body) as resp: data = await resp.json() if resp.status != 200: error = data.get('error', {}).get('message', f'HTTP {resp.status}') error = re.sub(r'data:image[^"]*', '[IMAGE]', str(error))[:500] return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": error, } choice = data["choices"][0] return { "content": choice["message"]["content"], "reasoning_content": "", "tool_calls": [], "full_text": choice["message"]["content"], "finish_reason": choice.get("finish_reason", "stop"), "error": None, } except Exception as e: return { "content": "", "reasoning_content": "", "tool_calls": [], "full_text": "", "finish_reason": "error", "error": str(e), } # ============================================================================= # Tool Execution # ============================================================================= VALID_TOOL_CALL_NAMES = { "choose_frames", "find_frame", "image_search", "image_search_tool", "image_zoom_in_tool", "text_search_tool", "web_search", "zoom_in", } def _strip_code_fence(text: str) -> str: content = text.strip() match = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", content, re.DOTALL | re.IGNORECASE) return match.group(1).strip() if match else content def _load_tool_call_json(content: str) -> Optional[dict]: """Load a tool call JSON object from a possibly noisy text fragment.""" content = _strip_code_fence(content) candidates = [content] decoder = json.JSONDecoder() try: obj, _ = decoder.raw_decode(content) if isinstance(obj, dict): candidates.append(json.dumps(obj, ensure_ascii=False)) except json.JSONDecodeError: pass first_brace = content.find("{") last_brace = content.rfind("}") if 0 <= first_brace < last_brace: candidates.append(content[first_brace:last_brace + 1]) for candidate in candidates: candidate = candidate.strip() while candidate: try: obj = json.loads(candidate) except json.JSONDecodeError as exc: if exc.msg == "Extra data" and 0 <= exc.pos < len(candidate): candidate = candidate[:exc.pos].strip() continue if candidate.endswith("}"): candidate = candidate[:-1].rstrip() continue break if isinstance(obj, dict) and obj.get("name") in VALID_TOOL_CALL_NAMES: return obj break return None def _extract_tool_call_json_from_text(text: str) -> Optional[dict]: """Recover a tool call JSON object when XML tags are missing or malformed.""" if '"name"' not in text and "'name'" not in text: return None for match in re.finditer(r"\{", text): candidate = _load_tool_call_json(text[match.start():]) if candidate: return candidate return None def _clean_search_query(text: str) -> str: text = re.sub(r"\[IMAGE\s+\d+\]", " ", text, flags=re.IGNORECASE) text = re.sub(r"<[^>]+>", " ", text) text = re.sub(r"\s+", " ", text).strip(" .:-\n\t") return text[:240].strip() def _extract_visual_search_hints(text: str) -> list[str]: hints: list[str] = [] patterns = [ r'visible text\s+"([^"]+)"', r'prominent text\s+"([^"]+)"', r'reads\s+"([^"]+)"', r'\(([^()]{3,80})\)', r'\b(?:as|is|likely|identified as|reads)\s+(?:the\s+)?([A-Z][A-Za-z0-9&.\' -]{2,60})', ] for pattern in patterns: for match in re.finditer(pattern, text): value = _clean_search_query(match.group(1)) if value and value.lower() not in {"the", "a", "an"} and value not in hints: hints.append(value) for match in re.finditer(r"[\u4e00-\u9fff]{2,20}", text): value = match.group(0) if value not in hints: hints.append(value) return hints[:5] def _recover_intent_tool_call(text: str, user_query: str = "") -> Optional[dict]: """Infer a tool call only when the model clearly announces a tool intent.""" if "" in text: return None intent_prefix = ( r"\b(?:i|we)\s+" r"(?:will|would|need to|should|can|must|am going to|plan to)\s+" r"(?:directly\s+)?(?:use|perform|conduct|call|do)\s+(?:an?\s+)?" ) asks_web_search = bool(re.search( intent_prefix + r"`?(?:web_search|web search|search the web)`?\b", text, flags=re.IGNORECASE, )) asks_image_search = bool(re.search( intent_prefix + r"`?(?:image_search|image search|reverse image search)`?\b", text, flags=re.IGNORECASE, )) if asks_web_search: query_parts = [] cleaned_user_query = _clean_search_query(user_query) if cleaned_user_query: query_parts.append(cleaned_user_query) query_parts.extend(_extract_visual_search_hints(text)) if not query_parts: match = re.search( r"(?:web_search|web search|search the web)\s+(?:to|for)?\s*" r"(?:find|determine|identify|look up|verify|confirm)?\s*([^.\n<]{8,180})", text, flags=re.IGNORECASE, ) if match: query_parts.append(_clean_search_query(match.group(1))) query = _clean_search_query(" ".join(part for part in query_parts if part)) if query: return {"name": "web_search", "arguments": {"query": query}} if asks_image_search: return {"name": "image_search", "arguments": {"bbox": [0, 0, 1000, 1000]}} return None def parse_tool_call(text: str, user_query: str = "") -> Optional[dict]: """Parse or conservatively recover a tool call from model output.""" match = re.search(r'\s*(.*?)\s*', text, re.DOTALL) if match: parsed = _load_tool_call_json(match.group(1).strip()) if parsed: return parsed parsed = _extract_tool_call_json_from_text(text) if parsed: return parsed return _recover_intent_tool_call(text, user_query=user_query) def sanitize_assistant_tool_turn(text: str) -> str: """工具回合只保留 assistant 自己的工具调用,截掉模型续写的用户回合。""" text = text or "" match = re.search(r"\s*.*?", text, re.DOTALL | re.IGNORECASE) if match: return text[:match.end()].strip() cutoff = len(text) for marker in ("<|im_start|>user", "", "<|im_end|>"): idx = text.find(marker) if idx >= 0: cutoff = min(cutoff, idx) return text[:cutoff].strip() def extract_tool_name_and_args(tool_call: Optional[dict]) -> tuple[str, dict]: """兼容嵌套 `arguments` 与扁平 JSON 两种 tool_call。""" if not isinstance(tool_call, dict): return "", {} tool_name = tool_call.get("name", "") arguments = tool_call.get("arguments") if isinstance(arguments, dict): return tool_name, arguments return tool_name, {k: v for k, v in tool_call.items() if k != "name"} def extract_tool_call_from_result(result: dict, assistant_text: str, user_query: str = "") -> Optional[dict]: """优先读取结构化 tool_calls,回退到文本里的 。""" tool_calls = result.get("tool_calls") or [] if tool_calls: first = tool_calls[0] if isinstance(first, dict): return { "name": first.get("name", ""), "arguments": first.get("arguments", {}), } return parse_tool_call(assistant_text, user_query=user_query) FORMAT_REPAIR_PROMPT = """ Your previous assistant message was empty or did not follow the required protocol, so no valid tool call or final answer could be parsed. Continue from the current evidence. Return exactly ONE of the following forms: brief reasoning {"name": "", "arguments": {...}} or: brief reasoning final answer only Do not output an empty message. Do not output an empty . Do not put the final answer only inside . """ FINAL_ANSWER_PROMPT = """ No more tool calls are available. Use only the video evidence and tool results already collected. You must now provide the final answer in exactly this form: brief reasoning based on the collected evidence final answer only Do not call any tool. Do not search again. Do not output an empty message. """ FINAL_ANSWER_REPAIR_PROMPT = """ Your previous message still did not provide a final answer. Tool calls are disabled. Return exactly: brief reasoning based on the collected evidence final answer only Do not call any tool. """ FINAL_ONLY_SYSTEM_PROMPT = """Use the provided evidence to answer the question. Continue the assistant prefix with the final answer text only. """ FINAL_ONLY_REPAIR_PROMPT = """Your previous message still was not a valid final answer. Use the provided evidence and continue the assistant prefix with the final answer text only. """ def _has_answer_tag(text: str) -> bool: return bool(re.search(r".*?", text or "", re.DOTALL)) def _extract_last_answer_tag(text: str) -> str: matches = list(re.finditer(r"(.*?)", text or "", re.DOTALL)) if not matches: return "" return matches[-1].group(1).strip() def _has_tool_call_tag(text: str) -> bool: return " bool: text = (text or "").strip() if not text: return True if _has_tool_call_tag(text): return True if "" in text and "" not in text: return True match = re.search(r"\s*(.*?)\s*", text, re.DOTALL) if match and not match.group(1).strip(): return True return False def _looks_like_unfinished_tool_intent(text: str) -> bool: """Detect prose that asks to use a tool but failed to emit a valid call.""" text = (text or "").lower() if not text: return True intent_patterns = [ r"\b(?:i|we)\s+(?:will|would|need to|should|must|am going to|plan to)\s+" r"(?:use|call|perform|conduct|do)\s+(?:the\s+)?(?:web_search|image_search|web search|image search|reverse image search|zoom_in|zoom in|find_frame|choose_frames)", r"\b(?:next|now)\s+(?:i|we)\s+(?:will|would|should|need to)\s+" r"(?:search|zoom|inspect|call|use|look up)", ] return any(re.search(pattern, text) for pattern in intent_patterns) def _looks_like_unfinished_no_tool_answer(text: str) -> bool: """Detect no-tool prose that is still analysis instead of a final answer.""" raw = (text or "").strip() if not raw: return True if _has_answer_tag(raw): return False lowered = raw.lower() # A reasoning block without an answer tag is usually an unfinished tool-mode # turn, even if it contains useful partial conclusions. if re.search(r"", raw, flags=re.IGNORECASE): return True plain = _strip_protocol_text(raw) if not plain: return True if len(plain) > 1200: return True if plain[-1:] in {",", ";", ":"}: return True unfinished_patterns = [ r"\b(?:i|we)\s+(?:need to|should|must|will|would|can|could|am going to|plan to)\s+" r"(?:find|confirm|verify|check|search|look up|identify|zoom|inspect|use|call)", r"\b(?:let me|let's|now i|next i|first i|i'll|i will)\b", r"\b(?:it could be|could be|might be|may be|likely|probably|possibly)\b", r"\b(?:wait|actually|however|but hold on|not sure|unclear)\b", r"\bcommon .* (?:are|include):", r"\b(?:first|second|third),?\s+(?:i|we)\b", r"\b(?:step|plan|search strategy)\b", r"\n\s*(?:1\.|2\.|3\.|\*)\s+", ] return any(re.search(pattern, lowered, flags=re.DOTALL) for pattern in unfinished_patterns) def _strip_protocol_text(text: str) -> str: text = re.sub(r"<\|im_end\|>", " ", text or "") text = re.sub(r"", " ", text, flags=re.IGNORECASE) text = re.sub(r"", " ", text, flags=re.IGNORECASE) text = re.sub(r"", " ", text, flags=re.IGNORECASE) return " ".join(text.split()).strip() def _recover_answer_from_no_tool_output(result: dict, assistant_text: str) -> Optional[str]: """Return a final-answer text when no tool call exists but the model clearly answered.""" assistant_text = assistant_text or "" if _has_answer_tag(assistant_text): if _has_tool_call_tag(assistant_text): answer = _extract_last_answer_tag(assistant_text) return f"{answer}" if answer else None return assistant_text content = _message_content_to_text(result.get("content")).strip() reasoning = _message_content_to_text(result.get("reasoning_content")).strip() # If the model produced visible content without tool intent, treat it as the # final answer. This avoids labeling plain-answer turns as protocol errors. content_plain = _strip_protocol_text(content) if ( content_plain and not _has_tool_call_tag(content) and not _looks_like_unfinished_tool_intent(content) and not _looks_like_unfinished_no_tool_answer(content) ): return content # Some Qwen3/SGLang deployments put the final answer entirely in # reasoning_content and leave content=None. Recover only direct-looking # answers; malformed tool shells and tool intents should be retried instead. reasoning_plain = _strip_protocol_text(reasoning) if ( reasoning_plain and len(reasoning_plain) <= 800 and not _looks_like_malformed_tool_shell(reasoning) and not _looks_like_unfinished_tool_intent(reasoning) and not _looks_like_unfinished_no_tool_answer(reasoning) ): return f"\n{reasoning}\n\n{reasoning_plain}" return None def _make_format_repair_turn(assistant_text: str) -> tuple[dict | None, dict, str]: """Build messages/output snippet for one protocol-repair retry.""" assistant_message = {"role": "assistant", "content": assistant_text} if assistant_text else None user_message = {"role": "user", "content": FORMAT_REPAIR_PROMPT} output_snippet = ( f"{assistant_text}<|im_end|><|im_start|>user\n" f"{FORMAT_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" ) return assistant_message, user_message, output_snippet def _build_final_only_user_prompt(question: str, output_parts: list[str], max_chars: int = 8000) -> str: transcript = "".join(output_parts) responses = re.findall(r"\s*(.*?)\s*", transcript, re.DOTALL) cleaned_responses = [] for idx, response in enumerate(responses, 1): text = re.sub(r"<\|im_start\|>.*", "", response, flags=re.DOTALL) text = re.sub(r"\n{3,}", "\n\n", text).strip() if text: cleaned_responses.append(f"[Evidence {idx}]\n{text}") evidence_text = "\n\n".join(cleaned_responses) if len(evidence_text) > max_chars: evidence_text = evidence_text[-max_chars:] evidence_text = "[Evidence truncated to the most recent collected results]\n" + evidence_text return ( "Question:\n" f"{question}\n\n" "Evidence:\n" f"{evidence_text or 'No textual evidence was collected.'}" ) def _with_final_only_extra_body(kwargs: dict) -> dict: extra_body = dict(kwargs.get("extra_body") or {}) chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {}) chat_template_kwargs["enable_thinking"] = False extra_body["chat_template_kwargs"] = chat_template_kwargs extra_body["continue_final_message"] = True extra_body["stop"] = ["", " Optional[str]: candidates = [ _message_content_to_text(result.get("content")).strip(), (assistant_text or "").strip(), ] for text in candidates: if not text: continue text = re.sub(r"^```(?:json)?\s*", "", text.strip(), flags=re.IGNORECASE) text = re.sub(r"\s*```$", "", text.strip()) try: data = json.loads(text) except json.JSONDecodeError: continue if not isinstance(data, dict): continue reasoning = str(data.get("reasoning") or "").strip() answer = str(data.get("answer") or "").strip() if not answer: continue if _has_tool_call_tag(reasoning) or _has_tool_call_tag(answer): continue return f"{reasoning}\n{answer}" return None async def call_final_only_answer( client: str, model: str, base_url: str, api_key: str, question: str, output_parts: list[str], tool_calls: list[dict], sample_id: str, retry_limit: int, **kwargs, ) -> dict: final_messages = [ {"role": "system", "content": FINAL_ONLY_SYSTEM_PROMPT}, {"role": "user", "content": _build_final_only_user_prompt(question, output_parts)}, {"role": "assistant", "content": ""}, ] transcript_parts = [ f"<|im_start|>system\n{FINAL_ONLY_SYSTEM_PROMPT}<|im_end|>\n", f"<|im_start|>user\n{final_messages[1]['content']}<|im_end|>\n<|im_start|>assistant\n", ] last_result = {} last_output = "" retries_used = 0 for attempt in range(max(0, retry_limit) + 1): api_kwargs = dict(kwargs) api_kwargs["request_debug_context"] = { "sample_id": sample_id, "turn": f"final_only_{attempt}", "task_kind": "video_final_only", } api_kwargs["extra_body"] = _with_final_only_extra_body(api_kwargs) if client == "gemini": result = await call_gemini_api(final_messages, model, base_url, api_key, **api_kwargs) elif client == "azure": result = await call_azure_api(final_messages, model, base_url, api_key, **api_kwargs) elif client == "vertex": result = await call_vertex_gemini_api(final_messages, model, base_url, api_key, **api_kwargs) elif client == "gateway": result = await call_gateway_api(final_messages, model, base_url, api_key, **api_kwargs) else: result = await call_openai_api(final_messages, model, base_url, api_key, **api_kwargs) last_result = result last_output = result.get("full_text") or result.get("content", "") if result.get("error"): transcript_parts.append(last_output + "<|im_end|>") break answer_text = _extract_last_answer_tag(last_output) or _strip_protocol_text(last_output) recovered_answer = None if ( answer_text and not _has_tool_call_tag(last_output) and not _looks_like_unfinished_tool_intent(last_output) ): recovered_answer = f"{answer_text}" if not recovered_answer: recovered_answer = _recover_final_only_json_output(result, last_output) if not recovered_answer: recovered_answer = _recover_answer_from_no_tool_output(result, last_output) if recovered_answer and not _has_tool_call_tag(recovered_answer): if recovered_answer != last_output: last_output = recovered_answer transcript_output = last_output if transcript_output.startswith(""): transcript_output = transcript_output[len(""):] transcript_parts.append(transcript_output + "<|im_end|>") final_messages[-1] = {"role": "assistant", "content": last_output} return { "output": last_output, "transcript": "".join(transcript_parts), "messages": final_messages, "finish_reason": "answer", "result": result, "retries": retries_used, } transcript_parts.append(last_output + "<|im_end|>") final_messages[-1] = {"role": "assistant", "content": f"{last_output}"} if attempt < retry_limit: final_messages.append({"role": "user", "content": FINAL_ONLY_REPAIR_PROMPT}) final_messages.append({"role": "assistant", "content": ""}) transcript_parts.append( f"\n<|im_start|>user\n{FINAL_ONLY_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" ) retries_used += 1 return { "output": last_output, "transcript": "".join(transcript_parts), "messages": final_messages, "finish_reason": "error" if last_result.get("error") else "max_turns", "result": last_result, "retries": retries_used, } # Summarization prompt (mmsearch_r1 style - 5 sentence summary) # Used for both per-URL summaries and final summary SUMMARY_SYSTEM_PROMPT = """You are a helpful assistant. Your task is to summarize the main content of the given web page in no more than five sentences. Your summary should cover the overall key points of the page, not just parts related to the user's question. If any part of the content is helpful for answering the user's question, be sure to include it clearly in the summary. Do not ignore relevant information, but also make sure the general structure and main ideas of the page are preserved. Your summary should be concise, factual, and informative.""" SUMMARY_USER_PROMPT = """Webpage Content (first {content_limit} characters) is: {content} Question: {query}""" # Skip extensions for non-HTML resources # Aligned with server config.json excluded_extensions SKIP_EXTENSIONS = ['.pdf', '.doc', '.docx', '.ppt', '.pptx', '.xls', '.xlsx', '.jpg', '.jpeg', '.png', '.gif'] # Global browser context for reuse _playwright = None _browser = None _browser_context = None _browser_lock = asyncio.Lock() _playwright_fetch_semaphore = asyncio.Semaphore( max(1, int(os.environ.get("WEB_FETCH_PLAYWRIGHT_CONCURRENCY", "4"))) ) def _is_browser_closed_error(error: Exception) -> bool: """Return True when Playwright's shared browser/context must be recreated.""" text = str(error) closed_markers = [ "Target page, context or browser has been closed", "BrowserContext.new_page", "BrowserType.launch", "browser has been closed", "context has been closed", ] return any(marker in text for marker in closed_markers) async def _get_browser_context(): """Get or create a shared browser context.""" global _playwright, _browser, _browser_context async with _browser_lock: if _browser is not None and not _browser.is_connected(): await _cleanup_browser_unlocked() if _browser_context is None: _playwright = await async_playwright().start() _browser = await _playwright.chromium.launch(headless=True) # Use Chrome on Windows user agent _browser_context = await _browser.new_context( user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.6261.57 Safari/537.36" ) return _browser_context async def _cleanup_browser(): """Cleanup browser resources.""" async with _browser_lock: await _cleanup_browser_unlocked() async def _cleanup_browser_unlocked(): """Cleanup browser resources. Caller must hold _browser_lock.""" global _playwright, _browser, _browser_context if _browser_context: try: await asyncio.wait_for(_browser_context.close(), timeout=5.0) except Exception: pass _browser_context = None if _browser: try: await asyncio.wait_for(_browser.close(), timeout=5.0) except Exception: pass _browser = None if _playwright: try: await asyncio.wait_for(_playwright.stop(), timeout=5.0) except Exception: pass _playwright = None # Bot detection status codes BOT_DETECTION_CODES = {403, 429, 406, 418, 421, 451} # Tracking domains to block TRACKING_DOMAINS = [ # Google tracking and ads "*google-analytics.com*", "*googletagmanager.com*", "*doubleclick.net*", "*googleadservices.com*", "*googlesyndication.com*", "*googletagservices.com*", # Facebook/Meta tracking "*facebook.com/tr*", "*connect.facebook.net*", "*facebook.net*", # Other major tracking networks "*hotjar.com*", "*mixpanel.com*", "*segment.com*", "*amplitude.com*", "*fullstory.com*", "*logrocket.com*", "*mouseflow.com*", # Ad networks and exchanges "*adsystem.com*", "*pubmatic.com*", "*rubiconproject.com*", "*amazon-adsystem.com*", "*adsafeprotected.com*", # Analytics and tracking "*newrelic.com*", "*nr-data.net*", "*pingdom.net*", "*optimizely.com*", "*quantserve.com*", "*scorecardresearch.com*" ] async def fetch_url_content(url: str, timeout_sec: int = 20) -> Optional[str]: """Fetch URL content using Playwright with JS rendering. Returns None for skip cases (non-HTML, skip extensions). Raises URLFetchError for retriable errors (timeout, HTTP error, etc.). """ # Skip case: non-HTML resource extensions (don't retry) if any(url.lower().endswith(ext) for ext in SKIP_EXTENSIONS): return None async with _playwright_fetch_semaphore: return await _fetch_url_content_with_playwright(url, timeout_sec) async def _fetch_url_content_with_playwright(url: str, timeout_sec: int = 20) -> Optional[str]: """Fetch one URL while holding the Playwright fetch semaphore.""" from bs4 import BeautifulSoup page = None try: context = await _get_browser_context() page = await context.new_page() # Block unnecessary resources excluded_resource_types = ["image", "stylesheet", "font", "media", "websocket", "eventsource", "manifest"] async def resource_handler(route): if route.request.resource_type in excluded_resource_types: await route.abort() else: await route.continue_() await page.route("**/*", resource_handler) # Block tracking domains async def tracking_handler(route): await route.abort() for domain_pattern in TRACKING_DOMAINS: try: await page.route(domain_pattern, tracking_handler) except: pass # wait_until='domcontentloaded' try: response = await asyncio.wait_for( page.goto(url, timeout=timeout_sec * 1000, wait_until='domcontentloaded'), timeout=25.0 ) except asyncio.TimeoutError: raise URLFetchError(f"Timeout fetching {url}") if response is None: raise URLFetchError(f"No response from {url}") if response.status >= 400: raise URLFetchError(f"HTTP {response.status} from {url}") # Skip case: non-HTML content type (don't retry) content_type = response.headers.get('content-type', '') if 'text/html' not in content_type: return None try: html = await asyncio.wait_for(page.content(), timeout=15.0) except asyncio.TimeoutError: raise URLFetchError(f"Content timeout for {url}") # Parse with BeautifulSoup soup = BeautifulSoup(html, "lxml") for script in soup(["script", "style"]): script.decompose() raw_content = soup.get_text(separator='\n', strip=True) if not raw_content: raise URLFetchError(f"Empty content from {url}") return raw_content except URLFetchError: raise # Re-raise URLFetchError for retry logic except Exception as e: if _is_browser_closed_error(e): await _cleanup_browser() raise URLFetchError(f"Error fetching {url}: {e}") finally: if page: try: await asyncio.wait_for(page.close(), timeout=5.0) except: pass def _clean_think_blocks(text: str) -> str: """Remove ... blocks from response.""" if not text: return text return re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE).strip() async def summarize_content( query: str, content: str, summarizer_base_url: str, summarizer_model: str, content_limit: int = 30000, max_retries: int = 5, # Aligned with server config.json max_try_times=5 ) -> Optional[str]: """Summarize content with LLM using mmsearch_r1 style prompt.""" # Server has no minimum content check - send whatever content we have if not content: return None messages = [ {"role": "system", "content": SUMMARY_SYSTEM_PROMPT}, {"role": "user", "content": SUMMARY_USER_PROMPT.format( query=query, content=content[:content_limit], content_limit=content_limit, )}, ] # Retry logic aligned with server's LLMGenerator (fixed 1s backoff) # Note: Server does NOT set temperature - uses API default (typically 1.0) # For Qwen3 models, disable thinking mode extra_body = None if "qwen3" in summarizer_model.lower(): extra_body = {"chat_template_kwargs": {"enable_thinking": False}} for attempt in range(max_retries): try: result = await call_openai_api( messages=messages, model=summarizer_model, base_url=summarizer_base_url, api_key="", max_tokens=8192, extra_body=extra_body, ) if result.get("error"): if attempt < max_retries - 1: await asyncio.sleep(1) # Fixed 1s backoff like server continue return None # Clean blocks from response return _clean_think_blocks(result["content"]) except Exception as e: if attempt < max_retries - 1: await asyncio.sleep(1) # Fixed 1s backoff like server else: print(f"[SUMMARIZER ERROR] {e}") return None return None async def call_text_search( query: str, serper_api_key: str, summarizer_base_url: str, summarizer_model: str, serper_semaphore: asyncio.Semaphore, serper_concurrency: int = 5, top_k: int = 3, content_limit: int = 30000, search_cache: Optional["SearchCache"] = None, max_serper_attempts: int = 3, ) -> str: """Call Google Serper API for text search, fetch each URL, summarize each, then generate final summary. Flow: 1. Check cache (if enabled) 2. Search via Serper to get URLs (with retry) 3. For each URL: fetch content and summarize with LLM 4. Generate final summary from all per-URL summaries 5. Store in cache (if enabled) """ # Step 0: Check cache if search_cache: cached = await search_cache.get(query, top_k, summarizer_model) if cached: return f"Found cached summary for query: {query}\n{cached}" # Step 1: Search via Serper (rate limited by semaphore) # Search API doesn't use "num" parameter - gets default results then slices to top_k # Retry logic with max_attempts=3 and exponential backoff url = "https://google.serper.dev/search" headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"} payload = {"q": query} all_search_results = None search_results = None last_error = None for attempt in range(max_serper_attempts): async with serper_semaphore: try: async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: if resp.status != 200: error_text = await resp.text() raise Exception(f"HTTP {resp.status}: {error_text}") data = await resp.json() organic_results = data.get("organic", []) if not organic_results: return f"No search results found for query: {query}" search_results = organic_results break # Success, exit retry loop except Exception as e: last_error = e if attempt < max_serper_attempts - 1: # Exponential backoff: 1.0 * (attempt + 1) seconds backoff = 1.0 * (attempt + 1) print(f"[SERPER RETRY] Attempt {attempt + 1}/{max_serper_attempts} failed: {e}, retrying in {backoff}s...", flush=True) await asyncio.sleep(backoff) # If all retries failed, raise error if search_results is None: error_msg = str(last_error) if last_error and str(last_error) else "Unknown error" raise FatalAPIError(f"[SERPER ERROR] Failed after {max_serper_attempts} attempts: {error_msg}") from last_error result = await summarize_serper_organic_results( query=query, organic_results=search_results, summarizer_base_url=summarizer_base_url, summarizer_model=summarizer_model, top_k=top_k, content_limit=content_limit, ) # Store in cache if search_cache and not result.startswith("Error:"): await search_cache.set(query, top_k, summarizer_model, result) return f"Final summary generated for query: {query}\n{result}" async def summarize_serper_organic_results( query: str, organic_results: list[dict], summarizer_base_url: str, summarizer_model: str, top_k: int = 3, content_limit: int = 30000, ) -> str: """Fetch Serper organic result URLs and summarize them with the text_search_tool pipeline.""" all_search_results = [ {'title': res.get('title', ''), 'snippet': res.get('snippet', ''), 'link': res.get('link', '')} for res in organic_results ] search_results = all_search_results[:top_k] # Step 2: Fetch each URL and summarize in PARALLEL # 3 retries with exponential backoff async def fetch_and_summarize(item, max_retries: int = 3): item_url = item.get("link", "") if not item_url: return None for attempt in range(max_retries): try: content = await fetch_url_content(item_url) if not content: # No content but no exception - don't retry (e.g., non-HTML, skip extensions) _web_fetch_stats.record_skip() return None summary = await summarize_content( query=query, content=content, summarizer_base_url=summarizer_base_url, summarizer_model=summarizer_model, content_limit=content_limit, ) _web_fetch_stats.record_success() return summary except Exception as e: if attempt == max_retries - 1: print(f"[WEB] Failed to fetch {item_url} after {max_retries} attempts: {e}", flush=True) _web_fetch_stats.record_failure(str(e)) return None # Exponential backoff: min(2^attempt, 5) backoff_time = min(2 ** attempt, 5) await asyncio.sleep(backoff_time) return None tasks = [fetch_and_summarize(item) for item in search_results] summaries = await asyncio.gather(*tasks, return_exceptions=True) # Filter out exceptions summaries = [s if not isinstance(s, Exception) else None for s in summaries] # Step 3: Format all content and generate final summary # Use ALL search results for Search Results section # But only top_k summaries in Web Page Summaries section all_content = f"Search Query: {query}\n\n" all_content += "--- Search Results ---\n" for i, result in enumerate(all_search_results, 1): all_content += f"[{i}] Title: {result.get('title', '')}\n" all_content += f" Snippet: {result.get('snippet', 'No snippet')}\n" all_content += f" Link: {result.get('link', 'No link')}\n\n" all_content += "--- Web Page Summaries ---\n" for summary in summaries: if summary: all_content += f"{summary}\n\n" # Generate final summary using same prompt as per-URL summaries final_summary = await summarize_content( query=query, content=all_content, summarizer_base_url=summarizer_base_url, summarizer_model=summarizer_model, content_limit=content_limit, ) # Return error if final summary fails if not final_summary: return f"Error: Failed to generate final summary for query: {query}" return final_summary def _coerce_tavily_option(value: str): """把 CLI/env 字符串转成 Tavily API 可接受的布尔值或枚举字符串。""" normalized = (value or "").strip().lower() if normalized in {"", "none", "null"}: return False if normalized in {"1", "true", "yes", "y", "on"}: return True if normalized in {"0", "false", "no", "n", "off"}: return False return normalized def _split_csv_env(value: str) -> list[str]: """解析逗号分隔的域名过滤配置。""" return [item.strip() for item in (value or "").split(",") if item.strip()] def _split_secret_list(value: str) -> list[str]: """解析 API key 列表,支持逗号、分号和空白分隔。""" return [item.strip() for item in re.split(r"[,;\s]+", value or "") if item.strip()] def load_tavily_api_keys() -> list[str]: """从环境变量和可选文件加载 Tavily API key 池。""" keys: list[str] = [] keys.extend(_split_secret_list(os.environ.get("TAVILY_API_KEYS", ""))) key_file = os.environ.get("TAVILY_API_KEY_FILE", DEFAULT_TAVILY_API_KEY_FILE).strip() if key_file: try: with open(key_file, "r", encoding="utf-8") as f: for line in f: item = line.strip() if not item or item.startswith("#"): continue keys.extend(_split_secret_list(item)) except OSError as e: print(f"[WARN] Failed to read TAVILY_API_KEY_FILE={key_file}: {e}", flush=True) single_key = os.environ.get("TAVILY_API_KEY", "").strip() if single_key: keys.append(single_key) deduped = [] seen = set() for key in keys: if key not in seen: deduped.append(key) seen.add(key) return deduped class TavilyApiKeyPool: """Tavily API key 轮询池,遇到限流或服务错误时临时冷却单个 key。""" def __init__(self, keys: list[str], cooldown_seconds: int = 60): self.keys = list(keys) self.cooldown_seconds = max(1, int(cooldown_seconds or 60)) self._cooldown_until = {key: 0.0 for key in self.keys} self._cursor = 0 self._lock = asyncio.Lock() def __len__(self) -> int: return len(self.keys) async def next_key(self, exclude: Optional[set[str]] = None) -> str: if not self.keys: return "" async with self._lock: excluded = exclude or set() available_keys = [key for key in self.keys if key not in excluded] if not available_keys: return "" now = time.time() n = len(self.keys) for _ in range(n): key = self.keys[self._cursor % n] self._cursor += 1 if key in excluded: continue if self._cooldown_until.get(key, 0.0) <= now: return key # 如果候选 key 全部都在冷却中,仍按最早恢复顺序尝试一次,确保本次请求 # 能明确判断“全池失败”并触发上层止损,而不是永久等待。 key = min(available_keys, key=lambda k: self._cooldown_until.get(k, 0.0)) self._cursor += 1 return key async def mark_failure(self, key: str, status: int | None = None) -> None: if not key: return if status in {401, 403, 432}: cooldown = max(self.cooldown_seconds, 3600) elif status == 429: cooldown = self.cooldown_seconds elif status and status >= 500: cooldown = min(self.cooldown_seconds, 30) else: cooldown = min(self.cooldown_seconds, 10) async with self._lock: self._cooldown_until[key] = time.time() + cooldown def _truncate_text(value: str, limit: int = 1400) -> str: """限制 Tavily 返回给模型的单条结果长度,避免工具响应过长。""" text = re.sub(r"\s+", " ", value or "").strip() if len(text) <= limit: return text return text[: limit - 3].rstrip() + "..." async def call_tavily_search( query: str, session: aiohttp.ClientSession, api_key: str = "", api_key_pool: Optional[TavilyApiKeyPool] = None, search_depth: str = "advanced", max_results: int = 8, include_answer: str = "advanced", include_raw_content: str = "false", topic: str = "general", auto_parameters: bool = False, include_domains: Optional[list[str]] = None, exclude_domains: Optional[list[str]] = None, timeout_seconds: int = 60, search_cache: Optional["SearchCache"] = None, ) -> str: """调用 Tavily Search API,直接返回 answer 与搜索结果,不再二次 summarizer。""" clean_query = query.strip().replace("\n", " ") if not clean_query: return "Error: Tavily search requires a non-empty query." include_answer_value = _coerce_tavily_option(include_answer) include_raw_value = _coerce_tavily_option(include_raw_content) max_results = max(1, min(int(max_results or 8), 20)) cache_model = ( f"tavily:depth={search_depth}:answer={include_answer_value}:" f"raw={include_raw_value}:topic={topic}:auto={int(bool(auto_parameters))}" ) if search_cache: cached = await search_cache.get(clean_query, max_results, cache_model) if cached: return cached if not api_key and (not api_key_pool or len(api_key_pool) == 0): return "Error: TAVILY_API_KEY/TAVILY_API_KEYS not configured." payload = { "query": clean_query, "topic": topic, "search_depth": search_depth, "max_results": max_results, "include_answer": include_answer_value, "include_raw_content": include_raw_value, "auto_parameters": bool(auto_parameters), } if include_domains: payload["include_domains"] = include_domains if exclude_domains: payload["exclude_domains"] = exclude_domains max_attempts = max(1, len(api_key_pool) if api_key_pool else 1) last_error = "" data = None attempted_keys: set[str] = set() for attempt in range(max_attempts): selected_key = await api_key_pool.next_key(exclude=attempted_keys) if api_key_pool else api_key if not selected_key: break attempted_keys.add(selected_key) try: headers = { "Authorization": f"Bearer {selected_key}", "Content-Type": "application/json", } async with session.post( "https://api.tavily.com/search", headers=headers, json=payload, timeout=aiohttp.ClientTimeout(total=timeout_seconds), ) as resp: body = await resp.text() if resp.status == 200: try: data = json.loads(body) except json.JSONDecodeError as e: last_error = f"Error: Tavily search returned invalid JSON: {e}" if api_key_pool: await api_key_pool.mark_failure(selected_key, None) if attempt < max_attempts - 1: continue break break last_error = f"Error: Tavily search returned HTTP {resp.status}: {body[:500]}" if api_key_pool: await api_key_pool.mark_failure(selected_key, resp.status) if attempt < max_attempts - 1: continue break return last_error except asyncio.TimeoutError: last_error = f"Error: Tavily search timeout for query: {clean_query[:100]}" if api_key_pool: await api_key_pool.mark_failure(selected_key, None) if attempt < max_attempts - 1: continue break return last_error except aiohttp.ClientError as e: last_error = f"Error: Tavily search failed: {e}" if api_key_pool: await api_key_pool.mark_failure(selected_key, None) if attempt < max_attempts - 1: continue break return last_error except Exception as e: last_error = f"Error: Tavily search failed: {e}" if api_key_pool: await api_key_pool.mark_failure(selected_key, None) if attempt < max_attempts - 1: continue break return last_error if data is None: if api_key_pool: raise FatalAPIError( "[TAVILY ERROR] All Tavily API keys failed for one search request " f"after {len(attempted_keys)}/{max_attempts} attempts. Last error: " f"{last_error or 'no key available'}" ) return last_error or "Error: Tavily search failed before receiving a response." lines = [ f'Web Search Results for "{clean_query}" ' f"(Tavily, depth={search_depth}, max_results={max_results}):" ] answer = (data.get("answer") or "").strip() if answer: lines.extend(["", f"Answer: {answer}"]) results = data.get("results") or [] if results: lines.append("") lines.append("Results:") for idx, item in enumerate(results, 1): title = (item.get("title") or "").strip() url = (item.get("url") or "").strip() score = item.get("score") content = _truncate_text(item.get("content") or "") raw_content = item.get("raw_content") lines.append(f"{idx}. {title or 'Untitled'}") if url: lines.append(f" URL: {url}") if score is not None: lines.append(f" Score: {score}") if content: lines.append(f" Content: {content}") if raw_content: lines.append(f" Raw Content: {_truncate_text(str(raw_content), 1800)}") else: lines.extend(["", "No Tavily results returned."]) request_id = data.get("request_id") response_time = data.get("response_time") if request_id or response_time: suffix = [] if response_time: suffix.append(f"response_time={response_time}") if request_id: suffix.append(f"request_id={request_id}") lines.extend(["", "Metadata: " + ", ".join(suffix)]) result = "\n".join(lines) if search_cache: await search_cache.set(clean_query, max_results, cache_model, result) return result def _default_web_search_backend() -> str: """读取 web_search 默认后端;未配置时使用企业 Serper 网关。""" return os.environ.get( "VIDEO_DR_WEB_SEARCH_BACKEND", os.environ.get("WEB_SEARCH_BACKEND", "serper_gateway"), ) def _normalize_web_search_backend(raw_backend: str) -> str: """归一化 web_search 后端别名。""" backend = (raw_backend or _default_web_search_backend()).strip().lower().replace("-", "_") aliases = { "gateway": "serper_gateway", "gateway_serper": "serper_gateway", "internal_serper": "serper_gateway", "company_serper": "serper_gateway", } return aliases.get(backend, backend) def _get_gateway_text_search_config(kwargs: dict) -> tuple[str, str, str, str]: """从 kwargs/env/VideoDR 配置读取企业 Serper 网关参数。""" gateway_url = ( kwargs.get("serper_gateway_url") or os.environ.get("SERPER_GATEWAY_URL") or os.environ.get("GATEWAY_URL") or VIDEO_DR_GATEWAY_URL ) gateway_username = ( kwargs.get("serper_gateway_username") or os.environ.get("SERPER_GATEWAY_USERNAME") or os.environ.get("GATEWAY_USERNAME") or VIDEO_DR_GATEWAY_USERNAME ) gateway_userid = ( kwargs.get("serper_gateway_userid") or os.environ.get("SERPER_GATEWAY_USERID") or os.environ.get("GATEWAY_USERID") or VIDEO_DR_GATEWAY_USERID ) gateway_token = ( kwargs.get("serper_gateway_token") or os.environ.get("SERPER_GATEWAY_TOKEN") or os.environ.get("GATEWAY_TOKEN") or VIDEO_DR_GATEWAY_TOKEN ) return gateway_url, gateway_username, gateway_userid, gateway_token def _parse_gateway_model_output(gateway_resp: dict) -> dict: """解析企业网关返回的 model_output 字段。""" raw_output = gateway_resp.get("model_output", {}) if isinstance(raw_output, str): return json.loads(raw_output or "{}") if isinstance(raw_output, dict): return raw_output raise ValueError(f"Unexpected gateway model_output type: {type(raw_output).__name__}") def _format_serper_text_results(data: dict, max_results: int = 8) -> str: """把 Serper 文搜结果格式化成工具响应中的可读文本。""" lines = [] answer_box = data.get("answerBox") or {} if answer_box: answer = answer_box.get("answer") or answer_box.get("snippet") or "" title = answer_box.get("title") or "" link = answer_box.get("link") or "" if answer: lines.append("Quick Answer:") if title: lines.append(f" Title: {title}") lines.append(f" Answer: {answer}") if link: lines.append(f" URL: {link}") lines.append("") knowledge_graph = data.get("knowledgeGraph") or {} if knowledge_graph: kg_title = knowledge_graph.get("title") or "" kg_type = knowledge_graph.get("type") or "" kg_desc = knowledge_graph.get("description") or "" kg_link = knowledge_graph.get("descriptionLink") or knowledge_graph.get("website") or "" lines.append("Knowledge Graph:") if kg_title: lines.append(f" Title: {kg_title}") if kg_type: lines.append(f" Type: {kg_type}") if kg_desc: lines.append(f" Description: {kg_desc}") if kg_link: lines.append(f" URL: {kg_link}") lines.append("") organic = data.get("organic") or [] if organic: lines.append("Organic Results:") for idx, item in enumerate(organic[:max_results], 1): title = item.get("title") or "" snippet = item.get("snippet") or "" link = item.get("link") or "" source = item.get("source") or item.get("domain") or "" lines.append(f"Result {idx}:") if title: lines.append(f" Title: {title}") if snippet: lines.append(f" Snippet: {snippet}") if source: lines.append(f" Source: {source}") if link: lines.append(f" URL: {link}") lines.append("") people_also_ask = data.get("peopleAlsoAsk") or [] if people_also_ask: lines.append("People Also Ask:") for idx, item in enumerate(people_also_ask[: min(3, max_results)], 1): question = item.get("question") or "" snippet = item.get("snippet") or "" link = item.get("link") or "" lines.append(f"{idx}. {question}") if snippet: lines.append(f" {snippet}") if link: lines.append(f" URL: {link}") lines.append("") formatted = "\n".join(lines).strip() return formatted or "No relevant results found." def _build_serper_text_summary_prompt(query: str, formatted_results: str) -> str: """构造企业 Serper 文搜结果 summarizer prompt。""" return ( "You are a helpful assistant. Your task is to summarize the following Google " "Serper text search results in no more than five sentences.\n\n" "Focus on facts that help answer the user's query. Preserve important names, " "dates, entities, and distinctions when they appear in the search results. " "Use only the provided search result titles, snippets, quick answers, and " "metadata. Do not invent facts and do not assume content from linked pages.\n\n" "If the results are ambiguous, conflicting, or insufficient, clearly state " "that uncertainty.\n\n" f"Query: {query}\n\n" f"Search Results:\n{formatted_results}" ) async def summarize_serper_text_results( query: str, formatted_results: str, session: aiohttp.ClientSession, summarizer_address: str = "", summarizer_model: str = "", max_tokens: int = 1024, ) -> Optional[str]: """用本地 summarizer 汇总 Serper 文搜元数据,不抓取网页正文。""" summarizer_addr = summarizer_address or MARS_SUMMARIZER_ADDRESS sum_model = summarizer_model or MARS_SUMMARIZER_MODEL if not formatted_results or not summarizer_addr or not sum_model: return None payload = { "model": sum_model, "messages": [{"role": "user", "content": _build_serper_text_summary_prompt(query, formatted_results)}], "max_tokens": max_tokens, "temperature": 0.3, "chat_template_kwargs": {"enable_thinking": False}, } try: async with session.post( f"http://{summarizer_addr.rstrip('/')}/v1/chat/completions", json=payload, headers={"Content-Type": "application/json"}, timeout=aiohttp.ClientTimeout(total=120), ) as resp: if resp.status != 200: print(f" [SERPER_GATEWAY] Summarizer returned HTTP {resp.status}, falling back to raw results", flush=True) return None data = await resp.json() choices = data.get("choices", []) if choices and isinstance(choices, list): content = choices[0].get("message", {}).get("content", "") summary = _clean_think_blocks(content).strip() return summary or None except asyncio.TimeoutError: print(" [SERPER_GATEWAY] Summarizer timeout, falling back to raw results", flush=True) except Exception as e: print(f" [SERPER_GATEWAY] Summarizer error: {e}, falling back to raw results", flush=True) return None async def call_serper_gateway_search( query: str, session: aiohttp.ClientSession, kwargs: dict, max_results: int = 8, timeout_seconds: int = 60, search_cache: Optional["SearchCache"] = None, ) -> str: """通过企业内部网关调用 Serper 文本搜索,后续复用 text_search_tool 的网页抓取与总结流程。""" clean_query = query.strip().replace("\n", " ") if not clean_query: return "Error: Serper gateway search requires a non-empty query." max_results = max(1, min(int(max_results or 8), 20)) summarizer_base_url = kwargs.get("summarizer_base_url", "") or ( f"http://{MARS_SUMMARIZER_ADDRESS.rstrip('/')}" if MARS_SUMMARIZER_ADDRESS else "" ) summarizer_model = kwargs.get("summarizer_model", "") or MARS_SUMMARIZER_MODEL top_k = 3 content_limit = 30000 if not summarizer_base_url or not summarizer_model: return "Error: Summarizer not configured for Serper gateway search." cache_model = ( f"serper_gateway_fetch:{summarizer_model}:max_results={max_results}:" f"top_k={top_k}:content_limit={content_limit}" ) if search_cache: cached = await search_cache.get(clean_query, top_k, cache_model) if cached: return f"Found cached summary for query: {clean_query}\n{cached}" gateway_url, gateway_username, gateway_userid, gateway_token = _get_gateway_text_search_config(kwargs) if not gateway_url: return "Error: GATEWAY_URL/SERPER_GATEWAY_URL not configured." if not gateway_token: return "Error: GATEWAY_TOKEN/SERPER_GATEWAY_TOKEN not configured." headers = { "Content-Type": "application/json", "User-Agent": "ifbook-http-client", } serper_params = { "q": clean_query, "type": "search", "num": max_results, } gateway_payload = { "sec_info": { "username": gateway_username, "userid": gateway_userid, "token": gateway_token, }, "model_type": "openai", "model_name": "serper", "params": json.dumps(serper_params), } max_attempts = 2 last_error = "" for attempt in range(max_attempts): try: async with session.post( gateway_url, headers=headers, json=gateway_payload, timeout=aiohttp.ClientTimeout(total=timeout_seconds), ) as resp: body = await resp.text() if resp.status != 200: last_error = f"Error: Serper gateway returned HTTP {resp.status}: {body[:500]}" if attempt < max_attempts - 1: print( f" [SERPER_GATEWAY] HTTP {resp.status}, retrying ({attempt + 1}/{max_attempts})...", flush=True, ) await asyncio.sleep(3 * (attempt + 1)) continue return last_error data = _parse_gateway_model_output(json.loads(body)) organic_results = data.get("organic", []) if not organic_results: return f"No search results found for query: {clean_query}" result = await summarize_serper_organic_results( query=clean_query, organic_results=organic_results, summarizer_base_url=summarizer_base_url, summarizer_model=summarizer_model, top_k=top_k, content_limit=content_limit, ) if search_cache and not result.startswith("Error:"): await search_cache.set(clean_query, top_k, cache_model, result) return f"Final summary generated for query: {clean_query}\n{result}" except asyncio.TimeoutError: last_error = f"Error: Serper gateway search timeout for query: {clean_query[:100]}" except Exception as e: last_error = f"Error: Serper gateway search failed: {e}" if attempt < max_attempts - 1: print(f" [SERPER_GATEWAY] Error, retrying ({attempt + 1}/{max_attempts}): {last_error}", flush=True) await asyncio.sleep(3 * (attempt + 1)) return last_error or "Error: Serper gateway search failed." def resolve_web_search_backend(raw_backend: str, tavily_api_keys: list[str]) -> str: """解析 web_search 后端;auto 在有 Tavily key 时优先 Tavily,否则用企业 Serper 网关。""" backend = _normalize_web_search_backend(raw_backend) if backend == "auto": return "tavily" if tavily_api_keys else "serper_gateway" return backend async def call_configured_web_search(query: str, kwargs: dict) -> tuple[str, str]: """按配置调用 VideoDR/web_search 后端,返回 `(结果文本, 后端名)`。""" session = kwargs.get("shared_session") raw_backend = kwargs.get("web_search_backend", _default_web_search_backend()) tavily_api_key = kwargs.get("tavily_api_key", "") tavily_key_pool = kwargs.get("tavily_api_key_pool") backend = resolve_web_search_backend( raw_backend, tavily_key_pool.keys if tavily_key_pool else ([tavily_api_key] if tavily_api_key else []), ) if backend == "tavily": result = await call_tavily_search( query=query, session=session, api_key=tavily_api_key, api_key_pool=tavily_key_pool, search_depth=kwargs.get("tavily_search_depth", "advanced"), max_results=kwargs.get("tavily_max_results", 8), include_answer=kwargs.get("tavily_include_answer", "advanced"), include_raw_content=kwargs.get("tavily_include_raw_content", "false"), topic=kwargs.get("tavily_topic", "general"), auto_parameters=kwargs.get("tavily_auto_parameters", False), include_domains=kwargs.get("tavily_include_domains") or None, exclude_domains=kwargs.get("tavily_exclude_domains") or None, timeout_seconds=kwargs.get("tavily_timeout", 60), search_cache=kwargs.get("search_cache"), ) return result, backend if backend == "serper_gateway": result = await call_serper_gateway_search( query=query, session=session, kwargs=kwargs, max_results=kwargs.get("serper_gateway_max_results", kwargs.get("tavily_max_results", 8)), timeout_seconds=kwargs.get("serper_gateway_timeout", 60), search_cache=kwargs.get("search_cache"), ) return result, backend if backend != "mars": return ( f"Error: Unknown WEB_SEARCH_BACKEND '{raw_backend}'. " "Supported values: serper_gateway, mars, tavily, auto.", backend, ) return await vdr_mars_web_search(query, session), backend # ============================================================================= # Evaluation Functions # ============================================================================= async def evaluate_direct( client: str, model: str, base_url: str, api_key: str, question: str, image_path: str, system_prompt: str = "", format_instruction: str = "", **kwargs, ) -> dict: """Direct mode - single turn, no tools.""" # Process image img = load_and_process_image(image_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True)) img_b64, mime = image_to_base64(img) # Build prompt with format instruction from dataset config prompt = question if format_instruction: prompt = f"{question}\n{format_instruction}" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{img_b64}"}}, {"type": "text", "text": prompt}, ] }) if client == "gemini": result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "azure": result = await call_azure_api(messages, model, base_url, api_key, **kwargs) elif client == "vertex": result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "gateway": result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) else: api_kwargs = dict(kwargs) api_kwargs["request_debug_context"] = { "sample_id": kwargs.get("sample_id", ""), "turn": 0, "task_kind": "image_direct", } result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) assistant_text = result.get("full_text") or result["content"] conversation_messages = json.loads(json.dumps(messages)) if assistant_text: conversation_messages.append({"role": "assistant", "content": assistant_text}) return { "output": assistant_text, "input_messages": messages, "conversation_messages": conversation_messages, "finish_reason": result["finish_reason"], "num_round": 1, "tool_calls": result.get("tool_calls", []), "error": result.get("error"), "saved_images": [], # Direct mode doesn't save images for HTML } async def evaluate_video_direct( client: str, model: str, base_url: str, api_key: str, question: str, video_path: str, system_prompt: str = GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, sample_id: str = "", images_dir: str = "", extra_image_paths: Optional[list[str]] = None, **kwargs, ) -> dict: """General video direct mode: sampled frames in, one model answer out.""" saved_images = [] image_counter = 0 frame_cache_dir = kwargs.get("frame_cache_dir", "") sample_frame_dir = os.path.join(frame_cache_dir or images_dir or ".", sample_id) os.makedirs(sample_frame_dir, exist_ok=True) def save_image_for_html(img: Image.Image, img_type: str) -> str: nonlocal image_counter image_counter += 1 marker = f"[IMAGE {image_counter}]" if images_dir and sample_id: filename = f"{sample_id}_{img_type}_{image_counter}.jpg" filepath = os.path.join(images_dir, filename) if img.mode != "RGB": img = img.convert("RGB") img.save(filepath, "JPEG", quality=85) saved_images.append({"marker": marker, "path": filename, "type": img_type}) return marker def build_image_part(img_path: str, img_type: str) -> tuple[dict, str]: img = load_and_process_image( img_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) img_b64, img_mime = image_to_base64(img) marker = save_image_for_html(img, img_type) return {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, marker all_frames = extract_video_frames_1fps( video_path, sample_frame_dir, kwargs.get("video_max_resolution", DEFAULT_VIDEO_MAX_RESOLUTION), kwargs.get("video_jpeg_quality", DEFAULT_VIDEO_JPEG_QUALITY), ) initial_indices = vdr_uniform_sample_indices( len(all_frames), kwargs.get("video_initial_frames", DEFAULT_VIDEO_INITIAL_FRAMES), ) user_content = [] for idx in initial_indices: image_part, _ = build_image_part(all_frames[idx], "input_frame") user_content.append(image_part) for image_idx, extra_image_path in enumerate(extra_image_paths or [], start=1): image_part, _ = build_image_part(extra_image_path, f"supplemental_image_{image_idx}") user_content.append({ "type": "text", "text": ( f"\nSupplemental image {image_idx} referenced by the question " f"as :\n" ), }) user_content.append(image_part) user_content.append({"type": "text", "text": question}) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_content}) if client == "gemini": result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "azure": result = await call_azure_api(messages, model, base_url, api_key, **kwargs) elif client == "vertex": result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "gateway": result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) else: api_kwargs = dict(kwargs) api_kwargs["request_debug_context"] = { "sample_id": sample_id, "turn": 0, "task_kind": "video_direct", } result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) assistant_text = result.get("full_text") or result["content"] conversation_messages = json.loads(json.dumps(messages)) if assistant_text: conversation_messages.append({"role": "assistant", "content": assistant_text}) return { "output": assistant_text, "input_messages": messages, "conversation_messages": conversation_messages, "finish_reason": "error" if result.get("error") else result["finish_reason"], "num_round": 1, "tool_calls": [], "error": result.get("error"), "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, } async def evaluate_tool( client: str, model: str, base_url: str, api_key: str, question: str, image_path: str, tool_system_prompt: str, max_turns: int = 10, serper_api_key: str = "", image_search_data: Optional[dict] = None, sample_id: str = "", images_dir: str = "", **kwargs, ) -> dict: """图像 benchmark 的工具模式,兼容旧工具名与 VideoDR 新工具名。""" saved_images = [] image_counter = 0 bbox_config = kwargs.get("bbox_config") or vdr_get_bbox_config(model) tool_work_dir = os.path.join(kwargs.get("frame_cache_dir", images_dir or "."), f"{sample_id}_tool") os.makedirs(tool_work_dir, exist_ok=True) image_search_cache = kwargs.get("image_search_cache") def save_image_for_html(img: Image.Image, img_type: str) -> tuple[str, str]: nonlocal image_counter image_counter += 1 marker = f"[IMAGE {image_counter}]" if images_dir and sample_id: filename = f"{sample_id}_{img_type}_{image_counter}.jpg" filepath = os.path.join(images_dir, filename) if img.mode != "RGB": img = img.convert("RGB") img.save(filepath, "JPEG", quality=85) saved_images.append({"marker": marker, "path": filename, "type": img_type}) return marker, filename return marker, "" def build_multimodal_tool_message(prefix: str, img: Image.Image, img_type: str) -> tuple[str, list[dict]]: marker, _ = save_image_for_html(img, img_type) img_b64, img_mime = image_to_base64(img) tool_response = f"\n{prefix}{marker}\n" content = [ {"type": "text", "text": f"\n{prefix}"}, {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, {"type": "text", "text": "\n"}, ] return tool_response, content original_image = Image.open(image_path) original_image.load() if original_image.mode != "RGB": original_image = original_image.convert("RGB") processed = process_image( original_image, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) img_b64, mime = image_to_base64(processed) save_image_for_html(processed, "input") messages = [ {"role": "system", "content": tool_system_prompt}, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{img_b64}"}}, {"type": "text", "text": question}, ], }, ] initial_messages = json.loads(json.dumps(messages)) tool_calls = [] output_parts = [] frame_label = kwargs.get("frame_label", 0) format_retry_limit = int(kwargs.get("format_retry_limit", 2)) format_retries = 0 for turn in range(max_turns + format_retry_limit): if client == "gemini": result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "azure": result = await call_azure_api(messages, model, base_url, api_key, **kwargs) elif client == "vertex": result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "gateway": result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) else: api_kwargs = dict(kwargs) api_kwargs["request_debug_context"] = { "sample_id": sample_id, "turn": turn, "task_kind": "image_tool", } result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) output = result.get("full_text") or result["content"] if result.get("error"): if output: messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "error", "num_round": turn + 1, "tool_calls": tool_calls, "error": result["error"], "saved_images": saved_images, } tool_call = extract_tool_call_from_result(result, output, user_query=question) if not tool_call: recovered_answer = ( _recover_answer_from_no_tool_output(result, output) if kwargs.get("recover_no_tool_answer", True) else None ) if recovered_answer: if recovered_answer != output: output = recovered_answer messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "answer", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "saved_images": saved_images, "format_retries": format_retries, } if format_retries < format_retry_limit: assistant_message, user_message, output_snippet = _make_format_repair_turn(output) if assistant_message: messages.append(assistant_message) messages.append(user_message) output_parts.append(output_snippet) format_retries += 1 continue if output: messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "no_tool_calls", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "saved_images": saved_images, "format_retries": format_retries, } tool_name, args = extract_tool_name_and_args(tool_call) assistant_tool_output = sanitize_assistant_tool_turn(output) allowed_tool_names = kwargs.get("allowed_tool_names") canonical_tool_name = canonicalize_tool_name(tool_name) if allowed_tool_names is not None and canonical_tool_name not in allowed_tool_names: format_retries = 0 messages.append({"role": "assistant", "content": assistant_tool_output}) tool_response = build_disallowed_tool_response(tool_name, allowed_tool_names) tool_calls.append({ "name": canonical_tool_name, "raw_name": tool_name, "blocked": True, "reason": "tool_ablation_profile", }) messages.append({"role": "user", "content": tool_response}) output_parts.append( f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n" ) continue format_retries = 0 messages.append({"role": "assistant", "content": assistant_tool_output}) tool_response = "" if tool_name in {"image_zoom_in_tool", "zoom_in"}: raw_bbox = args.get("bbox_2d") or args.get("bbox") if raw_bbox and len(raw_bbox) == 4: bbox = vdr_normalize_bbox(raw_bbox, bbox_config) crop_path = os.path.join(tool_work_dir, f"zoom_turn{turn:03d}.jpg") crop_frame_to_rgb_jpeg(image_path, bbox, crop_path) cropped = load_and_process_image( crop_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) tool_calls.append({"name": "zoom_in", "bbox": bbox}) tool_response, content = build_multimodal_tool_message( ( f"Here is the zoomed-in region " f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, " f"x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] of Frame {frame_label}:\n" ), cropped, "zoom", ) messages.append({"role": "user", "content": content}) else: tool_response = "\nError: bbox must have exactly 4 values.\n" messages.append({"role": "user", "content": tool_response}) elif tool_name == "text_search_tool": query = args.get("query", "") summarizer_base_url = kwargs.get("summarizer_base_url", "") summarizer_model = kwargs.get("summarizer_model", "") if query and serper_api_key and summarizer_base_url and summarizer_model: search_result = await call_text_search( query=query, serper_api_key=serper_api_key, summarizer_base_url=summarizer_base_url, summarizer_model=summarizer_model, serper_semaphore=kwargs.get("serper_semaphore"), serper_concurrency=kwargs.get("serper_concurrency", 5), search_cache=kwargs.get("search_cache"), ) tool_calls.append({"name": tool_name, "query": query}) tool_response = f"\n{search_result}\n" else: tool_response = "\nError: Search not available.\n" messages.append({"role": "user", "content": tool_response}) elif tool_name == "web_search": query = args.get("query", "") if query: search_result, search_backend = await call_configured_web_search(query, kwargs) tool_calls.append({"name": "web_search", "query": query, "backend": search_backend}) if search_result.startswith(f'Web Search Results for "{query}"'): tool_response = f"\n{search_result}\n" else: tool_response = f'\nWeb Search Results for "{query}":\n\n{search_result}\n' else: tool_response = "\nError: web_search requires a non-empty query.\n" messages.append({"role": "user", "content": tool_response}) elif tool_name in {"image_search_tool", "image_search"}: raw_bbox = args.get("bbox", [0, 0, 1000, 1000]) tool_calls.append({"name": "image_search", "bbox": raw_bbox}) if image_search_data: title_list = image_search_data.get("image_search_title_list", []) thumbnail_list = image_search_data.get("image_search_thumbnail_list", []) data_root = kwargs.get("data_root", "") if title_list and thumbnail_list: content_parts = [{"type": "text", "text": "\nReverse Image Search Results:"}] thumb_markers = [] for i, title in enumerate(title_list): content_parts.append({"type": "text", "text": f"\n\nTitle {i+1}: {title}\nThumbnail {i+1}: "}) if i < len(thumbnail_list): thumb_path = thumbnail_list[i] if data_root and not os.path.isabs(thumb_path): thumb_path = os.path.join(data_root, thumb_path) if os.path.exists(thumb_path): thumb_img = load_and_process_image( thumb_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) marker, _ = save_image_for_html(thumb_img, "thumbnail") thumb_markers.append(marker) thumb_b64, thumb_mime = image_to_base64(thumb_img) content_parts.append({"type": "image_url", "image_url": {"url": f"data:{thumb_mime};base64,{thumb_b64}"}}) content_parts.append({"type": "text", "text": "\n"}) messages.append({"role": "user", "content": content_parts}) tool_response = "\nReverse Image Search Results:" for i, title in enumerate(title_list): tool_response += f"\n\nTitle {i+1}: {title}\nThumbnail {i+1}: " if i < len(thumb_markers): tool_response += thumb_markers[i] tool_response += "\n" else: tool_response = "\nNo matching images were found.\n" messages.append({"role": "user", "content": tool_response}) else: if not raw_bbox or len(raw_bbox) != 4: tool_response = "\nError: bbox must have exactly 4 values.\n" messages.append({"role": "user", "content": tool_response}) else: bbox = vdr_normalize_bbox(raw_bbox, bbox_config) padded_bbox = vdr_add_search_padding(bbox, image_path, padding=(0.2, 0.2), padding_cap_px=600) crop_path = os.path.join(tool_work_dir, f"image_search_turn{turn:03d}.jpg") crop_frame_to_rgb_jpeg(image_path, padded_bbox, crop_path) with open(crop_path, "rb") as f: crop_bytes = f.read() cached = image_search_cache.get(crop_bytes) if image_search_cache else None if cached: result_text = cached else: crop_b64 = base64.b64encode(crop_bytes).decode("ascii") result_text = await vdr_real_image_search(crop_b64, kwargs.get("shared_session"), serper_api_key, crop_path=crop_path) if image_search_cache and not any(pat in result_text for pat in ("Image search error:", "Error:", "No results found from reverse image search")): image_search_cache.set(crop_bytes, result_text) tool_response = ( f"\nReverse Image Search Results " f"(region [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}] " f"of Frame {frame_label}):\n\n{result_text}\n" ) messages.append({"role": "user", "content": tool_response}) else: tool_response = f"\nError: Unknown tool '{tool_name}'.\n" messages.append({"role": "user", "content": tool_response}) output_parts.append(f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "max_turns", "num_round": max_turns, "tool_calls": tool_calls, "error": None, "saved_images": saved_images, "format_retries": format_retries, } async def evaluate_video_tool( client: str, model: str, base_url: str, api_key: str, question: str, video_path: str, tool_system_prompt: str = VIDEO_DR_SYSTEM_PROMPT, max_turns: int = 10, serper_api_key: str = "", sample_id: str = "", images_dir: str = "", **kwargs, ) -> dict: """VideoDR 工具模式评测:64 帧初始化 + 五工具状态机。""" saved_images = [] image_counter = 0 bbox_config = kwargs.get("bbox_config") or vdr_get_bbox_config(model) frame_cache_dir = kwargs.get("frame_cache_dir", "") sample_frame_dir = os.path.join(frame_cache_dir or images_dir or ".", sample_id) tool_work_dir = os.path.join(sample_frame_dir, "_tool") os.makedirs(tool_work_dir, exist_ok=True) image_search_cache = kwargs.get("image_search_cache") def save_image_for_html(img: Image.Image, img_type: str) -> str: nonlocal image_counter image_counter += 1 marker = f"[IMAGE {image_counter}]" if images_dir and sample_id: filename = f"{sample_id}_{img_type}_{image_counter}.jpg" filepath = os.path.join(images_dir, filename) if img.mode != "RGB": img = img.convert("RGB") img.save(filepath, "JPEG", quality=85) saved_images.append({"marker": marker, "path": filename, "type": img_type}) return marker def build_image_part(img_path: str, img_type: str) -> tuple[dict, str]: img = load_and_process_image( img_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) img_b64, img_mime = image_to_base64(img) marker = save_image_for_html(img, img_type) return {"type": "image_url", "image_url": {"url": f"data:{img_mime};base64,{img_b64}"}}, marker all_frames = extract_video_frames_1fps( video_path, sample_frame_dir, kwargs.get("video_max_resolution", DEFAULT_VIDEO_MAX_RESOLUTION), kwargs.get("video_jpeg_quality", DEFAULT_VIDEO_JPEG_QUALITY), ) initial_indices = vdr_uniform_sample_indices(len(all_frames), kwargs.get("video_initial_frames", DEFAULT_VIDEO_INITIAL_FRAMES)) user_content = [] for idx in initial_indices: image_part, _ = build_image_part(all_frames[idx], "input_frame") user_content.append(image_part) user_content.append({"type": "text", "text": question}) messages = [ {"role": "system", "content": tool_system_prompt}, {"role": "user", "content": user_content}, ] initial_messages = json.loads(json.dumps(messages)) locked_frame_idx = None locked_frame_path = "" tool_calls = [] output_parts = [] format_retry_limit = int(kwargs.get("format_retry_limit", 2)) format_retries = 0 force_final_answer_turn = bool(kwargs.get("force_final_answer_turn", True)) final_answer_retry_limit = int(kwargs.get("final_answer_retry_limit", 1)) final_answer_retries = 0 forced_final_answer = False for turn in range(max_turns + format_retry_limit): force_final_this_turn = ( force_final_answer_turn and not forced_final_answer and bool(tool_calls) and turn >= max_turns - 1 ) if force_final_this_turn: forced_final_answer = True final_result = await call_final_only_answer( client=client, model=model, base_url=base_url, api_key=api_key, question=question, output_parts=output_parts, tool_calls=tool_calls, sample_id=sample_id, retry_limit=final_answer_retry_limit, **kwargs, ) output_parts.append(final_result["transcript"]) final_api_result = final_result.get("result") or {} return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages + final_result.get("messages", []))), "finish_reason": final_result["finish_reason"], "num_round": turn + 1, "tool_calls": tool_calls, "error": final_api_result.get("error"), "error_status_code": final_api_result.get("error_status_code"), "error_raw_response": final_api_result.get("error_raw_response", ""), "request_debug": final_api_result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_result.get("retries", 0), "final_only_answer": True, } if client == "gemini": result = await call_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "azure": result = await call_azure_api(messages, model, base_url, api_key, **kwargs) elif client == "vertex": result = await call_vertex_gemini_api(messages, model, base_url, api_key, **kwargs) elif client == "gateway": result = await call_gateway_api(messages, model, base_url, api_key, **kwargs) else: api_kwargs = dict(kwargs) api_kwargs["request_debug_context"] = { "sample_id": sample_id, "turn": turn, "task_kind": "video_tool", } result = await call_openai_api(messages, model, base_url, api_key, **api_kwargs) output = result.get("full_text") or result["content"] if result.get("error"): if output: messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "error", "num_round": turn + 1, "tool_calls": tool_calls, "error": result["error"], "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, } tool_call = extract_tool_call_from_result(result, output, user_query=question) if not tool_call: recovered_answer = ( _recover_answer_from_no_tool_output(result, output) if kwargs.get("recover_no_tool_answer", True) else None ) if recovered_answer: if recovered_answer != output: output = recovered_answer messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "answer", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_answer_retries, } if forced_final_answer and final_answer_retries < final_answer_retry_limit: assistant_message = {"role": "assistant", "content": output} if output else None if assistant_message: messages.append(assistant_message) messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) output_parts.append( f"{output}<|im_end|><|im_start|>user\n" f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" ) final_answer_retries += 1 continue if format_retries < format_retry_limit: assistant_message, user_message, output_snippet = _make_format_repair_turn(output) if assistant_message: messages.append(assistant_message) messages.append(user_message) output_parts.append(output_snippet) format_retries += 1 continue if output: messages.append({"role": "assistant", "content": output}) output_parts.append(output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "no_tool_calls", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_answer_retries, } tool_name, args = extract_tool_name_and_args(tool_call) assistant_tool_output = sanitize_assistant_tool_turn(output) allowed_tool_names = kwargs.get("allowed_tool_names") canonical_tool_name = canonicalize_tool_name(tool_name) if allowed_tool_names is not None and canonical_tool_name not in allowed_tool_names: if forced_final_answer: if final_answer_retries < final_answer_retry_limit: messages.append({"role": "assistant", "content": assistant_tool_output}) messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) output_parts.append( f"{assistant_tool_output}<|im_end|><|im_start|>user\n" f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" ) final_answer_retries += 1 continue messages.append({"role": "assistant", "content": assistant_tool_output}) output_parts.append(assistant_tool_output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "max_turns", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_answer_retries, } format_retries = 0 messages.append({"role": "assistant", "content": assistant_tool_output}) tool_response = build_disallowed_tool_response(tool_name, allowed_tool_names) tool_calls.append({ "name": canonical_tool_name, "raw_name": tool_name, "blocked": True, "reason": "tool_ablation_profile", }) messages.append({"role": "user", "content": tool_response}) output_parts.append( f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n" ) continue if forced_final_answer: if final_answer_retries < final_answer_retry_limit: messages.append({"role": "assistant", "content": assistant_tool_output}) messages.append({"role": "user", "content": FINAL_ANSWER_REPAIR_PROMPT}) output_parts.append( f"{assistant_tool_output}<|im_end|><|im_start|>user\n" f"{FINAL_ANSWER_REPAIR_PROMPT}<|im_end|>\n<|im_start|>assistant\n" ) final_answer_retries += 1 continue messages.append({"role": "assistant", "content": assistant_tool_output}) output_parts.append(assistant_tool_output + "<|im_end|>") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "max_turns", "num_round": turn + 1, "tool_calls": tool_calls, "error": None, "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_answer_retries, } format_retries = 0 messages.append({"role": "assistant", "content": assistant_tool_output}) tool_response = "" if tool_name == "choose_frames": start = int(args.get("start_frame_index", 0)) end = int(args.get("end_frame_index", start)) start = max(0, min(start, len(all_frames) - 1)) end = max(0, min(end, len(all_frames) - 1)) if start > end: start, end = end, start interval_frames = vdr_sample_interval( all_frames, start, end, kwargs.get("video_interval_samples", DEFAULT_VIDEO_INTERVAL_SAMPLES), ) if not interval_frames: interval_frames = [vdr_get_frame(all_frames, (start + end) // 2)] content_parts = [{ "type": "text", "text": f"\nHere are {len(interval_frames)} uniformly sampled frames from the interval [Frame {start} to Frame {end}]:\n\n", }] markers = [] for idx, frame_path in interval_frames: image_part, marker = build_image_part(frame_path, "interval_frame") markers.append(marker) content_parts.append(image_part) content_parts.append({"type": "text", "text": "\n"}) content_parts.append({"type": "text", "text": ""}) messages.append({"role": "user", "content": content_parts}) tool_calls.append({"name": "choose_frames", "start_frame_index": start, "end_frame_index": end}) tool_response = ( f"\nHere are {len(interval_frames)} uniformly sampled frames from the interval " f"[Frame {start} to Frame {end}]:\n\n" + "\n".join(markers) + "\n" ) elif tool_name == "find_frame": requested_idx = int(args.get("frame_index", 0)) locked_frame_idx, locked_frame_path = vdr_get_frame(all_frames, requested_idx) image_part, marker = build_image_part(locked_frame_path, "locked_frame") messages.append({ "role": "user", "content": [ {"type": "text", "text": f"\nHere is Frame {locked_frame_idx}:\n"}, image_part, {"type": "text", "text": "\n"}, ], }) tool_calls.append({"name": "find_frame", "frame_index": requested_idx, "actual_frame_index": locked_frame_idx}) tool_response = f"\nHere is Frame {locked_frame_idx}:\n{marker}\n" elif tool_name == "zoom_in": if not locked_frame_path: tool_response = "\nError: zoom_in can only be used after find_frame.\n" messages.append({"role": "user", "content": tool_response}) else: raw_bbox = args.get("bbox") if not raw_bbox or len(raw_bbox) != 4: tool_response = "\nError: bbox must have exactly 4 values.\n" messages.append({"role": "user", "content": tool_response}) else: bbox = vdr_normalize_bbox(raw_bbox, bbox_config) crop_path = os.path.join(tool_work_dir, f"zoom_turn{turn:03d}.jpg") crop_frame_to_rgb_jpeg(locked_frame_path, bbox, crop_path) cropped = load_and_process_image( crop_path, kwargs.get("min_pixels", 65536), kwargs.get("max_pixels", 8294400), kwargs.get("factor", 32), kwargs.get("qwen_vl_processing", True), ) crop_b64, crop_mime = image_to_base64(cropped) marker = save_image_for_html(cropped, "zoom") messages.append({ "role": "user", "content": [ { "type": "text", "text": ( f"\nHere is the zoomed-in region " f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] " f"of Frame {locked_frame_idx}:\n" ), }, {"type": "image_url", "image_url": {"url": f"data:{crop_mime};base64,{crop_b64}"}}, {"type": "text", "text": "\n"}, ], }) tool_calls.append({"name": "zoom_in", "bbox": bbox, "frame_index": locked_frame_idx}) tool_response = ( f"\nHere is the zoomed-in region " f"[x1={bbox[0]:.3f}, y1={bbox[1]:.3f}, x2={bbox[2]:.3f}, y2={bbox[3]:.3f}] " f"of Frame {locked_frame_idx}:\n{marker}\n" ) elif tool_name == "image_search": if not locked_frame_path: tool_response = "\nError: image_search can only be used after find_frame.\n" messages.append({"role": "user", "content": tool_response}) else: raw_bbox = args.get("bbox") if not raw_bbox or len(raw_bbox) != 4: tool_response = "\nError: bbox must have exactly 4 values.\n" messages.append({"role": "user", "content": tool_response}) else: bbox = vdr_normalize_bbox(raw_bbox, bbox_config) padded_bbox = vdr_add_search_padding(bbox, locked_frame_path, padding=(0.2, 0.2), padding_cap_px=600) crop_path = os.path.join(tool_work_dir, f"image_search_turn{turn:03d}.jpg") crop_frame_to_rgb_jpeg(locked_frame_path, padded_bbox, crop_path) with open(crop_path, "rb") as f: crop_bytes = f.read() cached = image_search_cache.get(crop_bytes) if image_search_cache else None if cached: result_text = cached else: crop_b64 = base64.b64encode(crop_bytes).decode("ascii") result_text = await vdr_real_image_search(crop_b64, kwargs.get("shared_session"), serper_api_key, crop_path=crop_path) if image_search_cache and not any(pat in result_text for pat in ("Image search error:", "Error:", "No results found from reverse image search")): image_search_cache.set(crop_bytes, result_text) tool_calls.append({"name": "image_search", "bbox": bbox, "frame_index": locked_frame_idx}) tool_response = ( f"\nReverse Image Search Results " f"(region [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}] " f"of Frame {locked_frame_idx}):\n\n{result_text}\n" ) messages.append({"role": "user", "content": tool_response}) elif tool_name == "web_search": query = args.get("query", "") if not query: tool_response = "\nError: web_search requires a non-empty query.\n" else: result_text, search_backend = await call_configured_web_search(query, kwargs) if result_text.startswith(f'Web Search Results for "{query}"'): tool_response = f"\n{result_text}\n" else: tool_response = f'\nWeb Search Results for "{query}":\n\n{result_text}\n' tool_calls.append({ "name": "web_search", "query": query, "frame_index": locked_frame_idx, "backend": search_backend, }) messages.append({"role": "user", "content": tool_response}) else: tool_response = f"\nError: Unknown tool '{tool_name}'.\n" messages.append({"role": "user", "content": tool_response}) output_parts.append(f"{assistant_tool_output}<|im_end|><|im_start|>user\n{tool_response}<|im_end|>\n<|im_start|>assistant\n") return { "output": "".join(output_parts), "input_messages": initial_messages, "conversation_messages": json.loads(json.dumps(messages)), "finish_reason": "max_turns", "num_round": max_turns, "tool_calls": tool_calls, "error": None, "error_status_code": None, "error_raw_response": "", "request_debug": result.get("request_debug") if "result" in locals() else None, "saved_images": saved_images, "locked_frame_path": locked_frame_path, "format_retries": format_retries, "forced_final_answer": forced_final_answer, "final_answer_retries": final_answer_retries, } # ============================================================================= # Main Evaluation Loop # ============================================================================= def messages_to_input_string( messages: list, saved_images: Optional[list[dict]] = None, append_assistant_prompt: bool = True, ) -> str: """Convert messages to transcript string like rollout data. When saved_images is provided, image markers are assigned in saved order. """ parts = [] image_markers = [img.get("marker", "[IMAGE]") for img in (saved_images or []) if img.get("marker")] image_marker_idx = 0 def next_marker() -> str: nonlocal image_marker_idx if image_marker_idx < len(image_markers): marker = image_markers[image_marker_idx] image_marker_idx += 1 return marker image_marker_idx += 1 return f"[IMAGE {image_marker_idx}]" for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if isinstance(content, str): parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") else: # Handle multimodal content text_parts = [] has_image = False for item in content: if isinstance(item, str): text_parts.append(item) elif item.get("type") == "text": text_parts.append(item["text"]) elif item.get("type") == "image_url": has_image = True text_parts.append(next_marker()) content_str = "\n".join(text_parts) parts.append(f"<|im_start|>{role}\n{content_str}<|im_end|>") transcript = "\n".join(parts) if append_assistant_prompt: return transcript + "\n<|im_start|>assistant\n" return transcript async def evaluate_sample( sample: dict, client: str, model: str, base_url: str, api_key: str, mode: str, dataset_configs: dict, semaphore: asyncio.Semaphore, **kwargs, ) -> dict: """Evaluate a single sample.""" dataset_name = sample.get("dataset", "") ds_config = dataset_configs.get(dataset_name, {}) score_methods = ds_config.get("score_methods", []) task_kind = sample.get("task_kind") or ds_config.get("task_kind", "image") # Get prompts from dataset config system_prompt = ds_config.get("system_prompt", "") format_instruction = ds_config.get("format_instruction", "") async with semaphore: try: if mode == "direct": if task_kind == "video_dr": result = await evaluate_video_direct( client, model, base_url, api_key, sample["question"], sample["video_path"], system_prompt=kwargs.get( "general_video_direct_system_prompt", GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, ), sample_id=sample["id"], extra_image_paths=sample.get("extra_image_paths") or [], **kwargs, ) else: result = await evaluate_direct( client, model, base_url, api_key, sample["question"], sample["image_path"], system_prompt=system_prompt, format_instruction=format_instruction, **kwargs ) else: if task_kind == "video_dr": result = await evaluate_video_tool( client, model, base_url, api_key, sample["question"], sample["video_path"], tool_system_prompt=system_prompt or kwargs.get("video_dr_system_prompt", VIDEO_DR_SYSTEM_PROMPT), sample_id=sample["id"], **kwargs, ) else: tools_section = kwargs.get("tools_section", "") if not tools_section: raise ValueError("tools_section required for image tool mode (use --tool-config)") task_instruction_parts = [] if system_prompt: task_instruction_parts.append(system_prompt.strip()) if format_instruction: task_instruction_parts.append( "# Required Answer Format\n" + format_instruction.strip() ) task_instruction = "\n\n".join(part for part in task_instruction_parts if part) if not task_instruction: raise ValueError( f"system_prompt or format_instruction required in dataset config for tool mode " f"(dataset: {dataset_name})" ) tool_system_prompt = build_image_tool_system_prompt( task_instruction=task_instruction, tools_section=tools_section, max_turns=kwargs.get("max_turns", 10), allowed_tool_names=kwargs.get("allowed_tool_names"), ) result = await evaluate_tool( client, model, base_url, api_key, sample["question"], sample["image_path"], tool_system_prompt=tool_system_prompt, image_search_data=sample.get("image_search_data"), data_root=sample.get("data_root", ""), sample_id=sample["id"], **kwargs, ) # Compute scores dynamically based on score_methods from config scores = {} extracted = None for method in score_methods: if method == "em_score_mcq": extracted = extract_mcq_answer(result["output"]) em_correct = check_answer(extracted, sample["answer"]) scores[method] = 1.0 if em_correct else 0.0 elif method == "llm_score": judge_client = kwargs.get("judge_client", "azure") judge_base_url = kwargs.get("judge_base_url", "") judge_api_key = kwargs.get("judge_api_key", "") judge_temperature = kwargs.get("judge_temperature", 0.0) scores[method] = await llm_judge_score( sample["question"], result["output"], sample["answer"], sample.get("judge_image_path") or result.get("locked_frame_path", "") or sample.get("image_path", ""), judge_client, judge_base_url, judge_api_key, judge_temperature, shared_session=kwargs.get("shared_session"), ) else: # Unknown method - skip or set to None scores[method] = None # Convert messages to input string format input_str = messages_to_input_string(result["input_messages"], result.get("saved_images", [])) trajectory_str = messages_to_input_string( result.get("conversation_messages", result["input_messages"]), result.get("saved_images", []), append_assistant_prompt=False, ) result_dict = { "sample_id": sample["id"], "dataset": dataset_name, "input": input_str, "output": result["output"], "trajectory": trajectory_str, "gts": sample["answer"], # answer is already a list "finish_reason": result["finish_reason"], "num_round": result["num_round"], "tool_calls": result.get("tool_calls", []), "error": result.get("error"), "error_status_code": result.get("error_status_code"), "error_raw_response": result.get("error_raw_response", ""), "request_debug": result.get("request_debug"), "format_retries": result.get("format_retries", 0), "forced_final_answer": result.get("forced_final_answer", False), "final_answer_retries": result.get("final_answer_retries", 0), "final_only_answer": result.get("final_only_answer", False), "eval_compat_profile": kwargs.get("eval_compat_profile", "current"), "saved_images": result.get("saved_images", []), # For HTML generation only } for key in ("category", "difficulty", "task_kind"): if key in sample: result_dict[key] = sample[key] if result.get("locked_frame_path"): result_dict["locked_frame_path"] = result["locked_frame_path"] # Add all scores with their method names result_dict.update(scores) return result_dict except FatalAPIError: raise # Re-raise to crash on SERPER/judge errors except Exception as e: print(f"[ERROR] {sample['id']}: {e}") error_dict = { "sample_id": sample["id"], "dataset": dataset_name, "input": "", "output": "", "gts": sample["answer"], # answer is already a list "finish_reason": "error", "num_round": 0, "error": str(e), } for key in ("category", "difficulty", "task_kind"): if key in sample: error_dict[key] = sample[key] # Add all score methods as 0.0 for errors for method in score_methods: error_dict[method] = 0.0 return error_dict async def run_evaluation( samples: list[dict], dataset_configs: dict, client: str, model: str, base_url: str, api_key: str, mode: str, max_concurrent: int = 4, output_dir: str = "", **kwargs, ) -> dict: """Run evaluation on all samples.""" datasets = set(s.get("dataset", "") for s in samples) # Resume: load existing results and skip completed samples completed_ids = set() existing_results = [] results_file = os.path.join(output_dir, "results.jsonl") if output_dir else "" if results_file and os.path.exists(results_file): with open(results_file) as f: for line in f: r = json.loads(line) existing_results.append(r) completed_ids.add(r.get("sample_id")) print(f"Resuming: loaded {len(existing_results)} completed results, skipping {len(completed_ids)} samples") samples = [s for s in samples if s["id"] not in completed_ids] print(f"Evaluating {len(samples)} samples across {len(datasets)} datasets") print(f"Client: {client}, Model: {model}") print(f"Mode: {mode}") print() # Incremental stats tracking (avoid recomputing from full results list) dataset_stats = {ds: {"total": 0, "errors": 0, "em_correct": 0, "em_total": 0, "llm_correct": 0, "llm_total": 0} for ds in datasets} # Initialize with existing results for r in existing_results: ds = r.get("dataset", "") if ds in dataset_stats: dataset_stats[ds]["total"] += 1 if r.get("error"): dataset_stats[ds]["errors"] += 1 if r.get("em_score_mcq") is not None: dataset_stats[ds]["em_total"] += 1 if r["em_score_mcq"] > 0.5: dataset_stats[ds]["em_correct"] += 1 if r.get("llm_score") is not None: dataset_stats[ds]["llm_total"] += 1 if r["llm_score"] > 0.5: dataset_stats[ds]["llm_correct"] += 1 if not samples: print("All samples already completed!") # Generate HTML from existing results (note: saved_images not in JSONL, so images won't be inlined) if output_dir and existing_results: html_path = os.path.join(output_dir, "results.html") generate_html(existing_results, html_path, images_subdir="images") print(f"Generated HTML viewer: {html_path}", flush=True) # Return stats from existing results dataset_results = {} for dataset_name in sorted(datasets): st = dataset_stats[dataset_name] ds_stats = {"total": st["total"], "errors": st["errors"]} if st["em_total"] > 0: ds_stats["em_correct"] = st["em_correct"] ds_stats["em_total"] = st["em_total"] ds_stats["em_accuracy"] = st["em_correct"] / st["em_total"] if st["llm_total"] > 0: ds_stats["llm_correct"] = st["llm_correct"] ds_stats["llm_total"] = st["llm_total"] ds_stats["llm_accuracy"] = st["llm_correct"] / st["llm_total"] dataset_results[dataset_name] = ds_stats await _cleanup_browser() return dataset_results semaphore = asyncio.Semaphore(max_concurrent) # Create serper semaphore for rate limiting (must be inside event loop) serper_concurrency = kwargs.pop("serper_concurrency", 5) serper_semaphore = asyncio.Semaphore(serper_concurrency) kwargs["serper_semaphore"] = serper_semaphore kwargs["serper_concurrency"] = serper_concurrency # Create images directory for HTML viewer images_dir = "" frame_cache_dir = "" image_search_cache = None if output_dir: os.makedirs(output_dir, exist_ok=True) images_dir = os.path.join(output_dir, "images") frame_cache_dir = os.path.join(output_dir, "frame_cache") os.makedirs(images_dir, exist_ok=True) os.makedirs(frame_cache_dir, exist_ok=True) image_search_cache = ImageSearchCache(os.path.join(output_dir, "image_search_cache.json")) image_seed_paths = kwargs.get("image_search_cache_seed_paths") or [] if image_seed_paths: seed_image_search_cache(image_search_cache, image_seed_paths) kwargs["images_dir"] = images_dir kwargs["frame_cache_dir"] = frame_cache_dir kwargs["image_search_cache"] = image_search_cache start_time = time.time() num_existing = len(existing_results) last_log_time = time.time() # Prepare output file for incremental saves results_file_handle = None all_results_for_html = [] try: shared_timeout = aiohttp.ClientTimeout(total=3600) async with create_http_session(shared_timeout) as shared_session: kwargs["shared_session"] = shared_session tasks = [ evaluate_sample(s, client, model, base_url, api_key, mode, dataset_configs, semaphore, **kwargs) for s in samples ] if output_dir: write_mode = "a" if existing_results else "w" results_file_handle = open(os.path.join(output_dir, "results.jsonl"), write_mode) for i, task in enumerate(asyncio.as_completed(tasks)): try: result = await task except FatalAPIError as e: print(f"\n{'='*70}", flush=True) print(f"FATAL ERROR: {e}", flush=True) print(f"{'='*70}\n", flush=True) raise # Update incremental stats ds = result.get("dataset", "") if ds in dataset_stats: dataset_stats[ds]["total"] += 1 if result.get("error"): dataset_stats[ds]["errors"] += 1 if result.get("em_score_mcq") is not None: dataset_stats[ds]["em_total"] += 1 if result["em_score_mcq"] > 0.5: dataset_stats[ds]["em_correct"] += 1 if result.get("llm_score") is not None: dataset_stats[ds]["llm_total"] += 1 if result["llm_score"] > 0.5: dataset_stats[ds]["llm_correct"] += 1 # Save immediately (skip keys not suitable for jsonl) if results_file_handle: jsonl_skip_keys = {"saved_images"} jsonl_result = {k: v for k, v in result.items() if k not in jsonl_skip_keys} results_file_handle.write(json.dumps(jsonl_result, ensure_ascii=False) + "\n") results_file_handle.flush() # Collect for HTML generation (keeps saved_images) all_results_for_html.append(result) now = time.time() total_done = num_existing + i + 1 if (i + 1) % 10 == 0 or (i + 1) == len(tasks) or (now - last_log_time > 30): elapsed = now - start_time rate = (i + 1) / elapsed if elapsed > 0 else 0 eta = (len(tasks) - i - 1) / rate if rate > 0 else 0 ds_stats_strs = [] for ds_name in sorted(datasets): st = dataset_stats[ds_name] if st["em_total"] > 0: ds_acc = 100 * st["em_correct"] / st["em_total"] ds_stats_strs.append(f"{ds_name}(em): {st['em_correct']}/{st['em_total']} ({ds_acc:.1f}%)") if st["llm_total"] > 0: ds_acc = 100 * st["llm_correct"] / st["llm_total"] ds_stats_strs.append(f"{ds_name}(llm): {st['llm_correct']}/{st['llm_total']} ({ds_acc:.1f}%)") cache_info = "" search_cache = kwargs.get("search_cache") if search_cache: cs = search_cache.get_stats() cache_info = f" | Cache: {cs['hits']}/{cs['total']} ({cs['hit_rate']:.0f}%)" web_stats = _web_fetch_stats.format_progress() web_info = f" | {web_stats}" if web_stats else "" print( f"Progress: {total_done}/{num_existing + len(tasks)} | " f"Elapsed: {elapsed:.0f}s | {rate:.1f} samples/s | ETA: {eta:.0f}s" f"{cache_info}{web_info}", flush=True, ) if ds_stats_strs: print(f" {' | '.join(ds_stats_strs)}", flush=True) last_log_time = now finally: if results_file_handle: results_file_handle.close() if image_search_cache: image_search_cache.save() await _cleanup_browser() # Generate HTML viewer with all results (new + existing) # Note: existing_results from resume don't have saved_images (not in JSONL), # so their [IMAGE N] markers won't be replaced with actual images in HTML if output_dir and (all_results_for_html or existing_results): html_results = existing_results + all_results_for_html html_path = os.path.join(output_dir, "results.html") generate_html(html_results, html_path, images_subdir="images") print(f"Generated HTML viewer: {html_path}", flush=True) elapsed = time.time() - start_time # Per-dataset results (use incremental stats - no recomputation needed) print(flush=True) print("=" * 70, flush=True) print("Results by Dataset:", flush=True) print("=" * 70, flush=True) dataset_results = {} total_samples = 0 for dataset_name in sorted(datasets): st = dataset_stats[dataset_name] total_samples += st["total"] ds_stats = {"total": st["total"], "errors": st["errors"]} if st["em_total"] > 0: ds_stats["em_correct"] = st["em_correct"] ds_stats["em_total"] = st["em_total"] ds_stats["em_accuracy"] = st["em_correct"] / st["em_total"] if st["llm_total"] > 0: ds_stats["llm_correct"] = st["llm_correct"] ds_stats["llm_total"] = st["llm_total"] ds_stats["llm_accuracy"] = st["llm_correct"] / st["llm_total"] dataset_results[dataset_name] = ds_stats print(f"\n{dataset_name} ({st['total']} samples):", flush=True) if "em_accuracy" in ds_stats: print(f" em_score_mcq: {ds_stats['em_accuracy']:.4f} ({ds_stats['em_correct']}/{ds_stats['em_total']})", flush=True) if "llm_accuracy" in ds_stats: print(f" llm_score_allow_no_answer: {ds_stats['llm_accuracy']:.4f} ({ds_stats['llm_correct']}/{ds_stats['llm_total']})", flush=True) if ds_stats["errors"] > 0: print(f" errors: {ds_stats['errors']}", flush=True) print(flush=True) if total_samples > 0: print(f"Time: {elapsed:.1f}s ({elapsed/total_samples:.2f}s per sample)", flush=True) else: print(f"Time: {elapsed:.1f}s (no samples processed)", flush=True) # Web fetch stats summary ws = _web_fetch_stats.get_stats() if ws["total"] > 0: print(f"\nWeb Fetch Stats: {ws['successful']}/{ws['total']} ({ws['success_rate']:.1f}%) successful", flush=True) if ws["failed"] > 0: print(f" Failed: {ws['failed']}", flush=True) if ws["errors_by_code"]: code_str = ", ".join(f"HTTP {code}: {cnt}" for code, cnt in sorted(ws["errors_by_code"].items())) print(f" Error codes: {code_str}", flush=True) if ws["skipped"] > 0: print(f" Skipped (non-HTML): {ws['skipped']}", flush=True) print("=" * 70, flush=True) # Return dataset_results only - results are already saved incrementally to JSONL return dataset_results def generate_html(results: list[dict], output_path: str, images_subdir: str = "images"): """Generate HTML viewer for results with actual images displayed. Args: results: List of result dicts, each may contain 'saved_images' with image paths output_path: Path to write HTML file images_subdir: Subdirectory containing images (relative to HTML file) """ import html as html_lib def highlight_tags_with_images(text: str, saved_images: list[dict]) -> str: """Add syntax highlighting for special tags and replace [IMAGE N] with actual images.""" text = html_lib.escape(text) # Build marker -> img tag mapping for img_info in saved_images: marker = img_info.get("marker", "") path = img_info.get("path", "") img_type = img_info.get("type", "") if marker and path: escaped_marker = html_lib.escape(marker) # Create img tag with relative path img_tag = f'

' text = text.replace(escaped_marker, img_tag) text = text.replace("\n", "
") # Highlight tags tag_styles = [ (r"<think>", '<think>'), (r"</think>", '</think>'), (r"<thinking>", '<thinking>'), (r"</thinking>", '</thinking>'), (r"<tool_call>", '<tool_call>'), (r"</tool_call>", '</tool_call>'), (r"<answer>", '<answer>'), (r"</answer>", '</answer>'), (r"<tool_response>", '<tool_response>'), (r"</tool_response>", '</tool_response>'), ] for pattern, replacement in tag_styles: text = text.replace(pattern, replacement) return text html_content = ''' Eval Results Viewer

Eval Results Viewer

Samples: ''' + str(len(results)) + '''
Generated: ''' + time.strftime("%Y-%m-%d %H:%M:%S") + '''
''' for i, r in enumerate(results): sample_id = r.get("sample_id", f"sample-{i}") gts = r.get("gts", "") saved_images = r.get("saved_images", []) # Keys to skip (large or not useful for display) skip_keys = {"sample_id", "dataset", "input", "output", "gts", "saved_images"} # Display all keys equally in meta-row meta_items = [] meta_items.append(f'
Ground Truth: {html_lib.escape(str(gts))}
') for key, value in r.items(): if key in skip_keys: continue # Format values if value is None: formatted = "None" elif isinstance(value, float): formatted = f"{value:.4f}" elif isinstance(value, (int, bool)): formatted = str(value) elif isinstance(value, str) and len(value) < 200: formatted = html_lib.escape(value) else: continue # Skip complex or long values meta_items.append(f'
{html_lib.escape(key)}: {formatted}
') # Process input and output with image replacements input_html = highlight_tags_with_images(r.get("input", ""), saved_images) output_html = highlight_tags_with_images(r.get("output", ""), saved_images) html_content += f''' ''' html_content += '''
''' with open(output_path, "w") as f: f.write(html_content) def save_results(dataset_results: dict, output_dir: str, run_config: Optional[dict] = None): """Save summary to output directory. Note: results.jsonl is saved incrementally during eval.""" os.makedirs(output_dir, exist_ok=True) # Summary only (results.jsonl already saved incrementally) summary = {"datasets": dataset_results} if run_config: summary["run_config"] = run_config # Add web fetch stats if any ws = _web_fetch_stats.get_stats() if ws["total"] > 0: summary["web_fetch_stats"] = ws with open(os.path.join(output_dir, "summary.json"), "w") as f: json.dump(summary, f, indent=2) print(f"\nResults saved to: {output_dir}", flush=True) print(f" Results: {output_dir}/results.jsonl", flush=True) html_path = os.path.join(output_dir, "results.html") if os.path.exists(html_path): print(f" HTML: {html_path}", flush=True) print(f" Summary: {output_dir}/summary.json", flush=True) # ============================================================================= # CLI # ============================================================================= EVAL_COMPAT_PROFILES = ("current", "qwen235b_repair", "step14_plus_tavily432") def _env_flag(name: str, default: bool = False) -> bool: value = os.environ.get(name) if value is None: return default return value.lower() in {"1", "true", "yes", "on"} def _argv_has_option(option: str) -> bool: prefix = option + "=" return any(arg == option or arg.startswith(prefix) for arg in sys.argv[1:]) def apply_eval_compat_profile(args: argparse.Namespace, parser: argparse.ArgumentParser) -> str: """Normalize evaluation behavior for reproducible cross-run comparisons.""" profile = args.eval_compat_profile if profile not in EVAL_COMPAT_PROFILES: parser.error(f"Unknown --eval-compat-profile: {args.eval_compat_profile}") if profile == "step14_plus_tavily432": if _argv_has_option("--format-retry-limit") and args.format_retry_limit not in (None, 0): parser.error("--eval-compat-profile step14_plus_tavily432 requires --format-retry-limit 0") if _argv_has_option("--final-answer-retry-limit") and args.final_answer_retry_limit not in (None, 0): parser.error("--eval-compat-profile step14_plus_tavily432 requires --final-answer-retry-limit 0") args.format_retry_limit = 0 args.disable_force_final_answer_turn = True args.final_answer_retry_limit = 0 args.disable_no_tool_answer_recovery = True return "step14_plus_tavily432" if args.format_retry_limit is None: args.format_retry_limit = int(os.environ.get("FORMAT_RETRY_LIMIT", "2")) if args.final_answer_retry_limit is None: args.final_answer_retry_limit = int(os.environ.get("FINAL_ANSWER_RETRY_LIMIT", "1")) return profile def main(): parser = argparse.ArgumentParser(description="VLM Evaluation Script v2") parser.add_argument("--model", type=str, required=True, help="Model name") parser.add_argument("--mode", type=str, required=True, choices=["direct", "tool"]) parser.add_argument("--datasets", type=str, default="", help="Path to datasets JSON config. If omitted, use --benchmarks with --eval-root.") parser.add_argument("--eval-root", type=str, default="", help=( "Root directory whose subfolders are image benchmarks with data.jsonl. " f"Default used with --benchmarks: {DEFAULT_EVAL_ROOT}" )) parser.add_argument("--benchmarks", type=str, nargs="+", default=None, help="Benchmark subfolder names under --eval-root, e.g. mmsearch_end2end_only_image hr_mmsearch.") parser.add_argument("--save-resolved-datasets-config", type=str, default="", help="Optional path to save the auto-generated datasets JSON config.") parser.add_argument("--model-client", type=str, choices=["gemini", "openai", "azure", "gateway", "vertex"], required=True, help="Model API client: gemini, openai (Qwen/MARS), azure (GPT), gateway (company GPT gateway), or vertex (Vertex Gemini service-account pool)") parser.add_argument("--judge-client", type=str, choices=["openai", "azure"], default=None, help="Legacy judge client flag; llm_score now follows video_dr_gen and ignores this option") parser.add_argument("--judge-temperature", type=float, default=0.0, help="Legacy judge temperature flag; llm_score now follows video_dr_gen and ignores this option") parser.add_argument("--data-root", type=str, default="", help="Root for relative paths") parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--target-ids", type=str, nargs="+", default=None, help="Only evaluate the specified sample ids or source ids") parser.add_argument("--vertex-account-pool-file", type=str, default=os.environ.get("VERTEX_ACCOUNT_POOL_FILE", ""), help="Vertex Gemini account pool JSON; defaults to env VERTEX_ACCOUNT_POOL_FILE") parser.add_argument("--vertex-location", type=str, default=os.environ.get("VERTEX_LOCATION", "global"), help="Default Vertex location for account-pool entries without location (default: global)") parser.add_argument("--vertex-rate-limit-cooldown", type=float, default=float(os.environ.get("VERTEX_RATE_LIMIT_COOLDOWN_SECONDS", "60")), help="Cooldown seconds for a Vertex account after 429/RESOURCE_EXHAUSTED before rotating back") parser.add_argument("--max-concurrent", type=int, default=4) parser.add_argument("--max-tokens", type=int, default=4096) parser.add_argument("--max-turns", type=int, default=50, help="Max assistant turns (default: 50)") parser.add_argument("--eval-compat-profile", type=str, choices=EVAL_COMPAT_PROFILES, default=os.environ.get("EVAL_COMPAT_PROFILE", "current"), help=( "Evaluation behavior profile. Use step14_plus_tavily432 to keep " "step14 prompt/control flow while retaining later Tavily key-pool fixes." )) parser.add_argument("--format-retry-limit", type=int, default=None, help="Extra protocol-repair retries when a tool-mode response is empty, malformed, or only in reasoning_content (default: 2; forced to 0 by step14_plus_tavily432)") parser.add_argument("--disable-force-final-answer-turn", action="store_true", default=_env_flag("DISABLE_FORCE_FINAL_ANSWER_TURN"), help="Disable the final no-tool answer turn used to prevent VideoDR tool loops near max_turns") parser.add_argument("--final-answer-retry-limit", type=int, default=None, help="Extra retries after the forced final-answer instruction if the model still fails to answer (default: 1; forced to 0 by step14_plus_tavily432)") parser.add_argument("--disable-no-tool-answer-recovery", action="store_true", default=_env_flag("DISABLE_NO_TOOL_ANSWER_RECOVERY"), help="Disable recovery of plain no-tool outputs as final answers.") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top-p", type=float, default=0.8) parser.add_argument("--top-k", type=int, default=20) parser.add_argument("--presence-penalty", type=float, default=1.5) parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty for model generation (default: 1.0)") parser.add_argument("--seed", type=int, default=3407, help="Random seed for deterministic generation (default: 3407)") parser.add_argument("--min-pixels", type=int, default=65536, help="Min pixels for image processing") parser.add_argument("--max-pixels", type=int, default=8294400, help="Max pixels for image processing") parser.add_argument("--factor", type=int, default=32, help="Alignment factor (32 for Qwen3-VL, 28 for Qwen2-VL)") parser.add_argument("--qwen-vl-processing", type=lambda x: x.lower() == 'true', default=True, help="Use Qwen-VL style image processing (default: True, set False for Gemini/GPT)") parser.add_argument("--tool-config", type=str, default="", help="Path to tool config YAML (optional, overrides dataset system_prompt)") parser.add_argument("--tool-ablation-profile", type=str, choices=TOOL_ABLATION_PROFILES, default=os.environ.get("TOOL_ABLATION_PROFILE", "none"), help=( "Tool ablation profile for tool-mode evaluation: none, nosearch " "(remove image_search/web_search), or nolocation (remove choose_frames/zoom_in)." )) parser.add_argument("--serper-concurrency", type=int, default=5, help="Max concurrent Serper API requests (default: 5)") parser.add_argument("--search-cache-dir", type=str, default="", help="Directory for search result cache (optional, no cache if not set)") parser.add_argument("--seed-search-cache-from", type=str, nargs="*", default=None, help="Seed this run's search cache from historical search_cache.db files under these files/directories. Defaults to inference/runs when --search-cache-dir is set.") parser.add_argument("--no-auto-seed-search-cache", action="store_true", default=os.environ.get("NO_AUTO_SEED_SEARCH_CACHE", "").lower() in {"1", "true", "yes", "on"}, help="Disable automatic historical search cache seeding.") parser.add_argument("--seed-image-search-cache-from", type=str, nargs="*", default=None, help="Seed this run's image_search_cache.json from historical image_search_cache.json files under these files/directories. Defaults to inference/runs in tool mode.") parser.add_argument("--no-auto-seed-image-search-cache", action="store_true", default=os.environ.get("NO_AUTO_SEED_IMAGE_SEARCH_CACHE", "").lower() in {"1", "true", "yes", "on"}, help="Disable automatic historical image_search_cache.json seeding.") parser.add_argument("--web-search-backend", type=str, choices=["serper_gateway", "gateway", "gateway_serper", "internal_serper", "company_serper", "mars", "tavily", "auto"], default=_default_web_search_backend(), help="Backend for web_search tool: serper_gateway, mars, tavily, or auto (default: env WEB_SEARCH_BACKEND/VIDEO_DR_WEB_SEARCH_BACKEND or serper_gateway)") parser.add_argument("--serper-gateway-max-results", type=int, default=int(os.environ.get("SERPER_GATEWAY_MAX_RESULTS", "8")), help="Serper gateway max results, clamped to 1-20 (default: 8)") parser.add_argument("--serper-gateway-timeout", type=int, default=int(os.environ.get("SERPER_GATEWAY_TIMEOUT", "60")), help="Serper gateway request timeout in seconds (default: 60)") parser.add_argument("--serper-gateway-summary-max-tokens", type=int, default=int(os.environ.get("SERPER_GATEWAY_SUMMARY_MAX_TOKENS", "1024")), help="Deprecated compatibility flag for the old metadata-only Serper gateway summarizer.") parser.add_argument("--tavily-search-depth", type=str, choices=["basic", "advanced"], default=os.environ.get("TAVILY_SEARCH_DEPTH", "advanced"), help="Tavily search_depth when --web-search-backend=tavily") parser.add_argument("--tavily-max-results", type=int, default=int(os.environ.get("TAVILY_MAX_RESULTS", "8")), help="Tavily max_results, clamped to 1-20 (default: 8)") parser.add_argument("--tavily-include-answer", type=str, default=os.environ.get("TAVILY_INCLUDE_ANSWER", "advanced"), help="Tavily include_answer: false, true/basic, or advanced (default: advanced)") parser.add_argument("--tavily-include-raw-content", type=str, default=os.environ.get("TAVILY_INCLUDE_RAW_CONTENT", "false"), help="Tavily include_raw_content: false, true/markdown, or text (default: false)") parser.add_argument("--tavily-topic", type=str, choices=["general", "news", "finance"], default=os.environ.get("TAVILY_TOPIC", "general"), help="Tavily topic (default: general)") parser.add_argument("--tavily-auto-parameters", action="store_true", default=os.environ.get("TAVILY_AUTO_PARAMETERS", "").lower() in {"1", "true", "yes", "on"}, help="Enable Tavily auto_parameters; uses extra Tavily credits") parser.add_argument("--tavily-timeout", type=int, default=int(os.environ.get("TAVILY_TIMEOUT", "60")), help="Tavily request timeout in seconds (default: 60)") parser.add_argument("--tavily-key-cooldown-seconds", type=int, default=int(os.environ.get("TAVILY_KEY_COOLDOWN_SECONDS", "60")), help="Cooldown for a Tavily key after 429/5xx before rotating back (default: 60)") parser.add_argument("--video-initial-frames", type=int, default=DEFAULT_VIDEO_INITIAL_FRAMES, help="Number of initial 1fps frames shown to the model for VideoDR") parser.add_argument("--video-interval-samples", type=int, default=DEFAULT_VIDEO_INTERVAL_SAMPLES, help="Number of uniformly sampled frames returned by choose_frames") parser.add_argument("--video-max-resolution", type=int, default=DEFAULT_VIDEO_MAX_RESOLUTION, help="Max side length for 1fps extracted video frames") parser.add_argument("--video-jpeg-quality", type=int, default=DEFAULT_VIDEO_JPEG_QUALITY, help="JPEG quality for cached video frames") args = parser.parse_args() resolved_eval_profile = apply_eval_compat_profile(args, parser) video_dr_system_prompt = get_video_dr_system_prompt(resolved_eval_profile) print( "Evaluation compatibility profile: " f"{resolved_eval_profile} " f"(format_retry_limit={args.format_retry_limit}, " f"force_final_answer_turn={not args.disable_force_final_answer_turn}, " f"final_answer_retry_limit={args.final_answer_retry_limit}, " f"no_tool_answer_recovery={not args.disable_no_tool_answer_recovery})", flush=True, ) model_client = args.model_client # Get API credentials based on model_client if model_client == "gemini": if "GEMINI_API_KEY" not in os.environ: parser.error("GEMINI_API_KEY required") if "GEMINI_BASE_URL" not in os.environ: parser.error("GEMINI_BASE_URL required") api_key = os.environ["GEMINI_API_KEY"] base_url = os.environ["GEMINI_BASE_URL"] elif model_client == "azure": if "AZURE_OPENAI_API_KEY" not in os.environ: parser.error("AZURE_OPENAI_API_KEY required") if "AZURE_OPENAI_BASE_URL" not in os.environ: parser.error("AZURE_OPENAI_BASE_URL required") api_key = os.environ["AZURE_OPENAI_API_KEY"] base_url = os.environ["AZURE_OPENAI_BASE_URL"] elif model_client == "gateway": _, _, api_key = _get_model_gateway_credentials(args.model) if not api_key: token_hint = ( "MODEL_GATEWAY_GEMINI_TOKEN or GEMINI_GATEWAY_TOKEN required" if _is_gemini_gateway_model(args.model) else "MODEL_GATEWAY_TOKEN or GATEWAY_TOKEN required" ) parser.error(token_hint) base_url = ( os.environ.get("MODEL_GATEWAY_URL") or os.environ.get("GATEWAY_URL") or "http://112.65.194.90:8000/trpc.youtu.llm_interface_service.Greeter/DescribeLlmResult" ) elif model_client == "vertex": if not args.vertex_account_pool_file: parser.error("--vertex-account-pool-file or VERTEX_ACCOUNT_POOL_FILE required for --model-client vertex") api_key = "" base_url = "vertex://account-pool" else: # OpenAI-compatible (Qwen, MARS, company OpenAI proxy, etc.) base_url = os.environ.get("MODEL_BASE_URL", DEFAULT_COMPANY_OPENAI_BASE_URL) api_key = ( os.environ.get("MODEL_API_KEY") or os.environ.get("MODEL_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY", "") or _read_secret_file(os.environ.get("MODEL_API_KEY_FILE", DEFAULT_COMPANY_OPENAI_API_KEY_FILE)) ) if model_client not in {"gateway", "vertex"} and not api_key: configure_local_service_no_proxy(base_url) vertex_account_pool = None if model_client == "vertex": vertex_account_pool = VertexAccountPool( args.vertex_account_pool_file, default_location=args.vertex_location, cooldown_seconds=args.vertex_rate_limit_cooldown, ) # Load datasets auto_dataset_config = None if args.benchmarks or args.eval_root: if args.datasets: parser.error("Use either --datasets or --eval-root/--benchmarks, not both.") eval_root = args.eval_root or DEFAULT_EVAL_ROOT auto_dataset_config = build_eval_root_dataset_config(eval_root, args.benchmarks) print( "Loading benchmarks from eval root " f"{eval_root}: {', '.join(auto_dataset_config.keys())}", flush=True, ) if args.save_resolved_datasets_config: save_dir = os.path.dirname(os.path.abspath(args.save_resolved_datasets_config)) if save_dir: os.makedirs(save_dir, exist_ok=True) with open(args.save_resolved_datasets_config, "w") as f: json.dump(auto_dataset_config, f, ensure_ascii=False, indent=2) print( f"Saved resolved datasets config to {args.save_resolved_datasets_config}", flush=True, ) samples, dataset_configs = load_datasets( auto_dataset_config, args.data_root, video_dr_system_prompt=video_dr_system_prompt, ) else: if not args.datasets: parser.error("--datasets is required unless --benchmarks or --eval-root is provided.") print(f"Loading datasets from {args.datasets}...") samples, dataset_configs = load_datasets( args.datasets, args.data_root, video_dr_system_prompt=video_dr_system_prompt, ) if args.target_ids: target_ids = set(args.target_ids) total_before_filter = len(samples) samples = [ sample for sample in samples if sample["id"] in target_ids or sample.get("source_id") in target_ids ] print( f"Filtered samples with --target-ids: {len(samples)}/{total_before_filter}", flush=True, ) if not samples: parser.error("--target-ids did not match any sample id/source id") print(f"Loaded {len(samples)} samples total\n") selected_dataset_names = {sample.get("dataset", "") for sample in samples} selected_dataset_configs = [ dataset_configs[name] for name in selected_dataset_names if name in dataset_configs ] has_video_dr_dataset = any(cfg.get("task_kind") == "video_dr" for cfg in selected_dataset_configs) has_image_dataset = any(cfg.get("task_kind") == "image" for cfg in selected_dataset_configs) if args.tool_ablation_profile != "none" and args.mode != "tool": parser.error("--tool-ablation-profile 仅支持 --mode tool") if args.tool_ablation_profile != "none" and has_video_dr_dataset and has_image_dataset: parser.error("同一次工具消融评测暂不支持混合 VideoDR 与 image benchmark") # Check LLM judge requirement needs_llm = any( "llm_score" in cfg.get("score_methods", []) for cfg in selected_dataset_configs ) judge_client = args.judge_client or "" judge_base_url = "" judge_api_key = "" if needs_llm and (not MARS_SUMMARIZER_ADDRESS or not MARS_SUMMARIZER_MODEL): print( "[WARN] llm_score 已切换到 video_dr_gen 的 MARS summarizer judge," "但 MARS_SUMMARIZER_ADDRESS 或 MARS_SUMMARIZER_MODEL 未配置;" "此时 llm_score 只会走快速路径,无法调用 LLM judge。", flush=True, ) serper_api_key = os.environ.get("SERPER_API_KEY", "") tavily_api_keys = load_tavily_api_keys() tavily_key_pool = TavilyApiKeyPool(tavily_api_keys, cooldown_seconds=args.tavily_key_cooldown_seconds) summarizer_base_url = os.environ.get("SUMMARIZER_BASE_URL", "") summarizer_model = os.environ.get("SUMMARIZER_MODEL", "") resolved_web_search_backend = resolve_web_search_backend(args.web_search_backend, tavily_api_keys) if args.mode == "tool": print( f"web_search backend: {resolved_web_search_backend}" + (" (from auto)" if args.web_search_backend == "auto" else ""), flush=True, ) if resolved_web_search_backend == "tavily" and not tavily_api_keys: parser.error("TAVILY_API_KEY, TAVILY_API_KEYS, or TAVILY_API_KEY_FILE required when --web-search-backend=tavily") if resolved_web_search_backend == "tavily": print(f"Tavily API key pool: {len(tavily_api_keys)} key(s)", flush=True) tools_section = "" allowed_tool_names = None if args.mode == "tool": if has_video_dr_dataset: allowed_tool_names = get_allowed_tool_names(args.tool_ablation_profile, "video_dr") elif has_image_dataset: allowed_tool_names = get_allowed_tool_names(args.tool_ablation_profile, "image") if has_image_dataset and not args.tool_config: parser.error("图像 benchmark 的 tool 模式需要提供 --tool-config") if args.tool_config: print(f"Loading tool config from {args.tool_config}...") tools_section = load_tool_config( args.tool_config, allowed_tool_names=allowed_tool_names, normalize_image_schema=has_image_dataset, normalize_video_schema=has_video_dr_dataset, ) if args.tool_ablation_profile != "none": print( "Tool ablation profile: " f"{args.tool_ablation_profile}; allowed tools: {format_tool_names(allowed_tool_names or set())}", flush=True, ) if has_video_dr_dataset and args.tool_ablation_profile != "none": if not tools_section: parser.error("VideoDR 工具消融需要提供 --tool-config 以生成消融后的工具定义 prompt") video_dr_system_prompt = build_video_tool_system_prompt( tools_section=tools_section, allowed_tool_names=allowed_tool_names or set(), max_turns=args.max_turns, ) for cfg in selected_dataset_configs: if cfg.get("task_kind") == "video_dr": cfg["system_prompt"] = video_dr_system_prompt configure_local_service_no_proxy(summarizer_base_url) configure_local_service_no_proxy(MARS_RETRIEVAL_ADDRESS) configure_local_service_no_proxy(MARS_SUMMARIZER_ADDRESS) if args.mode == "tool" and VIDEO_DR_IMAGE_SEARCH_MODE == "gateway": configure_local_service_no_proxy(VIDEO_DR_GATEWAY_URL) # Output dir if args.output_dir: output_dir = args.output_dir else: timestamp = time.strftime("%y%m%d%H%M%S") output_dir = f"eval_{args.model.replace('/', '_')}_{args.mode}_{timestamp}" os.makedirs(output_dir, exist_ok=True) if args.mode == "tool" and not args.search_cache_dir: args.search_cache_dir = os.path.join(output_dir, "search_cache") print( f"Search cache dir not provided; using {args.search_cache_dir}", flush=True, ) # Create search cache if directory provided search_cache = None search_cache_seed_paths = resolve_cache_seed_paths( args.seed_search_cache_from, args.no_auto_seed_search_cache, ) if args.search_cache_dir: search_cache = SearchCache(args.search_cache_dir) if search_cache_seed_paths: search_cache.seed_from_paths(search_cache_seed_paths) image_search_cache_seed_paths = resolve_cache_seed_paths( args.seed_image_search_cache_from, args.no_auto_seed_image_search_cache, ) # Run evaluation kwargs = { "eval_compat_profile": resolved_eval_profile, "video_dr_system_prompt": video_dr_system_prompt, "general_video_direct_system_prompt": GENERAL_VIDEO_DIRECT_SYSTEM_PROMPT, "max_tokens": args.max_tokens, "max_turns": args.max_turns, "format_retry_limit": args.format_retry_limit, "force_final_answer_turn": not args.disable_force_final_answer_turn, "final_answer_retry_limit": args.final_answer_retry_limit, "recover_no_tool_answer": not args.disable_no_tool_answer_recovery, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, "presence_penalty": args.presence_penalty, "repetition_penalty": args.repetition_penalty, "seed": args.seed, "min_pixels": args.min_pixels, "max_pixels": args.max_pixels, "factor": args.factor, "qwen_vl_processing": args.qwen_vl_processing, "serper_api_key": serper_api_key, "web_search_backend": args.web_search_backend, "serper_gateway_max_results": args.serper_gateway_max_results, "serper_gateway_timeout": args.serper_gateway_timeout, "serper_gateway_summary_max_tokens": args.serper_gateway_summary_max_tokens, "tavily_api_key": tavily_api_keys[0] if tavily_api_keys else "", "tavily_api_key_pool": tavily_key_pool, "tavily_search_depth": args.tavily_search_depth, "tavily_max_results": args.tavily_max_results, "tavily_include_answer": args.tavily_include_answer, "tavily_include_raw_content": args.tavily_include_raw_content, "tavily_topic": args.tavily_topic, "tavily_auto_parameters": args.tavily_auto_parameters, "tavily_timeout": args.tavily_timeout, "tavily_include_domains": _split_csv_env(os.environ.get("TAVILY_INCLUDE_DOMAINS", "")), "tavily_exclude_domains": _split_csv_env(os.environ.get("TAVILY_EXCLUDE_DOMAINS", "")), "summarizer_base_url": summarizer_base_url, "summarizer_model": summarizer_model, "serper_concurrency": args.serper_concurrency, "search_cache": search_cache, "image_search_cache_seed_paths": image_search_cache_seed_paths, "tools_section": tools_section, "tool_ablation_profile": args.tool_ablation_profile, "allowed_tool_names": allowed_tool_names, "judge_client": judge_client, "judge_base_url": judge_base_url, "judge_api_key": judge_api_key, "judge_temperature": args.judge_temperature, "video_initial_frames": args.video_initial_frames, "video_interval_samples": args.video_interval_samples, "video_max_resolution": args.video_max_resolution, "video_jpeg_quality": args.video_jpeg_quality, "vertex_account_pool": vertex_account_pool, } # Health check - crash early if servers are down async def health_check(): import aiohttp print("Running health checks...", flush=True) # Check model server print(f" Checking model server: {base_url}", flush=True) try: timeout = aiohttp.ClientTimeout(total=10) if model_client == "gateway": result = await call_gateway_api( [{"role": "user", "content": "Reply with exactly: OK"}], args.model, base_url, api_key, max_tokens=128, temperature=0.0, model_request_timeout=10, ) if result.get("error"): raise Exception(result["error"]) print(" Model gateway OK", flush=True) elif model_client == "vertex": result = await call_vertex_gemini_api( [{"role": "user", "content": "Hello, please test the connection and reply with OK."}], args.model, base_url, api_key, vertex_account_pool=vertex_account_pool, max_tokens=32, temperature=0.0, model_request_timeout=10, ) if result.get("error"): raise Exception(result["error"]) project = result.get("vertex_project_id", "") print(f" Vertex Gemini OK project={project}", flush=True) else: async with create_http_session(timeout) as session: # Try /v1/models endpoint (OpenAI-compatible) url = _openai_models_url(base_url) headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" async with session.get(url, headers=headers) as resp: if resp.status != 200: raise Exception(f"Model server returned HTTP {resp.status}") print(f" Model server OK", flush=True) except Exception as e: raise RuntimeError(f"Model server not reachable at {base_url}: {e}") # Check summarizer server (if in tool mode) if args.mode == "tool" and summarizer_base_url: print(f" Checking summarizer server: {summarizer_base_url}", flush=True) try: async with create_http_session(timeout) as session: url = f"{summarizer_base_url.rstrip('/')}/v1/models" async with session.get(url) as resp: if resp.status != 200: raise Exception(f"Summarizer server returned HTTP {resp.status}") print(f" Summarizer server OK", flush=True) except Exception as e: raise RuntimeError(f"Summarizer server not reachable at {summarizer_base_url}: {e}") if needs_llm and MARS_SUMMARIZER_ADDRESS: print(f" Checking LLM judge summarizer: http://{MARS_SUMMARIZER_ADDRESS}", flush=True) try: async with create_http_session(timeout) as session: url = f"http://{MARS_SUMMARIZER_ADDRESS.rstrip('/')}/v1/models" async with session.get(url, proxy="") as resp: if resp.status != 200: raise Exception(f"Judge summarizer returned HTTP {resp.status}") print(" LLM judge summarizer OK", flush=True) except Exception as e: raise RuntimeError(f"LLM judge summarizer not reachable at http://{MARS_SUMMARIZER_ADDRESS}: {e}") print("Health checks passed!", flush=True) asyncio.run(health_check()) try: dataset_results = asyncio.run(run_evaluation( samples, dataset_configs, model_client, args.model, base_url, api_key, args.mode, args.max_concurrent, output_dir=output_dir, **kwargs )) save_results( dataset_results, output_dir, run_config={ "eval_compat_profile": resolved_eval_profile, "format_retry_limit": args.format_retry_limit, "force_final_answer_turn": not args.disable_force_final_answer_turn, "final_answer_retry_limit": args.final_answer_retry_limit, "recover_no_tool_answer": not args.disable_no_tool_answer_recovery, "web_search_backend": resolved_web_search_backend, "tavily_search_depth": args.tavily_search_depth, "tavily_max_results": args.tavily_max_results, "tavily_include_answer": args.tavily_include_answer, "tavily_include_raw_content": args.tavily_include_raw_content, "tavily_topic": args.tavily_topic, "eval_root": args.eval_root or (DEFAULT_EVAL_ROOT if auto_dataset_config else ""), "benchmarks": list(auto_dataset_config.keys()) if auto_dataset_config else [], "resolved_datasets_config": args.save_resolved_datasets_config, }, ) finally: # Skip browser cleanup - Playwright can hang indefinitely and OS will clean up on exit # Always print cache stats and close, even on error if search_cache: stats = search_cache.get_stats() print(f"Search cache: {stats['hits']}/{stats['total']} hits ({stats['hit_rate']:.1f}%), {stats['misses']} new searches cached", flush=True) search_cache.close() print("Done.", flush=True) if __name__ == "__main__": main()