""" API工具函数模块 包含SSE生成、流处理、token统计和请求验证等工具函数 """ import asyncio import json import time import datetime from typing import Any, Dict, List, Optional, AsyncGenerator from asyncio import Queue from models import Message # --- SSE生成函数 --- def generate_sse_chunk(delta: str, req_id: str, model: str) -> str: """生成SSE数据块""" chunk_data = { "id": f"chatcmpl-{req_id}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}] } return f"data: {json.dumps(chunk_data)}\n\n" def generate_sse_stop_chunk(req_id: str, model: str, reason: str = "stop", usage: dict = None) -> str: """生成SSE停止块""" stop_chunk_data = { "id": f"chatcmpl-{req_id}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {}, "finish_reason": reason}] } # 添加usage信息(如果提供) if usage: stop_chunk_data["usage"] = usage return f"data: {json.dumps(stop_chunk_data)}\n\ndata: [DONE]\n\n" def generate_sse_error_chunk(message: str, req_id: str, error_type: str = "server_error") -> str: """生成SSE错误块""" error_chunk = {"error": {"message": message, "type": error_type, "param": None, "code": req_id}} return f"data: {json.dumps(error_chunk)}\n\n" # --- 流处理工具函数 --- async def use_stream_response(req_id: str) -> AsyncGenerator[Any, None]: """使用流响应(从服务器的全局队列获取数据)""" from server import STREAM_QUEUE, logger import queue if STREAM_QUEUE is None: logger.warning(f"[{req_id}] STREAM_QUEUE is None, 无法使用流响应") return logger.info(f"[{req_id}] 开始使用流响应") empty_count = 0 max_empty_retries = 300 # 30秒超时 data_received = False try: while True: try: # 从队列中获取数据 data = STREAM_QUEUE.get_nowait() if data is None: # 结束标志 logger.info(f"[{req_id}] 接收到流结束标志") break # 重置空计数器 empty_count = 0 data_received = True logger.debug(f"[{req_id}] 接收到流数据: {type(data)} - {str(data)[:200]}...") # 检查是否是JSON字符串形式的结束标志 if isinstance(data, str): try: parsed_data = json.loads(data) if parsed_data.get("done") is True: logger.info(f"[{req_id}] 接收到JSON格式的完成标志") yield parsed_data break else: yield parsed_data except json.JSONDecodeError: # 如果不是JSON,直接返回字符串 logger.debug(f"[{req_id}] 返回非JSON字符串数据") yield data else: # 直接返回数据 yield data # 检查字典类型的结束标志 if isinstance(data, dict) and data.get("done") is True: logger.info(f"[{req_id}] 接收到字典格式的完成标志") break except (queue.Empty, asyncio.QueueEmpty): empty_count += 1 if empty_count % 50 == 0: # 每5秒记录一次等待状态 logger.info(f"[{req_id}] 等待流数据... ({empty_count}/{max_empty_retries})") if empty_count >= max_empty_retries: if not data_received: logger.error(f"[{req_id}] 流响应队列空读取次数达到上限且未收到任何数据,可能是辅助流未启动或出错") else: logger.warning(f"[{req_id}] 流响应队列空读取次数达到上限 ({max_empty_retries}),结束读取") # 返回超时完成信号,而不是简单退出 yield {"done": True, "reason": "internal_timeout", "body": "", "function": []} return await asyncio.sleep(0.1) # 100ms等待 continue except Exception as e: logger.error(f"[{req_id}] 使用流响应时出错: {e}") raise finally: logger.info(f"[{req_id}] 流响应使用完成,数据接收状态: {data_received}") async def clear_stream_queue(): """清空流队列(与原始参考文件保持一致)""" from server import STREAM_QUEUE, logger import queue if STREAM_QUEUE is None: logger.info("流队列未初始化或已被禁用,跳过清空操作。") return while True: try: data_chunk = await asyncio.to_thread(STREAM_QUEUE.get_nowait) # logger.info(f"清空流式队列缓存,丢弃数据: {data_chunk}") except queue.Empty: logger.info("流式队列已清空 (捕获到 queue.Empty)。") break except Exception as e: logger.error(f"清空流式队列时发生意外错误: {e}", exc_info=True) break logger.info("流式队列缓存清空完毕。") # --- Helper response generator --- async def use_helper_get_response(helper_endpoint: str, helper_sapisid: str) -> AsyncGenerator[str, None]: """使用Helper服务获取响应的生成器""" from server import logger import aiohttp logger.info(f"正在尝试使用Helper端点: {helper_endpoint}") try: async with aiohttp.ClientSession() as session: headers = { 'Content-Type': 'application/json', 'Cookie': f'SAPISID={helper_sapisid}' if helper_sapisid else '' } async with session.get(helper_endpoint, headers=headers) as response: if response.status == 200: async for chunk in response.content.iter_chunked(1024): if chunk: yield chunk.decode('utf-8', errors='ignore') else: logger.error(f"Helper端点返回错误状态: {response.status}") except Exception as e: logger.error(f"使用Helper端点时出错: {e}") # --- 请求验证函数 --- def validate_chat_request(messages: List[Message], req_id: str) -> Dict[str, Optional[str]]: """验证聊天请求""" from server import logger if not messages: raise ValueError(f"[{req_id}] 无效请求: 'messages' 数组缺失或为空。") if not any(msg.role != 'system' for msg in messages): raise ValueError(f"[{req_id}] 无效请求: 所有消息都是系统消息。至少需要一条用户或助手消息。") # 返回验证结果 return { "error": None, "warning": None } # --- 提示准备函数 --- def prepare_combined_prompt(messages: List[Message], req_id: str) -> str: """准备组合提示""" from server import logger logger.info(f"[{req_id}] (准备提示) 正在从 {len(messages)} 条消息准备组合提示 (包括历史)。") combined_parts = [] system_prompt_content: Optional[str] = None processed_system_message_indices = set() # 处理系统消息 for i, msg in enumerate(messages): if msg.role == 'system': content = msg.content if isinstance(content, str) and content.strip(): system_prompt_content = content.strip() processed_system_message_indices.add(i) logger.info(f"[{req_id}] (准备提示) 在索引 {i} 找到并使用系统提示: '{system_prompt_content[:80]}...'") system_instr_prefix = "系统指令:\n" combined_parts.append(f"{system_instr_prefix}{system_prompt_content}") else: logger.info(f"[{req_id}] (准备提示) 在索引 {i} 忽略非字符串或空的系统消息。") processed_system_message_indices.add(i) break role_map_ui = {"user": "用户", "assistant": "助手", "system": "系统", "tool": "工具"} turn_separator = "\n---\n" # 处理其他消息 for i, msg in enumerate(messages): if i in processed_system_message_indices: continue if msg.role == 'system': logger.info(f"[{req_id}] (准备提示) 跳过在索引 {i} 的后续系统消息。") continue if combined_parts: combined_parts.append(turn_separator) role = msg.role or 'unknown' role_prefix_ui = f"{role_map_ui.get(role, role.capitalize())}:\n" current_turn_parts = [role_prefix_ui] content = msg.content or '' content_str = "" if isinstance(content, str): content_str = content.strip() elif isinstance(content, list): # 处理多模态内容 text_parts = [] for item in content: if hasattr(item, 'type') and item.type == 'text': text_parts.append(item.text or '') elif isinstance(item, dict) and item.get('type') == 'text': text_parts.append(item.get('text', '')) else: logger.warning(f"[{req_id}] (准备提示) 警告: 在索引 {i} 的消息中忽略非文本或未知类型的 content item") content_str = "\n".join(text_parts).strip() else: logger.warning(f"[{req_id}] (准备提示) 警告: 角色 {role} 在索引 {i} 的内容类型意外 ({type(content)}) 或为 None。") content_str = str(content or "").strip() if content_str: current_turn_parts.append(content_str) # 处理工具调用 tool_calls = msg.tool_calls if role == 'assistant' and tool_calls: if content_str: current_turn_parts.append("\n") tool_call_visualizations = [] for tool_call in tool_calls: if hasattr(tool_call, 'type') and tool_call.type == 'function': function_call = tool_call.function func_name = function_call.name if function_call else None func_args_str = function_call.arguments if function_call else None try: parsed_args = json.loads(func_args_str if func_args_str else '{}') formatted_args = json.dumps(parsed_args, indent=2, ensure_ascii=False) except (json.JSONDecodeError, TypeError): formatted_args = func_args_str if func_args_str is not None else "{}" tool_call_visualizations.append( f"请求调用函数: {func_name}\n参数:\n{formatted_args}" ) if tool_call_visualizations: current_turn_parts.append("\n".join(tool_call_visualizations)) if len(current_turn_parts) > 1 or (role == 'assistant' and tool_calls): combined_parts.append("".join(current_turn_parts)) elif not combined_parts and not current_turn_parts: logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (且无工具调用)。") elif len(current_turn_parts) == 1 and not combined_parts: logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (只有前缀)。") final_prompt = "".join(combined_parts) if final_prompt: final_prompt += "\n" preview_text = final_prompt[:300].replace('\n', '\\n') logger.info(f"[{req_id}] (准备提示) 组合提示长度: {len(final_prompt)}。预览: '{preview_text}...'") return final_prompt def estimate_tokens(text: str) -> int: """ 估算文本的token数量 使用简单的字符计数方法: - 英文:大约4个字符 = 1个token - 中文:大约1.5个字符 = 1个token - 混合文本:采用加权平均 """ if not text: return 0 # 统计中文字符数量(包括中文标点) chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef') # 统计非中文字符数量 non_chinese_chars = len(text) - chinese_chars # 计算token估算 chinese_tokens = chinese_chars / 1.5 # 中文大约1.5字符/token english_tokens = non_chinese_chars / 4.0 # 英文大约4字符/token return max(1, int(chinese_tokens + english_tokens)) def calculate_usage_stats(messages: List[dict], response_content: str, reasoning_content: str = None) -> dict: """ 计算token使用统计 Args: messages: 请求中的消息列表 response_content: 响应内容 reasoning_content: 推理内容(可选) Returns: 包含token使用统计的字典 """ # 计算输入token(prompt tokens) prompt_text = "" for message in messages: role = message.get("role", "") content = message.get("content", "") prompt_text += f"{role}: {content}\n" prompt_tokens = estimate_tokens(prompt_text) # 计算输出token(completion tokens) completion_text = response_content or "" if reasoning_content: completion_text += reasoning_content completion_tokens = estimate_tokens(completion_text) # 总token数 total_tokens = prompt_tokens + completion_tokens return { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens } def generate_sse_stop_chunk_with_usage(req_id: str, model: str, usage_stats: dict, reason: str = "stop") -> str: """生成带usage统计的SSE停止块""" return generate_sse_stop_chunk(req_id, model, reason, usage_stats)