dp / main.py
BG5's picture
Update main.py
31b1c18 verified
from fastapi import FastAPI, Request, HTTPException, Header, Depends
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
from typing import Optional
import os
import json
from datetime import datetime
app = FastAPI()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 从环境变量获取 TOKEN
TOKEN = os.getenv("TOKEN")
# 验证 token 的依赖函数
async def verify_token(authorization: Optional[str] = Header(None)):
if TOKEN:
if not authorization or authorization != f"Bearer {TOKEN}":
raise HTTPException(status_code=401, detail="Unauthorized")
return True
# 模型列表数据
MODELS = {
"object": "list",
"data": [
{
"id": "deepseek-ai/DeepSeek-R1",
"object": "model",
"created": 1624980000,
"owned_by": "deepseek-ai"
},
{
"id": "Qwen/QwQ-32B",
"object": "model",
"created": 1640000000,
"owned_by": "Qwen-ai"
},
{
"id": "deepseek-ai/DeepSeek-V3",
"object": "model",
"created": 1632000000,
"owned_by": "deepseek-ai"
},
{
"id": "deepseek-ai/DeepSeek-V3-0324",
"object": "model",
"created": 1632000000,
"owned_by": "deepseek-ai"
},
{
"id": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"object": "model",
"created": 1640000000,
"owned_by": "deepseek-ai"
},
{
"id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"object": "model",
"created": 1645000000,
"owned_by": "deepseek-ai"
}
]
}
@app.options("/{path:path}")
async def handle_options():
return Response(
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Max-Age": "86400",
}
)
@app.get("/v1/models")
async def get_models(auth: bool = Depends(verify_token)):
return JSONResponse(
content=MODELS,
headers={"Access-Control-Allow-Origin": "*"}
)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, auth: bool = Depends(verify_token)):
headers = {
"Origin": "https://deepinfra.com",
"Referer": "https://deepinfra.com/",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "same-site",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36",
"X-Deepinfra-Source": "model-embed",
"accept": "text/event-stream",
"sec-ch-ua": "\"Google Chrome\";v=\"137\", \"Chromium\";v=\"137\", \"Not/A)Brand\";v=\"24\"",
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "\"Windows\""
}
try:
# 1. 获取原始请求体
body = await request.json()
url = "https://api.deepinfra.com/v1/openai/chat/completions"
# 2. 使用 `httpx` 发起流式请求,设置 `stream=True`
stream = body.get("stream", True)
body["stream"] = True # 目标只支持为 True
if stream:
return StreamingResponse(
chunk(url, headers, body), # 逐块读取数据
media_type="text/event-stream"
)
else:
# 3. 如果不需要流式响应,直接获取完整响应
return await chat(url, headers, body)
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=str(e))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def chunk(url, headers, body):
"""逐块读取数据"""
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=body, headers=headers, timeout=1200) as response:
async for line in response.aiter_text():
if line.startswith("data: "):
# line = line[6:]
# print(line)
yield line
async def chat(url, headers, body):
"""逐块读取数据,把content组装到一起返回。"""
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=body, headers=headers, timeout=1200) as response:
content = ""
async for line in response.aiter_text():
# 逐行读取数据并添加到 json_response 中
if line.startswith("data: {"):
# print(line)
line = line[6:]
json_response = json.loads(line)
content += json_response["choices"][0]["delta"]["content"]
json_response["choices"][0]["delta"]["content"] = content
return json_response # Add this line to return the assembled content
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def catch_all(path: str, request: Request):
if request.method == "OPTIONS":
return Response(
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Max-Age": "86400",
}
)
raise HTTPException(status_code=404, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)