Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import uuid | |
| import httpx | |
| import re | |
| import asyncio | |
| import xml.etree.ElementTree as ET | |
| import logging | |
| import struct | |
| import base64 | |
| import copy | |
| from fastapi import FastAPI, HTTPException, Request, Header, Depends | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Any, Union | |
| from dotenv import load_dotenv | |
| from json_repair import repair_json | |
| # Configure logging | |
| # logging.basicConfig(level=logging.INFO) # for dev | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Ki2API - Claude Sonnet 4 OpenAI Compatible API", | |
| description="OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", | |
| version="3.0.1" | |
| ) | |
| # Configuration | |
| API_KEY = os.getenv("API_KEY") | |
| KIRO_ACCESS_TOKEN = os.getenv("KIRO_ACCESS_TOKEN") | |
| KIRO_REFRESH_TOKEN = os.getenv("KIRO_REFRESH_TOKEN") | |
| KIRO_BASE_URL = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" | |
| PROFILE_ARN = "arn:aws:codewhisperer:us-east-1:699475941385:profile/EHGA3GRVQMUK" | |
| # Model mapping | |
| MODEL_MAP = { | |
| "claude-sonnet-4-20250514": "CLAUDE_SONNET_4_20250514_V1_0", | |
| "claude-3-5-haiku-20241022": "CLAUDE_3_7_SONNET_20250219_V1_0", | |
| } | |
| DEFAULT_MODEL = "claude-sonnet-4-20250514" | |
| # Pydantic models for OpenAI compatibility | |
| class ImageUrl(BaseModel): | |
| url: str | |
| detail: Optional[str] = "auto" | |
| class ContentPart(BaseModel): | |
| type: str | |
| text: Optional[str] = None | |
| image_url: Optional[ImageUrl] = None | |
| class ToolCall(BaseModel): | |
| id: str | |
| type: str = "function" | |
| function: Dict[str, Any] | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Union[str, List[ContentPart], None] | |
| tool_calls: Optional[List[ToolCall]] = None | |
| tool_call_id: Optional[str] = None # 用于 tool 角色的消息 | |
| def get_content_text(self) -> str: | |
| """Extract text content from either string or content parts""" | |
| # Handle None content | |
| if self.content is None: | |
| logger.warning(f"Message with role '{self.role}' has None content") | |
| return "" | |
| if isinstance(self.content, str): | |
| return self.content | |
| elif isinstance(self.content, list): | |
| text_parts = [] | |
| for part in self.content: | |
| if isinstance(part, dict): | |
| if part.get("type") == "text" and "text" in part: | |
| text_parts.append(part.get("text", "")) | |
| elif part.get("type") == "tool_result" and "content" in part: | |
| text_parts.append(part.get("content", "")) | |
| elif hasattr(part, 'text') and part.text: | |
| text_parts.append(part.text) | |
| return "".join(text_parts) | |
| else: | |
| logger.warning(f"Unexpected content type: {type(self.content)}") | |
| return str(self.content) if self.content else "" | |
| class Function(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| parameters: Optional[Dict[str, Any]] = None | |
| class Tool(BaseModel): | |
| type: str = "function" | |
| function: Function | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = 4000 | |
| stream: Optional[bool] = False | |
| top_p: Optional[float] = 1.0 | |
| frequency_penalty: Optional[float] = 0.0 | |
| presence_penalty: Optional[float] = 0.0 | |
| stop: Optional[Union[str, List[str]]] = None | |
| user: Optional[str] = None | |
| tools: Optional[List[Tool]] = None | |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| prompt_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"cached_tokens": 0}) | |
| completion_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"reasoning_tokens": 0}) | |
| class ResponseMessage(BaseModel): | |
| role: str | |
| content: Optional[str] = None | |
| tool_calls: Optional[List[ToolCall]] = None | |
| class Choice(BaseModel): | |
| index: int | |
| message: ResponseMessage | |
| logprobs: Optional[Any] = None | |
| finish_reason: str | |
| class StreamChoice(BaseModel): | |
| index: int | |
| delta: Dict[str, Any] | |
| logprobs: Optional[Any] = None | |
| finish_reason: Optional[str] = None | |
| class ChatCompletionResponse(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") | |
| object: str = "chat.completion" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str | |
| system_fingerprint: Optional[str] = "fp_ki2api_v3" | |
| choices: List[Choice] | |
| usage: Usage | |
| class ChatCompletionStreamResponse(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") | |
| object: str = "chat.completion.chunk" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str | |
| system_fingerprint: Optional[str] = "fp_ki2api_v3" | |
| choices: List[StreamChoice] | |
| usage: Optional[Usage] = None | |
| class ErrorResponse(BaseModel): | |
| error: Dict[str, Any] | |
| # Authentication | |
| async def verify_api_key(authorization: str = Header(None)): | |
| if not authorization: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "You didn't provide an API key.", | |
| "type": "invalid_request_error", | |
| "param": None, | |
| "code": "invalid_api_key" | |
| } | |
| } | |
| ) | |
| if not authorization.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "Invalid API key format. Expected 'Bearer <key>'", | |
| "type": "invalid_request_error", | |
| "param": None, | |
| "code": "invalid_api_key" | |
| } | |
| } | |
| ) | |
| api_key = authorization.replace("Bearer ", "") | |
| if api_key != API_KEY: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "Invalid API key provided", | |
| "type": "invalid_request_error", | |
| "param": None, | |
| "code": "invalid_api_key" | |
| } | |
| } | |
| ) | |
| return api_key | |
| # Token management | |
| class TokenManager: | |
| def __init__(self): | |
| self.access_token = KIRO_ACCESS_TOKEN | |
| self.refresh_token = KIRO_REFRESH_TOKEN | |
| self.refresh_url = "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" | |
| self.last_refresh_time = 0 | |
| self.refresh_lock = asyncio.Lock() | |
| async def refresh_tokens(self): | |
| """刷新token,使用锁防止并发刷新请求""" | |
| if not self.refresh_token: | |
| logger.error("没有刷新token,无法刷新访问token") | |
| return None | |
| async with self.refresh_lock: | |
| # 检查是否在短时间内已经刷新过 | |
| current_time = time.time() | |
| if current_time - self.last_refresh_time < 5: | |
| logger.info("最近已刷新token,使用现有token") | |
| return self.access_token | |
| try: | |
| logger.info("开始刷新token...") | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| self.refresh_url, | |
| json={"refreshToken": self.refresh_token}, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| if "accessToken" not in data: | |
| logger.error(f"刷新token响应中没有accessToken: {data}") | |
| return None | |
| self.access_token = data.get("accessToken") | |
| self.last_refresh_time = current_time | |
| logger.info("token刷新成功") | |
| # 更新环境变量 | |
| os.environ["KIRO_ACCESS_TOKEN"] = self.access_token | |
| return self.access_token | |
| except Exception as e: | |
| logger.error(f"token刷新失败: {str(e)}") | |
| return None | |
| def get_token(self): | |
| return self.access_token | |
| token_manager = TokenManager() | |
| # XML Tool Call Parser (from version 1) | |
| def parse_xml_tool_calls(response_text: str) -> Optional[List[ToolCall]]: | |
| """解析CodeWhisperer返回的XML格式工具调用,转换为OpenAI格式""" | |
| if not response_text: | |
| return None | |
| tool_calls = [] | |
| logger.info(f"🔍 开始解析XML工具调用,响应文本长度: {len(response_text)}") | |
| # 方法1: 解析 <tool_use> 标签格式 | |
| tool_use_pattern = r'<tool_use>\s*<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>\s*</tool_use>' | |
| matches = re.finditer(tool_use_pattern, response_text, re.DOTALL | re.IGNORECASE) | |
| for match in matches: | |
| function_name = match.group(1).strip() | |
| param_name = match.group(2).strip() | |
| param_value = match.group(3).strip() | |
| arguments = {param_name: param_value} | |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" | |
| tool_call = ToolCall( | |
| id=tool_call_id, | |
| type="function", | |
| function={ | |
| "name": function_name, | |
| "arguments": json.dumps(arguments, ensure_ascii=False) | |
| } | |
| ) | |
| tool_calls.append(tool_call) | |
| logger.info(f"✅ 解析到工具调用: {function_name} with {param_name}={param_value}") | |
| # 方法2: 解析简单的 <tool_name> 格式 | |
| if not tool_calls: | |
| simple_pattern = r'<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>' | |
| matches = re.finditer(simple_pattern, response_text, re.DOTALL | re.IGNORECASE) | |
| for match in matches: | |
| function_name = match.group(1).strip() | |
| param_name = match.group(2).strip() | |
| param_value = match.group(3).strip() | |
| arguments = {param_name: param_value} | |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" | |
| tool_call = ToolCall( | |
| id=tool_call_id, | |
| type="function", | |
| function={ | |
| "name": function_name, | |
| "arguments": json.dumps(arguments, ensure_ascii=False) | |
| } | |
| ) | |
| tool_calls.append(tool_call) | |
| logger.info(f"✅ 解析到简单工具调用: {function_name} with {param_name}={param_value}") | |
| # 方法3: 解析只有工具名的情况 | |
| if not tool_calls: | |
| name_only_pattern = r'<tool_name>([^<]+)</tool_name>' | |
| matches = re.finditer(name_only_pattern, response_text, re.IGNORECASE) | |
| for match in matches: | |
| function_name = match.group(1).strip() | |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" | |
| tool_call = ToolCall( | |
| id=tool_call_id, | |
| type="function", | |
| function={ | |
| "name": function_name, | |
| "arguments": "{}" | |
| } | |
| ) | |
| tool_calls.append(tool_call) | |
| logger.info(f"✅ 解析到无参数工具调用: {function_name}") | |
| if tool_calls: | |
| logger.info(f"🎉 总共解析出 {len(tool_calls)} 个工具调用") | |
| return tool_calls | |
| else: | |
| logger.info("❌ 未发现任何XML格式的工具调用") | |
| return None | |
| def find_matching_bracket(text: str, start_pos: int) -> int: | |
| """找到匹配的结束括号位置""" | |
| logger.info(f"🔧 FIND BRACKET: text length={len(text)}, start_pos={start_pos}") | |
| logger.info(f"🔧 FIND BRACKET: First 100 chars: >>>{text[:100]}<<<") | |
| if not text or start_pos >= len(text) or text[start_pos] != '[': | |
| logger.info(f"🔧 FIND BRACKET: Early return -1, text[start_pos]={text[start_pos] if start_pos < len(text) else 'OOB'}") | |
| return -1 | |
| bracket_count = 1 | |
| in_string = False | |
| escape_next = False | |
| logger.info(f"🔧 FIND BRACKET: Starting search from position {start_pos + 1}") | |
| for i in range(start_pos + 1, len(text)): | |
| char = text[i] | |
| if escape_next: | |
| escape_next = False | |
| continue | |
| if char == '\\' and in_string: | |
| escape_next = True | |
| continue | |
| if char == '"' and not escape_next: | |
| in_string = not in_string | |
| logger.info(f"🔧 FIND BRACKET: Toggle string mode at {i}, in_string={in_string}") | |
| continue | |
| if not in_string: | |
| if char == '[': | |
| bracket_count += 1 | |
| logger.info(f"🔧 FIND BRACKET: [ at {i}, bracket_count={bracket_count}") | |
| elif char == ']': | |
| bracket_count -= 1 | |
| logger.info(f"🔧 FIND BRACKET: ] at {i}, bracket_count={bracket_count}") | |
| if bracket_count == 0: # 只检查方括号匹配,不管花括号 | |
| logger.info(f"🔧 FIND BRACKET: Found matching ] at position {i}") | |
| logger.info(f"🔧 FIND BRACKET: Complete match: >>>{text[start_pos:i+1]}<<<") | |
| return i | |
| logger.info(f"🔧 FIND BRACKET: No matching bracket found, returning -1") | |
| logger.info(f"🔧 FIND BRACKET: Final bracket_count={bracket_count}") | |
| return -1 | |
| def parse_single_tool_call_professional(tool_call_text: str) -> Optional[ToolCall]: | |
| """专业的工具调用解析器 - 使用json_repair库""" | |
| logger.info(f"🔧 开始解析工具调用文本 (长度: {len(tool_call_text)})") | |
| # 步骤1: 提取函数名 | |
| name_pattern = r'\[Called\s+(\w+)\s+with\s+args:' | |
| name_match = re.search(name_pattern, tool_call_text, re.IGNORECASE) | |
| if not name_match: | |
| logger.warning("⚠️ 无法从文本中提取函数名") | |
| return None | |
| function_name = name_match.group(1).strip() | |
| logger.info(f"✅ 提取到函数名: {function_name}") | |
| # 步骤2: 提取JSON参数部分 | |
| # 找到 "with args:" 之后的位置 | |
| args_start_marker = "with args:" | |
| args_start_pos = tool_call_text.lower().find(args_start_marker.lower()) | |
| if args_start_pos == -1: | |
| logger.error("❌ 找不到 'with args:' 标记") | |
| return None | |
| # 从 "with args:" 后开始 | |
| args_start = args_start_pos + len(args_start_marker) | |
| # 找到最后的 ']' | |
| args_end = tool_call_text.rfind(']') | |
| if args_end <= args_start: | |
| logger.error("❌ 找不到结束的 ']'") | |
| return None | |
| # 提取可能包含JSON的部分 | |
| json_candidate = tool_call_text[args_start:args_end].strip() | |
| logger.info(f"📝 提取的JSON候选文本长度: {len(json_candidate)}") | |
| # 步骤3: 修复并解析JSON | |
| try: | |
| # 使用json_repair修复可能损坏的JSON | |
| repaired_json = repair_json(json_candidate) | |
| logger.info(f"🔧 JSON修复完成,修复后长度: {len(repaired_json)}") | |
| # 解析修复后的JSON | |
| arguments = json.loads(repaired_json) | |
| # 验证解析结果是字典 | |
| if not isinstance(arguments, dict): | |
| logger.error(f"❌ 解析结果不是字典类型: {type(arguments)}") | |
| return None | |
| # 创建工具调用对象 | |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" | |
| tool_call = ToolCall( | |
| id=tool_call_id, | |
| type="function", | |
| function={ | |
| "name": function_name, | |
| "arguments": json.dumps(arguments, ensure_ascii=False) | |
| } | |
| ) | |
| logger.info(f"✅ 成功创建工具调用: {function_name} (参数键: {list(arguments.keys())})") | |
| return tool_call | |
| except Exception as e: | |
| logger.error(f"❌ JSON修复/解析失败: {type(e).__name__}: {str(e)}") | |
| # 备用方案:尝试更激进的修复 | |
| try: | |
| # 查找第一个 { 和最后一个 } | |
| first_brace = json_candidate.find('{') | |
| last_brace = json_candidate.rfind('}') | |
| if first_brace != -1 and last_brace > first_brace: | |
| core_json = json_candidate[first_brace:last_brace + 1] | |
| # 再次尝试修复 | |
| repaired_core = repair_json(core_json) | |
| arguments = json.loads(repaired_core) | |
| if isinstance(arguments, dict): | |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" | |
| tool_call = ToolCall( | |
| id=tool_call_id, | |
| type="function", | |
| function={ | |
| "name": function_name, | |
| "arguments": json.dumps(arguments, ensure_ascii=False) | |
| } | |
| ) | |
| logger.info(f"✅ 备用方案成功: {function_name}") | |
| return tool_call | |
| except Exception as backup_error: | |
| logger.error(f"❌ 备用方案也失败了: {backup_error}") | |
| return None | |
| def parse_bracket_tool_calls_professional(response_text: str) -> Optional[List[ToolCall]]: | |
| """专业的批量工具调用解析器""" | |
| if not response_text or "[Called" not in response_text: | |
| logger.info("📭 响应文本中没有工具调用标记") | |
| return None | |
| tool_calls = [] | |
| errors = [] | |
| # 方法1: 使用改进的分割方法 | |
| try: | |
| # 找到所有 [Called 的位置 | |
| call_positions = [] | |
| start = 0 | |
| while True: | |
| pos = response_text.find("[Called", start) | |
| if pos == -1: | |
| break | |
| call_positions.append(pos) | |
| start = pos + 1 | |
| logger.info(f"🔍 找到 {len(call_positions)} 个潜在的工具调用") | |
| for i, start_pos in enumerate(call_positions): | |
| # 确定这个工具调用的结束位置 | |
| # 可能是下一个 [Called 的位置,或者文本结束 | |
| if i + 1 < len(call_positions): | |
| end_search_limit = call_positions[i + 1] | |
| else: | |
| end_search_limit = len(response_text) | |
| # 在限定范围内查找结束的 ] | |
| segment = response_text[start_pos:end_search_limit] | |
| # 查找匹配的结束括号 | |
| bracket_count = 0 | |
| end_pos = -1 | |
| for j, char in enumerate(segment): | |
| if char == '[': | |
| bracket_count += 1 | |
| elif char == ']': | |
| bracket_count -= 1 | |
| if bracket_count == 0: | |
| end_pos = start_pos + j | |
| break | |
| if end_pos == -1: | |
| # 如果没找到匹配的括号,尝试找最后一个 ] | |
| last_bracket = segment.rfind(']') | |
| if last_bracket != -1: | |
| end_pos = start_pos + last_bracket | |
| else: | |
| logger.warning(f"⚠️ 工具调用 {i+1} 没有找到结束括号") | |
| continue | |
| # 提取工具调用文本 | |
| tool_call_text = response_text[start_pos:end_pos + 1] | |
| logger.info(f"📋 提取工具调用 {i+1}, 长度: {len(tool_call_text)}") | |
| # 解析单个工具调用 | |
| parsed_call = parse_single_tool_call_professional(tool_call_text) | |
| if parsed_call: | |
| tool_calls.append(parsed_call) | |
| else: | |
| errors.append(f"工具调用 {i+1} 解析失败") | |
| except Exception as e: | |
| logger.error(f"❌ 批量解析过程出错: {type(e).__name__}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # 记录结果 | |
| if tool_calls: | |
| logger.info(f"🎉 成功解析 {len(tool_calls)} 个工具调用") | |
| for tc in tool_calls: | |
| logger.info(f" ✓ {tc.function['name']} (ID: {tc.id})") | |
| if errors: | |
| logger.warning(f"⚠️ 有 {len(errors)} 个解析失败:") | |
| for error in errors: | |
| logger.warning(f" ✗ {error}") | |
| return tool_calls if tool_calls else None | |
| # 为了确保兼容性,也更新原来的函数名 | |
| def parse_bracket_tool_calls(response_text: str) -> Optional[List[ToolCall]]: | |
| """向后兼容的函数名""" | |
| return parse_bracket_tool_calls_professional(response_text) | |
| def parse_single_tool_call(tool_call_text: str) -> Optional[ToolCall]: | |
| """向后兼容的函数名""" | |
| return parse_single_tool_call_professional(tool_call_text) | |
| # Add deduplication function | |
| def deduplicate_tool_calls(tool_calls: List[Union[Dict, ToolCall]]) -> List[ToolCall]: | |
| """Deduplicate tool calls based on function name and arguments""" | |
| seen = set() | |
| unique_tool_calls = [] | |
| for tool_call in tool_calls: | |
| # Convert to ToolCall if it's a dict | |
| if isinstance(tool_call, dict): | |
| tc = ToolCall( | |
| id=tool_call.get("id", f"call_{uuid.uuid4().hex[:8]}"), | |
| type=tool_call.get("type", "function"), | |
| function=tool_call.get("function", {}) | |
| ) | |
| else: | |
| tc = tool_call | |
| # Create unique key based on function name and arguments | |
| key = ( | |
| tc.function.get("name", ""), | |
| tc.function.get("arguments", "") | |
| ) | |
| if key not in seen: | |
| seen.add(key) | |
| unique_tool_calls.append(tc) | |
| else: | |
| logger.info(f"🔄 Skipping duplicate tool call: {tc.function.get('name', 'unknown')}") | |
| return unique_tool_calls | |
| def build_codewhisperer_request(request: ChatCompletionRequest): | |
| codewhisperer_model = MODEL_MAP.get(request.model, MODEL_MAP[DEFAULT_MODEL]) | |
| conversation_id = str(uuid.uuid4()) | |
| # Extract system prompt and user messages | |
| system_prompt = "" | |
| conversation_messages = [] | |
| for msg in request.messages: | |
| if msg.role == "system": | |
| system_prompt = msg.get_content_text() | |
| elif msg.role in ["user", "assistant", "tool"]: | |
| conversation_messages.append(msg) | |
| if not conversation_messages: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": { | |
| "message": "No conversation messages found", | |
| "type": "invalid_request_error", | |
| "param": "messages", | |
| "code": "invalid_request" | |
| } | |
| } | |
| ) | |
| # Build history - only include user/assistant pairs | |
| history = [] | |
| # Process history messages (all except the last one) | |
| if len(conversation_messages) > 1: | |
| history_messages = conversation_messages[:-1] | |
| # Build user messages list (combining tool results with user messages) | |
| processed_messages = [] | |
| i = 0 | |
| while i < len(history_messages): | |
| msg = history_messages[i] | |
| if msg.role == "user": | |
| content = msg.get_content_text() or "Continue" | |
| processed_messages.append(("user", content)) | |
| i += 1 | |
| elif msg.role == "assistant": | |
| # Check if this assistant message contains tool calls | |
| if hasattr(msg, 'tool_calls') and msg.tool_calls: | |
| # Build a description of the tool calls | |
| tool_descriptions = [] | |
| for tc in msg.tool_calls: | |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" | |
| args = tc.function.get("arguments", "{}") if isinstance(tc.function, dict) else "{}" | |
| tool_descriptions.append(f"[Called {func_name} with args: {args}]") | |
| content = " ".join(tool_descriptions) | |
| logger.info(f"📌 Processing assistant message with tool calls: {content}") | |
| else: | |
| content = msg.get_content_text() or "I understand." | |
| processed_messages.append(("assistant", content)) | |
| i += 1 | |
| elif msg.role == "tool": | |
| # Combine tool results into the next user message | |
| tool_content = msg.get_content_text() or "[Tool executed]" | |
| tool_call_id = getattr(msg, 'tool_call_id', 'unknown') | |
| # Format tool result with ID for tracking | |
| formatted_tool_result = f"[Tool result for {tool_call_id}]: {tool_content}" | |
| # Look ahead to see if there's a user message | |
| if i + 1 < len(history_messages) and history_messages[i + 1].role == "user": | |
| user_content = history_messages[i + 1].get_content_text() or "" | |
| combined_content = f"{formatted_tool_result}\n{user_content}".strip() | |
| processed_messages.append(("user", combined_content)) | |
| i += 2 | |
| else: | |
| # Tool result without following user message - add as user message | |
| processed_messages.append(("user", formatted_tool_result)) | |
| i += 1 | |
| else: | |
| i += 1 | |
| # Build history pairs | |
| i = 0 | |
| while i < len(processed_messages): | |
| role, content = processed_messages[i] | |
| if role == "user": | |
| history.append({ | |
| "userInputMessage": { | |
| "content": content, | |
| "modelId": codewhisperer_model, | |
| "origin": "AI_EDITOR" | |
| } | |
| }) | |
| # Look for assistant response | |
| if i + 1 < len(processed_messages) and processed_messages[i + 1][0] == "assistant": | |
| _, assistant_content = processed_messages[i + 1] | |
| history.append({ | |
| "assistantResponseMessage": { | |
| "content": assistant_content | |
| } | |
| }) | |
| i += 2 | |
| else: | |
| # No assistant response, add a placeholder | |
| history.append({ | |
| "assistantResponseMessage": { | |
| "content": "I understand." | |
| } | |
| }) | |
| i += 1 | |
| elif role == "assistant": | |
| # Orphaned assistant message | |
| history.append({ | |
| "userInputMessage": { | |
| "content": "Continue", | |
| "modelId": codewhisperer_model, | |
| "origin": "AI_EDITOR" | |
| } | |
| }) | |
| history.append({ | |
| "assistantResponseMessage": { | |
| "content": content | |
| } | |
| }) | |
| i += 1 | |
| else: | |
| i += 1 | |
| # Build current message | |
| current_message = conversation_messages[-1] | |
| # Handle images in the last message | |
| images = [] | |
| if isinstance(current_message.content, list): | |
| for part in current_message.content: | |
| if part.type == "image_url" and part.image_url: | |
| try: | |
| # 记录原始 URL 的前 50 个字符,用于调试 | |
| logger.info(f"🔍 处理图片 URL: {part.image_url.url[:50]}...") | |
| # 检查 URL 格式是否正确 | |
| if not part.image_url.url.startswith("data:image/"): | |
| logger.error(f"❌ 图片 URL 格式不正确,应该以 'data:image/' 开头") | |
| continue | |
| # Correctly parse the data URI | |
| # format: data:image/jpeg;base64,{base64_string} | |
| header, encoded_data = part.image_url.url.split(",", 1) | |
| # Correctly parse the image format from the mime type | |
| # "data:image/jpeg;base64" -> "jpeg" | |
| # Use regex to reliably extract image format, e.g., "jpeg" from "data:image/jpeg;base64" | |
| match = re.search(r'image/(\w+)', header) | |
| if match: | |
| image_format = match.group(1) | |
| # 验证 Base64 编码是否有效 | |
| try: | |
| base64.b64decode(encoded_data) | |
| logger.info("✅ Base64 编码验证通过") | |
| except Exception as e: | |
| logger.error(f"❌ Base64 编码无效: {e}") | |
| continue | |
| images.append({ | |
| "format": image_format, | |
| "source": {"bytes": encoded_data} | |
| }) | |
| logger.info(f"🖼️ 成功处理图片,格式: {image_format}, 大小: {len(encoded_data)} 字符") | |
| else: | |
| logger.warning(f"⚠️ 无法从头部确定图片格式: {header}") | |
| except Exception as e: | |
| logger.error(f"❌ 处理图片 URL 失败: {str(e)}") | |
| current_content = current_message.get_content_text() | |
| # Handle different roles for current message | |
| if current_message.role == "tool": | |
| # For tool results, format them properly and mark as completed | |
| tool_result = current_content or '[Tool executed]' | |
| tool_call_id = getattr(current_message, 'tool_call_id', 'unknown') | |
| current_content = f"[Tool execution completed for {tool_call_id}]: {tool_result}" | |
| # Check if this tool result follows a tool call in history | |
| if len(conversation_messages) > 1: | |
| prev_message = conversation_messages[-2] | |
| if prev_message.role == "assistant" and hasattr(prev_message, 'tool_calls') and prev_message.tool_calls: | |
| # Find the corresponding tool call | |
| for tc in prev_message.tool_calls: | |
| if tc.id == tool_call_id: | |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" | |
| current_content = f"[Completed execution of {func_name}]: {tool_result}" | |
| break | |
| elif current_message.role == "assistant": | |
| # If last message is from assistant with tool calls, format it appropriately | |
| if hasattr(current_message, 'tool_calls') and current_message.tool_calls: | |
| tool_descriptions = [] | |
| for tc in current_message.tool_calls: | |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" | |
| tool_descriptions.append(f"Continue after calling {func_name}") | |
| current_content = "; ".join(tool_descriptions) | |
| else: | |
| current_content = "Continue the conversation" | |
| # Ensure current message has content | |
| if not current_content: | |
| current_content = "Continue" | |
| # Add system prompt to current message | |
| if system_prompt: | |
| current_content = f"{system_prompt}\n\n{current_content}" | |
| # Build request | |
| codewhisperer_request = { | |
| "profileArn": PROFILE_ARN, | |
| "conversationState": { | |
| "chatTriggerType": "MANUAL", | |
| "conversationId": conversation_id, | |
| "currentMessage": { | |
| "userInputMessage": { | |
| "content": current_content, | |
| "modelId": codewhisperer_model, | |
| "origin": "AI_EDITOR" | |
| } | |
| }, | |
| "history": history | |
| } | |
| } | |
| # Add context for tools | |
| user_input_message_context = {} | |
| if request.tools: | |
| user_input_message_context["tools"] = [ | |
| { | |
| "toolSpecification": { | |
| "name": tool.function.name, | |
| "description": tool.function.description or "", | |
| "inputSchema": {"json": tool.function.parameters or {}} | |
| } | |
| } for tool in request.tools | |
| ] | |
| # 根据文档,images 应该是 userInputMessage 的直接子字段,而不是在 userInputMessageContext 中 | |
| if images: | |
| # 直接添加到 userInputMessage 中 | |
| codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["images"] = images | |
| logger.info(f"📊 添加了 {len(images)} 个图片到 userInputMessage 中") | |
| for i, img in enumerate(images): | |
| logger.info(f" - 图片 {i+1}: 格式={img['format']}, 大小={len(img['source']['bytes'])} 字符") | |
| # 记录图片数据的前20个字符,用于调试 | |
| logger.info(f" - 图片数据前20字符: {img['source']['bytes'][:20]}...") | |
| logger.info(f"✅ 成功添加 images 到 userInputMessage 中") | |
| if user_input_message_context: | |
| codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["userInputMessageContext"] = user_input_message_context | |
| logger.info(f"✅ 成功添加 userInputMessageContext 到请求中") | |
| # 创建一个用于日志记录的请求副本,避免记录完整的图片数据 | |
| log_request = copy.deepcopy(codewhisperer_request) | |
| # 检查 images 是否在 userInputMessage 中 | |
| if "images" in log_request.get("conversationState", {}).get("currentMessage", {}).get("userInputMessage", {}): | |
| for img in log_request["conversationState"]["currentMessage"]["userInputMessage"]["images"]: | |
| if "bytes" in img.get("source", {}): | |
| img["source"]["bytes"] = img["source"]["bytes"][:20] + "..." # 只记录前20个字符 | |
| logger.info(f"🔄 COMPLETE CODEWHISPERER REQUEST: {json.dumps(log_request, indent=2)}") | |
| return codewhisperer_request | |
| # AWS Event Stream Parser (from version 2) | |
| class CodeWhispererStreamParser: | |
| def __init__(self): | |
| self.buffer = b'' | |
| self.error_count = 0 | |
| self.max_errors = 5 | |
| def parse(self, chunk: bytes) -> List[Dict[str, Any]]: | |
| """解析AWS事件流格式的数据块""" | |
| self.buffer += chunk | |
| logger.debug(f"Parser received {len(chunk)} bytes. Buffer size: {len(self.buffer)}") | |
| events = [] | |
| if len(self.buffer) < 12: | |
| return [] | |
| while len(self.buffer) >= 12: | |
| try: | |
| header_bytes = self.buffer[0:8] | |
| total_len, header_len = struct.unpack('>II', header_bytes) | |
| # 安全检查 | |
| if total_len > 2000000 or header_len > 2000000: | |
| logger.error(f"Unreasonable header values: total_len={total_len}, header_len={header_len}") | |
| self.buffer = self.buffer[8:] | |
| self.error_count += 1 | |
| if self.error_count > self.max_errors: | |
| logger.error("Too many parsing errors, clearing buffer") | |
| self.buffer = b'' | |
| continue | |
| # 等待完整帧 | |
| if len(self.buffer) < total_len: | |
| break | |
| # 提取完整帧 | |
| frame = self.buffer[:total_len] | |
| self.buffer = self.buffer[total_len:] | |
| # 提取有效载荷 | |
| payload_start = 8 + header_len | |
| payload_end = total_len - 4 # 减去尾部CRC | |
| if payload_start >= payload_end or payload_end > len(frame): | |
| logger.error(f"Invalid payload bounds") | |
| continue | |
| payload = frame[payload_start:payload_end] | |
| # 解码有效载荷 | |
| try: | |
| payload_str = payload.decode('utf-8', errors='ignore') | |
| # 尝试解析JSON | |
| json_start_index = payload_str.find('{') | |
| if json_start_index != -1: | |
| json_payload = payload_str[json_start_index:] | |
| event_data = json.loads(json_payload) | |
| events.append(event_data) | |
| logger.debug(f"Successfully parsed event: {event_data}") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON decode error: {e}") | |
| continue | |
| except struct.error as e: | |
| logger.error(f"Struct unpack error: {e}") | |
| self.buffer = self.buffer[1:] | |
| self.error_count += 1 | |
| if self.error_count > self.max_errors: | |
| logger.error("Too many parsing errors, clearing buffer") | |
| self.buffer = b'' | |
| except Exception as e: | |
| logger.error(f"Unexpected error during parsing: {str(e)}") | |
| self.buffer = self.buffer[1:] | |
| self.error_count += 1 | |
| if self.error_count > self.max_errors: | |
| logger.error("Too many parsing errors, clearing buffer") | |
| self.buffer = b'' | |
| if events: | |
| self.error_count = 0 | |
| return events | |
| # Simple fallback parser for basic responses | |
| class SimpleResponseParser: | |
| def parse_event_stream_to_json(raw_data: bytes) -> Dict[str, Any]: | |
| """Simple parser for fallback (from version 1)""" | |
| try: | |
| if isinstance(raw_data, bytes): | |
| raw_str = raw_data.decode('utf-8', errors='ignore') | |
| else: | |
| raw_str = str(raw_data) | |
| # Method 1: Look for JSON objects with content field | |
| json_pattern = r'\{[^{}]*"content"[^{}]*\}' | |
| matches = re.findall(json_pattern, raw_str, re.DOTALL) | |
| if matches: | |
| content_parts = [] | |
| for match in matches: | |
| try: | |
| data = json.loads(match) | |
| if 'content' in data and data['content']: | |
| content_parts.append(data['content']) | |
| except json.JSONDecodeError: | |
| continue | |
| if content_parts: | |
| full_content = ''.join(content_parts) | |
| return { | |
| "content": full_content, | |
| "tokens": len(full_content.split()) | |
| } | |
| # Method 2: Extract readable text | |
| clean_text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', raw_str) | |
| clean_text = re.sub(r':event-type[^:]*:[^:]*:[^:]*:', '', clean_text) | |
| clean_text = re.sub(r':content-type[^:]*:[^:]*:[^:]*:', '', clean_text) | |
| meaningful_text = re.sub(r'[^\w\s\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff.,!?;:()"\'-]', '', clean_text) | |
| meaningful_text = re.sub(r'\s+', ' ', meaningful_text).strip() | |
| if meaningful_text and len(meaningful_text) > 5: | |
| return { | |
| "content": meaningful_text, | |
| "tokens": len(meaningful_text.split()) | |
| } | |
| return {"content": "No readable content found", "tokens": 0} | |
| except Exception as e: | |
| return {"content": f"Error parsing response: {str(e)}", "tokens": 0} | |
| # API call to CodeWhisperer | |
| async def call_kiro_api(request: ChatCompletionRequest): | |
| """Make API call to Kiro/CodeWhisperer with token refresh handling""" | |
| token = token_manager.get_token() | |
| if not token: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "message": "No access token available", | |
| "type": "authentication_error", | |
| "param": None, | |
| "code": "invalid_api_key" | |
| } | |
| } | |
| ) | |
| request_data = build_codewhisperer_request(request) | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" if request.stream else "application/json" | |
| } | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| KIRO_BASE_URL, | |
| headers=headers, | |
| json=request_data, | |
| timeout=120 | |
| ) | |
| logger.info(f"📤 RESPONSE STATUS: {response.status_code}") | |
| if response.status_code == 403: | |
| logger.info("收到403响应,尝试刷新token...") | |
| new_token = await token_manager.refresh_tokens() | |
| if new_token: | |
| headers["Authorization"] = f"Bearer {new_token}" | |
| response = await client.post( | |
| KIRO_BASE_URL, | |
| headers=headers, | |
| json=request_data, | |
| timeout=120 | |
| ) | |
| logger.info(f"📤 RETRY RESPONSE STATUS: {response.status_code}") | |
| else: | |
| raise HTTPException(status_code=401, detail="Token refresh failed") | |
| if response.status_code == 429: | |
| raise HTTPException( | |
| status_code=429, | |
| detail={ | |
| "error": { | |
| "message": "Rate limit exceeded", | |
| "type": "rate_limit_error", | |
| "param": None, | |
| "code": "rate_limit_exceeded" | |
| } | |
| } | |
| ) | |
| response.raise_for_status() | |
| return response | |
| except httpx.HTTPStatusError as e: | |
| logger.error(f"HTTP ERROR: {e.response.status_code} - {e.response.text}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail={ | |
| "error": { | |
| "message": f"API call failed: {str(e)}", | |
| "type": "api_error", | |
| "param": None, | |
| "code": "api_error" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"API call failed: {str(e)}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail={ | |
| "error": { | |
| "message": f"API call failed: {str(e)}", | |
| "type": "api_error", | |
| "param": None, | |
| "code": "api_error" | |
| } | |
| } | |
| ) | |
| # Utility functions | |
| def estimate_tokens(text: str) -> int: | |
| """Rough token estimation""" | |
| return max(1, len(text) // 4) | |
| def create_usage_stats(prompt_text: str, completion_text: str) -> Usage: | |
| """Create usage statistics""" | |
| prompt_tokens = estimate_tokens(prompt_text) | |
| completion_tokens = estimate_tokens(completion_text) | |
| return Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens | |
| ) | |
| # API endpoints | |
| async def list_models(api_key: str = Depends(verify_api_key)): | |
| """List available models""" | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": model_id, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "ki2api" | |
| } | |
| for model_id in MODEL_MAP.keys() | |
| ] | |
| } | |
| async def create_chat_completion( | |
| request: ChatCompletionRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Create a chat completion""" | |
| logger.info(f"📥 COMPLETE REQUEST: {request.model_dump_json(indent=2)}") | |
| # Validate messages have content | |
| for i, msg in enumerate(request.messages): | |
| if msg.content is None and msg.role != "assistant": | |
| logger.warning(f"Message {i} with role '{msg.role}' has None content") | |
| if request.model not in MODEL_MAP: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": { | |
| "message": f"The model '{request.model}' does not exist or you do not have access to it.", | |
| "type": "invalid_request_error", | |
| "param": "model", | |
| "code": "model_not_found" | |
| } | |
| } | |
| ) | |
| # 总是使用非流式响应,但根据请求类型返回不同格式 | |
| response = await create_non_streaming_response(request) | |
| if request.stream: | |
| # 将非流式响应转换为流式格式 | |
| return await convert_to_streaming_response(response) | |
| else: | |
| return response | |
| async def convert_to_streaming_response(response: ChatCompletionResponse): | |
| """将非流式响应转换为流式格式返回""" | |
| async def generate_stream(): | |
| # 使用原响应的ID和时间戳 | |
| response_id = response.id | |
| created = response.created | |
| model = response.model | |
| # 发送初始块 - role | |
| initial_chunk = ChatCompletionStreamResponse( | |
| id=response_id, | |
| model=model, | |
| created=created, | |
| choices=[StreamChoice( | |
| index=0, | |
| delta={"role": "assistant"}, | |
| finish_reason=None | |
| )] | |
| ) | |
| yield f"data: {initial_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 获取响应消息 | |
| if response.choices and len(response.choices) > 0: | |
| message = response.choices[0].message | |
| # 如果有工具调用,发送工具调用 | |
| if message.tool_calls: | |
| for i, tool_call in enumerate(message.tool_calls): | |
| # 发送完整的工具调用作为一个块 | |
| tool_chunk = ChatCompletionStreamResponse( | |
| id=response_id, | |
| model=model, | |
| created=created, | |
| choices=[StreamChoice( | |
| index=0, | |
| delta={ | |
| "tool_calls": [{ | |
| "index": i, | |
| "id": tool_call.id, | |
| "type": tool_call.type, | |
| "function": tool_call.function | |
| }] | |
| }, | |
| finish_reason=None | |
| )] | |
| ) | |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 如果有内容,分块发送内容 | |
| elif message.content: | |
| # 将内容分成较小的块以模拟流式传输 | |
| content = message.content | |
| chunk_size = 50 # 每个块的字符数 | |
| for i in range(0, len(content), chunk_size): | |
| chunk_text = content[i:i + chunk_size] | |
| content_chunk = ChatCompletionStreamResponse( | |
| id=response_id, | |
| model=model, | |
| created=created, | |
| choices=[StreamChoice( | |
| index=0, | |
| delta={"content": chunk_text}, | |
| finish_reason=None | |
| )] | |
| ) | |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 添加小延迟以模拟真实的流式传输 | |
| await asyncio.sleep(0.01) | |
| # 发送结束块 | |
| finish_reason = response.choices[0].finish_reason | |
| end_chunk = ChatCompletionStreamResponse( | |
| id=response_id, | |
| model=model, | |
| created=created, | |
| choices=[StreamChoice( | |
| index=0, | |
| delta={}, | |
| finish_reason=finish_reason | |
| )] | |
| ) | |
| yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 发送流结束标记 | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream" | |
| } | |
| ) | |
| async def create_non_streaming_response(request: ChatCompletionRequest): | |
| """ | |
| Handles non-streaming chat completion requests. | |
| It fetches the complete response from CodeWhisperer, parses it using | |
| CodeWhispererStreamParser, and constructs a single OpenAI-compatible | |
| ChatCompletionResponse. This version correctly handles tool calls by | |
| parsing both structured event data and bracket format in text. | |
| """ | |
| try: | |
| logger.info("🚀 开始非流式响应生成...") | |
| response = await call_kiro_api(request) | |
| # 添加详细的原始响应日志 | |
| logger.info(f"📤 CodeWhisperer响应状态码: {response.status_code}") | |
| logger.info(f"📤 响应头: {dict(response.headers)}") | |
| logger.info(f"📤 原始响应体长度: {len(response.content)} bytes") | |
| # 获取原始响应文本用于工具调用检测 | |
| raw_response_text = "" | |
| try: | |
| raw_response_text = response.content.decode('utf-8', errors='ignore') | |
| logger.info(f"🔍 原始响应文本长度: {len(raw_response_text)}") | |
| logger.info(f"🔍 原始响应预览(前1000字符): {raw_response_text[:1000]}") | |
| # 检查是否包含工具调用标记 | |
| if "[Called" in raw_response_text: | |
| logger.info("✅ 原始响应中发现 [Called 标记") | |
| called_positions = [m.start() for m in re.finditer(r'\[Called', raw_response_text)] | |
| logger.info(f"🎯 [Called 出现位置: {called_positions}") | |
| else: | |
| logger.info("❌ 原始响应中未发现 [Called 标记") | |
| except Exception as e: | |
| logger.error(f"❌ 解码原始响应失败: {e}") | |
| # 使用 CodeWhispererStreamParser 一次性解析整个响应体 | |
| parser = CodeWhispererStreamParser() | |
| events = parser.parse(response.content) | |
| full_response_text = "" | |
| tool_calls = [] | |
| current_tool_call_dict = None | |
| logger.info(f"🔄 解析到 {len(events)} 个事件,开始处理...") | |
| # 记录每个事件的详细信息 | |
| for i, event in enumerate(events): | |
| logger.info(f"📋 事件 {i}: {event}") | |
| for event in events: | |
| # 优先处理结构化工具调用事件 | |
| if "name" in event and "toolUseId" in event: | |
| logger.info(f"🔧 发现结构化工具调用事件: {event}") | |
| # 如果是新的工具调用,则初始化 | |
| if not current_tool_call_dict: | |
| current_tool_call_dict = { | |
| "id": event.get("toolUseId"), | |
| "type": "function", | |
| "function": { | |
| "name": event.get("name"), | |
| "arguments": "" | |
| } | |
| } | |
| logger.info(f"🆕 开始解析工具调用: {current_tool_call_dict['function']['name']}") | |
| # 累积参数 | |
| if "input" in event: | |
| current_tool_call_dict["function"]["arguments"] += event.get("input", "") | |
| logger.info(f"📝 累积参数: {event.get('input', '')}") | |
| # 工具调用结束 | |
| if event.get("stop"): | |
| logger.info(f"✅ 完成工具调用: {current_tool_call_dict['function']['name']}") | |
| # 验证并标准化参数为JSON字符串 | |
| try: | |
| args = json.loads(current_tool_call_dict["function"]["arguments"]) | |
| current_tool_call_dict["function"]["arguments"] = json.dumps(args, ensure_ascii=False) | |
| logger.info(f"✅ 工具调用参数验证成功") | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"⚠️ 工具调用的参数不是有效的JSON: {current_tool_call_dict['function']['arguments']}") | |
| logger.warning(f"⚠️ JSON错误: {e}") | |
| tool_calls.append(ToolCall(**current_tool_call_dict)) | |
| current_tool_call_dict = None # 重置以备下一个 | |
| # 处理普通文本内容事件 | |
| elif "content" in event: | |
| content = event.get("content", "") | |
| full_response_text += content | |
| logger.info(f"📄 添加文本内容: {content[:100]}...") | |
| # 如果流在工具调用中间意外结束,也将其添加 | |
| if current_tool_call_dict: | |
| logger.warning("⚠️ 响应流在工具调用结束前终止,仍尝试添加。") | |
| tool_calls.append(ToolCall(**current_tool_call_dict)) | |
| logger.info(f"📊 事件处理完成 - 文本长度: {len(full_response_text)}, 结构化工具调用: {len(tool_calls)}") | |
| # 检查解析后文本中的 bracket 格式工具调用 | |
| logger.info("🔍 开始检查解析后文本中的bracket格式工具调用...") | |
| bracket_tool_calls = parse_bracket_tool_calls(full_response_text) | |
| if bracket_tool_calls: | |
| logger.info(f"✅ 在解析后文本中发现 {len(bracket_tool_calls)} 个 bracket 格式工具调用") | |
| tool_calls.extend(bracket_tool_calls) | |
| # 从响应文本中移除工具调用文本 | |
| for tc in bracket_tool_calls: | |
| # 构建精确的正则表达式来匹配这个特定的工具调用 | |
| func_name = tc.function.get("name", "unknown") | |
| # 转义函数名中的特殊字符 | |
| escaped_name = re.escape(func_name) | |
| # 匹配 [Called FunctionName with args: {...}] | |
| pattern = r'\[Called\s+' + escaped_name + r'\s+with\s+args:\s*\{[^}]*(?:\{[^}]*\}[^}]*)*\}\s*\]' | |
| full_response_text = re.sub(pattern, '', full_response_text, flags=re.DOTALL) | |
| # 清理多余的空白 | |
| full_response_text = re.sub(r'\s+', ' ', full_response_text).strip() | |
| # 关键修复:检查原始响应中的 bracket 格式工具调用 | |
| logger.info("🔍 开始检查原始响应中的bracket格式工具调用...") | |
| raw_bracket_tool_calls = parse_bracket_tool_calls(raw_response_text) | |
| if raw_bracket_tool_calls and isinstance(raw_bracket_tool_calls, list): | |
| logger.info(f"✅ 在原始响应中发现 {len(raw_bracket_tool_calls)} 个 bracket 格式工具调用") | |
| tool_calls.extend(raw_bracket_tool_calls) | |
| else: | |
| logger.info("❌ 原始响应中未发现bracket格式工具调用") | |
| # 去重工具调用 | |
| logger.info(f"🔄 去重前工具调用数量: {len(tool_calls)}") | |
| unique_tool_calls = deduplicate_tool_calls(tool_calls) | |
| logger.info(f"🔄 去重后工具调用数量: {len(unique_tool_calls)}") | |
| # 根据是否有工具调用来构建响应 | |
| if unique_tool_calls: | |
| logger.info(f"🔧 构建工具调用响应,包含 {len(unique_tool_calls)} 个工具调用") | |
| for i, tc in enumerate(unique_tool_calls): | |
| logger.info(f"🔧 工具调用 {i}: {tc.function.get('name', 'unknown')}") | |
| response_message = ResponseMessage( | |
| role="assistant", | |
| content=None, # OpenAI规范:当有tool_calls时,content必须为None | |
| tool_calls=unique_tool_calls | |
| ) | |
| finish_reason = "tool_calls" | |
| else: | |
| logger.info("📄 构建普通文本响应") | |
| # 如果没有工具调用,使用清理后的文本 | |
| content = full_response_text.strip() if full_response_text.strip() else "I understand." | |
| logger.info(f"📄 最终文本内容: {content[:200]}...") | |
| response_message = ResponseMessage( | |
| role="assistant", | |
| content=content | |
| ) | |
| finish_reason = "stop" | |
| choice = Choice( | |
| index=0, | |
| message=response_message, | |
| finish_reason=finish_reason | |
| ) | |
| usage = create_usage_stats( | |
| prompt_text=" ".join([msg.get_content_text() for msg in request.messages]), | |
| completion_text=full_response_text if not unique_tool_calls else "" | |
| ) | |
| chat_response = ChatCompletionResponse( | |
| model=request.model, | |
| choices=[choice], | |
| usage=usage | |
| ) | |
| logger.info(f"📤 最终非流式响应构建完成") | |
| logger.info(f"📤 响应类型: {'工具调用' if unique_tool_calls else '文本内容'}") | |
| logger.info(f"📤 完整响应: {chat_response.model_dump_json(indent=2, exclude_none=True)}") | |
| return chat_response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ 非流式响应处理出错: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": { | |
| "message": f"Internal server error: {str(e)}", | |
| "type": "internal_server_error", | |
| "param": None, | |
| "code": "internal_error" | |
| } | |
| } | |
| ) | |
| async def create_streaming_response(request: ChatCompletionRequest): | |
| """ | |
| Handles streaming chat completion requests. | |
| This function iteratively processes the binary event stream from CodeWhisperer, | |
| parsing events on the fly. It maintains state to correctly identify and | |
| stream text content or tool calls in the OpenAI-compatible format. | |
| """ | |
| try: | |
| logger.info("开始流式响应生成...") | |
| response = await call_kiro_api(request) | |
| async def generate_stream(): | |
| response_id = f"chatcmpl-{uuid.uuid4()}" | |
| created = int(time.time()) | |
| parser = CodeWhispererStreamParser() | |
| # --- 状态变量 --- | |
| is_in_tool_call = False | |
| sent_role = False | |
| current_tool_call_index = 0 | |
| streamed_tool_calls_count = 0 | |
| content_buffer = "" # 用于累积文本内容 | |
| incomplete_tool_call = "" # 用于累积不完整的工具调用 | |
| async for chunk in response.aiter_bytes(): | |
| events = parser.parse(chunk) | |
| for event in events: | |
| # --- 处理结构化工具调用事件 --- | |
| if "name" in event and "toolUseId" in event: | |
| logger.info(f"🎯 STREAM: Found structured tool call event: {event}") | |
| # 开始一个新的工具调用 | |
| if not is_in_tool_call: | |
| is_in_tool_call = True | |
| # 发送工具调用开始的 chunk | |
| delta_start = { | |
| "tool_calls": [{ | |
| "index": current_tool_call_index, | |
| "id": event.get("toolUseId"), | |
| "type": "function", | |
| "function": {"name": event.get("name"), "arguments": ""} | |
| }] | |
| } | |
| # 如果是第一个数据块,需要包含 role: assistant | |
| if not sent_role: | |
| delta_start["role"] = "assistant" | |
| sent_role = True | |
| start_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_start)] | |
| ) | |
| yield f"data: {start_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 累积工具调用的参数 | |
| if "input" in event: | |
| arg_chunk_str = event.get("input", "") | |
| if arg_chunk_str: | |
| arg_chunk_delta = { | |
| "tool_calls": [{ | |
| "index": current_tool_call_index, | |
| "function": {"arguments": arg_chunk_str} | |
| }] | |
| } | |
| arg_chunk_resp = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=arg_chunk_delta)] | |
| ) | |
| yield f"data: {arg_chunk_resp.model_dump_json(exclude_none=True)}\n\n" | |
| # 结束一个工具调用 | |
| if event.get("stop"): | |
| is_in_tool_call = False | |
| current_tool_call_index += 1 | |
| streamed_tool_calls_count += 1 | |
| # --- 处理普通文本内容事件 --- | |
| elif "content" in event and not is_in_tool_call: | |
| content_text = event.get("content", "") | |
| if content_text: | |
| content_buffer += content_text | |
| logger.info(f"📝 STREAM DEBUG: Buffer updated. Length: {len(content_buffer)}. Content: >>>{content_buffer}<<<") | |
| logger.info(f"📝 STREAM DEBUG: incomplete_tool_call: >>>{incomplete_tool_call}<<<") | |
| # 处理bracket格式的工具调用 | |
| while True: | |
| # 查找 [Called 的开始位置 | |
| called_start = content_buffer.find("[Called") | |
| logger.info(f"🔍 BRACKET DEBUG: Searching for [Called in buffer (length={len(content_buffer)})") | |
| logger.info(f"🔍 BRACKET DEBUG: called_start={called_start}") | |
| logger.info(f"🔍 BRACKET DEBUG: Full buffer content: >>>{content_buffer}<<<") | |
| if called_start == -1: | |
| # 没有工具调用,发送所有内容 | |
| logger.info(f"🔍 BRACKET DEBUG: No [Called found, sending buffer as content") | |
| logger.info(f"🔍 BRACKET DEBUG: incomplete_tool_call status: {bool(incomplete_tool_call)}") | |
| if content_buffer and not incomplete_tool_call: | |
| delta_content = {"content": content_buffer} | |
| if not sent_role: | |
| delta_content["role"] = "assistant" | |
| sent_role = True | |
| logger.info(f"📤 STREAM: Sending content chunk: {delta_content}") | |
| content_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_content)] | |
| ) | |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| content_buffer = "" | |
| break | |
| logger.info(f"🔍 BRACKET DEBUG: Found [Called at position {called_start}") | |
| # 发送 [Called 之前的文本 | |
| if called_start > 0: | |
| text_before = content_buffer[:called_start] | |
| logger.info(f"🔍 BRACKET DEBUG: Text before [Called: >>>{text_before}<<<") | |
| if text_before.strip(): | |
| delta_content = {"content": text_before} | |
| if not sent_role: | |
| delta_content["role"] = "assistant" | |
| sent_role = True | |
| content_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_content)] | |
| ) | |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # 查找对应的结束 ] | |
| remaining_text = content_buffer[called_start:] | |
| logger.info(f"🔍 BRACKET DEBUG: Looking for matching ] in: >>>{remaining_text[:100]}...<<<") | |
| bracket_end = find_matching_bracket(remaining_text, 0) | |
| logger.info(f"🔍 BRACKET DEBUG: bracket_end={bracket_end}") | |
| if bracket_end == -1: | |
| # 工具调用不完整,保留在缓冲区 | |
| logger.info(f"🔍 BRACKET DEBUG: Tool call incomplete, saving to incomplete_tool_call") | |
| logger.info(f"🔍 BRACKET DEBUG: Incomplete content: >>>{remaining_text}<<<") | |
| incomplete_tool_call = remaining_text | |
| content_buffer = "" | |
| break | |
| # 提取完整的工具调用 | |
| tool_call_text = remaining_text[:bracket_end + 1] | |
| logger.info(f"🔍 BRACKET DEBUG: Extracting tool call: >>>{tool_call_text}<<<") | |
| parsed_call = parse_single_tool_call(tool_call_text) | |
| logger.info(f"🔍 BRACKET DEBUG: Parsed call result: {parsed_call}") | |
| if parsed_call: | |
| # 发送工具调用 | |
| delta_tool = { | |
| "tool_calls": [{ | |
| "index": current_tool_call_index, | |
| "id": parsed_call.id, | |
| "type": "function", | |
| "function": { | |
| "name": parsed_call.function["name"], | |
| "arguments": parsed_call.function["arguments"] | |
| } | |
| }] | |
| } | |
| if not sent_role: | |
| delta_tool["role"] = "assistant" | |
| sent_role = True | |
| logger.info(f"📤 STREAM: Sending tool call chunk: {delta_tool}") | |
| tool_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_tool)] | |
| ) | |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| current_tool_call_index += 1 | |
| streamed_tool_calls_count += 1 | |
| else: | |
| logger.error(f"❌ BRACKET DEBUG: Failed to parse tool call") | |
| # 更新缓冲区 | |
| content_buffer = remaining_text[bracket_end + 1:] | |
| incomplete_tool_call = "" | |
| logger.info(f"🔍 BRACKET DEBUG: Updated buffer after tool call: >>>{content_buffer}<<<") | |
| # 处理剩余的内容 | |
| logger.info(f"📊 STREAM END: Processing remaining content") | |
| logger.info(f"📊 STREAM END: incomplete_tool_call: >>>{incomplete_tool_call}<<<") | |
| logger.info(f"📊 STREAM END: content_buffer: >>>{content_buffer}<<<") | |
| if incomplete_tool_call: | |
| # 尝试再次解析不完整的工具调用(可能现在已经完整了) | |
| logger.info(f"🔄 STREAM END: Attempting to parse incomplete tool call") | |
| content_buffer = incomplete_tool_call + content_buffer | |
| incomplete_tool_call = "" | |
| # 重复上面的解析逻辑 | |
| called_start = content_buffer.find("[Called") | |
| if called_start == 0: | |
| bracket_end = find_matching_bracket(content_buffer, 0) | |
| logger.info(f"🔄 STREAM END: bracket_end for incomplete={bracket_end}") | |
| if bracket_end != -1: | |
| tool_call_text = content_buffer[:bracket_end + 1] | |
| parsed_call = parse_single_tool_call(tool_call_text) | |
| if parsed_call: | |
| delta_tool = { | |
| "tool_calls": [{ | |
| "index": current_tool_call_index, | |
| "id": parsed_call.id, | |
| "type": "function", | |
| "function": { | |
| "name": parsed_call.function["name"], | |
| "arguments": parsed_call.function["arguments"] | |
| } | |
| }] | |
| } | |
| if not sent_role: | |
| delta_tool["role"] = "assistant" | |
| sent_role = True | |
| logger.info(f"📤 STREAM END: Sending final tool call: {delta_tool}") | |
| tool_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_tool)] | |
| ) | |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| current_tool_call_index += 1 | |
| streamed_tool_calls_count += 1 | |
| content_buffer = content_buffer[bracket_end + 1:] | |
| # 发送任何剩余的内容 | |
| if content_buffer.strip(): | |
| logger.info(f"📤 STREAM END: Sending remaining content: >>>{content_buffer}<<<") | |
| delta_content = {"content": content_buffer} | |
| if not sent_role: | |
| delta_content["role"] = "assistant" | |
| sent_role = True | |
| content_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta=delta_content)] | |
| ) | |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| # --- 流结束 --- | |
| finish_reason = "tool_calls" if streamed_tool_calls_count > 0 else "stop" | |
| logger.info(f"🏁 STREAM FINISH: streamed_tool_calls_count={streamed_tool_calls_count}, finish_reason={finish_reason}") | |
| end_chunk = ChatCompletionStreamResponse( | |
| id=response_id, model=request.model, created=created, | |
| choices=[StreamChoice(index=0, delta={}, finish_reason=finish_reason)] | |
| ) | |
| yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ 流式响应生成失败: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": { | |
| "message": f"Stream generation failed: {str(e)}", | |
| "type": "internal_server_error", | |
| "param": None, | |
| "code": "stream_error" | |
| } | |
| } | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "service": "Ki2API", "version": "3.0.1"} | |
| async def root(): | |
| """Root endpoint with service information""" | |
| return { | |
| "service": "Ki2API", | |
| "description": "OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", | |
| "version": "3.0.1", | |
| "endpoints": { | |
| "models": "/v1/models", | |
| "chat": "/v1/chat/completions", | |
| "health": "/health" | |
| }, | |
| "features": { | |
| "streaming": True, | |
| "tools": True, | |
| "multiple_models": True, | |
| "xml_tool_parsing": True, | |
| "auto_token_refresh": True, | |
| "null_content_handling": True, | |
| "tool_call_deduplication": True | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |