kr2 / app.py
jardan's picture
Update app.py
118abb7 verified
import os
import json
import time
import uuid
import httpx
import re
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Initialize FastAPI app
app = FastAPI(
title="Ki2API - Claude Sonnet 4 OpenAI Compatible API",
description="Simple Docker-ready OpenAI-compatible API for Claude Sonnet 4",
version="1.0.0"
)
# Configuration
API_KEY = os.getenv("API_KEY", "ki2api-key-2024")
KIRO_ACCESS_TOKEN = os.getenv("KIRO_ACCESS_TOKEN")
KIRO_REFRESH_TOKEN = os.getenv("KIRO_REFRESH_TOKEN")
KIRO_BASE_URL = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse"
PROFILE_ARN = "arn:aws:codewhisperer:us-east-1:699475941385:profile/EHGA3GRVQMUK"
# Model mapping
MODEL_NAME = "claude-sonnet-4-20250514"
CODEWHISPERER_MODEL = "CLAUDE_SONNET_4_20250514_V1_0"
# Pydantic models
class ContentPart(BaseModel):
type: str = "text"
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]
def get_content_text(self) -> str:
"""Extract text content from either string or content parts"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
# Join all text parts
text_parts = []
for part in self.content:
if isinstance(part, dict):
if part.get("type") == "text" and "text" in part:
text_parts.append(part["text"])
elif hasattr(part, 'text'):
text_parts.append(part.text)
return "".join(text_parts)
return str(self.content)
# Anthropic Claude format models
class AnthropicContentBlock(BaseModel):
type: str = "text"
text: str
class AnthropicMessage(BaseModel):
role: str # "user" or "assistant"
content: Union[str, List[AnthropicContentBlock]]
class AnthropicMessagesRequest(BaseModel):
model: str
max_tokens: int
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class AnthropicMessagesResponse(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
type: str = "message"
role: str = "assistant"
content: List[AnthropicContentBlock]
model: str
stop_reason: str = "end_turn"
stop_sequence: Optional[str] = None
usage: Dict[str, int]
class AnthropicStreamResponse(BaseModel):
type: str
index: Optional[int] = None
content_block: Optional[AnthropicContentBlock] = None
delta: Optional[Dict[str, Any]] = None
message: Optional[Dict[str, Any]] = None
usage: Optional[Dict[str, int]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 4000
stream: Optional[bool] = False
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[Dict[str, Any]]
# Token management
class TokenManager:
def __init__(self):
self.access_token = KIRO_ACCESS_TOKEN
self.refresh_token = KIRO_REFRESH_TOKEN
self.refresh_url = "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken"
async def refresh_tokens(self):
if not self.refresh_token:
return None
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.refresh_url,
json={"refreshToken": self.refresh_token},
timeout=30
)
response.raise_for_status()
data = response.json()
self.access_token = data.get("accessToken")
return self.access_token
except Exception as e:
print(f"Token refresh failed: {e}")
return None
def get_token(self):
return self.access_token
token_manager = TokenManager()
# Build CodeWhisperer request
def build_codewhisperer_request(messages: List[ChatMessage]):
conversation_id = str(uuid.uuid4())
# Extract system prompt and user messages
system_prompt = ""
user_messages = []
for msg in messages:
if msg.role == "system":
system_prompt = msg.get_content_text()
else:
user_messages.append(msg)
if not user_messages:
raise HTTPException(status_code=400, detail="No user messages found")
# Build history
history = []
for i in range(0, len(user_messages) - 1, 2):
if i + 1 < len(user_messages):
history.append({
"userInputMessage": {
"content": user_messages[i].get_content_text(),
"modelId": CODEWHISPERER_MODEL,
"origin": "AI_EDITOR"
}
})
history.append({
"assistantResponseMessage": {
"content": user_messages[i + 1].get_content_text(),
"toolUses": []
}
})
# Build current message
current_message = user_messages[-1]
content = current_message.get_content_text()
if system_prompt:
content = f"{system_prompt}\n\n{content}"
return {
"profileArn": PROFILE_ARN,
"conversationState": {
"chatTriggerType": "MANUAL",
"conversationId": conversation_id,
"currentMessage": {
"userInputMessage": {
"content": content,
"modelId": CODEWHISPERER_MODEL,
"origin": "AI_EDITOR",
"userInputMessageContext": {}
}
},
"history": history
}
}
# Convert Anthropic messages to internal ChatMessage format
def anthropic_to_chat_messages(anthropic_request: AnthropicMessagesRequest) -> List[ChatMessage]:
"""Convert Anthropic messages format to internal ChatMessage format"""
chat_messages = []
# Add system message if present
if anthropic_request.system:
chat_messages.append(ChatMessage(role="system", content=anthropic_request.system))
# Convert Anthropic messages
for msg in anthropic_request.messages:
if isinstance(msg.content, str):
content = msg.content
else: # List[AnthropicContentBlock]
# Extract text from content blocks
text_parts = []
for block in msg.content:
if block.type == "text":
text_parts.append(block.text)
content = "".join(text_parts)
chat_messages.append(ChatMessage(role=msg.role, content=content))
return chat_messages
# AWS Event Stream Parser
class AWSStreamParser:
@staticmethod
def parse_event_stream_to_json(raw_data: bytes) -> Dict[str, Any]:
"""Parse AWS event stream format to JSON"""
try:
# Convert bytes to string if needed
if isinstance(raw_data, bytes):
# Try to decode as UTF-8 first
try:
raw_str = raw_data.decode('utf-8')
except UnicodeDecodeError:
# If UTF-8 fails, try to find JSON in binary
raw_str = raw_data.decode('utf-8', errors='ignore')
else:
raw_str = str(raw_data)
# Look for JSON content in the response
# AWS event stream contains binary headers followed by JSON payloads
json_pattern = r'\{[^{}]*"content"[^{}]*\}'
matches = re.findall(json_pattern, raw_str, re.DOTALL)
if matches:
content_parts = []
for match in matches:
try:
data = json.loads(match)
if 'content' in data and data['content']:
content_parts.append(data['content'])
except:
continue
if content_parts:
return {"content": ''.join(content_parts)}
# Try to extract from AWS event stream format
# Look for :content-type and extract JSON after headers
content_type_pattern = r':content-type[^:]*:[^:]*:[^:]*:(\{.*\})'
content_matches = re.findall(content_type_pattern, raw_str, re.DOTALL)
if content_matches:
for match in content_matches:
try:
data = json.loads(match.strip())
if isinstance(data, dict) and 'content' in data:
return {"content": data['content']}
except:
continue
# Try to extract any JSON objects
json_objects = re.findall(r'\{[^{}]*\}', raw_str)
for obj in json_objects:
try:
data = json.loads(obj)
if isinstance(data, dict) and 'content' in data:
return {"content": data['content']}
except:
continue
# Final fallback: extract readable text
readable_text = re.sub(r'[^\x20-\x7E\n\r\t]', '', raw_str)
readable_text = re.sub(r':event-type[^:]*:[^:]*:[^:]*:', '', readable_text)
# Look for Chinese characters or meaningful content
chinese_pattern = r'[\u4e00-\u9fff]+'
chinese_matches = re.findall(chinese_pattern, raw_str)
if chinese_matches:
return {"content": ''.join(chinese_matches)}
return {"content": readable_text.strip() or "No content found in response"}
except Exception as e:
return {"content": f"Error parsing response: {str(e)}"}
# Make API call to Kiro/CodeWhisperer
async def call_kiro_api(messages: List[ChatMessage], stream: bool = False):
token = token_manager.get_token()
if not token:
raise HTTPException(status_code=401, detail="No access token available")
request_data = build_codewhisperer_request(messages)
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": "text/event-stream" if stream else "application/json"
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
KIRO_BASE_URL,
headers=headers,
json=request_data,
timeout=120
)
if response.status_code == 403:
# Try to refresh token
new_token = await token_manager.refresh_tokens()
if new_token:
headers["Authorization"] = f"Bearer {new_token}"
response = await client.post(
KIRO_BASE_URL,
headers=headers,
json=request_data,
timeout=120
)
response.raise_for_status()
return response
except Exception as e:
import traceback
print(f"API call failed: {str(e)}")
print(traceback.format_exc())
raise HTTPException(status_code=503, detail=f"API call failed: {str(e)}")
# API endpoints
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": MODEL_NAME,
"object": "model",
"created": int(time.time()),
"owned_by": "ki2api"
}
]
}
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
if request.model != MODEL_NAME:
raise HTTPException(status_code=400, detail=f"Only {MODEL_NAME} is supported")
if request.stream:
return await create_streaming_response(request)
else:
return await create_non_streaming_response(request)
async def create_non_streaming_response(request: ChatCompletionRequest):
response = await call_kiro_api(request.messages, stream=False)
return await create_conversion_response(response)
async def create_conversion_response(response):
"""Convert AWS event stream to OpenAI format"""
try:
print(f"Response status: {response.status_code}")
print(f"Response headers: {dict(response.headers)}")
# Get response content as bytes to handle binary data
response_bytes = response.content
print(f"Response content type: {type(response_bytes)}")
print(f"Response content length: {len(response_bytes)}")
# Try to parse as JSON first
try:
response_data = response.json()
print(f"Successfully parsed JSON response")
if isinstance(response_data, dict) and 'content' in response_data:
response_text = response_data['content']
else:
response_text = str(response_data)
except Exception as e:
print(f"JSON parsing failed: {e}")
# Handle event stream format using AWS parser
parsed_data = AWSStreamParser.parse_event_stream_to_json(response_bytes)
response_text = parsed_data.get('content', "")
print(f"Parsed content length: {len(response_text)}")
if not response_text or response_text == "No content found in response":
# Last resort: try to decode as text
try:
response_text = response_bytes.decode('utf-8', errors='ignore')
print(f"Fallback text decode length: {len(response_text)}")
except Exception as decode_error:
response_text = f"Unable to decode response: {str(decode_error)}"
print(f"Final response text: {response_text[:200]}...")
except Exception as e:
print(f"Error in conversion: {e}")
import traceback
traceback.print_exc()
response_text = f"Error processing response: {str(e)}"
return ChatCompletionResponse(
model=MODEL_NAME,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
)
async def create_streaming_response(request: ChatCompletionRequest):
response = await call_kiro_api(request.messages, stream=True)
return await create_streaming_conversion_response(response)
async def create_streaming_conversion_response(response):
"""Convert AWS event stream to OpenAI streaming format"""
print(f"Starting streaming response, status: {response.status_code}")
async def generate():
# Send initial response
initial_chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'role': 'assistant'},
'finish_reason': None
}]
}
print(f"Sending initial chunk: {initial_chunk}")
yield f"data: {json.dumps(initial_chunk)}\n\n"
# Read response and stream content
content = ""
chunk_count = 0
# Read the entire response as bytes first
response_bytes = response.content
print(f"Streaming response bytes length: {len(response_bytes)}")
# Parse the AWS event stream
try:
# Convert bytes to string
if isinstance(response_bytes, bytes):
response_str = response_bytes.decode('utf-8', errors='ignore')
else:
response_str = str(response_bytes)
# Look for content in the AWS event stream
# AWS uses a specific format with binary headers and JSON payloads
# Method 1: Look for JSON objects with content
json_pattern = r'\{[^{}]*"content"[^{}]*\}'
json_matches = re.findall(json_pattern, response_str, re.DOTALL)
if json_matches:
for match in json_matches:
try:
data = json.loads(match)
if 'content' in data and data['content']:
chunk_text = data['content']
content += chunk_text
chunk_count += 1
chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': chunk_text},
'finish_reason': None
}]
}
print(f"Streaming JSON chunk {chunk_count}: {chunk_text[:50]}...")
yield f"data: {json.dumps(chunk)}\n\n"
# Small delay to simulate streaming
import asyncio
await asyncio.sleep(0.01)
except Exception as e:
print(f"Error streaming JSON chunk: {e}")
continue
else:
# Method 2: Try to extract readable text
readable_text = re.sub(r'[^\x20-\x7E\n\r\t\u4e00-\u9fff]', '', response_str)
# Look for Chinese text specifically
chinese_pattern = r'[\u4e00-\u9fff][\u4e00-\u9fff\s\.,!?]*[\u4e00-\u9fff]'
chinese_matches = re.findall(chinese_pattern, response_str)
if chinese_matches:
combined_text = ''.join(chinese_matches)
# Split into chunks for streaming
chunk_size = max(1, len(combined_text) // 10)
for i in range(0, len(combined_text), chunk_size):
chunk_text = combined_text[i:i + chunk_size]
content += chunk_text
chunk_count += 1
chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': chunk_text},
'finish_reason': None
}]
}
print(f"Streaming Chinese text chunk {chunk_count}: {chunk_text[:50]}...")
yield f"data: {json.dumps(chunk)}\n\n"
import asyncio
await asyncio.sleep(0.05)
else:
# Method 3: Use the entire readable text
if readable_text.strip():
chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': readable_text.strip()},
'finish_reason': None
}]
}
print(f"Streaming fallback text: {readable_text[:100]}...")
yield f"data: {json.dumps(chunk)}\n\n"
content = readable_text.strip()
except Exception as e:
print(f"Error in streaming generation: {e}")
import traceback
traceback.print_exc()
# Send error as content
error_chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': f"Error: {str(e)}"},
'finish_reason': None
}]
}
yield f"data: {json.dumps(error_chunk)}\n\n"
print(f"Streaming complete, total chunks: {chunk_count}, content length: {len(content)}")
# Send final response
final_chunk = {
'id': f'chatcmpl-{uuid.uuid4()}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# Anthropic response conversion functions
async def create_anthropic_response(response, model: str):
"""Convert AWS event stream to Anthropic Messages format"""
try:
print(f"Response status: {response.status_code}")
print(f"Response headers: {dict(response.headers)}")
# Get response content as bytes to handle binary data
response_bytes = response.content
print(f"Response content type: {type(response_bytes)}")
print(f"Response content length: {len(response_bytes)}")
# Try to parse as JSON first
try:
response_data = response.json()
print(f"Successfully parsed JSON response")
if isinstance(response_data, dict) and 'content' in response_data:
response_text = response_data['content']
else:
response_text = str(response_data)
except Exception as e:
print(f"JSON parsing failed: {e}")
# Handle event stream format using AWS parser
parsed_data = AWSStreamParser.parse_event_stream_to_json(response_bytes)
response_text = parsed_data.get('content', "")
print(f"Parsed content length: {len(response_text)}")
if not response_text or response_text == "No content found in response":
# Last resort: try to decode as text
try:
response_text = response_bytes.decode('utf-8', errors='ignore')
print(f"Fallback text decode length: {len(response_text)}")
except Exception as decode_error:
response_text = f"Unable to decode response: {str(decode_error)}"
print(f"Final response text: {response_text[:200]}...")
except Exception as e:
print(f"Error in conversion: {e}")
import traceback
traceback.print_exc()
response_text = f"Error processing response: {str(e)}"
return AnthropicMessagesResponse(
model=model,
content=[AnthropicContentBlock(type="text", text=response_text)],
usage={
"input_tokens": 0,
"output_tokens": 0
}
)
async def create_anthropic_streaming_response(response, model: str):
"""Convert AWS event stream to Anthropic streaming format"""
print(f"Starting Anthropic streaming response, status: {response.status_code}")
async def generate():
# Send message_start event
message_start = {
"type": "message_start",
"message": {
"id": f"msg_{uuid.uuid4()}",
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0}
}
}
print(f"Sending message_start: {message_start}")
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
# Send content_block_start event
content_block_start = {
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "text",
"text": ""
}
}
yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
# Read response and stream content
content = ""
chunk_count = 0
# Read the entire response as bytes first
response_bytes = response.content
print(f"Anthropic streaming response bytes length: {len(response_bytes)}")
# Parse the AWS event stream
try:
# Convert bytes to string
if isinstance(response_bytes, bytes):
response_str = response_bytes.decode('utf-8', errors='ignore')
else:
response_str = str(response_bytes)
# Look for content in the AWS event stream
# Method 1: Look for JSON objects with content
json_pattern = r'\{[^{}]*"content"[^{}]*\}'
json_matches = re.findall(json_pattern, response_str, re.DOTALL)
if json_matches:
for match in json_matches:
try:
data = json.loads(match)
if 'content' in data and data['content']:
chunk_text = data['content']
content += chunk_text
chunk_count += 1
# Send content_block_delta event
content_block_delta = {
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": chunk_text
}
}
print(f"Streaming Anthropic JSON chunk {chunk_count}: {chunk_text[:50]}...")
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
# Small delay to simulate streaming
import asyncio
await asyncio.sleep(0.01)
except Exception as e:
print(f"Error streaming JSON chunk: {e}")
continue
else:
# Method 2: Try to extract readable text
readable_text = re.sub(r'[^\x20-\x7E\n\r\t\u4e00-\u9fff]', '', response_str)
# Look for Chinese text specifically
chinese_pattern = r'[\u4e00-\u9fff][\u4e00-\u9fff\s\.,!?]*[\u4e00-\u9fff]'
chinese_matches = re.findall(chinese_pattern, response_str)
if chinese_matches:
combined_text = ''.join(chinese_matches)
# Split into chunks for streaming
chunk_size = max(1, len(combined_text) // 10)
for i in range(0, len(combined_text), chunk_size):
chunk_text = combined_text[i:i + chunk_size]
content += chunk_text
chunk_count += 1
# Send content_block_delta event
content_block_delta = {
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": chunk_text
}
}
print(f"Streaming Anthropic Chinese text chunk {chunk_count}: {chunk_text[:50]}...")
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
import asyncio
await asyncio.sleep(0.05)
else:
# Method 3: Use the entire readable text
if readable_text.strip():
content_block_delta = {
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": readable_text.strip()
}
}
print(f"Streaming Anthropic fallback text: {readable_text[:100]}...")
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
content = readable_text.strip()
except Exception as e:
print(f"Error in Anthropic streaming generation: {e}")
import traceback
traceback.print_exc()
# Send error as content
error_delta = {
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": f"Error: {str(e)}"
}
}
yield f"event: content_block_delta\ndata: {json.dumps(error_delta)}\n\n"
print(f"Anthropic streaming complete, total chunks: {chunk_count}, content length: {len(content)}")
# Send content_block_stop event
content_block_stop = {
"type": "content_block_stop",
"index": 0
}
yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
# Send message_stop event
message_stop = {
"type": "message_stop"
}
yield f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# API endpoints
@app.post("/v1/messages")
async def create_messages(request: AnthropicMessagesRequest):
if request.model != MODEL_NAME:
raise HTTPException(status_code=400, detail=f"Only {MODEL_NAME} is supported")
# Convert Anthropic format to internal ChatMessage format
chat_messages = anthropic_to_chat_messages(request)
# Call the Kiro API
response = await call_kiro_api(chat_messages, stream=request.stream)
if request.stream:
return await create_anthropic_streaming_response(response, request.model)
else:
return await create_anthropic_response(response, request.model)
# Health check
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "ki2api"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)