|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import uuid |
|
|
import time |
|
|
import asyncio |
|
|
import httpx |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, Request, HTTPException, Depends, Security |
|
|
from fastapi.security.api_key import APIKeyHeader |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from loguru import logger |
|
|
from typing import AsyncGenerator, Set, Optional, Dict, Any, List |
|
|
|
|
|
|
|
|
logger.remove() |
|
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper() |
|
|
logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") |
|
|
|
|
|
|
|
|
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") |
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") |
|
|
VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) |
|
|
CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0)) |
|
|
READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0)) |
|
|
WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0)) |
|
|
POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0)) |
|
|
MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100)) |
|
|
MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20)) |
|
|
HTTP_PROXY = os.getenv("HTTP_PROXY") |
|
|
|
|
|
|
|
|
client: httpx.AsyncClient |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Manage the lifespan of the httpx client.""" |
|
|
global client |
|
|
limits = httpx.Limits(max_connections=MAX_CONNECTIONS, max_keepalive_connections=MAX_KEEPALIVE) |
|
|
timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT) |
|
|
proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None |
|
|
|
|
|
logger.info("Initializing httpx client for upstream requests.") |
|
|
|
|
|
if proxy_config: |
|
|
logger.info(f"Using outbound proxy: {HTTP_PROXY}") |
|
|
if not OPENAI_API_KEY: |
|
|
logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.") |
|
|
if not VALID_API_KEYS: |
|
|
logger.warning("PROXY_API_KEYS is not set. The proxy endpoint will be open to anyone (NOT RECOMMENDED for production).") |
|
|
|
|
|
client = httpx.AsyncClient( |
|
|
limits=limits, |
|
|
timeout=timeout_config, |
|
|
proxies=proxy_config, |
|
|
http2=True, |
|
|
follow_redirects=True |
|
|
) |
|
|
yield |
|
|
logger.info("Closing httpx client.") |
|
|
await client.aclose() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Claude to OpenAI Proxy", |
|
|
description="A proxy server that translates Claude API requests to OpenAI API format and back.", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
API_KEY_NAME = "X-API-Key" |
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) |
|
|
|
|
|
async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str: |
|
|
"""Validate the API key provided in the request header.""" |
|
|
if not VALID_API_KEYS: |
|
|
logger.warning("No PROXY_API_KEYS configured. Allowing request.") |
|
|
return "unsecured_dummy_key" |
|
|
|
|
|
if key is None: |
|
|
logger.warning("API key missing from request header.") |
|
|
raise HTTPException(status_code=401, detail=f"API Key required in header '{API_KEY_NAME}'") |
|
|
if key not in VALID_API_KEYS: |
|
|
logger.warning(f"Invalid API key received (length: {len(key)}).") |
|
|
raise HTTPException(status_code=401, detail="Invalid or expired API Key") |
|
|
logger.debug(f"Valid API key received (length: {len(key)}).") |
|
|
return key |
|
|
|
|
|
|
|
|
|
|
|
def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Converts a Claude API request body to OpenAI API format.""" |
|
|
messages = [] |
|
|
system_prompt = claude_request.get("system") |
|
|
if system_prompt: |
|
|
system_content = "" |
|
|
if isinstance(system_prompt, list): |
|
|
system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text") |
|
|
elif isinstance(system_prompt, str): |
|
|
system_content = system_prompt |
|
|
if system_content: |
|
|
messages.append({"role": "system", "content": system_content}) |
|
|
|
|
|
for msg in claude_request.get("messages", []): |
|
|
role = msg.get("role") |
|
|
content_parts = [] |
|
|
if isinstance(msg.get("content"), list): |
|
|
for block in msg.get("content", []): |
|
|
if block.get("type") == "text": |
|
|
content_parts.append(block.get("text", "")) |
|
|
elif isinstance(msg.get("content"), str): |
|
|
content_parts.append(msg.get("content")) |
|
|
if role and content_parts: |
|
|
messages.append({"role": role, "content": "\n".join(content_parts)}) |
|
|
|
|
|
openai_payload = { |
|
|
"model": claude_request.get("model", "gpt-3.5-turbo"), |
|
|
"messages": messages, |
|
|
"stream": claude_request.get("stream", False), |
|
|
**({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}), |
|
|
**({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}), |
|
|
**({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}), |
|
|
**({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}), |
|
|
} |
|
|
return openai_payload |
|
|
|
|
|
def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]: |
|
|
"""Converts a non-streaming OpenAI response to Claude API format.""" |
|
|
try: |
|
|
choice = openai_response.get("choices", [{}])[0] |
|
|
message = choice.get("message", {}) |
|
|
content = message.get("content", "") |
|
|
role = message.get("role", "assistant") |
|
|
finish_reason = choice.get("finish_reason", "stop") |
|
|
|
|
|
stop_reason_map = { |
|
|
"stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", |
|
|
"content_filter": "stop_sequence", "null": "stop_sequence", |
|
|
} |
|
|
claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence") |
|
|
|
|
|
usage = openai_response.get("usage", {}) |
|
|
prompt_tokens = usage.get("prompt_tokens", 0) |
|
|
completion_tokens = usage.get("completion_tokens", 0) |
|
|
|
|
|
claude_response = { |
|
|
"id": openai_response.get("id", claude_request_id), |
|
|
"type": "message", "role": role, |
|
|
"content": [{"type": "text", "text": content or ""}], |
|
|
"model": openai_response.get("model", "claude-proxy-model"), |
|
|
"stop_reason": claude_stop_reason, "stop_sequence": None, |
|
|
"usage": { "input_tokens": prompt_tokens, "output_tokens": completion_tokens }, |
|
|
} |
|
|
logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.") |
|
|
return claude_response |
|
|
except (KeyError, IndexError, TypeError) as e: |
|
|
logger.error(f"[{claude_request_id}] Error converting non-streaming OpenAI response: {e}") |
|
|
raise ValueError(f"Failed to parse OpenAI response: {e}") |
|
|
|
|
|
async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]: |
|
|
"""Converts an OpenAI SSE stream to Claude API SSE format.""" |
|
|
message_id = claude_request_id |
|
|
accumulated_content_len = 0 |
|
|
openai_finish_reason = None |
|
|
input_tokens = 0 |
|
|
output_tokens = 0 |
|
|
last_ping_time = time.time() |
|
|
|
|
|
logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.") |
|
|
|
|
|
|
|
|
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n" |
|
|
|
|
|
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" |
|
|
|
|
|
yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" |
|
|
|
|
|
try: |
|
|
async for line in openai_response.aiter_lines(): |
|
|
line = line.strip() |
|
|
if not line: continue |
|
|
|
|
|
if line.startswith("data:"): |
|
|
data_str = line[len("data: "):].strip() |
|
|
if data_str == "[DONE]": |
|
|
logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.") |
|
|
break |
|
|
|
|
|
try: |
|
|
data = json.loads(data_str) |
|
|
choices = data.get("choices", []) |
|
|
if not choices: continue |
|
|
|
|
|
|
|
|
usage_update = data.get("usage") |
|
|
if usage_update and usage_update.get("prompt_tokens") is not None and input_tokens == 0: |
|
|
input_tokens = usage_update.get("prompt_tokens", 0) |
|
|
logger.debug(f"[{message_id}] Captured input_tokens: {input_tokens}") |
|
|
|
|
|
|
|
|
delta = choices[0].get("delta", {}) |
|
|
content_chunk = delta.get("content") |
|
|
|
|
|
if choices[0].get("finish_reason"): |
|
|
openai_finish_reason = choices[0].get("finish_reason") |
|
|
logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}") |
|
|
|
|
|
|
|
|
if usage_update and usage_update.get("completion_tokens") is not None: |
|
|
output_tokens = usage_update.get("completion_tokens", output_tokens) |
|
|
logger.debug(f"[{message_id}] Received completion_tokens update: {output_tokens}") |
|
|
|
|
|
|
|
|
if content_chunk: |
|
|
accumulated_content_len += len(content_chunk) |
|
|
|
|
|
if not (usage_update and usage_update.get("completion_tokens") is not None): |
|
|
output_tokens += 1 |
|
|
|
|
|
|
|
|
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n" |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}") |
|
|
except Exception as e: |
|
|
logger.error(f"[{message_id}] Error processing stream data chunk: {e}") |
|
|
|
|
|
current_time = time.time() |
|
|
if current_time - last_ping_time >= 10: |
|
|
yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" |
|
|
last_ping_time = current_time |
|
|
|
|
|
except httpx.ReadTimeout: |
|
|
logger.error(f"[{message_id}] Timeout reading from OpenAI stream.") |
|
|
openai_finish_reason = "error_timeout" |
|
|
yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n" |
|
|
except Exception as e: |
|
|
logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}") |
|
|
openai_finish_reason = "error_exception" |
|
|
yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n" |
|
|
finally: |
|
|
stop_reason_map = { |
|
|
"stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", |
|
|
"content_filter": "stop_sequence", "null": "stop_sequence", |
|
|
"error_timeout": "error", "error_exception": "error", |
|
|
} |
|
|
claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence") |
|
|
|
|
|
logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}") |
|
|
|
|
|
|
|
|
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" |
|
|
|
|
|
|
|
|
final_delta = { |
|
|
'type': 'message_delta', |
|
|
'delta': { |
|
|
'stop_reason': claude_stop_reason, |
|
|
'stop_sequence': None |
|
|
} |
|
|
} |
|
|
yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n" |
|
|
|
|
|
|
|
|
|
|
|
final_stop_event_data = { |
|
|
'type': 'message_stop', |
|
|
'usage': { |
|
|
'input_tokens': input_tokens, |
|
|
'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4) |
|
|
} |
|
|
} |
|
|
yield f"event: message_stop\ndata: {json.dumps(final_stop_event_data)}\n\n" |
|
|
|
|
|
|
|
|
logger.info(f"[{message_id}] Completed sending Claude SSE stream.") |
|
|
|
|
|
|
|
|
def create_error_response(status_code: int, error_type: str, message: str) -> JSONResponse: |
|
|
"""Creates a JSONResponse with a Claude-like error structure.""" |
|
|
return JSONResponse( |
|
|
status_code=status_code, |
|
|
content={"type": "error", "error": {"type": error_type, "message": message}} |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/messages", dependencies=[Depends(get_api_key)]) |
|
|
async def proxy_claude_to_openai(request: Request): |
|
|
""" |
|
|
Receives a Claude-formatted request, proxies it to OpenAI, |
|
|
and returns a Claude-formatted response. |
|
|
Requires a valid API key via the X-API-Key header. |
|
|
""" |
|
|
request_id = f"msg_{uuid.uuid4().hex[:24]}" |
|
|
try: |
|
|
claude_request_data = await request.json() |
|
|
logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}") |
|
|
except json.JSONDecodeError: |
|
|
logger.error(f"[{request_id}] Invalid JSON received in request body.") |
|
|
return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.") |
|
|
|
|
|
try: |
|
|
openai_payload = claude_request_to_openai_payload(claude_request_data) |
|
|
except Exception as e: |
|
|
logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}") |
|
|
return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}") |
|
|
|
|
|
is_streaming = openai_payload.get("stream", False) |
|
|
requested_model = openai_payload.get("model", "unknown_model") |
|
|
|
|
|
target_headers = { "Content-Type": "application/json" } |
|
|
if OPENAI_API_KEY: |
|
|
target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" |
|
|
logger.debug(f"[{request_id}] Added Authorization header to upstream request.") |
|
|
else: |
|
|
logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.") |
|
|
|
|
|
if logger.level("DEBUG").no >= logger.level(log_level).no: |
|
|
logged_headers = target_headers.copy() |
|
|
if "Authorization" in logged_headers: logged_headers["Authorization"] = "Bearer [REDACTED]" |
|
|
logger.debug(f"[{request_id}] Sending request to upstream API.") |
|
|
logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}") |
|
|
try: |
|
|
payload_str = json.dumps(openai_payload, indent=2) |
|
|
max_log_len = 1024 |
|
|
logger.debug(f"[{request_id}] Upstream Payload {'(truncated)' if len(payload_str) > max_log_len else ''}: {payload_str[:max_log_len]}{'...' if len(payload_str) > max_log_len else ''}") |
|
|
except Exception as log_e: |
|
|
logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}") |
|
|
else: |
|
|
logger.debug(f"[{request_id}] Sending request to upstream API...") |
|
|
|
|
|
try: |
|
|
target_request = client.build_request("POST", OPENAI_API_ENDPOINT, headers=target_headers, json=openai_payload) |
|
|
response = await client.send(target_request, stream=is_streaming) |
|
|
response.raise_for_status() |
|
|
|
|
|
if is_streaming: |
|
|
logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.") |
|
|
return StreamingResponse( |
|
|
stream_openai_response_to_claude_events(response, request_id, requested_model), |
|
|
media_type="text/event-stream", |
|
|
headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
|
) |
|
|
else: |
|
|
logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.") |
|
|
openai_response_data = response.json() |
|
|
logger.debug(f"[{request_id}] Received non-streaming response from upstream.") |
|
|
try: |
|
|
claude_response_data = openai_response_to_claude_response(openai_response_data, request_id) |
|
|
return JSONResponse(content=claude_response_data) |
|
|
except ValueError as e: |
|
|
logger.error(f"[{request_id}] Failed to convert upstream non-streaming response: {e}") |
|
|
return create_error_response(500, "api_error", f"Error processing response from upstream API: {e}") |
|
|
except Exception as e: |
|
|
logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}") |
|
|
return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.") |
|
|
|
|
|
except httpx.HTTPStatusError as e: |
|
|
status_code = e.response.status_code |
|
|
error_detail_text = "[Could not decode error response]" |
|
|
try: |
|
|
error_detail = e.response.json(); error_detail_text = json.dumps(error_detail) |
|
|
except json.JSONDecodeError: |
|
|
try: error_detail_text = e.response.text |
|
|
except Exception: logger.warning(f"[{request_id}] Could not read error response body as text.") |
|
|
logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...") |
|
|
if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})." |
|
|
elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})." |
|
|
elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})." |
|
|
elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})." |
|
|
elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})." |
|
|
else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})." |
|
|
return create_error_response(status_code, err_type, msg) |
|
|
except httpx.TimeoutException: |
|
|
logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).") |
|
|
return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.") |
|
|
except httpx.RequestError as e: |
|
|
logger.error(f"[{request_id}] Network error connecting to target endpoint: {type(e).__name__}") |
|
|
return create_error_response(502, "api_error", f"Bad Gateway: Network error connecting to upstream API.") |
|
|
except Exception as e: |
|
|
logger.exception(f"[{request_id}] Unexpected error during proxy operation: {e}") |
|
|
return create_error_response(500, "internal_server_error", f"Internal Server Error: {e}") |
|
|
|
|
|
|
|
|
@app.get("/health", summary="Health Check", tags=["Management"]) |
|
|
async def health_check(): |
|
|
"""Returns the operational status of the proxy server.""" |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
try: |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
logger.info("Loaded environment variables from .env file (if present).") |
|
|
PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") |
|
|
VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) |
|
|
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") |
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
HTTP_PROXY = os.getenv("HTTP_PROXY") |
|
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper() |
|
|
logger.remove() |
|
|
logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") |
|
|
logger.info(f"Log level set to: {log_level}") |
|
|
logger.info(f"Valid Proxy API Keys configured: {len(VALID_API_KEYS)}") |
|
|
except ImportError: |
|
|
logger.info("python-dotenv not installed, skipping .env file loading.") |
|
|
|
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
host = os.getenv("HOST", "0.0.0.0") |
|
|
log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info" |
|
|
logger.info(f"Starting Uvicorn server on {host}:{port}") |
|
|
uvicorn.run("proxy_server:app", host=host, port=port, reload=True, log_level=log_config_level) |
|
|
|