| |
| """ |
| 处理流式响应的逻辑。 |
| """ |
| import asyncio |
| import json |
| import logging |
| import time |
| from typing import List, Dict, Any, Optional, AsyncGenerator |
| from collections import defaultdict |
|
|
| |
| from app.api.models import ChatCompletionRequest |
| from app.core.services.gemini import GeminiClient |
| from app.core.keys.manager import APIKeyManager |
| from app.core.cache.manager import CacheManager |
| from sqlalchemy.ext.asyncio import AsyncSession |
| import httpx |
|
|
| |
| from app.core.processing.utils import save_context_after_success, update_token_counts |
| |
| from app import config |
|
|
| |
| from app.core.tracking import usage_data, usage_lock |
|
|
| logger = logging.getLogger('my_logger') |
|
|
| async def handle_stream_end( |
| response_id: str, |
| assistant_message_yielded: bool, |
| actual_finish_reason: str, |
| safety_issue_detail_received: Optional[Dict[str, Any]] |
| ) -> AsyncGenerator[str, None]: |
| """ |
| 处理流式响应结束时的逻辑,根据不同情况发送合适的结束块或错误块。 |
| 确保最后发送 [DONE] 标记。 |
| |
| Args: |
| response_id (str): 本次流式响应的唯一 ID。 |
| assistant_message_yielded (bool): 标记在流传输过程中是否已成功生成并发送了至少一个有效的助手消息块 (content 或 tool_calls)。 |
| actual_finish_reason (str): 从 Gemini API 获取的实际完成原因 (例如 "STOP", "MAX_TOKENS", "SAFETY" 等)。 |
| safety_issue_detail_received (Optional[Dict[str, Any]]): 如果完成原因是 SAFETY,这里会包含安全问题的详细信息。 |
| |
| Yields: |
| str: Server-Sent Events (SSE) 格式的字符串,包含结束块、错误块或最终的 [DONE] 标记。 |
| """ |
| if not assistant_message_yielded: |
| |
| if actual_finish_reason == "STOP": |
| if safety_issue_detail_received: |
| |
| error_message_detail = f"模型因安全策略停止生成内容。详情: {safety_issue_detail_received}" |
| logger.warning(f"流 {response_id}: 结束时未产生助手内容,完成原因是 STOP,但检测到安全问题。向客户端发送安全提示。详情: {safety_issue_detail_received}") |
| error_code = "safety_block" |
| error_type = "model_error" |
| else: |
| |
| error_message_detail = f"模型返回 STOP 但未生成任何内容。可能是由于输入问题或模型内部错误。完成原因: {actual_finish_reason}" |
| logger.error(f"流 {response_id}: 结束时未产生助手内容,但完成原因是 STOP。向客户端发送错误。") |
| error_code = "empty_response" |
| error_type = "model_error" |
|
|
| |
| error_payload = { |
| "error": { |
| "message": error_message_detail, |
| "type": error_type, |
| "code": error_code |
| } |
| } |
| yield f"data: {json.dumps(error_payload)}\n\n" |
| else: |
| |
| logger.warning(f"流 {response_id}: 结束时未产生助手内容 (完成原因: {actual_finish_reason})。发送包含 finish_reason 的结束块。") |
| |
| end_chunk = { |
| "id": response_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": "ignored", |
| "choices": [{"delta": {}, "index": 0, "finish_reason": actual_finish_reason}] |
| } |
| yield f"data: {json.dumps(end_chunk)}\n\n" |
| else: |
| |
| |
| logger.debug(f"流 {response_id}: 正常结束,发送包含 finish_reason '{actual_finish_reason}' 的结束块。") |
| end_chunk = { |
| "id": response_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": "ignored", |
| "choices": [{"delta": {}, "index": 0, "finish_reason": actual_finish_reason}] |
| } |
| yield f"data: {json.dumps(end_chunk)}\n\n" |
|
|
| |
| |
| yield "data: [DONE]\n\n" |
|
|
|
|
| async def generate_stream_response( |
| |
| gemini_client_instance: GeminiClient, |
| chat_request: ChatCompletionRequest, |
| contents: List[Dict[str, Any]], |
| safety_settings: List[Dict[str, Any]], |
| system_instruction: Optional[str], |
| cached_content_id: Optional[str], |
| response_id: str, |
| |
| enable_native_caching: bool, |
| cache_manager_instance: CacheManager, |
| content_to_cache_on_success: Optional[Dict[str, Any]], |
| db_for_cache: Optional[AsyncSession], |
| user_id_for_mapping: Optional[str], |
| |
| key_manager: APIKeyManager, |
| selected_key: str, |
| model_name: str, |
| limits: Optional[Dict[str, Any]], |
| client_ip: str, |
| today_date_str_pt: str, |
| |
| |
| |
| ) -> AsyncGenerator[str, None]: |
| """ |
| 异步生成器函数,负责调用 Gemini API 的流式接口,处理返回的数据块, |
| 并将其格式化为 Server-Sent Events (SSE) 发送给客户端。 |
| 同时处理流结束、错误、Token 计数、缓存创建和 Key 状态更新等逻辑。 |
| |
| Args: |
| (参数说明见上方的类型提示) |
| |
| Yields: |
| str: Server-Sent Events (SSE) 格式的字符串数据块。 |
| 可能的块类型包括:内容块 (delta)、工具调用块 (tool_calls)、错误块 (error)、结束块 (finish_reason)、[DONE] 标记。 |
| """ |
| |
| stream_error_occurred = False |
| assistant_message_yielded = False |
| full_reply_content = "" |
| usage_metadata_received = None |
| actual_finish_reason = "stop" |
| safety_issue_detail_received = None |
| final_tool_calls = None |
|
|
| try: |
| |
| |
| async for chunk_data in gemini_client_instance.stream_chat( |
| request=chat_request, |
| contents=contents, |
| safety_settings=safety_settings, |
| system_instruction=system_instruction, |
| cached_content_id=cached_content_id |
| ): |
| |
| if isinstance(chunk_data, dict): |
| |
| if '_usage_metadata' in chunk_data: |
| usage_metadata_received = chunk_data['_usage_metadata'] |
| logger.debug(f"流 {response_id}: 接收到 usage metadata: {usage_metadata_received}") |
| continue |
| elif '_final_finish_reason' in chunk_data: |
| actual_finish_reason = chunk_data['_final_finish_reason'] |
| logger.debug(f"流 {response_id}: 接收到最终完成原因: {actual_finish_reason}") |
| continue |
| elif '_safety_issue' in chunk_data: |
| safety_issue_detail_received = chunk_data['_safety_issue'] |
| logger.warning(f"流 {response_id}: 接收到安全问题详情: {safety_issue_detail_received}") |
| continue |
| elif '_tool_calls' in chunk_data: |
| |
| final_tool_calls = chunk_data['_tool_calls'] |
| logger.info(f"流 {response_id}: 接收到工具调用: {final_tool_calls}") |
| |
| formatted_chunk = { |
| "id": response_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model_name, |
| "choices": [{ |
| "delta": {"role": "assistant", "tool_calls": final_tool_calls}, |
| "index": 0, |
| "finish_reason": None |
| }] |
| } |
| yield f"data: {json.dumps(formatted_chunk)}\n\n" |
| assistant_message_yielded = True |
| continue |
|
|
| elif isinstance(chunk_data, str): |
| |
| if chunk_data: |
| |
| formatted_chunk = { |
| "id": response_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model_name, |
| "choices": [{ |
| "delta": {"role": "assistant", "content": chunk_data}, |
| "index": 0, |
| "finish_reason": None |
| }] |
| } |
| yield f"data: {json.dumps(formatted_chunk)}\n\n" |
| assistant_message_yielded = True |
| full_reply_content += chunk_data |
| else: |
| |
| logger.warning(f"流 {response_id}: 接收到未知类型的块: {type(chunk_data)}") |
|
|
|
|
| |
| if not stream_error_occurred: |
| |
| async for end_chunk_data in handle_stream_end( |
| response_id, |
| assistant_message_yielded, |
| actual_finish_reason, |
| safety_issue_detail_received |
| ): |
| yield end_chunk_data |
|
|
| |
| |
| if assistant_message_yielded or final_tool_calls: |
| |
| if usage_metadata_received: |
| prompt_tokens = usage_metadata_received.get('promptTokenCount') |
| |
| |
| logger.debug(f"流 {response_id}: 请求成功,更新 Key {selected_key[:8]}... ({model_name}) 的 Token 计数 (占位符)。") |
| else: |
| |
| logger.warning(f"流 {response_id}: 响应成功但未找到 usage metadata。Token 计数未更新。") |
|
|
| |
| with usage_lock: |
| |
| key_usage = usage_data.setdefault(selected_key, defaultdict(lambda: defaultdict(int)))[model_name] |
| key_usage['last_used_timestamp'] = time.time() |
| logger.debug(f"流 {response_id}: 请求成功,更新 Key {selected_key[:8]}... ({model_name}) 的 last_used_timestamp") |
|
|
| |
| if user_id_for_mapping and db_for_cache: |
| try: |
| |
| await key_manager.update_user_key_association(db_for_cache, user_id_for_mapping, selected_key) |
| logger.debug(f"流 {response_id}: 请求成功,更新用户 {user_id_for_mapping} 与 Key {selected_key[:8]}... 的关联。") |
| except Exception as assoc_err: |
| |
| logger.error(f"流 {response_id}: 更新用户 Key 关联失败: {assoc_err}", exc_info=True) |
| elif user_id_for_mapping and not db_for_cache: |
| |
| logger.warning(f"流 {response_id}: db session 无效,跳过用户 Key 关联更新。") |
|
|
|
|
| |
| if enable_native_caching and content_to_cache_on_success: |
| logger.debug(f"流 {response_id}: 请求成功且是缓存未命中,尝试创建新缓存 (Key: {selected_key[:8]}...)") |
| try: |
| |
| if db_for_cache and user_id_for_mapping is not None: |
| |
| api_key_id = await key_manager.get_key_id(selected_key) |
| if api_key_id is not None: |
| |
| new_cache_id = await cache_manager_instance.create_cache( |
| db=db_for_cache, |
| user_id=user_id_for_mapping, |
| api_key_id=api_key_id, |
| content=content_to_cache_on_success, |
| ttl=3600 |
| ) |
| if new_cache_id: |
| logger.info(f"流 {response_id}: 新缓存创建成功: {new_cache_id} (Key: {selected_key[:8]}...)") |
| |
| else: |
| logger.warning(f"流 {response_id}: 创建新缓存失败 (Key: {selected_key[:8]}...)") |
| else: |
| logger.warning(f"流 {response_id}: 无法获取 Key {selected_key[:8]}... 的 ID,跳过缓存创建。") |
| else: |
| logger.warning(f"流 {response_id}: db session 或 user_id 无效,跳过缓存创建。") |
| except Exception as cache_create_err: |
| |
| logger.error(f"流 {response_id}: 创建缓存时发生异常 (Key: {selected_key[:8]}...): {cache_create_err}", exc_info=True) |
|
|
| |
| |
| if config.STREAM_SAVE_REPLY and user_id_for_mapping and (full_reply_content or final_tool_calls): |
| logger.info(f"流 {response_id}: STREAM_SAVE_REPLY 已启用,准备保存流式响应上下文。") |
| try: |
| |
| await save_context_after_success( |
| proxy_key=user_id_for_mapping, |
| contents_to_send=contents, |
| model_reply_content=full_reply_content, |
| model_name=model_name, |
| enable_context=True, |
| final_tool_calls=final_tool_calls |
| ) |
| logger.info(f"流 {response_id}: 流式响应上下文已保存。") |
| except Exception as context_save_err: |
| logger.error(f"流 {response_id}: 保存流式响应上下文失败: {context_save_err}", exc_info=True) |
| elif config.STREAM_SAVE_REPLY: |
| logger.debug(f"流 {response_id}: STREAM_SAVE_REPLY 已启用,但无内容 ({'有' if full_reply_content else '无'}文本, {'有' if final_tool_calls else '无'}工具调用) 或无 user_id_for_mapping ({user_id_for_mapping}),跳过上下文保存。") |
| else: |
| logger.debug(f"流 {response_id}: STREAM_SAVE_REPLY 未启用,跳过流式响应上下文保存。") |
|
|
| except asyncio.CancelledError: |
| |
| logger.info(f"流 {response_id}: 客户端连接已中断 (IP: {client_ip})") |
| |
| except httpx.HTTPStatusError as http_err: |
| |
| logger.error(f"流 {response_id}: API HTTP 错误: {http_err.response.status_code} - {http_err.response.text}", exc_info=False) |
| stream_error_occurred = True |
| |
| error_info = {"message": f"API Error: {http_err.response.status_code}", "type": "api_error", "code": http_err.response.status_code} |
| yield f"data: {json.dumps({'error': error_info})}\n\n" |
| yield "data: [DONE]\n\n" |
| except Exception as stream_e: |
| |
| logger.error(f"流 {response_id}: 处理中捕获到意外异常: {stream_e}", exc_info=True) |
| stream_error_occurred = True |
| |
| error_info = { |
| "message": f"流处理中发生意外异常: {stream_e}", |
| "type": "internal_error", |
| "code": 500 |
| } |
| yield f"data: {json.dumps({'error': error_info})}\n\n" |
| yield "data: [DONE]\n\n" |
|
|