|
|
import asyncio |
|
|
import base64 |
|
|
import json |
|
|
import random |
|
|
import string |
|
|
import time |
|
|
import uuid |
|
|
from functools import wraps |
|
|
from typing import Union, Callable, Any, AsyncGenerator, Dict |
|
|
|
|
|
from curl_cffi.requests.exceptions import RequestException |
|
|
from sse_starlette import EventSourceResponse |
|
|
from starlette.responses import JSONResponse |
|
|
|
|
|
from app.errors import CursorWebError |
|
|
from app.models import ChatCompletionRequest, Usage, ToolCall, Message |
|
|
|
|
|
|
|
|
async def safe_stream_wrapper( |
|
|
generator_func, *args, **kwargs |
|
|
) -> Union[EventSourceResponse, JSONResponse]: |
|
|
""" |
|
|
安全的流响应包装器 |
|
|
先执行生成器获取第一个值,如果成功才创建流响应 |
|
|
""" |
|
|
|
|
|
generator = generator_func(*args, **kwargs) |
|
|
|
|
|
|
|
|
first_item = await generator.__anext__() |
|
|
|
|
|
|
|
|
async def wrapped_generator(): |
|
|
|
|
|
yield first_item |
|
|
|
|
|
async for item in generator: |
|
|
yield item |
|
|
|
|
|
|
|
|
return EventSourceResponse( |
|
|
wrapped_generator(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no", |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
async def error_wrapper(func: Callable, *args, **kwargs) -> Any: |
|
|
from .config import MAX_RETRIES |
|
|
for attempt in range(MAX_RETRIES + 1): |
|
|
try: |
|
|
return await func(*args, **kwargs) |
|
|
except (CursorWebError, RequestException) as e: |
|
|
|
|
|
|
|
|
if attempt == MAX_RETRIES: |
|
|
if isinstance(e, CursorWebError): |
|
|
return JSONResponse( |
|
|
e.to_openai_error(), |
|
|
status_code=e.response_status_code |
|
|
) |
|
|
elif isinstance(e, RequestException): |
|
|
return JSONResponse( |
|
|
{ |
|
|
'error': { |
|
|
'message': str(e), |
|
|
"type": "http_error", |
|
|
"code": "http_error" |
|
|
} |
|
|
}, |
|
|
status_code=500 |
|
|
) |
|
|
|
|
|
if attempt < MAX_RETRIES: |
|
|
continue |
|
|
return None |
|
|
|
|
|
|
|
|
def decode_base64url_safe(data): |
|
|
"""使用安全的base64url解码""" |
|
|
|
|
|
missing_padding = len(data) % 4 |
|
|
if missing_padding: |
|
|
data += '=' * (4 - missing_padding) |
|
|
|
|
|
return base64.urlsafe_b64decode(data) |
|
|
|
|
|
|
|
|
def to_async(sync_func): |
|
|
@wraps(sync_func) |
|
|
async def async_wrapper(*args): |
|
|
loop = asyncio.get_running_loop() |
|
|
return await loop.run_in_executor(None, sync_func, *args) |
|
|
|
|
|
return async_wrapper |
|
|
|
|
|
|
|
|
def generate_random_string(length): |
|
|
""" |
|
|
生成一个指定长度的随机字符串,包含大小写字母和数字。 |
|
|
""" |
|
|
|
|
|
characters = string.ascii_letters + string.digits |
|
|
|
|
|
|
|
|
random_string = ''.join(random.choice(characters) for _ in range(length)) |
|
|
return random_string |
|
|
|
|
|
|
|
|
def normalize_tool_name(name: str) -> str: |
|
|
"""将工具名统一标准化:将所有下划线替换为连字符""" |
|
|
return name.replace('_', '-') |
|
|
|
|
|
|
|
|
def match_tool_name(tool_name: str, available_tools: list[str]) -> str: |
|
|
""" |
|
|
匹配工具名称,如果不在列表中则尝试标准化匹配 |
|
|
|
|
|
Args: |
|
|
tool_name: 需要匹配的工具名 |
|
|
available_tools: 可用的工具名列表 |
|
|
|
|
|
Returns: |
|
|
匹配到的实际工具名,如果没有匹配返回原名称 |
|
|
""" |
|
|
|
|
|
if tool_name in available_tools: |
|
|
return tool_name |
|
|
|
|
|
|
|
|
normalized_input = normalize_tool_name(tool_name) |
|
|
for available_tool in available_tools: |
|
|
if normalize_tool_name(available_tool) == normalized_input: |
|
|
return available_tool |
|
|
|
|
|
|
|
|
return tool_name |
|
|
|
|
|
|
|
|
async def non_stream_chat_completion( |
|
|
request: ChatCompletionRequest, |
|
|
generator: AsyncGenerator[str, None] |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
非流式响应:接受外部异步生成器,收集所有输出返回完整响应 |
|
|
""" |
|
|
|
|
|
full_content = "" |
|
|
tool_calls = [] |
|
|
usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) |
|
|
async for chunk in generator: |
|
|
if isinstance(chunk, Usage): |
|
|
usage = chunk |
|
|
continue |
|
|
if isinstance(chunk, ToolCall): |
|
|
tool_calls.append({ |
|
|
"id": chunk.toolId, |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": chunk.toolName, |
|
|
"arguments": chunk.toolInput, |
|
|
} |
|
|
}) |
|
|
continue |
|
|
full_content += chunk |
|
|
|
|
|
|
|
|
response = { |
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:29]}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": full_content, |
|
|
"tool_calls": tool_calls |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": usage.prompt_tokens, |
|
|
"completion_tokens": usage.completion_tokens, |
|
|
"total_tokens": usage.total_tokens |
|
|
} |
|
|
} |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
async def stream_chat_completion( |
|
|
request: ChatCompletionRequest, |
|
|
generator: AsyncGenerator[str, None] |
|
|
) -> AsyncGenerator[Dict[str, Any], None]: |
|
|
""" |
|
|
流式响应:接受外部异步生成器,包装成OpenAI SSE格式 |
|
|
""" |
|
|
chat_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" |
|
|
created_time = int(time.time()) |
|
|
|
|
|
is_send_init = False |
|
|
|
|
|
|
|
|
initial_response = { |
|
|
"id": chat_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": {"role": "assistant", "content": ""}, |
|
|
"finish_reason": None |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
usage = None |
|
|
tool_call_idx = 0 |
|
|
async for chunk in generator: |
|
|
if not is_send_init: |
|
|
yield { |
|
|
"data": json.dumps(initial_response, ensure_ascii=False) |
|
|
} |
|
|
is_send_init = True |
|
|
if isinstance(chunk, Usage): |
|
|
usage = chunk |
|
|
continue |
|
|
|
|
|
if isinstance(chunk, ToolCall): |
|
|
data = { |
|
|
"id": chat_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"tool_calls": [ |
|
|
{ |
|
|
"index": tool_call_idx, |
|
|
"id": chunk.toolId, |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": chunk.toolName, |
|
|
"arguments": chunk.toolInput, |
|
|
}, |
|
|
} |
|
|
] |
|
|
}, |
|
|
"finish_reason": None, |
|
|
} |
|
|
], |
|
|
} |
|
|
tool_call_idx += 1 |
|
|
yield {'data': json.dumps(data, ensure_ascii=False)} |
|
|
continue |
|
|
|
|
|
chunk_response = { |
|
|
"id": chat_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": {"content": chunk}, |
|
|
"finish_reason": None |
|
|
} |
|
|
] |
|
|
} |
|
|
yield {"data": json.dumps(chunk_response, ensure_ascii=False)} |
|
|
|
|
|
|
|
|
final_response = { |
|
|
"id": chat_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": {}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
] |
|
|
} |
|
|
yield {"data": json.dumps(final_response, ensure_ascii=False)} |
|
|
if usage: |
|
|
usage_data = {"id": chat_id, "object": "chat.completion.chunk", |
|
|
"created": created_time, "model": request.model, |
|
|
"choices": [], |
|
|
"usage": {"prompt_tokens": usage.prompt_tokens, |
|
|
"completion_tokens": usage.completion_tokens, |
|
|
"total_tokens": usage.total_tokens, "prompt_tokens_details": { |
|
|
"cached_tokens": 0, |
|
|
"text_tokens": 0, |
|
|
"audio_tokens": 0, |
|
|
"image_tokens": 0 |
|
|
}, |
|
|
"completion_tokens_details": { |
|
|
"text_tokens": 0, |
|
|
"audio_tokens": 0, |
|
|
"reasoning_tokens": 0 |
|
|
}, |
|
|
"input_tokens": 0, |
|
|
"output_tokens": 0, |
|
|
"input_tokens_details": None} |
|
|
} |
|
|
|
|
|
yield { |
|
|
"data": json.dumps(usage_data, ensure_ascii=False) |
|
|
} |
|
|
yield {"data": "[DONE]"} |
|
|
|
|
|
|
|
|
async def empty_retry_wrapper( |
|
|
cursor_chat_func: Callable, |
|
|
request: ChatCompletionRequest, |
|
|
max_retries: int = 3 |
|
|
) -> AsyncGenerator[Union[str, Usage, ToolCall], None]: |
|
|
""" |
|
|
空回复重试包装器:检测到空回复时自动重试 |
|
|
|
|
|
Args: |
|
|
cursor_chat_func: cursor_chat函数 |
|
|
request: 聊天请求 |
|
|
max_retries: 最大重试次数 |
|
|
|
|
|
Yields: |
|
|
str/Usage/ToolCall: 流式输出 |
|
|
|
|
|
Raises: |
|
|
CursorWebError: 重试后仍然空回复 |
|
|
""" |
|
|
for retry_count in range(max_retries + 1): |
|
|
generator = cursor_chat_func(request) |
|
|
has_content = False |
|
|
|
|
|
async for chunk in generator: |
|
|
if isinstance(chunk, ToolCall): |
|
|
|
|
|
has_content = True |
|
|
yield chunk |
|
|
return |
|
|
|
|
|
elif isinstance(chunk, Usage): |
|
|
|
|
|
yield chunk |
|
|
|
|
|
else: |
|
|
|
|
|
has_content = True |
|
|
yield chunk |
|
|
|
|
|
|
|
|
if has_content: |
|
|
return |
|
|
|
|
|
|
|
|
if retry_count < max_retries: |
|
|
continue |
|
|
|
|
|
|
|
|
raise CursorWebError(200, f"空回复重试{max_retries}次后仍然失败") |
|
|
|
|
|
|
|
|
async def truncation_continue_wrapper( |
|
|
cursor_chat_func: Callable, |
|
|
request: ChatCompletionRequest, |
|
|
max_retries: int = 10 |
|
|
) -> AsyncGenerator[Union[str, Usage, ToolCall], None]: |
|
|
""" |
|
|
截断继续包装器:实时流式输出,检测到截断时自动重试 |
|
|
|
|
|
Args: |
|
|
cursor_chat_func: cursor_chat函数 |
|
|
request: 聊天请求 |
|
|
max_retries: 最大重试次数 |
|
|
|
|
|
Yields: |
|
|
str/Usage/ToolCall: 流式输出 |
|
|
""" |
|
|
full_content = "" |
|
|
total_prompt_tokens = 0 |
|
|
total_completion_tokens = 0 |
|
|
total_tokens = 0 |
|
|
current_usage = None |
|
|
|
|
|
for retry_count in range(max_retries + 1): |
|
|
generator = cursor_chat_func(request) |
|
|
current_content = "" |
|
|
is_truncated = False |
|
|
buffer = "" |
|
|
buffer_yielded = False |
|
|
|
|
|
async for chunk in generator: |
|
|
if isinstance(chunk, Usage): |
|
|
current_usage = chunk |
|
|
|
|
|
total_prompt_tokens += chunk.prompt_tokens |
|
|
total_completion_tokens += chunk.completion_tokens |
|
|
total_tokens += chunk.total_tokens |
|
|
|
|
|
|
|
|
is_truncated = chunk.completion_tokens == 4096 |
|
|
break |
|
|
|
|
|
elif isinstance(chunk, ToolCall): |
|
|
|
|
|
yield chunk |
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
current_content += chunk |
|
|
|
|
|
if retry_count == 0: |
|
|
|
|
|
yield chunk |
|
|
else: |
|
|
|
|
|
buffer += chunk |
|
|
last_10_chars = full_content[-10:] if len(full_content) >= 10 else full_content |
|
|
|
|
|
if not buffer_yielded: |
|
|
|
|
|
if last_10_chars and last_10_chars in buffer: |
|
|
|
|
|
buffer = buffer.replace(last_10_chars, "", 1) |
|
|
if buffer: |
|
|
yield buffer |
|
|
buffer = "" |
|
|
buffer_yielded = True |
|
|
elif len(buffer) > 20: |
|
|
|
|
|
yield buffer |
|
|
buffer = "" |
|
|
buffer_yielded = True |
|
|
else: |
|
|
|
|
|
yield chunk |
|
|
buffer = "" |
|
|
|
|
|
|
|
|
if retry_count > 0 and buffer: |
|
|
last_10_chars = full_content[-10:] if len(full_content) >= 10 else full_content |
|
|
if not buffer_yielded and last_10_chars and last_10_chars in buffer: |
|
|
buffer = buffer.replace(last_10_chars, "", 1) |
|
|
if buffer: |
|
|
yield buffer |
|
|
|
|
|
|
|
|
full_content += current_content |
|
|
|
|
|
|
|
|
if not is_truncated: |
|
|
|
|
|
if current_usage: |
|
|
yield current_usage |
|
|
return |
|
|
|
|
|
|
|
|
last_10_chars = full_content[-10:] if len(full_content) >= 10 else full_content |
|
|
continue_prompt = f'''你的回复在"{last_10_chars}"处意外中断。 |
|
|
|
|
|
请直接从该处继续输出,遵循以下规则: |
|
|
1. 以"{last_10_chars}"开头,紧接新内容 |
|
|
2. 若在代码块中,直接续写代码,禁止重复```标记或语言标识 |
|
|
3. 保持原有的格式、缩进和上下文 |
|
|
|
|
|
错误示例:截断于"document." |
|
|
❌ ```javascript\nlet a=1;\ndocument.createElement... |
|
|
|
|
|
正确示例: |
|
|
✅ document.createElement... |
|
|
|
|
|
立即继续,不要解释或重新开始。''' |
|
|
|
|
|
|
|
|
new_messages = request.messages.copy() |
|
|
new_messages.append(Message(role="assistant", content=full_content, tool_calls=None, tool_call_id=None)) |
|
|
new_messages.append(Message(role="user", content=continue_prompt, tool_calls=None, tool_call_id=None)) |
|
|
|
|
|
request = ChatCompletionRequest( |
|
|
messages=new_messages, |
|
|
stream=request.stream, |
|
|
model=request.model, |
|
|
tools=request.tools |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if current_usage: |
|
|
yield current_usage |
|
|
|