xai / app.py
forzen's picture
Upload 2 files
8d674ec verified
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
import requests
from requests.exceptions import RequestException
import os
import logging
import json
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
XAI_API_BASE = "https://api.x.ai/v1"
class ChatRequest(BaseModel):
messages: list[dict] = Field(..., description="消息列表")
model: str = Field(..., description="模型ID")
max_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
presence_penalty: float | None = None
frequency_penalty: float | None = None
async def stream_generator(response, stream):
try:
for chunk in response.iter_lines():
if chunk:
decoded_chunk = chunk.decode('utf-8')
if decoded_chunk.startswith("data: "):
yield f"data: {decoded_chunk[6:]}\n\n"
else:
yield f"data: {json.dumps({'error': 'Invalid chunk format'})}\n\n"
except RequestException as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
if stream:
response.close()
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest, request: Request):
logger.info(f"收到请求: {req.dict()}")
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if not api_key:
logger.error("缺少API密钥")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少API密钥")
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": "SillyTavern-Proxy/1.0",
"Accept": "text/event-stream" if req.stream else "application/json"
}
payload = req.dict(exclude_unset=True)
filtered_payload = {k: v for k, v in payload.items() if k in ["messages", "model", "max_tokens", "temperature", "top_p", "stream"]}
logger.info(f"转发到xAI的负载: {filtered_payload}")
response = requests.post(
f"{XAI_API_BASE}/chat/completions",
headers=headers,
json=filtered_payload,
stream=req.stream,
timeout=20
)
response.raise_for_status()
if req.stream:
return StreamingResponse(
stream_generator(response, req.stream),
media_type="text/event-stream"
)
else:
try:
return response.json()
except json.JSONDecodeError:
logger.error(f"无效的JSON响应: {response.text}")
raise HTTPException(status_code=502, detail="上游服务器返回无效响应")
except RequestException as e:
error_detail = ""
if e.response is not None:
try:
error_detail = e.response.json().get("error", e.response.text)
except json.JSONDecodeError:
error_detail = e.response.text[:500]
status_code = e.response.status_code
else:
error_detail = str(e)
status_code = 504
logger.error(f"请求失败: {error_detail}")
raise HTTPException(
status_code=status_code,
detail=f"xAI API错误: {error_detail}"
)
# 保持原来的模型列表端点
@app.get("/v1/models")
async def get_models(request: Request):
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if not api_key:
raise HTTPException(status_code=401, detail="缺少API密钥")
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.get(f"{XAI_API_BASE}/models", headers=headers, timeout=10)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.warning(f"获取模型失败: {str(e)},返回备用数据")
return {
"object": "list",
"data": [
{"id": "grok-3-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
{"id": "grok-3-mini-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
{"id": "grok-3-fast-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
{"id": "grok-3-mini-fast-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
]
}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)