Spaces:
Paused
Paused
| import os | |
| import sys | |
| import json | |
| import uuid | |
| import time | |
| import asyncio | |
| import httpx | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, Request, HTTPException, Depends, Security | |
| from fastapi.security.api_key import APIKeyHeader | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from loguru import logger | |
| from typing import AsyncGenerator, Set, Optional, Dict, Any, List | |
| # --- Logging Configuration --- | |
| logger.remove() | |
| log_level = os.getenv("LOG_LEVEL", "INFO").upper() | |
| logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") | |
| # --- Environment Variable Configuration --- | |
| OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") | |
| VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) | |
| CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0)) | |
| READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0)) | |
| WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0)) | |
| POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0)) | |
| MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100)) | |
| MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20)) | |
| HTTP_PROXY = os.getenv("HTTP_PROXY") | |
| # --- Global httpx Client --- | |
| client: httpx.AsyncClient | |
| async def lifespan(app: FastAPI): | |
| """Manage the lifespan of the httpx client.""" | |
| global client | |
| limits = httpx.Limits(max_connections=MAX_CONNECTIONS, max_keepalive_connections=MAX_KEEPALIVE) | |
| timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT) | |
| proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None | |
| logger.info("Initializing httpx client for upstream requests.") | |
| if proxy_config: | |
| logger.info(f"Using outbound proxy: {HTTP_PROXY}") | |
| if not OPENAI_API_KEY: | |
| logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.") | |
| if not VALID_API_KEYS: | |
| logger.warning("PROXY_API_KEYS is not set. The proxy endpoint will be open to anyone (NOT RECOMMENDED for production).") | |
| client = httpx.AsyncClient( | |
| limits=limits, | |
| timeout=timeout_config, | |
| proxies=proxy_config, | |
| http2=True, | |
| follow_redirects=True | |
| ) | |
| yield | |
| logger.info("Closing httpx client.") | |
| await client.aclose() | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI( | |
| title="Claude to OpenAI Proxy", | |
| description="A proxy server that translates Claude API requests to OpenAI API format and back.", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # --- CORS Middleware --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- API Key Authentication --- | |
| API_KEY_NAME = "X-API-Key" | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
| async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str: | |
| """Validate the API key provided in the request header.""" | |
| if not VALID_API_KEYS: | |
| logger.warning("No PROXY_API_KEYS configured. Allowing request.") | |
| return "unsecured_dummy_key" | |
| if key is None: | |
| logger.warning("API key missing from request header.") | |
| raise HTTPException(status_code=401, detail=f"API Key required in header '{API_KEY_NAME}'") | |
| if key not in VALID_API_KEYS: | |
| logger.warning(f"Invalid API key received (length: {len(key)}).") | |
| raise HTTPException(status_code=401, detail="Invalid or expired API Key") | |
| logger.debug(f"Valid API key received (length: {len(key)}).") | |
| return key | |
| # --- Format Conversion Logic --- | |
| def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]: | |
| """Converts a Claude API request body to OpenAI API format.""" | |
| messages = [] | |
| system_prompt = claude_request.get("system") | |
| if system_prompt: | |
| system_content = "" | |
| if isinstance(system_prompt, list): | |
| system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text") | |
| elif isinstance(system_prompt, str): | |
| system_content = system_prompt | |
| if system_content: | |
| messages.append({"role": "system", "content": system_content}) | |
| for msg in claude_request.get("messages", []): | |
| role = msg.get("role") | |
| content_parts = [] | |
| if isinstance(msg.get("content"), list): | |
| for block in msg.get("content", []): | |
| if block.get("type") == "text": | |
| content_parts.append(block.get("text", "")) | |
| elif isinstance(msg.get("content"), str): | |
| content_parts.append(msg.get("content")) | |
| if role and content_parts: | |
| messages.append({"role": role, "content": "\n".join(content_parts)}) | |
| openai_payload = { | |
| "model": claude_request.get("model", "gpt-3.5-turbo"), | |
| "messages": messages, | |
| "stream": claude_request.get("stream", False), | |
| **({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}), | |
| **({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}), | |
| **({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}), | |
| **({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}), | |
| } | |
| return openai_payload | |
| def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]: | |
| """Converts a non-streaming OpenAI response to Claude API format.""" | |
| try: | |
| choice = openai_response.get("choices", [{}])[0] | |
| message = choice.get("message", {}) | |
| content = message.get("content", "") | |
| role = message.get("role", "assistant") | |
| finish_reason = choice.get("finish_reason", "stop") | |
| stop_reason_map = { | |
| "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", | |
| "content_filter": "stop_sequence", "null": "stop_sequence", | |
| } | |
| claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence") | |
| usage = openai_response.get("usage", {}) | |
| prompt_tokens = usage.get("prompt_tokens", 0) | |
| completion_tokens = usage.get("completion_tokens", 0) | |
| claude_response = { | |
| "id": openai_response.get("id", claude_request_id), | |
| "type": "message", "role": role, | |
| "content": [{"type": "text", "text": content or ""}], | |
| "model": openai_response.get("model", "claude-proxy-model"), | |
| "stop_reason": claude_stop_reason, "stop_sequence": None, | |
| "usage": { "input_tokens": prompt_tokens, "output_tokens": completion_tokens }, | |
| } | |
| logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.") | |
| return claude_response | |
| except (KeyError, IndexError, TypeError) as e: | |
| logger.error(f"[{claude_request_id}] Error converting non-streaming OpenAI response: {e}") | |
| raise ValueError(f"Failed to parse OpenAI response: {e}") | |
| async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]: | |
| """Converts an OpenAI SSE stream to Claude API SSE format.""" | |
| message_id = claude_request_id | |
| accumulated_content_len = 0 | |
| openai_finish_reason = None | |
| input_tokens = 0 # Try to capture this | |
| output_tokens = 0 | |
| last_ping_time = time.time() | |
| logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.") | |
| # 1. Send message_start event | |
| yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n" # Initial usage is 0 | |
| # 2. Send content_block_start event | |
| yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" | |
| # 3. Send initial ping | |
| yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" | |
| try: | |
| async for line in openai_response.aiter_lines(): | |
| line = line.strip() | |
| if not line: continue | |
| if line.startswith("data:"): | |
| data_str = line[len("data: "):].strip() | |
| if data_str == "[DONE]": | |
| logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.") | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| choices = data.get("choices", []) | |
| if not choices: continue | |
| # --- Try to capture input tokens if sent early --- | |
| usage_update = data.get("usage") | |
| if usage_update and usage_update.get("prompt_tokens") is not None and input_tokens == 0: | |
| input_tokens = usage_update.get("prompt_tokens", 0) | |
| logger.debug(f"[{message_id}] Captured input_tokens: {input_tokens}") | |
| # --- | |
| delta = choices[0].get("delta", {}) | |
| content_chunk = delta.get("content") | |
| if choices[0].get("finish_reason"): | |
| openai_finish_reason = choices[0].get("finish_reason") | |
| logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}") | |
| # Update output tokens based on usage update if available | |
| if usage_update and usage_update.get("completion_tokens") is not None: | |
| output_tokens = usage_update.get("completion_tokens", output_tokens) | |
| logger.debug(f"[{message_id}] Received completion_tokens update: {output_tokens}") | |
| if content_chunk: | |
| accumulated_content_len += len(content_chunk) | |
| # Estimate output tokens if not provided by usage update | |
| if not (usage_update and usage_update.get("completion_tokens") is not None): | |
| output_tokens += 1 # Simple increment per chunk as fallback | |
| # 4. Send content_block_delta | |
| yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n" | |
| except json.JSONDecodeError: | |
| logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}") | |
| except Exception as e: | |
| logger.error(f"[{message_id}] Error processing stream data chunk: {e}") | |
| current_time = time.time() | |
| if current_time - last_ping_time >= 10: | |
| yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n" | |
| last_ping_time = current_time | |
| except httpx.ReadTimeout: | |
| logger.error(f"[{message_id}] Timeout reading from OpenAI stream.") | |
| openai_finish_reason = "error_timeout" | |
| yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n" | |
| except Exception as e: | |
| logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}") | |
| openai_finish_reason = "error_exception" | |
| yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n" | |
| finally: | |
| stop_reason_map = { | |
| "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use", | |
| "content_filter": "stop_sequence", "null": "stop_sequence", | |
| "error_timeout": "error", "error_exception": "error", | |
| } | |
| claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence") | |
| logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}") | |
| # 5. Send content_block_stop | |
| yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" | |
| # 6. Send message_delta with final stop reason ONLY | |
| final_delta = { | |
| 'type': 'message_delta', | |
| 'delta': { | |
| 'stop_reason': claude_stop_reason, | |
| 'stop_sequence': None | |
| } | |
| } | |
| yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n" | |
| # 7. Send message_stop (including final usage) | |
| # --- FIX: Use simpler 'usage' structure in message_stop --- | |
| final_stop_event_data = { | |
| 'type': 'message_stop', | |
| 'usage': { | |
| 'input_tokens': input_tokens, | |
| 'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4) # Use estimate if needed | |
| } | |
| } | |
| yield f"event: message_stop\ndata: {json.dumps(final_stop_event_data)}\n\n" | |
| # --- End Fix --- | |
| logger.info(f"[{message_id}] Completed sending Claude SSE stream.") | |
| def create_error_response(status_code: int, error_type: str, message: str) -> JSONResponse: | |
| """Creates a JSONResponse with a Claude-like error structure.""" | |
| return JSONResponse( | |
| status_code=status_code, | |
| content={"type": "error", "error": {"type": error_type, "message": message}} | |
| ) | |
| # --- Main Proxy Endpoint --- | |
| async def proxy_claude_to_openai(request: Request): | |
| """ | |
| Receives a Claude-formatted request, proxies it to OpenAI, | |
| and returns a Claude-formatted response. | |
| Requires a valid API key via the X-API-Key header. | |
| """ | |
| request_id = f"msg_{uuid.uuid4().hex[:24]}" | |
| try: | |
| claude_request_data = await request.json() | |
| logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}") | |
| except json.JSONDecodeError: | |
| logger.error(f"[{request_id}] Invalid JSON received in request body.") | |
| return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.") | |
| try: | |
| openai_payload = claude_request_to_openai_payload(claude_request_data) | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}") | |
| return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}") | |
| is_streaming = openai_payload.get("stream", False) | |
| requested_model = openai_payload.get("model", "unknown_model") | |
| target_headers = { "Content-Type": "application/json" } | |
| if OPENAI_API_KEY: | |
| target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" | |
| logger.debug(f"[{request_id}] Added Authorization header to upstream request.") | |
| else: | |
| logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.") | |
| if logger.level("DEBUG").no >= logger.level(log_level).no: | |
| logged_headers = target_headers.copy() | |
| if "Authorization" in logged_headers: logged_headers["Authorization"] = "Bearer [REDACTED]" | |
| logger.debug(f"[{request_id}] Sending request to upstream API.") | |
| logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}") | |
| try: | |
| payload_str = json.dumps(openai_payload, indent=2) | |
| max_log_len = 1024 | |
| logger.debug(f"[{request_id}] Upstream Payload {'(truncated)' if len(payload_str) > max_log_len else ''}: {payload_str[:max_log_len]}{'...' if len(payload_str) > max_log_len else ''}") | |
| except Exception as log_e: | |
| logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}") | |
| else: | |
| logger.debug(f"[{request_id}] Sending request to upstream API...") | |
| try: | |
| target_request = client.build_request("POST", OPENAI_API_ENDPOINT, headers=target_headers, json=openai_payload) | |
| response = await client.send(target_request, stream=is_streaming) | |
| response.raise_for_status() | |
| if is_streaming: | |
| logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.") | |
| return StreamingResponse( | |
| stream_openai_response_to_claude_events(response, request_id, requested_model), | |
| media_type="text/event-stream", | |
| headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache", "Connection": "keep-alive"} | |
| ) | |
| else: | |
| logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.") | |
| openai_response_data = response.json() | |
| logger.debug(f"[{request_id}] Received non-streaming response from upstream.") | |
| try: | |
| claude_response_data = openai_response_to_claude_response(openai_response_data, request_id) | |
| return JSONResponse(content=claude_response_data) | |
| except ValueError as e: | |
| logger.error(f"[{request_id}] Failed to convert upstream non-streaming response: {e}") | |
| return create_error_response(500, "api_error", f"Error processing response from upstream API: {e}") | |
| except Exception as e: | |
| logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}") | |
| return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.") | |
| except httpx.HTTPStatusError as e: | |
| status_code = e.response.status_code | |
| error_detail_text = "[Could not decode error response]" | |
| try: | |
| error_detail = e.response.json(); error_detail_text = json.dumps(error_detail) | |
| except json.JSONDecodeError: | |
| try: error_detail_text = e.response.text | |
| except Exception: logger.warning(f"[{request_id}] Could not read error response body as text.") | |
| logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...") | |
| if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})." | |
| elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})." | |
| elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})." | |
| elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})." | |
| elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})." | |
| else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})." | |
| return create_error_response(status_code, err_type, msg) | |
| except httpx.TimeoutException: | |
| logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).") | |
| return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.") | |
| except httpx.RequestError as e: | |
| logger.error(f"[{request_id}] Network error connecting to target endpoint: {type(e).__name__}") | |
| return create_error_response(502, "api_error", f"Bad Gateway: Network error connecting to upstream API.") | |
| except Exception as e: | |
| logger.exception(f"[{request_id}] Unexpected error during proxy operation: {e}") | |
| return create_error_response(500, "internal_server_error", f"Internal Server Error: {e}") | |
| # --- Health Check Endpoint --- | |
| async def health_check(): | |
| """Returns the operational status of the proxy server.""" | |
| return {"status": "healthy"} | |
| # --- Run with Uvicorn (for local development) --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logger.info("Loaded environment variables from .env file (if present).") | |
| PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") | |
| VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip()) | |
| OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| HTTP_PROXY = os.getenv("HTTP_PROXY") | |
| log_level = os.getenv("LOG_LEVEL", "INFO").upper() | |
| logger.remove() | |
| logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") | |
| logger.info(f"Log level set to: {log_level}") | |
| logger.info(f"Valid Proxy API Keys configured: {len(VALID_API_KEYS)}") | |
| except ImportError: | |
| logger.info("python-dotenv not installed, skipping .env file loading.") | |
| port = int(os.getenv("PORT", 7860)) | |
| host = os.getenv("HOST", "0.0.0.0") | |
| log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info" | |
| logger.info(f"Starting Uvicorn server on {host}:{port}") | |
| uvicorn.run("proxy_server:app", host=host, port=port, reload=True, log_level=log_config_level) | |