File size: 5,023 Bytes
8d674ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
import requests
from requests.exceptions import RequestException
import os
import logging
import json

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()
XAI_API_BASE = "https://api.x.ai/v1"

class ChatRequest(BaseModel):
    messages: list[dict] = Field(..., description="消息列表")
    model: str = Field(..., description="模型ID")
    max_tokens: int | None = None
    temperature: float | None = None
    top_p: float | None = None
    stream: bool = False
    presence_penalty: float | None = None
    frequency_penalty: float | None = None

async def stream_generator(response, stream):
    try:
        for chunk in response.iter_lines():
            if chunk:
                decoded_chunk = chunk.decode('utf-8')
                if decoded_chunk.startswith("data: "):
                    yield f"data: {decoded_chunk[6:]}\n\n"
                else:
                    yield f"data: {json.dumps({'error': 'Invalid chunk format'})}\n\n"
    except RequestException as e:
        yield f"data: {json.dumps({'error': str(e)})}\n\n"
    finally:
        if stream:
            response.close()

@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest, request: Request):
    logger.info(f"收到请求: {req.dict()}")
    api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
    
    if not api_key:
        logger.error("缺少API密钥")
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少API密钥")

    try:
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
            "User-Agent": "SillyTavern-Proxy/1.0",
            "Accept": "text/event-stream" if req.stream else "application/json"
        }

        payload = req.dict(exclude_unset=True)
        filtered_payload = {k: v for k, v in payload.items() if k in ["messages", "model", "max_tokens", "temperature", "top_p", "stream"]}

        logger.info(f"转发到xAI的负载: {filtered_payload}")
        
        response = requests.post(
            f"{XAI_API_BASE}/chat/completions",
            headers=headers,
            json=filtered_payload,
            stream=req.stream,
            timeout=20
        )
        response.raise_for_status()

        if req.stream:
            return StreamingResponse(
                stream_generator(response, req.stream),
                media_type="text/event-stream"
            )
        else:
            try:
                return response.json()
            except json.JSONDecodeError:
                logger.error(f"无效的JSON响应: {response.text}")
                raise HTTPException(status_code=502, detail="上游服务器返回无效响应")

    except RequestException as e:
        error_detail = ""
        if e.response is not None:
            try:
                error_detail = e.response.json().get("error", e.response.text)
            except json.JSONDecodeError:
                error_detail = e.response.text[:500]
            status_code = e.response.status_code
        else:
            error_detail = str(e)
            status_code = 504

        logger.error(f"请求失败: {error_detail}")
        raise HTTPException(
            status_code=status_code,
            detail=f"xAI API错误: {error_detail}"
        )

# 保持原来的模型列表端点
@app.get("/v1/models")
async def get_models(request: Request):
    api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
    if not api_key:
        raise HTTPException(status_code=401, detail="缺少API密钥")
    
    try:
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        response = requests.get(f"{XAI_API_BASE}/models", headers=headers, timeout=10)
        response.raise_for_status()
        return response.json()
    except RequestException as e:
        logger.warning(f"获取模型失败: {str(e)},返回备用数据")
        return {
            "object": "list",
            "data": [
                {"id": "grok-3-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
                {"id": "grok-3-mini-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
                {"id": "grok-3-fast-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
                {"id": "grok-3-mini-fast-beta", "object": "model", "created": 1744681729, "owned_by": "xAI"},
            ]
        }

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)