| |
| """ |
| Video DeepResearch 公共工具模块。 |
| |
| 输入: |
| - `config.py` 导出的模型端点、Vertex 配置、本地检索服务地址与价格参数。 |
| - 视频帧、裁剪图、LLM messages、搜索请求与各阶段中间结果。 |
| |
| 处理: |
| - 提供帧采样、图片编码、搜索访问、输出清洗、token 统计等通用能力。 |
| - 在 `LLMClient` 中实现多 endpoint 轮询、429 冷却切换、Vertex 原生调用与 OAuth2 认证。 |
| - 支持“一个 project 对应一个 service account json”的 Vertex 凭证池,并按 URL 选择对应凭证。 |
| |
| 输出: |
| - 为 Phase1/Phase2 提供统一的工具函数、搜索函数、token 统计结构和 `LLMClient`。 |
| - 返回规范化的模型响应、搜索结果、图片路径与 token 使用信息。 |
| """ |
|
|
| import asyncio |
| import aiohttp |
| import base64 |
| import hashlib |
| import itertools |
| import json |
| import math |
| import os |
| import random |
| import re |
| import subprocess |
| import time |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional, Any, Set |
| from dataclasses import dataclass, field |
|
|
| from config import ( |
| API_ENDPOINTS, DEFAULT_MODEL, API_KEY, |
| WEB_SEARCH_ADDRESS, WEB_SEARCH_CONFIG, |
| SERPER_API_KEY, MOCK_SEARCH, IMAGE_SEARCH_CACHE_FILE, |
| DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_MAX_TOKENS, |
| DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, |
| VERTEX_MIN_REQUEST_INTERVAL_SECONDS, |
| VERTEX_RATE_LIMIT_COOLDOWN_SECONDS, |
| VERTEX_REQUEST_JITTER_SECONDS, |
| TOKEN_PRICING, THINKING_TOKEN_PRICING, |
| DEFAULT_INPUT_PRICE_PER_1M, DEFAULT_OUTPUT_PRICE_PER_1M, |
| OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET, |
| OSS_ENDPOINT, OSS_BUCKET_NAME, OSS_UPLOAD_PREFIX, |
| SEARCH_CROP_MAX_SIZE, SEARCH_CROP_JPEG_QUALITY, |
| IMAGE_SEARCH_SUMMARIZE_SERPER, IMAGE_SEARCH_SUMMARIZER_ADDRESS, |
| IMAGE_SEARCH_SUMMARIZER_MODEL, IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, |
| IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, |
| IMAGE_SEARCH_MODE, GATEWAY_URL, GATEWAY_USERNAME, GATEWAY_USERID, GATEWAY_TOKEN, |
| IMAGE_SEARCH_ALLOW_BASE64_FALLBACK, |
| ) |
| from config import BBOX_CONFIGS |
| import config |
| |
| MARS_RETRIEVAL_ADDRESS = getattr(config, 'MARS_RETRIEVAL_ADDRESS', '') |
| MARS_SUMMARIZER_ADDRESS = getattr(config, 'MARS_SUMMARIZER_ADDRESS', '') |
| MARS_RETRIEVAL_TOPK = getattr(config, 'MARS_RETRIEVAL_TOPK', 3) |
| MARS_SUMMARIZER_MODEL = getattr(config, 'MARS_SUMMARIZER_MODEL', '') |
| MARS_WEB_SEARCH_MODE = getattr(config, 'MARS_WEB_SEARCH_MODE', 'serper') |
| MARS_RETRIEVAL_TIMEOUT = getattr(config, 'MARS_RETRIEVAL_TIMEOUT', 120) |
| MARS_RETRIEVAL_CONCURRENCY = getattr(config, 'MARS_RETRIEVAL_CONCURRENCY', 0) |
| IMAGE_SEARCH_SUMMARIZE_SERPER = getattr(config, 'IMAGE_SEARCH_SUMMARIZE_SERPER', True) |
| IMAGE_SEARCH_SUMMARIZER_ADDRESS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_ADDRESS', MARS_SUMMARIZER_ADDRESS) |
| IMAGE_SEARCH_SUMMARIZER_MODEL = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MODEL', MARS_SUMMARIZER_MODEL) |
| IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS', 5) |
| IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS', 512) |
|
|
| GCP_PROJECT_ID = getattr(config, 'GCP_PROJECT_ID', '') |
| GCP_LOCATION = getattr(config, 'GCP_LOCATION', '') |
| GCP_SERVICE_ACCOUNT_KEY = getattr(config, 'GCP_SERVICE_ACCOUNT_KEY', '') |
| VERTEX_CREDENTIALS_POOL = getattr(config, 'VERTEX_CREDENTIALS_POOL', []) |
|
|
| |
| from prompts import PHASE1_SYSTEM_PROMPT, PHASE2_SYSTEM_PROMPT, SYSTEM_PROMPT_BASE |
|
|
| |
| SYSTEM_PROMPT = PHASE1_SYSTEM_PROMPT |
|
|
|
|
| |
| |
| |
|
|
| def _get_pricing(model: str) -> Tuple[float, float, float]: |
| """Get (input_price, output_price, thinking_price) per 1M tokens for a model. |
| |
| Matches model name by substring (e.g., 'gemini-2.5-flash' matches |
| 'gemini-2.5-flash-preview-05-20'). |
| |
| Returns prices in USD per 1M tokens. |
| """ |
| input_price = DEFAULT_INPUT_PRICE_PER_1M |
| output_price = DEFAULT_OUTPUT_PRICE_PER_1M |
| thinking_price = None |
|
|
| model_lower = model.lower() |
|
|
| |
| best_match = "" |
| for pattern, (inp, outp) in TOKEN_PRICING.items(): |
| if pattern.lower() in model_lower and len(pattern) > len(best_match): |
| best_match = pattern |
| input_price = inp |
| output_price = outp |
|
|
| |
| for pattern, tp in THINKING_TOKEN_PRICING.items(): |
| if pattern.lower() in model_lower: |
| thinking_price = tp |
| break |
|
|
| if thinking_price is None: |
| thinking_price = output_price |
|
|
| return input_price, output_price, thinking_price |
|
|
|
|
| @dataclass |
| class TokenUsage: |
| """Token usage for a single LLM call.""" |
| prompt_tokens: int = 0 |
| completion_tokens: int = 0 |
| total_tokens: int = 0 |
| |
| thinking_tokens: int = 0 |
| |
| cached_tokens: int = 0 |
|
|
| def to_dict(self) -> Dict[str, int]: |
| return { |
| "prompt_tokens": self.prompt_tokens, |
| "completion_tokens": self.completion_tokens, |
| "total_tokens": self.total_tokens, |
| "thinking_tokens": self.thinking_tokens, |
| "cached_tokens": self.cached_tokens, |
| } |
|
|
| @staticmethod |
| def from_api_response(usage_data: Dict[str, Any]) -> "TokenUsage": |
| """Parse token usage from OpenAI-compatible API response. |
| |
| Handles various API formats: |
| - Standard: {prompt_tokens, completion_tokens, total_tokens} |
| - Gemini extended: {prompt_tokens, completion_tokens, total_tokens, |
| completion_tokens_details: {reasoning_tokens: N}} |
| - Some APIs: {input_tokens, output_tokens} |
| """ |
| if not usage_data: |
| return TokenUsage() |
|
|
| prompt = usage_data.get("prompt_tokens", 0) or usage_data.get("input_tokens", 0) or 0 |
| completion = usage_data.get("completion_tokens", 0) or usage_data.get("output_tokens", 0) or 0 |
| total = usage_data.get("total_tokens", 0) or (prompt + completion) |
|
|
| |
| thinking = 0 |
| details = usage_data.get("completion_tokens_details", {}) |
| if isinstance(details, dict): |
| thinking = details.get("reasoning_tokens", 0) or details.get("thinking_tokens", 0) or 0 |
|
|
| |
| if not thinking: |
| thinking = usage_data.get("reasoning_tokens", 0) or usage_data.get("thinking_tokens", 0) or 0 |
|
|
| |
| cached = 0 |
| prompt_details = usage_data.get("prompt_tokens_details", {}) |
| if isinstance(prompt_details, dict): |
| cached = prompt_details.get("cached_tokens", 0) or 0 |
| if not cached: |
| cached = usage_data.get("cached_tokens", 0) or 0 |
|
|
| return TokenUsage( |
| prompt_tokens=prompt, |
| completion_tokens=completion, |
| total_tokens=total, |
| thinking_tokens=thinking, |
| cached_tokens=cached, |
| ) |
|
|
|
|
| @dataclass |
| class TokenTracker: |
| """Tracks token usage across multiple LLM calls for a single data entry. |
| |
| Accumulates prompt_tokens, completion_tokens, thinking_tokens and |
| computes estimated cost based on model pricing. |
| """ |
| model: str = "" |
| total_prompt_tokens: int = 0 |
| total_completion_tokens: int = 0 |
| total_thinking_tokens: int = 0 |
| total_cached_tokens: int = 0 |
| num_calls: int = 0 |
| call_details: List[Dict[str, Any]] = field(default_factory=list) |
|
|
| def add(self, usage: TokenUsage, call_label: str = ""): |
| """Add token usage from one LLM call.""" |
| self.total_prompt_tokens += usage.prompt_tokens |
| self.total_completion_tokens += usage.completion_tokens |
| self.total_thinking_tokens += usage.thinking_tokens |
| self.total_cached_tokens += usage.cached_tokens |
| self.num_calls += 1 |
| self.call_details.append({ |
| "label": call_label, |
| **usage.to_dict(), |
| }) |
|
|
| @property |
| def total_tokens(self) -> int: |
| return self.total_prompt_tokens + self.total_completion_tokens |
|
|
| def estimate_cost(self, model: str = "") -> Dict[str, float]: |
| """Estimate cost in USD based on model pricing. |
| |
| Returns dict with input_cost, output_cost, thinking_cost, total_cost. |
| """ |
| m = model or self.model |
| input_price, output_price, thinking_price = _get_pricing(m) |
|
|
| input_cost = self.total_prompt_tokens * input_price / 1_000_000 |
| |
| |
| non_thinking_completion = max(0, self.total_completion_tokens - self.total_thinking_tokens) |
| output_cost = non_thinking_completion * output_price / 1_000_000 |
| thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000 |
|
|
| return { |
| "input_cost_usd": round(input_cost, 6), |
| "output_cost_usd": round(output_cost, 6), |
| "thinking_cost_usd": round(thinking_cost, 6), |
| "total_cost_usd": round(input_cost + output_cost + thinking_cost, 6), |
| } |
|
|
| def to_dict(self, model: str = "") -> Dict[str, Any]: |
| """Export full tracking info as a dict.""" |
| cost = self.estimate_cost(model) |
| return { |
| "num_llm_calls": self.num_calls, |
| "total_prompt_tokens": self.total_prompt_tokens, |
| "total_completion_tokens": self.total_completion_tokens, |
| "total_thinking_tokens": self.total_thinking_tokens, |
| "total_cached_tokens": self.total_cached_tokens, |
| "total_tokens": self.total_tokens, |
| "estimated_cost": cost, |
| "call_details": self.call_details, |
| } |
|
|
| def summary_str(self, model: str = "") -> str: |
| """Human-readable one-line summary.""" |
| cost = self.estimate_cost(model) |
| return ( |
| f"calls={self.num_calls} " |
| f"prompt={self.total_prompt_tokens:,} " |
| f"completion={self.total_completion_tokens:,} " |
| f"thinking={self.total_thinking_tokens:,} " |
| f"total={self.total_tokens:,} " |
| f"cost=${cost['total_cost_usd']:.4f}" |
| ) |
|
|
|
|
| class GlobalTokenStats: |
| """Thread-safe aggregator for token stats across all entries.""" |
|
|
| def __init__(self, model: str = ""): |
| self.model = model |
| self.total_prompt_tokens = 0 |
| self.total_completion_tokens = 0 |
| self.total_thinking_tokens = 0 |
| self.total_cached_tokens = 0 |
| self.total_calls = 0 |
| self.total_entries = 0 |
| self._lock = asyncio.Lock() |
|
|
| async def add(self, tracker: TokenTracker): |
| async with self._lock: |
| self.total_prompt_tokens += tracker.total_prompt_tokens |
| self.total_completion_tokens += tracker.total_completion_tokens |
| self.total_thinking_tokens += tracker.total_thinking_tokens |
| self.total_cached_tokens += tracker.total_cached_tokens |
| self.total_calls += tracker.num_calls |
| self.total_entries += 1 |
|
|
| def estimate_cost(self) -> Dict[str, float]: |
| input_price, output_price, thinking_price = _get_pricing(self.model) |
| input_cost = self.total_prompt_tokens * input_price / 1_000_000 |
| non_thinking = max(0, self.total_completion_tokens - self.total_thinking_tokens) |
| output_cost = non_thinking * output_price / 1_000_000 |
| thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000 |
| return { |
| "input_cost_usd": round(input_cost, 4), |
| "output_cost_usd": round(output_cost, 4), |
| "thinking_cost_usd": round(thinking_cost, 4), |
| "total_cost_usd": round(input_cost + output_cost + thinking_cost, 4), |
| } |
|
|
| def summary_str(self) -> str: |
| cost = self.estimate_cost() |
| avg_tokens = self.total_prompt_tokens + self.total_completion_tokens |
| avg_per_entry = avg_tokens / max(1, self.total_entries) |
| avg_cost = cost["total_cost_usd"] / max(1, self.total_entries) |
| return ( |
| f"\n{'=' * 72}\n" |
| f" Token Usage Summary\n" |
| f"{'=' * 72}\n" |
| f" Model : {self.model}\n" |
| f" Total entries : {self.total_entries}\n" |
| f" Total LLM calls : {self.total_calls}\n" |
| f" Total prompt tok : {self.total_prompt_tokens:,}\n" |
| f" Total completion : {self.total_completion_tokens:,}\n" |
| f" Total thinking : {self.total_thinking_tokens:,}\n" |
| f" Total cached : {self.total_cached_tokens:,}\n" |
| f" Total tokens : {self.total_prompt_tokens + self.total_completion_tokens:,}\n" |
| f" ──────────────────────────────────────\n" |
| f" Avg tokens/entry : {avg_per_entry:,.0f}\n" |
| f" Avg cost/entry : ${avg_cost:.4f}\n" |
| f" ──────────────────────────────────────\n" |
| f" Input cost : ${cost['input_cost_usd']:.4f}\n" |
| f" Output cost : ${cost['output_cost_usd']:.4f}\n" |
| f" Thinking cost : ${cost['thinking_cost_usd']:.4f}\n" |
| f" TOTAL COST : ${cost['total_cost_usd']:.4f}\n" |
| f"{'=' * 72}" |
| ) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| cost = self.estimate_cost() |
| avg_tokens = self.total_prompt_tokens + self.total_completion_tokens |
| return { |
| "model": self.model, |
| "total_entries": self.total_entries, |
| "total_llm_calls": self.total_calls, |
| "total_prompt_tokens": self.total_prompt_tokens, |
| "total_completion_tokens": self.total_completion_tokens, |
| "total_thinking_tokens": self.total_thinking_tokens, |
| "total_cached_tokens": self.total_cached_tokens, |
| "total_tokens": avg_tokens, |
| "avg_tokens_per_entry": round(avg_tokens / max(1, self.total_entries)), |
| "avg_cost_per_entry_usd": round(cost["total_cost_usd"] / max(1, self.total_entries), 6), |
| "estimated_cost": cost, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def think_is_nonempty(text: str) -> bool: |
| """Check if a think block contains basic meaningful content. |
| Relaxed version: Removed all arbitrary length/word count limits. |
| Only ensures the model didn't return a completely blank string.""" |
| if not text: |
| return False |
| |
| clean = re.sub(r'</?think>', '', text).strip() |
| |
| alnum = re.sub(r'[^A-Za-z0-9]', '', clean) |
| if len(alnum) < 5: |
| return False |
| return True |
|
|
| def extract_think_text(text: str) -> str: |
| """Extract text content from a <think> block.""" |
| m = re.search(r'<think>(.*?)</think>', text, re.DOTALL) |
| return (m.group(1) or '').strip() if m else '' |
|
|
|
|
| def _dedup_think_content(text: str) -> str: |
| """Remove paragraph-level repetition inside a <think> block. |
| |
| LLMs sometimes repeat entire paragraphs verbatim within a single turn |
| (decoding-level repetition / "repetition hallucination"). |
| This function detects and removes such duplicates while preserving |
| unique content and ordering. |
| |
| Strategy: split by double-newline into paragraphs, keep only the first |
| occurrence of each paragraph (compared after whitespace normalisation). |
| Also detects the case where the entire content is duplicated as one |
| contiguous block (no blank-line separator between the copies). |
| """ |
| if not text or not text.strip(): |
| return text |
|
|
| |
| paragraphs = re.split(r'\n\s*\n', text.strip()) |
| if len(paragraphs) >= 2: |
| seen = set() |
| unique = [] |
| for p in paragraphs: |
| key = ' '.join(p.split()) |
| if key and key not in seen: |
| seen.add(key) |
| unique.append(p) |
| if len(unique) < len(paragraphs): |
| return '\n\n'.join(unique) |
|
|
| |
| |
| stripped = text.strip() |
| length = len(stripped) |
| if length >= 80: |
| |
| mid = length // 2 |
| for offset in range(0, min(40, mid)): |
| for pos in (mid + offset, mid - offset): |
| if pos <= 0 or pos >= length: |
| continue |
| if stripped[pos] != '\n': |
| continue |
| first_half = stripped[:pos].strip() |
| second_half = stripped[pos:].strip() |
| if first_half == second_half: |
| return first_half |
|
|
| return text |
|
|
|
|
| |
| |
| |
|
|
| def sanitize_llm_output(text: str) -> str: |
| """Truncate LLM output after the FIRST valid action block. |
| |
| Detects and removes hallucinated content where the model generates: |
| - Multiple tool_calls in one turn |
| - Fake <tool_response> blocks |
| - "MODERATION:" blocks |
| - Fake search results |
| - Both <tool_call> AND <answer> in the same turn |
| |
| Returns cleaned text containing at most: <think>...</think> + one action. |
| """ |
| if not text or not text.strip(): |
| return text |
|
|
| text = text.strip() |
|
|
| |
| |
| if '<tool_response>' in text: |
| |
| tr_start = text.index('<tool_response>') |
| text = text[:tr_start].strip() |
|
|
| |
| moderation_pattern = re.compile(r'\n*---+\s*\n*MODERATION:.*', re.DOTALL | re.IGNORECASE) |
| text = moderation_pattern.sub('', text).strip() |
|
|
| |
| tc_matches = list(re.finditer(r'<tool_call>.*?</tool_call>', text, re.DOTALL)) |
| ans_matches = list(re.finditer(r'<answer>.*?</answer>', text, re.DOTALL)) |
|
|
| all_actions = [] |
| for m in tc_matches: |
| all_actions.append(('tool_call', m.start(), m.end())) |
| for m in ans_matches: |
| all_actions.append(('answer', m.start(), m.end())) |
|
|
| if not all_actions: |
| |
| return text |
|
|
| |
| all_actions.sort(key=lambda x: x[1]) |
| first_type, first_start, first_end = all_actions[0] |
|
|
| |
| text = text[:first_end].strip() |
|
|
| return text |
|
|
|
|
| def is_hallucinated_output(text: str) -> bool: |
| """Check if the LLM output contains hallucination markers. |
| |
| Returns True if the output contains: |
| - Multiple <tool_call> blocks |
| - Any <tool_response> block (should only come from system) |
| - "MODERATION:" blocks |
| - Both <tool_call> and <answer> in same turn |
| """ |
| if not text: |
| return False |
|
|
| tc_count = len(re.findall(r'<tool_call>', text)) |
| has_tool_response = '<tool_response>' in text |
| has_moderation = bool(re.search(r'MODERATION:', text, re.IGNORECASE)) |
| has_both = '<tool_call>' in text and '<answer>' in text |
|
|
| return tc_count > 1 or has_tool_response or has_moderation or has_both |
|
|
|
|
| |
| |
| |
|
|
| def normalize_gpt_output(text: str) -> str: |
| """Ensure every gpt turn follows strict format: |
| <think>...</think> followed by <tool_call>...</tool_call> or <answer>...</answer> |
| |
| Pipeline: |
| 1. Sanitize hallucinated content (truncate after first action) |
| 2. If <think> already present, validate and fix duplicates |
| 3. If no <think>, wrap pre-action text as think block if valid |
| 4. ZERO FILLER: If think is missing or empty, DO NOT inject fake fallback text. |
| """ |
| if not text or not text.strip(): |
| return text |
|
|
| text = text.strip() |
|
|
| |
| if text.count('</think>') > 1: |
| first_close_pos = text.index('</think>') |
| before_and_first = text[:first_close_pos + len('</think>')] |
| after_first = text[first_close_pos + len('</think>'):] |
| after_first = after_first.replace('</think>', '') |
| text = before_and_first + after_first |
|
|
| |
| if text.count('<think>') > 1: |
| first_open_pos = text.index('<think>') |
| before_and_first = text[:first_open_pos + len('<think>')] |
| after_first = text[first_open_pos + len('<think>'):] |
| after_first = after_first.replace('<think>', '') |
| text = before_and_first + after_first |
|
|
| has_think_open = '<think>' in text |
| has_think_close = '</think>' in text |
| has_tc = '<tool_call>' in text |
| has_answer = '<answer>' in text |
|
|
| if has_think_open: |
| |
| if not has_think_close: |
| |
| if has_tc: |
| tc_pos = text.index('<tool_call>') |
| text = text[:tc_pos] + '</think>\n\n' + text[tc_pos:] |
| elif has_answer: |
| ans_pos = text.index('<answer>') |
| text = text[:ans_pos] + '</think>\n\n' + text[ans_pos:] |
| else: |
| text = text + '</think>' |
|
|
| |
| think_text = extract_think_text(text) |
| if think_text: |
| deduped = _dedup_think_content(think_text) |
| if deduped != think_text: |
| text = text.replace(think_text, deduped, 1) |
|
|
| |
| return text |
|
|
| |
| if has_tc: |
| tc_start = text.index('<tool_call>') |
| pre_text = text[:tc_start].strip() |
| tc_and_after = text[tc_start:] |
| if pre_text and think_is_nonempty(pre_text): |
| return f"<think>{pre_text}</think>\n\n{tc_and_after}" |
| else: |
| |
| return tc_and_after |
|
|
| elif has_answer: |
| ans_start = text.index('<answer>') |
| pre_text = text[:ans_start].strip() |
| ans_and_after = text[ans_start:] |
| if pre_text and think_is_nonempty(pre_text): |
| return f"<think>{pre_text}</think>\n\n{ans_and_after}" |
| else: |
| |
| return ans_and_after |
|
|
| else: |
| |
| return f"<think>{text}</think>" |
|
|
|
|
|
|
| |
| |
| |
|
|
| def clean_conversations(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]: |
| """Remove error recovery turns from training conversations. |
| |
| Produces a clean trajectory suitable for SFT training by: |
| 1. Removing empty gpt turns (model returned nothing) |
| 2. Removing corresponding error tool_response turns |
| 3. Removing gpt turns that were error messages |
| 4. Re-normalizing remaining gpt turns |
| |
| Returns a new list of clean conversation turns. |
| """ |
| if not conversations: |
| return conversations |
|
|
| cleaned = [] |
| skip_next_human = False |
|
|
| for i, turn in enumerate(conversations): |
| role = turn.get("from", "") |
| value = turn.get("value", "") |
|
|
| if skip_next_human: |
| if role == "human": |
| skip_next_human = False |
| |
| if _is_error_tool_response(value): |
| continue |
| |
| cleaned.append(turn) |
| else: |
| |
| skip_next_human = False |
| cleaned.append(turn) |
| continue |
|
|
| if role == "gpt": |
| |
| stripped = value.strip() if value else "" |
|
|
| |
| if not stripped: |
| |
| skip_next_human = True |
| continue |
|
|
| |
| has_tc = '<tool_call>' in stripped |
| has_answer = '<answer>' in stripped |
| if not has_tc and not has_answer: |
| |
| skip_next_human = True |
| continue |
|
|
| |
| normalized = normalize_gpt_output(stripped) |
| cleaned.append({"from": "gpt", "value": normalized}) |
|
|
| elif role == "human": |
| |
| if _is_error_tool_response(value): |
| |
| if cleaned and cleaned[-1].get("from") == "gpt": |
| |
| |
| pass |
| continue |
| cleaned.append(turn) |
|
|
| else: |
| |
| cleaned.append(turn) |
|
|
| return cleaned |
|
|
|
|
| def _is_error_tool_response(value: str) -> bool: |
| """Check if a human turn is an error tool_response that should be removed.""" |
| if not value: |
| return False |
| v = value.strip() |
| error_markers = [ |
| "No <tool_call> or <answer> found", |
| "Error: Tool", |
| "Malformed tool_call JSON", |
| "Error: bbox must have exactly", |
| "Error: web_search requires a non-empty query", |
| ] |
| for marker in error_markers: |
| if marker in v: |
| return True |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def get_video_frame_count(video_path: str) -> int: |
| """Get total frame count from a 1fps video (duration ~ frames).""" |
| try: |
| result = subprocess.run( |
| ["ffprobe", "-v", "error", "-show_entries", "format=duration", |
| "-of", "csv=p=0", video_path], |
| capture_output=True, text=True, timeout=30) |
| return int(math.ceil(float(result.stdout.strip()))) |
| except Exception: |
| return -1 |
|
|
|
|
| def extract_all_frames(video_path: str, output_dir: str, |
| max_resolution: int = 768, |
| jpeg_quality: int = 85) -> Dict[int, str]: |
| """Extract all frames from 1fps video. Returns {frame_index(0-based): path}. |
| Reuses existing frames if the directory already has them.""" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| existing = {} |
| for f in sorted(os.listdir(output_dir)): |
| if f.startswith("frame_") and f.endswith(".jpg"): |
| idx = int(f.replace("frame_", "").replace(".jpg", "")) |
| existing[idx] = os.path.join(output_dir, f) |
| if existing: |
| return existing |
|
|
| vf = (f"scale='min({max_resolution},iw)':'min({max_resolution},ih)'" |
| f":force_original_aspect_ratio=decrease") |
| q = max(2, min(31, round(2 + (100 - jpeg_quality) * 29 / 99))) |
|
|
| cmd = [ |
| "ffmpeg", "-y", "-i", video_path, |
| "-vf", vf, "-q:v", str(q), |
| "-start_number", "0", |
| os.path.join(output_dir, "frame_%06d.jpg"), |
| ] |
| proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600) |
| if proc.returncode != 0: |
| raise RuntimeError(f"ffmpeg failed: {proc.stderr[-500:]}") |
|
|
| frames = {} |
| for f in sorted(os.listdir(output_dir)): |
| if f.startswith("frame_") and f.endswith(".jpg"): |
| idx = int(f.replace("frame_", "").replace(".jpg", "")) |
| frames[idx] = os.path.join(output_dir, f) |
| if not frames: |
| raise RuntimeError(f"No frames extracted from {video_path}") |
| return frames |
|
|
|
|
| def uniform_sample_indices(total_frames: int, num_samples: int) -> List[int]: |
| """Uniformly sample 0-based frame indices. total_frames is the count.""" |
| if total_frames <= num_samples: |
| return list(range(total_frames)) |
| return sorted(set(int(i) for i in np.linspace(0, total_frames - 1, num_samples))) |
|
|
|
|
| def sample_interval(all_frames: Dict[int, str], start: int, end: int, |
| num_samples: int = 8) -> List[Tuple[int, str]]: |
| """Uniformly sample frames from [start, end] interval.""" |
| available = sorted(k for k in all_frames if start <= k <= end) |
| if not available: |
| return [] |
| if len(available) <= num_samples: |
| return [(k, all_frames[k]) for k in available] |
| positions = np.linspace(0, len(available) - 1, num_samples, dtype=int) |
| selected = sorted(set(available[p] for p in positions)) |
| return [(k, all_frames[k]) for k in selected] |
|
|
|
|
| def get_frame(all_frames: Dict[int, str], idx: int) -> Tuple[int, str]: |
| """Get exact frame or nearest available.""" |
| if idx in all_frames: |
| return (idx, all_frames[idx]) |
| nearest = min(all_frames.keys(), key=lambda k: abs(k - idx)) |
| return (nearest, all_frames[nearest]) |
|
|
|
|
| |
| |
| |
|
|
| def get_bbox_config(model: str) -> dict: |
| """根据模型名称返回对应的 bbox 格式配置。 |
| |
| 匹配规则: model 名称中包含 pattern 的最长匹配优先。 |
| 如果没有匹配, 返回 default 配置。 |
| |
| Returns: |
| {"order": "xyxy"|"yxyx", "range": "norm"|"permille"} |
| """ |
| model_lower = model.lower() |
| best_match = "" |
| best_config = BBOX_CONFIGS.get("default", {"order": "xyxy", "range": "norm"}) |
|
|
| for pattern, config in BBOX_CONFIGS.items(): |
| if pattern == "default": |
| continue |
| if pattern.lower() in model_lower and len(pattern) > len(best_match): |
| best_match = pattern |
| best_config = config |
|
|
| return best_config |
|
|
|
|
| def normalize_bbox(raw_bbox: list, bbox_config: dict) -> list: |
| """将模型输出的 bbox 统一转换为 [x1, y1, x2, y2] 且值域 [0.0, 1.0]。 |
| |
| 支持的输入格式: |
| - Gemini: [y_min, x_min, y_max, x_max] in [0, 1000] |
| - 标准: [x1, y1, x2, y2] in [0.0, 1.0] |
| - 像素: [x1, y1, x2, y2] 绝对像素 (自动降级) |
| |
| Args: |
| raw_bbox: 长度为 4 的列表 |
| bbox_config: get_bbox_config() 返回的配置字典 |
| |
| Returns: |
| [x1, y1, x2, y2] 全部归一化到 [0.0, 1.0] |
| """ |
| if len(raw_bbox) != 4: |
| return [0.0, 0.0, 1.0, 1.0] |
|
|
| |
| if bbox_config.get("order") == "yxyx": |
| |
| y1, x1, y2, x2 = raw_bbox |
| else: |
| |
| x1, y1, x2, y2 = raw_bbox |
|
|
| |
| if bbox_config.get("range") == "permille": |
| |
| x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0 |
| else: |
| |
| max_coord = max(abs(x1), abs(y1), abs(x2), abs(y2)) |
| if max_coord > 1.0 and max_coord <= 1000: |
| x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0 |
| elif max_coord > 1000: |
| |
| |
| pass |
|
|
| |
| x1 = max(0.0, min(1.0, float(x1))) |
| y1 = max(0.0, min(1.0, float(y1))) |
| x2 = max(0.0, min(1.0, float(x2))) |
| y2 = max(0.0, min(1.0, float(y2))) |
|
|
| |
| if x1 > x2: |
| x1, x2 = x2, x1 |
| if y1 > y2: |
| y1, y2 = y2, y1 |
|
|
| |
| if x2 - x1 < 0.001: |
| x2 = min(1.0, x1 + 0.01) |
| if y2 - y1 < 0.001: |
| y2 = min(1.0, y1 + 0.01) |
|
|
| return [x1, y1, x2, y2] |
|
|
| |
| |
| |
|
|
| class RetrieverDownError(Exception): |
| """Retriever 服务(IP/端口)连接失败,应立即停止整个 pipeline。""" |
| pass |
|
|
|
|
| class ImageSearchFailedError(Exception): |
| """Raised when image_search returns an error, timeout, or no results. |
| Signals the entire entry should be retried from scratch so that |
| failed serper results never appear in the SFT dataset.""" |
| pass |
|
|
|
|
| class QuotaExhaustedError(Exception): |
| """API 额度真正用尽 (非临时限流),应立即停止整个 pipeline。""" |
| pass |
|
|
|
|
| class ProhibitedContentError(Exception): |
| """内容被 Vertex AI 安全策略拦截 (PROHIBITED_CONTENT),不应重试。""" |
| pass |
|
|
|
|
| class ProjectDisabledError(Exception): |
| """某个 Vertex project 不可用,应从本轮 project 池中禁用。""" |
| pass |
|
|
|
|
| def _is_quota_exhausted(status_code: int, error_body: str) -> bool: |
| """判断 API 错误是否为额度真正用尽 (区别于临时 rate limit)。 |
| |
| Vertex AI 的 429 / RESOURCE_EXHAUSTED 通常是瞬时限流(QPM/TPM 打满), |
| 应该重试而非停止。只有 403 + 明确的配额/账单关键词才认为是真正的额度用尽。 |
| |
| 参考: https://docs.cloud.google.com/docs/quotas/troubleshoot |
| - 429 → 临时限流,退避重试 |
| - 403 + QUOTA_EXCEEDED / RATE_LIMIT_EXCEEDED / billing → 真正额度用尽 |
| """ |
| body_lower = error_body.lower() |
|
|
| |
| if status_code == 429: |
| return False |
|
|
| |
| if status_code == 403: |
| quota_keywords = [ |
| "quota_exceeded", |
| "quota exceeded", |
| "rate_limit_exceeded", |
| "billing account", |
| "billing is disabled", |
| "insufficient quota", |
| "out of quota", |
| "daily limit", |
| "per-day limit", |
| ] |
| if any(kw in body_lower for kw in quota_keywords): |
| return True |
|
|
| return False |
|
|
|
|
| def is_image_search_failed(result_text: str) -> bool: |
| """Check if an image search result indicates a failure. |
| |
| Returns True for any error, timeout, or empty-result response |
| from the Serper Lens API. |
| """ |
| if not result_text or not result_text.strip(): |
| return True |
| fail_patterns = [ |
| "Image search error:", |
| "No results found from reverse image search", |
| "Error: SERPER_API_KEY not configured", |
| "request timed out", |
| ] |
| for pattern in fail_patterns: |
| if pattern in result_text: |
| return True |
| return False |
|
|
|
|
| def add_search_padding(bbox: List[float], frame_path: str, |
| padding: tuple = (0.5, 0.5), |
| padding_cap_px: int = 600) -> List[float]: |
| """Add padding to a normalized [x1, y1, x2, y2] bbox for image search. |
| |
| Padding is proportional to the BBOX size (not image size), so that: |
| - A tight face crop gets moderate expansion (include shoulders, some background) |
| - A large crop doesn't balloon to cover the entire frame |
| |
| Args: |
| bbox: [x1, y1, x2, y2] normalized to [0.0, 1.0] |
| frame_path: path to the source frame (used to compute pixel cap) |
| padding: (pad_x_ratio, pad_y_ratio) as fraction of bbox width/height |
| e.g. (0.5, 0.5) means expand each side by 50% of bbox dimension |
| padding_cap_px: max padding in pixels on each side (prevents excessive |
| expansion on very large bboxes) |
| |
| Returns: |
| Padded [x1, y1, x2, y2] clamped to [0.0, 1.0] |
| """ |
| x1, y1, x2, y2 = bbox |
| pad_x_ratio, pad_y_ratio = padding |
|
|
| |
| bbox_w = x2 - x1 |
| bbox_h = y2 - y1 |
| pad_x = bbox_w * pad_x_ratio |
| pad_y = bbox_h * pad_y_ratio |
|
|
| |
| try: |
| with Image.open(frame_path) as img: |
| img_w, img_h = img.size |
| cap_x = padding_cap_px / img_w |
| cap_y = padding_cap_px / img_h |
| pad_x = min(pad_x, cap_x) |
| pad_y = min(pad_y, cap_y) |
| except Exception: |
| pass |
|
|
| x1 = max(0.0, x1 - pad_x) |
| y1 = max(0.0, y1 - pad_y) |
| x2 = min(1.0, x2 + pad_x) |
| y2 = min(1.0, y2 + pad_y) |
|
|
| return [x1, y1, x2, y2] |
|
|
|
|
|
|
| def crop_frame(frame_path: str, bbox: List[float], output_path: str) -> str: |
| """Crop frame at bbox coordinates with smart format detection, clamping, |
| and 2x upscaling for better search quality. |
| |
| Bbox format auto-detection: |
| - [0.0, 1.0] range → normalized relative coordinates (standard) |
| - [0, 1000] range → permille relative coordinates (some models output this) |
| - values > 1000 → absolute pixel coordinates |
| |
| After cropping: |
| - 2x LANCZOS upscale to help visual search models recognize small objects |
| - Save as high-quality JPEG (quality=95) to minimize compression artifacts |
| |
| Args: |
| frame_path: path to the source frame image |
| bbox: [x1, y1, x2, y2] in any of the three supported formats |
| output_path: where to save the cropped image |
| |
| Returns: |
| output_path |
| """ |
| with Image.open(frame_path) as img: |
| w, h = img.size |
|
|
| raw_x1, raw_y1, raw_x2, raw_y2 = bbox |
|
|
| |
| max_coord = max(abs(raw_x1), abs(raw_y1), abs(raw_x2), abs(raw_y2)) |
|
|
| if max_coord <= 1.0: |
| |
| px_x1 = raw_x1 * w |
| px_y1 = raw_y1 * h |
| px_x2 = raw_x2 * w |
| px_y2 = raw_y2 * h |
| elif max_coord <= 1000: |
| |
| px_x1 = raw_x1 / 1000.0 * w |
| px_y1 = raw_y1 / 1000.0 * h |
| px_x2 = raw_x2 / 1000.0 * w |
| px_y2 = raw_y2 / 1000.0 * h |
| else: |
| |
| px_x1 = raw_x1 |
| px_y1 = raw_y1 |
| px_x2 = raw_x2 |
| px_y2 = raw_y2 |
|
|
| |
| x1 = max(0, min(int(round(px_x1)), w - 1)) |
| y1 = max(0, min(int(round(px_y1)), h - 1)) |
| x2 = max(0, min(int(round(px_x2)), w)) |
| y2 = max(0, min(int(round(px_y2)), h)) |
|
|
| |
| if x2 <= x1: |
| x2 = min(x1 + 1, w) |
| if y2 <= y1: |
| y2 = min(y1 + 1, h) |
|
|
| |
| cropped_img = img.crop((x1, y1, x2, y2)) |
|
|
| |
| cropped_img = cropped_img.resize( |
| (cropped_img.width * 2, cropped_img.height * 2), |
| Image.Resampling.LANCZOS, |
| ) |
|
|
| |
| cropped_img.save(output_path, 'JPEG', quality=95) |
|
|
| return output_path |
|
|
|
|
|
|
| def encode_image_b64(path: str) -> str: |
| with open(path, "rb") as f: |
| return base64.b64encode(f.read()).decode("ascii") |
|
|
|
|
| |
| |
| |
|
|
| class LLMClient: |
| """Async LLM client with round-robin load balancing across API endpoints. |
| |
| v3: call() now returns (content, reasoning_content, TokenUsage). |
| v4: Google Vertex AI authentication support (OAuth2 auto-refresh). |
| v5: Vertex AI Native generateContent API support (for global region). |
| """ |
|
|
| def __init__(self, api_urls: List[str], model: str, |
| concurrency_per_url: int = 2, |
| temperature: float = DEFAULT_TEMPERATURE, |
| top_p: float = DEFAULT_TOP_P, |
| max_tokens: int = DEFAULT_MAX_TOKENS, |
| timeout: int = DEFAULT_TIMEOUT, |
| max_retries: int = DEFAULT_MAX_RETRIES, |
| api_key: str = API_KEY): |
| self.api_urls = api_urls |
| self.model = model |
| self.temperature = temperature |
| self.top_p = top_p |
| self.max_tokens = max_tokens |
| self.timeout = timeout |
| self.max_retries = max_retries |
| self.api_key = api_key |
| self._url_cycle = itertools.cycle(api_urls) |
| self._semaphores = {url: asyncio.Semaphore(concurrency_per_url) |
| for url in api_urls} |
| self._cycle_lock = asyncio.Lock() |
|
|
| |
| self._active_url_index = 0 |
| self._url_switch_lock = asyncio.Lock() |
| self._project_cooldown_until = {url: 0.0 for url in api_urls} |
| self._disabled_urls: Set[str] = set() |
|
|
| |
| self._vertex_min_interval = max(0.0, float(VERTEX_MIN_REQUEST_INTERVAL_SECONDS or 0.0)) |
| self._vertex_rate_limit_cooldown = max(0.0, float(VERTEX_RATE_LIMIT_COOLDOWN_SECONDS or 0.0)) |
| self._vertex_request_jitter = max(0.0, float(VERTEX_REQUEST_JITTER_SECONDS or 0.0)) |
| self._vertex_next_request_at = 0.0 |
| self._vertex_pacing_lock = asyncio.Lock() |
|
|
| |
| self._use_vertex_auth = self._detect_vertex_endpoint() |
| self._use_vertex_native = self._detect_vertex_native() |
| self._gcp_credentials = None |
| self._gcp_credentials_by_url: Dict[str, Any] = {} |
| self._gcp_auth_request = None |
| if self._use_vertex_auth: |
| self._init_gcp_auth() |
|
|
| def _detect_vertex_endpoint(self) -> bool: |
| """检测是否使用 Vertex AI endpoint(根据 URL 判断)。""" |
| for url in self.api_urls: |
| if "aiplatform.googleapis.com" in url: |
| return True |
| return False |
|
|
| def _detect_vertex_native(self) -> bool: |
| """检测是否使用 Vertex AI 原生 generateContent API。 |
| |
| 判断依据:URL 中包含 /publishers/google/models(原生) |
| 而非 /endpoints/openapi/chat/completions(OpenAI 兼容) |
| """ |
| for url in self.api_urls: |
| if "/publishers/google/models" in url: |
| return True |
| |
| try: |
| use_native = getattr(config, 'USE_VERTEX_NATIVE_API', False) |
| if use_native: |
| return True |
| except Exception: |
| pass |
| return False |
|
|
| def _init_gcp_auth(self): |
| """初始化 Google Cloud 认证凭据。""" |
| try: |
| import google.auth |
| import google.auth.transport.requests |
| from google.oauth2 import service_account as sa_module |
|
|
| scopes = ["https://www.googleapis.com/auth/cloud-platform"] |
| self._gcp_auth_request = google.auth.transport.requests.Request() |
|
|
| pool_records = [] |
| for item in VERTEX_CREDENTIALS_POOL: |
| if not isinstance(item, dict): |
| continue |
| api_url = item.get("api_url") |
| if api_url in self.api_urls: |
| pool_records.append(item) |
|
|
| if pool_records: |
| for item in pool_records: |
| credential = None |
| key_path = item.get("service_account_key") or "" |
| key_info = item.get("service_account_info") |
| if key_path and os.path.exists(key_path): |
| credential = sa_module.Credentials.from_service_account_file( |
| key_path, scopes=scopes, |
| ) |
| print( |
| f"[GCP AUTH] 使用账号池密钥: account={item.get('account_name', '')} " |
| f"project={item.get('project_id', '')} key={key_path}" |
| ) |
| elif isinstance(key_info, dict): |
| credential = sa_module.Credentials.from_service_account_info( |
| key_info, scopes=scopes, |
| ) |
| print( |
| f"[GCP AUTH] 使用账号池内嵌密钥: account={item.get('account_name', '')} " |
| f"project={item.get('project_id', '')}" |
| ) |
| else: |
| raise RuntimeError( |
| f"Vertex 账号池条目缺少可用凭证: project={item.get('project_id', '')}" |
| ) |
|
|
| credential.refresh(self._gcp_auth_request) |
| self._gcp_credentials_by_url[item["api_url"]] = credential |
|
|
| print(f"[GCP AUTH] 账号池认证成功,可用条目数: {len(self._gcp_credentials_by_url)}") |
| return |
|
|
| if GCP_SERVICE_ACCOUNT_KEY and os.path.exists(GCP_SERVICE_ACCOUNT_KEY): |
| |
| self._gcp_credentials = sa_module.Credentials.from_service_account_file( |
| GCP_SERVICE_ACCOUNT_KEY, scopes=scopes, |
| ) |
| print(f"[GCP AUTH] 使用服务账号密钥: {GCP_SERVICE_ACCOUNT_KEY}") |
| else: |
| |
| |
| self._gcp_credentials, project = google.auth.default(scopes=scopes) |
| print(f"[GCP AUTH] 使用 ADC (Application Default Credentials), project={project}") |
|
|
| |
| self._gcp_credentials.refresh(self._gcp_auth_request) |
| print(f"[GCP AUTH] 认证成功,token 有效期至 {self._gcp_credentials.expiry}") |
|
|
| except ImportError: |
| raise ImportError( |
| "使用 Vertex AI 需要安装 google-auth 库:\n" |
| " pip install google-auth google-auth-httplib2" |
| ) |
| except Exception as e: |
| raise RuntimeError( |
| f"Google Cloud 认证失败: {e}\n" |
| f"请确保已配置 GCP_SERVICE_ACCOUNT_KEY 或运行 " |
| f"`gcloud auth application-default login`" |
| ) |
|
|
| def _get_gcp_token(self, api_url: Optional[str] = None) -> str: |
| """获取有效的 GCP OAuth2 access token(自动刷新过期 token)。""" |
| if api_url and api_url in self._gcp_credentials_by_url: |
| credentials = self._gcp_credentials_by_url[api_url] |
| if credentials.expired or not credentials.token: |
| credentials.refresh(self._gcp_auth_request) |
| return credentials.token |
|
|
| if self._gcp_credentials.expired or not self._gcp_credentials.token: |
| self._gcp_credentials.refresh(self._gcp_auth_request) |
| return self._gcp_credentials.token |
|
|
| async def _next_url(self) -> str: |
| async with self._cycle_lock: |
| return next(self._url_cycle) |
|
|
| def _get_active_url(self) -> str: |
| """获取当前活跃的 project URL(429 切换后会变化)。""" |
| return self.api_urls[self._active_url_index] |
|
|
| async def _acquire_vertex_request_slot(self): |
| """对 Vertex 请求做全局节流,避免共享池秒级 burst。""" |
| if not self._use_vertex_auth or self._vertex_min_interval <= 0: |
| return |
|
|
| while True: |
| async with self._vertex_pacing_lock: |
| now = time.monotonic() |
| wait = self._vertex_next_request_at - now |
| if wait <= 0: |
| reserve = self._vertex_min_interval + random.uniform(0, self._vertex_request_jitter) |
| self._vertex_next_request_at = now + reserve |
| return |
| if wait > 0.5: |
| print(f" [VERTEX PACE] waiting {wait:.1f}s before next request") |
| await asyncio.sleep(max(wait, 0.05)) |
|
|
| async def _mark_project_cooldown(self, failed_url: str): |
| """某个 project 返回 429 后,短时间内不要立即打回去。""" |
| if self._vertex_rate_limit_cooldown <= 0: |
| return |
|
|
| cooldown = self._vertex_rate_limit_cooldown + random.uniform(0, self._vertex_request_jitter) |
| async with self._url_switch_lock: |
| until_ts = time.monotonic() + cooldown |
| prev = self._project_cooldown_until.get(failed_url, 0.0) |
| self._project_cooldown_until[failed_url] = max(prev, until_ts) |
| print(f" [PROJECT POOL] cooldown {cooldown:.1f}s for {failed_url[:80]}...") |
|
|
| async def _disable_project_url(self, failed_url: str, reason: str): |
| """禁用 suspended / API-disabled / permission-denied 的 project。""" |
| async with self._url_switch_lock: |
| self._disabled_urls.add(failed_url) |
| print(f" [PROJECT POOL] disabled project: {failed_url[:80]}... reason={reason[:160]}") |
| if len(self._disabled_urls) >= len(self.api_urls): |
| raise ProjectDisabledError("All Vertex projects are disabled or suspended.") |
| if self.api_urls[self._active_url_index] == failed_url: |
| for idx, url in enumerate(self.api_urls): |
| if url not in self._disabled_urls: |
| self._active_url_index = idx |
| break |
|
|
| async def _acquire_project_url(self) -> Tuple[str, int]: |
| """选择一个当前未处于冷却期的 project;如果都在冷却,则等待最早恢复的那个。""" |
| while True: |
| async with self._url_switch_lock: |
| now = time.monotonic() |
| count = len(self.api_urls) |
| for offset in range(count): |
| idx = (self._active_url_index + offset) % count |
| url = self.api_urls[idx] |
| if url in self._disabled_urls: |
| continue |
| if self._project_cooldown_until.get(url, 0.0) <= now: |
| self._active_url_index = idx |
| return url, idx |
|
|
| available_urls = [u for u in self.api_urls if u not in self._disabled_urls] |
| if not available_urls: |
| raise ProjectDisabledError("All Vertex projects are disabled or suspended.") |
| soonest_url = min(available_urls, key=lambda u: self._project_cooldown_until.get(u, 0.0)) |
| wait = max(0.0, self._project_cooldown_until.get(soonest_url, 0.0) - now) |
|
|
| print(f" [PROJECT POOL] all projects cooling down, waiting {wait:.1f}s") |
| await asyncio.sleep(max(wait, 0.1)) |
|
|
| async def _switch_to_next_project(self, failed_index: int) -> bool: |
| """429 时切换到下一个 project。返回 True 如果成功切换到不同的 project。""" |
| async with self._url_switch_lock: |
| |
| if self._active_url_index != failed_index: |
| return True |
| now = time.monotonic() |
| count = len(self.api_urls) |
| for offset in range(1, count + 1): |
| next_index = (failed_index + offset) % count |
| next_url = self.api_urls[next_index] |
| if next_url in self._disabled_urls: |
| continue |
| if self._project_cooldown_until.get(next_url, 0.0) <= now: |
| self._active_url_index = next_index |
| print(f" [PROJECT POOL] 429 → 切换到 project #{next_index + 1}/{len(self.api_urls)}: {next_url[:80]}...") |
| return True |
| return False |
|
|
| def _all_projects_cooling_down(self) -> bool: |
| now = time.monotonic() |
| active_urls = [url for url in self.api_urls if url not in self._disabled_urls] |
| return bool(active_urls) and all(self._project_cooldown_until.get(url, 0.0) > now for url in active_urls) |
|
|
| @staticmethod |
| def _is_project_disabled_error(err_str: str) -> bool: |
| err_lower = err_str.lower() |
| return any(marker in err_lower for marker in ( |
| "consumer_suspended", |
| "has been suspended", |
| "service_disabled", |
| "api has not been used", |
| "api is disabled", |
| )) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _convert_messages_to_vertex_native( |
| messages: List[Dict], |
| ) -> Tuple[Optional[Dict], List[Dict]]: |
| """将 OpenAI 格式的 messages 转换为 Vertex AI 原生格式。 |
| |
| OpenAI 格式: |
| [{"role": "system", "content": "..."}, |
| {"role": "user", "content": "..." | [{"type":"text",...}, {"type":"image_url",...}]}, |
| {"role": "assistant", "content": "..."}] |
| |
| Vertex 原生格式: |
| systemInstruction: {"parts": [{"text": "..."}]} |
| contents: [ |
| {"role": "user", "parts": [{"text": "..."}, {"inlineData": {...}}]}, |
| {"role": "model", "parts": [{"text": "..."}]} |
| ] |
| |
| Returns: (system_instruction_dict_or_None, contents_list) |
| """ |
| system_instruction = None |
| contents = [] |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
|
|
| if role == "system": |
| |
| if isinstance(content, str): |
| system_instruction = {"parts": [{"text": content}]} |
| elif isinstance(content, list): |
| parts = [] |
| for item in content: |
| if isinstance(item, str): |
| parts.append({"text": item}) |
| elif isinstance(item, dict) and item.get("type") == "text": |
| parts.append({"text": item.get("text", "")}) |
| system_instruction = {"parts": parts} |
| continue |
|
|
| |
| vertex_role = "model" if role == "assistant" else "user" |
|
|
| |
| parts = [] |
| if isinstance(content, str): |
| if content.strip(): |
| parts.append({"text": content}) |
| elif isinstance(content, list): |
| for item in content: |
| if isinstance(item, str): |
| if item.strip(): |
| parts.append({"text": item}) |
| elif isinstance(item, dict): |
| item_type = item.get("type", "") |
| if item_type == "text": |
| text_val = item.get("text", "") |
| if text_val.strip(): |
| parts.append({"text": text_val}) |
| elif item_type == "image_url": |
| |
| image_url = item.get("image_url", {}) |
| url = image_url.get("url", "") if isinstance(image_url, dict) else "" |
| if url.startswith("data:"): |
| |
| |
| try: |
| header, b64_data = url.split(",", 1) |
| |
| mime_type = header.split(":")[1].split(";")[0] |
| parts.append({ |
| "inlineData": { |
| "mimeType": mime_type, |
| "data": b64_data, |
| } |
| }) |
| except (ValueError, IndexError): |
| |
| pass |
| elif url.startswith("gs://"): |
| |
| parts.append({ |
| "fileData": { |
| "fileUri": url, |
| "mimeType": "image/jpeg", |
| } |
| }) |
|
|
| if parts: |
| contents.append({"role": vertex_role, "parts": parts}) |
|
|
| return system_instruction, contents |
|
|
| @staticmethod |
| def _parse_vertex_native_response( |
| result: Dict[str, Any], |
| ) -> Tuple[str, str, TokenUsage]: |
| """解析 Vertex AI 原生 generateContent 响应。 |
| |
| Vertex 响应格式: |
| { |
| "candidates": [{ |
| "content": { |
| "role": "model", |
| "parts": [{"text": "..."}, ...] |
| }, |
| "finishReason": "STOP" |
| }], |
| "usageMetadata": { |
| "promptTokenCount": 100, |
| "candidatesTokenCount": 50, |
| "totalTokenCount": 150, |
| "thoughtsTokenCount": 20 // 可选,thinking tokens |
| } |
| } |
| |
| Returns: (content_text, reasoning_text, TokenUsage) |
| """ |
| |
| prompt_feedback = result.get("promptFeedback", {}) |
| block_reason = prompt_feedback.get("blockReason", "") |
| if block_reason == "PROHIBITED_CONTENT": |
| raise ProhibitedContentError( |
| f"Content blocked by Vertex safety filter: {block_reason}") |
|
|
| candidates = result.get("candidates", []) |
| if not candidates: |
| raise ValueError(f"Empty candidates in Vertex response: {json.dumps(result)[:300]}") |
|
|
| candidate = candidates[0] |
| parts = candidate.get("content", {}).get("parts", []) |
|
|
| content_text = "" |
| reasoning_text = "" |
|
|
| for part in parts: |
| if "text" in part: |
| |
| |
| if part.get("thought", False): |
| reasoning_text += part["text"] |
| else: |
| content_text += part["text"] |
|
|
| |
| usage_meta = result.get("usageMetadata", {}) |
| prompt_tokens = usage_meta.get("promptTokenCount", 0) or 0 |
| completion_tokens = usage_meta.get("candidatesTokenCount", 0) or 0 |
| total_tokens = usage_meta.get("totalTokenCount", 0) or (prompt_tokens + completion_tokens) |
| thinking_tokens = usage_meta.get("thoughtsTokenCount", 0) or 0 |
| cached_tokens = usage_meta.get("cachedContentTokenCount", 0) or 0 |
|
|
| usage = TokenUsage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=total_tokens, |
| thinking_tokens=thinking_tokens, |
| cached_tokens=cached_tokens, |
| ) |
|
|
| return content_text, reasoning_text, usage |
|
|
| |
| |
| |
|
|
| async def call( |
| self, |
| messages: List[Dict], |
| session: aiohttp.ClientSession, |
| temperature: float = None, |
| max_tokens: int = None, |
| ) -> Tuple[str, str, TokenUsage]: |
| """Call LLM with load balancing and project pool failover. |
| |
| 自动根据 API 模式选择 OpenAI 兼容 或 Vertex Native 格式。 |
| 遇到 429 限流时自动切换到下一个 project URL 重试。 |
| Returns (content, reasoning_content, TokenUsage). |
| """ |
| last_err = None |
| start_index = self._active_url_index |
| |
| for _pool_attempt in range(len(self.api_urls)): |
| base_url, url_index = await self._acquire_project_url() |
| sem = self._semaphores[base_url] |
|
|
| try: |
| if self._use_vertex_native: |
| return await self._call_vertex_native( |
| messages, session, base_url, sem, temperature, max_tokens |
| ) |
| else: |
| return await self._call_openai_compat( |
| messages, session, base_url, sem, temperature, max_tokens |
| ) |
| except Exception as e: |
| err_str = str(e) |
| is_rate_limit = ( |
| "429" in err_str |
| or "RESOURCE_EXHAUSTED" in err_str |
| or "rate limit" in err_str.lower() |
| ) |
| if self._is_project_disabled_error(err_str): |
| last_err = e |
| await self._disable_project_url(base_url, err_str) |
| continue |
| if is_rate_limit and len(self.api_urls) > 1: |
| last_err = e |
| await self._mark_project_cooldown(base_url) |
| await self._switch_to_next_project(url_index) |
| |
| if self._all_projects_cooling_down(): |
| print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 正在 cooldown") |
| raise |
| if self._active_url_index == start_index: |
| print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 均已 429,上抛异常") |
| raise |
| continue |
| |
| raise |
|
|
| |
| raise last_err or RuntimeError("All projects in pool exhausted (429)") |
|
|
| async def _call_openai_compat( |
| self, |
| messages: List[Dict], |
| session: aiohttp.ClientSession, |
| api_url: str, |
| sem: asyncio.Semaphore, |
| temperature: float = None, |
| max_tokens: int = None, |
| ) -> Tuple[str, str, TokenUsage]: |
| """OpenAI 兼容端点调用(原有逻辑,保持不变)。""" |
| payload = { |
| "model": self.model, |
| "messages": messages, |
| "temperature": temperature or self.temperature, |
| "top_p": self.top_p, |
| "max_tokens": max_tokens or self.max_tokens, |
| } |
|
|
| for attempt in range(self.max_retries): |
| try: |
| headers = {"Content-Type": "application/json"} |
|
|
| if self._use_vertex_auth: |
| token = self._get_gcp_token(api_url) |
| headers["Authorization"] = f"Bearer {token}" |
| elif self.api_key: |
| headers["Authorization"] = f"Bearer {self.api_key}" |
|
|
| async with sem: |
| await self._acquire_vertex_request_slot() |
| async with session.post( |
| api_url, json=payload, headers=headers, |
| timeout=aiohttp.ClientTimeout(total=self.timeout), |
| ) as resp: |
| if resp.status != 200: |
| error_body = await resp.text() |
| if _is_quota_exhausted(resp.status, error_body): |
| raise QuotaExhaustedError( |
| f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}") |
| raise RuntimeError( |
| f"HTTP {resp.status} from {api_url}: {error_body[:500]}") |
| result = await resp.json() |
|
|
| choices = result.get("choices", []) |
| if not choices: |
| raise ValueError(f"Empty choices: {json.dumps(result)[:300]}") |
|
|
| message = choices[0].get("message", {}) |
| content = message.get("content", "") or "" |
| reasoning = message.get("reasoning_content", "") or "" |
|
|
| |
| usage_data = result.get("usage", {}) |
| usage = TokenUsage.from_api_response(usage_data) |
|
|
| return content, reasoning, usage |
|
|
| except QuotaExhaustedError: |
| raise |
| except ProhibitedContentError: |
| raise |
| except ProjectDisabledError: |
| raise |
| except Exception as e: |
| if attempt == self.max_retries - 1: |
| raise |
| err_str = str(e) |
| if self._is_project_disabled_error(err_str): |
| raise ProjectDisabledError(err_str) |
| if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower(): |
| if len(self.api_urls) > 1: |
| |
| raise |
| |
| wait = (15 * (2 ** attempt)) + random.uniform(0, 5) |
| print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})") |
| else: |
| wait = 2 ** attempt |
| await asyncio.sleep(wait) |
|
|
| raise RuntimeError("Unreachable") |
|
|
| async def _call_vertex_native( |
| self, |
| messages: List[Dict], |
| session: aiohttp.ClientSession, |
| base_url: str, |
| sem: asyncio.Semaphore, |
| temperature: float = None, |
| max_tokens: int = None, |
| ) -> Tuple[str, str, TokenUsage]: |
| """Vertex AI 原生 generateContent 调用。 |
| |
| base_url 格式: https://aiplatform.googleapis.com/v1/projects/{P}/locations/{L}/publishers/google/models |
| 实际请求 URL = base_url/{model}:generateContent |
| """ |
| |
| |
| model_name = self.model |
| if model_name.startswith("google/"): |
| model_name = model_name[len("google/"):] |
|
|
| full_url = f"{base_url}/{model_name}:generateContent" |
|
|
| |
| system_instruction, contents = self._convert_messages_to_vertex_native(messages) |
|
|
| |
| payload: Dict[str, Any] = { |
| "contents": contents, |
| "generationConfig": { |
| "temperature": temperature or self.temperature, |
| "topP": self.top_p, |
| "maxOutputTokens": max_tokens or self.max_tokens, |
| }, |
| } |
| if system_instruction: |
| payload["systemInstruction"] = system_instruction |
|
|
| for attempt in range(self.max_retries): |
| try: |
| headers = {"Content-Type": "application/json"} |
|
|
| if self._use_vertex_auth: |
| token = self._get_gcp_token(base_url) |
| headers["Authorization"] = f"Bearer {token}" |
|
|
| async with sem: |
| await self._acquire_vertex_request_slot() |
| async with session.post( |
| full_url, json=payload, headers=headers, |
| timeout=aiohttp.ClientTimeout(total=self.timeout), |
| ) as resp: |
| if resp.status != 200: |
| error_body = await resp.text() |
| if _is_quota_exhausted(resp.status, error_body): |
| raise QuotaExhaustedError( |
| f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}") |
| raise RuntimeError( |
| f"HTTP {resp.status} from {full_url}: {error_body[:500]}") |
| result = await resp.json() |
|
|
| content, reasoning, usage = self._parse_vertex_native_response(result) |
| return content, reasoning, usage |
|
|
| except QuotaExhaustedError: |
| raise |
| except ProhibitedContentError: |
| raise |
| except ProjectDisabledError: |
| raise |
| except Exception as e: |
| if attempt == self.max_retries - 1: |
| raise |
| err_str = str(e) |
| if self._is_project_disabled_error(err_str): |
| raise ProjectDisabledError(err_str) |
| if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower(): |
| if len(self.api_urls) > 1: |
| |
| raise |
| |
| wait = (15 * (2 ** attempt)) + random.uniform(0, 5) |
| print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})") |
| else: |
| wait = 2 ** attempt |
| print(f" [WARN] Vertex native call failed (attempt {attempt+1}): {e}") |
| await asyncio.sleep(wait) |
|
|
| raise RuntimeError("Unreachable") |
|
|
|
|
|
|
| |
| |
| |
|
|
| def build_user_message(text: str, image_paths: List[str] = None) -> Dict: |
| """Build user message with optional base64 images.""" |
| if not image_paths: |
| return {"role": "user", "content": text} |
| content = [] |
| for path in image_paths: |
| b64 = encode_image_b64(path) |
| content.append({ |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{b64}"} |
| }) |
| content.append({"type": "text", "text": text}) |
| return {"role": "user", "content": content} |
|
|
|
|
| def build_assistant_message(text: str) -> Dict: |
| return {"role": "assistant", "content": text} |
|
|
|
|
| |
| |
| |
|
|
| def parse_llm_response(text: str) -> Tuple[str, Any]: |
| """Parse LLM response to extract the FIRST action by position. |
| |
| Returns ('tool_call', dict), ('answer', str), or ('error', str). |
| |
| CRITICAL FIX: Uses position-based priority instead of always preferring |
| <answer>. The first action tag that appears in the text wins. |
| This prevents hallucinated <answer> tags from overriding valid <tool_call> tags. |
| """ |
| |
| tc_m = re.search(r'<tool_call>\s*(.*?)\s*</tool_call>', text, re.DOTALL) |
| answer_m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, re.DOTALL) |
|
|
| if tc_m and answer_m: |
| |
| if tc_m.start() < answer_m.start(): |
| try: |
| return ("tool_call", json.loads(tc_m.group(1))) |
| except json.JSONDecodeError: |
| return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}") |
| else: |
| return ("answer", answer_m.group(1).strip()) |
|
|
| if tc_m: |
| try: |
| return ("tool_call", json.loads(tc_m.group(1))) |
| except json.JSONDecodeError: |
| return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}") |
|
|
| if answer_m: |
| return ("answer", answer_m.group(1).strip()) |
|
|
| return ("error", "No <tool_call> or <answer> found in response") |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| class ImageSearchCache: |
| """MD5-keyed cache for image search results.""" |
|
|
| def __init__(self, cache_file: str): |
| self.cache_file = cache_file |
| self.cache = self._load() |
| self._dirty_count = 0 |
|
|
| def _load(self) -> Dict[str, str]: |
| if os.path.exists(self.cache_file): |
| try: |
| with open(self.cache_file, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except Exception: |
| return {} |
| return {} |
|
|
| def save(self): |
| with open(self.cache_file, "w", encoding="utf-8") as f: |
| json.dump(self.cache, f, ensure_ascii=False, indent=2) |
| self._dirty_count = 0 |
|
|
| def get(self, image_bytes: bytes) -> Optional[str]: |
| key = hashlib.md5(image_bytes).hexdigest() |
| return self.cache.get(key) |
|
|
| def set(self, image_bytes: bytes, result: str): |
| key = hashlib.md5(image_bytes).hexdigest() |
| self.cache[key] = result |
| self._dirty_count += 1 |
| if self._dirty_count >= 10: |
| self.save() |
|
|
|
|
| |
|
|
| def mock_image_search(entity: str, bbox: List[float]) -> str: |
| """Mock image search result for offline testing.""" |
| return ( |
| f"Reverse Image Search Results:\n\n" |
| f"Result 1:\n" |
| f" Title: {entity} - Character Profile\n" |
| f" Snippet: {entity} is a well-known character/entity.\n" |
| f" URL: https://example.com/{entity.lower().replace(' ', '_')}\n\n" |
| f"Result 2:\n" |
| f" Title: {entity} | Wiki\n" |
| f" Snippet: Detailed information about {entity}.\n" |
| f" URL: https://wiki.example.com/{entity.lower().replace(' ', '_')}" |
| ) |
|
|
|
|
| def mock_web_search(query: str) -> str: |
| """Mock web search result for offline testing.""" |
| return ( |
| f'Web Search Results for "{query}":\n\n' |
| f"Quick Answer: Information related to the query.\n\n" |
| f"Result 1: {query} - Overview\n" |
| f" Relevant information about {query}.\n\n" |
| f"Result 2: {query} - Details\n" |
| f" Additional details and facts." |
| ) |
|
|
|
|
| |
|
|
|
|
| def _format_serper_lens_results(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str: |
| """Format raw Serper Lens organic results into readable text.""" |
| parts = [] |
| for i, item in enumerate(organic_results[:max_results], 1): |
| title = item.get('title', '') |
| snippet = item.get('snippet', '') |
| link = item.get('link', '') |
| source = item.get('source', '') or item.get('domain', '') |
|
|
| block = [f"Result {i}:"] |
| if title: |
| block.append(f" Title: {title}") |
| if snippet: |
| block.append(f" Snippet: {snippet}") |
| if source: |
| block.append(f" Source: {source}") |
| if link: |
| block.append(f" URL: {link}") |
| parts.append("\n".join(block)) |
| return "\n\n".join(parts) |
|
|
|
|
| def _build_serper_lens_summary_prompt(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str: |
| """Build English summarizer prompt using only Serper Lens result metadata.""" |
| context_parts = [] |
| for i, item in enumerate(organic_results[:max_results], 1): |
| title = item.get('title', '') |
| snippet = item.get('snippet', '') |
| link = item.get('link', '') |
| source = item.get('source', '') or item.get('domain', '') |
|
|
| block = [f"Result {i}:"] |
| if title: |
| block.append(f"Title: {title}") |
| if snippet: |
| block.append(f"Snippet: {snippet}") |
| if source: |
| block.append(f"Source: {source}") |
| if link: |
| block.append(f"Link: {link}") |
| context_parts.append("\n".join(block)) |
|
|
| context_text = "\n\n".join(context_parts) |
| return ( |
| "You are a helpful assistant. Your task is to summarize the main content of the given " |
| "Serper Lens reverse image search results in no more than five sentences.\n\n" |
| "Your summary should cover the overall key points across the results, not just the parts " |
| "most related to the user's question.\n\n" |
| "If any part of the results is helpful for identifying the entity or answering the user's " |
| "question, include it clearly in the summary. Do not ignore relevant information, but make " |
| "sure the general structure and main ideas of the results are preserved.\n\n" |
| "Your summary should be concise, factual, and informative. If the results are ambiguous, " |
| "conflicting, or insufficient, clearly state that uncertainty.\n\n" |
| "Use only the provided result titles, snippets, and source/link metadata. Do not invent facts " |
| "and do not assume content from the linked pages.\n\n" |
| f"{context_text}" |
| ) |
|
|
|
|
| async def summarize_serper_image_results( |
| organic_results: List[Dict[str, Any]], |
| session: aiohttp.ClientSession, |
| summarizer_address: str = "", |
| summarizer_model: str = "", |
| max_results: int = 5, |
| max_tokens: int = 512, |
| ) -> Optional[str]: |
| """Summarize Serper Lens results without fetching linked webpages.""" |
| summarizer_addr = summarizer_address or IMAGE_SEARCH_SUMMARIZER_ADDRESS |
| sum_model = summarizer_model or IMAGE_SEARCH_SUMMARIZER_MODEL |
|
|
| if not organic_results or not summarizer_addr or not sum_model: |
| return None |
|
|
| summarizer_prompt = _build_serper_lens_summary_prompt(organic_results, max_results=max_results) |
| summarizer_payload = { |
| "model": sum_model, |
| "messages": [{"role": "user", "content": summarizer_prompt}], |
| "max_tokens": max_tokens, |
| "temperature": 0.3, |
| "chat_template_kwargs": {"enable_thinking": False}, |
| } |
|
|
| try: |
| async with session.post( |
| f"http://{summarizer_addr}/v1/chat/completions", |
| json=summarizer_payload, |
| headers={"Content-Type": "application/json"}, |
| timeout=aiohttp.ClientTimeout(total=120), |
| ) as resp: |
| if resp.status != 200: |
| print(f" [IMAGE_SEARCH] Summarizer returned HTTP {resp.status}, falling back to raw results") |
| return None |
|
|
| data = await resp.json() |
| choices = data.get("choices", []) |
| if choices and isinstance(choices, list): |
| msg = choices[0].get("message", {}) |
| summary = msg.get("content", "") |
| if summary and summary.strip(): |
| summary = _strip_thinking_tags(summary).strip() |
| return summary or None |
| return None |
| except asyncio.TimeoutError: |
| print(" [IMAGE_SEARCH] Summarizer timeout, falling back to raw results") |
| return None |
| except Exception as e: |
| print(f" [IMAGE_SEARCH] Summarizer error: {e}, falling back to raw results") |
| return None |
|
|
|
|
| async def real_image_search( |
| image_b64_or_path: str, |
| session: aiohttp.ClientSession, |
| api_key: str, |
| crop_path: str = None, |
| ) -> str: |
| """反向图片搜索,根据 IMAGE_SEARCH_MODE 选择直连 Serper 或公司内部网关。""" |
| if IMAGE_SEARCH_MODE == "gateway": |
| return await _gateway_image_search(image_b64_or_path, session, crop_path) |
| else: |
| return await _serper_image_search(image_b64_or_path, session, api_key, crop_path) |
|
|
|
|
| async def _gateway_image_search( |
| image_b64_or_path: str, |
| session: aiohttp.ClientSession, |
| crop_path: str = None, |
| ) -> str: |
| """反向图片搜索 via 公司内部网关 → Serper Google Lens + optional LLM summarization.""" |
| if not GATEWAY_TOKEN: |
| return "Error: GATEWAY_TOKEN not configured." |
|
|
| image_url, prep_error = _prepare_image_search_url( |
| image_b64_or_path, crop_path, "IMAGE_SEARCH/GATEWAY" |
| ) |
| if prep_error: |
| return prep_error |
|
|
| headers = { |
| 'Content-Type': 'application/json', |
| 'User-Agent': 'ifbook-http-client', |
| } |
| serper_params = { |
| "url": image_url, |
| "type": "lens", |
| } |
| gateway_payload = { |
| "sec_info": { |
| "username": GATEWAY_USERNAME, |
| "userid": GATEWAY_USERID, |
| "token": GATEWAY_TOKEN, |
| }, |
| "model_type": "openai", |
| "model_name": "serper", |
| "params": json.dumps(serper_params), |
| } |
|
|
| max_api_retries = 2 |
| last_error = None |
| for api_attempt in range(max_api_retries): |
| try: |
| async with session.post( |
| GATEWAY_URL, |
| headers=headers, |
| json=gateway_payload, |
| timeout=aiohttp.ClientTimeout(total=60), |
| ) as resp: |
| if resp.status != 200: |
| error_body = await resp.text() |
| last_error = (f"Gateway error: HTTP {resp.status}: " |
| f"{error_body[:300]}") |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH/GATEWAY] HTTP {resp.status}, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
|
|
| gateway_resp = await resp.json() |
| model_output_str = gateway_resp.get("model_output", "{}") |
| data = json.loads(model_output_str) |
|
|
| organic = data.get('organic', []) |
| if not organic: |
| return "No results found from reverse image search." |
|
|
| raw_results = _format_serper_lens_results( |
| organic, |
| max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, |
| ) |
|
|
| if not IMAGE_SEARCH_SUMMARIZE_SERPER: |
| return raw_results |
|
|
| summary = await summarize_serper_image_results( |
| organic, |
| session, |
| summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS, |
| summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL, |
| max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, |
| max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, |
| ) |
| if summary: |
| return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}" |
| return raw_results |
|
|
| except asyncio.TimeoutError: |
| last_error = "Image search error: request timed out after 60s" |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH/GATEWAY] Timeout, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
| except Exception as e: |
| last_error = f"Image search error: {e}" |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH/GATEWAY] Error: {e}, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
|
|
| return last_error or "Image search error: unknown failure" |
|
|
|
|
| async def _serper_image_search( |
| image_b64_or_path: str, |
| session: aiohttp.ClientSession, |
| api_key: str, |
| crop_path: str = None, |
| ) -> str: |
| """反向图片搜索 via Serper Google Lens + optional LLM summarization(原始直连方式)。""" |
| if not api_key: |
| return "Error: SERPER_API_KEY not configured." |
|
|
| headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'} |
|
|
| image_url, prep_error = _prepare_image_search_url( |
| image_b64_or_path, crop_path, "IMAGE_SEARCH" |
| ) |
| if prep_error: |
| return prep_error |
|
|
| max_api_retries = 2 |
| last_error = None |
| for api_attempt in range(max_api_retries): |
| try: |
| async with session.post( |
| "https://google.serper.dev/lens", |
| headers=headers, |
| json={"url": image_url}, |
| timeout=aiohttp.ClientTimeout(total=60), |
| ) as resp: |
| if resp.status != 200: |
| error_body = await resp.text() |
| last_error = (f"Image search error: HTTP {resp.status}: " |
| f"{error_body[:300]}") |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH] HTTP {resp.status}, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
|
|
| data = await resp.json() |
| organic = data.get('organic', []) |
| if not organic: |
| return "No results found from reverse image search." |
|
|
| raw_results = _format_serper_lens_results( |
| organic, |
| max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, |
| ) |
|
|
| if not IMAGE_SEARCH_SUMMARIZE_SERPER: |
| return raw_results |
|
|
| summary = await summarize_serper_image_results( |
| organic, |
| session, |
| summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS, |
| summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL, |
| max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, |
| max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, |
| ) |
| if summary: |
| return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}" |
| return raw_results |
|
|
| except asyncio.TimeoutError: |
| last_error = "Image search error: request timed out after 60s" |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH] Timeout, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
| except Exception as e: |
| last_error = f"Image search error: {e}" |
| if api_attempt < max_api_retries - 1: |
| print(f" [IMAGE_SEARCH] Error: {e}, " |
| f"retrying ({api_attempt+1}/{max_api_retries})...") |
| await asyncio.sleep(3 * (api_attempt + 1)) |
| continue |
| return last_error |
|
|
| return last_error or "Image search error: unknown failure" |
|
|
|
|
| async def real_web_search( |
| query: str, |
| session: aiohttp.ClientSession, |
| address: str = WEB_SEARCH_ADDRESS, |
| ) -> str: |
| """Web search via internal search server (SenseNova pattern).""" |
| payload = { |
| "query": query.strip().replace("\n", " "), |
| "top_k": 3, |
| "retrieval_mode": "google_serper", |
| **WEB_SEARCH_CONFIG, |
| } |
| try: |
| async with session.post( |
| f"http://{address}/search", json=payload, |
| timeout=aiohttp.ClientTimeout(total=100), |
| ) as resp: |
| resp.raise_for_status() |
| return await resp.text() |
| except asyncio.TimeoutError: |
| return f"Error: Web search timeout for query: {query[:100]}" |
| except Exception as e: |
| return f"Error: Web search failed: {e}" |
|
|
|
|
| async def serper_web_search( |
| query: str, |
| session: aiohttp.ClientSession, |
| api_key: str, |
| ) -> str: |
| """Fallback web search via Serper Google Search.""" |
| if not api_key: |
| return "Error: SERPER_API_KEY not configured." |
| headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'} |
| try: |
| async with session.post( |
| "https://google.serper.dev/search", |
| headers=headers, |
| json={"q": query}, |
| timeout=aiohttp.ClientTimeout(total=20), |
| ) as resp: |
| resp.raise_for_status() |
| data = await resp.json() |
| parts = [] |
| ab = data.get('answerBox', {}) |
| if ab: |
| answer = ab.get('answer') or ab.get('snippet', '') |
| if answer: |
| parts.append(f"Quick Answer: {answer}") |
| kg = data.get('knowledgeGraph', {}) |
| if kg: |
| parts.append( |
| f"Knowledge Graph: {kg.get('title', '')} - " |
| f"{kg.get('description', '')}") |
| for i, item in enumerate(data.get('organic', [])[:5], 1): |
| parts.append( |
| f"Result {i}: {item.get('title', '')}\n" |
| f" {item.get('snippet', '')}") |
| return "\n".join(parts) if parts else "No relevant results found." |
| except Exception as e: |
| return f"Web search error: {e}" |
|
|
| |
| _mars_retrieval_semaphore: Optional[asyncio.Semaphore] = None |
|
|
| def _get_mars_retrieval_semaphore() -> Optional[asyncio.Semaphore]: |
| global _mars_retrieval_semaphore |
| if MARS_RETRIEVAL_CONCURRENCY > 0 and _mars_retrieval_semaphore is None: |
| _mars_retrieval_semaphore = asyncio.Semaphore(MARS_RETRIEVAL_CONCURRENCY) |
| return _mars_retrieval_semaphore if MARS_RETRIEVAL_CONCURRENCY > 0 else None |
|
|
| async def mars_web_search( |
| query: str, |
| session: aiohttp.ClientSession, |
| retrieval_address: str = "", |
| summarizer_address: str = "", |
| retrieval_topk: int = 3, |
| summarizer_model: str = "", |
| ) -> str: |
| """SenseNova-MARS style web search: retrieve from local Wikipedia + summarize via LLM. |
| |
| Two-step pipeline: |
| 1. POST to Search-R1 retrieval server → get top-k document passages |
| 2. POST to summarizer LLM (OpenAI-compatible /v1/chat/completions) → get concise summary |
| |
| This mirrors the SenseNova-MARS web_search_server architecture but without |
| the intermediate FastAPI layer. |
| """ |
| retrieval_addr = retrieval_address or MARS_RETRIEVAL_ADDRESS |
| summarizer_addr = summarizer_address or MARS_SUMMARIZER_ADDRESS |
| topk = retrieval_topk or MARS_RETRIEVAL_TOPK |
| sum_model = summarizer_model or MARS_SUMMARIZER_MODEL |
|
|
| if not retrieval_addr: |
| return "Error: MARS_RETRIEVAL_ADDRESS not configured." |
| if not summarizer_addr: |
| return "Error: MARS_SUMMARIZER_ADDRESS not configured." |
|
|
| clean_query = query.strip().replace("\n", " ") |
|
|
| |
| |
| |
| retrieval_payload = { |
| "queries": [clean_query], |
| "return_scores": True, |
| "topk": topk, |
| } |
|
|
| retrieved_passages = [] |
| sem = _get_mars_retrieval_semaphore() |
| try: |
| if sem: |
| await sem.acquire() |
| try: |
| async with session.post( |
| f"http://{retrieval_addr}/retrieve", |
| json=retrieval_payload, |
| timeout=aiohttp.ClientTimeout(total=MARS_RETRIEVAL_TIMEOUT), |
| proxy="", |
| ) as resp: |
| if resp.status != 200: |
| error_body = await resp.text() |
| return (f"Error: Retrieval server returned HTTP {resp.status}: " |
| f"{error_body[:300]}") |
| data = await resp.json() |
|
|
| |
| |
| raw_results = data.get("result", []) |
| for query_results in raw_results: |
| if not isinstance(query_results, list): |
| continue |
| for item in query_results: |
| if not isinstance(item, dict): |
| continue |
| doc = item.get("document", {}) |
| if isinstance(doc, dict): |
| contents = doc.get("contents", "") |
| elif isinstance(doc, str): |
| contents = doc |
| else: |
| continue |
| if contents: |
| lines = contents.split("\n", 1) |
| title = lines[0].strip('"') if lines else "" |
| text = lines[1] if len(lines) > 1 else contents |
| |
| retrieved_passages.append({"title": title, "text": text[:2000]}) |
| finally: |
| if sem: |
| sem.release() |
|
|
| except asyncio.TimeoutError: |
| return f"Error: Retrieval server timeout for query: {clean_query[:100]}" |
| except (aiohttp.ClientConnectorError, |
| aiohttp.ServerDisconnectedError, |
| ConnectionRefusedError, |
| ConnectionResetError, |
| OSError) as e: |
| raise RetrieverDownError( |
| f"Retriever 服务 {retrieval_addr} 连接失败: {e}" |
| ) from e |
| except Exception as e: |
| err_str = str(e).lower() |
| if any(kw in err_str for kw in ( |
| "cannot connect", "connection refused", "connect call failed", |
| "server disconnected", |
| )): |
| raise RetrieverDownError( |
| f"Retriever 服务 {retrieval_addr} 连接失败: {e}" |
| ) from e |
| return f"Error: Retrieval failed: {e}" |
|
|
| if not retrieved_passages: |
| return f"No relevant passages found for query: {clean_query}" |
|
|
| |
| |
| |
| context_parts = [] |
| for i, p in enumerate(retrieved_passages, 1): |
| title_str = f" (Title: {p['title']})" if p['title'] else "" |
| context_parts.append(f"Passage {i}{title_str}:\n{p['text']}") |
| context_text = "\n\n".join(context_parts) |
|
|
| summarizer_prompt = ( |
| f"Based on the following retrieved passages, provide a concise and informative " |
| f"summary that answers the query: \"{clean_query}\"\n\n" |
| f"{context_text}\n\n" |
| f"Please provide a concise summary focusing on the most relevant information. " |
| f"If the passages do not contain relevant information, say so." |
| ) |
|
|
| summarizer_payload = { |
| "model": sum_model, |
| "messages": [{"role": "user", "content": summarizer_prompt}], |
| "max_tokens": 1024, |
| "temperature": 0.3, |
| "chat_template_kwargs": {"enable_thinking": False}, |
| } |
|
|
| try: |
| async with session.post( |
| f"http://{summarizer_addr}/v1/chat/completions", |
| json=summarizer_payload, |
| headers={"Content-Type": "application/json"}, |
| timeout=aiohttp.ClientTimeout(total=120), |
| proxy="", |
| ) as resp: |
| if resp.status != 200: |
| print(f" [MARS_SEARCH] Summarizer returned HTTP {resp.status}, " |
| f"falling back to raw passages") |
| return _format_raw_passages(clean_query, retrieved_passages) |
|
|
| data = await resp.json() |
| choices = data.get("choices", []) |
| if choices and isinstance(choices, list): |
| msg = choices[0].get("message", {}) |
| summary = msg.get("content", "") |
| if summary and summary.strip(): |
| summary = _strip_thinking_tags(summary) |
| return ( |
| f"Web Search Results for \"{clean_query}\" " |
| f"(MARS retrieve+summarize, {len(retrieved_passages)} passages):\n\n" |
| f"{summary.strip()}" |
| ) |
|
|
| return _format_raw_passages(clean_query, retrieved_passages) |
|
|
| except asyncio.TimeoutError: |
| print(f" [MARS_SEARCH] Summarizer timeout, falling back to raw passages") |
| return _format_raw_passages(clean_query, retrieved_passages) |
| except Exception as e: |
| print(f" [MARS_SEARCH] Summarizer error: {e}, falling back to raw passages") |
| return _format_raw_passages(clean_query, retrieved_passages) |
|
|
|
|
|
|
| def _strip_thinking_tags(text: str) -> str: |
| """Remove thinking content from summarizer output.""" |
| |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() |
| |
| text = re.sub(r'^Thinking Process:.*?(?=\n\n)', '', text, flags=re.DOTALL).strip() |
| |
| if text.startswith("Thinking"): |
| parts = text.split("\n\n", 1) |
| if len(parts) > 1: |
| text = parts[1].strip() |
| return text |
|
|
|
|
|
|
| def _format_raw_passages(query: str, passages: list) -> str: |
| """Format raw retrieved passages as fallback when summarizer fails.""" |
| parts = [f"Web Search Results for \"{query}\" (raw retrieval, {len(passages)} passages):"] |
| for i, p in enumerate(passages, 1): |
| title_str = f" — {p['title']}" if p['title'] else "" |
| text_preview = p['text'][:500] + ("..." if len(p['text']) > 500 else "") |
| parts.append(f"\nResult {i}{title_str}:\n {text_preview}") |
| return "\n".join(parts) |
|
|
| |
| |
| |
|
|
| def make_uid(entry: Dict) -> str: |
| """Generate a unique ID from id + video_filename. |
| A single 'id' can map to multiple videos, so we need a composite key.""" |
| eid = entry.get("id", "unknown") |
| vf = entry.get("video_filename", "") |
| if vf: |
| stem = os.path.splitext(vf)[0] |
| return f"{eid}__{stem}" |
| return eid |
|
|
|
|
| def load_completed_ids(output_file: str) -> Set[str]: |
| """Load already-completed entry UIDs from output JSONL for resume.""" |
| completed = set() |
| if not os.path.exists(output_file): |
| return completed |
| with open(output_file, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| record = json.loads(line) |
| if "uid" in record and "error" not in record: |
| completed.add(record["uid"]) |
| except json.JSONDecodeError: |
| pass |
| return completed |
|
|
|
|
| def get_question(entry: Dict) -> str: |
| """Get the appropriate question field based on verdict.""" |
| if entry.get("verdict") == "rewrite": |
| return entry.get("rewritten_question") or entry.get("original_question", "") |
| return entry.get("original_question", "") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _oss_bucket = None |
| _oss_init_attempted = False |
| _oss_last_error = "" |
|
|
|
|
| def _get_oss_bucket(): |
| """延迟初始化 OSS Bucket 对象,只初始化一次。""" |
| global _oss_bucket, _oss_init_attempted, _oss_last_error |
| if _oss_init_attempted: |
| return _oss_bucket |
| _oss_init_attempted = True |
| try: |
| import oss2 |
| if not OSS_ACCESS_KEY_ID or not OSS_ACCESS_KEY_SECRET: |
| _oss_last_error = "OSS_ACCESS_KEY_ID / OSS_ACCESS_KEY_SECRET 未配置" |
| print(f"[ERROR] {_oss_last_error}") |
| return None |
| auth = oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET) |
| _oss_bucket = oss2.Bucket(auth, f"https://{OSS_ENDPOINT}", OSS_BUCKET_NAME) |
| _oss_last_error = "" |
| print(f"[OSS] Bucket initialized: {OSS_BUCKET_NAME} @ {OSS_ENDPOINT}") |
| return _oss_bucket |
| except ImportError: |
| _oss_last_error = "oss2 未安装,请执行: pip install oss2 --break-system-packages" |
| print(f"[ERROR] {_oss_last_error}") |
| return None |
| except Exception as e: |
| _oss_last_error = f"OSS Bucket 初始化失败: {e}" |
| print(f"[ERROR] {_oss_last_error}") |
| return None |
|
|
|
|
| def optimize_crop_for_search(crop_path: str, output_path: str = None, |
| max_size: int = SEARCH_CROP_MAX_SIZE, |
| quality: int = SEARCH_CROP_JPEG_QUALITY) -> str: |
| """优化裁剪图用于搜索:缩小尺寸 + 降低质量,减少上传体积。 |
| |
| Args: |
| crop_path: 原始裁剪图路径 |
| output_path: 优化后保存路径(默认在同目录加 _opt 后缀) |
| max_size: 最大边长(像素) |
| quality: JPEG 质量 |
| |
| Returns: |
| 优化后的图片路径(如果优化失败则返回原路径)。 |
| """ |
| if output_path is None: |
| base, ext = os.path.splitext(crop_path) |
| output_path = f"{base}_opt{ext}" |
| try: |
| with Image.open(crop_path) as img: |
| w, h = img.size |
| if max(w, h) > max_size: |
| ratio = max_size / max(w, h) |
| new_w, new_h = int(w * ratio), int(h * ratio) |
| img = img.resize((new_w, new_h), Image.LANCZOS) |
| img.save(output_path, 'JPEG', quality=quality) |
| return output_path |
| except Exception as e: |
| print(f" [WARN] optimize_crop_for_search failed: {e}, using original") |
| return crop_path |
|
|
|
|
| def _upload_to_oss(local_path: str, oss_object_name: str = None) -> Optional[str]: |
| """上传图片到阿里云 OSS,返回公网 URL。 |
| |
| 去重策略:用文件内容的 MD5 作为对象名,相同内容不重复上传。 |
| |
| Args: |
| local_path: 本地图片路径 |
| oss_object_name: OSS 对象名,默认用 md5 hash 去重 |
| |
| Returns: |
| 公网 URL,失败返回 None |
| """ |
| global _oss_last_error |
| bucket = _get_oss_bucket() |
| if bucket is None: |
| return None |
|
|
| try: |
| |
| opt_path = optimize_crop_for_search(local_path) |
|
|
| |
| with open(opt_path, "rb") as f: |
| file_bytes = f.read() |
| file_hash = hashlib.md5(file_bytes).hexdigest() |
|
|
| if oss_object_name is None: |
| oss_object_name = f"{OSS_UPLOAD_PREFIX}/{file_hash}.jpg" |
|
|
| |
| try: |
| bucket.head_object(oss_object_name) |
| |
| public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}" |
| print(f" [OSS] Cache hit: {public_url}") |
| return public_url |
| except Exception: |
| |
| pass |
|
|
| |
| bucket.put_object(oss_object_name, file_bytes, headers={ |
| 'Content-Type': 'image/jpeg', |
| }) |
|
|
| public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}" |
| print(f" [OSS] Uploaded: {public_url}") |
| return public_url |
|
|
| except Exception as e: |
| _oss_last_error = f"OSS upload failed: {e}" |
| print(f" [WARN] {_oss_last_error}") |
| return None |
|
|
|
|
| def _prepare_image_search_url( |
| image_b64_or_path: str, |
| crop_path: Optional[str], |
| log_prefix: str, |
| ) -> Tuple[Optional[str], Optional[str]]: |
| """为图片搜索准备 URL;默认要求先上传到 OSS。""" |
| image_url = None |
| if crop_path and os.path.exists(crop_path): |
| image_url = _upload_to_oss(crop_path) |
| if image_url: |
| print(f" [{log_prefix}] Using OSS URL: {image_url}") |
| return image_url, None |
|
|
| if IMAGE_SEARCH_ALLOW_BASE64_FALLBACK: |
| image_url = f"data:image/jpeg;base64,{image_b64_or_path}" |
| print(f" [{log_prefix}] WARNING: OSS upload failed, falling back to " |
| f"base64 data URI (len={len(image_b64_or_path)}, may get 400)") |
| return image_url, None |
|
|
| failure_reason = _oss_last_error or "OSS upload returned None" |
| return None, ( |
| "Image search unavailable: OSS upload failed and base64 fallback is disabled. " |
| f"Root cause: {failure_reason}" |
| ) |
|
|