#!/usr/bin/env python3
"""
Video DeepResearch 公共工具模块。
输入:
- `config.py` 导出的模型端点、Vertex 配置、本地检索服务地址与价格参数。
- 视频帧、裁剪图、LLM messages、搜索请求与各阶段中间结果。
处理:
- 提供帧采样、图片编码、搜索访问、输出清洗、token 统计等通用能力。
- 在 `LLMClient` 中实现多 endpoint 轮询、429 冷却切换、Vertex 原生调用与 OAuth2 认证。
- 支持“一个 project 对应一个 service account json”的 Vertex 凭证池,并按 URL 选择对应凭证。
输出:
- 为 Phase1/Phase2 提供统一的工具函数、搜索函数、token 统计结构和 `LLMClient`。
- 返回规范化的模型响应、搜索结果、图片路径与 token 使用信息。
"""
import asyncio
import aiohttp
import base64
import hashlib
import itertools
import json
import math
import os
import random
import re
import subprocess
import time
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass, field
from config import (
API_ENDPOINTS, DEFAULT_MODEL, API_KEY,
WEB_SEARCH_ADDRESS, WEB_SEARCH_CONFIG,
SERPER_API_KEY, MOCK_SEARCH, IMAGE_SEARCH_CACHE_FILE,
DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_MAX_TOKENS,
DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES,
VERTEX_MIN_REQUEST_INTERVAL_SECONDS,
VERTEX_RATE_LIMIT_COOLDOWN_SECONDS,
VERTEX_REQUEST_JITTER_SECONDS,
TOKEN_PRICING, THINKING_TOKEN_PRICING,
DEFAULT_INPUT_PRICE_PER_1M, DEFAULT_OUTPUT_PRICE_PER_1M,
OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET,
OSS_ENDPOINT, OSS_BUCKET_NAME, OSS_UPLOAD_PREFIX,
SEARCH_CROP_MAX_SIZE, SEARCH_CROP_JPEG_QUALITY,
IMAGE_SEARCH_SUMMARIZE_SERPER, IMAGE_SEARCH_SUMMARIZER_ADDRESS,
IMAGE_SEARCH_SUMMARIZER_MODEL, IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS,
IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS,
IMAGE_SEARCH_MODE, GATEWAY_URL, GATEWAY_USERNAME, GATEWAY_USERID, GATEWAY_TOKEN,
IMAGE_SEARCH_ALLOW_BASE64_FALLBACK,
)
from config import BBOX_CONFIGS
import config
# ── MARS-style web search config ──
MARS_RETRIEVAL_ADDRESS = getattr(config, 'MARS_RETRIEVAL_ADDRESS', '')
MARS_SUMMARIZER_ADDRESS = getattr(config, 'MARS_SUMMARIZER_ADDRESS', '')
MARS_RETRIEVAL_TOPK = getattr(config, 'MARS_RETRIEVAL_TOPK', 3)
MARS_SUMMARIZER_MODEL = getattr(config, 'MARS_SUMMARIZER_MODEL', '')
MARS_WEB_SEARCH_MODE = getattr(config, 'MARS_WEB_SEARCH_MODE', 'serper')
MARS_RETRIEVAL_TIMEOUT = getattr(config, 'MARS_RETRIEVAL_TIMEOUT', 120)
MARS_RETRIEVAL_CONCURRENCY = getattr(config, 'MARS_RETRIEVAL_CONCURRENCY', 0)
IMAGE_SEARCH_SUMMARIZE_SERPER = getattr(config, 'IMAGE_SEARCH_SUMMARIZE_SERPER', True)
IMAGE_SEARCH_SUMMARIZER_ADDRESS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_ADDRESS', MARS_SUMMARIZER_ADDRESS)
IMAGE_SEARCH_SUMMARIZER_MODEL = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MODEL', MARS_SUMMARIZER_MODEL)
IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS', 5)
IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS', 512)
GCP_PROJECT_ID = getattr(config, 'GCP_PROJECT_ID', '')
GCP_LOCATION = getattr(config, 'GCP_LOCATION', '')
GCP_SERVICE_ACCOUNT_KEY = getattr(config, 'GCP_SERVICE_ACCOUNT_KEY', '')
VERTEX_CREDENTIALS_POOL = getattr(config, 'VERTEX_CREDENTIALS_POOL', [])
# Import phase-specific system prompts from prompts.py
from prompts import PHASE1_SYSTEM_PROMPT, PHASE2_SYSTEM_PROMPT, SYSTEM_PROMPT_BASE
# Legacy alias — kept for backward compatibility if any external code references it
SYSTEM_PROMPT = PHASE1_SYSTEM_PROMPT
# ════════════════════════════════════════════════════════════════════════
# Token Tracking & Cost Estimation
# ════════════════════════════════════════════════════════════════════════
def _get_pricing(model: str) -> Tuple[float, float, float]:
"""Get (input_price, output_price, thinking_price) per 1M tokens for a model.
Matches model name by substring (e.g., 'gemini-2.5-flash' matches
'gemini-2.5-flash-preview-05-20').
Returns prices in USD per 1M tokens.
"""
input_price = DEFAULT_INPUT_PRICE_PER_1M
output_price = DEFAULT_OUTPUT_PRICE_PER_1M
thinking_price = None # None means use output_price
model_lower = model.lower()
# Find best match (longest matching key wins)
best_match = ""
for pattern, (inp, outp) in TOKEN_PRICING.items():
if pattern.lower() in model_lower and len(pattern) > len(best_match):
best_match = pattern
input_price = inp
output_price = outp
# Check thinking token pricing
for pattern, tp in THINKING_TOKEN_PRICING.items():
if pattern.lower() in model_lower:
thinking_price = tp
break
if thinking_price is None:
thinking_price = output_price
return input_price, output_price, thinking_price
@dataclass
class TokenUsage:
"""Token usage for a single LLM call."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
# Some APIs (e.g., Gemini thinking mode) separate thinking tokens
thinking_tokens: int = 0
# Cache-related (some APIs report cached token counts)
cached_tokens: int = 0
def to_dict(self) -> Dict[str, int]:
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"thinking_tokens": self.thinking_tokens,
"cached_tokens": self.cached_tokens,
}
@staticmethod
def from_api_response(usage_data: Dict[str, Any]) -> "TokenUsage":
"""Parse token usage from OpenAI-compatible API response.
Handles various API formats:
- Standard: {prompt_tokens, completion_tokens, total_tokens}
- Gemini extended: {prompt_tokens, completion_tokens, total_tokens,
completion_tokens_details: {reasoning_tokens: N}}
- Some APIs: {input_tokens, output_tokens}
"""
if not usage_data:
return TokenUsage()
prompt = usage_data.get("prompt_tokens", 0) or usage_data.get("input_tokens", 0) or 0
completion = usage_data.get("completion_tokens", 0) or usage_data.get("output_tokens", 0) or 0
total = usage_data.get("total_tokens", 0) or (prompt + completion)
# Extract thinking/reasoning tokens if available
thinking = 0
details = usage_data.get("completion_tokens_details", {})
if isinstance(details, dict):
thinking = details.get("reasoning_tokens", 0) or details.get("thinking_tokens", 0) or 0
# Some Gemini APIs put it at top level
if not thinking:
thinking = usage_data.get("reasoning_tokens", 0) or usage_data.get("thinking_tokens", 0) or 0
# Cached tokens
cached = 0
prompt_details = usage_data.get("prompt_tokens_details", {})
if isinstance(prompt_details, dict):
cached = prompt_details.get("cached_tokens", 0) or 0
if not cached:
cached = usage_data.get("cached_tokens", 0) or 0
return TokenUsage(
prompt_tokens=prompt,
completion_tokens=completion,
total_tokens=total,
thinking_tokens=thinking,
cached_tokens=cached,
)
@dataclass
class TokenTracker:
"""Tracks token usage across multiple LLM calls for a single data entry.
Accumulates prompt_tokens, completion_tokens, thinking_tokens and
computes estimated cost based on model pricing.
"""
model: str = ""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_thinking_tokens: int = 0
total_cached_tokens: int = 0
num_calls: int = 0
call_details: List[Dict[str, Any]] = field(default_factory=list)
def add(self, usage: TokenUsage, call_label: str = ""):
"""Add token usage from one LLM call."""
self.total_prompt_tokens += usage.prompt_tokens
self.total_completion_tokens += usage.completion_tokens
self.total_thinking_tokens += usage.thinking_tokens
self.total_cached_tokens += usage.cached_tokens
self.num_calls += 1
self.call_details.append({
"label": call_label,
**usage.to_dict(),
})
@property
def total_tokens(self) -> int:
return self.total_prompt_tokens + self.total_completion_tokens
def estimate_cost(self, model: str = "") -> Dict[str, float]:
"""Estimate cost in USD based on model pricing.
Returns dict with input_cost, output_cost, thinking_cost, total_cost.
"""
m = model or self.model
input_price, output_price, thinking_price = _get_pricing(m)
input_cost = self.total_prompt_tokens * input_price / 1_000_000
# completion_tokens includes thinking_tokens for some APIs,
# so we separate them for pricing
non_thinking_completion = max(0, self.total_completion_tokens - self.total_thinking_tokens)
output_cost = non_thinking_completion * output_price / 1_000_000
thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000
return {
"input_cost_usd": round(input_cost, 6),
"output_cost_usd": round(output_cost, 6),
"thinking_cost_usd": round(thinking_cost, 6),
"total_cost_usd": round(input_cost + output_cost + thinking_cost, 6),
}
def to_dict(self, model: str = "") -> Dict[str, Any]:
"""Export full tracking info as a dict."""
cost = self.estimate_cost(model)
return {
"num_llm_calls": self.num_calls,
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
"total_thinking_tokens": self.total_thinking_tokens,
"total_cached_tokens": self.total_cached_tokens,
"total_tokens": self.total_tokens,
"estimated_cost": cost,
"call_details": self.call_details,
}
def summary_str(self, model: str = "") -> str:
"""Human-readable one-line summary."""
cost = self.estimate_cost(model)
return (
f"calls={self.num_calls} "
f"prompt={self.total_prompt_tokens:,} "
f"completion={self.total_completion_tokens:,} "
f"thinking={self.total_thinking_tokens:,} "
f"total={self.total_tokens:,} "
f"cost=${cost['total_cost_usd']:.4f}"
)
class GlobalTokenStats:
"""Thread-safe aggregator for token stats across all entries."""
def __init__(self, model: str = ""):
self.model = model
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.total_thinking_tokens = 0
self.total_cached_tokens = 0
self.total_calls = 0
self.total_entries = 0
self._lock = asyncio.Lock()
async def add(self, tracker: TokenTracker):
async with self._lock:
self.total_prompt_tokens += tracker.total_prompt_tokens
self.total_completion_tokens += tracker.total_completion_tokens
self.total_thinking_tokens += tracker.total_thinking_tokens
self.total_cached_tokens += tracker.total_cached_tokens
self.total_calls += tracker.num_calls
self.total_entries += 1
def estimate_cost(self) -> Dict[str, float]:
input_price, output_price, thinking_price = _get_pricing(self.model)
input_cost = self.total_prompt_tokens * input_price / 1_000_000
non_thinking = max(0, self.total_completion_tokens - self.total_thinking_tokens)
output_cost = non_thinking * output_price / 1_000_000
thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000
return {
"input_cost_usd": round(input_cost, 4),
"output_cost_usd": round(output_cost, 4),
"thinking_cost_usd": round(thinking_cost, 4),
"total_cost_usd": round(input_cost + output_cost + thinking_cost, 4),
}
def summary_str(self) -> str:
cost = self.estimate_cost()
avg_tokens = self.total_prompt_tokens + self.total_completion_tokens
avg_per_entry = avg_tokens / max(1, self.total_entries)
avg_cost = cost["total_cost_usd"] / max(1, self.total_entries)
return (
f"\n{'=' * 72}\n"
f" Token Usage Summary\n"
f"{'=' * 72}\n"
f" Model : {self.model}\n"
f" Total entries : {self.total_entries}\n"
f" Total LLM calls : {self.total_calls}\n"
f" Total prompt tok : {self.total_prompt_tokens:,}\n"
f" Total completion : {self.total_completion_tokens:,}\n"
f" Total thinking : {self.total_thinking_tokens:,}\n"
f" Total cached : {self.total_cached_tokens:,}\n"
f" Total tokens : {self.total_prompt_tokens + self.total_completion_tokens:,}\n"
f" ──────────────────────────────────────\n"
f" Avg tokens/entry : {avg_per_entry:,.0f}\n"
f" Avg cost/entry : ${avg_cost:.4f}\n"
f" ──────────────────────────────────────\n"
f" Input cost : ${cost['input_cost_usd']:.4f}\n"
f" Output cost : ${cost['output_cost_usd']:.4f}\n"
f" Thinking cost : ${cost['thinking_cost_usd']:.4f}\n"
f" TOTAL COST : ${cost['total_cost_usd']:.4f}\n"
f"{'=' * 72}"
)
def to_dict(self) -> Dict[str, Any]:
cost = self.estimate_cost()
avg_tokens = self.total_prompt_tokens + self.total_completion_tokens
return {
"model": self.model,
"total_entries": self.total_entries,
"total_llm_calls": self.total_calls,
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
"total_thinking_tokens": self.total_thinking_tokens,
"total_cached_tokens": self.total_cached_tokens,
"total_tokens": avg_tokens,
"avg_tokens_per_entry": round(avg_tokens / max(1, self.total_entries)),
"avg_cost_per_entry_usd": round(cost["total_cost_usd"] / max(1, self.total_entries), 6),
"estimated_cost": cost,
}
# ════════════════════════════════════════════════════════════════════════
# Think Validation Utilities
# ════════════════════════════════════════════════════════════════════════
def think_is_nonempty(text: str) -> bool:
"""Check if a think block contains basic meaningful content.
Relaxed version: Removed all arbitrary length/word count limits.
Only ensures the model didn't return a completely blank string."""
if not text:
return False
# Strip tags if present
clean = re.sub(r'?think>', '', text).strip()
# Must have just a tiny bit of alphanumeric content to prove it's not just spaces/punctuation
alnum = re.sub(r'[^A-Za-z0-9]', '', clean)
if len(alnum) < 5:
return False
return True
def extract_think_text(text: str) -> str:
"""Extract text content from a block."""
m = re.search(r'(.*?)', text, re.DOTALL)
return (m.group(1) or '').strip() if m else ''
def _dedup_think_content(text: str) -> str:
"""Remove paragraph-level repetition inside a block.
LLMs sometimes repeat entire paragraphs verbatim within a single turn
(decoding-level repetition / "repetition hallucination").
This function detects and removes such duplicates while preserving
unique content and ordering.
Strategy: split by double-newline into paragraphs, keep only the first
occurrence of each paragraph (compared after whitespace normalisation).
Also detects the case where the entire content is duplicated as one
contiguous block (no blank-line separator between the copies).
"""
if not text or not text.strip():
return text
# --- Case 1: paragraph-level dedup (split on blank lines) ---
paragraphs = re.split(r'\n\s*\n', text.strip())
if len(paragraphs) >= 2:
seen = set()
unique = []
for p in paragraphs:
key = ' '.join(p.split()) # normalise whitespace for comparison
if key and key not in seen:
seen.add(key)
unique.append(p)
if len(unique) < len(paragraphs):
return '\n\n'.join(unique)
# --- Case 2: whole-block duplication without blank-line separator ---
# e.g. "ABC\nABC" where ABC is a multi-sentence chunk
stripped = text.strip()
length = len(stripped)
if length >= 80: # only bother for non-trivial blocks
# try splitting at every \n boundary near the midpoint
mid = length // 2
for offset in range(0, min(40, mid)):
for pos in (mid + offset, mid - offset):
if pos <= 0 or pos >= length:
continue
if stripped[pos] != '\n':
continue
first_half = stripped[:pos].strip()
second_half = stripped[pos:].strip()
if first_half == second_half:
return first_half
return text
# ════════════════════════════════════════════════════════════════════════
# Hallucination Detection & Sanitization
# ════════════════════════════════════════════════════════════════════════
def sanitize_llm_output(text: str) -> str:
"""Truncate LLM output after the FIRST valid action block.
Detects and removes hallucinated content where the model generates:
- Multiple tool_calls in one turn
- Fake blocks
- "MODERATION:" blocks
- Fake search results
- Both AND in the same turn
Returns cleaned text containing at most: ... + one action.
"""
if not text or not text.strip():
return text
text = text.strip()
# Remove any blocks the model hallucinated
# (tool_response should ONLY come from the system)
if '' in text:
# Truncate at the first
tr_start = text.index('')
text = text[:tr_start].strip()
# Remove any "---" separator + MODERATION blocks
moderation_pattern = re.compile(r'\n*---+\s*\n*MODERATION:.*', re.DOTALL | re.IGNORECASE)
text = moderation_pattern.sub('', text).strip()
# Find ALL action blocks (tool_call and answer) with their positions
tc_matches = list(re.finditer(r'.*?', text, re.DOTALL))
ans_matches = list(re.finditer(r'.*?', text, re.DOTALL))
all_actions = []
for m in tc_matches:
all_actions.append(('tool_call', m.start(), m.end()))
for m in ans_matches:
all_actions.append(('answer', m.start(), m.end()))
if not all_actions:
# No action found — return as-is (will be handled by normalize)
return text
# Sort by position — keep only the FIRST action
all_actions.sort(key=lambda x: x[1])
first_type, first_start, first_end = all_actions[0]
# Truncate: keep everything up to and including the first action
text = text[:first_end].strip()
return text
def is_hallucinated_output(text: str) -> bool:
"""Check if the LLM output contains hallucination markers.
Returns True if the output contains:
- Multiple blocks
- Any block (should only come from system)
- "MODERATION:" blocks
- Both and in same turn
"""
if not text:
return False
tc_count = len(re.findall(r'', text))
has_tool_response = '' in text
has_moderation = bool(re.search(r'MODERATION:', text, re.IGNORECASE))
has_both = '' in text and '' in text
return tc_count > 1 or has_tool_response or has_moderation or has_both
# ════════════════════════════════════════════════════════════════════════
# GPT Output Normalizer
# ════════════════════════════════════════════════════════════════════════
def normalize_gpt_output(text: str) -> str:
"""Ensure every gpt turn follows strict format:
... followed by ... or ...
Pipeline:
1. Sanitize hallucinated content (truncate after first action)
2. If already present, validate and fix duplicates
3. If no , wrap pre-action text as think block if valid
4. ZERO FILLER: If think is missing or empty, DO NOT inject fake fallback text.
"""
if not text or not text.strip():
return text
text = text.strip()
# First pass: deduplicate tags — keep only the first occurrence
if text.count('') > 1:
first_close_pos = text.index('')
before_and_first = text[:first_close_pos + len('')]
after_first = text[first_close_pos + len(''):]
after_first = after_first.replace('', '')
text = before_and_first + after_first
# Also deduplicate tags — keep only the first occurrence
if text.count('') > 1:
first_open_pos = text.index('')
before_and_first = text[:first_open_pos + len('')]
after_first = text[first_open_pos + len(''):]
after_first = after_first.replace('', '')
text = before_and_first + after_first
has_think_open = '' in text
has_think_close = '' in text
has_tc = '' in text
has_answer = '' in text
if has_think_open:
# already has ... — validate & return
if not has_think_close:
# fix unclosed think: insert before first action
if has_tc:
tc_pos = text.index('')
text = text[:tc_pos] + '\n\n' + text[tc_pos:]
elif has_answer:
ans_pos = text.index('')
text = text[:ans_pos] + '\n\n' + text[ans_pos:]
else:
text = text + ''
# Dedup paragraph-level repetition inside block
think_text = extract_think_text(text)
if think_text:
deduped = _dedup_think_content(think_text)
if deduped != think_text:
text = text.replace(think_text, deduped, 1)
# 不再做强行补全:哪怕 里面是空的,也原样返回,绝对不加垃圾数据
return text
# No tag — check if there is valid text before the action
if has_tc:
tc_start = text.index('')
pre_text = text[:tc_start].strip()
tc_and_after = text[tc_start:]
if pre_text and think_is_nonempty(pre_text):
return f"{pre_text}\n\n{tc_and_after}"
else:
# 没有有效前置文本,直接返回动作,不补充废话
return tc_and_after
elif has_answer:
ans_start = text.index('')
pre_text = text[:ans_start].strip()
ans_and_after = text[ans_start:]
if pre_text and think_is_nonempty(pre_text):
return f"{pre_text}\n\n{ans_and_after}"
else:
# 没有有效前置文本,直接返回动作,不补充废话
return ans_and_after
else:
# No tool_call and no answer — wrap whatever it is as think, let downstream filters drop it if needed
return f"{text}"
# ════════════════════════════════════════════════════════════════════════
# Training Data Conversation Cleaning
# ════════════════════════════════════════════════════════════════════════
def clean_conversations(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Remove error recovery turns from training conversations.
Produces a clean trajectory suitable for SFT training by:
1. Removing empty gpt turns (model returned nothing)
2. Removing corresponding error tool_response turns
3. Removing gpt turns that were error messages
4. Re-normalizing remaining gpt turns
Returns a new list of clean conversation turns.
"""
if not conversations:
return conversations
cleaned = []
skip_next_human = False
for i, turn in enumerate(conversations):
role = turn.get("from", "")
value = turn.get("value", "")
if skip_next_human:
if role == "human":
skip_next_human = False
# Check if this is an error response we should skip
if _is_error_tool_response(value):
continue
# Otherwise keep it
cleaned.append(turn)
else:
# Unexpected: another gpt turn — don't skip
skip_next_human = False
cleaned.append(turn)
continue
if role == "gpt":
# Check if this turn is problematic
stripped = value.strip() if value else ""
# Case 1: Empty gpt turn
if not stripped:
# Skip this turn AND the next human turn (error response)
skip_next_human = True
continue
# Case 2: gpt turn with no valid action (only think, no tool_call/answer)
has_tc = '' in stripped
has_answer = '' in stripped
if not has_tc and not has_answer:
# No action — skip this turn and next error response
skip_next_human = True
continue
# Case 3: Valid turn — normalize and keep
normalized = normalize_gpt_output(stripped)
cleaned.append({"from": "gpt", "value": normalized})
elif role == "human":
# Check if this is an error response to skip
if _is_error_tool_response(value):
# Remove the preceding gpt turn too if it was just added
if cleaned and cleaned[-1].get("from") == "gpt":
# Check if the gpt turn led to this error
# Only remove if the gpt turn had no valid action
pass
continue
cleaned.append(turn)
else:
# system or other — keep as-is
cleaned.append(turn)
return cleaned
def _is_error_tool_response(value: str) -> bool:
"""Check if a human turn is an error tool_response that should be removed."""
if not value:
return False
v = value.strip()
error_markers = [
"No or found",
"Error: Tool",
"Malformed tool_call JSON",
"Error: bbox must have exactly",
"Error: web_search requires a non-empty query",
]
for marker in error_markers:
if marker in v:
return True
return False
# ════════════════════════════════════════════════════════════════════════
# Frame Utilities
# ════════════════════════════════════════════════════════════════════════
def get_video_frame_count(video_path: str) -> int:
"""Get total frame count from a 1fps video (duration ~ frames)."""
try:
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "csv=p=0", video_path],
capture_output=True, text=True, timeout=30)
return int(math.ceil(float(result.stdout.strip())))
except Exception:
return -1
def extract_all_frames(video_path: str, output_dir: str,
max_resolution: int = 768,
jpeg_quality: int = 85) -> Dict[int, str]:
"""Extract all frames from 1fps video. Returns {frame_index(0-based): path}.
Reuses existing frames if the directory already has them."""
os.makedirs(output_dir, exist_ok=True)
# Check if frames already exist
existing = {}
for f in sorted(os.listdir(output_dir)):
if f.startswith("frame_") and f.endswith(".jpg"):
idx = int(f.replace("frame_", "").replace(".jpg", ""))
existing[idx] = os.path.join(output_dir, f)
if existing:
return existing
vf = (f"scale='min({max_resolution},iw)':'min({max_resolution},ih)'"
f":force_original_aspect_ratio=decrease")
q = max(2, min(31, round(2 + (100 - jpeg_quality) * 29 / 99)))
cmd = [
"ffmpeg", "-y", "-i", video_path,
"-vf", vf, "-q:v", str(q),
"-start_number", "0",
os.path.join(output_dir, "frame_%06d.jpg"),
]
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {proc.stderr[-500:]}")
frames = {}
for f in sorted(os.listdir(output_dir)):
if f.startswith("frame_") and f.endswith(".jpg"):
idx = int(f.replace("frame_", "").replace(".jpg", ""))
frames[idx] = os.path.join(output_dir, f)
if not frames:
raise RuntimeError(f"No frames extracted from {video_path}")
return frames
def uniform_sample_indices(total_frames: int, num_samples: int) -> List[int]:
"""Uniformly sample 0-based frame indices. total_frames is the count."""
if total_frames <= num_samples:
return list(range(total_frames))
return sorted(set(int(i) for i in np.linspace(0, total_frames - 1, num_samples)))
def sample_interval(all_frames: Dict[int, str], start: int, end: int,
num_samples: int = 8) -> List[Tuple[int, str]]:
"""Uniformly sample frames from [start, end] interval."""
available = sorted(k for k in all_frames if start <= k <= end)
if not available:
return []
if len(available) <= num_samples:
return [(k, all_frames[k]) for k in available]
positions = np.linspace(0, len(available) - 1, num_samples, dtype=int)
selected = sorted(set(available[p] for p in positions))
return [(k, all_frames[k]) for k in selected]
def get_frame(all_frames: Dict[int, str], idx: int) -> Tuple[int, str]:
"""Get exact frame or nearest available."""
if idx in all_frames:
return (idx, all_frames[idx])
nearest = min(all_frames.keys(), key=lambda k: abs(k - idx))
return (nearest, all_frames[nearest])
# ════════════════════════════════════════════════════════════════════════
# Bbox Format Detection & Normalization
# ════════════════════════════════════════════════════════════════════════
def get_bbox_config(model: str) -> dict:
"""根据模型名称返回对应的 bbox 格式配置。
匹配规则: model 名称中包含 pattern 的最长匹配优先。
如果没有匹配, 返回 default 配置。
Returns:
{"order": "xyxy"|"yxyx", "range": "norm"|"permille"}
"""
model_lower = model.lower()
best_match = ""
best_config = BBOX_CONFIGS.get("default", {"order": "xyxy", "range": "norm"})
for pattern, config in BBOX_CONFIGS.items():
if pattern == "default":
continue
if pattern.lower() in model_lower and len(pattern) > len(best_match):
best_match = pattern
best_config = config
return best_config
def normalize_bbox(raw_bbox: list, bbox_config: dict) -> list:
"""将模型输出的 bbox 统一转换为 [x1, y1, x2, y2] 且值域 [0.0, 1.0]。
支持的输入格式:
- Gemini: [y_min, x_min, y_max, x_max] in [0, 1000]
- 标准: [x1, y1, x2, y2] in [0.0, 1.0]
- 像素: [x1, y1, x2, y2] 绝对像素 (自动降级)
Args:
raw_bbox: 长度为 4 的列表
bbox_config: get_bbox_config() 返回的配置字典
Returns:
[x1, y1, x2, y2] 全部归一化到 [0.0, 1.0]
"""
if len(raw_bbox) != 4:
return [0.0, 0.0, 1.0, 1.0]
# Step 1: 解析坐标顺序
if bbox_config.get("order") == "yxyx":
# Gemini: [y_min, x_min, y_max, x_max]
y1, x1, y2, x2 = raw_bbox
else:
# 标准: [x1, y1, x2, y2]
x1, y1, x2, y2 = raw_bbox
# Step 2: 归一化值域到 [0.0, 1.0]
if bbox_config.get("range") == "permille":
# [0, 1000] → [0.0, 1.0]
x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0
else:
# 自动检测: 如果最大值 > 1.0 但 <= 1000, 按 permille 处理
max_coord = max(abs(x1), abs(y1), abs(x2), abs(y2))
if max_coord > 1.0 and max_coord <= 1000:
x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0
elif max_coord > 1000:
# 绝对像素坐标 — 无法在此归一化, 返回原始值让 crop_frame 处理
# (crop_frame 内部有 auto-detect 逻辑)
pass
# Step 3: Clamp 到 [0.0, 1.0]
x1 = max(0.0, min(1.0, float(x1)))
y1 = max(0.0, min(1.0, float(y1)))
x2 = max(0.0, min(1.0, float(x2)))
y2 = max(0.0, min(1.0, float(y2)))
# Step 4: 确保 x1 < x2, y1 < y2
if x1 > x2:
x1, x2 = x2, x1
if y1 > y2:
y1, y2 = y2, y1
# Step 5: 保证最小面积
if x2 - x1 < 0.001:
x2 = min(1.0, x1 + 0.01)
if y2 - y1 < 0.001:
y2 = min(1.0, y1 + 0.01)
return [x1, y1, x2, y2]
# ════════════════════════════════════════════════════════════════════════
# Image Search Failure Handling & Padding
# ════════════════════════════════════════════════════════════════════════
class RetrieverDownError(Exception):
"""Retriever 服务(IP/端口)连接失败,应立即停止整个 pipeline。"""
pass
class ImageSearchFailedError(Exception):
"""Raised when image_search returns an error, timeout, or no results.
Signals the entire entry should be retried from scratch so that
failed serper results never appear in the SFT dataset."""
pass
class QuotaExhaustedError(Exception):
"""API 额度真正用尽 (非临时限流),应立即停止整个 pipeline。"""
pass
class ProhibitedContentError(Exception):
"""内容被 Vertex AI 安全策略拦截 (PROHIBITED_CONTENT),不应重试。"""
pass
class ProjectDisabledError(Exception):
"""某个 Vertex project 不可用,应从本轮 project 池中禁用。"""
pass
def _is_quota_exhausted(status_code: int, error_body: str) -> bool:
"""判断 API 错误是否为额度真正用尽 (区别于临时 rate limit)。
Vertex AI 的 429 / RESOURCE_EXHAUSTED 通常是瞬时限流(QPM/TPM 打满),
应该重试而非停止。只有 403 + 明确的配额/账单关键词才认为是真正的额度用尽。
参考: https://docs.cloud.google.com/docs/quotas/troubleshoot
- 429 → 临时限流,退避重试
- 403 + QUOTA_EXCEEDED / RATE_LIMIT_EXCEEDED / billing → 真正额度用尽
"""
body_lower = error_body.lower()
# 429 一律视为临时限流,不算额度用尽
if status_code == 429:
return False
# 403 + 明确配额/账单信号 → 真正额度用尽
if status_code == 403:
quota_keywords = [
"quota_exceeded",
"quota exceeded",
"rate_limit_exceeded",
"billing account",
"billing is disabled",
"insufficient quota",
"out of quota",
"daily limit",
"per-day limit",
]
if any(kw in body_lower for kw in quota_keywords):
return True
return False
def is_image_search_failed(result_text: str) -> bool:
"""Check if an image search result indicates a failure.
Returns True for any error, timeout, or empty-result response
from the Serper Lens API.
"""
if not result_text or not result_text.strip():
return True
fail_patterns = [
"Image search error:",
"No results found from reverse image search",
"Error: SERPER_API_KEY not configured",
"request timed out",
]
for pattern in fail_patterns:
if pattern in result_text:
return True
return False
def add_search_padding(bbox: List[float], frame_path: str,
padding: tuple = (0.5, 0.5),
padding_cap_px: int = 600) -> List[float]:
"""Add padding to a normalized [x1, y1, x2, y2] bbox for image search.
Padding is proportional to the BBOX size (not image size), so that:
- A tight face crop gets moderate expansion (include shoulders, some background)
- A large crop doesn't balloon to cover the entire frame
Args:
bbox: [x1, y1, x2, y2] normalized to [0.0, 1.0]
frame_path: path to the source frame (used to compute pixel cap)
padding: (pad_x_ratio, pad_y_ratio) as fraction of bbox width/height
e.g. (0.5, 0.5) means expand each side by 50% of bbox dimension
padding_cap_px: max padding in pixels on each side (prevents excessive
expansion on very large bboxes)
Returns:
Padded [x1, y1, x2, y2] clamped to [0.0, 1.0]
"""
x1, y1, x2, y2 = bbox
pad_x_ratio, pad_y_ratio = padding
# Padding proportional to bbox dimensions
bbox_w = x2 - x1
bbox_h = y2 - y1
pad_x = bbox_w * pad_x_ratio
pad_y = bbox_h * pad_y_ratio
# Cap padding at padding_cap_px pixels (convert to normalized coords)
try:
with Image.open(frame_path) as img:
img_w, img_h = img.size
cap_x = padding_cap_px / img_w
cap_y = padding_cap_px / img_h
pad_x = min(pad_x, cap_x)
pad_y = min(pad_y, cap_y)
except Exception:
pass # If we can't read dimensions, just use the bbox-proportional padding
x1 = max(0.0, x1 - pad_x)
y1 = max(0.0, y1 - pad_y)
x2 = min(1.0, x2 + pad_x)
y2 = min(1.0, y2 + pad_y)
return [x1, y1, x2, y2]
def crop_frame(frame_path: str, bbox: List[float], output_path: str) -> str:
"""Crop frame at bbox coordinates with smart format detection, clamping,
and 2x upscaling for better search quality.
Bbox format auto-detection:
- [0.0, 1.0] range → normalized relative coordinates (standard)
- [0, 1000] range → permille relative coordinates (some models output this)
- values > 1000 → absolute pixel coordinates
After cropping:
- 2x LANCZOS upscale to help visual search models recognize small objects
- Save as high-quality JPEG (quality=95) to minimize compression artifacts
Args:
frame_path: path to the source frame image
bbox: [x1, y1, x2, y2] in any of the three supported formats
output_path: where to save the cropped image
Returns:
output_path
"""
with Image.open(frame_path) as img:
w, h = img.size
raw_x1, raw_y1, raw_x2, raw_y2 = bbox
# ── Step 1: Auto-detect coordinate format and convert to pixels ──
max_coord = max(abs(raw_x1), abs(raw_y1), abs(raw_x2), abs(raw_y2))
if max_coord <= 1.0:
# Format A: normalized [0.0, 1.0]
px_x1 = raw_x1 * w
px_y1 = raw_y1 * h
px_x2 = raw_x2 * w
px_y2 = raw_y2 * h
elif max_coord <= 1000:
# Format B: permille [0, 1000]
px_x1 = raw_x1 / 1000.0 * w
px_y1 = raw_y1 / 1000.0 * h
px_x2 = raw_x2 / 1000.0 * w
px_y2 = raw_y2 / 1000.0 * h
else:
# Format C: absolute pixel coordinates
px_x1 = raw_x1
px_y1 = raw_y1
px_x2 = raw_x2
px_y2 = raw_y2
# ── Step 2: Clamp to image bounds ──
x1 = max(0, min(int(round(px_x1)), w - 1))
y1 = max(0, min(int(round(px_y1)), h - 1))
x2 = max(0, min(int(round(px_x2)), w))
y2 = max(0, min(int(round(px_y2)), h))
# Ensure minimum 1px crop (prevent zero-area)
if x2 <= x1:
x2 = min(x1 + 1, w)
if y2 <= y1:
y2 = min(y1 + 1, h)
# ── Step 3: Crop ──
cropped_img = img.crop((x1, y1, x2, y2))
# ── Step 4: 2x upscale with LANCZOS for better search recognition ──
cropped_img = cropped_img.resize(
(cropped_img.width * 2, cropped_img.height * 2),
Image.Resampling.LANCZOS,
)
# ── Step 5: Save as high-quality JPEG ──
cropped_img.save(output_path, 'JPEG', quality=95)
return output_path
def encode_image_b64(path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("ascii")
# ════════════════════════════════════════════════════════════════════════
# LLM Client with Multi-API Load Balancing (v3 — Token Tracking)
# ════════════════════════════════════════════════════════════════════════
class LLMClient:
"""Async LLM client with round-robin load balancing across API endpoints.
v3: call() now returns (content, reasoning_content, TokenUsage).
v4: Google Vertex AI authentication support (OAuth2 auto-refresh).
v5: Vertex AI Native generateContent API support (for global region).
"""
def __init__(self, api_urls: List[str], model: str,
concurrency_per_url: int = 2,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
max_tokens: int = DEFAULT_MAX_TOKENS,
timeout: int = DEFAULT_TIMEOUT,
max_retries: int = DEFAULT_MAX_RETRIES,
api_key: str = API_KEY):
self.api_urls = api_urls
self.model = model
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
self.timeout = timeout
self.max_retries = max_retries
self.api_key = api_key
self._url_cycle = itertools.cycle(api_urls)
self._semaphores = {url: asyncio.Semaphore(concurrency_per_url)
for url in api_urls}
self._cycle_lock = asyncio.Lock()
# ── Project 池:429 时自动切换 ──
self._active_url_index = 0 # 当前活跃的 URL 索引
self._url_switch_lock = asyncio.Lock() # 切换锁
self._project_cooldown_until = {url: 0.0 for url in api_urls}
self._disabled_urls: Set[str] = set()
# ── Vertex 请求整形:降低秒级 burst 导致的 429 ──
self._vertex_min_interval = max(0.0, float(VERTEX_MIN_REQUEST_INTERVAL_SECONDS or 0.0))
self._vertex_rate_limit_cooldown = max(0.0, float(VERTEX_RATE_LIMIT_COOLDOWN_SECONDS or 0.0))
self._vertex_request_jitter = max(0.0, float(VERTEX_REQUEST_JITTER_SECONDS or 0.0))
self._vertex_next_request_at = 0.0
self._vertex_pacing_lock = asyncio.Lock()
# ── 检测 API 模式 ──
self._use_vertex_auth = self._detect_vertex_endpoint()
self._use_vertex_native = self._detect_vertex_native()
self._gcp_credentials = None
self._gcp_credentials_by_url: Dict[str, Any] = {}
self._gcp_auth_request = None
if self._use_vertex_auth:
self._init_gcp_auth()
def _detect_vertex_endpoint(self) -> bool:
"""检测是否使用 Vertex AI endpoint(根据 URL 判断)。"""
for url in self.api_urls:
if "aiplatform.googleapis.com" in url:
return True
return False
def _detect_vertex_native(self) -> bool:
"""检测是否使用 Vertex AI 原生 generateContent API。
判断依据:URL 中包含 /publishers/google/models(原生)
而非 /endpoints/openapi/chat/completions(OpenAI 兼容)
"""
for url in self.api_urls:
if "/publishers/google/models" in url:
return True
# 也检查 config 中的标志
try:
use_native = getattr(config, 'USE_VERTEX_NATIVE_API', False)
if use_native:
return True
except Exception:
pass
return False
def _init_gcp_auth(self):
"""初始化 Google Cloud 认证凭据。"""
try:
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account as sa_module
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self._gcp_auth_request = google.auth.transport.requests.Request()
pool_records = []
for item in VERTEX_CREDENTIALS_POOL:
if not isinstance(item, dict):
continue
api_url = item.get("api_url")
if api_url in self.api_urls:
pool_records.append(item)
if pool_records:
for item in pool_records:
credential = None
key_path = item.get("service_account_key") or ""
key_info = item.get("service_account_info")
if key_path and os.path.exists(key_path):
credential = sa_module.Credentials.from_service_account_file(
key_path, scopes=scopes,
)
print(
f"[GCP AUTH] 使用账号池密钥: account={item.get('account_name', '')} "
f"project={item.get('project_id', '')} key={key_path}"
)
elif isinstance(key_info, dict):
credential = sa_module.Credentials.from_service_account_info(
key_info, scopes=scopes,
)
print(
f"[GCP AUTH] 使用账号池内嵌密钥: account={item.get('account_name', '')} "
f"project={item.get('project_id', '')}"
)
else:
raise RuntimeError(
f"Vertex 账号池条目缺少可用凭证: project={item.get('project_id', '')}"
)
credential.refresh(self._gcp_auth_request)
self._gcp_credentials_by_url[item["api_url"]] = credential
print(f"[GCP AUTH] 账号池认证成功,可用条目数: {len(self._gcp_credentials_by_url)}")
return
if GCP_SERVICE_ACCOUNT_KEY and os.path.exists(GCP_SERVICE_ACCOUNT_KEY):
# 使用服务账号密钥文件
self._gcp_credentials = sa_module.Credentials.from_service_account_file(
GCP_SERVICE_ACCOUNT_KEY, scopes=scopes,
)
print(f"[GCP AUTH] 使用服务账号密钥: {GCP_SERVICE_ACCOUNT_KEY}")
else:
# 使用 Application Default Credentials (ADC)
# 需要先运行 `gcloud auth application-default login`
self._gcp_credentials, project = google.auth.default(scopes=scopes)
print(f"[GCP AUTH] 使用 ADC (Application Default Credentials), project={project}")
# 预先刷新一次,确认凭据有效
self._gcp_credentials.refresh(self._gcp_auth_request)
print(f"[GCP AUTH] 认证成功,token 有效期至 {self._gcp_credentials.expiry}")
except ImportError:
raise ImportError(
"使用 Vertex AI 需要安装 google-auth 库:\n"
" pip install google-auth google-auth-httplib2"
)
except Exception as e:
raise RuntimeError(
f"Google Cloud 认证失败: {e}\n"
f"请确保已配置 GCP_SERVICE_ACCOUNT_KEY 或运行 "
f"`gcloud auth application-default login`"
)
def _get_gcp_token(self, api_url: Optional[str] = None) -> str:
"""获取有效的 GCP OAuth2 access token(自动刷新过期 token)。"""
if api_url and api_url in self._gcp_credentials_by_url:
credentials = self._gcp_credentials_by_url[api_url]
if credentials.expired or not credentials.token:
credentials.refresh(self._gcp_auth_request)
return credentials.token
if self._gcp_credentials.expired or not self._gcp_credentials.token:
self._gcp_credentials.refresh(self._gcp_auth_request)
return self._gcp_credentials.token
async def _next_url(self) -> str:
async with self._cycle_lock:
return next(self._url_cycle)
def _get_active_url(self) -> str:
"""获取当前活跃的 project URL(429 切换后会变化)。"""
return self.api_urls[self._active_url_index]
async def _acquire_vertex_request_slot(self):
"""对 Vertex 请求做全局节流,避免共享池秒级 burst。"""
if not self._use_vertex_auth or self._vertex_min_interval <= 0:
return
while True:
async with self._vertex_pacing_lock:
now = time.monotonic()
wait = self._vertex_next_request_at - now
if wait <= 0:
reserve = self._vertex_min_interval + random.uniform(0, self._vertex_request_jitter)
self._vertex_next_request_at = now + reserve
return
if wait > 0.5:
print(f" [VERTEX PACE] waiting {wait:.1f}s before next request")
await asyncio.sleep(max(wait, 0.05))
async def _mark_project_cooldown(self, failed_url: str):
"""某个 project 返回 429 后,短时间内不要立即打回去。"""
if self._vertex_rate_limit_cooldown <= 0:
return
cooldown = self._vertex_rate_limit_cooldown + random.uniform(0, self._vertex_request_jitter)
async with self._url_switch_lock:
until_ts = time.monotonic() + cooldown
prev = self._project_cooldown_until.get(failed_url, 0.0)
self._project_cooldown_until[failed_url] = max(prev, until_ts)
print(f" [PROJECT POOL] cooldown {cooldown:.1f}s for {failed_url[:80]}...")
async def _disable_project_url(self, failed_url: str, reason: str):
"""禁用 suspended / API-disabled / permission-denied 的 project。"""
async with self._url_switch_lock:
self._disabled_urls.add(failed_url)
print(f" [PROJECT POOL] disabled project: {failed_url[:80]}... reason={reason[:160]}")
if len(self._disabled_urls) >= len(self.api_urls):
raise ProjectDisabledError("All Vertex projects are disabled or suspended.")
if self.api_urls[self._active_url_index] == failed_url:
for idx, url in enumerate(self.api_urls):
if url not in self._disabled_urls:
self._active_url_index = idx
break
async def _acquire_project_url(self) -> Tuple[str, int]:
"""选择一个当前未处于冷却期的 project;如果都在冷却,则等待最早恢复的那个。"""
while True:
async with self._url_switch_lock:
now = time.monotonic()
count = len(self.api_urls)
for offset in range(count):
idx = (self._active_url_index + offset) % count
url = self.api_urls[idx]
if url in self._disabled_urls:
continue
if self._project_cooldown_until.get(url, 0.0) <= now:
self._active_url_index = idx
return url, idx
available_urls = [u for u in self.api_urls if u not in self._disabled_urls]
if not available_urls:
raise ProjectDisabledError("All Vertex projects are disabled or suspended.")
soonest_url = min(available_urls, key=lambda u: self._project_cooldown_until.get(u, 0.0))
wait = max(0.0, self._project_cooldown_until.get(soonest_url, 0.0) - now)
print(f" [PROJECT POOL] all projects cooling down, waiting {wait:.1f}s")
await asyncio.sleep(max(wait, 0.1))
async def _switch_to_next_project(self, failed_index: int) -> bool:
"""429 时切换到下一个 project。返回 True 如果成功切换到不同的 project。"""
async with self._url_switch_lock:
# 可能已经被其他协程切换过了
if self._active_url_index != failed_index:
return True # 已经切换了,直接重试
now = time.monotonic()
count = len(self.api_urls)
for offset in range(1, count + 1):
next_index = (failed_index + offset) % count
next_url = self.api_urls[next_index]
if next_url in self._disabled_urls:
continue
if self._project_cooldown_until.get(next_url, 0.0) <= now:
self._active_url_index = next_index
print(f" [PROJECT POOL] 429 → 切换到 project #{next_index + 1}/{len(self.api_urls)}: {next_url[:80]}...")
return True
return False
def _all_projects_cooling_down(self) -> bool:
now = time.monotonic()
active_urls = [url for url in self.api_urls if url not in self._disabled_urls]
return bool(active_urls) and all(self._project_cooldown_until.get(url, 0.0) > now for url in active_urls)
@staticmethod
def _is_project_disabled_error(err_str: str) -> bool:
err_lower = err_str.lower()
return any(marker in err_lower for marker in (
"consumer_suspended",
"has been suspended",
"service_disabled",
"api has not been used",
"api is disabled",
))
# ════════════════════════════════════════════════════════════════
# OpenAI → Vertex Native 消息格式转换
# ════════════════════════════════════════════════════════════════
@staticmethod
def _convert_messages_to_vertex_native(
messages: List[Dict],
) -> Tuple[Optional[Dict], List[Dict]]:
"""将 OpenAI 格式的 messages 转换为 Vertex AI 原生格式。
OpenAI 格式:
[{"role": "system", "content": "..."},
{"role": "user", "content": "..." | [{"type":"text",...}, {"type":"image_url",...}]},
{"role": "assistant", "content": "..."}]
Vertex 原生格式:
systemInstruction: {"parts": [{"text": "..."}]}
contents: [
{"role": "user", "parts": [{"text": "..."}, {"inlineData": {...}}]},
{"role": "model", "parts": [{"text": "..."}]}
]
Returns: (system_instruction_dict_or_None, contents_list)
"""
system_instruction = None
contents = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# 提取 system prompt → systemInstruction
if isinstance(content, str):
system_instruction = {"parts": [{"text": content}]}
elif isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append({"text": item})
elif isinstance(item, dict) and item.get("type") == "text":
parts.append({"text": item.get("text", "")})
system_instruction = {"parts": parts}
continue
# 角色映射: assistant → model
vertex_role = "model" if role == "assistant" else "user"
# 转换 content → parts
parts = []
if isinstance(content, str):
if content.strip():
parts.append({"text": content})
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
if item.strip():
parts.append({"text": item})
elif isinstance(item, dict):
item_type = item.get("type", "")
if item_type == "text":
text_val = item.get("text", "")
if text_val.strip():
parts.append({"text": text_val})
elif item_type == "image_url":
# OpenAI: {"type":"image_url","image_url":{"url":"data:image/jpeg;base64,..."}}
image_url = item.get("image_url", {})
url = image_url.get("url", "") if isinstance(image_url, dict) else ""
if url.startswith("data:"):
# 解析 data URI: data:image/jpeg;base64,xxxx
# 提取 mimeType 和 base64 数据
try:
header, b64_data = url.split(",", 1)
# header = "data:image/jpeg;base64"
mime_type = header.split(":")[1].split(";")[0]
parts.append({
"inlineData": {
"mimeType": mime_type,
"data": b64_data,
}
})
except (ValueError, IndexError):
# 解析失败,跳过
pass
elif url.startswith("gs://"):
# GCS URI
parts.append({
"fileData": {
"fileUri": url,
"mimeType": "image/jpeg",
}
})
if parts:
contents.append({"role": vertex_role, "parts": parts})
return system_instruction, contents
@staticmethod
def _parse_vertex_native_response(
result: Dict[str, Any],
) -> Tuple[str, str, TokenUsage]:
"""解析 Vertex AI 原生 generateContent 响应。
Vertex 响应格式:
{
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "..."}, ...]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150,
"thoughtsTokenCount": 20 // 可选,thinking tokens
}
}
Returns: (content_text, reasoning_text, TokenUsage)
"""
# 检测内容安全拦截 — 不应重试
prompt_feedback = result.get("promptFeedback", {})
block_reason = prompt_feedback.get("blockReason", "")
if block_reason == "PROHIBITED_CONTENT":
raise ProhibitedContentError(
f"Content blocked by Vertex safety filter: {block_reason}")
candidates = result.get("candidates", [])
if not candidates:
raise ValueError(f"Empty candidates in Vertex response: {json.dumps(result)[:300]}")
candidate = candidates[0]
parts = candidate.get("content", {}).get("parts", [])
content_text = ""
reasoning_text = ""
for part in parts:
if "text" in part:
# 检查是否是 thought/reasoning part
# Vertex 原生 API 中 thinking 内容可能在 thought 字段
if part.get("thought", False):
reasoning_text += part["text"]
else:
content_text += part["text"]
# 解析 token 使用量
usage_meta = result.get("usageMetadata", {})
prompt_tokens = usage_meta.get("promptTokenCount", 0) or 0
completion_tokens = usage_meta.get("candidatesTokenCount", 0) or 0
total_tokens = usage_meta.get("totalTokenCount", 0) or (prompt_tokens + completion_tokens)
thinking_tokens = usage_meta.get("thoughtsTokenCount", 0) or 0
cached_tokens = usage_meta.get("cachedContentTokenCount", 0) or 0
usage = TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
thinking_tokens=thinking_tokens,
cached_tokens=cached_tokens,
)
return content_text, reasoning_text, usage
# ════════════════════════════════════════════════════════════════
# 主调用方法
# ════════════════════════════════════════════════════════════════
async def call(
self,
messages: List[Dict],
session: aiohttp.ClientSession,
temperature: float = None,
max_tokens: int = None,
) -> Tuple[str, str, TokenUsage]:
"""Call LLM with load balancing and project pool failover.
自动根据 API 模式选择 OpenAI 兼容 或 Vertex Native 格式。
遇到 429 限流时自动切换到下一个 project URL 重试。
Returns (content, reasoning_content, TokenUsage).
"""
last_err = None
start_index = self._active_url_index
# 最多尝试所有 project
for _pool_attempt in range(len(self.api_urls)):
base_url, url_index = await self._acquire_project_url()
sem = self._semaphores[base_url]
try:
if self._use_vertex_native:
return await self._call_vertex_native(
messages, session, base_url, sem, temperature, max_tokens
)
else:
return await self._call_openai_compat(
messages, session, base_url, sem, temperature, max_tokens
)
except Exception as e:
err_str = str(e)
is_rate_limit = (
"429" in err_str
or "RESOURCE_EXHAUSTED" in err_str
or "rate limit" in err_str.lower()
)
if self._is_project_disabled_error(err_str):
last_err = e
await self._disable_project_url(base_url, err_str)
continue
if is_rate_limit and len(self.api_urls) > 1:
last_err = e
await self._mark_project_cooldown(base_url)
await self._switch_to_next_project(url_index)
# 如果切换后回到起点,说明所有 project 都 429 了
if self._all_projects_cooling_down():
print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 正在 cooldown")
raise
if self._active_url_index == start_index:
print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 均已 429,上抛异常")
raise
continue # 用新 project 重试
# 非 429 错误 → 直接抛出
raise
# 所有 project 都 429
raise last_err or RuntimeError("All projects in pool exhausted (429)")
async def _call_openai_compat(
self,
messages: List[Dict],
session: aiohttp.ClientSession,
api_url: str,
sem: asyncio.Semaphore,
temperature: float = None,
max_tokens: int = None,
) -> Tuple[str, str, TokenUsage]:
"""OpenAI 兼容端点调用(原有逻辑,保持不变)。"""
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature or self.temperature,
"top_p": self.top_p,
"max_tokens": max_tokens or self.max_tokens,
}
for attempt in range(self.max_retries):
try:
headers = {"Content-Type": "application/json"}
if self._use_vertex_auth:
token = self._get_gcp_token(api_url)
headers["Authorization"] = f"Bearer {token}"
elif self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async with sem:
await self._acquire_vertex_request_slot()
async with session.post(
api_url, json=payload, headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as resp:
if resp.status != 200:
error_body = await resp.text()
if _is_quota_exhausted(resp.status, error_body):
raise QuotaExhaustedError(
f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}")
raise RuntimeError(
f"HTTP {resp.status} from {api_url}: {error_body[:500]}")
result = await resp.json()
choices = result.get("choices", [])
if not choices:
raise ValueError(f"Empty choices: {json.dumps(result)[:300]}")
message = choices[0].get("message", {})
content = message.get("content", "") or ""
reasoning = message.get("reasoning_content", "") or ""
# Parse token usage
usage_data = result.get("usage", {})
usage = TokenUsage.from_api_response(usage_data)
return content, reasoning, usage
except QuotaExhaustedError:
raise # 额度用尽,不重试,直接上抛
except ProhibitedContentError:
raise # 内容被安全策略拦截,不重试,直接上抛
except ProjectDisabledError:
raise
except Exception as e:
if attempt == self.max_retries - 1:
raise
err_str = str(e)
if self._is_project_disabled_error(err_str):
raise ProjectDisabledError(err_str)
if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower():
if len(self.api_urls) > 1:
# 有多个 project,立即抛出让 call() 切换 project
raise
# 单 project 模式:退避重试
wait = (15 * (2 ** attempt)) + random.uniform(0, 5)
print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})")
else:
wait = 2 ** attempt
await asyncio.sleep(wait)
raise RuntimeError("Unreachable")
async def _call_vertex_native(
self,
messages: List[Dict],
session: aiohttp.ClientSession,
base_url: str,
sem: asyncio.Semaphore,
temperature: float = None,
max_tokens: int = None,
) -> Tuple[str, str, TokenUsage]:
"""Vertex AI 原生 generateContent 调用。
base_url 格式: https://aiplatform.googleapis.com/v1/projects/{P}/locations/{L}/publishers/google/models
实际请求 URL = base_url/{model}:generateContent
"""
# 拼接完整 URL
# model 名中可能有 "google/" 前缀(从 OpenAI 兼容迁移过来),需要去掉
model_name = self.model
if model_name.startswith("google/"):
model_name = model_name[len("google/"):]
full_url = f"{base_url}/{model_name}:generateContent"
# 转换消息格式
system_instruction, contents = self._convert_messages_to_vertex_native(messages)
# 构建请求体
payload: Dict[str, Any] = {
"contents": contents,
"generationConfig": {
"temperature": temperature or self.temperature,
"topP": self.top_p,
"maxOutputTokens": max_tokens or self.max_tokens,
},
}
if system_instruction:
payload["systemInstruction"] = system_instruction
for attempt in range(self.max_retries):
try:
headers = {"Content-Type": "application/json"}
if self._use_vertex_auth:
token = self._get_gcp_token(base_url)
headers["Authorization"] = f"Bearer {token}"
async with sem:
await self._acquire_vertex_request_slot()
async with session.post(
full_url, json=payload, headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as resp:
if resp.status != 200:
error_body = await resp.text()
if _is_quota_exhausted(resp.status, error_body):
raise QuotaExhaustedError(
f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}")
raise RuntimeError(
f"HTTP {resp.status} from {full_url}: {error_body[:500]}")
result = await resp.json()
content, reasoning, usage = self._parse_vertex_native_response(result)
return content, reasoning, usage
except QuotaExhaustedError:
raise # 额度用尽,不重试,直接上抛
except ProhibitedContentError:
raise # 内容被安全策略拦截,不重试,直接上抛
except ProjectDisabledError:
raise
except Exception as e:
if attempt == self.max_retries - 1:
raise
err_str = str(e)
if self._is_project_disabled_error(err_str):
raise ProjectDisabledError(err_str)
if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower():
if len(self.api_urls) > 1:
# 有多个 project,立即抛出让 call() 切换 project
raise
# 单 project 模式:退避重试
wait = (15 * (2 ** attempt)) + random.uniform(0, 5)
print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})")
else:
wait = 2 ** attempt
print(f" [WARN] Vertex native call failed (attempt {attempt+1}): {e}")
await asyncio.sleep(wait)
raise RuntimeError("Unreachable")
# ════════════════════════════════════════════════════════════════════════
# Message Builders (OpenAI format)
# ════════════════════════════════════════════════════════════════════════
def build_user_message(text: str, image_paths: List[str] = None) -> Dict:
"""Build user message with optional base64 images."""
if not image_paths:
return {"role": "user", "content": text}
content = []
for path in image_paths:
b64 = encode_image_b64(path)
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{b64}"}
})
content.append({"type": "text", "text": text})
return {"role": "user", "content": content}
def build_assistant_message(text: str) -> Dict:
return {"role": "assistant", "content": text}
# ════════════════════════════════════════════════════════════════════════
# Response Parser (FIXED: position-based priority)
# ════════════════════════════════════════════════════════════════════════
def parse_llm_response(text: str) -> Tuple[str, Any]:
"""Parse LLM response to extract the FIRST action by position.
Returns ('tool_call', dict), ('answer', str), or ('error', str).
CRITICAL FIX: Uses position-based priority instead of always preferring
. The first action tag that appears in the text wins.
This prevents hallucinated tags from overriding valid tags.
"""
# Find positions of first tool_call and first answer
tc_m = re.search(r'\s*(.*?)\s*', text, re.DOTALL)
answer_m = re.search(r'\s*(.*?)\s*', text, re.DOTALL)
if tc_m and answer_m:
# Both found — use whichever comes FIRST in the text
if tc_m.start() < answer_m.start():
try:
return ("tool_call", json.loads(tc_m.group(1)))
except json.JSONDecodeError:
return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}")
else:
return ("answer", answer_m.group(1).strip())
if tc_m:
try:
return ("tool_call", json.loads(tc_m.group(1)))
except json.JSONDecodeError:
return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}")
if answer_m:
return ("answer", answer_m.group(1).strip())
return ("error", "No or found in response")
# ════════════════════════════════════════════════════════════════════════
# Search Utilities
# ════════════════════════════════════════════════════════════════════════
# ── Image Search Cache ──
class ImageSearchCache:
"""MD5-keyed cache for image search results."""
def __init__(self, cache_file: str):
self.cache_file = cache_file
self.cache = self._load()
self._dirty_count = 0
def _load(self) -> Dict[str, str]:
if os.path.exists(self.cache_file):
try:
with open(self.cache_file, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
return {}
def save(self):
with open(self.cache_file, "w", encoding="utf-8") as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
self._dirty_count = 0
def get(self, image_bytes: bytes) -> Optional[str]:
key = hashlib.md5(image_bytes).hexdigest()
return self.cache.get(key)
def set(self, image_bytes: bytes, result: str):
key = hashlib.md5(image_bytes).hexdigest()
self.cache[key] = result
self._dirty_count += 1
if self._dirty_count >= 10:
self.save()
# ── Mock Search Implementations ──
def mock_image_search(entity: str, bbox: List[float]) -> str:
"""Mock image search result for offline testing."""
return (
f"Reverse Image Search Results:\n\n"
f"Result 1:\n"
f" Title: {entity} - Character Profile\n"
f" Snippet: {entity} is a well-known character/entity.\n"
f" URL: https://example.com/{entity.lower().replace(' ', '_')}\n\n"
f"Result 2:\n"
f" Title: {entity} | Wiki\n"
f" Snippet: Detailed information about {entity}.\n"
f" URL: https://wiki.example.com/{entity.lower().replace(' ', '_')}"
)
def mock_web_search(query: str) -> str:
"""Mock web search result for offline testing."""
return (
f'Web Search Results for "{query}":\n\n'
f"Quick Answer: Information related to the query.\n\n"
f"Result 1: {query} - Overview\n"
f" Relevant information about {query}.\n\n"
f"Result 2: {query} - Details\n"
f" Additional details and facts."
)
# ── Real Search Implementations ──
def _format_serper_lens_results(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str:
"""Format raw Serper Lens organic results into readable text."""
parts = []
for i, item in enumerate(organic_results[:max_results], 1):
title = item.get('title', '')
snippet = item.get('snippet', '')
link = item.get('link', '')
source = item.get('source', '') or item.get('domain', '')
block = [f"Result {i}:"]
if title:
block.append(f" Title: {title}")
if snippet:
block.append(f" Snippet: {snippet}")
if source:
block.append(f" Source: {source}")
if link:
block.append(f" URL: {link}")
parts.append("\n".join(block))
return "\n\n".join(parts)
def _build_serper_lens_summary_prompt(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str:
"""Build English summarizer prompt using only Serper Lens result metadata."""
context_parts = []
for i, item in enumerate(organic_results[:max_results], 1):
title = item.get('title', '')
snippet = item.get('snippet', '')
link = item.get('link', '')
source = item.get('source', '') or item.get('domain', '')
block = [f"Result {i}:"]
if title:
block.append(f"Title: {title}")
if snippet:
block.append(f"Snippet: {snippet}")
if source:
block.append(f"Source: {source}")
if link:
block.append(f"Link: {link}")
context_parts.append("\n".join(block))
context_text = "\n\n".join(context_parts)
return (
"You are a helpful assistant. Your task is to summarize the main content of the given "
"Serper Lens reverse image search results in no more than five sentences.\n\n"
"Your summary should cover the overall key points across the results, not just the parts "
"most related to the user's question.\n\n"
"If any part of the results is helpful for identifying the entity or answering the user's "
"question, include it clearly in the summary. Do not ignore relevant information, but make "
"sure the general structure and main ideas of the results are preserved.\n\n"
"Your summary should be concise, factual, and informative. If the results are ambiguous, "
"conflicting, or insufficient, clearly state that uncertainty.\n\n"
"Use only the provided result titles, snippets, and source/link metadata. Do not invent facts "
"and do not assume content from the linked pages.\n\n"
f"{context_text}"
)
async def summarize_serper_image_results(
organic_results: List[Dict[str, Any]],
session: aiohttp.ClientSession,
summarizer_address: str = "",
summarizer_model: str = "",
max_results: int = 5,
max_tokens: int = 512,
) -> Optional[str]:
"""Summarize Serper Lens results without fetching linked webpages."""
summarizer_addr = summarizer_address or IMAGE_SEARCH_SUMMARIZER_ADDRESS
sum_model = summarizer_model or IMAGE_SEARCH_SUMMARIZER_MODEL
if not organic_results or not summarizer_addr or not sum_model:
return None
summarizer_prompt = _build_serper_lens_summary_prompt(organic_results, max_results=max_results)
summarizer_payload = {
"model": sum_model,
"messages": [{"role": "user", "content": summarizer_prompt}],
"max_tokens": max_tokens,
"temperature": 0.3,
"chat_template_kwargs": {"enable_thinking": False},
}
try:
async with session.post(
f"http://{summarizer_addr}/v1/chat/completions",
json=summarizer_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=120),
) as resp:
if resp.status != 200:
print(f" [IMAGE_SEARCH] Summarizer returned HTTP {resp.status}, falling back to raw results")
return None
data = await resp.json()
choices = data.get("choices", [])
if choices and isinstance(choices, list):
msg = choices[0].get("message", {})
summary = msg.get("content", "")
if summary and summary.strip():
summary = _strip_thinking_tags(summary).strip()
return summary or None
return None
except asyncio.TimeoutError:
print(" [IMAGE_SEARCH] Summarizer timeout, falling back to raw results")
return None
except Exception as e:
print(f" [IMAGE_SEARCH] Summarizer error: {e}, falling back to raw results")
return None
async def real_image_search(
image_b64_or_path: str,
session: aiohttp.ClientSession,
api_key: str,
crop_path: str = None,
) -> str:
"""反向图片搜索,根据 IMAGE_SEARCH_MODE 选择直连 Serper 或公司内部网关。"""
if IMAGE_SEARCH_MODE == "gateway":
return await _gateway_image_search(image_b64_or_path, session, crop_path)
else:
return await _serper_image_search(image_b64_or_path, session, api_key, crop_path)
async def _gateway_image_search(
image_b64_or_path: str,
session: aiohttp.ClientSession,
crop_path: str = None,
) -> str:
"""反向图片搜索 via 公司内部网关 → Serper Google Lens + optional LLM summarization."""
if not GATEWAY_TOKEN:
return "Error: GATEWAY_TOKEN not configured."
image_url, prep_error = _prepare_image_search_url(
image_b64_or_path, crop_path, "IMAGE_SEARCH/GATEWAY"
)
if prep_error:
return prep_error
headers = {
'Content-Type': 'application/json',
'User-Agent': 'ifbook-http-client',
}
serper_params = {
"url": image_url,
"type": "lens",
}
gateway_payload = {
"sec_info": {
"username": GATEWAY_USERNAME,
"userid": GATEWAY_USERID,
"token": GATEWAY_TOKEN,
},
"model_type": "openai",
"model_name": "serper",
"params": json.dumps(serper_params),
}
max_api_retries = 2
last_error = None
for api_attempt in range(max_api_retries):
try:
async with session.post(
GATEWAY_URL,
headers=headers,
json=gateway_payload,
timeout=aiohttp.ClientTimeout(total=60),
) as resp:
if resp.status != 200:
error_body = await resp.text()
last_error = (f"Gateway error: HTTP {resp.status}: "
f"{error_body[:300]}")
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH/GATEWAY] HTTP {resp.status}, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
gateway_resp = await resp.json()
model_output_str = gateway_resp.get("model_output", "{}")
data = json.loads(model_output_str)
organic = data.get('organic', [])
if not organic:
return "No results found from reverse image search."
raw_results = _format_serper_lens_results(
organic,
max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS,
)
if not IMAGE_SEARCH_SUMMARIZE_SERPER:
return raw_results
summary = await summarize_serper_image_results(
organic,
session,
summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS,
summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL,
max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS,
max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS,
)
if summary:
return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}"
return raw_results
except asyncio.TimeoutError:
last_error = "Image search error: request timed out after 60s"
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH/GATEWAY] Timeout, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
except Exception as e:
last_error = f"Image search error: {e}"
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH/GATEWAY] Error: {e}, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
return last_error or "Image search error: unknown failure"
async def _serper_image_search(
image_b64_or_path: str,
session: aiohttp.ClientSession,
api_key: str,
crop_path: str = None,
) -> str:
"""反向图片搜索 via Serper Google Lens + optional LLM summarization(原始直连方式)。"""
if not api_key:
return "Error: SERPER_API_KEY not configured."
headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'}
image_url, prep_error = _prepare_image_search_url(
image_b64_or_path, crop_path, "IMAGE_SEARCH"
)
if prep_error:
return prep_error
max_api_retries = 2
last_error = None
for api_attempt in range(max_api_retries):
try:
async with session.post(
"https://google.serper.dev/lens",
headers=headers,
json={"url": image_url},
timeout=aiohttp.ClientTimeout(total=60),
) as resp:
if resp.status != 200:
error_body = await resp.text()
last_error = (f"Image search error: HTTP {resp.status}: "
f"{error_body[:300]}")
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH] HTTP {resp.status}, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
data = await resp.json()
organic = data.get('organic', [])
if not organic:
return "No results found from reverse image search."
raw_results = _format_serper_lens_results(
organic,
max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS,
)
if not IMAGE_SEARCH_SUMMARIZE_SERPER:
return raw_results
summary = await summarize_serper_image_results(
organic,
session,
summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS,
summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL,
max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS,
max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS,
)
if summary:
return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}"
return raw_results
except asyncio.TimeoutError:
last_error = "Image search error: request timed out after 60s"
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH] Timeout, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
except Exception as e:
last_error = f"Image search error: {e}"
if api_attempt < max_api_retries - 1:
print(f" [IMAGE_SEARCH] Error: {e}, "
f"retrying ({api_attempt+1}/{max_api_retries})...")
await asyncio.sleep(3 * (api_attempt + 1))
continue
return last_error
return last_error or "Image search error: unknown failure"
async def real_web_search(
query: str,
session: aiohttp.ClientSession,
address: str = WEB_SEARCH_ADDRESS,
) -> str:
"""Web search via internal search server (SenseNova pattern)."""
payload = {
"query": query.strip().replace("\n", " "),
"top_k": 3,
"retrieval_mode": "google_serper",
**WEB_SEARCH_CONFIG,
}
try:
async with session.post(
f"http://{address}/search", json=payload,
timeout=aiohttp.ClientTimeout(total=100),
) as resp:
resp.raise_for_status()
return await resp.text()
except asyncio.TimeoutError:
return f"Error: Web search timeout for query: {query[:100]}"
except Exception as e:
return f"Error: Web search failed: {e}"
async def serper_web_search(
query: str,
session: aiohttp.ClientSession,
api_key: str,
) -> str:
"""Fallback web search via Serper Google Search."""
if not api_key:
return "Error: SERPER_API_KEY not configured."
headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'}
try:
async with session.post(
"https://google.serper.dev/search",
headers=headers,
json={"q": query},
timeout=aiohttp.ClientTimeout(total=20),
) as resp:
resp.raise_for_status()
data = await resp.json()
parts = []
ab = data.get('answerBox', {})
if ab:
answer = ab.get('answer') or ab.get('snippet', '')
if answer:
parts.append(f"Quick Answer: {answer}")
kg = data.get('knowledgeGraph', {})
if kg:
parts.append(
f"Knowledge Graph: {kg.get('title', '')} - "
f"{kg.get('description', '')}")
for i, item in enumerate(data.get('organic', [])[:5], 1):
parts.append(
f"Result {i}: {item.get('title', '')}\n"
f" {item.get('snippet', '')}")
return "\n".join(parts) if parts else "No relevant results found."
except Exception as e:
return f"Web search error: {e}"
# ── MARS retrieval concurrency semaphore (lazy init) ──
_mars_retrieval_semaphore: Optional[asyncio.Semaphore] = None
def _get_mars_retrieval_semaphore() -> Optional[asyncio.Semaphore]:
global _mars_retrieval_semaphore
if MARS_RETRIEVAL_CONCURRENCY > 0 and _mars_retrieval_semaphore is None:
_mars_retrieval_semaphore = asyncio.Semaphore(MARS_RETRIEVAL_CONCURRENCY)
return _mars_retrieval_semaphore if MARS_RETRIEVAL_CONCURRENCY > 0 else None
async def mars_web_search(
query: str,
session: aiohttp.ClientSession,
retrieval_address: str = "",
summarizer_address: str = "",
retrieval_topk: int = 3,
summarizer_model: str = "",
) -> str:
"""SenseNova-MARS style web search: retrieve from local Wikipedia + summarize via LLM.
Two-step pipeline:
1. POST to Search-R1 retrieval server → get top-k document passages
2. POST to summarizer LLM (OpenAI-compatible /v1/chat/completions) → get concise summary
This mirrors the SenseNova-MARS web_search_server architecture but without
the intermediate FastAPI layer.
"""
retrieval_addr = retrieval_address or MARS_RETRIEVAL_ADDRESS
summarizer_addr = summarizer_address or MARS_SUMMARIZER_ADDRESS
topk = retrieval_topk or MARS_RETRIEVAL_TOPK
sum_model = summarizer_model or MARS_SUMMARIZER_MODEL
if not retrieval_addr:
return "Error: MARS_RETRIEVAL_ADDRESS not configured."
if not summarizer_addr:
return "Error: MARS_SUMMARIZER_ADDRESS not configured."
clean_query = query.strip().replace("\n", " ")
# ════════════════════════════════════════════════════════════════
# Step 1: Retrieve passages from Search-R1 retrieval server
# ════════════════════════════════════════════════════════════════
retrieval_payload = {
"queries": [clean_query], # Search-R1 要求 queries 是 List[str]
"return_scores": True, # 必须为 True,否则服务端解包崩溃返回 500
"topk": topk,
}
retrieved_passages = []
sem = _get_mars_retrieval_semaphore()
try:
if sem:
await sem.acquire()
try:
async with session.post(
f"http://{retrieval_addr}/retrieve",
json=retrieval_payload,
timeout=aiohttp.ClientTimeout(total=MARS_RETRIEVAL_TIMEOUT),
proxy="", # ← 绕过系统代理,直连内网
) as resp:
if resp.status != 200:
error_body = await resp.text()
return (f"Error: Retrieval server returned HTTP {resp.status}: "
f"{error_body[:300]}")
data = await resp.json()
# 已确认的 Search-R1 返回格式:
# {"result": [[{"document": {"id": "xx", "contents": "\"title\"\ntext"}, "score": 0.84}, ...]]}
raw_results = data.get("result", [])
for query_results in raw_results:
if not isinstance(query_results, list):
continue
for item in query_results:
if not isinstance(item, dict):
continue
doc = item.get("document", {})
if isinstance(doc, dict):
contents = doc.get("contents", "")
elif isinstance(doc, str):
contents = doc
else:
continue
if contents:
lines = contents.split("\n", 1)
title = lines[0].strip('"') if lines else ""
text = lines[1] if len(lines) > 1 else contents
# 截断过长的 passage,避免超出 summarizer context window
retrieved_passages.append({"title": title, "text": text[:2000]})
finally:
if sem:
sem.release()
except asyncio.TimeoutError:
return f"Error: Retrieval server timeout for query: {clean_query[:100]}"
except (aiohttp.ClientConnectorError,
aiohttp.ServerDisconnectedError,
ConnectionRefusedError,
ConnectionResetError,
OSError) as e:
raise RetrieverDownError(
f"Retriever 服务 {retrieval_addr} 连接失败: {e}"
) from e
except Exception as e:
err_str = str(e).lower()
if any(kw in err_str for kw in (
"cannot connect", "connection refused", "connect call failed",
"server disconnected",
)):
raise RetrieverDownError(
f"Retriever 服务 {retrieval_addr} 连接失败: {e}"
) from e
return f"Error: Retrieval failed: {e}"
if not retrieved_passages:
return f"No relevant passages found for query: {clean_query}"
# ════════════════════════════════════════════════════════════════
# Step 2: Summarize via LLM (OpenAI-compatible API)
# ════════════════════════════════════════════════════════════════
context_parts = []
for i, p in enumerate(retrieved_passages, 1):
title_str = f" (Title: {p['title']})" if p['title'] else ""
context_parts.append(f"Passage {i}{title_str}:\n{p['text']}")
context_text = "\n\n".join(context_parts)
summarizer_prompt = (
f"Based on the following retrieved passages, provide a concise and informative "
f"summary that answers the query: \"{clean_query}\"\n\n"
f"{context_text}\n\n"
f"Please provide a concise summary focusing on the most relevant information. "
f"If the passages do not contain relevant information, say so."
)
summarizer_payload = {
"model": sum_model,
"messages": [{"role": "user", "content": summarizer_prompt}],
"max_tokens": 1024,
"temperature": 0.3,
"chat_template_kwargs": {"enable_thinking": False},
}
try:
async with session.post(
f"http://{summarizer_addr}/v1/chat/completions",
json=summarizer_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=120),
proxy="", # ← 绕过系统代理,直连内网
) as resp:
if resp.status != 200:
print(f" [MARS_SEARCH] Summarizer returned HTTP {resp.status}, "
f"falling back to raw passages")
return _format_raw_passages(clean_query, retrieved_passages)
data = await resp.json()
choices = data.get("choices", [])
if choices and isinstance(choices, list):
msg = choices[0].get("message", {})
summary = msg.get("content", "")
if summary and summary.strip():
summary = _strip_thinking_tags(summary)
return (
f"Web Search Results for \"{clean_query}\" "
f"(MARS retrieve+summarize, {len(retrieved_passages)} passages):\n\n"
f"{summary.strip()}"
)
return _format_raw_passages(clean_query, retrieved_passages)
except asyncio.TimeoutError:
print(f" [MARS_SEARCH] Summarizer timeout, falling back to raw passages")
return _format_raw_passages(clean_query, retrieved_passages)
except Exception as e:
print(f" [MARS_SEARCH] Summarizer error: {e}, falling back to raw passages")
return _format_raw_passages(clean_query, retrieved_passages)
def _strip_thinking_tags(text: str) -> str:
"""Remove thinking content from summarizer output."""
# 1. 移除 XML 风格的 ...
text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip()
# 2. 移除 "Thinking Process:" 开头的内容块(到第一个连续空行为止)
text = re.sub(r'^Thinking Process:.*?(?=\n\n)', '', text, flags=re.DOTALL).strip()
# 3. 如果清理后还是以 "Thinking" 开头(边界情况),再截一次
if text.startswith("Thinking"):
parts = text.split("\n\n", 1)
if len(parts) > 1:
text = parts[1].strip()
return text
def _format_raw_passages(query: str, passages: list) -> str:
"""Format raw retrieved passages as fallback when summarizer fails."""
parts = [f"Web Search Results for \"{query}\" (raw retrieval, {len(passages)} passages):"]
for i, p in enumerate(passages, 1):
title_str = f" — {p['title']}" if p['title'] else ""
text_preview = p['text'][:500] + ("..." if len(p['text']) > 500 else "")
parts.append(f"\nResult {i}{title_str}:\n {text_preview}")
return "\n".join(parts)
# ════════════════════════════════════════════════════════════════════════
# I/O Helpers
# ════════════════════════════════════════════════════════════════════════
def make_uid(entry: Dict) -> str:
"""Generate a unique ID from id + video_filename.
A single 'id' can map to multiple videos, so we need a composite key."""
eid = entry.get("id", "unknown")
vf = entry.get("video_filename", "")
if vf:
stem = os.path.splitext(vf)[0]
return f"{eid}__{stem}"
return eid
def load_completed_ids(output_file: str) -> Set[str]:
"""Load already-completed entry UIDs from output JSONL for resume."""
completed = set()
if not os.path.exists(output_file):
return completed
with open(output_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
if "uid" in record and "error" not in record:
completed.add(record["uid"])
except json.JSONDecodeError:
pass
return completed
def get_question(entry: Dict) -> str:
"""Get the appropriate question field based on verdict."""
if entry.get("verdict") == "rewrite":
return entry.get("rewritten_question") or entry.get("original_question", "")
return entry.get("original_question", "")
# ════════════════════════════════════════════════════════════════════════
# OSS Image Upload for Search (解决 Serper Lens 400 Bad Request)
# ════════════════════════════════════════════════════════════════════════
# 对齐第二份代码的思路:先把裁剪图上传到公网可访问的对象存储 →
# 拿到 https://... URL → 传给 Serper Lens → 获取搜索结果。
#
# Serper Lens 对 base64 data URI 支持不稳定(频繁返回 400),
# 但对 HTTPS URL 工作正常。
_oss_bucket = None # 延迟初始化,避免未安装时直接报错
_oss_init_attempted = False # 只尝试初始化一次
_oss_last_error = ""
def _get_oss_bucket():
"""延迟初始化 OSS Bucket 对象,只初始化一次。"""
global _oss_bucket, _oss_init_attempted, _oss_last_error
if _oss_init_attempted:
return _oss_bucket
_oss_init_attempted = True
try:
import oss2
if not OSS_ACCESS_KEY_ID or not OSS_ACCESS_KEY_SECRET:
_oss_last_error = "OSS_ACCESS_KEY_ID / OSS_ACCESS_KEY_SECRET 未配置"
print(f"[ERROR] {_oss_last_error}")
return None
auth = oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET)
_oss_bucket = oss2.Bucket(auth, f"https://{OSS_ENDPOINT}", OSS_BUCKET_NAME)
_oss_last_error = ""
print(f"[OSS] Bucket initialized: {OSS_BUCKET_NAME} @ {OSS_ENDPOINT}")
return _oss_bucket
except ImportError:
_oss_last_error = "oss2 未安装,请执行: pip install oss2 --break-system-packages"
print(f"[ERROR] {_oss_last_error}")
return None
except Exception as e:
_oss_last_error = f"OSS Bucket 初始化失败: {e}"
print(f"[ERROR] {_oss_last_error}")
return None
def optimize_crop_for_search(crop_path: str, output_path: str = None,
max_size: int = SEARCH_CROP_MAX_SIZE,
quality: int = SEARCH_CROP_JPEG_QUALITY) -> str:
"""优化裁剪图用于搜索:缩小尺寸 + 降低质量,减少上传体积。
Args:
crop_path: 原始裁剪图路径
output_path: 优化后保存路径(默认在同目录加 _opt 后缀)
max_size: 最大边长(像素)
quality: JPEG 质量
Returns:
优化后的图片路径(如果优化失败则返回原路径)。
"""
if output_path is None:
base, ext = os.path.splitext(crop_path)
output_path = f"{base}_opt{ext}"
try:
with Image.open(crop_path) as img:
w, h = img.size
if max(w, h) > max_size:
ratio = max_size / max(w, h)
new_w, new_h = int(w * ratio), int(h * ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
img.save(output_path, 'JPEG', quality=quality)
return output_path
except Exception as e:
print(f" [WARN] optimize_crop_for_search failed: {e}, using original")
return crop_path
def _upload_to_oss(local_path: str, oss_object_name: str = None) -> Optional[str]:
"""上传图片到阿里云 OSS,返回公网 URL。
去重策略:用文件内容的 MD5 作为对象名,相同内容不重复上传。
Args:
local_path: 本地图片路径
oss_object_name: OSS 对象名,默认用 md5 hash 去重
Returns:
公网 URL,失败返回 None
"""
global _oss_last_error
bucket = _get_oss_bucket()
if bucket is None:
return None
try:
# 先优化图片尺寸和质量
opt_path = optimize_crop_for_search(local_path)
# 用 md5 作为文件名去重
with open(opt_path, "rb") as f:
file_bytes = f.read()
file_hash = hashlib.md5(file_bytes).hexdigest()
if oss_object_name is None:
oss_object_name = f"{OSS_UPLOAD_PREFIX}/{file_hash}.jpg"
# 检查是否已存在(head_object 成功说明已上传过)
try:
bucket.head_object(oss_object_name)
# 已存在,直接返回 URL
public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}"
print(f" [OSS] Cache hit: {public_url}")
return public_url
except Exception:
# 不存在或其他异常,继续上传
pass
# 上传
bucket.put_object(oss_object_name, file_bytes, headers={
'Content-Type': 'image/jpeg',
})
public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}"
print(f" [OSS] Uploaded: {public_url}")
return public_url
except Exception as e:
_oss_last_error = f"OSS upload failed: {e}"
print(f" [WARN] {_oss_last_error}")
return None
def _prepare_image_search_url(
image_b64_or_path: str,
crop_path: Optional[str],
log_prefix: str,
) -> Tuple[Optional[str], Optional[str]]:
"""为图片搜索准备 URL;默认要求先上传到 OSS。"""
image_url = None
if crop_path and os.path.exists(crop_path):
image_url = _upload_to_oss(crop_path)
if image_url:
print(f" [{log_prefix}] Using OSS URL: {image_url}")
return image_url, None
if IMAGE_SEARCH_ALLOW_BASE64_FALLBACK:
image_url = f"data:image/jpeg;base64,{image_b64_or_path}"
print(f" [{log_prefix}] WARNING: OSS upload failed, falling back to "
f"base64 data URI (len={len(image_b64_or_path)}, may get 400)")
return image_url, None
failure_reason = _oss_last_error or "OSS upload returned None"
return None, (
"Image search unavailable: OSS upload failed and base64 fallback is disabled. "
f"Root cause: {failure_reason}"
)