api-proxy / app.py
tanbushi's picture
asdf
ab3f32d
# uvicorn app:app --host 0.0.0.0 --port 7860 --reload
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
import re, json
import httpx
import uuid
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.get("/v1/{dest_url:path}")
async def gre_dest_url(dest_url: str):
return dest_url
API_CLIENT = "genai-js/0.21.0"
DEFAULT_MODEL = "gemini-1.5-pro-latest"
async def transform_request(req: dict):
# This is a placeholder, implement the transformation logic as needed
return req
async def process_completions_response(data: dict, model: str, id: str):
# Process the response to match the OpenAI format
choices = []
if "candidates" in data:
for i, candidate in enumerate(data["candidates"]):
message = {}
if "content" in candidate:
message["content"] = candidate["content"]
else:
message["content"] = ""
message["role"] = "assistant"
choices.append({
"finish_reason": candidate.get("finishReason", "stop"),
"index": i,
"message": message
})
usage = {}
if "usageMetadata" in data:
usage = {
"completion_tokens": data["usageMetadata"].get("tokenCount", 0),
"prompt_tokens": 0, # This value is not available in the response
"total_tokens": data["usageMetadata"].get("tokenCount", 0)
}
response_data = {
"id": id,
"choices": choices,
"created": 1678787675, # Replace with actual timestamp if available
"model": model,
"object": "chat.completion",
"usage": usage
}
return json.dumps(response_data, ensure_ascii=False)
@app.post("/v1/{dest_url:path}")
async def proxy_url(dest_url: str, request: Request):
body = await request.body()
headers = dict(request.headers)
# Remove Content-Length and Host headers
if 'content-length' in headers:
del headers['content-length']
if 'host' in headers:
del headers['host']
# Extract API key from Authorization header
auth = headers.get("Authorization")
api_key = auth.split(" ")[1] if auth else None
# Set required headers
headers["x-goog-api-client"] = API_CLIENT
if api_key:
headers["x-goog-api-key"] = api_key
headers['Content-Type'] = 'application/json'
#if 'user-agent' in headers:
# del headers['user-agent']
dest_url = re.sub('/', '://', dest_url, count=1)
# Modify dest_url based on the endpoint
if dest_url.endswith("/chat/completions"):
model = DEFAULT_MODEL
req_body = json.loads(body.decode('utf-8'))
if 'model' in req_body:
model = req_body['model']
if model.startswith("models/"):
model = model[7:]
TASK = "generateContent"
url = f"{dest_url.rsplit('/', 1)[0]}/{model}:{TASK}"
async with httpx.AsyncClient() as client:
try:
# Forward the modified request
response = await client.post(url, content=body, headers=headers)
# Check response status code
if response.status_code == 200:
# Process JSON response
if 'application/json' in response.headers.get('content-type', ''):
json_response = response.json()
json_response['id'] = f"chatcmpl-{uuid.uuid4()}"
processed_response = await process_completions_response(json_response, model, json_response['id'])
resp = Response(content=processed_response, media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
else:
return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"}
else:
# Convert error response to JSON format and return to the client
try:
error_data = response.json()
error_data['id'] = f"chatcmpl-{uuid.uuid4()}"
except ValueError:
error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"}
print(f"Error response: {error_data}")
resp = Response(content=json.dumps(error_data, ensure_ascii=False), media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
except httpx.RequestError as e:
# Handle request errors
print(f"Request error: {e}")
raise HTTPException(status_code=500, detail=str(e))
else:
async with httpx.AsyncClient() as client:
try:
# Forward the original request
response = await client.post(dest_url, content=body, headers=headers)
# Check response status code
if response.status_code == 200:
# Process JSON response
if 'application/json' in response.headers.get('content-type', ''):
json_response = response.json()
json_response['id'] = f"chatcmpl-{uuid.uuid4()}"
resp = Response(content=json.dumps(json_response), media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
else:
return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"}
else:
# Convert error response to JSON format and return to the client
try:
error_data = response.json()
error_data['id'] = f"chatcmpl-{uuid.uuid4()}"
except ValueError:
error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"}
resp = Response(content=json.dumps(error_data), media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
except httpx.RequestError as e:
# Handle request errors
raise HTTPException(status_code=500, detail=str(e))
# uvicorn app:app --host 0.0.0.0 --port 7860 --reload
from fastapi import FastAPI, Request, HTTPException, Response
import re, json
import httpx
import uuid
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.get("/v1/{dest_url:path}")
async def gre_dest_url(dest_url: str):
return dest_url
@app.post("/v1/{dest_url:path}")
async def proxy_url(dest_url: str, request: Request):
body = await request.body()
headers = dict(request.headers) # 将请求头转换为字典
# 移除 Content-Length 和 Host 头部
if 'content-length' in headers:
del headers['content-length']
if 'host' in headers:
del headers['host']
headers['User-Agent']='PostmanRuntime/7.43.0'
dest_url = re.sub('/', '://', dest_url, count=1)
async with httpx.AsyncClient() as client:
try:
print(f"Request Headers: {headers}")
# 向目标 URL 发送 POST 请求
response = await client.post(dest_url, content=body, headers=headers)
# 检查响应状态码
if response.status_code == 200:
# 检查响应内容类型是否为 JSON
if 'application/json' in response.headers.get('content-type', ''):
json_response = response.json()
json_response['id'] = f"chatcmpl-{uuid.uuid4()}"
# resp = Response(content=str(json_response), media_type="application/json")
# resp = Response(content=str(json_response), media_type="application/json")
resp = Response(content=json.dumps(json_response), media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
else:
return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"}
else:
# 将错误响应转换为 JSON 格式并返回给客户端
try:
error_data = response.json()
error_data['id'] = f"chatcmpl-{uuid.uuid4()}"
except ValueError:
error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"}
resp = Response(content=str(error_data), media_type="application/json")
resp.headers["Access-Control-Allow-Origin"] = "*"
return resp
except httpx.RequestError as e:
# 处理请求错误
raise HTTPException(status_code=500, detail=str(e))