nim-proxy / main.py
ZSNonSKY's picture
Update main.py
410f92b verified
import os
import httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
app = FastAPI()
# Configuration
NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
API_KEY = os.getenv("NVIDIA_API_KEY")
@app.get("/")
def home():
return {"status": "Proxy is running."}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request):
# --- LOGGING FOR DEBUGGING ---
print(f"Incoming request path: {path}")
if not API_KEY:
print("Error: API Key is missing")
raise HTTPException(status_code=500, detail="NVIDIA_API_KEY secret not set.")
# --- MAGIC FIX FOR "v1" PATH ---
# If client sends request to ".../v1" directly, assume they want chat completions
if path in ["v1", "v1/"]:
print("Auto-correcting path: Redirecting to chat/completions")
clean_path = "chat/completions"
else:
# Standard cleaning
clean_path = path
if clean_path.startswith("v1/"):
clean_path = clean_path[3:]
elif clean_path.startswith("v1"):
clean_path = clean_path[2:]
# Construct Target URL
target_url = f"{NVIDIA_BASE_URL}/{clean_path}"
print(f"Forwarding to: {target_url}")
body = await request.body()
headers = dict(request.headers)
headers["authorization"] = f"Bearer {API_KEY}"
headers["host"] = "integrate.api.nvidia.com"
# Remove problematic headers
for h in ["content-length", "connection", "host", "user-agent", "accept-encoding"]:
headers.pop(h, None)
async def stream_generator():
async with httpx.AsyncClient(timeout=60.0) as client:
try:
req = client.build_request(
request.method, target_url, headers=headers, content=body
)
response = await client.send(req, stream=True)
print(f"Nvidia status: {response.status_code}")
if response.status_code >= 400:
error_content = await response.aread()
print(f"Nvidia Error: {error_content.decode()}")
yield error_content
return
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
print(f"Server Error: {e}")
yield str(e).encode()
return StreamingResponse(stream_generator(), media_type="application/json")