File size: 12,801 Bytes
e536852
ea61576
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd88cfe
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
fd88cfe
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd88cfe
e536852
 
 
 
fd88cfe
e536852
 
fd88cfe
e536852
fd88cfe
e536852
 
 
 
 
 
 
fd88cfe
 
 
e536852
 
fd88cfe
 
e536852
fd88cfe
 
ea61576
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea61576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea61576
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# ===================================================================
# main.py (最终修复版:修正函数定义顺序)
# ===================================================================

import json
import os
import time
import uuid
import threading
from typing import Any, Dict, List, Optional, TypedDict, Union

import requests
from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field

# --- 类型定义和全局变量 ---
class CodeGeeXToken(TypedDict):
    token: str
    is_valid: bool
    last_used: float
    error_count: int

VALID_CLIENT_KEYS: set = set()
CODEGEEX_TOKENS: List[CodeGeeXToken] = []
CODEGEEX_MODELS: List[str] = ["claude-3-7-sonnet", "claude-sonnet-4"]
token_rotation_lock = threading.Lock()
MAX_ERROR_COUNT = 3
ERROR_COOLDOWN = 300
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true"

# --- Pydantic 模型 ---
class ChatMessage(BaseModel):
    role: str
    content: Union[str, List[Dict[str, Any]]]
    reasoning_content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    stream: bool = True
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    top_p: Optional[float] = None
class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int
    owned_by: str
class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]
class ChatCompletionChoice(BaseModel):
    message: ChatMessage
    index: int = 0
    finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
    object: str = "chat.completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionChoice]
    usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
class StreamChoice(BaseModel):
    delta: Dict[str, Any] = Field(default_factory=dict)
    index: int = 0
    finish_reason: Optional[str] = None
class StreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
    object: str = "chat.completion.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[StreamChoice]

# --- FastAPI App ---
app = FastAPI(title="CodeGeeX OpenAI API Adapter")
security = HTTPBearer(auto_error=False)

def log_debug(message: str):
    if DEBUG_MODE:
        print(f"[DEBUG] {message}")

# --- 配置加载函数 (从 Secrets 读取) ---
def load_client_api_keys_from_secrets():
    global VALID_CLIENT_KEYS
    try:
        keys_str = os.environ.get("CLIENT_API_KEYS")
        if not keys_str: raise ValueError("Secret 'CLIENT_API_KEYS' not found.")
        keys = json.loads(keys_str)
        VALID_CLIENT_KEYS = set(keys) if isinstance(keys, list) else set()
        print(f"Successfully loaded {len(VALID_CLIENT_KEYS)} client API keys.")
    except Exception as e:
        print(f"FATAL: Error loading client API keys: {e}")
        VALID_CLIENT_KEYS = set()

def load_codegeex_tokens_from_secrets():
    global CODEGEEX_TOKENS
    CODEGEEX_TOKENS = []
    try:
        tokens_str = os.environ.get("CODEGEEX_TOKENS")
        if not tokens_str: raise ValueError("Secret 'CODEGEEX_TOKENS' not found.")
        tokens = json.loads(tokens_str)
        if not isinstance(tokens, list): raise TypeError("Secret 'CODEGEEX_TOKENS' must be a JSON list.")
        for token in tokens:
            if isinstance(token, str) and token:
                CODEGEEX_TOKENS.append({"token": token, "is_valid": True, "last_used": 0, "error_count": 0})
        print(f"Successfully loaded {len(CODEGEEX_TOKENS)} CodeGeeX tokens.")
    except Exception as e:
        print(f"FATAL: Error loading CodeGeeX tokens: {e}")

# --- 核心工具函数和认证 (确保在使用前定义) ---
def get_best_codegeex_token() -> Optional[CodeGeeXToken]:
    with token_rotation_lock:
        now = time.time()
        valid_tokens = [t for t in CODEGEEX_TOKENS if t["is_valid"] and (t["error_count"] < MAX_ERROR_COUNT or now - t["last_used"] > ERROR_COOLDOWN)]
        if not valid_tokens: return None
        for token in valid_tokens:
            if token["error_count"] >= MAX_ERROR_COUNT and now - token["last_used"] > ERROR_COOLDOWN: token["error_count"] = 0
        valid_tokens.sort(key=lambda x: (x["last_used"], x["error_count"]))
        token = valid_tokens[0]
        token["last_used"] = now
        return token

def _convert_messages_to_codegeex_format(messages: List[ChatMessage]):
    if not messages: return "", []
    last_user_msg = next((msg for msg in reversed(messages) if msg.role == "user"), None)
    if not last_user_msg: raise HTTPException(status_code=400, detail="No user message found.")
    prompt = last_user_msg.content if isinstance(last_user_msg.content, str) else ""
    history, user_content, assistant_content = [], "", ""
    for msg in messages:
        if msg == last_user_msg: break
        if msg.role == "user":
            if user_content and assistant_content: history.append({"query": user_content, "answer": assistant_content, "id": f"{uuid.uuid4()}"}); user_content, assistant_content = "", ""
            user_content = msg.content if isinstance(msg.content, str) else ""
        elif msg.role == "assistant":
            assistant_content = msg.content if isinstance(msg.content, str) else ""
            if user_content: history.append({"query": user_content, "answer": assistant_content, "id": f"{uuid.uuid4()}"}); user_content, assistant_content = "", ""
    if user_content and not assistant_content: prompt = user_content + "\n" + prompt
    return prompt, history

async def authenticate_client(auth: Optional[HTTPAuthorizationCredentials] = Depends(security)):
    if not VALID_CLIENT_KEYS: raise HTTPException(status_code=503, detail="Service unavailable: Client API keys not configured.")
    if not auth or not auth.credentials: raise HTTPException(status_code=401, detail="API key required.", headers={"WWW-Authenticate": "Bearer"})
    if auth.credentials not in VALID_CLIENT_KEYS: raise HTTPException(status_code=403, detail="Invalid client API key.")

# --- FastAPI 事件和路由 ---
@app.on_event("startup")
async def startup():
    print("Starting CodeGeeX OpenAI API Adapter server...")
    load_client_api_keys_from_secrets()
    load_codegeex_tokens_from_secrets()
    print("Server initialization completed.")

@app.get("/")
def health_check():
    return {"status": "ok", "message": "CodeGeeX API Adapter is running."}

def get_models_list_response() -> ModelList:
    return ModelList(data=[ModelInfo(id=model, created=int(time.time()), owned_by="anthropic") for model in CODEGEEX_MODELS])

@app.get("/v1/models", response_model=ModelList)
async def list_v1_models(_: None = Depends(authenticate_client)):
    return get_models_list_response()

@app.get("/models", response_model=ModelList)
async def list_models_no_auth():
    return get_models_list_response()

def _codegeex_stream_generator(response, model: str):
    stream_id = f"chatcmpl-{uuid.uuid4().hex}"
    created_time = int(time.time())
    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).json()}\n\n"
    buffer = ""
    try:
        for chunk in response.iter_content(chunk_size=1024):
            if not chunk: continue
            buffer += chunk.decode("utf-8", errors='ignore')
            while "\n\n" in buffer:
                event_data, buffer = buffer.split("\n\n", 1)
                event_data = event_data.strip()
                if not event_data: continue
                event_type, data_json = None, None
                for line in event_data.split("\n"):
                    if line.startswith("event:"): event_type = line[6:].strip()
                    elif line.startswith("data:"):
                        try: data_json = json.loads(line[5:].strip())
                        except: continue
                if not event_type or not data_json: continue
                if event_type == "add":
                    delta = data_json.get("text", "")
                    if delta: yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': delta})]).json()}\n\n"
                elif event_type == "finish":
                    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n"
                    yield "data: [DONE]\n\n"
                    return
    except Exception as e:
        log_debug(f"Stream processing error: {e}")
        yield f"data: {json.dumps({'error': str(e)})}\n\n"
    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n"
    yield "data: [DONE]\n\n"

def _build_codegeex_non_stream_response(response, model: str) -> ChatCompletionResponse:
    full_content = ""
    buffer = ""
    for chunk in response.iter_content(chunk_size=1024):
        if not chunk: continue
        buffer += chunk.decode("utf-8", errors='ignore')
        while "\n\n" in buffer:
            event_data, buffer = buffer.split("\n\n", 1)
            event_data = event_data.strip()
            if not event_data: continue
            event_type, data_json = None, None
            for line in event_data.split("\n"):
                if line.startswith("event:"): event_type = line[6:].strip()
                elif line.startswith("data:"):
                    try: data_json = json.loads(line[5:].strip())
                    except: continue
                if not event_type or not data_json: continue
                if event_type == "add": full_content += data_json.get("text", "")
                elif event_type == "finish":
                    finish_text = data_json.get("text", "")
                    if finish_text: full_content = finish_text
                    return ChatCompletionResponse(model=model, choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))])
    return ChatCompletionResponse(model=model, choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))])

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, _: None = Depends(authenticate_client)):
    if request.model not in CODEGEEX_MODELS: raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.")
    if not request.messages: raise HTTPException(status_code=400, detail="No messages provided.")
    try: prompt, history = _convert_messages_to_codegeex_format(request.messages)
    except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to process messages: {e}")
    for attempt in range(len(CODEGEEX_TOKENS) + 1):
        if attempt == len(CODEGEEX_TOKENS): raise HTTPException(status_code=503, detail="All attempts to contact CodeGeeX API failed.")
        token = get_best_codegeex_token()
        if not token: raise HTTPException(status_code=503, detail="No valid CodeGeeX tokens available.")
        try:
            payload = {"user_role": 0, "ide": "VSCode", "prompt": prompt, "model": request.model, "history": history, "talkId": f"{uuid.uuid4()}", "plugin_version": "", "locale": "", "agent": None, "candidates": {"candidate_msg_id": "", "candidate_type": "", "selected_candidate": ""}, "ide_version": "", "machineId": ""}
            headers = {"User-Agent": "Mozilla/5.0", "Accept": "text/event-stream", "Content-Type": "application/json", "code-token": token["token"]}
            response = requests.post("https://codegeex.cn/prod/code/chatCodeSseV3/chat", data=json.dumps(payload), headers=headers, stream=True, timeout=300.0)
            response.raise_for_status()
            if request.stream: return StreamingResponse(_codegeex_stream_generator(response, request.model), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"})
            else: return _build_codegeex_non_stream_response(response, request.model)
        except requests.HTTPError as e:
            status_code = getattr(e.response, "status_code", 500)
            with token_rotation_lock:
                if status_code in [401, 403]: token["is_valid"] = False
                elif status_code in [429, 500, 502, 503, 504]: token["error_count"] += 1
        except Exception as e:
            with token_rotation_lock: token["error_count"] += 1