File size: 13,362 Bytes
e536852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# ===================================================================
# main.py (已修改以适配 Hugging Face Secrets)
# ===================================================================

import json
import os
import time
import uuid
import threading
from typing import Any, Dict, List, Optional, TypedDict, Union

import requests
from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field

# --- 类型定义和全局变量 (与原始文件一致) ---
class CodeGeeXToken(TypedDict):
    token: str
    is_valid: bool
    last_used: float
    error_count: int

VALID_CLIENT_KEYS: set = set()
CODEGEEX_TOKENS: List[CodeGeeXToken] = []
CODEGEEX_MODELS: List[str] = ["claude-3-7-sonnet", "claude-sonnet-4"]
token_rotation_lock = threading.Lock()
MAX_ERROR_COUNT = 3
ERROR_COOLDOWN = 300
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true"

# --- Pydantic 模型 (与原始文件一致) ---
class ChatMessage(BaseModel):
    role: str
    content: Union[str, List[Dict[str, Any]]]
    reasoning_content: Optional[str] = None

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    stream: bool = True
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    top_p: Optional[float] = None

class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int
    owned_by: str

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]
# ... (其他 Pydantic 模型与原始文件一致)
class ChatCompletionChoice(BaseModel):
    message: ChatMessage
    index: int = 0
    finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
    object: str = "chat.completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionChoice]
    usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
class StreamChoice(BaseModel):
    delta: Dict[str, Any] = Field(default_factory=dict)
    index: int = 0
    finish_reason: Optional[str] = None
class StreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
    object: str = "chat.completion.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[StreamChoice]

# --- FastAPI App ---
app = FastAPI(title="CodeGeeX OpenAI API Adapter")
security = HTTPBearer(auto_error=False)

def log_debug(message: str):
    if DEBUG_MODE:
        print(f"[DEBUG] {message}")

# --- 配置加载函数 (已修改为从 Secrets 读取) ---
def load_client_api_keys_from_secrets():
    """从环境变量加载客户端 API Keys"""
    global VALID_CLIENT_KEYS
    try:
        keys_str = os.environ.get("CLIENT_API_KEYS")
        if not keys_str:
            raise ValueError("Secret 'CLIENT_API_KEYS' not found.")
        keys = json.loads(keys_str)
        VALID_CLIENT_KEYS = set(keys) if isinstance(keys, list) else set()
        print(f"Successfully loaded {len(VALID_CLIENT_KEYS)} client API keys from secrets.")
    except Exception as e:
        print(f"FATAL: Error loading client API keys from secrets: {e}")
        VALID_CLIENT_KEYS = set()

def load_codegeex_tokens_from_secrets():
    """从环境变量加载 CodeGeeX Tokens"""
    global CODEGEEX_TOKENS
    CODEGEEX_TOKENS = []
    try:
        tokens_str = os.environ.get("CODEGEEX_TOKENS")
        if not tokens_str:
            raise ValueError("Secret 'CODEGEEX_TOKENS' not found.")
        tokens = json.loads(tokens_str) # 假设Secret是一个JSON数组
        if not isinstance(tokens, list):
             raise TypeError("Secret 'CODEGEEX_TOKENS' must be a JSON list of strings.")

        for token in tokens:
            if isinstance(token, str) and token:
                CODEGEEX_TOKENS.append({
                    "token": token, "is_valid": True, "last_used": 0, "error_count": 0
                })
        print(f"Successfully loaded {len(CODEGEEX_TOKENS)} CodeGeeX tokens from secrets.")
    except Exception as e:
        print(f"FATAL: Error loading CodeGeeX tokens from secrets: {e}")

# --- 核心逻辑 (与原始文件保持一致) ---
def get_best_codegeex_token() -> Optional[CodeGeeXToken]:
    with token_rotation_lock:
        now = time.time()
        valid_tokens = [t for t in CODEGEEX_TOKENS if t["is_valid"] and (t["error_count"] < MAX_ERROR_COUNT or now - t["last_used"] > ERROR_COOLDOWN)]
        if not valid_tokens: return None
        for token in valid_tokens:
            if token["error_count"] >= MAX_ERROR_COUNT and now - token["last_used"] > ERROR_COOLDOWN: token["error_count"] = 0
        valid_tokens.sort(key=lambda x: (x["last_used"], x["error_count"]))
        token = valid_tokens[0]
        token["last_used"] = now
        return token

def _convert_messages_to_codegeex_format(messages: List[ChatMessage]):
    if not messages: return "", []
    last_user_msg = next((msg for msg in reversed(messages) if msg.role == "user"), None)
    if not last_user_msg: raise HTTPException(status_code=400, detail="No user message found.")
    prompt = last_user_msg.content if isinstance(last_user_msg.content, str) else ""
    history, user_content, assistant_content = [], "", ""
    for msg in messages:
        if msg == last_user_msg: break
        if msg.role == "user":
            if user_content and assistant_content: history.append({"query": user_content, "answer": assistant_content, "id": f"{uuid.uuid4()}"}); user_content, assistant_content = "", ""
            user_content = msg.content if isinstance(msg.content, str) else ""
        elif msg.role == "assistant":
            assistant_content = msg.content if isinstance(msg.content, str) else ""
            if user_content: history.append({"query": user_content, "answer": assistant_content, "id": f"{uuid.uuid4()}"}); user_content, assistant_content = "", ""
    if user_content and not assistant_content: prompt = user_content + "\n" + prompt
    return prompt, history

async def authenticate_client(auth: Optional[HTTPAuthorizationCredentials] = Depends(security)):
    if not VALID_CLIENT_KEYS: raise HTTPException(status_code=503, detail="Service unavailable: Client API keys not configured.")
    if not auth or not auth.credentials: raise HTTPException(status_code=401, detail="API key required.", headers={"WWW-Authenticate": "Bearer"})
    if auth.credentials not in VALID_CLIENT_KEYS: raise HTTPException(status_code=403, detail="Invalid client API key.")

@app.on_event("startup")
async def startup():
    print("Starting CodeGeeX OpenAI API Adapter server...")
    load_client_api_keys_from_secrets()
    load_codegeex_tokens_from_secrets()
    print("Server initialization completed.")

def get_models_list_response() -> ModelList:
    return ModelList(data=[ModelInfo(id=model, created=int(time.time()), owned_by="anthropic") for model in CODEGEEX_MODELS])

@app.get("/v1/models", response_model=ModelList)
async def list_v1_models(_: None = Depends(authenticate_client)):
    return get_models_list_response()

@app.get("/models", response_model=ModelList)
async def list_models_no_auth():
    return get_models_list_response()

# ... (所有路由和核心函数都与原始文件一致, 此处省略以保持简洁, 但它们都在上面的完整代码块中)
# --- The rest of the original code follows ---
# This includes _codegeex_stream_generator, _build_codegeex_non_stream_response, chat_completions, etc.
# They are included in the full code block above.
def _codegeex_stream_generator(response, model: str):
    stream_id = f"chatcmpl-{uuid.uuid4().hex}"
    created_time = int(time.time())
    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).json()}\n\n"
    buffer = ""
    try:
        for chunk in response.iter_content(chunk_size=1024):
            if not chunk: continue
            buffer += chunk.decode("utf-8", errors='ignore')
            while "\n\n" in buffer:
                event_data, buffer = buffer.split("\n\n", 1)
                event_data = event_data.strip()
                if not event_data: continue
                event_type, data_json = None, None
                for line in event_data.split("\n"):
                    if line.startswith("event:"): event_type = line[6:].strip()
                    elif line.startswith("data:"):
                        try: data_json = json.loads(line[5:].strip())
                        except: continue
                if not event_type or not data_json: continue
                if event_type == "add":
                    delta = data_json.get("text", "")
                    if delta: yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': delta})]).json()}\n\n"
                elif event_type == "finish":
                    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n"
                    yield "data: [DONE]\n\n"
                    return
    except Exception as e:
        log_debug(f"Stream processing error: {e}")
        yield f"data: {json.dumps({'error': str(e)})}\n\n"
    yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n"
    yield "data: [DONE]\n\n"

def _build_codegeex_non_stream_response(response, model: str) -> ChatCompletionResponse:
    full_content = ""
    buffer = ""
    for chunk in response.iter_content(chunk_size=1024):
        if not chunk: continue
        buffer += chunk.decode("utf-8", errors='ignore')
        while "\n\n" in buffer:
            event_data, buffer = buffer.split("\n\n", 1)
            event_data = event_data.strip()
            if not event_data: continue
            event_type, data_json = None, None
            for line in event_data.split("\n"):
                if line.startswith("event:"): event_type = line[6:].strip()
                elif line.startswith("data:"):
                    try: data_json = json.loads(line[5:].strip())
                    except: continue
                if not event_type or not data_json: continue
                if event_type == "add": full_content += data_json.get("text", "")
                elif event_type == "finish":
                    finish_text = data_json.get("text", "")
                    if finish_text: full_content = finish_text
                    return ChatCompletionResponse(model=model, choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))])
    return ChatCompletionResponse(model=model, choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))])

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, _: None = Depends(authenticate_client)):
    if request.model not in CODEGEEX_MODELS: raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.")
    if not request.messages: raise HTTPException(status_code=400, detail="No messages provided.")
    try: prompt, history = _convert_messages_to_codegeex_format(request.messages)
    except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to process messages: {e}")
    for attempt in range(len(CODEGEEX_TOKENS) + 1):
        if attempt == len(CODEGEEX_TOKENS): raise HTTPException(status_code=503, detail="All attempts to contact CodeGeeX API failed.")
        token = get_best_codegeex_token()
        if not token: raise HTTPException(status_code=503, detail="No valid CodeGeeX tokens available.")
        try:
            payload = {"user_role": 0, "ide": "VSCode", "prompt": prompt, "model": request.model, "history": history, "talkId": f"{uuid.uuid4()}", "plugin_version": "", "locale": "", "agent": None, "candidates": {"candidate_msg_id": "", "candidate_type": "", "selected_candidate": ""}, "ide_version": "", "machineId": ""}
            headers = {"User-Agent": "Mozilla/5.0", "Accept": "text/event-stream", "Content-Type": "application/json", "code-token": token["token"]}
            response = requests.post("https://codegeex.cn/prod/code/chatCodeSseV3/chat", data=json.dumps(payload), headers=headers, stream=True, timeout=300.0)
            response.raise_for_status()
            if request.stream: return StreamingResponse(_codegeex_stream_generator(response, request.model), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"})
            else: return _build_codegeex_non_stream_response(response, request.model)
        except requests.HTTPError as e:
            status_code = getattr(e.response, "status_code", 500)
            with token_rotation_lock:
                if status_code in [401, 403]: token["is_valid"] = False
                elif status_code in [429, 500, 502, 503, 504]: token["error_count"] += 1
        except Exception as e:
            with token_rotation_lock: token["error_count"] += 1