helicone / app.py
bibibi12345's picture
Update app.py
1c30306 verified
import os
import httpx
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse, JSONResponse, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import logging # <-- Add logging
from fastapi.middleware.cors import CORSMiddleware # Add CORS Middleware import
from dotenv import load_dotenv
import json
# Load environment variables from .env file
load_dotenv()
# Configuration
REMOTE_CHAT_COMPLETION_URL = "https://us.helicone.ai/api/llm"
REMOTE_MODELS_URL = "https://openrouter.ai/api/v1/models"
EXPECTED_API_KEY = os.getenv("PROXY_API_KEY", "default_insecure_key") # Load API key from .env or use a default
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Authentication ---
security = HTTPBearer()
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify the provided API key."""
if credentials.scheme != "Bearer" or credentials.credentials != EXPECTED_API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
return credentials.credentials
# --- FastAPI App ---
app = FastAPI(
title="OpenAI Format Proxy",
description="A proxy server that translates requests to an OpenAI-compatible format.",
version="1.0.0",
)
# --- CORS Middleware ---
# Allows requests from any origin, with any method and headers.
# Adjust origins if you need to restrict access to specific domains.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.)
allow_headers=["*"], # Allows all headers
)
# --- Helper Functions ---
# Removed stream_wrapper as we'll use a different approach
async def log_stream_chunks(iterator, request_path: str):
"""Async generator wrapper to log incoming stream chunks."""
import time
start_time = time.time()
chunk_count = 0
logger.info(f"[{request_path}] log_stream_chunks: Starting iteration at {start_time:.3f}")
try:
async for chunk in iterator:
chunk_count += 1
current_time = time.time()
logger.info(f"[{request_path}] log_stream_chunks: Received chunk {chunk_count} ({len(chunk)} bytes) at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed)")
yield chunk
except Exception as e:
current_time = time.time()
logger.error(f"[{request_path}] log_stream_chunks: Error during iteration at {current_time:.3f} ({(current_time - start_time):.3f}s elapsed): {e}", exc_info=True)
raise # Re-raise after logging
finally:
end_time = time.time()
logger.info(f"[{request_path}] log_stream_chunks: Finished iteration. Total chunks: {chunk_count}. Total time: {(end_time - start_time):.3f}s")
async def forward_request(request: Request, target_url: str):
"""Forwards the request to the target URL using httpx, handling streaming based on Content-Type."""
body = await request.body()
# Prepare headers, exclude Host header and potentially sensitive headers like Authorization if needed
# Build headers for upstream, excluding 'host' and any 'accept' header initially
headers = {key: value for key, value in request.headers.items() if key.lower() not in ['host', 'accept']}
# Set the desired 'Accept' header based on the target URL
if target_url == REMOTE_CHAT_COMPLETION_URL:
headers['Accept'] = 'text/event-stream' # Force stream accept for chat
elif 'accept' in request.headers:
# If original request had an 'accept' header and it's not chat, forward it
headers['Accept'] = request.headers['accept']
# else: No specific Accept header needed/provided for other targets
logger.info(f"[{request.url.path}] Forwarding {request.method} request to {target_url}")
logger.info(f"[{request.url.path}] Sending upstream request with headers: {headers}") # Log outgoing headers
async with httpx.AsyncClient(timeout=None) as client:
try:
# Make the request without using client.stream() initially
response = await client.request(
method=request.method,
url=target_url,
headers=headers,
params=request.query_params,
content=body
)
# Log upstream status
logger.info(f"[{request.url.path}] Received response from {target_url} with status {response.status_code}")
# Check if the response indicates an error status code
if response.status_code >= 400:
error_content = await response.aread()
detail = error_content.decode(errors='replace')
logger.warning(f"[{request.url.path}] Upstream server {target_url} returned error {response.status_code}: {detail}")
# Forward the exact error if possible
raise HTTPException(status_code=response.status_code, detail=detail)
# Check Content-Type for streaming
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" in content_type:
logger.info(f"[{request.url.path}] Detected 'text/event-stream' content type. Streaming response back.")
# Use aiter_bytes() for async streaming
return StreamingResponse(
log_stream_chunks(response.aiter_bytes(), request.url.path), # Use the logging wrapper
status_code=response.status_code,
headers=dict(response.headers),
media_type="text/event-stream" # Ensure correct media type propagates
)
else:
logger.info(f"[{request.url.path}] Non-streaming content type detected ('{content_type}'). Sending full response.")
# Read the entire response content for non-streaming responses
response_content = await response.aread()
# Determine the correct response class (JSONResponse or Response)
if "application/json" in content_type:
try:
json_content = json.loads(response_content)
return JSONResponse(
content=json_content,
status_code=response.status_code,
headers=dict(response.headers)
)
except json.JSONDecodeError:
logger.warning(f"[{request.url.path}] Declared JSON but failed to parse. Sending raw.")
# Fallback to raw Response if JSON parsing fails
return Response(
content=response_content,
status_code=response.status_code,
headers=dict(response.headers)
)
else:
# Return raw response for other content types
return Response(
content=response_content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.RequestError as e:
logger.error(f"[{request.url.path}] Error communicating with target server {target_url}: {e}", exc_info=True)
raise HTTPException(status_code=502, detail=f"Error communicating with target server: {e}")
except Exception as e: # Catch other potential errors
logger.error(f"[{request.url.path}] Unexpected error forwarding request to {target_url}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error during request forwarding: {e}")
# --- API Endpoints ---
# @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
@app.get("/v1/models")
async def get_models(request: Request):
"""Proxies requests to the remote models endpoint."""
async with httpx.AsyncClient(timeout=30.0) as client: # Shorter timeout for potentially faster models endpoint
try:
# Use specific headers provided by the user for the /v1/models request
model_request_headers = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7,zh-TW;q=0.6,ja;q=0.5",
"priority": "u=0, i",
"sec-ch-ua": "\"Google Chrome\";v=\"135\", \"Not-A.Brand\";v=\"8\", \"Chromium\";v=\"135\"",
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "\"macOS\"",
"sec-fetch-dest": "document",
"sec-fetch-mode": "navigate",
"sec-fetch-site": "none",
"sec-fetch-user": "?1",
"upgrade-insecure-requests": "1",
# Note: httpx automatically handles user-agent, host, connection, etc.
# We exclude cookies and authorization from the client request by not forwarding them.
}
resp = await client.get(REMOTE_MODELS_URL, headers=model_request_headers)
resp.raise_for_status() # Check for HTTP errors first
# Get raw bytes and check content encoding
content_bytes = resp.content
content_encoding = resp.headers.get("content-encoding", "").lower()
try:
original_data = json.loads(content_bytes)
# Transform the data into OpenAI format
openai_models_data = []
# if isinstance(original_data.get("data"), list):
# for model_info in original_data["data"]:
# openai_models_data.append({
# "id": model_info.get("id"),
# "object": "model", # Standard OpenAI format field
# "created": model_info.get("created"), # Use original timestamp
# "owned_by": "system" # Default value, as owner isn't specified
# })
final_response_data = {
"object": "list", # Standard OpenAI format field
"data": original_data["data"]
}
# Return the transformed successful JSON response
return JSONResponse(
content=final_response_data,
status_code=resp.status_code, # Forward original status
headers={'Content-Type': 'application/json'} # Set correct content type
)
# Removed specific gzip/zlib handling here as httpx handles content decoding by default unless streaming raw
except UnicodeDecodeError:
logger.error(f"[{request.url.path}] Failed to decode upstream models response as UTF-8.")
raise HTTPException(status_code=500, detail="Failed to decode upstream models response as UTF-8.")
except json.JSONDecodeError:
# Log JSON parsing error
logger.error(f"[{request.url.path}] Upstream models response was not valid JSON.", exc_info=True)
raise HTTPException(status_code=500, detail="Upstream models response was not valid JSON after decoding.")
except httpx.HTTPStatusError as e:
error_detail = e.response.text
try:
error_detail = e.response.json()
except json.JSONDecodeError:
pass
raise HTTPException(status_code=e.response.status_code, detail=error_detail)
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error communicating with models server: {e}")
# @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""Proxies chat completion requests to the remote server, handling streaming."""
return await forward_request(request, REMOTE_CHAT_COMPLETION_URL)
# --- Health Check --- (Good practice for deployments)
@app.get("/health")
async def health_check():
"""Simple health check endpoint."""
return {"status": "ok"}
# --- Main Execution --- (For local testing with uvicorn)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 8000)) # Allow port configuration via env
uvicorn.run(app, host="0.0.0.0", port=port)