|
|
import os |
|
|
import re |
|
|
import httpx |
|
|
from fastapi import FastAPI, Request, HTTPException, Security |
|
|
from fastapi.responses import StreamingResponse, Response |
|
|
from fastapi.security import APIKeyHeader, APIKeyQuery |
|
|
from itertools import cycle |
|
|
import asyncio |
|
|
import json |
|
|
|
|
|
|
|
|
PROXY_API_KEY = os.environ.get("PROXY_API_KEY") |
|
|
VERTEX_EXPRESS_KEYS_STR = os.environ.get("VERTEX_EXPRESS_KEYS") |
|
|
VERTEX_EXPRESS_KEYS = [key.strip() for key in VERTEX_EXPRESS_KEYS_STR.split(',')] if VERTEX_EXPRESS_KEYS_STR else [] |
|
|
|
|
|
if not VERTEX_EXPRESS_KEYS: |
|
|
raise ValueError("VERTEX_EXPRESS_KEYS environment variable not set or empty.") |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
project_id_cache = {} |
|
|
key_rotator = cycle(VERTEX_EXPRESS_KEYS) |
|
|
key_lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
api_key_query = APIKeyQuery(name="key", auto_error=False) |
|
|
api_key_header = APIKeyHeader(name="x-goog-api-key", auto_error=False) |
|
|
|
|
|
async def get_api_key( |
|
|
key_query: str = Security(api_key_query), |
|
|
key_header: str = Security(api_key_header), |
|
|
): |
|
|
if PROXY_API_KEY: |
|
|
if key_query == PROXY_API_KEY: |
|
|
return key_query |
|
|
if key_header == PROXY_API_KEY: |
|
|
return key_header |
|
|
raise HTTPException(status_code=401, detail="Invalid or missing API Key") |
|
|
else: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
async def get_project_id(key: str): |
|
|
if key in project_id_cache: |
|
|
return project_id_cache[key] |
|
|
|
|
|
url = f"https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.6-pro:generateContent?key={key}" |
|
|
headers = {'Content-Type': 'application/json'} |
|
|
data = '{}' |
|
|
|
|
|
async with httpx.AsyncClient() as client: |
|
|
try: |
|
|
response = await client.post(url, headers=headers, data=data) |
|
|
response.raise_for_status() |
|
|
except httpx.HTTPStatusError as e: |
|
|
if e.response.status_code == 404: |
|
|
error_message = e.response.json().get("error", {}).get("message", "") |
|
|
match = re.search(r"projects/([^/]+)/locations/", error_message) |
|
|
if match: |
|
|
project_id = match.group(1) |
|
|
project_id_cache[key] = project_id |
|
|
return project_id |
|
|
raise HTTPException(status_code=500, detail=f"Failed to extract project ID: {e.response.text}") |
|
|
|
|
|
raise HTTPException(status_code=500, detail="Could not extract project ID from any key.") |
|
|
|
|
|
|
|
|
@app.post("/v1beta/models/{model_path:path}") |
|
|
async def proxy(request: Request, model_path: str, api_key: str = Security(get_api_key)): |
|
|
async with key_lock: |
|
|
express_key = next(key_rotator) |
|
|
|
|
|
project_id = await get_project_id(express_key) |
|
|
|
|
|
raw_request_body = await request.body() |
|
|
request_body_to_send = raw_request_body |
|
|
|
|
|
try: |
|
|
request_json = json.loads(raw_request_body) |
|
|
if "gemini-2.0-flash-exp-image-generation" in model_path: |
|
|
model_path = model_path.replace("gemini-2.0-flash-exp-image-generation", "gemini-2.5-flash-image-preview") |
|
|
|
|
|
if "generationConfig" not in request_json: |
|
|
request_json["generationConfig"] = {} |
|
|
|
|
|
|
|
|
if "gemini-2.5-flash-image-preview" in model_path: |
|
|
if "generationConfig" in request_json and "thinkingConfig" in request_json.get("generationConfig", {}): |
|
|
del request_json["generationConfig"]["thinkingConfig"] |
|
|
print(request_json["generationConfig"]) |
|
|
if "generationConfig" in request_json and "responseMimeType" in request_json.get("generationConfig", {}): |
|
|
del request_json["generationConfig"]["responseMimeType"] |
|
|
request_json["generationConfig"] |
|
|
request_json["generationConfig"]["responseModalities"] = ["TEXT", "IMAGE"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
request_body_to_send = json.dumps(request_json).encode('utf-8') |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
target_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers/google/models/{model_path}?key={express_key}" |
|
|
|
|
|
client = httpx.AsyncClient(timeout=None) |
|
|
|
|
|
headers_to_proxy = { |
|
|
k: v for k, v in request.headers.items() |
|
|
if k.lower() not in ['host', 'authorization', 'x-goog-api-key', 'content-length'] |
|
|
} |
|
|
|
|
|
print(request_body_to_send) |
|
|
|
|
|
if "streamGenerateContent" in model_path: |
|
|
target_url = target_url + "&alt=sse" |
|
|
|
|
|
req = client.build_request( |
|
|
method=request.method, |
|
|
url=target_url, |
|
|
headers=headers_to_proxy, |
|
|
content=request_body_to_send, |
|
|
) |
|
|
response = await client.send(req, stream=True) |
|
|
|
|
|
if response.status_code != 200: |
|
|
try: |
|
|
response_data = await response.aread() |
|
|
return Response( |
|
|
content=response_data, |
|
|
status_code=response.status_code, |
|
|
headers=dict(response.headers), |
|
|
) |
|
|
finally: |
|
|
await response.aclose() |
|
|
await client.aclose() |
|
|
|
|
|
if "streamGenerateContent" in model_path: |
|
|
async def stream_generator(): |
|
|
try: |
|
|
async for line in response.aiter_lines(): |
|
|
print(line) |
|
|
yield f"{line}\n" |
|
|
finally: |
|
|
await response.aclose() |
|
|
await client.aclose() |
|
|
|
|
|
return StreamingResponse(stream_generator(), media_type=response.headers.get("content-type")) |
|
|
else: |
|
|
try: |
|
|
response_data = await response.aread() |
|
|
response_json = json.loads(response_data) |
|
|
|
|
|
if 'candidates' in response_json: |
|
|
for candidate in response_json.get('candidates', []): |
|
|
if 'content' in candidate and 'parts' in candidate.get('content', {}): |
|
|
candidate['content']['parts'] = [part for part in candidate['content']['parts'] if part] |
|
|
|
|
|
modified_response_data = json.dumps(response_json).encode('utf-8') |
|
|
|
|
|
return Response( |
|
|
content=modified_response_data, |
|
|
status_code=response.status_code, |
|
|
headers={"content-type":response.headers.get("content-type")}, |
|
|
) |
|
|
finally: |
|
|
await response.aclose() |
|
|
await client.aclose() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |