dan92 commited on
Commit
a7d2f0d
·
verified ·
1 Parent(s): e533997

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +762 -762
main.py CHANGED
@@ -1,762 +1,762 @@
1
- import asyncio
2
- import json
3
- import sys
4
- import uuid
5
- import base64
6
- import re
7
- import os
8
- import argparse
9
- from datetime import datetime, timezone, timedelta
10
- from typing import List, Optional
11
-
12
- import httpx
13
- import uvicorn
14
- from fastapi import (
15
- BackgroundTasks,
16
- FastAPI,
17
- HTTPException,
18
- Request,
19
- Response,
20
- status,
21
- )
22
- from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
23
- from fastapi.middleware.cors import CORSMiddleware
24
- from fastapi.staticfiles import StaticFiles
25
-
26
- from bearer_token import BearerTokenGenerator
27
-
28
- # 模型列表(根据需求,可自行调整)
29
- MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
-
31
- # 默认端口
32
- INITIAL_PORT = 3000
33
-
34
- # 外部API的URL
35
- EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
36
-
37
- # 定义 images 目录
38
- IMAGES_DIR = "images"
39
-
40
- # 确保 images 目录存在
41
- def ensure_images_dir_exists(directory: str = IMAGES_DIR):
42
- try:
43
- os.makedirs(directory, exist_ok=True)
44
- print(f"Directory '{directory}' is ready.")
45
- except Exception as e:
46
- print(f"Failed to create directory '{directory}': {e}")
47
- sys.exit(1)
48
-
49
- # 在挂载静态文件之前确保 images 目录存在
50
- ensure_images_dir_exists()
51
-
52
- # 初始化FastAPI应用
53
- app = FastAPI()
54
-
55
- # 挂载静态文件路由以提供 images 目录的内容
56
- app.mount("/images", StaticFiles(directory=IMAGES_DIR), name="images")
57
-
58
- # 添加CORS中间件
59
- app.add_middleware(
60
- CORSMiddleware,
61
- allow_origins=["*"], # 允许所有来源
62
- allow_credentials=True,
63
- allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
64
- allow_headers=["Content-Type", "Authorization"], # 允许的头部
65
- )
66
-
67
- # 辅助函数
68
- def send_error_response(message: str, status_code: int = 400):
69
- """构建错误响应,并确保包含CORS头"""
70
- error_json = {"error": message}
71
- headers = {
72
- "Access-Control-Allow-Origin": "*",
73
- "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
74
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
75
- }
76
- return JSONResponse(status_code=status_code, content=error_json, headers=headers)
77
-
78
- def extract_path_from_markdown(markdown: str) -> Optional[str]:
79
- """
80
- 提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
81
- """
82
- pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
83
- match = pattern.search(markdown)
84
- if match:
85
- return match.group(1)
86
- return None
87
-
88
- async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
89
- """
90
- 从 storage URL 获取 JSON 并提取 getUrl
91
- """
92
- async with httpx.AsyncClient() as client:
93
- try:
94
- response = await client.get(storage_url)
95
- if response.status_code != 200:
96
- print(f"获取 storage URL 失败,状态码: {response.status_code}")
97
- return None
98
- json_response = response.json()
99
- return json_response.get("getUrl")
100
- except Exception as e:
101
- print(f"Error fetching getUrl from storage: {e}")
102
- return None
103
-
104
- async def download_image(image_url: str) -> Optional[bytes]:
105
- """
106
- 下载图像
107
- """
108
- async with httpx.AsyncClient() as client:
109
- try:
110
- response = await client.get(image_url)
111
- if response.status_code == 200:
112
- return response.content
113
- else:
114
- print(f"下载图像失败,状态码: {response.status_code}")
115
- return None
116
- except Exception as e:
117
- print(f"Error downloading image: {e}")
118
- return None
119
-
120
- def cleanup_images(images_dir: str = IMAGES_DIR, age_seconds: int = 60):
121
- """
122
- 清理 images 目录中创建时间超过指定秒数的图片
123
- """
124
- now = datetime.now(timezone.utc)
125
- cutoff_time = now - timedelta(seconds=age_seconds)
126
-
127
- if not os.path.exists(images_dir):
128
- return
129
-
130
- for filename in os.listdir(images_dir):
131
- file_path = os.path.join(images_dir, filename)
132
- if os.path.isfile(file_path):
133
- try:
134
- file_creation_time = datetime.fromtimestamp(os.path.getctime(file_path), timezone.utc)
135
- if file_creation_time < cutoff_time:
136
- os.remove(file_path)
137
- print(f"已删除旧图片: {filename}")
138
- except Exception as e:
139
- print(f"无法删除文件 {filename}: {e}")
140
-
141
- def save_base64_image(base64_str: str, images_dir: str = IMAGES_DIR) -> str:
142
- """
143
- 将Base64编码的图片保存到images目录,返回文件名
144
- """
145
- # 先清理1分钟前的所有图片
146
- cleanup_images(images_dir, age_seconds=60)
147
-
148
- try:
149
- image_data = base64.b64decode(base64_str)
150
- except base64.binascii.Error as e:
151
- print(f"Base64解码失败: {e}")
152
- raise ValueError("Invalid base64 image data")
153
-
154
- filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
155
- file_path = os.path.join(images_dir, filename)
156
- try:
157
- with open(file_path, "wb") as f:
158
- f.write(image_data)
159
- print(f"保存图片: {filename}")
160
- except Exception as e:
161
- print(f"保存图片失败: {e}")
162
- raise
163
-
164
- return filename
165
-
166
- def is_base64_image(url: str) -> bool:
167
- """
168
- 判断URL是否为Base64编码的图片
169
- """
170
- return url.startswith("data:image/")
171
-
172
- # 根路径GET请求处理
173
- @app.get("/", response_class=HTMLResponse)
174
- async def read_root():
175
- """返回欢迎页面"""
176
- html_content = """
177
- <html>
178
- <head>
179
- <title>Welcome to API</title>
180
- </head>
181
- <body>
182
- <h1>Welcome to API</h1>
183
- <p>You can send messages to the model and receive responses.</p>
184
- </body>
185
- </html>
186
- """
187
- return HTMLResponse(content=html_content, status_code=200)
188
-
189
- # 聊天完成处理(保留原有逻辑,未修改)
190
- @app.post("/ai/v1/chat/completions")
191
- @app.post("/v1/chat/completions")
192
- async def chat_completions(request: Request, background_tasks: BackgroundTasks):
193
- """
194
- 处理聊天完成请求
195
- """
196
- try:
197
- request_body = await request.json()
198
- except json.JSONDecodeError:
199
- raise HTTPException(status_code=400, detail="Invalid JSON")
200
-
201
- # 打印接收到的请求
202
- print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
203
-
204
- # 处理消息内容
205
- messages = request_body.get("messages", [])
206
- temperature = request_body.get("temperature", 1.0)
207
- #top_p = request_body.get("top_p", 1.0)
208
- max_tokens = request_body.get("max_tokens", 8000)
209
- model = request_body.get("model", "gpt-4o")
210
- is_stream = request_body.get("stream", False) # 获取 stream 字段
211
-
212
- # 验证模型
213
- if model not in MODELS:
214
- raise HTTPException(status_code=400, detail=f"无效的 model: {model}. 可用的模型有: {', '.join(MODELS)}")
215
-
216
- has_image = False
217
- has_text = False
218
-
219
- # 清理和提取消息内容
220
- cleaned_messages = []
221
- for message in messages:
222
- content = message.get("content", "")
223
- if isinstance(content, list):
224
- text_parts = []
225
- images = []
226
- for item in content:
227
- if "text" in item:
228
- text_parts.append(item.get("text", ""))
229
- elif "image_url" in item:
230
- has_image = True
231
- image_info = item.get("image_url", {})
232
- url = image_info.get("url", "")
233
- if is_base64_image(url):
234
- # 解码并保存图片
235
- try:
236
- base64_str = url.split(",")[1]
237
- filename = save_base64_image(base64_str)
238
- base_url = app.state.base_url
239
- image_url = f"{base_url}/images/{filename}"
240
- images.append({"data": image_url})
241
- except (IndexError, ValueError) as e:
242
- print(f"处理Base64图片失败: {e}")
243
- continue
244
- else:
245
- images.append({"data": url})
246
- extracted_content = " ".join(text_parts).strip()
247
- if extracted_content:
248
- has_text = True
249
- message["content"] = extracted_content
250
- if images:
251
- message["images"] = images
252
- cleaned_messages.append(message)
253
- print("Extracted:", extracted_content)
254
- else:
255
- if images:
256
- has_image = True
257
- message["content"] = ""
258
- message["images"] = images
259
- cleaned_messages.append(message)
260
- print("Extracted image only.")
261
- else:
262
- print("Deleted message with empty content.")
263
- elif isinstance(content, str):
264
- content_str = content.strip()
265
- if content_str:
266
- has_text = True
267
- message["content"] = content_str
268
- cleaned_messages.append(message)
269
- print("Retained content:", content_str)
270
- else:
271
- print("Deleted message with empty content.")
272
- else:
273
- print("Deleted non-expected type of content message.")
274
-
275
- if not cleaned_messages:
276
- raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
277
-
278
- # 构建新的请求JSON
279
- new_request_json = {
280
- "function_image_gen": False,
281
- "function_web_search": True,
282
- "max_tokens": max_tokens,
283
- "model": model,
284
- "source": "chat/pro",
285
- "temperature": temperature,
286
- #"top_p": top_p,
287
- "messages": cleaned_messages,
288
- }
289
-
290
- modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
291
- print("Modified Request JSON:", modified_request_body)
292
-
293
- # 获取Bearer Token
294
- tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
295
- if not tmp_token:
296
- raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
297
-
298
- bearer_token, formatted_date = tmp_token
299
-
300
- headers = {
301
- "Date": formatted_date,
302
- "Client-time-zone": "-05:00",
303
- "Authorization": bearer_token,
304
- "User-Agent": "ChatOn_Android/1.55.488",
305
- "Accept-Language": "en-US",
306
- "X-Cl-Options": "hb",
307
- "Content-Type": "application/json; charset=UTF-8",
308
- }
309
-
310
- if is_stream:
311
- import uuid
312
- from datetime import datetime, timezone
313
-
314
- # 定义 should_filter_out 函数
315
- def should_filter_out(json_data):
316
- if 'ping' in json_data:
317
- return True
318
- if 'data' in json_data:
319
- data = json_data['data']
320
- if 'analytics' in data:
321
- return True
322
- if 'operation' in data and 'message' in data:
323
- return True
324
- return False
325
-
326
- # 定义 generate_id 函数
327
- def generate_id():
328
- return uuid.uuid4().hex[:24]
329
-
330
- # 流式响应处理
331
- async def event_generator():
332
- async with httpx.AsyncClient(timeout=None) as client_stream:
333
- try:
334
- async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
335
- async for line in streamed_response.aiter_lines():
336
- if line.startswith("data: "):
337
- data = line[6:].strip()
338
- if data == "[DONE]":
339
- # 通知客户端流结束
340
- yield "data: [DONE]\n\n"
341
- break
342
- try:
343
- sse_json = json.loads(data)
344
-
345
- # 判断是否需要过滤
346
- if should_filter_out(sse_json):
347
- continue
348
-
349
- # 处理包含 web sources 的消息
350
- if 'data' in sse_json and 'web' in sse_json['data']:
351
- web_data = sse_json['data']['web']
352
- if 'sources' in web_data:
353
- sources = web_data['sources']
354
- urls_list = []
355
- for source in sources:
356
- if 'url' in source:
357
- urls_list.append(source['url'])
358
- urls_content = '\n\n'.join(urls_list)
359
- print(f"从 API 接收到的内容: {urls_content}")
360
- # 构造新的 SSE 消息,填入 content 字段
361
- new_sse_json = {
362
- "id": generate_id(),
363
- "object": "chat.completion.chunk",
364
- "created": int(datetime.now(timezone.utc).timestamp()),
365
- "model": sse_json.get("model", "gpt-4o"),
366
- "choices": [
367
- {
368
- "delta": {"content": "\n" + urls_content + "\n"},
369
- "index": 0,
370
- "finish_reason": None
371
- }
372
- ]
373
- }
374
- new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
375
- yield new_sse_line
376
- else:
377
- # 尝试打印内容
378
- if 'choices' in sse_json:
379
- for choice in sse_json['choices']:
380
- delta = choice.get('delta', {})
381
- content = delta.get('content')
382
- if content:
383
- print(content, end='')
384
- # 直接转发其他消息
385
- yield f"data: {data}\n\n"
386
- except json.JSONDecodeError as e:
387
- print(f"JSON解析错误: {e}")
388
- continue
389
- else:
390
- # 忽略不以 "data: " 开头的行
391
- continue
392
- except httpx.RequestError as exc:
393
- print(f"外部API请求失败: {exc}")
394
- yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
395
-
396
- return StreamingResponse(
397
- event_generator(),
398
- media_type="text/event-stream",
399
- headers={
400
- "Cache-Control": "no-cache",
401
- "Connection": "keep-alive",
402
- # CORS头已通过中间件处理,无需在这里重复添加
403
- },
404
- )
405
- else:
406
- # 非流式响应处理
407
- async with httpx.AsyncClient(timeout=None) as client:
408
- try:
409
- response = await client.post(
410
- EXTERNAL_API_URL,
411
- headers=headers,
412
- content=modified_request_body,
413
- timeout=None
414
- )
415
-
416
- if response.status_code != 200:
417
- raise HTTPException(
418
- status_code=response.status_code,
419
- detail=f"API 错误: {response.status_code}",
420
- )
421
-
422
- sse_lines = response.text.splitlines()
423
- content_builder = ""
424
- images_urls = []
425
-
426
- for line in sse_lines:
427
- if line.startswith("data: "):
428
- data = line[6:].strip()
429
- if data == "[DONE]":
430
- break
431
- try:
432
- sse_json = json.loads(data)
433
- if "choices" in sse_json:
434
- for choice in sse_json["choices"]:
435
- if "delta" in choice:
436
- delta = choice["delta"]
437
- if "content" in delta:
438
- content_builder += delta["content"]
439
- except json.JSONDecodeError:
440
- print("JSON解析错误")
441
- continue
442
-
443
- openai_response = {
444
- "id": f"chatcmpl-{uuid.uuid4()}",
445
- "object": "chat.completion",
446
- "created": int(datetime.now(timezone.utc).timestamp()),
447
- "model": model, # 使用用户传入的model
448
- "choices": [
449
- {
450
- "index": 0,
451
- "message": {
452
- "role": "assistant",
453
- "content": content_builder,
454
- },
455
- "finish_reason": "stop",
456
- }
457
- ],
458
- }
459
-
460
- # 处理图片(如果有)
461
- if has_image:
462
- images = []
463
- for message in cleaned_messages:
464
- if "images" in message:
465
- for img in message["images"]:
466
- images.append({"data": img["data"]})
467
- openai_response["choices"][0]["message"]["images"] = images
468
-
469
- return JSONResponse(content=openai_response, status_code=200)
470
- except httpx.RequestError as exc:
471
- raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
472
- except Exception as exc:
473
- raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
474
-
475
- # 图像生成处理
476
- @app.post("/ai/v1/images/generations")
477
- @app.post("/v1/images/generations")
478
- async def images_generations(request: Request):
479
- """
480
- 处理图像生成请求
481
- """
482
- try:
483
- request_body = await request.json()
484
- except json.JSONDecodeError:
485
- return send_error_response("Invalid JSON", status_code=400)
486
-
487
- print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
488
-
489
- # 验证必需的字段
490
- if "prompt" not in request_body:
491
- return send_error_response("缺少必需的字段: prompt", status_code=400)
492
-
493
- user_prompt = request_body.get("prompt", "").strip()
494
- response_format = request_body.get("response_format", "").strip().lower()
495
- model = request_body.get("model", "gpt-4o")
496
- n = request_body.get("n", 1) # 生成图像数量,默认1
497
- size = request_body.get("size", "1024x1024") # 图片尺寸,默认1024x1024
498
-
499
- is_base64_response = response_format == "b64_json"
500
-
501
- if not user_prompt:
502
- return send_error_response("Prompt 不能为空。", status_code=400)
503
-
504
- print(f"Prompt: {user_prompt}")
505
- print(f"Response Format: {response_format}")
506
- print(f"Number of images to generate (n): {n}")
507
- print(f"Size: {size}")
508
-
509
- # 设置最大尝试次数为 2 * n
510
- max_attempts = 2 * n
511
- print(f"Max Attempts: {max_attempts}")
512
-
513
- # 初始化用于存储多个 URL 的线程安全列表
514
- final_download_urls: List[str] = []
515
-
516
- async def attempt_generate_image(attempt: int) -> Optional[str]:
517
- """
518
- 尝试生成单张图像,带有重试机制。
519
- """
520
- try:
521
- # 构建新的 TextToImage JSON 请求体
522
- text_to_image_json = {
523
- "function_image_gen": True,
524
- "function_web_search": True,
525
- "image_aspect_ratio": "1:1", # 图片比例可选:1:1/9:19/16:9/4:3
526
- "image_style": "photographic", # 固定 image_style,可根据需要调整
527
- "max_tokens": 8000,
528
- "messages": [
529
- {
530
- "content": "You are a helpful artist, please draw a picture. Based on imagination, draw a picture with user message.",
531
- "role": "system"
532
- },
533
- {
534
- "content": "Draw: " + user_prompt,
535
- "role": "user"
536
- }
537
- ],
538
- "model": "gpt-4o",
539
- "source": "chat/free" # 固定 source
540
- }
541
-
542
- modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
543
- print(f"Attempt {attempt} - Modified Request JSON: {modified_request_body}")
544
-
545
- # 获取Bearer Token
546
- tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
547
- if not tmp_token:
548
- print(f"Attempt {attempt} - 无法生成 Bearer Token")
549
- return None
550
-
551
- bearer_token, formatted_date = tmp_token
552
-
553
- headers = {
554
- "Date": formatted_date,
555
- "Client-time-zone": "-05:00",
556
- "Authorization": bearer_token,
557
- "User-Agent": "ChatOn_Android/1.53.502",
558
- "Accept-Language": "en-US",
559
- "X-Cl-Options": "hb",
560
- "Content-Type": "application/json; charset=UTF-8",
561
- }
562
-
563
- async with httpx.AsyncClient(timeout=None) as client:
564
- response = await client.post(
565
- EXTERNAL_API_URL,
566
- headers=headers,
567
- content=modified_request_body,
568
- timeout=None
569
- )
570
-
571
- if response.status_code != 200:
572
- print(f"Attempt {attempt} - API 错误: {response.status_code}")
573
- return None
574
-
575
- # 读取 SSE 流并提取图像URL
576
- sse_lines = response.text.splitlines()
577
- image_markdown = ""
578
-
579
- for line in sse_lines:
580
- if line.startswith("data: "):
581
- data = line[6:].strip()
582
- if data == "[DONE]":
583
- break
584
- try:
585
- sse_json = json.loads(data)
586
- if "choices" in sse_json:
587
- for choice in sse_json["choices"]:
588
- delta = choice.get("delta", {})
589
- content = delta.get("content")
590
- if content:
591
- image_markdown += content
592
- except json.JSONDecodeError:
593
- print(f"Attempt {attempt} - JSON解析错误")
594
- continue
595
-
596
- # 检查Markdown文本是否为空
597
- if not image_markdown:
598
- print(f"Attempt {attempt} - 无法从 SSE 流中构建图像 Markdown。")
599
- return None
600
-
601
- # 从Markdown中提取图像路径
602
- extracted_path = extract_path_from_markdown(image_markdown)
603
- if not extracted_path:
604
- print(f"Attempt {attempt} - 无法从 Markdown 中提取路径。")
605
- return None
606
-
607
- print(f"Attempt {attempt} - 提取的路径: {extracted_path}")
608
-
609
- # 拼接最终的存储URL
610
- storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
611
- print(f"Attempt {attempt} - 存储URL: {storage_url}")
612
-
613
- # 获取最终下载URL
614
- final_download_url = await fetch_get_url_from_storage(storage_url)
615
- if not final_download_url:
616
- print(f"Attempt {attempt} - 无法从 storage URL 获取最终下载链接。")
617
- return None
618
-
619
- print(f"Attempt {attempt} - Final Download URL: {final_download_url}")
620
-
621
- return final_download_url
622
- except Exception as e:
623
- print(f"Attempt {attempt} - 处理响应时发生错误: {e}")
624
- return None
625
-
626
- # 定义一个异步任务池,限制并发数量
627
- semaphore = asyncio.Semaphore(10) # 限制同时进行的任务数���10
628
-
629
- async def generate_with_retries(attempt: int) -> Optional[str]:
630
- async with semaphore:
631
- return await attempt_generate_image(attempt)
632
-
633
- # 开始尝试生成图像
634
- for attempt in range(1, max_attempts + 1):
635
- needed = n - len(final_download_urls)
636
- if needed <= 0:
637
- break
638
-
639
- print(f"Attempt {attempt} - 需要生成的图像数量: {needed}")
640
-
641
- # 创建多个任务同时生成所需数量的图像
642
- tasks = [asyncio.create_task(generate_with_retries(attempt)) for _ in range(needed)]
643
-
644
- # 等待所有任务完成
645
- results = await asyncio.gather(*tasks, return_exceptions=True)
646
-
647
- for result in results:
648
- if isinstance(result, Exception):
649
- print(f"Attempt {attempt} - 任务发生异常: {result}")
650
- continue
651
- if result:
652
- final_download_urls.append(result)
653
- print(f"Attempt {attempt} - 收集到下载链接: {result}")
654
-
655
- # 检查是否已经收集到足够的下载链接
656
- if len(final_download_urls) >= n:
657
- break
658
-
659
- # 检查是否收集到足够的链接
660
- if len(final_download_urls) < n:
661
- print("已达到最大尝试次数,仍未收集到足够数量的下载链接。")
662
- return send_error_response("无法生成足够数量的图像。", status_code=500)
663
-
664
- # 根据 response_format 返回相应的响应
665
- data_array = []
666
-
667
- if is_base64_response:
668
- for download_url in final_download_urls[:n]:
669
- try:
670
- image_bytes = await download_image(download_url)
671
- if not image_bytes:
672
- print(f"无法从 URL 下载图像: {download_url}")
673
- continue
674
-
675
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
676
- data_array.append({
677
- "b64_json": image_base64
678
- })
679
- except Exception as e:
680
- print(f"处理图像时发生错误: {e}")
681
- continue
682
- else:
683
- for download_url in final_download_urls[:n]:
684
- data_array.append({
685
- "url": download_url
686
- })
687
-
688
- # 如果收集的 URL 数量不足 n,则通过复制现有的 URL 来填充
689
- while len(data_array) < n and len(data_array) > 0:
690
- for item in data_array.copy():
691
- if len(data_array) >= n:
692
- break
693
- data_array.append(item)
694
-
695
- # 构建最终响应
696
- response_json = {
697
- "created": int(datetime.now(timezone.utc).timestamp()),
698
- "data": data_array
699
- }
700
-
701
- # 如果data_array为空,返回错误
702
- if not data_array:
703
- return send_error_response("无法生成图像。", status_code=500)
704
-
705
- # 返回响应
706
- return JSONResponse(content=response_json, status_code=200)
707
-
708
- @app.get("/ai/v1/models", response_class=JSONResponse)
709
- @app.get("/v1/models", response_class=JSONResponse)
710
- async def get_models():
711
- models_data = {
712
- "object": "list",
713
- "data": [
714
- {"id": "gpt-4o", "object": "model"},
715
- {"id": "gpt-4o-mini", "object": "model"},
716
- {"id": "claude-3-5-sonnet", "object": "model"},
717
- {"id": "claude", "object": "model"}
718
- ]
719
- }
720
- return models_data
721
-
722
- async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
723
- """查找可用的端口号"""
724
- for port in range(start_port, end_port + 1):
725
- try:
726
- # 尝试绑定端口
727
- server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
728
- server.close()
729
- await server.wait_closed()
730
- return port
731
- except OSError:
732
- continue
733
- raise RuntimeError(f"No available ports between {start_port} and {end_port}")
734
-
735
- def main():
736
- parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
737
- parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
738
- parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
739
- args = parser.parse_args()
740
-
741
- base_url = args.base_url
742
- port = args.port
743
-
744
- # 设置 FastAPI 应用的 state
745
- app.state.base_url = base_url
746
-
747
- print(f"Starting server on port {port} with base_url: {base_url}")
748
-
749
- # 检查端口可用性
750
- try:
751
- port = asyncio.run(get_available_port(start_port=port))
752
- except RuntimeError as e:
753
- print(e)
754
- return
755
-
756
- print(f"Server running on available port: {port}")
757
-
758
- # 运行FastAPI应用
759
- uvicorn.run(app, host="0.0.0.0", port=port)
760
-
761
- if __name__ == "__main__":
762
- main()
 
1
+ import asyncio
2
+ import json
3
+ import sys
4
+ import uuid
5
+ import base64
6
+ import re
7
+ import os
8
+ import argparse
9
+ from datetime import datetime, timezone, timedelta
10
+ from typing import List, Optional
11
+
12
+ import httpx
13
+ import uvicorn
14
+ from fastapi import (
15
+ BackgroundTasks,
16
+ FastAPI,
17
+ HTTPException,
18
+ Request,
19
+ Response,
20
+ status,
21
+ )
22
+ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.staticfiles import StaticFiles
25
+
26
+ from bearer_token import BearerTokenGenerator
27
+
28
+ # 模型列表(根据需求,可自行调整)
29
+ MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
+
31
+ # 默认端口
32
+ INITIAL_PORT = 3000
33
+
34
+ # 外部API的URL
35
+ EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
36
+
37
+ # 定义 images 目录
38
+ IMAGES_DIR = "images"
39
+
40
+ # 确保 images 目录存在
41
+ def ensure_images_dir_exists(directory: str = IMAGES_DIR):
42
+ try:
43
+ os.makedirs(directory, exist_ok=True)
44
+ print(f"Directory '{directory}' is ready.")
45
+ except Exception as e:
46
+ print(f"Failed to create directory '{directory}': {e}")
47
+ sys.exit(1)
48
+
49
+ # 在挂载静态文件之前确保 images 目录存在
50
+ ensure_images_dir_exists()
51
+
52
+ # 初始化FastAPI应用
53
+ app = FastAPI()
54
+
55
+ # 挂载静态文件路由以提供 images 目录的内容
56
+ app.mount("/images", StaticFiles(directory=IMAGES_DIR), name="images")
57
+
58
+ # 添加CORS中间件
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=["*"], # 允许所有来源
62
+ allow_credentials=True,
63
+ allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
64
+ allow_headers=["Content-Type", "Authorization"], # 允许的头部
65
+ )
66
+
67
+ # 辅助函数
68
+ def send_error_response(message: str, status_code: int = 400):
69
+ """构建错误响应,并确保包含CORS头"""
70
+ error_json = {"error": message}
71
+ headers = {
72
+ "Access-Control-Allow-Origin": "*",
73
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
74
+ "Access-Control-Allow-Headers": "Content-Type, Authorization",
75
+ }
76
+ return JSONResponse(status_code=status_code, content=error_json, headers=headers)
77
+
78
+ def extract_path_from_markdown(markdown: str) -> Optional[str]:
79
+ """
80
+ 提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
81
+ """
82
+ pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
83
+ match = pattern.search(markdown)
84
+ if match:
85
+ return match.group(1)
86
+ return None
87
+
88
+ async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
89
+ """
90
+ 从 storage URL 获取 JSON 并提取 getUrl
91
+ """
92
+ async with httpx.AsyncClient() as client:
93
+ try:
94
+ response = await client.get(storage_url)
95
+ if response.status_code != 200:
96
+ print(f"获取 storage URL 失败,状态码: {response.status_code}")
97
+ return None
98
+ json_response = response.json()
99
+ return json_response.get("getUrl")
100
+ except Exception as e:
101
+ print(f"Error fetching getUrl from storage: {e}")
102
+ return None
103
+
104
+ async def download_image(image_url: str) -> Optional[bytes]:
105
+ """
106
+ 下载图像
107
+ """
108
+ async with httpx.AsyncClient() as client:
109
+ try:
110
+ response = await client.get(image_url)
111
+ if response.status_code == 200:
112
+ return response.content
113
+ else:
114
+ print(f"下载图像失败,状态码: {response.status_code}")
115
+ return None
116
+ except Exception as e:
117
+ print(f"Error downloading image: {e}")
118
+ return None
119
+
120
+ def cleanup_images(images_dir: str = IMAGES_DIR, age_seconds: int = 60):
121
+ """
122
+ 清理 images 目录中创建时间超过指定秒数的图片
123
+ """
124
+ now = datetime.now(timezone.utc)
125
+ cutoff_time = now - timedelta(seconds=age_seconds)
126
+
127
+ if not os.path.exists(images_dir):
128
+ return
129
+
130
+ for filename in os.listdir(images_dir):
131
+ file_path = os.path.join(images_dir, filename)
132
+ if os.path.isfile(file_path):
133
+ try:
134
+ file_creation_time = datetime.fromtimestamp(os.path.getctime(file_path), timezone.utc)
135
+ if file_creation_time < cutoff_time:
136
+ os.remove(file_path)
137
+ print(f"已删除旧图片: {filename}")
138
+ except Exception as e:
139
+ print(f"无法删除文件 {filename}: {e}")
140
+
141
+ def save_base64_image(base64_str: str, images_dir: str = IMAGES_DIR) -> str:
142
+ """
143
+ 将Base64编码的图片保存到images目录,返回文件名
144
+ """
145
+ # 先清理1分钟前的所有图片
146
+ cleanup_images(images_dir, age_seconds=60)
147
+
148
+ try:
149
+ image_data = base64.b64decode(base64_str)
150
+ except base64.binascii.Error as e:
151
+ print(f"Base64解码失败: {e}")
152
+ raise ValueError("Invalid base64 image data")
153
+
154
+ filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
155
+ file_path = os.path.join(images_dir, filename)
156
+ try:
157
+ with open(file_path, "wb") as f:
158
+ f.write(image_data)
159
+ print(f"保存图片: {filename}")
160
+ except Exception as e:
161
+ print(f"保存图片失败: {e}")
162
+ raise
163
+
164
+ return filename
165
+
166
+ def is_base64_image(url: str) -> bool:
167
+ """
168
+ 判断URL是否为Base64编码的图片
169
+ """
170
+ return url.startswith("data:image/")
171
+
172
+ # 根路径GET请求处理
173
+ @app.get("/", response_class=HTMLResponse)
174
+ async def read_root():
175
+ """返回欢迎页面"""
176
+ html_content = """
177
+ <html>
178
+ <head>
179
+ <title>Welcome to API</title>
180
+ </head>
181
+ <body>
182
+ <h1>Welcome to API</h1>
183
+ <p>You can send messages to the model and receive responses.</p>
184
+ </body>
185
+ </html>
186
+ """
187
+ return HTMLResponse(content=html_content, status_code=200)
188
+
189
+ # 聊天完成处理(保留原有逻辑,未修改)
190
+ @app.post("/ai/v1/chat/completions")
191
+ @app.post("/v1/chat/completions")
192
+ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
193
+ """
194
+ 处理聊天完成请求
195
+ """
196
+ try:
197
+ request_body = await request.json()
198
+ except json.JSONDecodeError:
199
+ raise HTTPException(status_code=400, detail="Invalid JSON")
200
+
201
+ # 打印接收到的请求
202
+ print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
203
+
204
+ # 处理消息内容
205
+ messages = request_body.get("messages", [])
206
+ temperature = request_body.get("temperature", 1.0)
207
+ #top_p = request_body.get("top_p", 1.0)
208
+ max_tokens = request_body.get("max_tokens", 8000)
209
+ model = request_body.get("model", "gpt-4o")
210
+ is_stream = request_body.get("stream", False) # 获取 stream 字段
211
+
212
+ # 验证模型
213
+ if model not in MODELS:
214
+ raise HTTPException(status_code=400, detail=f"无效的 model: {model}. 可用的模型有: {', '.join(MODELS)}")
215
+
216
+ has_image = False
217
+ has_text = False
218
+
219
+ # 清理和提取消息内容
220
+ cleaned_messages = []
221
+ for message in messages:
222
+ content = message.get("content", "")
223
+ if isinstance(content, list):
224
+ text_parts = []
225
+ images = []
226
+ for item in content:
227
+ if "text" in item:
228
+ text_parts.append(item.get("text", ""))
229
+ elif "image_url" in item:
230
+ has_image = True
231
+ image_info = item.get("image_url", {})
232
+ url = image_info.get("url", "")
233
+ if is_base64_image(url):
234
+ # 解码并保存图片
235
+ try:
236
+ base64_str = url.split(",")[1]
237
+ filename = save_base64_image(base64_str)
238
+ base_url = app.state.base_url
239
+ image_url = f"{base_url}/images/{filename}"
240
+ images.append({"data": image_url})
241
+ except (IndexError, ValueError) as e:
242
+ print(f"处理Base64图片失败: {e}")
243
+ continue
244
+ else:
245
+ images.append({"data": url})
246
+ extracted_content = " ".join(text_parts).strip()
247
+ if extracted_content:
248
+ has_text = True
249
+ message["content"] = extracted_content
250
+ if images:
251
+ message["images"] = images
252
+ cleaned_messages.append(message)
253
+ print("Extracted:", extracted_content)
254
+ else:
255
+ if images:
256
+ has_image = True
257
+ message["content"] = ""
258
+ message["images"] = images
259
+ cleaned_messages.append(message)
260
+ print("Extracted image only.")
261
+ else:
262
+ print("Deleted message with empty content.")
263
+ elif isinstance(content, str):
264
+ content_str = content.strip()
265
+ if content_str:
266
+ has_text = True
267
+ new_message = {"role": message.get("role"), "content": content_str}
268
+ cleaned_messages.append(new_message)
269
+ print("Retained content:", content_str)
270
+ else:
271
+ print("Deleted message with empty content.")
272
+ else:
273
+ print("Deleted non-expected type of content message.")
274
+
275
+ if not cleaned_messages:
276
+ raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
277
+
278
+ # 构建新的请求JSON
279
+ new_request_json = {
280
+ "function_image_gen": False,
281
+ "function_web_search": True,
282
+ "max_tokens": max_tokens,
283
+ "model": model,
284
+ "source": "chat/pro",
285
+ "temperature": temperature,
286
+ #"top_p": top_p,
287
+ "messages": cleaned_messages,
288
+ }
289
+
290
+ modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
291
+ print("Modified Request JSON:", modified_request_body)
292
+
293
+ # 获取Bearer Token
294
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
295
+ if not tmp_token:
296
+ raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
297
+
298
+ bearer_token, formatted_date = tmp_token
299
+
300
+ headers = {
301
+ "Date": formatted_date,
302
+ "Client-time-zone": "-05:00",
303
+ "Authorization": bearer_token,
304
+ "User-Agent": "ChatOn_Android/1.55.488",
305
+ "Accept-Language": "en-US",
306
+ "X-Cl-Options": "hb",
307
+ "Content-Type": "application/json; charset=UTF-8",
308
+ }
309
+
310
+ if is_stream:
311
+ import uuid
312
+ from datetime import datetime, timezone
313
+
314
+ # 定义 should_filter_out 函数
315
+ def should_filter_out(json_data):
316
+ if 'ping' in json_data:
317
+ return True
318
+ if 'data' in json_data:
319
+ data = json_data['data']
320
+ if 'analytics' in data:
321
+ return True
322
+ if 'operation' in data and 'message' in data:
323
+ return True
324
+ return False
325
+
326
+ # 定义 generate_id 函数
327
+ def generate_id():
328
+ return uuid.uuid4().hex[:24]
329
+
330
+ # 流式响应处理
331
+ async def event_generator():
332
+ async with httpx.AsyncClient(timeout=None) as client_stream:
333
+ try:
334
+ async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
335
+ async for line in streamed_response.aiter_lines():
336
+ if line.startswith("data: "):
337
+ data = line[6:].strip()
338
+ if data == "[DONE]":
339
+ # 通知客户端流结束
340
+ yield "data: [DONE]\n\n"
341
+ break
342
+ try:
343
+ sse_json = json.loads(data)
344
+
345
+ # 判断是否需要过滤
346
+ if should_filter_out(sse_json):
347
+ continue
348
+
349
+ # 处理包含 web sources 的消息
350
+ if 'data' in sse_json and 'web' in sse_json['data']:
351
+ web_data = sse_json['data']['web']
352
+ if 'sources' in web_data:
353
+ sources = web_data['sources']
354
+ urls_list = []
355
+ for source in sources:
356
+ if 'url' in source:
357
+ urls_list.append(source['url'])
358
+ urls_content = '\n\n'.join(urls_list)
359
+ print(f"从 API 接收到的内容: {urls_content}")
360
+ # 构造新的 SSE 消息,填入 content 字段
361
+ new_sse_json = {
362
+ "id": generate_id(),
363
+ "object": "chat.completion.chunk",
364
+ "created": int(datetime.now(timezone.utc).timestamp()),
365
+ "model": sse_json.get("model", "gpt-4o"),
366
+ "choices": [
367
+ {
368
+ "delta": {"content": "\n" + urls_content + "\n"},
369
+ "index": 0,
370
+ "finish_reason": None
371
+ }
372
+ ]
373
+ }
374
+ new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
375
+ yield new_sse_line
376
+ else:
377
+ # 尝试打印内容
378
+ if 'choices' in sse_json:
379
+ for choice in sse_json['choices']:
380
+ delta = choice.get('delta', {})
381
+ content = delta.get('content')
382
+ if content:
383
+ print(content, end='')
384
+ # 直接转发其他消息
385
+ yield f"data: {data}\n\n"
386
+ except json.JSONDecodeError as e:
387
+ print(f"JSON解析错误: {e}")
388
+ continue
389
+ else:
390
+ # 忽略不以 "data: " 开头的行
391
+ continue
392
+ except httpx.RequestError as exc:
393
+ print(f"外部API请求失败: {exc}")
394
+ yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
395
+
396
+ return StreamingResponse(
397
+ event_generator(),
398
+ media_type="text/event-stream",
399
+ headers={
400
+ "Cache-Control": "no-cache",
401
+ "Connection": "keep-alive",
402
+ # CORS头已通过中间件处理,无需在这里重复添加
403
+ },
404
+ )
405
+ else:
406
+ # 非流式响应处理
407
+ async with httpx.AsyncClient(timeout=None) as client:
408
+ try:
409
+ response = await client.post(
410
+ EXTERNAL_API_URL,
411
+ headers=headers,
412
+ content=modified_request_body,
413
+ timeout=None
414
+ )
415
+
416
+ if response.status_code != 200:
417
+ raise HTTPException(
418
+ status_code=response.status_code,
419
+ detail=f"API 错误: {response.status_code}",
420
+ )
421
+
422
+ sse_lines = response.text.splitlines()
423
+ content_builder = ""
424
+ images_urls = []
425
+
426
+ for line in sse_lines:
427
+ if line.startswith("data: "):
428
+ data = line[6:].strip()
429
+ if data == "[DONE]":
430
+ break
431
+ try:
432
+ sse_json = json.loads(data)
433
+ if "choices" in sse_json:
434
+ for choice in sse_json["choices"]:
435
+ if "delta" in choice:
436
+ delta = choice["delta"]
437
+ if "content" in delta:
438
+ content_builder += delta["content"]
439
+ except json.JSONDecodeError:
440
+ print("JSON解析错误")
441
+ continue
442
+
443
+ openai_response = {
444
+ "id": f"chatcmpl-{uuid.uuid4()}",
445
+ "object": "chat.completion",
446
+ "created": int(datetime.now(timezone.utc).timestamp()),
447
+ "model": model, # 使用用户传入的model
448
+ "choices": [
449
+ {
450
+ "index": 0,
451
+ "message": {
452
+ "role": "assistant",
453
+ "content": content_builder,
454
+ },
455
+ "finish_reason": "stop",
456
+ }
457
+ ],
458
+ }
459
+
460
+ # 处理图片(如果有)
461
+ if has_image:
462
+ images = []
463
+ for message in cleaned_messages:
464
+ if "images" in message:
465
+ for img in message["images"]:
466
+ images.append({"data": img["data"]})
467
+ openai_response["choices"][0]["message"]["images"] = images
468
+
469
+ return JSONResponse(content=openai_response, status_code=200)
470
+ except httpx.RequestError as exc:
471
+ raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
472
+ except Exception as exc:
473
+ raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
474
+
475
+ # 图像生成处理
476
+ @app.post("/ai/v1/images/generations")
477
+ @app.post("/v1/images/generations")
478
+ async def images_generations(request: Request):
479
+ """
480
+ 处理图像生成请求
481
+ """
482
+ try:
483
+ request_body = await request.json()
484
+ except json.JSONDecodeError:
485
+ return send_error_response("Invalid JSON", status_code=400)
486
+
487
+ print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
488
+
489
+ # 验证必需的字段
490
+ if "prompt" not in request_body:
491
+ return send_error_response("缺少必需的字段: prompt", status_code=400)
492
+
493
+ user_prompt = request_body.get("prompt", "").strip()
494
+ response_format = request_body.get("response_format", "").strip().lower()
495
+ model = request_body.get("model", "gpt-4o")
496
+ n = request_body.get("n", 1) # 生成图像数量,默认1
497
+ size = request_body.get("size", "1024x1024") # 图片尺寸,默认1024x1024
498
+
499
+ is_base64_response = response_format == "b64_json"
500
+
501
+ if not user_prompt:
502
+ return send_error_response("Prompt 不能为空。", status_code=400)
503
+
504
+ print(f"Prompt: {user_prompt}")
505
+ print(f"Response Format: {response_format}")
506
+ print(f"Number of images to generate (n): {n}")
507
+ print(f"Size: {size}")
508
+
509
+ # 设置最大尝试次数为 2 * n
510
+ max_attempts = 2 * n
511
+ print(f"Max Attempts: {max_attempts}")
512
+
513
+ # 初始化用于存储多个 URL 的线程安全列表
514
+ final_download_urls: List[str] = []
515
+
516
+ async def attempt_generate_image(attempt: int) -> Optional[str]:
517
+ """
518
+ 尝试生成单张图像,带有重试机制。
519
+ """
520
+ try:
521
+ # 构建新的 TextToImage JSON 请求体
522
+ text_to_image_json = {
523
+ "function_image_gen": True,
524
+ "function_web_search": True,
525
+ "image_aspect_ratio": "1:1", # 图片比例可选:1:1/9:19/16:9/4:3
526
+ "image_style": "photographic", # 固定 image_style,可根据需要调整
527
+ "max_tokens": 8000,
528
+ "messages": [
529
+ {
530
+ "content": "You are a helpful artist, please draw a picture. Based on imagination, draw a picture with user message.",
531
+ "role": "system"
532
+ },
533
+ {
534
+ "content": "Draw: " + user_prompt,
535
+ "role": "user"
536
+ }
537
+ ],
538
+ "model": "gpt-4o",
539
+ "source": "chat/free" # 固定 source
540
+ }
541
+
542
+ modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
543
+ print(f"Attempt {attempt} - Modified Request JSON: {modified_request_body}")
544
+
545
+ # 获取Bearer Token
546
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
547
+ if not tmp_token:
548
+ print(f"Attempt {attempt} - 无法生成 Bearer Token")
549
+ return None
550
+
551
+ bearer_token, formatted_date = tmp_token
552
+
553
+ headers = {
554
+ "Date": formatted_date,
555
+ "Client-time-zone": "-05:00",
556
+ "Authorization": bearer_token,
557
+ "User-Agent": "ChatOn_Android/1.53.502",
558
+ "Accept-Language": "en-US",
559
+ "X-Cl-Options": "hb",
560
+ "Content-Type": "application/json; charset=UTF-8",
561
+ }
562
+
563
+ async with httpx.AsyncClient(timeout=None) as client:
564
+ response = await client.post(
565
+ EXTERNAL_API_URL,
566
+ headers=headers,
567
+ content=modified_request_body,
568
+ timeout=None
569
+ )
570
+
571
+ if response.status_code != 200:
572
+ print(f"Attempt {attempt} - API 错误: {response.status_code}")
573
+ return None
574
+
575
+ # 读取 SSE 流并提取图像URL
576
+ sse_lines = response.text.splitlines()
577
+ image_markdown = ""
578
+
579
+ for line in sse_lines:
580
+ if line.startswith("data: "):
581
+ data = line[6:].strip()
582
+ if data == "[DONE]":
583
+ break
584
+ try:
585
+ sse_json = json.loads(data)
586
+ if "choices" in sse_json:
587
+ for choice in sse_json["choices"]:
588
+ delta = choice.get("delta", {})
589
+ content = delta.get("content")
590
+ if content:
591
+ image_markdown += content
592
+ except json.JSONDecodeError:
593
+ print(f"Attempt {attempt} - JSON解析错误")
594
+ continue
595
+
596
+ # 检查Markdown文本是否为空
597
+ if not image_markdown:
598
+ print(f"Attempt {attempt} - 无法从 SSE 流中构建图像 Markdown。")
599
+ return None
600
+
601
+ # 从Markdown中提取图像路径
602
+ extracted_path = extract_path_from_markdown(image_markdown)
603
+ if not extracted_path:
604
+ print(f"Attempt {attempt} - 无法从 Markdown 中提取路径。")
605
+ return None
606
+
607
+ print(f"Attempt {attempt} - 提取的路径: {extracted_path}")
608
+
609
+ # 拼接最终的存储URL
610
+ storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
611
+ print(f"Attempt {attempt} - 存储URL: {storage_url}")
612
+
613
+ # 获取最终下载URL
614
+ final_download_url = await fetch_get_url_from_storage(storage_url)
615
+ if not final_download_url:
616
+ print(f"Attempt {attempt} - 无法从 storage URL 获取最终下载链接。")
617
+ return None
618
+
619
+ print(f"Attempt {attempt} - Final Download URL: {final_download_url}")
620
+
621
+ return final_download_url
622
+ except Exception as e:
623
+ print(f"Attempt {attempt} - 处理响应时发生错误: {e}")
624
+ return None
625
+
626
+ # 定义一个异步任务池,限制并发数量
627
+ semaphore = asyncio.Semaphore(10) # 限制同时进行的任务数为10
628
+
629
+ async def generate_with_retries(attempt: int) -> Optional[str]:
630
+ async with semaphore:
631
+ return await attempt_generate_image(attempt)
632
+
633
+ # 开始尝试生成图像
634
+ for attempt in range(1, max_attempts + 1):
635
+ needed = n - len(final_download_urls)
636
+ if needed <= 0:
637
+ break
638
+
639
+ print(f"Attempt {attempt} - 需要生成的图像数量: {needed}")
640
+
641
+ # 创建多个任务同时生成所需数量的图像
642
+ tasks = [asyncio.create_task(generate_with_retries(attempt)) for _ in range(needed)]
643
+
644
+ # 等待所有任务完成
645
+ results = await asyncio.gather(*tasks, return_exceptions=True)
646
+
647
+ for result in results:
648
+ if isinstance(result, Exception):
649
+ print(f"Attempt {attempt} - 任务发生异常: {result}")
650
+ continue
651
+ if result:
652
+ final_download_urls.append(result)
653
+ print(f"Attempt {attempt} - 收集到下载链接: {result}")
654
+
655
+ # 检查是否已经收集到足够的下载链接
656
+ if len(final_download_urls) >= n:
657
+ break
658
+
659
+ # 检查是否收集到足够的链接
660
+ if len(final_download_urls) < n:
661
+ print("已达到最大尝试次数,仍未收集到足够数量的下载链接。")
662
+ return send_error_response("无法生成足够数量的图像。", status_code=500)
663
+
664
+ # 根据 response_format 返回相应的响应
665
+ data_array = []
666
+
667
+ if is_base64_response:
668
+ for download_url in final_download_urls[:n]:
669
+ try:
670
+ image_bytes = await download_image(download_url)
671
+ if not image_bytes:
672
+ print(f"无法从 URL 下载图像: {download_url}")
673
+ continue
674
+
675
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
676
+ data_array.append({
677
+ "b64_json": image_base64
678
+ })
679
+ except Exception as e:
680
+ print(f"处理图像时发生错误: {e}")
681
+ continue
682
+ else:
683
+ for download_url in final_download_urls[:n]:
684
+ data_array.append({
685
+ "url": download_url
686
+ })
687
+
688
+ # 如果收集的 URL 数量不足 n,则通过复制现有的 URL 来填充
689
+ while len(data_array) < n and len(data_array) > 0:
690
+ for item in data_array.copy():
691
+ if len(data_array) >= n:
692
+ break
693
+ data_array.append(item)
694
+
695
+ # 构建最终响应
696
+ response_json = {
697
+ "created": int(datetime.now(timezone.utc).timestamp()),
698
+ "data": data_array
699
+ }
700
+
701
+ # 如果data_array为空,返回错误
702
+ if not data_array:
703
+ return send_error_response("无法生成图像。", status_code=500)
704
+
705
+ # 返回响应
706
+ return JSONResponse(content=response_json, status_code=200)
707
+
708
+ @app.get("/ai/v1/models", response_class=JSONResponse)
709
+ @app.get("/v1/models", response_class=JSONResponse)
710
+ async def get_models():
711
+ models_data = {
712
+ "object": "list",
713
+ "data": [
714
+ {"id": "gpt-4o", "object": "model"},
715
+ {"id": "gpt-4o-mini", "object": "model"},
716
+ {"id": "claude-3-5-sonnet", "object": "model"},
717
+ {"id": "claude", "object": "model"}
718
+ ]
719
+ }
720
+ return models_data
721
+
722
+ async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
723
+ """查找可用的端口号"""
724
+ for port in range(start_port, end_port + 1):
725
+ try:
726
+ # 尝试绑定端口
727
+ server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
728
+ server.close()
729
+ await server.wait_closed()
730
+ return port
731
+ except OSError:
732
+ continue
733
+ raise RuntimeError(f"No available ports between {start_port} and {end_port}")
734
+
735
+ def main():
736
+ parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
737
+ parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
738
+ parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
739
+ args = parser.parse_args()
740
+
741
+ base_url = args.base_url
742
+ port = args.port
743
+
744
+ # 设置 FastAPI 应用的 state
745
+ app.state.base_url = base_url
746
+
747
+ print(f"Starting server on port {port} with base_url: {base_url}")
748
+
749
+ # 检查端口可用性
750
+ try:
751
+ port = asyncio.run(get_available_port(start_port=port))
752
+ except RuntimeError as e:
753
+ print(e)
754
+ return
755
+
756
+ print(f"Server running on available port: {port}")
757
+
758
+ # 运行FastAPI应用
759
+ uvicorn.run(app, host="0.0.0.0", port=port)
760
+
761
+ if __name__ == "__main__":
762
+ main()