Spaces:
Sleeping
Sleeping
| import os | |
| import httpx | |
| from fastapi import FastAPI, Request, HTTPException, Depends | |
| from fastapi.responses import StreamingResponse, JSONResponse, Response | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| import logging # <-- Add logging | |
| from fastapi.middleware.cors import CORSMiddleware # Add CORS Middleware import | |
| from dotenv import load_dotenv | |
| import json | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Configuration | |
| REMOTE_CHAT_COMPLETION_URL = "https://us.helicone.ai/api/llm" | |
| REMOTE_MODELS_URL = "https://openrouter.ai/api/v1/models" | |
| EXPECTED_API_KEY = os.getenv("PROXY_API_KEY", "default_insecure_key") # Load API key from .env or use a default | |
| # --- Logging Setup --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Authentication --- | |
| security = HTTPBearer() | |
| async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Verify the provided API key.""" | |
| if credentials.scheme != "Bearer" or credentials.credentials != EXPECTED_API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key") | |
| return credentials.credentials | |
| # --- FastAPI App --- | |
| app = FastAPI( | |
| title="OpenAI Format Proxy", | |
| description="A proxy server that translates requests to an OpenAI-compatible format.", | |
| version="1.0.0", | |
| ) | |
| # --- CORS Middleware --- | |
| # Allows requests from any origin, with any method and headers. | |
| # Adjust origins if you need to restrict access to specific domains. | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.) | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # --- Helper Functions --- | |
| # Removed stream_wrapper as we'll use a different approach | |
| async def log_stream_chunks(iterator, request_path: str): | |
| """Async generator wrapper to log incoming stream chunks.""" | |
| import time | |
| start_time = time.time() | |
| chunk_count = 0 | |
| logger.info(f"[{request_path}] log_stream_chunks: Starting iteration at {start_time:.3f}") | |
| try: | |
| async for chunk in iterator: | |
| chunk_count += 1 | |
| current_time = time.time() | |
| logger.info(f"[{request_path}] log_stream_chunks: Received chunk {chunk_count} ({len(chunk)} bytes) at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed)") | |
| yield chunk | |
| except Exception as e: | |
| current_time = time.time() | |
| logger.error(f"[{request_path}] log_stream_chunks: Error during iteration at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed): {e}", exc_info=True) | |
| raise # Re-raise after logging | |
| finally: | |
| end_time = time.time() | |
| logger.info(f"[{request_path}] log_stream_chunks: Finished iteration. Total chunks: {chunk_count}. Total time: {(end_time - start_time):.3f}s") | |
| async def forward_request(request: Request, target_url: str): | |
| """Forwards the request to the target URL using httpx, handling streaming based on Content-Type.""" | |
| body = await request.body() | |
| # Prepare headers, exclude Host header and potentially sensitive headers like Authorization if needed | |
| # Build headers for upstream, excluding 'host' and any 'accept' header initially | |
| headers = {key: value for key, value in request.headers.items() if key.lower() not in ['host', 'accept']} | |
| # Set the desired 'Accept' header based on the target URL | |
| if target_url == REMOTE_CHAT_COMPLETION_URL: | |
| headers['Accept'] = 'text/event-stream' # Force stream accept for chat | |
| elif 'accept' in request.headers: | |
| # If original request had an 'accept' header and it's not chat, forward it | |
| headers['Accept'] = request.headers['accept'] | |
| # else: No specific Accept header needed/provided for other targets | |
| logger.info(f"[{request.url.path}] Forwarding {request.method} request to {target_url}") | |
| logger.info(f"[{request.url.path}] Sending upstream request with headers: {headers}") # Log outgoing headers | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| try: | |
| # Make the request without using client.stream() initially | |
| response = await client.request( | |
| method=request.method, | |
| url=target_url, | |
| headers=headers, | |
| params=request.query_params, | |
| content=body | |
| ) | |
| # Log upstream status | |
| logger.info(f"[{request.url.path}] Received response from {target_url} with status {response.status_code}") | |
| # Check if the response indicates an error status code | |
| if response.status_code >= 400: | |
| error_content = await response.aread() | |
| detail = error_content.decode(errors='replace') | |
| logger.warning(f"[{request.url.path}] Upstream server {target_url} returned error {response.status_code}: {detail}") | |
| # Forward the exact error if possible | |
| raise HTTPException(status_code=response.status_code, detail=detail) | |
| # Check Content-Type for streaming | |
| content_type = response.headers.get("content-type", "").lower() | |
| if "text/event-stream" in content_type: | |
| logger.info(f"[{request.url.path}] Detected 'text/event-stream' content type. Streaming response back.") | |
| # Use aiter_bytes() for async streaming | |
| return StreamingResponse( | |
| log_stream_chunks(response.aiter_bytes(), request.url.path), # Use the logging wrapper | |
| status_code=response.status_code, | |
| headers=dict(response.headers), | |
| media_type="text/event-stream" # Ensure correct media type propagates | |
| ) | |
| else: | |
| logger.info(f"[{request.url.path}] Non-streaming content type detected ('{content_type}'). Sending full response.") | |
| # Read the entire response content for non-streaming responses | |
| response_content = await response.aread() | |
| # Determine the correct response class (JSONResponse or Response) | |
| if "application/json" in content_type: | |
| try: | |
| json_content = json.loads(response_content) | |
| return JSONResponse( | |
| content=json_content, | |
| status_code=response.status_code, | |
| headers=dict(response.headers) | |
| ) | |
| except json.JSONDecodeError: | |
| logger.warning(f"[{request.url.path}] Declared JSON but failed to parse. Sending raw.") | |
| # Fallback to raw Response if JSON parsing fails | |
| return Response( | |
| content=response_content, | |
| status_code=response.status_code, | |
| headers=dict(response.headers) | |
| ) | |
| else: | |
| # Return raw response for other content types | |
| return Response( | |
| content=response_content, | |
| status_code=response.status_code, | |
| headers=dict(response.headers) | |
| ) | |
| except httpx.RequestError as e: | |
| logger.error(f"[{request.url.path}] Error communicating with target server {target_url}: {e}", exc_info=True) | |
| raise HTTPException(status_code=502, detail=f"Error communicating with target server: {e}") | |
| except Exception as e: # Catch other potential errors | |
| logger.error(f"[{request.url.path}] Unexpected error forwarding request to {target_url}: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Internal server error during request forwarding: {e}") | |
| # --- API Endpoints --- | |
| # @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) | |
| async def get_models(request: Request): | |
| """Proxies requests to the remote models endpoint.""" | |
| async with httpx.AsyncClient(timeout=30.0) as client: # Shorter timeout for potentially faster models endpoint | |
| try: | |
| # Use specific headers provided by the user for the /v1/models request | |
| model_request_headers = { | |
| "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", | |
| "accept-language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7,zh-TW;q=0.6,ja;q=0.5", | |
| "priority": "u=0, i", | |
| "sec-ch-ua": "\"Google Chrome\";v=\"135\", \"Not-A.Brand\";v=\"8\", \"Chromium\";v=\"135\"", | |
| "sec-ch-ua-mobile": "?0", | |
| "sec-ch-ua-platform": "\"macOS\"", | |
| "sec-fetch-dest": "document", | |
| "sec-fetch-mode": "navigate", | |
| "sec-fetch-site": "none", | |
| "sec-fetch-user": "?1", | |
| "upgrade-insecure-requests": "1", | |
| # Note: httpx automatically handles user-agent, host, connection, etc. | |
| # We exclude cookies and authorization from the client request by not forwarding them. | |
| } | |
| resp = await client.get(REMOTE_MODELS_URL, headers=model_request_headers) | |
| resp.raise_for_status() # Check for HTTP errors first | |
| # Get raw bytes and check content encoding | |
| content_bytes = resp.content | |
| content_encoding = resp.headers.get("content-encoding", "").lower() | |
| try: | |
| original_data = json.loads(content_bytes) | |
| # Transform the data into OpenAI format | |
| openai_models_data = [] | |
| # if isinstance(original_data.get("data"), list): | |
| # for model_info in original_data["data"]: | |
| # openai_models_data.append({ | |
| # "id": model_info.get("id"), | |
| # "object": "model", # Standard OpenAI format field | |
| # "created": model_info.get("created"), # Use original timestamp | |
| # "owned_by": "system" # Default value, as owner isn't specified | |
| # }) | |
| final_response_data = { | |
| "object": "list", # Standard OpenAI format field | |
| "data": original_data["data"] | |
| } | |
| # Return the transformed successful JSON response | |
| return JSONResponse( | |
| content=final_response_data, | |
| status_code=resp.status_code, # Forward original status | |
| headers={'Content-Type': 'application/json'} # Set correct content type | |
| ) | |
| # Removed specific gzip/zlib handling here as httpx handles content decoding by default unless streaming raw | |
| except UnicodeDecodeError: | |
| logger.error(f"[{request.url.path}] Failed to decode upstream models response as UTF-8.") | |
| raise HTTPException(status_code=500, detail="Failed to decode upstream models response as UTF-8.") | |
| except json.JSONDecodeError: | |
| # Log JSON parsing error | |
| logger.error(f"[{request.url.path}] Upstream models response was not valid JSON.", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Upstream models response was not valid JSON after decoding.") | |
| except httpx.HTTPStatusError as e: | |
| error_detail = e.response.text | |
| try: | |
| error_detail = e.response.json() | |
| except json.JSONDecodeError: | |
| pass | |
| raise HTTPException(status_code=e.response.status_code, detail=error_detail) | |
| except httpx.RequestError as e: | |
| raise HTTPException(status_code=502, detail=f"Error communicating with models server: {e}") | |
| # @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)]) | |
| async def chat_completions(request: Request): | |
| """Proxies chat completion requests to the remote server, handling streaming.""" | |
| return await forward_request(request, REMOTE_CHAT_COMPLETION_URL) | |
| # --- Health Check --- (Good practice for deployments) | |
| async def health_check(): | |
| """Simple health check endpoint.""" | |
| return {"status": "ok"} | |
| # --- Main Execution --- (For local testing with uvicorn) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 8000)) # Allow port configuration via env | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |