import os import sys import json import uuid import time import asyncio import httpx from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Depends, Security from fastapi.security.api_key import APIKeyHeader from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from loguru import logger from typing import AsyncGenerator, Set, Optional, Dict, Any, List # --- Logging Configuration --- logger.remove() log_level = os.getenv("LOG_LEVEL", "INFO").upper() logger.add(sys.stderr, level=log_level, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}") # --- Environment Variable Configuration --- OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0)) READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0)) WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0)) POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0)) MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100)) MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20)) HTTP_PROXY = os.getenv("HTTP_PROXY") # --- Global httpx Client --- client: httpx.AsyncClient @asynccontextmanager async def lifespan(app: FastAPI): """Manage the lifespan of the httpx client.""" global client limits = httpx.Limits(max_connections=MAX_CONNECTIONS, max_keepalive_connections=MAX_KEEPALIVE) timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT) proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None logger.info("Initializing httpx client for upstream requests.") if proxy_config: logger.info(f"Using outbound proxy: {HTTP_PROXY}") if not OPENAI_API_KEY: logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.") if not VALID_API_KEYS: logger.warning("PROXY_API_KEYS is not set. The proxy endpoint will be open to anyone (NOT RECOMMENDED for production).") client = httpx.AsyncClient( limits=limits, timeout=timeout_config, proxies=proxy_config, http2=True, follow_redirects=True ) yield logger.info("Closing httpx client.") await client.aclose() # --- FastAPI Application Setup --- app = FastAPI( title="Claude to OpenAI Proxy", description="A proxy server that translates Claude API requests to OpenAI API format and back.", version="1.0.0", lifespan=lifespan ) # --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- API Key Authentication --- API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str: """Validate the API key provided in the request header.""" if not VALID_API_KEYS: logger.warning("No PROXY_API_KEYS configured. Allowing request.") return "unsecured_dummy_key" if key is None: logger.warning("API key missing from request header.") raise HTTPException(status_code=401, detail=f"API Key required in header '{API_KEY_NAME}'") if key not in VALID_API_KEYS: logger.warning(f"Invalid API key received (length: {len(key)}).") raise HTTPException(status_code=401, detail="Invalid or expired API Key") logger.debug(f"Valid API key received (length: {len(key)}).") return key # --- Format Conversion Logic --- def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]: """Converts a Claude API request body to OpenAI API format.""" messages = [] system_prompt = claude_request.get("system") if system_prompt: system_content = "" if isinstance(system_prompt, list): system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text") elif isinstance(system_prompt, str): system_content = system_prompt if system_content: messages.append({"role": "system", "content": system_content}) for msg in claude_request.get("messages", []): role = msg.get("role") content_parts = [] if isinstance(msg.get("content"), list): for block in msg.get("content", []): if block.get("type") == "text": content_parts.append(block.get("text", "")) elif isinstance(msg.get("content"), str): content_parts.append(msg.get("content")) if role and content_parts: messages.append({"role": role, "content": "\n".join(content_parts)}) openai_payload = { "model": claude_request.get("model", "gpt-3.5-turbo"), "messages": messages, "stream": claude_request.get("stream", False), **({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}), **({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}), **({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}), **({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}), } return openai_payload def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]: """Converts a non-streaming OpenAI response to Claude API format.""" try: choice = openai_response.get("choices", [{}])[0] message = choice.get("message", {}) content = message.get("content", "") role = message.get("role", "assistant") finish_reason = choice.get("finish_reason", "stop") stop_reason_map = { "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", "content_filter": "stop_sequence", "null": "stop_sequence", } claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence") usage = openai_response.get("usage", {}) prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) claude_response = { "id": openai_response.get("id", claude_request_id), "type": "message", "role": role, "content": [{"type": "text", "text": content or ""}], "model": openai_response.get("model", "claude-proxy-model"), "stop_reason": claude_stop_reason, "stop_sequence": None, "usage": { "input_tokens": prompt_tokens, "output_tokens": completion_tokens }, } logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.") return claude_response except (KeyError, IndexError, TypeError) as e: logger.error(f"[{claude_request_id}] Error converting non-streaming OpenAI response: {e}") raise ValueError(f"Failed to parse OpenAI response: {e}") async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]: """Converts an OpenAI SSE stream to Claude API SSE format.""" message_id = claude_request_id accumulated_content_len = 0 openai_finish_reason = None input_tokens = 0 # Try to capture this output_tokens = 0 last_ping_time = time.time() logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.") # 1. Send message_start event yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n" # Initial usage is 0 # 2. Send content_block_start event yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" # 3. Send initial ping yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" try: async for line in openai_response.aiter_lines(): line = line.strip() if not line: continue if line.startswith("data:"): data_str = line[len("data: "):].strip() if data_str == "[DONE]": logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.") break try: data = json.loads(data_str) choices = data.get("choices", []) if not choices: continue # --- Try to capture input tokens if sent early --- usage_update = data.get("usage") if usage_update and usage_update.get("prompt_tokens") is not None and input_tokens == 0: input_tokens = usage_update.get("prompt_tokens", 0) logger.debug(f"[{message_id}] Captured input_tokens: {input_tokens}") # --- delta = choices[0].get("delta", {}) content_chunk = delta.get("content") if choices[0].get("finish_reason"): openai_finish_reason = choices[0].get("finish_reason") logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}") # Update output tokens based on usage update if available if usage_update and usage_update.get("completion_tokens") is not None: output_tokens = usage_update.get("completion_tokens", output_tokens) logger.debug(f"[{message_id}] Received completion_tokens update: {output_tokens}") if content_chunk: accumulated_content_len += len(content_chunk) # Estimate output tokens if not provided by usage update if not (usage_update and usage_update.get("completion_tokens") is not None): output_tokens += 1 # Simple increment per chunk as fallback # 4. Send content_block_delta yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n" except json.JSONDecodeError: logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}") except Exception as e: logger.error(f"[{message_id}] Error processing stream data chunk: {e}") current_time = time.time() if current_time - last_ping_time >= 10: yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" last_ping_time = current_time except httpx.ReadTimeout: logger.error(f"[{message_id}] Timeout reading from OpenAI stream.") openai_finish_reason = "error_timeout" yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n" except Exception as e: logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}") openai_finish_reason = "error_exception" yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n" finally: stop_reason_map = { "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", "content_filter": "stop_sequence", "null": "stop_sequence", "error_timeout": "error", "error_exception": "error", } claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence") logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}") # 5. Send content_block_stop yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" # 6. Send message_delta with final stop reason ONLY final_delta = { 'type': 'message_delta', 'delta': { 'stop_reason': claude_stop_reason, 'stop_sequence': None } } yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n" # 7. Send message_stop (including final usage) # --- FIX: Use simpler 'usage' structure in message_stop --- final_stop_event_data = { 'type': 'message_stop', 'usage': { 'input_tokens': input_tokens, 'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4) # Use estimate if needed } } yield f"event: message_stop\ndata: {json.dumps(final_stop_event_data)}\n\n" # --- End Fix --- logger.info(f"[{message_id}] Completed sending Claude SSE stream.") def create_error_response(status_code: int, error_type: str, message: str) -> JSONResponse: """Creates a JSONResponse with a Claude-like error structure.""" return JSONResponse( status_code=status_code, content={"type": "error", "error": {"type": error_type, "message": message}} ) # --- Main Proxy Endpoint --- @app.post("/v1/messages", dependencies=[Depends(get_api_key)]) async def proxy_claude_to_openai(request: Request): """ Receives a Claude-formatted request, proxies it to OpenAI, and returns a Claude-formatted response. Requires a valid API key via the X-API-Key header. """ request_id = f"msg_{uuid.uuid4().hex[:24]}" try: claude_request_data = await request.json() logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}") except json.JSONDecodeError: logger.error(f"[{request_id}] Invalid JSON received in request body.") return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.") try: openai_payload = claude_request_to_openai_payload(claude_request_data) except Exception as e: logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}") return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}") is_streaming = openai_payload.get("stream", False) requested_model = openai_payload.get("model", "unknown_model") target_headers = { "Content-Type": "application/json" } if OPENAI_API_KEY: target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" logger.debug(f"[{request_id}] Added Authorization header to upstream request.") else: logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.") if logger.level("DEBUG").no >= logger.level(log_level).no: logged_headers = target_headers.copy() if "Authorization" in logged_headers: logged_headers["Authorization"] = "Bearer [REDACTED]" logger.debug(f"[{request_id}] Sending request to upstream API.") logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}") try: payload_str = json.dumps(openai_payload, indent=2) max_log_len = 1024 logger.debug(f"[{request_id}] Upstream Payload {'(truncated)' if len(payload_str) > max_log_len else ''}: {payload_str[:max_log_len]}{'...' if len(payload_str) > max_log_len else ''}") except Exception as log_e: logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}") else: logger.debug(f"[{request_id}] Sending request to upstream API...") try: target_request = client.build_request("POST", OPENAI_API_ENDPOINT, headers=target_headers, json=openai_payload) response = await client.send(target_request, stream=is_streaming) response.raise_for_status() if is_streaming: logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.") return StreamingResponse( stream_openai_response_to_claude_events(response, request_id, requested_model), media_type="text/event-stream", headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache", "Connection": "keep-alive"} ) else: logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.") openai_response_data = response.json() logger.debug(f"[{request_id}] Received non-streaming response from upstream.") try: claude_response_data = openai_response_to_claude_response(openai_response_data, request_id) return JSONResponse(content=claude_response_data) except ValueError as e: logger.error(f"[{request_id}] Failed to convert upstream non-streaming response: {e}") return create_error_response(500, "api_error", f"Error processing response from upstream API: {e}") except Exception as e: logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}") return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.") except httpx.HTTPStatusError as e: status_code = e.response.status_code error_detail_text = "[Could not decode error response]" try: error_detail = e.response.json(); error_detail_text = json.dumps(error_detail) except json.JSONDecodeError: try: error_detail_text = e.response.text except Exception: logger.warning(f"[{request_id}] Could not read error response body as text.") logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...") if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})." elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})." elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})." elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})." elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})." else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})." return create_error_response(status_code, err_type, msg) except httpx.TimeoutException: logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).") return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.") except httpx.RequestError as e: logger.error(f"[{request_id}] Network error connecting to target endpoint: {type(e).__name__}") return create_error_response(502, "api_error", f"Bad Gateway: Network error connecting to upstream API.") except Exception as e: logger.exception(f"[{request_id}] Unexpected error during proxy operation: {e}") return create_error_response(500, "internal_server_error", f"Internal Server Error: {e}") # --- Health Check Endpoint --- @app.get("/health", summary="Health Check", tags=["Management"]) async def health_check(): """Returns the operational status of the proxy server.""" return {"status": "healthy"} # --- Run with Uvicorn (for local development) --- if __name__ == "__main__": import uvicorn try: from dotenv import load_dotenv load_dotenv() logger.info("Loaded environment variables from .env file (if present).") PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") HTTP_PROXY = os.getenv("HTTP_PROXY") log_level = os.getenv("LOG_LEVEL", "INFO").upper() logger.remove() logger.add(sys.stderr, level=log_level, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}") logger.info(f"Log level set to: {log_level}") logger.info(f"Valid Proxy API Keys configured: {len(VALID_API_KEYS)}") except ImportError: logger.info("python-dotenv not installed, skipping .env file loading.") port = int(os.getenv("PORT", 7860)) host = os.getenv("HOST", "0.0.0.0") log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info" logger.info(f"Starting Uvicorn server on {host}:{port}") uvicorn.run("proxy_server:app", host=host, port=port, reload=True, log_level=log_config_level)