dan92 commited on
Commit
4cfabc2
·
verified ·
1 Parent(s): 01fc516

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +400 -271
main.py CHANGED
@@ -6,8 +6,7 @@ import base64
6
  import re
7
  import os
8
  import argparse
9
- import time
10
- from datetime import datetime, timezone
11
  from typing import List, Optional
12
 
13
  import httpx
@@ -26,21 +25,36 @@ from fastapi.staticfiles import StaticFiles
26
 
27
  from bearer_token import BearerTokenGenerator
28
 
29
- from fastapi import Depends, HTTPException, Security
30
- from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
31
-
32
- # 模型列表
33
  MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
34
 
35
  # 默认端口
36
- INITIAL_PORT = 3000
37
 
38
  # 外部API的URL
39
  EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # 初始化FastAPI应用
42
  app = FastAPI()
43
 
 
 
 
44
  # 添加CORS中间件
45
  app.add_middleware(
46
  CORSMiddleware,
@@ -50,9 +64,6 @@ app.add_middleware(
50
  allow_headers=["Content-Type", "Authorization"], # 允许的头部
51
  )
52
 
53
- # 挂载静态文件路由以提供 images 目录的内容
54
- app.mount("/images", StaticFiles(directory="images"), name="images")
55
-
56
  # 辅助函数
57
  def send_error_response(message: str, status_code: int = 400):
58
  """构建错误响应,并确保包含CORS头"""
@@ -106,17 +117,50 @@ async def download_image(image_url: str) -> Optional[bytes]:
106
  print(f"Error downloading image: {e}")
107
  return None
108
 
109
- def save_base64_image(base64_str: str, images_dir: str = "images") -> str:
110
  """
111
- 将Base64编码的图片保存到images目录,返回文件名
112
  """
 
 
 
113
  if not os.path.exists(images_dir):
114
- os.makedirs(images_dir)
115
- image_data = base64.b64decode(base64_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
117
  file_path = os.path.join(images_dir, filename)
118
- with open(file_path, "wb") as f:
119
- f.write(image_data)
 
 
 
 
 
 
120
  return filename
121
 
122
  def is_base64_image(url: str) -> bool:
@@ -125,73 +169,26 @@ def is_base64_image(url: str) -> bool:
125
  """
126
  return url.startswith("data:image/")
127
 
128
- # 添加 HTTPBearer 实例
129
- security = HTTPBearer()
130
-
131
- # 添加 API_KEY 验证函数
132
- def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
133
- api_key = os.environ.get("API_KEY")
134
- if api_key is None:
135
- raise HTTPException(status_code=500, detail="API_KEY not set in environment variables")
136
- if credentials.credentials != api_key:
137
- raise HTTPException(status_code=401, detail="Invalid API key")
138
- return credentials.credentials
139
-
140
  # 根路径GET请求处理
141
- @app.get("/")
142
- async def root():
143
- return JSONResponse(content={
144
- "service": "AI Chat Completion Proxy",
145
- "usage": {
146
- "endpoint": "/ai/v1/chat/completions",
147
- "method": "POST",
148
- "headers": {
149
- "Content-Type": "application/json",
150
- "Authorization": "Bearer YOUR_API_KEY"
151
- },
152
- "body": {
153
- "model": "One of: " + ", ".join(MODELS),
154
- "messages": [
155
- {"role": "system", "content": "You are a helpful assistant."},
156
- {"role": "user", "content": "Hello, who are you?"}
157
- ],
158
- "stream": False,
159
- "temperature": 0.7,
160
- "max_tokens": 8000
161
- }
162
- },
163
- "availableModels": MODELS,
164
- "endpoints": {
165
- "/ai/v1/chat/completions": "Chat completion endpoint",
166
- "/ai/v1/images/generations": "Image generation endpoint",
167
- "/ai/v1/models": "List available models"
168
- },
169
- "note": "Replace YOUR_API_KEY with your actual API key."
170
- })
171
-
172
- # 返回模型列表
173
- @app.get("/ai/v1/models")
174
- async def list_models(api_key: str = Depends(verify_api_key)):
175
- """返回可用模型列表。"""
176
- models = [
177
- {
178
- "id": model,
179
- "object": "model",
180
- "created": int(time.time()),
181
- "owned_by": "chaton",
182
- "permission": [],
183
- "root": model,
184
- "parent": None,
185
- } for model in MODELS
186
- ]
187
- return JSONResponse(content={
188
- "object": "list",
189
- "data": models
190
- })
191
 
192
- # 聊天完成处理
193
- @app.post("/ai/v1/chat/completions")
194
- async def chat_completions(request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
195
  """
196
  处理聊天完成请求
197
  """
@@ -206,11 +203,15 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
206
  # 处理消息内容
207
  messages = request_body.get("messages", [])
208
  temperature = request_body.get("temperature", 1.0)
209
- top_p = request_body.get("top_p", 1.0)
210
  max_tokens = request_body.get("max_tokens", 8000)
211
  model = request_body.get("model", "gpt-4o")
212
  is_stream = request_body.get("stream", False) # 获取 stream 字段
213
 
 
 
 
 
214
  has_image = False
215
  has_text = False
216
 
@@ -230,11 +231,15 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
230
  url = image_info.get("url", "")
231
  if is_base64_image(url):
232
  # 解码并保存图片
233
- base64_str = url.split(",")[1]
234
- filename = save_base64_image(base64_str)
235
- base_url = app.state.base_url
236
- image_url = f"{base_url}/images/{filename}"
237
- images.append({"data": image_url})
 
 
 
 
238
  else:
239
  images.append({"data": url})
240
  extracted_content = " ".join(text_parts).strip()
@@ -269,19 +274,15 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
269
  if not cleaned_messages:
270
  raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
271
 
272
- # 验证模型
273
- if model not in MODELS:
274
- model = "gpt-4o"
275
-
276
  # 构建新的请求JSON
277
  new_request_json = {
278
  "function_image_gen": False,
279
  "function_web_search": True,
280
  "max_tokens": max_tokens,
281
  "model": model,
282
- "source": "chat/free",
283
  "temperature": temperature,
284
- "top_p": top_p,
285
  "messages": cleaned_messages,
286
  }
287
 
@@ -299,13 +300,32 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
299
  "Date": formatted_date,
300
  "Client-time-zone": "-05:00",
301
  "Authorization": bearer_token,
302
- "User-Agent": "ChatOn_Android/1.53.502",
303
  "Accept-Language": "en-US",
304
  "X-Cl-Options": "hb",
305
  "Content-Type": "application/json; charset=UTF-8",
306
  }
307
 
308
  if is_stream:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # 流式响应处理
310
  async def event_generator():
311
  async with httpx.AsyncClient(timeout=None) as client_stream:
@@ -320,32 +340,54 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
320
  break
321
  try:
322
  sse_json = json.loads(data)
323
- if "choices" in sse_json:
324
- for choice in sse_json["choices"]:
325
- delta = choice.get("delta", {})
326
- content = delta.get("content")
327
- if content:
328
- new_sse_json = {
329
- "choices": [
330
- {
331
- "index": choice.get("index", 0),
332
- "delta": {"content": content},
333
- }
334
- ],
335
- "created": sse_json.get(
336
- "created", int(datetime.now(timezone.utc).timestamp())
337
- ),
338
- "id": sse_json.get(
339
- "id", str(uuid.uuid4())
340
- ),
341
- "model": sse_json.get("model", "gpt-4o"),
342
- "system_fingerprint": f"fp_{uuid.uuid4().hex[:12]}",
343
- }
344
- new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
345
- yield new_sse_line
346
- except json.JSONDecodeError:
347
- print("JSON解析错误")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  continue
 
 
 
349
  except httpx.RequestError as exc:
350
  print(f"外部API请求失败: {exc}")
351
  yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
@@ -401,7 +443,7 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
401
  "id": f"chatcmpl-{uuid.uuid4()}",
402
  "object": "chat.completion",
403
  "created": int(datetime.now(timezone.utc).timestamp()),
404
- "model": model,
405
  "choices": [
406
  {
407
  "index": 0,
@@ -430,8 +472,8 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks,
430
  raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
431
 
432
  # 图像生成处理
433
- @app.post("/ai/v1/images/generations")
434
- async def images_generations(request: Request, api_key: str = Depends(verify_api_key)):
435
  """
436
  处理图像生成请求
437
  """
@@ -447,177 +489,238 @@ async def images_generations(request: Request, api_key: str = Depends(verify_api
447
  return send_error_response("缺少必需的字段: prompt", status_code=400)
448
 
449
  user_prompt = request_body.get("prompt", "").strip()
450
- response_format = request_body.get("response_format", "b64_json").strip()
 
 
 
 
 
451
 
452
  if not user_prompt:
453
  return send_error_response("Prompt 不能为空。", status_code=400)
454
 
455
  print(f"Prompt: {user_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- # 构建新的 TextToImage JSON 请求体
458
- text_to_image_json = {
459
- "function_image_gen": True,
460
- "function_web_search": True,
461
- "image_aspect_ratio": "1:1",
462
- "image_style": "photographic", # 暂时固定 image_style
463
- "max_tokens": 8000,
464
- "messages": [
465
- {
466
- "content": "You are a helpful artist, please based on imagination draw a picture.",
467
- "role": "system"
468
- },
469
- {
470
- "content": "Draw: " + user_prompt,
471
- "role": "user"
 
 
 
 
472
  }
473
- ],
474
- "model": "gpt-4o", # 固定 model,只能gpt-4o或gpt-4o-mini
475
- "source": "chat/pro_image" # 固定 source
476
- }
477
 
478
- modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
479
- print("Modified Request JSON:", modified_request_body)
 
 
 
 
 
480
 
481
- # 获取Bearer Token
482
- tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
483
- if not tmp_token:
484
- return send_error_response("无法生成 Bearer Token", status_code=500)
485
 
486
- bearer_token, formatted_date = tmp_token
 
 
487
 
488
- headers = {
489
- "Date": formatted_date,
490
- "Client-time-zone": "-05:00",
491
- "Authorization": bearer_token,
492
- "User-Agent": "ChatOn_Android/1.53.502",
493
- "Accept-Language": "en-US",
494
- "X-Cl-Options": "hb",
495
- "Content-Type": "application/json; charset=UTF-8",
496
- }
 
 
 
 
 
 
 
497
 
498
- async with httpx.AsyncClient(timeout=None) as client:
499
- try:
500
- response = await client.post(
501
- EXTERNAL_API_URL, headers=headers, content=modified_request_body, timeout=None
502
- )
503
- if response.status_code != 200:
504
- return send_error_response(f"API 错误: {response.status_code}", status_code=500)
505
-
506
- # 初始化用于拼接 URL 的字符串
507
- url_builder = ""
508
-
509
- # 读取 SSE 流并拼接 URL
510
- async for line in response.aiter_lines():
511
- if line.startswith("data: "):
512
- data = line[6:].strip()
513
- if data == "[DONE]":
514
- break
515
- try:
516
- sse_json = json.loads(data)
517
- if "choices" in sse_json:
518
- for choice in sse_json["choices"]:
519
- delta = choice.get("delta", {})
520
- content = delta.get("content")
521
- if content:
522
- url_builder += content
523
- except json.JSONDecodeError:
524
- print("JSON解析错误")
525
- continue
526
-
527
- image_markdown = url_builder
528
- # Step 1: 检查Markdown文本是否为空
529
- if not image_markdown:
530
- print("无法从 SSE 流中构建图像 Markdown。")
531
- return send_error_response("无法从 SSE 流中构建图像 Markdown。", status_code=500)
532
-
533
- # Step 2, 3, 4, 5: 处理图像
534
- extracted_path = extract_path_from_markdown(image_markdown)
535
- if not extracted_path:
536
- print("无法从 Markdown 中提取路径。")
537
- return send_error_response("无法从 Markdown 中提取路径。", status_code=500)
538
-
539
- print(f"提��的路径: {extracted_path}")
540
-
541
- # Step 5: 拼接最终的存储URL
542
- storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
543
- print(f"存储URL: {storage_url}")
544
-
545
- # 获取最终下载URL
546
- final_download_url = await fetch_get_url_from_storage(storage_url)
547
- if not final_download_url:
548
- return send_error_response("无法从 storage URL 获取最终下载链接。", status_code=500)
549
-
550
- print(f"Final Download URL: {final_download_url}")
551
-
552
- # 下载图像
553
- image_bytes = await download_image(final_download_url)
554
- if not image_bytes:
555
- return send_error_response("无法从 URL 下载图像。", status_code=500)
556
-
557
- # 转换为 Base64
558
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
559
-
560
- # 将图片保存到images目录并构建可访问的URL
561
- filename = save_base64_image(image_base64)
562
- base_url = app.state.base_url
563
- accessible_url = f"{base_url}/images/{filename}"
564
-
565
- # 根据 response_format 返回相应的响应
566
- if response_format.lower() == "b64_json":
567
- response_json = {
568
- "data": [
569
- {
570
- "b64_json": image_base64
571
- }
572
- ]
573
- }
574
- return JSONResponse(content=response_json, status_code=200)
575
- else:
576
- # 构建包含可访问URL的响应
577
- response_json = {
578
- "data": [
579
- {
580
- "url": accessible_url
581
- }
582
- ]
583
- }
584
- return JSONResponse(content=response_json, status_code=200)
585
- except httpx.RequestError as exc:
586
- print(f"请求失败: {exc}")
587
- return send_error_response(f"请求失败: {str(exc)}", status_code=500)
588
- except Exception as exc:
589
- print(f"内部服务器错误: {exc}")
590
- return send_error_response(f"内部服务器错误: {str(exc)}", status_code=500)
591
-
592
- # 运行服务器
593
- def main():
594
- parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
595
- parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
596
- parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
597
- args = parser.parse_args()
598
- base_url = args.base_url
599
- port = args.port
600
 
601
- # 检查 API_KEY 是否设置
602
- if not os.environ.get("API_KEY"):
603
- print("警告: API_KEY 环境变量未设置。客户端验证将无法正常工作。")
 
 
604
 
605
- # 确保 images 目录存在
606
- if not os.path.exists("images"):
607
- os.makedirs("images")
608
 
609
- # 设置 FastAPI 应用的 state
610
- app.state.base_url = base_url
 
611
 
612
- print(f"Server started on port {port} with base_url: {base_url}")
 
 
 
 
613
 
614
- # 运行FastAPI应用
615
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
  async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
618
  """查找可用的端口号"""
619
  for port in range(start_port, end_port + 1):
620
  try:
 
621
  server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
622
  server.close()
623
  await server.wait_closed()
@@ -626,5 +729,31 @@ async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 655
626
  continue
627
  raise RuntimeError(f"No available ports between {start_port} and {end_port}")
628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  if __name__ == "__main__":
630
  main()
 
6
  import re
7
  import os
8
  import argparse
9
+ from datetime import datetime, timezone, timedelta
 
10
  from typing import List, Optional
11
 
12
  import httpx
 
25
 
26
  from bearer_token import BearerTokenGenerator
27
 
28
+ # 模型列表(根据需求,可自行调整)
 
 
 
29
  MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
 
31
  # 默认端口
32
+ INITIAL_PORT = 8080
33
 
34
  # 外部API的URL
35
  EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
36
 
37
+ # 定义 images 目录
38
+ IMAGES_DIR = "images"
39
+
40
+ # 确保 images 目录存在
41
+ def ensure_images_dir_exists(directory: str = IMAGES_DIR):
42
+ try:
43
+ os.makedirs(directory, exist_ok=True)
44
+ print(f"Directory '{directory}' is ready.")
45
+ except Exception as e:
46
+ print(f"Failed to create directory '{directory}': {e}")
47
+ sys.exit(1)
48
+
49
+ # 在挂载静态文件之前确保 images 目录存在
50
+ ensure_images_dir_exists()
51
+
52
  # 初始化FastAPI应用
53
  app = FastAPI()
54
 
55
+ # 挂载静态文件路由以提供 images 目录的内容
56
+ app.mount("/images", StaticFiles(directory=IMAGES_DIR), name="images")
57
+
58
  # 添加CORS中间件
59
  app.add_middleware(
60
  CORSMiddleware,
 
64
  allow_headers=["Content-Type", "Authorization"], # 允许的头部
65
  )
66
 
 
 
 
67
  # 辅助函数
68
  def send_error_response(message: str, status_code: int = 400):
69
  """构建错误响应,并确保包含CORS头"""
 
117
  print(f"Error downloading image: {e}")
118
  return None
119
 
120
+ def cleanup_images(images_dir: str = IMAGES_DIR, age_seconds: int = 60):
121
  """
122
+ 清理 images 目录中创建时间超过指定秒数的图片
123
  """
124
+ now = datetime.now(timezone.utc)
125
+ cutoff_time = now - timedelta(seconds=age_seconds)
126
+
127
  if not os.path.exists(images_dir):
128
+ return
129
+
130
+ for filename in os.listdir(images_dir):
131
+ file_path = os.path.join(images_dir, filename)
132
+ if os.path.isfile(file_path):
133
+ try:
134
+ file_creation_time = datetime.fromtimestamp(os.path.getctime(file_path), timezone.utc)
135
+ if file_creation_time < cutoff_time:
136
+ os.remove(file_path)
137
+ print(f"已删除旧图片: {filename}")
138
+ except Exception as e:
139
+ print(f"无法删除文件 {filename}: {e}")
140
+
141
+ def save_base64_image(base64_str: str, images_dir: str = IMAGES_DIR) -> str:
142
+ """
143
+ 将Base64编码的图片保存到images目录,返回文件名
144
+ """
145
+ # 先清理1分钟前的所有图片
146
+ cleanup_images(images_dir, age_seconds=60)
147
+
148
+ try:
149
+ image_data = base64.b64decode(base64_str)
150
+ except base64.binascii.Error as e:
151
+ print(f"Base64解码失败: {e}")
152
+ raise ValueError("Invalid base64 image data")
153
+
154
  filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
155
  file_path = os.path.join(images_dir, filename)
156
+ try:
157
+ with open(file_path, "wb") as f:
158
+ f.write(image_data)
159
+ print(f"保存图片: {filename}")
160
+ except Exception as e:
161
+ print(f"保存图片失败: {e}")
162
+ raise
163
+
164
  return filename
165
 
166
  def is_base64_image(url: str) -> bool:
 
169
  """
170
  return url.startswith("data:image/")
171
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  # 根路径GET请求处理
173
+ @app.get("/", response_class=HTMLResponse)
174
+ async def read_root():
175
+ """返回欢迎页面"""
176
+ html_content = """
177
+ <html>
178
+ <head>
179
+ <title>Welcome to API</title>
180
+ </head>
181
+ <body>
182
+ <h1>Welcome to API</h1>
183
+ <p>You can send messages to the model and receive responses.</p>
184
+ </body>
185
+ </html>
186
+ """
187
+ return HTMLResponse(content=html_content, status_code=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # 聊天完成处理(保留原有逻辑,未修改)
190
+ @app.post("/v1/chat/completions")
191
+ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
192
  """
193
  处理聊天完成请求
194
  """
 
203
  # 处理消息内容
204
  messages = request_body.get("messages", [])
205
  temperature = request_body.get("temperature", 1.0)
206
+ #top_p = request_body.get("top_p", 1.0)
207
  max_tokens = request_body.get("max_tokens", 8000)
208
  model = request_body.get("model", "gpt-4o")
209
  is_stream = request_body.get("stream", False) # 获取 stream 字段
210
 
211
+ # 验证模型
212
+ if model not in MODELS:
213
+ raise HTTPException(status_code=400, detail=f"无效的 model: {model}. 可用的模型有: {', '.join(MODELS)}")
214
+
215
  has_image = False
216
  has_text = False
217
 
 
231
  url = image_info.get("url", "")
232
  if is_base64_image(url):
233
  # 解码并保存图片
234
+ try:
235
+ base64_str = url.split(",")[1]
236
+ filename = save_base64_image(base64_str)
237
+ base_url = app.state.base_url
238
+ image_url = f"{base_url}/images/{filename}"
239
+ images.append({"data": image_url})
240
+ except (IndexError, ValueError) as e:
241
+ print(f"处理Base64图片失败: {e}")
242
+ continue
243
  else:
244
  images.append({"data": url})
245
  extracted_content = " ".join(text_parts).strip()
 
274
  if not cleaned_messages:
275
  raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
276
 
 
 
 
 
277
  # 构建新的请求JSON
278
  new_request_json = {
279
  "function_image_gen": False,
280
  "function_web_search": True,
281
  "max_tokens": max_tokens,
282
  "model": model,
283
+ "source": "chat/pro",
284
  "temperature": temperature,
285
+ #"top_p": top_p,
286
  "messages": cleaned_messages,
287
  }
288
 
 
300
  "Date": formatted_date,
301
  "Client-time-zone": "-05:00",
302
  "Authorization": bearer_token,
303
+ "User-Agent": "ChatOn_Android/1.55.488",
304
  "Accept-Language": "en-US",
305
  "X-Cl-Options": "hb",
306
  "Content-Type": "application/json; charset=UTF-8",
307
  }
308
 
309
  if is_stream:
310
+ import uuid
311
+ from datetime import datetime, timezone
312
+
313
+ # 定义 should_filter_out 函数
314
+ def should_filter_out(json_data):
315
+ if 'ping' in json_data:
316
+ return True
317
+ if 'data' in json_data:
318
+ data = json_data['data']
319
+ if 'analytics' in data:
320
+ return True
321
+ if 'operation' in data and 'message' in data:
322
+ return True
323
+ return False
324
+
325
+ # 定义 generate_id 函数
326
+ def generate_id():
327
+ return uuid.uuid4().hex[:24]
328
+
329
  # 流式响应处理
330
  async def event_generator():
331
  async with httpx.AsyncClient(timeout=None) as client_stream:
 
340
  break
341
  try:
342
  sse_json = json.loads(data)
343
+
344
+ # 判断是否需要过滤
345
+ if should_filter_out(sse_json):
346
+ continue
347
+
348
+ # 处理包含 web sources 的消息
349
+ if 'data' in sse_json and 'web' in sse_json['data']:
350
+ web_data = sse_json['data']['web']
351
+ if 'sources' in web_data:
352
+ sources = web_data['sources']
353
+ urls_list = []
354
+ for source in sources:
355
+ if 'url' in source:
356
+ urls_list.append(source['url'])
357
+ urls_content = '\n\n'.join(urls_list)
358
+ print(f"从 API 接收到的内容: {urls_content}")
359
+ # 构造新的 SSE 消息,填入 content 字段
360
+ new_sse_json = {
361
+ "id": generate_id(),
362
+ "object": "chat.completion.chunk",
363
+ "created": int(datetime.now(timezone.utc).timestamp()),
364
+ "model": sse_json.get("model", "gpt-4o"),
365
+ "choices": [
366
+ {
367
+ "delta": {"content": "\n" + urls_content + "\n"},
368
+ "index": 0,
369
+ "finish_reason": None
370
+ }
371
+ ]
372
+ }
373
+ new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
374
+ yield new_sse_line
375
+ else:
376
+ # 尝试打印内容
377
+ if 'choices' in sse_json:
378
+ for choice in sse_json['choices']:
379
+ delta = choice.get('delta', {})
380
+ content = delta.get('content')
381
+ if content:
382
+ print(content, end='')
383
+ # 直接转发其他消息
384
+ yield f"data: {data}\n\n"
385
+ except json.JSONDecodeError as e:
386
+ print(f"JSON解析错误: {e}")
387
  continue
388
+ else:
389
+ # 忽略不以 "data: " 开头的行
390
+ continue
391
  except httpx.RequestError as exc:
392
  print(f"外部API请求失败: {exc}")
393
  yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
 
443
  "id": f"chatcmpl-{uuid.uuid4()}",
444
  "object": "chat.completion",
445
  "created": int(datetime.now(timezone.utc).timestamp()),
446
+ "model": model, # 使用用户传入的model
447
  "choices": [
448
  {
449
  "index": 0,
 
472
  raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
473
 
474
  # 图像生成处理
475
+ @app.post("/v1/images/generations")
476
+ async def images_generations(request: Request):
477
  """
478
  处理图像生成请求
479
  """
 
489
  return send_error_response("缺少必需的字段: prompt", status_code=400)
490
 
491
  user_prompt = request_body.get("prompt", "").strip()
492
+ response_format = request_body.get("response_format", "").strip().lower()
493
+ model = request_body.get("model", "gpt-4o")
494
+ n = request_body.get("n", 1) # 生成图像数量,默认1
495
+ size = request_body.get("size", "1024x1024") # 图片尺寸,默认1024x1024
496
+
497
+ is_base64_response = response_format == "b64_json"
498
 
499
  if not user_prompt:
500
  return send_error_response("Prompt 不能为空。", status_code=400)
501
 
502
  print(f"Prompt: {user_prompt}")
503
+ print(f"Response Format: {response_format}")
504
+ print(f"Number of images to generate (n): {n}")
505
+ print(f"Size: {size}")
506
+
507
+ # 设置最大尝试次数为 2 * n
508
+ max_attempts = 2 * n
509
+ print(f"Max Attempts: {max_attempts}")
510
+
511
+ # 初始化用于存储多个 URL 的线程安全列表
512
+ final_download_urls: List[str] = []
513
+
514
+ async def attempt_generate_image(attempt: int) -> Optional[str]:
515
+ """
516
+ 尝试生成单张图像,带有重试机制。
517
+ """
518
+ try:
519
+ # 构建新的 TextToImage JSON 请求体
520
+ text_to_image_json = {
521
+ "function_image_gen": True,
522
+ "function_web_search": True,
523
+ "image_aspect_ratio": "1:1", # 图片比例可选:1:1/9:19/16:9/4:3
524
+ "image_style": "photographic", # 固定 image_style,可根据需要调整
525
+ "max_tokens": 8000,
526
+ "messages": [
527
+ {
528
+ "content": "You are a helpful artist, please draw a picture. Based on imagination, draw a picture with user message.",
529
+ "role": "system"
530
+ },
531
+ {
532
+ "content": "Draw: " + user_prompt,
533
+ "role": "user"
534
+ }
535
+ ],
536
+ "model": "gpt-4o",
537
+ "source": "chat/free" # 固定 source
538
+ }
539
 
540
+ modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
541
+ print(f"Attempt {attempt} - Modified Request JSON: {modified_request_body}")
542
+
543
+ # 获取Bearer Token
544
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
545
+ if not tmp_token:
546
+ print(f"Attempt {attempt} - 无法生成 Bearer Token")
547
+ return None
548
+
549
+ bearer_token, formatted_date = tmp_token
550
+
551
+ headers = {
552
+ "Date": formatted_date,
553
+ "Client-time-zone": "-05:00",
554
+ "Authorization": bearer_token,
555
+ "User-Agent": "ChatOn_Android/1.53.502",
556
+ "Accept-Language": "en-US",
557
+ "X-Cl-Options": "hb",
558
+ "Content-Type": "application/json; charset=UTF-8",
559
  }
 
 
 
 
560
 
561
+ async with httpx.AsyncClient(timeout=None) as client:
562
+ response = await client.post(
563
+ EXTERNAL_API_URL,
564
+ headers=headers,
565
+ content=modified_request_body,
566
+ timeout=None
567
+ )
568
 
569
+ if response.status_code != 200:
570
+ print(f"Attempt {attempt} - API 错误: {response.status_code}")
571
+ return None
 
572
 
573
+ # 读取 SSE 流并提取图像URL
574
+ sse_lines = response.text.splitlines()
575
+ image_markdown = ""
576
 
577
+ for line in sse_lines:
578
+ if line.startswith("data: "):
579
+ data = line[6:].strip()
580
+ if data == "[DONE]":
581
+ break
582
+ try:
583
+ sse_json = json.loads(data)
584
+ if "choices" in sse_json:
585
+ for choice in sse_json["choices"]:
586
+ delta = choice.get("delta", {})
587
+ content = delta.get("content")
588
+ if content:
589
+ image_markdown += content
590
+ except json.JSONDecodeError:
591
+ print(f"Attempt {attempt} - JSON解析错误")
592
+ continue
593
 
594
+ # 检查Markdown文本是否为空
595
+ if not image_markdown:
596
+ print(f"Attempt {attempt} - 无法从 SSE 流中构建图像 Markdown。")
597
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
+ # 从Markdown中提取图像路径
600
+ extracted_path = extract_path_from_markdown(image_markdown)
601
+ if not extracted_path:
602
+ print(f"Attempt {attempt} - 无法从 Markdown 中提取路径。")
603
+ return None
604
 
605
+ print(f"Attempt {attempt} - 提取的路径: {extracted_path}")
 
 
606
 
607
+ # 拼接最终的存储URL
608
+ storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
609
+ print(f"Attempt {attempt} - 存储URL: {storage_url}")
610
 
611
+ # 获取最终下载URL
612
+ final_download_url = await fetch_get_url_from_storage(storage_url)
613
+ if not final_download_url:
614
+ print(f"Attempt {attempt} - 无法从 storage URL 获取最终下载链接。")
615
+ return None
616
 
617
+ print(f"Attempt {attempt} - Final Download URL: {final_download_url}")
618
+
619
+ return final_download_url
620
+ except Exception as e:
621
+ print(f"Attempt {attempt} - 处理响应时发生错误: {e}")
622
+ return None
623
+
624
+ # 定义一个异步任务池,限制并发数量
625
+ semaphore = asyncio.Semaphore(10) # 限制同时进行的任务数为10
626
+
627
+ async def generate_with_retries(attempt: int) -> Optional[str]:
628
+ async with semaphore:
629
+ return await attempt_generate_image(attempt)
630
+
631
+ # 开始尝试生成图像
632
+ for attempt in range(1, max_attempts + 1):
633
+ needed = n - len(final_download_urls)
634
+ if needed <= 0:
635
+ break
636
+
637
+ print(f"Attempt {attempt} - 需要生成的图像数量: {needed}")
638
+
639
+ # 创建多个任务同时生成所需数量的图��
640
+ tasks = [asyncio.create_task(generate_with_retries(attempt)) for _ in range(needed)]
641
+
642
+ # 等待所有任务完成
643
+ results = await asyncio.gather(*tasks, return_exceptions=True)
644
+
645
+ for result in results:
646
+ if isinstance(result, Exception):
647
+ print(f"Attempt {attempt} - 任务发生异常: {result}")
648
+ continue
649
+ if result:
650
+ final_download_urls.append(result)
651
+ print(f"Attempt {attempt} - 收集到下载链接: {result}")
652
+
653
+ # 检查是否已经收集到足够的下载链接
654
+ if len(final_download_urls) >= n:
655
+ break
656
+
657
+ # 检查是否收集到足够的链接
658
+ if len(final_download_urls) < n:
659
+ print("已达到最大尝试次数,仍未收集到足够数量的下载链接。")
660
+ return send_error_response("无法生成足够数量的图像。", status_code=500)
661
+
662
+ # 根据 response_format 返回相应的响应
663
+ data_array = []
664
+
665
+ if is_base64_response:
666
+ for download_url in final_download_urls[:n]:
667
+ try:
668
+ image_bytes = await download_image(download_url)
669
+ if not image_bytes:
670
+ print(f"无法从 URL 下载图像: {download_url}")
671
+ continue
672
+
673
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
674
+ data_array.append({
675
+ "b64_json": image_base64
676
+ })
677
+ except Exception as e:
678
+ print(f"处理图像时发生错误: {e}")
679
+ continue
680
+ else:
681
+ for download_url in final_download_urls[:n]:
682
+ data_array.append({
683
+ "url": download_url
684
+ })
685
+
686
+ # 如果收集的 URL 数量不足 n,则通过复制现有的 URL 来填充
687
+ while len(data_array) < n and len(data_array) > 0:
688
+ for item in data_array.copy():
689
+ if len(data_array) >= n:
690
+ break
691
+ data_array.append(item)
692
+
693
+ # 构建最终响应
694
+ response_json = {
695
+ "created": int(datetime.now(timezone.utc).timestamp()),
696
+ "data": data_array
697
+ }
698
+
699
+ # 如果data_array为空,返回错误
700
+ if not data_array:
701
+ return send_error_response("无法生成图像。", status_code=500)
702
+
703
+ # 返回响应
704
+ return JSONResponse(content=response_json, status_code=200)
705
+
706
+ @app.get("/v1/models", response_class=JSONResponse)
707
+ async def get_models():
708
+ models_data = {
709
+ "object": "list",
710
+ "data": [
711
+ {"id": "gpt-4o", "object": "model"},
712
+ {"id": "gpt-4o-mini", "object": "model"},
713
+ {"id": "claude-3-5-sonnet", "object": "model"},
714
+ {"id": "claude", "object": "model"}
715
+ ]
716
+ }
717
+ return models_data
718
 
719
  async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
720
  """查找可用的端口号"""
721
  for port in range(start_port, end_port + 1):
722
  try:
723
+ # 尝试绑定端口
724
  server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
725
  server.close()
726
  await server.wait_closed()
 
729
  continue
730
  raise RuntimeError(f"No available ports between {start_port} and {end_port}")
731
 
732
+ def main():
733
+ parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
734
+ parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
735
+ parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
736
+ args = parser.parse_args()
737
+
738
+ base_url = args.base_url
739
+ port = args.port
740
+
741
+ # 设置 FastAPI 应用的 state
742
+ app.state.base_url = base_url
743
+
744
+ print(f"Starting server on port {port} with base_url: {base_url}")
745
+
746
+ # 检查端口可用性
747
+ try:
748
+ port = asyncio.run(get_available_port(start_port=port))
749
+ except RuntimeError as e:
750
+ print(e)
751
+ return
752
+
753
+ print(f"Server running on available port: {port}")
754
+
755
+ # 运行FastAPI应用
756
+ uvicorn.run(app, host="0.0.0.0", port=port)
757
+
758
  if __name__ == "__main__":
759
  main()