import os import httpx from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import StreamingResponse, JSONResponse, Response from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import logging # <-- Add logging from fastapi.middleware.cors import CORSMiddleware # Add CORS Middleware import from dotenv import load_dotenv import json # Load environment variables from .env file load_dotenv() # Configuration REMOTE_CHAT_COMPLETION_URL = "https://us.helicone.ai/api/llm" REMOTE_MODELS_URL = "https://openrouter.ai/api/v1/models" EXPECTED_API_KEY = os.getenv("PROXY_API_KEY", "default_insecure_key") # Load API key from .env or use a default # --- Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Authentication --- security = HTTPBearer() async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify the provided API key.""" if credentials.scheme != "Bearer" or credentials.credentials != EXPECTED_API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") return credentials.credentials # --- FastAPI App --- app = FastAPI( title="OpenAI Format Proxy", description="A proxy server that translates requests to an OpenAI-compatible format.", version="1.0.0", ) # --- CORS Middleware --- # Allows requests from any origin, with any method and headers. # Adjust origins if you need to restrict access to specific domains. app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.) allow_headers=["*"], # Allows all headers ) # --- Helper Functions --- # Removed stream_wrapper as we'll use a different approach async def log_stream_chunks(iterator, request_path: str): """Async generator wrapper to log incoming stream chunks.""" import time start_time = time.time() chunk_count = 0 logger.info(f"[{request_path}] log_stream_chunks: Starting iteration at {start_time:.3f}") try: async for chunk in iterator: chunk_count += 1 current_time = time.time() logger.info(f"[{request_path}] log_stream_chunks: Received chunk {chunk_count} ({len(chunk)} bytes) at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed)") yield chunk except Exception as e: current_time = time.time() logger.error(f"[{request_path}] log_stream_chunks: Error during iteration at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed): {e}", exc_info=True) raise # Re-raise after logging finally: end_time = time.time() logger.info(f"[{request_path}] log_stream_chunks: Finished iteration. Total chunks: {chunk_count}. Total time: {(end_time - start_time):.3f}s") async def forward_request(request: Request, target_url: str): """Forwards the request to the target URL using httpx, handling streaming based on Content-Type.""" body = await request.body() # Prepare headers, exclude Host header and potentially sensitive headers like Authorization if needed # Build headers for upstream, excluding 'host' and any 'accept' header initially headers = {key: value for key, value in request.headers.items() if key.lower() not in ['host', 'accept']} # Set the desired 'Accept' header based on the target URL if target_url == REMOTE_CHAT_COMPLETION_URL: headers['Accept'] = 'text/event-stream' # Force stream accept for chat elif 'accept' in request.headers: # If original request had an 'accept' header and it's not chat, forward it headers['Accept'] = request.headers['accept'] # else: No specific Accept header needed/provided for other targets logger.info(f"[{request.url.path}] Forwarding {request.method} request to {target_url}") logger.info(f"[{request.url.path}] Sending upstream request with headers: {headers}") # Log outgoing headers async with httpx.AsyncClient(timeout=None) as client: try: # Make the request without using client.stream() initially response = await client.request( method=request.method, url=target_url, headers=headers, params=request.query_params, content=body ) # Log upstream status logger.info(f"[{request.url.path}] Received response from {target_url} with status {response.status_code}") # Check if the response indicates an error status code if response.status_code >= 400: error_content = await response.aread() detail = error_content.decode(errors='replace') logger.warning(f"[{request.url.path}] Upstream server {target_url} returned error {response.status_code}: {detail}") # Forward the exact error if possible raise HTTPException(status_code=response.status_code, detail=detail) # Check Content-Type for streaming content_type = response.headers.get("content-type", "").lower() if "text/event-stream" in content_type: logger.info(f"[{request.url.path}] Detected 'text/event-stream' content type. Streaming response back.") # Use aiter_bytes() for async streaming return StreamingResponse( log_stream_chunks(response.aiter_bytes(), request.url.path), # Use the logging wrapper status_code=response.status_code, headers=dict(response.headers), media_type="text/event-stream" # Ensure correct media type propagates ) else: logger.info(f"[{request.url.path}] Non-streaming content type detected ('{content_type}'). Sending full response.") # Read the entire response content for non-streaming responses response_content = await response.aread() # Determine the correct response class (JSONResponse or Response) if "application/json" in content_type: try: json_content = json.loads(response_content) return JSONResponse( content=json_content, status_code=response.status_code, headers=dict(response.headers) ) except json.JSONDecodeError: logger.warning(f"[{request.url.path}] Declared JSON but failed to parse. Sending raw.") # Fallback to raw Response if JSON parsing fails return Response( content=response_content, status_code=response.status_code, headers=dict(response.headers) ) else: # Return raw response for other content types return Response( content=response_content, status_code=response.status_code, headers=dict(response.headers) ) except httpx.RequestError as e: logger.error(f"[{request.url.path}] Error communicating with target server {target_url}: {e}", exc_info=True) raise HTTPException(status_code=502, detail=f"Error communicating with target server: {e}") except Exception as e: # Catch other potential errors logger.error(f"[{request.url.path}] Unexpected error forwarding request to {target_url}: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error during request forwarding: {e}") # --- API Endpoints --- # @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) @app.get("/v1/models") async def get_models(request: Request): """Proxies requests to the remote models endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: # Shorter timeout for potentially faster models endpoint try: # Use specific headers provided by the user for the /v1/models request model_request_headers = { "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "accept-language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7,zh-TW;q=0.6,ja;q=0.5", "priority": "u=0, i", "sec-ch-ua": "\"Google Chrome\";v=\"135\", \"Not-A.Brand\";v=\"8\", \"Chromium\";v=\"135\"", "sec-ch-ua-mobile": "?0", "sec-ch-ua-platform": "\"macOS\"", "sec-fetch-dest": "document", "sec-fetch-mode": "navigate", "sec-fetch-site": "none", "sec-fetch-user": "?1", "upgrade-insecure-requests": "1", # Note: httpx automatically handles user-agent, host, connection, etc. # We exclude cookies and authorization from the client request by not forwarding them. } resp = await client.get(REMOTE_MODELS_URL, headers=model_request_headers) resp.raise_for_status() # Check for HTTP errors first # Get raw bytes and check content encoding content_bytes = resp.content content_encoding = resp.headers.get("content-encoding", "").lower() try: original_data = json.loads(content_bytes) # Transform the data into OpenAI format openai_models_data = [] # if isinstance(original_data.get("data"), list): # for model_info in original_data["data"]: # openai_models_data.append({ # "id": model_info.get("id"), # "object": "model", # Standard OpenAI format field # "created": model_info.get("created"), # Use original timestamp # "owned_by": "system" # Default value, as owner isn't specified # }) final_response_data = { "object": "list", # Standard OpenAI format field "data": original_data["data"] } # Return the transformed successful JSON response return JSONResponse( content=final_response_data, status_code=resp.status_code, # Forward original status headers={'Content-Type': 'application/json'} # Set correct content type ) # Removed specific gzip/zlib handling here as httpx handles content decoding by default unless streaming raw except UnicodeDecodeError: logger.error(f"[{request.url.path}] Failed to decode upstream models response as UTF-8.") raise HTTPException(status_code=500, detail="Failed to decode upstream models response as UTF-8.") except json.JSONDecodeError: # Log JSON parsing error logger.error(f"[{request.url.path}] Upstream models response was not valid JSON.", exc_info=True) raise HTTPException(status_code=500, detail="Upstream models response was not valid JSON after decoding.") except httpx.HTTPStatusError as e: error_detail = e.response.text try: error_detail = e.response.json() except json.JSONDecodeError: pass raise HTTPException(status_code=e.response.status_code, detail=error_detail) except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Error communicating with models server: {e}") # @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)]) @app.post("/v1/chat/completions") async def chat_completions(request: Request): """Proxies chat completion requests to the remote server, handling streaming.""" return await forward_request(request, REMOTE_CHAT_COMPLETION_URL) # --- Health Check --- (Good practice for deployments) @app.get("/health") async def health_check(): """Simple health check endpoint.""" return {"status": "ok"} # --- Main Execution --- (For local testing with uvicorn) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 8000)) # Allow port configuration via env uvicorn.run(app, host="0.0.0.0", port=port)