import os
import sys
import json
import uuid
import time
import asyncio
import httpx
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, Depends, Security
from fastapi.security.api_key import APIKeyHeader
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from typing import AsyncGenerator, Set, Optional, Dict, Any, List
# --- Logging Configuration ---
logger.remove()
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logger.add(sys.stderr, level=log_level, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}")
# --- Environment Variable Configuration ---
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "")
VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0))
READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0))
WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0))
POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0))
MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100))
MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20))
HTTP_PROXY = os.getenv("HTTP_PROXY")
# --- Global httpx Client ---
client: httpx.AsyncClient
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage the lifespan of the httpx client."""
global client
limits = httpx.Limits(max_connections=MAX_CONNECTIONS, max_keepalive_connections=MAX_KEEPALIVE)
timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT)
proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None
logger.info("Initializing httpx client for upstream requests.")
if proxy_config:
logger.info(f"Using outbound proxy: {HTTP_PROXY}")
if not OPENAI_API_KEY:
logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.")
if not VALID_API_KEYS:
logger.warning("PROXY_API_KEYS is not set. The proxy endpoint will be open to anyone (NOT RECOMMENDED for production).")
client = httpx.AsyncClient(
limits=limits,
timeout=timeout_config,
proxies=proxy_config,
http2=True,
follow_redirects=True
)
yield
logger.info("Closing httpx client.")
await client.aclose()
# --- FastAPI Application Setup ---
app = FastAPI(
title="Claude to OpenAI Proxy",
description="A proxy server that translates Claude API requests to OpenAI API format and back.",
version="1.0.0",
lifespan=lifespan
)
# --- CORS Middleware ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- API Key Authentication ---
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str:
"""Validate the API key provided in the request header."""
if not VALID_API_KEYS:
logger.warning("No PROXY_API_KEYS configured. Allowing request.")
return "unsecured_dummy_key"
if key is None:
logger.warning("API key missing from request header.")
raise HTTPException(status_code=401, detail=f"API Key required in header '{API_KEY_NAME}'")
if key not in VALID_API_KEYS:
logger.warning(f"Invalid API key received (length: {len(key)}).")
raise HTTPException(status_code=401, detail="Invalid or expired API Key")
logger.debug(f"Valid API key received (length: {len(key)}).")
return key
# --- Format Conversion Logic ---
def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a Claude API request body to OpenAI API format."""
messages = []
system_prompt = claude_request.get("system")
if system_prompt:
system_content = ""
if isinstance(system_prompt, list):
system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text")
elif isinstance(system_prompt, str):
system_content = system_prompt
if system_content:
messages.append({"role": "system", "content": system_content})
for msg in claude_request.get("messages", []):
role = msg.get("role")
content_parts = []
if isinstance(msg.get("content"), list):
for block in msg.get("content", []):
if block.get("type") == "text":
content_parts.append(block.get("text", ""))
elif isinstance(msg.get("content"), str):
content_parts.append(msg.get("content"))
if role and content_parts:
messages.append({"role": role, "content": "\n".join(content_parts)})
openai_payload = {
"model": claude_request.get("model", "gpt-3.5-turbo"),
"messages": messages,
"stream": claude_request.get("stream", False),
**({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}),
**({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}),
**({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}),
**({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}),
}
return openai_payload
def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]:
"""Converts a non-streaming OpenAI response to Claude API format."""
try:
choice = openai_response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
role = message.get("role", "assistant")
finish_reason = choice.get("finish_reason", "stop")
stop_reason_map = {
"stop": "end_turn", "length": "max_tokens", "function_call": "tool_use",
"content_filter": "stop_sequence", "null": "stop_sequence",
}
claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence")
usage = openai_response.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
claude_response = {
"id": openai_response.get("id", claude_request_id),
"type": "message", "role": role,
"content": [{"type": "text", "text": content or ""}],
"model": openai_response.get("model", "claude-proxy-model"),
"stop_reason": claude_stop_reason, "stop_sequence": None,
"usage": { "input_tokens": prompt_tokens, "output_tokens": completion_tokens },
}
logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.")
return claude_response
except (KeyError, IndexError, TypeError) as e:
logger.error(f"[{claude_request_id}] Error converting non-streaming OpenAI response: {e}")
raise ValueError(f"Failed to parse OpenAI response: {e}")
async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]:
"""Converts an OpenAI SSE stream to Claude API SSE format."""
message_id = claude_request_id
accumulated_content_len = 0
openai_finish_reason = None
input_tokens = 0 # Try to capture this
output_tokens = 0
last_ping_time = time.time()
logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.")
# 1. Send message_start event
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n" # Initial usage is 0
# 2. Send content_block_start event
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
# 3. Send initial ping
yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
try:
async for line in openai_response.aiter_lines():
line = line.strip()
if not line: continue
if line.startswith("data:"):
data_str = line[len("data: "):].strip()
if data_str == "[DONE]":
logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.")
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices: continue
# --- Try to capture input tokens if sent early ---
usage_update = data.get("usage")
if usage_update and usage_update.get("prompt_tokens") is not None and input_tokens == 0:
input_tokens = usage_update.get("prompt_tokens", 0)
logger.debug(f"[{message_id}] Captured input_tokens: {input_tokens}")
# ---
delta = choices[0].get("delta", {})
content_chunk = delta.get("content")
if choices[0].get("finish_reason"):
openai_finish_reason = choices[0].get("finish_reason")
logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}")
# Update output tokens based on usage update if available
if usage_update and usage_update.get("completion_tokens") is not None:
output_tokens = usage_update.get("completion_tokens", output_tokens)
logger.debug(f"[{message_id}] Received completion_tokens update: {output_tokens}")
if content_chunk:
accumulated_content_len += len(content_chunk)
# Estimate output tokens if not provided by usage update
if not (usage_update and usage_update.get("completion_tokens") is not None):
output_tokens += 1 # Simple increment per chunk as fallback
# 4. Send content_block_delta
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n"
except json.JSONDecodeError:
logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}")
except Exception as e:
logger.error(f"[{message_id}] Error processing stream data chunk: {e}")
current_time = time.time()
if current_time - last_ping_time >= 10:
yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
last_ping_time = current_time
except httpx.ReadTimeout:
logger.error(f"[{message_id}] Timeout reading from OpenAI stream.")
openai_finish_reason = "error_timeout"
yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n"
except Exception as e:
logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}")
openai_finish_reason = "error_exception"
yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n"
finally:
stop_reason_map = {
"stop": "end_turn", "length": "max_tokens", "function_call": "tool_use",
"content_filter": "stop_sequence", "null": "stop_sequence",
"error_timeout": "error", "error_exception": "error",
}
claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence")
logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}")
# 5. Send content_block_stop
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
# 6. Send message_delta with final stop reason ONLY
final_delta = {
'type': 'message_delta',
'delta': {
'stop_reason': claude_stop_reason,
'stop_sequence': None
}
}
yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n"
# 7. Send message_stop (including final usage)
# --- FIX: Use simpler 'usage' structure in message_stop ---
final_stop_event_data = {
'type': 'message_stop',
'usage': {
'input_tokens': input_tokens,
'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4) # Use estimate if needed
}
}
yield f"event: message_stop\ndata: {json.dumps(final_stop_event_data)}\n\n"
# --- End Fix ---
logger.info(f"[{message_id}] Completed sending Claude SSE stream.")
def create_error_response(status_code: int, error_type: str, message: str) -> JSONResponse:
"""Creates a JSONResponse with a Claude-like error structure."""
return JSONResponse(
status_code=status_code,
content={"type": "error", "error": {"type": error_type, "message": message}}
)
# --- Main Proxy Endpoint ---
@app.post("/v1/messages", dependencies=[Depends(get_api_key)])
async def proxy_claude_to_openai(request: Request):
"""
Receives a Claude-formatted request, proxies it to OpenAI,
and returns a Claude-formatted response.
Requires a valid API key via the X-API-Key header.
"""
request_id = f"msg_{uuid.uuid4().hex[:24]}"
try:
claude_request_data = await request.json()
logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}")
except json.JSONDecodeError:
logger.error(f"[{request_id}] Invalid JSON received in request body.")
return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.")
try:
openai_payload = claude_request_to_openai_payload(claude_request_data)
except Exception as e:
logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}")
return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}")
is_streaming = openai_payload.get("stream", False)
requested_model = openai_payload.get("model", "unknown_model")
target_headers = { "Content-Type": "application/json" }
if OPENAI_API_KEY:
target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
logger.debug(f"[{request_id}] Added Authorization header to upstream request.")
else:
logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.")
if logger.level("DEBUG").no >= logger.level(log_level).no:
logged_headers = target_headers.copy()
if "Authorization" in logged_headers: logged_headers["Authorization"] = "Bearer [REDACTED]"
logger.debug(f"[{request_id}] Sending request to upstream API.")
logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}")
try:
payload_str = json.dumps(openai_payload, indent=2)
max_log_len = 1024
logger.debug(f"[{request_id}] Upstream Payload {'(truncated)' if len(payload_str) > max_log_len else ''}: {payload_str[:max_log_len]}{'...' if len(payload_str) > max_log_len else ''}")
except Exception as log_e:
logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}")
else:
logger.debug(f"[{request_id}] Sending request to upstream API...")
try:
target_request = client.build_request("POST", OPENAI_API_ENDPOINT, headers=target_headers, json=openai_payload)
response = await client.send(target_request, stream=is_streaming)
response.raise_for_status()
if is_streaming:
logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.")
return StreamingResponse(
stream_openai_response_to_claude_events(response, request_id, requested_model),
media_type="text/event-stream",
headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.")
openai_response_data = response.json()
logger.debug(f"[{request_id}] Received non-streaming response from upstream.")
try:
claude_response_data = openai_response_to_claude_response(openai_response_data, request_id)
return JSONResponse(content=claude_response_data)
except ValueError as e:
logger.error(f"[{request_id}] Failed to convert upstream non-streaming response: {e}")
return create_error_response(500, "api_error", f"Error processing response from upstream API: {e}")
except Exception as e:
logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}")
return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.")
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
error_detail_text = "[Could not decode error response]"
try:
error_detail = e.response.json(); error_detail_text = json.dumps(error_detail)
except json.JSONDecodeError:
try: error_detail_text = e.response.text
except Exception: logger.warning(f"[{request_id}] Could not read error response body as text.")
logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...")
if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})."
elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})."
elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})."
elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})."
elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})."
else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})."
return create_error_response(status_code, err_type, msg)
except httpx.TimeoutException:
logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).")
return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.")
except httpx.RequestError as e:
logger.error(f"[{request_id}] Network error connecting to target endpoint: {type(e).__name__}")
return create_error_response(502, "api_error", f"Bad Gateway: Network error connecting to upstream API.")
except Exception as e:
logger.exception(f"[{request_id}] Unexpected error during proxy operation: {e}")
return create_error_response(500, "internal_server_error", f"Internal Server Error: {e}")
# --- Health Check Endpoint ---
@app.get("/health", summary="Health Check", tags=["Management"])
async def health_check():
"""Returns the operational status of the proxy server."""
return {"status": "healthy"}
# --- Run with Uvicorn (for local development) ---
if __name__ == "__main__":
import uvicorn
try:
from dotenv import load_dotenv
load_dotenv()
logger.info("Loaded environment variables from .env file (if present).")
PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "")
VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
HTTP_PROXY = os.getenv("HTTP_PROXY")
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logger.remove()
logger.add(sys.stderr, level=log_level, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}")
logger.info(f"Log level set to: {log_level}")
logger.info(f"Valid Proxy API Keys configured: {len(VALID_API_KEYS)}")
except ImportError:
logger.info("python-dotenv not installed, skipping .env file loading.")
port = int(os.getenv("PORT", 7860))
host = os.getenv("HOST", "0.0.0.0")
log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info"
logger.info(f"Starting Uvicorn server on {host}:{port}")
uvicorn.run("proxy_server:app", host=host, port=port, reload=True, log_level=log_config_level)