| |
| """ |
| 请求处理相关的工具函数,包括 Token 估算、上下文截断、速率限制检查和计数更新、上下文保存等。 |
| """ |
| import json |
| import logging |
| import time |
| from typing import List, Dict, Any, Optional, Tuple |
| from collections import Counter, defaultdict |
|
|
| |
| from app.core.database import utils as db_utils |
| from app.core.context import store as context_store |
| |
| from app import config as app_config |
| |
| from app.core.tracking import ( |
| ip_daily_input_token_counts, ip_input_token_counts_lock, |
| usage_data, usage_lock, RPM_WINDOW_SECONDS, TPM_WINDOW_SECONDS |
| ) |
|
|
| |
| logger = logging.getLogger('my_logger') |
|
|
| |
|
|
| def estimate_token_count(contents: List[Dict[str, Any]]) -> int: |
| """ |
| 估算 Gemini contents 列表的 Token 数量。 |
| 使用简单的字符数估算方法 (1 个 token 大约等于 4 个字符)。 |
| 注意:这是一个非常粗略的估算,实际 Token 数可能因模型和内容而异。 |
| |
| Args: |
| contents (List[Dict[str, Any]]): Gemini 格式的内容列表。 |
| |
| Returns: |
| int: 估算的 Token 数量。 |
| """ |
| if not contents: |
| return 0 |
| try: |
| |
| |
| char_count = len(json.dumps(contents, ensure_ascii=False)) |
| |
| return char_count // 4 |
| except TypeError as e: |
| |
| logger.error(f"序列化 contents 进行 Token 估算时出错: {e}", exc_info=True) |
| return 0 |
|
|
| async def truncate_context( |
| contents: List[Dict[str, Any]], |
| model_name: str, |
| dynamic_max_tokens_limit: Optional[int] = None |
| ) -> Tuple[List[Dict[str, Any]], bool]: |
| """ |
| 根据模型限制和可选的动态限制截断对话历史 (contents)。 |
| 采用从开头成对移除消息(通常是 user/model 对)的策略, |
| 直到估算的 Token 数量满足限制要求。 |
| |
| Args: |
| contents (List[Dict[str, Any]]): 完整的对话历史列表 (Gemini 格式)。 |
| model_name (str): 当前请求使用的模型名称,用于查找其 Token 限制。 |
| dynamic_max_tokens_limit (Optional[int]): 可选的动态 Token 限制, |
| 通常基于 API Key 的实时可用容量。如果提供,将使用此限制与模型静态限制中的较小值。 |
| |
| Returns: |
| Tuple[List[Dict[str, Any]], bool]: |
| - 第一个元素是截断后的对话历史列表。 |
| - 第二个元素是一个布尔值,指示截断后是否仍然超限 |
| (True 表示超限,False 表示未超限或无需截断)。 |
| 如果返回 True,调用者通常不应保存此上下文,因为它可能仍然过长。 |
| """ |
| if not contents: |
| return [], False |
|
|
| |
| |
| |
| default_max_tokens = getattr(app_config, 'DEFAULT_MAX_CONTEXT_TOKENS', 30000) |
| safety_margin = getattr(app_config, 'CONTEXT_TOKEN_SAFETY_MARGIN', 200) |
|
|
| |
| model_limits = getattr(app_config, 'MODEL_LIMITS', {}) |
| limit_info = model_limits.get(model_name) |
| static_max_tokens = default_max_tokens |
| if limit_info and isinstance(limit_info, dict) and limit_info.get("input_token_limit"): |
| try: |
| limit_value = limit_info["input_token_limit"] |
| if limit_value is not None: |
| static_max_tokens = int(limit_value) |
| else: |
| |
| logger.warning(f"模型 '{model_name}' 的 input_token_limit 值为 null,使用默认值 {default_max_tokens}") |
| except (ValueError, TypeError): |
| |
| logger.warning(f"模型 '{model_name}' 的 input_token_limit 值无效 ('{limit_info.get('input_token_limit')}'),使用默认值 {default_max_tokens}") |
| else: |
| |
| logger.warning(f"模型 '{model_name}' 或其 input_token_limit 未在 model_limits.json 中定义,使用默认值 {default_max_tokens}") |
|
|
| |
| actual_max_tokens = static_max_tokens |
| if dynamic_max_tokens_limit is not None and dynamic_max_tokens_limit >= 0: |
| |
| actual_max_tokens = min(static_max_tokens, dynamic_max_tokens_limit) |
| |
| logger.debug(f"使用动态限制 {dynamic_max_tokens_limit} 和静态限制 {static_max_tokens},最终最大 Token 限制为 {actual_max_tokens}") |
|
|
| |
| |
| truncation_threshold = max(0, actual_max_tokens - safety_margin) |
|
|
| |
| |
| estimated_tokens = estimate_token_count(contents) |
|
|
| |
| if estimated_tokens > truncation_threshold: |
| logger.info(f"上下文估算 Token ({estimated_tokens}) 超出阈值 ({truncation_threshold} for model {model_name}, actual max tokens {actual_max_tokens}),开始截断...") |
| |
| truncated_contents = list(contents) |
| |
| while estimate_token_count(truncated_contents) > truncation_threshold and len(truncated_contents) >= 2: |
| |
| removed_first = truncated_contents.pop(0) |
| removed_second = truncated_contents.pop(0) |
| |
| logger.debug(f"移除旧消息对: roles={removed_first.get('role')}, {removed_second.get('role')}") |
|
|
| |
| final_estimated_tokens = estimate_token_count(truncated_contents) |
|
|
| |
| if final_estimated_tokens > truncation_threshold: |
| |
| logger.error(f"截断后上下文估算 Token ({final_estimated_tokens}) 仍然超过阈值 ({truncation_threshold})。本次交互的上下文不应被保存。") |
| |
| return truncated_contents, True |
| else: |
| |
| logger.info(f"上下文截断完成,剩余消息数: {len(truncated_contents)}, 最终估算 Token: {final_estimated_tokens}") |
| |
| return truncated_contents, False |
| else: |
| |
| return contents, False |
|
|
| |
|
|
| def check_rate_limits_and_update_counts( |
| api_key: str, |
| model_name: str, |
| limits: Optional[Dict[str, Any]] |
| ) -> bool: |
| """ |
| 检查给定 API Key 和模型的速率限制 (RPD, TPD_Input, RPM, TPM_Input)。 |
| 此函数在选择 Key *之前* 调用,用于预检查 Key 是否已达到已知限制。 |
| 如果未达到限制,则更新 RPM 和 RPD 计数(假设本次请求会发生),并返回 True。 |
| 如果达到任何限制,则记录警告并返回 False,表示不应选择此 Key。 |
| |
| Args: |
| api_key (str): 当前尝试使用的 API Key。 |
| model_name (str): 请求的模型名称。 |
| limits (Optional[Dict[str, Any]]): 从配置中获取的该模型的限制字典。 |
| |
| Returns: |
| bool: 如果根据已知计数判断可以继续进行 API 调用则返回 True,否则返回 False。 |
| """ |
| if not limits: |
| logger.warning(f"模型 '{model_name}' 不在 model_limits.json 中,跳过本地速率限制检查。") |
| return True |
|
|
| now = time.time() |
| perform_api_call = True |
|
|
| with usage_lock: |
| |
| |
| key_usage = usage_data.setdefault(api_key, defaultdict(lambda: defaultdict(int)))[model_name] |
|
|
| |
| rpm_limit = limits.get("rpm") |
| if rpm_limit is not None: |
| current_rpm_count = key_usage.get("rpm_count", 0) |
| rpm_timestamp = key_usage.get("rpm_timestamp", 0) |
|
|
| if now - rpm_timestamp >= RPM_WINDOW_SECONDS: |
| |
| key_usage["rpm_count"] = 1 |
| key_usage["rpm_timestamp"] = now |
| logger.debug(f"RPM 窗口过期,重置计数并增加 (Key: {api_key[:8]}, Model: {model_name}): 新 RPM=1") |
| else: |
| |
| if current_rpm_count + 1 > rpm_limit: |
| logger.warning(f"速率限制预检查失败 (Key: {api_key[:8]}, Model: {model_name}): RPM 达到限制 ({current_rpm_count}/{rpm_limit})。跳过此 Key。") |
| perform_api_call = False |
| else: |
| |
| key_usage["rpm_count"] = current_rpm_count + 1 |
| |
| logger.debug(f"RPM 计数增加 (Key: {api_key[:8]}, Model: {model_name}): 新 RPM={key_usage['rpm_count']}") |
|
|
| |
| |
| if perform_api_call: |
| rpd_limit = limits.get("rpd") |
| if rpd_limit is not None: |
| current_rpd_count = key_usage.get("rpd_count", 0) |
| |
| if current_rpd_count + 1 > rpd_limit: |
| logger.warning(f"速率限制预检查失败 (Key: {api_key[:8]}, Model: {model_name}): RPD 达到限制 ({current_rpd_count}/{rpd_limit})。跳过此 Key。") |
| perform_api_call = False |
| else: |
| |
| key_usage["rpd_count"] = current_rpd_count + 1 |
| logger.debug(f"RPD 计数增加 (Key: {api_key[:8]}, Model: {model_name}): 新 RPD={key_usage['rpd_count']}") |
|
|
| |
| |
| |
| if perform_api_call: |
| tpd_input_limit = limits.get("tpd_input") |
| if tpd_input_limit is not None and key_usage.get("tpd_input_count", 0) >= tpd_input_limit: |
| logger.warning(f"速率限制预检查失败 (Key: {api_key[:8]}, Model: {model_name}): TPD_Input 达到限制 ({key_usage.get('tpd_input_count', 0)}/{tpd_input_limit})。跳过此 Key。") |
| perform_api_call = False |
|
|
| |
| |
| if perform_api_call: |
| tpm_input_limit = limits.get("tpm_input") |
| if tpm_input_limit is not None: |
| |
| if now - key_usage.get("tpm_input_timestamp", 0) < TPM_WINDOW_SECONDS: |
| |
| if key_usage.get("tpm_input_count", 0) >= tpm_input_limit: |
| logger.warning(f"速率限制预检查失败 (Key: {api_key[:8]}, Model: {model_name}): TPM_Input 达到限制 ({key_usage.get('tpm_input_count', 0)}/{tpm_input_limit})。跳过此 Key。") |
| perform_api_call = False |
| |
| |
|
|
| |
| if perform_api_call: |
| key_usage["last_request_timestamp"] = now |
|
|
| return perform_api_call |
|
|
| def update_token_counts( |
| api_key: str, |
| model_name: str, |
| limits: Optional[Dict[str, Any]], |
| prompt_tokens: Optional[int], |
| client_ip: str, |
| today_date_str_pt: str |
| ) -> None: |
| """ |
| 在 API 调用成功 *之后* 更新给定 API Key 和模型的 TPD_Input 和 TPM_Input 计数。 |
| 同时记录基于 IP 的每日输入 Token 消耗。 |
| |
| Args: |
| api_key (str): 当前成功使用的 API Key。 |
| model_name (str): 请求的模型名称。 |
| limits (Optional[Dict[str, Any]]): 从配置中获取的该模型的限制字典。 |
| prompt_tokens (Optional[int]): 从 API 响应中获取的实际输入 Token 数量。 |
| client_ip (str): 客户端 IP 地址。 |
| today_date_str_pt (str): 当前的太平洋时区日期字符串 (YYYY-MM-DD),用于 IP 每日计数。 |
| """ |
| |
| if not limits or not prompt_tokens or prompt_tokens <= 0: |
| if limits and (not prompt_tokens or prompt_tokens <= 0): |
| logger.warning(f"Token 计数更新跳过 (Key: {api_key[:8]}, Model: {model_name}): 无效的 prompt_tokens ({prompt_tokens})。") |
| |
| return |
|
|
| with usage_lock: |
| |
| key_usage = usage_data.setdefault(api_key, defaultdict(lambda: defaultdict(int)))[model_name] |
|
|
| |
| |
| key_usage["tpd_input_count"] = key_usage.get("tpd_input_count", 0) + prompt_tokens |
|
|
| |
| tpm_input_limit = limits.get("tpm_input") |
| if tpm_input_limit is not None: |
| now_tpm = time.time() |
| |
| if now_tpm - key_usage.get("tpm_input_timestamp", 0) >= TPM_WINDOW_SECONDS: |
| |
| key_usage["tpm_input_count"] = prompt_tokens |
| key_usage["tpm_input_timestamp"] = now_tpm |
| else: |
| |
| key_usage["tpm_input_count"] = key_usage.get("tpm_input_count", 0) + prompt_tokens |
| |
| logger.debug(f"输入 Token 计数更新 (Key: {api_key[:8]}, Model: {model_name}): Added TPD_Input={prompt_tokens}, TPM_Input={key_usage['tpm_input_count']}") |
|
|
| |
| |
| with ip_input_token_counts_lock: |
| |
| |
| ip_daily_input_token_counts.setdefault(today_date_str_pt, Counter())[client_ip] += prompt_tokens |
|
|
| |
| async def save_context_after_success( |
| proxy_key: str, |
| contents_to_send: List[Dict[str, Any]], |
| model_reply_content: str, |
| model_name: str, |
| enable_context: bool, |
| final_tool_calls: Optional[List[Dict[str, Any]]] = None |
| ): |
| """ |
| 在 API 调用成功后保存上下文(如果启用)。 |
| |
| Args: |
| proxy_key (str): 用于存储上下文的键 (通常是 user_id)。 |
| contents_to_send (List[Dict[str, Any]]): 发送给模型的最终内容列表 (包含历史)。 |
| model_reply_content (str): 模型返回的文本回复。 |
| model_name (str): 使用的模型名称。 |
| enable_context (bool): 是否启用上下文保存功能。 |
| final_tool_calls (Optional[List[Dict[str, Any]]]): 模型返回的工具调用信息(目前暂未处理)。 |
| """ |
| if not enable_context: |
| logger.debug(f"Key {proxy_key[:8]}... 的上下文补全已禁用,跳过上下文保存。") |
| return |
|
|
| |
| logger.debug(f"准备为 Key '{proxy_key[:8]}...' 保存上下文 (内存模式: {db_utils.IS_MEMORY_DB})") |
|
|
| |
| model_reply_part = {"role": "model", "parts": [{"text": model_reply_content}]} |
| if final_tool_calls: |
| |
| |
| |
| |
| logger.warning("上下文保存:暂未处理工具调用 (tool_calls) 的保存。") |
| pass |
|
|
|
|
| |
| final_contents_to_save = contents_to_send + [model_reply_part] |
|
|
| |
| |
| |
| |
| |
| truncated_contents_to_save, still_over_limit_final = await truncate_context(final_contents_to_save, model_name) |
|
|
| if not still_over_limit_final: |
| try: |
| |
| await context_store.save_context(proxy_key, truncated_contents_to_save) |
| logger.info(f"上下文保存成功 for Key {proxy_key[:8]}...") |
| except Exception as e: |
| |
| logger.error(f"保存上下文失败 (Key: {proxy_key[:8]}...): {str(e)}", exc_info=True) |
| else: |
| |
| logger.error(f"上下文在添加回复并再次截断后仍然超限 (Key: {proxy_key[:8]}...). 上下文未保存。") |
|
|
| |
| def process_tool_calls(gemini_tool_calls: Any) -> Optional[List[Dict[str, Any]]]: |
| """ |
| 将 Gemini 返回的 functionCall 列表转换为 OpenAI 兼容的 tool_calls 格式。 |
| Gemini: [{'functionCall': {'name': 'func_name', 'args': {...}}}] |
| OpenAI: [{'id': 'call_...', 'type': 'function', 'function': {'name': 'func_name', 'arguments': '{...}'}}] |
| """ |
| if not isinstance(gemini_tool_calls, list): |
| logger.warning(f"期望 gemini_tool_calls 是列表,但得到 {type(gemini_tool_calls)}") |
| return None |
|
|
| openai_tool_calls = [] |
| |
| for i, call in enumerate(gemini_tool_calls): |
| |
| if not isinstance(call, dict): |
| logger.warning(f"工具调用列表中的元素不是字典: {call}") |
| continue |
|
|
| |
| function_call_data = call.get('functionCall') |
| if not isinstance(function_call_data, dict): |
| logger.warning(f"工具调用元素缺少有效的 'functionCall' 字典: {call}") |
| continue |
|
|
| |
| func_name = function_call_data.get('name') |
| if not isinstance(func_name, str) or not func_name: |
| logger.warning(f"工具调用元素缺少有效的 'name' 字段: {call}") |
| continue |
|
|
| |
| func_args = function_call_data.get('args') |
| if not isinstance(func_args, dict): |
| logger.warning(f"工具调用元素缺少有效的 'args' 字典: {call}") |
| continue |
|
|
| try: |
| |
| arguments_str = json.dumps(func_args, ensure_ascii=False) |
| except TypeError as e: |
| logger.error(f"序列化工具调用参数失败 (Name: {func_name}): {e}", exc_info=True) |
| continue |
|
|
| |
| openai_tool_calls.append({ |
| "id": f"call_{int(time.time()*1000)}_{i}", |
| "type": "function", |
| "function": { |
| "name": func_name, |
| "arguments": arguments_str, |
| } |
| }) |
|
|
| return openai_tool_calls if openai_tool_calls else None |
|
|