#!/usr/bin/env python3 """ Video DeepResearch 公共工具模块。 输入: - `config.py` 导出的模型端点、Vertex 配置、本地检索服务地址与价格参数。 - 视频帧、裁剪图、LLM messages、搜索请求与各阶段中间结果。 处理: - 提供帧采样、图片编码、搜索访问、输出清洗、token 统计等通用能力。 - 在 `LLMClient` 中实现多 endpoint 轮询、429 冷却切换、Vertex 原生调用与 OAuth2 认证。 - 支持“一个 project 对应一个 service account json”的 Vertex 凭证池,并按 URL 选择对应凭证。 输出: - 为 Phase1/Phase2 提供统一的工具函数、搜索函数、token 统计结构和 `LLMClient`。 - 返回规范化的模型响应、搜索结果、图片路径与 token 使用信息。 """ import asyncio import aiohttp import base64 import hashlib import itertools import json import math import os import random import re import subprocess import time import numpy as np from PIL import Image from pathlib import Path from typing import Dict, List, Tuple, Optional, Any, Set from dataclasses import dataclass, field from config import ( API_ENDPOINTS, DEFAULT_MODEL, API_KEY, WEB_SEARCH_ADDRESS, WEB_SEARCH_CONFIG, SERPER_API_KEY, MOCK_SEARCH, IMAGE_SEARCH_CACHE_FILE, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_MAX_TOKENS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, VERTEX_MIN_REQUEST_INTERVAL_SECONDS, VERTEX_RATE_LIMIT_COOLDOWN_SECONDS, VERTEX_REQUEST_JITTER_SECONDS, TOKEN_PRICING, THINKING_TOKEN_PRICING, DEFAULT_INPUT_PRICE_PER_1M, DEFAULT_OUTPUT_PRICE_PER_1M, OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET, OSS_ENDPOINT, OSS_BUCKET_NAME, OSS_UPLOAD_PREFIX, SEARCH_CROP_MAX_SIZE, SEARCH_CROP_JPEG_QUALITY, IMAGE_SEARCH_SUMMARIZE_SERPER, IMAGE_SEARCH_SUMMARIZER_ADDRESS, IMAGE_SEARCH_SUMMARIZER_MODEL, IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, IMAGE_SEARCH_MODE, GATEWAY_URL, GATEWAY_USERNAME, GATEWAY_USERID, GATEWAY_TOKEN, IMAGE_SEARCH_ALLOW_BASE64_FALLBACK, ) from config import BBOX_CONFIGS import config # ── MARS-style web search config ── MARS_RETRIEVAL_ADDRESS = getattr(config, 'MARS_RETRIEVAL_ADDRESS', '') MARS_SUMMARIZER_ADDRESS = getattr(config, 'MARS_SUMMARIZER_ADDRESS', '') MARS_RETRIEVAL_TOPK = getattr(config, 'MARS_RETRIEVAL_TOPK', 3) MARS_SUMMARIZER_MODEL = getattr(config, 'MARS_SUMMARIZER_MODEL', '') MARS_WEB_SEARCH_MODE = getattr(config, 'MARS_WEB_SEARCH_MODE', 'serper') MARS_RETRIEVAL_TIMEOUT = getattr(config, 'MARS_RETRIEVAL_TIMEOUT', 120) MARS_RETRIEVAL_CONCURRENCY = getattr(config, 'MARS_RETRIEVAL_CONCURRENCY', 0) IMAGE_SEARCH_SUMMARIZE_SERPER = getattr(config, 'IMAGE_SEARCH_SUMMARIZE_SERPER', True) IMAGE_SEARCH_SUMMARIZER_ADDRESS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_ADDRESS', MARS_SUMMARIZER_ADDRESS) IMAGE_SEARCH_SUMMARIZER_MODEL = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MODEL', MARS_SUMMARIZER_MODEL) IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS', 5) IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS = getattr(config, 'IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS', 512) GCP_PROJECT_ID = getattr(config, 'GCP_PROJECT_ID', '') GCP_LOCATION = getattr(config, 'GCP_LOCATION', '') GCP_SERVICE_ACCOUNT_KEY = getattr(config, 'GCP_SERVICE_ACCOUNT_KEY', '') VERTEX_CREDENTIALS_POOL = getattr(config, 'VERTEX_CREDENTIALS_POOL', []) # Import phase-specific system prompts from prompts.py from prompts import PHASE1_SYSTEM_PROMPT, PHASE2_SYSTEM_PROMPT, SYSTEM_PROMPT_BASE # Legacy alias — kept for backward compatibility if any external code references it SYSTEM_PROMPT = PHASE1_SYSTEM_PROMPT # ════════════════════════════════════════════════════════════════════════ # Token Tracking & Cost Estimation # ════════════════════════════════════════════════════════════════════════ def _get_pricing(model: str) -> Tuple[float, float, float]: """Get (input_price, output_price, thinking_price) per 1M tokens for a model. Matches model name by substring (e.g., 'gemini-2.5-flash' matches 'gemini-2.5-flash-preview-05-20'). Returns prices in USD per 1M tokens. """ input_price = DEFAULT_INPUT_PRICE_PER_1M output_price = DEFAULT_OUTPUT_PRICE_PER_1M thinking_price = None # None means use output_price model_lower = model.lower() # Find best match (longest matching key wins) best_match = "" for pattern, (inp, outp) in TOKEN_PRICING.items(): if pattern.lower() in model_lower and len(pattern) > len(best_match): best_match = pattern input_price = inp output_price = outp # Check thinking token pricing for pattern, tp in THINKING_TOKEN_PRICING.items(): if pattern.lower() in model_lower: thinking_price = tp break if thinking_price is None: thinking_price = output_price return input_price, output_price, thinking_price @dataclass class TokenUsage: """Token usage for a single LLM call.""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 # Some APIs (e.g., Gemini thinking mode) separate thinking tokens thinking_tokens: int = 0 # Cache-related (some APIs report cached token counts) cached_tokens: int = 0 def to_dict(self) -> Dict[str, int]: return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, "total_tokens": self.total_tokens, "thinking_tokens": self.thinking_tokens, "cached_tokens": self.cached_tokens, } @staticmethod def from_api_response(usage_data: Dict[str, Any]) -> "TokenUsage": """Parse token usage from OpenAI-compatible API response. Handles various API formats: - Standard: {prompt_tokens, completion_tokens, total_tokens} - Gemini extended: {prompt_tokens, completion_tokens, total_tokens, completion_tokens_details: {reasoning_tokens: N}} - Some APIs: {input_tokens, output_tokens} """ if not usage_data: return TokenUsage() prompt = usage_data.get("prompt_tokens", 0) or usage_data.get("input_tokens", 0) or 0 completion = usage_data.get("completion_tokens", 0) or usage_data.get("output_tokens", 0) or 0 total = usage_data.get("total_tokens", 0) or (prompt + completion) # Extract thinking/reasoning tokens if available thinking = 0 details = usage_data.get("completion_tokens_details", {}) if isinstance(details, dict): thinking = details.get("reasoning_tokens", 0) or details.get("thinking_tokens", 0) or 0 # Some Gemini APIs put it at top level if not thinking: thinking = usage_data.get("reasoning_tokens", 0) or usage_data.get("thinking_tokens", 0) or 0 # Cached tokens cached = 0 prompt_details = usage_data.get("prompt_tokens_details", {}) if isinstance(prompt_details, dict): cached = prompt_details.get("cached_tokens", 0) or 0 if not cached: cached = usage_data.get("cached_tokens", 0) or 0 return TokenUsage( prompt_tokens=prompt, completion_tokens=completion, total_tokens=total, thinking_tokens=thinking, cached_tokens=cached, ) @dataclass class TokenTracker: """Tracks token usage across multiple LLM calls for a single data entry. Accumulates prompt_tokens, completion_tokens, thinking_tokens and computes estimated cost based on model pricing. """ model: str = "" total_prompt_tokens: int = 0 total_completion_tokens: int = 0 total_thinking_tokens: int = 0 total_cached_tokens: int = 0 num_calls: int = 0 call_details: List[Dict[str, Any]] = field(default_factory=list) def add(self, usage: TokenUsage, call_label: str = ""): """Add token usage from one LLM call.""" self.total_prompt_tokens += usage.prompt_tokens self.total_completion_tokens += usage.completion_tokens self.total_thinking_tokens += usage.thinking_tokens self.total_cached_tokens += usage.cached_tokens self.num_calls += 1 self.call_details.append({ "label": call_label, **usage.to_dict(), }) @property def total_tokens(self) -> int: return self.total_prompt_tokens + self.total_completion_tokens def estimate_cost(self, model: str = "") -> Dict[str, float]: """Estimate cost in USD based on model pricing. Returns dict with input_cost, output_cost, thinking_cost, total_cost. """ m = model or self.model input_price, output_price, thinking_price = _get_pricing(m) input_cost = self.total_prompt_tokens * input_price / 1_000_000 # completion_tokens includes thinking_tokens for some APIs, # so we separate them for pricing non_thinking_completion = max(0, self.total_completion_tokens - self.total_thinking_tokens) output_cost = non_thinking_completion * output_price / 1_000_000 thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000 return { "input_cost_usd": round(input_cost, 6), "output_cost_usd": round(output_cost, 6), "thinking_cost_usd": round(thinking_cost, 6), "total_cost_usd": round(input_cost + output_cost + thinking_cost, 6), } def to_dict(self, model: str = "") -> Dict[str, Any]: """Export full tracking info as a dict.""" cost = self.estimate_cost(model) return { "num_llm_calls": self.num_calls, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, "total_thinking_tokens": self.total_thinking_tokens, "total_cached_tokens": self.total_cached_tokens, "total_tokens": self.total_tokens, "estimated_cost": cost, "call_details": self.call_details, } def summary_str(self, model: str = "") -> str: """Human-readable one-line summary.""" cost = self.estimate_cost(model) return ( f"calls={self.num_calls} " f"prompt={self.total_prompt_tokens:,} " f"completion={self.total_completion_tokens:,} " f"thinking={self.total_thinking_tokens:,} " f"total={self.total_tokens:,} " f"cost=${cost['total_cost_usd']:.4f}" ) class GlobalTokenStats: """Thread-safe aggregator for token stats across all entries.""" def __init__(self, model: str = ""): self.model = model self.total_prompt_tokens = 0 self.total_completion_tokens = 0 self.total_thinking_tokens = 0 self.total_cached_tokens = 0 self.total_calls = 0 self.total_entries = 0 self._lock = asyncio.Lock() async def add(self, tracker: TokenTracker): async with self._lock: self.total_prompt_tokens += tracker.total_prompt_tokens self.total_completion_tokens += tracker.total_completion_tokens self.total_thinking_tokens += tracker.total_thinking_tokens self.total_cached_tokens += tracker.total_cached_tokens self.total_calls += tracker.num_calls self.total_entries += 1 def estimate_cost(self) -> Dict[str, float]: input_price, output_price, thinking_price = _get_pricing(self.model) input_cost = self.total_prompt_tokens * input_price / 1_000_000 non_thinking = max(0, self.total_completion_tokens - self.total_thinking_tokens) output_cost = non_thinking * output_price / 1_000_000 thinking_cost = self.total_thinking_tokens * thinking_price / 1_000_000 return { "input_cost_usd": round(input_cost, 4), "output_cost_usd": round(output_cost, 4), "thinking_cost_usd": round(thinking_cost, 4), "total_cost_usd": round(input_cost + output_cost + thinking_cost, 4), } def summary_str(self) -> str: cost = self.estimate_cost() avg_tokens = self.total_prompt_tokens + self.total_completion_tokens avg_per_entry = avg_tokens / max(1, self.total_entries) avg_cost = cost["total_cost_usd"] / max(1, self.total_entries) return ( f"\n{'=' * 72}\n" f" Token Usage Summary\n" f"{'=' * 72}\n" f" Model : {self.model}\n" f" Total entries : {self.total_entries}\n" f" Total LLM calls : {self.total_calls}\n" f" Total prompt tok : {self.total_prompt_tokens:,}\n" f" Total completion : {self.total_completion_tokens:,}\n" f" Total thinking : {self.total_thinking_tokens:,}\n" f" Total cached : {self.total_cached_tokens:,}\n" f" Total tokens : {self.total_prompt_tokens + self.total_completion_tokens:,}\n" f" ──────────────────────────────────────\n" f" Avg tokens/entry : {avg_per_entry:,.0f}\n" f" Avg cost/entry : ${avg_cost:.4f}\n" f" ──────────────────────────────────────\n" f" Input cost : ${cost['input_cost_usd']:.4f}\n" f" Output cost : ${cost['output_cost_usd']:.4f}\n" f" Thinking cost : ${cost['thinking_cost_usd']:.4f}\n" f" TOTAL COST : ${cost['total_cost_usd']:.4f}\n" f"{'=' * 72}" ) def to_dict(self) -> Dict[str, Any]: cost = self.estimate_cost() avg_tokens = self.total_prompt_tokens + self.total_completion_tokens return { "model": self.model, "total_entries": self.total_entries, "total_llm_calls": self.total_calls, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, "total_thinking_tokens": self.total_thinking_tokens, "total_cached_tokens": self.total_cached_tokens, "total_tokens": avg_tokens, "avg_tokens_per_entry": round(avg_tokens / max(1, self.total_entries)), "avg_cost_per_entry_usd": round(cost["total_cost_usd"] / max(1, self.total_entries), 6), "estimated_cost": cost, } # ════════════════════════════════════════════════════════════════════════ # Think Validation Utilities # ════════════════════════════════════════════════════════════════════════ def think_is_nonempty(text: str) -> bool: """Check if a think block contains basic meaningful content. Relaxed version: Removed all arbitrary length/word count limits. Only ensures the model didn't return a completely blank string.""" if not text: return False # Strip tags if present clean = re.sub(r'', '', text).strip() # Must have just a tiny bit of alphanumeric content to prove it's not just spaces/punctuation alnum = re.sub(r'[^A-Za-z0-9]', '', clean) if len(alnum) < 5: return False return True def extract_think_text(text: str) -> str: """Extract text content from a block.""" m = re.search(r'(.*?)', text, re.DOTALL) return (m.group(1) or '').strip() if m else '' def _dedup_think_content(text: str) -> str: """Remove paragraph-level repetition inside a block. LLMs sometimes repeat entire paragraphs verbatim within a single turn (decoding-level repetition / "repetition hallucination"). This function detects and removes such duplicates while preserving unique content and ordering. Strategy: split by double-newline into paragraphs, keep only the first occurrence of each paragraph (compared after whitespace normalisation). Also detects the case where the entire content is duplicated as one contiguous block (no blank-line separator between the copies). """ if not text or not text.strip(): return text # --- Case 1: paragraph-level dedup (split on blank lines) --- paragraphs = re.split(r'\n\s*\n', text.strip()) if len(paragraphs) >= 2: seen = set() unique = [] for p in paragraphs: key = ' '.join(p.split()) # normalise whitespace for comparison if key and key not in seen: seen.add(key) unique.append(p) if len(unique) < len(paragraphs): return '\n\n'.join(unique) # --- Case 2: whole-block duplication without blank-line separator --- # e.g. "ABC\nABC" where ABC is a multi-sentence chunk stripped = text.strip() length = len(stripped) if length >= 80: # only bother for non-trivial blocks # try splitting at every \n boundary near the midpoint mid = length // 2 for offset in range(0, min(40, mid)): for pos in (mid + offset, mid - offset): if pos <= 0 or pos >= length: continue if stripped[pos] != '\n': continue first_half = stripped[:pos].strip() second_half = stripped[pos:].strip() if first_half == second_half: return first_half return text # ════════════════════════════════════════════════════════════════════════ # Hallucination Detection & Sanitization # ════════════════════════════════════════════════════════════════════════ def sanitize_llm_output(text: str) -> str: """Truncate LLM output after the FIRST valid action block. Detects and removes hallucinated content where the model generates: - Multiple tool_calls in one turn - Fake blocks - "MODERATION:" blocks - Fake search results - Both AND in the same turn Returns cleaned text containing at most: ... + one action. """ if not text or not text.strip(): return text text = text.strip() # Remove any blocks the model hallucinated # (tool_response should ONLY come from the system) if '' in text: # Truncate at the first tr_start = text.index('') text = text[:tr_start].strip() # Remove any "---" separator + MODERATION blocks moderation_pattern = re.compile(r'\n*---+\s*\n*MODERATION:.*', re.DOTALL | re.IGNORECASE) text = moderation_pattern.sub('', text).strip() # Find ALL action blocks (tool_call and answer) with their positions tc_matches = list(re.finditer(r'.*?', text, re.DOTALL)) ans_matches = list(re.finditer(r'.*?', text, re.DOTALL)) all_actions = [] for m in tc_matches: all_actions.append(('tool_call', m.start(), m.end())) for m in ans_matches: all_actions.append(('answer', m.start(), m.end())) if not all_actions: # No action found — return as-is (will be handled by normalize) return text # Sort by position — keep only the FIRST action all_actions.sort(key=lambda x: x[1]) first_type, first_start, first_end = all_actions[0] # Truncate: keep everything up to and including the first action text = text[:first_end].strip() return text def is_hallucinated_output(text: str) -> bool: """Check if the LLM output contains hallucination markers. Returns True if the output contains: - Multiple blocks - Any block (should only come from system) - "MODERATION:" blocks - Both and in same turn """ if not text: return False tc_count = len(re.findall(r'', text)) has_tool_response = '' in text has_moderation = bool(re.search(r'MODERATION:', text, re.IGNORECASE)) has_both = '' in text and '' in text return tc_count > 1 or has_tool_response or has_moderation or has_both # ════════════════════════════════════════════════════════════════════════ # GPT Output Normalizer # ════════════════════════════════════════════════════════════════════════ def normalize_gpt_output(text: str) -> str: """Ensure every gpt turn follows strict format: ... followed by ... or ... Pipeline: 1. Sanitize hallucinated content (truncate after first action) 2. If already present, validate and fix duplicates 3. If no , wrap pre-action text as think block if valid 4. ZERO FILLER: If think is missing or empty, DO NOT inject fake fallback text. """ if not text or not text.strip(): return text text = text.strip() # First pass: deduplicate tags — keep only the first occurrence if text.count('') > 1: first_close_pos = text.index('') before_and_first = text[:first_close_pos + len('')] after_first = text[first_close_pos + len(''):] after_first = after_first.replace('', '') text = before_and_first + after_first # Also deduplicate tags — keep only the first occurrence if text.count('') > 1: first_open_pos = text.index('') before_and_first = text[:first_open_pos + len('')] after_first = text[first_open_pos + len(''):] after_first = after_first.replace('', '') text = before_and_first + after_first has_think_open = '' in text has_think_close = '' in text has_tc = '' in text has_answer = '' in text if has_think_open: # already has ... — validate & return if not has_think_close: # fix unclosed think: insert before first action if has_tc: tc_pos = text.index('') text = text[:tc_pos] + '\n\n' + text[tc_pos:] elif has_answer: ans_pos = text.index('') text = text[:ans_pos] + '\n\n' + text[ans_pos:] else: text = text + '' # Dedup paragraph-level repetition inside block think_text = extract_think_text(text) if think_text: deduped = _dedup_think_content(think_text) if deduped != think_text: text = text.replace(think_text, deduped, 1) # 不再做强行补全:哪怕 里面是空的,也原样返回,绝对不加垃圾数据 return text # No tag — check if there is valid text before the action if has_tc: tc_start = text.index('') pre_text = text[:tc_start].strip() tc_and_after = text[tc_start:] if pre_text and think_is_nonempty(pre_text): return f"{pre_text}\n\n{tc_and_after}" else: # 没有有效前置文本,直接返回动作,不补充废话 return tc_and_after elif has_answer: ans_start = text.index('') pre_text = text[:ans_start].strip() ans_and_after = text[ans_start:] if pre_text and think_is_nonempty(pre_text): return f"{pre_text}\n\n{ans_and_after}" else: # 没有有效前置文本,直接返回动作,不补充废话 return ans_and_after else: # No tool_call and no answer — wrap whatever it is as think, let downstream filters drop it if needed return f"{text}" # ════════════════════════════════════════════════════════════════════════ # Training Data Conversation Cleaning # ════════════════════════════════════════════════════════════════════════ def clean_conversations(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]: """Remove error recovery turns from training conversations. Produces a clean trajectory suitable for SFT training by: 1. Removing empty gpt turns (model returned nothing) 2. Removing corresponding error tool_response turns 3. Removing gpt turns that were error messages 4. Re-normalizing remaining gpt turns Returns a new list of clean conversation turns. """ if not conversations: return conversations cleaned = [] skip_next_human = False for i, turn in enumerate(conversations): role = turn.get("from", "") value = turn.get("value", "") if skip_next_human: if role == "human": skip_next_human = False # Check if this is an error response we should skip if _is_error_tool_response(value): continue # Otherwise keep it cleaned.append(turn) else: # Unexpected: another gpt turn — don't skip skip_next_human = False cleaned.append(turn) continue if role == "gpt": # Check if this turn is problematic stripped = value.strip() if value else "" # Case 1: Empty gpt turn if not stripped: # Skip this turn AND the next human turn (error response) skip_next_human = True continue # Case 2: gpt turn with no valid action (only think, no tool_call/answer) has_tc = '' in stripped has_answer = '' in stripped if not has_tc and not has_answer: # No action — skip this turn and next error response skip_next_human = True continue # Case 3: Valid turn — normalize and keep normalized = normalize_gpt_output(stripped) cleaned.append({"from": "gpt", "value": normalized}) elif role == "human": # Check if this is an error response to skip if _is_error_tool_response(value): # Remove the preceding gpt turn too if it was just added if cleaned and cleaned[-1].get("from") == "gpt": # Check if the gpt turn led to this error # Only remove if the gpt turn had no valid action pass continue cleaned.append(turn) else: # system or other — keep as-is cleaned.append(turn) return cleaned def _is_error_tool_response(value: str) -> bool: """Check if a human turn is an error tool_response that should be removed.""" if not value: return False v = value.strip() error_markers = [ "No or found", "Error: Tool", "Malformed tool_call JSON", "Error: bbox must have exactly", "Error: web_search requires a non-empty query", ] for marker in error_markers: if marker in v: return True return False # ════════════════════════════════════════════════════════════════════════ # Frame Utilities # ════════════════════════════════════════════════════════════════════════ def get_video_frame_count(video_path: str) -> int: """Get total frame count from a 1fps video (duration ~ frames).""" try: result = subprocess.run( ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "csv=p=0", video_path], capture_output=True, text=True, timeout=30) return int(math.ceil(float(result.stdout.strip()))) except Exception: return -1 def extract_all_frames(video_path: str, output_dir: str, max_resolution: int = 768, jpeg_quality: int = 85) -> Dict[int, str]: """Extract all frames from 1fps video. Returns {frame_index(0-based): path}. Reuses existing frames if the directory already has them.""" os.makedirs(output_dir, exist_ok=True) # Check if frames already exist existing = {} for f in sorted(os.listdir(output_dir)): if f.startswith("frame_") and f.endswith(".jpg"): idx = int(f.replace("frame_", "").replace(".jpg", "")) existing[idx] = os.path.join(output_dir, f) if existing: return existing vf = (f"scale='min({max_resolution},iw)':'min({max_resolution},ih)'" f":force_original_aspect_ratio=decrease") q = max(2, min(31, round(2 + (100 - jpeg_quality) * 29 / 99))) cmd = [ "ffmpeg", "-y", "-i", video_path, "-vf", vf, "-q:v", str(q), "-start_number", "0", os.path.join(output_dir, "frame_%06d.jpg"), ] proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600) if proc.returncode != 0: raise RuntimeError(f"ffmpeg failed: {proc.stderr[-500:]}") frames = {} for f in sorted(os.listdir(output_dir)): if f.startswith("frame_") and f.endswith(".jpg"): idx = int(f.replace("frame_", "").replace(".jpg", "")) frames[idx] = os.path.join(output_dir, f) if not frames: raise RuntimeError(f"No frames extracted from {video_path}") return frames def uniform_sample_indices(total_frames: int, num_samples: int) -> List[int]: """Uniformly sample 0-based frame indices. total_frames is the count.""" if total_frames <= num_samples: return list(range(total_frames)) return sorted(set(int(i) for i in np.linspace(0, total_frames - 1, num_samples))) def sample_interval(all_frames: Dict[int, str], start: int, end: int, num_samples: int = 8) -> List[Tuple[int, str]]: """Uniformly sample frames from [start, end] interval.""" available = sorted(k for k in all_frames if start <= k <= end) if not available: return [] if len(available) <= num_samples: return [(k, all_frames[k]) for k in available] positions = np.linspace(0, len(available) - 1, num_samples, dtype=int) selected = sorted(set(available[p] for p in positions)) return [(k, all_frames[k]) for k in selected] def get_frame(all_frames: Dict[int, str], idx: int) -> Tuple[int, str]: """Get exact frame or nearest available.""" if idx in all_frames: return (idx, all_frames[idx]) nearest = min(all_frames.keys(), key=lambda k: abs(k - idx)) return (nearest, all_frames[nearest]) # ════════════════════════════════════════════════════════════════════════ # Bbox Format Detection & Normalization # ════════════════════════════════════════════════════════════════════════ def get_bbox_config(model: str) -> dict: """根据模型名称返回对应的 bbox 格式配置。 匹配规则: model 名称中包含 pattern 的最长匹配优先。 如果没有匹配, 返回 default 配置。 Returns: {"order": "xyxy"|"yxyx", "range": "norm"|"permille"} """ model_lower = model.lower() best_match = "" best_config = BBOX_CONFIGS.get("default", {"order": "xyxy", "range": "norm"}) for pattern, config in BBOX_CONFIGS.items(): if pattern == "default": continue if pattern.lower() in model_lower and len(pattern) > len(best_match): best_match = pattern best_config = config return best_config def normalize_bbox(raw_bbox: list, bbox_config: dict) -> list: """将模型输出的 bbox 统一转换为 [x1, y1, x2, y2] 且值域 [0.0, 1.0]。 支持的输入格式: - Gemini: [y_min, x_min, y_max, x_max] in [0, 1000] - 标准: [x1, y1, x2, y2] in [0.0, 1.0] - 像素: [x1, y1, x2, y2] 绝对像素 (自动降级) Args: raw_bbox: 长度为 4 的列表 bbox_config: get_bbox_config() 返回的配置字典 Returns: [x1, y1, x2, y2] 全部归一化到 [0.0, 1.0] """ if len(raw_bbox) != 4: return [0.0, 0.0, 1.0, 1.0] # Step 1: 解析坐标顺序 if bbox_config.get("order") == "yxyx": # Gemini: [y_min, x_min, y_max, x_max] y1, x1, y2, x2 = raw_bbox else: # 标准: [x1, y1, x2, y2] x1, y1, x2, y2 = raw_bbox # Step 2: 归一化值域到 [0.0, 1.0] if bbox_config.get("range") == "permille": # [0, 1000] → [0.0, 1.0] x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0 else: # 自动检测: 如果最大值 > 1.0 但 <= 1000, 按 permille 处理 max_coord = max(abs(x1), abs(y1), abs(x2), abs(y2)) if max_coord > 1.0 and max_coord <= 1000: x1, y1, x2, y2 = x1 / 1000.0, y1 / 1000.0, x2 / 1000.0, y2 / 1000.0 elif max_coord > 1000: # 绝对像素坐标 — 无法在此归一化, 返回原始值让 crop_frame 处理 # (crop_frame 内部有 auto-detect 逻辑) pass # Step 3: Clamp 到 [0.0, 1.0] x1 = max(0.0, min(1.0, float(x1))) y1 = max(0.0, min(1.0, float(y1))) x2 = max(0.0, min(1.0, float(x2))) y2 = max(0.0, min(1.0, float(y2))) # Step 4: 确保 x1 < x2, y1 < y2 if x1 > x2: x1, x2 = x2, x1 if y1 > y2: y1, y2 = y2, y1 # Step 5: 保证最小面积 if x2 - x1 < 0.001: x2 = min(1.0, x1 + 0.01) if y2 - y1 < 0.001: y2 = min(1.0, y1 + 0.01) return [x1, y1, x2, y2] # ════════════════════════════════════════════════════════════════════════ # Image Search Failure Handling & Padding # ════════════════════════════════════════════════════════════════════════ class RetrieverDownError(Exception): """Retriever 服务(IP/端口)连接失败,应立即停止整个 pipeline。""" pass class ImageSearchFailedError(Exception): """Raised when image_search returns an error, timeout, or no results. Signals the entire entry should be retried from scratch so that failed serper results never appear in the SFT dataset.""" pass class QuotaExhaustedError(Exception): """API 额度真正用尽 (非临时限流),应立即停止整个 pipeline。""" pass class ProhibitedContentError(Exception): """内容被 Vertex AI 安全策略拦截 (PROHIBITED_CONTENT),不应重试。""" pass class ProjectDisabledError(Exception): """某个 Vertex project 不可用,应从本轮 project 池中禁用。""" pass def _is_quota_exhausted(status_code: int, error_body: str) -> bool: """判断 API 错误是否为额度真正用尽 (区别于临时 rate limit)。 Vertex AI 的 429 / RESOURCE_EXHAUSTED 通常是瞬时限流(QPM/TPM 打满), 应该重试而非停止。只有 403 + 明确的配额/账单关键词才认为是真正的额度用尽。 参考: https://docs.cloud.google.com/docs/quotas/troubleshoot - 429 → 临时限流,退避重试 - 403 + QUOTA_EXCEEDED / RATE_LIMIT_EXCEEDED / billing → 真正额度用尽 """ body_lower = error_body.lower() # 429 一律视为临时限流,不算额度用尽 if status_code == 429: return False # 403 + 明确配额/账单信号 → 真正额度用尽 if status_code == 403: quota_keywords = [ "quota_exceeded", "quota exceeded", "rate_limit_exceeded", "billing account", "billing is disabled", "insufficient quota", "out of quota", "daily limit", "per-day limit", ] if any(kw in body_lower for kw in quota_keywords): return True return False def is_image_search_failed(result_text: str) -> bool: """Check if an image search result indicates a failure. Returns True for any error, timeout, or empty-result response from the Serper Lens API. """ if not result_text or not result_text.strip(): return True fail_patterns = [ "Image search error:", "No results found from reverse image search", "Error: SERPER_API_KEY not configured", "request timed out", ] for pattern in fail_patterns: if pattern in result_text: return True return False def add_search_padding(bbox: List[float], frame_path: str, padding: tuple = (0.5, 0.5), padding_cap_px: int = 600) -> List[float]: """Add padding to a normalized [x1, y1, x2, y2] bbox for image search. Padding is proportional to the BBOX size (not image size), so that: - A tight face crop gets moderate expansion (include shoulders, some background) - A large crop doesn't balloon to cover the entire frame Args: bbox: [x1, y1, x2, y2] normalized to [0.0, 1.0] frame_path: path to the source frame (used to compute pixel cap) padding: (pad_x_ratio, pad_y_ratio) as fraction of bbox width/height e.g. (0.5, 0.5) means expand each side by 50% of bbox dimension padding_cap_px: max padding in pixels on each side (prevents excessive expansion on very large bboxes) Returns: Padded [x1, y1, x2, y2] clamped to [0.0, 1.0] """ x1, y1, x2, y2 = bbox pad_x_ratio, pad_y_ratio = padding # Padding proportional to bbox dimensions bbox_w = x2 - x1 bbox_h = y2 - y1 pad_x = bbox_w * pad_x_ratio pad_y = bbox_h * pad_y_ratio # Cap padding at padding_cap_px pixels (convert to normalized coords) try: with Image.open(frame_path) as img: img_w, img_h = img.size cap_x = padding_cap_px / img_w cap_y = padding_cap_px / img_h pad_x = min(pad_x, cap_x) pad_y = min(pad_y, cap_y) except Exception: pass # If we can't read dimensions, just use the bbox-proportional padding x1 = max(0.0, x1 - pad_x) y1 = max(0.0, y1 - pad_y) x2 = min(1.0, x2 + pad_x) y2 = min(1.0, y2 + pad_y) return [x1, y1, x2, y2] def crop_frame(frame_path: str, bbox: List[float], output_path: str) -> str: """Crop frame at bbox coordinates with smart format detection, clamping, and 2x upscaling for better search quality. Bbox format auto-detection: - [0.0, 1.0] range → normalized relative coordinates (standard) - [0, 1000] range → permille relative coordinates (some models output this) - values > 1000 → absolute pixel coordinates After cropping: - 2x LANCZOS upscale to help visual search models recognize small objects - Save as high-quality JPEG (quality=95) to minimize compression artifacts Args: frame_path: path to the source frame image bbox: [x1, y1, x2, y2] in any of the three supported formats output_path: where to save the cropped image Returns: output_path """ with Image.open(frame_path) as img: w, h = img.size raw_x1, raw_y1, raw_x2, raw_y2 = bbox # ── Step 1: Auto-detect coordinate format and convert to pixels ── max_coord = max(abs(raw_x1), abs(raw_y1), abs(raw_x2), abs(raw_y2)) if max_coord <= 1.0: # Format A: normalized [0.0, 1.0] px_x1 = raw_x1 * w px_y1 = raw_y1 * h px_x2 = raw_x2 * w px_y2 = raw_y2 * h elif max_coord <= 1000: # Format B: permille [0, 1000] px_x1 = raw_x1 / 1000.0 * w px_y1 = raw_y1 / 1000.0 * h px_x2 = raw_x2 / 1000.0 * w px_y2 = raw_y2 / 1000.0 * h else: # Format C: absolute pixel coordinates px_x1 = raw_x1 px_y1 = raw_y1 px_x2 = raw_x2 px_y2 = raw_y2 # ── Step 2: Clamp to image bounds ── x1 = max(0, min(int(round(px_x1)), w - 1)) y1 = max(0, min(int(round(px_y1)), h - 1)) x2 = max(0, min(int(round(px_x2)), w)) y2 = max(0, min(int(round(px_y2)), h)) # Ensure minimum 1px crop (prevent zero-area) if x2 <= x1: x2 = min(x1 + 1, w) if y2 <= y1: y2 = min(y1 + 1, h) # ── Step 3: Crop ── cropped_img = img.crop((x1, y1, x2, y2)) # ── Step 4: 2x upscale with LANCZOS for better search recognition ── cropped_img = cropped_img.resize( (cropped_img.width * 2, cropped_img.height * 2), Image.Resampling.LANCZOS, ) # ── Step 5: Save as high-quality JPEG ── cropped_img.save(output_path, 'JPEG', quality=95) return output_path def encode_image_b64(path: str) -> str: with open(path, "rb") as f: return base64.b64encode(f.read()).decode("ascii") # ════════════════════════════════════════════════════════════════════════ # LLM Client with Multi-API Load Balancing (v3 — Token Tracking) # ════════════════════════════════════════════════════════════════════════ class LLMClient: """Async LLM client with round-robin load balancing across API endpoints. v3: call() now returns (content, reasoning_content, TokenUsage). v4: Google Vertex AI authentication support (OAuth2 auto-refresh). v5: Vertex AI Native generateContent API support (for global region). """ def __init__(self, api_urls: List[str], model: str, concurrency_per_url: int = 2, temperature: float = DEFAULT_TEMPERATURE, top_p: float = DEFAULT_TOP_P, max_tokens: int = DEFAULT_MAX_TOKENS, timeout: int = DEFAULT_TIMEOUT, max_retries: int = DEFAULT_MAX_RETRIES, api_key: str = API_KEY): self.api_urls = api_urls self.model = model self.temperature = temperature self.top_p = top_p self.max_tokens = max_tokens self.timeout = timeout self.max_retries = max_retries self.api_key = api_key self._url_cycle = itertools.cycle(api_urls) self._semaphores = {url: asyncio.Semaphore(concurrency_per_url) for url in api_urls} self._cycle_lock = asyncio.Lock() # ── Project 池:429 时自动切换 ── self._active_url_index = 0 # 当前活跃的 URL 索引 self._url_switch_lock = asyncio.Lock() # 切换锁 self._project_cooldown_until = {url: 0.0 for url in api_urls} self._disabled_urls: Set[str] = set() # ── Vertex 请求整形:降低秒级 burst 导致的 429 ── self._vertex_min_interval = max(0.0, float(VERTEX_MIN_REQUEST_INTERVAL_SECONDS or 0.0)) self._vertex_rate_limit_cooldown = max(0.0, float(VERTEX_RATE_LIMIT_COOLDOWN_SECONDS or 0.0)) self._vertex_request_jitter = max(0.0, float(VERTEX_REQUEST_JITTER_SECONDS or 0.0)) self._vertex_next_request_at = 0.0 self._vertex_pacing_lock = asyncio.Lock() # ── 检测 API 模式 ── self._use_vertex_auth = self._detect_vertex_endpoint() self._use_vertex_native = self._detect_vertex_native() self._gcp_credentials = None self._gcp_credentials_by_url: Dict[str, Any] = {} self._gcp_auth_request = None if self._use_vertex_auth: self._init_gcp_auth() def _detect_vertex_endpoint(self) -> bool: """检测是否使用 Vertex AI endpoint(根据 URL 判断)。""" for url in self.api_urls: if "aiplatform.googleapis.com" in url: return True return False def _detect_vertex_native(self) -> bool: """检测是否使用 Vertex AI 原生 generateContent API。 判断依据:URL 中包含 /publishers/google/models(原生) 而非 /endpoints/openapi/chat/completions(OpenAI 兼容) """ for url in self.api_urls: if "/publishers/google/models" in url: return True # 也检查 config 中的标志 try: use_native = getattr(config, 'USE_VERTEX_NATIVE_API', False) if use_native: return True except Exception: pass return False def _init_gcp_auth(self): """初始化 Google Cloud 认证凭据。""" try: import google.auth import google.auth.transport.requests from google.oauth2 import service_account as sa_module scopes = ["https://www.googleapis.com/auth/cloud-platform"] self._gcp_auth_request = google.auth.transport.requests.Request() pool_records = [] for item in VERTEX_CREDENTIALS_POOL: if not isinstance(item, dict): continue api_url = item.get("api_url") if api_url in self.api_urls: pool_records.append(item) if pool_records: for item in pool_records: credential = None key_path = item.get("service_account_key") or "" key_info = item.get("service_account_info") if key_path and os.path.exists(key_path): credential = sa_module.Credentials.from_service_account_file( key_path, scopes=scopes, ) print( f"[GCP AUTH] 使用账号池密钥: account={item.get('account_name', '')} " f"project={item.get('project_id', '')} key={key_path}" ) elif isinstance(key_info, dict): credential = sa_module.Credentials.from_service_account_info( key_info, scopes=scopes, ) print( f"[GCP AUTH] 使用账号池内嵌密钥: account={item.get('account_name', '')} " f"project={item.get('project_id', '')}" ) else: raise RuntimeError( f"Vertex 账号池条目缺少可用凭证: project={item.get('project_id', '')}" ) credential.refresh(self._gcp_auth_request) self._gcp_credentials_by_url[item["api_url"]] = credential print(f"[GCP AUTH] 账号池认证成功,可用条目数: {len(self._gcp_credentials_by_url)}") return if GCP_SERVICE_ACCOUNT_KEY and os.path.exists(GCP_SERVICE_ACCOUNT_KEY): # 使用服务账号密钥文件 self._gcp_credentials = sa_module.Credentials.from_service_account_file( GCP_SERVICE_ACCOUNT_KEY, scopes=scopes, ) print(f"[GCP AUTH] 使用服务账号密钥: {GCP_SERVICE_ACCOUNT_KEY}") else: # 使用 Application Default Credentials (ADC) # 需要先运行 `gcloud auth application-default login` self._gcp_credentials, project = google.auth.default(scopes=scopes) print(f"[GCP AUTH] 使用 ADC (Application Default Credentials), project={project}") # 预先刷新一次,确认凭据有效 self._gcp_credentials.refresh(self._gcp_auth_request) print(f"[GCP AUTH] 认证成功,token 有效期至 {self._gcp_credentials.expiry}") except ImportError: raise ImportError( "使用 Vertex AI 需要安装 google-auth 库:\n" " pip install google-auth google-auth-httplib2" ) except Exception as e: raise RuntimeError( f"Google Cloud 认证失败: {e}\n" f"请确保已配置 GCP_SERVICE_ACCOUNT_KEY 或运行 " f"`gcloud auth application-default login`" ) def _get_gcp_token(self, api_url: Optional[str] = None) -> str: """获取有效的 GCP OAuth2 access token(自动刷新过期 token)。""" if api_url and api_url in self._gcp_credentials_by_url: credentials = self._gcp_credentials_by_url[api_url] if credentials.expired or not credentials.token: credentials.refresh(self._gcp_auth_request) return credentials.token if self._gcp_credentials.expired or not self._gcp_credentials.token: self._gcp_credentials.refresh(self._gcp_auth_request) return self._gcp_credentials.token async def _next_url(self) -> str: async with self._cycle_lock: return next(self._url_cycle) def _get_active_url(self) -> str: """获取当前活跃的 project URL(429 切换后会变化)。""" return self.api_urls[self._active_url_index] async def _acquire_vertex_request_slot(self): """对 Vertex 请求做全局节流,避免共享池秒级 burst。""" if not self._use_vertex_auth or self._vertex_min_interval <= 0: return while True: async with self._vertex_pacing_lock: now = time.monotonic() wait = self._vertex_next_request_at - now if wait <= 0: reserve = self._vertex_min_interval + random.uniform(0, self._vertex_request_jitter) self._vertex_next_request_at = now + reserve return if wait > 0.5: print(f" [VERTEX PACE] waiting {wait:.1f}s before next request") await asyncio.sleep(max(wait, 0.05)) async def _mark_project_cooldown(self, failed_url: str): """某个 project 返回 429 后,短时间内不要立即打回去。""" if self._vertex_rate_limit_cooldown <= 0: return cooldown = self._vertex_rate_limit_cooldown + random.uniform(0, self._vertex_request_jitter) async with self._url_switch_lock: until_ts = time.monotonic() + cooldown prev = self._project_cooldown_until.get(failed_url, 0.0) self._project_cooldown_until[failed_url] = max(prev, until_ts) print(f" [PROJECT POOL] cooldown {cooldown:.1f}s for {failed_url[:80]}...") async def _disable_project_url(self, failed_url: str, reason: str): """禁用 suspended / API-disabled / permission-denied 的 project。""" async with self._url_switch_lock: self._disabled_urls.add(failed_url) print(f" [PROJECT POOL] disabled project: {failed_url[:80]}... reason={reason[:160]}") if len(self._disabled_urls) >= len(self.api_urls): raise ProjectDisabledError("All Vertex projects are disabled or suspended.") if self.api_urls[self._active_url_index] == failed_url: for idx, url in enumerate(self.api_urls): if url not in self._disabled_urls: self._active_url_index = idx break async def _acquire_project_url(self) -> Tuple[str, int]: """选择一个当前未处于冷却期的 project;如果都在冷却,则等待最早恢复的那个。""" while True: async with self._url_switch_lock: now = time.monotonic() count = len(self.api_urls) for offset in range(count): idx = (self._active_url_index + offset) % count url = self.api_urls[idx] if url in self._disabled_urls: continue if self._project_cooldown_until.get(url, 0.0) <= now: self._active_url_index = idx return url, idx available_urls = [u for u in self.api_urls if u not in self._disabled_urls] if not available_urls: raise ProjectDisabledError("All Vertex projects are disabled or suspended.") soonest_url = min(available_urls, key=lambda u: self._project_cooldown_until.get(u, 0.0)) wait = max(0.0, self._project_cooldown_until.get(soonest_url, 0.0) - now) print(f" [PROJECT POOL] all projects cooling down, waiting {wait:.1f}s") await asyncio.sleep(max(wait, 0.1)) async def _switch_to_next_project(self, failed_index: int) -> bool: """429 时切换到下一个 project。返回 True 如果成功切换到不同的 project。""" async with self._url_switch_lock: # 可能已经被其他协程切换过了 if self._active_url_index != failed_index: return True # 已经切换了,直接重试 now = time.monotonic() count = len(self.api_urls) for offset in range(1, count + 1): next_index = (failed_index + offset) % count next_url = self.api_urls[next_index] if next_url in self._disabled_urls: continue if self._project_cooldown_until.get(next_url, 0.0) <= now: self._active_url_index = next_index print(f" [PROJECT POOL] 429 → 切换到 project #{next_index + 1}/{len(self.api_urls)}: {next_url[:80]}...") return True return False def _all_projects_cooling_down(self) -> bool: now = time.monotonic() active_urls = [url for url in self.api_urls if url not in self._disabled_urls] return bool(active_urls) and all(self._project_cooldown_until.get(url, 0.0) > now for url in active_urls) @staticmethod def _is_project_disabled_error(err_str: str) -> bool: err_lower = err_str.lower() return any(marker in err_lower for marker in ( "consumer_suspended", "has been suspended", "service_disabled", "api has not been used", "api is disabled", )) # ════════════════════════════════════════════════════════════════ # OpenAI → Vertex Native 消息格式转换 # ════════════════════════════════════════════════════════════════ @staticmethod def _convert_messages_to_vertex_native( messages: List[Dict], ) -> Tuple[Optional[Dict], List[Dict]]: """将 OpenAI 格式的 messages 转换为 Vertex AI 原生格式。 OpenAI 格式: [{"role": "system", "content": "..."}, {"role": "user", "content": "..." | [{"type":"text",...}, {"type":"image_url",...}]}, {"role": "assistant", "content": "..."}] Vertex 原生格式: systemInstruction: {"parts": [{"text": "..."}]} contents: [ {"role": "user", "parts": [{"text": "..."}, {"inlineData": {...}}]}, {"role": "model", "parts": [{"text": "..."}]} ] Returns: (system_instruction_dict_or_None, contents_list) """ system_instruction = None contents = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": # 提取 system prompt → systemInstruction if isinstance(content, str): system_instruction = {"parts": [{"text": content}]} elif isinstance(content, list): parts = [] for item in content: if isinstance(item, str): parts.append({"text": item}) elif isinstance(item, dict) and item.get("type") == "text": parts.append({"text": item.get("text", "")}) system_instruction = {"parts": parts} continue # 角色映射: assistant → model vertex_role = "model" if role == "assistant" else "user" # 转换 content → parts parts = [] if isinstance(content, str): if content.strip(): parts.append({"text": content}) elif isinstance(content, list): for item in content: if isinstance(item, str): if item.strip(): parts.append({"text": item}) elif isinstance(item, dict): item_type = item.get("type", "") if item_type == "text": text_val = item.get("text", "") if text_val.strip(): parts.append({"text": text_val}) elif item_type == "image_url": # OpenAI: {"type":"image_url","image_url":{"url":"data:image/jpeg;base64,..."}} image_url = item.get("image_url", {}) url = image_url.get("url", "") if isinstance(image_url, dict) else "" if url.startswith("data:"): # 解析 data URI: data:image/jpeg;base64,xxxx # 提取 mimeType 和 base64 数据 try: header, b64_data = url.split(",", 1) # header = "data:image/jpeg;base64" mime_type = header.split(":")[1].split(";")[0] parts.append({ "inlineData": { "mimeType": mime_type, "data": b64_data, } }) except (ValueError, IndexError): # 解析失败,跳过 pass elif url.startswith("gs://"): # GCS URI parts.append({ "fileData": { "fileUri": url, "mimeType": "image/jpeg", } }) if parts: contents.append({"role": vertex_role, "parts": parts}) return system_instruction, contents @staticmethod def _parse_vertex_native_response( result: Dict[str, Any], ) -> Tuple[str, str, TokenUsage]: """解析 Vertex AI 原生 generateContent 响应。 Vertex 响应格式: { "candidates": [{ "content": { "role": "model", "parts": [{"text": "..."}, ...] }, "finishReason": "STOP" }], "usageMetadata": { "promptTokenCount": 100, "candidatesTokenCount": 50, "totalTokenCount": 150, "thoughtsTokenCount": 20 // 可选,thinking tokens } } Returns: (content_text, reasoning_text, TokenUsage) """ # 检测内容安全拦截 — 不应重试 prompt_feedback = result.get("promptFeedback", {}) block_reason = prompt_feedback.get("blockReason", "") if block_reason == "PROHIBITED_CONTENT": raise ProhibitedContentError( f"Content blocked by Vertex safety filter: {block_reason}") candidates = result.get("candidates", []) if not candidates: raise ValueError(f"Empty candidates in Vertex response: {json.dumps(result)[:300]}") candidate = candidates[0] parts = candidate.get("content", {}).get("parts", []) content_text = "" reasoning_text = "" for part in parts: if "text" in part: # 检查是否是 thought/reasoning part # Vertex 原生 API 中 thinking 内容可能在 thought 字段 if part.get("thought", False): reasoning_text += part["text"] else: content_text += part["text"] # 解析 token 使用量 usage_meta = result.get("usageMetadata", {}) prompt_tokens = usage_meta.get("promptTokenCount", 0) or 0 completion_tokens = usage_meta.get("candidatesTokenCount", 0) or 0 total_tokens = usage_meta.get("totalTokenCount", 0) or (prompt_tokens + completion_tokens) thinking_tokens = usage_meta.get("thoughtsTokenCount", 0) or 0 cached_tokens = usage_meta.get("cachedContentTokenCount", 0) or 0 usage = TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, thinking_tokens=thinking_tokens, cached_tokens=cached_tokens, ) return content_text, reasoning_text, usage # ════════════════════════════════════════════════════════════════ # 主调用方法 # ════════════════════════════════════════════════════════════════ async def call( self, messages: List[Dict], session: aiohttp.ClientSession, temperature: float = None, max_tokens: int = None, ) -> Tuple[str, str, TokenUsage]: """Call LLM with load balancing and project pool failover. 自动根据 API 模式选择 OpenAI 兼容 或 Vertex Native 格式。 遇到 429 限流时自动切换到下一个 project URL 重试。 Returns (content, reasoning_content, TokenUsage). """ last_err = None start_index = self._active_url_index # 最多尝试所有 project for _pool_attempt in range(len(self.api_urls)): base_url, url_index = await self._acquire_project_url() sem = self._semaphores[base_url] try: if self._use_vertex_native: return await self._call_vertex_native( messages, session, base_url, sem, temperature, max_tokens ) else: return await self._call_openai_compat( messages, session, base_url, sem, temperature, max_tokens ) except Exception as e: err_str = str(e) is_rate_limit = ( "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower() ) if self._is_project_disabled_error(err_str): last_err = e await self._disable_project_url(base_url, err_str) continue if is_rate_limit and len(self.api_urls) > 1: last_err = e await self._mark_project_cooldown(base_url) await self._switch_to_next_project(url_index) # 如果切换后回到起点,说明所有 project 都 429 了 if self._all_projects_cooling_down(): print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 正在 cooldown") raise if self._active_url_index == start_index: print(f" [PROJECT POOL] 所有 {len(self.api_urls)} 个 project 均已 429,上抛异常") raise continue # 用新 project 重试 # 非 429 错误 → 直接抛出 raise # 所有 project 都 429 raise last_err or RuntimeError("All projects in pool exhausted (429)") async def _call_openai_compat( self, messages: List[Dict], session: aiohttp.ClientSession, api_url: str, sem: asyncio.Semaphore, temperature: float = None, max_tokens: int = None, ) -> Tuple[str, str, TokenUsage]: """OpenAI 兼容端点调用(原有逻辑,保持不变)。""" payload = { "model": self.model, "messages": messages, "temperature": temperature or self.temperature, "top_p": self.top_p, "max_tokens": max_tokens or self.max_tokens, } for attempt in range(self.max_retries): try: headers = {"Content-Type": "application/json"} if self._use_vertex_auth: token = self._get_gcp_token(api_url) headers["Authorization"] = f"Bearer {token}" elif self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" async with sem: await self._acquire_vertex_request_slot() async with session.post( api_url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout), ) as resp: if resp.status != 200: error_body = await resp.text() if _is_quota_exhausted(resp.status, error_body): raise QuotaExhaustedError( f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}") raise RuntimeError( f"HTTP {resp.status} from {api_url}: {error_body[:500]}") result = await resp.json() choices = result.get("choices", []) if not choices: raise ValueError(f"Empty choices: {json.dumps(result)[:300]}") message = choices[0].get("message", {}) content = message.get("content", "") or "" reasoning = message.get("reasoning_content", "") or "" # Parse token usage usage_data = result.get("usage", {}) usage = TokenUsage.from_api_response(usage_data) return content, reasoning, usage except QuotaExhaustedError: raise # 额度用尽,不重试,直接上抛 except ProhibitedContentError: raise # 内容被安全策略拦截,不重试,直接上抛 except ProjectDisabledError: raise except Exception as e: if attempt == self.max_retries - 1: raise err_str = str(e) if self._is_project_disabled_error(err_str): raise ProjectDisabledError(err_str) if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower(): if len(self.api_urls) > 1: # 有多个 project,立即抛出让 call() 切换 project raise # 单 project 模式:退避重试 wait = (15 * (2 ** attempt)) + random.uniform(0, 5) print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})") else: wait = 2 ** attempt await asyncio.sleep(wait) raise RuntimeError("Unreachable") async def _call_vertex_native( self, messages: List[Dict], session: aiohttp.ClientSession, base_url: str, sem: asyncio.Semaphore, temperature: float = None, max_tokens: int = None, ) -> Tuple[str, str, TokenUsage]: """Vertex AI 原生 generateContent 调用。 base_url 格式: https://aiplatform.googleapis.com/v1/projects/{P}/locations/{L}/publishers/google/models 实际请求 URL = base_url/{model}:generateContent """ # 拼接完整 URL # model 名中可能有 "google/" 前缀(从 OpenAI 兼容迁移过来),需要去掉 model_name = self.model if model_name.startswith("google/"): model_name = model_name[len("google/"):] full_url = f"{base_url}/{model_name}:generateContent" # 转换消息格式 system_instruction, contents = self._convert_messages_to_vertex_native(messages) # 构建请求体 payload: Dict[str, Any] = { "contents": contents, "generationConfig": { "temperature": temperature or self.temperature, "topP": self.top_p, "maxOutputTokens": max_tokens or self.max_tokens, }, } if system_instruction: payload["systemInstruction"] = system_instruction for attempt in range(self.max_retries): try: headers = {"Content-Type": "application/json"} if self._use_vertex_auth: token = self._get_gcp_token(base_url) headers["Authorization"] = f"Bearer {token}" async with sem: await self._acquire_vertex_request_slot() async with session.post( full_url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout), ) as resp: if resp.status != 200: error_body = await resp.text() if _is_quota_exhausted(resp.status, error_body): raise QuotaExhaustedError( f"API 额度用尽! HTTP {resp.status}: {error_body[:500]}") raise RuntimeError( f"HTTP {resp.status} from {full_url}: {error_body[:500]}") result = await resp.json() content, reasoning, usage = self._parse_vertex_native_response(result) return content, reasoning, usage except QuotaExhaustedError: raise # 额度用尽,不重试,直接上抛 except ProhibitedContentError: raise # 内容被安全策略拦截,不重试,直接上抛 except ProjectDisabledError: raise except Exception as e: if attempt == self.max_retries - 1: raise err_str = str(e) if self._is_project_disabled_error(err_str): raise ProjectDisabledError(err_str) if "429" in err_str or "RESOURCE_EXHAUSTED" in err_str or "rate limit" in err_str.lower(): if len(self.api_urls) > 1: # 有多个 project,立即抛出让 call() 切换 project raise # 单 project 模式:退避重试 wait = (15 * (2 ** attempt)) + random.uniform(0, 5) print(f" [RATE LIMIT] 429 detected, backing off {wait:.1f}s (attempt {attempt+1}/{self.max_retries})") else: wait = 2 ** attempt print(f" [WARN] Vertex native call failed (attempt {attempt+1}): {e}") await asyncio.sleep(wait) raise RuntimeError("Unreachable") # ════════════════════════════════════════════════════════════════════════ # Message Builders (OpenAI format) # ════════════════════════════════════════════════════════════════════════ def build_user_message(text: str, image_paths: List[str] = None) -> Dict: """Build user message with optional base64 images.""" if not image_paths: return {"role": "user", "content": text} content = [] for path in image_paths: b64 = encode_image_b64(path) content.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"} }) content.append({"type": "text", "text": text}) return {"role": "user", "content": content} def build_assistant_message(text: str) -> Dict: return {"role": "assistant", "content": text} # ════════════════════════════════════════════════════════════════════════ # Response Parser (FIXED: position-based priority) # ════════════════════════════════════════════════════════════════════════ def parse_llm_response(text: str) -> Tuple[str, Any]: """Parse LLM response to extract the FIRST action by position. Returns ('tool_call', dict), ('answer', str), or ('error', str). CRITICAL FIX: Uses position-based priority instead of always preferring . The first action tag that appears in the text wins. This prevents hallucinated tags from overriding valid tags. """ # Find positions of first tool_call and first answer tc_m = re.search(r'\s*(.*?)\s*', text, re.DOTALL) answer_m = re.search(r'\s*(.*?)\s*', text, re.DOTALL) if tc_m and answer_m: # Both found — use whichever comes FIRST in the text if tc_m.start() < answer_m.start(): try: return ("tool_call", json.loads(tc_m.group(1))) except json.JSONDecodeError: return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}") else: return ("answer", answer_m.group(1).strip()) if tc_m: try: return ("tool_call", json.loads(tc_m.group(1))) except json.JSONDecodeError: return ("error", f"Malformed tool_call JSON: {tc_m.group(1)[:200]}") if answer_m: return ("answer", answer_m.group(1).strip()) return ("error", "No or found in response") # ════════════════════════════════════════════════════════════════════════ # Search Utilities # ════════════════════════════════════════════════════════════════════════ # ── Image Search Cache ── class ImageSearchCache: """MD5-keyed cache for image search results.""" def __init__(self, cache_file: str): self.cache_file = cache_file self.cache = self._load() self._dirty_count = 0 def _load(self) -> Dict[str, str]: if os.path.exists(self.cache_file): try: with open(self.cache_file, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {} return {} def save(self): with open(self.cache_file, "w", encoding="utf-8") as f: json.dump(self.cache, f, ensure_ascii=False, indent=2) self._dirty_count = 0 def get(self, image_bytes: bytes) -> Optional[str]: key = hashlib.md5(image_bytes).hexdigest() return self.cache.get(key) def set(self, image_bytes: bytes, result: str): key = hashlib.md5(image_bytes).hexdigest() self.cache[key] = result self._dirty_count += 1 if self._dirty_count >= 10: self.save() # ── Mock Search Implementations ── def mock_image_search(entity: str, bbox: List[float]) -> str: """Mock image search result for offline testing.""" return ( f"Reverse Image Search Results:\n\n" f"Result 1:\n" f" Title: {entity} - Character Profile\n" f" Snippet: {entity} is a well-known character/entity.\n" f" URL: https://example.com/{entity.lower().replace(' ', '_')}\n\n" f"Result 2:\n" f" Title: {entity} | Wiki\n" f" Snippet: Detailed information about {entity}.\n" f" URL: https://wiki.example.com/{entity.lower().replace(' ', '_')}" ) def mock_web_search(query: str) -> str: """Mock web search result for offline testing.""" return ( f'Web Search Results for "{query}":\n\n' f"Quick Answer: Information related to the query.\n\n" f"Result 1: {query} - Overview\n" f" Relevant information about {query}.\n\n" f"Result 2: {query} - Details\n" f" Additional details and facts." ) # ── Real Search Implementations ── def _format_serper_lens_results(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str: """Format raw Serper Lens organic results into readable text.""" parts = [] for i, item in enumerate(organic_results[:max_results], 1): title = item.get('title', '') snippet = item.get('snippet', '') link = item.get('link', '') source = item.get('source', '') or item.get('domain', '') block = [f"Result {i}:"] if title: block.append(f" Title: {title}") if snippet: block.append(f" Snippet: {snippet}") if source: block.append(f" Source: {source}") if link: block.append(f" URL: {link}") parts.append("\n".join(block)) return "\n\n".join(parts) def _build_serper_lens_summary_prompt(organic_results: List[Dict[str, Any]], max_results: int = 5) -> str: """Build English summarizer prompt using only Serper Lens result metadata.""" context_parts = [] for i, item in enumerate(organic_results[:max_results], 1): title = item.get('title', '') snippet = item.get('snippet', '') link = item.get('link', '') source = item.get('source', '') or item.get('domain', '') block = [f"Result {i}:"] if title: block.append(f"Title: {title}") if snippet: block.append(f"Snippet: {snippet}") if source: block.append(f"Source: {source}") if link: block.append(f"Link: {link}") context_parts.append("\n".join(block)) context_text = "\n\n".join(context_parts) return ( "You are a helpful assistant. Your task is to summarize the main content of the given " "Serper Lens reverse image search results in no more than five sentences.\n\n" "Your summary should cover the overall key points across the results, not just the parts " "most related to the user's question.\n\n" "If any part of the results is helpful for identifying the entity or answering the user's " "question, include it clearly in the summary. Do not ignore relevant information, but make " "sure the general structure and main ideas of the results are preserved.\n\n" "Your summary should be concise, factual, and informative. If the results are ambiguous, " "conflicting, or insufficient, clearly state that uncertainty.\n\n" "Use only the provided result titles, snippets, and source/link metadata. Do not invent facts " "and do not assume content from the linked pages.\n\n" f"{context_text}" ) async def summarize_serper_image_results( organic_results: List[Dict[str, Any]], session: aiohttp.ClientSession, summarizer_address: str = "", summarizer_model: str = "", max_results: int = 5, max_tokens: int = 512, ) -> Optional[str]: """Summarize Serper Lens results without fetching linked webpages.""" summarizer_addr = summarizer_address or IMAGE_SEARCH_SUMMARIZER_ADDRESS sum_model = summarizer_model or IMAGE_SEARCH_SUMMARIZER_MODEL if not organic_results or not summarizer_addr or not sum_model: return None summarizer_prompt = _build_serper_lens_summary_prompt(organic_results, max_results=max_results) summarizer_payload = { "model": sum_model, "messages": [{"role": "user", "content": summarizer_prompt}], "max_tokens": max_tokens, "temperature": 0.3, "chat_template_kwargs": {"enable_thinking": False}, } try: async with session.post( f"http://{summarizer_addr}/v1/chat/completions", json=summarizer_payload, headers={"Content-Type": "application/json"}, timeout=aiohttp.ClientTimeout(total=120), ) as resp: if resp.status != 200: print(f" [IMAGE_SEARCH] Summarizer returned HTTP {resp.status}, falling back to raw results") return None data = await resp.json() choices = data.get("choices", []) if choices and isinstance(choices, list): msg = choices[0].get("message", {}) summary = msg.get("content", "") if summary and summary.strip(): summary = _strip_thinking_tags(summary).strip() return summary or None return None except asyncio.TimeoutError: print(" [IMAGE_SEARCH] Summarizer timeout, falling back to raw results") return None except Exception as e: print(f" [IMAGE_SEARCH] Summarizer error: {e}, falling back to raw results") return None async def real_image_search( image_b64_or_path: str, session: aiohttp.ClientSession, api_key: str, crop_path: str = None, ) -> str: """反向图片搜索,根据 IMAGE_SEARCH_MODE 选择直连 Serper 或公司内部网关。""" if IMAGE_SEARCH_MODE == "gateway": return await _gateway_image_search(image_b64_or_path, session, crop_path) else: return await _serper_image_search(image_b64_or_path, session, api_key, crop_path) async def _gateway_image_search( image_b64_or_path: str, session: aiohttp.ClientSession, crop_path: str = None, ) -> str: """反向图片搜索 via 公司内部网关 → Serper Google Lens + optional LLM summarization.""" if not GATEWAY_TOKEN: return "Error: GATEWAY_TOKEN not configured." image_url, prep_error = _prepare_image_search_url( image_b64_or_path, crop_path, "IMAGE_SEARCH/GATEWAY" ) if prep_error: return prep_error headers = { 'Content-Type': 'application/json', 'User-Agent': 'ifbook-http-client', } serper_params = { "url": image_url, "type": "lens", } gateway_payload = { "sec_info": { "username": GATEWAY_USERNAME, "userid": GATEWAY_USERID, "token": GATEWAY_TOKEN, }, "model_type": "openai", "model_name": "serper", "params": json.dumps(serper_params), } max_api_retries = 2 last_error = None for api_attempt in range(max_api_retries): try: async with session.post( GATEWAY_URL, headers=headers, json=gateway_payload, timeout=aiohttp.ClientTimeout(total=60), ) as resp: if resp.status != 200: error_body = await resp.text() last_error = (f"Gateway error: HTTP {resp.status}: " f"{error_body[:300]}") if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH/GATEWAY] HTTP {resp.status}, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error gateway_resp = await resp.json() model_output_str = gateway_resp.get("model_output", "{}") data = json.loads(model_output_str) organic = data.get('organic', []) if not organic: return "No results found from reverse image search." raw_results = _format_serper_lens_results( organic, max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, ) if not IMAGE_SEARCH_SUMMARIZE_SERPER: return raw_results summary = await summarize_serper_image_results( organic, session, summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS, summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL, max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, ) if summary: return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}" return raw_results except asyncio.TimeoutError: last_error = "Image search error: request timed out after 60s" if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH/GATEWAY] Timeout, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error except Exception as e: last_error = f"Image search error: {e}" if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH/GATEWAY] Error: {e}, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error return last_error or "Image search error: unknown failure" async def _serper_image_search( image_b64_or_path: str, session: aiohttp.ClientSession, api_key: str, crop_path: str = None, ) -> str: """反向图片搜索 via Serper Google Lens + optional LLM summarization(原始直连方式)。""" if not api_key: return "Error: SERPER_API_KEY not configured." headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'} image_url, prep_error = _prepare_image_search_url( image_b64_or_path, crop_path, "IMAGE_SEARCH" ) if prep_error: return prep_error max_api_retries = 2 last_error = None for api_attempt in range(max_api_retries): try: async with session.post( "https://google.serper.dev/lens", headers=headers, json={"url": image_url}, timeout=aiohttp.ClientTimeout(total=60), ) as resp: if resp.status != 200: error_body = await resp.text() last_error = (f"Image search error: HTTP {resp.status}: " f"{error_body[:300]}") if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH] HTTP {resp.status}, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error data = await resp.json() organic = data.get('organic', []) if not organic: return "No results found from reverse image search." raw_results = _format_serper_lens_results( organic, max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, ) if not IMAGE_SEARCH_SUMMARIZE_SERPER: return raw_results summary = await summarize_serper_image_results( organic, session, summarizer_address=IMAGE_SEARCH_SUMMARIZER_ADDRESS, summarizer_model=IMAGE_SEARCH_SUMMARIZER_MODEL, max_results=IMAGE_SEARCH_SUMMARIZER_MAX_RESULTS, max_tokens=IMAGE_SEARCH_SUMMARIZER_MAX_TOKENS, ) if summary: return f"Summary: {summary}\n\nTop Lens Results:\n\n{raw_results}" return raw_results except asyncio.TimeoutError: last_error = "Image search error: request timed out after 60s" if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH] Timeout, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error except Exception as e: last_error = f"Image search error: {e}" if api_attempt < max_api_retries - 1: print(f" [IMAGE_SEARCH] Error: {e}, " f"retrying ({api_attempt+1}/{max_api_retries})...") await asyncio.sleep(3 * (api_attempt + 1)) continue return last_error return last_error or "Image search error: unknown failure" async def real_web_search( query: str, session: aiohttp.ClientSession, address: str = WEB_SEARCH_ADDRESS, ) -> str: """Web search via internal search server (SenseNova pattern).""" payload = { "query": query.strip().replace("\n", " "), "top_k": 3, "retrieval_mode": "google_serper", **WEB_SEARCH_CONFIG, } try: async with session.post( f"http://{address}/search", json=payload, timeout=aiohttp.ClientTimeout(total=100), ) as resp: resp.raise_for_status() return await resp.text() except asyncio.TimeoutError: return f"Error: Web search timeout for query: {query[:100]}" except Exception as e: return f"Error: Web search failed: {e}" async def serper_web_search( query: str, session: aiohttp.ClientSession, api_key: str, ) -> str: """Fallback web search via Serper Google Search.""" if not api_key: return "Error: SERPER_API_KEY not configured." headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'} try: async with session.post( "https://google.serper.dev/search", headers=headers, json={"q": query}, timeout=aiohttp.ClientTimeout(total=20), ) as resp: resp.raise_for_status() data = await resp.json() parts = [] ab = data.get('answerBox', {}) if ab: answer = ab.get('answer') or ab.get('snippet', '') if answer: parts.append(f"Quick Answer: {answer}") kg = data.get('knowledgeGraph', {}) if kg: parts.append( f"Knowledge Graph: {kg.get('title', '')} - " f"{kg.get('description', '')}") for i, item in enumerate(data.get('organic', [])[:5], 1): parts.append( f"Result {i}: {item.get('title', '')}\n" f" {item.get('snippet', '')}") return "\n".join(parts) if parts else "No relevant results found." except Exception as e: return f"Web search error: {e}" # ── MARS retrieval concurrency semaphore (lazy init) ── _mars_retrieval_semaphore: Optional[asyncio.Semaphore] = None def _get_mars_retrieval_semaphore() -> Optional[asyncio.Semaphore]: global _mars_retrieval_semaphore if MARS_RETRIEVAL_CONCURRENCY > 0 and _mars_retrieval_semaphore is None: _mars_retrieval_semaphore = asyncio.Semaphore(MARS_RETRIEVAL_CONCURRENCY) return _mars_retrieval_semaphore if MARS_RETRIEVAL_CONCURRENCY > 0 else None async def mars_web_search( query: str, session: aiohttp.ClientSession, retrieval_address: str = "", summarizer_address: str = "", retrieval_topk: int = 3, summarizer_model: str = "", ) -> str: """SenseNova-MARS style web search: retrieve from local Wikipedia + summarize via LLM. Two-step pipeline: 1. POST to Search-R1 retrieval server → get top-k document passages 2. POST to summarizer LLM (OpenAI-compatible /v1/chat/completions) → get concise summary This mirrors the SenseNova-MARS web_search_server architecture but without the intermediate FastAPI layer. """ retrieval_addr = retrieval_address or MARS_RETRIEVAL_ADDRESS summarizer_addr = summarizer_address or MARS_SUMMARIZER_ADDRESS topk = retrieval_topk or MARS_RETRIEVAL_TOPK sum_model = summarizer_model or MARS_SUMMARIZER_MODEL if not retrieval_addr: return "Error: MARS_RETRIEVAL_ADDRESS not configured." if not summarizer_addr: return "Error: MARS_SUMMARIZER_ADDRESS not configured." clean_query = query.strip().replace("\n", " ") # ════════════════════════════════════════════════════════════════ # Step 1: Retrieve passages from Search-R1 retrieval server # ════════════════════════════════════════════════════════════════ retrieval_payload = { "queries": [clean_query], # Search-R1 要求 queries 是 List[str] "return_scores": True, # 必须为 True,否则服务端解包崩溃返回 500 "topk": topk, } retrieved_passages = [] sem = _get_mars_retrieval_semaphore() try: if sem: await sem.acquire() try: async with session.post( f"http://{retrieval_addr}/retrieve", json=retrieval_payload, timeout=aiohttp.ClientTimeout(total=MARS_RETRIEVAL_TIMEOUT), proxy="", # ← 绕过系统代理,直连内网 ) as resp: if resp.status != 200: error_body = await resp.text() return (f"Error: Retrieval server returned HTTP {resp.status}: " f"{error_body[:300]}") data = await resp.json() # 已确认的 Search-R1 返回格式: # {"result": [[{"document": {"id": "xx", "contents": "\"title\"\ntext"}, "score": 0.84}, ...]]} raw_results = data.get("result", []) for query_results in raw_results: if not isinstance(query_results, list): continue for item in query_results: if not isinstance(item, dict): continue doc = item.get("document", {}) if isinstance(doc, dict): contents = doc.get("contents", "") elif isinstance(doc, str): contents = doc else: continue if contents: lines = contents.split("\n", 1) title = lines[0].strip('"') if lines else "" text = lines[1] if len(lines) > 1 else contents # 截断过长的 passage,避免超出 summarizer context window retrieved_passages.append({"title": title, "text": text[:2000]}) finally: if sem: sem.release() except asyncio.TimeoutError: return f"Error: Retrieval server timeout for query: {clean_query[:100]}" except (aiohttp.ClientConnectorError, aiohttp.ServerDisconnectedError, ConnectionRefusedError, ConnectionResetError, OSError) as e: raise RetrieverDownError( f"Retriever 服务 {retrieval_addr} 连接失败: {e}" ) from e except Exception as e: err_str = str(e).lower() if any(kw in err_str for kw in ( "cannot connect", "connection refused", "connect call failed", "server disconnected", )): raise RetrieverDownError( f"Retriever 服务 {retrieval_addr} 连接失败: {e}" ) from e return f"Error: Retrieval failed: {e}" if not retrieved_passages: return f"No relevant passages found for query: {clean_query}" # ════════════════════════════════════════════════════════════════ # Step 2: Summarize via LLM (OpenAI-compatible API) # ════════════════════════════════════════════════════════════════ context_parts = [] for i, p in enumerate(retrieved_passages, 1): title_str = f" (Title: {p['title']})" if p['title'] else "" context_parts.append(f"Passage {i}{title_str}:\n{p['text']}") context_text = "\n\n".join(context_parts) summarizer_prompt = ( f"Based on the following retrieved passages, provide a concise and informative " f"summary that answers the query: \"{clean_query}\"\n\n" f"{context_text}\n\n" f"Please provide a concise summary focusing on the most relevant information. " f"If the passages do not contain relevant information, say so." ) summarizer_payload = { "model": sum_model, "messages": [{"role": "user", "content": summarizer_prompt}], "max_tokens": 1024, "temperature": 0.3, "chat_template_kwargs": {"enable_thinking": False}, } try: async with session.post( f"http://{summarizer_addr}/v1/chat/completions", json=summarizer_payload, headers={"Content-Type": "application/json"}, timeout=aiohttp.ClientTimeout(total=120), proxy="", # ← 绕过系统代理,直连内网 ) as resp: if resp.status != 200: print(f" [MARS_SEARCH] Summarizer returned HTTP {resp.status}, " f"falling back to raw passages") return _format_raw_passages(clean_query, retrieved_passages) data = await resp.json() choices = data.get("choices", []) if choices and isinstance(choices, list): msg = choices[0].get("message", {}) summary = msg.get("content", "") if summary and summary.strip(): summary = _strip_thinking_tags(summary) return ( f"Web Search Results for \"{clean_query}\" " f"(MARS retrieve+summarize, {len(retrieved_passages)} passages):\n\n" f"{summary.strip()}" ) return _format_raw_passages(clean_query, retrieved_passages) except asyncio.TimeoutError: print(f" [MARS_SEARCH] Summarizer timeout, falling back to raw passages") return _format_raw_passages(clean_query, retrieved_passages) except Exception as e: print(f" [MARS_SEARCH] Summarizer error: {e}, falling back to raw passages") return _format_raw_passages(clean_query, retrieved_passages) def _strip_thinking_tags(text: str) -> str: """Remove thinking content from summarizer output.""" # 1. 移除 XML 风格的 ... text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() # 2. 移除 "Thinking Process:" 开头的内容块(到第一个连续空行为止) text = re.sub(r'^Thinking Process:.*?(?=\n\n)', '', text, flags=re.DOTALL).strip() # 3. 如果清理后还是以 "Thinking" 开头(边界情况),再截一次 if text.startswith("Thinking"): parts = text.split("\n\n", 1) if len(parts) > 1: text = parts[1].strip() return text def _format_raw_passages(query: str, passages: list) -> str: """Format raw retrieved passages as fallback when summarizer fails.""" parts = [f"Web Search Results for \"{query}\" (raw retrieval, {len(passages)} passages):"] for i, p in enumerate(passages, 1): title_str = f" — {p['title']}" if p['title'] else "" text_preview = p['text'][:500] + ("..." if len(p['text']) > 500 else "") parts.append(f"\nResult {i}{title_str}:\n {text_preview}") return "\n".join(parts) # ════════════════════════════════════════════════════════════════════════ # I/O Helpers # ════════════════════════════════════════════════════════════════════════ def make_uid(entry: Dict) -> str: """Generate a unique ID from id + video_filename. A single 'id' can map to multiple videos, so we need a composite key.""" eid = entry.get("id", "unknown") vf = entry.get("video_filename", "") if vf: stem = os.path.splitext(vf)[0] return f"{eid}__{stem}" return eid def load_completed_ids(output_file: str) -> Set[str]: """Load already-completed entry UIDs from output JSONL for resume.""" completed = set() if not os.path.exists(output_file): return completed with open(output_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: record = json.loads(line) if "uid" in record and "error" not in record: completed.add(record["uid"]) except json.JSONDecodeError: pass return completed def get_question(entry: Dict) -> str: """Get the appropriate question field based on verdict.""" if entry.get("verdict") == "rewrite": return entry.get("rewritten_question") or entry.get("original_question", "") return entry.get("original_question", "") # ════════════════════════════════════════════════════════════════════════ # OSS Image Upload for Search (解决 Serper Lens 400 Bad Request) # ════════════════════════════════════════════════════════════════════════ # 对齐第二份代码的思路:先把裁剪图上传到公网可访问的对象存储 → # 拿到 https://... URL → 传给 Serper Lens → 获取搜索结果。 # # Serper Lens 对 base64 data URI 支持不稳定(频繁返回 400), # 但对 HTTPS URL 工作正常。 _oss_bucket = None # 延迟初始化,避免未安装时直接报错 _oss_init_attempted = False # 只尝试初始化一次 _oss_last_error = "" def _get_oss_bucket(): """延迟初始化 OSS Bucket 对象,只初始化一次。""" global _oss_bucket, _oss_init_attempted, _oss_last_error if _oss_init_attempted: return _oss_bucket _oss_init_attempted = True try: import oss2 if not OSS_ACCESS_KEY_ID or not OSS_ACCESS_KEY_SECRET: _oss_last_error = "OSS_ACCESS_KEY_ID / OSS_ACCESS_KEY_SECRET 未配置" print(f"[ERROR] {_oss_last_error}") return None auth = oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET) _oss_bucket = oss2.Bucket(auth, f"https://{OSS_ENDPOINT}", OSS_BUCKET_NAME) _oss_last_error = "" print(f"[OSS] Bucket initialized: {OSS_BUCKET_NAME} @ {OSS_ENDPOINT}") return _oss_bucket except ImportError: _oss_last_error = "oss2 未安装,请执行: pip install oss2 --break-system-packages" print(f"[ERROR] {_oss_last_error}") return None except Exception as e: _oss_last_error = f"OSS Bucket 初始化失败: {e}" print(f"[ERROR] {_oss_last_error}") return None def optimize_crop_for_search(crop_path: str, output_path: str = None, max_size: int = SEARCH_CROP_MAX_SIZE, quality: int = SEARCH_CROP_JPEG_QUALITY) -> str: """优化裁剪图用于搜索:缩小尺寸 + 降低质量,减少上传体积。 Args: crop_path: 原始裁剪图路径 output_path: 优化后保存路径(默认在同目录加 _opt 后缀) max_size: 最大边长(像素) quality: JPEG 质量 Returns: 优化后的图片路径(如果优化失败则返回原路径)。 """ if output_path is None: base, ext = os.path.splitext(crop_path) output_path = f"{base}_opt{ext}" try: with Image.open(crop_path) as img: w, h = img.size if max(w, h) > max_size: ratio = max_size / max(w, h) new_w, new_h = int(w * ratio), int(h * ratio) img = img.resize((new_w, new_h), Image.LANCZOS) img.save(output_path, 'JPEG', quality=quality) return output_path except Exception as e: print(f" [WARN] optimize_crop_for_search failed: {e}, using original") return crop_path def _upload_to_oss(local_path: str, oss_object_name: str = None) -> Optional[str]: """上传图片到阿里云 OSS,返回公网 URL。 去重策略:用文件内容的 MD5 作为对象名,相同内容不重复上传。 Args: local_path: 本地图片路径 oss_object_name: OSS 对象名,默认用 md5 hash 去重 Returns: 公网 URL,失败返回 None """ global _oss_last_error bucket = _get_oss_bucket() if bucket is None: return None try: # 先优化图片尺寸和质量 opt_path = optimize_crop_for_search(local_path) # 用 md5 作为文件名去重 with open(opt_path, "rb") as f: file_bytes = f.read() file_hash = hashlib.md5(file_bytes).hexdigest() if oss_object_name is None: oss_object_name = f"{OSS_UPLOAD_PREFIX}/{file_hash}.jpg" # 检查是否已存在(head_object 成功说明已上传过) try: bucket.head_object(oss_object_name) # 已存在,直接返回 URL public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}" print(f" [OSS] Cache hit: {public_url}") return public_url except Exception: # 不存在或其他异常,继续上传 pass # 上传 bucket.put_object(oss_object_name, file_bytes, headers={ 'Content-Type': 'image/jpeg', }) public_url = f"https://{OSS_BUCKET_NAME}.{OSS_ENDPOINT}/{oss_object_name}" print(f" [OSS] Uploaded: {public_url}") return public_url except Exception as e: _oss_last_error = f"OSS upload failed: {e}" print(f" [WARN] {_oss_last_error}") return None def _prepare_image_search_url( image_b64_or_path: str, crop_path: Optional[str], log_prefix: str, ) -> Tuple[Optional[str], Optional[str]]: """为图片搜索准备 URL;默认要求先上传到 OSS。""" image_url = None if crop_path and os.path.exists(crop_path): image_url = _upload_to_oss(crop_path) if image_url: print(f" [{log_prefix}] Using OSS URL: {image_url}") return image_url, None if IMAGE_SEARCH_ALLOW_BASE64_FALLBACK: image_url = f"data:image/jpeg;base64,{image_b64_or_path}" print(f" [{log_prefix}] WARNING: OSS upload failed, falling back to " f"base64 data URI (len={len(image_b64_or_path)}, may get 400)") return image_url, None failure_reason = _oss_last_error or "OSS upload returned None" return None, ( "Image search unavailable: OSS upload failed and base64 fallback is disabled. " f"Root cause: {failure_reason}" )