polarbearblue commited on
Commit
e896faf
·
verified ·
1 Parent(s): 29b6bf1

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +386 -0
main.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import uuid
4
+ import threading
5
+ from typing import Any, AsyncGenerator, Dict, List, Optional
6
+
7
+ import httpx
8
+ import uvicorn
9
+ from fastapi import FastAPI, HTTPException, Depends, Header
10
+ from fastapi.responses import StreamingResponse
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ from pydantic import BaseModel, Field
13
+
14
+ # Configuration
15
+ CONVERSATION_CACHE_MAX_SIZE = 100
16
+ DEFAULT_REQUEST_TIMEOUT = 30.0
17
+
18
+ # Global variables
19
+ VALID_CLIENT_KEYS: set = set()
20
+ JETBRAINS_JWTS: list = []
21
+ current_jwt_index: int = 0
22
+ jwt_rotation_lock = threading.Lock()
23
+ models_data: Dict[str, Any] = {}
24
+ http_client: Optional[httpx.AsyncClient] = None
25
+
26
+ # Pydantic Models
27
+ class ChatMessage(BaseModel):
28
+ role: str
29
+ content: str
30
+
31
+ class ChatCompletionRequest(BaseModel):
32
+ model: str
33
+ messages: List[ChatMessage]
34
+ stream: bool = False
35
+ temperature: Optional[float] = None
36
+ max_tokens: Optional[int] = None
37
+ top_p: Optional[float] = None
38
+
39
+ class ModelInfo(BaseModel):
40
+ id: str
41
+ object: str = "model"
42
+ created: int
43
+ owned_by: str
44
+
45
+ class ModelList(BaseModel):
46
+ object: str = "list"
47
+ data: List[ModelInfo]
48
+
49
+ class ChatCompletionChoice(BaseModel):
50
+ message: ChatMessage
51
+ index: int = 0
52
+ finish_reason: str = "stop"
53
+
54
+ class ChatCompletionResponse(BaseModel):
55
+ id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
56
+ object: str = "chat.completion"
57
+ created: int = Field(default_factory=lambda: int(time.time()))
58
+ model: str
59
+ choices: List[ChatCompletionChoice]
60
+ usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
61
+
62
+ class StreamChoice(BaseModel):
63
+ delta: Dict[str, Any] = Field(default_factory=dict)
64
+ index: int = 0
65
+ finish_reason: Optional[str] = None
66
+
67
+ class StreamResponse(BaseModel):
68
+ id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
69
+ object: str = "chat.completion.chunk"
70
+ created: int = Field(default_factory=lambda: int(time.time()))
71
+ model: str
72
+ choices: List[StreamChoice]
73
+
74
+ # FastAPI App
75
+ app = FastAPI(title="JetBrains AI OpenAI Compatible API")
76
+ security = HTTPBearer(auto_error=False)
77
+
78
+ # Helper functions
79
+ def load_models():
80
+ """加载模型配置"""
81
+ try:
82
+ with open("models.json", "r", encoding="utf-8") as f:
83
+ model_ids = json.load(f)
84
+
85
+ processed_models = []
86
+ if isinstance(model_ids, list):
87
+ for model_id in model_ids:
88
+ if isinstance(model_id, str):
89
+ processed_models.append({
90
+ "id": model_id,
91
+ "object": "model",
92
+ "created": int(time.time()),
93
+ "owned_by": "jetbrains-ai"
94
+ })
95
+
96
+ return {"data": processed_models}
97
+ except Exception as e:
98
+ print(f"加载 models.json 时出错: {e}")
99
+ return {"data": []}
100
+
101
+ def load_client_api_keys():
102
+ """加载客户端 API 密钥"""
103
+ global VALID_CLIENT_KEYS
104
+ try:
105
+ with open("client_api_keys.json", "r", encoding="utf-8") as f:
106
+ keys = json.load(f)
107
+ if not isinstance(keys, list):
108
+ print("警告: client_api_keys.json 应包含密钥列表")
109
+ VALID_CLIENT_KEYS = set()
110
+ return
111
+ VALID_CLIENT_KEYS = set(keys)
112
+ if not VALID_CLIENT_KEYS:
113
+ print("警告: client_api_keys.json 为空")
114
+ else:
115
+ print(f"成功加载 {len(VALID_CLIENT_KEYS)} 个客户端 API 密钥")
116
+ except FileNotFoundError:
117
+ print("错误: 未找到 client_api_keys.json")
118
+ VALID_CLIENT_KEYS = set()
119
+ except Exception as e:
120
+ print(f"加载 client_api_keys.json 时出错: {e}")
121
+ VALID_CLIENT_KEYS = set()
122
+
123
+ def load_jetbrains_jwts():
124
+ """加载 JetBrains AI 认证 JWT"""
125
+ global JETBRAINS_JWTS
126
+ try:
127
+ with open("jetbrainsai.json", "r", encoding="utf-8") as f:
128
+ # 假设 jetbrainsai.json 包含一个对象列表,每个对象都有 'jwt' 键
129
+ jwt_data = json.load(f)
130
+ if isinstance(jwt_data, list):
131
+ JETBRAINS_JWTS = [item.get("jwt") for item in jwt_data if "jwt" in item]
132
+
133
+ if not JETBRAINS_JWTS:
134
+ print("警告: jetbrainsai.json 中未找到有效的 JWT")
135
+ else:
136
+ print(f"成功加载 {len(JETBRAINS_JWTS)} 个 JetBrains AI JWT")
137
+
138
+ except FileNotFoundError:
139
+ print("错误: 未找到 jetbrainsai.json 文件")
140
+ JETBRAINS_JWTS = []
141
+ except Exception as e:
142
+ print(f"加载 jetbrainsai.json 时出错: {e}")
143
+ JETBRAINS_JWTS = []
144
+
145
+ def get_model_item(model_id: str) -> Optional[Dict]:
146
+ """根据模型ID获取模型配置"""
147
+ for model in models_data.get("data", []):
148
+ if model.get("id") == model_id:
149
+ return model
150
+ return None
151
+
152
+ async def authenticate_client(auth: Optional[HTTPAuthorizationCredentials] = Depends(security)):
153
+ """客户端认证"""
154
+ if not VALID_CLIENT_KEYS:
155
+ raise HTTPException(status_code=503, detail="服务不可用: 未配置客户端 API 密钥")
156
+
157
+ if not auth or not auth.credentials:
158
+ raise HTTPException(
159
+ status_code=401,
160
+ detail="需要在 Authorization header 中提供 API 密钥",
161
+ headers={"WWW-Authenticate": "Bearer"},
162
+ )
163
+
164
+ if auth.credentials not in VALID_CLIENT_KEYS:
165
+ raise HTTPException(status_code=403, detail="无效的客户端 API 密钥")
166
+
167
+ def get_next_jetbrains_jwt() -> str:
168
+ """轮询获取下一个 JetBrains JWT"""
169
+ global current_jwt_index
170
+
171
+ if not JETBRAINS_JWTS:
172
+ raise HTTPException(status_code=503, detail="服务不可用: 未配置 JetBrains JWT")
173
+
174
+ with jwt_rotation_lock:
175
+ if not JETBRAINS_JWTS:
176
+ raise HTTPException(status_code=503, detail="服务不可用: JetBrains JWT 不可用")
177
+ token_to_use = JETBRAINS_JWTS[current_jwt_index]
178
+ current_jwt_index = (current_jwt_index + 1) % len(JETBRAINS_JWTS)
179
+ return token_to_use
180
+
181
+ # FastAPI 生命周期事件
182
+ @app.on_event("startup")
183
+ async def startup():
184
+ global models_data, http_client
185
+ models_data = load_models()
186
+ load_client_api_keys()
187
+ load_jetbrains_jwts()
188
+ http_client = httpx.AsyncClient(timeout=None)
189
+ print("JetBrains AI OpenAI Compatible API 服务器已启动")
190
+
191
+ @app.on_event("shutdown")
192
+ async def shutdown():
193
+ global http_client
194
+ if http_client:
195
+ await http_client.aclose()
196
+
197
+ # API 端点
198
+ @app.get("/v1/models", response_model=ModelList)
199
+ async def list_models(_: None = Depends(authenticate_client)):
200
+ """列出可用模型"""
201
+ model_list = []
202
+ for model in models_data.get("data", []):
203
+ model_list.append(ModelInfo(
204
+ id=model.get("id", ""),
205
+ created=model.get("created", int(time.time())),
206
+ owned_by=model.get("owned_by", "jetbrains-ai")
207
+ ))
208
+ return ModelList(data=model_list)
209
+
210
+ async def openai_stream_adapter(
211
+ api_stream_generator: AsyncGenerator[str, None],
212
+ model_name: str
213
+ ) -> AsyncGenerator[str, None]:
214
+ """将 JetBrains API 的流转换为 OpenAI 格式的 SSE"""
215
+ stream_id = f"chatcmpl-{uuid.uuid4().hex}"
216
+ first_chunk_sent = False
217
+
218
+ try:
219
+ async for line in api_stream_generator:
220
+ if not line or line == "data: end":
221
+ continue
222
+
223
+ if line.startswith('data: '):
224
+ try:
225
+ data = json.loads(line[6:])
226
+ event_type = data.get("type")
227
+
228
+ if event_type == "Content":
229
+ content = data.get("content", "")
230
+ if not content:
231
+ continue
232
+
233
+ delta_payload = {}
234
+ if not first_chunk_sent:
235
+ delta_payload = {"role": "assistant", "content": content}
236
+ first_chunk_sent = True
237
+ else:
238
+ delta_payload = {"content": content}
239
+
240
+ stream_resp = StreamResponse(id=stream_id, model=model_name, choices=[StreamChoice(delta=delta_payload)])
241
+ yield f"data: {stream_resp.json()}\n\n"
242
+
243
+ elif event_type == "FinishMetadata":
244
+ final_resp = StreamResponse(id=stream_id, model=model_name, choices=[StreamChoice(delta={}, finish_reason="stop")])
245
+ yield f"data: {final_resp.json()}\n\n"
246
+ break
247
+ except json.JSONDecodeError:
248
+ print(f"警告: 无法解析的 JSON 行: {line}")
249
+ continue
250
+
251
+ yield "data: [DONE]\n\n"
252
+
253
+ except Exception as e:
254
+ print(f"流式适配器错误: {e}")
255
+ error_resp = StreamResponse(
256
+ id=stream_id,
257
+ model=model_name,
258
+ choices=[StreamChoice(
259
+ delta={"role": "assistant", "content": f"内部错误: {str(e)}"},
260
+ index=0,
261
+ finish_reason="stop"
262
+ )]
263
+ )
264
+ yield f"data: {error_resp.json()}\n\n"
265
+ yield "data: [DONE]\n\n"
266
+
267
+ async def aggregate_stream_for_non_stream_response(
268
+ openai_sse_stream: AsyncGenerator[str, None],
269
+ model_name: str
270
+ ) -> ChatCompletionResponse:
271
+ """聚合流式响应为完整响应"""
272
+ content_parts = []
273
+
274
+ async for sse_line in openai_sse_stream:
275
+ if sse_line.startswith("data: ") and sse_line.strip() != "data: [DONE]":
276
+ try:
277
+ data = json.loads(sse_line[6:].strip())
278
+ if data.get("choices") and len(data["choices"]) > 0:
279
+ delta = data["choices"][0].get("delta", {})
280
+ if "content" in delta:
281
+ content_parts.append(delta["content"])
282
+ except:
283
+ pass
284
+
285
+ full_content = "".join(content_parts)
286
+
287
+ return ChatCompletionResponse(
288
+ model=model_name,
289
+ choices=[ChatCompletionChoice(
290
+ message=ChatMessage(role="assistant", content=full_content),
291
+ finish_reason="stop"
292
+ )]
293
+ )
294
+
295
+ @app.post("/v1/chat/completions")
296
+ async def chat_completions(
297
+ request: ChatCompletionRequest,
298
+ _: None = Depends(authenticate_client)
299
+ ):
300
+ """创建聊天完成"""
301
+ model_config = get_model_item(request.model)
302
+ if not model_config:
303
+ raise HTTPException(status_code=404, detail=f"模型 {request.model} 未找到")
304
+
305
+ auth_token = get_next_jetbrains_jwt()
306
+
307
+ # 将 OpenAI 格式的消息转换为 JetBrains 格式
308
+ jetbrains_messages = []
309
+ for msg in request.messages:
310
+ # JetBrains API 需要一个特定的交替格式,这里我们简化处理
311
+ # 实际可能需要更复杂的逻辑来确保用户/助手消息交替
312
+ jetbrains_messages.append({"type": f"{msg.role}_message", "content": msg.content})
313
+
314
+ # 创建 API 请求的 payload
315
+ payload = {
316
+ "prompt": "ij.chat.request.new-chat-on-start", # or other relevant prompt
317
+ "profile": request.model,
318
+ "chat": {
319
+ "messages": jetbrains_messages
320
+ },
321
+ "parameters": {"data": []},
322
+ }
323
+
324
+ headers = {
325
+ "User-Agent": "ktor-client",
326
+ "Accept": "text/event-stream",
327
+ "Content-Type": "application/json",
328
+ "Accept-Charset": "UTF-8",
329
+ "Cache-Control": "no-cache",
330
+ "grazie-agent": '{"name":"aia:pycharm","version":"251.26094.80.13:251.26094.141"}', # 可根据需要更新
331
+ "grazie-authenticate-jwt": auth_token,
332
+ }
333
+
334
+ async def api_stream_generator():
335
+ """一个包装 httpx 请求的异步生成器"""
336
+ async with http_client.stream("POST", "https://api.jetbrains.ai/user/v5/llm/chat/stream/v7",
337
+ json=payload, headers=headers, timeout=300) as response:
338
+ response.raise_for_status()
339
+ async for line in response.aiter_lines():
340
+ yield line
341
+
342
+ # 创建 OpenAI 格式的流
343
+ openai_sse_stream = openai_stream_adapter(
344
+ api_stream_generator(),
345
+ request.model
346
+ )
347
+
348
+ # 返回流式或非流式响应
349
+ if request.stream:
350
+ return StreamingResponse(
351
+ openai_sse_stream,
352
+ media_type="text/event-stream"
353
+ )
354
+ else:
355
+ return await aggregate_stream_for_non_stream_response(
356
+ openai_sse_stream,
357
+ request.model
358
+ )
359
+
360
+ # 主程序入口
361
+ if __name__ == "__main__":
362
+ import os
363
+
364
+ # 创建示例配置文件(如果不存在)
365
+ if not os.path.exists("client_api_keys.json"):
366
+ with open("client_api_keys.json", "w", encoding="utf-8") as f:
367
+ json.dump(["sk-your-custom-key-here"], f, indent=2)
368
+ print("已创建示例 client_api_keys.json 文件")
369
+
370
+ if not os.path.exists("jetbrainsai.json"):
371
+ with open("jetbrainsai.json", "w", encoding="utf-8") as f:
372
+ json.dump([{"jwt": "your-jwt-here"}], f, indent=2)
373
+ print("已创建示例 jetbrainsai.json 文件")
374
+
375
+ if not os.path.exists("models.json"):
376
+ with open("models.json", "w", encoding="utf-8") as f:
377
+ json.dump(["anthropic-claude-3.5-sonnet"], f, indent=2)
378
+ print("已创建示例 models.json 文件")
379
+
380
+ print("正在启动 JetBrains AI OpenAI Compatible API 服务器...")
381
+ print("端点:")
382
+ print(" GET /v1/models")
383
+ print(" POST /v1/chat/completions")
384
+ print("\n在 Authorization header 中使用客户端 API 密钥 (Bearer sk-xxx)")
385
+
386
+ uvicorn.run(app, host="0.0.0.0", port=8000)